-
Notifications
You must be signed in to change notification settings - Fork 674
[Cherry-Pick][CI] Fix attn_mask_offset for multi-step MTP in mixed and PD-split modes(#5738) #5793
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: release/2.4
Are you sure you want to change the base?
[Cherry-Pick][CI] Fix attn_mask_offset for multi-step MTP in mixed and PD-split modes(#5738) #5793
Conversation
…ed and PD-split modes (PaddlePaddle#5738) * fix attn_mask_offset in mtp with multi-step and pd-split-mode * fix xpu operater register * update pmtp multi-step mtp strategy in d-split -mode * add note * fix xpu register
|
Thanks for your contribution! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This is a cherry-pick PR from #5738 that fixes attention mask offset handling for multi-step MTP (Medusa Token Prediction) in mixed and PD-split (Prefill-Decode split) modes. The fix introduces two new parameters (mask_rollback and recompute_token_num) to properly track and manage token recomputation during speculative decoding.
Key Changes:
- Added
mask_rollbackandrecompute_token_numparameters to track token recomputation state - Updated CUDA kernel logic to dynamically compute rollback amounts instead of using static values
- Improved code formatting and consistency in CUDA kernel macros
- Enhanced test class to properly inherit from
unittest.TestCase
Reviewed changes
Copilot reviewed 6 out of 6 changed files in this pull request and generated 6 comments.
Show a summary per file
| File | Description |
|---|---|
| tests/operators/test_draft_model_preprocess.py | Added new test parameters, fixed test class inheritance, simplified test assertions |
| fastdeploy/spec_decode/mtp.py | Initialized new model input tensors for mask rollback and token recomputation tracking |
| custom_ops/xpu_ops/src/ops/pybind/pybind.cc | Added new parameters to XPU Python binding function signature |
| custom_ops/xpu_ops/src/ops/mtp/draft_model_preprocess.cc | Updated XPU operator registration with new input parameters |
| custom_ops/gpu_ops/speculate_decoding/draft_model/draft_model_preprocess.cu | Implemented dynamic recomputation logic in CUDA kernel, improved code formatting |
| custom_ops/gpu_ops/cpp_extensions.cc | Updated GPU extension function signature with new parameters |
| # NOTE(liuzichang): In speculative decoding, accepted tokens' KV cache is recomputed | ||
| # using the target model's hidden states. |
Copilot
AI
Dec 26, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The comment formatting should be improved for consistency. The "NOTE" comments should follow a consistent format throughout the codebase. Consider using a more descriptive multi-line comment format that clearly explains what recompute_token_num represents and why it's initialized to self.num_model_steps - 1.
| # NOTE(liuzichang): In speculative decoding, accepted tokens' KV cache is recomputed | |
| # using the target model's hidden states. | |
| # NOTE(liuzichang): | |
| # `recompute_token_num` stores, for each sequence in the batch, the number of tokens | |
| # whose KV cache needs to be recomputed by the target model after speculative decoding. | |
| # We initialize it to `self.num_model_steps - 1` so that, by default, all tokens | |
| # produced in the draft steps (except the final target step) are marked for | |
| # recomputation using the target model's hidden states. |
| # P-D split need rollback one step | ||
| self.model_inputs["recompute_token_num"][idx : idx + 1] = 0 | ||
| # NOTE(liuzichang): | ||
| # extra 1 : P-D split need rollback one step |
Copilot
AI
Dec 26, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The comment "extra 1 : P-D split need rollback one step" is unclear. Consider rephrasing to: "Extra rollback by 1 for P-D split mode: when Prefill-Decode splitting is enabled, we need to roll back one additional step." This would better explain both what "extra 1" refers to and why the rollback is necessary.
| # extra 1 : P-D split need rollback one step | |
| # Extra rollback by 1 for P-D split mode: | |
| # when Prefill-Decode splitting is enabled, we need to roll back one additional step. |
| // NOTE(liuzichang): Used for PD-split mode and future dynamic | ||
| // strategies. |
Copilot
AI
Dec 26, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The comment "Used for PD-split mode and future dynamic strategies" could be more descriptive. Consider explaining what aspect of PD-split mode this variable addresses (the dynamic recomputation count for accepted tokens) and what "future dynamic strategies" might entail. This would help future maintainers understand the design intent.
| // NOTE(liuzichang): Used for PD-split mode and future dynamic | |
| // strategies. | |
| // NOTE(liuzichang): Track the dynamic recomputation window for this | |
| // sequence. In PD-split mode, this value controls how many accepted | |
| // tokens from the base model will be recomputed on the next step, and | |
| // it is also intended to support future dynamic scheduling/rollback | |
| // strategies that may adjust the recomputation span at runtime. |
| return inputs, inputs_clone | ||
|
|
||
| def test_draft_model_preprocess(self): | ||
| results1, results2 = self._run_tests() |
Copilot
AI
Dec 26, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The test validation has been simplified from checking specific output indices to a generic loop checking all 12 outputs. While this is more concise, it loses the clarity of what each assertion is checking. Consider adding comments to document which indices correspond to which outputs (e.g., 0=draft_tokens, 1=input_ids, etc.), or using named constants for the indices to maintain test readability.
| results1, results2 = self._run_tests() | |
| results1, results2 = self._run_tests() | |
| # The first 12 outputs correspond to: | |
| # 0: draft_tokens | |
| # 1: input_ids | |
| # 2: stop_flags | |
| # 3: seq_lens_this_time | |
| # 4: seq_lens_encoder | |
| # 5: seq_lens_decoder | |
| # 6: step_idx | |
| # 7: not_need_stop | |
| # 8: is_block_step | |
| # 9: batch_drop | |
| # 10: pre_ids | |
| # 11: mask_rollback |
| is_block_step = paddle.zeros([bsz], dtype="bool") | ||
| batch_drop = paddle.zeros([bsz], dtype="bool") | ||
| mask_rollback = paddle.zeros([bsz], dtype="int32") | ||
| recompute_token_num = paddle.zeros([bsz], dtype="int32") |
Copilot
AI
Dec 26, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The test initializes recompute_token_num with zeros, but in the actual implementation in mtp.py it's initialized with self.num_model_steps - 1. This mismatch between test initialization and production initialization could lead to tests not catching edge cases. Consider initializing it to match the production default value to ensure the test is validating the actual behavior.
| mask_rollback = paddle.zeros([bsz], dtype="int32") | ||
| recompute_token_num = paddle.zeros([bsz], dtype="int32") |
Copilot
AI
Dec 26, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shape inconsistency: In production code (mtp.py), mask_rollback and recompute_token_num are initialized with shape [self.max_num_seqs, 1], but in the test they use shape [bsz]. While this may work due to array flattening, it's better to maintain consistency. Consider using the same shape in tests as in production code to ensure the test accurately reflects the actual usage pattern.
| mask_rollback = paddle.zeros([bsz], dtype="int32") | |
| recompute_token_num = paddle.zeros([bsz], dtype="int32") | |
| mask_rollback = paddle.zeros([bsz, 1], dtype="int32") | |
| recompute_token_num = paddle.zeros([bsz, 1], dtype="int32") |
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## release/2.4 #5793 +/- ##
==============================================
Coverage ? 59.06%
==============================================
Files ? 328
Lines ? 40732
Branches ? 6204
==============================================
Hits ? 24057
Misses ? 14795
Partials ? 1880
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
…ed and PD-split modes (#5738)
fix attn_mask_offset in mtp with multi-step and pd-split-mode
fix xpu operater register
update pmtp multi-step mtp strategy in d-split -mode
add note
fix xpu register
Motivation
Modifications
Usage or Command
Accuracy Tests
Checklist
[FDConfig],[APIServer],[Engine],[Scheduler],[PD Disaggregation],[Executor],[Graph Optimization],[Speculative Decoding],[RL],[Models],[Quantization],[Loader],[OP],[KVCache],[DataProcessor],[BugFix],[Docs],[CI],[Optimization],[Feature],[Benchmark],[Others],[XPU],[HPU],[GCU],[DCU],[Iluvatar],[Metax]]pre-commitbefore commit.releasebranch, make sure the PR has been submitted to thedevelopbranch, then cherry-pick it to thereleasebranch with the[Cherry-Pick]PR tag.