Open
Conversation
Contributor
There was a problem hiding this comment.
Pull request overview
This PR updates the CUDA fused MoE routing index generation to reduce memory usage for large expert counts, and adjusts a PyTorch engine default related to prefill sizing.
Changes:
- Replaces the previous mask/cumsum-based
_get_sorted_idximplementation with a 2-phase Triton approach (atomic histogram + scatter). - Wires the new Triton implementation as the default
_get_sorted_idxused by fused MoE kernels. - Increases
PytorchEngineConfig.max_prefill_token_numdefault from 4096 to 8192.
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 6 comments.
| File | Description |
|---|---|
lmdeploy/pytorch/kernels/cuda/fused_moe.py |
Introduces new Triton kernels and replaces _get_sorted_idx with a 2-phase atomic/scatter approach. |
lmdeploy/messages.py |
Updates the default PyTorch engine prefill token budget. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| return out_mask, out_k | ||
|
|
||
| def _get_sorted_idx_triton(topk_ids: torch.Tensor, num_experts: int): | ||
| """Get sorted idx with 2-phase Triton kernels (4 kernel launches total).""" |
Comment on lines
+341
to
+343
| counts = torch.zeros(num_experts, dtype=topk_ids.dtype, device=topk_ids.device) | ||
| local_pos = torch.empty(N, dtype=topk_ids.dtype, device=topk_ids.device) | ||
| _sorted_idx_phase1_kernel[grid](topk_ids, counts, local_pos, N, BLOCK_SIZE=BLOCK_SIZE) |
Collaborator
Author
There was a problem hiding this comment.
Use int64 to avoid experts weights visit overflow.
Comment on lines
+148
to
+155
| # Compute exp_start = exp_end - counts (only first block writes it) | ||
| if pid == 0: | ||
| e_offs = tl.arange(0, BLOCK_E) | ||
| e_mask = e_offs < num_experts | ||
| end_val = tl.load(ExpEnd + e_offs, mask=e_mask) | ||
| cnt_val = tl.load(Counts + e_offs, mask=e_mask) | ||
| tl.store(ExpStart + e_offs, end_val - cnt_val, mask=e_mask) | ||
|
|
Collaborator
Author
There was a problem hiding this comment.
Tested on H800 with 10240 experts.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
optimize for large number of experts.
less memory usage
Warning
The order of the sorted idx is not stable.