Back

2026-04-25

Training HydraLoRA for a Diffusion Model

The problem: style bleed

Train a single LoRA on three artists and the optimizer's best rank-rr compromise is the common subspace — the union of three styles becomes a muddled in-between average. Distinct fingerprints disappear. This is the well-known style-bleed failure of plain LoRA when the dataset is multi-modal.

The structural fix is a mixture of experts on the up-projection: keep the rank-rr bottleneck shared, but give the up-matrix several heads and let a router pick a per-sample mixture. That's HydraLoRA. On paper, sample xix_i from artist A and sample xjx_j from artist B can push the router toward different experts, those experts then receive differentiated gradient updates, and per-expert specialization emerges.

In practice, getting that to actually happen took five attempts.


Plain HydraLoRA

The architecture is small. Per adapted Linear:

ComponentShapeRole
lora_down(r, din)(r,\ d_\text{in})shared down-projection
lora_up_weight(E, dout, r)(E,\ d_\text{out},\ r)stacked per-expert up heads
router.weight(E, r)(E,\ r)layer-local gate logits
router.bias(E,)(E,)zero-init

Forward:

lx       = F.linear(x, lora_down.weight)                  # (B, L, r)
pooled   = rms_pool_over_seq(lx)                          # (B, r)
gate     = softmax(router(pooled), dim=-1)                # (B, E)
combined = einsum("be,eod->bod", gate, lora_up_weight)    # (B, d_out, r)
out      = bmm(lx_3d, combined.transpose(1, 2))           # (B, L, d_out)
return org_forward(x) + out * multiplier * scale

Layer-local: every adapted Linear has its own router, so the same sample can get different gates at different layers. Specialization is learned per-layer, not as a global "style pick" broadcast everywhere.

The router input — RMS-pooled rank-rr activation — is itself the result of an earlier failure (mean-pooling raw dind_\text{in} inputs cancels by L\sqrt{L} over a 4096-token sequence; the router never received gradient and the balance loss pinned all gates to uniform). RMS over rank-rr keeps signal alive: squaring before averaging defeats the sign cancellation, and lora_down is small enough that bf16 softmax stays stable. That fix was a precondition. With it, the router can learn — but it still needs something to differentiate.


Attempt 1: zero-init experts (the cold-start deadlock)

The natural LoRA-safe init is B=0B = 0, so ΔW=0\Delta W = 0 at step 0. With HydraLoRA that means all EE expert heads start identical. Under a near-uniform router, each expert receives identical gradient. They evolve permutation-symmetrically forever, the router has no signal to differentiate them, and the network trains as a single LoRA averaged across EE heads — paying E×E\times the parameter cost for no specialization.

The Switch-Transformer balance loss does not save you here. It penalizes uneven dominant-expert distributions; it does nothing to break the symmetry that creates uniform gates in the first place. We confirmed the failure on anima-hydra-0420-644: router.weight\|\text{router.weight}\| never moved from Kaiming init, median normalized gate entropy sat at 1.0000, dominant-top1 fraction ≈ 2e-4. Four heads, one effective expert.

Something has to break the symmetry. It cannot come from the loss. It has to come from initialization or from the training schedule itself.


Attempt 2: jitter the experts (expert_init_std)

The first thing we tried — and the most obvious one — was a small Gaussian perturbation on lora_up_weight at init:

self.lora_up_weight = torch.nn.Parameter(
    torch.zeros(E, out_dim, r) + std * torch.randn(E, out_dim, r)
)

This breaks ΔW=0\Delta W = 0 at step 0 by a small amount ϵ\epsilon, and crucially makes the experts numerically distinct from each other. The hope was: tiny initial differences would be amplified by the router's per-expert gradient, the experts would diverge, the balance loss would keep them all alive, and we'd be off to the races.

It didn't work. Bench 0424-484 showed the same plateau: experts still collapsed, gates still drifted to uniform under the balance pressure. Worse, it was misleading — measured expert divergence (e.g. cosine similarity between expert ups) looked healthy at init, so we kept attributing the failure to other things.

The diagnosis: jittering breaks expert symmetry, but the failure mode wasn't expert-side. It was router-side. The router sees softmax(router_weight @ pooled + bias). At init, router_weight ~ Kaiming(std=1/√r) is small, the bias is zero, and the per-expert score variance across the batch is dominated by content noise, not by which expert is which. Even with ϵ\epsilon-jittered experts, the per-sample expert scores are near-identical at step 0, so the gate is near-uniform, all experts get near-identical gradient, and they re-converge to the same direction faster than the jitter can amplify.

In short: ϵ\epsilon-distinct experts plus a router that can't tell them apart still gives you uniform routing. We removed expert_init_std (2026-04-24).


Attempt 3: orthogonalize each expert via SVD slicing

The breakthrough was realizing the symmetry break has to live in the output subspace each expert writes to, not in their numerical values.

Build on OrthoLoRA's Cayley re-parameterization: factor each LoRA delta as ΔW=PSQ\Delta W = P \cdot S \cdot Q^\top where PRdout×rP \in \mathbb{R}^{d_\text{out} \times r} and QRdin×rQ \in \mathbb{R}^{d_\text{in} \times r} are tied to the top-rr left/right singular vectors of the frozen W0W_0, and SS is a learned r×rr \times r map that's Cayley-parameterized to stay orthogonal. The bases are frozen; the rotation is learned.

OrthoHydra makes one critical change: instead of every expert sharing the same P_basis, partition the top-(Er)(E \cdot r) left singular vectors of W0W_0 into EE disjoint slices of rr columns each:

Pbases[e]Rdout×r,Pbases[i]Pbases[j] = 0for ijP_\text{bases}[e] \in \mathbb{R}^{d_\text{out}\times r}, \qquad P_\text{bases}[i]^{\top}\, P_\text{bases}[j]\ =\ 0 \quad \text{for}\ i \ne j

Each expert then rotates inside its own slice via its own Cayley rotation Rp[e]SO(r)R_p[e] \in \text{SO}(r). Because the rotation stays inside the slice, cross-expert orthogonality is preserved through all of training.

This is a structural deadlock breaker:

  • With a shared basis, every Peff[e]P_\text{eff}[e] lives in the same rank-rr column span. The Gram matrix Peff[i]Peff[j]P_\text{eff}[i]^\top P_\text{eff}[j] is itself an orthogonal matrix — it cannot be zero. The router's per-expert score is near-identical at init, no gradient differentiates experts, deadlock.
  • With disjoint slices, Peff[i]Peff[j]=0P_\text{eff}[i]^\top P_\text{eff}[j] = 0 at init and forever. Each expert writes into a distinct output subspace. The router has signal to latch onto before any expert has been trained, because experts are already different in the only sense that matters: where their output lives.

The shared lora_down keeps the input-side Q_basis shared too — that matches "common features in shared down, specialization in per-expert up." Only the up-projection partitions.

Activated by use_ortho = true and use_hydra = true. This is now the configured default.


Fallback: expert_warmup_ratio for narrow layers

The disjoint partition needs min(din,dout)Er\min(d_\text{in}, d_\text{out}) \ge E \cdot r. If the layer is too narrow — say a final projection where dout<Erd_\text{out} < E \cdot r — the partition can't fit, P_bases falls back to the legacy shared P_basis replicated EE times, and we're back in the deadlock.

For these narrow layers we keep a training-side schedule called expert-warmup:

# For the first warmup_ratio * max_train_steps steps:
#   - forward uses ALL experts via the learned gate (router still sees real diversity)
#   - backward zeros gradient on all but one randomly-chosen expert per module per step

Each expert is guaranteed to accumulate gradient from a distinct subset of training samples during the warmup window, even if the router is uniform, because we forcibly make their gradient masks disjoint over time. It's expensive (only one expert learns per step) but it breaks the deadlock without touching architecture.

Two implementation notes that mattered for torch.compile:

  • The active-expert pick is held in a per-module buffer whose value mutations don't trigger dynamo recompiles.
  • The "are we in warmup" gate is a Python bool that flips exactly twice per run (enter + leave). No per-step recompile storm.

For wide layers with disjoint slices, expert_warmup_ratio is redundant — the structural break already does the job. For narrow-layer fallback, it's load-bearing: do not run those layers with expert_warmup_ratio = 0.


σ-conditional routing (borrowed from T-LoRA)

T-LoRA taught us something useful for the router. Diffusion timesteps are not interchangeable: at high noise (σ1\sigma \to 1 in our convention) the model is reshaping coarse structure (composition, silhouette); at low noise (σ0\sigma \to 0) it's refining fine detail (texture, edges). A LoRA that should specialize might want different experts at different noise levels — an "early-step composition expert" and a "late-step texture expert" are both reasonable specializations, and they're orthogonal to per-artist specialization.

The architectural change is small. Concatenate sinusoidal features of σ to the router input:

sigma_feat = sinusoidal(sigma, sigma_feature_dim)          # (B, F)
router_in  = cat([rms_pooled_lx, sigma_feat], dim=-1)      # (B, r + F)
gate       = softmax(router(router_in), dim=-1)            # (B, E)

router becomes a Linear(r + F, E) — a few extra columns. The σ-feature columns are zero-init, so step-0 behavior is identical to the no-σ router; σ-dependence only emerges if the gradient pushes those columns to non-zero values.

A previous design used a separate σ→bias MLP that added a bias to the gate logits. It didn't work, for the same reason expert_init_std didn't work: a bias-only σ path's gradient is dL/d_logits · d_sigma_feat, which vanishes whenever experts are undifferentiated (score_e ≈ const makes dL/d_logit_e ≈ 0). The chicken-and-egg problem: σ routing can't emerge until experts differ, but experts won't differ if the σ path can't influence which one to use.

Feeding σ into the router's input dodges this. The σ columns train alongside the content columns on the same chain rule, so σ routing emerges as soon as the router learns anything at all. With disjoint-slice ortho experts already differentiated at init, that's step 1.

Configuration in our hydralora_sigma.toml:

use_hydra            = true
use_ortho            = true
use_sigma_router     = true
sigma_feature_dim    = 16    # 8 cos/sin pairs
num_experts          = 4
balance_loss_weight  = 1e-5  # see below

Where we are now

Stacking these:

ComponentRole
Shared lora_downCommon features across experts
Per-expert lora_upSpecialization
Layer-local routerPer-sample, per-layer gate over experts
RMS over rank-rrPool that survives L\sqrt{L} cancellation in bf16
Disjoint SVD slicesStructural symmetry break — wide layers
expert_warmupSchedule-side symmetry break — narrow-layer fallback
σ-concat routerLets gate vary with denoising timestep
Switch balance lossKeeps experts alive once they've started differentiating (1e51\mathrm{e}{-5}, was 1e21\mathrm{e}{-2})

The exit criteria we're targeting on the next retrain:

  • router.weight\|\text{router.weight}\| at final step >1.5×> 1.5\times init,
  • median normalized gate entropy [0.6, 0.95]\in [0.6,\ 0.95],
  • mean dominant-top1 >0.2> 0.2,
  • zero dead experts,
  • generation quality \ge non-Hydra LoRA baseline,
  • ComfyUI live-routing visually matches CLI at strength 1.0.

Lessons

  1. The router is the hard part, not the experts. Three of the four failures (mean-pool, jitter, σ-bias-MLP) came from the router being unable to act on differences that did exist. Differentiating experts is necessary but not sufficient.

  2. Symmetry breaks must be structural where possible. Cold-start deadlocks resist gradient-based escape because the gradient itself is symmetric. Disjoint output subspaces break the symmetry at t=0t = 0 in the only place that the router can actually see — the output direction each expert writes to.

  3. Schedule-side fallbacks are cheap insurance. expert_warmup_ratio is overkill for layers where the structural break already works, but it's free to leave on, and it's the only thing keeping narrow-layer fallback alive.

  4. Borrow from your other adapters. T-LoRA's "the bottleneck behaves differently at different timesteps" insight ports directly to the router, with one change (σ in the input, not the bias) to dodge a vanishing-gradient trap that the static-weights setting doesn't have.

The MoE LoRA literature mostly assumes the router will figure itself out. For T2I diffusion at our parameter scale, with frozen base weights and a 4096-token sequence cancelling most of the natural router signal, it really won't. Every piece above earns its place.