The Problem
A diffusion model that takes 28 sampling steps with CFG=4 runs two forward passes per step — one conditional, one unconditional — for 56 backbone forwards per image. A turbo model collapses that to 4 steps with CFG baked in: 1 forward per step, 4 forwards per image. A ~14× inference cut.
The standard way to get there is few-step distillation: a slow, high-quality teacher supervises a student that learns to take giant strides. The catch is that distillation recipes are written for clusters. Distribution Matching Distillation (DMD2), the variant we use, needs three models live at training time:
- The teacher — the frozen reference distribution.
- The student — the few-step generator you're training.
- A "fake" score model — an auxiliary that learns what the student currently produces, so the loss can measure how far the student's distribution is from the teacher's.
Three diffusion transformers resident at once, plus optimizer state for two of them, plus the BPTT graph if you unroll the sampler — that is a multi-GPU recipe. This post is about getting the whole thing onto one 16 GB consumer card, with the output falling out as a plain LoRA you can load through the normal inference path.
The work is our anima_lora fork; the method is a port of Liu et al., "CFG Augmentation as the Spear, Distribution Matching as the Shield" (arXiv:2511.22677), the Decoupled-Hybrid DMD2 row.
The Memory Trick: Three Roles, One Frozen Backbone
The naive reading of "teacher + student + fake" is three sets of weights. But the teacher, student, and fake all share the same frozen DiT — only thin LoRA adapters differ between them. So we keep one backbone in memory (~5 GB in bf16) and attach two LoRA stacks to it: one is the student, one is the fake. The teacher is just the backbone with both stacks switched off.
frozen Anima DiT (no grad, ~5 GB bf16)
│
┌───────────────────────┼───────────────────────┐
│ │ │
teacher view student view fake view
both LoRAs OFF student ON, fake OFF fake ON, student OFF
→ base velocity → v_student → x_pred → v_fake (score tracker)
(CFG'd at α=4) (the generator)
A "role" is selected by toggling which adapters are active. Each LoRA module short-circuits on a disabled flag, so switching from the teacher view to the student view is an O(num_modules) Python flag flip — negligible next to a DiT forward. The backbone is never an optimizer target; there are exactly two AdamW states, one per LoRA, each a few tens of MB.
The memory math that makes 16 GB work:
Three diffusion transformers' worth of behavior, one diffusion transformer's worth of weights.
Counter-intuitive corollary: the fake LoRA's rank should be ≥ the student's, not smaller. Usual LoRA intuition says "lower rank = better regularized," but the fake isn't a generator — it's a score tracker. If it under-fits the student's moving output distribution, its corrective signal goes noisy.
fake_rank ≥ student_rankis a capacity floor on the regularizer, so we run both at 64.
The Compute Trick: A Single-Call Generator, No BPTT
The other thing that blows up distillation memory is unrolling the sampler at train time. If the student is "a 4-step generator," the obvious training graph runs all 4 steps and backpropagates through the whole trajectory — a 4× activation graph plus the autograd tape across steps.
DMD2 avoids this with a single-call generator. We never unroll the 4-step sampler during training. Each step:
-
Sample a generator timestep
t, build a noised latentx_t = (1−t)·x_0 + t·εfrom a dataset latent. -
One student forward gives a velocity; convert it to a clean-image estimate in one ODE step:
The gradient is one step from the sampled t, not a rollout. Exactly one forward carries gradients per training step; everything else — the teacher's CFG evaluation, the fake's score — is computed under no_grad. So the expensive activation graph exists for a single forward, and the other ~5 forwards per step cost only what an inference pass costs.
Per training step, on that one shared backbone:
- 1 grad-bearing student forward →
x_pred - 2 no-grad teacher forwards (cond + unconditional) for the CFG signal
- 2 no-grad forwards (teacher + fake) for the distribution-matching signal
- 2 fake-update forwards (the fake trains on the student's detached output)
All on the same ~5 GB of weights, view-toggled. No second model is ever loaded.
What the Two Signals Are For
DMD2's gradient on the student factors into two terms with different jobs — the "spear and shield" of the paper's title:
- CA (CFG-augmentation) is the engine. is the classifier-free guidance direction. Baking times this into the student is what converts a many-step model into a few-step one, and it does so fast. Left alone, it collapses into artifacts after a few thousand steps.
- DM (distribution-matching) is the shield. measures how far the student's distribution has drifted from the teacher's, and cancels the artifacts CA introduces. It's a regularizer, not the engine.
The "decoupled" part is that the two terms are evaluated at different re-noise levels, sampled per step. To score x_pred at a fresh noise level τ, we re-noise it — the same forward path applied to the predicted clean image:
- CA re-noises strictly noisier than the current step (). This focuses the engine on still-unresolved content, and dodges a subtle trap: at small generator-
tthe student sits near the regime where the teacher's iterative state has drifted out of the training distribution, so evaluating the CFG gap there would read off-manifold scores. Renoising up puts the teacher back where its score is well-calibrated. (We skip CA entirely whent > 0.95, where theU(t,1)interval collapses to pure noise.) - DM spans the full range () so it can fix global artifacts — color drift, oversaturation — the student carries regardless of step.
The sign is load-bearing
Anima predicts velocity () on a flow-matching path, but the DMD math is written for score/ε-prediction. Re-deriving the update in velocity/x0 space, converting a velocity gap to its x0 gap at level τ picks up a factor. We want x_pred to move toward the teacher's clean endpoint, so the surrogate loss is assembled so gradient descent steps in that direction:
grad_signal = tau_dm * delta_dm + tau_ca * (alpha_eff - 1) * delta_cfg
loss_student = (grad_signal.detach() * x_pred).mean() # ∂/∂x_pred = grad_signal
loss_student.backward() # walks x_pred → v_student → θ
This sign was once inverted, and the failure mode is instructive: an inverted student gradient doesn't blow up — it produces output that looks like "base model, 4-step, never trained" (blurry, un-distilled). A blow-up would have been easier to catch. If your distillation looks like it's doing nothing, check the sign before you check the learning rate.
Two Stability Levers That Earned Their Keep
α warmup. The CFG strength is a large kick. Applied from step 0, it NaNs a low-capacity LoRA student before any image structure forms. So ramps linearly over the first 1000 steps. At step 0 the spear term vanishes and the student starts inside the regime the shield can regularize. This is structural, not a band-aid — the LoRA student has far less capacity than the paper's full-finetune student, so it needs the gentler entry.
Fake warmup + a hotter fake. The fake's target is moving — the student sharpens under it every step — so the fake has to stay ahead. Three things keep it there: it runs 2 inner updates per student update, at a higher LR (3e-5 vs the student's 2e-5), and it gets a 50-step head-start before the main loop so the critic is calibrated against the initial (≈ teacher) output distribution before the student starts moving. That head-start kills an early gradient spike around step 50.
A metric note that cost us time: do not trigger fake interventions on rising fake_loss. Against a moving, sharpening student, a rising fake loss is expected equilibrium. The real triggers are dm_rel_gap rising (the fake is falling behind) or dm_cos dropping (the fake is pointing the wrong way). The loss going up can mean everything is working.
One A/B Worth Reporting: Normalize the Shield
DMD's original recipe normalizes the DM gradient per-sample for scale-invariance. There are two ways to weight that term, and they turned out to be alternatives, not a stack:
- (a) τ-damping — use directly.
- (b) per-sample x0-norm — divide by the per-sample magnitude . Because the τ roughly cancels, (b) is essentially "drop the τ-weight and magnitude-normalize."
(b) won decisively, and the signal only showed up in multi-seed samples, not in any scalar metric. Under (a), outputs collapsed to near-identical images across seeds — classic DMD mode-seeking — and text rendering blurred. The τ-damp multiplied by raw magnitude over-weights high-‖v‖ samples (high-frequency structure, text, off-mode tails) and over-pulls them to the dominant mode. (b)'s per-sample normalization gives every sample a unit-scale direction, so distribution-matching pressure is even: tails (seed diversity) and fine structure (text) survive.
We ruled out the effective-LR confound — (b) runs ~2× the DM-grad magnitude, and more DM pressure should mean more mode-seeking, yet diversity went up. So scale-invariance is the driver, not the magnitude bump. The per-step health scalars (x_pred_std, dm_cos) were blind to this entirely; the only decisive signal was eyeballing 4-step samples across seeds. Worth remembering when a metric dashboard says two configs are identical.
It Ships as a Plain LoRA
The fake is training scaffolding and is never saved. save_student serializes only the student, in the standard LoRA layout. The result is an ordinary rank-64 LoRA with CFG=4 baked in — you load it through the existing adapter path and run:
--infer_steps 4 --cfg 1.0
No turbo-specific code runs at inference. It composes with concept LoRAs linearly, the same way LCM-LoRA composes with a style LoRA (ranks add). The plain-LoRA bake is also the hard constraint on what the method can't do: anything needing a step-size or per-t input at inference — Shortcut/MeanFlow Δt-conditioning, timestep-conditioned masks — gives nothing after the bake, because a plain LoRA must average antagonistic per-t corrections into one static delta. True multi-stride (2/4/8-step) robustness is out of scope by construction.
Summary
The shape of the win is the same one that makes LoRA finetuning fit on consumer cards in the first place: keep the big thing frozen and resident, train only the small deltas, and never hold two copies of the big thing at once. DMD2 looks like it demands three models, but two of them are the same weights wearing different adapters, and the third is the bare weights — so 16 GB is enough.