feat(pt): add FSDP & ZeRO1 (Zero Redundancy Optimizer) support#5222
feat(pt): add FSDP & ZeRO1 (Zero Redundancy Optimizer) support#5222OutisLi wants to merge 2 commits intodeepmodeling:masterfrom
Conversation
Summary of ChangesHello @OutisLi, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request enhances the PyTorch backend's distributed training capabilities by integrating ZeRO (Zero Redundancy Optimizer) stages, including FSDP2. This allows for more efficient memory utilization on GPUs, which is crucial for training large models. The changes introduce a configurable Highlights
Changelog
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces support for ZeRO stages 1, 2, and 3 in the PyTorch backend, enabling more memory-efficient distributed training. The implementation correctly uses ZeroRedundancyOptimizer for stage 1 and FSDP for stages 2 and 3. The logic for creating optimizers and handling model/optimizer state dicts during loading and saving has been refactored to support these new distributed strategies. The changes are well-structured and include excellent updates to both code comments and user documentation. My main feedback concerns a potential regression in the DDP path that could affect existing single-task models.
There was a problem hiding this comment.
Pull request overview
This PR adds ZeRO (Zero Redundancy Optimizer) and FSDP2 support to the PyTorch backend, enabling memory-efficient distributed training through optimizer state and parameter sharding.
Changes:
- Added
zero_stageconfiguration parameter (0-3) to control memory optimization strategy - Implemented FSDP2 integration for stages 2-3 and ZeroRedundancyOptimizer for stage 1
- Updated checkpoint saving/loading to handle distributed state collection
Reviewed changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 5 comments.
| File | Description |
|---|---|
| doc/train/parallel-training.md | Adds documentation explaining ZeRO stages, memory savings, communication patterns, and usage constraints |
| deepmd/utils/argcheck.py | Adds zero_stage parameter definition with detailed documentation |
| deepmd/pt/train/training.py | Implements ZeRO/FSDP2 logic including optimizer wrapping, state dict handling, and distributed checkpoint operations |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| for b in self.wrapper.buffers(): | ||
| dist.broadcast(b.data, src=0) | ||
| reshard = self.zero_stage >= 3 | ||
| fully_shard(self.wrapper, reshard_after_forward=reshard) |
There was a problem hiding this comment.
The fully_shard function modifies self.wrapper in-place but this is not reflected in the variable assignment. Consider either documenting this behavior with a comment or assigning the return value to make the in-place modification explicit: self.wrapper = fully_shard(self.wrapper, reshard_after_forward=reshard).
| fully_shard(self.wrapper, reshard_after_forward=reshard) | |
| self.wrapper = fully_shard(self.wrapper, reshard_after_forward=reshard) |
| ) | ||
| or (display_step_id) == self.num_steps | ||
| ) and (self.rank == 0 or dist.get_rank() == 0): | ||
| ) and (self.zero_stage > 0 or self.rank == 0 or dist.get_rank() == 0): |
There was a problem hiding this comment.
The condition self.zero_stage > 0 or self.rank == 0 or dist.get_rank() == 0 is redundant since self.rank already equals dist.get_rank() when distributed is initialized. Simplify to self.zero_stage > 0 or self.rank == 0.
| symlink_prefix_files(self.latest_model.stem, self.save_ckpt) | ||
| with open("checkpoint", "w") as f: | ||
| f.write(str(self.latest_model)) | ||
| if self.rank == 0 or dist.get_rank() == 0: |
There was a problem hiding this comment.
The condition self.rank == 0 or dist.get_rank() == 0 is redundant since self.rank is set to dist.get_rank() when distributed. Simplify to if self.rank == 0:.
| self.save_model(self.latest_model, lr=0, step=0) | ||
|
|
||
| if ( | ||
| self.rank == 0 or dist.get_rank() == 0 |
There was a problem hiding this comment.
The condition self.rank == 0 or dist.get_rank() == 0 is redundant. Simplify to self.rank == 0.
| # Avoid error_if_nonfinite=True: FSDP2 sharded | ||
| # DTensor gradients may not support it. Manual | ||
| # isfinite check achieves the same fail-fast behavior. |
There was a problem hiding this comment.
The comment spans three lines but could be more concise. Consider condensing to: # FSDP2 sharded DTensor gradients don't support error_if_nonfinite; use manual isfinite check instead.
| # Avoid error_if_nonfinite=True: FSDP2 sharded | |
| # DTensor gradients may not support it. Manual | |
| # isfinite check achieves the same fail-fast behavior. | |
| # FSDP2 sharded DTensor gradients don't support error_if_nonfinite; use manual isfinite check instead. |
📝 WalkthroughWalkthroughAdds configurable ZeRO/FSDP support to PyTorch training: new Changes
Sequence DiagramsequenceDiagram
participant User
participant Trainer
participant Model as ModelWrapper
participant Optim as Optimizer
participant Store as Checkpoint
User->>Trainer: Init with config (zero_stage)
Trainer->>Trainer: Validate zero_stage & is_distributed
Trainer->>Trainer: Set rank/world_size defaults
alt zero_stage >= 2 (FSDP2)
Trainer->>Model: Broadcast params from rank 0
Trainer->>Model: fully_shard / FSDP2 wrap
else zero_stage == 1 (ZeRO-1)
Trainer->>Model: Wrap with DDP
Trainer->>Optim: Wrap with ZeroRedundancyOptimizer
else zero_stage == 0 (Standard)
Trainer->>Model: Wrap with DDP
end
User->>Model: Forward
Model->>Trainer: Backward
Trainer->>Optim: optimizer.step (with device/context handling)
Trainer->>Trainer: Clip grads (check finite)
alt Save checkpoint
Trainer->>Trainer: Gather model/optimizer state (sharded or consolidated)
Trainer->>Store: Rank-0 writes checkpoint
end
Trainer->>User: Training complete
Estimated code review effort🎯 4 (Complex) | ⏱️ ~50 minutes Suggested reviewers
🚥 Pre-merge checks | ✅ 3 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
deepmd/utils/argcheck.py (1)
3276-3421:⚠️ Potential issue | 🟡 MinorValidate
zero_stagerange (0–3).Docs define stages 0–3, but the config accepts any integer. A simple range check will prevent silent fall-through to unintended behavior.
✅ Suggested validation
def training_args( multi_task: bool = False, ) -> list[Argument]: @@ doc_zero_stage = ( "ZeRO optimization stage for distributed training memory reduction. " @@ "Currently supports single-task training; does not support LKF or change_bias_after_training." ) + def _check_zero_stage(val: int) -> bool: + if val not in (0, 1, 2, 3): + raise ValueError("training.zero_stage must be one of {0,1,2,3}.") + return True @@ Argument( "zero_stage", int, optional=True, default=0, doc=doc_only_pt_supported + doc_zero_stage, + extra_check=_check_zero_stage, ),deepmd/pt/train/training.py (1)
771-783:⚠️ Potential issue | 🟠 MajorAdaMuon path misses scheduler and resume state load.
step()usesself.schedulerfor AdaMuon, but the AdaMuon branch doesn’t initialize it; resume also won’t load optimizer state. This will raise at runtime and break restart training. Add the same scheduler/state-load block used by Adam/HybridMuon.🐛 Proposed fix
elif self.opt_type == "AdaMuon": self.optimizer = self._create_optimizer( AdaMuonOptimizer, lr=self.lr_exp.start_lr, momentum=float(self.opt_param["momentum"]), weight_decay=float(self.opt_param["weight_decay"]), adam_betas=( float(self.opt_param["adam_beta1"]), float(self.opt_param["adam_beta2"]), ), lr_adjust=float(self.opt_param["lr_adjust"]), lr_adjust_coeff=float(self.opt_param["lr_adjust_coeff"]), ) + self._load_optimizer_state(optimizer_state_dict) + self.scheduler = torch.optim.lr_scheduler.LambdaLR( + self.optimizer, + lambda step: warm_up_linear(step + self.start_step, self.warmup_steps), + )
🧹 Nitpick comments (1)
doc/train/parallel-training.md (1)
101-145: Clarify backend activation and the 2Ψ/3Ψ notation in the ZeRO section.The constraints mention “PyTorch backend” but don’t remind users how to enable it; also “2Ψ/3Ψ” is undefined and may be opaque to new readers. Consider a short clarification to reduce confusion.
✍️ Suggested doc tweak
-| `zero_stage` | Strategy | Communication | Memory saving | +| `zero_stage` | Strategy | Communication (Ψ = model size) | Memory saving | | ------------ | ----------------------------- | ------------- | --------------------------------------------- | ... +*Ψ denotes the model size in bytes; communication volumes are relative to it.* ... Constraints: -- Works only in PyTorch backend. +- Works only in PyTorch backend (e.g., `dp --pt` or `export DP_BACKEND=pytorch`; use `input_torch.json`).Based on learnings: Use PyTorch backend with
--ptflag orexport DP_BACKEND=pytorch; useinput_torch.jsonfor configuration typically.
There was a problem hiding this comment.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
deepmd/pt/train/training.py (1)
771-782:⚠️ Potential issue | 🔴 CriticalBug: AdaMuon optimizer state is never restored on restart.
_load_optimizer_stateis called for Adam/AdamW (line 761) and HybridMuon (line 798), but not for AdaMuon. This means restarting training with AdaMuon will silently discard the saved optimizer state (momentum buffers, adaptive learning rates, etc.), causing training regression.🐛 Proposed fix
lr_adjust=float(self.opt_param["lr_adjust"]), lr_adjust_coeff=float(self.opt_param["lr_adjust_coeff"]), ) + self._load_optimizer_state(optimizer_state_dict) + self.scheduler = torch.optim.lr_scheduler.LambdaLR( + self.optimizer, + lambda step: warm_up_linear(step + self.start_step, self.warmup_steps), + ) elif self.opt_type == "HybridMuon":Note: If AdaMuon also needs a scheduler (like Adam and HybridMuon), it should be added here as well. Please verify the intended behavior.
🤖 Fix all issues with AI agents
In `@deepmd/pt/train/training.py`:
- Line 164: The assignment to self.zero_stage using
int(training_params.get("zero_stage", 0)) lacks upper-bound validation; after
parsing zero_stage, validate that self.zero_stage is between 0 and 3 inclusive
(e.g., if not (0 <= self.zero_stage <= 3) raise a ValueError with a clear
message including the invalid value), or alternatively clamp/normalize it if
that behavior is preferred; update the code around the self.zero_stage
assignment (use the same training_params.get and self.zero_stage symbols) to
enforce this check and surface an explicit error when out-of-range values are
provided.
🧹 Nitpick comments (1)
deepmd/pt/train/training.py (1)
715-733: FSDP2 applies only at root level — submodule sharding would improve memory savings.
fully_shardat line 725 is applied only toself.wrapper(the root module). For Stage 2/3 to achieve meaningful memory savings on large models, you'd typically want to callfully_shardon individual submodules (e.g., transformer layers) before the root. The current approach is functionally correct but may not deliver the expected memory reduction for large models.
| self.change_bias_after_training = training_params.get( | ||
| "change_bias_after_training", False | ||
| ) | ||
| self.zero_stage = int(training_params.get("zero_stage", 0)) |
There was a problem hiding this comment.
Missing upper-bound validation for zero_stage.
Values above 3 (e.g., zero_stage=4) are silently accepted and treated as Stage 3. Consider validating the range:
Proposed fix
self.zero_stage = int(training_params.get("zero_stage", 0))
+if self.zero_stage not in (0, 1, 2, 3):
+ raise ValueError(
+ f"training.zero_stage must be 0, 1, 2, or 3, got {self.zero_stage}"
+ )
if self.zero_stage > 0 and not self.is_distributed:📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| self.zero_stage = int(training_params.get("zero_stage", 0)) | |
| self.zero_stage = int(training_params.get("zero_stage", 0)) | |
| if self.zero_stage not in (0, 1, 2, 3): | |
| raise ValueError( | |
| f"training.zero_stage must be 0, 1, 2, or 3, got {self.zero_stage}" | |
| ) | |
| if self.zero_stage > 0 and not self.is_distributed: |
🤖 Prompt for AI Agents
In `@deepmd/pt/train/training.py` at line 164, The assignment to self.zero_stage
using int(training_params.get("zero_stage", 0)) lacks upper-bound validation;
after parsing zero_stage, validate that self.zero_stage is between 0 and 3
inclusive (e.g., if not (0 <= self.zero_stage <= 3) raise a ValueError with a
clear message including the invalid value), or alternatively clamp/normalize it
if that behavior is preferred; update the code around the self.zero_stage
assignment (use the same training_params.get and self.zero_stage symbols) to
enforce this check and surface an explicit error when out-of-range values are
provided.
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #5222 +/- ##
==========================================
- Coverage 82.07% 82.04% -0.03%
==========================================
Files 732 732
Lines 73974 74031 +57
Branches 3615 3615
==========================================
+ Hits 60711 60736 +25
- Misses 12100 12130 +30
- Partials 1163 1165 +2 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Signed-off-by: OutisLi <[email protected]>
There was a problem hiding this comment.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
deepmd/pt/train/training.py (1)
770-782:⚠️ Potential issue | 🟠 MajorAdaMuon is missing
_load_optimizer_statecall and scheduler setup — restart will lose optimizer state.The
AdaMuonoptimizer (lines 770-782) does not callself._load_optimizer_state(optimizer_state_dict)after creation, unlike bothAdam/AdamW(line 761) andHybridMuon(line 798). It also lacks the learning rate scheduler setup. On training restart, the saved optimizer state (momentum, adaptive learning rates) will be silently discarded, and the learning rate schedule will not be initialized.Proposed fix
lr_adjust=float(self.opt_param["lr_adjust"]), lr_adjust_coeff=float(self.opt_param["lr_adjust_coeff"]), ) + self._load_optimizer_state(optimizer_state_dict) + self.scheduler = torch.optim.lr_scheduler.LambdaLR( + self.optimizer, + lambda step: warm_up_linear(step + self.start_step, self.warmup_steps), + ) elif self.opt_type == "HybridMuon":
🤖 Fix all issues with AI agents
In `@deepmd/pt/train/training.py`:
- Around line 84-95: The unconditional imports of fully_shard and the state_dict
utilities (StateDictOptions, get_model_state_dict, get_optimizer_state_dict,
set_optimizer_state_dict) will raise ImportError on PyTorch <2.4; guard these
imports with the same PyTorch version check used elsewhere (e.g., the existing
if torch.__version__.startswith("2") block) or alternatively update the module
documentation to state PyTorch 2.4+ is required; locate the import block
importing fully_shard and the state_dict symbols in training.py and either wrap
them in the version check or add a clear top-level requirement comment
mentioning PyTorch>=2.4.
🧹 Nitpick comments (1)
deepmd/pt/train/training.py (1)
715-733: Consider sharding submodules individually for better FSDP2 memory efficiency.Currently,
fully_shardis applied only to the rootself.wrapper. For large models, applyingfully_shardto individual submodules (e.g., each layer) before the root provides finer-grained sharding and better peak memory reduction. The current approach is correct but may not yield the full memory savings FSDP2 is designed for.
| from torch.distributed.checkpoint.state_dict import ( | ||
| StateDictOptions, | ||
| get_model_state_dict, | ||
| get_optimizer_state_dict, | ||
| set_optimizer_state_dict, | ||
| ) | ||
| from torch.distributed.fsdp import ( | ||
| fully_shard, | ||
| ) | ||
| from torch.distributed.optim import ( | ||
| ZeroRedundancyOptimizer, | ||
| ) |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
cat -n deepmd/pt/train/training.py | sed -n '75,100p'Repository: deepmodeling/deepmd-kit
Length of output: 807
🌐 Web query:
PyTorch fully_shard FSDP2 minimum version requirement introduced
💡 Result:
torch.distributed.fsdp.fully_shard (the FSDP2 “fully_shard” frontend) was introduced as part of PyTorch 2.4’s prototype FSDP2 release, so the minimum PyTorch version requirement is 2.4.0. [1], [2], [3]
Sources: PyTorch 2.4 release blog announcing prototype FSDP2 [1]; PyTorch blog noting “FSDP v2 … released as part of PyTorch 2.4” [2]; downstream projects documenting “PyTorch version >= 2.4.0 is required” for the FSDP2 API [3].
🌐 Web query:
torch.distributed.checkpoint.state_dict minimum PyTorch version
💡 Result:
torch.distributed.checkpoint.state_dict (the helper APIs like get_state_dict, set_state_dict, StateDictOptions, get_model_state_dict, etc.) is available in PyTorch >= 2.4 (it’s present in the v2.4.0 codebase and discussed as part of “new” DCP state-dict/flattening behavior in 2.4). [2]
For context: Distributed Checkpoint (DCP) itself exists earlier (e.g., PyTorch 2.0 has torch.distributed.checkpoint.save_state_dict / load_state_dict), but that older API surface is different from torch.distributed.checkpoint.state_dict.*. [1]
Sources: PyTorch 2.0 DCP docs (no state_dict.* helpers) [1]; PyTorch issue referencing “new … introduced in torch 2.4” behavior and v2.4.0 DCP internals [2]; current DCP docs showing torch.distributed.checkpoint.state_dict.* APIs [3].
[1] https://pytorch.cadn.net.cn/docs_en/2.0/distributed.checkpoint.html
[2] pytorch/pytorch#137327
[3] https://docs.pytorch.org/docs/stable/distributed.checkpoint.html
🏁 Script executed:
rg "torch.__version__" deepmd/pt/train/training.py -A 2 -B 2Repository: deepmodeling/deepmd-kit
Length of output: 133
🏁 Script executed:
rg "zero_stage" deepmd/pt/train/training.py | head -20Repository: deepmodeling/deepmd-kit
Length of output: 1302
Guard these imports behind a PyTorch 2.4+ version check, or document the new minimum version requirement.
fully_shard (FSDP2) and the torch.distributed.checkpoint.state_dict utilities (get_model_state_dict, get_optimizer_state_dict, StateDictOptions, set_optimizer_state_dict) were introduced in PyTorch 2.4 and are imported unconditionally at lines 84–95. This will cause an ImportError at module load time for anyone using PyTorch < 2.4, including users who don't use zero_stage > 0 (where the default is 0).
Either guard these imports like the existing check on line 80 (if torch.__version__.startswith("2")), or update documentation to declare PyTorch 2.4+ as a new global requirement for this module.
🤖 Prompt for AI Agents
In `@deepmd/pt/train/training.py` around lines 84 - 95, The unconditional imports
of fully_shard and the state_dict utilities (StateDictOptions,
get_model_state_dict, get_optimizer_state_dict, set_optimizer_state_dict) will
raise ImportError on PyTorch <2.4; guard these imports with the same PyTorch
version check used elsewhere (e.g., the existing if
torch.__version__.startswith("2") block) or alternatively update the module
documentation to state PyTorch 2.4+ is required; locate the import block
importing fully_shard and the state_dict symbols in training.py and either wrap
them in the version check or add a clear top-level requirement comment
mentioning PyTorch>=2.4.
Summary by CodeRabbit
New Features
Bug Fixes
Documentation