fix(pt): remove compile in HybridMuon for FSDP2 compatibility use triton instead#5221
fix(pt): remove compile in HybridMuon for FSDP2 compatibility use triton instead#5221OutisLi wants to merge 4 commits intodeepmodeling:masterfrom
Conversation
Summary of ChangesHello @OutisLi, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly refactors the Highlights
Changelog
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
📝 WalkthroughWalkthroughAdds a Triton-accelerated "flash" Newton–Schulz orthogonalization path with lazy per-(M,device) buffers, refactors per-parameter routing to per-entry processing, and exposes a new Changes
Sequence Diagram(s)sequenceDiagram
participant Trainer as Training loop
participant Optimizer as HybridMuon optimizer
participant Buffers as NS Buffers (per-(M,device))
participant Triton as Triton matmul kernel
participant PyTorch as torch matmul ops
Trainer->>Optimizer: step()
Optimizer->>Optimizer: iterate params / determine routing
Optimizer->>Buffers: get_or_alloc_buffers(M, device)
alt flash_muon enabled & TRITON_AVAILABLE & M >= FLASH_MIN_DIM
Optimizer->>Triton: _flash_newton_schulz_orth(G, buf1, buf2)
Triton-->>Optimizer: orthogonalized result
Optimizer->>Buffers: reuse/update bufs
else fallback
Optimizer->>PyTorch: _newton_schulz_orth(G) using torch ops
PyTorch-->>Optimizer: orthogonalized result
end
Optimizer-->>Trainer: updated parameters
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Possibly related PRs
Suggested reviewers
🚥 Pre-merge checks | ✅ 2 | ❌ 2❌ Failed checks (2 warnings)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
🧪 Generate unit tests (beta)
⚔️ Resolve merge conflicts (beta)
No actionable comments were generated in the recent review. 🎉 🧹 Recent nitpick comments
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request aims to improve FSDP2 compatibility for the HybridMuonOptimizer by removing features that can be problematic with sharded tensors. The changes include removing torch.compile, replacing batched processing with individual processing in a loop, and substituting optimized torch._foreach kernels with standard Python loops. These modifications are logical and correctly implemented to achieve the stated goal, though they may come at the cost of performance in non-FSDP environments. The tests have been updated accordingly to reflect these changes. I have one suggestion to add a code comment to improve maintainability by explaining the rationale behind replacing _foreach operations.
There was a problem hiding this comment.
Pull request overview
Updates HybridMuon’s Newton–Schulz orthogonalization path to avoid torch.compile and stacked/batched processing, improving FSDP2 compatibility.
Changes:
- Replace public
zeropower_via_newtonschulz5+ compiled kernels with a single_newton_schulz_orthimplementation. - Update optimizer step to orthogonalize each matrix update individually (no batched 3D kernel / stacking).
- Adjust unit tests to call the new function and relax invalid-input exception expectations.
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 4 comments.
| File | Description |
|---|---|
| source/tests/pt/test_hybrid_muon.py | Updates tests to use _newton_schulz_orth, removes 3D-shape coverage, and loosens invalid-input exception assertions. |
| deepmd/pt/optimizer/hybrid_muon.py | Removes torch.compile wrapper and batched NS kernels; switches various foreach ops and batched NS to per-tensor loops for FSDP2 compatibility. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #5221 +/- ##
==========================================
- Coverage 82.07% 82.06% -0.01%
==========================================
Files 732 736 +4
Lines 73974 74286 +312
Branches 3615 3616 +1
==========================================
+ Hits 60711 60962 +251
- Misses 12100 12161 +61
Partials 1163 1163 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
There was a problem hiding this comment.
Actionable comments posted: 2
🤖 Fix all issues with AI agents
In `@deepmd/pt/optimizer/hybrid_muon.py`:
- Around line 111-138: Remove the unused `# noqa` annotations in this file:
delete the `# noqa: ANN202` on the `_get_autotune_config` definition and the `#
noqa: ANN001` markers on the `_mmt_kernel` parameter lines (x, y, M, K,
stride_xm, stride_xk, stride_ym, stride_yn) as they are not enabled for Ruff;
leave the function, decorators (`@triton.autotune`, `@triton.jit`) and parameter
list intact and only remove the extraneous `# noqa` comments.
- Around line 157-176: The bug is that offs_xm/off_xn use modulo wrapping which
pulls rows from the top of the matrix into tail blocks; remove the "% M" wrap
and compute the raw row indices (e.g., offs_row_m = pid_m * BLOCK_SIZE_M +
tl.arange(0, BLOCK_SIZE_M) and similarly for offs_row_n), then build
a_ptrs/b_ptrs from those raw indices and ensure tl.load masks out any rows >= M
(combine the existing K-mask with a row-mask like (offs_row[:, None] < M) so
loads read 0 for out-of-range rows); keep the existing store mask (c_mask) for
the write-back but do not use modulo on offs_xm/offs_xn so the accumulator never
sees wrapped rows.
🧹 Nitpick comments (1)
source/tests/pt/test_hybrid_muon.py (1)
259-291: Consider forcing the flash path in the consistency test when Triton+CUDA are available.With
FLASH_MIN_DIM = 1024, the current 32×64 shapes won’t ever exercise the flash path on GPU, so this test may only compare the PyTorch fallback against itself. You could bump the dim when Triton+CUDA are present to actually validate flash-vs-non-flash consistency.🔧 Suggested tweak
- model1 = torch.nn.Linear(32, 64, device=self.device) - model2 = torch.nn.Linear(32, 64, device=self.device) + dim = 32 + if TRITON_AVAILABLE and self.device.type == "cuda": + from deepmd.pt.optimizer.hybrid_muon import FLASH_MIN_DIM + dim = max(dim, FLASH_MIN_DIM) + model1 = torch.nn.Linear(dim, dim * 2, device=self.device) + model2 = torch.nn.Linear(dim, dim * 2, device=self.device) model2.load_state_dict(model1.state_dict()) @@ - x = torch.randn(4, 32, device=self.device) + x = torch.randn(4, dim, device=self.device)
| ] | ||
|
|
||
| @triton.autotune(configs=_get_autotune_config(), key=["M", "K"]) | ||
| @triton.jit |
There was a problem hiding this comment.
Will it be expensive at the import time?
| grid = lambda META: ( # noqa: E731 | ||
| triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(M, META["BLOCK_SIZE_M"]), | ||
| ) | ||
| with torch.cuda.device(d_in.device.index): |
There was a problem hiding this comment.
What is the situation for the CPU?
There was a problem hiding this comment.
It seems that triton is not installed in the CI.
Summary by CodeRabbit
Refactor
New Features
Tests
Documentation