Skip to content

Conversation

@freeliuzc
Copy link
Collaborator

…ed and PD-split modes (#5738)

  • fix attn_mask_offset in mtp with multi-step and pd-split-mode

  • fix xpu operater register

  • update pmtp multi-step mtp strategy in d-split -mode

  • add note

  • fix xpu register

Motivation

💡 If this PR is a Cherry Pick, the PR title needs to follow the format by adding the [Cherry-Pick] label at the very beginning and appending the original PR ID at the end. For example, [Cherry-Pick][CI] Add check trigger and logic(#5191)

💡 如若此PR是Cherry Pick,PR标题需遵循格式,在最开始加上[Cherry-Pick]标签,以及最后面加上原PR ID,例如[Cherry-Pick][CI] Add check trigger and logic(#5191)

Modifications

Usage or Command

Accuracy Tests

Checklist

  • Add at least a tag in the PR title.
    • Tag list: [[FDConfig],[APIServer],[Engine], [Scheduler], [PD Disaggregation], [Executor], [Graph Optimization], [Speculative Decoding], [RL], [Models], [Quantization], [Loader], [OP], [KVCache], [DataProcessor], [BugFix], [Docs], [CI], [Optimization], [Feature], [Benchmark], [Others], [XPU], [HPU], [GCU], [DCU], [Iluvatar], [Metax]]
    • You can add new tags based on the PR content, but the semantics must be clear.
  • Format your code, run pre-commit before commit.
  • Add unit tests. Please write the reason in this PR if no unit tests.
  • Provide accuracy results.
  • If the current PR is submitting to the release branch, make sure the PR has been submitted to the develop branch, then cherry-pick it to the release branch with the [Cherry-Pick] PR tag.

…ed and PD-split modes (PaddlePaddle#5738)

* fix attn_mask_offset in mtp with multi-step and pd-split-mode

* fix xpu operater register

* update pmtp multi-step mtp strategy in d-split -mode

* add note

* fix xpu register
Copilot AI review requested due to automatic review settings December 26, 2025 07:43
@paddle-bot
Copy link

paddle-bot bot commented Dec 26, 2025

Thanks for your contribution!

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This is a cherry-pick PR from #5738 that fixes attention mask offset handling for multi-step MTP (Medusa Token Prediction) in mixed and PD-split (Prefill-Decode split) modes. The fix introduces two new parameters (mask_rollback and recompute_token_num) to properly track and manage token recomputation during speculative decoding.

Key Changes:

  • Added mask_rollback and recompute_token_num parameters to track token recomputation state
  • Updated CUDA kernel logic to dynamically compute rollback amounts instead of using static values
  • Improved code formatting and consistency in CUDA kernel macros
  • Enhanced test class to properly inherit from unittest.TestCase

Reviewed changes

Copilot reviewed 6 out of 6 changed files in this pull request and generated 6 comments.

Show a summary per file
File Description
tests/operators/test_draft_model_preprocess.py Added new test parameters, fixed test class inheritance, simplified test assertions
fastdeploy/spec_decode/mtp.py Initialized new model input tensors for mask rollback and token recomputation tracking
custom_ops/xpu_ops/src/ops/pybind/pybind.cc Added new parameters to XPU Python binding function signature
custom_ops/xpu_ops/src/ops/mtp/draft_model_preprocess.cc Updated XPU operator registration with new input parameters
custom_ops/gpu_ops/speculate_decoding/draft_model/draft_model_preprocess.cu Implemented dynamic recomputation logic in CUDA kernel, improved code formatting
custom_ops/gpu_ops/cpp_extensions.cc Updated GPU extension function signature with new parameters

Comment on lines +496 to +497
# NOTE(liuzichang): In speculative decoding, accepted tokens' KV cache is recomputed
# using the target model's hidden states.
Copy link

Copilot AI Dec 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment formatting should be improved for consistency. The "NOTE" comments should follow a consistent format throughout the codebase. Consider using a more descriptive multi-line comment format that clearly explains what recompute_token_num represents and why it's initialized to self.num_model_steps - 1.

Suggested change
# NOTE(liuzichang): In speculative decoding, accepted tokens' KV cache is recomputed
# using the target model's hidden states.
# NOTE(liuzichang):
# `recompute_token_num` stores, for each sequence in the batch, the number of tokens
# whose KV cache needs to be recomputed by the target model after speculative decoding.
# We initialize it to `self.num_model_steps - 1` so that, by default, all tokens
# produced in the draft steps (except the final target step) are marked for
# recomputation using the target model's hidden states.

Copilot uses AI. Check for mistakes.
# P-D split need rollback one step
self.model_inputs["recompute_token_num"][idx : idx + 1] = 0
# NOTE(liuzichang):
# extra 1 : P-D split need rollback one step
Copy link

Copilot AI Dec 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment "extra 1 : P-D split need rollback one step" is unclear. Consider rephrasing to: "Extra rollback by 1 for P-D split mode: when Prefill-Decode splitting is enabled, we need to roll back one additional step." This would better explain both what "extra 1" refers to and why the rollback is necessary.

Suggested change
# extra 1 : P-D split need rollback one step
# Extra rollback by 1 for P-D split mode:
# when Prefill-Decode splitting is enabled, we need to roll back one additional step.

Copilot uses AI. Check for mistakes.
Comment on lines +238 to +239
// NOTE(liuzichang): Used for PD-split mode and future dynamic
// strategies.
Copy link

Copilot AI Dec 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment "Used for PD-split mode and future dynamic strategies" could be more descriptive. Consider explaining what aspect of PD-split mode this variable addresses (the dynamic recomputation count for accepted tokens) and what "future dynamic strategies" might entail. This would help future maintainers understand the design intent.

Suggested change
// NOTE(liuzichang): Used for PD-split mode and future dynamic
// strategies.
// NOTE(liuzichang): Track the dynamic recomputation window for this
// sequence. In PD-split mode, this value controls how many accepted
// tokens from the base model will be recomputed on the next step, and
// it is also intended to support future dynamic scheduling/rollback
// strategies that may adjust the recomputation span at runtime.

Copilot uses AI. Check for mistakes.
return inputs, inputs_clone

def test_draft_model_preprocess(self):
results1, results2 = self._run_tests()
Copy link

Copilot AI Dec 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test validation has been simplified from checking specific output indices to a generic loop checking all 12 outputs. While this is more concise, it loses the clarity of what each assertion is checking. Consider adding comments to document which indices correspond to which outputs (e.g., 0=draft_tokens, 1=input_ids, etc.), or using named constants for the indices to maintain test readability.

Suggested change
results1, results2 = self._run_tests()
results1, results2 = self._run_tests()
# The first 12 outputs correspond to:
# 0: draft_tokens
# 1: input_ids
# 2: stop_flags
# 3: seq_lens_this_time
# 4: seq_lens_encoder
# 5: seq_lens_decoder
# 6: step_idx
# 7: not_need_stop
# 8: is_block_step
# 9: batch_drop
# 10: pre_ids
# 11: mask_rollback

Copilot uses AI. Check for mistakes.
is_block_step = paddle.zeros([bsz], dtype="bool")
batch_drop = paddle.zeros([bsz], dtype="bool")
mask_rollback = paddle.zeros([bsz], dtype="int32")
recompute_token_num = paddle.zeros([bsz], dtype="int32")
Copy link

Copilot AI Dec 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test initializes recompute_token_num with zeros, but in the actual implementation in mtp.py it's initialized with self.num_model_steps - 1. This mismatch between test initialization and production initialization could lead to tests not catching edge cases. Consider initializing it to match the production default value to ensure the test is validating the actual behavior.

Copilot uses AI. Check for mistakes.
Comment on lines +359 to +360
mask_rollback = paddle.zeros([bsz], dtype="int32")
recompute_token_num = paddle.zeros([bsz], dtype="int32")
Copy link

Copilot AI Dec 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shape inconsistency: In production code (mtp.py), mask_rollback and recompute_token_num are initialized with shape [self.max_num_seqs, 1], but in the test they use shape [bsz]. While this may work due to array flattening, it's better to maintain consistency. Consider using the same shape in tests as in production code to ensure the test accurately reflects the actual usage pattern.

Suggested change
mask_rollback = paddle.zeros([bsz], dtype="int32")
recompute_token_num = paddle.zeros([bsz], dtype="int32")
mask_rollback = paddle.zeros([bsz, 1], dtype="int32")
recompute_token_num = paddle.zeros([bsz, 1], dtype="int32")

Copilot uses AI. Check for mistakes.
@codecov-commenter
Copy link

Codecov Report

❌ Patch coverage is 50.00000% with 1 line in your changes missing coverage. Please review.
⚠️ Please upload report for BASE (release/2.4@9a8e215). Learn more about missing BASE report.

Files with missing lines Patch % Lines
fastdeploy/spec_decode/mtp.py 50.00% 1 Missing ⚠️
Additional details and impacted files
@@              Coverage Diff               @@
##             release/2.4    #5793   +/-   ##
==============================================
  Coverage               ?   59.06%           
==============================================
  Files                  ?      328           
  Lines                  ?    40732           
  Branches               ?     6204           
==============================================
  Hits                   ?    24057           
  Misses                 ?    14795           
  Partials               ?     1880           
Flag Coverage Δ
GPU 59.06% <50.00%> (?)

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants