Skip to content

fix(dreambooth): batch size mismatch with --with_prior_preservation in flux2 scripts#13307

Open
agarwalprakhar2511 wants to merge 1 commit intohuggingface:mainfrom
agarwalprakhar2511:fix/dreambooth-prior-preservation-batch-mismatch
Open

fix(dreambooth): batch size mismatch with --with_prior_preservation in flux2 scripts#13307
agarwalprakhar2511 wants to merge 1 commit intohuggingface:mainfrom
agarwalprakhar2511:fix/dreambooth-prior-preservation-batch-mismatch

Conversation

@agarwalprakhar2511
Copy link

@agarwalprakhar2511 agarwalprakhar2511 commented Mar 22, 2026

Summary

When --with_prior_preservation is enabled in the Flux2 dreambooth LoRA training scripts, the prompt embedding repeat logic double-counts the batch size, producing a shape mismatch against the latent tensor.

Root cause: collate_fn appends class prompts to the instance prompts list (doubling len(prompts)), but prompt_embeds is already doubled earlier via torch.cat([instance_embeds, class_embeds]). Using the full len(prompts) as the repeat count produces 4 embeddings for 2 latents at batch_size=1.

Fix: Use len(prompts) // 2 when args.with_prior_preservation is active, so the repeat count matches the number of unique prompt groups rather than the doubled collated list.

Applied to all three affected scripts:

  • examples/dreambooth/train_dreambooth_lora_flux2_klein.py
  • examples/dreambooth/train_dreambooth_lora_flux2.py
  • examples/dreambooth/train_dreambooth_lora_flux2_klein_img2img.py

Test plan

Dimension trace verified for all four scenarios:

batch_size with_prior_preservation len(prompts) prompt_embeds pre-repeat Repeat count (old → new) Result (old → new)
1 True 2 [2, seq, H] 2 → 1 [4, seq, H] MISMATCH → [2, seq, H] OK
2 True 4 [2, seq, H] 4 → 2 [8, seq, H] MISMATCH → [4, seq, H] OK
1 False 1 [1, seq, H] 1 → 1 [1, seq, H] OK → [1, seq, H] OK
2 False 2 [1, seq, H] 2 → 2 [2, seq, H] OK → [2, seq, H] OK

Fixes #13292

…n flux2 scripts

When `--with_prior_preservation` is enabled, `collate_fn` doubles the
`prompts` list (instance + class), but `prompt_embeds` was already
doubled via `torch.cat([instance, class])` during pre-computation.
Using `len(prompts)` as the repeat count produces 2x too many
embeddings, causing a shape mismatch against the latent batch.

Fix: use `len(prompts) // 2` when prior preservation is active, so the
repeat count matches the actual number of unique prompt groups rather
than the doubled collated list.

Applied to all three affected scripts:
- train_dreambooth_lora_flux2_klein.py
- train_dreambooth_lora_flux2.py
- train_dreambooth_lora_flux2_klein_img2img.py

Fixes huggingface#13292

Made-with: Cursor
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Bug] train_dreambooth_lora_flux2_klein.py: batch size mismatch with --with_prior_preservation

1 participant