Skip to content

Skip Softmax diffusion export#1269

Open
jingyu-ml wants to merge 43 commits intomainfrom
jingyux/diffusion-skip-softmax-2
Open

Skip Softmax diffusion export#1269
jingyu-ml wants to merge 43 commits intomainfrom
jingyux/diffusion-skip-softmax-2

Conversation

@jingyu-ml
Copy link
Copy Markdown
Contributor

@jingyu-ml jingyu-ml commented Apr 15, 2026

What does this PR do?

Type of change: New Feature

Adds HuggingFace checkpoint export for diffusion pipelines calibrated with skip-softmax, on top of the base skip-softmax MR (jingyux/diffusion-skip-softmax). Concretely:

  • _export_diffusers_checkpoint now walks every nn.Module component of a diffusers pipeline, calls export_sparse_attention_config, injects the result into that component's config.json as sparse_attention_config, and additionally writes a unified top-level sparse.yaml keyed by pipeline component (transformer, transformer_2, …). The existing LLM export_hf_checkpoint path also gains a sibling sparse.yaml dump.
  • export_sparse_attention_config is generalized: per-group nesting (group_0.threshold_scale_factor, group_0.raw_threshold, group_0.disabled_layers) so future sparse methods can coexist, plus per-layer disabled_layers reporting and a raw_threshold-only path for uncalibrated use.
  • Log-space calibration export: the calibrator now propagates log_a / fit_logspace through its result dict, and the exporter emits the matching formula: "log_a + b * target_sparsity" for diffusion (linear-space a * exp(b * S) is still used for LLMs).
  • Example wiring: examples/diffusers/sparsity/wan22_skip_softmax.py gets an --export-dir flag that calls export_hf_checkpoint(pipe, export_dir=...) after calibration.
  • Updated CHANGELOG.rst to note diffusion coverage for skip-softmax.

Usage

python examples/diffusers/sparsity/wan22_skip_softmax.py \
    --model-path Wan-AI/Wan2.2-TI2V-5B-Diffusers \
    --calibrate \
    --target-sparsity 0.5 \
    --export-dir ./wan22_skip_softmax_ckpt

Equivalent Python:

from diffusers import WanPipeline
from modelopt.torch.export import export_hf_checkpoint
import modelopt.torch.sparsity.attention_sparsity as mtsa

pipe = WanPipeline.from_pretrained("Wan-AI/Wan2.2-TI2V-5B-Diffusers")
mtsa.calibrate(pipe.transformer, ...)          # from the base skip-softmax MR
export_hf_checkpoint(pipe, export_dir="./wan22_skip_softmax_ckpt")

Resulting layout:

wan22_skip_softmax_ckpt/
├── sparse.yaml                      # unified, keyed by component
├── transformer/
│   └── config.json                  # carries sparse_attention_config
├── transformer_2/
│   └── config.json
├── vae/ …
└── scheduler/ …

A representative config.json entry for a diffusion component:

"sparse_attention_config": {
  "config_groups": {
    "group_0": {
      "sparse_algo": "softmax_skip",
      "targets": ["WanAttention"],
      "threshold_scale_factor": {
        "formula": "log_a + b * target_sparsity",
        "prefill": {"log_a": 0.21, "b": 3.45}
      },
      "disabled_layers": ["blocks.0.attn1", "blocks.39.attn1"]
    }
  },
  "producer": {"name": "modelopt", "version": "0.37.0"}
}

Testing

  • Extended tests/examples/diffusers/test_sparsity.py with a calibrate → export → reload round-trip on a small diffusion pipeline, asserting the presence and shape of sparse_attention_config in each component's config.json and the unified top-level sparse.yaml.
  • Manually verified on Wan2.2-T2V-14B: sparse.yaml and transformer{,_2}/config.json contain the expected log-space threshold_scale_factor, any disabled layers, and producer metadata; a freshly loaded pipeline from the exported checkpoint reproduces the calibrated sparsity target end-to-end.
  • LLM export path regression-checked by re-running existing test_sparsity tests — the new sparse.yaml sibling is emitted without changing the existing config.json patching behavior.

Before your PR is "Ready for review"

Make sure you read and follow Contributor guidelines and your commits are signed (git commit -s -S).

Make sure you read and follow the Security Best Practices (e.g. avoiding hardcoded trust_remote_code=True, torch.load(..., weights_only=False), pickle, etc.).

  • Is this change backward compatible?: ✅
  • If you copied code from any other sources or added a new PIP dependency, did you follow guidance in CONTRIBUTING.md: ✅
  • Did you write any new necessary tests?: ✅
  • Did you update Changelog?: ✅

Additional Information

Summary by CodeRabbit

Release Notes

  • New Features

    • Added skip-softmax sparse attention support for diffusion models (Wan 2.2, with LTX-2 coming soon), enabling both calibration-based and fixed-threshold sparsification modes.
    • Runtime sparsity measurement capability for analyzing actual tile skipping during inference.
    • Sparse attention configuration export to model checkpoints.
  • Documentation

    • Added comprehensive guides and example scripts for skip-softmax sparse attention with Wan 2.2 text-to-video models.
  • Chores

    • Added license compliance warning for LTX-2 package exports.

jingyu-ml and others added 30 commits April 2, 2026 06:02
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
jingyu-ml and others added 3 commits April 15, 2026 05:30
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
@jingyu-ml jingyu-ml requested review from a team as code owners April 15, 2026 22:00
@jingyu-ml jingyu-ml requested review from ajrasane and kaix-nv April 15, 2026 22:00
@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot bot commented Apr 15, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@jingyu-ml jingyu-ml requested review from ChenhanYu and realAsma April 15, 2026 22:00
@jingyu-ml jingyu-ml marked this pull request as draft April 15, 2026 22:00
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Apr 15, 2026

📝 Walkthrough

Walkthrough

This PR implements skip-softmax sparse attention for diffusion models (WAN 2.2 and planned LTX-2 support), introducing a Triton kernel calibration API for multi-threshold sparsity measurement, configurable threshold modes (raw or calibrated), diffusers and LTX-2 backend integration, comprehensive documentation, and an end-to-end example script with inference and calibration workflows.

Changes

Cohort / File(s) Summary
Changelog & Documentation
CHANGELOG.rst, examples/diffusers/README.md, examples/diffusers/sparsity/README.md
Documented skip-softmax sparse attention feature covering both language models and video diffusion. Added quickstart guides with CLI modes, threshold workflows (fixed raw vs. calibrated), and model coverage details.
Example Script
examples/diffusers/sparsity/wan22_skip_softmax.py
Comprehensive WAN 2.2 inference example with CLI modes for dense baseline, fixed/calibrated thresholds, calibration via OpenVid-1M captions, optional checkpoint export, and runtime sparsity reporting.
Triton Kernel Extension
modelopt/torch/kernels/triton_fa.py, modelopt/torch/kernels/__init__.py
Added calibration kernel attention_calibrate() for multi-threshold sparsity statistics, runtime sparsity measurement via atomic counters, raw threshold mode, and normalized denominator clamping to prevent NaN when all KV tiles are skipped.
Calibration System
modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.py, modelopt/torch/sparsity/attention_sparsity/calibration/calibrator.py
Introduced log-space exponential model fitting, lazy tokenizer loading, calibration parameter extraction (log_a, fit_logspace), observed sparsity bounds tracking, and per-phase threshold trial support.
Configuration & Registry
modelopt/torch/sparsity/attention_sparsity/config.py, modelopt/torch/sparsity/attention_sparsity/methods/registry.py
Added skip_softmax_raw_threshold, fit_logspace config fields, and calibration mode toggle in sparse method base class.
Sparse Attention Methods
modelopt/torch/sparsity/attention_sparsity/methods/triton_skip_softmax.py, modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py
Extended TritonSkipSoftmaxMethod with dual calibration/inference context modes, threshold priority logic (raw > scale_factor > static), sparsity measurement APIs, and backend configuration helpers. Updated flash context to enable skip-softmax routing via thread-local flag.
Backend Integrations
modelopt/torch/sparsity/attention_sparsity/kernels/diffusers_triton_attention.py, modelopt/torch/sparsity/attention_sparsity/kernels/ltx_triton_attention.py, modelopt/torch/sparsity/attention_sparsity/kernels/__init__.py
Implemented Triton attention backends for diffusers and LTX-2 with tensor reshaping, varlen metadata construction, calibration/inference dispatch, thread-local config APIs, and register_*_triton_attention() functions. Added skip-softmax context thread-local mechanism.
Conversion & Export
modelopt/torch/sparsity/attention_sparsity/conversion.py, modelopt/torch/export/unified_export_hf.py, modelopt/torch/export/diffusers_utils.py
Added backend auto-registration during conversion, sparse attention config export to config.json and sparse.yaml with nested config_groups, raw threshold export, formula-mode detection (log-space vs. linear-space), disabled layer tracking, and LTX-2 license disclaimer warning.
Utilities & Plugin Updates
modelopt/torch/sparsity/attention_sparsity/plugins/huggingface.py, modelopt/torch/sparsity/attention_sparsity/stats_manager.py
Made model type checks lazy (diffusers ModelMixin support), added optional sparse_blocks field handling and conditional normalized_gaps in stats aggregation.
Test Utilities & Suite
tests/_test_utils/torch/diffusers_models.py, tests/examples/diffusers/test_sparsity.py, tests/unit/torch/sparsity/attention_sparsity/test_kernel_backends.py
Added tiny WAN 2.2 pipeline fixtures and builders. Integration tests for baseline, triton-baseline, raw-threshold, and calibrated export modes with artifact validation. Unit tests for skip-softmax context API and backend registration behavior.

Sequence Diagram(s)

sequenceDiagram
    participant User
    participant CalibScript as Calibration Script
    participant Pipeline as Diffusers Pipeline
    participant TritonKernel as Triton attention_calibrate
    participant Calibrator as DynamicThresholdCalibrator
    
    User->>CalibScript: invoke with --calibrate --target-sparsity 0.5
    CalibScript->>Pipeline: build WAN 2.2 model
    CalibScript->>Pipeline: convert_to_sparse_attention_model
    CalibScript->>Pipeline: set calibration mode & threshold trials
    
    loop calibration forward passes (OpenVid-1M captions)
        CalibScript->>Pipeline: forward(sample_prompts)
        Pipeline->>TritonKernel: attention_calibrate with threshold_trials
        TritonKernel-->>Pipeline: (output, sparsity_counters[num_thresholds, 2])
        Pipeline->>Calibrator: collect per-threshold sparsity stats
    end
    
    Calibrator->>Calibrator: fit log(scale_factor) = log_a + b*sparsity
    Calibrator-->>Pipeline: return calibration_params {log_a, b, fit_logspace}
    Pipeline->>Pipeline: store threshold formula in module config
    CalibScript->>CalibScript: free GPU memory
Loading
sequenceDiagram
    participant User
    participant InferenceScript as Inference Script
    participant DiffusersBackend as Diffusers Triton Backend
    participant Pipeline as WAN 2.2 Pipeline
    participant TritonKernel as Triton attention (inference)
    
    User->>InferenceScript: invoke with --raw-threshold -5.0
    InferenceScript->>Pipeline: load sparsified model
    InferenceScript->>DiffusersBackend: set_triton_skip_softmax_config(raw_threshold=-5.0)
    
    InferenceScript->>Pipeline: pipeline(prompt, num_inference_steps=50)
    
    loop diffusion steps
        Pipeline->>DiffusersBackend: _diffusers_triton_attention (forward)
        DiffusersBackend->>DiffusersBackend: reshape [B, S, H, D] → [B*S, H, D]
        DiffusersBackend->>TritonKernel: attention(query, key, value, skip_softmax_raw_threshold=-5.0)
        TritonKernel->>TritonKernel: tile-wise skip_softmax: skip tile if log2(score_max) < -5.0
        TritonKernel-->>DiffusersBackend: (output, optional sparsity counters)
        DiffusersBackend->>DiffusersBackend: reshape output back [B, S, H, D]
        DiffusersBackend-->>Pipeline: attention output
    end
    
    Pipeline-->>InferenceScript: generated video
    InferenceScript->>InferenceScript: print per-module threshold info & runtime sparsity
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~75 minutes

🚥 Pre-merge checks | ✅ 4
✅ Passed checks (4 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title 'Skip Softmax diffusion export' directly describes the main feature added: exporting skip-softmax sparse attention configuration for diffusion model pipelines.
Docstring Coverage ✅ Passed Docstring coverage is 82.54% which is sufficient. The required threshold is 80.00%.
Security Anti-Patterns ✅ Passed No insecure patterns detected in modified or added Python code.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
📝 Generate docstrings
  • Create stacked PR
  • Commit on current branch
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch jingyux/diffusion-skip-softmax-2

Comment @coderabbitai help to get the list of available commands and usage tips.

@jingyu-ml jingyu-ml self-assigned this Apr 15, 2026
@github-actions
Copy link
Copy Markdown
Contributor

github-actions bot commented Apr 15, 2026

PR Preview Action v1.8.1

QR code for preview link

🚀 View preview at
https://NVIDIA.github.io/Model-Optimizer/pr-preview/pr-1269/

Built to branch gh-pages at 2026-04-17 00:02 UTC.
Preview will be ready when the GitHub Pages deployment is complete.

@codecov
Copy link
Copy Markdown

codecov bot commented Apr 15, 2026

Codecov Report

❌ Patch coverage is 25.38330% with 438 lines in your changes missing coverage. Please review.
✅ Project coverage is 55.31%. Comparing base (04fcf24) to head (b77d098).

Files with missing lines Patch % Lines
modelopt/torch/kernels/triton_fa.py 0.00% 108 Missing ⚠️
.../attention_sparsity/methods/triton_skip_softmax.py 16.37% 97 Missing ⚠️
...attention_sparsity/kernels/ltx_triton_attention.py 19.54% 70 Missing ⚠️
...ion_sparsity/kernels/diffusers_triton_attention.py 48.51% 52 Missing ⚠️
...rsity/attention_sparsity/calibration/calibrator.py 9.75% 37 Missing ⚠️
...pt/torch/sparsity/attention_sparsity/conversion.py 40.42% 28 Missing ⚠️
modelopt/torch/export/unified_export_hf.py 26.08% 17 Missing ⚠️
...arsity/attention_sparsity/calibration/calibrate.py 19.04% 17 Missing ⚠️
...y/attention_sparsity/methods/flash_skip_softmax.py 14.28% 6 Missing ⚠️
modelopt/torch/kernels/__init__.py 33.33% 2 Missing ⚠️
... and 4 more
Additional details and impacted files
@@             Coverage Diff             @@
##             main    #1269       +/-   ##
===========================================
- Coverage   75.58%   55.31%   -20.28%     
===========================================
  Files         459      460        +1     
  Lines       48612    49345      +733     
===========================================
- Hits        36745    27295     -9450     
- Misses      11867    22050    +10183     
Flag Coverage Δ
unit 51.95% <25.38%> (-0.27%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

jingyu-ml and others added 5 commits April 15, 2026 22:14
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
@jingyu-ml jingyu-ml changed the title Jingyux/diffusion skip softmax 2 Skip Softmax diffusion export Apr 16, 2026
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
@jingyu-ml jingyu-ml marked this pull request as ready for review April 17, 2026 00:03
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 6

🧹 Nitpick comments (5)
modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py (1)

387-394: Minor: Consider wrapping set_skip_softmax_context(True) in the ExitStack for exception safety.

If an exception occurs after set_skip_softmax_context(True) but before the callback is registered (e.g., during stack.enter_context), the skip-softmax context would remain enabled without cleanup.

♻️ Safer ordering
         from ..kernels import set_skip_softmax_context

         stack = ExitStack()
+        stack.callback(set_skip_softmax_context, False)
         set_skip_softmax_context(True)
-        stack.callback(set_skip_softmax_context, False)

         stack.enter_context(replace_function(torch.nn.functional, "softmax", sparse_softmax))
         return stack

This ensures the cleanup callback is registered before the state is modified, so any subsequent exception during setup will still trigger cleanup when the stack is garbage collected or explicitly closed.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py`
around lines 387 - 394, Register the cleanup callback on the ExitStack before
enabling the skip-softmax flag to ensure exception safety: call
stack.callback(set_skip_softmax_context, False) first, then call
set_skip_softmax_context(True), and only after that perform
stack.enter_context(replace_function(torch.nn.functional, "softmax",
sparse_softmax)); this guarantees that if an exception occurs during
enter_context the skip-softmax state (managed by set_skip_softmax_context) will
still be cleaned up.
modelopt/torch/export/diffusers_utils.py (1)

49-59: Consider making this a one-time warning or moving it to actual usage.

This warning fires at module import time whenever ltx_pipelines is installed, which may be noisy for users who import diffusers_utils but don't use LTX-2 features. Additionally, stacklevel=2 at module load time may not point to a meaningful location.

Consider using warnings.warn(..., stacklevel=2) with filterwarnings to show once, or deferring the warning to actual LTX-2 usage (similar to line 395-404 in _ltx2_dummy_forward).

♻️ One-time warning option
import warnings

# At the top of the module
_LTX2_LICENSE_WARNING_SHOWN = False

# Inside the try block
try:
    from ltx_pipelines.ti2vid_two_stages import TI2VidTwoStagesPipeline as _TI2VidTwoStagesPipeline
    
    def _show_ltx2_license_warning():
        global _LTX2_LICENSE_WARNING_SHOWN
        if not _LTX2_LICENSE_WARNING_SHOWN:
            warnings.warn(
                "LTX-2 packages ... (license text)",
                UserWarning,
                stacklevel=3,
            )
            _LTX2_LICENSE_WARNING_SHOWN = True
    
    TI2VidTwoStagesPipeline = _TI2VidTwoStagesPipeline
except Exception:
    TI2VidTwoStagesPipeline = None

Then call _show_ltx2_license_warning() in functions that actually use LTX-2.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/export/diffusers_utils.py` around lines 49 - 59, The module
currently emits a loud license UserWarning at import via the warnings.warn call;
change this to a one-time or deferred warning: wrap the import of ltx_pipelines
and the current warnings.warn invocation behind a module-level flag (e.g.,
_LTX2_LICENSE_WARNING_SHOWN) or remove the warn from import and instead call a
small helper like _show_ltx2_license_warning() from actual LTX-2 entrypoints
(for example inside TI2VidTwoStagesPipeline usage code or _ltx2_dummy_forward)
so the warning is emitted at first use only; keep stacklevel at an appropriate
value (e.g., 3) when invoking warnings.warn to point to user code and ensure
TI2VidTwoStagesPipeline remains set to None on import failure.
modelopt/torch/export/unified_export_hf.py (1)

1261-1266: Remove redundant import.

The yaml module is already imported at line 31. This inline import is unnecessary.

♻️ Proposed fix
                 config_data["sparse_attention_config"] = sparse_attn_config

                 # Also save as standalone YAML for easy inspection and reuse
-                import yaml
-
                 yaml_path = Path(export_dir) / "sparse.yaml"
                 with open(yaml_path, "w") as file:
                     yaml.dump(sparse_attn_config, file, default_flow_style=False, sort_keys=False)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/export/unified_export_hf.py` around lines 1261 - 1266, Remove
the redundant inline "import yaml" inside the export block; the module is
already imported at the top of the file, so delete the inline import and keep
the yaml usage (yaml_path = Path(export_dir) / "sparse.yaml" and
yaml.dump(sparse_attn_config, ...)) as-is to avoid duplicate imports and retain
the YAML file write behavior.
modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.py (1)

314-318: Consider documenting this asymmetry.

The decode phase always requires calibration_data and tokenizer (RULER-based), even when a custom forward_loop was provided for prefill. This asymmetry between prefill (supports custom forward_loop) and decode (always requires RULER) could be confusing.

Consider adding a docstring note or raising a more descriptive error message explaining why decode requires the RULER dataset.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.py` around
lines 314 - 318, Update the documentation and error to make the prefill/decode
asymmetry explicit: add a docstring note on the top-level function in
calibrate.py (where decode_forward_loop is created) explaining that prefill
accepts a custom forward_loop but decode always requires a RULER-style
calibration_data and tokenizer because decode uses
create_decode_calibration_forward_loop; and replace the RuntimeError raised when
calibration_data or tokenizer is missing with a more descriptive message that
states "decode requires a RULER-style calibration_data and tokenizer (used by
create_decode_calibration_forward_loop) even if a custom prefill forward_loop
was provided." Reference the symbols calibration_data, tokenizer,
create_decode_calibration_forward_loop, and decode_forward_loop in the
docstring/error.
examples/diffusers/sparsity/wan22_skip_softmax.py (1)

56-63: Lazy-load the optional datasets/diffusers dependencies.

This example hard-imports optional integrations at module load time, which makes simple imports fail unless the full diffusers stack is already installed. Moving these imports into build_pipeline(), load_calib_prompts(), and main() keeps the example gated behind the right extras and avoids breaking unrelated tooling that imports example modules. As per coding guidelines, "Gate optional features by install extras ([onnx], [hf], [all]); avoid hard imports of optional dependencies at module level" and "Use the plugin system with import_plugin() for lazy loading of optional integrations (HuggingFace, Megatron, etc.)".

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/diffusers/sparsity/wan22_skip_softmax.py` around lines 56 - 63, The
module currently hard-imports optional packages (datasets, diffusers,
diffusers.utils.export_to_video and diffusers classes
AutoencoderKLWan/WanPipeline) at top-level; move those imports into the
functions that actually use them (e.g., build_pipeline(), load_calib_prompts(),
and main()) so the example can be imported without the optional extras
installed, and prefer the plugin loader where available (e.g., call
import_plugin or similar before importing HuggingFace/diffusers modules) to gate
the integrations; update references to SparseAttentionModule and any diffusers
types after the local imports so runtime code uses the lazily-loaded symbols.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@examples/diffusers/sparsity/wan22_skip_softmax.py`:
- Around line 271-276: The function load_calib_prompts currently loads the
entire "train" split which is wasteful; change it to only load the first
calib_size examples by using a sliced split (e.g. "train[:{calib_size}]" or
equivalent) or streaming so only the needed items are materialized, then collect
the captions and return them; update load_calib_prompts to call load_dataset
with the sliced split string based on the calib_size parameter and build the
prompts list from that smaller dataset.

In `@modelopt/torch/kernels/triton_fa.py`:
- Around line 1241-1246: The program index calculation uses a per-program tile
count computed from tl.load(b_seq_len + 0) which collides when batches have
variable lengths; replace the local num_q_tiles = tl.cdiv(tl.load(b_seq_len +
0), BLOCK_M) with a tile count derived from the same max_input_len used when
allocating counters (e.g. num_q_tiles = tl.cdiv(max_input_len, BLOCK_M)), ensure
the kernel signature accepts that max_input_len scalar (or otherwise pass the
same allocation param) and update uses of num_q_tiles/prog_idx/base so each
(batch, head, tile) maps to a unique counter slot matching
attention_calibrate()'s allocation.

In `@modelopt/torch/sparsity/attention_sparsity/conversion.py`:
- Around line 127-136: The current broad try/except hides real errors from
register_diffusers_triton_attention() so integration failures are swallowed;
change the logic to only suppress ImportError when importing ModelMixin but
allow exceptions from register_diffusers_triton_attention() to surface: first
try importing ModelMixin and on ImportError simply return/skip, then if
isinstance(model, ModelMixin) import register_diffusers_triton_attention and if
it's not None call register_diffusers_triton_attention() without a broad except
so any runtime errors propagate (or re-raise after logging) — refer to
ModelMixin and register_diffusers_triton_attention to locate and update the
conversion.py block.

In
`@modelopt/torch/sparsity/attention_sparsity/kernels/diffusers_triton_attention.py`:
- Around line 123-213: The function _diffusers_triton_attention currently
accepts attn_mask, dropout_p, and enable_gqa but ignores them; either remove
these params from the signature and ensure the backend's _supported_arg_names
(where supported args are derived) no longer lists them, or implement their
semantics: detect attn_mask != None and convert it into the Triton/kernel
expected mask metadata (or pass it via kw as "attn_mask" / appropriate key),
handle dropout_p > 0 by passing a "dropout_p" kw and ensuring training mode
semantics are respected, and honor enable_gqa by adjusting q/k/v shapes/scale
(grouped-query attention layout changes) and passing an "enable_gqa" or
equivalent flag into kw; update callers/registration so _supported_arg_names
matches the final signature/behavior.

In `@modelopt/torch/sparsity/attention_sparsity/kernels/ltx_triton_attention.py`:
- Around line 164-168: The current __call__ wrapper incorrectly routes masked
attention through _ltx_triton_attention which ignores mask and forces
is_causal=False; change the conditional to only use the Triton path when active
and mask is None (e.g., if active and mask is None: return
_ltx_triton_attention(...)); otherwise always call and return
self._original_fn(q, k, v, heads, mask) so mask semantics are preserved. Ensure
you reference _get_ltx_triton_context, _ltx_triton_attention, and
self._original_fn when making the change.

In `@modelopt/torch/sparsity/attention_sparsity/methods/triton_skip_softmax.py`:
- Around line 255-269: get_threshold_info() currently always reports the static
lambda threshold even when _triton_inference_context() supplies a
skip_softmax_raw_threshold used by the kernel; update get_threshold_info to
check the triton inference context (via _triton_inference_context()) and, if a
skip_softmax_raw_threshold is present, return a "raw" type with that raw
threshold value (and note any related keys like skip_softmax_threshold or
target_sparse_ratio as applicable); otherwise preserve the existing
dynamic_calibrated/static return paths so the reported info matches the actual
value used by the kernel.

---

Nitpick comments:
In `@examples/diffusers/sparsity/wan22_skip_softmax.py`:
- Around line 56-63: The module currently hard-imports optional packages
(datasets, diffusers, diffusers.utils.export_to_video and diffusers classes
AutoencoderKLWan/WanPipeline) at top-level; move those imports into the
functions that actually use them (e.g., build_pipeline(), load_calib_prompts(),
and main()) so the example can be imported without the optional extras
installed, and prefer the plugin loader where available (e.g., call
import_plugin or similar before importing HuggingFace/diffusers modules) to gate
the integrations; update references to SparseAttentionModule and any diffusers
types after the local imports so runtime code uses the lazily-loaded symbols.

In `@modelopt/torch/export/diffusers_utils.py`:
- Around line 49-59: The module currently emits a loud license UserWarning at
import via the warnings.warn call; change this to a one-time or deferred
warning: wrap the import of ltx_pipelines and the current warnings.warn
invocation behind a module-level flag (e.g., _LTX2_LICENSE_WARNING_SHOWN) or
remove the warn from import and instead call a small helper like
_show_ltx2_license_warning() from actual LTX-2 entrypoints (for example inside
TI2VidTwoStagesPipeline usage code or _ltx2_dummy_forward) so the warning is
emitted at first use only; keep stacklevel at an appropriate value (e.g., 3)
when invoking warnings.warn to point to user code and ensure
TI2VidTwoStagesPipeline remains set to None on import failure.

In `@modelopt/torch/export/unified_export_hf.py`:
- Around line 1261-1266: Remove the redundant inline "import yaml" inside the
export block; the module is already imported at the top of the file, so delete
the inline import and keep the yaml usage (yaml_path = Path(export_dir) /
"sparse.yaml" and yaml.dump(sparse_attn_config, ...)) as-is to avoid duplicate
imports and retain the YAML file write behavior.

In `@modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.py`:
- Around line 314-318: Update the documentation and error to make the
prefill/decode asymmetry explicit: add a docstring note on the top-level
function in calibrate.py (where decode_forward_loop is created) explaining that
prefill accepts a custom forward_loop but decode always requires a RULER-style
calibration_data and tokenizer because decode uses
create_decode_calibration_forward_loop; and replace the RuntimeError raised when
calibration_data or tokenizer is missing with a more descriptive message that
states "decode requires a RULER-style calibration_data and tokenizer (used by
create_decode_calibration_forward_loop) even if a custom prefill forward_loop
was provided." Reference the symbols calibration_data, tokenizer,
create_decode_calibration_forward_loop, and decode_forward_loop in the
docstring/error.

In `@modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py`:
- Around line 387-394: Register the cleanup callback on the ExitStack before
enabling the skip-softmax flag to ensure exception safety: call
stack.callback(set_skip_softmax_context, False) first, then call
set_skip_softmax_context(True), and only after that perform
stack.enter_context(replace_function(torch.nn.functional, "softmax",
sparse_softmax)); this guarantees that if an exception occurs during
enter_context the skip-softmax state (managed by set_skip_softmax_context) will
still be cleaned up.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro Plus

Run ID: b4db3917-e462-4d17-b588-79e5f63acc4d

📥 Commits

Reviewing files that changed from the base of the PR and between 04fcf24 and b77d098.

📒 Files selected for processing (23)
  • CHANGELOG.rst
  • examples/diffusers/README.md
  • examples/diffusers/sparsity/README.md
  • examples/diffusers/sparsity/wan22_skip_softmax.py
  • modelopt/torch/export/diffusers_utils.py
  • modelopt/torch/export/unified_export_hf.py
  • modelopt/torch/kernels/__init__.py
  • modelopt/torch/kernels/triton_fa.py
  • modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.py
  • modelopt/torch/sparsity/attention_sparsity/calibration/calibrator.py
  • modelopt/torch/sparsity/attention_sparsity/config.py
  • modelopt/torch/sparsity/attention_sparsity/conversion.py
  • modelopt/torch/sparsity/attention_sparsity/kernels/__init__.py
  • modelopt/torch/sparsity/attention_sparsity/kernels/diffusers_triton_attention.py
  • modelopt/torch/sparsity/attention_sparsity/kernels/ltx_triton_attention.py
  • modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py
  • modelopt/torch/sparsity/attention_sparsity/methods/registry.py
  • modelopt/torch/sparsity/attention_sparsity/methods/triton_skip_softmax.py
  • modelopt/torch/sparsity/attention_sparsity/plugins/huggingface.py
  • modelopt/torch/sparsity/attention_sparsity/stats_manager.py
  • tests/_test_utils/torch/diffusers_models.py
  • tests/examples/diffusers/test_sparsity.py
  • tests/unit/torch/sparsity/attention_sparsity/test_kernel_backends.py

Comment on lines +271 to +276
def load_calib_prompts(calib_size: int) -> list[str]:
"""Load calibration prompts from OpenVid-1M dataset."""
dataset = load_dataset("nkp37/OpenVid-1M", split="train")
prompts = list(dataset["caption"][:calib_size])
print(f"Loaded {len(prompts)} calibration prompts from OpenVid-1M")
return prompts
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Don’t load the full OpenVid-1M split just to grab a few prompts.

load_dataset(..., split="train") materializes the whole split metadata, and then this code only uses the first calib_size captions. For the default calib_size=4, that is a lot of unnecessary I/O and makes calibration much slower and more fragile than it needs to be.

🔧 Suggested fix
 def load_calib_prompts(calib_size: int) -> list[str]:
     """Load calibration prompts from OpenVid-1M dataset."""
-    dataset = load_dataset("nkp37/OpenVid-1M", split="train")
-    prompts = list(dataset["caption"][:calib_size])
+    dataset = load_dataset("nkp37/OpenVid-1M", split=f"train[:{calib_size}]")
+    prompts = list(dataset["caption"])
     print(f"Loaded {len(prompts)} calibration prompts from OpenVid-1M")
     return prompts
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/diffusers/sparsity/wan22_skip_softmax.py` around lines 271 - 276,
The function load_calib_prompts currently loads the entire "train" split which
is wasteful; change it to only load the first calib_size examples by using a
sliced split (e.g. "train[:{calib_size}]" or equivalent) or streaming so only
the needed items are materialized, then collect the captions and return them;
update load_calib_prompts to call load_dataset with the sliced split string
based on the calib_size parameter and build the prompts list from that smaller
dataset.

Comment on lines +1241 to +1246
# --- Write per-program counters (no atomics, just stores) ---
# Compute unique flat program index for this (batch, head, q_tile)
num_q_tiles = tl.cdiv(tl.load(b_seq_len + 0), BLOCK_M) # conservative upper bound
num_heads = tl.num_programs(1)
prog_idx = batch_idx * num_heads * num_q_tiles + head_idx * num_q_tiles + tile_q
base = prog_idx * NUM_THRESHOLDS
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Program indexing collides for variable-length batches.

attention_calibrate() allocates one counter row per launched program using triton.cdiv(max_input_len, BLOCK_M), but the kernel flattens (batch, head, tile) with tl.cdiv(tl.load(b_seq_len + 0), BLOCK_M). If batch element 0 is shorter than another sequence, different programs write into the same counter slots and the exported calibration stats are wrong.

🔧 Suggested fix
-    num_q_tiles = tl.cdiv(tl.load(b_seq_len + 0), BLOCK_M)  # conservative upper bound
+    num_q_tiles = tl.num_programs(2)
     num_heads = tl.num_programs(1)
     prog_idx = batch_idx * num_heads * num_q_tiles + head_idx * num_q_tiles + tile_q
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/kernels/triton_fa.py` around lines 1241 - 1246, The program
index calculation uses a per-program tile count computed from tl.load(b_seq_len
+ 0) which collides when batches have variable lengths; replace the local
num_q_tiles = tl.cdiv(tl.load(b_seq_len + 0), BLOCK_M) with a tile count derived
from the same max_input_len used when allocating counters (e.g. num_q_tiles =
tl.cdiv(max_input_len, BLOCK_M)), ensure the kernel signature accepts that
max_input_len scalar (or otherwise pass the same allocation param) and update
uses of num_q_tiles/prog_idx/base so each (batch, head, tile) maps to a unique
counter slot matching attention_calibrate()'s allocation.

Comment on lines +127 to +136
try:
from diffusers.models.modeling_utils import ModelMixin

if isinstance(model, ModelMixin):
from .kernels import register_diffusers_triton_attention

if register_diffusers_triton_attention is not None:
register_diffusers_triton_attention()
except (ImportError, Exception):
pass
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Don’t swallow diffusers backend registration failures.

If register_diffusers_triton_attention() raises here, conversion still continues and later paths can still select "modelopt_triton". That turns a real integration failure into a later runtime error or a silent fallback. Only missing optional deps should be suppressed here; actual registration errors should surface.

🔧 Suggested fix
-    try:
-        from diffusers.models.modeling_utils import ModelMixin
-
-        if isinstance(model, ModelMixin):
-            from .kernels import register_diffusers_triton_attention
-
-            if register_diffusers_triton_attention is not None:
-                register_diffusers_triton_attention()
-    except (ImportError, Exception):
-        pass
+    try:
+        from diffusers.models.modeling_utils import ModelMixin
+    except ImportError:
+        ModelMixin = None
+
+    if ModelMixin is not None and isinstance(model, ModelMixin):
+        from .kernels import register_diffusers_triton_attention
+
+        if register_diffusers_triton_attention is not None:
+            register_diffusers_triton_attention()
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/sparsity/attention_sparsity/conversion.py` around lines 127 -
136, The current broad try/except hides real errors from
register_diffusers_triton_attention() so integration failures are swallowed;
change the logic to only suppress ImportError when importing ModelMixin but
allow exceptions from register_diffusers_triton_attention() to surface: first
try importing ModelMixin and on ImportError simply return/skip, then if
isinstance(model, ModelMixin) import register_diffusers_triton_attention and if
it's not None call register_diffusers_triton_attention() without a broad except
so any runtime errors propagate (or re-raise after logging) — refer to
ModelMixin and register_diffusers_triton_attention to locate and update the
conversion.py block.

Comment on lines +123 to +213
def _diffusers_triton_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: torch.Tensor | None = None,
dropout_p: float = 0.0,
is_causal: bool = False,
scale: float | None = None,
enable_gqa: bool = False,
) -> torch.Tensor:
"""Compute attention via Triton FA kernel on diffusers layout ``[B, S, H, D]``."""
batch, seq_q, num_heads_q, head_dim = query.shape
seq_k = key.shape[1]
device = query.device

# Reshape from diffusers [B, S, H, D] -> flat [B*S, H, D]
q = query.reshape(batch * seq_q, num_heads_q, head_dim).contiguous()
k = key.reshape(batch * seq_k, key.shape[2], head_dim).contiguous()
v = value.reshape(batch * seq_k, value.shape[2], head_dim).contiguous()

# Build varlen metadata
b_start_loc_q = torch.arange(batch, device=device, dtype=torch.int32) * seq_q
b_seq_len_q = torch.full((batch,), seq_q, device=device, dtype=torch.int32)

if scale is None:
scale = 1.0 / math.sqrt(head_dim)

kw: dict = {
"b_start_loc": b_start_loc_q,
"b_seq_len": b_seq_len_q,
"max_input_len": seq_q,
"is_causal": is_causal,
"softmax_scale": scale,
}

if seq_q != seq_k:
b_start_loc_k = torch.arange(batch, device=device, dtype=torch.int32) * seq_k
b_seq_len_k = torch.full((batch,), seq_k, device=device, dtype=torch.int32)
kw["b_start_loc_k"] = b_start_loc_k
kw["b_seq_len_k"] = b_seq_len_k
kw["max_input_len_k"] = seq_k

# --- Calibration mode: collect multi-threshold stats ---
calib_mode = getattr(_thread_local, "calibration_mode", False)
if calib_mode:
trials = getattr(_thread_local, "threshold_trials", None)
if trials and attention_calibrate is not None:
o, counters = attention_calibrate(q, k, v, **kw, threshold_trials=trials)

# Accumulate counters across all attention calls in this forward pass
prev = getattr(_thread_local, "calibration_counters", None)
if prev is None:
_thread_local.calibration_counters = counters
else:
_thread_local.calibration_counters = prev + counters

# Store actual KV sequence length for calibration stats
_thread_local.calibration_seq_k = seq_k

return o.view(batch, seq_q, num_heads_q, head_dim)

# --- Inference mode: skip-softmax with raw, dynamic, or static threshold ---
raw_thresh = getattr(_thread_local, "raw_threshold", None)
if raw_thresh is not None:
# Raw threshold: passed directly to kernel as skip_threshold_log2
kw["skip_softmax_raw_threshold"] = raw_thresh
else:
scale_factor = getattr(_thread_local, "scale_factor", None)
if scale_factor is not None and scale_factor > 0.0:
# Dynamic threshold: adapt to actual sequence length
kw["skip_softmax_threshold"] = scale_factor / seq_k
else:
threshold = getattr(_thread_local, "skip_threshold", None)
if threshold is not None and threshold > 0.0:
kw["skip_softmax_threshold"] = threshold

assert attention is not None, "Triton attention kernel not available (requires CUDA + triton)"
do_measure = getattr(_thread_local, "measure_sparsity", False)
if do_measure:
kw["measure_sparsity"] = True
o = attention(q, k, v, **kw)

# Accumulate runtime sparsity counters from the kernel output
if do_measure and hasattr(o, "_sparsity_total"):
prev_total = getattr(_thread_local, "sparsity_total", 0)
prev_skipped = getattr(_thread_local, "sparsity_skipped", 0)
_thread_local.sparsity_total = prev_total + o._sparsity_total
_thread_local.sparsity_skipped = prev_skipped + o._sparsity_skipped

return o.view(batch, seq_q, num_heads_q, head_dim)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🌐 Web query:

In diffusers 0.32.2 custom attention backends, can attn_mask, dropout_p, or enable_gqa be non-default for registered backends, and are custom backends expected to preserve those semantics?

💡 Result:

In diffusers 0.32.2, custom attention backends registered via _AttentionBackendRegistry can have non-default support for attn_mask, dropout_p, and enable_gqa. The dispatch_attention_fn passes these parameters (along with query, key, value, is_causal, scale, attention_kwargs) to the registered backend function. Only parameters matching the backend's supported_arg_names are passed; unsupported ones are filtered out without error (though checks may raise if enabled). Custom backends are expected to handle these parameters if listed in their supported args during registration, or ignore/filter them as appropriate to preserve semantics when non-None/non-default. Some registered backends do not support attn_mask, leading to silent ignoring unless checks are enabled (DIFFUSERS_ATTN_CHECKS=yes), which validates shapes and raises errors for mismatches.

Citations:


🏁 Script executed:

# Find the file and look for registration/decorator code
cat -n modelopt/torch/sparsity/attention_sparsity/kernels/diffusers_triton_attention.py | head -150

Repository: NVIDIA/Model-Optimizer

Length of output: 6967


🏁 Script executed:

# Continue reading to find registration/decorator code
cat -n modelopt/torch/sparsity/attention_sparsity/kernels/diffusers_triton_attention.py | tail -100

Repository: NVIDIA/Model-Optimizer

Length of output: 4972


🏁 Script executed:

# Double-check the signature one more time to ensure accuracy
python3 << 'EOF'
import inspect
sig = """def _diffusers_triton_attention(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    attn_mask: torch.Tensor | None = None,
    dropout_p: float = 0.0,
    is_causal: bool = False,
    scale: float | None = None,
    enable_gqa: bool = False,
) -> torch.Tensor:"""

# Extract parameter names
params = ['query', 'key', 'value', 'attn_mask', 'dropout_p', 'is_causal', 'scale', 'enable_gqa']
print("Function parameters:", params)
print("\nParameters with defaults that could be non-default:")
print("- attn_mask (default None, but can be passed as Tensor)")
print("- dropout_p (default 0.0, but can be non-zero)")
print("- enable_gqa (default False, but can be True)")
EOF

Repository: NVIDIA/Model-Optimizer

Length of output: 376


Remove unsupported parameters from the function signature or implement their semantics.

This backend is registered with attn_mask, dropout_p, and enable_gqa in its _supported_arg_names (lines 237–239 derive supported args from the function signature), which means diffusers will pass these arguments to the implementation. However, lines 123–212 silently ignore all three, computing plain attention regardless. This breaks correctness if a caller provides a non-None attention mask, non-zero dropout, or enables GQA.

Either:

  • Remove these parameters from the signature and handle the filtered-out case explicitly, or
  • Implement the masking, dropout, and GQA semantics in the Triton path.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In
`@modelopt/torch/sparsity/attention_sparsity/kernels/diffusers_triton_attention.py`
around lines 123 - 213, The function _diffusers_triton_attention currently
accepts attn_mask, dropout_p, and enable_gqa but ignores them; either remove
these params from the signature and ensure the backend's _supported_arg_names
(where supported args are derived) no longer lists them, or implement their
semantics: detect attn_mask != None and convert it into the Triton/kernel
expected mask metadata (or pass it via kw as "attn_mask" / appropriate key),
handle dropout_p > 0 by passing a "dropout_p" kw and ensuring training mode
semantics are respected, and honor enable_gqa by adjusting q/k/v shapes/scale
(grouped-query attention layout changes) and passing an "enable_gqa" or
equivalent flag into kw; update callers/registration so _supported_arg_names
matches the final signature/behavior.

Comment on lines +164 to +168
def __call__(self, q, k, v, heads, mask=None):
active, threshold, _scale_factor = _get_ltx_triton_context()
if active:
return _ltx_triton_attention(q, k, v, heads, mask, threshold)
return self._original_fn(q, k, v, heads, mask)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🌐 Web query:

In ltx_core's Attention.attention_function API, can the mask argument be non-None during normal inference or training, and what semantics must a custom backend preserve for it?

💡 Result:

Yes, the mask argument in ltx_core's Attention.attention_function API can be non-None during both normal inference and training. The API signature for AttentionCallable.call is def call(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, heads: int, mask: torch.Tensor | None = None) -> torch.Tensor:. The Attention module's forward method passes a potentially non-None mask from the Modality.attention_mask, which is described as "Optional 2-D self-attention mask, shape (B, T, T). Values in [0, 1] where 1 = full attention and 0 = no attention. None means unrestricted (full) attention between all tokens. Built incrementally by conditioning items." Custom backends must preserve the semantics of applying the mask as an additive attention bias in scaled dot-product attention, matching PyTorch's torch.nn.functional.scaled_dot_product_attention behavior where attn_mask=mask is passed with dropout_p=0.0 and is_causal=False. The PytorchAttention implementation reshapes the mask if ndim==2 by unsqueezing to add batch dim, then passes directly. XFormers and FlashAttention implementations also accept and use the mask parameter. During inference, dropout_p=0.0 is used (no dropout), and the same API applies to training where masks from conditioning (e.g., attention_strength_wrapper) can be non-None. Perturbations skip entire attention computations but do not alter mask handling when attention runs.

Citations:


Fall back when mask is provided until Triton path supports masking.

The _ltx_triton_attention() function ignores the mask parameter and always sets is_causal=False. Since ltx_core passes non-None masks during normal inference and training, and custom backends must preserve mask semantics as an additive attention bias, this wrapper changes attention behavior for any masked case. Only take the Triton path when mask is None; otherwise, defer to self._original_fn(...).

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/sparsity/attention_sparsity/kernels/ltx_triton_attention.py`
around lines 164 - 168, The current __call__ wrapper incorrectly routes masked
attention through _ltx_triton_attention which ignores mask and forces
is_causal=False; change the conditional to only use the Triton path when active
and mask is None (e.g., if active and mask is None: return
_ltx_triton_attention(...)); otherwise always call and return
self._original_fn(q, k, v, heads, mask) so mask semantics are preserved. Ensure
you reference _get_ltx_triton_context, _ltx_triton_attention, and
self._original_fn when making the change.

Comment on lines +255 to +269
def get_threshold_info(self) -> dict:
"""Get threshold information for debugging/display."""
scale_factor = self._get_scale_factor()
if scale_factor is not None:
return {
"type": "dynamic_calibrated",
"formula": "threshold = scale_factor / seq_k (computed at runtime)",
"scale_factor": scale_factor,
"calibration_params": self.calibration_params,
"target_sparse_ratio": self.target_sparse_ratio,
}
return {
"type": "static",
"value": self.skip_softmax_threshold,
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Report raw-threshold mode in get_threshold_info().

_triton_inference_context() gives skip_softmax_raw_threshold highest priority, but get_threshold_info() still reports the static lambda threshold. In raw-threshold runs the printed summary is therefore misleading even though the kernel is using a different value.

🔧 Suggested fix
     def get_threshold_info(self) -> dict:
         """Get threshold information for debugging/display."""
+        if self.skip_softmax_raw_threshold is not None:
+            return {
+                "type": "raw",
+                "value": self.skip_softmax_raw_threshold,
+            }
         scale_factor = self._get_scale_factor()
         if scale_factor is not None:
             return {
                 "type": "dynamic_calibrated",
                 "formula": "threshold = scale_factor / seq_k (computed at runtime)",
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/sparsity/attention_sparsity/methods/triton_skip_softmax.py`
around lines 255 - 269, get_threshold_info() currently always reports the static
lambda threshold even when _triton_inference_context() supplies a
skip_softmax_raw_threshold used by the kernel; update get_threshold_info to
check the triton inference context (via _triton_inference_context()) and, if a
skip_softmax_raw_threshold is present, return a "raw" type with that raw
threshold value (and note any related keys like skip_softmax_threshold or
target_sparse_ratio as applicable); otherwise preserve the existing
dynamic_calibrated/static return paths so the reported info matches the actual
value used by the kernel.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant