diff --git a/examples/puzzletron/BYPASS.md b/examples/puzzletron/BYPASS.md new file mode 100644 index 0000000000..8a4bc9d249 --- /dev/null +++ b/examples/puzzletron/BYPASS.md @@ -0,0 +1,159 @@ +# Bypass Distillation (Blockwise Local Distillation) + +Bypass distillation (also called **Blockwise Local Distillation / BLD**) is an optional pipeline +stage that trains alternative transformer block configurations using per-block knowledge +distillation from the teacher model. It significantly improves the quality of aggressively +compressed models by producing better "puzzle pieces" for the MIP solver. + +## When to use bypass + +Bypass is most beneficial whenever the pruned block structure deviates significantly from the +teacher — either because the weight-initialisation heuristic is too coarse, or because one +sub-block must compensate for something the other no longer provides. Specifically, use bypass +when: + +- **KV head reduction (any amount)**: the `AverageKV` initialisation is a naive starting point + that averages existing KV heads together. The resulting weights are a poor local minimum and + bypass distillation is needed to repair the quality loss. This applies even to moderate + reductions (e.g., 8 → 4 heads). +- **Attention removed (`no_op: true`)**: removing an entire attention block leaves the co-located + FFN doing all the work for that block. Bypass trains the FFN to compensate for the missing + attention and recover the representational capacity. +- **FFN removed (`no_op: true`)**: similarly, when an FFN block is removed, bypass trains the + remaining attention to compensate. +- **Extreme FFN / MoE compression**: when the target `intermediate_size` is reduced by more than + ~3/4 of the teacher width, or the number of MoE experts is reduced by half or more, simple + weight truncation / expert selection leaves the block far from a good solution and bypass + significantly improves quality. For example, on Llama-3.1-8B (`intermediate_size=14336`), + bypass is strongly recommended for sizes ≤ 3584. + +## Time cost + +Bypass distillation is a full training loop. Plan for several hours per configuration when using +~1B training tokens on H100 GPUs. Total time scales with +`len(bypass.configs) × training_tokens`. This is comparable to lightweight fine-tuning. + +## Sequential execution + +Each entry in `bypass.configs` trains **sequentially** (one config at a time). There is no +parallelism across configurations. Distribute jobs across different runs if time is a +constraint. + +## Enabling bypass + +In your concrete model YAML, uncomment the bypass line: + +```yaml +defaults: + - Llama-3_1-8B + - bypass: defaults # remove the comment to enable bypass distillation + - _self_ +``` + +A shared `bypass/defaults.yaml` is located at +[`configs/bypass/defaults.yaml`](configs/bypass/defaults.yaml). It is used by all models. +Adjust `training.training_tokens` (default is 10K tokens for sanity-check runs; set to `1e+9` +for production runs) and the `auto_configs` or `configs` settings to match your compression +targets. + +## Decoupled vs. coupled BLD + +**Decoupled BLD** trains only one sub-block type at a time while keeping the other frozen: + +| `keys_to_learn` | What is trained | +|---|---| +| `subblock_ffn` | FFN weights only (attention frozen) | +| `subblock_attention` | Attention weights only (FFN frozen) | +| `subblock_mamba` | Mamba SSM weights (hybrid models, e.g. NemotronH) | +| `entire_block` | Full transformer block (coupled BLD) | + +**Coupled BLD** (`keys_to_learn: entire_block`) trains the whole block end-to-end and captures +interactions between attention and FFN. The main cost is combinatorial: if you have N FFN sizes +and M attention sizes in your replacement library, coupled BLD requires N × M training runs +instead of N + M for decoupled. Decoupled BLD is therefore the default and usually sufficient. + +## Training multiple configurations + +Use `bypass.configs` to train multiple block configurations sequentially: + +```yaml +bypass: + training: + training_tokens: 1e+9 # ~1B tokens per config + configs: + - model_config_overrides: + ffn: + - intermediate_size: 1792 # aggressive — bypass strongly recommended + attention: + - num_key_value_heads: null + keys_to_learn: subblock_ffn + - model_config_overrides: + ffn: + - intermediate_size: 3584 + attention: + - num_key_value_heads: null + keys_to_learn: subblock_ffn +``` + +> **Note:** Always include `num_key_value_heads: null` under `attention:` even when not +> changing KV heads. Omitting it when `no_op: true` is set on another field can cause +> a config parsing issue. + +Trained checkpoints are automatically symlinked into `$PUZZLE_DIR/ckpts/` where the replacement +library builder picks them up in the next pipeline stage. + +## Auto-generating configs from the pruning search space + +Instead of listing each config manually, use `bypass.auto_configs` to generate configs +automatically from the pruning search space. The default (`auto_configs.attn: true`) trains +one attention-only bypass per KV-head reduction specified in `pruning.n_heads_in_group_list`: + +```yaml +bypass: + auto_configs: + attn: true # one subblock_attention config per pruned kv-head count + ffn: false # set true: one subblock_ffn config per size in pruning.intermediate_size_list + blk: false # set true: cartesian product (FFN size × kv-head count), entire_block BLD + training: + training_tokens: 1e+9 +``` + +Teacher-size subblocks are automatically excluded (no redundant training). For `blk`, all +combinations where **both** FFN and attention are at teacher values are skipped. + +All three flags can be combined. Order of generated configs: FFN → attn → blk. + +## Attention no-op + FFN-only bypass + +A common aggressive compression pattern removes entire attention blocks (`no_op: true`) and +trains only the FFN in those blocks. Example config: + +```yaml +configs: + - model_config_overrides: + ffn: + - intermediate_size: 12288 + attention: + - num_key_value_heads: null + no_op: true + keys_to_learn: subblock_ffn +``` + +When attention is removed, only the FFN parameters are trained. The bypass code automatically +skips attention-related weights (including model-specific ones such as Qwen3's `q_norm`/`k_norm`) +during student weight initialisation. + +## Weights & Biases logging + +Enable W&B to track per-block distillation loss and validation metrics: + +```yaml +bypass: + wandb_log: true + wandb: + project: my-puzzletron-project + entity: my-org +``` + +W&B logs iteration number, token count, learning rate, and per-block loss at each log interval. +If `wandb` is not installed, logging is silently disabled. diff --git a/examples/puzzletron/README.md b/examples/puzzletron/README.md index a7e3aedfc1..5ab73cca95 100644 --- a/examples/puzzletron/README.md +++ b/examples/puzzletron/README.md @@ -57,7 +57,7 @@ hf auth login We can also set the target size of the resulting model using `num_params = 7_000_000_000`. This will be used as an upper bound for the number of parameters of the model. -3. Run the puzzletron pipeline. +3. Run the puzzletron pipeline. Bypass distillation is **disabled by default**; to enable it see [BYPASS.md](./BYPASS.md). ```bash torchrun --nproc_per_node 2 examples/puzzletron/main.py --config examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/llama-3_1-8B_pruneffn_memory.yaml 2>&1 | tee ./log.txt | grep "Puzzletron Progress" @@ -130,6 +130,14 @@ hf auth login 30% GPU memory reduction leads to nearly 5% regression in token_accuracy_top_10 metric (0.898 / 0.942). +## Bypass Distillation (Local Knowledge Distillation) + +Bypass distillation (Blockwise Local Distillation / BLD) is an **optional** pipeline stage that +trains compressed block configurations via per-block knowledge distillation from the teacher. +It is most beneficial for aggressive compression targets (heavy FFN pruning, KV head reduction, +or full attention removal). See **[BYPASS.md](./BYPASS.md)** for the full guide, including +decoupled vs. coupled BLD, auto-config generation, and a worked example with attention no-op blocks. + ## Re-run MIP Search with different constraints If you want to try different constraints without re-running the expensive pruning and scoring steps, use the `--mip-only` flag. diff --git a/examples/puzzletron/configs/bypass/defaults.yaml b/examples/puzzletron/configs/bypass/defaults.yaml new file mode 100644 index 0000000000..8c20bfafc9 --- /dev/null +++ b/examples/puzzletron/configs/bypass/defaults.yaml @@ -0,0 +1,164 @@ +# @package bypass +# Bypass Distillation Configuration (shared across all models) +# This config defines parameters for blockwise local distillation (BLD), +# which trains alternative transformer block configurations using per-block +# knowledge distillation from a teacher model. + +# Runtime Configuration +dtype: "bf16" # Model precision: bf16 for efficiency, fp32 for stability +seed: 42 # Random seed for reproducibility + +# Experiment Tracking +experiment_id: # Unique identifier for this experiment. Will be dynamically set +experiment_dir: # Directory for this experiment. Will be dynamically set +iter_num: 1 # Current iteration number +step_num: 1 # Current step number within iteration +token_count: 0 # Token count tracker (auto-updated during training) + +# Data Configuration +data: + data_column: "messages" + block_size: 512 # Sequence length (tokens per sample) + bos_rate: 0.5 + fim_rate: 0 + fim_spm_rate: 0 + source_datasets_to_discard: [] + load_from_disk: true # Load preprocessed data from disk or from stream + keep_in_memory: false + val_dataset_name: valid + max_eval_samples: 4 + eval_samples_per_process: null # Samples per GPU during distributed eval (auto if null) + shuffle_train_data_seed: ${random_int:0,9999} # Seed for shuffling train data + +# Training Configuration +training: + learning_rate: 1e-4 # Initial learning rate + training_tokens: 1e+5 # Total training tokens (10K = sanity check; set 1e+9 for production) + micro_batch_size: 2 + val_micro_batch_size: 1 + warmup_ratio: 0.05 + warmup_steps: ${warmup_steps:${.training_tokens},${..data.block_size},${.micro_batch_size},${.warmup_ratio}} + min_lr_factor: 1e-5 + grad_accumulation_steps: 1 + skip_first_batches: 0 + weight_decay: 0.1 + decay_lr: true + beta1: 0.9 + beta2: 0.95 + use_grad_scaling: false + grad_clip: 1.0 + grad_clip_type: norm + clipping_count: 0 + log_interval: 5 + eval_interval: 5 + +# Model Loading Configuration +resume_checkpoint_path: null # Path to resume training from checkpoint +find_last_ckpt_for_resume: True # Auto-resume by finding last checkpoint (bool) +parameter_count: null +init_checkpoint_path: null # Path to initialize weights from + +model: + student_weights_dtype: "bf16" # Student model weight precision + + model_overrides: + delete_old_checkpoints: true # Clean up old checkpoints to save disk space + save_interval_seconds: 12900 # Save checkpoint every ~3.5 hours + save_interval: 1e+9 # Save checkpoint every 1B steps (effectively disabled) + save_checkpoint_when_done: true # Save final checkpoint when training completes + + # Architecture modifications for student model (used when not overridden by configs: or auto_configs) + model_config_overrides: + ffn: + - intermediate_size: + no_op: # Disable FFN entirely (true/false) + attention: + - num_key_value_heads: # Number of kv-heads (for GQA) + no_op: # Disable attention entirely (true/false) + +# Model Factory Configuration - Controls student model creation and initialization +model_factory: + factory: bypass_factory_fn # Unified factory supporting all layer types + block_loss_func: normalized_mse_loss # Loss function for comparing teacher/student blocks + gqa_init_mode: AverageKV # How to initialize K/V heads in GQA + mlp_init_mode: Truncate # MLP initialization + mlp_init_config: # Configuration for MLP initialization (if needed) + activations_log_dir: null # Directory with activation statistics + linear_init_mode: FromTeacher # How to initialize linear layers + submodule_for_loss_calculation: null # Specific submodule for loss calc + keys_to_learn: null # What parameters to train. Computed dynamically. + +# Validation Configuration +disable_initial_validate: false +validate_teacher_model: true +validate_student_model: true +disable_validation: false +best_val_loss: 1e+9 + +# Performance Optimization +compile: false # Use PyTorch compilation +disable_fa2: false # Disable Flash Attention 2 (false = use FA2 if available) +teacher_model_load_on_cpu: false + +# Checkpoint Management +save_checkpoint_before_training: false # Save initial checkpoint before training +disable_checkpoint_save: false # Disable all checkpoint saving +save_best_ckpt: false # Save checkpoint when validation improves (disabled by default) +kill_after_first_save: false # Exit after first checkpoint save (for testing) +realize_best_or_latest: "best" + +wandb_log: false +wandb: + project: + entity: + +# Auto-generate bypass configurations from the pruning search space. +# Reads pruning.intermediate_size_list (for ffn/blk) and pruning.n_heads_in_group_list (for attn/blk). +# Teacher-size subblocks are excluded automatically (no redundant training). +# +# attn: true — one subblock_attention config per pruned kv-head count (teacher kv excluded) +# ffn: true — one subblock_ffn config per pruned FFN size (teacher size excluded) +# blk: true — cartesian product of ALL FFN sizes × ALL kv-head counts (entire_block); +# only the single pair (teacher_ffn, teacher_kv) is skipped +# +# Set to false or omit any key to disable that variant. +auto_configs: + attn: false # train attention-only bypass for each KV-head reduction in pruning.n_heads_in_group_list + ffn: true # uncomment to train FFN-only bypass for each size in pruning.intermediate_size_list + blk: false # uncomment to train coupled (full-block) bypass for all (FFN size × KV heads) pairs + +# Explicit bypass configurations (alternative to auto_configs). +# Each entry overrides model.model_config_overrides and optionally model_factory.keys_to_learn. +# If both auto_configs and configs are set, configs takes precedence. +# Uncomment and adapt one of the examples below: +# +# # Example: two FFN sizes, attention unchanged (decoupled FFN-only BLD) +# configs: +# - model_config_overrides: +# ffn: +# - intermediate_size: 3072 +# - intermediate_size: 5888 +# attention: +# - num_key_value_heads: null +# keys_to_learn: subblock_ffn +# +# # Example: two KV-head reductions, FFN unchanged (decoupled attention-only BLD) +# configs: +# - model_config_overrides: +# ffn: +# - intermediate_size: null +# attention: +# - num_key_value_heads: 4 +# - num_key_value_heads: 2 +# keys_to_learn: subblock_attention +# +# # Example: two coupled block configs (FFN + attention together, entire_block BLD) +# configs: +# - model_config_overrides: +# ffn: +# - intermediate_size: 3072 +# - intermediate_size: 5888 +# attention: +# - num_key_value_heads: 4 +# - num_key_value_heads: 2 +# keys_to_learn: entire_block diff --git a/examples/puzzletron/configs/gptoss-20b_remove_experts_memory/bypass/defaults.yaml b/examples/puzzletron/configs/gptoss-20b_remove_experts_memory/bypass/defaults.yaml new file mode 120000 index 0000000000..5614740c97 --- /dev/null +++ b/examples/puzzletron/configs/gptoss-20b_remove_experts_memory/bypass/defaults.yaml @@ -0,0 +1 @@ +../../bypass/defaults.yaml \ No newline at end of file diff --git a/examples/puzzletron/configs/gptoss-20b_remove_experts_memory/gptoss-20b.yaml b/examples/puzzletron/configs/gptoss-20b_remove_experts_memory/gptoss-20b.yaml index b48f1de78c..395c4eeb2c 100644 --- a/examples/puzzletron/configs/gptoss-20b_remove_experts_memory/gptoss-20b.yaml +++ b/examples/puzzletron/configs/gptoss-20b_remove_experts_memory/gptoss-20b.yaml @@ -2,7 +2,7 @@ defaults: - pruning: ffn_pruning - scoring: ../validate_solutions_defaults - realize_model: ../validate_solutions_defaults - - bypass: + - bypass: # uncomment and set to "defaults" to enable bypass distillation - override hydra/hydra_logging: disabled - _self_ diff --git a/examples/puzzletron/configs/gptoss-20b_remove_experts_memory/validate_model_defaults.yaml b/examples/puzzletron/configs/gptoss-20b_remove_experts_memory/validate_model_defaults.yaml deleted file mode 100644 index b80faea5f5..0000000000 --- a/examples/puzzletron/configs/gptoss-20b_remove_experts_memory/validate_model_defaults.yaml +++ /dev/null @@ -1,18 +0,0 @@ -model_dtype: torch.bfloat16 # dtype to cast the model for validate_model -autocast_dtype: torch.bfloat16 # dtype for torch.autocast for validate_model -block_size: 8192 -bos_rate: 0.5 -data_column: messages -val_dataset_name: valid -shuffle_seed: 81436 -seed: 42 -fim_rate: 0 -fim_spm_rate: 0 -source_datasets_to_discard: -varlen: false -write_results: false -calc_losses_on_cpu: false -activations_log_dir: -model_name_or_path: -load_dataset_fn: ${get_object:modelopt.torch.puzzletron.utils.data.dataloaders.load_from_disk_fn} - diff --git a/examples/puzzletron/configs/gptoss-20b_remove_experts_memory/validate_model_defaults.yaml b/examples/puzzletron/configs/gptoss-20b_remove_experts_memory/validate_model_defaults.yaml new file mode 120000 index 0000000000..8114ca754c --- /dev/null +++ b/examples/puzzletron/configs/gptoss-20b_remove_experts_memory/validate_model_defaults.yaml @@ -0,0 +1 @@ +../validate_model_defaults.yaml \ No newline at end of file diff --git a/examples/puzzletron/configs/gptoss-20b_remove_experts_memory/validate_solutions_defaults.yaml b/examples/puzzletron/configs/gptoss-20b_remove_experts_memory/validate_solutions_defaults.yaml deleted file mode 100644 index ab8c892182..0000000000 --- a/examples/puzzletron/configs/gptoss-20b_remove_experts_memory/validate_solutions_defaults.yaml +++ /dev/null @@ -1,11 +0,0 @@ -defaults: - - /validate_model_defaults - - _self_ - -solutions_to_validate: -skip_validation: false -save_models: false -bigger_is_better: false -sort_solutions_by: -calculate_full_score_ablations: false - diff --git a/examples/puzzletron/configs/gptoss-20b_remove_experts_memory/validate_solutions_defaults.yaml b/examples/puzzletron/configs/gptoss-20b_remove_experts_memory/validate_solutions_defaults.yaml new file mode 120000 index 0000000000..695e65bf9c --- /dev/null +++ b/examples/puzzletron/configs/gptoss-20b_remove_experts_memory/validate_solutions_defaults.yaml @@ -0,0 +1 @@ +../validate_solutions_defaults.yaml \ No newline at end of file diff --git a/examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/Llama-3_1-8B.yaml b/examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/Llama-3_1-8B.yaml index 21903db162..7c087f07cc 100644 --- a/examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/Llama-3_1-8B.yaml +++ b/examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/Llama-3_1-8B.yaml @@ -2,7 +2,7 @@ defaults: - pruning: ffn_pruning - scoring: ../validate_solutions_defaults - realize_model: ../validate_solutions_defaults - - bypass: + - bypass: defaults # uncomment and set to "defaults" to enable bypass distillation - override hydra/hydra_logging: disabled - _self_ @@ -42,7 +42,7 @@ scoring: teacher_dir: ${to_path:${teacher_dir}} output_dir: ${puzzle_dir}/single_sequence_replacement_solutions--validation - eval_samples: 128 + eval_samples: 4 micro_batch_size: 1 seed: 42 shuffle_seed: 444 diff --git a/examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/bypass/defaults.yaml b/examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/bypass/defaults.yaml new file mode 120000 index 0000000000..5614740c97 --- /dev/null +++ b/examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/bypass/defaults.yaml @@ -0,0 +1 @@ +../../bypass/defaults.yaml \ No newline at end of file diff --git a/examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/pruning/pruning_defaults.yaml b/examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/pruning/pruning_defaults.yaml deleted file mode 100644 index e05e775bee..0000000000 --- a/examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/pruning/pruning_defaults.yaml +++ /dev/null @@ -1,33 +0,0 @@ -defaults: - - /validate_model_defaults - -descriptor: ${descriptor} -model_name_or_path: ${teacher_dir} -experiment_id: ${pruning.eval_samples}samples_diverse_mini -activations_log_dir: ??? -activation_hooks_kwargs: ??? - -# Data: -eval_samples: 1000 # default is 10000 -micro_batch_size: 4 -dataset_path: ${dataset_path} -val_dataset_name: train - -# Prune ckpts -pruned_ckpts_output_dir: ${puzzle_dir}/pruning/${pruning.experiment_id} - -## FFN pruning -ffn_list: -mlp_init_mode: "Truncate" # PruneByActivationsLog - -## KV-heads pruning -n_heads_in_group_list: -gqa_init_mode: "AverageKV" - -## Hidden dimension pruning -hidden_size_list: -hidden_size_init_mode: "PruneByChannelRanking" -linear_init_mode: "FromTeacher" - -mlp_init_config_yaml: - activations_log_dir: ${pruning.activations_log_dir} diff --git a/examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/pruning/pruning_defaults.yaml b/examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/pruning/pruning_defaults.yaml new file mode 120000 index 0000000000..3d599095c3 --- /dev/null +++ b/examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/pruning/pruning_defaults.yaml @@ -0,0 +1 @@ +../../pruning/defaults.yaml \ No newline at end of file diff --git a/examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/validate_model_defaults.yaml b/examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/validate_model_defaults.yaml deleted file mode 100644 index ce1749d969..0000000000 --- a/examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/validate_model_defaults.yaml +++ /dev/null @@ -1,17 +0,0 @@ -model_dtype: torch.bfloat16 # dtype to cast the model for validate_model -autocast_dtype: torch.bfloat16 # dtype for torch.autocast for validate_model -block_size: 8192 -bos_rate: 0.5 -data_column: messages -val_dataset_name: valid -shuffle_seed: 81436 -seed: 42 -fim_rate: 0 -fim_spm_rate: 0 -source_datasets_to_discard: -varlen: false -write_results: false -calc_losses_on_cpu: false -activations_log_dir: -model_name_or_path: -load_dataset_fn: ${get_object:modelopt.torch.puzzletron.utils.data.dataloaders.load_from_disk_fn} diff --git a/examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/validate_model_defaults.yaml b/examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/validate_model_defaults.yaml new file mode 120000 index 0000000000..8114ca754c --- /dev/null +++ b/examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/validate_model_defaults.yaml @@ -0,0 +1 @@ +../validate_model_defaults.yaml \ No newline at end of file diff --git a/examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/validate_solutions_defaults.yaml b/examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/validate_solutions_defaults.yaml deleted file mode 100644 index ec13902379..0000000000 --- a/examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/validate_solutions_defaults.yaml +++ /dev/null @@ -1,10 +0,0 @@ -defaults: - - /validate_model_defaults - - _self_ - -solutions_to_validate: -skip_validation: false -save_models: false -bigger_is_better: false -sort_solutions_by: -calculate_full_score_ablations: false diff --git a/examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/validate_solutions_defaults.yaml b/examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/validate_solutions_defaults.yaml new file mode 120000 index 0000000000..695e65bf9c --- /dev/null +++ b/examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/validate_solutions_defaults.yaml @@ -0,0 +1 @@ +../validate_solutions_defaults.yaml \ No newline at end of file diff --git a/examples/puzzletron/configs/llama-3_2-3B_pruneffn_memory/Llama-3_2-3B.yaml b/examples/puzzletron/configs/llama-3_2-3B_pruneffn_memory/Llama-3_2-3B.yaml index 7de281e788..f5434c0588 100644 --- a/examples/puzzletron/configs/llama-3_2-3B_pruneffn_memory/Llama-3_2-3B.yaml +++ b/examples/puzzletron/configs/llama-3_2-3B_pruneffn_memory/Llama-3_2-3B.yaml @@ -2,7 +2,7 @@ defaults: - pruning: ffn_pruning - scoring: ../validate_solutions_defaults - realize_model: ../validate_solutions_defaults - - bypass: + - bypass: # uncomment and set to "defaults" to enable bypass distillation - override hydra/hydra_logging: disabled - _self_ diff --git a/examples/puzzletron/configs/llama-3_2-3B_pruneffn_memory/bypass/defaults.yaml b/examples/puzzletron/configs/llama-3_2-3B_pruneffn_memory/bypass/defaults.yaml new file mode 120000 index 0000000000..5614740c97 --- /dev/null +++ b/examples/puzzletron/configs/llama-3_2-3B_pruneffn_memory/bypass/defaults.yaml @@ -0,0 +1 @@ +../../bypass/defaults.yaml \ No newline at end of file diff --git a/examples/puzzletron/configs/llama-3_2-3B_pruneffn_memory/pruning/pruning_defaults.yaml b/examples/puzzletron/configs/llama-3_2-3B_pruneffn_memory/pruning/pruning_defaults.yaml deleted file mode 100644 index e05e775bee..0000000000 --- a/examples/puzzletron/configs/llama-3_2-3B_pruneffn_memory/pruning/pruning_defaults.yaml +++ /dev/null @@ -1,33 +0,0 @@ -defaults: - - /validate_model_defaults - -descriptor: ${descriptor} -model_name_or_path: ${teacher_dir} -experiment_id: ${pruning.eval_samples}samples_diverse_mini -activations_log_dir: ??? -activation_hooks_kwargs: ??? - -# Data: -eval_samples: 1000 # default is 10000 -micro_batch_size: 4 -dataset_path: ${dataset_path} -val_dataset_name: train - -# Prune ckpts -pruned_ckpts_output_dir: ${puzzle_dir}/pruning/${pruning.experiment_id} - -## FFN pruning -ffn_list: -mlp_init_mode: "Truncate" # PruneByActivationsLog - -## KV-heads pruning -n_heads_in_group_list: -gqa_init_mode: "AverageKV" - -## Hidden dimension pruning -hidden_size_list: -hidden_size_init_mode: "PruneByChannelRanking" -linear_init_mode: "FromTeacher" - -mlp_init_config_yaml: - activations_log_dir: ${pruning.activations_log_dir} diff --git a/examples/puzzletron/configs/llama-3_2-3B_pruneffn_memory/pruning/pruning_defaults.yaml b/examples/puzzletron/configs/llama-3_2-3B_pruneffn_memory/pruning/pruning_defaults.yaml new file mode 120000 index 0000000000..3d599095c3 --- /dev/null +++ b/examples/puzzletron/configs/llama-3_2-3B_pruneffn_memory/pruning/pruning_defaults.yaml @@ -0,0 +1 @@ +../../pruning/defaults.yaml \ No newline at end of file diff --git a/examples/puzzletron/configs/llama-3_2-3B_pruneffn_memory/validate_model_defaults.yaml b/examples/puzzletron/configs/llama-3_2-3B_pruneffn_memory/validate_model_defaults.yaml deleted file mode 100644 index b80faea5f5..0000000000 --- a/examples/puzzletron/configs/llama-3_2-3B_pruneffn_memory/validate_model_defaults.yaml +++ /dev/null @@ -1,18 +0,0 @@ -model_dtype: torch.bfloat16 # dtype to cast the model for validate_model -autocast_dtype: torch.bfloat16 # dtype for torch.autocast for validate_model -block_size: 8192 -bos_rate: 0.5 -data_column: messages -val_dataset_name: valid -shuffle_seed: 81436 -seed: 42 -fim_rate: 0 -fim_spm_rate: 0 -source_datasets_to_discard: -varlen: false -write_results: false -calc_losses_on_cpu: false -activations_log_dir: -model_name_or_path: -load_dataset_fn: ${get_object:modelopt.torch.puzzletron.utils.data.dataloaders.load_from_disk_fn} - diff --git a/examples/puzzletron/configs/llama-3_2-3B_pruneffn_memory/validate_model_defaults.yaml b/examples/puzzletron/configs/llama-3_2-3B_pruneffn_memory/validate_model_defaults.yaml new file mode 120000 index 0000000000..8114ca754c --- /dev/null +++ b/examples/puzzletron/configs/llama-3_2-3B_pruneffn_memory/validate_model_defaults.yaml @@ -0,0 +1 @@ +../validate_model_defaults.yaml \ No newline at end of file diff --git a/examples/puzzletron/configs/llama-3_2-3B_pruneffn_memory/validate_solutions_defaults.yaml b/examples/puzzletron/configs/llama-3_2-3B_pruneffn_memory/validate_solutions_defaults.yaml deleted file mode 100644 index ab8c892182..0000000000 --- a/examples/puzzletron/configs/llama-3_2-3B_pruneffn_memory/validate_solutions_defaults.yaml +++ /dev/null @@ -1,11 +0,0 @@ -defaults: - - /validate_model_defaults - - _self_ - -solutions_to_validate: -skip_validation: false -save_models: false -bigger_is_better: false -sort_solutions_by: -calculate_full_score_ablations: false - diff --git a/examples/puzzletron/configs/llama-3_2-3B_pruneffn_memory/validate_solutions_defaults.yaml b/examples/puzzletron/configs/llama-3_2-3B_pruneffn_memory/validate_solutions_defaults.yaml new file mode 120000 index 0000000000..695e65bf9c --- /dev/null +++ b/examples/puzzletron/configs/llama-3_2-3B_pruneffn_memory/validate_solutions_defaults.yaml @@ -0,0 +1 @@ +../validate_solutions_defaults.yaml \ No newline at end of file diff --git a/examples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/Mistral-Small-24B.yaml b/examples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/Mistral-Small-24B.yaml index 18213f9b7a..9eb29d3afa 100644 --- a/examples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/Mistral-Small-24B.yaml +++ b/examples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/Mistral-Small-24B.yaml @@ -2,7 +2,7 @@ defaults: - pruning: ffn_pruning - scoring: ../validate_solutions_defaults - realize_model: ../validate_solutions_defaults - - bypass: + - bypass: # uncomment and set to "defaults" to enable bypass distillation - override hydra/hydra_logging: disabled - _self_ diff --git a/examples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/bypass/defaults.yaml b/examples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/bypass/defaults.yaml new file mode 120000 index 0000000000..5614740c97 --- /dev/null +++ b/examples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/bypass/defaults.yaml @@ -0,0 +1 @@ +../../bypass/defaults.yaml \ No newline at end of file diff --git a/examples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/pruning/pruning_defaults.yaml b/examples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/pruning/pruning_defaults.yaml deleted file mode 100644 index e05e775bee..0000000000 --- a/examples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/pruning/pruning_defaults.yaml +++ /dev/null @@ -1,33 +0,0 @@ -defaults: - - /validate_model_defaults - -descriptor: ${descriptor} -model_name_or_path: ${teacher_dir} -experiment_id: ${pruning.eval_samples}samples_diverse_mini -activations_log_dir: ??? -activation_hooks_kwargs: ??? - -# Data: -eval_samples: 1000 # default is 10000 -micro_batch_size: 4 -dataset_path: ${dataset_path} -val_dataset_name: train - -# Prune ckpts -pruned_ckpts_output_dir: ${puzzle_dir}/pruning/${pruning.experiment_id} - -## FFN pruning -ffn_list: -mlp_init_mode: "Truncate" # PruneByActivationsLog - -## KV-heads pruning -n_heads_in_group_list: -gqa_init_mode: "AverageKV" - -## Hidden dimension pruning -hidden_size_list: -hidden_size_init_mode: "PruneByChannelRanking" -linear_init_mode: "FromTeacher" - -mlp_init_config_yaml: - activations_log_dir: ${pruning.activations_log_dir} diff --git a/examples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/pruning/pruning_defaults.yaml b/examples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/pruning/pruning_defaults.yaml new file mode 120000 index 0000000000..3d599095c3 --- /dev/null +++ b/examples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/pruning/pruning_defaults.yaml @@ -0,0 +1 @@ +../../pruning/defaults.yaml \ No newline at end of file diff --git a/examples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/validate_model_defaults.yaml b/examples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/validate_model_defaults.yaml deleted file mode 100644 index ce1749d969..0000000000 --- a/examples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/validate_model_defaults.yaml +++ /dev/null @@ -1,17 +0,0 @@ -model_dtype: torch.bfloat16 # dtype to cast the model for validate_model -autocast_dtype: torch.bfloat16 # dtype for torch.autocast for validate_model -block_size: 8192 -bos_rate: 0.5 -data_column: messages -val_dataset_name: valid -shuffle_seed: 81436 -seed: 42 -fim_rate: 0 -fim_spm_rate: 0 -source_datasets_to_discard: -varlen: false -write_results: false -calc_losses_on_cpu: false -activations_log_dir: -model_name_or_path: -load_dataset_fn: ${get_object:modelopt.torch.puzzletron.utils.data.dataloaders.load_from_disk_fn} diff --git a/examples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/validate_model_defaults.yaml b/examples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/validate_model_defaults.yaml new file mode 120000 index 0000000000..8114ca754c --- /dev/null +++ b/examples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/validate_model_defaults.yaml @@ -0,0 +1 @@ +../validate_model_defaults.yaml \ No newline at end of file diff --git a/examples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/validate_solutions_defaults.yaml b/examples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/validate_solutions_defaults.yaml deleted file mode 100644 index ec13902379..0000000000 --- a/examples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/validate_solutions_defaults.yaml +++ /dev/null @@ -1,10 +0,0 @@ -defaults: - - /validate_model_defaults - - _self_ - -solutions_to_validate: -skip_validation: false -save_models: false -bigger_is_better: false -sort_solutions_by: -calculate_full_score_ablations: false diff --git a/examples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/validate_solutions_defaults.yaml b/examples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/validate_solutions_defaults.yaml new file mode 120000 index 0000000000..695e65bf9c --- /dev/null +++ b/examples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/validate_solutions_defaults.yaml @@ -0,0 +1 @@ +../validate_solutions_defaults.yaml \ No newline at end of file diff --git a/examples/puzzletron/configs/nemotron-nano-12b-v2/bypass/defaults.yaml b/examples/puzzletron/configs/nemotron-nano-12b-v2/bypass/defaults.yaml new file mode 120000 index 0000000000..5614740c97 --- /dev/null +++ b/examples/puzzletron/configs/nemotron-nano-12b-v2/bypass/defaults.yaml @@ -0,0 +1 @@ +../../bypass/defaults.yaml \ No newline at end of file diff --git a/examples/puzzletron/configs/nemotron-nano-12b-v2/nemotron_nano_12b_v2.yaml b/examples/puzzletron/configs/nemotron-nano-12b-v2/nemotron_nano_12b_v2.yaml index 62b6ecb4cb..e39599f355 100644 --- a/examples/puzzletron/configs/nemotron-nano-12b-v2/nemotron_nano_12b_v2.yaml +++ b/examples/puzzletron/configs/nemotron-nano-12b-v2/nemotron_nano_12b_v2.yaml @@ -2,7 +2,7 @@ defaults: - pruning: ffn_pruning - scoring: ../validate_solutions_defaults - realize_model: ../validate_solutions_defaults - - bypass: + - bypass: # uncomment and set to "defaults" to enable bypass distillation - override hydra/hydra_logging: disabled - _self_ diff --git a/examples/puzzletron/configs/nemotron-nano-12b-v2/pruning/pruning_defaults.yaml b/examples/puzzletron/configs/nemotron-nano-12b-v2/pruning/pruning_defaults.yaml deleted file mode 100644 index 8816eecc4a..0000000000 --- a/examples/puzzletron/configs/nemotron-nano-12b-v2/pruning/pruning_defaults.yaml +++ /dev/null @@ -1,34 +0,0 @@ -defaults: - - /validate_model_defaults - -model_name_or_path: ${teacher_dir} -experiment_id: ${pruning.eval_samples}samples_diverse_mini -activations_log_dir: ??? -activation_hooks_kwargs: ??? - -descriptor: ${descriptor} - -# Data: -eval_samples: 1000 # default is 10000 -micro_batch_size: 4 -dataset_path: ${dataset_path} -val_dataset_name: train - -# Prune ckpts -pruned_ckpts_output_dir: ${puzzle_dir}/pruning/${pruning.experiment_id} - -## FFN pruning -ffn_list: -mlp_init_mode: "Truncate" - -## KV-heads pruning -n_heads_in_group_list: -gqa_init_mode: "AverageKV" - -## Hidden dimension pruning -hidden_size_list: -hidden_size_init_mode: "PruneByChannelRanking" -linear_init_mode: "FromTeacher" - -mlp_init_config_yaml: - activations_log_dir: ${pruning.activations_log_dir} diff --git a/examples/puzzletron/configs/nemotron-nano-12b-v2/pruning/pruning_defaults.yaml b/examples/puzzletron/configs/nemotron-nano-12b-v2/pruning/pruning_defaults.yaml new file mode 120000 index 0000000000..3d599095c3 --- /dev/null +++ b/examples/puzzletron/configs/nemotron-nano-12b-v2/pruning/pruning_defaults.yaml @@ -0,0 +1 @@ +../../pruning/defaults.yaml \ No newline at end of file diff --git a/examples/puzzletron/configs/nemotron-nano-12b-v2/validate_model_defaults.yaml b/examples/puzzletron/configs/nemotron-nano-12b-v2/validate_model_defaults.yaml deleted file mode 100644 index ce1749d969..0000000000 --- a/examples/puzzletron/configs/nemotron-nano-12b-v2/validate_model_defaults.yaml +++ /dev/null @@ -1,17 +0,0 @@ -model_dtype: torch.bfloat16 # dtype to cast the model for validate_model -autocast_dtype: torch.bfloat16 # dtype for torch.autocast for validate_model -block_size: 8192 -bos_rate: 0.5 -data_column: messages -val_dataset_name: valid -shuffle_seed: 81436 -seed: 42 -fim_rate: 0 -fim_spm_rate: 0 -source_datasets_to_discard: -varlen: false -write_results: false -calc_losses_on_cpu: false -activations_log_dir: -model_name_or_path: -load_dataset_fn: ${get_object:modelopt.torch.puzzletron.utils.data.dataloaders.load_from_disk_fn} diff --git a/examples/puzzletron/configs/nemotron-nano-12b-v2/validate_model_defaults.yaml b/examples/puzzletron/configs/nemotron-nano-12b-v2/validate_model_defaults.yaml new file mode 120000 index 0000000000..8114ca754c --- /dev/null +++ b/examples/puzzletron/configs/nemotron-nano-12b-v2/validate_model_defaults.yaml @@ -0,0 +1 @@ +../validate_model_defaults.yaml \ No newline at end of file diff --git a/examples/puzzletron/configs/nemotron-nano-12b-v2/validate_solutions_defaults.yaml b/examples/puzzletron/configs/nemotron-nano-12b-v2/validate_solutions_defaults.yaml deleted file mode 100644 index ec13902379..0000000000 --- a/examples/puzzletron/configs/nemotron-nano-12b-v2/validate_solutions_defaults.yaml +++ /dev/null @@ -1,10 +0,0 @@ -defaults: - - /validate_model_defaults - - _self_ - -solutions_to_validate: -skip_validation: false -save_models: false -bigger_is_better: false -sort_solutions_by: -calculate_full_score_ablations: false diff --git a/examples/puzzletron/configs/nemotron-nano-12b-v2/validate_solutions_defaults.yaml b/examples/puzzletron/configs/nemotron-nano-12b-v2/validate_solutions_defaults.yaml new file mode 120000 index 0000000000..695e65bf9c --- /dev/null +++ b/examples/puzzletron/configs/nemotron-nano-12b-v2/validate_solutions_defaults.yaml @@ -0,0 +1 @@ +../validate_solutions_defaults.yaml \ No newline at end of file diff --git a/examples/puzzletron/configs/pruning/defaults.yaml b/examples/puzzletron/configs/pruning/defaults.yaml new file mode 100644 index 0000000000..8816eecc4a --- /dev/null +++ b/examples/puzzletron/configs/pruning/defaults.yaml @@ -0,0 +1,34 @@ +defaults: + - /validate_model_defaults + +model_name_or_path: ${teacher_dir} +experiment_id: ${pruning.eval_samples}samples_diverse_mini +activations_log_dir: ??? +activation_hooks_kwargs: ??? + +descriptor: ${descriptor} + +# Data: +eval_samples: 1000 # default is 10000 +micro_batch_size: 4 +dataset_path: ${dataset_path} +val_dataset_name: train + +# Prune ckpts +pruned_ckpts_output_dir: ${puzzle_dir}/pruning/${pruning.experiment_id} + +## FFN pruning +ffn_list: +mlp_init_mode: "Truncate" + +## KV-heads pruning +n_heads_in_group_list: +gqa_init_mode: "AverageKV" + +## Hidden dimension pruning +hidden_size_list: +hidden_size_init_mode: "PruneByChannelRanking" +linear_init_mode: "FromTeacher" + +mlp_init_config_yaml: + activations_log_dir: ${pruning.activations_log_dir} diff --git a/examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/bypass/defaults.yaml b/examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/bypass/defaults.yaml new file mode 120000 index 0000000000..5614740c97 --- /dev/null +++ b/examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/bypass/defaults.yaml @@ -0,0 +1 @@ +../../bypass/defaults.yaml \ No newline at end of file diff --git a/examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/pruning/attn_pruning.yaml b/examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/pruning/attn_pruning.yaml index 3f7a248ee7..5fdaa3c31b 100644 --- a/examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/pruning/attn_pruning.yaml +++ b/examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/pruning/attn_pruning.yaml @@ -1,7 +1,8 @@ defaults: - pruning_defaults -activations_log_dir: ${puzzle_dir}/pruning/pruning_scores/attn_${modelopt.torch.puzzletron.pruning.activation_hooks_kwargs.method}/${modelopt.torch.puzzletron.pruning.experiment_id} +activations_log_dir: + ${puzzle_dir}/pruning/pruning_scores/attn_${modelopt.torch.puzzletron.pruning.activation_hooks_kwargs.method}/${modelopt.torch.puzzletron.pruning.experiment_id} activation_hooks_kwargs: method: independent_kv_head_contribution diff --git a/examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/pruning/hidden_dim_pruning.yaml b/examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/pruning/hidden_dim_pruning.yaml index af8af990b7..75d4a53085 100644 --- a/examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/pruning/hidden_dim_pruning.yaml +++ b/examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/pruning/hidden_dim_pruning.yaml @@ -1,7 +1,8 @@ defaults: - pruning_defaults -activations_log_dir: ${puzzle_dir}/pruning/pruning_scores/hidden_dim_${modelopt.torch.puzzletron.pruning.activation_hooks_kwargs.method}/${modelopt.torch.puzzletron.pruning.experiment_id} +activations_log_dir: + ${puzzle_dir}/pruning/pruning_scores/hidden_dim_${modelopt.torch.puzzletron.pruning.activation_hooks_kwargs.method}/${modelopt.torch.puzzletron.pruning.experiment_id} activation_hooks_kwargs: method: layer_norm_contribution diff --git a/examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/pruning/pruning_defaults.yaml b/examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/pruning/pruning_defaults.yaml deleted file mode 100644 index 8816eecc4a..0000000000 --- a/examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/pruning/pruning_defaults.yaml +++ /dev/null @@ -1,34 +0,0 @@ -defaults: - - /validate_model_defaults - -model_name_or_path: ${teacher_dir} -experiment_id: ${pruning.eval_samples}samples_diverse_mini -activations_log_dir: ??? -activation_hooks_kwargs: ??? - -descriptor: ${descriptor} - -# Data: -eval_samples: 1000 # default is 10000 -micro_batch_size: 4 -dataset_path: ${dataset_path} -val_dataset_name: train - -# Prune ckpts -pruned_ckpts_output_dir: ${puzzle_dir}/pruning/${pruning.experiment_id} - -## FFN pruning -ffn_list: -mlp_init_mode: "Truncate" - -## KV-heads pruning -n_heads_in_group_list: -gqa_init_mode: "AverageKV" - -## Hidden dimension pruning -hidden_size_list: -hidden_size_init_mode: "PruneByChannelRanking" -linear_init_mode: "FromTeacher" - -mlp_init_config_yaml: - activations_log_dir: ${pruning.activations_log_dir} diff --git a/examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/pruning/pruning_defaults.yaml b/examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/pruning/pruning_defaults.yaml new file mode 120000 index 0000000000..3d599095c3 --- /dev/null +++ b/examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/pruning/pruning_defaults.yaml @@ -0,0 +1 @@ +../../pruning/defaults.yaml \ No newline at end of file diff --git a/examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/qwen2_5_7b_instruct.yaml b/examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/qwen2_5_7b_instruct.yaml index aa11499a3c..bbb1286231 100644 --- a/examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/qwen2_5_7b_instruct.yaml +++ b/examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/qwen2_5_7b_instruct.yaml @@ -2,7 +2,7 @@ defaults: - pruning: ffn_pruning - scoring: ../validate_solutions_defaults - realize_model: ../validate_solutions_defaults - - bypass: + - bypass: # uncomment and set to "defaults" to enable bypass distillation - override hydra/hydra_logging: disabled - _self_ diff --git a/examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/validate_model_defaults.yaml b/examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/validate_model_defaults.yaml deleted file mode 100644 index ce1749d969..0000000000 --- a/examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/validate_model_defaults.yaml +++ /dev/null @@ -1,17 +0,0 @@ -model_dtype: torch.bfloat16 # dtype to cast the model for validate_model -autocast_dtype: torch.bfloat16 # dtype for torch.autocast for validate_model -block_size: 8192 -bos_rate: 0.5 -data_column: messages -val_dataset_name: valid -shuffle_seed: 81436 -seed: 42 -fim_rate: 0 -fim_spm_rate: 0 -source_datasets_to_discard: -varlen: false -write_results: false -calc_losses_on_cpu: false -activations_log_dir: -model_name_or_path: -load_dataset_fn: ${get_object:modelopt.torch.puzzletron.utils.data.dataloaders.load_from_disk_fn} diff --git a/examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/validate_model_defaults.yaml b/examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/validate_model_defaults.yaml new file mode 120000 index 0000000000..8114ca754c --- /dev/null +++ b/examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/validate_model_defaults.yaml @@ -0,0 +1 @@ +../validate_model_defaults.yaml \ No newline at end of file diff --git a/examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/validate_solutions_defaults.yaml b/examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/validate_solutions_defaults.yaml deleted file mode 100644 index ec13902379..0000000000 --- a/examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/validate_solutions_defaults.yaml +++ /dev/null @@ -1,10 +0,0 @@ -defaults: - - /validate_model_defaults - - _self_ - -solutions_to_validate: -skip_validation: false -save_models: false -bigger_is_better: false -sort_solutions_by: -calculate_full_score_ablations: false diff --git a/examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/validate_solutions_defaults.yaml b/examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/validate_solutions_defaults.yaml new file mode 120000 index 0000000000..695e65bf9c --- /dev/null +++ b/examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/validate_solutions_defaults.yaml @@ -0,0 +1 @@ +../validate_solutions_defaults.yaml \ No newline at end of file diff --git a/examples/puzzletron/configs/qwen3-8b_pruneffn_memory/bypass/defaults.yaml b/examples/puzzletron/configs/qwen3-8b_pruneffn_memory/bypass/defaults.yaml new file mode 120000 index 0000000000..5614740c97 --- /dev/null +++ b/examples/puzzletron/configs/qwen3-8b_pruneffn_memory/bypass/defaults.yaml @@ -0,0 +1 @@ +../../bypass/defaults.yaml \ No newline at end of file diff --git a/examples/puzzletron/configs/qwen3-8b_pruneffn_memory/pruning/pruning_defaults.yaml b/examples/puzzletron/configs/qwen3-8b_pruneffn_memory/pruning/pruning_defaults.yaml deleted file mode 100644 index 8816eecc4a..0000000000 --- a/examples/puzzletron/configs/qwen3-8b_pruneffn_memory/pruning/pruning_defaults.yaml +++ /dev/null @@ -1,34 +0,0 @@ -defaults: - - /validate_model_defaults - -model_name_or_path: ${teacher_dir} -experiment_id: ${pruning.eval_samples}samples_diverse_mini -activations_log_dir: ??? -activation_hooks_kwargs: ??? - -descriptor: ${descriptor} - -# Data: -eval_samples: 1000 # default is 10000 -micro_batch_size: 4 -dataset_path: ${dataset_path} -val_dataset_name: train - -# Prune ckpts -pruned_ckpts_output_dir: ${puzzle_dir}/pruning/${pruning.experiment_id} - -## FFN pruning -ffn_list: -mlp_init_mode: "Truncate" - -## KV-heads pruning -n_heads_in_group_list: -gqa_init_mode: "AverageKV" - -## Hidden dimension pruning -hidden_size_list: -hidden_size_init_mode: "PruneByChannelRanking" -linear_init_mode: "FromTeacher" - -mlp_init_config_yaml: - activations_log_dir: ${pruning.activations_log_dir} diff --git a/examples/puzzletron/configs/qwen3-8b_pruneffn_memory/pruning/pruning_defaults.yaml b/examples/puzzletron/configs/qwen3-8b_pruneffn_memory/pruning/pruning_defaults.yaml new file mode 120000 index 0000000000..3d599095c3 --- /dev/null +++ b/examples/puzzletron/configs/qwen3-8b_pruneffn_memory/pruning/pruning_defaults.yaml @@ -0,0 +1 @@ +../../pruning/defaults.yaml \ No newline at end of file diff --git a/examples/puzzletron/configs/qwen3-8b_pruneffn_memory/qwen3_8b.yaml b/examples/puzzletron/configs/qwen3-8b_pruneffn_memory/qwen3_8b.yaml index eec82a7d63..6e73ddba0d 100644 --- a/examples/puzzletron/configs/qwen3-8b_pruneffn_memory/qwen3_8b.yaml +++ b/examples/puzzletron/configs/qwen3-8b_pruneffn_memory/qwen3_8b.yaml @@ -2,7 +2,7 @@ defaults: - pruning: ffn_pruning - scoring: ../validate_solutions_defaults - realize_model: ../validate_solutions_defaults - - bypass: + - bypass: # uncomment and set to "defaults" to enable bypass distillation - override hydra/hydra_logging: disabled - _self_ diff --git a/examples/puzzletron/configs/qwen3-8b_pruneffn_memory/validate_model_defaults.yaml b/examples/puzzletron/configs/qwen3-8b_pruneffn_memory/validate_model_defaults.yaml deleted file mode 100644 index ce1749d969..0000000000 --- a/examples/puzzletron/configs/qwen3-8b_pruneffn_memory/validate_model_defaults.yaml +++ /dev/null @@ -1,17 +0,0 @@ -model_dtype: torch.bfloat16 # dtype to cast the model for validate_model -autocast_dtype: torch.bfloat16 # dtype for torch.autocast for validate_model -block_size: 8192 -bos_rate: 0.5 -data_column: messages -val_dataset_name: valid -shuffle_seed: 81436 -seed: 42 -fim_rate: 0 -fim_spm_rate: 0 -source_datasets_to_discard: -varlen: false -write_results: false -calc_losses_on_cpu: false -activations_log_dir: -model_name_or_path: -load_dataset_fn: ${get_object:modelopt.torch.puzzletron.utils.data.dataloaders.load_from_disk_fn} diff --git a/examples/puzzletron/configs/qwen3-8b_pruneffn_memory/validate_model_defaults.yaml b/examples/puzzletron/configs/qwen3-8b_pruneffn_memory/validate_model_defaults.yaml new file mode 120000 index 0000000000..8114ca754c --- /dev/null +++ b/examples/puzzletron/configs/qwen3-8b_pruneffn_memory/validate_model_defaults.yaml @@ -0,0 +1 @@ +../validate_model_defaults.yaml \ No newline at end of file diff --git a/examples/puzzletron/configs/qwen3-8b_pruneffn_memory/validate_solutions_defaults.yaml b/examples/puzzletron/configs/qwen3-8b_pruneffn_memory/validate_solutions_defaults.yaml deleted file mode 100644 index ec13902379..0000000000 --- a/examples/puzzletron/configs/qwen3-8b_pruneffn_memory/validate_solutions_defaults.yaml +++ /dev/null @@ -1,10 +0,0 @@ -defaults: - - /validate_model_defaults - - _self_ - -solutions_to_validate: -skip_validation: false -save_models: false -bigger_is_better: false -sort_solutions_by: -calculate_full_score_ablations: false diff --git a/examples/puzzletron/configs/qwen3-8b_pruneffn_memory/validate_solutions_defaults.yaml b/examples/puzzletron/configs/qwen3-8b_pruneffn_memory/validate_solutions_defaults.yaml new file mode 120000 index 0000000000..695e65bf9c --- /dev/null +++ b/examples/puzzletron/configs/qwen3-8b_pruneffn_memory/validate_solutions_defaults.yaml @@ -0,0 +1 @@ +../validate_solutions_defaults.yaml \ No newline at end of file diff --git a/examples/puzzletron/configs/realize_model/validate_solutions_defaults.yaml b/examples/puzzletron/configs/realize_model/validate_solutions_defaults.yaml new file mode 100644 index 0000000000..ec13902379 --- /dev/null +++ b/examples/puzzletron/configs/realize_model/validate_solutions_defaults.yaml @@ -0,0 +1,10 @@ +defaults: + - /validate_model_defaults + - _self_ + +solutions_to_validate: +skip_validation: false +save_models: false +bigger_is_better: false +sort_solutions_by: +calculate_full_score_ablations: false diff --git a/examples/puzzletron/configs/scoring/validate_solutions_defaults.yaml b/examples/puzzletron/configs/scoring/validate_solutions_defaults.yaml new file mode 100644 index 0000000000..ec13902379 --- /dev/null +++ b/examples/puzzletron/configs/scoring/validate_solutions_defaults.yaml @@ -0,0 +1,10 @@ +defaults: + - /validate_model_defaults + - _self_ + +solutions_to_validate: +skip_validation: false +save_models: false +bigger_is_better: false +sort_solutions_by: +calculate_full_score_ablations: false diff --git a/examples/puzzletron/configs/validate_model_defaults.yaml b/examples/puzzletron/configs/validate_model_defaults.yaml new file mode 100644 index 0000000000..ce1749d969 --- /dev/null +++ b/examples/puzzletron/configs/validate_model_defaults.yaml @@ -0,0 +1,17 @@ +model_dtype: torch.bfloat16 # dtype to cast the model for validate_model +autocast_dtype: torch.bfloat16 # dtype for torch.autocast for validate_model +block_size: 8192 +bos_rate: 0.5 +data_column: messages +val_dataset_name: valid +shuffle_seed: 81436 +seed: 42 +fim_rate: 0 +fim_spm_rate: 0 +source_datasets_to_discard: +varlen: false +write_results: false +calc_losses_on_cpu: false +activations_log_dir: +model_name_or_path: +load_dataset_fn: ${get_object:modelopt.torch.puzzletron.utils.data.dataloaders.load_from_disk_fn} diff --git a/examples/puzzletron/configs/validate_solutions_defaults.yaml b/examples/puzzletron/configs/validate_solutions_defaults.yaml new file mode 100644 index 0000000000..ec13902379 --- /dev/null +++ b/examples/puzzletron/configs/validate_solutions_defaults.yaml @@ -0,0 +1,10 @@ +defaults: + - /validate_model_defaults + - _self_ + +solutions_to_validate: +skip_validation: false +save_models: false +bigger_is_better: false +sort_solutions_by: +calculate_full_score_ablations: false diff --git a/examples/puzzletron/main.py b/examples/puzzletron/main.py index 5bb04818e5..5d4e8c7753 100644 --- a/examples/puzzletron/main.py +++ b/examples/puzzletron/main.py @@ -40,7 +40,10 @@ import modelopt.torch.puzzletron.mip.mip_and_realize_models as mip_and_realize_models import modelopt.torch.puzzletron.mip.sweep as sweep import modelopt.torch.utils.distributed as dist -from modelopt.torch.puzzletron.nas.plugins.puzzletron_nas_plugin import PuzzletronModel +from modelopt.torch.puzzletron.nas.plugins.puzzletron_nas_plugin import ( + PuzzletronModel, + _total_steps, +) from modelopt.torch.puzzletron.tools.hydra_utils import ( initialize_hydra_config_for_dir, register_hydra_resolvers, @@ -68,28 +71,37 @@ def parse_args(): return parser.parse_args() -def run_full_puzzletron(hydra_config_path: str): - """Run the full puzzletron pipeline. +def _setup(hydra_config_path: str): + """Common setup for all entry points: distributed init, Hydra config load. - Args: - config_path: Path to the YAML configuration file + Returns: + Tuple of (hydra_cfg, hydra_config_dir, hydra_config_name, n) where n is + the total number of pipeline steps. """ - mprint("Puzzletron Progress 1/8: starting puzzletron pipeline") - dist.setup(timeout=timedelta(10)) - - # Register Hydra custom resolvers (needed for config resolution) + dist.setup(timeout=timedelta(minutes=10)) register_hydra_resolvers() - hydra_config_path = Path(hydra_config_path).resolve() - hydra_config_dir = str(hydra_config_path.parent) - hydra_config_name = hydra_config_path.stem + resolved = Path(hydra_config_path).resolve() + hydra_config_dir = str(resolved.parent) + hydra_config_name = resolved.stem - # Load hydra config hydra_cfg = initialize_hydra_config_for_dir( config_dir=hydra_config_dir, config_name=hydra_config_name, overrides=[], ) + return hydra_cfg, hydra_config_dir, hydra_config_name, _total_steps(hydra_cfg) + + +def run_full_puzzletron(hydra_config_path: str): + """Run the full puzzletron pipeline. + + Args: + config_path: Path to the YAML configuration file + """ + hydra_cfg, hydra_config_dir, hydra_config_name, n = _setup(hydra_config_path) + + mprint(f"Puzzletron Progress 1/{n}: starting puzzletron pipeline") # Convert model (convert from HF to DeciLM, score pruning activations, # prune the model and save pruned checkpoints) @@ -120,7 +132,7 @@ def run_full_puzzletron(hydra_config_path: str): ) dist.cleanup() - mprint("Puzzletron Progress 8/8: puzzletron pipeline completed (multi-gpu)") + mprint(f"Puzzletron Progress {n}/{n}: puzzletron pipeline completed (multi-gpu)") def run_mip_only(hydra_config_path: str): @@ -132,36 +144,23 @@ def run_mip_only(hydra_config_path: str): Args: hydra_config_path: Path to the YAML configuration file """ - dist.setup(timeout=timedelta(10)) - - # Register Hydra custom resolvers (needed for config resolution) - register_hydra_resolvers() - - hydra_config_path = Path(hydra_config_path).resolve() - hydra_config_dir = str(hydra_config_path.parent) - hydra_config_name = hydra_config_path.stem - - # Load hydra config - hydra_cfg = initialize_hydra_config_for_dir( - config_dir=hydra_config_dir, - config_name=hydra_config_name, - overrides=[], - ) + hydra_cfg, _hydra_config_dir, _hydra_config_name, n = _setup(hydra_config_path) + mip_step = n - 1 # Check if sweep mode is enabled if hasattr(hydra_cfg.mip, "sweep") and hydra_cfg.mip.sweep.get("enabled", False): mprint( - "Puzzletron Progress 7/8: running MIP sweep for multiple compression rates (multi-gpu)" + f"Puzzletron Progress {mip_step}/{n}: running MIP sweep for multiple compression rates (multi-gpu)" ) sweep.run_mip_sweep(hydra_cfg) else: # mip_and_realize_models (distributed processing) # TODO: How to make it part of mnt.search() api, similarly to run_full_puzzletron() API - mprint("Puzzletron Progress 7/8: running MIP and realizing models (multi-gpu)") + mprint(f"Puzzletron Progress {mip_step}/{n}: running MIP and realizing models (multi-gpu)") mip_and_realize_models.launch_mip_and_realize_model(hydra_cfg) dist.cleanup() - mprint("Puzzletron Progress 8/8: puzzletron pipeline completed (multi-gpu)") + mprint(f"Puzzletron Progress {n}/{n}: puzzletron pipeline completed (multi-gpu)") def main(): diff --git a/modelopt/torch/puzzletron/anymodel/model_descriptor/model_descriptor.py b/modelopt/torch/puzzletron/anymodel/model_descriptor/model_descriptor.py index 4cc4356c8e..9bcddad186 100644 --- a/modelopt/torch/puzzletron/anymodel/model_descriptor/model_descriptor.py +++ b/modelopt/torch/puzzletron/anymodel/model_descriptor/model_descriptor.py @@ -160,6 +160,19 @@ def layer_name_predicates(num_layers: int) -> Dict[str, re.Pattern]: """ raise NotImplementedError + @staticmethod + def pruning_mixins() -> Dict[str, Any]: + """Return available pruning mixins for bypass distillation. + + Override in subclasses to provide model-specific pruning mixins, e.g. + ``{"kv_heads": KVHeadsPruningMixIn(...), "experts_removal": ExpertRemovalPruningMixIn(...)}``. + + Returns an empty dict by default so that descriptors that do not need + model-specific weight-slicing (e.g. Llama with standard FFN truncation) + can rely on the generic ``create_child_state_dict`` fallback path. + """ + return {} + @staticmethod def uses_autocast() -> bool: """Whether this model supports torch.autocast. diff --git a/modelopt/torch/puzzletron/anymodel/models/nemotron_h/nemotron_h_model_descriptor.py b/modelopt/torch/puzzletron/anymodel/models/nemotron_h/nemotron_h_model_descriptor.py index 55d9ef56ca..50dc2db4b9 100644 --- a/modelopt/torch/puzzletron/anymodel/models/nemotron_h/nemotron_h_model_descriptor.py +++ b/modelopt/torch/puzzletron/anymodel/models/nemotron_h/nemotron_h_model_descriptor.py @@ -34,6 +34,10 @@ ExpertRemovalLayerDescriptor, ExpertRemovalPruningMixIn, ) +from modelopt.torch.puzzletron.pruning.kv_heads_pruning_mixin import ( + KVHeadsLayerDescriptor, + KVHeadsPruningMixIn, +) from modelopt.torch.puzzletron.pruning.pruning_mixin import PruningMixIn @@ -52,6 +56,15 @@ def get_dynamic_modules(module_cls_str: str) -> List[Type[nn.Module]]: return matches +@dataclass +class NemotronHKVHeadsLayerDescriptor(KVHeadsLayerDescriptor): + o_proj_name: str = "mixer.o_proj" + attn_prefix_name: str = "backbone.layers.{layer_idx}.mixer" + qkvo_weight_names: List[str] = field( + default_factory=lambda: ["q_proj", "k_proj", "v_proj", "o_proj"] + ) + + @dataclass class NemotronHExpertRemovalLayerDescriptor(ExpertRemovalLayerDescriptor): target_name: str = "mixer.gate" @@ -253,4 +266,5 @@ def build_attention_predicates() -> Dict[str, re.Pattern]: def pruning_mixins() -> Dict[str, PruningMixIn]: return { "experts_removal": ExpertRemovalPruningMixIn(NemotronHExpertRemovalLayerDescriptor()), + "kv_heads": KVHeadsPruningMixIn(NemotronHKVHeadsLayerDescriptor()), } diff --git a/modelopt/torch/puzzletron/bypass_distillation/__init__.py b/modelopt/torch/puzzletron/bypass_distillation/__init__.py new file mode 100644 index 0000000000..790166b451 --- /dev/null +++ b/modelopt/torch/puzzletron/bypass_distillation/__init__.py @@ -0,0 +1,22 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Bypass distillation (blockwise local distillation) for the PUZZLE framework. + +This module implements Stage 1 of the PUZZLE pipeline: training alternative transformer +block configurations using per-block knowledge distillation from a teacher model. +""" + +from .training_loop import launch_bypass_distillation diff --git a/modelopt/torch/puzzletron/bypass_distillation/bypass_checkpoint_utils.py b/modelopt/torch/puzzletron/bypass_distillation/bypass_checkpoint_utils.py new file mode 100644 index 0000000000..c2064717dd --- /dev/null +++ b/modelopt/torch/puzzletron/bypass_distillation/bypass_checkpoint_utils.py @@ -0,0 +1,211 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Checkpoint utilities for bypass distillation.""" + +import re +from collections import OrderedDict +from pathlib import Path +from typing import Optional, Type, Union + +import torch +from omegaconf import DictConfig +from tqdm import tqdm + +import modelopt.torch.utils.distributed as dist +from modelopt.torch.puzzletron.anymodel.model_descriptor import ModelDescriptor +from modelopt.torch.puzzletron.tools.checkpoint_utils_hf import save_checkpoint +from modelopt.torch.puzzletron.tools.logger import aprint, mprint +from modelopt.torch.puzzletron.tools.robust_json import json_dump + +from .stitched_model_factory import StitchedModuleDescriptor + + +def find_latest_run_dir(run_parent_dir: Union[str, Path]) -> str | None: + """Find the latest checkpoint directory within a run parent directory.""" + run_parent_dir = Path(run_parent_dir) + + # Check for the "latest" directory + latest_dir = run_parent_dir / "latest" + if latest_dir.exists() and (latest_dir / "saving_completed").exists(): + return str(latest_dir) + + # If "latest" doesn't exist, look explicitly into directories with `*iter-*` + candidate_dirs = [d for d in run_parent_dir.glob("*iter-*") if d.is_dir()] + + if not candidate_dirs: + return None + + def get_iter_num(dir_name): + match = re.search(r"iter-(\d+)", dir_name.name) + return int(match.group(1)) if match else 0 + + checkpoint_dirs = sorted(candidate_dirs, key=get_iter_num, reverse=True) + for latest_dir in checkpoint_dirs: + if (latest_dir / "saving_completed").exists(): + return str(latest_dir) + return None + + +def find_best_run_dir(run_parent_dir: Union[str, Path]) -> str | None: + """Find the best-validation checkpoint directory within a run parent directory. + + Returns the ``best-iter-*`` directory with the highest iteration number that has a + ``saving_completed`` marker. Falls back to ``None`` when no best checkpoint exists + (e.g. validation was disabled or no improvement was recorded). + """ + run_parent_dir = Path(run_parent_dir) + best_dirs = [d for d in run_parent_dir.glob("best-iter-*") if d.is_dir()] + if not best_dirs: + return None + + def get_iter_num(d): + m = re.search(r"iter-(\d+)", d.name) + return int(m.group(1)) if m else 0 + + for best_dir in sorted(best_dirs, key=get_iter_num, reverse=True): + if (best_dir / "saving_completed").exists(): + return str(best_dir) + return None + + +def load_local_state( + stitched_module_descriptors: OrderedDict[str, StitchedModuleDescriptor], + checkpoint_path: str | Path, + verbose=True, +) -> None: + """Load local state from a checkpoint. + + Loads both optimizer and state dicts into stitched module descriptors. + Modifies stitched_module_descriptors in place. + """ + device = torch.device(f"cuda:{dist.local_rank()}") + load_dir = Path(checkpoint_path) + + if not load_dir.exists(): + raise RuntimeError(f'Can\'t load local state. "{load_dir}" does not exist.') + + for stitched_module_name, stitched_module_descriptor in stitched_module_descriptors.items(): + stitched_module = stitched_module_descriptor.stitched_module + optimizer = stitched_module_descriptor.optimizer + + state_dict_path = load_dir / "stitched" / f"{stitched_module_name}.state_dict.pth" + if verbose: + mprint(f"Loading state dict for module {stitched_module_name} from {state_dict_path}") + loaded_state_dict = torch.load(state_dict_path, map_location=device, weights_only=True) + # Use strict=False so parameters absent in the checkpoint (e.g. newly added adapter + # keys not yet saved) retain their initialised values rather than raising an error. + stitched_module.load_state_dict(loaded_state_dict, strict=False) + del loaded_state_dict + + if optimizer is not None: + optimizer_state_path = ( + load_dir / "stitched" / f"{stitched_module_name}.optimizer_state.pth" + ) + if verbose: + mprint( + f"Loading optimizer state for module {stitched_module_name} from {optimizer_state_path}" + ) + loaded_optimizer_state = torch.load( + optimizer_state_path, map_location=device, weights_only=True + ) + optimizer.load_state_dict(loaded_optimizer_state) + del loaded_optimizer_state + + +def _save_local_file(obj, save_path: Path | str, overwrite=True): + save_path = Path(save_path) + if save_path.exists(): + if not overwrite: + mprint(f'WARNING: Local save path "{save_path}" already exists. Skipping') + return + torch.save(obj, save_path) + + +def _save_local_state( + stitched_module_descriptors: OrderedDict[str, StitchedModuleDescriptor], + checkpoint_dir: Path | str, + overwrite=True, + verbose=True, +) -> None: + save_dir = Path(checkpoint_dir) / "stitched" + + if dist.is_master(): + save_dir.mkdir(parents=True, exist_ok=True) + + # Main process creates the directory, so we must wait for it to finish + dist.barrier() + + for stitched_module_name, stitched_module_descriptor in tqdm( + stitched_module_descriptors.items() + ): + optimizer = stitched_module_descriptor.optimizer + + state_dict_path = save_dir / f"{stitched_module_name}.state_dict.pth" + if verbose: + aprint(f"Saving state dict for module {stitched_module_name} to {state_dict_path}") + state_dict = { + **stitched_module_descriptor.owned_parameters, + **stitched_module_descriptor.owned_buffers, + } + _save_local_file(state_dict, state_dict_path, overwrite=overwrite) + + if optimizer is not None: + optimizer_state_path = save_dir / f"{stitched_module_name}.optimizer_state.pth" + if verbose: + mprint( + f"Saving optimizer state for module {stitched_module_name} to {optimizer_state_path}" + ) + _save_local_file(optimizer.state_dict(), optimizer_state_path, overwrite=overwrite) + + dist.barrier() + + +def save_bypass_checkpoint( + cfg: DictConfig, + descriptor: Type[ModelDescriptor], + model: torch.nn.Module, + stitched_module_descriptors: OrderedDict[str, StitchedModuleDescriptor], + checkpoint_dir: Path | str, + reference_checkpoint_dir: Optional[Path] = None, +) -> None: + """Save a bypass distillation checkpoint.""" + checkpoint_dir = Path(checkpoint_dir) + mprint("Starting checkpoint save") + mprint(f"Saving checkpoint to {checkpoint_dir}") + + # Save stitched module states + _save_local_state( + stitched_module_descriptors=stitched_module_descriptors, + checkpoint_dir=checkpoint_dir, + overwrite=cfg.bypass.model.model_overrides.delete_old_checkpoints, + verbose=dist.is_master() and False, + ) + # Save as HF checkpoint + save_checkpoint(model=model, checkpoint_dir=checkpoint_dir, descriptor=descriptor) + + if dist.is_master(): + # Create 'latest' symlink + latest_symlink = Path(cfg.bypass.experiment_dir) / "latest" + latest_symlink.unlink(missing_ok=True) + latest_symlink.symlink_to(checkpoint_dir.name) + # Save config args json + json_dump(cfg.bypass, checkpoint_dir / "args.json") + # Save completed file + completed_file = checkpoint_dir / "saving_completed" + completed_file.touch() + + dist.barrier() + mprint("Checkpoint save done") diff --git a/modelopt/torch/puzzletron/bypass_distillation/bypass_utils.py b/modelopt/torch/puzzletron/bypass_distillation/bypass_utils.py new file mode 100644 index 0000000000..8d3ecea7d1 --- /dev/null +++ b/modelopt/torch/puzzletron/bypass_distillation/bypass_utils.py @@ -0,0 +1,149 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utility functions for bypass distillation.""" + +from pathlib import Path +from typing import Any + +from omegaconf import DictConfig + +import modelopt.torch.utils.distributed as dist + +# --------------------------------------------------------------------------- +# Experiment ID generation +# --------------------------------------------------------------------------- + +# Priority-ordered specs: (section, dot-path into first entry, tag prefix). +# For each section the *first* matching non-None field contributes one +# component; later fields in the same section are skipped. +# Add entries here to support new block types or dimension fields. +_OVERRIDE_COMPONENT_SPECS: list[tuple[str, str, str]] = [ + ("ffn", "intermediate_size", "ffn"), + ("ffn", "moe.num_local_experts", "experts"), + ("attention", "num_key_value_heads", "kv"), + ("attention", "mamba.state_dim", "mambastate"), +] + +# Fallback type tag when no structural change exists in model_config_overrides. +_KEYS_TO_LEARN_FALLBACK: dict[str, str] = { + "subblock_ffn": "ffn", + "subblock_attention": "attn", + "subblock_mamba": "mamba", + "entire_block": "block", +} + + +def _get_nested(obj: Any, dotpath: str) -> Any: + """Return a nested value from a dict/DictConfig via dot-separated path. + + Returns ``None`` for any missing key or traversal failure so callers can + safely treat absent and ``None``-valued fields identically. + """ + for key in dotpath.split("."): + if obj is None: + return None + try: + obj = obj[key] + except Exception: + return None + return obj + + +def _build_experiment_id_components(overrides: Any) -> list[str]: + """Return ID components derived from non-None values in *overrides*. + + Each section (``ffn``, ``attention``, …) contributes at most one + component, chosen by the first matching entry in + ``_OVERRIDE_COMPONENT_SPECS``. When per-layer entries hold multiple + distinct non-None values they are listed ascending with ``-`` as + separator (e.g. ``ffn256-3072``). + """ + seen_sections: set[str] = set() + components: list[str] = [] + + for section, field_path, tag_prefix in _OVERRIDE_COMPONENT_SPECS: + if section in seen_sections: + continue + if section not in overrides or not overrides[section]: + continue + + values = [ + v for entry in overrides[section] if (v := _get_nested(entry, field_path)) is not None + ] + if not values: + continue + + unique_vals = sorted(set(values)) + components.append(tag_prefix + "-".join(str(v) for v in unique_vals)) + seen_sections.add(section) + + return components + + +def set_experiment_id(cfg: DictConfig) -> None: + """Set the experiment ID derived from model config overrides and keys_to_learn. + + The ID has the form ``bypass_{component1}_{component2}...`` where each + component encodes one structural change: + + * ``ffn{size}`` — FFN ``intermediate_size`` (e.g. ``ffn256``) + * ``experts{n}`` — MoE ``num_local_experts`` (e.g. ``experts4``) + * ``kv{n}`` — Attention ``num_key_value_heads`` (e.g. ``kv4``) + * ``mambastate{dim}`` — Mamba ``state_dim`` change + + Multiple distinct per-layer values are joined with ``-`` + (e.g. ``ffn256-3072``). When no structural change is present (pure + training bypass) the ``keys_to_learn`` type is used as a fallback + (e.g. ``bypass_mamba``, ``bypass_block``). + """ + if cfg.bypass.experiment_id is not None: + return + + overrides = cfg.bypass.model.model_config_overrides + components = _build_experiment_id_components(overrides) + + if not components: + keys_to_learn = cfg.bypass.model_factory.get("keys_to_learn", "entire_block") + fallback = ( + _KEYS_TO_LEARN_FALLBACK.get(keys_to_learn, keys_to_learn) + if isinstance(keys_to_learn, str) + else "block" + ) + components = [fallback] + + cfg.bypass.experiment_id = "bypass_" + "_".join(components) + + +def set_experiment_dir(cfg: DictConfig) -> None: + """Set the experiment directory for the bypass run.""" + experiment_dir = Path(cfg.puzzle_dir) / "bypass" / "bypass_runs" / cfg.bypass.experiment_id + cfg.bypass.experiment_dir = str(experiment_dir) + if dist.is_master(): + experiment_dir.mkdir(parents=True, exist_ok=True) + + +def get_distributed_modules_ownership(module_count: int, world_size: int) -> list[int]: + """Map module (block) indices to GPU ranks for pipeline-parallel distribution.""" + modules_process_ownership: list[int] = [] + + for i in range(world_size): + num_modules_for_process = module_count // world_size + if i < module_count % world_size: + num_modules_for_process += 1 + + modules_process_ownership.extend([i] * num_modules_for_process) + + return modules_process_ownership diff --git a/modelopt/torch/puzzletron/bypass_distillation/data_classes.py b/modelopt/torch/puzzletron/bypass_distillation/data_classes.py new file mode 100644 index 0000000000..a6b37099ce --- /dev/null +++ b/modelopt/torch/puzzletron/bypass_distillation/data_classes.py @@ -0,0 +1,42 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Data classes for bypass distillation training.""" + +import dataclasses +from typing import TypeAlias + +IterNum: TypeAlias = int +GlobalRank: TypeAlias = int + + +@dataclasses.dataclass(kw_only=True, frozen=True) +class IterStatistics: + step_num: int + token_count: int + iter_duration: float + lr: float + clipping_count: int + + +@dataclasses.dataclass(kw_only=True, frozen=True) +class LocalTrainingStats: + iter_num: int + stitched_module_losses: dict[str, float] + + +@dataclasses.dataclass(kw_only=True, frozen=True) +class TimeToSaveSignal: + step_num: int diff --git a/modelopt/torch/puzzletron/bypass_distillation/stitched_model_factory.py b/modelopt/torch/puzzletron/bypass_distillation/stitched_model_factory.py new file mode 100644 index 0000000000..d6e084e72e --- /dev/null +++ b/modelopt/torch/puzzletron/bypass_distillation/stitched_model_factory.py @@ -0,0 +1,640 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Factory for creating stitched teacher/student models for bypass distillation.""" + +import copy +import dataclasses +import re +from argparse import Namespace +from collections import OrderedDict +from pathlib import Path +from typing import Any, Callable, Mapping, Optional, Sequence, Type + +import torch +from omegaconf import DictConfig, OmegaConf +from torch.amp.grad_scaler import GradScaler +from torch.optim import AdamW, Optimizer +from transformers import PretrainedConfig, PreTrainedModel + +import modelopt.torch.utils.distributed as dist +from modelopt.torch.puzzletron.anymodel.model_descriptor import ModelDescriptor +from modelopt.torch.puzzletron.pruning.pruning_utils import GQAInitMode, LinearInitMode, MlpInitMode +from modelopt.torch.puzzletron.sewing_kit import ( + ExternalTarget, + FunctionTarget, + InputArgs, + ModuleTarget, + Needle, + RemoteTarget, + StitchedModule, + always_true_predicate, +) +from modelopt.torch.puzzletron.sewing_kit.core import InputReducer +from modelopt.torch.puzzletron.sewing_kit.utils import ( + batched_normalized_mse_loss, + normalized_mse_loss, + vectorwise_normalized_mse_loss, +) +from modelopt.torch.puzzletron.tools.bypassed_training.child_init import ( + create_child_state_dict, + update_model_config, +) +from modelopt.torch.puzzletron.tools.logger import mprint +from modelopt.torch.puzzletron.tools.sharded_checkpoint_utils import create_sharded_model +from modelopt.torch.puzzletron.utils.parsing import format_block_configs, parse_dtype + +StitchedModulesProcessOwnership = list[int] +SyncDistributedModelWeightsFn = Callable[[], None] +Config = Mapping[str, Any] +Args = Namespace + + +@dataclasses.dataclass +class StitchedModuleDescriptor: + stitched_module: StitchedModule + owned_parameters: dict[str, torch.nn.Parameter] + owned_buffers: dict[str, torch.Tensor] + optimizer: Optional[Optimizer] = None + grad_scaler: Optional[GradScaler] = None + + +def default_factory( + teacher_model: PreTrainedModel, + descriptor: Type[ModelDescriptor], + config: Config, + model_blocks_process_ownership: Sequence[int], + student_model: Optional[PreTrainedModel] = None, +) -> tuple[ + PreTrainedModel, + StitchedModule, + StitchedModule, + StitchedModule, + OrderedDict[str, StitchedModuleDescriptor], + PretrainedConfig, +]: + raise NotImplementedError() + + +StitchedModelFactoryFn = type(default_factory) + +_SUBBLOCK_KEYS_TO_LEARN = frozenset( + {"subblock_ffn", "subblock_attention", "subblock_mamba", "entire_block"} +) + + +def _set_keys_to_learn( + model: PreTrainedModel, + descriptor: Type[ModelDescriptor], + keys_to_learn: str | Sequence[str], +) -> None: + """Set ``requires_grad=True`` on parameters selected by ``keys_to_learn``. + + * A **sequence of strings** (not a bare ``str``): each string is a full parameter + name; gradients are enabled only where ``named_parameters()`` names match exactly. + * A **single string**: if it is ``"subblock_ffn"``, ``"subblock_attention"``, or + ``"entire_block"``, enables gradients for the corresponding descriptor weight + groups; otherwise ``re.search`` is applied to each parameter name. + """ + # If keys_to_learn is a sequence of strings. + if isinstance(keys_to_learn, Sequence) and not isinstance(keys_to_learn, str): + param_names = set(keys_to_learn) + # If keys_to_learn is a single string. + else: + # If keys_to_learn is a single string that is a subblock key. + if keys_to_learn in _SUBBLOCK_KEYS_TO_LEARN: + lm_config = descriptor.get_language_model_config(model.config) + weight_groups = descriptor.get_weight_groups( + model.state_dict().keys(), lm_config.num_hidden_layers + ) + + attn_group_names = [ + group_name + for group_name in weight_groups.keys() + if group_name.endswith("_attention") + ] + ffn_group_names = [ + group_name for group_name in weight_groups.keys() if group_name.endswith("_ffn") + ] + if keys_to_learn == "subblock_attention": + group_names = attn_group_names + elif keys_to_learn == "subblock_ffn": + group_names = ffn_group_names + elif keys_to_learn == "subblock_mamba": + group_names = attn_group_names # Mamba params live in _attention groups + else: # entire_block + group_names = attn_group_names + ffn_group_names + + block_configs = getattr(lm_config, "block_configs", None) + + param_names = [] + for group_name in group_names: + # For hybrid models (e.g. NemotronH), a single "_attention" group + # name can contain either Mamba SSM params *or* GQA params depending + # on the block. Use the block config — not the keys_to_learn string + # — to decide whether each block belongs to the current subblock type. + if block_configs is not None: + m = re.match(r"block_(\d+)_attention", group_name) + if m: + block_idx = int(m.group(1)) + if block_idx < len(block_configs): + is_mamba = ( + getattr(block_configs[block_idx].attention, "mamba", None) + is not None + ) + # subblock_attention → GQA blocks only (not Mamba) + # subblock_mamba → Mamba blocks only (not GQA) + # entire_block → all blocks (no filtering) + if keys_to_learn == "subblock_attention" and is_mamba: + continue + if keys_to_learn == "subblock_mamba" and not is_mamba: + continue + param_names.extend(weight_groups[group_name]) + param_names = set(param_names) + # If keys_to_learn is a single string that is not a subblock key, treat as regex. + else: + param_names = { + param_name + for param_name, _ in model.named_parameters() + if re.search(keys_to_learn, param_name) + } + # In pipeline-parallel training a rank may own only blocks that don't match + # keys_to_learn (e.g. a rank with only Mamba blocks during subblock_attention + # bypass has no GQA params after the _mamba rename). That is a valid state: + # all its blocks will produce NaN loss and be excluded from statistics. + if not param_names: + return + + # Set requires_grad to True for the selected parameters. + for param_name, param in model.named_parameters(): + if param_name in param_names and torch.is_floating_point(param): + param.requires_grad_(True) + + +def _get_all_non_persistent_buffers_set(module: torch.nn.Module) -> set[str]: + all_non_persistent = set() + for module_name, submodule in module.named_modules(): + for buffer_name in submodule._non_persistent_buffers_set: + full_name = f"{module_name}.{buffer_name}" if module_name else buffer_name + all_non_persistent.add(full_name) + return all_non_persistent + + +def _initialize_student_model( + teacher_model: PreTrainedModel, + descriptor: Type[ModelDescriptor], + cfg: DictConfig, + owned_block_indexes: set[int], + device: torch.device, + student_model: Optional[PreTrainedModel] = None, +) -> tuple[PreTrainedModel, PretrainedConfig]: + """Create and initialise the student model, or extract its config if already provided. + + If *student_model* is ``None``, builds a new model from *teacher_model* using the + pruning / init modes specified in *cfg* and loads the derived weight state dict. + Otherwise simply returns the provided model together with its config. + + Returns: + Tuple of ``(student_model, student_model_config)``. + """ + if student_model is None: + mprint("Creating student model from teacher model") + + model_config_overrides = cfg.model.model_config_overrides + + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + if isinstance(model_config_overrides, DictConfig): + config_to_override = OmegaConf.to_container(model_config_overrides, resolve=True) + else: + config_to_override = model_config_overrides + mprint(f"{config_to_override=}") + student_model_config = update_model_config( + model_config=teacher_model.config, + model_config_overrides=config_to_override, + ) + student_model_config.use_cache = False + + mprint(f"Student block configs:\n {format_block_configs(student_model_config)}") + + from modelopt.torch.puzzletron.anymodel.puzzformer import deci_x_patcher + + runtime = Namespace( + device=device, + dtype=torch.bfloat16, + global_rank=dist.rank(), + world_size=dist.size(), + is_main_process=dist.is_master(), + is_last_process=dist.is_last_process(), + ) + + with deci_x_patcher( + model_descriptor=descriptor, + block_configs=getattr(student_model_config, "block_configs", None), + ): + student_model = create_sharded_model( + runtime=runtime, + descriptor=descriptor, + model_config=student_model_config, + owned_block_indexes=owned_block_indexes, + device=device, + ) + student_model._init_weights(student_model) + + student_weights_dtype = parse_dtype(cfg.model.student_weights_dtype) + descriptor.init_rotary_embedding(student_model, runtime) + student_model.type(student_weights_dtype) + + mlp_init_mode = MlpInitMode(cfg.model_factory.mlp_init_mode or MlpInitMode.CopyAsIs) + + # For expert removal, use the model-specific pruning mixin so that model-specific + # key paths (e.g. backbone.layers.{i}.mixer for Nemotron-H vs model.layers.{i}.mlp + # for GPT-OSS) are handled correctly. For all other init modes the legacy inline + # key logic in create_child_state_dict is sufficient. + _mixins = [] + if mlp_init_mode == MlpInitMode.ExpertRemoval: + _expert_mixin = descriptor.pruning_mixins().get("experts_removal") + if _expert_mixin is not None: + _mixins.append(_expert_mixin) + + # If any attention layer has fewer KV heads in the student than the teacher, use the + # model-specific KV heads mixin so that k_proj/v_proj weights are correctly sliced + # rather than copied verbatim from the (larger) teacher state dict. + _kv_mixin = descriptor.pruning_mixins().get("kv_heads") + if _kv_mixin is not None: + _student_kv = [ + b.attention.num_key_value_heads + for b in student_model_config.block_configs + if b.attention is not None and b.attention.num_key_value_heads is not None + ] + _teacher_kv = [ + b.attention.num_key_value_heads + for b in teacher_model.config.block_configs + if b.attention is not None and b.attention.num_key_value_heads is not None + ] + if _student_kv != _teacher_kv: + _mixins.append(_kv_mixin) + + if len(_mixins) == 0: + pruning_mixin = None + elif len(_mixins) == 1: + pruning_mixin = _mixins[0] + else: + pruning_mixin = _mixins + + # GQA init mode is optional: only relevant when the student has fewer KV heads than + # the teacher. Defaults to AverageKV and is a no-op when head counts are equal. + gqa_init_mode = GQAInitMode(cfg.model_factory.get("gqa_init_mode", GQAInitMode.AverageKV)) + + student_state_dict = create_child_state_dict( + pruning_mixin=pruning_mixin, + descriptor=descriptor, + original_state_dict=teacher_model.state_dict(), + new_state_dict=student_model.state_dict(), + original_config=teacher_model.config, + new_config=student_model_config, + gqa_init_mode=gqa_init_mode, + mlp_init_mode=mlp_init_mode, + mlp_init_config=cfg.model_factory.mlp_init_config, + owned_block_indexes=owned_block_indexes, + linear_init_mode=LinearInitMode( + cfg.model_factory.linear_init_mode or LinearInitMode.Random + ), + ) + + # Load student state dict + missing_keys, unexpected_keys = student_model.load_state_dict( + student_state_dict, strict=False + ) + assert len(unexpected_keys) == 0, f"{unexpected_keys=}" + # GQA models have learnable logit parameters not present in the teacher state dict; + # allow those to be absent and assert nothing else is missing. + non_gqa_missing = [k for k in missing_keys if not re.search(r"gqa_\w+_logits", k)] + assert len(non_gqa_missing) == 0, f"Unexpected missing keys: {non_gqa_missing}" + + else: + mprint("Student model provided explicitly, not using teacher model to instantiate") + student_model_config = student_model.config + + return student_model, student_model_config + + +def bypass_factory_fn( + teacher_model: PreTrainedModel, + descriptor: Type[ModelDescriptor], + cfg: DictConfig, + model_blocks_process_ownership: Sequence[int], + student_model: Optional[PreTrainedModel] = None, +) -> tuple[ + PreTrainedModel, + StitchedModule, + StitchedModule, + StitchedModule, + OrderedDict[str, StitchedModuleDescriptor], + PretrainedConfig, +]: + """Unified factory function for bypass (blockwise local) distillation. + + Handles all layer types — FFN, attention (GQA/MHA), MoE experts, Mamba, and whole blocks — + through a single pipeline. Behavior is driven entirely by ``model_factory`` config fields: + + - ``mlp_init_mode``: how student FFN / MoE weights are initialised + - ``"ExpertRemoval"``: select top-N experts from teacher (MoE models) + - ``"Truncate"`` / ``"PruneByActivationsLog"``: prune FFN channels (dense models) + - ``"CopyAsIs"``: copy weights unchanged (attention-only or Mamba-only runs) + - ``gqa_init_mode``: how attention KV heads are initialised (optional, default ``AverageKV``). + Irrelevant when the student has the same number of KV heads as the teacher. + - ``keys_to_learn``: which parameters to train. + Accepts ``"subblock_ffn"``, ``"subblock_attention"``, ``"entire_block"``, or a regex string. + + The stitching logic (pipeline-parallel per-block KD) is architecture-agnostic and unchanged + regardless of which layer type is being distilled. + + Args: + teacher_model: The teacher model to use for stitching. + descriptor: Model descriptor for layer naming and pruning mixin lookup. + cfg: The bypass config section. + model_blocks_process_ownership: Ownership mapping of model blocks to process ranks. + student_model: Optionally provided pre-built student model (skips initialisation). + + Returns: + Tuple of (student_model, teacher_stitched, teacher_val_stitched, + student_val_stitched, stitched_module_descriptors, student_config) + """ + device = torch.device(f"cuda:{dist.local_rank()}") + + block_loss_func = { + "normalized_mse_loss": normalized_mse_loss, + "vectorwise_normalized_mse_loss": vectorwise_normalized_mse_loss, + "batched_normalized_mse_loss": batched_normalized_mse_loss, + }[cfg.model_factory.block_loss_func] + mprint(f"{block_loss_func.__name__=}") + + owned_block_indexes = set( + block_index + for block_index, owner_rank in enumerate(model_blocks_process_ownership) + if owner_rank == dist.rank() + ) + + student_model, student_model_config = _initialize_student_model( + teacher_model=teacher_model, + descriptor=descriptor, + cfg=cfg, + owned_block_indexes=owned_block_indexes, + device=device, + student_model=student_model, + ) + + # Set up training parameters + lm_config = descriptor.get_language_model_config(student_model_config) + all_block_indices = list(range(lm_config.num_hidden_layers)) + + student_model.requires_grad_(False) + keys_to_learn = cfg.model_factory.keys_to_learn + mprint(f"Keys to learn: {keys_to_learn}") + + _set_keys_to_learn(model=student_model, descriptor=descriptor, keys_to_learn=keys_to_learn) + + dist.barrier() + mprint(f"Global rank: {dist.rank()}, {owned_block_indexes=}") + dist.barrier() + + torch.cuda.synchronize() + torch.cuda.empty_cache() + dist.barrier() + + min_owned_index = min(owned_block_indexes) + max_owned_index = max(owned_block_indexes) + prev_rank: Optional[int] = ( + None + if min_owned_index == min(all_block_indices) + else model_blocks_process_ownership[min_owned_index - 1] + ) + next_rank: Optional[int] = ( + None + if max_owned_index == max(all_block_indices) + else model_blocks_process_ownership[max_owned_index + 1] + ) + + teacher_parameters = set(teacher_model.parameters()) + teacher_buffers = set(teacher_model.buffers()) + + # Setup the student model's submodules for knowledge distillation training + with torch.autocast(device_type="cuda", dtype=torch.bfloat16), torch.device(device): + stitched_module_descriptors = OrderedDict[str, StitchedModuleDescriptor]() + submodule_for_loss_calculation = cfg.model_factory.submodule_for_loss_calculation + + teacher_target = ModuleTarget("teacher", teacher_model) + teacher_stitcher = Needle() + teacher_val_stitcher = Needle() + + student_target = ModuleTarget("student", student_model) + student_val_stitcher = Needle() + + for local_block_index, global_block_index in enumerate(sorted(owned_block_indexes)): + module_name = descriptor.layer_block_name(global_block_index) + module = student_model.get_submodule(module_name) + + submodule_name = "" + submodule_input_descriptor = submodule_name + submodule_output_descriptor = submodule_name + + if submodule_for_loss_calculation is not None: + assert hasattr(module, submodule_for_loss_calculation) + submodule_output_descriptor = submodule_for_loss_calculation + + input_descriptor = f"{module_name}.{submodule_input_descriptor}".rstrip(".") + output_descriptor = f"{module_name}.{submodule_output_descriptor}".rstrip(".") + + # Receive activations from previous rank + if global_block_index > 0 and local_block_index == 0 and prev_rank is not None: + teacher_stitcher.stitch( + RemoteTarget(peer_rank=prev_rank).value( + name="teacher_activations", adapter=lambda x: InputArgs(x) + ), + teacher_target.input( + name=module_name, + reducer=InputReducer( + lambda acc, override, orig, *args: override + orig.drop_args(0) + ), + ), + ) + teacher_val_stitcher.stitch( + RemoteTarget(peer_rank=prev_rank).value( + name="teacher_activations", adapter=lambda x: InputArgs(x) + ), + teacher_target.input( + name=module_name, + reducer=InputReducer( + lambda acc, override, orig, *args: override + orig.drop_args(0) + ), + ), + ) + student_val_stitcher.stitch( + RemoteTarget(peer_rank=prev_rank).value( + name="student_activations", adapter=lambda x: InputArgs(x) + ), + student_target.input( + name=module_name, + reducer=InputReducer( + lambda acc, override, orig, *args: override + orig.drop_args(0) + ), + ), + ) + + # Send activations to next rank or register model output + if local_block_index + 1 == len(owned_block_indexes): + if next_rank is None: + student_val_stitcher.stitch( + student_target.output(name=""), + ExternalTarget().output("model_output"), + ) + teacher_val_stitcher.stitch( + teacher_target.output(name=""), + ExternalTarget().output("model_output"), + ) + else: + teacher_stitcher.stitch( + teacher_target.output(name=module_name), + RemoteTarget(peer_rank=next_rank).value(name="teacher_activations"), + ) + teacher_val_stitcher.stitch( + teacher_target.output(name=module_name), + RemoteTarget(peer_rank=next_rank).value(name="teacher_activations"), + ) + student_val_stitcher.stitch( + student_target.output(name=module_name), + RemoteTarget(peer_rank=next_rank).value(name="student_activations"), + ) + + # Bypass training stitches + teacher_stitcher.stitch( + teacher_target.input(name=input_descriptor), + ExternalTarget().input(name=input_descriptor), + ).stitch( + teacher_target.output(name=output_descriptor), + ExternalTarget().output(name=output_descriptor), + ) + + # Create the student block stitched module + student_stitched_module_loss_target = FunctionTarget( + "module_loss_func", block_loss_func + ) + student_stitched_module_name = f"block_{global_block_index}" + student_submodule_target = ModuleTarget("student_submodule", module) + # CONTRACT: block_loss_func must accept exactly two keyword arguments: + # input= (student activations) and target= (teacher activations). + # The adapters below wire activations into those kwargs via InputArgs. + # All built-in loss functions (normalized_mse_loss, vectorwise_normalized_mse_loss, + # batched_normalized_mse_loss) satisfy this contract. If you substitute a custom + # block_loss_func, ensure it accepts (input=..., target=...) as keyword arguments. + student_stitched_module = ( + Needle() + .stitch( + ExternalTarget().input(name=input_descriptor), + student_submodule_target.input(name=submodule_input_descriptor), + ) + .stitch( + ExternalTarget().output( + name=output_descriptor, + adapter=lambda v: InputArgs(target=v) + if not isinstance(v, tuple) + else InputArgs(target=v[0]), + ), + student_stitched_module_loss_target.input(), + ) + .stitch( + student_submodule_target.output( + name=submodule_output_descriptor, + adapter=lambda v: InputArgs(input=v) + if not isinstance(v, tuple) + else InputArgs(input=v[0]), + ), + student_stitched_module_loss_target.input(), + ) + .stitch( + student_stitched_module_loss_target.output(), + ExternalTarget().output(name="loss"), + ) + .knot( + ignore_extra_overrides=True, + capture_cache_outputs_predicate=always_true_predicate, + ) + ) + + assert "learning_rate" in cfg.training + # Do NOT enable dummy params: blocks with no real trainable parameters + # (e.g. Mamba blocks during an attention-only bypass run) should produce + # NaN loss so they are excluded from statistics — identical to the + # optimizer=None path in the training loop. + + student_module_parameters = { + p_name: p + for p_name, p in student_stitched_module.named_parameters() + if p not in teacher_parameters and "dummy_param" not in p_name + } + student_module_buffers = { + p_name: p + for p_name, p in student_stitched_module.named_buffers() + if p not in teacher_buffers + and p_name not in _get_all_non_persistent_buffers_set(student_stitched_module) + } + + trainable_params = { + p_name: p for p_name, p in student_module_parameters.items() if p.requires_grad + } + mprint( + f"Block {submodule_name}: {len(trainable_params)} trainable parameter tensors " + f"({sum(p.numel() for p in trainable_params.values()):,} params)" + ) + + optimizer = ( + AdamW( + list(trainable_params.values()), + lr=cfg.training.learning_rate, + weight_decay=cfg.training.weight_decay, + betas=(cfg.training.beta1, cfg.training.beta2), + fused=True, + ) + if len(trainable_params) > 0 + else None + ) + + grad_scaler = ( + None + if optimizer is None + else GradScaler(device=device.type, enabled=cfg.training.use_grad_scaling) + ) + + stitched_module_descriptors[student_stitched_module_name] = StitchedModuleDescriptor( + stitched_module=student_stitched_module, + owned_parameters=student_module_parameters, + owned_buffers=student_module_buffers, + optimizer=optimizer, + grad_scaler=grad_scaler, + ) + + teacher_stitched_module = teacher_stitcher.knot(ignore_extra_overrides=True) + teacher_val_stitched_module = teacher_val_stitcher.knot(ignore_extra_overrides=True) + student_val_stitched_module = student_val_stitcher.knot(ignore_extra_overrides=True) + + return ( + student_model, + teacher_stitched_module, + teacher_val_stitched_module, + student_val_stitched_module, + stitched_module_descriptors, + student_model_config, + ) + diff --git a/modelopt/torch/puzzletron/bypass_distillation/training_loop.py b/modelopt/torch/puzzletron/bypass_distillation/training_loop.py new file mode 100644 index 0000000000..ffa84a58e7 --- /dev/null +++ b/modelopt/torch/puzzletron/bypass_distillation/training_loop.py @@ -0,0 +1,1151 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Bypass distillation training loop for per-block knowledge distillation. + +This module implements the blockwise local distillation (BLD) stage of the PUZZLE framework. +It trains alternative transformer block configurations using per-block knowledge distillation +from a teacher model, producing a library of "puzzle pieces" with different efficiency/performance +trade-offs. +""" + +import logging +import math +import os +import shutil +import sys +import time +import traceback +from collections import OrderedDict, defaultdict +from pathlib import Path +from statistics import mean +from typing import Optional, Type, cast + +import torch +import torch.distributed +import transformers +from omegaconf import DictConfig, OmegaConf +from torch.utils.data.dataloader import DataLoader +from transformers import AutoTokenizer, PretrainedConfig, PreTrainedTokenizerBase + +import datasets +import modelopt.torch.puzzletron.bypass_distillation.stitched_model_factory as stitched_model_factory_module +import modelopt.torch.utils.distributed as dist +from modelopt.torch.puzzletron.anymodel.model_descriptor import ( + ModelDescriptor, + ModelDescriptorFactory, +) +from modelopt.torch.puzzletron.sewing_kit import InputArgs, StitchedModule +from modelopt.torch.puzzletron.sewing_kit.utils import fake_tensor +from modelopt.torch.puzzletron.tools.checkpoint_utils_hf import load_model_config +from modelopt.torch.puzzletron.tools.logger import aprint, mprint +from modelopt.torch.puzzletron.tools.robust_json import json_load +from modelopt.torch.puzzletron.tools.sharded_checkpoint_utils import load_and_shard_model +from modelopt.torch.puzzletron.utils.parsing import format_global_config, format_stitched_losses + +from .bypass_checkpoint_utils import ( + find_best_run_dir, + find_latest_run_dir, + load_local_state, + save_bypass_checkpoint, +) +from .bypass_utils import get_distributed_modules_ownership, set_experiment_dir, set_experiment_id +from .data_classes import GlobalRank, IterNum, IterStatistics, LocalTrainingStats, TimeToSaveSignal +from .stitched_model_factory import StitchedModuleDescriptor, StitchedModulesProcessOwnership + +time_start = time.time() + +os.environ["TOKENIZERS_PARALLELISM"] = "false" + + +def launch_bypass_distillation(hydra_cfg: DictConfig) -> None: + """Top-level entry point for bypass distillation stage. + + Supports multiple bypass configurations via ``bypass.configs`` list. + Each entry overrides ``bypass.model.model_config_overrides`` and optionally + ``bypass.model_factory.keys_to_learn``, then runs a full bypass training. + + If ``bypass.configs`` is absent or empty, runs a single bypass training + with the settings already in ``bypass``. + + Args: + hydra_cfg: The full Hydra configuration with a 'bypass' section. + """ + configs_list = hydra_cfg.bypass.get("configs", None) + + # Auto-generate bypass configs from the pruning search space. + # bypass.auto_configs.ffn: true → one subblock_ffn config per size in + # pruning.intermediate_size_list (teacher size excluded) + # bypass.auto_configs.attn: true → one subblock_attention config per kv-head count derived + # from pruning.n_heads_in_group_list (teacher kv excluded) + # bypass.auto_configs.blk: true → cartesian product of ALL FFN sizes × ALL kv-head counts, + # trained as entire_block; only the single combination where + # both FFN and attention equal teacher values is skipped + if not configs_list: + auto_cfg = hydra_cfg.bypass.get("auto_configs", {}) or {} + + do_ffn = bool(auto_cfg.get("ffn", False)) + do_attn = bool(auto_cfg.get("attn", False)) + do_blk = bool(auto_cfg.get("blk", False)) + + if do_ffn or do_attn or do_blk: + from transformers import AutoConfig as HFAutoConfig + + teacher_hf_cfg = HFAutoConfig.from_pretrained(str(hydra_cfg.teacher_dir)) + teacher_intermediate_size = getattr(teacher_hf_cfg, "intermediate_size", None) + teacher_num_heads = getattr(teacher_hf_cfg, "num_attention_heads", None) + teacher_kv_heads = ( + getattr(teacher_hf_cfg, "num_key_value_heads", None) or teacher_num_heads + ) + + # ffn_sizes: pruning candidates with teacher size removed (for ffn/blk) + ffn_sizes: list = [] + # all_ffn_sizes: full list including teacher size (for blk cartesian) + all_ffn_sizes: list = [] + # kv_heads_list: pruning candidates with teacher kv removed (for attn/blk) + kv_heads_list: list = [] + # all_kv_heads: full list including teacher kv (for blk cartesian) + all_kv_heads: list = [] + + if do_ffn or do_blk: + all_ffn_sizes = list(hydra_cfg.pruning.intermediate_size_list) + ffn_sizes = [s for s in all_ffn_sizes if s != teacher_intermediate_size] + if len(ffn_sizes) < len(all_ffn_sizes): + mprint( + f"auto_configs: skipped teacher intermediate_size={teacher_intermediate_size}" + f" from ffn configs" + ) + + if do_attn or do_blk: + n_heads_in_group_list = list(hydra_cfg.pruning.get("n_heads_in_group_list") or []) + if n_heads_in_group_list and teacher_num_heads: + all_kv_heads = [teacher_num_heads // n for n in n_heads_in_group_list] + kv_heads_list = [kv for kv in all_kv_heads if kv != teacher_kv_heads] + if len(kv_heads_list) < len(all_kv_heads): + mprint( + f"auto_configs: skipped teacher kv_heads={teacher_kv_heads}" + f" from attn configs" + ) + elif do_attn or do_blk: + mprint( + "auto_configs: pruning.n_heads_in_group_list is not set; " + "skipping attn/blk auto_configs" + ) + + generated: list = [] + + if do_ffn: + for size in ffn_sizes: + generated.append( + { + "model_config_overrides": { + "ffn": [{"intermediate_size": size}], + "attention": [{"num_key_value_heads": None}], + }, + "keys_to_learn": "subblock_ffn", + } + ) + + if do_attn: + for kv in kv_heads_list: + generated.append( + { + "model_config_overrides": { + "ffn": [{"intermediate_size": None}], + "attention": [{"num_key_value_heads": kv}], + }, + "keys_to_learn": "subblock_attention", + } + ) + + if do_blk and all_ffn_sizes and all_kv_heads: + # Cartesian product of ALL sizes × ALL kv_heads. + # Only skip the one combination where both equal the teacher's values. + for size in all_ffn_sizes: + for kv in all_kv_heads: + if size == teacher_intermediate_size and kv == teacher_kv_heads: + continue + generated.append( + { + "model_config_overrides": { + "ffn": [{"intermediate_size": size}], + "attention": [{"num_key_value_heads": kv}], + }, + "keys_to_learn": "entire_block", + } + ) + + if generated: + mprint( + f"auto_configs: generated {len(generated)} bypass configs " + f"(ffn={do_ffn}, attn={do_attn}, blk={do_blk})" + ) + configs_list = OmegaConf.create(generated) + + if not configs_list: + # Single config mode — run once with whatever is in bypass already + mprint("Starting bypass distillation (single config)") + run_bypassed_training(hydra_cfg) + mprint("Bypass distillation completed") + return + + mprint(f"Starting bypass distillation sweep ({len(configs_list)} configs)") + for i, override in enumerate(configs_list): + mprint(f"Bypass config {i + 1}/{len(configs_list)}: {override}") + + # Apply overrides for this run + if "model_config_overrides" in override: + hydra_cfg.bypass.model.model_config_overrides = override.model_config_overrides + if "keys_to_learn" in override: + hydra_cfg.bypass.model_factory.keys_to_learn = override.keys_to_learn + + # Reset per-run state so each config starts fresh + hydra_cfg.bypass.experiment_id = None + hydra_cfg.bypass.iter_num = 1 + hydra_cfg.bypass.step_num = 1 + hydra_cfg.bypass.token_count = 0 + hydra_cfg.bypass.best_val_loss = 1e9 + hydra_cfg.bypass.training.clipping_count = 0 + + # Resolve the experiment dir now (cheap) to check for a completion marker. + # set_experiment_id checks experiment_id is None before overwriting, so the + # call inside run_bypassed_training becomes a no-op. + set_experiment_id(hydra_cfg) + set_experiment_dir(hydra_cfg) + training_marker = Path(hydra_cfg.bypass.experiment_dir) / "training.complete" + if training_marker.exists(): + mprint( + f"Bypass config {i + 1}/{len(configs_list)}: already completed " + f"({hydra_cfg.bypass.experiment_id}), skipping" + ) + continue + + run_bypassed_training(hydra_cfg) + mprint(f"Bypass config {i + 1}/{len(configs_list)} completed") + + mprint("Bypass distillation sweep completed") + + +def train( + cfg: DictConfig, + descriptor: Type[ModelDescriptor], + student_model: torch.nn.Module, + student_stitched_model: StitchedModule, + teacher_stitched_model: StitchedModule, + stitched_module_descriptors: OrderedDict[str, StitchedModuleDescriptor], + stitched_modules_process_ownership: StitchedModulesProcessOwnership, + train_dataloader: DataLoader, + val_dataloader: Optional[DataLoader], + student_model_config: PretrainedConfig, + skip_first_batches: int = 0, + tokenizer: Optional[PreTrainedTokenizerBase] = None, +) -> None: + """Inner training loop for bypass distillation.""" + device = torch.device(f"cuda:{dist.local_rank()}") + + dist.barrier() + + time_last_save = time_start + iter_t0 = time.time() + + resumed_iter_num = cfg.bypass.iter_num + mprint(f"resumed_iter_num: {resumed_iter_num}") + + # Number of total stitched modules + global_stitched_modules_count = len(stitched_modules_process_ownership) + # Number of stitched modules per process + num_stitched_modules_per_process = [ + sum(1 for x in stitched_modules_process_ownership if x == owner_rank) + for owner_rank in range(dist.size()) + ] + # Indices of stitched modules owned by the current process + owned_stitched_module_indices = [ + i for i, owner in enumerate(stitched_modules_process_ownership) if owner == dist.rank() + ] + mprint(f"{global_stitched_modules_count=}") + mprint(f"{num_stitched_modules_per_process=}") + dist.barrier() + + if dist.is_master(): + # {iter_num: {stitched_module_name: loss}} + stitched_losses_history = dict[IterNum, dict[str, float]]() + else: + stitched_losses_history = None + + # Save checkpoint before training starts + if cfg.bypass.save_checkpoint_before_training and not cfg.bypass.disable_checkpoint_save: + subdir_name = f"start-iter-{cfg.bypass.iter_num:06d}-ckpt" + save_bypass_checkpoint( + cfg=cfg, + descriptor=descriptor, + model=student_model, + stitched_module_descriptors=stitched_module_descriptors, + checkpoint_dir=Path(cfg.bypass.experiment_dir) / subdir_name, + reference_checkpoint_dir=cfg.teacher_dir, + ) + + # Track statistics for each iteration + iter_stats_history: dict[IterNum, IterStatistics] = {} + + # Create fake input ids for the teacher model + fake_input_ids = fake_tensor( + torch.ones( + size=(cfg.bypass.training.micro_batch_size, cfg.bypass.data.block_size), + dtype=torch.long, + device=device, + ) + ) + + # Get pipeline neighbor ranks + min_owned_index = min(owned_stitched_module_indices) + max_owned_index = max(owned_stitched_module_indices) + prev_rank: Optional[int] = ( + None if min_owned_index - 1 < 0 else stitched_modules_process_ownership[min_owned_index - 1] + ) + next_rank: Optional[int] = ( + None + if max_owned_index + 1 >= global_stitched_modules_count + else stitched_modules_process_ownership[max_owned_index + 1] + ) + + torch.cuda.synchronize() + + mprint( + f"Grad scaling status: {'enabled' if cfg.bypass.training.use_grad_scaling else 'disabled'}" + ) + + train_iterator = iter(train_dataloader) + + mprint("Waiting for everyone before training starts") + dist.barrier() + + step_to_save = None + # Track best loss value for each block + best_losses_by_name = dict[str, float]() + best_steps_by_name = dict[str, int]() + # Buffer variables + input_ids = torch.zeros(1, 1, dtype=torch.int64) + + aprint( + f"previous rank: {str(prev_rank):<5} next rank: {str(next_rank):<5} {owned_stitched_module_indices=}" + ) + + # Train loop start + while True: + time_now = time.time() + # Check if we've reached the maximum number of steps + if cfg.bypass.step_num >= cfg.bypass.training.max_steps: + _save_final_checkpoint(cfg, descriptor, student_model, stitched_module_descriptors) + break + + is_accumulating = cfg.bypass.iter_num % cfg.bypass.training.grad_accumulation_steps != 0 + # Determine and set the learning rate for this iteration + lr = ( + _get_lr(cfg, cfg.bypass.step_num) + if cfg.bypass.training.decay_lr + else cfg.bypass.training.learning_rate + ) + for stitched_module_descriptor in stitched_module_descriptors.values(): + optimizer = stitched_module_descriptor.optimizer + if optimizer is not None: + for param_group in optimizer.param_groups: + param_group["lr"] = lr + + if dist.is_master(): + train_data = next(train_iterator) + input_ids = train_data["input_ids"] + input_ids = input_ids.to(device) + + with torch.autocast(device_type="cuda", dtype=torch.bfloat16), torch.no_grad(): + teacher_input_ids = input_ids if prev_rank is None else fake_input_ids + teacher_output = teacher_stitched_model({}, {}, teacher_input_ids) + + input_overrides = teacher_output.captured_inputs + output_overrides = teacher_output.captured_outputs + + del teacher_output + + input_overrides["teacher_inputs"] = InputArgs(fake_input_ids) + + iter_stitched_module_losses: dict[str, float] = {} + + for local_stitched_module_index, ( + stitched_module_name, + stitched_module_descriptor, + ) in enumerate(stitched_module_descriptors.items()): + stitched_module = stitched_module_descriptor.stitched_module + optimizer = stitched_module_descriptor.optimizer + grad_scaler = stitched_module_descriptor.grad_scaler + + if optimizer is not None: + assert grad_scaler is not None + + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + stitched_module_output = stitched_module( + input_overrides=input_overrides, + output_overrides=output_overrides, + ) + stitched_module_loss = stitched_module_output.captured_outputs["loss"] + del stitched_module_output + grad_scaler.scale(stitched_module_loss).backward() + else: + stitched_module_loss = torch.full([1], fill_value=torch.nan, dtype=torch.float32) + + iter_stitched_module_losses[stitched_module_name] = stitched_module_loss.to( + "cpu" + ).item() + + del stitched_module_loss + + if not is_accumulating: + if optimizer is not None: + grad_clip = cfg.bypass.training.grad_clip + if grad_clip is not None: + if cfg.bypass.training.grad_clip_type == "norm": + grad_norm = torch.nn.utils.clip_grad_norm_( + parameters=stitched_module.parameters(), + max_norm=grad_clip, + ) + if grad_norm > grad_clip: + cfg.bypass.training.clipping_count += 1 + elif cfg.bypass.training.grad_clip_type == "value": + max_abs_grad_per_param = [ + p.grad.abs().max().item() + for p in stitched_module.parameters() + if p.grad is not None + ] + max_abs_grad = ( + max(max_abs_grad_per_param) + if len(max_abs_grad_per_param) > 0 + else 0.0 + ) + if max_abs_grad > grad_clip: + cfg.bypass.training.clipping_count += 1 + torch.nn.utils.clip_grad_value_( + parameters=stitched_module.parameters(), + clip_value=grad_clip, + ) + else: + raise RuntimeError(f"Invalid {cfg.bypass.training.grad_clip_type}") + + assert grad_scaler is not None + grad_scaler.step(optimizer) + grad_scaler.update() + optimizer.zero_grad(set_to_none=True) + + # Collect losses from all ranks using all_gather_object + local_training_stats = LocalTrainingStats( + iter_num=cfg.bypass.iter_num, + stitched_module_losses=iter_stitched_module_losses, + ) + all_training_stats = [None] * dist.size() + torch.distributed.all_gather_object(all_training_stats, local_training_stats) + + if dist.is_master(): + if cfg.bypass.iter_num == resumed_iter_num: + mprint(f"Starting from iter {cfg.bypass.iter_num}") + + # Merge all stats into the losses history + assert stitched_losses_history is not None + merged_losses: dict[str, float] = {} + for stats in all_training_stats: + if stats is not None: + merged_losses.update(stats.stitched_module_losses) + stitched_losses_history[cfg.bypass.iter_num] = merged_losses + + cfg.bypass.token_count += cfg.bypass.training.tokens_per_iter + iter_t1 = time.time() + iter_duration = iter_t1 - iter_t0 + iter_stats_history[cfg.bypass.iter_num] = IterStatistics( + token_count=cfg.bypass.token_count, + iter_duration=iter_duration, + step_num=cfg.bypass.step_num, + lr=lr, + clipping_count=cfg.bypass.training.clipping_count, + ) + iter_t0 = iter_t1 + + # Time-based save signal (broadcast from master) + save_signal = [step_to_save] + if dist.is_master(): + if cfg.bypass.model.model_overrides.save_interval_seconds is not None: + time_now = time.time() + if ( + time_now - time_last_save + >= cfg.bypass.model.model_overrides.save_interval_seconds + ): + mprint( + f"Time to save! {cfg.bypass.model.model_overrides.save_interval_seconds=}, " + f"{time_last_save=}, {time_now=}" + ) + step_to_save = cfg.bypass.step_num + 5 + save_signal = [step_to_save] + time_last_save = time_now + + torch.distributed.broadcast_object_list(save_signal, src=0) + step_to_save = save_signal[0] + + # Logging + if dist.is_master(): + _log_training_stats( + cfg, + stitched_losses_history, + iter_stats_history, + best_losses_by_name, + best_steps_by_name, + ) + + # Validation + if ( + not is_accumulating + and (cfg.bypass.step_num % cfg.bypass.training.eval_interval) == 0 + and val_dataloader is not None + ): + _run_validation( + cfg, descriptor, student_model, student_stitched_model, + stitched_module_descriptors, val_dataloader, device, + ) + + # Checkpoint saving (step-based or time-based) + _save_interval_checkpoint( + cfg, descriptor, student_model, stitched_module_descriptors, + step_to_save, is_accumulating, + ) + + cfg.bypass.iter_num += 1 + if not is_accumulating: + cfg.bypass.step_num += 1 + + mprint("Finished successfully!") + + +def _save_final_checkpoint( + cfg: DictConfig, + descriptor: Type[ModelDescriptor], + student_model: torch.nn.Module, + stitched_module_descriptors: OrderedDict[str, StitchedModuleDescriptor], +) -> None: + """Save the final training checkpoint and delete old iter-* checkpoints if configured.""" + if not ( + cfg.bypass.model.model_overrides.save_checkpoint_when_done + and not cfg.bypass.disable_checkpoint_save + ): + return + mprint("Saving final checkpoint before training completion") + subdir_name = f"final-iter-{cfg.bypass.iter_num:06d}-ckpt" + save_bypass_checkpoint( + cfg=cfg, + descriptor=descriptor, + model=student_model, + stitched_module_descriptors=stitched_module_descriptors, + checkpoint_dir=Path(cfg.bypass.experiment_dir) / subdir_name, + reference_checkpoint_dir=cfg.teacher_dir, + ) + if cfg.bypass.model.model_overrides.delete_old_checkpoints and dist.is_master(): + for old_ckpt_path in Path(cfg.bypass.experiment_dir).glob("iter-*"): + if old_ckpt_path.name != subdir_name: + shutil.rmtree(str(old_ckpt_path)) + + +def _log_training_stats( + cfg: DictConfig, + stitched_losses_history: dict, + iter_stats_history: dict, + best_losses_by_name: dict, + best_steps_by_name: dict, +) -> None: + """Log training statistics for the current log-interval chunk (master only). + + Processes ``stitched_losses_history`` in chunks of ``log_interval`` iters, prints + per-block loss tables and speed metrics, optionally logs to W&B, then removes + processed entries from both history dicts. Mutates ``best_losses_by_name`` and + ``best_steps_by_name`` in place. + """ + assert stitched_losses_history is not None + while len(stitched_losses_history) >= cfg.bypass.training.log_interval: + lowest_iter = next(iter(stitched_losses_history.keys())) + + log_chunk = { + it: losses + for it, losses in stitched_losses_history.items() + if it - lowest_iter < cfg.bypass.training.log_interval + } + if len(log_chunk) < cfg.bypass.training.log_interval: + break + + highest_iter = list(log_chunk.keys())[-1] + highest_iter_stats = iter_stats_history[highest_iter] + + losses_by_name = defaultdict[str, list[float]](lambda: []) + for losses in log_chunk.values(): + for name, loss in losses.items(): + losses_by_name[name].append(loss) + + losses_by_name_avg = {name: mean(losses) for name, losses in losses_by_name.items()} + + for name, current_loss in losses_by_name_avg.items(): + if name not in best_losses_by_name or current_loss < best_losses_by_name[name]: + best_losses_by_name[name] = current_loss + best_steps_by_name[name] = highest_iter + + chunk_iter_durations = [iter_stats_history[it].iter_duration for it in log_chunk.keys()] + avg_chunk_iter_duration = mean(chunk_iter_durations) + avg_token_speed = cfg.bypass.training.tokens_per_iter / avg_chunk_iter_duration + mprint( + f"iter {highest_iter}/{cfg.bypass.training.max_steps:,}:" + f" avg_iter_time={avg_chunk_iter_duration * 1000:.2f}ms" + f" avg_token_speed={avg_token_speed:,.0f}[tok/s]" + ) + mprint( + format_stitched_losses( + losses_dict=losses_by_name_avg, + best_steps_dict=best_steps_by_name, + best_values_dict=best_losses_by_name, + step_number=highest_iter, + title="Stitched Module Losses", + ) + ) + + if cfg.bypass.wandb_log: + try: + import wandb + + wandb.log( + { + "iter": highest_iter, + "step": highest_iter_stats.step_num, + "token_count": highest_iter_stats.token_count, + "token_speed": avg_token_speed, + "lr": highest_iter_stats.lr, + "grad_clipping": highest_iter_stats.clipping_count, + }, + step=highest_iter, + ) + except ImportError: + pass + + for it in log_chunk.keys(): + del iter_stats_history[it] + del stitched_losses_history[it] + + +def _run_validation( + cfg: DictConfig, + descriptor: Type[ModelDescriptor], + student_model: torch.nn.Module, + student_stitched_model: StitchedModule, + stitched_module_descriptors: OrderedDict[str, StitchedModuleDescriptor], + val_dataloader: DataLoader, + device: torch.device, +) -> None: + """Run validation, broadcast val_loss, and save best checkpoint if loss improved.""" + from modelopt.torch.puzzletron.utils.validate_runtime_pipeline import ( + calculate_losses_pipeline, + ) + + losses, _ = calculate_losses_pipeline( + stitched_model=student_stitched_model, + dataloader=val_dataloader, + descriptor=descriptor, + ) + + val_loss = float("inf") + if losses is not None and "lm_loss" in losses: + val_loss = losses["lm_loss"]["avg"] + mprint(f"Validation loss at iter {cfg.bypass.iter_num}: {val_loss:.4f}") + + # Broadcast val_loss so all ranks agree on checkpoint decisions + val_loss_tensor = torch.tensor([val_loss], device=device) + torch.distributed.broadcast(val_loss_tensor, src=dist.size() - 1) + val_loss = val_loss_tensor.item() + + if val_loss < cfg.bypass.best_val_loss: + cfg.bypass.best_val_loss = val_loss + if not cfg.bypass.disable_checkpoint_save and cfg.bypass.save_best_ckpt: + subdir_name = f"best-iter-{cfg.bypass.iter_num:06d}-ckpt" + save_bypass_checkpoint( + cfg=cfg, + descriptor=descriptor, + model=student_model, + stitched_module_descriptors=stitched_module_descriptors, + checkpoint_dir=Path(cfg.bypass.experiment_dir) / subdir_name, + reference_checkpoint_dir=cfg.teacher_dir, + ) + if cfg.bypass.kill_after_first_save: + raise RuntimeError("Done saving checkpoint, kill_after_first_save=True") + + +def _save_interval_checkpoint( + cfg: DictConfig, + descriptor: Type[ModelDescriptor], + student_model: torch.nn.Module, + stitched_module_descriptors: OrderedDict[str, StitchedModuleDescriptor], + step_to_save: Optional[int], + is_accumulating: bool, +) -> None: + """Save a step-interval or time-based checkpoint if the current step qualifies.""" + if is_accumulating: + return + if not ( + (cfg.bypass.step_num % cfg.bypass.model.model_overrides.save_interval) == 0 + or step_to_save == cfg.bypass.step_num + or ( + cfg.bypass.model.model_overrides.save_checkpoint_when_done + and cfg.bypass.step_num >= cfg.bypass.training.max_steps + ) + ): + return + if cfg.bypass.disable_checkpoint_save: + return + + if (cfg.bypass.step_num % cfg.bypass.model.model_overrides.save_interval) == 0: + mprint("Saving step-interval checkpoint") + elif step_to_save == cfg.bypass.step_num: + mprint("Saving time-based checkpoint") + elif ( + cfg.bypass.model.model_overrides.save_checkpoint_when_done + and cfg.bypass.step_num >= cfg.bypass.training.max_steps - 100 + ): + mprint("Saving final checkpoint") + + subdir_name = f"iter-{cfg.bypass.iter_num:06d}-ckpt" + save_bypass_checkpoint( + cfg=cfg, + descriptor=descriptor, + model=student_model, + stitched_module_descriptors=stitched_module_descriptors, + checkpoint_dir=Path(cfg.bypass.experiment_dir) / subdir_name, + reference_checkpoint_dir=cfg.teacher_dir, + ) + + if cfg.bypass.kill_after_first_save: + dist.barrier() + raise RuntimeError("Done saving checkpoint, kill_after_first_save=True") + + if cfg.bypass.model.model_overrides.delete_old_checkpoints and dist.is_master(): + for old_ckpt_path in Path(cfg.bypass.experiment_dir).glob("iter-*"): + if old_ckpt_path.name != subdir_name: + shutil.rmtree(str(old_ckpt_path)) + + +# Learning rate decay scheduler (cosine with warmup) +def _get_lr(cfg: DictConfig, step: int) -> float: + # 1) linear warmup for warmup_steps steps + if step <= cfg.bypass.training.warmup_steps: + lr = cfg.bypass.training.learning_rate * step / cfg.bypass.training.warmup_steps + # 2) if step > lr_decay_steps, return min learning rate + elif step > cfg.bypass.training.lr_decay_steps: + lr = cfg.bypass.training.min_lr + # 3) in between, use cosine decay down to min learning rate + else: + decay_ratio = (step - cfg.bypass.training.warmup_steps - 1) / ( + cfg.bypass.training.lr_decay_steps - cfg.bypass.training.warmup_steps + ) + assert 0 <= decay_ratio <= 1 + coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff ranges 0..1 + lr = cfg.bypass.training.min_lr + coeff * ( + cfg.bypass.training.learning_rate - cfg.bypass.training.min_lr + ) + + return lr + + +def run_bypassed_training(cfg: DictConfig): + """Setup and orchestrate bypass distillation training.""" + logging.basicConfig( + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.WARN + ) + + # Suppress debug messages from HuggingFace libraries + datasets.utils.logging.set_verbosity_error() + transformers.utils.logging.set_verbosity_error() + + device = torch.device(f"cuda:{dist.local_rank()}") + + descriptor = ModelDescriptorFactory.get(cfg.descriptor) + trust_remote_code = descriptor.requires_trust_remote_code() + teacher_model_config = load_model_config(cfg.teacher_dir, trust_remote_code=trust_remote_code) + + try: + mprint("Waiting for distributed setup...") + dist.barrier() + + if cfg.bypass.disable_initial_validate: + cfg.bypass.validate_teacher_model = False + cfg.bypass.validate_student_model = False + + if cfg.bypass.teacher_model_load_on_cpu: + assert not cfg.bypass.validate_teacher_model, ( + "Teacher model validation is too slow on CPU" + ) + + num_hidden_layers = descriptor.get_language_model_config( + teacher_model_config + ).num_hidden_layers + + model_blocks_process_ownership = get_distributed_modules_ownership( + module_count=num_hidden_layers, + world_size=dist.size(), + ) + + owned_block_indexes = set( + block_index + for block_index, owner_rank in enumerate(model_blocks_process_ownership) + if owner_rank == dist.rank() + ) + + cfg.teacher_dir = str(Path(cfg.teacher_dir).expanduser()) + teacher_model_config = load_model_config( + cfg.teacher_dir, + model_config_overrides={"use_cache": False}, + trust_remote_code=trust_remote_code, + ) + + student_model = None + if cfg.bypass.init_checkpoint_path is not None: + mprint(f"Loading student model from {cfg.bypass.init_checkpoint_path}") + student_model = load_and_shard_model( + descriptor=descriptor, + checkpoint_path=cfg.bypass.init_checkpoint_path, + owned_block_indexes=owned_block_indexes, + ) + + cfg.bypass.training.min_lr = ( + cfg.bypass.training.learning_rate * cfg.bypass.training.min_lr_factor + ) + cfg.bypass.training.batch_size_per_iter = cfg.bypass.training.micro_batch_size + cfg.bypass.training.tokens_per_iter = ( + cfg.bypass.data.block_size * cfg.bypass.training.batch_size_per_iter + ) + cfg.bypass.training.max_steps = math.ceil( + cfg.bypass.training.training_tokens / cfg.bypass.training.tokens_per_iter + ) + cfg.bypass.training.max_iters = ( + cfg.bypass.training.max_steps * cfg.bypass.training.grad_accumulation_steps + ) + cfg.bypass.training.max_token_count = ( + cfg.bypass.training.max_iters * cfg.bypass.training.tokens_per_iter + ) + cfg.bypass.training.lr_decay_steps = cfg.bypass.training.max_steps + + if cfg.bypass.training.val_micro_batch_size is None: + cfg.bypass.training.val_micro_batch_size = cfg.bypass.training.micro_batch_size + + if cfg.bypass.training.warmup_steps is None: + cfg.bypass.training.warmup_steps = 0 + + mprint(f"\n{format_global_config(cfg.bypass, 'Bypass Configurations')}") + mprint(f"Max token count: {cfg.bypass.training.max_token_count:,}") + + seed = cfg.bypass.seed + torch.manual_seed(seed) + + tokenizer = AutoTokenizer.from_pretrained( + cfg.teacher_dir, + trust_remote_code=trust_remote_code, + token=True, + ) + + assert teacher_model_config is not None + + mprint(f"Load and shard model with: {owned_block_indexes=}, {cfg.teacher_dir=}") + teacher_model = load_and_shard_model( + descriptor=descriptor, + checkpoint_path=cfg.teacher_dir, + owned_block_indexes=owned_block_indexes, + model_config=teacher_model_config, + ) + + teacher_model.requires_grad_(False) + + # Create dataloaders + from modelopt.torch.puzzletron.utils.data.dataloaders import ( + create_train_dataloader, + create_validation_dataloader, + load_from_disk_fn, + load_streaming_fn, + ) + + if cfg.bypass.data.eval_samples_per_process is not None: + max_eval_samples = cfg.bypass.data.eval_samples_per_process * dist.size() + else: + max_eval_samples = cfg.bypass.data.max_eval_samples + + load_dataset_fn = ( + load_streaming_fn if not cfg.bypass.data.load_from_disk else load_from_disk_fn + ) + + train_dataloader = create_train_dataloader( + seed=seed, + tokenizer=tokenizer, + block_size=cfg.bypass.data.block_size, + dataset_path=cfg.dataset_path, + content_field=cfg.bypass.data.data_column, + fim_rate=cfg.bypass.data.fim_rate, + fim_spm_rate=cfg.bypass.data.fim_spm_rate, + micro_batch_size=cfg.bypass.training.micro_batch_size, + load_dataset_fn=load_dataset_fn, + keep_in_memory=cfg.bypass.data.keep_in_memory, + source_datasets_to_discard=cfg.bypass.get("source_datasets_to_discard", tuple()), + bos_rate=cfg.bypass.data.bos_rate, + shuffle_seed=cfg.bypass.data.shuffle_train_data_seed, + ) + + val_dataloader = None + if not cfg.bypass.disable_validation: + val_dataloader = create_validation_dataloader( + accelerator=None, + seed=seed, + tokenizer=tokenizer, + block_size=cfg.bypass.data.block_size, + dataset=cfg.dataset_path, + content_field=cfg.bypass.data.data_column, + fim_rate=cfg.bypass.data.fim_rate, + fim_spm_rate=cfg.bypass.data.fim_spm_rate, + micro_batch_size=cfg.bypass.training.val_micro_batch_size, + eval_samples=max_eval_samples, + load_dataset_fn=load_dataset_fn, + dataset_name=cfg.bypass.data.val_dataset_name, + keep_in_memory=cfg.bypass.data.keep_in_memory, + source_datasets_to_discard=cfg.bypass.get("source_datasets_to_discard", tuple()), + bos_rate=cfg.bypass.data.bos_rate, + ) + + # Set ID from experiment configuration + set_experiment_id(cfg) + # Set directory for experiment ID + set_experiment_dir(cfg) + + dist.barrier() + + with torch.device(device): + stitched_model_factory_fn = cast( + stitched_model_factory_module.StitchedModelFactoryFn, + getattr(stitched_model_factory_module, cfg.bypass.model_factory.factory), + ) + ( + student_model, + teacher_stitched_model, + teacher_val_stitched_module, + student_val_stitched_model, + stitched_module_descriptors, + student_model_config, + ) = stitched_model_factory_fn( + teacher_model=teacher_model, + descriptor=descriptor, + cfg=cfg.bypass, + model_blocks_process_ownership=model_blocks_process_ownership, + student_model=student_model, + ) + + # Check whether to resume from checkpoint + resume_checkpoint_path = None + if cfg.bypass.resume_checkpoint_path is not None: + resume_checkpoint_path = cfg.bypass.resume_checkpoint_path + elif cfg.bypass.find_last_ckpt_for_resume: + _ckpt_dir = find_latest_run_dir(run_parent_dir=cfg.bypass.experiment_dir) + if _ckpt_dir is None: + mprint("Couldn't find any run dir for resume, assuming this is the first job") + else: + mprint( + f"`cfg.bypass.find_last_ckpt_for_resume` is True. " + f"Auto-found a checkpoint to resume: `{_ckpt_dir}`" + ) + resume_checkpoint_path = _ckpt_dir + + if resume_checkpoint_path: + load_local_state( + stitched_module_descriptors=stitched_module_descriptors, + checkpoint_path=resume_checkpoint_path, + ) + + # Load resume ckpt bypass configs and extract resume iter_num + resume_cfg = DictConfig(json_load(Path(resume_checkpoint_path) / "args.json")) + + # Resume stats + cfg.bypass.iter_num = resume_cfg.iter_num + cfg.bypass.token_count = resume_cfg.token_count + cfg.bypass.step_num = resume_cfg.step_num + cfg.bypass.best_val_loss = resume_cfg.best_val_loss + cfg.bypass.training.clipping_count = resume_cfg.training.clipping_count + mprint(f"Resume from iter_num: {cfg.bypass.iter_num}") + + # Only copy wandb.run_id if it exists in resume config + if hasattr(resume_cfg, "wandb") and hasattr(resume_cfg.wandb, "run_id"): + cfg.bypass.wandb.run_id = resume_cfg.wandb.run_id + + cfg.bypass.save_checkpoint_before_training = False + cfg.bypass.validate_teacher_model = False + cfg.bypass.validate_student_model = False + + cfg.bypass.resume_checkpoint_path = resume_checkpoint_path + + # Initialize Weights and Biases + if cfg.bypass.wandb_log: + try: + import wandb + + wandb.init( + project=cfg.bypass.wandb.project, + entity=cfg.bypass.wandb.entity, + config=dict(cfg.bypass), + ) + except ImportError: + mprint("wandb not installed, disabling wandb logging") + cfg.bypass.wandb_log = False + else: + mprint("Weights & Biases logging disabled (wandb_log=False)") + + if cfg.bypass.validate_teacher_model and val_dataloader is not None: + from modelopt.torch.puzzletron.utils.validate_runtime_pipeline import ( + calculate_losses_pipeline, + ) + + mprint("Evaluating teacher model:") + losses, _ = calculate_losses_pipeline( + stitched_model=teacher_val_stitched_module, + dataloader=val_dataloader, + descriptor=descriptor, + ) + if losses is not None: + mprint(f"Teacher validation losses: {losses}") + mprint("Evaluated teacher model") + + torch.cuda.empty_cache() + dist.barrier() + + parameter_count = sum(p.numel() for p in student_model.parameters()) + aprint(f"Model parameter count: {parameter_count:,}") + cfg.bypass.parameter_count = parameter_count + + dist.barrier() + mprint("Performing dummy runs on stitched modules:") + torch.cuda.synchronize() + with ( + torch.no_grad(), + torch.autocast(device_type="cuda", dtype=torch.bfloat16), + torch.device(device), + ): + input_ids = torch.ones( + (cfg.bypass.training.micro_batch_size, cfg.bypass.data.block_size), + dtype=torch.long, + ) + dummy_fake_input_ids = fake_tensor(input_ids) + mprint(f"Dummy runs on stitched modules with shape: {dummy_fake_input_ids.shape=}") + teacher_output = teacher_stitched_model({}, {}, input_ids) + for stitched_module_descriptor in stitched_module_descriptors.values(): + stitched_module = stitched_module_descriptor.stitched_module + stitched_module( + input_overrides={ + **teacher_output.captured_inputs, + "teacher_inputs": InputArgs(dummy_fake_input_ids), + }, + output_overrides=teacher_output.captured_outputs, + ) + for name, param in stitched_module.named_parameters(recurse=True): + if "iter_num" in name: + param.data = torch.zeros_like(param.data) + del name, param + del input_ids, dummy_fake_input_ids, teacher_output + torch.cuda.synchronize() + dist.barrier() + + del teacher_model + + if cfg.bypass.validate_student_model and val_dataloader is not None: + from modelopt.torch.puzzletron.utils.validate_runtime_pipeline import ( + calculate_losses_pipeline, + ) + + mprint("Validating model before training:") + losses, _ = calculate_losses_pipeline( + stitched_model=student_val_stitched_model, + dataloader=val_dataloader, + descriptor=descriptor, + ) + if losses is not None: + mprint(f"Student validation losses: {losses}") + + dist.barrier() + torch.cuda.empty_cache() + dist.barrier() + + train( + cfg=cfg, + descriptor=descriptor, + student_model=student_model, + student_stitched_model=student_val_stitched_model, + teacher_stitched_model=teacher_stitched_model, + stitched_module_descriptors=stitched_module_descriptors, + stitched_modules_process_ownership=model_blocks_process_ownership, + train_dataloader=train_dataloader, + val_dataloader=val_dataloader, + student_model_config=student_model_config, + skip_first_batches=cfg.bypass.training.skip_first_batches, + tokenizer=tokenizer, + ) + + aprint("Finished training successfully!") + dist.barrier() + + except Exception as e: + print(traceback.format_exc(), file=sys.stderr) + raise + + dist.barrier() + if dist.is_master(): + mprint("Realizing bypass checkpoints") + realize_bypass_checkpoints(cfg) + training_marker = Path(cfg.bypass.experiment_dir) / "training.complete" + training_marker.touch() + + +def realize_bypass_checkpoints(cfg: DictConfig): + """Create symlinks from bypass checkpoint directories to the ckpts directory. + + When ``cfg.bypass.realize_best_or_latest == "best"``, the symlink points to the + highest-iteration ``best-iter-*`` checkpoint (saved when validation loss improved). + Falls back to ``latest`` if no best checkpoint exists (e.g. validation was disabled). + When set to ``"latest"``, always uses the ``latest`` symlink (last saved checkpoint). + """ + experiment_dir = Path(cfg.bypass.experiment_dir) + realize_mode = cfg.bypass.get("realize_best_or_latest", "latest") + + checkpoint_dir_str = None + if realize_mode == "best": + checkpoint_dir_str = find_best_run_dir(experiment_dir) + if checkpoint_dir_str is None: + mprint( + "realize_best_or_latest='best' but no best checkpoint found " + "(validation may be disabled); falling back to latest" + ) + if checkpoint_dir_str is None: + checkpoint_dir = experiment_dir / "latest" + else: + checkpoint_dir = Path(checkpoint_dir_str) + + if not checkpoint_dir.exists(): + mprint(f"Could not find checkpoint directory: {checkpoint_dir}") + return + + ckpts_dir = Path(cfg.puzzle_dir) / "ckpts" + ckpts_dir.mkdir(parents=True, exist_ok=True) + + symlink_name = ckpts_dir / cfg.bypass.experiment_id + if symlink_name.exists() or symlink_name.is_symlink(): + symlink_name.unlink() + + symlink_name.symlink_to(checkpoint_dir, target_is_directory=True) + mprint(f"Created symlink: {symlink_name} -> {checkpoint_dir}") diff --git a/modelopt/torch/puzzletron/dataset/prepare_dataset.py b/modelopt/torch/puzzletron/dataset/prepare_dataset.py index 6f1749697c..5c18a03dd1 100644 --- a/modelopt/torch/puzzletron/dataset/prepare_dataset.py +++ b/modelopt/torch/puzzletron/dataset/prepare_dataset.py @@ -15,10 +15,10 @@ import os -import datasets import fire import numpy as np +import datasets from modelopt.torch.puzzletron.tools.logger import mprint diff --git a/modelopt/torch/puzzletron/nas/plugins/puzzletron_nas_plugin.py b/modelopt/torch/puzzletron/nas/plugins/puzzletron_nas_plugin.py index e5025dea7d..3390c551a2 100644 --- a/modelopt/torch/puzzletron/nas/plugins/puzzletron_nas_plugin.py +++ b/modelopt/torch/puzzletron/nas/plugins/puzzletron_nas_plugin.py @@ -27,6 +27,7 @@ import torch from torch import nn +import modelopt.torch.puzzletron.bypass_distillation as bypass_distillation import modelopt.torch.puzzletron.mip.mip_and_realize_models as mip_and_realize_models import modelopt.torch.puzzletron.pruning.pruning_ckpts as pruning_ckpts import modelopt.torch.puzzletron.scoring.scoring as scoring @@ -92,10 +93,28 @@ class PuzzletronConfig(ModeloptBaseConfig): ) +def _total_steps(hydra_cfg) -> int: + """Return total pipeline step count: 9 with bypass, 8 without. + + Steps: + 1 starting (main.py) + 2 convert model + 3 score pruning activations + 4 prune checkpoints + [5 bypass distillation — only when bypass is configured] + 5/6 build replacement library & subblock stats + 6/7 calculate one block scores + 7/8 MIP and realize models + 8/9 completed (main.py) + """ + return 9 if hydra_cfg.get("bypass", None) is not None else 8 + + def convert_puzzletron_model(model: nn.Module, config: PuzzletronConfig) -> ConvertReturnType: """1. Convert the model from HF format to AnyModel format. 2. Score the pruning activations. - 3. Prune the model and save pruned checkpoints + 3. Prune the model and save pruned checkpoints. + 4. (Optional) Run bypass distillation. The output of this step will be used by mnt.search() to perform the NAS search. """ @@ -117,37 +136,100 @@ def convert_puzzletron_model(model: nn.Module, config: PuzzletronConfig) -> Conv # Instantiate nested Hydra configs (e.g., pruning_mixin, hook_class) hydra_cfg = hydra.utils.instantiate(hydra_cfg) - # Convert HuggingFace model to Puzzletron heterogeneous format (generic, uses descriptor from config) - if dist.is_master(): - mprint( - "Puzzletron Progress 2/8: converting model to Puzzletron heterogeneous format (single-gpu)" - ) - hf_ckpt_teacher_dir = "ckpts/teacher" # TODO: make it configurable + has_bypass = hydra_cfg.get("bypass", None) is not None + n = _total_steps(hydra_cfg) - # Get descriptor and converter from the hydra config - descriptor_name = hydra_cfg.descriptor - descriptor = ModelDescriptorFactory.get(descriptor_name) - converter = ConverterFactory.get(descriptor_name) + puzzle_dir = Path(config.puzzle_dir) - converter.convert( - descriptor=descriptor, - input_dir=Path(config.input_model_path), - output_dir=Path(config.puzzle_dir) / hf_ckpt_teacher_dir, - ) + # Step 2: Convert HuggingFace model to Puzzletron heterogeneous format + hf_ckpt_teacher_dir = "ckpts/teacher" # TODO: make it configurable + teacher_dir = puzzle_dir / hf_ckpt_teacher_dir + convert_marker = puzzle_dir / "convert.complete" + if dist.is_master(): + # Check durable marker first; fall back to artifact existence for backward compat + already_done = convert_marker.exists() or (teacher_dir / "config.json").exists() + if already_done: + mprint( + f"Puzzletron Progress 2/{n}: teacher checkpoint already exists, skipping conversion" + ) + else: + mprint( + f"Puzzletron Progress 2/{n}: converting model to Puzzletron heterogeneous format (single-gpu)" + ) + + # Get descriptor and converter from the hydra config + descriptor_name = hydra_cfg.descriptor + descriptor = ModelDescriptorFactory.get(descriptor_name) + converter = ConverterFactory.get(descriptor_name) + + # Auto-download from HuggingFace if path doesn't exist locally. + # input_model_path is only used on rank 0 (conversion is single-process); + # other ranks wait at the dist.barrier() below and never read this variable. + input_model_path = config.input_model_path + if not Path(input_model_path).exists(): + from huggingface_hub import snapshot_download + + if input_model_path.startswith("https://huggingface.co/"): + model_id = "/".join(input_model_path.rstrip("/").split("/")[-2:]) + else: + model_id = input_model_path # assume HF model ID like "org/model-name" + mprint( + f"Downloading HuggingFace model '{model_id}' — this may take several minutes " + f"for large models. Other ranks are waiting at a barrier." + ) + input_model_path = snapshot_download(repo_id=model_id) + mprint(f"Downloaded to: {input_model_path}") + + converter.convert( + descriptor=descriptor, + input_dir=Path(input_model_path), + output_dir=teacher_dir, + ) + convert_marker.touch() dist.barrier() - # Score_pruning_activations (distributed processing) - mprint("Puzzletron Progress 3/8: scoring pruning activations (multi-gpu)") - score_pruning_activations.launch_score_activations(hydra_cfg) + # Step 3: Score pruning activations (distributed processing) + activations_log_dir = Path(hydra_cfg.pruning.activations_log_dir) + score_marker = puzzle_dir / "score_activations.complete" + # Check durable marker first; fall back to artifact existence for backward compat + already_scored = score_marker.exists() or ( + activations_log_dir.exists() and any(activations_log_dir.glob("rank_*.pth")) + ) + if already_scored: + mprint( + f"Puzzletron Progress 3/{n}: pruning activation scores already exist, skipping scoring" + ) + dist.barrier() + else: + mprint(f"Puzzletron Progress 3/{n}: scoring pruning activations (multi-gpu)") + score_pruning_activations.launch_score_activations(hydra_cfg) + if dist.is_master(): + score_marker.touch() + dist.barrier() - # Prune the model and save pruned checkpoints + # Step 4: Prune the model and save pruned checkpoints (single process) + pruned_ckpts_dir = Path(hydra_cfg.pruning.pruned_ckpts_output_dir) + prune_marker = puzzle_dir / "prune.complete" if dist.is_master(): - mprint( - "Puzzletron Progress 4/8: pruning the model and saving pruned checkpoints (single-gpu)" + # Check durable marker first; fall back to artifact existence for backward compat + already_pruned = prune_marker.exists() or ( + pruned_ckpts_dir.exists() and any(pruned_ckpts_dir.iterdir()) ) - pruning_ckpts.launch_prune_ckpt(hydra_cfg) + if already_pruned: + mprint(f"Puzzletron Progress 4/{n}: pruned checkpoints already exist, skipping pruning") + else: + mprint( + f"Puzzletron Progress 4/{n}: pruning the model and saving pruned checkpoints (single-gpu)" + ) + pruning_ckpts.launch_prune_ckpt(hydra_cfg) + prune_marker.touch() dist.barrier() + # Step 5: Bypass distillation (optional, distributed processing) + if has_bypass: + mprint(f"Puzzletron Progress 5/{n}: running bypass distillation (multi-gpu)") + bypass_distillation.launch_bypass_distillation(hydra_cfg) + return model, {} @@ -218,18 +300,40 @@ def run_search(self) -> None: # Instantiate nested Hydra configs (e.g., pruning_mixin, hook_class) hydra_cfg = hydra.utils.instantiate(hydra_cfg) - # Build_library_and_stats (single process) + has_bypass = hydra_cfg.get("bypass", None) is not None + n = _total_steps(hydra_cfg) + # With bypass: library=6, scoring=7, mip=8 (out of 9) + # Without bypass: library=5, scoring=6, mip=7 (out of 8) + library_step = 6 if has_bypass else 5 + scoring_step = 7 if has_bypass else 6 + mip_step = 8 if has_bypass else 7 + + # Build replacement library and subblock statistics (single process) + puzzle_dir = Path(self.model.puzzle_dir) + replacement_library_path = puzzle_dir / "replacement_library.json" + subblock_stats_path = puzzle_dir / hydra_cfg.calc_subblock_stats.subblock_stats_filename + library_marker = puzzle_dir / "library.complete" if dist.is_master(): - mprint( - "Puzzletron Progress 5/8: building replacement library and subblock statistics (single-gpu)" + # Check durable marker first; fall back to artifact existence for backward compat + already_built = library_marker.exists() or ( + replacement_library_path.exists() and subblock_stats_path.exists() ) - build_library_and_stats.launch_build_library_and_stats(hydra_cfg) + if already_built: + mprint( + f"Puzzletron Progress {library_step}/{n}: replacement library and subblock stats already exist, skipping" + ) + else: + mprint( + f"Puzzletron Progress {library_step}/{n}: building replacement library and subblock statistics (single-gpu)" + ) + build_library_and_stats.launch_build_library_and_stats(hydra_cfg) + library_marker.touch() dist.barrier() - # Calc_one_block_scores (distributed processing) - mprint("Puzzletron Progress 6/8: calculating one block scores (multi-gpu)") + # Calculate one block scores (distributed processing) + mprint(f"Puzzletron Progress {scoring_step}/{n}: calculating one block scores (multi-gpu)") scoring.launch_scoring(hydra_cfg) - # mip_and_realize_models (distributed processing) - mprint("Puzzletron Progress 7/8: running MIP and realizing models (multi-gpu)") + # MIP search and realize models (distributed processing) + mprint(f"Puzzletron Progress {mip_step}/{n}: running MIP and realizing models (multi-gpu)") mip_and_realize_models.launch_mip_and_realize_model(hydra_cfg) diff --git a/modelopt/torch/puzzletron/pruning/pruning_utils.py b/modelopt/torch/puzzletron/pruning/pruning_utils.py index 82ba675c94..44b388a92e 100644 --- a/modelopt/torch/puzzletron/pruning/pruning_utils.py +++ b/modelopt/torch/puzzletron/pruning/pruning_utils.py @@ -44,6 +44,7 @@ class MlpInitMode(Enum): PruneByActivationsLog = "PruneByActivationsLog" ExpertRemoval = "ExpertRemoval" ConcatExpertsIntoDenseFFN = "ConcatExpertsIntoDenseFFN" + MoEChannelPruning = "MoEChannelPruning" class LinearInitMode(Enum): @@ -136,6 +137,7 @@ def _init_mlp_module( elif mlp_init_mode in ( MlpInitMode.Truncate, MlpInitMode.PruneByActivationsLog, + MlpInitMode.MoEChannelPruning, ): assert original_intermediate_size >= new_intermediate_size, ( f"({original_intermediate_size=}) < ({new_intermediate_size=}), can't be truncated." @@ -149,7 +151,7 @@ def _init_mlp_module( ) mlp_module_weight = truncated_weight - elif mlp_init_mode == MlpInitMode.PruneByActivationsLog: + elif mlp_init_mode in (MlpInitMode.PruneByActivationsLog, MlpInitMode.MoEChannelPruning): if pruned_filters is None: filter_importance = _load_activations_log( mlp_init_config, module_name=f"{mlp_prefix}.down_proj" diff --git a/modelopt/torch/puzzletron/puzzletron.py b/modelopt/torch/puzzletron/puzzletron.py index 5a1484e07a..457fef6df4 100644 --- a/modelopt/torch/puzzletron/puzzletron.py +++ b/modelopt/torch/puzzletron/puzzletron.py @@ -20,6 +20,7 @@ import modelopt.torch.puzzletron.activation_scoring.score_pruning_activations as score_pruning_activations import modelopt.torch.puzzletron.build_library_and_stats as build_library_and_stats +import modelopt.torch.puzzletron.bypass_distillation as bypass_distillation import modelopt.torch.puzzletron.mip.mip_and_realize_models as mip_and_realize_models import modelopt.torch.puzzletron.pruning.pruning_ckpts as pruning_ckpts import modelopt.torch.puzzletron.scoring.scoring as scoring @@ -62,6 +63,10 @@ def puzzletron( pruning_ckpts.launch_prune_ckpt(hydra_cfg) dist.barrier() + # Step 3: bypass distillation (optional, distributed processing) + if hydra_cfg.get("bypass", None) is not None: + bypass_distillation.launch_bypass_distillation(hydra_cfg) + # Step 4: build_library_and_stats (single process) if dist.is_master(): build_library_and_stats.launch_build_library_and_stats(hydra_cfg) diff --git a/modelopt/torch/puzzletron/sewing_kit/utils.py b/modelopt/torch/puzzletron/sewing_kit/utils.py index 19c1bd6c83..7b885c181b 100644 --- a/modelopt/torch/puzzletron/sewing_kit/utils.py +++ b/modelopt/torch/puzzletron/sewing_kit/utils.py @@ -43,6 +43,8 @@ from torch._subclasses import FakeTensor, FakeTensorMode from typing_extensions import override +from modelopt.torch.puzzletron.tools.kd_model import normalized_mse_loss + Fn = TypeVar("Fn", bound=Callable) @@ -429,3 +431,38 @@ def _get_group_kwarg_if_necessary() -> dict: torch.distributed.distributed_c10d._object_to_tensor ).parameters.keys() return dict(group=None) if "group" in arg_names else dict() + + +# ────────────────────────────────────────────────────────────────────────────── +# Loss functions for bypass distillation (blockwise local knowledge distillation) +# ────────────────────────────────────────────────────────────────────────────── + +Reduction = Literal["none", "mean", "sum"] + + +def vectorwise_normalized_mse_loss( + input: torch.Tensor, + target: torch.Tensor, + epsilon: float = 1e-6, +) -> torch.Tensor: + """Like normalized_mse_loss, but normalization is done per-vector (last dim), then averaged.""" + return batched_normalized_mse_loss(input, target, epsilon, batch_dims=range(input.ndim - 1)) + + +def batched_normalized_mse_loss( + input: torch.Tensor, + target: torch.Tensor, + epsilon: float = 1e-6, + batch_dims: Sequence[int] = (0,), +) -> torch.Tensor: + """Like normalized_mse_loss, but normalization is done on non-batch dims, then averaged. + + Useful when activations within a batch item should be normalized independently + rather than normalizing across the full batch. + """ + norm_dims = list(set(range(input.ndim)) - set(batch_dims)) + norm_of_target_vectors = ( + F.mse_loss(target, torch.zeros_like(target), reduction="none").mean(norm_dims) + epsilon + ) + loss = F.mse_loss(input, target, reduction="none").mean(norm_dims) / norm_of_target_vectors + return loss.mean() diff --git a/modelopt/torch/puzzletron/tools/bypassed_training/child_init.py b/modelopt/torch/puzzletron/tools/bypassed_training/child_init.py index b30e7eefa9..72df36ac41 100644 --- a/modelopt/torch/puzzletron/tools/bypassed_training/child_init.py +++ b/modelopt/torch/puzzletron/tools/bypassed_training/child_init.py @@ -91,26 +91,28 @@ def _process_single_layer( keys_to_remove = {} layer_out_state_dict = {} - # Delegate to pruning_mixin if available + # Delegate to pruning_mixin if available (supports a single mixin or a list of mixins) if pruning_mixin is not None: - _layer_out = pruning_mixin.prune_single_layer( - layer_idx=layer_idx, - parent_state_dict=parent_state_dict, - new_state_dict=new_state_dict, - original_config=original_config, - new_config=new_config, - gqa_init_mode=gqa_init_mode, - mlp_init_mode=mlp_init_mode, - mlp_init_config=mlp_init_config, - linear_init_mode=linear_init_mode, - ignored_keys=ignored_keys, - keys=keys, - is_original_mha=is_original_mha, - head_size=head_size, - hidden_size=hidden_size, - keys_to_remove=keys_to_remove, - ) - layer_out_state_dict.update(_layer_out) + _mixins = pruning_mixin if isinstance(pruning_mixin, list) else [pruning_mixin] + for _mixin in _mixins: + _layer_out = _mixin.prune_single_layer( + layer_idx=layer_idx, + parent_state_dict=parent_state_dict, + new_state_dict=new_state_dict, + original_config=original_config, + new_config=new_config, + gqa_init_mode=gqa_init_mode, + mlp_init_mode=mlp_init_mode, + mlp_init_config=mlp_init_config, + linear_init_mode=linear_init_mode, + ignored_keys=ignored_keys, + keys=keys, + is_original_mha=is_original_mha, + head_size=head_size, + hidden_size=hidden_size, + keys_to_remove=keys_to_remove, + ) + layer_out_state_dict.update(_layer_out) return layer_out_state_dict, keys_to_remove # Legacy inline processing (fallback when no pruning_mixin) @@ -486,8 +488,15 @@ def create_child_state_dict( ) # Phase 3: Copy remaining keys from original model + # Only copy keys that exist in the student model (expected_keys_and_shapes). + # When a layer type is removed (e.g. attention no_op), the teacher may have + # model-specific sub-parameters (e.g. Qwen3 q_norm/k_norm) that are absent + # from the student; those keys must be skipped here rather than copied and + # later rejected by the verification assert. copy_start_time = time.time() - keys_to_copy_from_orig_model = set(keys.values()) - ignored_keys + keys_to_copy_from_orig_model = (set(keys.values()) - ignored_keys) & set( + expected_keys_and_shapes.keys() + ) for key in keys_to_copy_from_orig_model: # Memory optimization: avoid unnecessary copies tensor = original_state_dict[key] @@ -801,7 +810,7 @@ def update_model_config( def override(item, item_overrides): if item_overrides is None: - return item_overrides + return item # None override means "keep original value" if dataclasses.is_dataclass(item): assert isinstance(item_overrides, dict) return dataclass_override(item, item_overrides) diff --git a/modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py b/modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py index 020afdfadd..d326aff938 100644 --- a/modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py +++ b/modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py @@ -22,7 +22,9 @@ import concurrent.futures import dataclasses import fcntl +import inspect import os +import shutil import time import warnings from collections import defaultdict @@ -368,6 +370,41 @@ def _build_safetensors_weight_map( return weight_map +def _copy_auto_map_code_files(model_config: PretrainedConfig, checkpoint_dir: Path) -> None: + """Copy custom modeling Python files referenced in auto_map to the checkpoint directory. + + PretrainedConfig.save_pretrained() only copies the config class's own source file. + This copies any additional files (e.g., modeling_*.py) also referenced in auto_map, + which are required when loading the checkpoint with trust_remote_code=True. + + Models that need this include NemotronH (nvidia/Minitron-*) and other models that + ship custom modeling code via the HuggingFace auto_map mechanism. + """ + if not hasattr(model_config, "auto_map"): + return + + # The config class's source file lives in the HF cache together with all other + # custom code files for this model. Walk the auto_map values to find every + # module file that needs to be present alongside config.json. + source_dir = Path(inspect.getfile(type(model_config))).parent + + module_files = set() + for class_ref in model_config.auto_map.values(): + # Normalize: lists/tuples carry multiple class refs — take the first + if isinstance(class_ref, (list, tuple)): + class_ref = class_ref[0] + # Strip repo qualifier: "repo_id--module.ClassName" → "module.ClassName" + if "--" in class_ref: + class_ref = class_ref.split("--", 1)[1] + module_files.add(f"{class_ref.split('.')[0]}.py") + + for filename in module_files: + src = source_dir / filename + dst = Path(checkpoint_dir) / filename + if src.exists() and not dst.exists(): + shutil.copy(src, dst) + + def save_model_config(model_config: PretrainedConfig, checkpoint_dir: Path | str) -> None: if hasattr(model_config, "block_configs"): model_config.block_configs = [ @@ -375,3 +412,4 @@ def save_model_config(model_config: PretrainedConfig, checkpoint_dir: Path | str for conf in model_config.block_configs ] model_config.save_pretrained(checkpoint_dir) + _copy_auto_map_code_files(model_config, Path(checkpoint_dir)) diff --git a/modelopt/torch/puzzletron/tools/kd_model.py b/modelopt/torch/puzzletron/tools/kd_model.py index 8590c3f56c..6fcafdf4f6 100644 --- a/modelopt/torch/puzzletron/tools/kd_model.py +++ b/modelopt/torch/puzzletron/tools/kd_model.py @@ -35,8 +35,8 @@ def normalized_mse_loss( reduction: Literal["none", "mean", "sum"] = "mean", epsilon: float = 1e-6, ) -> Tensor: - loss = F.mse_loss(input, target, reduction=reduction) / F.mse_loss( - target, torch.zeros_like(target) + epsilon, reduction=reduction + loss = F.mse_loss(input, target, reduction=reduction) / ( + F.mse_loss(target, torch.zeros_like(target), reduction=reduction) + epsilon ) return loss diff --git a/modelopt/torch/puzzletron/utils/data/dataloaders.py b/modelopt/torch/puzzletron/utils/data/dataloaders.py index 892d1f3c2c..69e8d2e830 100644 --- a/modelopt/torch/puzzletron/utils/data/dataloaders.py +++ b/modelopt/torch/puzzletron/utils/data/dataloaders.py @@ -19,7 +19,6 @@ from functools import partial from typing import Protocol, TypeVar -import datasets import torch import torch.distributed from accelerate import Accelerator @@ -28,6 +27,7 @@ from tqdm import tqdm from transformers import PreTrainedTokenizerBase +import datasets from modelopt.torch.puzzletron.tools.logger import mprint from modelopt.torch.puzzletron.utils.data.dataset import ConstantLengthDataset @@ -71,6 +71,54 @@ def load_streaming_fn( return dataset +def create_train_dataloader( + seed: int, + tokenizer: PreTrainedTokenizerBase, + block_size: int, + dataset_path: str | Mapping[str, Dataset], + content_field: str, + fim_rate: float, + fim_spm_rate: float, + micro_batch_size: int, + load_dataset_fn: LoadDatasetFn = load_from_disk_fn, + dataset_name: str = "train", + keep_in_memory: bool = False, + shuffle_seed: int | None = None, + source_datasets_to_discard: Sequence[str] = (), + bos_rate: float = 1.0, + num_workers: int = 0, +) -> DataLoader: + """Create an infinite training DataLoader over ConstantLengthDataset.""" + if isinstance(dataset_path, str): + dataset = load_dataset_fn(dataset_path, content_field, keep_in_memory) + else: + dataset = dataset_path + + train_data = dataset[dataset_name] + if shuffle_seed is not None: + train_data = train_data.shuffle(seed=shuffle_seed, keep_in_memory=True) + + train_dataset = ConstantLengthDataset( + tokenizer, + train_data, + infinite=True, + seq_length=block_size, + content_field=content_field, + fim_rate=fim_rate, + fim_spm_rate=fim_spm_rate, + seed=seed, + source_datasets_to_discard=source_datasets_to_discard, + bos_rate=bos_rate, + ) + + return DataLoader( + train_dataset, + batch_size=micro_batch_size, + pin_memory=True, + num_workers=num_workers, + ) + + def create_validation_dataloader( accelerator: Accelerator | None, seed: int, diff --git a/modelopt/torch/puzzletron/utils/data/dataset.py b/modelopt/torch/puzzletron/utils/data/dataset.py index fffc2a3a1d..c1278f0541 100644 --- a/modelopt/torch/puzzletron/utils/data/dataset.py +++ b/modelopt/torch/puzzletron/utils/data/dataset.py @@ -120,7 +120,14 @@ def __iter__(self) -> dict[str, torch.Tensor]: and {"content", "role"}.issubset(sample[0]) ): if len(sample) > 1: - sample = self.tokenizer.apply_chat_template(sample, tokenize=False) + if getattr(self.tokenizer, "chat_template", None) is not None: + sample = self.tokenizer.apply_chat_template( + sample, tokenize=False + ) + else: + # Base models have no chat template — concatenate message + # contents separated by newlines as plain text. + sample = "\n".join(m["content"] for m in sample) else: sample = sample[0]["content"] else: diff --git a/modelopt/torch/puzzletron/utils/parsing.py b/modelopt/torch/puzzletron/utils/parsing.py index ff5bb6963a..f684ac9bb3 100644 --- a/modelopt/torch/puzzletron/utils/parsing.py +++ b/modelopt/torch/puzzletron/utils/parsing.py @@ -24,12 +24,16 @@ # mypy: ignore-errors import json +import logging +import math from pathlib import Path from typing import Any import torch from omegaconf import DictConfig +_logger = logging.getLogger(__name__) + def handle_arg_string(arg): if arg.lower() == "true": @@ -332,6 +336,24 @@ def format_stitched_losses( if not losses_dict: return "❌ No losses found" + # Filter out nan entries. NaN is expected for no-op blocks (e.g. Mamba) that have no + # trainable parameters. For any other block, NaN signals divergence — warn loudly. + nan_keys = [k for k, v in losses_dict.items() if math.isnan(v)] + if nan_keys: + _logger.warning( + "NaN loss detected for block(s): %s. " + "Expected for no-op/skipped blocks (e.g. Mamba); indicates divergence otherwise.", + nan_keys, + ) + losses_dict = {k: v for k, v in losses_dict.items() if not math.isnan(v)} + if best_steps_dict: + best_steps_dict = {k: v for k, v in best_steps_dict.items() if k in losses_dict} + if best_values_dict: + best_values_dict = {k: v for k, v in best_values_dict.items() if k in losses_dict} + + if not losses_dict: + return "❌ No trainable blocks found" + lines = [] # Calculate statistics diff --git a/modelopt/torch/utils/plugins/transformers_dataset.py b/modelopt/torch/utils/plugins/transformers_dataset.py index e147ebf2c2..b70ae5875b 100644 --- a/modelopt/torch/utils/plugins/transformers_dataset.py +++ b/modelopt/torch/utils/plugins/transformers_dataset.py @@ -21,9 +21,9 @@ import torch import transformers -from datasets import load_dataset from transformers.trainer_pt_utils import LabelSmoother +from datasets import load_dataset from modelopt.torch.utils import print_rank_0 REMOVE_THINK_CHAT_TEMPLATE = ( diff --git a/tests/_test_utils/torch/puzzletron/utils.py b/tests/_test_utils/torch/puzzletron/utils.py index b5e32566de..ea3fc9c5f5 100644 --- a/tests/_test_utils/torch/puzzletron/utils.py +++ b/tests/_test_utils/torch/puzzletron/utils.py @@ -18,10 +18,10 @@ from pathlib import Path import torch -from datasets import Dataset, DatasetDict from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, PreTrainedTokenizerBase import modelopt.torch.utils.distributed as dist +from datasets import Dataset, DatasetDict from modelopt.torch.puzzletron.tools.hydra_utils import register_hydra_resolvers diff --git a/tests/gpu/torch/puzzletron/nas/plugins/test_nas_convert.py b/tests/gpu/torch/puzzletron/nas/plugins/test_nas_convert.py index 8a5bad0c62..30a6e669c6 100644 --- a/tests/gpu/torch/puzzletron/nas/plugins/test_nas_convert.py +++ b/tests/gpu/torch/puzzletron/nas/plugins/test_nas_convert.py @@ -38,7 +38,7 @@ def test_nas_convert_ffn_pruning(project_root_path: Path, tmp_path: Path): def _test_nas_convert_ffn_pruning_multiprocess_job( project_root_path: Path, tmp_path: Path, rank: int, size: int ): - dist.setup(timeout=timedelta(10)) + dist.setup(timeout=timedelta(minutes=10)) # Setup the test model and data. puzzle_dir, llama_checkpoint_path, dataset_path = setup_test_model_and_data( project_root_path, tmp_path, rank, "meta-llama/Llama-3.1-8B-Instruct" @@ -94,7 +94,7 @@ def test_nas_convert_attn_pruning(project_root_path: Path, tmp_path: Path): def _test_nas_convert_attn_pruning_multiprocess_job( project_root_path: Path, tmp_path: Path, rank: int, size: int ): - dist.setup(timeout=timedelta(10)) + dist.setup(timeout=timedelta(minutes=10)) # Setup the test model and data. puzzle_dir, llama_checkpoint_path, dataset_path = setup_test_model_and_data( project_root_path, tmp_path, rank, "meta-llama/Llama-3.1-8B-Instruct" diff --git a/tests/gpu/torch/puzzletron/nas/plugins/test_nas_search.py b/tests/gpu/torch/puzzletron/nas/plugins/test_nas_search.py index 2af371e5ca..c1410f4e65 100644 --- a/tests/gpu/torch/puzzletron/nas/plugins/test_nas_search.py +++ b/tests/gpu/torch/puzzletron/nas/plugins/test_nas_search.py @@ -37,7 +37,7 @@ def test_nas_search(project_root_path: Path, tmp_path: Path): def _test_nas_search_multiprocess_job( project_root_path: Path, tmp_path: Path, rank: int, size: int ): - dist.setup(timeout=timedelta(10)) + dist.setup(timeout=timedelta(minutes=10)) # Setup the test model and data. puzzle_dir, llama_checkpoint_path, dataset_path = setup_test_model_and_data( project_root_path, tmp_path, rank, "meta-llama/Llama-3.1-8B-Instruct" diff --git a/tests/gpu/torch/puzzletron/resources/configs/Qwen/Qwen3-VL-30B-A3B-Instruct/pruning/expert_pruning.yaml b/tests/gpu/torch/puzzletron/resources/configs/Qwen/Qwen3-VL-30B-A3B-Instruct/pruning/expert_pruning.yaml index 81c5f35ba5..23173d78e0 100644 --- a/tests/gpu/torch/puzzletron/resources/configs/Qwen/Qwen3-VL-30B-A3B-Instruct/pruning/expert_pruning.yaml +++ b/tests/gpu/torch/puzzletron/resources/configs/Qwen/Qwen3-VL-30B-A3B-Instruct/pruning/expert_pruning.yaml @@ -6,7 +6,8 @@ activations_log_dir: ${puzzle_dir}/pruning/pruning_scores/expert_removal/${pruni pruning_mixin: _target_: modelopt.torch.puzzletron.pruning.expert_removal_pruning_mixin.ExpertRemovalPruningMixIn layer_descriptor: - _target_: modelopt.torch.puzzletron.anymodel.models.qwen3_vl_30b_a3b_instruct.qwen3_vl_30b_a3b_instruct_model_descriptor.Qwen3VL30BA3BInstructExpertRemovalLayerDescriptor + _target_: + modelopt.torch.puzzletron.anymodel.models.qwen3_vl_30b_a3b_instruct.qwen3_vl_30b_a3b_instruct_model_descriptor.Qwen3VL30BA3BInstructExpertRemovalLayerDescriptor target_name: "mlp" hook_class: ${get_object:modelopt.torch.nas.plugins.megatron_hooks.base_hooks.Qwen3VLRemoveExpertsIndependentHook} diff --git a/tests/gpu/torch/puzzletron/resources/configs/meta-llama/Llama-3.1-8B-Instruct/bypass/test_bypass.yaml b/tests/gpu/torch/puzzletron/resources/configs/meta-llama/Llama-3.1-8B-Instruct/bypass/test_bypass.yaml new file mode 100644 index 0000000000..0d78205c1e --- /dev/null +++ b/tests/gpu/torch/puzzletron/resources/configs/meta-llama/Llama-3.1-8B-Instruct/bypass/test_bypass.yaml @@ -0,0 +1,99 @@ +# @package bypass +# Minimal bypass config for GPU integration tests. +# Uses tiny training budget (128 tokens) and tiny model (hidden_size=256, +# intermediate_size=512, 2 layers) to run fast on CI. + +dtype: "bf16" +seed: 42 +experiment_id: +experiment_dir: +iter_num: 1 +step_num: 1 +token_count: 0 + +data: + data_column: "conversation" + block_size: 64 + bos_rate: 0.5 + fim_rate: 0 + fim_spm_rate: 0 + source_datasets_to_discard: [] + load_from_disk: true + keep_in_memory: false + val_dataset_name: valid + max_eval_samples: 1 + eval_samples_per_process: + shuffle_train_data_seed: 42 + +training: + learning_rate: 1e-4 + training_tokens: 128 + micro_batch_size: 1 + val_micro_batch_size: 1 + warmup_ratio: 0.05 + warmup_steps: + min_lr_factor: 1e-5 + grad_accumulation_steps: 1 + skip_first_batches: 0 + weight_decay: 0.1 + decay_lr: true + beta1: 0.9 + beta2: 0.95 + use_grad_scaling: false + grad_clip: 1.0 + grad_clip_type: norm + clipping_count: 0 + log_interval: 5 + eval_interval: 100 + +resume_checkpoint_path: +find_last_ckpt_for_resume: false +parameter_count: +init_checkpoint_path: + +model: + student_weights_dtype: "bf16" + model_overrides: + delete_old_checkpoints: true + save_interval_seconds: + save_interval: 1000000000 + save_checkpoint_when_done: true + model_config_overrides: + ffn: + - intermediate_size: + no_op: + attention: + - num_key_value_heads: + no_op: + +model_factory: + factory: bypass_factory_fn + block_loss_func: normalized_mse_loss + gqa_init_mode: AverageKV + mlp_init_mode: Truncate + mlp_init_config: + activations_log_dir: + linear_init_mode: FromTeacher + submodule_for_loss_calculation: + keys_to_learn: entire_block + +disable_initial_validate: true +validate_teacher_model: false +validate_student_model: false +disable_validation: true +best_val_loss: 1.0e+9 + +compile: false +disable_fa2: false +teacher_model_load_on_cpu: false + +save_checkpoint_before_training: false +disable_checkpoint_save: false +save_best_ckpt: true +kill_after_first_save: false +realize_best_or_latest: "best" + +wandb_log: false +wandb: + project: + entity: diff --git a/tests/gpu/torch/puzzletron/test_bypass.py b/tests/gpu/torch/puzzletron/test_bypass.py new file mode 100644 index 0000000000..93f7244f34 --- /dev/null +++ b/tests/gpu/torch/puzzletron/test_bypass.py @@ -0,0 +1,598 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""GPU integration tests for bypass distillation (blockwise local distillation). + +These tests verify that: +- Bypass distillation runs end-to-end with a tiny Llama model (hidden_size=256, + intermediate_size=512, num_layers=max(2, world_size)). +- FFN pruning, KV-head compression, and multi-config sweep all produce the expected + checkpoint symlinks in puzzle_dir/ckpts/. +- The bypass config injection pattern via OmegaConf works correctly for tests that + do not load a full bypass Hydra config file. + +Model parameters used throughout: + - teacher intermediate_size: 512 -> pruned to 256 (half) for FFN tests + - teacher num_key_value_heads: 8 -> pruned to 4 for KV-head tests + - training_tokens: 128, block_size: 64, micro_batch_size: 1 -> max_steps = 2 +""" + +from datetime import timedelta +from functools import partial +from pathlib import Path + +import hydra +import torch +from _test_utils.torch.distributed.utils import spawn_multiprocess_job +from _test_utils.torch.misc import set_seed +from _test_utils.torch.puzzletron.utils import setup_test_model_and_data +from omegaconf import OmegaConf + +import modelopt.torch.puzzletron.activation_scoring.score_pruning_activations as score_pruning_activations +import modelopt.torch.puzzletron.bypass_distillation as bypass_distillation +import modelopt.torch.puzzletron.pruning.pruning_ckpts as pruning_ckpts +import modelopt.torch.utils.distributed as dist +from modelopt.torch.puzzletron.anymodel import convert_model +from modelopt.torch.puzzletron.tools.hydra_utils import initialize_hydra_config_for_dir + +# --------------------------------------------------------------------------- +# Constants +# --------------------------------------------------------------------------- + +SEED = 1234 +HF_MODEL_NAME = "meta-llama/Llama-3.1-8B-Instruct" +CONVERTER = "llama" +HYDRA_CONFIG_NAME = "meta-llama/Llama-3.1-8B-Instruct/Llama-3.1-8B-Instruct" + +# Teacher model dimensions (set by setup_test_model_and_data for Llama) +TEACHER_INTERMEDIATE_SIZE = 512 +TEACHER_NUM_KV_HEADS = 8 + +# Pruned sizes used in tests +PRUNED_INTERMEDIATE_SIZE = 256 # half of teacher +PRUNED_NUM_KV_HEADS = 4 # half of teacher + +# Training budget: 128 tokens / (64 block * 1 mbs) = 2 steps — completes fast +TRAINING_TOKENS = 128 +BLOCK_SIZE = 64 + + +# --------------------------------------------------------------------------- +# Helper: build the bypass config dict for injection into hydra_cfg +# --------------------------------------------------------------------------- + + +def _make_bypass_cfg_dict( + intermediate_size: int = PRUNED_INTERMEDIATE_SIZE, + num_key_value_heads: int = PRUNED_NUM_KV_HEADS, + configs_list: list | None = None, +) -> dict: + """Return a plain-dict bypass config suitable for OmegaConf.update injection. + + Args: + intermediate_size: FFN intermediate size for the student model. + num_key_value_heads: Number of KV heads for the student model. + configs_list: If provided, populates bypass.configs for a multi-config sweep. + Each entry is a dict with ``model_config_overrides`` and optionally + ``keys_to_learn``. + """ + cfg = { + "dtype": "bf16", + "seed": 42, + "experiment_id": None, + "experiment_dir": None, + "iter_num": 1, + "step_num": 1, + "token_count": 0, + "data": { + # The dummy test dataset stores conversations under the "conversation" column. + "data_column": "conversation", + "block_size": BLOCK_SIZE, + "bos_rate": 0.5, + "fim_rate": 0, + "fim_spm_rate": 0, + "source_datasets_to_discard": [], + "load_from_disk": True, + "keep_in_memory": False, + "val_dataset_name": "valid", + "max_eval_samples": 1, + "eval_samples_per_process": None, + "shuffle_train_data_seed": 42, + }, + "training": { + "learning_rate": 1e-4, + "training_tokens": TRAINING_TOKENS, + "micro_batch_size": 1, + "val_micro_batch_size": 1, + "warmup_ratio": 0.05, + "warmup_steps": None, + "min_lr_factor": 1e-5, + "grad_accumulation_steps": 1, + "skip_first_batches": 0, + "weight_decay": 0.1, + "decay_lr": True, + "beta1": 0.9, + "beta2": 0.95, + "use_grad_scaling": False, + "grad_clip": 1.0, + "grad_clip_type": "norm", + "clipping_count": 0, + "log_interval": 5, + # Large eval_interval so validation is skipped during this short run. + # Validation is fully disabled anyway (disable_validation=True below). + "eval_interval": 100, + }, + "resume_checkpoint_path": None, + "find_last_ckpt_for_resume": False, + "parameter_count": None, + "init_checkpoint_path": None, + "model": { + "student_weights_dtype": "bf16", + "model_overrides": { + "delete_old_checkpoints": True, + "save_interval_seconds": None, + # Effectively disable step-interval saving; rely on save_checkpoint_when_done. + "save_interval": 1_000_000_000, + "save_checkpoint_when_done": True, + }, + "model_config_overrides": { + "ffn": [{"intermediate_size": intermediate_size, "no_op": None}], + "attention": [{"num_key_value_heads": num_key_value_heads, "no_op": None}], + }, + }, + "model_factory": { + "factory": "bypass_factory_fn", + "block_loss_func": "normalized_mse_loss", + "gqa_init_mode": "AverageKV", + "mlp_init_mode": "Truncate", + "mlp_init_config": {"activations_log_dir": None}, + "linear_init_mode": "FromTeacher", + "submodule_for_loss_calculation": None, + "keys_to_learn": "entire_block", + }, + # Disable all validation to keep tests fast. + "disable_initial_validate": True, + "validate_teacher_model": False, + "validate_student_model": False, + "disable_validation": True, + "best_val_loss": 1e9, + "compile": False, + "disable_fa2": False, + "teacher_model_load_on_cpu": False, + "save_checkpoint_before_training": False, + "disable_checkpoint_save": False, + "save_best_ckpt": True, + # Do NOT use kill_after_first_save — it raises RuntimeError which becomes sys.exit(1). + # Instead let the short training run (2 steps) complete naturally. + "kill_after_first_save": False, + "realize_best_or_latest": "best", + "wandb_log": False, + "wandb": {"project": None, "entity": None}, + } + + if configs_list is not None: + cfg["configs"] = configs_list + + return cfg + + +# --------------------------------------------------------------------------- +# Helper: load hydra config and run pruning prerequisites +# --------------------------------------------------------------------------- + + +def _setup_hydra_cfg_and_pruning( + project_root_path: Path, + tmp_path: Path, + rank: int, + size: int, +) -> tuple: + """Set up the tiny model, convert it, score activations, and create pruning ckpts. + + This is the shared preamble for all bypass tests. Returns + ``(puzzle_dir, dataset_path, hydra_cfg)``. + + Steps performed: + 1. Create a small HF model and dummy dataset via ``setup_test_model_and_data``. + 2. Convert the HF checkpoint to AnyModel/DeciLM format (rank 0 only). + 3. Load the Hydra config with ``puzzle_dir`` and ``dataset_path`` overrides. + 4. Run ``score_pruning_activations`` (distributed). + 5. Run ``pruning_ckpts`` (rank 0 only) then barrier. + """ + set_seed(SEED) + dist.setup(timeout=timedelta(minutes=10)) + + puzzle_dir, hf_checkpoint_path, dataset_path = setup_test_model_and_data( + project_root_path, tmp_path, rank, HF_MODEL_NAME + ) + + hydra_config_dir = str(project_root_path / "tests/gpu/torch/puzzletron/resources/configs") + + # Step 0: Convert HF checkpoint to AnyModel/DeciLM format. + if rank == 0: + convert_model( + input_dir=str(hf_checkpoint_path), + output_dir=str(puzzle_dir / "ckpts/teacher"), + converter=CONVERTER, + ) + dist.barrier() + + # Step 1: Load Hydra config. + hydra_cfg = initialize_hydra_config_for_dir( + config_dir=hydra_config_dir, + config_name=HYDRA_CONFIG_NAME, + overrides=[ + f"puzzle_dir={puzzle_dir}", + f"dataset_path={dataset_path}", + ], + ) + hydra_cfg = hydra.utils.instantiate(hydra_cfg) + + # Step 2: Score pruning activations (distributed). + score_pruning_activations.launch_score_activations(hydra_cfg) + + # Step 3: Create pruning checkpoints (rank 0 only). + if rank == 0: + pruning_ckpts.launch_prune_ckpt(hydra_cfg) + dist.barrier() + + return puzzle_dir, dataset_path, hydra_cfg + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +def test_bypass_ffn_pruning(project_root_path: Path, tmp_path: Path): + """Bypass distillation with FFN pruned to intermediate_size=256. + + Verifies that after training: + - The experiment directory ``bypass/bypass_runs/bypass_ffn256_kv4`` exists. + - A symlink ``ckpts/bypass_ffn256_kv4`` pointing into the experiment dir + is created by ``realize_bypass_checkpoints``. + """ + spawn_multiprocess_job( + size=torch.cuda.device_count(), + job=partial( + _test_bypass_ffn_pruning_job, + project_root_path, + tmp_path, + ), + backend="nccl", + ) + + +def _test_bypass_ffn_pruning_job( + project_root_path: Path, + tmp_path: Path, + rank: int, + size: int, +): + puzzle_dir, dataset_path, hydra_cfg = _setup_hydra_cfg_and_pruning( + project_root_path, tmp_path, rank, size + ) + + # Inject bypass config: prune FFN to 256, keep num_key_value_heads=4. + # experiment_id will be set dynamically to "bypass_ffn256_kv4". + bypass_cfg_dict = _make_bypass_cfg_dict( + intermediate_size=PRUNED_INTERMEDIATE_SIZE, + num_key_value_heads=PRUNED_NUM_KV_HEADS, + ) + OmegaConf.update(hydra_cfg, "bypass", bypass_cfg_dict, merge=True) + + bypass_distillation.launch_bypass_distillation(hydra_cfg) + dist.barrier() + + if rank == 0: + expected_experiment_id = f"bypass_ffn{PRUNED_INTERMEDIATE_SIZE}_kv{PRUNED_NUM_KV_HEADS}" + experiment_dir = puzzle_dir / "bypass/bypass_runs" / expected_experiment_id + ckpt_symlink = puzzle_dir / "ckpts" / expected_experiment_id + + assert experiment_dir.exists(), ( + f"Expected bypass experiment directory to exist: {experiment_dir}" + ) + assert ckpt_symlink.exists() or ckpt_symlink.is_symlink(), ( + f"Expected bypass checkpoint symlink to exist: {ckpt_symlink}" + ) + + dist.cleanup() + + print( + f"PYTEST SUMMARY: test_bypass_ffn_pruning completed successfully. " + f"Puzzle directory: {puzzle_dir}" + ) + + +def test_bypass_kv_head_compression(project_root_path: Path, tmp_path: Path): + """Bypass distillation with KV heads reduced from 8 to 4, FFN kept at 512. + + The experiment_id is ``bypass_ffn512_kv4`` because both FFN and attention + overrides are specified (FFN is kept at teacher size, attention is halved). + """ + spawn_multiprocess_job( + size=torch.cuda.device_count(), + job=partial( + _test_bypass_kv_head_compression_job, + project_root_path, + tmp_path, + ), + backend="nccl", + ) + + +def _test_bypass_kv_head_compression_job( + project_root_path: Path, + tmp_path: Path, + rank: int, + size: int, +): + puzzle_dir, dataset_path, hydra_cfg = _setup_hydra_cfg_and_pruning( + project_root_path, tmp_path, rank, size + ) + + # Keep FFN at teacher size (512) but halve KV heads (8 -> 4). + # experiment_id will be "bypass_ffn512_kv4". + bypass_cfg_dict = _make_bypass_cfg_dict( + intermediate_size=TEACHER_INTERMEDIATE_SIZE, + num_key_value_heads=PRUNED_NUM_KV_HEADS, + ) + OmegaConf.update(hydra_cfg, "bypass", bypass_cfg_dict, merge=True) + + bypass_distillation.launch_bypass_distillation(hydra_cfg) + dist.barrier() + + if rank == 0: + expected_experiment_id = f"bypass_ffn{TEACHER_INTERMEDIATE_SIZE}_kv{PRUNED_NUM_KV_HEADS}" + experiment_dir = puzzle_dir / "bypass/bypass_runs" / expected_experiment_id + ckpt_symlink = puzzle_dir / "ckpts" / expected_experiment_id + + assert experiment_dir.exists(), ( + f"Expected bypass experiment directory to exist: {experiment_dir}" + ) + assert ckpt_symlink.exists() or ckpt_symlink.is_symlink(), ( + f"Expected bypass checkpoint symlink to exist: {ckpt_symlink}" + ) + + dist.cleanup() + + print( + f"PYTEST SUMMARY: test_bypass_kv_head_compression completed successfully. " + f"Puzzle directory: {puzzle_dir}" + ) + + +def test_bypass_multi_config_sequential(project_root_path: Path, tmp_path: Path): + """Bypass distillation sweep: two configs run sequentially via bypass.configs list. + + Config 0: FFN=256, heads=4 -> experiment_id ``bypass_ffn256_kv4`` + Config 1: FFN=512, heads=4 -> experiment_id ``bypass_ffn512_kv4`` + + Both symlinks must exist after the sweep completes. + """ + spawn_multiprocess_job( + size=torch.cuda.device_count(), + job=partial( + _test_bypass_multi_config_sequential_job, + project_root_path, + tmp_path, + ), + backend="nccl", + ) + + +def _test_bypass_multi_config_sequential_job( + project_root_path: Path, + tmp_path: Path, + rank: int, + size: int, +): + puzzle_dir, dataset_path, hydra_cfg = _setup_hydra_cfg_and_pruning( + project_root_path, tmp_path, rank, size + ) + + # Build base bypass config (model_config_overrides will be overwritten by configs list). + configs_list = [ + { + "model_config_overrides": { + "ffn": [{"intermediate_size": PRUNED_INTERMEDIATE_SIZE, "no_op": None}], + "attention": [{"num_key_value_heads": PRUNED_NUM_KV_HEADS, "no_op": None}], + }, + "keys_to_learn": "entire_block", + }, + { + "model_config_overrides": { + "ffn": [{"intermediate_size": TEACHER_INTERMEDIATE_SIZE, "no_op": None}], + "attention": [{"num_key_value_heads": PRUNED_NUM_KV_HEADS, "no_op": None}], + }, + "keys_to_learn": "entire_block", + }, + ] + bypass_cfg_dict = _make_bypass_cfg_dict( + intermediate_size=PRUNED_INTERMEDIATE_SIZE, + num_key_value_heads=PRUNED_NUM_KV_HEADS, + configs_list=configs_list, + ) + OmegaConf.update(hydra_cfg, "bypass", bypass_cfg_dict, merge=True) + + bypass_distillation.launch_bypass_distillation(hydra_cfg) + dist.barrier() + + if rank == 0: + expected_ids = [ + f"bypass_ffn{PRUNED_INTERMEDIATE_SIZE}_kv{PRUNED_NUM_KV_HEADS}", + f"bypass_ffn{TEACHER_INTERMEDIATE_SIZE}_kv{PRUNED_NUM_KV_HEADS}", + ] + for experiment_id in expected_ids: + experiment_dir = puzzle_dir / "bypass/bypass_runs" / experiment_id + ckpt_symlink = puzzle_dir / "ckpts" / experiment_id + + assert experiment_dir.exists(), ( + f"Expected bypass experiment directory to exist: {experiment_dir}" + ) + assert ckpt_symlink.exists() or ckpt_symlink.is_symlink(), ( + f"Expected bypass checkpoint symlink to exist: {ckpt_symlink}" + ) + + dist.cleanup() + + print( + f"PYTEST SUMMARY: test_bypass_multi_config_sequential completed successfully. " + f"Puzzle directory: {puzzle_dir}" + ) + + +def test_bypass_checkpoint_contents(project_root_path: Path, tmp_path: Path): + """Verify that a bypass checkpoint contains expected HuggingFace model files. + + After bypass completes, the checkpoint directory (reachable via the symlink at + ``ckpts/{experiment_id}``) must contain a ``config.json`` (saved by + ``save_checkpoint`` / ``save_bypass_checkpoint``). + """ + spawn_multiprocess_job( + size=torch.cuda.device_count(), + job=partial( + _test_bypass_checkpoint_contents_job, + project_root_path, + tmp_path, + ), + backend="nccl", + ) + + +def _test_bypass_checkpoint_contents_job( + project_root_path: Path, + tmp_path: Path, + rank: int, + size: int, +): + puzzle_dir, dataset_path, hydra_cfg = _setup_hydra_cfg_and_pruning( + project_root_path, tmp_path, rank, size + ) + + bypass_cfg_dict = _make_bypass_cfg_dict( + intermediate_size=PRUNED_INTERMEDIATE_SIZE, + num_key_value_heads=PRUNED_NUM_KV_HEADS, + ) + OmegaConf.update(hydra_cfg, "bypass", bypass_cfg_dict, merge=True) + + bypass_distillation.launch_bypass_distillation(hydra_cfg) + dist.barrier() + + if rank == 0: + expected_experiment_id = f"bypass_ffn{PRUNED_INTERMEDIATE_SIZE}_kv{PRUNED_NUM_KV_HEADS}" + ckpt_symlink = puzzle_dir / "ckpts" / expected_experiment_id + + assert ckpt_symlink.exists() or ckpt_symlink.is_symlink(), ( + f"Expected bypass checkpoint symlink: {ckpt_symlink}" + ) + + # The symlink resolves to the latest checkpoint dir; verify HF config exists. + resolved = ckpt_symlink.resolve() + config_json = resolved / "config.json" + assert config_json.exists(), ( + f"Expected HuggingFace config.json inside checkpoint: {config_json}" + ) + + # The saving_completed marker must be present (set by save_bypass_checkpoint). + saving_completed = resolved / "saving_completed" + assert saving_completed.exists(), ( + f"Expected saving_completed marker inside checkpoint: {saving_completed}" + ) + + dist.cleanup() + + print( + f"PYTEST SUMMARY: test_bypass_checkpoint_contents completed successfully. " + f"Puzzle directory: {puzzle_dir}" + ) + + +def test_bypass_checkpoint_resume(project_root_path: Path, tmp_path: Path): + """Verify that bypass distillation can resume from a previous checkpoint. + + Runs bypass twice with the same experiment_id: + - First run: completes 2 training steps and saves a checkpoint. + - Second run: uses ``find_last_ckpt_for_resume=True`` to auto-detect the + saved checkpoint and resume from it. + + Checks that the second run finds the checkpoint, loads it without error, + and produces a final checkpoint in the experiment directory. + """ + spawn_multiprocess_job( + size=torch.cuda.device_count(), + job=partial( + _test_bypass_checkpoint_resume_job, + project_root_path, + tmp_path, + ), + backend="nccl", + ) + + +def _test_bypass_checkpoint_resume_job( + project_root_path: Path, + tmp_path: Path, + rank: int, + size: int, +): + puzzle_dir, dataset_path, hydra_cfg = _setup_hydra_cfg_and_pruning( + project_root_path, tmp_path, rank, size + ) + + bypass_cfg_dict = _make_bypass_cfg_dict( + intermediate_size=PRUNED_INTERMEDIATE_SIZE, + num_key_value_heads=PRUNED_NUM_KV_HEADS, + ) + OmegaConf.update(hydra_cfg, "bypass", bypass_cfg_dict, merge=True) + + # --- First run: train and save a checkpoint. --- + bypass_distillation.launch_bypass_distillation(hydra_cfg) + dist.barrier() + + expected_experiment_id = f"bypass_ffn{PRUNED_INTERMEDIATE_SIZE}_kv{PRUNED_NUM_KV_HEADS}" + experiment_dir = puzzle_dir / "bypass/bypass_runs" / expected_experiment_id + + if rank == 0: + assert experiment_dir.exists(), ( + f"First run should have created experiment directory: {experiment_dir}" + ) + + dist.barrier() + + # --- Second run: resume from the checkpoint saved by the first run. --- + # Reset training counters so the second run starts fresh in terms of config, + # but find_last_ckpt_for_resume=True causes it to reload the saved state. + OmegaConf.update(hydra_cfg, "bypass.iter_num", 1, merge=True) + OmegaConf.update(hydra_cfg, "bypass.step_num", 1, merge=True) + OmegaConf.update(hydra_cfg, "bypass.token_count", 0, merge=True) + OmegaConf.update(hydra_cfg, "bypass.find_last_ckpt_for_resume", True, merge=True) + + # The second run should not raise; it should load the checkpoint and complete. + bypass_distillation.launch_bypass_distillation(hydra_cfg) + dist.barrier() + + if rank == 0: + ckpt_symlink = puzzle_dir / "ckpts" / expected_experiment_id + assert ckpt_symlink.exists() or ckpt_symlink.is_symlink(), ( + f"Second (resume) run should produce a checkpoint symlink: {ckpt_symlink}" + ) + + dist.cleanup() + + print( + f"PYTEST SUMMARY: test_bypass_checkpoint_resume completed successfully. " + f"Puzzle directory: {puzzle_dir}" + ) diff --git a/tests/gpu/torch/puzzletron/test_puzzletron.py b/tests/gpu/torch/puzzletron/test_puzzletron.py index f3f49bed27..f40e5e757f 100644 --- a/tests/gpu/torch/puzzletron/test_puzzletron.py +++ b/tests/gpu/torch/puzzletron/test_puzzletron.py @@ -90,7 +90,7 @@ def _test_puzzletron_multiprocess_job( ): # Set seed BEFORE dist.setup() to ensure reproducibility across all processes set_seed(SEED) - dist.setup(timeout=timedelta(10)) + dist.setup(timeout=timedelta(minutes=10)) # Setup the test model and data. puzzle_dir, hf_checkpoint_path, dataset_path = setup_test_model_and_data( @@ -206,28 +206,44 @@ def _assert_score_pruning_activations(puzzle_dir: Path, hf_model_name: str): size = dist.size() if expected is not None: - # In multi-GPU: layers are distributed across ranks - # Each rank processes len(expected) // size layers - expected_layers_per_rank = len(expected) // size - assert len(layer_names) == expected_layers_per_rank, ( - f"Expected {expected_layers_per_rank} FFN layers on rank {rank}/{size}, got {len(layer_names)}" + # The test model has num_hidden_layers = max(2, size), so every rank owns at least + # one layer. Compute the actual expected count for *this* rank. + total_layers = max(2, size) + layers_this_rank = total_layers // size + (1 if rank < total_layers % size else 0) + assert len(layer_names) == layers_this_rank, ( + f"Expected {layers_this_rank} FFN layers on rank {rank}/{size}, got {len(layer_names)}" ) - # Check each layer's values - for i, layer_name in enumerate(layer_names): - layer_data = pruning_scores[layer_name] - # Calculate global layer index from rank and local index - global_idx = rank * expected_layers_per_rank + i - assert layer_data["score"][0].item() == expected[global_idx]["score"] - assert ( - layer_data["channels_importance_ascending"][0].item() - == expected[global_idx]["channels"] + + # Numerical score checks are only meaningful when the expected table was + # collected with the same GPU count (same total_layers). When running on + # more GPUs than the table covers, skip the per-value assertions rather than + # failing: the layer-count check above already confirms the distribution is right. + if len(expected) == total_layers: + global_start = sum( + max(2, size) // size + (1 if r < max(2, size) % size else 0) for r in range(rank) ) - else: - # Print values for new models - update EXPECTED_PRUNING_VALUES with these - print(f"\n=== PRUNING VALUES for {hf_model_name} (num_layers={len(layer_names)}) ===") + for i, layer_name in enumerate(layer_names): + layer_data = pruning_scores[layer_name] + global_idx = global_start + i + assert layer_data["score"][0].item() == expected[global_idx]["score"] + assert ( + layer_data["channels_importance_ascending"][0].item() + == expected[global_idx]["channels"] + ) + # Print values for new models — update EXPECTED_PRUNING_VALUES with these. + # Only rank 0 prints: it loads all rank_*.pth files and outputs the complete + # ordered table so multi-GPU runs produce a single, uninterleaved snippet. + elif rank == 0: + total_layers = max(2, size) + scores_dir = (puzzle_dir / rank_filepath).parent + all_scores: dict = {} + for r in range(size): + all_scores.update(torch.load(scores_dir / f"rank_{r}.pth")) + sorted_names = sorted(all_scores.keys()) + print(f"\n=== PRUNING VALUES for {hf_model_name} (num_layers={total_layers}) ===") print(f'"{hf_model_name}": [') - for layer_name in layer_names: - layer_data = pruning_scores[layer_name] + for layer_name in sorted_names: + layer_data = all_scores[layer_name] score = layer_data["score"][0].item() channels = layer_data["channels_importance_ascending"][0].item() print(f' {{"score": {score}, "channels": {channels}}},') diff --git a/tests/unit/torch/puzzletron/__init__.py b/tests/unit/torch/puzzletron/__init__.py new file mode 100644 index 0000000000..1275d78dff --- /dev/null +++ b/tests/unit/torch/puzzletron/__init__.py @@ -0,0 +1,15 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + diff --git a/tests/unit/torch/puzzletron/test_bypass_losses.py b/tests/unit/torch/puzzletron/test_bypass_losses.py new file mode 100644 index 0000000000..4f6869c4e6 --- /dev/null +++ b/tests/unit/torch/puzzletron/test_bypass_losses.py @@ -0,0 +1,115 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for normalized MSE loss functions in sewing_kit/utils.py.""" + +import torch + +from modelopt.torch.puzzletron.sewing_kit.utils import ( + batched_normalized_mse_loss, + normalized_mse_loss, + vectorwise_normalized_mse_loss, +) + +# --------------------------------------------------------------------------- +# normalized_mse_loss +# --------------------------------------------------------------------------- + + +def test_normalized_mse_loss_identical_tensors(): + """Identical input and target should produce a loss of approximately 0.""" + torch.manual_seed(42) + x = torch.randn(4, 8) + loss = normalized_mse_loss(x, x) + assert torch.allclose(loss, torch.zeros_like(loss), atol=1e-6) + + +def test_normalized_mse_loss_basic(): + """Loss should be positive and finite for random, non-identical tensors.""" + torch.manual_seed(42) + input_ = torch.randn(4, 8) + target = torch.randn(4, 8) + loss = normalized_mse_loss(input_, target) + assert loss.item() > 0.0 + assert torch.isfinite(loss) + + +def test_normalized_mse_loss_reduction_none(): + """With reduction='none' the output shape should match the input shape.""" + torch.manual_seed(42) + input_ = torch.randn(4, 8) + target = torch.randn(4, 8) + loss = normalized_mse_loss(input_, target, reduction="none") + assert loss.shape == input_.shape + + +def test_normalized_mse_loss_reduction_sum(): + """With reduction='sum' the output should be a scalar tensor.""" + torch.manual_seed(42) + input_ = torch.randn(4, 8) + target = torch.randn(4, 8) + loss = normalized_mse_loss(input_, target, reduction="sum") + assert loss.ndim == 0 # scalar + assert torch.isfinite(loss) + + +# --------------------------------------------------------------------------- +# vectorwise_normalized_mse_loss +# --------------------------------------------------------------------------- + + +def test_vectorwise_normalized_mse_loss_shape(): + """vectorwise_normalized_mse_loss should return a scalar for any 2-D input.""" + torch.manual_seed(42) + input_ = torch.randn(4, 16) + target = torch.randn(4, 16) + loss = vectorwise_normalized_mse_loss(input_, target) + assert loss.ndim == 0 # scalar + assert torch.isfinite(loss) + + +def test_vectorwise_normalized_mse_loss_identical(): + """Identical input and target should give a loss of approximately 0.""" + torch.manual_seed(42) + x = torch.randn(4, 16) + loss = vectorwise_normalized_mse_loss(x, x) + assert torch.allclose(loss, torch.zeros_like(loss), atol=1e-6) + + +# --------------------------------------------------------------------------- +# batched_normalized_mse_loss +# --------------------------------------------------------------------------- + + +def test_batched_normalized_mse_loss_basic(): + """Should return a scalar with a positive, finite value for random tensors.""" + torch.manual_seed(42) + input_ = torch.randn(4, 8) + target = torch.randn(4, 8) + loss = batched_normalized_mse_loss(input_, target) + assert loss.ndim == 0 # scalar + assert loss.item() > 0.0 + assert torch.isfinite(loss) + + +def test_batched_normalized_mse_loss_custom_dims(): + """Custom batch_dims=(0, 1) on a 3-D tensor should still return a scalar.""" + torch.manual_seed(42) + input_ = torch.randn(2, 3, 8) + target = torch.randn(2, 3, 8) + loss = batched_normalized_mse_loss(input_, target, batch_dims=(0, 1)) + assert loss.ndim == 0 # scalar + assert torch.isfinite(loss) + assert loss.item() > 0.0 diff --git a/tests/unit/torch/puzzletron/test_bypass_utils.py b/tests/unit/torch/puzzletron/test_bypass_utils.py new file mode 100644 index 0000000000..9e817e3991 --- /dev/null +++ b/tests/unit/torch/puzzletron/test_bypass_utils.py @@ -0,0 +1,388 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for bypass_utils helpers and _set_keys_to_learn.""" + +import types + +import pytest +import torch +import torch.nn as nn +from omegaconf import OmegaConf + +from modelopt.torch.puzzletron.bypass_distillation.bypass_utils import ( + get_distributed_modules_ownership, + set_experiment_id, +) +from modelopt.torch.puzzletron.bypass_distillation.stitched_model_factory import _set_keys_to_learn + + +def test_single_gpu_all_to_rank_0(): + """With world_size=1, all 4 modules should be assigned to rank 0.""" + ownership = get_distributed_modules_ownership(module_count=4, world_size=1) + assert ownership == [0, 0, 0, 0] + + +def test_even_distribution(): + """With world_size=2 and 4 modules, each rank should own exactly 2 modules.""" + ownership = get_distributed_modules_ownership(module_count=4, world_size=2) + assert ownership.count(0) == 2 + assert ownership.count(1) == 2 + assert len(ownership) == 4 + + +def test_uneven_distribution(): + """With world_size=2 and 3 modules, rank 0 should own 2 and rank 1 should own 1.""" + ownership = get_distributed_modules_ownership(module_count=3, world_size=2) + assert ownership.count(0) == 2 + assert ownership.count(1) == 1 + assert len(ownership) == 3 + + +@pytest.mark.parametrize( + ("module_count", "world_size"), + [ + (1, 1), + (4, 1), + (4, 2), + (4, 4), + (7, 3), + (10, 4), + (1, 2), + ], +) +def test_total_equals_module_count(module_count, world_size): + """The length of the ownership list must always equal module_count.""" + ownership = get_distributed_modules_ownership(module_count=module_count, world_size=world_size) + assert len(ownership) == module_count + + +def test_consecutive_ownership(): + """Each rank should own a contiguous block of indices (no interleaving).""" + ownership = get_distributed_modules_ownership(module_count=7, world_size=3) + # Verify that once we see a new rank, we never see the previous rank again. + seen_ranks = set() + prev_rank = ownership[0] + seen_ranks.add(prev_rank) + for rank in ownership[1:]: + if rank != prev_rank: + assert rank not in seen_ranks, ( + f"Rank {rank} appears non-consecutively in ownership list: {ownership}" + ) + seen_ranks.add(rank) + prev_rank = rank + + +def test_single_module(): + """With world_size=2 and only 1 module, rank 0 should be the sole owner.""" + ownership = get_distributed_modules_ownership(module_count=1, world_size=2) + assert ownership == [0] + assert len(ownership) == 1 + + +# --------------------------------------------------------------------------- +# Helpers for _set_keys_to_learn tests +# --------------------------------------------------------------------------- + + +def _make_flat_model(*param_names: str) -> nn.Module: + """Return a flat module whose named_parameters() yields exactly the given names. + + Parameter names must not contain dots (use underscores instead). + All parameters are float32 and start with requires_grad=False. + """ + model = nn.Module() + model.config = types.SimpleNamespace() + for name in param_names: + assert "." not in name, f"Use underscores, not dots, in flat model param names: {name}" + model.register_parameter(name, nn.Parameter(torch.randn(4), requires_grad=False)) + return model + + +class _FakeLMConfig: + def __init__(self, num_hidden_layers, block_configs=None): + self.num_hidden_layers = num_hidden_layers + self.block_configs = block_configs + + +class _FakeDescriptor: + """Minimal descriptor stub for _set_keys_to_learn tests.""" + + def __init__(self, lm_config, weight_groups): + self._lm_config = lm_config + self._weight_groups = weight_groups + + def get_language_model_config(self, model_config): + return self._lm_config + + def get_weight_groups(self, state_dict_keys, num_hidden_layers): + return self._weight_groups + + +# --------------------------------------------------------------------------- +# _set_keys_to_learn tests +# --------------------------------------------------------------------------- + + +def test_set_keys_to_learn_sequence(): + """Passing a list of parameter names enables grad only on those params.""" + model = _make_flat_model("weight_a", "weight_b", "weight_c") + _set_keys_to_learn(model, descriptor=None, keys_to_learn=["weight_a", "weight_c"]) + + assert model.get_parameter("weight_a").requires_grad is True + assert model.get_parameter("weight_b").requires_grad is False + assert model.get_parameter("weight_c").requires_grad is True + + +def test_set_keys_to_learn_regex(): + """A bare regex string selects parameters by re.search.""" + model = _make_flat_model("block_0_ffn_weight", "block_0_attn_weight", "block_1_ffn_weight") + _set_keys_to_learn(model, descriptor=None, keys_to_learn=r"_ffn_") + + assert model.get_parameter("block_0_ffn_weight").requires_grad is True + assert model.get_parameter("block_0_attn_weight").requires_grad is False + assert model.get_parameter("block_1_ffn_weight").requires_grad is True + + +def test_set_keys_to_learn_no_match_is_noop(): + """A regex that matches nothing should not raise and leave all params unchanged.""" + model = _make_flat_model("weight_a", "weight_b") + _set_keys_to_learn(model, descriptor=None, keys_to_learn=r"NONEXISTENT_PATTERN_XYZ") + + assert model.get_parameter("weight_a").requires_grad is False + assert model.get_parameter("weight_b").requires_grad is False + + +def test_set_keys_to_learn_subblock_ffn(): + """'subblock_ffn' should enable only params in _ffn weight groups.""" + # Names use underscores throughout (no dots) so register_parameter accepts them. + model = _make_flat_model("block_0_ffn_w1", "block_0_ffn_w2", "block_0_attn_q", "block_0_attn_k") + weight_groups = { + "block_0_ffn": ["block_0_ffn_w1", "block_0_ffn_w2"], + "block_0_attention": ["block_0_attn_q", "block_0_attn_k"], + } + lm_config = _FakeLMConfig(num_hidden_layers=1) + descriptor = _FakeDescriptor(lm_config, weight_groups) + + _set_keys_to_learn(model, descriptor=descriptor, keys_to_learn="subblock_ffn") + + assert model.get_parameter("block_0_ffn_w1").requires_grad is True + assert model.get_parameter("block_0_ffn_w2").requires_grad is True + assert model.get_parameter("block_0_attn_q").requires_grad is False + assert model.get_parameter("block_0_attn_k").requires_grad is False + + +def test_set_keys_to_learn_subblock_attention(): + """'subblock_attention' should enable only params in _attention weight groups.""" + model = _make_flat_model("block_0_ffn_w1", "block_0_attn_q", "block_0_attn_k") + weight_groups = { + "block_0_ffn": ["block_0_ffn_w1"], + "block_0_attention": ["block_0_attn_q", "block_0_attn_k"], + } + lm_config = _FakeLMConfig(num_hidden_layers=1) + descriptor = _FakeDescriptor(lm_config, weight_groups) + + _set_keys_to_learn(model, descriptor=descriptor, keys_to_learn="subblock_attention") + + assert model.get_parameter("block_0_ffn_w1").requires_grad is False + assert model.get_parameter("block_0_attn_q").requires_grad is True + assert model.get_parameter("block_0_attn_k").requires_grad is True + + +def test_set_keys_to_learn_entire_block(): + """'entire_block' should enable all attention and ffn params.""" + model = _make_flat_model("block_0_ffn_w1", "block_0_attn_q") + weight_groups = { + "block_0_ffn": ["block_0_ffn_w1"], + "block_0_attention": ["block_0_attn_q"], + } + lm_config = _FakeLMConfig(num_hidden_layers=1) + descriptor = _FakeDescriptor(lm_config, weight_groups) + + _set_keys_to_learn(model, descriptor=descriptor, keys_to_learn="entire_block") + + assert model.get_parameter("block_0_ffn_w1").requires_grad is True + assert model.get_parameter("block_0_attn_q").requires_grad is True + + +def test_set_keys_to_learn_hybrid_mamba_filtering(): + """For hybrid models, subblock_attention skips Mamba blocks and vice-versa.""" + model = _make_flat_model("block_0_attn_q", "block_1_attn_ssm") + weight_groups = { + "block_0_attention": ["block_0_attn_q"], + "block_1_attention": ["block_1_attn_ssm"], + } + + # block_configs: block_0 is GQA (mamba=None), block_1 is Mamba (mamba != None) + block_cfg_0 = types.SimpleNamespace(attention=types.SimpleNamespace(mamba=None)) + block_cfg_1 = types.SimpleNamespace(attention=types.SimpleNamespace(mamba=object())) + + lm_config = _FakeLMConfig(num_hidden_layers=2, block_configs=[block_cfg_0, block_cfg_1]) + descriptor = _FakeDescriptor(lm_config, weight_groups) + + # subblock_attention: only GQA (block_0), not Mamba (block_1) + _set_keys_to_learn(model, descriptor=descriptor, keys_to_learn="subblock_attention") + assert model.get_parameter("block_0_attn_q").requires_grad is True + assert model.get_parameter("block_1_attn_ssm").requires_grad is False + + # Reset and test subblock_mamba: only Mamba (block_1), not GQA (block_0) + for p in model.parameters(): + p.requires_grad_(False) + _set_keys_to_learn(model, descriptor=descriptor, keys_to_learn="subblock_mamba") + assert model.get_parameter("block_0_attn_q").requires_grad is False + assert model.get_parameter("block_1_attn_ssm").requires_grad is True + + +# --------------------------------------------------------------------------- +# set_experiment_id tests +# --------------------------------------------------------------------------- + + +def _make_exp_cfg(overrides: dict, keys_to_learn: str = "entire_block", experiment_id=None): + """Build a minimal DictConfig that set_experiment_id can operate on.""" + return OmegaConf.create( + { + "bypass": { + "experiment_id": experiment_id, + "model": {"model_config_overrides": overrides}, + "model_factory": {"keys_to_learn": keys_to_learn}, + } + } + ) + + +def test_exp_id_ffn_only(): + """FFN intermediate_size change → bypass_ffn{size}.""" + cfg = _make_exp_cfg({"ffn": [{"intermediate_size": 256}]}) + set_experiment_id(cfg) + assert cfg.bypass.experiment_id == "bypass_ffn256" + + +def test_exp_id_attention_only(): + """KV-head change only (FFN None) → bypass_kv{n}.""" + cfg = _make_exp_cfg( + { + "ffn": [{"intermediate_size": None}], + "attention": [{"num_key_value_heads": 4}], + } + ) + set_experiment_id(cfg) + assert cfg.bypass.experiment_id == "bypass_kv4" + + +def test_exp_id_ffn_and_attention(): + """Combined FFN + attention change → bypass_ffn{size}_kv{n}.""" + cfg = _make_exp_cfg( + { + "ffn": [{"intermediate_size": 256}], + "attention": [{"num_key_value_heads": 4}], + } + ) + set_experiment_id(cfg) + assert cfg.bypass.experiment_id == "bypass_ffn256_kv4" + + +def test_exp_id_moe(): + """MoE expert-count change → bypass_experts{n}.""" + cfg = _make_exp_cfg( + { + "ffn": [{"moe": {"num_local_experts": 4}}], + } + ) + set_experiment_id(cfg) + assert cfg.bypass.experiment_id == "bypass_experts4" + + +def test_exp_id_mamba_with_state_dim(): + """Mamba state_dim change → bypass_mambastate{dim}.""" + cfg = _make_exp_cfg( + { + "attention": [{"mamba": {"state_dim": 64}}], + } + ) + set_experiment_id(cfg) + assert cfg.bypass.experiment_id == "bypass_mambastate64" + + +def test_exp_id_mamba_no_structural_change(): + """Mamba bypass with no structural override → fallback to keys_to_learn type.""" + cfg = _make_exp_cfg( + overrides={"attention": [{"num_key_value_heads": None}]}, + keys_to_learn="subblock_mamba", + ) + set_experiment_id(cfg) + assert cfg.bypass.experiment_id == "bypass_mamba" + + +def test_exp_id_fallback_keys_to_learn_variants(): + """No non-None overrides → experiment_id from keys_to_learn.""" + cases = [ + ("subblock_ffn", "bypass_ffn"), + ("subblock_attention", "bypass_attn"), + ("subblock_mamba", "bypass_mamba"), + ("entire_block", "bypass_block"), + ] + for keys_to_learn, expected in cases: + cfg = _make_exp_cfg(overrides={}, keys_to_learn=keys_to_learn) + set_experiment_id(cfg) + assert cfg.bypass.experiment_id == expected, ( + f"keys_to_learn={keys_to_learn!r}: expected {expected!r}, " + f"got {cfg.bypass.experiment_id!r}" + ) + + +def test_exp_id_per_layer_uniform(): + """All layers same size → single value (no dash separator).""" + cfg = _make_exp_cfg({"ffn": [{"intermediate_size": 256}] * 4}) + set_experiment_id(cfg) + assert cfg.bypass.experiment_id == "bypass_ffn256" + + +def test_exp_id_per_layer_mixed(): + """Different per-layer sizes → values joined with dash.""" + cfg = _make_exp_cfg( + { + "ffn": [ + {"intermediate_size": 256}, + {"intermediate_size": 3072}, + {"intermediate_size": 256}, + ] + } + ) + set_experiment_id(cfg) + assert cfg.bypass.experiment_id == "bypass_ffn256-3072" + + +def test_exp_id_already_set_is_noop(): + """experiment_id already populated → set_experiment_id is a no-op.""" + cfg = _make_exp_cfg( + overrides={"ffn": [{"intermediate_size": 256}]}, + experiment_id="my_custom_id", + ) + set_experiment_id(cfg) + assert cfg.bypass.experiment_id == "my_custom_id" + + +def test_exp_id_none_fields_not_included(): + """None-valued fields do not contribute to the experiment ID.""" + cfg = _make_exp_cfg( + { + "ffn": [{"intermediate_size": None, "moe": None}], + "attention": [{"num_key_value_heads": 4}], + } + ) + set_experiment_id(cfg) + # Only the kv component should appear + assert cfg.bypass.experiment_id == "bypass_kv4"