Fix SM120 forward pass crash: parent __init__ overwrites arch, enabling unsupported TMA path#2416
Open
moghon92 wants to merge 1 commit intoDao-AILab:mainfrom
Open
Fix SM120 forward pass crash: parent __init__ overwrites arch, enabling unsupported TMA path#2416moghon92 wants to merge 1 commit intoDao-AILab:mainfrom
moghon92 wants to merge 1 commit intoDao-AILab:mainfrom
Conversation
Fix SM120 forward pass crash: parent __init__ overwrites arch, enabling unsupported TMA path FlashAttentionForwardSm120 sets class variable arch=80 to force CpAsync code paths (no TMA for output). However, FlashAttentionForwardSm80.__init__() calls self.arch = BaseDSL._get_dsl().get_arch_enum(), which returns the real GPU architecture (Arch.sm_120), overwriting the class variable with an instance variable. This causes use_tma_O = (self.arch >= Arch.sm_90) to evaluate True, and the epilogue enters the TMA output path where tma_atom_O is None (never created for SM120), resulting in: AttributeError: 'NoneType' object has no attribute '_trait' in copy_utils.tma_get_copy_fn -> cpasync.tma_partition Fix: override __init__ to reset self.arch = Arch.sm_80 after super().__init__(). Tested on NVIDIA B200 (SM 12.0) with: - torch 2.9.1+cu129 - nvidia-cutlass-dsl 4.4.2 - quack-kernels 0.3.7
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Bug
Running FlashAttention-4 on SM120 GPUs (e.g., NVIDIA RTX PRO 6000 Blackwell Server Edition) crashes immediately with:
AttributeError: 'NoneType' object has no attribute '_trait'
Full traceback points to:
flash_fwd.py:399 → epilogue() → copy_utils.tma_get_copy_fn(tma_atom_O, ...)
where
tma_atom_OisNone.Root Cause
FlashAttentionForwardSm120is designed to use SM80-era CpAsync code paths (no TMA). To achieve this, it sets a class variablearch = 80. The intent is thatuse_tma_O = (self.arch >= Arch.sm_90)evaluates toFalse, skipping the TMA output path.However, the parent class
FlashAttentionForwardSm80.__init__()(inflash_fwd.py, line 110) does:python
self.arch = BaseDSL._get_dsl().get_arch_enum()
This queries the actual GPU and returns
Arch.sm_120, creating an instance variable that shadows the class variablearch = 80. In Python, instance attributes always take precedence over class attributes.As a result:
self.archbecomesArch.sm_120(not80as intended)self.use_tma_O = self.arch >= Arch.sm_90→True__call__method passestma_atom_O = Nonetoepilogue()(because SM120 never creates a TMA descriptor for output)tma_get_copy_fn(None, ...)→ crashFix
Add an
__init__override inFlashAttentionForwardSm120that resetsself.archback toArch.sm_80after callingsuper().__init__():python
def init(self, args, kwargs):
super().init_(*args, kwargs)
self.arch = Arch.sm_80
This ensures the CpAsync output path is used, matching the original design intent documented in the class comment.
Environment