Skip to content

Fix SM120 forward pass crash: parent __init__ overwrites arch, enabling unsupported TMA path#2416

Open
moghon92 wants to merge 1 commit intoDao-AILab:mainfrom
moghon92:fix/sm120-arch-override
Open

Fix SM120 forward pass crash: parent __init__ overwrites arch, enabling unsupported TMA path#2416
moghon92 wants to merge 1 commit intoDao-AILab:mainfrom
moghon92:fix/sm120-arch-override

Conversation

@moghon92
Copy link
Copy Markdown

Bug

Running FlashAttention-4 on SM120 GPUs (e.g., NVIDIA RTX PRO 6000 Blackwell Server Edition) crashes immediately with:

AttributeError: 'NoneType' object has no attribute '_trait'

Full traceback points to:

flash_fwd.py:399 → epilogue() → copy_utils.tma_get_copy_fn(tma_atom_O, ...)
where tma_atom_O is None.

Root Cause

FlashAttentionForwardSm120 is designed to use SM80-era CpAsync code paths (no TMA). To achieve this, it sets a class variable arch = 80. The intent is that use_tma_O = (self.arch >= Arch.sm_90) evaluates to False, skipping the TMA output path.

However, the parent class FlashAttentionForwardSm80.__init__() (in flash_fwd.py, line 110) does:

python
self.arch = BaseDSL._get_dsl().get_arch_enum()

This queries the actual GPU and returns Arch.sm_120, creating an instance variable that shadows the class variable arch = 80. In Python, instance attributes always take precedence over class attributes.

As a result:

  1. self.arch becomes Arch.sm_120 (not 80 as intended)
  2. self.use_tma_O = self.arch >= Arch.sm_90True
  3. The kernel's __call__ method passes tma_atom_O = None to epilogue() (because SM120 never creates a TMA descriptor for output)
  4. The epilogue enters the TMA branch and calls tma_get_copy_fn(None, ...) → crash

Fix

Add an __init__ override in FlashAttentionForwardSm120 that resets self.arch back to Arch.sm_80 after calling super().__init__():

python
def init(self, args, kwargs):
super().init_(*args, kwargs)
self.arch = Arch.sm_80

This ensures the CpAsync output path is used, matching the original design intent documented in the class comment.

Environment

  • GPU: NVIDIA RTX PRO 6000 Blackwell Server Edition (SM 12.0)
  • torch 2.9.1+cu129
  • nvidia-cutlass-dsl 4.4.2
  • quack-kernels 0.3.7
  • flash-attn-4 dev (commit 98024f9)

Fix SM120 forward pass crash: parent __init__ overwrites arch, enabling unsupported TMA path

FlashAttentionForwardSm120 sets class variable arch=80 to force CpAsync
code paths (no TMA for output). However, FlashAttentionForwardSm80.__init__()
calls self.arch = BaseDSL._get_dsl().get_arch_enum(), which returns the real
GPU architecture (Arch.sm_120), overwriting the class variable with an instance
variable.

This causes use_tma_O = (self.arch >= Arch.sm_90) to evaluate True, and the
epilogue enters the TMA output path where tma_atom_O is None (never created
for SM120), resulting in:
  AttributeError: 'NoneType' object has no attribute '_trait'
  in copy_utils.tma_get_copy_fn -> cpasync.tma_partition

Fix: override __init__ to reset self.arch = Arch.sm_80 after super().__init__().

Tested on NVIDIA B200 (SM 12.0) with:
  - torch 2.9.1+cu129
  - nvidia-cutlass-dsl 4.4.2
  - quack-kernels 0.3.7
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant