Skip to content

optimize get_sorted_idx in moe#4529

Open
grimoire wants to merge 2 commits intoInternLM:mainfrom
grimoire:optimize-moe-expert-map
Open

optimize get_sorted_idx in moe#4529
grimoire wants to merge 2 commits intoInternLM:mainfrom
grimoire:optimize-moe-expert-map

Conversation

@grimoire
Copy link
Copy Markdown
Collaborator

@grimoire grimoire commented Apr 15, 2026

optimize for large number of experts.
less memory usage

Warning

The order of the sorted idx is not stable.

@grimoire grimoire marked this pull request as ready for review April 16, 2026 04:18
Copilot AI review requested due to automatic review settings April 16, 2026 04:18
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR updates the CUDA fused MoE routing index generation to reduce memory usage for large expert counts, and adjusts a PyTorch engine default related to prefill sizing.

Changes:

  • Replaces the previous mask/cumsum-based _get_sorted_idx implementation with a 2-phase Triton approach (atomic histogram + scatter).
  • Wires the new Triton implementation as the default _get_sorted_idx used by fused MoE kernels.
  • Increases PytorchEngineConfig.max_prefill_token_num default from 4096 to 8192.

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 6 comments.

File Description
lmdeploy/pytorch/kernels/cuda/fused_moe.py Introduces new Triton kernels and replaces _get_sorted_idx with a 2-phase atomic/scatter approach.
lmdeploy/messages.py Updates the default PyTorch engine prefill token budget.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

return out_mask, out_k

def _get_sorted_idx_triton(topk_ids: torch.Tensor, num_experts: int):
"""Get sorted idx with 2-phase Triton kernels (4 kernel launches total)."""
Comment thread lmdeploy/pytorch/kernels/cuda/fused_moe.py
Comment on lines +341 to +343
counts = torch.zeros(num_experts, dtype=topk_ids.dtype, device=topk_ids.device)
local_pos = torch.empty(N, dtype=topk_ids.dtype, device=topk_ids.device)
_sorted_idx_phase1_kernel[grid](topk_ids, counts, local_pos, N, BLOCK_SIZE=BLOCK_SIZE)
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use int64 to avoid experts weights visit overflow.

Comment thread lmdeploy/pytorch/kernels/cuda/fused_moe.py
Comment thread lmdeploy/messages.py
Comment on lines +148 to +155
# Compute exp_start = exp_end - counts (only first block writes it)
if pid == 0:
e_offs = tl.arange(0, BLOCK_E)
e_mask = e_offs < num_experts
end_val = tl.load(ExpEnd + e_offs, mask=e_mask)
cnt_val = tl.load(Counts + e_offs, mask=e_mask)
tl.store(ExpStart + e_offs, end_val - cnt_val, mask=e_mask)

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tested on H800 with 10240 experts.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants