The files in configs/ define reproducible experiment presets using
ml_collections.ConfigDict. Each module exposes a single get_config
function that returns a nested configuration consumed by main.py and the
trainers under trainers/.
Pass the Python file and preset identifier when launching a job:
python main.py \
--config=configs/dit_imagenet.py:imagenet_256-XL_2 \
--bucket=$GCS_BUCKET \
--workdir=exp_nameThe suffix imagenet_256-XL_2 is split into an input dataset preset
(imagenet_256) and a model preset (XL_2). You can override any field from
the command line, e.g. --config.interface.train_time_dist_type=uniform.
Reusable dictionaries describing ImageNet data locations, encoder choices, and
network sizes. Config modules clone these entries into ConfigDict instances,
so edits made in common_specs.py immediately propagate to all presets.
_imagenet_data_presets: file-system paths, image resolution, batch size, latent vs raw pipeline flags, and cached statistics per variant._imagenet_encoder_presets/_imagenet_rae_encoder_presets: encoder class names plus keyword arguments (Stability VAE, RGB passthrough, or RAE)._dit_network_presetsand_ddt_network_presets: transformer widths, depths, and attention heads for DiT, Lightning DiT, and Lightning DDT models.
dit_imagenet.py: Baseline DiT training loop. Enables logit-normal time sampling, Heun-32 evaluation, EMA tracking, and optional evaluation jobs. The preset controls sharding (replicate by default) and defines standard logging intervals.lightning_dit_imagenet.py: Drops into the Lightning DiT architecture with continuous time embeddings, RoPE attention, SwiGLU MLPs, and a slightly higher learning rate (2e-4). Keeps the DiT sampler/eval defaults.lightning_ddt_imagenet.py: Targets the diffusion-decoder transformer (DDT). Shares Lightning tweaks but uses_ddt_network_presets, provides explicit sampling range overrides, and reduces EMA decay to0.9995.dit_imagenet_repa.py: Wraps the DiT preset with REPA alignment. Adds a DINOv2 feature detector, REPA loss coefficients, Euler-Maruyama-250 sampling, and switches the sharding strategy to FSDP along thedataaxis.mf_imagenet.py: Extends the DiT preset for MeanFlow training. Updates the interface to MeanFlow-specific guidance parameters, enablestake_dton the network, swaps in the Euler-Jump sampler, and keeps EMA/visualization in sync with the base preset.rae_imagenet.py: Configures an RAE autoencoder with the Lightning DDT backbone. Uses the RAE encoder preset, increases sampler steps to 50, enables reconstruction visualization, and leavesconfig.pretrained_ckptempty for caller-provided checkpoints.
Every preset returns a ConfigDict with consistent top-level sections. The
exact knobs vary per experiment, but the following structure is shared across
the shipped configs:
trainer/exp_name/project_name/seed: global metadata used bymain.pyand logging utilities to find the correct trainer and namespace checkpoints.data: dataloader paths, image size, batch size, and latent/raw toggles for choosing the pixel dataset / pre-extracted SD-VAE latents. Most fields come directly from_imagenet_data_presets, so new presets usually just point to a different preset key.encoder: choice of latent encoder plus keyword arguments (e.g. RGB encoder, Stability VAE statistics or RAE checkpoint paths). Values are cloned from the encoder presets incommon_specs.py.network: model hyperparameters such as patch size, hidden size, depth, attention heads, and feature flags (Lightning-specific options, MeanFlow switches, etc.). Presets swap between_dit_network_presetsand_ddt_network_presetsdepending on the architecture.interface: diffusion/flow formulation parameters. This includes the interface class (sit,mean_flow, etc.), time/noise distribution settings, and any auxiliary guidance knobs. Specialized configs (REPA, MeanFlow) extend this block with extra keys.optimizer/learning_rate: optimizer type and schedule metadata. These keys feed directly intoutils.initializehelpers.sampler: evaluation sampler selection (heun,euler,euler-maruyama,euler_jump) and step counts. Optionalsampling_time_kwargstune the time grid.ema: exponential moving average configuration (ema_class,decay).checkpoint: options forutils.checkpoint, including save cadence, retention count, and async flags.visualize: toggles for on-host sampling, guidance scale, and the reconstruction flag for latent encoders.eval: evaluation loop parameters—detector choice, guidance sweeps, sample counts, and cadence.sharding: mesh definition and partitioning rules. DiT defaults to replication, while REPA and Lightning variants demonstrate the FSDP template.- Optional blocks such as
repaorrepa.detector, and experiment-specific fields likepretrained_ckpt, can be added as newConfigDicts hanging off the root.
Each *_class attribute is required to be defined as a string and to be looked up in initialize.py.
We'll work to update it to be a Lazy Init object for better scalability.
- Add new dataset or model entries to
common_specs.py. - Create a new
get_configwrapper (or modify an existing one) to reference those presets. - Document the invocation pattern so it can be launched via
python main.py --config=configs/<file>.py:<preset>. - Keep overrides CLI-friendly: every new leaf should be reachable through a dotted flag.