Skip to content

Add the Skip softmax for diffusion#1166

Open
jingyu-ml wants to merge 14 commits intomainfrom
jingyux/diffusion-skip-softmax
Open

Add the Skip softmax for diffusion#1166
jingyu-ml wants to merge 14 commits intomainfrom
jingyux/diffusion-skip-softmax

Conversation

@jingyu-ml
Copy link
Copy Markdown
Contributor

@jingyu-ml jingyu-ml commented Apr 2, 2026

What does this PR do?

Type of change: new feature, new example

Summary

  • Add skip-softmax sparse attention support for diffusion models (LTX-2, Wan 2.2) using flash_skip_softmax with exponential model calibration (scale_factor = a * exp(b * sparsity))
  • Add diffusers/LTX kernel backends so that eager attention (with F.softmax patching) works on diffusion models that normally use scaled_dot_product_attention
  • Fix calibration to skip RULER dataset generation when user provides their own forward_loop (required for non-LLM models)

Changes

  • New kernel backends: diffusers_triton_attention.py, diffusers_eager_attention.py, ltx_triton_attention.py, ltx_eager_attention.py — route diffusers/LTX attention through explicit F.softmax for calibration
  • kernels/__init__.py: Thread-local context management, lazy imports for diffusers/LTX backends
  • conversion.py: Auto-register diffusers backends on sparsify(), updated export config and summary
  • calibrate.py: Skip RULER dataset when forward_loop is provided (enables diffusion model calibration)
  • flash_skip_softmax.py: Enhanced context manager activates diffusers eager backend
  • plugins/huggingface.py: Support diffusers ModelMixin in model detection
  • Example scripts: ltx2_skip_softmax.py, wan22_skip_softmax.py

Usage

import modelopt.torch.sparsity.attention_sparsity as mtsa

# 1. Build your diffusion pipeline and get the transformer
transformer = pipeline.transformer  # or pipeline.stage_1_model_ledger.transformer()

# 2. Define sparse config
config = {
    "sparse_cfg": {
        "calibration": {
            "target_sparse_ratio": {"prefill": 0.25},
            "threshold_trials": [1e-6, 5e-6, 1e-5, 5e-5, 1e-4, 5e-4, 1e-3, 5e-3,
                                 1e-2, 2e-2, 5e-2, 1e-1, 2e-1, 3e-1, 5e-1, 7e-1],
        },
        "*.attn1": {
            "method": "flash_skip_softmax",
            "thresholds": {"prefill": [1e-3]},
            "br": 128, "bc": 128,
            "backend": "pytorch",
            "is_causal": False,
            "collect_stats": True,
            "enable": True,
        },
        "*.attn2": {"enable": False},      # skip cross-attention
        "default": {"enable": False},
    },
}

# 3. Define a calibration forward loop (runs the diffusion pipeline)
def forward_loop(model):
    pipeline(prompt="a cat", num_frames=81, num_inference_steps=40, ...)

# 4. Sparsify + calibrate
mtsa.sparsify(transformer, config, forward_loop=forward_loop)

# 5. Generate as usual — sparsity is applied automatically
output = pipeline(prompt="a dog on the beach", ...)

Example scripts

# LTX-2 with 25% sparsity, skip first/last 3 layers
python examples/diffusers/sparsity/ltx2_skip_softmax.py \
    --prompt "A cat playing piano" --output out.mp4 \
    --calibrate --target-sparsity 0.25 --skip-first-last 3

# Wan 2.2 with 25% sparsity
python examples/diffusers/sparsity/wan22_skip_softmax.py \
    --prompt "A sunset over mountains" --output out.mp4 \
    --calibrate --target-sparsity 0.25

Before your PR is "Ready for review"

Make sure you read and follow Contributor guidelines and your commits are signed (git commit -s -S).

Make sure you read and follow the Security Best Practices (e.g. avoiding hardcoded trust_remote_code=True, torch.load(..., weights_only=False), pickle, etc.).

  • Is this change backward compatible?: ✅ / ❌ / N/A
  • If you copied code from any other sources or added a new PIP dependency, did you follow guidance in CONTRIBUTING.md: ✅ / ❌ / N/A
  • Did you write any new necessary tests?: ✅ / ❌ / N/A
  • Did you update Changelog?: ✅ / ❌ / N/A

Additional Information

Summary by CodeRabbit

Release Notes

  • New Features

    • Added skip-softmax sparse attention support for Diffusers models, enabling efficient video generation
    • Added support for both eager and Triton attention backends for sparse attention
    • Added new example script for Wan 2.2 text-to-video generation with sparse attention optimization
  • Documentation

    • Updated documentation with sparse attention configuration guide and usage examples
  • Tests

    • Added comprehensive unit tests for kernel backend registration and skip-softmax functionality

Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
@jingyu-ml jingyu-ml requested review from a team as code owners April 2, 2026 06:02
@jingyu-ml jingyu-ml requested a review from kaix-nv April 2, 2026 06:02
@jingyu-ml jingyu-ml marked this pull request as draft April 2, 2026 06:02
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Apr 2, 2026

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

This PR adds skip-softmax sparse attention support for Diffusers and LTX-2 models, including new eager and Triton kernel backends, calibration refinements, example scripts, and comprehensive tests for framework-specific attention implementations.

Changes

Cohort / File(s) Summary
Example Script & Documentation
examples/diffusers/sparsity/wan22_skip_softmax.py, examples/diffusers/README.md
New executable example for WAN 2.2 video generation using skip-softmax sparse attention with CLI argument parsing, calibration support, and sparsity summary reporting. README updated with sparse attention section and example script instructions.
Calibration & Conversion Core
modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.py, modelopt/torch/sparsity/attention_sparsity/conversion.py
Modified calibrate_sparse_attention() to defer tokenizer/dataset generation when forward_loop is provided. Added _register_diffusers_backends_if_needed() to conditionally register Diffusers backends and patch LTX modules. Updated print_sparse_attention_summary() to skip disabled modules when computing sparsity counts.
Diffusers Kernel Backends
modelopt/torch/sparsity/attention_sparsity/kernels/diffusers_eager_attention.py, modelopt/torch/sparsity/attention_sparsity/kernels/diffusers_triton_attention.py
New eager attention backend implementing scaled dot-product with softmax interception point. New Triton backend reshaping Diffusers layout to varlen format with optional skip-softmax threshold support. Both include idempotent backend registration and context managers.
LTX Kernel Backends
modelopt/torch/sparsity/attention_sparsity/kernels/ltx_eager_attention.py, modelopt/torch/sparsity/attention_sparsity/kernels/ltx_triton_attention.py
New eager wrapper for LTX-2 attention with skip-softmax context detection. New Triton backend reshaping LTX fused-head layout to varlen format with skip-softmax threshold support. Both include thread-local configuration and idempotent module wrapping.
Kernel Infrastructure
modelopt/torch/sparsity/attention_sparsity/kernels/__init__.py
Added thread-local skip-softmax context helpers (set_skip_softmax_context, get_skip_softmax_context) and optional backend registration symbols with conditional imports for Diffusers and LTX backends.
Sparse Attention Methods
modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py, modelopt/torch/sparsity/attention_sparsity/methods/triton_skip_softmax.py
Updated flash_skip_softmax get_sparse_context() to toggle skip-softmax context flag and conditionally enter Diffusers eager backend context. Added calculate_sparsity() and apply_sparsity() to TritonSkipSoftmaxMethod with explicit NotImplementedError for Python-path sparsity.
Plugin System & Infrastructure
modelopt/torch/sparsity/attention_sparsity/plugins/huggingface.py, modelopt/torch/sparsity/attention_sparsity/stats_manager.py
Deferred transformers import in _is_supported_model() and added Diffusers ModelMixin detection. Modified stats collection to conditionally extend sample stats with "normalized_gaps" when present in incoming statistics.
Unit Tests
tests/unit/torch/sparsity/attention_sparsity/test_kernel_backends.py
New comprehensive test module covering thread-local context behavior, eager/Triton backend registration idempotence, shape validation across attention scenarios (basic, cross-attention, causal, GQA), and diffusers backend registration via mocked dependencies.

Sequence Diagram(s)

sequenceDiagram
    actor User
    participant Pipeline as WAN Pipeline
    participant Sparsify as mtsa.sparsify()
    participant Transformer as Transformer<br/>Modules
    participant CalibLoop as Calibration<br/>Forward Loop
    participant Config as Sparse<br/>Config

    User->>Pipeline: build_pipeline(model_path)
    activate Pipeline
    Pipeline-->>User: pipeline ready
    deactivate Pipeline

    User->>Config: build_sparse_config(args)
    activate Config
    Config-->>User: sparse config dict
    deactivate Config

    alt Calibration Mode
        User->>CalibLoop: build_calibration_forward_loop()
        activate CalibLoop
        CalibLoop-->>User: forward_loop callable
        deactivate CalibLoop
        
        User->>Sparsify: sparsify(transformer, config,<br/>forward_loop=...)
        activate Sparsify
        Sparsify->>CalibLoop: invoke forward_loop<br/>(multiple prompts)
        CalibLoop->>Transformer: collect attention stats
        Transformer-->>CalibLoop: attention outputs
        Sparsify->>Transformer: apply sparse config
        Transformer-->>Sparsify: sparse attention modules
        Sparsify-->>User: sparsified transformer
        deactivate Sparsify
    else No Calibration
        User->>Sparsify: sparsify(transformer, config)
        activate Sparsify
        Sparsify->>Transformer: apply sparse config
        Transformer-->>Sparsify: sparse attention modules
        Sparsify-->>User: sparsified transformer
        deactivate Sparsify
    end

    User->>Pipeline: generate(prompt, ...)
    activate Pipeline
    Pipeline->>Transformer: forward with sparse<br/>attention
    Transformer-->>Pipeline: output frames
    deactivate Pipeline
    
    User->>User: print_sparsity_summary(model)
    activate User
    User->>Transformer: enumerate SparseAttentionModule
    Transformer-->>User: module configs
    deactivate User
Loading
sequenceDiagram
    participant Model as Model<br/>(Diffusers/LTX)
    participant ConversionFn as convert_to_sparse_<br/>attention_model()
    participant RegFn as _register_diffusers_<br/>backends_if_needed()
    participant DiffusersBackends as Diffusers<br/>Backends
    participant LTXBackends as LTX<br/>Backends
    participant Method as Sparse<br/>Method

    Model->>ConversionFn: convert_to_sparse_attention_model(model, ...)
    activate ConversionFn
    ConversionFn->>RegFn: _register_diffusers_backends_if_needed(model)
    activate RegFn
    
    alt Is Diffusers ModelMixin
        RegFn->>DiffusersBackends: register_diffusers_eager_attention()
        RegFn->>DiffusersBackends: register_diffusers_triton_attention()
        DiffusersBackends-->>RegFn: backends registered
    end
    
    alt Has LTX modules
        RegFn->>LTXBackends: patch ltx attention<br/>modules
        LTXBackends-->>RegFn: wrappers installed
    end
    
    RegFn-->>ConversionFn: registration complete
    deactivate RegFn
    
    ConversionFn->>ConversionFn: _set_attn_implementation()
    ConversionFn->>Method: apply sparse config
    Method-->>ConversionFn: sparse model
    ConversionFn-->>Model: sparse model ready
    deactivate ConversionFn
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~75 minutes

Possibly related PRs

  • PR #1078: Adds and wires up Triton-based N:M sparse softmax support affecting the Triton flash-attention path and sparse-attention kernel integration.

Suggested reviewers

  • Edwardf0t1
  • cjluo-nv
🚥 Pre-merge checks | ✅ 2 | ❌ 2

❌ Failed checks (1 warning, 1 inconclusive)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 61.11% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
Title check ❓ Inconclusive The title 'Add the Skip softmax for diffusion' is vague and partially related to the main changeset. It mentions skip-softmax for diffusion but lacks specificity about what is being added (kernel backends, calibration support, example scripts). Consider revising to be more specific, such as 'Add skip-softmax sparse attention support for diffusion models' or 'Implement skip-softmax sparse attention with diffusers/LTX backends and calibration'.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Security Anti-Patterns ✅ Passed Comprehensive security analysis confirms no critical security anti-patterns present in pull request changes.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch jingyux/diffusion-skip-softmax

Comment @coderabbitai help to get the list of available commands and usage tips.

@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot bot commented Apr 2, 2026

Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually.

Contributors can view more details about this message here.

@jingyu-ml jingyu-ml self-assigned this Apr 2, 2026
@github-actions
Copy link
Copy Markdown
Contributor

github-actions bot commented Apr 2, 2026

PR Preview Action v1.8.1

QR code for preview link

🚀 View preview at
https://NVIDIA.github.io/Model-Optimizer/pr-preview/pr-1166/

Built to branch gh-pages at 2026-04-07 20:24 UTC.
Preview will be ready when the GitHub Pages deployment is complete.

@codecov
Copy link
Copy Markdown

codecov bot commented Apr 2, 2026

Codecov Report

❌ Patch coverage is 43.80342% with 263 lines in your changes missing coverage. Please review.
✅ Project coverage is 73.51%. Comparing base (df80a0f) to head (fbeabcf).
⚠️ Report is 10 commits behind head on main.

Files with missing lines Patch % Lines
modelopt/torch/kernels/triton_fa.py 5.40% 70 Missing ⚠️
...attention_sparsity/kernels/ltx_triton_attention.py 5.71% 66 Missing ⚠️
.../attention_sparsity/methods/triton_skip_softmax.py 53.75% 37 Missing ⚠️
...ion_sparsity/kernels/diffusers_triton_attention.py 50.00% 35 Missing ⚠️
.../attention_sparsity/kernels/ltx_eager_attention.py 11.11% 32 Missing ⚠️
...pt/torch/sparsity/attention_sparsity/conversion.py 64.00% 9 Missing ⚠️
...arsity/attention_sparsity/calibration/calibrate.py 36.36% 7 Missing ⚠️
...tion_sparsity/kernels/diffusers_eager_attention.py 93.61% 3 Missing ⚠️
...rsity/attention_sparsity/calibration/calibrator.py 75.00% 1 Missing ⚠️
...ch/sparsity/attention_sparsity/methods/registry.py 50.00% 1 Missing ⚠️
... and 2 more
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1166      +/-   ##
==========================================
- Coverage   74.77%   73.51%   -1.26%     
==========================================
  Files         351      356       +5     
  Lines       40072    43723    +3651     
==========================================
+ Hits        29964    32145    +2181     
- Misses      10108    11578    +1470     
Flag Coverage Δ
examples 39.62% <4.05%> (-0.60%) ⬇️
gpu 56.63% <29.70%> (-0.48%) ⬇️
unit 54.81% <34.18%> (+0.06%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
@jingyu-ml jingyu-ml force-pushed the jingyux/diffusion-skip-softmax branch from 8151232 to 5873652 Compare April 2, 2026 08:38
jingyu-ml and others added 2 commits April 2, 2026 21:29
@jingyu-ml jingyu-ml marked this pull request as ready for review April 3, 2026 06:15
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 8

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.py (2)

306-316: ⚠️ Potential issue | 🟠 Major

Decode calibration fails when forward_loop is provided.

When a user supplies forward_loop, lines 264-265 skip building tokenizer and calibration_data. However, decode calibration (lines 312-316) unconditionally requires both, raising RuntimeError even though the user intended to use their own loop.

This creates an inconsistency: prefill calibration supports user-provided forward_loop, but decode calibration does not. The docstring (line 227) also states forward_loop is "Only used for prefill", but this limitation should either be enforced earlier or decode should also accept a custom loop.

💡 Suggested approach

Either:

  1. Skip decode calibration when forward_loop is provided and calibration_data is None, with a warning
  2. Accept a separate decode_forward_loop parameter
  3. Document and enforce that decode calibration requires RULER dataset
     # Run decode calibration if enabled
     if calibrate_decode:
+        if calibration_data is None or tokenizer is None:
+            warnings.warn(
+                "Decode calibration requires RULER dataset. Skipping decode calibration "
+                "because a custom forward_loop was provided without calibration_data."
+            )
+        else:
             print("\n" + "=" * 60)
             # ... rest of decode calibration
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.py` around
lines 306 - 316, The decode calibration block (calibrate_decode) currently
raises RuntimeError if calibration_data or tokenizer are missing even when the
user supplied forward_loop; change the logic in the calibrate_decode section to
detect when forward_loop is provided and calibration_data is None and skip
decode calibration with a warning instead of raising, i.e., only call
create_decode_calibration_forward_loop when calibration_data and tokenizer exist
(use create_decode_calibration_forward_loop(calibration_data, tokenizer, ...)),
otherwise log/warn that decode calibration is skipped due to missing
calibration_data/tokenizer while a custom forward_loop was supplied; update any
related docstring or comment near the calibrate_decode and forward_loop mentions
to reflect this behavior.

24-24: ⚠️ Potential issue | 🔴 Critical

Unconditional transformers import causes pipeline failure.

The module-level import of transformers.AutoTokenizer fails when transformers is not installed. This should be deferred to usage sites (inside _load_tokenizer or guarded) to allow the module to be imported when only diffusers-based workflows are used.

🐛 Proposed fix: defer import to usage site
-from transformers import AutoTokenizer

Then update _load_tokenizer:

 def _load_tokenizer(tokenizer_name_or_path: str) -> "AutoTokenizer":
     """Load tokenizer and ensure pad_token is set."""
+    from transformers import AutoTokenizer
+
     tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.py` at line
24, The file currently imports transformers.AutoTokenizer at module scope which
raises ImportError when transformers isn't installed; move the import into the
tokenizer-loading code path so the module can be imported without transformers.
Specifically, remove the top-level "from transformers import AutoTokenizer" and
instead import AutoTokenizer inside the _load_tokenizer function (or guard the
import with a try/except that raises a clear error when the function is called),
ensuring _load_tokenizer handles the absence of transformers and only then
attempts to create the tokenizer.
♻️ Duplicate comments (1)
modelopt/torch/sparsity/attention_sparsity/kernels/diffusers_triton_attention.py (1)

30-36: ⚠️ Potential issue | 🟠 Major

Same top-level import issue as the eager backend.

Both diffusers and modelopt.torch.kernels are imported unconditionally at the top level. This will cause import failures for users who don't have diffusers installed or don't have CUDA/Triton available.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In
`@modelopt/torch/sparsity/attention_sparsity/kernels/diffusers_triton_attention.py`
around lines 30 - 36, Top-level unconditional imports of diffusers symbols
(AttentionBackendName, _AttentionBackendRegistry, attention_backend) and
modelopt.torch.kernels.attention can fail for users without diffusers or
CUDA/Triton; move these imports into the function or class that actually uses
them (e.g., inside the registration function or the backend implementation in
diffusers_triton_attention.py) and guard with try/except ImportError to raise a
clear error only when the backend is instantiated; ensure you reference the same
symbols (AttentionBackendName, _AttentionBackendRegistry, attention_backend, and
attention) after relocating the imports so registration only occurs when
dependencies are present.
🧹 Nitpick comments (4)
modelopt/torch/sparsity/attention_sparsity/conversion.py (1)

124-125: Overly broad exception handling may hide bugs.

except (ImportError, Exception) catches all exceptions including programming errors (e.g., TypeError, AttributeError). Consider narrowing to specific expected exceptions.

♻️ Suggested fix
-    except (ImportError, Exception):
+    except (ImportError, RuntimeError):
         pass

Or log unexpected exceptions for debugging:

-    except (ImportError, Exception):
-        pass
+    except ImportError:
+        pass
+    except Exception as e:
+        import logging
+        logging.debug(f"Diffusers backend registration failed: {e}")
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/sparsity/attention_sparsity/conversion.py` around lines 124 -
125, The current broad except clause "except (ImportError, Exception)" in
conversion.py swallows all errors and can hide bugs; change it to only catch
ImportError (e.g., "except ImportError as e") for the import-failure path, and
if you must catch other runtime issues around the same block, catch specific
exceptions or log unexpected exceptions (use a logger.exception or re-raise
after logging) so that programming errors like TypeError/AttributeError are not
silently ignored; update the except block that follows the import attempt in
conversion.py accordingly.
examples/diffusers/sparsity/ltx2_skip_softmax.py (2)

66-81: Hardcoded user-specific paths should be placeholders.

The default paths contain user-specific paths (/home/scratch.omniml_data_2/jingyux/...) that won't exist on other systems. Consider using empty strings or raising a clear error when environment variables are not set.

Proposed fix: Require explicit configuration
-CHECKPOINT_PATH = os.environ.get(
-    "LTX2_CHECKPOINT",
-    "/home/scratch.omniml_data_2/jingyux/models/LTX-2/ltx-2-19b-dev.safetensors",
-)
+CHECKPOINT_PATH = os.environ.get("LTX2_CHECKPOINT", "")
+DISTILLED_LORA_PATH = os.environ.get("LTX2_DISTILLED_LORA", "")
+SPATIAL_UPSAMPLER_PATH = os.environ.get("LTX2_SPATIAL_UPSAMPLER", "")
+GEMMA_ROOT = os.environ.get("LTX2_GEMMA_ROOT", "")

Then in build_pipeline():

def build_pipeline() -> TI2VidTwoStagesPipeline:
    if not all([CHECKPOINT_PATH, DISTILLED_LORA_PATH, SPATIAL_UPSAMPLER_PATH, GEMMA_ROOT]):
        raise ValueError(
            "Missing required environment variables. Set: "
            "LTX2_CHECKPOINT, LTX2_DISTILLED_LORA, LTX2_SPATIAL_UPSAMPLER, LTX2_GEMMA_ROOT"
        )
    ...
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/diffusers/sparsity/ltx2_skip_softmax.py` around lines 66 - 81, The
file defines CHECKPOINT_PATH, DISTILLED_LORA_PATH, SPATIAL_UPSAMPLER_PATH, and
GEMMA_ROOT with hardcoded user-specific default paths; remove those user paths
and default to empty string or None when reading the env vars
(os.environ.get(..., "") or None), and add a validation at the start of
build_pipeline() that checks these constants (CHECKPOINT_PATH,
DISTILLED_LORA_PATH, SPATIAL_UPSAMPLER_PATH, GEMMA_ROOT) and raises a clear
ValueError listing the required env names (LTX2_CHECKPOINT, LTX2_DISTILLED_LORA,
LTX2_SPATIAL_UPSAMPLER, LTX2_GEMMA_ROOT) if any are missing so callers must
explicitly configure them.

260-267: load_dataset call may download data unexpectedly.

The load_dataset("nkp37/OpenVid-1M") call will download the dataset on first run, which could be surprising for users. Consider adding a note in the docstring or CLI help, or making this behavior opt-in.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/diffusers/sparsity/ltx2_skip_softmax.py` around lines 260 - 267, The
load_calib_prompts function calls load_dataset("nkp37/OpenVid-1M") which may
trigger a large download unexpectedly; update load_calib_prompts to make dataset
download explicit or opt-in (e.g., add a parameter like download: bool = False
or a dataset_path argument) and update the docstring to warn that calling this
function will download the OpenVid-1M dataset unless an existing local dataset
path is provided; ensure the function checks the opt-in flag or uses the
provided path before invoking load_dataset to avoid surprising downloads.
tests/unit/torch/sparsity/attention_sparsity/test_kernel_backends.py (1)

130-158: Test mock doesn't match actual kernel call signature.

The mock mk.attention = lambda q, k, v, **kw: q returns q directly, but per the context snippet, the actual kernel receives varlen metadata and returns output with shape [B*S, H, D]. The mock should return a tensor with the correct output shape to avoid masking reshape bugs in the code under test.

Proposed fix: Return correctly shaped tensor
-        mk.attention = lambda q, k, v, **kw: q
+        def mock_attention(q, k, v, **kw):
+            # Return tensor with same shape as q (correct for varlen format)
+            return torch.zeros_like(q)
+        mk.attention = mock_attention
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/unit/torch/sparsity/attention_sparsity/test_kernel_backends.py` around
lines 130 - 158, The mock attention implementation in
TestDiffusersTritonAttention._setup (mk.attention = lambda q, k, v, **kw: q)
does not match the real kernel signature/behavior and returns the wrong shape,
which hides reshape/masking bugs; update the mock in _setup (mk.attention) to
accept the same args including varlen metadata (keep **kw) and return a tensor
with shape [B*S, H, D] derived from the input q/k/v shapes (e.g., compute B, S,
H, D from q and construct/return a tensor of that shape instead of returning q
directly) so the code paths in _diffusers_triton_attention,
set_triton_skip_softmax_config, clear_triton_skip_softmax_config,
register_diffusers_triton_attention and get_triton_attention_backend are
exercised with realistic output shapes.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In
`@modelopt/torch/sparsity/attention_sparsity/kernels/diffusers_eager_attention.py`:
- Around line 31-35: The top-level import of diffusers internals
(AttentionBackendName, _AttentionBackendRegistry, attention_backend) should be
guarded to avoid ImportError for users without diffusers; wrap the import in a
try/except and set a module-level flag (e.g., _DIFFUSERS_AVAILABLE =
False/True). Update any registration or accessor functions that reference
AttentionBackendName, _AttentionBackendRegistry or attention_backend to check
_DIFFUSERS_AVAILABLE before using them and return/do nothing or raise a clear
runtime error when diffusers is unavailable. Ensure all places that previously
assumed the imports (registration functions) consult _DIFFUSERS_AVAILABLE so the
module can be imported without diffusers installed.
- Around line 120-132: The code unconditionally manipulates private diffusers
internals (AttentionBackendName, _AttentionBackendRegistry, etc.) that exist
only in diffusers >= 0.36.0; add a runtime/version guard before creating
new_member and registering _diffusers_eager_attention: check
diffusers.__version__ (or use the same utility used in
modelopt/torch/quantization/plugins/diffusion/diffusers.py) or probe for the
presence of attributes like AttentionBackendName._member_map_ and
_AttentionBackendRegistry._backends, and only perform the enum extension and
registry assignments when those APIs exist; also update pyproject.toml to
require diffusers>=0.36.0 or make the registration conditional so the code
no-ops on older diffusers versions.

In
`@modelopt/torch/sparsity/attention_sparsity/kernels/diffusers_triton_attention.py`:
- Around line 61-70: The _diffusers_triton_attention function currently accepts
attn_mask and enable_gqa but ignores them; update the function so it either
implements GQA and mask handling consistent with the eager backend or explicitly
fails fast: if enable_gqa is True or attn_mask is not None, raise a clear
NotImplementedError mentioning "_diffusers_triton_attention does not support
enable_gqa/attn_mask yet" (or implement the same GQA reshaping/aggregation logic
used in diffusers_eager_attention for query/key/value before calling the Triton
kernel) so callers won't silently get incorrect results.

In `@modelopt/torch/sparsity/attention_sparsity/kernels/ltx_triton_attention.py`:
- Around line 60-67: The _ltx_triton_attention function currently accepts a mask
parameter but never uses it; update the implementation to handle masks: either
pass the mask into the Triton kernel via the attn_mask argument when invoking
the kernel (ensure shapes/dtypes match and add logic to convert/expand the mask
to the kernel's expected form), or if kernel masking isn't supported yet,
explicitly reject masks by raising a clear error (e.g., raise
NotImplementedError("mask not supported by _ltx_triton_attention") when mask is
not None) so callers won't silently get incorrect results. Ensure the change is
applied inside _ltx_triton_attention and that any conversion/validation of mask
is performed before the kernel call.
- Line 29: The module unconditionally imports Attention from ltx_core which will
raise ImportError for users without LTX-2; change the top-level import to a
guarded import (try/except ImportError) or defer importing until registration,
set Attention = None on failure, and update register_ltx_triton_attention to
check if Attention is None and raise a clear ImportError like "ltx_core is
required for LTX-2 Triton attention" before proceeding; reference the symbols
Attention and register_ltx_triton_attention in ltx_triton_attention.py to locate
where to apply the guard.

In `@modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py`:
- Around line 389-401: The code sets the thread-wide flag via
set_skip_softmax_context(True) immediately which can leak if an exception occurs
before the returned ExitStack is entered; instead create a small context manager
(e.g., using contextlib.contextmanager or a tiny class) that calls
set_skip_softmax_context(True) on __enter__/enter and
set_skip_softmax_context(False) on __exit__/exit, and then register that context
with stack.enter_context rather than calling set_skip_softmax_context and
stack.callback directly; update the function that builds the stack (the block
using ExitStack, get_skip_softmax_attention_backend,
replace_function(torch.nn.functional, "softmax", sparse_softmax)) to enter the
new flag-context via stack.enter_context so the flag is only set when the stack
is actually entered and always cleaned up on exit.

In `@tests/unit/torch/sparsity/attention_sparsity/test_kernel_backends.py`:
- Around line 56-72: The tests import
modelopt.torch.sparsity.attention_sparsity.kernels which transitively imports
transformers and breaks CI; update TestSkipSoftmaxContext to skip when optional
dependency missing by using pytest.importorskip('transformers') or catching
ImportError before importing get_skip_softmax_context/set_skip_softmax_context
(or call pytest.skip) so the test cleanly skips in environments without
transformers; ensure the changes are applied around the imports used in
TestSkipSoftmaxContext (references: get_skip_softmax_context,
set_skip_softmax_context, TestSkipSoftmaxContext).
- Around line 178-207: The test fails because patching targets under
"modelopt.torch.sparsity.attention_sparsity.kernels" occurs before that
submodule is loaded, causing a Module attribute error; to fix, ensure the module
is imported before patching or patch the symbols at the location they are looked
up by _register_diffusers_backends_if_needed: import
modelopt.torch.sparsity.attention_sparsity.conversion (or the parent package)
first, then patch the call targets register_diffusers_eager_attention and
register_diffusers_triton_attention as used by that conversion module (i.e.,
patch where _register_diffusers_backends_if_needed resolves them) so the
MagicMock replacement applies correctly.

---

Outside diff comments:
In `@modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.py`:
- Around line 306-316: The decode calibration block (calibrate_decode) currently
raises RuntimeError if calibration_data or tokenizer are missing even when the
user supplied forward_loop; change the logic in the calibrate_decode section to
detect when forward_loop is provided and calibration_data is None and skip
decode calibration with a warning instead of raising, i.e., only call
create_decode_calibration_forward_loop when calibration_data and tokenizer exist
(use create_decode_calibration_forward_loop(calibration_data, tokenizer, ...)),
otherwise log/warn that decode calibration is skipped due to missing
calibration_data/tokenizer while a custom forward_loop was supplied; update any
related docstring or comment near the calibrate_decode and forward_loop mentions
to reflect this behavior.
- Line 24: The file currently imports transformers.AutoTokenizer at module scope
which raises ImportError when transformers isn't installed; move the import into
the tokenizer-loading code path so the module can be imported without
transformers. Specifically, remove the top-level "from transformers import
AutoTokenizer" and instead import AutoTokenizer inside the _load_tokenizer
function (or guard the import with a try/except that raises a clear error when
the function is called), ensuring _load_tokenizer handles the absence of
transformers and only then attempts to create the tokenizer.

---

Duplicate comments:
In
`@modelopt/torch/sparsity/attention_sparsity/kernels/diffusers_triton_attention.py`:
- Around line 30-36: Top-level unconditional imports of diffusers symbols
(AttentionBackendName, _AttentionBackendRegistry, attention_backend) and
modelopt.torch.kernels.attention can fail for users without diffusers or
CUDA/Triton; move these imports into the function or class that actually uses
them (e.g., inside the registration function or the backend implementation in
diffusers_triton_attention.py) and guard with try/except ImportError to raise a
clear error only when the backend is instantiated; ensure you reference the same
symbols (AttentionBackendName, _AttentionBackendRegistry, attention_backend, and
attention) after relocating the imports so registration only occurs when
dependencies are present.

---

Nitpick comments:
In `@examples/diffusers/sparsity/ltx2_skip_softmax.py`:
- Around line 66-81: The file defines CHECKPOINT_PATH, DISTILLED_LORA_PATH,
SPATIAL_UPSAMPLER_PATH, and GEMMA_ROOT with hardcoded user-specific default
paths; remove those user paths and default to empty string or None when reading
the env vars (os.environ.get(..., "") or None), and add a validation at the
start of build_pipeline() that checks these constants (CHECKPOINT_PATH,
DISTILLED_LORA_PATH, SPATIAL_UPSAMPLER_PATH, GEMMA_ROOT) and raises a clear
ValueError listing the required env names (LTX2_CHECKPOINT, LTX2_DISTILLED_LORA,
LTX2_SPATIAL_UPSAMPLER, LTX2_GEMMA_ROOT) if any are missing so callers must
explicitly configure them.
- Around line 260-267: The load_calib_prompts function calls
load_dataset("nkp37/OpenVid-1M") which may trigger a large download
unexpectedly; update load_calib_prompts to make dataset download explicit or
opt-in (e.g., add a parameter like download: bool = False or a dataset_path
argument) and update the docstring to warn that calling this function will
download the OpenVid-1M dataset unless an existing local dataset path is
provided; ensure the function checks the opt-in flag or uses the provided path
before invoking load_dataset to avoid surprising downloads.

In `@modelopt/torch/sparsity/attention_sparsity/conversion.py`:
- Around line 124-125: The current broad except clause "except (ImportError,
Exception)" in conversion.py swallows all errors and can hide bugs; change it to
only catch ImportError (e.g., "except ImportError as e") for the import-failure
path, and if you must catch other runtime issues around the same block, catch
specific exceptions or log unexpected exceptions (use a logger.exception or
re-raise after logging) so that programming errors like TypeError/AttributeError
are not silently ignored; update the except block that follows the import
attempt in conversion.py accordingly.

In `@tests/unit/torch/sparsity/attention_sparsity/test_kernel_backends.py`:
- Around line 130-158: The mock attention implementation in
TestDiffusersTritonAttention._setup (mk.attention = lambda q, k, v, **kw: q)
does not match the real kernel signature/behavior and returns the wrong shape,
which hides reshape/masking bugs; update the mock in _setup (mk.attention) to
accept the same args including varlen metadata (keep **kw) and return a tensor
with shape [B*S, H, D] derived from the input q/k/v shapes (e.g., compute B, S,
H, D from q and construct/return a tensor of that shape instead of returning q
directly) so the code paths in _diffusers_triton_attention,
set_triton_skip_softmax_config, clear_triton_skip_softmax_config,
register_diffusers_triton_attention and get_triton_attention_backend are
exercised with realistic output shapes.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 578e7893-06c2-4586-93e1-4726205a2f84

📥 Commits

Reviewing files that changed from the base of the PR and between 87ea8ba and 2c323df.

📒 Files selected for processing (14)
  • examples/diffusers/sparsity/ltx2_skip_softmax.py
  • examples/diffusers/sparsity/wan22_skip_softmax.py
  • modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.py
  • modelopt/torch/sparsity/attention_sparsity/conversion.py
  • modelopt/torch/sparsity/attention_sparsity/kernels/__init__.py
  • modelopt/torch/sparsity/attention_sparsity/kernels/diffusers_eager_attention.py
  • modelopt/torch/sparsity/attention_sparsity/kernels/diffusers_triton_attention.py
  • modelopt/torch/sparsity/attention_sparsity/kernels/ltx_eager_attention.py
  • modelopt/torch/sparsity/attention_sparsity/kernels/ltx_triton_attention.py
  • modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py
  • modelopt/torch/sparsity/attention_sparsity/methods/triton_skip_softmax.py
  • modelopt/torch/sparsity/attention_sparsity/plugins/huggingface.py
  • modelopt/torch/sparsity/attention_sparsity/stats_manager.py
  • tests/unit/torch/sparsity/attention_sparsity/test_kernel_backends.py

Comment on lines +31 to +35
from diffusers.models.attention_dispatch import (
AttentionBackendName,
_AttentionBackendRegistry,
attention_backend,
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Unconditional import of diffusers internals will cause ImportError for non-diffusers users.

Similar to the ltx_core issue, this module imports diffusers at the top level. Since attention_sparsity may be used without diffusers, consider guarding this import.

Proposed fix: guard the import
+try:
     from diffusers.models.attention_dispatch import (
         AttentionBackendName,
         _AttentionBackendRegistry,
         attention_backend,
     )
+    _DIFFUSERS_AVAILABLE = True
+except ImportError:
+    _DIFFUSERS_AVAILABLE = False
+    AttentionBackendName = None  # type: ignore
+    _AttentionBackendRegistry = None  # type: ignore
+    attention_backend = None  # type: ignore

Then check _DIFFUSERS_AVAILABLE in registration functions.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In
`@modelopt/torch/sparsity/attention_sparsity/kernels/diffusers_eager_attention.py`
around lines 31 - 35, The top-level import of diffusers internals
(AttentionBackendName, _AttentionBackendRegistry, attention_backend) should be
guarded to avoid ImportError for users without diffusers; wrap the import in a
try/except and set a module-level flag (e.g., _DIFFUSERS_AVAILABLE =
False/True). Update any registration or accessor functions that reference
AttentionBackendName, _AttentionBackendRegistry or attention_backend to check
_DIFFUSERS_AVAILABLE before using them and return/do nothing or raise a clear
runtime error when diffusers is unavailable. Ensure all places that previously
assumed the imports (registration functions) consult _DIFFUSERS_AVAILABLE so the
module can be imported without diffusers installed.

Comment on lines +120 to +132
# Extend the AttentionBackendName enum with our custom value
new_member = str.__new__(AttentionBackendName, _BACKEND_NAME)
new_member._name_ = "MODELOPT_SKIP_SOFTMAX"
new_member._value_ = _BACKEND_NAME
AttentionBackendName._member_map_["MODELOPT_SKIP_SOFTMAX"] = new_member
AttentionBackendName._value2member_map_[_BACKEND_NAME] = new_member

# Register the backend function
_AttentionBackendRegistry._backends[new_member] = _diffusers_eager_attention
_AttentionBackendRegistry._constraints[new_member] = []
_AttentionBackendRegistry._supported_arg_names[new_member] = set(
inspect.signature(_diffusers_eager_attention).parameters.keys()
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🌐 Web query:

What version of diffusers introduced _AttentionBackendRegistry and attention_backend context manager?

💡 Result:

The _AttentionBackendRegistry and attention_backend context manager were introduced in diffusers version 0.36.0.

Citations:


🏁 Script executed:

# Find pyproject.toml and requirements files to check diffusers version constraints
find . -type f \( -name "pyproject.toml" -o -name "requirements*.txt" -o -name "setup.py" \) | head -20

Repository: NVIDIA/Model-Optimizer

Length of output: 945


🏁 Script executed:

# Check the specific file to see its full context and any version checks
cat -n modelopt/torch/sparsity/attention_sparsity/kernels/diffusers_eager_attention.py | head -150

Repository: NVIDIA/Model-Optimizer

Length of output: 6161


🏁 Script executed:

# Search for any version checks related to diffusers in the codebase
rg "diffusers" --type py -A 2 | grep -i "version\|__version__" | head -20

Repository: NVIDIA/Model-Optimizer

Length of output: 2357


🏁 Script executed:

# Check if there are any documented version requirements or constraints
rg "0\.36\|diffusers.*version" --type md --type txt --type py | head -20

Repository: NVIDIA/Model-Optimizer

Length of output: 48


🏁 Script executed:

# Check conversion.py to see how registration is called and if version checks exist
cat -n modelopt/torch/sparsity/attention_sparsity/conversion.py | grep -A 15 -B 5 "_register_diffusers_backends_if_needed"

Repository: NVIDIA/Model-Optimizer

Length of output: 1992


🏁 Script executed:

# Find and check pyproject.toml for diffusers version constraints
cat pyproject.toml | grep -A 10 -B 2 "diffusers"

Repository: NVIDIA/Model-Optimizer

Length of output: 690


🏁 Script executed:

# Check if there are any import guards around the diffusers imports
rg "from diffusers.*attention_dispatch" -B 5 -A 5

Repository: NVIDIA/Model-Optimizer

Length of output: 3478


Add version check for diffusers >= 0.36.0 and document private API reliance.

The code imports and manipulates private diffusers internals (_AttentionBackendRegistry, _member_map_, _value2member_map_, _backends, _constraints, _supported_arg_names) that were introduced in diffusers 0.36.0. However, pyproject.toml specifies only diffusers>=0.32.2, creating a compatibility gap. Add a version guard on module import (following the pattern in modelopt/torch/quantization/plugins/diffusion/diffusers.py), update the minimum version requirement in pyproject.toml, or conditionally register the backend only when the required APIs are available.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In
`@modelopt/torch/sparsity/attention_sparsity/kernels/diffusers_eager_attention.py`
around lines 120 - 132, The code unconditionally manipulates private diffusers
internals (AttentionBackendName, _AttentionBackendRegistry, etc.) that exist
only in diffusers >= 0.36.0; add a runtime/version guard before creating
new_member and registering _diffusers_eager_attention: check
diffusers.__version__ (or use the same utility used in
modelopt/torch/quantization/plugins/diffusion/diffusers.py) or probe for the
presence of attributes like AttentionBackendName._member_map_ and
_AttentionBackendRegistry._backends, and only perform the enum extension and
registry assignments when those APIs exist; also update pyproject.toml to
require diffusers>=0.36.0 or make the registration conditional so the code
no-ops on older diffusers versions.

Comment on lines +61 to +70
def _diffusers_triton_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: torch.Tensor | None = None,
dropout_p: float = 0.0,
is_causal: bool = False,
scale: float | None = None,
enable_gqa: bool = False,
) -> torch.Tensor:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Parameters attn_mask and enable_gqa are accepted but silently ignored.

Unlike the eager backend which implements GQA support (Lines 66-71 in diffusers_eager_attention.py), the Triton backend ignores enable_gqa. This inconsistency could cause silent correctness issues when switching backends.

Proposed fix: Add GQA support or raise error
 def _diffusers_triton_attention(
     ...
     enable_gqa: bool = False,
 ) -> torch.Tensor:
+    if enable_gqa and query.shape[2] != key.shape[2]:
+        raise NotImplementedError(
+            "GQA not yet supported in Triton backend. Use eager backend for calibration."
+        )
+    if attn_mask is not None:
+        raise NotImplementedError(
+            "attn_mask not yet supported in Triton backend."
+        )
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def _diffusers_triton_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: torch.Tensor | None = None,
dropout_p: float = 0.0,
is_causal: bool = False,
scale: float | None = None,
enable_gqa: bool = False,
) -> torch.Tensor:
def _diffusers_triton_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: torch.Tensor | None = None,
dropout_p: float = 0.0,
is_causal: bool = False,
scale: float | None = None,
enable_gqa: bool = False,
) -> torch.Tensor:
if enable_gqa and query.shape[2] != key.shape[2]:
raise NotImplementedError(
"GQA not yet supported in Triton backend. Use eager backend for calibration."
)
if attn_mask is not None:
raise NotImplementedError(
"attn_mask not yet supported in Triton backend."
)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In
`@modelopt/torch/sparsity/attention_sparsity/kernels/diffusers_triton_attention.py`
around lines 61 - 70, The _diffusers_triton_attention function currently accepts
attn_mask and enable_gqa but ignores them; update the function so it either
implements GQA and mask handling consistent with the eager backend or explicitly
fails fast: if enable_gqa is True or attn_mask is not None, raise a clear
NotImplementedError mentioning "_diffusers_triton_attention does not support
enable_gqa/attn_mask yet" (or implement the same GQA reshaping/aggregation logic
used in diffusers_eager_attention for query/key/value before calling the Triton
kernel) so callers won't silently get incorrect results.

import threading

import torch
from ltx_core.model.transformer.attention import Attention
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Unconditional import of ltx_core will cause ImportError for users without LTX-2.

This module is loaded at import time, but ltx_core is an optional dependency only available for LTX-2 users. Consider wrapping in a try/except or deferring the import to registration time.

Proposed fix: guard the import
-from ltx_core.model.transformer.attention import Attention
+try:
+    from ltx_core.model.transformer.attention import Attention
+except ImportError:
+    Attention = None  # type: ignore[misc,assignment]

Then in register_ltx_triton_attention:

def register_ltx_triton_attention(model: torch.nn.Module) -> None:
    if Attention is None:
        raise ImportError("ltx_core is required for LTX-2 Triton attention")
    ...
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/sparsity/attention_sparsity/kernels/ltx_triton_attention.py`
at line 29, The module unconditionally imports Attention from ltx_core which
will raise ImportError for users without LTX-2; change the top-level import to a
guarded import (try/except ImportError) or defer importing until registration,
set Attention = None on failure, and update register_ltx_triton_attention to
check if Attention is None and raise a clear ImportError like "ltx_core is
required for LTX-2 Triton attention" before proceeding; reference the symbols
Attention and register_ltx_triton_attention in ltx_triton_attention.py to locate
where to apply the guard.

Comment on lines +60 to +67
def _ltx_triton_attention(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
heads: int,
mask: torch.Tensor | None = None,
threshold: float | None = None,
) -> torch.Tensor:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

mask parameter is accepted but silently ignored.

The mask parameter is passed through the wrapper (Line 136) but never used in the Triton attention implementation. This could lead to silent correctness issues if the original attention function relied on masking.

Consider either:

  1. Implementing mask support via attn_mask in the kernel call
  2. Raising an error/warning if a mask is provided but not supported
  3. Documenting this limitation clearly
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/sparsity/attention_sparsity/kernels/ltx_triton_attention.py`
around lines 60 - 67, The _ltx_triton_attention function currently accepts a
mask parameter but never uses it; update the implementation to handle masks:
either pass the mask into the Triton kernel via the attn_mask argument when
invoking the kernel (ensure shapes/dtypes match and add logic to convert/expand
the mask to the kernel's expected form), or if kernel masking isn't supported
yet, explicitly reject masks by raising a clear error (e.g., raise
NotImplementedError("mask not supported by _ltx_triton_attention") when mask is
not None) so callers won't silently get incorrect results. Ensure the change is
applied inside _ltx_triton_attention and that any conversion/validation of mask
is performed before the kernel call.

Comment on lines +389 to +401
stack = ExitStack()
set_skip_softmax_context(True)
stack.callback(set_skip_softmax_context, False)

try:
from ..kernels.diffusers_eager_attention import get_skip_softmax_attention_backend

stack.enter_context(get_skip_softmax_attention_backend())
except (ImportError, RuntimeError):
pass

stack.enter_context(replace_function(torch.nn.functional, "softmax", sparse_softmax))
return stack
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Context flag set before stack is entered — cleanup may not run.

set_skip_softmax_context(True) is called immediately at line 390, but the cleanup callback only executes when the ExitStack exits. If:

  1. An exception occurs before the stack is returned, or
  2. The caller never enters the returned stack as a context manager

...the flag will remain True, potentially affecting subsequent attention computations in the same thread.

🔧 Proposed fix: defer flag-setting to stack entry
         from ..kernels import set_skip_softmax_context
 
         stack = ExitStack()
-        set_skip_softmax_context(True)
-        stack.callback(set_skip_softmax_context, False)
+
+        `@contextmanager`
+        def _skip_softmax_flag():
+            set_skip_softmax_context(True)
+            try:
+                yield
+            finally:
+                set_skip_softmax_context(False)
+
+        stack.enter_context(_skip_softmax_flag())
 
         try:
             from ..kernels.diffusers_eager_attention import get_skip_softmax_attention_backend

Or simpler, wrap the flag toggle in a small context manager helper.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py`
around lines 389 - 401, The code sets the thread-wide flag via
set_skip_softmax_context(True) immediately which can leak if an exception occurs
before the returned ExitStack is entered; instead create a small context manager
(e.g., using contextlib.contextmanager or a tiny class) that calls
set_skip_softmax_context(True) on __enter__/enter and
set_skip_softmax_context(False) on __exit__/exit, and then register that context
with stack.enter_context rather than calling set_skip_softmax_context and
stack.callback directly; update the function that builds the stack (the block
using ExitStack, get_skip_softmax_attention_backend,
replace_function(torch.nn.functional, "softmax", sparse_softmax)) to enter the
new flag-context via stack.enter_context so the flag is only set when the stack
is actually entered and always cleaned up on exit.

Comment on lines +56 to +72
class TestSkipSoftmaxContext:
def test_default_is_false(self):
from modelopt.torch.sparsity.attention_sparsity.kernels import get_skip_softmax_context

assert get_skip_softmax_context() is False

def test_set_and_get(self):
from modelopt.torch.sparsity.attention_sparsity.kernels import (
get_skip_softmax_context,
set_skip_softmax_context,
)

set_skip_softmax_context(True)
assert get_skip_softmax_context() is True
set_skip_softmax_context(False)
assert get_skip_softmax_context() is False

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Pipeline failure: transformers import triggered during test execution.

The pipeline failures show that importing from modelopt.torch.sparsity.attention_sparsity.kernels transitively imports transformers, which isn't available in the unit test environment. Unit tests should run on CPU without optional dependencies.

Proposed fix: Skip tests when dependencies unavailable
 class TestSkipSoftmaxContext:
+    `@pytest.fixture`(autouse=True)
+    def _skip_if_unavailable(self):
+        pytest.importorskip("transformers", reason="transformers required for attention_sparsity")
+
     def test_default_is_false(self):

Or mock the entire import chain more thoroughly to avoid triggering the transitive import.

🧰 Tools
🪛 GitHub Actions: Unit tests

[error] 58-58: Failed import while executing test_default_is_false due to missing dependency: ModuleNotFoundError: No module named 'transformers'


[error] 63-63: Failed import while executing test_set_and_get due to missing dependency: ModuleNotFoundError: No module named 'transformers'

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/unit/torch/sparsity/attention_sparsity/test_kernel_backends.py` around
lines 56 - 72, The tests import
modelopt.torch.sparsity.attention_sparsity.kernels which transitively imports
transformers and breaks CI; update TestSkipSoftmaxContext to skip when optional
dependency missing by using pytest.importorskip('transformers') or catching
ImportError before importing get_skip_softmax_context/set_skip_softmax_context
(or call pytest.skip) so the test cleanly skips in environments without
transformers; ensure the changes are applied around the imports used in
TestSkipSoftmaxContext (references: get_skip_softmax_context,
set_skip_softmax_context, TestSkipSoftmaxContext).

Comment on lines +178 to +207
class TestRegisterDiffusersBackends:
def test_no_diffusers_no_error(self):
from modelopt.torch.sparsity.attention_sparsity.conversion import (
_register_diffusers_backends_if_needed,
)

_register_diffusers_backends_if_needed(nn.Linear(10, 10))

def test_with_diffusers_model(self):
from modelopt.torch.sparsity.attention_sparsity.conversion import (
_register_diffusers_backends_if_needed,
)

mock_mixin = type("ModelMixin", (nn.Module,), {})
mock_utils = types.ModuleType("diffusers.models.modeling_utils")
mock_utils.ModelMixin = mock_mixin

with (
patch.dict(sys.modules, {"diffusers.models.modeling_utils": mock_utils}),
patch(
"modelopt.torch.sparsity.attention_sparsity.kernels.register_diffusers_eager_attention",
MagicMock(),
) as mock_eager,
patch(
"modelopt.torch.sparsity.attention_sparsity.kernels.register_diffusers_triton_attention",
MagicMock(),
) as mock_triton,
):
_register_diffusers_backends_if_needed(mock_mixin())
mock_eager.assert_called_once()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Pipeline failure: Module attribute error during patching.

The error module 'modelopt.torch.sparsity' has no attribute 'attention_sparsity' suggests the module isn't fully loaded when patching. The patch target path may need adjustment.

Proposed fix: Patch at the correct import location

The patch targets should reference where the functions are used, not where they're defined:

-            patch(
-                "modelopt.torch.sparsity.attention_sparsity.kernels.register_diffusers_eager_attention",
+            patch(
+                "modelopt.torch.sparsity.attention_sparsity.conversion.register_diffusers_eager_attention",

Or ensure the module is imported before patching:

import modelopt.torch.sparsity.attention_sparsity.conversion  # Force load
🧰 Tools
🪛 GitHub Actions: Unit tests

[error] 197-197: AttributeError: module 'modelopt.torch.sparsity' has no attribute 'attention_sparsity' (failure in test_with_diffusers_model while patching kernel registration functions)

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/unit/torch/sparsity/attention_sparsity/test_kernel_backends.py` around
lines 178 - 207, The test fails because patching targets under
"modelopt.torch.sparsity.attention_sparsity.kernels" occurs before that
submodule is loaded, causing a Module attribute error; to fix, ensure the module
is imported before patching or patch the symbols at the location they are looked
up by _register_diffusers_backends_if_needed: import
modelopt.torch.sparsity.attention_sparsity.conversion (or the parent package)
first, then patch the call targets register_diffusers_eager_attention and
register_diffusers_triton_attention as used by that conversion module (i.e.,
patch where _register_diffusers_backends_if_needed resolves them) so the
MagicMock replacement applies correctly.

jingyu-ml and others added 2 commits April 6, 2026 17:20
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.py (1)

308-317: ⚠️ Potential issue | 🟡 Minor

Improve error message and consider early validation when custom forward_loop conflicts with decode calibration.

When a user provides a custom forward_loop (e.g., for diffusion models) but also configures calibrate_decode=True, the current RuntimeError message "calibration_data and tokenizer must be built before decode" is confusing—it implies a bug rather than an unsupported configuration.

Consider either:

  1. Early validation (preferred): Check at the start of calibration if forward_loop is not None and calibrate_decode and raise with a clear message, or
  2. Improve the error message to explicitly state the limitation.
Option 1: Add early validation near line 246
     # Skip if both phases are disabled
     if not calibrate_prefill and not calibrate_decode:
         print("Both prefill and decode target sparsity are 0.0, skipping calibration")
         return {}

+    # Decode calibration requires RULER dataset, which is incompatible with custom forward_loop
+    if forward_loop is not None and calibrate_decode:
+        raise ValueError(
+            "Decode calibration is not supported when a custom forward_loop is provided. "
+            "Either set decode target_sparse_ratio to 0.0 or remove the forward_loop argument "
+            "to use auto-generated RULER dataset calibration."
+        )
+
     # Get sparse attention modules
Option 2: Improve error message at lines 313-314
         if calibration_data is None or tokenizer is None:
-            raise RuntimeError("calibration_data and tokenizer must be built before decode")
+            raise RuntimeError(
+                "Decode calibration requires tokenizer and RULER dataset, which are not available "
+                "when a custom forward_loop is provided. Set decode target_sparse_ratio to 0.0 "
+                "or remove the forward_loop argument to enable decode calibration."
+            )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.py` around
lines 308 - 317, Add an early validation that explicitly rejects using a custom
forward_loop with decode calibration: if forward_loop is not None and
calibrate_decode is True, raise a clear RuntimeError stating that decode
calibration is incompatible with a custom forward_loop and instruct the user to
disable calibrate_decode or remove the custom forward_loop; place this check
near the start of the calibration flow (before
create_decode_calibration_forward_loop is called). Alternatively, if you prefer
the minimal change, improve the existing RuntimeError in the block that checks
calibration_data and tokenizer to a message that either explains missing
calibration_data/tokenizer or that decode calibration is unsupported when a
custom forward_loop is provided (reference forward_loop, calibrate_decode,
calibration_data, tokenizer, and create_decode_calibration_forward_loop).
🧹 Nitpick comments (1)
modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.py (1)

34-41: Consider restoring the return type annotation.

The lazy import is a good practice per project guidelines. However, removing the return type annotation reduces type safety. Consider adding it back using a string literal or TYPE_CHECKING import to maintain mypy compatibility:

Suggested fix
+from typing import TYPE_CHECKING
+
+if TYPE_CHECKING:
+    from transformers import PreTrainedTokenizerBase
+
-def _load_tokenizer(tokenizer_name_or_path: str):
+def _load_tokenizer(tokenizer_name_or_path: str) -> "PreTrainedTokenizerBase":
     """Load tokenizer and ensure pad_token is set."""
     from transformers import AutoTokenizer
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.py` around
lines 34 - 41, Restore the return type annotation for _load_tokenizer to
preserve type safety; update the signature to annotate the return as
"PreTrainedTokenizerBase" (a string literal) or import PreTrainedTokenizerBase
under TYPE_CHECKING from transformers and use it as the return type, keeping the
lazy import of AutoTokenizer inside the function and leaving runtime behavior
unchanged; ensure the chosen symbol (PreTrainedTokenizerBase) is referenced in
the annotation for mypy compatibility while avoiding importing transformers at
module import time.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@modelopt/torch/sparsity/attention_sparsity/conversion.py`:
- Around line 405-443: Guard against empty per-phase payloads by selecting the
first non-empty calibration dict instead of blindly using
next(iter(calibration_params.values())). Replace sample_params =
next(iter(calibration_params.values())) with logic that finds the first
non-empty value: sample_params = next((v for v in calibration_params.values() if
v), None); set is_percentile only when sample_params is not None, and if
sample_params is None skip building threshold_config / threshold_scale_factor
and the per-phase loops (so export_config stays without threshold entries),
referencing the existing names calibration_params, sample_params, is_percentile,
export_config, threshold_config and threshold_scale_factor.
- Around line 125-139: The code currently suppresses all exceptions around
diffusers backend registration (the block checking isinstance(model, ModelMixin)
and the uses of contextlib.suppress(Exception)), which can hide real
registration errors; change the try/except to only catch ImportError when
importing ModelMixin and only suppress ImportError when calling
register_diffusers_eager_attention and register_diffusers_triton_attention so
that any other Exception raised by those registration functions bubbles up (or
is re-raised) instead of being swallowed; update the other places using
contextlib.suppress(Exception) to suppress ImportError only and ensure failures
in register_diffusers_* or in _set_attn_implementation are not silenced so the
error surfaces during conversion.

---

Outside diff comments:
In `@modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.py`:
- Around line 308-317: Add an early validation that explicitly rejects using a
custom forward_loop with decode calibration: if forward_loop is not None and
calibrate_decode is True, raise a clear RuntimeError stating that decode
calibration is incompatible with a custom forward_loop and instruct the user to
disable calibrate_decode or remove the custom forward_loop; place this check
near the start of the calibration flow (before
create_decode_calibration_forward_loop is called). Alternatively, if you prefer
the minimal change, improve the existing RuntimeError in the block that checks
calibration_data and tokenizer to a message that either explains missing
calibration_data/tokenizer or that decode calibration is unsupported when a
custom forward_loop is provided (reference forward_loop, calibrate_decode,
calibration_data, tokenizer, and create_decode_calibration_forward_loop).

---

Nitpick comments:
In `@modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.py`:
- Around line 34-41: Restore the return type annotation for _load_tokenizer to
preserve type safety; update the signature to annotate the return as
"PreTrainedTokenizerBase" (a string literal) or import PreTrainedTokenizerBase
under TYPE_CHECKING from transformers and use it as the return type, keeping the
lazy import of AutoTokenizer inside the function and leaving runtime behavior
unchanged; ensure the chosen symbol (PreTrainedTokenizerBase) is referenced in
the annotation for mypy compatibility while avoiding importing transformers at
module import time.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 728d5b26-043d-44d2-a39e-cc945d944a89

📥 Commits

Reviewing files that changed from the base of the PR and between 2c323df and bbe2123.

📒 Files selected for processing (2)
  • modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.py
  • modelopt/torch/sparsity/attention_sparsity/conversion.py

Comment on lines +125 to +139
try:
from diffusers.models.modeling_utils import ModelMixin

if isinstance(model, ModelMixin):
from .kernels import (
register_diffusers_eager_attention,
register_diffusers_triton_attention,
)

if register_diffusers_eager_attention is not None:
register_diffusers_eager_attention()
if register_diffusers_triton_attention is not None:
register_diffusers_triton_attention()
except (ImportError, Exception):
pass
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

# First, let's examine the file structure and the specific lines mentioned
cd /repo && head -150 modelopt/torch/sparsity/attention_sparsity/conversion.py | tail -40

Repository: NVIDIA/Model-Optimizer

Length of output: 121


🏁 Script executed:

# Get a fuller view of the section around lines 118-155
cd /repo && sed -n '115,160p' modelopt/torch/sparsity/attention_sparsity/conversion.py

Repository: NVIDIA/Model-Optimizer

Length of output: 121


🏁 Script executed:

# Also check the other segments mentioned in the scratchpad (405-444, 516-525, 553-554)
cd /repo && sed -n '400,450p' modelopt/torch/sparsity/attention_sparsity/conversion.py

Repository: NVIDIA/Model-Optimizer

Length of output: 121


🏁 Script executed:

# Check what exceptions could be raised during registration
cd /repo && sed -n '1,50p' modelopt/torch/sparsity/attention_sparsity/conversion.py

Repository: NVIDIA/Model-Optimizer

Length of output: 121


🏁 Script executed:

# Find the file in the repository
git ls-files | grep "conversion.py"

Repository: NVIDIA/Model-Optimizer

Length of output: 544


🏁 Script executed:

# Get file size and initial content
git ls-files | grep "conversion.py" | xargs wc -l

Repository: NVIDIA/Model-Optimizer

Length of output: 627


🏁 Script executed:

# Read the specific sections mentioned in the review
git ls-files | grep "conversion.py" | xargs sed -n '115,160p'

Repository: NVIDIA/Model-Optimizer

Length of output: 1540


🏁 Script executed:

# Check the export format section (405-444)
git ls-files | grep "conversion.py" | xargs sed -n '400,450p'

Repository: NVIDIA/Model-Optimizer

Length of output: 1726


🏁 Script executed:

# Read lines 115-160 from the attention_sparsity conversion.py file
sed -n '115,160p' modelopt/torch/sparsity/attention_sparsity/conversion.py

Repository: NVIDIA/Model-Optimizer

Length of output: 1820


🏁 Script executed:

# Also get lines 405-450 for the export format logic
sed -n '405,450p' modelopt/torch/sparsity/attention_sparsity/conversion.py

Repository: NVIDIA/Model-Optimizer

Length of output: 1717


🏁 Script executed:

# Check lines 144-154 which is also mentioned
sed -n '140,160p' modelopt/torch/sparsity/attention_sparsity/conversion.py

Repository: NVIDIA/Model-Optimizer

Length of output: 775


🏁 Script executed:

# Get the full function context starting from line 118
sed -n '118,160p' modelopt/torch/sparsity/attention_sparsity/conversion.py

Repository: NVIDIA/Model-Optimizer

Length of output: 1730


🏁 Script executed:

# Verify where _register_diffusers_backends_if_needed is called
sed -n '160,180p' modelopt/torch/sparsity/attention_sparsity/conversion.py

Repository: NVIDIA/Model-Optimizer

Length of output: 701


🏁 Script executed:

# Check what _set_attn_implementation does
grep -n "_set_attn_implementation" modelopt/torch/sparsity/attention_sparsity/conversion.py

Repository: NVIDIA/Model-Optimizer

Length of output: 203


🏁 Script executed:

# See how the export_sparse_attention_config is called
grep -n "export_sparse_attention_config" modelopt/torch/sparsity/attention_sparsity/conversion.py

Repository: NVIDIA/Model-Optimizer

Length of output: 148


🏁 Script executed:

# Check what _set_attn_implementation does and what it expects
sed -n '35,70p' modelopt/torch/sparsity/attention_sparsity/conversion.py

Repository: NVIDIA/Model-Optimizer

Length of output: 2039


🏁 Script executed:

# Understand the broader context - what happens if backends aren't available
sed -n '35,100p' modelopt/torch/sparsity/attention_sparsity/conversion.py

Repository: NVIDIA/Model-Optimizer

Length of output: 3306


Don't swallow unexpected backend registration failures.

Lines 125-139 and 144-154 use overly broad exception suppression. Line 138 catches (ImportError, Exception), and lines 150 & 152 use contextlib.suppress(Exception). This masks registration bugs—if a backend registration fails unexpectedly, conversion still proceeds and _set_attn_implementation() will route the model to use eager or modelopt_triton with the backend never actually registered. Only suppress ImportError (optional dependency missing) and surface unexpected registration failures so they don't cause silent runtime failures.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/sparsity/attention_sparsity/conversion.py` around lines 125 -
139, The code currently suppresses all exceptions around diffusers backend
registration (the block checking isinstance(model, ModelMixin) and the uses of
contextlib.suppress(Exception)), which can hide real registration errors; change
the try/except to only catch ImportError when importing ModelMixin and only
suppress ImportError when calling register_diffusers_eager_attention and
register_diffusers_triton_attention so that any other Exception raised by those
registration functions bubbles up (or is re-raised) instead of being swallowed;
update the other places using contextlib.suppress(Exception) to suppress
ImportError only and ensure failures in register_diffusers_* or in
_set_attn_implementation are not silenced so the error surfaces during
conversion.

Comment on lines +405 to +443
# Detect calibration type from params
sample_params = next(iter(calibration_params.values()))
is_percentile = "threshold" in sample_params

# Build the export config
export_config: dict[str, Any] = {
"config_groups": {
"group_0": {
"sparse_algo": "softmax_skip",
"sparse_algo": "softmax_skip_diffusion" if is_percentile else "softmax_skip",
"targets": sorted(target_classes) if target_classes else ["Attention"],
}
},
"threshold_scale_factor": threshold_scale_factor,
"producer": {
"name": "modelopt",
"version": mo_version,
},
}

if is_percentile:
threshold_config: dict[str, Any] = {
"formula": "skip if gap >= threshold * log(seq_k)",
}
for phase in ["prefill", "decode"]:
if phase in calibration_params:
threshold_config[phase] = {
"threshold": calibration_params[phase]["threshold"],
}
export_config["threshold_config"] = threshold_config
else:
threshold_scale_factor: dict[str, Any] = {
"formula": "a * exp(b * target_sparsity)",
}
for phase in ["prefill", "decode"]:
if phase in calibration_params:
threshold_scale_factor[phase] = {
"a": calibration_params[phase]["a"],
"b": calibration_params[phase]["b"],
}
export_config["threshold_scale_factor"] = threshold_scale_factor
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Guard empty phase payloads before inferring the export format.

Line 406 assumes calibration_params already contains a populated per-phase dict. If the first module contributes {} or a placeholder phase entry, this now raises StopIteration or falls into the wrong branch and then hits KeyError on Lines 430/440 instead of honoring the documented None fallback for “no calibrated sparse attention modules”.

💡 Proposed fix
-    sample_params = next(iter(calibration_params.values()))
+    sample_params = next((params for params in calibration_params.values() if params), None)
+    if sample_params is None:
+        return None
     is_percentile = "threshold" in sample_params
@@
         for phase in ["prefill", "decode"]:
-            if phase in calibration_params:
+            phase_params = calibration_params.get(phase)
+            if phase_params and "threshold" in phase_params:
                 threshold_config[phase] = {
-                    "threshold": calibration_params[phase]["threshold"],
+                    "threshold": phase_params["threshold"],
                 }
@@
         for phase in ["prefill", "decode"]:
-            if phase in calibration_params:
+            phase_params = calibration_params.get(phase)
+            if phase_params and all(k in phase_params for k in ("a", "b")):
                 threshold_scale_factor[phase] = {
-                    "a": calibration_params[phase]["a"],
-                    "b": calibration_params[phase]["b"],
+                    "a": phase_params["a"],
+                    "b": phase_params["b"],
                 }
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/sparsity/attention_sparsity/conversion.py` around lines 405 -
443, Guard against empty per-phase payloads by selecting the first non-empty
calibration dict instead of blindly using
next(iter(calibration_params.values())). Replace sample_params =
next(iter(calibration_params.values())) with logic that finds the first
non-empty value: sample_params = next((v for v in calibration_params.values() if
v), None); set is_percentile only when sample_params is not None, and if
sample_params is None skip building threshold_config / threshold_scale_factor
and the per-phase loops (so export_config stays without threshold entries),
referencing the existing names calibration_params, sample_params, is_percentile,
export_config, threshold_config and threshold_scale_factor.

Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

♻️ Duplicate comments (1)
modelopt/torch/sparsity/attention_sparsity/conversion.py (1)

125-154: ⚠️ Potential issue | 🟠 Major

Don’t suppress all backend registration errors.

Line 138 and Lines 150-154 currently swallow arbitrary exceptions. That can hide real backend registration bugs and leave conversion proceeding with a partially configured runtime.

Proposed fix
-    except (ImportError, Exception):
+    except ImportError:
         pass
@@
-    except (ImportError, RuntimeError):
+    except ImportError:
         return
@@
-    if register_ltx_eager_attention is not None:
-        with contextlib.suppress(Exception):
-            register_ltx_eager_attention(model)
-    if register_ltx_triton_attention is not None:
-        with contextlib.suppress(Exception):
-            register_ltx_triton_attention(model)
+    if register_ltx_eager_attention is not None:
+        with contextlib.suppress(ImportError):
+            register_ltx_eager_attention(model)
+    if register_ltx_triton_attention is not None:
+        with contextlib.suppress(ImportError):
+            register_ltx_triton_attention(model)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/sparsity/attention_sparsity/conversion.py` around lines 125 -
154, The code currently suppresses all exceptions when importing/calling backend
registration functions (the try/except around ModelMixin and the
contextlib.suppress usage for
register_ltx_eager_attention/register_ltx_triton_attention), which hides real
errors; change these to only catch expected import/runtime errors and surface or
log unexpected exceptions: when importing from .kernels and checking
isinstance(model, ModelMixin), catch only ImportError and RuntimeError, and when
invoking register_diffusers_eager_attention,
register_diffusers_triton_attention, register_ltx_eager_attention, and
register_ltx_triton_attention, wrap each call in a try/except that logs the full
exception (including stack trace) via the module logger and re-raises or returns
a clear error instead of silently swallowing it so backend-registration failures
are visible and conversion does not proceed silently with a misconfigured
runtime.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Duplicate comments:
In `@modelopt/torch/sparsity/attention_sparsity/conversion.py`:
- Around line 125-154: The code currently suppresses all exceptions when
importing/calling backend registration functions (the try/except around
ModelMixin and the contextlib.suppress usage for
register_ltx_eager_attention/register_ltx_triton_attention), which hides real
errors; change these to only catch expected import/runtime errors and surface or
log unexpected exceptions: when importing from .kernels and checking
isinstance(model, ModelMixin), catch only ImportError and RuntimeError, and when
invoking register_diffusers_eager_attention,
register_diffusers_triton_attention, register_ltx_eager_attention, and
register_ltx_triton_attention, wrap each call in a try/except that logs the full
exception (including stack trace) via the module logger and re-raises or returns
a clear error instead of silently swallowing it so backend-registration failures
are visible and conversion does not proceed silently with a misconfigured
runtime.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 09ecbcc7-b977-4605-a5e2-b4e013bcaee8

📥 Commits

Reviewing files that changed from the base of the PR and between bbe2123 and 70099a5.

📒 Files selected for processing (1)
  • modelopt/torch/sparsity/attention_sparsity/conversion.py

Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
@jingyu-ml jingyu-ml requested a review from Edwardf0t1 April 7, 2026 02:16
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@examples/diffusers/sparsity/wan22_skip_softmax.py`:
- Around line 198-206: The calibration loop in build_calibration_forward_loop
currently hardcodes guidance_scale=5.0 so calibration activations ignore the CLI
--guidance-scale; add a guidance_scale parameter to
build_calibration_forward_loop (and the other calibration-related functions
mentioned) and pass that guidance_scale through to any calls that currently use
guidance_scale=5.0 (look for explicit guidance_scale=5.0 in the function body
and replace with the new guidance_scale parameter), ensuring the same parameter
is threaded into the calibration runs so activation collection honors the
user-specified guidance scale.
- Around line 177-183: The calibration block is being added at the top-level
sparse_cfg (making "calibration" a selector) instead of being nested under the
self-attention selector; update the code so that when args.calibrate is true you
insert the calibration dict (using args.target_sparsity,
DEFAULT_THRESHOLD_TRIALS and samples:1) into the "*.attn1*" entry of sparse_cfg
(i.e., sparse_cfg["*.attn1*"]["calibration"] = {...}) so the attn1 selector
receives the calibration config rather than creating a new selector at the top
level.
- Around line 77-135: The CLI allows invalid values that later break
calibration; update parse_args() to validate arguments after
parser.parse_args(): ensure args.num_frames and args.calib_frames satisfy (value
- 1) % 4 == 0 and are > 1 (i.e., 4k+1), ensure args.target_sparsity is between
0.0 and 1.0 inclusive, and if a check fails call parser.error(...) with a clear
message referencing the offending flag (e.g., --num-frames, --calib-frames,
--target-sparsity) so users get immediate, actionable feedback; implement these
checks in parse_args() using the parsed args variables.
- Around line 192-193: The code currently loads the entire "caption" column then
slices it; change the dataset load to request only the needed rows by using the
HuggingFace split range: replace load_dataset("nkp37/OpenVid-1M", split="train")
with load_dataset("nkp37/OpenVid-1M", split=f"train[:{calib_size}]") and then
build prompts from the returned small dataset (e.g., prompts =
list(dataset["caption"])). This avoids materializing the full caption column for
the tiny calibration sample.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: a396d14c-ebef-49a3-aeb9-c3d398add26a

📥 Commits

Reviewing files that changed from the base of the PR and between 6cc96a4 and 4de0d3b.

📒 Files selected for processing (2)
  • examples/diffusers/README.md
  • examples/diffusers/sparsity/wan22_skip_softmax.py
✅ Files skipped from review due to trivial changes (1)
  • examples/diffusers/README.md

Comment on lines +77 to +135
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="Wan 2.2 video generation with skip-softmax sparse attention"
)
parser.add_argument("--prompt", type=str, required=True, help="Text prompt for generation")
parser.add_argument("--output", type=str, default="output.mp4", help="Output video path")
parser.add_argument(
"--model-path", type=str, default=DEFAULT_MODEL_PATH, help="Wan 2.2 model path or HF ID"
)
parser.add_argument(
"--num-frames", type=int, default=81, help="Number of frames (must be 4k+1)"
)
parser.add_argument("--height", type=int, default=480, help="Video height")
parser.add_argument("--width", type=int, default=832, help="Video width")
parser.add_argument("--num-steps", type=int, default=50, help="Number of inference steps")
parser.add_argument(
"--guidance-scale", type=float, default=5.0, help="Classifier-free guidance scale"
)
parser.add_argument("--seed", type=int, default=42, help="Random seed")

# Sparse attention options
parser.add_argument(
"--skip-first-last",
type=int,
default=2,
help="Number of first/last transformer layers to keep dense (default: 2)",
)

# Calibration options
parser.add_argument(
"--calibrate",
action="store_true",
help="Calibrate threshold via exponential model (recommended)",
)
parser.add_argument(
"--target-sparsity",
type=float,
default=0.5,
help="Target sparsity ratio for calibration (0.0-1.0)",
)
parser.add_argument(
"--calib-steps",
type=int,
default=40,
help="Inference steps for calibration",
)
parser.add_argument(
"--calib-frames",
type=int,
default=151,
help="Number of frames for calibration",
)
parser.add_argument(
"--calib-size",
type=int,
default=4,
help="Number of calibration prompts from OpenVid-1M dataset",
)
return parser.parse_args()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Validate the documented CLI constraints in parse_args().

The help text already says --num-frames / --calib-frames must be 4k+1 and --target-sparsity must be in [0, 1], but invalid values are accepted and will fail later inside diffusers/calibration with much less actionable errors.

Proposed fix
-    return parser.parse_args()
+    args = parser.parse_args()
+    if args.num_frames < 1 or args.num_frames % 4 != 1:
+        parser.error("--num-frames must be of the form 4k+1")
+    if args.skip_first_last < 0:
+        parser.error("--skip-first-last must be >= 0")
+    if args.calibrate:
+        if args.calib_frames < 1 or args.calib_frames % 4 != 1:
+            parser.error("--calib-frames must be of the form 4k+1")
+        if not 0.0 <= args.target_sparsity <= 1.0:
+            parser.error("--target-sparsity must be between 0.0 and 1.0")
+        if args.calib_size < 1:
+            parser.error("--calib-size must be >= 1")
+    return args
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/diffusers/sparsity/wan22_skip_softmax.py` around lines 77 - 135, The
CLI allows invalid values that later break calibration; update parse_args() to
validate arguments after parser.parse_args(): ensure args.num_frames and
args.calib_frames satisfy (value - 1) % 4 == 0 and are > 1 (i.e., 4k+1), ensure
args.target_sparsity is between 0.0 and 1.0 inclusive, and if a check fails call
parser.error(...) with a clear message referencing the offending flag (e.g.,
--num-frames, --calib-frames, --target-sparsity) so users get immediate,
actionable feedback; implement these checks in parse_args() using the parsed
args variables.

Comment on lines +177 to +183
# Add calibration config with threshold trials
if args.calibrate:
sparse_cfg["calibration"] = {
"target_sparse_ratio": {"prefill": args.target_sparsity},
"samples": 1,
"threshold_trials": DEFAULT_THRESHOLD_TRIALS,
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Attach calibration to the *.attn1* config, not to sparse_cfg.

modelopt/torch/sparsity/attention_sparsity/model_sparsify.py describes sparse_cfg as a selector → attributes map. Putting "calibration" beside "*.attn1*" makes it another selector, so the self-attention entries never receive the calibration block.

Proposed fix
     # Add calibration config with threshold trials
     if args.calibrate:
-        sparse_cfg["calibration"] = {
+        attn_cfg["calibration"] = {
             "target_sparse_ratio": {"prefill": args.target_sparsity},
             "samples": 1,
             "threshold_trials": DEFAULT_THRESHOLD_TRIALS,
         }
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
# Add calibration config with threshold trials
if args.calibrate:
sparse_cfg["calibration"] = {
"target_sparse_ratio": {"prefill": args.target_sparsity},
"samples": 1,
"threshold_trials": DEFAULT_THRESHOLD_TRIALS,
}
# Add calibration config with threshold trials
if args.calibrate:
attn_cfg["calibration"] = {
"target_sparse_ratio": {"prefill": args.target_sparsity},
"samples": 1,
"threshold_trials": DEFAULT_THRESHOLD_TRIALS,
}
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/diffusers/sparsity/wan22_skip_softmax.py` around lines 177 - 183,
The calibration block is being added at the top-level sparse_cfg (making
"calibration" a selector) instead of being nested under the self-attention
selector; update the code so that when args.calibrate is true you insert the
calibration dict (using args.target_sparsity, DEFAULT_THRESHOLD_TRIALS and
samples:1) into the "*.attn1*" entry of sparse_cfg (i.e.,
sparse_cfg["*.attn1*"]["calibration"] = {...}) so the attn1 selector receives
the calibration config rather than creating a new selector at the top level.

Comment on lines +192 to +193
dataset = load_dataset("nkp37/OpenVid-1M", split="train")
prompts = list(dataset["caption"][:calib_size])
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

# First, find and examine the file
fd -t f "wan22_skip_softmax.py" && wc -l examples/diffusers/sparsity/wan22_skip_softmax.py

Repository: NVIDIA/Model-Optimizer

Length of output: 169


🏁 Script executed:

# Read the specific lines around 192-193
sed -n '185,200p' examples/diffusers/sparsity/wan22_skip_softmax.py

Repository: NVIDIA/Model-Optimizer

Length of output: 529


🌐 Web query:

huggingface datasets load_dataset indexing behavior [:n] vs split parameter

💡 Result:

  • load_dataset(..., split="train[:N]") (or "train[10:20]", "train[:10%]", etc.) uses the library’s split slicing / ReadInstruction mechanism at load time and returns a datasets.Dataset that contains only that requested subsplit. Examples of this string API are documented (e.g. split="train[10:20]", split="train[:10%]"). [1]

  • ds[:N] after you already loaded a split is Python-style indexing via Dataset.__getitem__. For a slice like [:100], it returns a batch as a plain Python dict of columns, not a new Dataset object—this commonly surprises users. The recommended way to get a smaller Dataset is ds.select(range(N)). [2]

  • Streaming caveat: when streaming=True, the "train[:N]" split-slicing syntax is (historically) not supported. [3]

Rule of thumb

  • Want a Dataset object subset: use split="train[:N]" at load time or ds.select(range(N)) after loading.
  • Want an in-memory batch of examples: use ds[:N].

Sources: [1] https://huggingface.co/docs/datasets/v1.10.2/splits.html • [2] https://discuss.huggingface.co/t/how-to-slice-an-already-loaded-dataset/27946 • [3] huggingface/datasets#5458


Avoid reading the entire OpenVid caption column for a tiny calibration sample.

dataset["caption"][:calib_size] loads the full caption column before slicing, which for OpenVid-1M means reading all captions to extract just 4 prompts. Use the split syntax instead to load only the needed rows at dataset load time.

Proposed fix
-    dataset = load_dataset("nkp37/OpenVid-1M", split="train")
-    prompts = list(dataset["caption"][:calib_size])
+    dataset = load_dataset("nkp37/OpenVid-1M", split=f"train[:{calib_size}]")
+    prompts = list(dataset["caption"])
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
dataset = load_dataset("nkp37/OpenVid-1M", split="train")
prompts = list(dataset["caption"][:calib_size])
dataset = load_dataset("nkp37/OpenVid-1M", split=f"train[:{calib_size}]")
prompts = list(dataset["caption"])
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/diffusers/sparsity/wan22_skip_softmax.py` around lines 192 - 193,
The code currently loads the entire "caption" column then slices it; change the
dataset load to request only the needed rows by using the HuggingFace split
range: replace load_dataset("nkp37/OpenVid-1M", split="train") with
load_dataset("nkp37/OpenVid-1M", split=f"train[:{calib_size}]") and then build
prompts from the returned small dataset (e.g., prompts =
list(dataset["caption"])). This avoids materializing the full caption column for
the tiny calibration sample.

Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
@jingyu-ml jingyu-ml requested a review from a team as a code owner April 7, 2026 19:55
@jingyu-ml jingyu-ml requested a review from kevalmorabia97 April 7, 2026 19:55
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we have a tiny wan2.2 model tested in tests/examples/diffusers for this file?

Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants