Skip to content

add: DFlash block diffusion speculative decoding#1128

Open
ChenhanYu wants to merge 72 commits intomainfrom
chenhany/dflash
Open

add: DFlash block diffusion speculative decoding#1128
ChenhanYu wants to merge 72 commits intomainfrom
chenhany/dflash

Conversation

@ChenhanYu
Copy link
Copy Markdown
Collaborator

@ChenhanYu ChenhanYu commented Mar 27, 2026

Implement DFlash (Block Diffusion for Flash Speculative Decoding) as a new mode in ModelOpt's speculative decoding framework.

Key architecture:

  • Feature Fusion: extract hidden states from uniformly sampled target model layers, project via FC layer
  • KV Injection: fused target features injected as K/V entries in every draft decoder layer's attention (not just first layer input)
  • Parallel Drafting: all tokens in a block predicted simultaneously using learnable mask embeddings and bidirectional within-block attention

Files:

  • dflash/ module: DFlashModel, DFlashConfig, conversion, default config
  • plugins/hf_dflash.py: HFDFlashModel with DFlashAttention (KV injection), DFlashModule (feature fusion + decoder), training forward pass with random anchor sampling and exponential position decay loss
  • main.py: --mode dflash support in training script

Reference: "DFlash: Block Diffusion for Flash Speculative Decoding" (arXiv:2602.06036)

What does this PR do?

Type of change: ?

Usage

# Add a code snippet demonstrating how to use this

Testing

Before your PR is "Ready for review"

Make sure you read and follow Contributor guidelines and your commits are signed (git commit -s -S).

Make sure you read and follow the Security Best Practices (e.g. avoiding hardcoded trust_remote_code=True, torch.load(..., weights_only=False), pickle, etc.).

  • Is this change backward compatible?: ✅ / ❌ / N/A
  • If you copied code from any other sources or added a new PIP dependency, did you follow guidance in CONTRIBUTING.md: ✅ / ❌ / N/A
  • Did you write any new necessary tests?: ✅ / ❌ / N/A
  • Did you update Changelog?: ✅ / ❌ / N/A

Additional Information

Summary by CodeRabbit

  • New Features

    • Added DFlash speculative-decoding mode with end-user config options and sensible defaults (block size, decoder layers, architecture JSON, torch.compile toggle, mask token).
    • CLI/training accepts and runs in DFlash mode; dataset flow supports DFlash.
    • HuggingFace DFlash plugin with draft-module training, draft forward, and pseudo‑speculative generation helpers.
  • Bug Fixes

    • Validation/logging during training gated to rank‑0 with proper synchronization.

@ChenhanYu ChenhanYu requested a review from a team as a code owner March 27, 2026 23:27
@ChenhanYu ChenhanYu requested a review from yeyu-nvidia March 27, 2026 23:27
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Mar 27, 2026

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

Adds DFlash speculative-decoding: new DFlash config and defaults, conversion/restore utilities and registry, a DFlash base model class and HuggingFace plugin implementing training/generation logic, integrates a new "dflash" mode and example/training adjustments.

Changes

Cohort / File(s) Summary
Training Example
examples/speculative_decoding/main.py
Added "dflash" to TrainingArguments.mode; introduced DFlashArguments; parse/bind dflash_args; optionally load JSON config; handle model conversion for mode "dflash"; include "dflash" in dataset construction selection.
Core Speculative Config
modelopt/torch/speculative/config.py
Added _get_dflash_default_config() and DFLASH_DEFAULT_CFG; introduced DFlashConfig with fields controlling block size, num layers, architecture config path, torch.compile toggle, mask token id, freeze/distillation/loss-decay/report flags.
DFlash Package Framework
modelopt/torch/speculative/dflash/__init__.py, modelopt/torch/speculative/dflash/conversion.py, modelopt/torch/speculative/dflash/default_config.py, modelopt/torch/speculative/dflash/dflash_model.py
New package: re-exports submodules; adds default_dflash_config; adds DFlashDMRegistry; conversion/restore entrypoints (convert_to_dflash_model, restore_dflash_model) and merging of defaults; defines DFlashModel base with _setup and modify to apply DFlash config attributes.
Mode Registry Integration
modelopt/torch/speculative/mode.py
Registered DFlashModeDescriptor in SpeculativeDecodingModeRegistry with name="dflash", config_class=DFlashConfig, and wired convert/restore to the DFlash entrypoints.
Speculative Plugins
modelopt/torch/speculative/plugins/__init__.py, modelopt/torch/speculative/plugins/hf_dflash.py
Added conditional plugin import for hf_dflash; new hf_dflash plugin implements HFDFlashModel, registers into DFlashDMRegistry, provides utilities (build_target_layer_ids, create_dflash_attention_mask, create_dflash_loss_mask), full draft module (attention, decoder layers, fusion), training forward (block-wise drafting, masking, CE loss, train accuracy), and pseudo_speculative_generate.
Training Utilities
examples/speculative_decoding/eagle_utils.py
AR validation in EagleTrainingPlot.on_step_end now runs only on master/rank 0 with a post-validation distributed barrier; wandb logging gated by availability within the rank-0 block.

Sequence Diagram(s)

sequenceDiagram
    actor User
    participant BaseModel as Base LLM
    participant HiddenCollector as Hidden State<br/>Collector
    participant Fusion as Feature Fusion<br/>(FC + RMSNorm)
    participant DraftModule as DFlash Draft<br/>Module (decoder stack)
    participant LMHead as Logit Head
    participant Loss as Loss & Accuracy

    User->>BaseModel: input_ids + attention_mask
    BaseModel->>HiddenCollector: forward -> hidden states
    HiddenCollector->>Fusion: collect target-layer states
    Fusion->>DraftModule: fused targets + noise embeddings
    DraftModule->>LMHead: draft hidden -> logits
    LMHead->>Loss: compute CE / accuracy
    Loss-->>User: loss + accuracy

    rect rgba(100,150,200,0.5)
    Note over BaseModel,Loss: DFlash training forward pass
    end
Loading
sequenceDiagram
    actor User
    participant BaseModel as Base LLM
    participant DFlashMod as DFlash Module
    participant BlockBuilder as Block Builder
    participant DraftDecoder as Draft Decoder
    participant TokenSel as Token Selector

    User->>BaseModel: input_ids (context)
    loop for each generation step
        BaseModel->>DFlashMod: base next-token anchor + hidden states
        DFlashMod->>DFlashMod: fuse target layers
        loop for each block position
            BlockBuilder->>DraftDecoder: build noise (anchor + mask), run draft step
            DraftDecoder->>TokenSel: draft logits -> argmax
        end
        DraftDecoder-->>User: base token + draft block
    end

    rect rgba(150,200,100,0.5)
    Note over BaseModel,TokenSel: DFlash pseudo-speculative generation
    end
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

🚥 Pre-merge checks | ✅ 3 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 60.61% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (3 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title clearly and concisely describes the main change: adding DFlash block diffusion speculative decoding as a new feature to the codebase.
Security Anti-Patterns ✅ Passed No security anti-patterns (unsafe torch.load, numpy.load, trust_remote_code, eval/exec) detected in modified files.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch chenhany/dflash

Comment @coderabbitai help to get the list of available commands and usage tips.

@github-actions
Copy link
Copy Markdown
Contributor

github-actions bot commented Mar 27, 2026

PR Preview Action v1.8.1

QR code for preview link

🚀 View preview at
https://NVIDIA.github.io/Model-Optimizer/pr-preview/pr-1128/

Built to branch gh-pages at 2026-04-04 20:12 UTC.
Preview will be ready when the GitHub Pages deployment is complete.

@ChenhanYu ChenhanYu requested a review from h-guo18 March 27, 2026 23:32
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🧹 Nitpick comments (4)
modelopt/torch/speculative/dflash/dflash_model.py (1)

27-34: Add type hint for config parameter.

The config parameter lacks a type annotation. Per project standards, type hints should be provided for static type checking with mypy.

♻️ Proposed fix
+from ..config import DFlashConfig
+
+
 class DFlashModel(DynamicModule):
     """Base DFlash Model."""

     def _setup(self):
         self._register_temp_attribute("dflash_module", None)

-    def modify(self, config):
+    def modify(self, config: DFlashConfig):
         """Base DFlash Model modify function. Child class should implement the details."""
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/speculative/dflash/dflash_model.py` around lines 27 - 34, The
modify method's config parameter is missing a type annotation; update def
modify(self, config) to include the appropriate config type (e.g., def
modify(self, config: DFlashConfig)) and import that type at the top of the
module from wherever the project's config dataclass/typing lives; if a concrete
config type isn't available yet, annotate with typing.Any as a temporary
fallback and add the proper import for Any. Ensure you update imports (from
typing import Any or from <module> import DFlashConfig) and keep the existing
attribute assignments in modify unchanged.
modelopt/torch/speculative/plugins/hf_dflash.py (3)

320-327: Silent exception handling may hide initialization errors.

The broad except Exception: continue pattern silently ignores all errors when locating base model parts, potentially masking genuine issues like attribute errors or type mismatches.

♻️ Proposed improvement to catch only expected exceptions
             for path in paths:
                 try:
                     submodule = self.get_submodule(path)
                     assert isinstance(submodule, torch.nn.Module)
                     setattr(self, name, path)
                     break
-                except Exception:
+                except (AttributeError, AssertionError):
                     continue
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/speculative/plugins/hf_dflash.py` around lines 320 - 327, The
loop that tries to locate submodules uses a broad "except Exception: continue",
which can hide real errors; change it to only catch expected exceptions (e.g.,
except (AttributeError, AssertionError, TypeError): continue) so intentional
lookup failures are ignored but other unexpected exceptions surface (or are
logged/re-raised); keep the calls to self.get_submodule and setattr(self, name,
path) intact but ensure you handle and/or log unexpected exceptions rather than
swallowing them silently.

471-474: Complex anchor sampling logic is hard to follow.

The nested max(), min(), and range() calls make this expression difficult to reason about. Consider breaking it into intermediate variables for clarity and easier debugging.

♻️ Suggested refactor for readability
-        num_blocks = max(1, max_anchor // block_size)
-        # Sample anchor positions uniformly
-        anchors = sorted(
-            random.sample(range(1, max(2, max_anchor)), min(num_blocks, max(1, max_anchor - 1)))
-        )
+        num_blocks = max(1, max_anchor // block_size)
+        # Sample anchor positions uniformly
+        sample_range_end = max(2, max_anchor)
+        sample_count = min(num_blocks, max(1, max_anchor - 1))
+        anchors = sorted(random.sample(range(1, sample_range_end), sample_count))
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/speculative/plugins/hf_dflash.py` around lines 471 - 474, The
anchor sampling expression assigned to anchors is hard to read; break it into
named intermediate variables (e.g., compute max_anchor_bound = max(2,
max_anchor), sample_upper = max(1, max_anchor - 1), num_to_sample =
min(num_blocks, sample_upper), and the range_to_sample = range(1,
max_anchor_bound)) and then call random.sample(range_to_sample, num_to_sample)
and sort the result; update the code in hf_dflash.py where anchors is defined
(the anchors variable in the anchor sampling block) to use these intermediate
names for clarity and easier debugging.

346-347: Accessing private _attn_implementation attribute is fragile.

_attn_implementation is a private attribute of PretrainedConfig that may change between transformers versions without notice.

♻️ Suggested defensive check
-        if self.dflash_config._attn_implementation is None:
+        if getattr(self.dflash_config, "_attn_implementation", None) is None:
             self.dflash_config._attn_implementation = "sdpa"
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/speculative/plugins/hf_dflash.py` around lines 346 - 347,
Replace the fragile direct access to the private attribute
self.dflash_config._attn_implementation with a defensive check using
getattr/hasattr and set via setattr (e.g., current = getattr(self.dflash_config,
"attn_implementation", None) or fallback = getattr(self.dflash_config,
"_attn_implementation", None); if current is None: setattr(self.dflash_config,
"attn_implementation", "sdpa") ), so the code uses public names when present and
only falls back to the underscore name if necessary, ensuring compatibility
across transformers versions.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@modelopt/torch/speculative/plugins/hf_dflash.py`:
- Around line 368-375: The modify() method currently registers forward hooks
each time without removing prior hooks, causing duplicated collections; fix by
tracking hook handles and removing old hooks before registering new ones: add a
persistent attribute (e.g., self._registered_forward_hooks = []) initialized in
the class (constructor), at the start of modify() iterate over
self._registered_forward_hooks calling handle.remove() and then clear the list,
then when registering forward hooks on layers (where you call
layer.register_forward_hook(self._collect_hidden_hook)) capture each returned
handle and append it to self._registered_forward_hooks; also ensure you reset
self._target_hidden_states (and optionally self._cached_masks) when re-modifying
to avoid stale state.
- Around line 55-62: The build_target_layer_ids function collapses to the same
index when num_target_layers <= 4 because start=1 and end=num_target_layers-3
produce span <= 0; change the logic to handle shallow models: if
num_target_layers <= 4 (or if num_sample_layers >= num_target_layers) return a
set of valid unique layer indices (e.g., evenly spaced or simply range(0,
num_target_layers) trimmed to num_sample_layers) and otherwise compute evenly
spaced indices across [0, num_target_layers-1]; update build_target_layer_ids to
clamp/adjust start/end and to limit num_sample_layers to at most
num_target_layers so indices remain unique and within bounds.

---

Nitpick comments:
In `@modelopt/torch/speculative/dflash/dflash_model.py`:
- Around line 27-34: The modify method's config parameter is missing a type
annotation; update def modify(self, config) to include the appropriate config
type (e.g., def modify(self, config: DFlashConfig)) and import that type at the
top of the module from wherever the project's config dataclass/typing lives; if
a concrete config type isn't available yet, annotate with typing.Any as a
temporary fallback and add the proper import for Any. Ensure you update imports
(from typing import Any or from <module> import DFlashConfig) and keep the
existing attribute assignments in modify unchanged.

In `@modelopt/torch/speculative/plugins/hf_dflash.py`:
- Around line 320-327: The loop that tries to locate submodules uses a broad
"except Exception: continue", which can hide real errors; change it to only
catch expected exceptions (e.g., except (AttributeError, AssertionError,
TypeError): continue) so intentional lookup failures are ignored but other
unexpected exceptions surface (or are logged/re-raised); keep the calls to
self.get_submodule and setattr(self, name, path) intact but ensure you handle
and/or log unexpected exceptions rather than swallowing them silently.
- Around line 471-474: The anchor sampling expression assigned to anchors is
hard to read; break it into named intermediate variables (e.g., compute
max_anchor_bound = max(2, max_anchor), sample_upper = max(1, max_anchor - 1),
num_to_sample = min(num_blocks, sample_upper), and the range_to_sample =
range(1, max_anchor_bound)) and then call random.sample(range_to_sample,
num_to_sample) and sort the result; update the code in hf_dflash.py where
anchors is defined (the anchors variable in the anchor sampling block) to use
these intermediate names for clarity and easier debugging.
- Around line 346-347: Replace the fragile direct access to the private
attribute self.dflash_config._attn_implementation with a defensive check using
getattr/hasattr and set via setattr (e.g., current = getattr(self.dflash_config,
"attn_implementation", None) or fallback = getattr(self.dflash_config,
"_attn_implementation", None); if current is None: setattr(self.dflash_config,
"attn_implementation", "sdpa") ), so the code uses public names when present and
only falls back to the underscore name if necessary, ensuring compatibility
across transformers versions.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 99f4f5f4-7d81-4e34-b82f-d958548f8d6d

📥 Commits

Reviewing files that changed from the base of the PR and between 2bad66c and c08fb9c.

📒 Files selected for processing (9)
  • examples/speculative_decoding/main.py
  • modelopt/torch/speculative/config.py
  • modelopt/torch/speculative/dflash/__init__.py
  • modelopt/torch/speculative/dflash/conversion.py
  • modelopt/torch/speculative/dflash/default_config.py
  • modelopt/torch/speculative/dflash/dflash_model.py
  • modelopt/torch/speculative/mode.py
  • modelopt/torch/speculative/plugins/__init__.py
  • modelopt/torch/speculative/plugins/hf_dflash.py

Comment on lines +55 to +62
def build_target_layer_ids(num_target_layers, num_sample_layers):
"""Select layers uniformly from the target model for feature extraction."""
if num_sample_layers == 1:
return [num_target_layers // 2]
start = 1
end = num_target_layers - 3
span = end - start
return [round(start + (i * span) / (num_sample_layers - 1)) for i in range(num_sample_layers)]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Edge case: small num_target_layers values produce degenerate sampling.

When num_target_layers <= 4, end <= start (e.g., num_target_layers=4start=1, end=1, span=0). All returned layer indices collapse to the same value, defeating uniform sampling. Consider adding a guard or adjusting the formula for shallow target models.

🛡️ Proposed fix to handle edge case
 def build_target_layer_ids(num_target_layers, num_sample_layers):
     """Select layers uniformly from the target model for feature extraction."""
     if num_sample_layers == 1:
         return [num_target_layers // 2]
+    if num_target_layers <= 4:
+        # For very shallow models, sample from all available layers
+        start, end = 0, num_target_layers - 1
+    else:
+        start = 1
+        end = num_target_layers - 3
-    start = 1
-    end = num_target_layers - 3
     span = end - start
+    if span <= 0:
+        return [start] * num_sample_layers
     return [round(start + (i * span) / (num_sample_layers - 1)) for i in range(num_sample_layers)]
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/speculative/plugins/hf_dflash.py` around lines 55 - 62, The
build_target_layer_ids function collapses to the same index when
num_target_layers <= 4 because start=1 and end=num_target_layers-3 produce span
<= 0; change the logic to handle shallow models: if num_target_layers <= 4 (or
if num_sample_layers >= num_target_layers) return a set of valid unique layer
indices (e.g., evenly spaced or simply range(0, num_target_layers) trimmed to
num_sample_layers) and otherwise compute evenly spaced indices across [0,
num_target_layers-1]; update build_target_layer_ids to clamp/adjust start/end
and to limit num_sample_layers to at most num_target_layers so indices remain
unique and within bounds.

@codecov
Copy link
Copy Markdown

codecov bot commented Mar 27, 2026

Codecov Report

❌ Patch coverage is 43.64261% with 328 lines in your changes missing coverage. Please review.
✅ Project coverage is 54.59%. Comparing base (ada1e26) to head (3a8ff9c).
⚠️ Report is 21 commits behind head on main.

Files with missing lines Patch % Lines
modelopt/torch/speculative/plugins/hf_dflash.py 42.01% 247 Missing ⚠️
...delopt/torch/utils/plugins/transformers_dataset.py 0.00% 51 Missing ⚠️
modelopt/torch/export/plugins/hf_spec_export.py 17.64% 28 Missing ⚠️
modelopt/torch/speculative/config.py 84.61% 2 Missing ⚠️

❗ There is a different number of reports uploaded between BASE (ada1e26) and HEAD (3a8ff9c). Click for more details.

HEAD has 2 uploads less than BASE
Flag BASE (ada1e26) HEAD (3a8ff9c)
2 0
Additional details and impacted files
@@             Coverage Diff             @@
##             main    #1128       +/-   ##
===========================================
- Coverage   70.18%   54.59%   -15.59%     
===========================================
  Files         230      355      +125     
  Lines       26080    40579    +14499     
===========================================
+ Hits        18304    22155     +3851     
- Misses       7776    18424    +10648     
Flag Coverage Δ
unit 54.59% <43.64%> (?)

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (4)
modelopt/torch/speculative/plugins/hf_dflash.py (4)

307-309: Device placement assumes _base_model.layers exists and is non-empty.

If the base model has a different structure (e.g., no layers attribute or empty layers), this will raise an AttributeError or IndexError.

🛡️ More robust device detection
         self.dflash_module = DFlashModule(self.dflash_config)
-        self.dflash_module.to(self._base_model.dtype).to(
-            next(self._base_model.layers[-1].parameters()).device
-        )
+        # Get device from any base model parameter
+        base_device = next(self._base_model.parameters()).device
+        base_dtype = next(self._base_model.parameters()).dtype
+        self.dflash_module.to(base_dtype).to(base_device)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/speculative/plugins/hf_dflash.py` around lines 307 - 309, The
placement code for self.dflash_module assumes self._base_model.layers exists and
has elements; instead, detect device robustly by checking for a non-empty
self._base_model.layers and using
next(self._base_model.layers[-1].parameters()).device only when present,
otherwise fall back to using next(self._base_model.parameters()).device (or a
CPU default if parameters are absent); apply the chosen device and the target
dtype (self._base_model.dtype) to self.dflash_module so device/dtype setting
works for models without a layers attribute or with empty layers.

82-82: Unused attribute is_causal.

The is_causal attribute is set to False but never referenced. The value is hardcoded directly in scaled_dot_product_attention call at line 129.

♻️ Suggested cleanup
         self.num_kv_heads = config.num_key_value_heads
         self.scaling = self.head_dim**-0.5
-        self.is_causal = False
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/speculative/plugins/hf_dflash.py` at line 82, The instance
attribute self.is_causal is defined but never used; either remove its assignment
or wire it into the attention call — replace the hardcoded False in the
scaled_dot_product_attention invocation with self.is_causal (or delete the
self.is_causal assignment if you prefer no flag). Update the attribute in the
class initializer where self.is_causal is set and the
scaled_dot_product_attention call at the line that currently passes False so
they are consistent.

412-414: Zero loss tensor with requires_grad=True won't propagate gradients.

When active_logits.numel() == 0, the returned loss tensor is a constant 0.0 with requires_grad=True. While this avoids errors during backward pass, it produces a gradient of zero for all parameters. Consider logging a warning when this edge case occurs.

💡 Add warning for visibility
+        import logging
+        logger = logging.getLogger(__name__)
+
         if active_logits.numel() > 0:
             loss = F.cross_entropy(active_logits, active_labels)
             with torch.no_grad():
                 preds = active_logits.argmax(dim=-1)
                 accuracy = (preds == active_labels).float().mean().item()
         else:
+            logger.warning("No active positions for loss computation")
             loss = torch.tensor(0.0, device=device, dtype=dtype, requires_grad=True)
             accuracy = 0.0
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/speculative/plugins/hf_dflash.py` around lines 412 - 414, When
active_logits.numel() == 0 the code returns loss = torch.tensor(0.0,
device=device, dtype=dtype, requires_grad=True) which yields zero gradients;
update the branch handling this case to also emit a clear warning so the
condition is visible in logs (e.g., log that active_logits is empty and loss is
a zero tensor) and keep returning the tensor and accuracy as currently done.
Locate the branch using active_logits.numel() == 0 (the else that sets loss and
accuracy), add a warning via the module/logger used in this file (e.g.,
logger.warning or process_logger) mentioning the function/context (hf_dflash
loss computation) and the tensor/device/dtype, and ensure no change to the
returned tensor shape or requires_grad behavior.

285-286: Accessing private HuggingFace attribute _attn_implementation.

_attn_implementation is a private/internal attribute in HuggingFace configs that may change without notice. Consider using the public API or documenting this dependency.

♻️ More defensive approach
-        if self.dflash_config._attn_implementation is None:
-            self.dflash_config._attn_implementation = "eager"
+        if getattr(self.dflash_config, "_attn_implementation", None) is None:
+            self.dflash_config._attn_implementation = "eager"
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/speculative/plugins/hf_dflash.py` around lines 285 - 286, The
code currently writes directly to the private HuggingFace attribute
self.dflash_config._attn_implementation; instead, make this defensive by
checking for the attribute's existence with getattr/hasattr and only set it when
present, otherwise fall back to a documented public config or internal default
and log/warn about relying on a private attribute; update the assignment around
self.dflash_config._attn_implementation to use a safe check
(getattr(self.dflash_config, "_attn_implementation", None)) and a clear fallback
path so the plugin won't break if the private field is removed in future.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@modelopt/torch/speculative/plugins/hf_dflash.py`:
- Around line 47-56: The function build_target_layer_ids should avoid the
redundant int(round(...)) and handle shallow models where span <= 0; change the
rounding to just round(...) (or remove the int() wrapper) and add a guard when
num_target_layers <= 4 (or when end <= start) to produce sensible unique
indices: compute end = max(1, num_target_layers - 3), compute span = max(1, end
- start), and if num_draft_layers > (end - start + 1) clamp num_draft_layers to
that maximum so the list comprehension in build_target_layer_ids returns evenly
spaced, non-duplicated layer indices (use the existing variables
num_target_layers, num_draft_layers, start, end, span in the fix).

---

Nitpick comments:
In `@modelopt/torch/speculative/plugins/hf_dflash.py`:
- Around line 307-309: The placement code for self.dflash_module assumes
self._base_model.layers exists and has elements; instead, detect device robustly
by checking for a non-empty self._base_model.layers and using
next(self._base_model.layers[-1].parameters()).device only when present,
otherwise fall back to using next(self._base_model.parameters()).device (or a
CPU default if parameters are absent); apply the chosen device and the target
dtype (self._base_model.dtype) to self.dflash_module so device/dtype setting
works for models without a layers attribute or with empty layers.
- Line 82: The instance attribute self.is_causal is defined but never used;
either remove its assignment or wire it into the attention call — replace the
hardcoded False in the scaled_dot_product_attention invocation with
self.is_causal (or delete the self.is_causal assignment if you prefer no flag).
Update the attribute in the class initializer where self.is_causal is set and
the scaled_dot_product_attention call at the line that currently passes False so
they are consistent.
- Around line 412-414: When active_logits.numel() == 0 the code returns loss =
torch.tensor(0.0, device=device, dtype=dtype, requires_grad=True) which yields
zero gradients; update the branch handling this case to also emit a clear
warning so the condition is visible in logs (e.g., log that active_logits is
empty and loss is a zero tensor) and keep returning the tensor and accuracy as
currently done. Locate the branch using active_logits.numel() == 0 (the else
that sets loss and accuracy), add a warning via the module/logger used in this
file (e.g., logger.warning or process_logger) mentioning the function/context
(hf_dflash loss computation) and the tensor/device/dtype, and ensure no change
to the returned tensor shape or requires_grad behavior.
- Around line 285-286: The code currently writes directly to the private
HuggingFace attribute self.dflash_config._attn_implementation; instead, make
this defensive by checking for the attribute's existence with getattr/hasattr
and only set it when present, otherwise fall back to a documented public config
or internal default and log/warn about relying on a private attribute; update
the assignment around self.dflash_config._attn_implementation to use a safe
check (getattr(self.dflash_config, "_attn_implementation", None)) and a clear
fallback path so the plugin won't break if the private field is removed in
future.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 3311c2cf-da29-4e8f-bbc0-c97039583962

📥 Commits

Reviewing files that changed from the base of the PR and between c08fb9c and 528f2bf.

📒 Files selected for processing (1)
  • modelopt/torch/speculative/plugins/hf_dflash.py

Comment on lines +47 to +56
def build_target_layer_ids(num_target_layers, num_draft_layers):
"""Select layers uniformly from the target model for feature extraction."""
if num_draft_layers == 1:
return [num_target_layers // 2]
start = 1
end = num_target_layers - 3
span = end - start
return [
int(round(start + (i * span) / (num_draft_layers - 1))) for i in range(num_draft_layers)
]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Fix linting error and handle edge case for shallow target models.

Two issues:

  1. Pipeline failure (Line 55): round() already returns an int in Python 3, making the int() wrapper redundant.

  2. Edge case: When num_target_layers <= 4, span becomes ≤ 0 (e.g., num_target_layers=4start=1, end=1, span=0), causing all returned indices to collapse to the same value.

🐛 Proposed fix
 def build_target_layer_ids(num_target_layers, num_draft_layers):
     """Select layers uniformly from the target model for feature extraction."""
     if num_draft_layers == 1:
         return [num_target_layers // 2]
-    start = 1
-    end = num_target_layers - 3
+    # For shallow models, use full range; otherwise skip first and last few layers
+    if num_target_layers <= 4:
+        start, end = 0, num_target_layers - 1
+    else:
+        start, end = 1, num_target_layers - 3
     span = end - start
+    if span <= 0 or num_draft_layers > num_target_layers:
+        # Fallback: return middle layer repeated
+        return [num_target_layers // 2] * num_draft_layers
     return [
-        int(round(start + (i * span) / (num_draft_layers - 1))) for i in range(num_draft_layers)
+        round(start + (i * span) / (num_draft_layers - 1)) for i in range(num_draft_layers)
     ]
🧰 Tools
🪛 GitHub Actions: Code Quality

[error] 55-55: ruff check failed (RUF046). Value being cast to int is already an integer. Line suggests int(round(...)); help: remove unnecessary int call.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/speculative/plugins/hf_dflash.py` around lines 47 - 56, The
function build_target_layer_ids should avoid the redundant int(round(...)) and
handle shallow models where span <= 0; change the rounding to just round(...)
(or remove the int() wrapper) and add a guard when num_target_layers <= 4 (or
when end <= start) to produce sensible unique indices: compute end = max(1,
num_target_layers - 3), compute span = max(1, end - start), and if
num_draft_layers > (end - start + 1) clamp num_draft_layers to that maximum so
the list comprehension in build_target_layer_ids returns evenly spaced,
non-duplicated layer indices (use the existing variables num_target_layers,
num_draft_layers, start, end, span in the fix).

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (3)
modelopt/torch/speculative/plugins/hf_dflash.py (3)

364-364: Replace print() with logging.

Production code should use the logging module for better control over output verbosity.

♻️ Proposed fix

Add import at top of file:

import logging

logger = logging.getLogger(__name__)

Then replace the print:

-        print(f"DFlash: using {original_cls.__name__}.forward as base forward")
+        logger.info(f"DFlash: using {original_cls.__name__}.forward as base forward")
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/speculative/plugins/hf_dflash.py` at line 364, Replace the
stray print call with structured logging: add an import for the logging module
and initialize a module-level logger (e.g., logger =
logging.getLogger(__name__)), then replace the print(f"DFlash: using
{original_cls.__name__}.forward as base forward") in hf_dflash.py with an
appropriate logger call (logger.info or logger.debug) referencing
original_cls.__name__ to retain the same message content; ensure the logger is
defined at top of the file so the statement in the speculative plugin uses it.

260-269: Overly broad exception handling.

Catching Exception may mask unexpected errors (e.g., AttributeError, TypeError). Consider narrowing to the specific exceptions expected from get_submodule.

♻️ Proposed fix
             for path in paths:
                 try:
                     submodule = self.get_submodule(path)
                     assert isinstance(submodule, torch.nn.Module)
                     setattr(self, name, path)
                     break
-                except Exception:
+                except (AttributeError, AssertionError):
                     continue
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/speculative/plugins/hf_dflash.py` around lines 260 - 269,
Narrow the broad except in the for-loop that calls get_submodule: replace
"except Exception" with a targeted exception tuple such as "except
(AttributeError, KeyError):" (these are the likely errors when a submodule path
is not found) so only lookup-related failures are silenced while letting other
errors (TypeError, AssertionError, etc.) surface; keep the rest of the logic in
the loop (get_submodule, isinstance check, setattr(self, name, path) and the
final ValueError) unchanged.

82-82: Unused attribute is_causal.

self.is_causal is assigned but never referenced. Line 129 hardcodes is_causal=False directly.

♻️ Proposed fix
-        self.is_causal = False
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/speculative/plugins/hf_dflash.py` at line 82, The attribute
self.is_causal is assigned but never used; replace the hardcoded is_causal=False
occurrence in this module with the instance attribute so the class-level flag is
honored: initialize self.is_causal in the class constructor (as currently
present) and change the call/site that currently passes is_causal=False to pass
is_causal=self.is_causal instead; alternatively, if causality is never intended
to be configurable, remove the unused self.is_causal assignment and keep the
hardcoded False—prefer the first option to preserve configurability.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@modelopt/torch/speculative/plugins/hf_dflash.py`:
- Around line 510-521: The debug print block guarded by self._psg_debug should
be removed or replaced with proper logging and a safe batch-aware access; locate
the block around _psg_debug that inspects base_outputs.hidden_states,
base_token, mask_token_id, dflash_block_size and calls
self._base_model_embeddings and either delete it or change prints to
logger.debug(...) and replace the unsafe base_token.item() with a batch-safe
access such as base_token[0,0].item() so it won't fail when bsz > 1; keep the
one-time flag behavior if you want this to run only once.

---

Nitpick comments:
In `@modelopt/torch/speculative/plugins/hf_dflash.py`:
- Line 364: Replace the stray print call with structured logging: add an import
for the logging module and initialize a module-level logger (e.g., logger =
logging.getLogger(__name__)), then replace the print(f"DFlash: using
{original_cls.__name__}.forward as base forward") in hf_dflash.py with an
appropriate logger call (logger.info or logger.debug) referencing
original_cls.__name__ to retain the same message content; ensure the logger is
defined at top of the file so the statement in the speculative plugin uses it.
- Around line 260-269: Narrow the broad except in the for-loop that calls
get_submodule: replace "except Exception" with a targeted exception tuple such
as "except (AttributeError, KeyError):" (these are the likely errors when a
submodule path is not found) so only lookup-related failures are silenced while
letting other errors (TypeError, AssertionError, etc.) surface; keep the rest of
the logic in the loop (get_submodule, isinstance check, setattr(self, name,
path) and the final ValueError) unchanged.
- Line 82: The attribute self.is_causal is assigned but never used; replace the
hardcoded is_causal=False occurrence in this module with the instance attribute
so the class-level flag is honored: initialize self.is_causal in the class
constructor (as currently present) and change the call/site that currently
passes is_causal=False to pass is_causal=self.is_causal instead; alternatively,
if causality is never intended to be configurable, remove the unused
self.is_causal assignment and keep the hardcoded False—prefer the first option
to preserve configurability.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 90e791a0-d745-4cf9-8a9d-1600e8d1ad34

📥 Commits

Reviewing files that changed from the base of the PR and between 528f2bf and 78ef707.

📒 Files selected for processing (3)
  • examples/speculative_decoding/main.py
  • modelopt/torch/speculative/dflash/default_config.py
  • modelopt/torch/speculative/plugins/hf_dflash.py
✅ Files skipped from review due to trivial changes (1)
  • modelopt/torch/speculative/dflash/default_config.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • examples/speculative_decoding/main.py

Comment on lines +510 to +521
if not hasattr(self, '_psg_debug'):
self._psg_debug = True
sel = [base_outputs.hidden_states[lid + hid_offset] for lid in self.target_layer_ids]
th = torch.cat(sel, dim=-1)
print(f"[psg] hidden_states layers: {len(base_outputs.hidden_states)}, target_hidden norm: {th.norm().item():.2f}, shape: {th.shape}")
print(f"[psg] base_token: {base_token.item()}, mask_token_id: {self.mask_token_id}")
print(f"[psg] block_ids: {[self.mask_token_id]*self.dflash_block_size}")
bi = torch.full((1, self.dflash_block_size), self.mask_token_id, dtype=torch.long, device=input_ids.device)
bi[0, 0] = base_token[0, 0]
ne = self._base_model_embeddings(bi)
print(f"[psg] noise_emb norm: {ne.norm().item():.2f}, shape: {ne.shape}")
print(f"[psg] pos_ids will be: ctx=[0..{input_ids.shape[1]-1}], blk=[{input_ids.shape[1]}..{input_ids.shape[1]+self.dflash_block_size-1}]")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Remove debug print statements from production code.

These debug prints are guarded by _psg_debug but will still execute once per model instance. Additionally, line 515 uses base_token.item() which will fail when bsz > 1.

🐛 Proposed fix

Either remove the debug block entirely, or convert to proper logging:

+import logging
+logger = logging.getLogger(__name__)
+
 ...
 
-        if not hasattr(self, '_psg_debug'):
-            self._psg_debug = True
-            sel = [base_outputs.hidden_states[lid + hid_offset] for lid in self.target_layer_ids]
-            th = torch.cat(sel, dim=-1)
-            print(f"[psg] hidden_states layers: {len(base_outputs.hidden_states)}, target_hidden norm: {th.norm().item():.2f}, shape: {th.shape}")
-            print(f"[psg] base_token: {base_token.item()}, mask_token_id: {self.mask_token_id}")
-            print(f"[psg] block_ids: {[self.mask_token_id]*self.dflash_block_size}")
-            bi = torch.full((1, self.dflash_block_size), self.mask_token_id, dtype=torch.long, device=input_ids.device)
-            bi[0, 0] = base_token[0, 0]
-            ne = self._base_model_embeddings(bi)
-            print(f"[psg] noise_emb norm: {ne.norm().item():.2f}, shape: {ne.shape}")
-            print(f"[psg] pos_ids will be: ctx=[0..{input_ids.shape[1]-1}], blk=[{input_ids.shape[1]}..{input_ids.shape[1]+self.dflash_block_size-1}]")

If debug logging is needed, use logger.debug() and fix the batch-size issue:

logger.debug(f"[psg] base_token: {base_token[0, 0].item()}, mask_token_id: {self.mask_token_id}")
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if not hasattr(self, '_psg_debug'):
self._psg_debug = True
sel = [base_outputs.hidden_states[lid + hid_offset] for lid in self.target_layer_ids]
th = torch.cat(sel, dim=-1)
print(f"[psg] hidden_states layers: {len(base_outputs.hidden_states)}, target_hidden norm: {th.norm().item():.2f}, shape: {th.shape}")
print(f"[psg] base_token: {base_token.item()}, mask_token_id: {self.mask_token_id}")
print(f"[psg] block_ids: {[self.mask_token_id]*self.dflash_block_size}")
bi = torch.full((1, self.dflash_block_size), self.mask_token_id, dtype=torch.long, device=input_ids.device)
bi[0, 0] = base_token[0, 0]
ne = self._base_model_embeddings(bi)
print(f"[psg] noise_emb norm: {ne.norm().item():.2f}, shape: {ne.shape}")
print(f"[psg] pos_ids will be: ctx=[0..{input_ids.shape[1]-1}], blk=[{input_ids.shape[1]}..{input_ids.shape[1]+self.dflash_block_size-1}]")
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/speculative/plugins/hf_dflash.py` around lines 510 - 521, The
debug print block guarded by self._psg_debug should be removed or replaced with
proper logging and a safe batch-aware access; locate the block around _psg_debug
that inspects base_outputs.hidden_states, base_token, mask_token_id,
dflash_block_size and calls self._base_model_embeddings and either delete it or
change prints to logger.debug(...) and replace the unsafe base_token.item() with
a batch-safe access such as base_token[0,0].item() so it won't fail when bsz >
1; keep the one-time flag behavior if you want this to run only once.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (1)
examples/speculative_decoding/eagle_utils.py (1)

241-249: Cache the MT-Bench dataset outside the synchronized section.

load_dataset(...) now sits on the global stall path, so every worker waits for rank 0 to redo the hub/cache lookup before training resumes. Reusing a cached dataset here would cut that idle time.

♻️ Proposed refactor
 class EagleTrainingPlot(TrainerCallback):
     def __init__(self, ar_validate_steps: int = 1000, estimate_ar: bool = False):
         self.ar_validate_steps = ar_validate_steps
         if wandb and is_master():
             wandb.init()
         self.estimate_ar = estimate_ar
+        self._ar_validation_ds = None
@@
             if is_master():
                 print_rank_0("Running AR validation...")
                 try:
+                    if self._ar_validation_ds is None:
+                        self._ar_validation_ds = load_dataset("HuggingFaceH4/mt_bench_prompts")[
+                            "train"
+                        ]
                     ars = validate_ar(
                         model=kwargs["model"],
                         tokenizer=kwargs["processing_class"],
-                        ds=load_dataset("HuggingFaceH4/mt_bench_prompts")["train"],
+                        ds=self._ar_validation_ds,
                         device=kwargs["model"].device,
                     )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/speculative_decoding/eagle_utils.py` around lines 241 - 249, The
call to load_dataset("HuggingFaceH4/mt_bench_prompts") is happening inside the
synchronized/critical section causing all ranks to wait; move the dataset
load/cache out of that section and reuse a single cached dataset reference when
calling validate_ar (keep calling validate_ar(model=kwargs["model"],
tokenizer=kwargs["processing_class"], ds=cached_mt_bench_ds,
device=kwargs["model"].device)), ensuring cached_mt_bench_ds is initialized once
(e.g., at module import or rank-0 setup) and shared by workers so only the heavy
hub/cache lookup happens once while still preserving use of state.global_step,
ars, and wandb logging.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@examples/speculative_decoding/eagle_utils.py`:
- Around line 238-254: Load the validation dataset before the master-only block
so non-master ranks don't waste time; specifically, when checking the AR
validation trigger (state.global_step % self.ar_validate_steps == 0 and
state.global_step > 0) call
load_dataset("HuggingFaceH4/mt_bench_prompts")["train"] into a local variable
(e.g., ds) before the is_master() check, then inside the is_master() block call
validate_ar(model=kwargs["model"], tokenizer=kwargs["processing_class"], ds=ds,
device=kwargs["model"].device) as before; keep print_rank_0, the try/except
around validate_ar, and the torch.distributed.barrier() after the block to
preserve synchronization.

---

Nitpick comments:
In `@examples/speculative_decoding/eagle_utils.py`:
- Around line 241-249: The call to
load_dataset("HuggingFaceH4/mt_bench_prompts") is happening inside the
synchronized/critical section causing all ranks to wait; move the dataset
load/cache out of that section and reuse a single cached dataset reference when
calling validate_ar (keep calling validate_ar(model=kwargs["model"],
tokenizer=kwargs["processing_class"], ds=cached_mt_bench_ds,
device=kwargs["model"].device)), ensuring cached_mt_bench_ds is initialized once
(e.g., at module import or rank-0 setup) and shared by workers so only the heavy
hub/cache lookup happens once while still preserving use of state.global_step,
ars, and wandb logging.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 43b9f242-1583-40b4-8db9-7a3aeb9b07cd

📥 Commits

Reviewing files that changed from the base of the PR and between 07f4312 and f23a54a.

📒 Files selected for processing (1)
  • examples/speculative_decoding/eagle_utils.py

@ChenhanYu ChenhanYu requested a review from a team as a code owner March 31, 2026 02:57
@ChenhanYu ChenhanYu requested a review from shengliangxu March 31, 2026 02:57
@h-guo18
Copy link
Copy Markdown
Contributor

h-guo18 commented Mar 31, 2026

LGTM. Shall we also add some unit/example tests in this PR

ChenhanYu and others added 17 commits March 31, 2026 18:39
Implement DFlash (Block Diffusion for Flash Speculative Decoding) as a
new mode in ModelOpt's speculative decoding framework.

Key architecture:
- Feature Fusion: extract hidden states from uniformly sampled target
  model layers, project via FC layer
- KV Injection: fused target features injected as K/V entries in every
  draft decoder layer's attention (not just first layer input)
- Parallel Drafting: all tokens in a block predicted simultaneously
  using learnable mask embeddings and bidirectional within-block attention

Files:
- dflash/ module: DFlashModel, DFlashConfig, conversion, default config
- plugins/hf_dflash.py: HFDFlashModel with DFlashAttention (KV injection),
  DFlashModule (feature fusion + decoder), training forward pass with
  random anchor sampling and exponential position decay loss
- main.py: --mode dflash support in training script

Reference: "DFlash: Block Diffusion for Flash Speculative Decoding"
(arXiv:2602.06036)

Signed-off-by: Chenhan Yu <chenhany@nvidia.com>
Signed-off-by: Chenhan Yu <chenhany@nvidia.com>
Key fixes:
- mask_token_id now read from dflash_architecture_config (e.g., 248070
  for Qwen3) instead of defaulting to pad/eos token. Wrong mask_token_id
  caused garbage draft output despite correct weights.
- Inherit model config from base model only as defaults; allow draft to
  have different num_heads/intermediate_size (needed for z-lab checkpoint)
- Clean default_dflash_config to only contain DFlash-specific settings
- pseudo_speculative_generate returns single block of tokens
- Add dflash_mask_token_id CLI argument to main.py

Validated: z-lab/Qwen3.5-4B-DFlash checkpoint produces AR=7.28 (expected ~6.08)

Signed-off-by: Chenhan Yu <chenhany@nvidia.com>
Resolution order:
1. Explicit in dflash_architecture_config (user override)
2. Auto-detect from model vocabulary:
   - Qwen3/3.5: built-in [MASK] token (e.g., 248070)
   - Llama3: reserved_special_token_0 (128002)
   - Others: pad_token_id fallback
3. CLI override via --dflash_mask_token_id

Based on z-lab checkpoints:
- z-lab/Qwen3.5-4B-DFlash: mask=248070
- z-lab/LLaMA3.1-8B-Instruct-DFlash: mask=128002
- z-lab/gpt-oss-20b-DFlash: mask=200000

Signed-off-by: Chenhan Yu <chenhany@nvidia.com>
AR validation runs pseudo_speculative_generate which does unsynchronized
model forward passes. In multi-GPU DDP training, this caused NCCL
timeout because other ranks were waiting at gradient sync.

Fix: only run validate_ar on rank 0 (is_master()), add
torch.distributed.barrier() after to synchronize all ranks.

Signed-off-by: Chenhan Yu <chenhany@nvidia.com>
super().forward() from HFDFlashModel goes through DynamicModule which
dispatches back to HFDFlashModel.forward(), causing infinite recursion
→ stack overflow → NCCL timeout in multi-GPU training.

Fix: use self._base_model() directly (same as pseudo_speculative_generate)
for both eval-mode and training base model forward passes.

Signed-off-by: Chenhan Yu <chenhany@nvidia.com>
The DynamicModule MRO correctly dispatches super().forward() to the
original model class (e.g., Qwen3_5ForCausalLM.forward()) without
looping — same pattern EAGLE uses successfully.

The previous self._base_model() approach bypassed DDP, causing NCCL
timeout because DDP's gradient sync couldn't track the forward pass.

Keep pseudo_speculative_generate using self._base_model() since that
runs outside DDP (single GPU AR validation).

Signed-off-by: Chenhan Yu <chenhany@nvidia.com>
When a rank's batch has no valid loss positions (e.g., all tokens in
Block 0 which is excluded), the loss was a detached zero tensor with
no connection to dflash_module parameters. DDP waited forever for
gradient sync on those parameters → NCCL ALLREDUCE timeout.

Fix: use logits.sum() * 0.0 as zero loss, which maintains the
computation graph through dflash_module parameters so DDP can sync
zero gradients properly.

Also revert to super().forward() for training (matching EAGLE pattern)
and add --ddp_find_unused_parameters True, --ddp_timeout 300.

Root cause analysis: rank 4 completed ALLREDUCE #272 and proceeded to
ALLGATHER #273, while other ranks were stuck at ALLREDUCE #272. This
indicated rank 4 had a different backward graph (no gradients for
dflash_module on that rank).

Signed-off-by: Chenhan Yu <chenhany@nvidia.com>
Add --dflash_use_logit_distillation flag that switches from hard CE
loss (predict ground truth tokens) to logit distillation (learn from
target model's output distribution).

Hard CE only works when training data is synthesized by the target
model itself. Logit distillation works with any data because it learns
from the target model's actual predictions, not the ground truth.

Usage:
  python main.py --mode dflash --dflash_use_logit_distillation ...

Config: dflash_self_logit_distillation (default=True in config,
toggled via CLI flag)

Signed-off-by: Chenhan Yu <chenhany@nvidia.com>
Signed-off-by: Chenhan Yu <chenhany@nvidia.com>
Pass answer_only_loss=True to LanguageDataCollator for DFlash mode.
This makes the tokenizer return assistant_masks via apply_chat_template
with return_assistant_tokens_mask=True.

HFDFlashModel.forward() now checks for assistant_masks in kwargs and
uses it as loss_mask instead of attention_mask. This matches SpecForge's
behavior of only computing loss on response tokens.

SpecForge-trained checkpoint (response-only mask): AR=1.95
ModelOpt-trained checkpoint (all tokens mask): AR=1.15
Both with 30-35% training accuracy on same data.

Signed-off-by: Chenhan Yu <chenhany@nvidia.com>
When answer_only_loss=True, set labels=-100 for non-assistant tokens
using the assistant_masks from tokenizer.apply_chat_template. This
ensures DFlash forward() can derive response-only loss mask from
labels != -100, without relying on HF Trainer to pass assistant_masks.

Also revert hf_dflash.py to use labels-based loss mask instead of
kwargs-based assistant_masks (Trainer strips unknown keys).

Signed-off-by: Chenhan Yu <chenhany@nvidia.com>
When answer_only_loss=True and the tokenizer's return_assistant_tokens_mask
returns empty/unsupported results, fall back to regex-based detection of
assistant spans in the formatted text (similar to SpecForge's approach).

Supports Qwen/ChatML, Llama3, Llama2, and generic assistant patterns.
Uses tokenizer offset_mapping to map character spans to token positions.

DFlash forward uses labels != -100 to derive the response-only loss mask.

Signed-off-by: Chenhan Yu <chenhany@nvidia.com>
Documents DFlash architecture, training usage, mask_token_id auto-detection,
and current status including the known AR gap from data pipeline differences.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: Chenhan Yu <chenhany@nvidia.com>
Instead of hardcoding Llama components (LlamaMLP, LlamaRMSNorm,
LlamaRotaryEmbedding), dynamically resolve them from the base model's
transformers module (e.g., Qwen3MLP for Qwen3 models). Falls back
to Llama components for unknown model types.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: Chenhan Yu <chenhany@nvidia.com>
Two bugs prevented response-only masking from working:

1. main.py never passed answer_only_loss=True to the data collator
   for DFlash mode, so all tokens had labels (511/512 instead of
   response-only).

2. HFDFlashModel.forward() used attention_mask (padding mask) for
   loss masking instead of labels. When answer_only_loss is enabled,
   the response-only information is in labels (where -100 = ignore),
   but this was completely ignored. Now uses labels when available.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: Chenhan Yu <chenhany@nvidia.com>
- Add common/dflash/online_training.sh for launcher
- Add examples/Qwen/Qwen3-8B/hf_online_dflash.yaml
- Add --mode dflash support to launch_train.sh with DFlash-specific
  args (block_size, num_layers, mask_token_id, config)
- DFlash uses DDP instead of FSDP for training

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: Chenhan Yu <chenhany@nvidia.com>
ChenhanYu and others added 16 commits April 3, 2026 09:43
Previously kept all tokens as valid when regex couldn't find
assistant spans, causing training on system/user tokens and
inflated per-token accuracy (70% vs SpecForge's 15%). Now masks
everything, so these samples contribute zero loss.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: Chenhan Yu <chenhany@nvidia.com>
Replace complex chat templates with simplified versions that include
{% generation %} tags when answer_only_loss=True. Supports ChatML
(Qwen, Phi) and Llama3 template styles.

This fixes the inflated per-token accuracy (70% vs SpecForge's 15%)
caused by the regex fallback silently training on system/user tokens
when {% generation %} tags were missing.

The simplified templates correctly:
- Mark only assistant content for loss (including <think> blocks)
- Support multi-turn conversations
- Mask system and user tokens

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: Chenhan Yu <chenhany@nvidia.com>
Qwen3's original template auto-injects <think>\n\n</think>\n\n before
assistant content. Match this in our simplified template by adding
the think wrapper when content doesn't already start with <think>.

Minor difference from original: we add it to all assistant turns,
while Qwen3 only adds to the last turn. This doesn't affect training.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: Chenhan Yu <chenhany@nvidia.com>
The original Qwen3 template adds <think>\n\n</think>\n\n only to the
last assistant turn, not all turns. Rather than replicating this complex
logic, keep the simplified template clean — just output message content
as-is. Training data already contains <think> blocks when present.
Llama3 template has no think logic at all.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: Chenhan Yu <chenhany@nvidia.com>
Qwen3's original template adds <think>\n\n</think>\n\n to the last
assistant turn when content doesn't start with <think>. Detect this
by checking if '<think>' appears in the original template and use
the chatml_think variant which replicates this behavior exactly.

Models without think logic (Llama3, basic ChatML) use the plain
chatml template. All three samples now match the original tokenization.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: Chenhan Yu <chenhany@nvidia.com>
Clearly document what the simplified chat templates do, what is
preserved vs dropped, and limitations for tool-use and multi-step
reasoning data. Also document how to use a custom template.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: Chenhan Yu <chenhany@nvidia.com>
The early-exit zero-loss path used target_hidden (computed under
no_grad) which has no gradient graph, causing 'does not require
grad' error. Use dflash_module.fc.weight instead to keep DDP happy.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: Chenhan Yu <chenhany@nvidia.com>
The Speculative-Decoding dataset has both 'messages' (prompt only)
and 'conversations' (prompt + response) fields. The collator took
'messages' first, missing the assistant response entirely. Now
checks if 'messages' has an assistant turn, otherwise falls back
to 'conversations'.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: Chenhan Yu <chenhany@nvidia.com>
… collator

- Remove debug prints in hf_dflash.py (_get_attn_fn and modify) while
  keeping informational prints (mask_token_id and base forward)
- Add "Legacy: used for inference only" comment on create_dflash_attention_mask
  and create_dflash_loss_mask
- Remove _apply_answer_only_labels regex fallback in transformers_dataset.py;
  raise ValueError when assistant_masks is missing/empty
- Add validation for missing assistant turns in __call__
- Make _ensure_generation_tags warnings more prominent with === WARNING === prefix

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: Chenhan Yu <chenhany@nvidia.com>
The AR validation callback ran model.forward() on rank 0 only, but
DDP model forward triggers collective ops that require all ranks.
Now unwraps the DDP model (model.module) before validation, so
forward runs without collective hooks. Other ranks wait at barrier.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: Chenhan Yu <chenhany@nvidia.com>
During distributed training, a single bad sample would crash all
ranks. Now warns and skips samples without assistant turns. Also
handle the case where all assistant content is truncated by masking
all labels instead of raising ValueError.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: Chenhan Yu <chenhany@nvidia.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: Chenhan Yu <chenhany@nvidia.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: Chenhan Yu <chenhany@nvidia.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: Chenhan Yu <chenhany@nvidia.com>
AR validation with DDP is fundamentally incompatible —
pseudo_speculative_generate runs inference on rank 0 while other
ranks deadlock on collective ops. Now detects world_size > 1 and
skips with a one-time warning. AR validation still works for
single-GPU and post-training (online_training.sh).

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: Chenhan Yu <chenhany@nvidia.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: Chenhan Yu <chenhany@nvidia.com>
@ChenhanYu ChenhanYu requested a review from a team as a code owner April 4, 2026 03:14
@ChenhanYu ChenhanYu requested a review from sugunav14 April 4, 2026 03:14
ChenhanYu and others added 10 commits April 3, 2026 20:18
- Fix resume: add device_map='cpu' to checkpoint loading path to
  avoid meta tensor errors
- Add export step to online_training.sh: after training, export
  DFlash checkpoint to z-lab HF format, then validate AR on the
  exported checkpoint
- AR validation prefers exported checkpoint (no prefix) over
  training checkpoint (with prefix)

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: Chenhan Yu <chenhany@nvidia.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: Chenhan Yu <chenhany@nvidia.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: Chenhan Yu <chenhany@nvidia.com>
Try scontrol first, then parse SLURM_JOB_NODELIST directly and
resolve via getent hosts. Works both inside and outside containers.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: Chenhan Yu <chenhany@nvidia.com>
Try scontrol, SLURM_LAUNCH_NODE_IPADDR, Python socket resolution,
and hostname -I as fallbacks. Should work inside containers where
scontrol is not available.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: Chenhan Yu <chenhany@nvidia.com>
DFlash uses DDP, not FSDP. The default dp_shard_size=TOTAL_GPU
caused FSDP-style sharding. Force to 1 for pure DDP replication.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: Chenhan Yu <chenhany@nvidia.com>
DFlash defaults to DDP (dp_shard_size=1). Pass --fsdp True to use
FSDP with full_shard instead.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: Chenhan Yu <chenhany@nvidia.com>
load_vlm_or_llm uses meta tensors internally. Use
AutoModelForCausalLM.from_pretrained with low_cpu_mem_usage=False
to avoid meta tensor errors during export.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: Chenhan Yu <chenhany@nvidia.com>
Same script used by EAGLE3 export. Avoids custom loading logic
that caused meta tensor errors.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: Chenhan Yu <chenhany@nvidia.com>
Loading from checkpoint subdirectory (e.g., checkpoint-12500/) causes
meta tensor errors with transformers 5.x. Load from output_dir
(top-level save) instead, which works. The checkpoint path is still
passed to trainer.train(resume_from_checkpoint=...) for optimizer
and step count resume.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: Chenhan Yu <chenhany@nvidia.com>
"""Forward matching SpecForge DFlashDraftModel.forward."""
hidden_states = noise_embedding
target_hidden = self.hidden_norm(self.fc(target_hidden))
position_embeddings = self.rotary_emb(hidden_states, position_ids)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
position_embeddings = self.rotary_emb(hidden_states, position_ids)
#lazy init rope
if not hasattr(self, "rotary_emb"):
self.rotary_emb = _ROTARY_CLS(config=config, device=hidden_states.device)
position_embeddings = self.rotary_emb(hidden_states, position_ids)

[DFlashDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
)
self.norm = _NORM_CLS(config.hidden_size, eps=config.rms_norm_eps)
self.rotary_emb = _ROTARY_CLS(config=config)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
self.rotary_emb = _ROTARY_CLS(config=config)

Lazy init as below to avoid meta tensor copy error

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants