From 75f700737cca7cfc297ba58bff7264065e3bee3b Mon Sep 17 00:00:00 2001 From: Prakhar Agarwal Date: Sun, 22 Mar 2026 02:10:04 -0700 Subject: [PATCH] fix(dreambooth): batch size mismatch with --with_prior_preservation in flux2 scripts When `--with_prior_preservation` is enabled, `collate_fn` doubles the `prompts` list (instance + class), but `prompt_embeds` was already doubled via `torch.cat([instance, class])` during pre-computation. Using `len(prompts)` as the repeat count produces 2x too many embeddings, causing a shape mismatch against the latent batch. Fix: use `len(prompts) // 2` when prior preservation is active, so the repeat count matches the actual number of unique prompt groups rather than the doubled collated list. Applied to all three affected scripts: - train_dreambooth_lora_flux2_klein.py - train_dreambooth_lora_flux2.py - train_dreambooth_lora_flux2_klein_img2img.py Fixes https://github.com/huggingface/diffusers/issues/13292 Made-with: Cursor --- examples/dreambooth/train_dreambooth_lora_flux2.py | 5 ++++- examples/dreambooth/train_dreambooth_lora_flux2_klein.py | 5 ++++- .../dreambooth/train_dreambooth_lora_flux2_klein_img2img.py | 5 ++++- 3 files changed, 12 insertions(+), 3 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_flux2.py b/examples/dreambooth/train_dreambooth_lora_flux2.py index 24d098add017..1c860cdd5d41 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux2.py +++ b/examples/dreambooth/train_dreambooth_lora_flux2.py @@ -1734,7 +1734,10 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): prompt_embeds = prompt_embeds_cache[step] text_ids = text_ids_cache[step] else: - num_repeat_elements = len(prompts) + # With prior preservation, prompt_embeds already contains [instance, class] embeddings + # from the cat above, but collate_fn also doubles the prompts list. Use half the + # prompts count to avoid a 2x over-repeat that produces more embeddings than latents. + num_repeat_elements = len(prompts) // 2 if args.with_prior_preservation else len(prompts) prompt_embeds = prompt_embeds.repeat(num_repeat_elements, 1, 1) text_ids = text_ids.repeat(num_repeat_elements, 1, 1) diff --git a/examples/dreambooth/train_dreambooth_lora_flux2_klein.py b/examples/dreambooth/train_dreambooth_lora_flux2_klein.py index 268d0148e446..d90b7425912f 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux2_klein.py +++ b/examples/dreambooth/train_dreambooth_lora_flux2_klein.py @@ -1674,7 +1674,10 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): prompt_embeds = prompt_embeds_cache[step] text_ids = text_ids_cache[step] else: - num_repeat_elements = len(prompts) + # With prior preservation, prompt_embeds already contains [instance, class] embeddings + # from the cat above, but collate_fn also doubles the prompts list. Use half the + # prompts count to avoid a 2x over-repeat that produces more embeddings than latents. + num_repeat_elements = len(prompts) // 2 if args.with_prior_preservation else len(prompts) prompt_embeds = prompt_embeds.repeat(num_repeat_elements, 1, 1) text_ids = text_ids.repeat(num_repeat_elements, 1, 1) diff --git a/examples/dreambooth/train_dreambooth_lora_flux2_klein_img2img.py b/examples/dreambooth/train_dreambooth_lora_flux2_klein_img2img.py index 0205f2e9e65f..bce2cca3b25c 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux2_klein_img2img.py +++ b/examples/dreambooth/train_dreambooth_lora_flux2_klein_img2img.py @@ -1618,7 +1618,10 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): prompt_embeds = prompt_embeds_cache[step] text_ids = text_ids_cache[step] else: - num_repeat_elements = len(prompts) + # With prior preservation, prompt_embeds already contains [instance, class] embeddings + # from the cat above, but collate_fn also doubles the prompts list. Use half the + # prompts count to avoid a 2x over-repeat that produces more embeddings than latents. + num_repeat_elements = len(prompts) // 2 if args.with_prior_preservation else len(prompts) prompt_embeds = prompt_embeds.repeat(num_repeat_elements, 1, 1) text_ids = text_ids.repeat(num_repeat_elements, 1, 1)