Conversation
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
📝 WalkthroughWalkthroughThis PR implements skip-softmax sparse attention for diffusion models (WAN 2.2 and planned LTX-2 support), introducing a Triton kernel calibration API for multi-threshold sparsity measurement, configurable threshold modes (raw or calibrated), diffusers and LTX-2 backend integration, comprehensive documentation, and an end-to-end example script with inference and calibration workflows. Changes
Sequence Diagram(s)sequenceDiagram
participant User
participant CalibScript as Calibration Script
participant Pipeline as Diffusers Pipeline
participant TritonKernel as Triton attention_calibrate
participant Calibrator as DynamicThresholdCalibrator
User->>CalibScript: invoke with --calibrate --target-sparsity 0.5
CalibScript->>Pipeline: build WAN 2.2 model
CalibScript->>Pipeline: convert_to_sparse_attention_model
CalibScript->>Pipeline: set calibration mode & threshold trials
loop calibration forward passes (OpenVid-1M captions)
CalibScript->>Pipeline: forward(sample_prompts)
Pipeline->>TritonKernel: attention_calibrate with threshold_trials
TritonKernel-->>Pipeline: (output, sparsity_counters[num_thresholds, 2])
Pipeline->>Calibrator: collect per-threshold sparsity stats
end
Calibrator->>Calibrator: fit log(scale_factor) = log_a + b*sparsity
Calibrator-->>Pipeline: return calibration_params {log_a, b, fit_logspace}
Pipeline->>Pipeline: store threshold formula in module config
CalibScript->>CalibScript: free GPU memory
sequenceDiagram
participant User
participant InferenceScript as Inference Script
participant DiffusersBackend as Diffusers Triton Backend
participant Pipeline as WAN 2.2 Pipeline
participant TritonKernel as Triton attention (inference)
User->>InferenceScript: invoke with --raw-threshold -5.0
InferenceScript->>Pipeline: load sparsified model
InferenceScript->>DiffusersBackend: set_triton_skip_softmax_config(raw_threshold=-5.0)
InferenceScript->>Pipeline: pipeline(prompt, num_inference_steps=50)
loop diffusion steps
Pipeline->>DiffusersBackend: _diffusers_triton_attention (forward)
DiffusersBackend->>DiffusersBackend: reshape [B, S, H, D] → [B*S, H, D]
DiffusersBackend->>TritonKernel: attention(query, key, value, skip_softmax_raw_threshold=-5.0)
TritonKernel->>TritonKernel: tile-wise skip_softmax: skip tile if log2(score_max) < -5.0
TritonKernel-->>DiffusersBackend: (output, optional sparsity counters)
DiffusersBackend->>DiffusersBackend: reshape output back [B, S, H, D]
DiffusersBackend-->>Pipeline: attention output
end
Pipeline-->>InferenceScript: generated video
InferenceScript->>InferenceScript: print per-module threshold info & runtime sparsity
Estimated code review effort🎯 4 (Complex) | ⏱️ ~75 minutes 🚥 Pre-merge checks | ✅ 4✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches📝 Generate docstrings
🧪 Generate unit tests (beta)
Comment |
|
Codecov Report❌ Patch coverage is Additional details and impacted files@@ Coverage Diff @@
## main #1269 +/- ##
===========================================
- Coverage 75.58% 55.31% -20.28%
===========================================
Files 459 460 +1
Lines 48612 49345 +733
===========================================
- Hits 36745 27295 -9450
- Misses 11867 22050 +10183
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
There was a problem hiding this comment.
Actionable comments posted: 6
🧹 Nitpick comments (5)
modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py (1)
387-394: Minor: Consider wrappingset_skip_softmax_context(True)in the ExitStack for exception safety.If an exception occurs after
set_skip_softmax_context(True)but before the callback is registered (e.g., duringstack.enter_context), the skip-softmax context would remain enabled without cleanup.♻️ Safer ordering
from ..kernels import set_skip_softmax_context stack = ExitStack() + stack.callback(set_skip_softmax_context, False) set_skip_softmax_context(True) - stack.callback(set_skip_softmax_context, False) stack.enter_context(replace_function(torch.nn.functional, "softmax", sparse_softmax)) return stackThis ensures the cleanup callback is registered before the state is modified, so any subsequent exception during setup will still trigger cleanup when the stack is garbage collected or explicitly closed.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py` around lines 387 - 394, Register the cleanup callback on the ExitStack before enabling the skip-softmax flag to ensure exception safety: call stack.callback(set_skip_softmax_context, False) first, then call set_skip_softmax_context(True), and only after that perform stack.enter_context(replace_function(torch.nn.functional, "softmax", sparse_softmax)); this guarantees that if an exception occurs during enter_context the skip-softmax state (managed by set_skip_softmax_context) will still be cleaned up.modelopt/torch/export/diffusers_utils.py (1)
49-59: Consider making this a one-time warning or moving it to actual usage.This warning fires at module import time whenever
ltx_pipelinesis installed, which may be noisy for users who importdiffusers_utilsbut don't use LTX-2 features. Additionally,stacklevel=2at module load time may not point to a meaningful location.Consider using
warnings.warn(..., stacklevel=2)withfilterwarningsto show once, or deferring the warning to actual LTX-2 usage (similar to line 395-404 in_ltx2_dummy_forward).♻️ One-time warning option
import warnings # At the top of the module _LTX2_LICENSE_WARNING_SHOWN = False # Inside the try block try: from ltx_pipelines.ti2vid_two_stages import TI2VidTwoStagesPipeline as _TI2VidTwoStagesPipeline def _show_ltx2_license_warning(): global _LTX2_LICENSE_WARNING_SHOWN if not _LTX2_LICENSE_WARNING_SHOWN: warnings.warn( "LTX-2 packages ... (license text)", UserWarning, stacklevel=3, ) _LTX2_LICENSE_WARNING_SHOWN = True TI2VidTwoStagesPipeline = _TI2VidTwoStagesPipeline except Exception: TI2VidTwoStagesPipeline = NoneThen call
_show_ltx2_license_warning()in functions that actually use LTX-2.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/export/diffusers_utils.py` around lines 49 - 59, The module currently emits a loud license UserWarning at import via the warnings.warn call; change this to a one-time or deferred warning: wrap the import of ltx_pipelines and the current warnings.warn invocation behind a module-level flag (e.g., _LTX2_LICENSE_WARNING_SHOWN) or remove the warn from import and instead call a small helper like _show_ltx2_license_warning() from actual LTX-2 entrypoints (for example inside TI2VidTwoStagesPipeline usage code or _ltx2_dummy_forward) so the warning is emitted at first use only; keep stacklevel at an appropriate value (e.g., 3) when invoking warnings.warn to point to user code and ensure TI2VidTwoStagesPipeline remains set to None on import failure.modelopt/torch/export/unified_export_hf.py (1)
1261-1266: Remove redundant import.The
yamlmodule is already imported at line 31. This inline import is unnecessary.♻️ Proposed fix
config_data["sparse_attention_config"] = sparse_attn_config # Also save as standalone YAML for easy inspection and reuse - import yaml - yaml_path = Path(export_dir) / "sparse.yaml" with open(yaml_path, "w") as file: yaml.dump(sparse_attn_config, file, default_flow_style=False, sort_keys=False)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/export/unified_export_hf.py` around lines 1261 - 1266, Remove the redundant inline "import yaml" inside the export block; the module is already imported at the top of the file, so delete the inline import and keep the yaml usage (yaml_path = Path(export_dir) / "sparse.yaml" and yaml.dump(sparse_attn_config, ...)) as-is to avoid duplicate imports and retain the YAML file write behavior.modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.py (1)
314-318: Consider documenting this asymmetry.The decode phase always requires
calibration_dataandtokenizer(RULER-based), even when a customforward_loopwas provided for prefill. This asymmetry between prefill (supports custom forward_loop) and decode (always requires RULER) could be confusing.Consider adding a docstring note or raising a more descriptive error message explaining why decode requires the RULER dataset.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.py` around lines 314 - 318, Update the documentation and error to make the prefill/decode asymmetry explicit: add a docstring note on the top-level function in calibrate.py (where decode_forward_loop is created) explaining that prefill accepts a custom forward_loop but decode always requires a RULER-style calibration_data and tokenizer because decode uses create_decode_calibration_forward_loop; and replace the RuntimeError raised when calibration_data or tokenizer is missing with a more descriptive message that states "decode requires a RULER-style calibration_data and tokenizer (used by create_decode_calibration_forward_loop) even if a custom prefill forward_loop was provided." Reference the symbols calibration_data, tokenizer, create_decode_calibration_forward_loop, and decode_forward_loop in the docstring/error.examples/diffusers/sparsity/wan22_skip_softmax.py (1)
56-63: Lazy-load the optionaldatasets/diffusersdependencies.This example hard-imports optional integrations at module load time, which makes simple imports fail unless the full diffusers stack is already installed. Moving these imports into
build_pipeline(),load_calib_prompts(), andmain()keeps the example gated behind the right extras and avoids breaking unrelated tooling that imports example modules. As per coding guidelines, "Gate optional features by install extras ([onnx],[hf],[all]); avoid hard imports of optional dependencies at module level" and "Use the plugin system withimport_plugin()for lazy loading of optional integrations (HuggingFace, Megatron, etc.)".🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/diffusers/sparsity/wan22_skip_softmax.py` around lines 56 - 63, The module currently hard-imports optional packages (datasets, diffusers, diffusers.utils.export_to_video and diffusers classes AutoencoderKLWan/WanPipeline) at top-level; move those imports into the functions that actually use them (e.g., build_pipeline(), load_calib_prompts(), and main()) so the example can be imported without the optional extras installed, and prefer the plugin loader where available (e.g., call import_plugin or similar before importing HuggingFace/diffusers modules) to gate the integrations; update references to SparseAttentionModule and any diffusers types after the local imports so runtime code uses the lazily-loaded symbols.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@examples/diffusers/sparsity/wan22_skip_softmax.py`:
- Around line 271-276: The function load_calib_prompts currently loads the
entire "train" split which is wasteful; change it to only load the first
calib_size examples by using a sliced split (e.g. "train[:{calib_size}]" or
equivalent) or streaming so only the needed items are materialized, then collect
the captions and return them; update load_calib_prompts to call load_dataset
with the sliced split string based on the calib_size parameter and build the
prompts list from that smaller dataset.
In `@modelopt/torch/kernels/triton_fa.py`:
- Around line 1241-1246: The program index calculation uses a per-program tile
count computed from tl.load(b_seq_len + 0) which collides when batches have
variable lengths; replace the local num_q_tiles = tl.cdiv(tl.load(b_seq_len +
0), BLOCK_M) with a tile count derived from the same max_input_len used when
allocating counters (e.g. num_q_tiles = tl.cdiv(max_input_len, BLOCK_M)), ensure
the kernel signature accepts that max_input_len scalar (or otherwise pass the
same allocation param) and update uses of num_q_tiles/prog_idx/base so each
(batch, head, tile) maps to a unique counter slot matching
attention_calibrate()'s allocation.
In `@modelopt/torch/sparsity/attention_sparsity/conversion.py`:
- Around line 127-136: The current broad try/except hides real errors from
register_diffusers_triton_attention() so integration failures are swallowed;
change the logic to only suppress ImportError when importing ModelMixin but
allow exceptions from register_diffusers_triton_attention() to surface: first
try importing ModelMixin and on ImportError simply return/skip, then if
isinstance(model, ModelMixin) import register_diffusers_triton_attention and if
it's not None call register_diffusers_triton_attention() without a broad except
so any runtime errors propagate (or re-raise after logging) — refer to
ModelMixin and register_diffusers_triton_attention to locate and update the
conversion.py block.
In
`@modelopt/torch/sparsity/attention_sparsity/kernels/diffusers_triton_attention.py`:
- Around line 123-213: The function _diffusers_triton_attention currently
accepts attn_mask, dropout_p, and enable_gqa but ignores them; either remove
these params from the signature and ensure the backend's _supported_arg_names
(where supported args are derived) no longer lists them, or implement their
semantics: detect attn_mask != None and convert it into the Triton/kernel
expected mask metadata (or pass it via kw as "attn_mask" / appropriate key),
handle dropout_p > 0 by passing a "dropout_p" kw and ensuring training mode
semantics are respected, and honor enable_gqa by adjusting q/k/v shapes/scale
(grouped-query attention layout changes) and passing an "enable_gqa" or
equivalent flag into kw; update callers/registration so _supported_arg_names
matches the final signature/behavior.
In `@modelopt/torch/sparsity/attention_sparsity/kernels/ltx_triton_attention.py`:
- Around line 164-168: The current __call__ wrapper incorrectly routes masked
attention through _ltx_triton_attention which ignores mask and forces
is_causal=False; change the conditional to only use the Triton path when active
and mask is None (e.g., if active and mask is None: return
_ltx_triton_attention(...)); otherwise always call and return
self._original_fn(q, k, v, heads, mask) so mask semantics are preserved. Ensure
you reference _get_ltx_triton_context, _ltx_triton_attention, and
self._original_fn when making the change.
In `@modelopt/torch/sparsity/attention_sparsity/methods/triton_skip_softmax.py`:
- Around line 255-269: get_threshold_info() currently always reports the static
lambda threshold even when _triton_inference_context() supplies a
skip_softmax_raw_threshold used by the kernel; update get_threshold_info to
check the triton inference context (via _triton_inference_context()) and, if a
skip_softmax_raw_threshold is present, return a "raw" type with that raw
threshold value (and note any related keys like skip_softmax_threshold or
target_sparse_ratio as applicable); otherwise preserve the existing
dynamic_calibrated/static return paths so the reported info matches the actual
value used by the kernel.
---
Nitpick comments:
In `@examples/diffusers/sparsity/wan22_skip_softmax.py`:
- Around line 56-63: The module currently hard-imports optional packages
(datasets, diffusers, diffusers.utils.export_to_video and diffusers classes
AutoencoderKLWan/WanPipeline) at top-level; move those imports into the
functions that actually use them (e.g., build_pipeline(), load_calib_prompts(),
and main()) so the example can be imported without the optional extras
installed, and prefer the plugin loader where available (e.g., call
import_plugin or similar before importing HuggingFace/diffusers modules) to gate
the integrations; update references to SparseAttentionModule and any diffusers
types after the local imports so runtime code uses the lazily-loaded symbols.
In `@modelopt/torch/export/diffusers_utils.py`:
- Around line 49-59: The module currently emits a loud license UserWarning at
import via the warnings.warn call; change this to a one-time or deferred
warning: wrap the import of ltx_pipelines and the current warnings.warn
invocation behind a module-level flag (e.g., _LTX2_LICENSE_WARNING_SHOWN) or
remove the warn from import and instead call a small helper like
_show_ltx2_license_warning() from actual LTX-2 entrypoints (for example inside
TI2VidTwoStagesPipeline usage code or _ltx2_dummy_forward) so the warning is
emitted at first use only; keep stacklevel at an appropriate value (e.g., 3)
when invoking warnings.warn to point to user code and ensure
TI2VidTwoStagesPipeline remains set to None on import failure.
In `@modelopt/torch/export/unified_export_hf.py`:
- Around line 1261-1266: Remove the redundant inline "import yaml" inside the
export block; the module is already imported at the top of the file, so delete
the inline import and keep the yaml usage (yaml_path = Path(export_dir) /
"sparse.yaml" and yaml.dump(sparse_attn_config, ...)) as-is to avoid duplicate
imports and retain the YAML file write behavior.
In `@modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.py`:
- Around line 314-318: Update the documentation and error to make the
prefill/decode asymmetry explicit: add a docstring note on the top-level
function in calibrate.py (where decode_forward_loop is created) explaining that
prefill accepts a custom forward_loop but decode always requires a RULER-style
calibration_data and tokenizer because decode uses
create_decode_calibration_forward_loop; and replace the RuntimeError raised when
calibration_data or tokenizer is missing with a more descriptive message that
states "decode requires a RULER-style calibration_data and tokenizer (used by
create_decode_calibration_forward_loop) even if a custom prefill forward_loop
was provided." Reference the symbols calibration_data, tokenizer,
create_decode_calibration_forward_loop, and decode_forward_loop in the
docstring/error.
In `@modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py`:
- Around line 387-394: Register the cleanup callback on the ExitStack before
enabling the skip-softmax flag to ensure exception safety: call
stack.callback(set_skip_softmax_context, False) first, then call
set_skip_softmax_context(True), and only after that perform
stack.enter_context(replace_function(torch.nn.functional, "softmax",
sparse_softmax)); this guarantees that if an exception occurs during
enter_context the skip-softmax state (managed by set_skip_softmax_context) will
still be cleaned up.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro Plus
Run ID: b4db3917-e462-4d17-b588-79e5f63acc4d
📒 Files selected for processing (23)
CHANGELOG.rstexamples/diffusers/README.mdexamples/diffusers/sparsity/README.mdexamples/diffusers/sparsity/wan22_skip_softmax.pymodelopt/torch/export/diffusers_utils.pymodelopt/torch/export/unified_export_hf.pymodelopt/torch/kernels/__init__.pymodelopt/torch/kernels/triton_fa.pymodelopt/torch/sparsity/attention_sparsity/calibration/calibrate.pymodelopt/torch/sparsity/attention_sparsity/calibration/calibrator.pymodelopt/torch/sparsity/attention_sparsity/config.pymodelopt/torch/sparsity/attention_sparsity/conversion.pymodelopt/torch/sparsity/attention_sparsity/kernels/__init__.pymodelopt/torch/sparsity/attention_sparsity/kernels/diffusers_triton_attention.pymodelopt/torch/sparsity/attention_sparsity/kernels/ltx_triton_attention.pymodelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.pymodelopt/torch/sparsity/attention_sparsity/methods/registry.pymodelopt/torch/sparsity/attention_sparsity/methods/triton_skip_softmax.pymodelopt/torch/sparsity/attention_sparsity/plugins/huggingface.pymodelopt/torch/sparsity/attention_sparsity/stats_manager.pytests/_test_utils/torch/diffusers_models.pytests/examples/diffusers/test_sparsity.pytests/unit/torch/sparsity/attention_sparsity/test_kernel_backends.py
| def load_calib_prompts(calib_size: int) -> list[str]: | ||
| """Load calibration prompts from OpenVid-1M dataset.""" | ||
| dataset = load_dataset("nkp37/OpenVid-1M", split="train") | ||
| prompts = list(dataset["caption"][:calib_size]) | ||
| print(f"Loaded {len(prompts)} calibration prompts from OpenVid-1M") | ||
| return prompts |
There was a problem hiding this comment.
Don’t load the full OpenVid-1M split just to grab a few prompts.
load_dataset(..., split="train") materializes the whole split metadata, and then this code only uses the first calib_size captions. For the default calib_size=4, that is a lot of unnecessary I/O and makes calibration much slower and more fragile than it needs to be.
🔧 Suggested fix
def load_calib_prompts(calib_size: int) -> list[str]:
"""Load calibration prompts from OpenVid-1M dataset."""
- dataset = load_dataset("nkp37/OpenVid-1M", split="train")
- prompts = list(dataset["caption"][:calib_size])
+ dataset = load_dataset("nkp37/OpenVid-1M", split=f"train[:{calib_size}]")
+ prompts = list(dataset["caption"])
print(f"Loaded {len(prompts)} calibration prompts from OpenVid-1M")
return prompts🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@examples/diffusers/sparsity/wan22_skip_softmax.py` around lines 271 - 276,
The function load_calib_prompts currently loads the entire "train" split which
is wasteful; change it to only load the first calib_size examples by using a
sliced split (e.g. "train[:{calib_size}]" or equivalent) or streaming so only
the needed items are materialized, then collect the captions and return them;
update load_calib_prompts to call load_dataset with the sliced split string
based on the calib_size parameter and build the prompts list from that smaller
dataset.
| # --- Write per-program counters (no atomics, just stores) --- | ||
| # Compute unique flat program index for this (batch, head, q_tile) | ||
| num_q_tiles = tl.cdiv(tl.load(b_seq_len + 0), BLOCK_M) # conservative upper bound | ||
| num_heads = tl.num_programs(1) | ||
| prog_idx = batch_idx * num_heads * num_q_tiles + head_idx * num_q_tiles + tile_q | ||
| base = prog_idx * NUM_THRESHOLDS |
There was a problem hiding this comment.
Program indexing collides for variable-length batches.
attention_calibrate() allocates one counter row per launched program using triton.cdiv(max_input_len, BLOCK_M), but the kernel flattens (batch, head, tile) with tl.cdiv(tl.load(b_seq_len + 0), BLOCK_M). If batch element 0 is shorter than another sequence, different programs write into the same counter slots and the exported calibration stats are wrong.
🔧 Suggested fix
- num_q_tiles = tl.cdiv(tl.load(b_seq_len + 0), BLOCK_M) # conservative upper bound
+ num_q_tiles = tl.num_programs(2)
num_heads = tl.num_programs(1)
prog_idx = batch_idx * num_heads * num_q_tiles + head_idx * num_q_tiles + tile_q🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@modelopt/torch/kernels/triton_fa.py` around lines 1241 - 1246, The program
index calculation uses a per-program tile count computed from tl.load(b_seq_len
+ 0) which collides when batches have variable lengths; replace the local
num_q_tiles = tl.cdiv(tl.load(b_seq_len + 0), BLOCK_M) with a tile count derived
from the same max_input_len used when allocating counters (e.g. num_q_tiles =
tl.cdiv(max_input_len, BLOCK_M)), ensure the kernel signature accepts that
max_input_len scalar (or otherwise pass the same allocation param) and update
uses of num_q_tiles/prog_idx/base so each (batch, head, tile) maps to a unique
counter slot matching attention_calibrate()'s allocation.
| try: | ||
| from diffusers.models.modeling_utils import ModelMixin | ||
|
|
||
| if isinstance(model, ModelMixin): | ||
| from .kernels import register_diffusers_triton_attention | ||
|
|
||
| if register_diffusers_triton_attention is not None: | ||
| register_diffusers_triton_attention() | ||
| except (ImportError, Exception): | ||
| pass |
There was a problem hiding this comment.
Don’t swallow diffusers backend registration failures.
If register_diffusers_triton_attention() raises here, conversion still continues and later paths can still select "modelopt_triton". That turns a real integration failure into a later runtime error or a silent fallback. Only missing optional deps should be suppressed here; actual registration errors should surface.
🔧 Suggested fix
- try:
- from diffusers.models.modeling_utils import ModelMixin
-
- if isinstance(model, ModelMixin):
- from .kernels import register_diffusers_triton_attention
-
- if register_diffusers_triton_attention is not None:
- register_diffusers_triton_attention()
- except (ImportError, Exception):
- pass
+ try:
+ from diffusers.models.modeling_utils import ModelMixin
+ except ImportError:
+ ModelMixin = None
+
+ if ModelMixin is not None and isinstance(model, ModelMixin):
+ from .kernels import register_diffusers_triton_attention
+
+ if register_diffusers_triton_attention is not None:
+ register_diffusers_triton_attention()🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@modelopt/torch/sparsity/attention_sparsity/conversion.py` around lines 127 -
136, The current broad try/except hides real errors from
register_diffusers_triton_attention() so integration failures are swallowed;
change the logic to only suppress ImportError when importing ModelMixin but
allow exceptions from register_diffusers_triton_attention() to surface: first
try importing ModelMixin and on ImportError simply return/skip, then if
isinstance(model, ModelMixin) import register_diffusers_triton_attention and if
it's not None call register_diffusers_triton_attention() without a broad except
so any runtime errors propagate (or re-raise after logging) — refer to
ModelMixin and register_diffusers_triton_attention to locate and update the
conversion.py block.
| def _diffusers_triton_attention( | ||
| query: torch.Tensor, | ||
| key: torch.Tensor, | ||
| value: torch.Tensor, | ||
| attn_mask: torch.Tensor | None = None, | ||
| dropout_p: float = 0.0, | ||
| is_causal: bool = False, | ||
| scale: float | None = None, | ||
| enable_gqa: bool = False, | ||
| ) -> torch.Tensor: | ||
| """Compute attention via Triton FA kernel on diffusers layout ``[B, S, H, D]``.""" | ||
| batch, seq_q, num_heads_q, head_dim = query.shape | ||
| seq_k = key.shape[1] | ||
| device = query.device | ||
|
|
||
| # Reshape from diffusers [B, S, H, D] -> flat [B*S, H, D] | ||
| q = query.reshape(batch * seq_q, num_heads_q, head_dim).contiguous() | ||
| k = key.reshape(batch * seq_k, key.shape[2], head_dim).contiguous() | ||
| v = value.reshape(batch * seq_k, value.shape[2], head_dim).contiguous() | ||
|
|
||
| # Build varlen metadata | ||
| b_start_loc_q = torch.arange(batch, device=device, dtype=torch.int32) * seq_q | ||
| b_seq_len_q = torch.full((batch,), seq_q, device=device, dtype=torch.int32) | ||
|
|
||
| if scale is None: | ||
| scale = 1.0 / math.sqrt(head_dim) | ||
|
|
||
| kw: dict = { | ||
| "b_start_loc": b_start_loc_q, | ||
| "b_seq_len": b_seq_len_q, | ||
| "max_input_len": seq_q, | ||
| "is_causal": is_causal, | ||
| "softmax_scale": scale, | ||
| } | ||
|
|
||
| if seq_q != seq_k: | ||
| b_start_loc_k = torch.arange(batch, device=device, dtype=torch.int32) * seq_k | ||
| b_seq_len_k = torch.full((batch,), seq_k, device=device, dtype=torch.int32) | ||
| kw["b_start_loc_k"] = b_start_loc_k | ||
| kw["b_seq_len_k"] = b_seq_len_k | ||
| kw["max_input_len_k"] = seq_k | ||
|
|
||
| # --- Calibration mode: collect multi-threshold stats --- | ||
| calib_mode = getattr(_thread_local, "calibration_mode", False) | ||
| if calib_mode: | ||
| trials = getattr(_thread_local, "threshold_trials", None) | ||
| if trials and attention_calibrate is not None: | ||
| o, counters = attention_calibrate(q, k, v, **kw, threshold_trials=trials) | ||
|
|
||
| # Accumulate counters across all attention calls in this forward pass | ||
| prev = getattr(_thread_local, "calibration_counters", None) | ||
| if prev is None: | ||
| _thread_local.calibration_counters = counters | ||
| else: | ||
| _thread_local.calibration_counters = prev + counters | ||
|
|
||
| # Store actual KV sequence length for calibration stats | ||
| _thread_local.calibration_seq_k = seq_k | ||
|
|
||
| return o.view(batch, seq_q, num_heads_q, head_dim) | ||
|
|
||
| # --- Inference mode: skip-softmax with raw, dynamic, or static threshold --- | ||
| raw_thresh = getattr(_thread_local, "raw_threshold", None) | ||
| if raw_thresh is not None: | ||
| # Raw threshold: passed directly to kernel as skip_threshold_log2 | ||
| kw["skip_softmax_raw_threshold"] = raw_thresh | ||
| else: | ||
| scale_factor = getattr(_thread_local, "scale_factor", None) | ||
| if scale_factor is not None and scale_factor > 0.0: | ||
| # Dynamic threshold: adapt to actual sequence length | ||
| kw["skip_softmax_threshold"] = scale_factor / seq_k | ||
| else: | ||
| threshold = getattr(_thread_local, "skip_threshold", None) | ||
| if threshold is not None and threshold > 0.0: | ||
| kw["skip_softmax_threshold"] = threshold | ||
|
|
||
| assert attention is not None, "Triton attention kernel not available (requires CUDA + triton)" | ||
| do_measure = getattr(_thread_local, "measure_sparsity", False) | ||
| if do_measure: | ||
| kw["measure_sparsity"] = True | ||
| o = attention(q, k, v, **kw) | ||
|
|
||
| # Accumulate runtime sparsity counters from the kernel output | ||
| if do_measure and hasattr(o, "_sparsity_total"): | ||
| prev_total = getattr(_thread_local, "sparsity_total", 0) | ||
| prev_skipped = getattr(_thread_local, "sparsity_skipped", 0) | ||
| _thread_local.sparsity_total = prev_total + o._sparsity_total | ||
| _thread_local.sparsity_skipped = prev_skipped + o._sparsity_skipped | ||
|
|
||
| return o.view(batch, seq_q, num_heads_q, head_dim) | ||
|
|
There was a problem hiding this comment.
🧩 Analysis chain
🌐 Web query:
In diffusers 0.32.2 custom attention backends, can attn_mask, dropout_p, or enable_gqa be non-default for registered backends, and are custom backends expected to preserve those semantics?
💡 Result:
In diffusers 0.32.2, custom attention backends registered via _AttentionBackendRegistry can have non-default support for attn_mask, dropout_p, and enable_gqa. The dispatch_attention_fn passes these parameters (along with query, key, value, is_causal, scale, attention_kwargs) to the registered backend function. Only parameters matching the backend's supported_arg_names are passed; unsupported ones are filtered out without error (though checks may raise if enabled). Custom backends are expected to handle these parameters if listed in their supported args during registration, or ignore/filter them as appropriate to preserve semantics when non-None/non-default. Some registered backends do not support attn_mask, leading to silent ignoring unless checks are enabled (DIFFUSERS_ATTN_CHECKS=yes), which validates shapes and raises errors for mismatches.
Citations:
- 1: https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_dispatch.py
- 2: https://github.com/huggingface/diffusers/blob/e5aa719241f9b74d6700be3320a777799bfab70a/src/diffusers/models/attention_dispatch.py
- 3: https://github.com/huggingface/diffusers/blob/5e181eddfe7e44c1444a2511b0d8e21d177850a0/src/diffusers/models/attention_dispatch.py
- 4: dispatch_attention_fn silently ignores attn_mask for certain backends huggingface/diffusers#12605
- 5: https://huggingface.co/docs/diffusers/optimization/attention_backends
- 6: https://github.com/huggingface/diffusers/tags
- 7: https://github.com/huggingface/diffusers
- 8: https://github.com/huggingface/diffusers/releases/tag/v0.32.1
🏁 Script executed:
# Find the file and look for registration/decorator code
cat -n modelopt/torch/sparsity/attention_sparsity/kernels/diffusers_triton_attention.py | head -150Repository: NVIDIA/Model-Optimizer
Length of output: 6967
🏁 Script executed:
# Continue reading to find registration/decorator code
cat -n modelopt/torch/sparsity/attention_sparsity/kernels/diffusers_triton_attention.py | tail -100Repository: NVIDIA/Model-Optimizer
Length of output: 4972
🏁 Script executed:
# Double-check the signature one more time to ensure accuracy
python3 << 'EOF'
import inspect
sig = """def _diffusers_triton_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: torch.Tensor | None = None,
dropout_p: float = 0.0,
is_causal: bool = False,
scale: float | None = None,
enable_gqa: bool = False,
) -> torch.Tensor:"""
# Extract parameter names
params = ['query', 'key', 'value', 'attn_mask', 'dropout_p', 'is_causal', 'scale', 'enable_gqa']
print("Function parameters:", params)
print("\nParameters with defaults that could be non-default:")
print("- attn_mask (default None, but can be passed as Tensor)")
print("- dropout_p (default 0.0, but can be non-zero)")
print("- enable_gqa (default False, but can be True)")
EOFRepository: NVIDIA/Model-Optimizer
Length of output: 376
Remove unsupported parameters from the function signature or implement their semantics.
This backend is registered with attn_mask, dropout_p, and enable_gqa in its _supported_arg_names (lines 237–239 derive supported args from the function signature), which means diffusers will pass these arguments to the implementation. However, lines 123–212 silently ignore all three, computing plain attention regardless. This breaks correctness if a caller provides a non-None attention mask, non-zero dropout, or enables GQA.
Either:
- Remove these parameters from the signature and handle the filtered-out case explicitly, or
- Implement the masking, dropout, and GQA semantics in the Triton path.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In
`@modelopt/torch/sparsity/attention_sparsity/kernels/diffusers_triton_attention.py`
around lines 123 - 213, The function _diffusers_triton_attention currently
accepts attn_mask, dropout_p, and enable_gqa but ignores them; either remove
these params from the signature and ensure the backend's _supported_arg_names
(where supported args are derived) no longer lists them, or implement their
semantics: detect attn_mask != None and convert it into the Triton/kernel
expected mask metadata (or pass it via kw as "attn_mask" / appropriate key),
handle dropout_p > 0 by passing a "dropout_p" kw and ensuring training mode
semantics are respected, and honor enable_gqa by adjusting q/k/v shapes/scale
(grouped-query attention layout changes) and passing an "enable_gqa" or
equivalent flag into kw; update callers/registration so _supported_arg_names
matches the final signature/behavior.
| def __call__(self, q, k, v, heads, mask=None): | ||
| active, threshold, _scale_factor = _get_ltx_triton_context() | ||
| if active: | ||
| return _ltx_triton_attention(q, k, v, heads, mask, threshold) | ||
| return self._original_fn(q, k, v, heads, mask) |
There was a problem hiding this comment.
🧩 Analysis chain
🌐 Web query:
In ltx_core's Attention.attention_function API, can the mask argument be non-None during normal inference or training, and what semantics must a custom backend preserve for it?
💡 Result:
Yes, the mask argument in ltx_core's Attention.attention_function API can be non-None during both normal inference and training. The API signature for AttentionCallable.call is def call(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, heads: int, mask: torch.Tensor | None = None) -> torch.Tensor:. The Attention module's forward method passes a potentially non-None mask from the Modality.attention_mask, which is described as "Optional 2-D self-attention mask, shape (B, T, T). Values in [0, 1] where 1 = full attention and 0 = no attention. None means unrestricted (full) attention between all tokens. Built incrementally by conditioning items." Custom backends must preserve the semantics of applying the mask as an additive attention bias in scaled dot-product attention, matching PyTorch's torch.nn.functional.scaled_dot_product_attention behavior where attn_mask=mask is passed with dropout_p=0.0 and is_causal=False. The PytorchAttention implementation reshapes the mask if ndim==2 by unsqueezing to add batch dim, then passes directly. XFormers and FlashAttention implementations also accept and use the mask parameter. During inference, dropout_p=0.0 is used (no dropout), and the same API applies to training where masks from conditioning (e.g., attention_strength_wrapper) can be non-None. Perturbations skip entire attention computations but do not alter mask handling when attention runs.
Citations:
- 1: https://huggingface.co/spaces/Lightricks/ltx-2/blob/f6c232a6e998406ff35d8322f6ca595a2a196e5d/packages/ltx-core/src/ltx_core/model/transformer/attention.py
- 2: https://huggingface.co/spaces/Lightricks/ltx-2/blob/95b6572a453c1d3c9e91946965e41e9d9c17450b/packages/ltx-core/src/ltx_core/model/transformer/attention.py
- 3: https://github.com/Lightricks/LTX-2/blob/main/packages/ltx-core/src/ltx_core/model/transformer/modality.py
Fall back when mask is provided until Triton path supports masking.
The _ltx_triton_attention() function ignores the mask parameter and always sets is_causal=False. Since ltx_core passes non-None masks during normal inference and training, and custom backends must preserve mask semantics as an additive attention bias, this wrapper changes attention behavior for any masked case. Only take the Triton path when mask is None; otherwise, defer to self._original_fn(...).
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@modelopt/torch/sparsity/attention_sparsity/kernels/ltx_triton_attention.py`
around lines 164 - 168, The current __call__ wrapper incorrectly routes masked
attention through _ltx_triton_attention which ignores mask and forces
is_causal=False; change the conditional to only use the Triton path when active
and mask is None (e.g., if active and mask is None: return
_ltx_triton_attention(...)); otherwise always call and return
self._original_fn(q, k, v, heads, mask) so mask semantics are preserved. Ensure
you reference _get_ltx_triton_context, _ltx_triton_attention, and
self._original_fn when making the change.
| def get_threshold_info(self) -> dict: | ||
| """Get threshold information for debugging/display.""" | ||
| scale_factor = self._get_scale_factor() | ||
| if scale_factor is not None: | ||
| return { | ||
| "type": "dynamic_calibrated", | ||
| "formula": "threshold = scale_factor / seq_k (computed at runtime)", | ||
| "scale_factor": scale_factor, | ||
| "calibration_params": self.calibration_params, | ||
| "target_sparse_ratio": self.target_sparse_ratio, | ||
| } | ||
| return { | ||
| "type": "static", | ||
| "value": self.skip_softmax_threshold, | ||
| } |
There was a problem hiding this comment.
Report raw-threshold mode in get_threshold_info().
_triton_inference_context() gives skip_softmax_raw_threshold highest priority, but get_threshold_info() still reports the static lambda threshold. In raw-threshold runs the printed summary is therefore misleading even though the kernel is using a different value.
🔧 Suggested fix
def get_threshold_info(self) -> dict:
"""Get threshold information for debugging/display."""
+ if self.skip_softmax_raw_threshold is not None:
+ return {
+ "type": "raw",
+ "value": self.skip_softmax_raw_threshold,
+ }
scale_factor = self._get_scale_factor()
if scale_factor is not None:
return {
"type": "dynamic_calibrated",
"formula": "threshold = scale_factor / seq_k (computed at runtime)",🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@modelopt/torch/sparsity/attention_sparsity/methods/triton_skip_softmax.py`
around lines 255 - 269, get_threshold_info() currently always reports the static
lambda threshold even when _triton_inference_context() supplies a
skip_softmax_raw_threshold used by the kernel; update get_threshold_info to
check the triton inference context (via _triton_inference_context()) and, if a
skip_softmax_raw_threshold is present, return a "raw" type with that raw
threshold value (and note any related keys like skip_softmax_threshold or
target_sparse_ratio as applicable); otherwise preserve the existing
dynamic_calibrated/static return paths so the reported info matches the actual
value used by the kernel.
What does this PR do?
Type of change: New Feature
Adds HuggingFace checkpoint export for diffusion pipelines calibrated with skip-softmax, on top of the base skip-softmax MR (
jingyux/diffusion-skip-softmax). Concretely:_export_diffusers_checkpointnow walks everynn.Modulecomponent of a diffusers pipeline, callsexport_sparse_attention_config, injects the result into that component'sconfig.jsonassparse_attention_config, and additionally writes a unified top-levelsparse.yamlkeyed by pipeline component (transformer,transformer_2, …). The existing LLMexport_hf_checkpointpath also gains a siblingsparse.yamldump.export_sparse_attention_configis generalized: per-group nesting (group_0.threshold_scale_factor,group_0.raw_threshold,group_0.disabled_layers) so future sparse methods can coexist, plus per-layerdisabled_layersreporting and araw_threshold-only path for uncalibrated use.log_a/fit_logspacethrough its result dict, and the exporter emits the matchingformula: "log_a + b * target_sparsity"for diffusion (linear-spacea * exp(b * S)is still used for LLMs).examples/diffusers/sparsity/wan22_skip_softmax.pygets an--export-dirflag that callsexport_hf_checkpoint(pipe, export_dir=...)after calibration.CHANGELOG.rstto note diffusion coverage for skip-softmax.Usage
python examples/diffusers/sparsity/wan22_skip_softmax.py \ --model-path Wan-AI/Wan2.2-TI2V-5B-Diffusers \ --calibrate \ --target-sparsity 0.5 \ --export-dir ./wan22_skip_softmax_ckptEquivalent Python:
Resulting layout:
A representative
config.jsonentry for a diffusion component:Testing
tests/examples/diffusers/test_sparsity.pywith a calibrate → export → reload round-trip on a small diffusion pipeline, asserting the presence and shape ofsparse_attention_configin each component'sconfig.jsonand the unified top-levelsparse.yaml.sparse.yamlandtransformer{,_2}/config.jsoncontain the expected log-spacethreshold_scale_factor, any disabled layers, and producer metadata; a freshly loaded pipeline from the exported checkpoint reproduces the calibrated sparsity target end-to-end.test_sparsitytests — the newsparse.yamlsibling is emitted without changing the existingconfig.jsonpatching behavior.Before your PR is "Ready for review"
Make sure you read and follow Contributor guidelines and your commits are signed (
git commit -s -S).Make sure you read and follow the Security Best Practices (e.g. avoiding hardcoded
trust_remote_code=True,torch.load(..., weights_only=False),pickle, etc.).CONTRIBUTING.md: ✅Additional Information
Summary by CodeRabbit
Release Notes
New Features
Documentation
Chores