Conversation
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughThis PR adds skip-softmax sparse attention support for Diffusers and LTX-2 models, including new eager and Triton kernel backends, calibration refinements, example scripts, and comprehensive tests for framework-specific attention implementations. Changes
Sequence Diagram(s)sequenceDiagram
actor User
participant Pipeline as WAN Pipeline
participant Sparsify as mtsa.sparsify()
participant Transformer as Transformer<br/>Modules
participant CalibLoop as Calibration<br/>Forward Loop
participant Config as Sparse<br/>Config
User->>Pipeline: build_pipeline(model_path)
activate Pipeline
Pipeline-->>User: pipeline ready
deactivate Pipeline
User->>Config: build_sparse_config(args)
activate Config
Config-->>User: sparse config dict
deactivate Config
alt Calibration Mode
User->>CalibLoop: build_calibration_forward_loop()
activate CalibLoop
CalibLoop-->>User: forward_loop callable
deactivate CalibLoop
User->>Sparsify: sparsify(transformer, config,<br/>forward_loop=...)
activate Sparsify
Sparsify->>CalibLoop: invoke forward_loop<br/>(multiple prompts)
CalibLoop->>Transformer: collect attention stats
Transformer-->>CalibLoop: attention outputs
Sparsify->>Transformer: apply sparse config
Transformer-->>Sparsify: sparse attention modules
Sparsify-->>User: sparsified transformer
deactivate Sparsify
else No Calibration
User->>Sparsify: sparsify(transformer, config)
activate Sparsify
Sparsify->>Transformer: apply sparse config
Transformer-->>Sparsify: sparse attention modules
Sparsify-->>User: sparsified transformer
deactivate Sparsify
end
User->>Pipeline: generate(prompt, ...)
activate Pipeline
Pipeline->>Transformer: forward with sparse<br/>attention
Transformer-->>Pipeline: output frames
deactivate Pipeline
User->>User: print_sparsity_summary(model)
activate User
User->>Transformer: enumerate SparseAttentionModule
Transformer-->>User: module configs
deactivate User
sequenceDiagram
participant Model as Model<br/>(Diffusers/LTX)
participant ConversionFn as convert_to_sparse_<br/>attention_model()
participant RegFn as _register_diffusers_<br/>backends_if_needed()
participant DiffusersBackends as Diffusers<br/>Backends
participant LTXBackends as LTX<br/>Backends
participant Method as Sparse<br/>Method
Model->>ConversionFn: convert_to_sparse_attention_model(model, ...)
activate ConversionFn
ConversionFn->>RegFn: _register_diffusers_backends_if_needed(model)
activate RegFn
alt Is Diffusers ModelMixin
RegFn->>DiffusersBackends: register_diffusers_eager_attention()
RegFn->>DiffusersBackends: register_diffusers_triton_attention()
DiffusersBackends-->>RegFn: backends registered
end
alt Has LTX modules
RegFn->>LTXBackends: patch ltx attention<br/>modules
LTXBackends-->>RegFn: wrappers installed
end
RegFn-->>ConversionFn: registration complete
deactivate RegFn
ConversionFn->>ConversionFn: _set_attn_implementation()
ConversionFn->>Method: apply sparse config
Method-->>ConversionFn: sparse model
ConversionFn-->>Model: sparse model ready
deactivate ConversionFn
Estimated code review effort🎯 4 (Complex) | ⏱️ ~75 minutes Possibly related PRs
Suggested reviewers
🚥 Pre-merge checks | ✅ 2 | ❌ 2❌ Failed checks (1 warning, 1 inconclusive)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
|
Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually. Contributors can view more details about this message here. |
|
Codecov Report❌ Patch coverage is Additional details and impacted files@@ Coverage Diff @@
## main #1166 +/- ##
==========================================
- Coverage 74.77% 73.51% -1.26%
==========================================
Files 351 356 +5
Lines 40072 43723 +3651
==========================================
+ Hits 29964 32145 +2181
- Misses 10108 11578 +1470
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
8151232 to
5873652
Compare
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
There was a problem hiding this comment.
Actionable comments posted: 8
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.py (2)
306-316:⚠️ Potential issue | 🟠 MajorDecode calibration fails when
forward_loopis provided.When a user supplies
forward_loop, lines 264-265 skip buildingtokenizerandcalibration_data. However, decode calibration (lines 312-316) unconditionally requires both, raisingRuntimeErroreven though the user intended to use their own loop.This creates an inconsistency: prefill calibration supports user-provided
forward_loop, but decode calibration does not. The docstring (line 227) also statesforward_loopis "Only used for prefill", but this limitation should either be enforced earlier or decode should also accept a custom loop.💡 Suggested approach
Either:
- Skip decode calibration when
forward_loopis provided andcalibration_dataisNone, with a warning- Accept a separate
decode_forward_loopparameter- Document and enforce that decode calibration requires RULER dataset
# Run decode calibration if enabled if calibrate_decode: + if calibration_data is None or tokenizer is None: + warnings.warn( + "Decode calibration requires RULER dataset. Skipping decode calibration " + "because a custom forward_loop was provided without calibration_data." + ) + else: print("\n" + "=" * 60) # ... rest of decode calibration🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.py` around lines 306 - 316, The decode calibration block (calibrate_decode) currently raises RuntimeError if calibration_data or tokenizer are missing even when the user supplied forward_loop; change the logic in the calibrate_decode section to detect when forward_loop is provided and calibration_data is None and skip decode calibration with a warning instead of raising, i.e., only call create_decode_calibration_forward_loop when calibration_data and tokenizer exist (use create_decode_calibration_forward_loop(calibration_data, tokenizer, ...)), otherwise log/warn that decode calibration is skipped due to missing calibration_data/tokenizer while a custom forward_loop was supplied; update any related docstring or comment near the calibrate_decode and forward_loop mentions to reflect this behavior.
24-24:⚠️ Potential issue | 🔴 CriticalUnconditional
transformersimport causes pipeline failure.The module-level import of
transformers.AutoTokenizerfails whentransformersis not installed. This should be deferred to usage sites (inside_load_tokenizeror guarded) to allow the module to be imported when only diffusers-based workflows are used.🐛 Proposed fix: defer import to usage site
-from transformers import AutoTokenizerThen update
_load_tokenizer:def _load_tokenizer(tokenizer_name_or_path: str) -> "AutoTokenizer": """Load tokenizer and ensure pad_token is set.""" + from transformers import AutoTokenizer + tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.py` at line 24, The file currently imports transformers.AutoTokenizer at module scope which raises ImportError when transformers isn't installed; move the import into the tokenizer-loading code path so the module can be imported without transformers. Specifically, remove the top-level "from transformers import AutoTokenizer" and instead import AutoTokenizer inside the _load_tokenizer function (or guard the import with a try/except that raises a clear error when the function is called), ensuring _load_tokenizer handles the absence of transformers and only then attempts to create the tokenizer.
♻️ Duplicate comments (1)
modelopt/torch/sparsity/attention_sparsity/kernels/diffusers_triton_attention.py (1)
30-36:⚠️ Potential issue | 🟠 MajorSame top-level import issue as the eager backend.
Both
diffusersandmodelopt.torch.kernelsare imported unconditionally at the top level. This will cause import failures for users who don't have diffusers installed or don't have CUDA/Triton available.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/sparsity/attention_sparsity/kernels/diffusers_triton_attention.py` around lines 30 - 36, Top-level unconditional imports of diffusers symbols (AttentionBackendName, _AttentionBackendRegistry, attention_backend) and modelopt.torch.kernels.attention can fail for users without diffusers or CUDA/Triton; move these imports into the function or class that actually uses them (e.g., inside the registration function or the backend implementation in diffusers_triton_attention.py) and guard with try/except ImportError to raise a clear error only when the backend is instantiated; ensure you reference the same symbols (AttentionBackendName, _AttentionBackendRegistry, attention_backend, and attention) after relocating the imports so registration only occurs when dependencies are present.
🧹 Nitpick comments (4)
modelopt/torch/sparsity/attention_sparsity/conversion.py (1)
124-125: Overly broad exception handling may hide bugs.
except (ImportError, Exception)catches all exceptions including programming errors (e.g.,TypeError,AttributeError). Consider narrowing to specific expected exceptions.♻️ Suggested fix
- except (ImportError, Exception): + except (ImportError, RuntimeError): passOr log unexpected exceptions for debugging:
- except (ImportError, Exception): - pass + except ImportError: + pass + except Exception as e: + import logging + logging.debug(f"Diffusers backend registration failed: {e}")🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/sparsity/attention_sparsity/conversion.py` around lines 124 - 125, The current broad except clause "except (ImportError, Exception)" in conversion.py swallows all errors and can hide bugs; change it to only catch ImportError (e.g., "except ImportError as e") for the import-failure path, and if you must catch other runtime issues around the same block, catch specific exceptions or log unexpected exceptions (use a logger.exception or re-raise after logging) so that programming errors like TypeError/AttributeError are not silently ignored; update the except block that follows the import attempt in conversion.py accordingly.examples/diffusers/sparsity/ltx2_skip_softmax.py (2)
66-81: Hardcoded user-specific paths should be placeholders.The default paths contain user-specific paths (
/home/scratch.omniml_data_2/jingyux/...) that won't exist on other systems. Consider using empty strings or raising a clear error when environment variables are not set.Proposed fix: Require explicit configuration
-CHECKPOINT_PATH = os.environ.get( - "LTX2_CHECKPOINT", - "/home/scratch.omniml_data_2/jingyux/models/LTX-2/ltx-2-19b-dev.safetensors", -) +CHECKPOINT_PATH = os.environ.get("LTX2_CHECKPOINT", "") +DISTILLED_LORA_PATH = os.environ.get("LTX2_DISTILLED_LORA", "") +SPATIAL_UPSAMPLER_PATH = os.environ.get("LTX2_SPATIAL_UPSAMPLER", "") +GEMMA_ROOT = os.environ.get("LTX2_GEMMA_ROOT", "")Then in
build_pipeline():def build_pipeline() -> TI2VidTwoStagesPipeline: if not all([CHECKPOINT_PATH, DISTILLED_LORA_PATH, SPATIAL_UPSAMPLER_PATH, GEMMA_ROOT]): raise ValueError( "Missing required environment variables. Set: " "LTX2_CHECKPOINT, LTX2_DISTILLED_LORA, LTX2_SPATIAL_UPSAMPLER, LTX2_GEMMA_ROOT" ) ...🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/diffusers/sparsity/ltx2_skip_softmax.py` around lines 66 - 81, The file defines CHECKPOINT_PATH, DISTILLED_LORA_PATH, SPATIAL_UPSAMPLER_PATH, and GEMMA_ROOT with hardcoded user-specific default paths; remove those user paths and default to empty string or None when reading the env vars (os.environ.get(..., "") or None), and add a validation at the start of build_pipeline() that checks these constants (CHECKPOINT_PATH, DISTILLED_LORA_PATH, SPATIAL_UPSAMPLER_PATH, GEMMA_ROOT) and raises a clear ValueError listing the required env names (LTX2_CHECKPOINT, LTX2_DISTILLED_LORA, LTX2_SPATIAL_UPSAMPLER, LTX2_GEMMA_ROOT) if any are missing so callers must explicitly configure them.
260-267:load_datasetcall may download data unexpectedly.The
load_dataset("nkp37/OpenVid-1M")call will download the dataset on first run, which could be surprising for users. Consider adding a note in the docstring or CLI help, or making this behavior opt-in.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/diffusers/sparsity/ltx2_skip_softmax.py` around lines 260 - 267, The load_calib_prompts function calls load_dataset("nkp37/OpenVid-1M") which may trigger a large download unexpectedly; update load_calib_prompts to make dataset download explicit or opt-in (e.g., add a parameter like download: bool = False or a dataset_path argument) and update the docstring to warn that calling this function will download the OpenVid-1M dataset unless an existing local dataset path is provided; ensure the function checks the opt-in flag or uses the provided path before invoking load_dataset to avoid surprising downloads.tests/unit/torch/sparsity/attention_sparsity/test_kernel_backends.py (1)
130-158: Test mock doesn't match actual kernel call signature.The mock
mk.attention = lambda q, k, v, **kw: qreturnsqdirectly, but per the context snippet, the actual kernel receives varlen metadata and returns output with shape[B*S, H, D]. The mock should return a tensor with the correct output shape to avoid masking reshape bugs in the code under test.Proposed fix: Return correctly shaped tensor
- mk.attention = lambda q, k, v, **kw: q + def mock_attention(q, k, v, **kw): + # Return tensor with same shape as q (correct for varlen format) + return torch.zeros_like(q) + mk.attention = mock_attention🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/unit/torch/sparsity/attention_sparsity/test_kernel_backends.py` around lines 130 - 158, The mock attention implementation in TestDiffusersTritonAttention._setup (mk.attention = lambda q, k, v, **kw: q) does not match the real kernel signature/behavior and returns the wrong shape, which hides reshape/masking bugs; update the mock in _setup (mk.attention) to accept the same args including varlen metadata (keep **kw) and return a tensor with shape [B*S, H, D] derived from the input q/k/v shapes (e.g., compute B, S, H, D from q and construct/return a tensor of that shape instead of returning q directly) so the code paths in _diffusers_triton_attention, set_triton_skip_softmax_config, clear_triton_skip_softmax_config, register_diffusers_triton_attention and get_triton_attention_backend are exercised with realistic output shapes.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In
`@modelopt/torch/sparsity/attention_sparsity/kernels/diffusers_eager_attention.py`:
- Around line 31-35: The top-level import of diffusers internals
(AttentionBackendName, _AttentionBackendRegistry, attention_backend) should be
guarded to avoid ImportError for users without diffusers; wrap the import in a
try/except and set a module-level flag (e.g., _DIFFUSERS_AVAILABLE =
False/True). Update any registration or accessor functions that reference
AttentionBackendName, _AttentionBackendRegistry or attention_backend to check
_DIFFUSERS_AVAILABLE before using them and return/do nothing or raise a clear
runtime error when diffusers is unavailable. Ensure all places that previously
assumed the imports (registration functions) consult _DIFFUSERS_AVAILABLE so the
module can be imported without diffusers installed.
- Around line 120-132: The code unconditionally manipulates private diffusers
internals (AttentionBackendName, _AttentionBackendRegistry, etc.) that exist
only in diffusers >= 0.36.0; add a runtime/version guard before creating
new_member and registering _diffusers_eager_attention: check
diffusers.__version__ (or use the same utility used in
modelopt/torch/quantization/plugins/diffusion/diffusers.py) or probe for the
presence of attributes like AttentionBackendName._member_map_ and
_AttentionBackendRegistry._backends, and only perform the enum extension and
registry assignments when those APIs exist; also update pyproject.toml to
require diffusers>=0.36.0 or make the registration conditional so the code
no-ops on older diffusers versions.
In
`@modelopt/torch/sparsity/attention_sparsity/kernels/diffusers_triton_attention.py`:
- Around line 61-70: The _diffusers_triton_attention function currently accepts
attn_mask and enable_gqa but ignores them; update the function so it either
implements GQA and mask handling consistent with the eager backend or explicitly
fails fast: if enable_gqa is True or attn_mask is not None, raise a clear
NotImplementedError mentioning "_diffusers_triton_attention does not support
enable_gqa/attn_mask yet" (or implement the same GQA reshaping/aggregation logic
used in diffusers_eager_attention for query/key/value before calling the Triton
kernel) so callers won't silently get incorrect results.
In `@modelopt/torch/sparsity/attention_sparsity/kernels/ltx_triton_attention.py`:
- Around line 60-67: The _ltx_triton_attention function currently accepts a mask
parameter but never uses it; update the implementation to handle masks: either
pass the mask into the Triton kernel via the attn_mask argument when invoking
the kernel (ensure shapes/dtypes match and add logic to convert/expand the mask
to the kernel's expected form), or if kernel masking isn't supported yet,
explicitly reject masks by raising a clear error (e.g., raise
NotImplementedError("mask not supported by _ltx_triton_attention") when mask is
not None) so callers won't silently get incorrect results. Ensure the change is
applied inside _ltx_triton_attention and that any conversion/validation of mask
is performed before the kernel call.
- Line 29: The module unconditionally imports Attention from ltx_core which will
raise ImportError for users without LTX-2; change the top-level import to a
guarded import (try/except ImportError) or defer importing until registration,
set Attention = None on failure, and update register_ltx_triton_attention to
check if Attention is None and raise a clear ImportError like "ltx_core is
required for LTX-2 Triton attention" before proceeding; reference the symbols
Attention and register_ltx_triton_attention in ltx_triton_attention.py to locate
where to apply the guard.
In `@modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py`:
- Around line 389-401: The code sets the thread-wide flag via
set_skip_softmax_context(True) immediately which can leak if an exception occurs
before the returned ExitStack is entered; instead create a small context manager
(e.g., using contextlib.contextmanager or a tiny class) that calls
set_skip_softmax_context(True) on __enter__/enter and
set_skip_softmax_context(False) on __exit__/exit, and then register that context
with stack.enter_context rather than calling set_skip_softmax_context and
stack.callback directly; update the function that builds the stack (the block
using ExitStack, get_skip_softmax_attention_backend,
replace_function(torch.nn.functional, "softmax", sparse_softmax)) to enter the
new flag-context via stack.enter_context so the flag is only set when the stack
is actually entered and always cleaned up on exit.
In `@tests/unit/torch/sparsity/attention_sparsity/test_kernel_backends.py`:
- Around line 56-72: The tests import
modelopt.torch.sparsity.attention_sparsity.kernels which transitively imports
transformers and breaks CI; update TestSkipSoftmaxContext to skip when optional
dependency missing by using pytest.importorskip('transformers') or catching
ImportError before importing get_skip_softmax_context/set_skip_softmax_context
(or call pytest.skip) so the test cleanly skips in environments without
transformers; ensure the changes are applied around the imports used in
TestSkipSoftmaxContext (references: get_skip_softmax_context,
set_skip_softmax_context, TestSkipSoftmaxContext).
- Around line 178-207: The test fails because patching targets under
"modelopt.torch.sparsity.attention_sparsity.kernels" occurs before that
submodule is loaded, causing a Module attribute error; to fix, ensure the module
is imported before patching or patch the symbols at the location they are looked
up by _register_diffusers_backends_if_needed: import
modelopt.torch.sparsity.attention_sparsity.conversion (or the parent package)
first, then patch the call targets register_diffusers_eager_attention and
register_diffusers_triton_attention as used by that conversion module (i.e.,
patch where _register_diffusers_backends_if_needed resolves them) so the
MagicMock replacement applies correctly.
---
Outside diff comments:
In `@modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.py`:
- Around line 306-316: The decode calibration block (calibrate_decode) currently
raises RuntimeError if calibration_data or tokenizer are missing even when the
user supplied forward_loop; change the logic in the calibrate_decode section to
detect when forward_loop is provided and calibration_data is None and skip
decode calibration with a warning instead of raising, i.e., only call
create_decode_calibration_forward_loop when calibration_data and tokenizer exist
(use create_decode_calibration_forward_loop(calibration_data, tokenizer, ...)),
otherwise log/warn that decode calibration is skipped due to missing
calibration_data/tokenizer while a custom forward_loop was supplied; update any
related docstring or comment near the calibrate_decode and forward_loop mentions
to reflect this behavior.
- Line 24: The file currently imports transformers.AutoTokenizer at module scope
which raises ImportError when transformers isn't installed; move the import into
the tokenizer-loading code path so the module can be imported without
transformers. Specifically, remove the top-level "from transformers import
AutoTokenizer" and instead import AutoTokenizer inside the _load_tokenizer
function (or guard the import with a try/except that raises a clear error when
the function is called), ensuring _load_tokenizer handles the absence of
transformers and only then attempts to create the tokenizer.
---
Duplicate comments:
In
`@modelopt/torch/sparsity/attention_sparsity/kernels/diffusers_triton_attention.py`:
- Around line 30-36: Top-level unconditional imports of diffusers symbols
(AttentionBackendName, _AttentionBackendRegistry, attention_backend) and
modelopt.torch.kernels.attention can fail for users without diffusers or
CUDA/Triton; move these imports into the function or class that actually uses
them (e.g., inside the registration function or the backend implementation in
diffusers_triton_attention.py) and guard with try/except ImportError to raise a
clear error only when the backend is instantiated; ensure you reference the same
symbols (AttentionBackendName, _AttentionBackendRegistry, attention_backend, and
attention) after relocating the imports so registration only occurs when
dependencies are present.
---
Nitpick comments:
In `@examples/diffusers/sparsity/ltx2_skip_softmax.py`:
- Around line 66-81: The file defines CHECKPOINT_PATH, DISTILLED_LORA_PATH,
SPATIAL_UPSAMPLER_PATH, and GEMMA_ROOT with hardcoded user-specific default
paths; remove those user paths and default to empty string or None when reading
the env vars (os.environ.get(..., "") or None), and add a validation at the
start of build_pipeline() that checks these constants (CHECKPOINT_PATH,
DISTILLED_LORA_PATH, SPATIAL_UPSAMPLER_PATH, GEMMA_ROOT) and raises a clear
ValueError listing the required env names (LTX2_CHECKPOINT, LTX2_DISTILLED_LORA,
LTX2_SPATIAL_UPSAMPLER, LTX2_GEMMA_ROOT) if any are missing so callers must
explicitly configure them.
- Around line 260-267: The load_calib_prompts function calls
load_dataset("nkp37/OpenVid-1M") which may trigger a large download
unexpectedly; update load_calib_prompts to make dataset download explicit or
opt-in (e.g., add a parameter like download: bool = False or a dataset_path
argument) and update the docstring to warn that calling this function will
download the OpenVid-1M dataset unless an existing local dataset path is
provided; ensure the function checks the opt-in flag or uses the provided path
before invoking load_dataset to avoid surprising downloads.
In `@modelopt/torch/sparsity/attention_sparsity/conversion.py`:
- Around line 124-125: The current broad except clause "except (ImportError,
Exception)" in conversion.py swallows all errors and can hide bugs; change it to
only catch ImportError (e.g., "except ImportError as e") for the import-failure
path, and if you must catch other runtime issues around the same block, catch
specific exceptions or log unexpected exceptions (use a logger.exception or
re-raise after logging) so that programming errors like TypeError/AttributeError
are not silently ignored; update the except block that follows the import
attempt in conversion.py accordingly.
In `@tests/unit/torch/sparsity/attention_sparsity/test_kernel_backends.py`:
- Around line 130-158: The mock attention implementation in
TestDiffusersTritonAttention._setup (mk.attention = lambda q, k, v, **kw: q)
does not match the real kernel signature/behavior and returns the wrong shape,
which hides reshape/masking bugs; update the mock in _setup (mk.attention) to
accept the same args including varlen metadata (keep **kw) and return a tensor
with shape [B*S, H, D] derived from the input q/k/v shapes (e.g., compute B, S,
H, D from q and construct/return a tensor of that shape instead of returning q
directly) so the code paths in _diffusers_triton_attention,
set_triton_skip_softmax_config, clear_triton_skip_softmax_config,
register_diffusers_triton_attention and get_triton_attention_backend are
exercised with realistic output shapes.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 578e7893-06c2-4586-93e1-4726205a2f84
📒 Files selected for processing (14)
examples/diffusers/sparsity/ltx2_skip_softmax.pyexamples/diffusers/sparsity/wan22_skip_softmax.pymodelopt/torch/sparsity/attention_sparsity/calibration/calibrate.pymodelopt/torch/sparsity/attention_sparsity/conversion.pymodelopt/torch/sparsity/attention_sparsity/kernels/__init__.pymodelopt/torch/sparsity/attention_sparsity/kernels/diffusers_eager_attention.pymodelopt/torch/sparsity/attention_sparsity/kernels/diffusers_triton_attention.pymodelopt/torch/sparsity/attention_sparsity/kernels/ltx_eager_attention.pymodelopt/torch/sparsity/attention_sparsity/kernels/ltx_triton_attention.pymodelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.pymodelopt/torch/sparsity/attention_sparsity/methods/triton_skip_softmax.pymodelopt/torch/sparsity/attention_sparsity/plugins/huggingface.pymodelopt/torch/sparsity/attention_sparsity/stats_manager.pytests/unit/torch/sparsity/attention_sparsity/test_kernel_backends.py
| from diffusers.models.attention_dispatch import ( | ||
| AttentionBackendName, | ||
| _AttentionBackendRegistry, | ||
| attention_backend, | ||
| ) |
There was a problem hiding this comment.
Unconditional import of diffusers internals will cause ImportError for non-diffusers users.
Similar to the ltx_core issue, this module imports diffusers at the top level. Since attention_sparsity may be used without diffusers, consider guarding this import.
Proposed fix: guard the import
+try:
from diffusers.models.attention_dispatch import (
AttentionBackendName,
_AttentionBackendRegistry,
attention_backend,
)
+ _DIFFUSERS_AVAILABLE = True
+except ImportError:
+ _DIFFUSERS_AVAILABLE = False
+ AttentionBackendName = None # type: ignore
+ _AttentionBackendRegistry = None # type: ignore
+ attention_backend = None # type: ignoreThen check _DIFFUSERS_AVAILABLE in registration functions.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In
`@modelopt/torch/sparsity/attention_sparsity/kernels/diffusers_eager_attention.py`
around lines 31 - 35, The top-level import of diffusers internals
(AttentionBackendName, _AttentionBackendRegistry, attention_backend) should be
guarded to avoid ImportError for users without diffusers; wrap the import in a
try/except and set a module-level flag (e.g., _DIFFUSERS_AVAILABLE =
False/True). Update any registration or accessor functions that reference
AttentionBackendName, _AttentionBackendRegistry or attention_backend to check
_DIFFUSERS_AVAILABLE before using them and return/do nothing or raise a clear
runtime error when diffusers is unavailable. Ensure all places that previously
assumed the imports (registration functions) consult _DIFFUSERS_AVAILABLE so the
module can be imported without diffusers installed.
| # Extend the AttentionBackendName enum with our custom value | ||
| new_member = str.__new__(AttentionBackendName, _BACKEND_NAME) | ||
| new_member._name_ = "MODELOPT_SKIP_SOFTMAX" | ||
| new_member._value_ = _BACKEND_NAME | ||
| AttentionBackendName._member_map_["MODELOPT_SKIP_SOFTMAX"] = new_member | ||
| AttentionBackendName._value2member_map_[_BACKEND_NAME] = new_member | ||
|
|
||
| # Register the backend function | ||
| _AttentionBackendRegistry._backends[new_member] = _diffusers_eager_attention | ||
| _AttentionBackendRegistry._constraints[new_member] = [] | ||
| _AttentionBackendRegistry._supported_arg_names[new_member] = set( | ||
| inspect.signature(_diffusers_eager_attention).parameters.keys() | ||
| ) |
There was a problem hiding this comment.
🧩 Analysis chain
🌐 Web query:
What version of diffusers introduced _AttentionBackendRegistry and attention_backend context manager?
💡 Result:
The _AttentionBackendRegistry and attention_backend context manager were introduced in diffusers version 0.36.0.
Citations:
- 1: https://huggingface.co/docs/diffusers/optimization/attention_backends
- 2: https://huggingface.co/docs/diffusers/main/optimization/attention_backends
- 3: https://github.com/huggingface/diffusers/blob/main/docs/source/en/optimization/attention_backends.md
- 4: https://newreleases.io/project/github/huggingface/diffusers/release/v0.36.0
🏁 Script executed:
# Find pyproject.toml and requirements files to check diffusers version constraints
find . -type f \( -name "pyproject.toml" -o -name "requirements*.txt" -o -name "setup.py" \) | head -20Repository: NVIDIA/Model-Optimizer
Length of output: 945
🏁 Script executed:
# Check the specific file to see its full context and any version checks
cat -n modelopt/torch/sparsity/attention_sparsity/kernels/diffusers_eager_attention.py | head -150Repository: NVIDIA/Model-Optimizer
Length of output: 6161
🏁 Script executed:
# Search for any version checks related to diffusers in the codebase
rg "diffusers" --type py -A 2 | grep -i "version\|__version__" | head -20Repository: NVIDIA/Model-Optimizer
Length of output: 2357
🏁 Script executed:
# Check if there are any documented version requirements or constraints
rg "0\.36\|diffusers.*version" --type md --type txt --type py | head -20Repository: NVIDIA/Model-Optimizer
Length of output: 48
🏁 Script executed:
# Check conversion.py to see how registration is called and if version checks exist
cat -n modelopt/torch/sparsity/attention_sparsity/conversion.py | grep -A 15 -B 5 "_register_diffusers_backends_if_needed"Repository: NVIDIA/Model-Optimizer
Length of output: 1992
🏁 Script executed:
# Find and check pyproject.toml for diffusers version constraints
cat pyproject.toml | grep -A 10 -B 2 "diffusers"Repository: NVIDIA/Model-Optimizer
Length of output: 690
🏁 Script executed:
# Check if there are any import guards around the diffusers imports
rg "from diffusers.*attention_dispatch" -B 5 -A 5Repository: NVIDIA/Model-Optimizer
Length of output: 3478
Add version check for diffusers >= 0.36.0 and document private API reliance.
The code imports and manipulates private diffusers internals (_AttentionBackendRegistry, _member_map_, _value2member_map_, _backends, _constraints, _supported_arg_names) that were introduced in diffusers 0.36.0. However, pyproject.toml specifies only diffusers>=0.32.2, creating a compatibility gap. Add a version guard on module import (following the pattern in modelopt/torch/quantization/plugins/diffusion/diffusers.py), update the minimum version requirement in pyproject.toml, or conditionally register the backend only when the required APIs are available.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In
`@modelopt/torch/sparsity/attention_sparsity/kernels/diffusers_eager_attention.py`
around lines 120 - 132, The code unconditionally manipulates private diffusers
internals (AttentionBackendName, _AttentionBackendRegistry, etc.) that exist
only in diffusers >= 0.36.0; add a runtime/version guard before creating
new_member and registering _diffusers_eager_attention: check
diffusers.__version__ (or use the same utility used in
modelopt/torch/quantization/plugins/diffusion/diffusers.py) or probe for the
presence of attributes like AttentionBackendName._member_map_ and
_AttentionBackendRegistry._backends, and only perform the enum extension and
registry assignments when those APIs exist; also update pyproject.toml to
require diffusers>=0.36.0 or make the registration conditional so the code
no-ops on older diffusers versions.
| def _diffusers_triton_attention( | ||
| query: torch.Tensor, | ||
| key: torch.Tensor, | ||
| value: torch.Tensor, | ||
| attn_mask: torch.Tensor | None = None, | ||
| dropout_p: float = 0.0, | ||
| is_causal: bool = False, | ||
| scale: float | None = None, | ||
| enable_gqa: bool = False, | ||
| ) -> torch.Tensor: |
There was a problem hiding this comment.
Parameters attn_mask and enable_gqa are accepted but silently ignored.
Unlike the eager backend which implements GQA support (Lines 66-71 in diffusers_eager_attention.py), the Triton backend ignores enable_gqa. This inconsistency could cause silent correctness issues when switching backends.
Proposed fix: Add GQA support or raise error
def _diffusers_triton_attention(
...
enable_gqa: bool = False,
) -> torch.Tensor:
+ if enable_gqa and query.shape[2] != key.shape[2]:
+ raise NotImplementedError(
+ "GQA not yet supported in Triton backend. Use eager backend for calibration."
+ )
+ if attn_mask is not None:
+ raise NotImplementedError(
+ "attn_mask not yet supported in Triton backend."
+ )📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| def _diffusers_triton_attention( | |
| query: torch.Tensor, | |
| key: torch.Tensor, | |
| value: torch.Tensor, | |
| attn_mask: torch.Tensor | None = None, | |
| dropout_p: float = 0.0, | |
| is_causal: bool = False, | |
| scale: float | None = None, | |
| enable_gqa: bool = False, | |
| ) -> torch.Tensor: | |
| def _diffusers_triton_attention( | |
| query: torch.Tensor, | |
| key: torch.Tensor, | |
| value: torch.Tensor, | |
| attn_mask: torch.Tensor | None = None, | |
| dropout_p: float = 0.0, | |
| is_causal: bool = False, | |
| scale: float | None = None, | |
| enable_gqa: bool = False, | |
| ) -> torch.Tensor: | |
| if enable_gqa and query.shape[2] != key.shape[2]: | |
| raise NotImplementedError( | |
| "GQA not yet supported in Triton backend. Use eager backend for calibration." | |
| ) | |
| if attn_mask is not None: | |
| raise NotImplementedError( | |
| "attn_mask not yet supported in Triton backend." | |
| ) |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In
`@modelopt/torch/sparsity/attention_sparsity/kernels/diffusers_triton_attention.py`
around lines 61 - 70, The _diffusers_triton_attention function currently accepts
attn_mask and enable_gqa but ignores them; update the function so it either
implements GQA and mask handling consistent with the eager backend or explicitly
fails fast: if enable_gqa is True or attn_mask is not None, raise a clear
NotImplementedError mentioning "_diffusers_triton_attention does not support
enable_gqa/attn_mask yet" (or implement the same GQA reshaping/aggregation logic
used in diffusers_eager_attention for query/key/value before calling the Triton
kernel) so callers won't silently get incorrect results.
| import threading | ||
|
|
||
| import torch | ||
| from ltx_core.model.transformer.attention import Attention |
There was a problem hiding this comment.
Unconditional import of ltx_core will cause ImportError for users without LTX-2.
This module is loaded at import time, but ltx_core is an optional dependency only available for LTX-2 users. Consider wrapping in a try/except or deferring the import to registration time.
Proposed fix: guard the import
-from ltx_core.model.transformer.attention import Attention
+try:
+ from ltx_core.model.transformer.attention import Attention
+except ImportError:
+ Attention = None # type: ignore[misc,assignment]Then in register_ltx_triton_attention:
def register_ltx_triton_attention(model: torch.nn.Module) -> None:
if Attention is None:
raise ImportError("ltx_core is required for LTX-2 Triton attention")
...🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@modelopt/torch/sparsity/attention_sparsity/kernels/ltx_triton_attention.py`
at line 29, The module unconditionally imports Attention from ltx_core which
will raise ImportError for users without LTX-2; change the top-level import to a
guarded import (try/except ImportError) or defer importing until registration,
set Attention = None on failure, and update register_ltx_triton_attention to
check if Attention is None and raise a clear ImportError like "ltx_core is
required for LTX-2 Triton attention" before proceeding; reference the symbols
Attention and register_ltx_triton_attention in ltx_triton_attention.py to locate
where to apply the guard.
| def _ltx_triton_attention( | ||
| q: torch.Tensor, | ||
| k: torch.Tensor, | ||
| v: torch.Tensor, | ||
| heads: int, | ||
| mask: torch.Tensor | None = None, | ||
| threshold: float | None = None, | ||
| ) -> torch.Tensor: |
There was a problem hiding this comment.
mask parameter is accepted but silently ignored.
The mask parameter is passed through the wrapper (Line 136) but never used in the Triton attention implementation. This could lead to silent correctness issues if the original attention function relied on masking.
Consider either:
- Implementing mask support via
attn_maskin the kernel call - Raising an error/warning if a mask is provided but not supported
- Documenting this limitation clearly
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@modelopt/torch/sparsity/attention_sparsity/kernels/ltx_triton_attention.py`
around lines 60 - 67, The _ltx_triton_attention function currently accepts a
mask parameter but never uses it; update the implementation to handle masks:
either pass the mask into the Triton kernel via the attn_mask argument when
invoking the kernel (ensure shapes/dtypes match and add logic to convert/expand
the mask to the kernel's expected form), or if kernel masking isn't supported
yet, explicitly reject masks by raising a clear error (e.g., raise
NotImplementedError("mask not supported by _ltx_triton_attention") when mask is
not None) so callers won't silently get incorrect results. Ensure the change is
applied inside _ltx_triton_attention and that any conversion/validation of mask
is performed before the kernel call.
| stack = ExitStack() | ||
| set_skip_softmax_context(True) | ||
| stack.callback(set_skip_softmax_context, False) | ||
|
|
||
| try: | ||
| from ..kernels.diffusers_eager_attention import get_skip_softmax_attention_backend | ||
|
|
||
| stack.enter_context(get_skip_softmax_attention_backend()) | ||
| except (ImportError, RuntimeError): | ||
| pass | ||
|
|
||
| stack.enter_context(replace_function(torch.nn.functional, "softmax", sparse_softmax)) | ||
| return stack |
There was a problem hiding this comment.
Context flag set before stack is entered — cleanup may not run.
set_skip_softmax_context(True) is called immediately at line 390, but the cleanup callback only executes when the ExitStack exits. If:
- An exception occurs before the stack is returned, or
- The caller never enters the returned stack as a context manager
...the flag will remain True, potentially affecting subsequent attention computations in the same thread.
🔧 Proposed fix: defer flag-setting to stack entry
from ..kernels import set_skip_softmax_context
stack = ExitStack()
- set_skip_softmax_context(True)
- stack.callback(set_skip_softmax_context, False)
+
+ `@contextmanager`
+ def _skip_softmax_flag():
+ set_skip_softmax_context(True)
+ try:
+ yield
+ finally:
+ set_skip_softmax_context(False)
+
+ stack.enter_context(_skip_softmax_flag())
try:
from ..kernels.diffusers_eager_attention import get_skip_softmax_attention_backendOr simpler, wrap the flag toggle in a small context manager helper.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py`
around lines 389 - 401, The code sets the thread-wide flag via
set_skip_softmax_context(True) immediately which can leak if an exception occurs
before the returned ExitStack is entered; instead create a small context manager
(e.g., using contextlib.contextmanager or a tiny class) that calls
set_skip_softmax_context(True) on __enter__/enter and
set_skip_softmax_context(False) on __exit__/exit, and then register that context
with stack.enter_context rather than calling set_skip_softmax_context and
stack.callback directly; update the function that builds the stack (the block
using ExitStack, get_skip_softmax_attention_backend,
replace_function(torch.nn.functional, "softmax", sparse_softmax)) to enter the
new flag-context via stack.enter_context so the flag is only set when the stack
is actually entered and always cleaned up on exit.
| class TestSkipSoftmaxContext: | ||
| def test_default_is_false(self): | ||
| from modelopt.torch.sparsity.attention_sparsity.kernels import get_skip_softmax_context | ||
|
|
||
| assert get_skip_softmax_context() is False | ||
|
|
||
| def test_set_and_get(self): | ||
| from modelopt.torch.sparsity.attention_sparsity.kernels import ( | ||
| get_skip_softmax_context, | ||
| set_skip_softmax_context, | ||
| ) | ||
|
|
||
| set_skip_softmax_context(True) | ||
| assert get_skip_softmax_context() is True | ||
| set_skip_softmax_context(False) | ||
| assert get_skip_softmax_context() is False | ||
|
|
There was a problem hiding this comment.
Pipeline failure: transformers import triggered during test execution.
The pipeline failures show that importing from modelopt.torch.sparsity.attention_sparsity.kernels transitively imports transformers, which isn't available in the unit test environment. Unit tests should run on CPU without optional dependencies.
Proposed fix: Skip tests when dependencies unavailable
class TestSkipSoftmaxContext:
+ `@pytest.fixture`(autouse=True)
+ def _skip_if_unavailable(self):
+ pytest.importorskip("transformers", reason="transformers required for attention_sparsity")
+
def test_default_is_false(self):Or mock the entire import chain more thoroughly to avoid triggering the transitive import.
🧰 Tools
🪛 GitHub Actions: Unit tests
[error] 58-58: Failed import while executing test_default_is_false due to missing dependency: ModuleNotFoundError: No module named 'transformers'
[error] 63-63: Failed import while executing test_set_and_get due to missing dependency: ModuleNotFoundError: No module named 'transformers'
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tests/unit/torch/sparsity/attention_sparsity/test_kernel_backends.py` around
lines 56 - 72, The tests import
modelopt.torch.sparsity.attention_sparsity.kernels which transitively imports
transformers and breaks CI; update TestSkipSoftmaxContext to skip when optional
dependency missing by using pytest.importorskip('transformers') or catching
ImportError before importing get_skip_softmax_context/set_skip_softmax_context
(or call pytest.skip) so the test cleanly skips in environments without
transformers; ensure the changes are applied around the imports used in
TestSkipSoftmaxContext (references: get_skip_softmax_context,
set_skip_softmax_context, TestSkipSoftmaxContext).
| class TestRegisterDiffusersBackends: | ||
| def test_no_diffusers_no_error(self): | ||
| from modelopt.torch.sparsity.attention_sparsity.conversion import ( | ||
| _register_diffusers_backends_if_needed, | ||
| ) | ||
|
|
||
| _register_diffusers_backends_if_needed(nn.Linear(10, 10)) | ||
|
|
||
| def test_with_diffusers_model(self): | ||
| from modelopt.torch.sparsity.attention_sparsity.conversion import ( | ||
| _register_diffusers_backends_if_needed, | ||
| ) | ||
|
|
||
| mock_mixin = type("ModelMixin", (nn.Module,), {}) | ||
| mock_utils = types.ModuleType("diffusers.models.modeling_utils") | ||
| mock_utils.ModelMixin = mock_mixin | ||
|
|
||
| with ( | ||
| patch.dict(sys.modules, {"diffusers.models.modeling_utils": mock_utils}), | ||
| patch( | ||
| "modelopt.torch.sparsity.attention_sparsity.kernels.register_diffusers_eager_attention", | ||
| MagicMock(), | ||
| ) as mock_eager, | ||
| patch( | ||
| "modelopt.torch.sparsity.attention_sparsity.kernels.register_diffusers_triton_attention", | ||
| MagicMock(), | ||
| ) as mock_triton, | ||
| ): | ||
| _register_diffusers_backends_if_needed(mock_mixin()) | ||
| mock_eager.assert_called_once() |
There was a problem hiding this comment.
Pipeline failure: Module attribute error during patching.
The error module 'modelopt.torch.sparsity' has no attribute 'attention_sparsity' suggests the module isn't fully loaded when patching. The patch target path may need adjustment.
Proposed fix: Patch at the correct import location
The patch targets should reference where the functions are used, not where they're defined:
- patch(
- "modelopt.torch.sparsity.attention_sparsity.kernels.register_diffusers_eager_attention",
+ patch(
+ "modelopt.torch.sparsity.attention_sparsity.conversion.register_diffusers_eager_attention",Or ensure the module is imported before patching:
import modelopt.torch.sparsity.attention_sparsity.conversion # Force load🧰 Tools
🪛 GitHub Actions: Unit tests
[error] 197-197: AttributeError: module 'modelopt.torch.sparsity' has no attribute 'attention_sparsity' (failure in test_with_diffusers_model while patching kernel registration functions)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tests/unit/torch/sparsity/attention_sparsity/test_kernel_backends.py` around
lines 178 - 207, The test fails because patching targets under
"modelopt.torch.sparsity.attention_sparsity.kernels" occurs before that
submodule is loaded, causing a Module attribute error; to fix, ensure the module
is imported before patching or patch the symbols at the location they are looked
up by _register_diffusers_backends_if_needed: import
modelopt.torch.sparsity.attention_sparsity.conversion (or the parent package)
first, then patch the call targets register_diffusers_eager_attention and
register_diffusers_triton_attention as used by that conversion module (i.e.,
patch where _register_diffusers_backends_if_needed resolves them) so the
MagicMock replacement applies correctly.
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
There was a problem hiding this comment.
Actionable comments posted: 2
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.py (1)
308-317:⚠️ Potential issue | 🟡 MinorImprove error message and consider early validation when custom
forward_loopconflicts with decode calibration.When a user provides a custom
forward_loop(e.g., for diffusion models) but also configurescalibrate_decode=True, the current RuntimeError message "calibration_data and tokenizer must be built before decode" is confusing—it implies a bug rather than an unsupported configuration.Consider either:
- Early validation (preferred): Check at the start of calibration if
forward_loop is not None and calibrate_decodeand raise with a clear message, or- Improve the error message to explicitly state the limitation.
Option 1: Add early validation near line 246
# Skip if both phases are disabled if not calibrate_prefill and not calibrate_decode: print("Both prefill and decode target sparsity are 0.0, skipping calibration") return {} + # Decode calibration requires RULER dataset, which is incompatible with custom forward_loop + if forward_loop is not None and calibrate_decode: + raise ValueError( + "Decode calibration is not supported when a custom forward_loop is provided. " + "Either set decode target_sparse_ratio to 0.0 or remove the forward_loop argument " + "to use auto-generated RULER dataset calibration." + ) + # Get sparse attention modulesOption 2: Improve error message at lines 313-314
if calibration_data is None or tokenizer is None: - raise RuntimeError("calibration_data and tokenizer must be built before decode") + raise RuntimeError( + "Decode calibration requires tokenizer and RULER dataset, which are not available " + "when a custom forward_loop is provided. Set decode target_sparse_ratio to 0.0 " + "or remove the forward_loop argument to enable decode calibration." + )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.py` around lines 308 - 317, Add an early validation that explicitly rejects using a custom forward_loop with decode calibration: if forward_loop is not None and calibrate_decode is True, raise a clear RuntimeError stating that decode calibration is incompatible with a custom forward_loop and instruct the user to disable calibrate_decode or remove the custom forward_loop; place this check near the start of the calibration flow (before create_decode_calibration_forward_loop is called). Alternatively, if you prefer the minimal change, improve the existing RuntimeError in the block that checks calibration_data and tokenizer to a message that either explains missing calibration_data/tokenizer or that decode calibration is unsupported when a custom forward_loop is provided (reference forward_loop, calibrate_decode, calibration_data, tokenizer, and create_decode_calibration_forward_loop).
🧹 Nitpick comments (1)
modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.py (1)
34-41: Consider restoring the return type annotation.The lazy import is a good practice per project guidelines. However, removing the return type annotation reduces type safety. Consider adding it back using a string literal or
TYPE_CHECKINGimport to maintain mypy compatibility:Suggested fix
+from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from transformers import PreTrainedTokenizerBase + -def _load_tokenizer(tokenizer_name_or_path: str): +def _load_tokenizer(tokenizer_name_or_path: str) -> "PreTrainedTokenizerBase": """Load tokenizer and ensure pad_token is set.""" from transformers import AutoTokenizer🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.py` around lines 34 - 41, Restore the return type annotation for _load_tokenizer to preserve type safety; update the signature to annotate the return as "PreTrainedTokenizerBase" (a string literal) or import PreTrainedTokenizerBase under TYPE_CHECKING from transformers and use it as the return type, keeping the lazy import of AutoTokenizer inside the function and leaving runtime behavior unchanged; ensure the chosen symbol (PreTrainedTokenizerBase) is referenced in the annotation for mypy compatibility while avoiding importing transformers at module import time.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@modelopt/torch/sparsity/attention_sparsity/conversion.py`:
- Around line 405-443: Guard against empty per-phase payloads by selecting the
first non-empty calibration dict instead of blindly using
next(iter(calibration_params.values())). Replace sample_params =
next(iter(calibration_params.values())) with logic that finds the first
non-empty value: sample_params = next((v for v in calibration_params.values() if
v), None); set is_percentile only when sample_params is not None, and if
sample_params is None skip building threshold_config / threshold_scale_factor
and the per-phase loops (so export_config stays without threshold entries),
referencing the existing names calibration_params, sample_params, is_percentile,
export_config, threshold_config and threshold_scale_factor.
- Around line 125-139: The code currently suppresses all exceptions around
diffusers backend registration (the block checking isinstance(model, ModelMixin)
and the uses of contextlib.suppress(Exception)), which can hide real
registration errors; change the try/except to only catch ImportError when
importing ModelMixin and only suppress ImportError when calling
register_diffusers_eager_attention and register_diffusers_triton_attention so
that any other Exception raised by those registration functions bubbles up (or
is re-raised) instead of being swallowed; update the other places using
contextlib.suppress(Exception) to suppress ImportError only and ensure failures
in register_diffusers_* or in _set_attn_implementation are not silenced so the
error surfaces during conversion.
---
Outside diff comments:
In `@modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.py`:
- Around line 308-317: Add an early validation that explicitly rejects using a
custom forward_loop with decode calibration: if forward_loop is not None and
calibrate_decode is True, raise a clear RuntimeError stating that decode
calibration is incompatible with a custom forward_loop and instruct the user to
disable calibrate_decode or remove the custom forward_loop; place this check
near the start of the calibration flow (before
create_decode_calibration_forward_loop is called). Alternatively, if you prefer
the minimal change, improve the existing RuntimeError in the block that checks
calibration_data and tokenizer to a message that either explains missing
calibration_data/tokenizer or that decode calibration is unsupported when a
custom forward_loop is provided (reference forward_loop, calibrate_decode,
calibration_data, tokenizer, and create_decode_calibration_forward_loop).
---
Nitpick comments:
In `@modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.py`:
- Around line 34-41: Restore the return type annotation for _load_tokenizer to
preserve type safety; update the signature to annotate the return as
"PreTrainedTokenizerBase" (a string literal) or import PreTrainedTokenizerBase
under TYPE_CHECKING from transformers and use it as the return type, keeping the
lazy import of AutoTokenizer inside the function and leaving runtime behavior
unchanged; ensure the chosen symbol (PreTrainedTokenizerBase) is referenced in
the annotation for mypy compatibility while avoiding importing transformers at
module import time.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 728d5b26-043d-44d2-a39e-cc945d944a89
📒 Files selected for processing (2)
modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.pymodelopt/torch/sparsity/attention_sparsity/conversion.py
| try: | ||
| from diffusers.models.modeling_utils import ModelMixin | ||
|
|
||
| if isinstance(model, ModelMixin): | ||
| from .kernels import ( | ||
| register_diffusers_eager_attention, | ||
| register_diffusers_triton_attention, | ||
| ) | ||
|
|
||
| if register_diffusers_eager_attention is not None: | ||
| register_diffusers_eager_attention() | ||
| if register_diffusers_triton_attention is not None: | ||
| register_diffusers_triton_attention() | ||
| except (ImportError, Exception): | ||
| pass |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
# First, let's examine the file structure and the specific lines mentioned
cd /repo && head -150 modelopt/torch/sparsity/attention_sparsity/conversion.py | tail -40Repository: NVIDIA/Model-Optimizer
Length of output: 121
🏁 Script executed:
# Get a fuller view of the section around lines 118-155
cd /repo && sed -n '115,160p' modelopt/torch/sparsity/attention_sparsity/conversion.pyRepository: NVIDIA/Model-Optimizer
Length of output: 121
🏁 Script executed:
# Also check the other segments mentioned in the scratchpad (405-444, 516-525, 553-554)
cd /repo && sed -n '400,450p' modelopt/torch/sparsity/attention_sparsity/conversion.pyRepository: NVIDIA/Model-Optimizer
Length of output: 121
🏁 Script executed:
# Check what exceptions could be raised during registration
cd /repo && sed -n '1,50p' modelopt/torch/sparsity/attention_sparsity/conversion.pyRepository: NVIDIA/Model-Optimizer
Length of output: 121
🏁 Script executed:
# Find the file in the repository
git ls-files | grep "conversion.py"Repository: NVIDIA/Model-Optimizer
Length of output: 544
🏁 Script executed:
# Get file size and initial content
git ls-files | grep "conversion.py" | xargs wc -lRepository: NVIDIA/Model-Optimizer
Length of output: 627
🏁 Script executed:
# Read the specific sections mentioned in the review
git ls-files | grep "conversion.py" | xargs sed -n '115,160p'Repository: NVIDIA/Model-Optimizer
Length of output: 1540
🏁 Script executed:
# Check the export format section (405-444)
git ls-files | grep "conversion.py" | xargs sed -n '400,450p'Repository: NVIDIA/Model-Optimizer
Length of output: 1726
🏁 Script executed:
# Read lines 115-160 from the attention_sparsity conversion.py file
sed -n '115,160p' modelopt/torch/sparsity/attention_sparsity/conversion.pyRepository: NVIDIA/Model-Optimizer
Length of output: 1820
🏁 Script executed:
# Also get lines 405-450 for the export format logic
sed -n '405,450p' modelopt/torch/sparsity/attention_sparsity/conversion.pyRepository: NVIDIA/Model-Optimizer
Length of output: 1717
🏁 Script executed:
# Check lines 144-154 which is also mentioned
sed -n '140,160p' modelopt/torch/sparsity/attention_sparsity/conversion.pyRepository: NVIDIA/Model-Optimizer
Length of output: 775
🏁 Script executed:
# Get the full function context starting from line 118
sed -n '118,160p' modelopt/torch/sparsity/attention_sparsity/conversion.pyRepository: NVIDIA/Model-Optimizer
Length of output: 1730
🏁 Script executed:
# Verify where _register_diffusers_backends_if_needed is called
sed -n '160,180p' modelopt/torch/sparsity/attention_sparsity/conversion.pyRepository: NVIDIA/Model-Optimizer
Length of output: 701
🏁 Script executed:
# Check what _set_attn_implementation does
grep -n "_set_attn_implementation" modelopt/torch/sparsity/attention_sparsity/conversion.pyRepository: NVIDIA/Model-Optimizer
Length of output: 203
🏁 Script executed:
# See how the export_sparse_attention_config is called
grep -n "export_sparse_attention_config" modelopt/torch/sparsity/attention_sparsity/conversion.pyRepository: NVIDIA/Model-Optimizer
Length of output: 148
🏁 Script executed:
# Check what _set_attn_implementation does and what it expects
sed -n '35,70p' modelopt/torch/sparsity/attention_sparsity/conversion.pyRepository: NVIDIA/Model-Optimizer
Length of output: 2039
🏁 Script executed:
# Understand the broader context - what happens if backends aren't available
sed -n '35,100p' modelopt/torch/sparsity/attention_sparsity/conversion.pyRepository: NVIDIA/Model-Optimizer
Length of output: 3306
Don't swallow unexpected backend registration failures.
Lines 125-139 and 144-154 use overly broad exception suppression. Line 138 catches (ImportError, Exception), and lines 150 & 152 use contextlib.suppress(Exception). This masks registration bugs—if a backend registration fails unexpectedly, conversion still proceeds and _set_attn_implementation() will route the model to use eager or modelopt_triton with the backend never actually registered. Only suppress ImportError (optional dependency missing) and surface unexpected registration failures so they don't cause silent runtime failures.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@modelopt/torch/sparsity/attention_sparsity/conversion.py` around lines 125 -
139, The code currently suppresses all exceptions around diffusers backend
registration (the block checking isinstance(model, ModelMixin) and the uses of
contextlib.suppress(Exception)), which can hide real registration errors; change
the try/except to only catch ImportError when importing ModelMixin and only
suppress ImportError when calling register_diffusers_eager_attention and
register_diffusers_triton_attention so that any other Exception raised by those
registration functions bubbles up (or is re-raised) instead of being swallowed;
update the other places using contextlib.suppress(Exception) to suppress
ImportError only and ensure failures in register_diffusers_* or in
_set_attn_implementation are not silenced so the error surfaces during
conversion.
| # Detect calibration type from params | ||
| sample_params = next(iter(calibration_params.values())) | ||
| is_percentile = "threshold" in sample_params | ||
|
|
||
| # Build the export config | ||
| export_config: dict[str, Any] = { | ||
| "config_groups": { | ||
| "group_0": { | ||
| "sparse_algo": "softmax_skip", | ||
| "sparse_algo": "softmax_skip_diffusion" if is_percentile else "softmax_skip", | ||
| "targets": sorted(target_classes) if target_classes else ["Attention"], | ||
| } | ||
| }, | ||
| "threshold_scale_factor": threshold_scale_factor, | ||
| "producer": { | ||
| "name": "modelopt", | ||
| "version": mo_version, | ||
| }, | ||
| } | ||
|
|
||
| if is_percentile: | ||
| threshold_config: dict[str, Any] = { | ||
| "formula": "skip if gap >= threshold * log(seq_k)", | ||
| } | ||
| for phase in ["prefill", "decode"]: | ||
| if phase in calibration_params: | ||
| threshold_config[phase] = { | ||
| "threshold": calibration_params[phase]["threshold"], | ||
| } | ||
| export_config["threshold_config"] = threshold_config | ||
| else: | ||
| threshold_scale_factor: dict[str, Any] = { | ||
| "formula": "a * exp(b * target_sparsity)", | ||
| } | ||
| for phase in ["prefill", "decode"]: | ||
| if phase in calibration_params: | ||
| threshold_scale_factor[phase] = { | ||
| "a": calibration_params[phase]["a"], | ||
| "b": calibration_params[phase]["b"], | ||
| } | ||
| export_config["threshold_scale_factor"] = threshold_scale_factor |
There was a problem hiding this comment.
Guard empty phase payloads before inferring the export format.
Line 406 assumes calibration_params already contains a populated per-phase dict. If the first module contributes {} or a placeholder phase entry, this now raises StopIteration or falls into the wrong branch and then hits KeyError on Lines 430/440 instead of honoring the documented None fallback for “no calibrated sparse attention modules”.
💡 Proposed fix
- sample_params = next(iter(calibration_params.values()))
+ sample_params = next((params for params in calibration_params.values() if params), None)
+ if sample_params is None:
+ return None
is_percentile = "threshold" in sample_params
@@
for phase in ["prefill", "decode"]:
- if phase in calibration_params:
+ phase_params = calibration_params.get(phase)
+ if phase_params and "threshold" in phase_params:
threshold_config[phase] = {
- "threshold": calibration_params[phase]["threshold"],
+ "threshold": phase_params["threshold"],
}
@@
for phase in ["prefill", "decode"]:
- if phase in calibration_params:
+ phase_params = calibration_params.get(phase)
+ if phase_params and all(k in phase_params for k in ("a", "b")):
threshold_scale_factor[phase] = {
- "a": calibration_params[phase]["a"],
- "b": calibration_params[phase]["b"],
+ "a": phase_params["a"],
+ "b": phase_params["b"],
}🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@modelopt/torch/sparsity/attention_sparsity/conversion.py` around lines 405 -
443, Guard against empty per-phase payloads by selecting the first non-empty
calibration dict instead of blindly using
next(iter(calibration_params.values())). Replace sample_params =
next(iter(calibration_params.values())) with logic that finds the first
non-empty value: sample_params = next((v for v in calibration_params.values() if
v), None); set is_percentile only when sample_params is not None, and if
sample_params is None skip building threshold_config / threshold_scale_factor
and the per-phase loops (so export_config stays without threshold entries),
referencing the existing names calibration_params, sample_params, is_percentile,
export_config, threshold_config and threshold_scale_factor.
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
There was a problem hiding this comment.
♻️ Duplicate comments (1)
modelopt/torch/sparsity/attention_sparsity/conversion.py (1)
125-154:⚠️ Potential issue | 🟠 MajorDon’t suppress all backend registration errors.
Line 138 and Lines 150-154 currently swallow arbitrary exceptions. That can hide real backend registration bugs and leave conversion proceeding with a partially configured runtime.
Proposed fix
- except (ImportError, Exception): + except ImportError: pass @@ - except (ImportError, RuntimeError): + except ImportError: return @@ - if register_ltx_eager_attention is not None: - with contextlib.suppress(Exception): - register_ltx_eager_attention(model) - if register_ltx_triton_attention is not None: - with contextlib.suppress(Exception): - register_ltx_triton_attention(model) + if register_ltx_eager_attention is not None: + with contextlib.suppress(ImportError): + register_ltx_eager_attention(model) + if register_ltx_triton_attention is not None: + with contextlib.suppress(ImportError): + register_ltx_triton_attention(model)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/sparsity/attention_sparsity/conversion.py` around lines 125 - 154, The code currently suppresses all exceptions when importing/calling backend registration functions (the try/except around ModelMixin and the contextlib.suppress usage for register_ltx_eager_attention/register_ltx_triton_attention), which hides real errors; change these to only catch expected import/runtime errors and surface or log unexpected exceptions: when importing from .kernels and checking isinstance(model, ModelMixin), catch only ImportError and RuntimeError, and when invoking register_diffusers_eager_attention, register_diffusers_triton_attention, register_ltx_eager_attention, and register_ltx_triton_attention, wrap each call in a try/except that logs the full exception (including stack trace) via the module logger and re-raises or returns a clear error instead of silently swallowing it so backend-registration failures are visible and conversion does not proceed silently with a misconfigured runtime.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Duplicate comments:
In `@modelopt/torch/sparsity/attention_sparsity/conversion.py`:
- Around line 125-154: The code currently suppresses all exceptions when
importing/calling backend registration functions (the try/except around
ModelMixin and the contextlib.suppress usage for
register_ltx_eager_attention/register_ltx_triton_attention), which hides real
errors; change these to only catch expected import/runtime errors and surface or
log unexpected exceptions: when importing from .kernels and checking
isinstance(model, ModelMixin), catch only ImportError and RuntimeError, and when
invoking register_diffusers_eager_attention,
register_diffusers_triton_attention, register_ltx_eager_attention, and
register_ltx_triton_attention, wrap each call in a try/except that logs the full
exception (including stack trace) via the module logger and re-raises or returns
a clear error instead of silently swallowing it so backend-registration failures
are visible and conversion does not proceed silently with a misconfigured
runtime.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 09ecbcc7-b977-4605-a5e2-b4e013bcaee8
📒 Files selected for processing (1)
modelopt/torch/sparsity/attention_sparsity/conversion.py
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
There was a problem hiding this comment.
Actionable comments posted: 4
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@examples/diffusers/sparsity/wan22_skip_softmax.py`:
- Around line 198-206: The calibration loop in build_calibration_forward_loop
currently hardcodes guidance_scale=5.0 so calibration activations ignore the CLI
--guidance-scale; add a guidance_scale parameter to
build_calibration_forward_loop (and the other calibration-related functions
mentioned) and pass that guidance_scale through to any calls that currently use
guidance_scale=5.0 (look for explicit guidance_scale=5.0 in the function body
and replace with the new guidance_scale parameter), ensuring the same parameter
is threaded into the calibration runs so activation collection honors the
user-specified guidance scale.
- Around line 177-183: The calibration block is being added at the top-level
sparse_cfg (making "calibration" a selector) instead of being nested under the
self-attention selector; update the code so that when args.calibrate is true you
insert the calibration dict (using args.target_sparsity,
DEFAULT_THRESHOLD_TRIALS and samples:1) into the "*.attn1*" entry of sparse_cfg
(i.e., sparse_cfg["*.attn1*"]["calibration"] = {...}) so the attn1 selector
receives the calibration config rather than creating a new selector at the top
level.
- Around line 77-135: The CLI allows invalid values that later break
calibration; update parse_args() to validate arguments after
parser.parse_args(): ensure args.num_frames and args.calib_frames satisfy (value
- 1) % 4 == 0 and are > 1 (i.e., 4k+1), ensure args.target_sparsity is between
0.0 and 1.0 inclusive, and if a check fails call parser.error(...) with a clear
message referencing the offending flag (e.g., --num-frames, --calib-frames,
--target-sparsity) so users get immediate, actionable feedback; implement these
checks in parse_args() using the parsed args variables.
- Around line 192-193: The code currently loads the entire "caption" column then
slices it; change the dataset load to request only the needed rows by using the
HuggingFace split range: replace load_dataset("nkp37/OpenVid-1M", split="train")
with load_dataset("nkp37/OpenVid-1M", split=f"train[:{calib_size}]") and then
build prompts from the returned small dataset (e.g., prompts =
list(dataset["caption"])). This avoids materializing the full caption column for
the tiny calibration sample.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: a396d14c-ebef-49a3-aeb9-c3d398add26a
📒 Files selected for processing (2)
examples/diffusers/README.mdexamples/diffusers/sparsity/wan22_skip_softmax.py
✅ Files skipped from review due to trivial changes (1)
- examples/diffusers/README.md
| def parse_args() -> argparse.Namespace: | ||
| parser = argparse.ArgumentParser( | ||
| description="Wan 2.2 video generation with skip-softmax sparse attention" | ||
| ) | ||
| parser.add_argument("--prompt", type=str, required=True, help="Text prompt for generation") | ||
| parser.add_argument("--output", type=str, default="output.mp4", help="Output video path") | ||
| parser.add_argument( | ||
| "--model-path", type=str, default=DEFAULT_MODEL_PATH, help="Wan 2.2 model path or HF ID" | ||
| ) | ||
| parser.add_argument( | ||
| "--num-frames", type=int, default=81, help="Number of frames (must be 4k+1)" | ||
| ) | ||
| parser.add_argument("--height", type=int, default=480, help="Video height") | ||
| parser.add_argument("--width", type=int, default=832, help="Video width") | ||
| parser.add_argument("--num-steps", type=int, default=50, help="Number of inference steps") | ||
| parser.add_argument( | ||
| "--guidance-scale", type=float, default=5.0, help="Classifier-free guidance scale" | ||
| ) | ||
| parser.add_argument("--seed", type=int, default=42, help="Random seed") | ||
|
|
||
| # Sparse attention options | ||
| parser.add_argument( | ||
| "--skip-first-last", | ||
| type=int, | ||
| default=2, | ||
| help="Number of first/last transformer layers to keep dense (default: 2)", | ||
| ) | ||
|
|
||
| # Calibration options | ||
| parser.add_argument( | ||
| "--calibrate", | ||
| action="store_true", | ||
| help="Calibrate threshold via exponential model (recommended)", | ||
| ) | ||
| parser.add_argument( | ||
| "--target-sparsity", | ||
| type=float, | ||
| default=0.5, | ||
| help="Target sparsity ratio for calibration (0.0-1.0)", | ||
| ) | ||
| parser.add_argument( | ||
| "--calib-steps", | ||
| type=int, | ||
| default=40, | ||
| help="Inference steps for calibration", | ||
| ) | ||
| parser.add_argument( | ||
| "--calib-frames", | ||
| type=int, | ||
| default=151, | ||
| help="Number of frames for calibration", | ||
| ) | ||
| parser.add_argument( | ||
| "--calib-size", | ||
| type=int, | ||
| default=4, | ||
| help="Number of calibration prompts from OpenVid-1M dataset", | ||
| ) | ||
| return parser.parse_args() |
There was a problem hiding this comment.
Validate the documented CLI constraints in parse_args().
The help text already says --num-frames / --calib-frames must be 4k+1 and --target-sparsity must be in [0, 1], but invalid values are accepted and will fail later inside diffusers/calibration with much less actionable errors.
Proposed fix
- return parser.parse_args()
+ args = parser.parse_args()
+ if args.num_frames < 1 or args.num_frames % 4 != 1:
+ parser.error("--num-frames must be of the form 4k+1")
+ if args.skip_first_last < 0:
+ parser.error("--skip-first-last must be >= 0")
+ if args.calibrate:
+ if args.calib_frames < 1 or args.calib_frames % 4 != 1:
+ parser.error("--calib-frames must be of the form 4k+1")
+ if not 0.0 <= args.target_sparsity <= 1.0:
+ parser.error("--target-sparsity must be between 0.0 and 1.0")
+ if args.calib_size < 1:
+ parser.error("--calib-size must be >= 1")
+ return args🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@examples/diffusers/sparsity/wan22_skip_softmax.py` around lines 77 - 135, The
CLI allows invalid values that later break calibration; update parse_args() to
validate arguments after parser.parse_args(): ensure args.num_frames and
args.calib_frames satisfy (value - 1) % 4 == 0 and are > 1 (i.e., 4k+1), ensure
args.target_sparsity is between 0.0 and 1.0 inclusive, and if a check fails call
parser.error(...) with a clear message referencing the offending flag (e.g.,
--num-frames, --calib-frames, --target-sparsity) so users get immediate,
actionable feedback; implement these checks in parse_args() using the parsed
args variables.
| # Add calibration config with threshold trials | ||
| if args.calibrate: | ||
| sparse_cfg["calibration"] = { | ||
| "target_sparse_ratio": {"prefill": args.target_sparsity}, | ||
| "samples": 1, | ||
| "threshold_trials": DEFAULT_THRESHOLD_TRIALS, | ||
| } |
There was a problem hiding this comment.
Attach calibration to the *.attn1* config, not to sparse_cfg.
modelopt/torch/sparsity/attention_sparsity/model_sparsify.py describes sparse_cfg as a selector → attributes map. Putting "calibration" beside "*.attn1*" makes it another selector, so the self-attention entries never receive the calibration block.
Proposed fix
# Add calibration config with threshold trials
if args.calibrate:
- sparse_cfg["calibration"] = {
+ attn_cfg["calibration"] = {
"target_sparse_ratio": {"prefill": args.target_sparsity},
"samples": 1,
"threshold_trials": DEFAULT_THRESHOLD_TRIALS,
}📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| # Add calibration config with threshold trials | |
| if args.calibrate: | |
| sparse_cfg["calibration"] = { | |
| "target_sparse_ratio": {"prefill": args.target_sparsity}, | |
| "samples": 1, | |
| "threshold_trials": DEFAULT_THRESHOLD_TRIALS, | |
| } | |
| # Add calibration config with threshold trials | |
| if args.calibrate: | |
| attn_cfg["calibration"] = { | |
| "target_sparse_ratio": {"prefill": args.target_sparsity}, | |
| "samples": 1, | |
| "threshold_trials": DEFAULT_THRESHOLD_TRIALS, | |
| } |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@examples/diffusers/sparsity/wan22_skip_softmax.py` around lines 177 - 183,
The calibration block is being added at the top-level sparse_cfg (making
"calibration" a selector) instead of being nested under the self-attention
selector; update the code so that when args.calibrate is true you insert the
calibration dict (using args.target_sparsity, DEFAULT_THRESHOLD_TRIALS and
samples:1) into the "*.attn1*" entry of sparse_cfg (i.e.,
sparse_cfg["*.attn1*"]["calibration"] = {...}) so the attn1 selector receives
the calibration config rather than creating a new selector at the top level.
| dataset = load_dataset("nkp37/OpenVid-1M", split="train") | ||
| prompts = list(dataset["caption"][:calib_size]) |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
# First, find and examine the file
fd -t f "wan22_skip_softmax.py" && wc -l examples/diffusers/sparsity/wan22_skip_softmax.pyRepository: NVIDIA/Model-Optimizer
Length of output: 169
🏁 Script executed:
# Read the specific lines around 192-193
sed -n '185,200p' examples/diffusers/sparsity/wan22_skip_softmax.pyRepository: NVIDIA/Model-Optimizer
Length of output: 529
🌐 Web query:
huggingface datasets load_dataset indexing behavior [:n] vs split parameter
💡 Result:
-
load_dataset(..., split="train[:N]")(or"train[10:20]","train[:10%]", etc.) uses the library’s split slicing /ReadInstructionmechanism at load time and returns adatasets.Datasetthat contains only that requested subsplit. Examples of this string API are documented (e.g.split="train[10:20]",split="train[:10%]"). [1] -
ds[:N]after you already loaded a split is Python-style indexing viaDataset.__getitem__. For a slice like[:100], it returns a batch as a plain Pythondictof columns, not a newDatasetobject—this commonly surprises users. The recommended way to get a smallerDatasetisds.select(range(N)). [2] -
Streaming caveat: when
streaming=True, the"train[:N]"split-slicing syntax is (historically) not supported. [3]
Rule of thumb
- Want a Dataset object subset: use
split="train[:N]"at load time ords.select(range(N))after loading. - Want an in-memory batch of examples: use
ds[:N].
Sources: [1] https://huggingface.co/docs/datasets/v1.10.2/splits.html • [2] https://discuss.huggingface.co/t/how-to-slice-an-already-loaded-dataset/27946 • [3] huggingface/datasets#5458
Avoid reading the entire OpenVid caption column for a tiny calibration sample.
dataset["caption"][:calib_size] loads the full caption column before slicing, which for OpenVid-1M means reading all captions to extract just 4 prompts. Use the split syntax instead to load only the needed rows at dataset load time.
Proposed fix
- dataset = load_dataset("nkp37/OpenVid-1M", split="train")
- prompts = list(dataset["caption"][:calib_size])
+ dataset = load_dataset("nkp37/OpenVid-1M", split=f"train[:{calib_size}]")
+ prompts = list(dataset["caption"])📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| dataset = load_dataset("nkp37/OpenVid-1M", split="train") | |
| prompts = list(dataset["caption"][:calib_size]) | |
| dataset = load_dataset("nkp37/OpenVid-1M", split=f"train[:{calib_size}]") | |
| prompts = list(dataset["caption"]) |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@examples/diffusers/sparsity/wan22_skip_softmax.py` around lines 192 - 193,
The code currently loads the entire "caption" column then slices it; change the
dataset load to request only the needed rows by using the HuggingFace split
range: replace load_dataset("nkp37/OpenVid-1M", split="train") with
load_dataset("nkp37/OpenVid-1M", split=f"train[:{calib_size}]") and then build
prompts from the returned small dataset (e.g., prompts =
list(dataset["caption"])). This avoids materializing the full caption column for
the tiny calibration sample.
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
There was a problem hiding this comment.
can we have a tiny wan2.2 model tested in tests/examples/diffusers for this file?
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
What does this PR do?
Type of change: new feature, new example
Summary
flash_skip_softmaxwith exponential model calibration (scale_factor = a * exp(b * sparsity))F.softmaxpatching) works on diffusion models that normally usescaled_dot_product_attentionforward_loop(required for non-LLM models)Changes
diffusers_triton_attention.py,diffusers_eager_attention.py,ltx_triton_attention.py,ltx_eager_attention.py— route diffusers/LTX attention through explicitF.softmaxfor calibrationkernels/__init__.py: Thread-local context management, lazy imports for diffusers/LTX backendsconversion.py: Auto-register diffusers backends onsparsify(), updated export config and summarycalibrate.py: Skip RULER dataset whenforward_loopis provided (enables diffusion model calibration)flash_skip_softmax.py: Enhanced context manager activates diffusers eager backendplugins/huggingface.py: Support diffusersModelMixinin model detectionltx2_skip_softmax.py,wan22_skip_softmax.pyUsage
Example scripts
Before your PR is "Ready for review"
Make sure you read and follow Contributor guidelines and your commits are signed (
git commit -s -S).Make sure you read and follow the Security Best Practices (e.g. avoiding hardcoded
trust_remote_code=True,torch.load(..., weights_only=False),pickle, etc.).CONTRIBUTING.md: ✅ / ❌ / N/AAdditional Information
Summary by CodeRabbit
Release Notes
New Features
Documentation
Tests