diff --git a/examples/dreambooth/train_dreambooth_lora_flux2.py b/examples/dreambooth/train_dreambooth_lora_flux2.py index 24d098add017..1c860cdd5d41 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux2.py +++ b/examples/dreambooth/train_dreambooth_lora_flux2.py @@ -1734,7 +1734,10 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): prompt_embeds = prompt_embeds_cache[step] text_ids = text_ids_cache[step] else: - num_repeat_elements = len(prompts) + # With prior preservation, prompt_embeds already contains [instance, class] embeddings + # from the cat above, but collate_fn also doubles the prompts list. Use half the + # prompts count to avoid a 2x over-repeat that produces more embeddings than latents. + num_repeat_elements = len(prompts) // 2 if args.with_prior_preservation else len(prompts) prompt_embeds = prompt_embeds.repeat(num_repeat_elements, 1, 1) text_ids = text_ids.repeat(num_repeat_elements, 1, 1) diff --git a/examples/dreambooth/train_dreambooth_lora_flux2_klein.py b/examples/dreambooth/train_dreambooth_lora_flux2_klein.py index 268d0148e446..d90b7425912f 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux2_klein.py +++ b/examples/dreambooth/train_dreambooth_lora_flux2_klein.py @@ -1674,7 +1674,10 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): prompt_embeds = prompt_embeds_cache[step] text_ids = text_ids_cache[step] else: - num_repeat_elements = len(prompts) + # With prior preservation, prompt_embeds already contains [instance, class] embeddings + # from the cat above, but collate_fn also doubles the prompts list. Use half the + # prompts count to avoid a 2x over-repeat that produces more embeddings than latents. + num_repeat_elements = len(prompts) // 2 if args.with_prior_preservation else len(prompts) prompt_embeds = prompt_embeds.repeat(num_repeat_elements, 1, 1) text_ids = text_ids.repeat(num_repeat_elements, 1, 1) diff --git a/examples/dreambooth/train_dreambooth_lora_flux2_klein_img2img.py b/examples/dreambooth/train_dreambooth_lora_flux2_klein_img2img.py index 0205f2e9e65f..bce2cca3b25c 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux2_klein_img2img.py +++ b/examples/dreambooth/train_dreambooth_lora_flux2_klein_img2img.py @@ -1618,7 +1618,10 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): prompt_embeds = prompt_embeds_cache[step] text_ids = text_ids_cache[step] else: - num_repeat_elements = len(prompts) + # With prior preservation, prompt_embeds already contains [instance, class] embeddings + # from the cat above, but collate_fn also doubles the prompts list. Use half the + # prompts count to avoid a 2x over-repeat that produces more embeddings than latents. + num_repeat_elements = len(prompts) // 2 if args.with_prior_preservation else len(prompts) prompt_embeds = prompt_embeds.repeat(num_repeat_elements, 1, 1) text_ids = text_ids.repeat(num_repeat_elements, 1, 1)