add: DFlash block diffusion speculative decoding#1128
add: DFlash block diffusion speculative decoding#1128
Conversation
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughAdds DFlash speculative-decoding: new DFlash config and defaults, conversion/restore utilities and registry, a DFlash base model class and HuggingFace plugin implementing training/generation logic, integrates a new "dflash" mode and example/training adjustments. Changes
Sequence Diagram(s)sequenceDiagram
actor User
participant BaseModel as Base LLM
participant HiddenCollector as Hidden State<br/>Collector
participant Fusion as Feature Fusion<br/>(FC + RMSNorm)
participant DraftModule as DFlash Draft<br/>Module (decoder stack)
participant LMHead as Logit Head
participant Loss as Loss & Accuracy
User->>BaseModel: input_ids + attention_mask
BaseModel->>HiddenCollector: forward -> hidden states
HiddenCollector->>Fusion: collect target-layer states
Fusion->>DraftModule: fused targets + noise embeddings
DraftModule->>LMHead: draft hidden -> logits
LMHead->>Loss: compute CE / accuracy
Loss-->>User: loss + accuracy
rect rgba(100,150,200,0.5)
Note over BaseModel,Loss: DFlash training forward pass
end
sequenceDiagram
actor User
participant BaseModel as Base LLM
participant DFlashMod as DFlash Module
participant BlockBuilder as Block Builder
participant DraftDecoder as Draft Decoder
participant TokenSel as Token Selector
User->>BaseModel: input_ids (context)
loop for each generation step
BaseModel->>DFlashMod: base next-token anchor + hidden states
DFlashMod->>DFlashMod: fuse target layers
loop for each block position
BlockBuilder->>DraftDecoder: build noise (anchor + mask), run draft step
DraftDecoder->>TokenSel: draft logits -> argmax
end
DraftDecoder-->>User: base token + draft block
end
rect rgba(150,200,100,0.5)
Note over BaseModel,TokenSel: DFlash pseudo-speculative generation
end
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes 🚥 Pre-merge checks | ✅ 3 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
|
There was a problem hiding this comment.
Actionable comments posted: 2
🧹 Nitpick comments (4)
modelopt/torch/speculative/dflash/dflash_model.py (1)
27-34: Add type hint forconfigparameter.The
configparameter lacks a type annotation. Per project standards, type hints should be provided for static type checking with mypy.♻️ Proposed fix
+from ..config import DFlashConfig + + class DFlashModel(DynamicModule): """Base DFlash Model.""" def _setup(self): self._register_temp_attribute("dflash_module", None) - def modify(self, config): + def modify(self, config: DFlashConfig): """Base DFlash Model modify function. Child class should implement the details."""🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/speculative/dflash/dflash_model.py` around lines 27 - 34, The modify method's config parameter is missing a type annotation; update def modify(self, config) to include the appropriate config type (e.g., def modify(self, config: DFlashConfig)) and import that type at the top of the module from wherever the project's config dataclass/typing lives; if a concrete config type isn't available yet, annotate with typing.Any as a temporary fallback and add the proper import for Any. Ensure you update imports (from typing import Any or from <module> import DFlashConfig) and keep the existing attribute assignments in modify unchanged.modelopt/torch/speculative/plugins/hf_dflash.py (3)
320-327: Silent exception handling may hide initialization errors.The broad
except Exception: continuepattern silently ignores all errors when locating base model parts, potentially masking genuine issues like attribute errors or type mismatches.♻️ Proposed improvement to catch only expected exceptions
for path in paths: try: submodule = self.get_submodule(path) assert isinstance(submodule, torch.nn.Module) setattr(self, name, path) break - except Exception: + except (AttributeError, AssertionError): continue🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/speculative/plugins/hf_dflash.py` around lines 320 - 327, The loop that tries to locate submodules uses a broad "except Exception: continue", which can hide real errors; change it to only catch expected exceptions (e.g., except (AttributeError, AssertionError, TypeError): continue) so intentional lookup failures are ignored but other unexpected exceptions surface (or are logged/re-raised); keep the calls to self.get_submodule and setattr(self, name, path) intact but ensure you handle and/or log unexpected exceptions rather than swallowing them silently.
471-474: Complex anchor sampling logic is hard to follow.The nested
max(),min(), andrange()calls make this expression difficult to reason about. Consider breaking it into intermediate variables for clarity and easier debugging.♻️ Suggested refactor for readability
- num_blocks = max(1, max_anchor // block_size) - # Sample anchor positions uniformly - anchors = sorted( - random.sample(range(1, max(2, max_anchor)), min(num_blocks, max(1, max_anchor - 1))) - ) + num_blocks = max(1, max_anchor // block_size) + # Sample anchor positions uniformly + sample_range_end = max(2, max_anchor) + sample_count = min(num_blocks, max(1, max_anchor - 1)) + anchors = sorted(random.sample(range(1, sample_range_end), sample_count))🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/speculative/plugins/hf_dflash.py` around lines 471 - 474, The anchor sampling expression assigned to anchors is hard to read; break it into named intermediate variables (e.g., compute max_anchor_bound = max(2, max_anchor), sample_upper = max(1, max_anchor - 1), num_to_sample = min(num_blocks, sample_upper), and the range_to_sample = range(1, max_anchor_bound)) and then call random.sample(range_to_sample, num_to_sample) and sort the result; update the code in hf_dflash.py where anchors is defined (the anchors variable in the anchor sampling block) to use these intermediate names for clarity and easier debugging.
346-347: Accessing private_attn_implementationattribute is fragile.
_attn_implementationis a private attribute ofPretrainedConfigthat may change between transformers versions without notice.♻️ Suggested defensive check
- if self.dflash_config._attn_implementation is None: + if getattr(self.dflash_config, "_attn_implementation", None) is None: self.dflash_config._attn_implementation = "sdpa"🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/speculative/plugins/hf_dflash.py` around lines 346 - 347, Replace the fragile direct access to the private attribute self.dflash_config._attn_implementation with a defensive check using getattr/hasattr and set via setattr (e.g., current = getattr(self.dflash_config, "attn_implementation", None) or fallback = getattr(self.dflash_config, "_attn_implementation", None); if current is None: setattr(self.dflash_config, "attn_implementation", "sdpa") ), so the code uses public names when present and only falls back to the underscore name if necessary, ensuring compatibility across transformers versions.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@modelopt/torch/speculative/plugins/hf_dflash.py`:
- Around line 368-375: The modify() method currently registers forward hooks
each time without removing prior hooks, causing duplicated collections; fix by
tracking hook handles and removing old hooks before registering new ones: add a
persistent attribute (e.g., self._registered_forward_hooks = []) initialized in
the class (constructor), at the start of modify() iterate over
self._registered_forward_hooks calling handle.remove() and then clear the list,
then when registering forward hooks on layers (where you call
layer.register_forward_hook(self._collect_hidden_hook)) capture each returned
handle and append it to self._registered_forward_hooks; also ensure you reset
self._target_hidden_states (and optionally self._cached_masks) when re-modifying
to avoid stale state.
- Around line 55-62: The build_target_layer_ids function collapses to the same
index when num_target_layers <= 4 because start=1 and end=num_target_layers-3
produce span <= 0; change the logic to handle shallow models: if
num_target_layers <= 4 (or if num_sample_layers >= num_target_layers) return a
set of valid unique layer indices (e.g., evenly spaced or simply range(0,
num_target_layers) trimmed to num_sample_layers) and otherwise compute evenly
spaced indices across [0, num_target_layers-1]; update build_target_layer_ids to
clamp/adjust start/end and to limit num_sample_layers to at most
num_target_layers so indices remain unique and within bounds.
---
Nitpick comments:
In `@modelopt/torch/speculative/dflash/dflash_model.py`:
- Around line 27-34: The modify method's config parameter is missing a type
annotation; update def modify(self, config) to include the appropriate config
type (e.g., def modify(self, config: DFlashConfig)) and import that type at the
top of the module from wherever the project's config dataclass/typing lives; if
a concrete config type isn't available yet, annotate with typing.Any as a
temporary fallback and add the proper import for Any. Ensure you update imports
(from typing import Any or from <module> import DFlashConfig) and keep the
existing attribute assignments in modify unchanged.
In `@modelopt/torch/speculative/plugins/hf_dflash.py`:
- Around line 320-327: The loop that tries to locate submodules uses a broad
"except Exception: continue", which can hide real errors; change it to only
catch expected exceptions (e.g., except (AttributeError, AssertionError,
TypeError): continue) so intentional lookup failures are ignored but other
unexpected exceptions surface (or are logged/re-raised); keep the calls to
self.get_submodule and setattr(self, name, path) intact but ensure you handle
and/or log unexpected exceptions rather than swallowing them silently.
- Around line 471-474: The anchor sampling expression assigned to anchors is
hard to read; break it into named intermediate variables (e.g., compute
max_anchor_bound = max(2, max_anchor), sample_upper = max(1, max_anchor - 1),
num_to_sample = min(num_blocks, sample_upper), and the range_to_sample =
range(1, max_anchor_bound)) and then call random.sample(range_to_sample,
num_to_sample) and sort the result; update the code in hf_dflash.py where
anchors is defined (the anchors variable in the anchor sampling block) to use
these intermediate names for clarity and easier debugging.
- Around line 346-347: Replace the fragile direct access to the private
attribute self.dflash_config._attn_implementation with a defensive check using
getattr/hasattr and set via setattr (e.g., current = getattr(self.dflash_config,
"attn_implementation", None) or fallback = getattr(self.dflash_config,
"_attn_implementation", None); if current is None: setattr(self.dflash_config,
"attn_implementation", "sdpa") ), so the code uses public names when present and
only falls back to the underscore name if necessary, ensuring compatibility
across transformers versions.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 99f4f5f4-7d81-4e34-b82f-d958548f8d6d
📒 Files selected for processing (9)
examples/speculative_decoding/main.pymodelopt/torch/speculative/config.pymodelopt/torch/speculative/dflash/__init__.pymodelopt/torch/speculative/dflash/conversion.pymodelopt/torch/speculative/dflash/default_config.pymodelopt/torch/speculative/dflash/dflash_model.pymodelopt/torch/speculative/mode.pymodelopt/torch/speculative/plugins/__init__.pymodelopt/torch/speculative/plugins/hf_dflash.py
| def build_target_layer_ids(num_target_layers, num_sample_layers): | ||
| """Select layers uniformly from the target model for feature extraction.""" | ||
| if num_sample_layers == 1: | ||
| return [num_target_layers // 2] | ||
| start = 1 | ||
| end = num_target_layers - 3 | ||
| span = end - start | ||
| return [round(start + (i * span) / (num_sample_layers - 1)) for i in range(num_sample_layers)] |
There was a problem hiding this comment.
Edge case: small num_target_layers values produce degenerate sampling.
When num_target_layers <= 4, end <= start (e.g., num_target_layers=4 → start=1, end=1, span=0). All returned layer indices collapse to the same value, defeating uniform sampling. Consider adding a guard or adjusting the formula for shallow target models.
🛡️ Proposed fix to handle edge case
def build_target_layer_ids(num_target_layers, num_sample_layers):
"""Select layers uniformly from the target model for feature extraction."""
if num_sample_layers == 1:
return [num_target_layers // 2]
+ if num_target_layers <= 4:
+ # For very shallow models, sample from all available layers
+ start, end = 0, num_target_layers - 1
+ else:
+ start = 1
+ end = num_target_layers - 3
- start = 1
- end = num_target_layers - 3
span = end - start
+ if span <= 0:
+ return [start] * num_sample_layers
return [round(start + (i * span) / (num_sample_layers - 1)) for i in range(num_sample_layers)]🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@modelopt/torch/speculative/plugins/hf_dflash.py` around lines 55 - 62, The
build_target_layer_ids function collapses to the same index when
num_target_layers <= 4 because start=1 and end=num_target_layers-3 produce span
<= 0; change the logic to handle shallow models: if num_target_layers <= 4 (or
if num_sample_layers >= num_target_layers) return a set of valid unique layer
indices (e.g., evenly spaced or simply range(0, num_target_layers) trimmed to
num_sample_layers) and otherwise compute evenly spaced indices across [0,
num_target_layers-1]; update build_target_layer_ids to clamp/adjust start/end
and to limit num_sample_layers to at most num_target_layers so indices remain
unique and within bounds.
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #1128 +/- ##
===========================================
- Coverage 70.18% 54.59% -15.59%
===========================================
Files 230 355 +125
Lines 26080 40579 +14499
===========================================
+ Hits 18304 22155 +3851
- Misses 7776 18424 +10648
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (4)
modelopt/torch/speculative/plugins/hf_dflash.py (4)
307-309: Device placement assumes_base_model.layersexists and is non-empty.If the base model has a different structure (e.g., no
layersattribute or empty layers), this will raise anAttributeErrororIndexError.🛡️ More robust device detection
self.dflash_module = DFlashModule(self.dflash_config) - self.dflash_module.to(self._base_model.dtype).to( - next(self._base_model.layers[-1].parameters()).device - ) + # Get device from any base model parameter + base_device = next(self._base_model.parameters()).device + base_dtype = next(self._base_model.parameters()).dtype + self.dflash_module.to(base_dtype).to(base_device)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/speculative/plugins/hf_dflash.py` around lines 307 - 309, The placement code for self.dflash_module assumes self._base_model.layers exists and has elements; instead, detect device robustly by checking for a non-empty self._base_model.layers and using next(self._base_model.layers[-1].parameters()).device only when present, otherwise fall back to using next(self._base_model.parameters()).device (or a CPU default if parameters are absent); apply the chosen device and the target dtype (self._base_model.dtype) to self.dflash_module so device/dtype setting works for models without a layers attribute or with empty layers.
82-82: Unused attributeis_causal.The
is_causalattribute is set toFalsebut never referenced. The value is hardcoded directly inscaled_dot_product_attentioncall at line 129.♻️ Suggested cleanup
self.num_kv_heads = config.num_key_value_heads self.scaling = self.head_dim**-0.5 - self.is_causal = False🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/speculative/plugins/hf_dflash.py` at line 82, The instance attribute self.is_causal is defined but never used; either remove its assignment or wire it into the attention call — replace the hardcoded False in the scaled_dot_product_attention invocation with self.is_causal (or delete the self.is_causal assignment if you prefer no flag). Update the attribute in the class initializer where self.is_causal is set and the scaled_dot_product_attention call at the line that currently passes False so they are consistent.
412-414: Zero loss tensor withrequires_grad=Truewon't propagate gradients.When
active_logits.numel() == 0, the returned loss tensor is a constant0.0withrequires_grad=True. While this avoids errors during backward pass, it produces a gradient of zero for all parameters. Consider logging a warning when this edge case occurs.💡 Add warning for visibility
+ import logging + logger = logging.getLogger(__name__) + if active_logits.numel() > 0: loss = F.cross_entropy(active_logits, active_labels) with torch.no_grad(): preds = active_logits.argmax(dim=-1) accuracy = (preds == active_labels).float().mean().item() else: + logger.warning("No active positions for loss computation") loss = torch.tensor(0.0, device=device, dtype=dtype, requires_grad=True) accuracy = 0.0🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/speculative/plugins/hf_dflash.py` around lines 412 - 414, When active_logits.numel() == 0 the code returns loss = torch.tensor(0.0, device=device, dtype=dtype, requires_grad=True) which yields zero gradients; update the branch handling this case to also emit a clear warning so the condition is visible in logs (e.g., log that active_logits is empty and loss is a zero tensor) and keep returning the tensor and accuracy as currently done. Locate the branch using active_logits.numel() == 0 (the else that sets loss and accuracy), add a warning via the module/logger used in this file (e.g., logger.warning or process_logger) mentioning the function/context (hf_dflash loss computation) and the tensor/device/dtype, and ensure no change to the returned tensor shape or requires_grad behavior.
285-286: Accessing private HuggingFace attribute_attn_implementation.
_attn_implementationis a private/internal attribute in HuggingFace configs that may change without notice. Consider using the public API or documenting this dependency.♻️ More defensive approach
- if self.dflash_config._attn_implementation is None: - self.dflash_config._attn_implementation = "eager" + if getattr(self.dflash_config, "_attn_implementation", None) is None: + self.dflash_config._attn_implementation = "eager"🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/speculative/plugins/hf_dflash.py` around lines 285 - 286, The code currently writes directly to the private HuggingFace attribute self.dflash_config._attn_implementation; instead, make this defensive by checking for the attribute's existence with getattr/hasattr and only set it when present, otherwise fall back to a documented public config or internal default and log/warn about relying on a private attribute; update the assignment around self.dflash_config._attn_implementation to use a safe check (getattr(self.dflash_config, "_attn_implementation", None)) and a clear fallback path so the plugin won't break if the private field is removed in future.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@modelopt/torch/speculative/plugins/hf_dflash.py`:
- Around line 47-56: The function build_target_layer_ids should avoid the
redundant int(round(...)) and handle shallow models where span <= 0; change the
rounding to just round(...) (or remove the int() wrapper) and add a guard when
num_target_layers <= 4 (or when end <= start) to produce sensible unique
indices: compute end = max(1, num_target_layers - 3), compute span = max(1, end
- start), and if num_draft_layers > (end - start + 1) clamp num_draft_layers to
that maximum so the list comprehension in build_target_layer_ids returns evenly
spaced, non-duplicated layer indices (use the existing variables
num_target_layers, num_draft_layers, start, end, span in the fix).
---
Nitpick comments:
In `@modelopt/torch/speculative/plugins/hf_dflash.py`:
- Around line 307-309: The placement code for self.dflash_module assumes
self._base_model.layers exists and has elements; instead, detect device robustly
by checking for a non-empty self._base_model.layers and using
next(self._base_model.layers[-1].parameters()).device only when present,
otherwise fall back to using next(self._base_model.parameters()).device (or a
CPU default if parameters are absent); apply the chosen device and the target
dtype (self._base_model.dtype) to self.dflash_module so device/dtype setting
works for models without a layers attribute or with empty layers.
- Line 82: The instance attribute self.is_causal is defined but never used;
either remove its assignment or wire it into the attention call — replace the
hardcoded False in the scaled_dot_product_attention invocation with
self.is_causal (or delete the self.is_causal assignment if you prefer no flag).
Update the attribute in the class initializer where self.is_causal is set and
the scaled_dot_product_attention call at the line that currently passes False so
they are consistent.
- Around line 412-414: When active_logits.numel() == 0 the code returns loss =
torch.tensor(0.0, device=device, dtype=dtype, requires_grad=True) which yields
zero gradients; update the branch handling this case to also emit a clear
warning so the condition is visible in logs (e.g., log that active_logits is
empty and loss is a zero tensor) and keep returning the tensor and accuracy as
currently done. Locate the branch using active_logits.numel() == 0 (the else
that sets loss and accuracy), add a warning via the module/logger used in this
file (e.g., logger.warning or process_logger) mentioning the function/context
(hf_dflash loss computation) and the tensor/device/dtype, and ensure no change
to the returned tensor shape or requires_grad behavior.
- Around line 285-286: The code currently writes directly to the private
HuggingFace attribute self.dflash_config._attn_implementation; instead, make
this defensive by checking for the attribute's existence with getattr/hasattr
and only set it when present, otherwise fall back to a documented public config
or internal default and log/warn about relying on a private attribute; update
the assignment around self.dflash_config._attn_implementation to use a safe
check (getattr(self.dflash_config, "_attn_implementation", None)) and a clear
fallback path so the plugin won't break if the private field is removed in
future.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 3311c2cf-da29-4e8f-bbc0-c97039583962
📒 Files selected for processing (1)
modelopt/torch/speculative/plugins/hf_dflash.py
| def build_target_layer_ids(num_target_layers, num_draft_layers): | ||
| """Select layers uniformly from the target model for feature extraction.""" | ||
| if num_draft_layers == 1: | ||
| return [num_target_layers // 2] | ||
| start = 1 | ||
| end = num_target_layers - 3 | ||
| span = end - start | ||
| return [ | ||
| int(round(start + (i * span) / (num_draft_layers - 1))) for i in range(num_draft_layers) | ||
| ] |
There was a problem hiding this comment.
Fix linting error and handle edge case for shallow target models.
Two issues:
-
Pipeline failure (Line 55):
round()already returns anintin Python 3, making theint()wrapper redundant. -
Edge case: When
num_target_layers <= 4,spanbecomes ≤ 0 (e.g.,num_target_layers=4→start=1,end=1,span=0), causing all returned indices to collapse to the same value.
🐛 Proposed fix
def build_target_layer_ids(num_target_layers, num_draft_layers):
"""Select layers uniformly from the target model for feature extraction."""
if num_draft_layers == 1:
return [num_target_layers // 2]
- start = 1
- end = num_target_layers - 3
+ # For shallow models, use full range; otherwise skip first and last few layers
+ if num_target_layers <= 4:
+ start, end = 0, num_target_layers - 1
+ else:
+ start, end = 1, num_target_layers - 3
span = end - start
+ if span <= 0 or num_draft_layers > num_target_layers:
+ # Fallback: return middle layer repeated
+ return [num_target_layers // 2] * num_draft_layers
return [
- int(round(start + (i * span) / (num_draft_layers - 1))) for i in range(num_draft_layers)
+ round(start + (i * span) / (num_draft_layers - 1)) for i in range(num_draft_layers)
]🧰 Tools
🪛 GitHub Actions: Code Quality
[error] 55-55: ruff check failed (RUF046). Value being cast to int is already an integer. Line suggests int(round(...)); help: remove unnecessary int call.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@modelopt/torch/speculative/plugins/hf_dflash.py` around lines 47 - 56, The
function build_target_layer_ids should avoid the redundant int(round(...)) and
handle shallow models where span <= 0; change the rounding to just round(...)
(or remove the int() wrapper) and add a guard when num_target_layers <= 4 (or
when end <= start) to produce sensible unique indices: compute end = max(1,
num_target_layers - 3), compute span = max(1, end - start), and if
num_draft_layers > (end - start + 1) clamp num_draft_layers to that maximum so
the list comprehension in build_target_layer_ids returns evenly spaced,
non-duplicated layer indices (use the existing variables num_target_layers,
num_draft_layers, start, end, span in the fix).
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (3)
modelopt/torch/speculative/plugins/hf_dflash.py (3)
364-364: Replaceprint()with logging.Production code should use the logging module for better control over output verbosity.
♻️ Proposed fix
Add import at top of file:
import logging logger = logging.getLogger(__name__)Then replace the print:
- print(f"DFlash: using {original_cls.__name__}.forward as base forward") + logger.info(f"DFlash: using {original_cls.__name__}.forward as base forward")🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/speculative/plugins/hf_dflash.py` at line 364, Replace the stray print call with structured logging: add an import for the logging module and initialize a module-level logger (e.g., logger = logging.getLogger(__name__)), then replace the print(f"DFlash: using {original_cls.__name__}.forward as base forward") in hf_dflash.py with an appropriate logger call (logger.info or logger.debug) referencing original_cls.__name__ to retain the same message content; ensure the logger is defined at top of the file so the statement in the speculative plugin uses it.
260-269: Overly broad exception handling.Catching
Exceptionmay mask unexpected errors (e.g.,AttributeError,TypeError). Consider narrowing to the specific exceptions expected fromget_submodule.♻️ Proposed fix
for path in paths: try: submodule = self.get_submodule(path) assert isinstance(submodule, torch.nn.Module) setattr(self, name, path) break - except Exception: + except (AttributeError, AssertionError): continue🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/speculative/plugins/hf_dflash.py` around lines 260 - 269, Narrow the broad except in the for-loop that calls get_submodule: replace "except Exception" with a targeted exception tuple such as "except (AttributeError, KeyError):" (these are the likely errors when a submodule path is not found) so only lookup-related failures are silenced while letting other errors (TypeError, AssertionError, etc.) surface; keep the rest of the logic in the loop (get_submodule, isinstance check, setattr(self, name, path) and the final ValueError) unchanged.
82-82: Unused attributeis_causal.
self.is_causalis assigned but never referenced. Line 129 hardcodesis_causal=Falsedirectly.♻️ Proposed fix
- self.is_causal = False🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/speculative/plugins/hf_dflash.py` at line 82, The attribute self.is_causal is assigned but never used; replace the hardcoded is_causal=False occurrence in this module with the instance attribute so the class-level flag is honored: initialize self.is_causal in the class constructor (as currently present) and change the call/site that currently passes is_causal=False to pass is_causal=self.is_causal instead; alternatively, if causality is never intended to be configurable, remove the unused self.is_causal assignment and keep the hardcoded False—prefer the first option to preserve configurability.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@modelopt/torch/speculative/plugins/hf_dflash.py`:
- Around line 510-521: The debug print block guarded by self._psg_debug should
be removed or replaced with proper logging and a safe batch-aware access; locate
the block around _psg_debug that inspects base_outputs.hidden_states,
base_token, mask_token_id, dflash_block_size and calls
self._base_model_embeddings and either delete it or change prints to
logger.debug(...) and replace the unsafe base_token.item() with a batch-safe
access such as base_token[0,0].item() so it won't fail when bsz > 1; keep the
one-time flag behavior if you want this to run only once.
---
Nitpick comments:
In `@modelopt/torch/speculative/plugins/hf_dflash.py`:
- Line 364: Replace the stray print call with structured logging: add an import
for the logging module and initialize a module-level logger (e.g., logger =
logging.getLogger(__name__)), then replace the print(f"DFlash: using
{original_cls.__name__}.forward as base forward") in hf_dflash.py with an
appropriate logger call (logger.info or logger.debug) referencing
original_cls.__name__ to retain the same message content; ensure the logger is
defined at top of the file so the statement in the speculative plugin uses it.
- Around line 260-269: Narrow the broad except in the for-loop that calls
get_submodule: replace "except Exception" with a targeted exception tuple such
as "except (AttributeError, KeyError):" (these are the likely errors when a
submodule path is not found) so only lookup-related failures are silenced while
letting other errors (TypeError, AssertionError, etc.) surface; keep the rest of
the logic in the loop (get_submodule, isinstance check, setattr(self, name,
path) and the final ValueError) unchanged.
- Line 82: The attribute self.is_causal is assigned but never used; replace the
hardcoded is_causal=False occurrence in this module with the instance attribute
so the class-level flag is honored: initialize self.is_causal in the class
constructor (as currently present) and change the call/site that currently
passes is_causal=False to pass is_causal=self.is_causal instead; alternatively,
if causality is never intended to be configurable, remove the unused
self.is_causal assignment and keep the hardcoded False—prefer the first option
to preserve configurability.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 90e791a0-d745-4cf9-8a9d-1600e8d1ad34
📒 Files selected for processing (3)
examples/speculative_decoding/main.pymodelopt/torch/speculative/dflash/default_config.pymodelopt/torch/speculative/plugins/hf_dflash.py
✅ Files skipped from review due to trivial changes (1)
- modelopt/torch/speculative/dflash/default_config.py
🚧 Files skipped from review as they are similar to previous changes (1)
- examples/speculative_decoding/main.py
| if not hasattr(self, '_psg_debug'): | ||
| self._psg_debug = True | ||
| sel = [base_outputs.hidden_states[lid + hid_offset] for lid in self.target_layer_ids] | ||
| th = torch.cat(sel, dim=-1) | ||
| print(f"[psg] hidden_states layers: {len(base_outputs.hidden_states)}, target_hidden norm: {th.norm().item():.2f}, shape: {th.shape}") | ||
| print(f"[psg] base_token: {base_token.item()}, mask_token_id: {self.mask_token_id}") | ||
| print(f"[psg] block_ids: {[self.mask_token_id]*self.dflash_block_size}") | ||
| bi = torch.full((1, self.dflash_block_size), self.mask_token_id, dtype=torch.long, device=input_ids.device) | ||
| bi[0, 0] = base_token[0, 0] | ||
| ne = self._base_model_embeddings(bi) | ||
| print(f"[psg] noise_emb norm: {ne.norm().item():.2f}, shape: {ne.shape}") | ||
| print(f"[psg] pos_ids will be: ctx=[0..{input_ids.shape[1]-1}], blk=[{input_ids.shape[1]}..{input_ids.shape[1]+self.dflash_block_size-1}]") |
There was a problem hiding this comment.
Remove debug print statements from production code.
These debug prints are guarded by _psg_debug but will still execute once per model instance. Additionally, line 515 uses base_token.item() which will fail when bsz > 1.
🐛 Proposed fix
Either remove the debug block entirely, or convert to proper logging:
+import logging
+logger = logging.getLogger(__name__)
+
...
- if not hasattr(self, '_psg_debug'):
- self._psg_debug = True
- sel = [base_outputs.hidden_states[lid + hid_offset] for lid in self.target_layer_ids]
- th = torch.cat(sel, dim=-1)
- print(f"[psg] hidden_states layers: {len(base_outputs.hidden_states)}, target_hidden norm: {th.norm().item():.2f}, shape: {th.shape}")
- print(f"[psg] base_token: {base_token.item()}, mask_token_id: {self.mask_token_id}")
- print(f"[psg] block_ids: {[self.mask_token_id]*self.dflash_block_size}")
- bi = torch.full((1, self.dflash_block_size), self.mask_token_id, dtype=torch.long, device=input_ids.device)
- bi[0, 0] = base_token[0, 0]
- ne = self._base_model_embeddings(bi)
- print(f"[psg] noise_emb norm: {ne.norm().item():.2f}, shape: {ne.shape}")
- print(f"[psg] pos_ids will be: ctx=[0..{input_ids.shape[1]-1}], blk=[{input_ids.shape[1]}..{input_ids.shape[1]+self.dflash_block_size-1}]")If debug logging is needed, use logger.debug() and fix the batch-size issue:
logger.debug(f"[psg] base_token: {base_token[0, 0].item()}, mask_token_id: {self.mask_token_id}")📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| if not hasattr(self, '_psg_debug'): | |
| self._psg_debug = True | |
| sel = [base_outputs.hidden_states[lid + hid_offset] for lid in self.target_layer_ids] | |
| th = torch.cat(sel, dim=-1) | |
| print(f"[psg] hidden_states layers: {len(base_outputs.hidden_states)}, target_hidden norm: {th.norm().item():.2f}, shape: {th.shape}") | |
| print(f"[psg] base_token: {base_token.item()}, mask_token_id: {self.mask_token_id}") | |
| print(f"[psg] block_ids: {[self.mask_token_id]*self.dflash_block_size}") | |
| bi = torch.full((1, self.dflash_block_size), self.mask_token_id, dtype=torch.long, device=input_ids.device) | |
| bi[0, 0] = base_token[0, 0] | |
| ne = self._base_model_embeddings(bi) | |
| print(f"[psg] noise_emb norm: {ne.norm().item():.2f}, shape: {ne.shape}") | |
| print(f"[psg] pos_ids will be: ctx=[0..{input_ids.shape[1]-1}], blk=[{input_ids.shape[1]}..{input_ids.shape[1]+self.dflash_block_size-1}]") |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@modelopt/torch/speculative/plugins/hf_dflash.py` around lines 510 - 521, The
debug print block guarded by self._psg_debug should be removed or replaced with
proper logging and a safe batch-aware access; locate the block around _psg_debug
that inspects base_outputs.hidden_states, base_token, mask_token_id,
dflash_block_size and calls self._base_model_embeddings and either delete it or
change prints to logger.debug(...) and replace the unsafe base_token.item() with
a batch-safe access such as base_token[0,0].item() so it won't fail when bsz >
1; keep the one-time flag behavior if you want this to run only once.
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (1)
examples/speculative_decoding/eagle_utils.py (1)
241-249: Cache the MT-Bench dataset outside the synchronized section.
load_dataset(...)now sits on the global stall path, so every worker waits for rank 0 to redo the hub/cache lookup before training resumes. Reusing a cached dataset here would cut that idle time.♻️ Proposed refactor
class EagleTrainingPlot(TrainerCallback): def __init__(self, ar_validate_steps: int = 1000, estimate_ar: bool = False): self.ar_validate_steps = ar_validate_steps if wandb and is_master(): wandb.init() self.estimate_ar = estimate_ar + self._ar_validation_ds = None @@ if is_master(): print_rank_0("Running AR validation...") try: + if self._ar_validation_ds is None: + self._ar_validation_ds = load_dataset("HuggingFaceH4/mt_bench_prompts")[ + "train" + ] ars = validate_ar( model=kwargs["model"], tokenizer=kwargs["processing_class"], - ds=load_dataset("HuggingFaceH4/mt_bench_prompts")["train"], + ds=self._ar_validation_ds, device=kwargs["model"].device, )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/speculative_decoding/eagle_utils.py` around lines 241 - 249, The call to load_dataset("HuggingFaceH4/mt_bench_prompts") is happening inside the synchronized/critical section causing all ranks to wait; move the dataset load/cache out of that section and reuse a single cached dataset reference when calling validate_ar (keep calling validate_ar(model=kwargs["model"], tokenizer=kwargs["processing_class"], ds=cached_mt_bench_ds, device=kwargs["model"].device)), ensuring cached_mt_bench_ds is initialized once (e.g., at module import or rank-0 setup) and shared by workers so only the heavy hub/cache lookup happens once while still preserving use of state.global_step, ars, and wandb logging.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@examples/speculative_decoding/eagle_utils.py`:
- Around line 238-254: Load the validation dataset before the master-only block
so non-master ranks don't waste time; specifically, when checking the AR
validation trigger (state.global_step % self.ar_validate_steps == 0 and
state.global_step > 0) call
load_dataset("HuggingFaceH4/mt_bench_prompts")["train"] into a local variable
(e.g., ds) before the is_master() check, then inside the is_master() block call
validate_ar(model=kwargs["model"], tokenizer=kwargs["processing_class"], ds=ds,
device=kwargs["model"].device) as before; keep print_rank_0, the try/except
around validate_ar, and the torch.distributed.barrier() after the block to
preserve synchronization.
---
Nitpick comments:
In `@examples/speculative_decoding/eagle_utils.py`:
- Around line 241-249: The call to
load_dataset("HuggingFaceH4/mt_bench_prompts") is happening inside the
synchronized/critical section causing all ranks to wait; move the dataset
load/cache out of that section and reuse a single cached dataset reference when
calling validate_ar (keep calling validate_ar(model=kwargs["model"],
tokenizer=kwargs["processing_class"], ds=cached_mt_bench_ds,
device=kwargs["model"].device)), ensuring cached_mt_bench_ds is initialized once
(e.g., at module import or rank-0 setup) and shared by workers so only the heavy
hub/cache lookup happens once while still preserving use of state.global_step,
ars, and wandb logging.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 43b9f242-1583-40b4-8db9-7a3aeb9b07cd
📒 Files selected for processing (1)
examples/speculative_decoding/eagle_utils.py
|
LGTM. Shall we also add some unit/example tests in this PR |
Implement DFlash (Block Diffusion for Flash Speculative Decoding) as a new mode in ModelOpt's speculative decoding framework. Key architecture: - Feature Fusion: extract hidden states from uniformly sampled target model layers, project via FC layer - KV Injection: fused target features injected as K/V entries in every draft decoder layer's attention (not just first layer input) - Parallel Drafting: all tokens in a block predicted simultaneously using learnable mask embeddings and bidirectional within-block attention Files: - dflash/ module: DFlashModel, DFlashConfig, conversion, default config - plugins/hf_dflash.py: HFDFlashModel with DFlashAttention (KV injection), DFlashModule (feature fusion + decoder), training forward pass with random anchor sampling and exponential position decay loss - main.py: --mode dflash support in training script Reference: "DFlash: Block Diffusion for Flash Speculative Decoding" (arXiv:2602.06036) Signed-off-by: Chenhan Yu <chenhany@nvidia.com>
Signed-off-by: Chenhan Yu <chenhany@nvidia.com>
Key fixes: - mask_token_id now read from dflash_architecture_config (e.g., 248070 for Qwen3) instead of defaulting to pad/eos token. Wrong mask_token_id caused garbage draft output despite correct weights. - Inherit model config from base model only as defaults; allow draft to have different num_heads/intermediate_size (needed for z-lab checkpoint) - Clean default_dflash_config to only contain DFlash-specific settings - pseudo_speculative_generate returns single block of tokens - Add dflash_mask_token_id CLI argument to main.py Validated: z-lab/Qwen3.5-4B-DFlash checkpoint produces AR=7.28 (expected ~6.08) Signed-off-by: Chenhan Yu <chenhany@nvidia.com>
Resolution order: 1. Explicit in dflash_architecture_config (user override) 2. Auto-detect from model vocabulary: - Qwen3/3.5: built-in [MASK] token (e.g., 248070) - Llama3: reserved_special_token_0 (128002) - Others: pad_token_id fallback 3. CLI override via --dflash_mask_token_id Based on z-lab checkpoints: - z-lab/Qwen3.5-4B-DFlash: mask=248070 - z-lab/LLaMA3.1-8B-Instruct-DFlash: mask=128002 - z-lab/gpt-oss-20b-DFlash: mask=200000 Signed-off-by: Chenhan Yu <chenhany@nvidia.com>
AR validation runs pseudo_speculative_generate which does unsynchronized model forward passes. In multi-GPU DDP training, this caused NCCL timeout because other ranks were waiting at gradient sync. Fix: only run validate_ar on rank 0 (is_master()), add torch.distributed.barrier() after to synchronize all ranks. Signed-off-by: Chenhan Yu <chenhany@nvidia.com>
super().forward() from HFDFlashModel goes through DynamicModule which dispatches back to HFDFlashModel.forward(), causing infinite recursion → stack overflow → NCCL timeout in multi-GPU training. Fix: use self._base_model() directly (same as pseudo_speculative_generate) for both eval-mode and training base model forward passes. Signed-off-by: Chenhan Yu <chenhany@nvidia.com>
The DynamicModule MRO correctly dispatches super().forward() to the original model class (e.g., Qwen3_5ForCausalLM.forward()) without looping — same pattern EAGLE uses successfully. The previous self._base_model() approach bypassed DDP, causing NCCL timeout because DDP's gradient sync couldn't track the forward pass. Keep pseudo_speculative_generate using self._base_model() since that runs outside DDP (single GPU AR validation). Signed-off-by: Chenhan Yu <chenhany@nvidia.com>
When a rank's batch has no valid loss positions (e.g., all tokens in Block 0 which is excluded), the loss was a detached zero tensor with no connection to dflash_module parameters. DDP waited forever for gradient sync on those parameters → NCCL ALLREDUCE timeout. Fix: use logits.sum() * 0.0 as zero loss, which maintains the computation graph through dflash_module parameters so DDP can sync zero gradients properly. Also revert to super().forward() for training (matching EAGLE pattern) and add --ddp_find_unused_parameters True, --ddp_timeout 300. Root cause analysis: rank 4 completed ALLREDUCE #272 and proceeded to ALLGATHER #273, while other ranks were stuck at ALLREDUCE #272. This indicated rank 4 had a different backward graph (no gradients for dflash_module on that rank). Signed-off-by: Chenhan Yu <chenhany@nvidia.com>
Add --dflash_use_logit_distillation flag that switches from hard CE loss (predict ground truth tokens) to logit distillation (learn from target model's output distribution). Hard CE only works when training data is synthesized by the target model itself. Logit distillation works with any data because it learns from the target model's actual predictions, not the ground truth. Usage: python main.py --mode dflash --dflash_use_logit_distillation ... Config: dflash_self_logit_distillation (default=True in config, toggled via CLI flag) Signed-off-by: Chenhan Yu <chenhany@nvidia.com>
Signed-off-by: Chenhan Yu <chenhany@nvidia.com>
Pass answer_only_loss=True to LanguageDataCollator for DFlash mode. This makes the tokenizer return assistant_masks via apply_chat_template with return_assistant_tokens_mask=True. HFDFlashModel.forward() now checks for assistant_masks in kwargs and uses it as loss_mask instead of attention_mask. This matches SpecForge's behavior of only computing loss on response tokens. SpecForge-trained checkpoint (response-only mask): AR=1.95 ModelOpt-trained checkpoint (all tokens mask): AR=1.15 Both with 30-35% training accuracy on same data. Signed-off-by: Chenhan Yu <chenhany@nvidia.com>
When answer_only_loss=True, set labels=-100 for non-assistant tokens using the assistant_masks from tokenizer.apply_chat_template. This ensures DFlash forward() can derive response-only loss mask from labels != -100, without relying on HF Trainer to pass assistant_masks. Also revert hf_dflash.py to use labels-based loss mask instead of kwargs-based assistant_masks (Trainer strips unknown keys). Signed-off-by: Chenhan Yu <chenhany@nvidia.com>
When answer_only_loss=True and the tokenizer's return_assistant_tokens_mask returns empty/unsupported results, fall back to regex-based detection of assistant spans in the formatted text (similar to SpecForge's approach). Supports Qwen/ChatML, Llama3, Llama2, and generic assistant patterns. Uses tokenizer offset_mapping to map character spans to token positions. DFlash forward uses labels != -100 to derive the response-only loss mask. Signed-off-by: Chenhan Yu <chenhany@nvidia.com>
Documents DFlash architecture, training usage, mask_token_id auto-detection, and current status including the known AR gap from data pipeline differences. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: Chenhan Yu <chenhany@nvidia.com>
Instead of hardcoding Llama components (LlamaMLP, LlamaRMSNorm, LlamaRotaryEmbedding), dynamically resolve them from the base model's transformers module (e.g., Qwen3MLP for Qwen3 models). Falls back to Llama components for unknown model types. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: Chenhan Yu <chenhany@nvidia.com>
Two bugs prevented response-only masking from working: 1. main.py never passed answer_only_loss=True to the data collator for DFlash mode, so all tokens had labels (511/512 instead of response-only). 2. HFDFlashModel.forward() used attention_mask (padding mask) for loss masking instead of labels. When answer_only_loss is enabled, the response-only information is in labels (where -100 = ignore), but this was completely ignored. Now uses labels when available. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: Chenhan Yu <chenhany@nvidia.com>
- Add common/dflash/online_training.sh for launcher - Add examples/Qwen/Qwen3-8B/hf_online_dflash.yaml - Add --mode dflash support to launch_train.sh with DFlash-specific args (block_size, num_layers, mask_token_id, config) - DFlash uses DDP instead of FSDP for training Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: Chenhan Yu <chenhany@nvidia.com>
Previously kept all tokens as valid when regex couldn't find assistant spans, causing training on system/user tokens and inflated per-token accuracy (70% vs SpecForge's 15%). Now masks everything, so these samples contribute zero loss. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: Chenhan Yu <chenhany@nvidia.com>
Replace complex chat templates with simplified versions that include
{% generation %} tags when answer_only_loss=True. Supports ChatML
(Qwen, Phi) and Llama3 template styles.
This fixes the inflated per-token accuracy (70% vs SpecForge's 15%)
caused by the regex fallback silently training on system/user tokens
when {% generation %} tags were missing.
The simplified templates correctly:
- Mark only assistant content for loss (including <think> blocks)
- Support multi-turn conversations
- Mask system and user tokens
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: Chenhan Yu <chenhany@nvidia.com>
Qwen3's original template auto-injects <think>\n\n</think>\n\n before assistant content. Match this in our simplified template by adding the think wrapper when content doesn't already start with <think>. Minor difference from original: we add it to all assistant turns, while Qwen3 only adds to the last turn. This doesn't affect training. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: Chenhan Yu <chenhany@nvidia.com>
The original Qwen3 template adds <think>\n\n</think>\n\n only to the last assistant turn, not all turns. Rather than replicating this complex logic, keep the simplified template clean — just output message content as-is. Training data already contains <think> blocks when present. Llama3 template has no think logic at all. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: Chenhan Yu <chenhany@nvidia.com>
Qwen3's original template adds <think>\n\n</think>\n\n to the last assistant turn when content doesn't start with <think>. Detect this by checking if '<think>' appears in the original template and use the chatml_think variant which replicates this behavior exactly. Models without think logic (Llama3, basic ChatML) use the plain chatml template. All three samples now match the original tokenization. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: Chenhan Yu <chenhany@nvidia.com>
Clearly document what the simplified chat templates do, what is preserved vs dropped, and limitations for tool-use and multi-step reasoning data. Also document how to use a custom template. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: Chenhan Yu <chenhany@nvidia.com>
The early-exit zero-loss path used target_hidden (computed under no_grad) which has no gradient graph, causing 'does not require grad' error. Use dflash_module.fc.weight instead to keep DDP happy. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: Chenhan Yu <chenhany@nvidia.com>
The Speculative-Decoding dataset has both 'messages' (prompt only) and 'conversations' (prompt + response) fields. The collator took 'messages' first, missing the assistant response entirely. Now checks if 'messages' has an assistant turn, otherwise falls back to 'conversations'. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: Chenhan Yu <chenhany@nvidia.com>
… collator - Remove debug prints in hf_dflash.py (_get_attn_fn and modify) while keeping informational prints (mask_token_id and base forward) - Add "Legacy: used for inference only" comment on create_dflash_attention_mask and create_dflash_loss_mask - Remove _apply_answer_only_labels regex fallback in transformers_dataset.py; raise ValueError when assistant_masks is missing/empty - Add validation for missing assistant turns in __call__ - Make _ensure_generation_tags warnings more prominent with === WARNING === prefix Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: Chenhan Yu <chenhany@nvidia.com>
The AR validation callback ran model.forward() on rank 0 only, but DDP model forward triggers collective ops that require all ranks. Now unwraps the DDP model (model.module) before validation, so forward runs without collective hooks. Other ranks wait at barrier. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: Chenhan Yu <chenhany@nvidia.com>
During distributed training, a single bad sample would crash all ranks. Now warns and skips samples without assistant turns. Also handle the case where all assistant content is truncated by masking all labels instead of raising ValueError. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: Chenhan Yu <chenhany@nvidia.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: Chenhan Yu <chenhany@nvidia.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: Chenhan Yu <chenhany@nvidia.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: Chenhan Yu <chenhany@nvidia.com>
AR validation with DDP is fundamentally incompatible — pseudo_speculative_generate runs inference on rank 0 while other ranks deadlock on collective ops. Now detects world_size > 1 and skips with a one-time warning. AR validation still works for single-GPU and post-training (online_training.sh). Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: Chenhan Yu <chenhany@nvidia.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: Chenhan Yu <chenhany@nvidia.com>
- Fix resume: add device_map='cpu' to checkpoint loading path to avoid meta tensor errors - Add export step to online_training.sh: after training, export DFlash checkpoint to z-lab HF format, then validate AR on the exported checkpoint - AR validation prefers exported checkpoint (no prefix) over training checkpoint (with prefix) Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: Chenhan Yu <chenhany@nvidia.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: Chenhan Yu <chenhany@nvidia.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: Chenhan Yu <chenhany@nvidia.com>
Try scontrol first, then parse SLURM_JOB_NODELIST directly and resolve via getent hosts. Works both inside and outside containers. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: Chenhan Yu <chenhany@nvidia.com>
Try scontrol, SLURM_LAUNCH_NODE_IPADDR, Python socket resolution, and hostname -I as fallbacks. Should work inside containers where scontrol is not available. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: Chenhan Yu <chenhany@nvidia.com>
DFlash uses DDP, not FSDP. The default dp_shard_size=TOTAL_GPU caused FSDP-style sharding. Force to 1 for pure DDP replication. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: Chenhan Yu <chenhany@nvidia.com>
DFlash defaults to DDP (dp_shard_size=1). Pass --fsdp True to use FSDP with full_shard instead. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: Chenhan Yu <chenhany@nvidia.com>
load_vlm_or_llm uses meta tensors internally. Use AutoModelForCausalLM.from_pretrained with low_cpu_mem_usage=False to avoid meta tensor errors during export. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: Chenhan Yu <chenhany@nvidia.com>
Same script used by EAGLE3 export. Avoids custom loading logic that caused meta tensor errors. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: Chenhan Yu <chenhany@nvidia.com>
Loading from checkpoint subdirectory (e.g., checkpoint-12500/) causes meta tensor errors with transformers 5.x. Load from output_dir (top-level save) instead, which works. The checkpoint path is still passed to trainer.train(resume_from_checkpoint=...) for optimizer and step count resume. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: Chenhan Yu <chenhany@nvidia.com>
| """Forward matching SpecForge DFlashDraftModel.forward.""" | ||
| hidden_states = noise_embedding | ||
| target_hidden = self.hidden_norm(self.fc(target_hidden)) | ||
| position_embeddings = self.rotary_emb(hidden_states, position_ids) |
There was a problem hiding this comment.
| position_embeddings = self.rotary_emb(hidden_states, position_ids) | |
| #lazy init rope | |
| if not hasattr(self, "rotary_emb"): | |
| self.rotary_emb = _ROTARY_CLS(config=config, device=hidden_states.device) | |
| position_embeddings = self.rotary_emb(hidden_states, position_ids) |
| [DFlashDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] | ||
| ) | ||
| self.norm = _NORM_CLS(config.hidden_size, eps=config.rms_norm_eps) | ||
| self.rotary_emb = _ROTARY_CLS(config=config) |
There was a problem hiding this comment.
| self.rotary_emb = _ROTARY_CLS(config=config) |
Lazy init as below to avoid meta tensor copy error
Implement DFlash (Block Diffusion for Flash Speculative Decoding) as a new mode in ModelOpt's speculative decoding framework.
Key architecture:
Files:
Reference: "DFlash: Block Diffusion for Flash Speculative Decoding" (arXiv:2602.06036)
What does this PR do?
Type of change: ?
Usage
# Add a code snippet demonstrating how to use thisTesting
Before your PR is "Ready for review"
Make sure you read and follow Contributor guidelines and your commits are signed (
git commit -s -S).Make sure you read and follow the Security Best Practices (e.g. avoiding hardcoded
trust_remote_code=True,torch.load(..., weights_only=False),pickle, etc.).CONTRIBUTING.md: ✅ / ❌ / N/AAdditional Information
Summary by CodeRabbit
New Features
Bug Fixes