Back

2026-04-03

Making torch.compile Actually Work for DiT LoRA Training

The Problem

torch.compile promises free speedups — let the compiler fuse ops, eliminate overhead, generate optimized kernels. In practice, for DiT (Diffusion Transformer) LoRA training, it's a trap. You enable it, and instead of faster training, you get a compilation loop that recompiles the model every step, eating all the gains and then some.

The root cause is shape dynamism. torch.compile generates and caches kernels based on tensor shapes. When shapes change, it recompiles. DiT training has three independent sources of shape variation:

  1. Spatial resolution — different aspect ratio buckets produce different (T, H, W) token counts
  2. Caption length — variable text encoder output lengths for cross-attention KV
  3. Batch size — trailing incomplete batches at epoch boundaries

Every unique combination of these three triggers a fresh compilation. With 20+ resolution buckets and arbitrary caption lengths, you quickly exceed dynamo's recompile limit and fall back to eager mode anyway.

This post documents how we eliminated all three in our anima_lora fork of sd-scripts.


Source 1: Spatial Resolution

A standard image training pipeline uses bucket sampling — images are grouped by similar aspect ratios and resized to bucket resolutions like 512×768, 768×1024, etc. Each bucket produces a different number of visual tokens after patchification: a 1024×1024 image becomes 4096 tokens, but 768×1024 becomes 3072.

Fix: Constant-Token Buckets + Static Padding

We designed 17 resolution buckets where every resolution produces approximately the same token count:

CONSTANT_TOKEN_BUCKETS = [
    (1024, 1024),   # 4096 tokens, 0.0% pad
    (960, 1088),    # 4080 tokens, 0.4% pad
    (1088, 960),
    # ... 14 more landscape/portrait pairs
    (2048, 512),    # 4096 tokens, 0.0% pad
]

Each resolution satisfies (W/16) * (H/16) ≈ 4096 with at most 1.6% padding overhead. The --static_token_count=4096 flag then pads every sequence to exactly 4096 tokens:

  1. Flatten 5D input (B, T, H, W, D) to (B, seq_len, D)
  2. Pad sequence dim to 4096 with zeros
  3. Reshape to fake-5D (B, 1, 4096, 1, D) — compatible with existing block code
  4. After all blocks: strip padding and restore original shape

The result: every forward pass sees identical spatial dimensions, regardless of the input image's aspect ratio.


Source 2: Caption Length

Even with spatial shapes fixed, cross-attention KV sequences vary with every caption. A short caption might produce 20 tokens; a detailed one, 200+. Each unique length is a new shape for torch.compile.

Fix: KV Bucket Trimming

We quantize caption lengths to 4 fixed buckets:

_KV_BUCKETS = (64, 128, 256, 512)

Each caption's KV sequence is trimmed (or padded) to the smallest bucket that fits. This gives torch.compile at most 4 shape variants for cross-attention instead of one per unique caption.

For Flash Attention 4 (which can't use padding masks directly), we apply an LSE sink correction after trimming — when zero-padded KV positions are removed, the softmax denominator must be compensated:

correction = torch.sigmoid(lse - math.log(n_pad))
x = out * correction.transpose(1, 2).unsqueeze(-1)

For flex attention, we pre-compute BlockMask objects before the block loop and pass them via AttentionParams, so mask creation never happens inside the compiled region.


Source 3: Batch Size

The last batch of each epoch is typically smaller than batch_size. One shape change per epoch sounds harmless, but it still triggers recompilation and dynamo guard re-evaluation.

Fix: Drop Incomplete Batches

batch_count = len(bucket) // self.batch_size  # integer division, not ceil

The trailing partial batch is dropped. When sample_ratio < 1.0 (where every image matters), this is skipped.


Making the Code Compile-Safe

Fixing shape dynamism is necessary but not sufficient. The model code itself had several patterns that cause graph breaks — points where dynamo gives up tracing and falls back to eager execution.

Removing einops

einops.rearrange uses string-based symbolic shape parsing that is opaque to dynamo. Every call is a graph break. We replaced all uses with explicit tensor operations:

einopsReplacement
rearrange(t, "b ... (h d) -> b ... h d", h=..., d=...).unflatten(-1, (n_heads, head_dim))
rearrange(x, "B T H W (p1 p2 t C) -> ...").unflatten().permute().reshape() chain
rearrange(em, "t h w d -> (t h w) 1 1 d").flatten(0, 2).unsqueeze(1).unsqueeze(1)

Removing torch.autocast Context Managers

Context managers introduce overhead and are difficult for dynamo to trace through. torch.autocast blocks in RMSNorm.forward and FinalLayer.forward were replaced with explicit .float() / .to(x.dtype) casts.

.repeat() → .expand()

expand() creates a view without allocating memory, while repeat() copies data:

# OLD — allocates memory
padding_mask.unsqueeze(1).repeat(1, n_heads, 1)
# NEW — view only
padding_mask.unsqueeze(2).expand(-1, -1, n_heads)

Flash Attention 4 Graph Breaks

FA4's CUTLASS/TVM kernels access raw DLPack data pointers, which fail with FakeTensors during dynamo tracing. Since FA4 is already a fused kernel that torch.compile can't improve, we wrap it with @torch.compiler.disable:

@torch.compiler.disable
def flash_attn_4_func(*args, **kwargs):
    out, _lse = _flash_attn_4_func_raw(*args, **kwargs)
    return out

This inserts a clean graph break while letting surrounding ops compile normally.

Flex Attention: Intentionally NOT Pre-Compiled

When blocks are individually compiled, the outer torch.compile already traces into flex_attention and fuses it. Pre-compiling causes nested compilation that exhausts dynamo's recompile limit and falls back to the slow unfused path:

# Intentionally NOT compiled — outer block compilation handles it
compiled_flex_attention = _flex_attention

Compilation Strategy

With all shape dynamism eliminated, the question is how to apply torch.compile. We compile each transformer block's _forward method individually:

for block in self.blocks:
    block._forward = torch.compile(
        block._forward, backend=backend, dynamic=True
    )

A critical subtlety: we compile _forward, NOT forward. The gradient checkpointing decorator (unsloth_checkpoint) uses @torch._disable_dynamo, which would cause an immediate graph break if forward itself were compiled — dynamo compiles nothing useful but still checks shape guards, causing recompile storms.

When --static_token_count is set, block-level compilation is used. Otherwise, the standard Accelerator-level dynamo path is used as a fallback.


Handling torch.compile State Dicts

torch.compile wraps modules in _orig_mod containers, inserting _orig_mod. or _orig_mod_ into state_dict keys. This breaks checkpoint loading and LoRA weight matching. We normalize keys in three places:

@staticmethod
def _strip_orig_mod_keys(state_dict):
    new_sd = {}
    for key, val in state_dict.items():
        new_key = re.sub(r"(?<=_)_orig_mod_", "", key)
        new_sd[new_key] = val
    return new_sd

Applied during: (1) loading external checkpoints, (2) LoRA target module discovery, (3) load_state_dict() override on the network, and (4) LoRA weight merging.

sd-scripts has zero _orig_mod_ awareness — loading a checkpoint trained with torch.compile would fail silently or error out.


Results

All benchmarks on a single RTX 5060 Ti 16 GB. LoRA r=32, lr=5e-5, batch size 2, 2 epochs (182 steps), seed=42. Gradient checkpointing and unsloth offload checkpointing enabled. Latent and text embeddings were pre-cached. Validation loss measured with fixed seed at timestep σ ∈ {0.05, 0.1, 0.2, 0.35}.

FA2 plainFA2 compile (eager fallback)FA2 compile (static tokens)FA4 compile (static tokens)
Peak VRAM7.0 GB7.7 GB6.2 GB6.3 GB
Total time14:5115:1011:0711:01
2nd epoch (post-compile)7:267:265:015:17
Train loss0.0920.0890.0860.089
Val loss0.2120.2110.1930.204

FA2 plain is the baseline — no compilation at all. FA2 compile (eager fallback) enables torch.compile without static shapes; dynamo hits recompile limits and falls back to eager, so there's no speedup — just extra overhead from compilation attempts.

The real gains come from static-token compilation. FA2 compile (static tokens) cuts training time by 25% and drops VRAM by 11% compared to baseline. Validation loss differences are within margin of error across runs.

FA4 compile (static tokens) performs comparably to FA2 with static tokens. The FA4 kernel itself doesn't add much on SM120 when the surrounding code is already compiled, but it does unlock further optimization potential for more complex attention patterns (flex attention, KV trimming without LSE correction overhead).


Summary

Shape dynamism sourceSolutionImpact
Spatial resolutionCONSTANT_TOKEN_BUCKETS + static_token_count padding20+ variants → 1
Caption length_KV_BUCKETS bucketed trimmingArbitrary → 4 variants
Batch sizeDrop incomplete last batchesEliminates trailing recompile
Code patternFix
einops.rearrangeExplicit tensor ops
torch.autocast blocksDirect dtype casts
.repeat().expand()
FA4 DLPack pointers@torch.compiler.disable
Nested flex compilationSkip pre-compilation
_orig_mod_ keysRegex normalization

The key insight: torch.compile for training is not about flipping a flag. It's about reshaping your entire data pipeline and model code to present static shapes and clean graphs to the compiler. Once you do that, the compiler does its job and the speedups are real.