Skip to content

FLUX.2-klein-base-9B full fine tune hangs before starting #511

@Joakim-L

Description

@Joakim-L

I'm trying to do a full fine tune on Klein base 9B but it gets stuck without any errors.

PyTorch version: 2.8.0+cu128

VRAM and GPU utilization never changes from this after it get stuck:

+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 570.195.03             Driver Version: 570.195.03     CUDA Version: 12.8     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA A40                     On  |   00000000:CE:00.0 Off |                    0 |
|  0%   44C    P0            115W /  300W |    9051MiB /  46068MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
|   1  NVIDIA A40                     On  |   00000000:D1:00.0 Off |                    0 |
|  0%   45C    P0            117W /  300W |    9471MiB /  46068MiB |    100%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
|   2  NVIDIA A40                     On  |   00000000:D2:00.0 Off |                    0 |
|  0%   45C    P0            115W /  300W |   17087MiB /  46068MiB |    100%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+

It logs up until this point and then nothing else happens:

[2026-02-15 15:24:45,698] [INFO] [engine.py:104:__init__] CONFIG: micro_batches=1 micro_batch_size=1
[2026-02-15 15:24:45,698] [INFO] [engine.py:145:__init__] is_pipe_partitioned= False is_grad_partitioned= False
[2026-02-15 15:24:45,739] [INFO] [config.py:687:__init__] Config mesh_device None world_size = 1
[2026-02-15 15:24:45,747] [INFO] [engine.py:145:__init__] is_pipe_partitioned= False is_grad_partitioned= False
[2026-02-15 15:24:45,838] [INFO] [config.py:687:__init__] Config mesh_device None world_size = 1
[2026-02-15 15:24:45,846] [INFO] [engine.py:145:__init__] is_pipe_partitioned= False is_grad_partitioned= False
[2026-02-15 15:24:45,883] [INFO] [engine.py:164:__init__] RANK=1 STAGE=1 LAYERS=13 [7, 20) STAGE_PARAMS=3103788544 (3103.789M) TOTAL_PARAMS=9078581248 (9078.581M) UNIQUE_PARAMS=9078581248 (9078.581M)
[2026-02-15 15:24:45,883] [INFO] [engine.py:164:__init__] RANK=2 STAGE=2 LAYERS=15 [20, 35) STAGE_PARAMS=3087535616 (3087.536M) TOTAL_PARAMS=9078581248 (9078.581M) UNIQUE_PARAMS=9078581248 (9078.581M)
[2026-02-15 15:24:45,883] [INFO] [engine.py:164:__init__] RANK=0 STAGE=0 LAYERS=7 [0, 7) STAGE_PARAMS=2887257088 (2887.257M) TOTAL_PARAMS=9078581248 (9078.581M) UNIQUE_PARAMS=9078581248 (9078.581M)
Global batch size = 1
[2026-02-15 15:24:45,930] [INFO] [logging.py:123:log_dist] [Rank 0] Using client callable to create basic optimizer
[2026-02-15 15:24:45,930] [INFO] [logging.py:123:log_dist] [Rank 0] Removing param_group that has no 'params' in the basic Optimizer
[2026-02-15 15:24:45,931] [INFO] [logging.py:123:log_dist] [Rank 0] DeepSpeed Basic Optimizer = AdamW
[2026-02-15 15:24:45,931] [INFO] [logging.py:123:log_dist] [Rank 0] DeepSpeed Final Optimizer = AdamW
Global batch size = 1
Global batch size = 1

Config:

output_dir = '/workspace/output'
dataset = 'examples/dataset.toml'

micro_batch_size_per_gpu = 1
pipeline_stages = 3
gradient_accumulation_steps = 1
gradient_clipping = 1.0

save_every_n_epochs = 5
activation_checkpointing = true
reentrant_activation_checkpointing = true

partition_method = 'parameters'

save_dtype = 'bfloat16'
caching_batch_size = 4
steps_per_print = 1

[model]
type = 'flux2'
diffusion_model = '/workspace/flux-2-klein-base-9b.safetensors'
vae = '/workspace/split_files/vae/flux2-vae.safetensors'
text_encoders = [
    {path = '/workspace/split_files/text_encoders/qwen_3_8b.safetensors', type = 'flux2'}
]
dtype = 'bfloat16'
diffusion_model_dtype = 'bfloat16'
timestep_sample_method = 'logit_normal'
shift = 3

[optimizer]
type = 'adamw_optimi'
lr = 1e-6 
betas = [0.9, 0.99]
weight_decay = 0.01
eps = 1e-8

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions