Background
We train LoRA models on an RTX 5060 Ti — a consumer Blackwell GPU with compute capability SM120. Flash Attention 4 (the CuTeDSL rewrite) launched with support for SM80 (Ampere), SM90 (Hopper), and SM100/SM110 (Blackwell data center), but not SM120.
SM120 is an odd architecture. It reports arch = 120 at runtime, yet uses SM80-era MMA instructions (mma.sync.aligned.m16n8k16) instead of the newer WGMMA units found in SM90+. It also has a much tighter shared memory budget — roughly 99-101 KB compared to 163 KB on SM80 and 227 KB on SM100. This means SM120 is essentially an SM80 execution model wearing a Blackwell badge.
We found recent upstream PRs (#2329, #2330, #2333) that add SM120 subclasses for forward, backward, and varlen attention. These PRs lay the groundwork, but after pulling them in, we still hit three critical bugs that prevented actual use.
Bug 1: Forward Pass TMA O-Store Crash
The very first run crashed with:
AttributeError: 'NoneType' object has no attribute '_trait'
deep inside the TMA (Tensor Memory Accelerator) epilogue path. The culprit was in flash_fwd.py:
self.use_tma_O = self.arch >= Arch.sm_90
SM120 is >= SM90, so the code tried to use TMA for the output epilogue. But SM120 doesn't support SM90's TMA output path — the copy descriptor was never initialized, leaving atom._trait as None.
The SM120 subclass tried to work around this by setting arch = 80 as a class attribute, but the parent __init__ immediately overwrites it:
self.arch = BaseDSL._get_dsl().get_arch_enum() # returns 120 at runtime
Fix: Exclude SM120 from TMA O-store explicitly:
self.use_tma_O = self.arch >= Arch.sm_90 and self.arch < Arch.sm_120
Bug 2: Backward Pass dQ_single_wg Unbound
With the forward pass fixed, the backward pass crashed immediately:
UnboundLocalError: cannot access local variable 'dQ_single_wg'
The SM120 configuration block in interface.py sets tile sizes, swap flags, and atom layouts — but simply forgot to set dQ_single_wg. Since SM120 backward follows SM80's single warp group model:
dQ_single_wg = False
One line. That was the fix.
Bug 3: NaN Loss — StMatrix Register Layout Mismatch
This was the hard one. After fixing Bugs 1 and 2, training ran without errors but produced nan loss from the very first step:
steps: 1%|▌ | 10/744 [00:30<36:53, 3.02s/it, avr_loss=nan]
No crash, no warning — just silent data corruption. We narrowed it down to the forward output being filled with garbage values.
The root cause was in utils.py. The function get_smem_store_atom() selects the stmatrix.sync.aligned instruction for storing results from registers to shared memory:
if const_expr(arch >= 90 and element_type.width == 16):
return StMatrix8x8x16bOp # Uses SM90+ register layout
SM120 passes arch >= 90, so it gets the stmatrix path. But stmatrix assumes SM90+ WGMMA output register layout. SM120 uses SM80's mma.sync with a completely different register-to-thread mapping. The result: registers get written to the wrong shared memory locations, producing corrupted output tensors.
This affected two critical sites:
- Forward epilogue (
flash_fwd.py:347) — corrupted O output - Backward postprocess (
flash_bwd_postprocess.py:537) — corrupted dQ gradients
Fix: Exclude SM120 from the stmatrix path, falling back to the universal (but slightly slower) copy:
if const_expr(arch < 90 or arch >= 120 or element_type.width != 16):
# Universal copy — safe for all architectures
After this fix, test_sm120_nan.py confirmed that FA4 SM120 output matches PyTorch's reference attention implementation.
The Pattern
All three bugs share the same root cause: Flash Attention 4 assumes SM90+ features are available whenever arch >= 90. SM120 breaks this assumption by being a high arch number with SM80 execution capabilities. Each bug was a different place where this assumption leaked through:
Why FA4 Over FA2 on SM120
Beyond the attention kernel itself, there's a practical reason to prefer FA4: torch.compile compatibility. FA4 is written in CuTe DSL (Python-level kernel definitions), which plays nicely with torch.compile. FA2's handwritten CUDA C++ kernels are opaque to the compiler — they work as custom ops but the surrounding code can't be fused through them. With FA4 on SM120 we can set torch_compile=True and let the compiler optimize the full training loop, not just everything-except-attention.
Results: FA4 vs FA2 on SM120
Both runs use torch_compile=True, 2 epochs, 186 steps, batch size 1, 8-bit AdamW, bfloat16, on a single RTX 5060 Ti.
FA4 is 25% faster and uses 12% less VRAM than FA2 — with comparable convergence. The speed gain comes not just from the attention kernel itself, but from torch.compile being able to optimize through CuTe DSL kernels where it can't with FA2's opaque CUDA C++ ops.
The VRAM reduction is a meaningful win on a 16 GB consumer card. That's headroom for larger batch sizes or higher resolution inputs.
The fixes are minimal and backward-compatible — three surgical one-line changes that exclude SM120 from SM90+ code paths. Upstream PRs provided the SM120 subclass scaffolding; we filled in the gaps where architecture assumptions didn't hold.
Environment: RTX 5060 Ti, CUDA 13.0, PyTorch 2.11.0+cu130, flash-attn 4.0.0b6.dev10