Skip to content

feat(pt): add FSDP & ZeRO1 (Zero Redundancy Optimizer) support#5222

Open
OutisLi wants to merge 2 commits intodeepmodeling:masterfrom
OutisLi:pr/fsdp
Open

feat(pt): add FSDP & ZeRO1 (Zero Redundancy Optimizer) support#5222
OutisLi wants to merge 2 commits intodeepmodeling:masterfrom
OutisLi:pr/fsdp

Conversation

@OutisLi
Copy link
Collaborator

@OutisLi OutisLi commented Feb 13, 2026

Summary by CodeRabbit

  • New Features

    • Configurable ZeRO/FSDP memory-optimization stages (zero_stage 0–3) for distributed training, with proper weight/optimizer state sharding and rank-0 checkpoint writes.
    • Runtime option added to select zero_stage.
  • Bug Fixes

    • Added validations and enforced constraints to prevent incompatible training combos when memory-optimization is enabled (single-task, LKF, bias-change restrictions).
  • Documentation

    • Expanded PyTorch parallel-training docs with stage guidance, example config, constraints, and launch notes.

Copilot AI review requested due to automatic review settings February 13, 2026 08:01
@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @OutisLi, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request enhances the PyTorch backend's distributed training capabilities by integrating ZeRO (Zero Redundancy Optimizer) stages, including FSDP2. This allows for more efficient memory utilization on GPUs, which is crucial for training large models. The changes introduce a configurable zero_stage parameter, refactor core training components to support these optimizations, and include detailed documentation to guide users on its application and trade-offs.

Highlights

  • Distributed Training Memory Optimization: Introduced support for ZeRO (Zero Redundancy Optimizer) stages 1, 2, and 3, including FSDP2, to significantly reduce GPU memory usage during distributed training in the PyTorch backend.
  • Configurable zero_stage Parameter: Added a new zero_stage parameter (0-3) to control the level of ZeRO optimization, allowing users to choose between standard DDP, sharding optimizer states, sharding gradients, or sharding parameters.
  • Refactored Optimizer and Model Saving: Implemented new helper methods _create_optimizer and _load_optimizer_state to abstract the logic for optimizer instantiation and state loading, accommodating ZeRO-specific behaviors. Model saving logic was also updated to correctly handle sharded states.
  • Validation and Constraints: Added validation checks for the zero_stage parameter, ensuring it is used with distributed launch, single-task training, and is incompatible with the 'LKF' optimizer or change_bias_after_training.
  • Documentation Update: Provided comprehensive documentation for the new zero_stage feature, detailing each stage's strategy, communication overhead, memory savings, and usage constraints.
Changelog
  • deepmd/pt/train/training.py
    • Imported necessary modules for FSDP and ZeroRedundancyOptimizer.
    • Refactored distributed environment checks into is_distributed attribute.
    • Added zero_stage parameter with associated validation logic.
    • Modified DDP initialization to conditionally apply fully_shard for FSDP2.
    • Introduced _create_optimizer method to wrap optimizers with ZeroRedundancyOptimizer for ZeRO Stage 1.
    • Implemented _load_optimizer_state to handle optimizer state loading for different ZeRO stages.
    • Updated gradient clipping to manually check for non-finite gradients, compatible with FSDP2.
    • Adjusted model saving logic to correctly collect and save state dictionaries based on the active zero_stage.
  • deepmd/utils/argcheck.py
    • Added a new argument zero_stage to the training configuration with detailed documentation on its functionality and implications.
  • doc/train/parallel-training.md
    • Added a new section titled 'Optional ZeRO memory optimization' explaining the different ZeRO stages, their benefits, and constraints.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@dosubot dosubot bot added the new feature label Feb 13, 2026
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for ZeRO stages 1, 2, and 3 in the PyTorch backend, enabling more memory-efficient distributed training. The implementation correctly uses ZeroRedundancyOptimizer for stage 1 and FSDP for stages 2 and 3. The logic for creating optimizers and handling model/optimizer state dicts during loading and saving has been refactored to support these new distributed strategies. The changes are well-structured and include excellent updates to both code comments and user documentation. My main feedback concerns a potential regression in the DDP path that could affect existing single-task models.

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR adds ZeRO (Zero Redundancy Optimizer) and FSDP2 support to the PyTorch backend, enabling memory-efficient distributed training through optimizer state and parameter sharding.

Changes:

  • Added zero_stage configuration parameter (0-3) to control memory optimization strategy
  • Implemented FSDP2 integration for stages 2-3 and ZeroRedundancyOptimizer for stage 1
  • Updated checkpoint saving/loading to handle distributed state collection

Reviewed changes

Copilot reviewed 3 out of 3 changed files in this pull request and generated 5 comments.

File Description
doc/train/parallel-training.md Adds documentation explaining ZeRO stages, memory savings, communication patterns, and usage constraints
deepmd/utils/argcheck.py Adds zero_stage parameter definition with detailed documentation
deepmd/pt/train/training.py Implements ZeRO/FSDP2 logic including optimizer wrapping, state dict handling, and distributed checkpoint operations

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

for b in self.wrapper.buffers():
dist.broadcast(b.data, src=0)
reshard = self.zero_stage >= 3
fully_shard(self.wrapper, reshard_after_forward=reshard)
Copy link

Copilot AI Feb 13, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The fully_shard function modifies self.wrapper in-place but this is not reflected in the variable assignment. Consider either documenting this behavior with a comment or assigning the return value to make the in-place modification explicit: self.wrapper = fully_shard(self.wrapper, reshard_after_forward=reshard).

Suggested change
fully_shard(self.wrapper, reshard_after_forward=reshard)
self.wrapper = fully_shard(self.wrapper, reshard_after_forward=reshard)

Copilot uses AI. Check for mistakes.
)
or (display_step_id) == self.num_steps
) and (self.rank == 0 or dist.get_rank() == 0):
) and (self.zero_stage > 0 or self.rank == 0 or dist.get_rank() == 0):
Copy link

Copilot AI Feb 13, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The condition self.zero_stage > 0 or self.rank == 0 or dist.get_rank() == 0 is redundant since self.rank already equals dist.get_rank() when distributed is initialized. Simplify to self.zero_stage > 0 or self.rank == 0.

Copilot uses AI. Check for mistakes.
symlink_prefix_files(self.latest_model.stem, self.save_ckpt)
with open("checkpoint", "w") as f:
f.write(str(self.latest_model))
if self.rank == 0 or dist.get_rank() == 0:
Copy link

Copilot AI Feb 13, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The condition self.rank == 0 or dist.get_rank() == 0 is redundant since self.rank is set to dist.get_rank() when distributed. Simplify to if self.rank == 0:.

Copilot uses AI. Check for mistakes.
self.save_model(self.latest_model, lr=0, step=0)

if (
self.rank == 0 or dist.get_rank() == 0
Copy link

Copilot AI Feb 13, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The condition self.rank == 0 or dist.get_rank() == 0 is redundant. Simplify to self.rank == 0.

Copilot uses AI. Check for mistakes.
Comment on lines +988 to +990
# Avoid error_if_nonfinite=True: FSDP2 sharded
# DTensor gradients may not support it. Manual
# isfinite check achieves the same fail-fast behavior.
Copy link

Copilot AI Feb 13, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment spans three lines but could be more concise. Consider condensing to: # FSDP2 sharded DTensor gradients don't support error_if_nonfinite; use manual isfinite check instead.

Suggested change
# Avoid error_if_nonfinite=True: FSDP2 sharded
# DTensor gradients may not support it. Manual
# isfinite check achieves the same fail-fast behavior.
# FSDP2 sharded DTensor gradients don't support error_if_nonfinite; use manual isfinite check instead.

Copilot uses AI. Check for mistakes.
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Feb 13, 2026

📝 Walkthrough

Walkthrough

Adds configurable ZeRO/FSDP support to PyTorch training: new zero_stage argument, distributed initialization guards, conditional model/optimizer wrapping for ZeRO stages (0/1/≥2), sharded checkpoint save/load behavior, and validations preventing incompatible options.

Changes

Cohort / File(s) Summary
Distributed Training Implementation
deepmd/pt/train/training.py
Add zero_stage-driven logic: distributed guards (is_distributed), imports for StateDictOptions/FSDP/ZeRO utilities, conditional wrapping for FSDP2 (zero≥2), ZeRO-1 optimizer sharding, DDP fallback, new Trainer helpers _create_optimizer, _get_inner_module, _load_optimizer_state, adjusted gradient clipping, and rank-0 checkpoint semantics. Enforce constraints (no multi-task/LKF/change_bias_after_training with ZeRO).
Configuration Schema
deepmd/utils/argcheck.py
Add zero_stage training argument (int, default 0) with documentation describing ZeRO/FSDP stages and constraints (PyTorch-only, requires distributed launch).
Documentation
doc/train/parallel-training.md
Add "Optional ZeRO memory optimization" section describing zero_stage 0–3, selection guidance, example config snippet, and usage constraints for PyTorch distributed runs.

Sequence Diagram

sequenceDiagram
    participant User
    participant Trainer
    participant Model as ModelWrapper
    participant Optim as Optimizer
    participant Store as Checkpoint

    User->>Trainer: Init with config (zero_stage)
    Trainer->>Trainer: Validate zero_stage & is_distributed
    Trainer->>Trainer: Set rank/world_size defaults

    alt zero_stage >= 2 (FSDP2)
        Trainer->>Model: Broadcast params from rank 0
        Trainer->>Model: fully_shard / FSDP2 wrap
    else zero_stage == 1 (ZeRO-1)
        Trainer->>Model: Wrap with DDP
        Trainer->>Optim: Wrap with ZeroRedundancyOptimizer
    else zero_stage == 0 (Standard)
        Trainer->>Model: Wrap with DDP
    end

    User->>Model: Forward
    Model->>Trainer: Backward
    Trainer->>Optim: optimizer.step (with device/context handling)
    Trainer->>Trainer: Clip grads (check finite)

    alt Save checkpoint
        Trainer->>Trainer: Gather model/optimizer state (sharded or consolidated)
        Trainer->>Store: Rank-0 writes checkpoint
    end

    Trainer->>User: Training complete
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~50 minutes

Suggested reviewers

  • njzjz
  • wanghan-iapcm
🚥 Pre-merge checks | ✅ 3 | ❌ 1
❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 50.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (3 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title accurately captures the main feature additions: FSDP (Fully Sharded Data Parallel) and ZeRO1 (Zero Redundancy Optimizer) support for the PyTorch implementation. This matches the substantial changes across the training module, argument checker, and documentation.
Merge Conflict Detection ✅ Passed ✅ No merge conflicts detected when merging into master

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
deepmd/utils/argcheck.py (1)

3276-3421: ⚠️ Potential issue | 🟡 Minor

Validate zero_stage range (0–3).

Docs define stages 0–3, but the config accepts any integer. A simple range check will prevent silent fall-through to unintended behavior.

✅ Suggested validation
 def training_args(
     multi_task: bool = False,
 ) -> list[Argument]:
@@
     doc_zero_stage = (
         "ZeRO optimization stage for distributed training memory reduction. "
@@
         "Currently supports single-task training; does not support LKF or change_bias_after_training."
     )
+    def _check_zero_stage(val: int) -> bool:
+        if val not in (0, 1, 2, 3):
+            raise ValueError("training.zero_stage must be one of {0,1,2,3}.")
+        return True
@@
         Argument(
             "zero_stage",
             int,
             optional=True,
             default=0,
             doc=doc_only_pt_supported + doc_zero_stage,
+            extra_check=_check_zero_stage,
         ),
deepmd/pt/train/training.py (1)

771-783: ⚠️ Potential issue | 🟠 Major

AdaMuon path misses scheduler and resume state load.

step() uses self.scheduler for AdaMuon, but the AdaMuon branch doesn’t initialize it; resume also won’t load optimizer state. This will raise at runtime and break restart training. Add the same scheduler/state-load block used by Adam/HybridMuon.

🐛 Proposed fix
         elif self.opt_type == "AdaMuon":
             self.optimizer = self._create_optimizer(
                 AdaMuonOptimizer,
                 lr=self.lr_exp.start_lr,
                 momentum=float(self.opt_param["momentum"]),
                 weight_decay=float(self.opt_param["weight_decay"]),
                 adam_betas=(
                     float(self.opt_param["adam_beta1"]),
                     float(self.opt_param["adam_beta2"]),
                 ),
                 lr_adjust=float(self.opt_param["lr_adjust"]),
                 lr_adjust_coeff=float(self.opt_param["lr_adjust_coeff"]),
             )
+            self._load_optimizer_state(optimizer_state_dict)
+            self.scheduler = torch.optim.lr_scheduler.LambdaLR(
+                self.optimizer,
+                lambda step: warm_up_linear(step + self.start_step, self.warmup_steps),
+            )
🧹 Nitpick comments (1)
doc/train/parallel-training.md (1)

101-145: Clarify backend activation and the 2Ψ/3Ψ notation in the ZeRO section.

The constraints mention “PyTorch backend” but don’t remind users how to enable it; also “2Ψ/3Ψ” is undefined and may be opaque to new readers. Consider a short clarification to reduce confusion.

✍️ Suggested doc tweak
-| `zero_stage` | Strategy                      | Communication | Memory saving                                 |
+| `zero_stage` | Strategy                      | Communication (Ψ = model size) | Memory saving                                 |
 | ------------ | ----------------------------- | ------------- | --------------------------------------------- |
 ...
+*Ψ denotes the model size in bytes; communication volumes are relative to it.*
 ...
 Constraints:
 
-- Works only in PyTorch backend.
+- Works only in PyTorch backend (e.g., `dp --pt` or `export DP_BACKEND=pytorch`; use `input_torch.json`).

Based on learnings: Use PyTorch backend with --pt flag or export DP_BACKEND=pytorch; use input_torch.json for configuration typically.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
deepmd/pt/train/training.py (1)

771-782: ⚠️ Potential issue | 🔴 Critical

Bug: AdaMuon optimizer state is never restored on restart.

_load_optimizer_state is called for Adam/AdamW (line 761) and HybridMuon (line 798), but not for AdaMuon. This means restarting training with AdaMuon will silently discard the saved optimizer state (momentum buffers, adaptive learning rates, etc.), causing training regression.

🐛 Proposed fix
             lr_adjust=float(self.opt_param["lr_adjust"]),
             lr_adjust_coeff=float(self.opt_param["lr_adjust_coeff"]),
         )
+        self._load_optimizer_state(optimizer_state_dict)
+        self.scheduler = torch.optim.lr_scheduler.LambdaLR(
+            self.optimizer,
+            lambda step: warm_up_linear(step + self.start_step, self.warmup_steps),
+        )
     elif self.opt_type == "HybridMuon":

Note: If AdaMuon also needs a scheduler (like Adam and HybridMuon), it should be added here as well. Please verify the intended behavior.

🤖 Fix all issues with AI agents
In `@deepmd/pt/train/training.py`:
- Line 164: The assignment to self.zero_stage using
int(training_params.get("zero_stage", 0)) lacks upper-bound validation; after
parsing zero_stage, validate that self.zero_stage is between 0 and 3 inclusive
(e.g., if not (0 <= self.zero_stage <= 3) raise a ValueError with a clear
message including the invalid value), or alternatively clamp/normalize it if
that behavior is preferred; update the code around the self.zero_stage
assignment (use the same training_params.get and self.zero_stage symbols) to
enforce this check and surface an explicit error when out-of-range values are
provided.
🧹 Nitpick comments (1)
deepmd/pt/train/training.py (1)

715-733: FSDP2 applies only at root level — submodule sharding would improve memory savings.

fully_shard at line 725 is applied only to self.wrapper (the root module). For Stage 2/3 to achieve meaningful memory savings on large models, you'd typically want to call fully_shard on individual submodules (e.g., transformer layers) before the root. The current approach is functionally correct but may not deliver the expected memory reduction for large models.

self.change_bias_after_training = training_params.get(
"change_bias_after_training", False
)
self.zero_stage = int(training_params.get("zero_stage", 0))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Missing upper-bound validation for zero_stage.

Values above 3 (e.g., zero_stage=4) are silently accepted and treated as Stage 3. Consider validating the range:

Proposed fix
 self.zero_stage = int(training_params.get("zero_stage", 0))
+if self.zero_stage not in (0, 1, 2, 3):
+    raise ValueError(
+        f"training.zero_stage must be 0, 1, 2, or 3, got {self.zero_stage}"
+    )
 if self.zero_stage > 0 and not self.is_distributed:
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
self.zero_stage = int(training_params.get("zero_stage", 0))
self.zero_stage = int(training_params.get("zero_stage", 0))
if self.zero_stage not in (0, 1, 2, 3):
raise ValueError(
f"training.zero_stage must be 0, 1, 2, or 3, got {self.zero_stage}"
)
if self.zero_stage > 0 and not self.is_distributed:
🤖 Prompt for AI Agents
In `@deepmd/pt/train/training.py` at line 164, The assignment to self.zero_stage
using int(training_params.get("zero_stage", 0)) lacks upper-bound validation;
after parsing zero_stage, validate that self.zero_stage is between 0 and 3
inclusive (e.g., if not (0 <= self.zero_stage <= 3) raise a ValueError with a
clear message including the invalid value), or alternatively clamp/normalize it
if that behavior is preferred; update the code around the self.zero_stage
assignment (use the same training_params.get and self.zero_stage symbols) to
enforce this check and surface an explicit error when out-of-range values are
provided.

@codecov
Copy link

codecov bot commented Feb 13, 2026

Codecov Report

❌ Patch coverage is 55.69620% with 35 lines in your changes missing coverage. Please review.
✅ Project coverage is 82.04%. Comparing base (4f182bc) to head (aefa5e0).

Files with missing lines Patch % Lines
deepmd/pt/train/training.py 55.12% 35 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master    #5222      +/-   ##
==========================================
- Coverage   82.07%   82.04%   -0.03%     
==========================================
  Files         732      732              
  Lines       73974    74031      +57     
  Branches     3615     3615              
==========================================
+ Hits        60711    60736      +25     
- Misses      12100    12130      +30     
- Partials     1163     1165       +2     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.
  • 📦 JS Bundle Analysis: Save yourself from yourself by tracking and limiting bundle sizes in JS merges.

OutisLi and others added 2 commits February 14, 2026 12:32
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Signed-off-by: OutisLi <[email protected]>
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
deepmd/pt/train/training.py (1)

770-782: ⚠️ Potential issue | 🟠 Major

AdaMuon is missing _load_optimizer_state call and scheduler setup — restart will lose optimizer state.

The AdaMuon optimizer (lines 770-782) does not call self._load_optimizer_state(optimizer_state_dict) after creation, unlike both Adam/AdamW (line 761) and HybridMuon (line 798). It also lacks the learning rate scheduler setup. On training restart, the saved optimizer state (momentum, adaptive learning rates) will be silently discarded, and the learning rate schedule will not be initialized.

Proposed fix
             lr_adjust=float(self.opt_param["lr_adjust"]),
             lr_adjust_coeff=float(self.opt_param["lr_adjust_coeff"]),
         )
+        self._load_optimizer_state(optimizer_state_dict)
+        self.scheduler = torch.optim.lr_scheduler.LambdaLR(
+            self.optimizer,
+            lambda step: warm_up_linear(step + self.start_step, self.warmup_steps),
+        )
     elif self.opt_type == "HybridMuon":
🤖 Fix all issues with AI agents
In `@deepmd/pt/train/training.py`:
- Around line 84-95: The unconditional imports of fully_shard and the state_dict
utilities (StateDictOptions, get_model_state_dict, get_optimizer_state_dict,
set_optimizer_state_dict) will raise ImportError on PyTorch <2.4; guard these
imports with the same PyTorch version check used elsewhere (e.g., the existing
if torch.__version__.startswith("2") block) or alternatively update the module
documentation to state PyTorch 2.4+ is required; locate the import block
importing fully_shard and the state_dict symbols in training.py and either wrap
them in the version check or add a clear top-level requirement comment
mentioning PyTorch>=2.4.
🧹 Nitpick comments (1)
deepmd/pt/train/training.py (1)

715-733: Consider sharding submodules individually for better FSDP2 memory efficiency.

Currently, fully_shard is applied only to the root self.wrapper. For large models, applying fully_shard to individual submodules (e.g., each layer) before the root provides finer-grained sharding and better peak memory reduction. The current approach is correct but may not yield the full memory savings FSDP2 is designed for.

Comment on lines +84 to +95
from torch.distributed.checkpoint.state_dict import (
StateDictOptions,
get_model_state_dict,
get_optimizer_state_dict,
set_optimizer_state_dict,
)
from torch.distributed.fsdp import (
fully_shard,
)
from torch.distributed.optim import (
ZeroRedundancyOptimizer,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

cat -n deepmd/pt/train/training.py | sed -n '75,100p'

Repository: deepmodeling/deepmd-kit

Length of output: 807


🌐 Web query:

PyTorch fully_shard FSDP2 minimum version requirement introduced

💡 Result:

torch.distributed.fsdp.fully_shard (the FSDP2 “fully_shard” frontend) was introduced as part of PyTorch 2.4’s prototype FSDP2 release, so the minimum PyTorch version requirement is 2.4.0. [1], [2], [3]

Sources: PyTorch 2.4 release blog announcing prototype FSDP2 [1]; PyTorch blog noting “FSDP v2 … released as part of PyTorch 2.4” [2]; downstream projects documenting “PyTorch version >= 2.4.0 is required” for the FSDP2 API [3].


🌐 Web query:

torch.distributed.checkpoint.state_dict minimum PyTorch version

💡 Result:

torch.distributed.checkpoint.state_dict (the helper APIs like get_state_dict, set_state_dict, StateDictOptions, get_model_state_dict, etc.) is available in PyTorch >= 2.4 (it’s present in the v2.4.0 codebase and discussed as part of “new” DCP state-dict/flattening behavior in 2.4). [2]

For context: Distributed Checkpoint (DCP) itself exists earlier (e.g., PyTorch 2.0 has torch.distributed.checkpoint.save_state_dict / load_state_dict), but that older API surface is different from torch.distributed.checkpoint.state_dict.*. [1]

Sources: PyTorch 2.0 DCP docs (no state_dict.* helpers) [1]; PyTorch issue referencing “new … introduced in torch 2.4” behavior and v2.4.0 DCP internals [2]; current DCP docs showing torch.distributed.checkpoint.state_dict.* APIs [3].

[1] https://pytorch.cadn.net.cn/docs_en/2.0/distributed.checkpoint.html
[2] pytorch/pytorch#137327
[3] https://docs.pytorch.org/docs/stable/distributed.checkpoint.html


🏁 Script executed:

rg "torch.__version__" deepmd/pt/train/training.py -A 2 -B 2

Repository: deepmodeling/deepmd-kit

Length of output: 133


🏁 Script executed:

rg "zero_stage" deepmd/pt/train/training.py | head -20

Repository: deepmodeling/deepmd-kit

Length of output: 1302


Guard these imports behind a PyTorch 2.4+ version check, or document the new minimum version requirement.

fully_shard (FSDP2) and the torch.distributed.checkpoint.state_dict utilities (get_model_state_dict, get_optimizer_state_dict, StateDictOptions, set_optimizer_state_dict) were introduced in PyTorch 2.4 and are imported unconditionally at lines 84–95. This will cause an ImportError at module load time for anyone using PyTorch < 2.4, including users who don't use zero_stage > 0 (where the default is 0).

Either guard these imports like the existing check on line 80 (if torch.__version__.startswith("2")), or update documentation to declare PyTorch 2.4+ as a new global requirement for this module.

🤖 Prompt for AI Agents
In `@deepmd/pt/train/training.py` around lines 84 - 95, The unconditional imports
of fully_shard and the state_dict utilities (StateDictOptions,
get_model_state_dict, get_optimizer_state_dict, set_optimizer_state_dict) will
raise ImportError on PyTorch <2.4; guard these imports with the same PyTorch
version check used elsewhere (e.g., the existing if
torch.__version__.startswith("2") block) or alternatively update the module
documentation to state PyTorch 2.4+ is required; locate the import block
importing fully_shard and the state_dict symbols in training.py and either wrap
them in the version check or add a clear top-level requirement comment
mentioning PyTorch>=2.4.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant