Skip to content

fix(pt): remove compile in HybridMuon for FSDP2 compatibility use triton instead#5221

Open
OutisLi wants to merge 4 commits intodeepmodeling:masterfrom
OutisLi:pr/fixMuon
Open

fix(pt): remove compile in HybridMuon for FSDP2 compatibility use triton instead#5221
OutisLi wants to merge 4 commits intodeepmodeling:masterfrom
OutisLi:pr/fixMuon

Conversation

@OutisLi
Copy link
Collaborator

@OutisLi OutisLi commented Feb 13, 2026

Summary by CodeRabbit

  • Refactor

    • Simplified orthogonalization routing to process matrix parameters per-entry and added a faster GPU path when available.
  • New Features

    • Added optimizer option "flash_muon" (default: true) to enable an accelerated orthogonalization path with automatic fallback.
  • Tests

    • Expanded tests to cover both accelerated and fallback paths and adjusted shape/validation cases.
  • Documentation

    • Updated optimizer argument docs to describe flash_muon and clarified 1D/>=2D parameter handling.

Copilot AI review requested due to automatic review settings February 13, 2026 07:58
@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @OutisLi, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly refactors the HybridMuonOptimizer to enhance its compatibility with PyTorch's Fully Sharded Data Parallel (FSDP2). The core changes involve removing torch.compile usage, simplifying the Newton-Schulz orthogonalization process to operate on individual tensors rather than batches, and replacing torch._foreach_ tensor operations with explicit loops. These modifications are crucial for ensuring that the optimizer functions correctly and efficiently within an FSDP2 environment, preventing potential issues with sharding and distributed training.

Highlights

  • FSDP2 Compatibility: Removed the _maybe_compile function and its usage to ensure compatibility with PyTorch's Fully Sharded Data Parallel (FSDP2), as torch.compile can interfere with FSDP2's sharding propagation.
  • Orthogonalization Refactoring: Consolidated the 2D and 3D Newton-Schulz orthogonalization functions into a single _newton_schulz_orth function, and removed the zeropower_via_newtonschulz5 dispatcher. The step function now processes orthogonalization for each entry individually, which is more compatible with FSDP2's sharding.
  • Replacement of _foreach_ Operations: Replaced torch._foreach_lerp_, torch._foreach_mul, and torch._foreach_norm calls with explicit Python loops and individual tensor operations (e.g., ea.lerp_(g, ...), g * g, p.norm()). This change addresses known issues where _foreach_ operations might not correctly propagate sharding information under FSDP2.
  • Test Updates: Updated the unit tests in test_hybrid_muon.py to reflect the changes in function names and the removal of 3D batched orthogonalization, ensuring test coverage for the refactored logic.
Changelog
  • deepmd/pt/optimizer/hybrid_muon.py
    • Removed the _maybe_compile function and all its call sites.
    • Renamed _zeropower_via_newtonschulz5_2d to _newton_schulz_orth.
    • Deleted _zeropower_via_newtonschulz5_3d and zeropower_via_newtonschulz5 functions.
    • Replaced torch._foreach_lerp_ with explicit for loops and lerp_ method calls.
    • Replaced torch._foreach_mul with list comprehensions for element-wise multiplication.
    • Replaced torch._foreach_norm with list comprehensions calling the .norm() method.
    • Modified the step function to iterate and process bucket_entries individually for orthogonalization, removing the batched 3D processing path.
  • source/tests/pt/test_hybrid_muon.py
    • Updated imports to reference the new _newton_schulz_orth function.
    • Modified test_orthogonalization and test_shape_and_dtype to use _newton_schulz_orth.
    • Removed 3D shape testing from test_shape_and_dtype.
    • Adjusted test_invalid_input to expect a broader range of errors for 1D input to _newton_schulz_orth.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@OutisLi OutisLi changed the title fix: remove compile in HybridMuon for FSDP2 compatibility fix(pt): remove compile in HybridMuon for FSDP2 compatibility Feb 13, 2026
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Feb 13, 2026

📝 Walkthrough

Walkthrough

Adds a Triton-accelerated "flash" Newton–Schulz orthogonalization path with lazy per-(M,device) buffers, refactors per-parameter routing to per-entry processing, and exposes a new flash_muon option wired through optimizer setup and CLI args.

Changes

Cohort / File(s) Summary
Hybrid Muon optimizer
deepmd/pt/optimizer/hybrid_muon.py
Add TRITON_AVAILABLE detection and FLASH_MIN_DIM; add _flash_newton_schulz_orth(G, buf1, buf2) and _get_ns_buffers(self, M, device); remove _maybe_compile and previous 2D/3D zeropower helpers; switch to per-entry Newton–Schulz routing, lazy reuse of buffers, and new flash_muon constructor flag.
Tests
source/tests/pt/test_hybrid_muon.py
Expose TRITON_AVAILABLE; replace zeropower_via_newtonschulz5 calls with _newton_schulz_orth; add TestFlashMuon suite validating flash vs fallback behavior and internal state (_use_flash, _ns_buffers) when Triton+CUDA present; reduce shapes to 2D-focused cases and broaden exception checks.
Training wiring
deepmd/pt/train/training.py
Add flash_muon to optimizer config (get_opt_param) and pass flash_muon=bool(self.opt_param["flash_muon"]) into HybridMuon instantiation.
CLI / argcheck
deepmd/utils/argcheck.py
Add flash_muon boolean CLI/argcheck entry (default True) documenting Triton-accelerated NS orthogonalization with PyTorch fallback.

Sequence Diagram(s)

sequenceDiagram
  participant Trainer as Training loop
  participant Optimizer as HybridMuon optimizer
  participant Buffers as NS Buffers (per-(M,device))
  participant Triton as Triton matmul kernel
  participant PyTorch as torch matmul ops

  Trainer->>Optimizer: step()
  Optimizer->>Optimizer: iterate params / determine routing
  Optimizer->>Buffers: get_or_alloc_buffers(M, device)
  alt flash_muon enabled & TRITON_AVAILABLE & M >= FLASH_MIN_DIM
    Optimizer->>Triton: _flash_newton_schulz_orth(G, buf1, buf2)
    Triton-->>Optimizer: orthogonalized result
    Optimizer->>Buffers: reuse/update bufs
  else fallback
    Optimizer->>PyTorch: _newton_schulz_orth(G) using torch ops
    PyTorch-->>Optimizer: orthogonalized result
  end
  Optimizer-->>Trainer: updated parameters
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Possibly related PRs

Suggested reviewers

  • njzjz
  • iProzd
  • wanghan-iapcm
🚥 Pre-merge checks | ✅ 2 | ❌ 2
❌ Failed checks (2 warnings)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 77.27% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
Merge Conflict Detection ⚠️ Warning ❌ Merge conflicts detected (7 files):

⚔️ deepmd/dpmodel/atomic_model/base_atomic_model.py (content)
⚔️ deepmd/dpmodel/common.py (content)
⚔️ deepmd/dpmodel/utils/type_embed.py (content)
⚔️ deepmd/pt/optimizer/hybrid_muon.py (content)
⚔️ deepmd/pt/train/training.py (content)
⚔️ deepmd/utils/argcheck.py (content)
⚔️ source/tests/pt/test_hybrid_muon.py (content)

These conflicts must be resolved before merging into master.
Resolve conflicts locally and push changes to this branch.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title accurately describes the main change: removing compile() from HybridMuon for FSDP2 compatibility and introducing Triton-accelerated Newton-Schulz orthogonalization as the alternative approach.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment
⚔️ Resolve merge conflicts (beta)
  • Auto-commit resolved conflicts to branch pr/fixMuon
  • Post resolved changes as copyable diffs in a comment

No actionable comments were generated in the recent review. 🎉

🧹 Recent nitpick comments
deepmd/pt/optimizer/hybrid_muon.py (1)

627-631: Consider adding strict=True to zip() calls.

Multiple zip() calls throughout step() (lines 627, 630, 686, 689, 748, 751, 816, 821) lack an explicit strict= parameter. While lengths are guaranteed to match here (lists built in the same loop), adding strict=True would catch bugs if the code is refactored later and satisfy Ruff B905.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request aims to improve FSDP2 compatibility for the HybridMuonOptimizer by removing features that can be problematic with sharded tensors. The changes include removing torch.compile, replacing batched processing with individual processing in a loop, and substituting optimized torch._foreach kernels with standard Python loops. These modifications are logical and correctly implemented to achieve the stated goal, though they may come at the cost of performance in non-FSDP environments. The tests have been updated accordingly to reflect these changes. I have one suggestion to add a code comment to improve maintainability by explaining the rationale behind replacing _foreach operations.

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Updates HybridMuon’s Newton–Schulz orthogonalization path to avoid torch.compile and stacked/batched processing, improving FSDP2 compatibility.

Changes:

  • Replace public zeropower_via_newtonschulz5 + compiled kernels with a single _newton_schulz_orth implementation.
  • Update optimizer step to orthogonalize each matrix update individually (no batched 3D kernel / stacking).
  • Adjust unit tests to call the new function and relax invalid-input exception expectations.

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 4 comments.

File Description
source/tests/pt/test_hybrid_muon.py Updates tests to use _newton_schulz_orth, removes 3D-shape coverage, and loosens invalid-input exception assertions.
deepmd/pt/optimizer/hybrid_muon.py Removes torch.compile wrapper and batched NS kernels; switches various foreach ops and batched NS to per-tensor loops for FSDP2 compatibility.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

@codecov
Copy link

codecov bot commented Feb 13, 2026

Codecov Report

❌ Patch coverage is 40.18692% with 64 lines in your changes missing coverage. Please review.
✅ Project coverage is 82.06%. Comparing base (4f182bc) to head (aa32989).
⚠️ Report is 1 commits behind head on master.

Files with missing lines Patch % Lines
deepmd/pt/optimizer/hybrid_muon.py 40.18% 64 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master    #5221      +/-   ##
==========================================
- Coverage   82.07%   82.06%   -0.01%     
==========================================
  Files         732      736       +4     
  Lines       73974    74286     +312     
  Branches     3615     3616       +1     
==========================================
+ Hits        60711    60962     +251     
- Misses      12100    12161      +61     
  Partials     1163     1163              

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.
  • 📦 JS Bundle Analysis: Save yourself from yourself by tracking and limiting bundle sizes in JS merges.

@OutisLi OutisLi changed the title fix(pt): remove compile in HybridMuon for FSDP2 compatibility fix(pt): remove compile in HybridMuon for FSDP2 compatibility use triton instead Feb 14, 2026
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🤖 Fix all issues with AI agents
In `@deepmd/pt/optimizer/hybrid_muon.py`:
- Around line 111-138: Remove the unused `# noqa` annotations in this file:
delete the `# noqa: ANN202` on the `_get_autotune_config` definition and the `#
noqa: ANN001` markers on the `_mmt_kernel` parameter lines (x, y, M, K,
stride_xm, stride_xk, stride_ym, stride_yn) as they are not enabled for Ruff;
leave the function, decorators (`@triton.autotune`, `@triton.jit`) and parameter
list intact and only remove the extraneous `# noqa` comments.
- Around line 157-176: The bug is that offs_xm/off_xn use modulo wrapping which
pulls rows from the top of the matrix into tail blocks; remove the "% M" wrap
and compute the raw row indices (e.g., offs_row_m = pid_m * BLOCK_SIZE_M +
tl.arange(0, BLOCK_SIZE_M) and similarly for offs_row_n), then build
a_ptrs/b_ptrs from those raw indices and ensure tl.load masks out any rows >= M
(combine the existing K-mask with a row-mask like (offs_row[:, None] < M) so
loads read 0 for out-of-range rows); keep the existing store mask (c_mask) for
the write-back but do not use modulo on offs_xm/offs_xn so the accumulator never
sees wrapped rows.
🧹 Nitpick comments (1)
source/tests/pt/test_hybrid_muon.py (1)

259-291: Consider forcing the flash path in the consistency test when Triton+CUDA are available.

With FLASH_MIN_DIM = 1024, the current 32×64 shapes won’t ever exercise the flash path on GPU, so this test may only compare the PyTorch fallback against itself. You could bump the dim when Triton+CUDA are present to actually validate flash-vs-non-flash consistency.

🔧 Suggested tweak
-        model1 = torch.nn.Linear(32, 64, device=self.device)
-        model2 = torch.nn.Linear(32, 64, device=self.device)
+        dim = 32
+        if TRITON_AVAILABLE and self.device.type == "cuda":
+            from deepmd.pt.optimizer.hybrid_muon import FLASH_MIN_DIM
+            dim = max(dim, FLASH_MIN_DIM)
+        model1 = torch.nn.Linear(dim, dim * 2, device=self.device)
+        model2 = torch.nn.Linear(dim, dim * 2, device=self.device)
         model2.load_state_dict(model1.state_dict())
@@
-        x = torch.randn(4, 32, device=self.device)
+        x = torch.randn(4, dim, device=self.device)

]

@triton.autotune(configs=_get_autotune_config(), key=["M", "K"])
@triton.jit
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will it be expensive at the import time?

grid = lambda META: ( # noqa: E731
triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(M, META["BLOCK_SIZE_M"]),
)
with torch.cuda.device(d_in.device.index):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the situation for the CPU?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems that triton is not installed in the CI.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants