diff --git a/examples/speculative_decoding/README.md b/examples/speculative_decoding/README.md index 2a29f644e6..96f75f8a81 100644 --- a/examples/speculative_decoding/README.md +++ b/examples/speculative_decoding/README.md @@ -319,15 +319,96 @@ trainer.save_state() trainer.save_model("") ``` +## DFlash: Block Diffusion for Flash Speculative Decoding + +DFlash ([arXiv:2602.06036](https://arxiv.org/abs/2602.06036)) is a parallel speculative decoding method that predicts multiple tokens simultaneously using block diffusion. Unlike autoregressive methods (EAGLE, Medusa) that draft one token at a time, DFlash predicts an entire block of tokens in parallel, then iteratively denoises them. + +### Architecture + +DFlash uses three key mechanisms: + +- **Feature Fusion**: Multi-layer hidden states from the target model are projected via a fully-connected layer and RMSNorm to create context features +- **KV Injection**: Context features are injected as K/V in every draft decoder layer, while Q comes from the noise embeddings. QK-Norm (RMSNorm on Q and K before RoPE) stabilizes attention +- **Parallel Drafting**: Within each block of size B, unknown positions use a `mask_token_id` token. Only block-start positions get the real token. The attention mask allows noise tokens to attend to all context tokens from previous blocks, plus causally within the same block + +### Training + +```bash +./launch_train.sh --model $BASE_MODEL \ + --output_dir $OUTPUT_DIR \ + --data input_conversations/train.jsonl \ + --num_epochs $NUM_EPOCH \ + --mode dflash \ + --dflash_block_size 16 \ + --dflash_num_layers 5 +``` + +Key arguments: + +| Flag | Default | Description | +|------|---------|-------------| +| `--mode dflash` | - | Enable DFlash mode | +| `--dflash_block_size` | 16 | Block size for parallel prediction | +| `--dflash_num_layers` | 5 | Number of decoder layers in draft module | +| `--dflash_config` | None | Path to JSON config for custom architecture | +| `--dflash_mask_token_id` | auto | Mask token ID (auto-detected from model) | +| `--dflash_disable_torch_compile` | False | Disable torch.compile | +| `--dflash_use_logit_distillation` | False | Use KD from target model logits instead of hard CE | + +### mask_token_id + +The `mask_token_id` is critical for DFlash training and inference. It must be consistent between training and deployment. Auto-detection logic: + +| Model Family | mask_token_id | Source | +|-------------|---------------|--------| +| Qwen3.5 | 248070 | Built-in `[MASK]` token | +| Qwen3 (8B) | 151643 | `eos_token_id` | +| Llama 3 | 128002 | `reserved_special_token_0` | +| Others | `pad_token_id` | Fallback | + +Override with `--dflash_mask_token_id ` if auto-detection is incorrect. + +### Configuring Draft Model + +Similar to EAGLE, provide a JSON config to customize the draft architecture: + +```json +{ + "num_hidden_layers": 5, + "rms_norm_eps": 1e-6 +} +``` + +Model dimensions (hidden_size, num_attention_heads, etc.) are automatically inherited from the base model. + +### Current Status (WIP) + +| Feature | Status | +|---------|--------| +| Architecture (Feature Fusion, KV Injection, Parallel Drafting) | Working | +| Online training with HF Trainer | Working | +| Inference / AR validation (`pseudo_speculative_generate`) | Working | +| z-lab checkpoint loading and inference (AR 7-9) | Working | +| Logit distillation option | Working | +| Response-only loss masking | Working | +| DDP training | Working (with `find_unused_parameters=True`) | + +**Known gap**: Training with ModelOpt achieves ~35% per-token accuracy (matching SpecForge's ~30%), but acceptance rate (AR) is lower than SpecForge-trained checkpoints (1.15 vs 1.95). Investigation shows the **data pipeline** differs significantly: + +- SpecForge uses its own tokenizer template with system prompt and response-only loss mask +- ModelOpt's `LanguageDataCollator` uses `apply_chat_template` with different formatting + +Aligning the data pipeline is the next step to close the AR gap. + ## Support Matrix -| Model | Medusa | EAGLE1/2 | EAGLE3 | -| :---: | :---: | :---: | :---: | -| LLAMA 2 | ✅ | ✅ | ✅ | -| LLAMA 3, 3.1 | ✅ | ✅ | ✅ | -| Mistral | ✅ | ✅ | ✅ | -| Phi 3 | ✅ | ✅ | ✅ | -| QWen 1.5,2,2.5,3 | ✅ | ✅ | ✅ | +| Model | Medusa | EAGLE1/2 | EAGLE3 | DFlash | +| :---: | :---: | :---: | :---: | :---: | +| LLAMA 2 | ✅ | ✅ | ✅ | ✅ | +| LLAMA 3, 3.1 | ✅ | ✅ | ✅ | ✅ | +| Mistral | ✅ | ✅ | ✅ | ✅ | +| Phi 3 | ✅ | ✅ | ✅ | ✅ | +| QWen 1.5,2,2.5,3 | ✅ | ✅ | ✅ | ✅ | ## Speculation Module Checkpoints diff --git a/examples/speculative_decoding/eagle_utils.py b/examples/speculative_decoding/eagle_utils.py index 1bc7df9810..0267ac684c 100644 --- a/examples/speculative_decoding/eagle_utils.py +++ b/examples/speculative_decoding/eagle_utils.py @@ -139,6 +139,7 @@ def make_eagle_supervised_data_module( tokenizer: transformers.PreTrainedTokenizer, data_args, train_len=None, + answer_only_loss=False, ) -> dict: if data_args.offline_data_path is None: train_dataset = ShardedDataset("json", data_files=data_args.data_path) @@ -148,6 +149,7 @@ def make_eagle_supervised_data_module( tokenizer=tokenizer, train_len=train_len, return_labels=True, + answer_only_loss=answer_only_loss, ) else: data_collator = VisionLanguageDataCollator( @@ -203,6 +205,12 @@ def on_log(self, args, state, control, **kwargs): if not hasattr(state, "training_accs") or len(state.training_accs) == 0: return control average_acc = np.mean(state.training_accs, axis=0) + # Always print accuracy to console + try: + acc_str = ", ".join(f"{a:.4f}" for a in np.array(average_acc).flatten()) + print_rank_0(f"Step {state.global_step} Training Acc: [{acc_str}]") + except Exception: + print_rank_0(f"Step {state.global_step} Training Acc: {average_acc}") if self.estimate_ar: # Calculate mean training AR since last log # NOTE: This is only an estimate of the real AR. @@ -217,41 +225,64 @@ def on_log(self, args, state, control, **kwargs): est_ar += acc_cumprod print_rank_0(f"Step {state.global_step} Estimated Training AR: {est_ar:.4f}") + # Log accuracy to HF Trainer's logs dict (picked up by TensorBoard) + logs = kwargs.get("logs") or {} + for i, draft_acc in enumerate(average_acc): + for j, step_acc in enumerate(draft_acc): + logs[f"train_acc/parallel_{i}_step_{j}"] = float(step_acc) + if self.estimate_ar: + logs["estimated_training_ar"] = est_ar + # log to wandb if wandb and is_master(): - logs = kwargs.get("logs") or {} if logs: wandb.log({k: v for k, v in logs.items() if v is not None}, step=state.global_step) - for i, draft_acc in enumerate(average_acc): - for j, step_acc in enumerate(draft_acc): - wandb.log( - {f"parallel_{i}_step_{j}_train_acc": step_acc}, step=state.global_step - ) - if self.estimate_ar: - wandb.log({"estimated_training_ar": est_ar}, step=state.global_step) # reset training_accs state.training_accs = [] return control def on_step_end(self, args, state, control, **kwargs): - """Run AR validation periodically, if available.""" + """Run AR validation periodically (single-GPU only). + + AR validation with DDP is not supported because pseudo_speculative_generate + runs only on rank 0 while other ranks deadlock waiting for collective ops. + When world_size > 1, AR validation is skipped with a one-time warning. + Use post-training AR validation instead (online_training.sh runs it after training). + """ if self.ar_validate_steps <= 0: return control if state.global_step % self.ar_validate_steps == 0 and state.global_step > 0: + if torch.distributed.is_initialized() and torch.distributed.get_world_size() > 1: + if not hasattr(self, "_ar_ddp_warned"): + self._ar_ddp_warned = True + print_rank_0( + "=== WARNING === AR validation during training is not supported with " + "DDP (world_size > 1). Skipping. Use post-training AR validation." + ) + return control + + model = kwargs["model"] + raw_model = model.module if hasattr(model, "module") else model + was_training = raw_model.training + raw_model.eval() print_rank_0("Running AR validation...") try: - ars = validate_ar( - model=kwargs["model"], - tokenizer=kwargs["processing_class"], - ds=load_dataset("HuggingFaceH4/mt_bench_prompts")["train"], - device=kwargs["model"].device, - ) + with torch.no_grad(): + ars = validate_ar( + model=raw_model, + tokenizer=kwargs["processing_class"], + ds=load_dataset("/hf-local/HuggingFaceH4/mt_bench_prompts")["train"], + device=next(raw_model.parameters()).device, + num_samples=8, + ) print_rank_0(f"Step {state.global_step} AR: {sum(ars) / len(ars):.4f}") - if wandb and is_master(): + if wandb: wandb.log({"validate_ar": sum(ars) / len(ars)}, step=state.global_step) - except Exception: - print_rank_0("AR validation not available.") + except Exception as e: + print_rank_0(f"AR validation failed: {e}") + if was_training: + raw_model.train() return control diff --git a/examples/speculative_decoding/launch_train.sh b/examples/speculative_decoding/launch_train.sh index c15b97bdaa..86fc400a35 100755 --- a/examples/speculative_decoding/launch_train.sh +++ b/examples/speculative_decoding/launch_train.sh @@ -134,6 +134,33 @@ while [ $# -gt 0 ]; do if [[ "$1" != *=* ]]; then shift; fi FSDP="${1#*=}" ;; + --dflash_block_size*) + if [[ "$1" != *=* ]]; then shift; fi + DFLASH_BLOCK_SIZE="${1#*=}" + ;; + --dflash_num_layers*) + if [[ "$1" != *=* ]]; then shift; fi + DFLASH_NUM_LAYERS="${1#*=}" + ;; + --dflash_config*) + if [[ "$1" != *=* ]]; then shift; fi + DFLASH_CONFIG="${1#*=}" + ;; + --dflash_mask_token_id*) + if [[ "$1" != *=* ]]; then shift; fi + DFLASH_MASK_TOKEN_ID="${1#*=}" + ;; + --dflash_num_anchors*) + if [[ "$1" != *=* ]]; then shift; fi + DFLASH_NUM_ANCHORS="${1#*=}" + ;; + --dflash_loss_decay_gamma*) + if [[ "$1" != *=* ]]; then shift; fi + DFLASH_LOSS_DECAY_GAMMA="${1#*=}" + ;; + --dflash_use_logit_distillation*) + DFLASH_USE_LOGIT_DISTILLATION="True" + ;; *) >&2 printf "Error: Invalid argument ${1#*=}\n" exit 1 @@ -195,8 +222,34 @@ if [[ "$MODE" == "eagle3" ]]; then else SPECULATIVE_ARGS="" fi +elif [[ "$MODE" == "dflash" ]]; then + DFLASH_BLOCK_SIZE=${DFLASH_BLOCK_SIZE:-16} + DFLASH_NUM_LAYERS=${DFLASH_NUM_LAYERS:-5} + SPECULATIVE_ARGS="--dflash_block_size $DFLASH_BLOCK_SIZE --dflash_num_layers $DFLASH_NUM_LAYERS --dflash_disable_torch_compile" + if [[ -n "$DFLASH_CONFIG" ]]; then + SPECULATIVE_ARGS="$SPECULATIVE_ARGS --dflash_config $DFLASH_CONFIG" + fi + if [[ -n "$DFLASH_MASK_TOKEN_ID" ]]; then + SPECULATIVE_ARGS="$SPECULATIVE_ARGS --dflash_mask_token_id $DFLASH_MASK_TOKEN_ID" + fi + if [[ -n "$DFLASH_NUM_ANCHORS" ]]; then + SPECULATIVE_ARGS="$SPECULATIVE_ARGS --dflash_num_anchors $DFLASH_NUM_ANCHORS" + fi + if [[ -n "$DFLASH_LOSS_DECAY_GAMMA" ]]; then + SPECULATIVE_ARGS="$SPECULATIVE_ARGS --dflash_loss_decay_gamma $DFLASH_LOSS_DECAY_GAMMA" + fi + if [[ "$DFLASH_USE_LOGIT_DISTILLATION" == "True" ]]; then + SPECULATIVE_ARGS="$SPECULATIVE_ARGS --dflash_use_logit_distillation" + fi + # DFlash: DDP by default, FSDP if --fsdp True is passed + if [[ "$FSDP" == "True" ]]; then + FSDP_ARGS="--fsdp 'full_shard' --fsdp_config ${SCRIPT_DIR}/fsdp_config.json" + else + FSDP_ARGS="--ddp_find_unused_parameters True --ddp_timeout 1800" + DP_SHARD_SIZE=1 + fi else - echo "Only eagle3 supported for now!" + echo "Unsupported mode: $MODE. Supported: eagle3, dflash" exit 1 fi @@ -218,12 +271,14 @@ else VLM_ARGS="" fi -if [[ "$TOTAL_GPU" -gt 1 && "$FSDP" == "True" ]]; then - #Use FSDP2 when multi GPU available - FSDP_ARGS="--fsdp 'full_shard' --fsdp_config ${SCRIPT_DIR}/fsdp_config.json" -else - #Otherwise, single GPU training - FSDP_ARGS="" +if [[ "$MODE" != "dflash" ]]; then + if [[ "$TOTAL_GPU" -gt 1 && "$FSDP" == "True" ]]; then + #Use FSDP2 when multi GPU available + FSDP_ARGS="--fsdp 'full_shard' --fsdp_config ${SCRIPT_DIR}/fsdp_config.json" + else + #Otherwise, single GPU training + FSDP_ARGS="" + fi fi @@ -267,6 +322,8 @@ CMD="accelerate launch $MULTI_NODE_ARGS --mixed_precision bf16 ${SCRIPT_DIR}/mai --warmup_steps 100 \ --lr_scheduler_type linear \ --logging_steps $LOG_STEPS \ + --report_to tensorboard \ + --logging_dir $OUTPUT_DIR/tensorboard \ --tf32 True \ $DATA_ARGS \ --disable_tqdm $DISABLE_TQDM \ diff --git a/examples/speculative_decoding/main.py b/examples/speculative_decoding/main.py index 3369d399c2..b5f2b3e26f 100644 --- a/examples/speculative_decoding/main.py +++ b/examples/speculative_decoding/main.py @@ -104,7 +104,7 @@ class TrainingArguments(transformers.TrainingArguments): ) dataloader_drop_last: bool = field(default=True) bf16: bool = field(default=True) - mode: Literal["eagle3", "medusa"] = "eagle3" + mode: Literal["eagle3", "medusa", "dflash"] = "eagle3" estimate_ar: bool = field( default=False, metadata={"help": "Whether to estimate AR during training for logging."} ) @@ -144,6 +144,43 @@ class EagleArguments: ) +@dataclass +class DFlashArguments: + dflash_block_size: int = field( + default=16, metadata={"help": "Block size for DFlash parallel prediction."} + ) + dflash_num_layers: int = field( + default=5, metadata={"help": "Number of decoder layers in the DFlash draft module."} + ) + dflash_config: str | None = field(default=None, metadata={"help": "Path to dflash_config.json"}) + dflash_disable_torch_compile: bool = field( + default=False, + metadata={"help": "Disable torch.compile on DFlash forward/loss methods."}, + ) + dflash_mask_token_id: int | None = field( + default=None, + metadata={"help": "Mask token ID for DFlash. If not set, auto-detected from model."}, + ) + dflash_use_logit_distillation: bool = field( + default=False, + metadata={ + "help": "Use logit distillation (KD from target model) instead of hard CE. " + "Enables training with data not synthesized by the target model." + }, + ) + dflash_num_anchors: int = field( + default=512, + metadata={"help": "Number of random anchor positions per sequence during training."}, + ) + dflash_loss_decay_gamma: float = field( + default=0.0, + metadata={ + "help": "Gamma for loss decay weighting (paper Eq.4). " + "Suggested: 7 for block_size=16. 0 disables." + }, + ) + + def train(): parser = transformers.HfArgumentParser( ( @@ -152,9 +189,10 @@ def train(): TrainingArguments, MedusaArguments, EagleArguments, + DFlashArguments, ) ) - model_args, data_args, training_args, medusa_args, eagle_args = ( + model_args, data_args, training_args, medusa_args, eagle_args, dflash_args = ( parser.parse_args_into_dataclasses() ) if not data_args.data_path and not data_args.offline_data_path: @@ -183,12 +221,19 @@ def train(): use_offline_training = data_args.offline_data_path is not None if checkpoint: + # Load model from output_dir (top-level save) rather than checkpoint subdir + # to avoid meta tensor errors. The checkpoint path is passed to + # trainer.train(resume_from_checkpoint=...) for optimizer/step resume. + model_load_path = training_args.output_dir with patch_transformers5_params_loading(): model = load_vlm_or_llm( - checkpoint, torch_dtype="auto", trust_remote_code=model_args.trust_remote_code + model_load_path, + torch_dtype="auto", + device_map="cpu", + trust_remote_code=model_args.trust_remote_code, ) tokenizer = transformers.AutoTokenizer.from_pretrained( - checkpoint, trust_remote_code=model_args.trust_remote_code + model_load_path, trust_remote_code=model_args.trust_remote_code ) else: # To avoid OOM for large models, we load and convert model on CPU first. @@ -236,13 +281,34 @@ def train(): ) model.eagle_module.d2t = torch.load(data_args.draft_vocab_cache) print_rank_0(f"Loaded draft vocab cache from {data_args.draft_vocab_cache}.") + elif training_args.mode == "dflash": + custom_config = ( + json.load(open(dflash_args.dflash_config)) if dflash_args.dflash_config else {} + ) + custom_config.setdefault("num_hidden_layers", dflash_args.dflash_num_layers) + if dflash_args.dflash_mask_token_id is not None: + custom_config["mask_token_id"] = dflash_args.dflash_mask_token_id + + config = { + "dflash_block_size": dflash_args.dflash_block_size, + "dflash_use_torch_compile": not dflash_args.dflash_disable_torch_compile, + "dflash_self_logit_distillation": dflash_args.dflash_use_logit_distillation, + "dflash_num_anchors": dflash_args.dflash_num_anchors, + "dflash_loss_decay_factor": dflash_args.dflash_loss_decay_gamma, + "dflash_architecture_config": custom_config, + } + + mtsp.convert(model, [("dflash", config)]) else: raise Exception(f"{training_args.mode} is not supported!") print_rank_0("Loading dataset...") - if training_args.mode == "eagle3": + if training_args.mode in ("eagle3", "dflash"): data_module = make_eagle_supervised_data_module( - tokenizer, data_args, train_len=training_args.training_seq_len + tokenizer, + data_args, + train_len=training_args.training_seq_len, + answer_only_loss=(training_args.mode == "dflash"), ) trainer = EagleTrainerWithAccLog( diff --git a/examples/speculative_decoding/train_dflash.py b/examples/speculative_decoding/train_dflash.py new file mode 100644 index 0000000000..20be8a85f0 --- /dev/null +++ b/examples/speculative_decoding/train_dflash.py @@ -0,0 +1,319 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Standalone DFlash training script using SpecForge's data pipeline. + +Uses SpecForge's tokenizer template + offset-mapping loss mask for data +preprocessing, and ModelOpt's DFlash module for the draft model. This +isolates data pipeline differences from model architecture differences. + +Usage: + torchrun --nproc_per_node=8 train_dflash.py \ + --model /path/to/Qwen3-8B \ + --data /path/to/train.jsonl \ + --chat-template qwen \ + --block-size 16 \ + --num-draft-layers 5 \ + --num-epochs 3 \ + --lr 1e-4 \ + --output-dir /path/to/output +""" + +import argparse +import math +import os + +import torch +import torch.distributed as dist +from datasets import load_dataset +from torch.utils.data import DataLoader, DistributedSampler +from transformers import AutoModelForCausalLM, AutoTokenizer + +import modelopt.torch.opt as mto +import modelopt.torch.speculative as mtsp + + +def parse_args(): + """Parse command line arguments.""" + parser = argparse.ArgumentParser(description="DFlash training with SpecForge data pipeline") + parser.add_argument("--model", type=str, required=True, help="Target model path") + parser.add_argument("--data", type=str, required=True, help="Training data JSONL path") + parser.add_argument("--chat-template", type=str, default="qwen", help="Chat template name") + parser.add_argument("--block-size", type=int, default=16) + parser.add_argument("--num-draft-layers", type=int, default=5) + parser.add_argument("--mask-token-id", type=int, default=None) + parser.add_argument("--max-length", type=int, default=512) + parser.add_argument("--num-epochs", type=int, default=3) + parser.add_argument("--lr", type=float, default=1e-4) + parser.add_argument("--batch-size", type=int, default=1) + parser.add_argument("--warmup-ratio", type=float, default=0.01) + parser.add_argument("--log-interval", type=int, default=100) + parser.add_argument("--save-interval", type=int, default=0, help="0 = save at end only") + parser.add_argument("--output-dir", type=str, required=True) + parser.add_argument("--seed", type=int, default=42) + parser.add_argument("--num-ar-samples", type=int, default=20, help="AR validation samples") + return parser.parse_args() + + +def is_rank0(): + """Check if current process is rank 0.""" + return not dist.is_initialized() or dist.get_rank() == 0 + + +def print_rank0(msg): + """Print only on rank 0.""" + if is_rank0(): + print(msg, flush=True) + + +def build_dataset(tokenizer, data_path, chat_template_name, max_length): + """Build dataset using SpecForge's data pipeline. + + Uses SpecForge's GeneralParser to tokenize conversations with the + proper chat template and compute offset-mapping-based loss masks. + """ + from specforge.data.parse import GeneralParser + from specforge.data.template import TEMPLATE_REGISTRY + + template = TEMPLATE_REGISTRY.get(chat_template_name) + parser = GeneralParser(tokenizer, template) + + raw_dataset = load_dataset("json", data_files=data_path)["train"] + + processed = {"input_ids": [], "loss_mask": []} + skipped = 0 + for sample in raw_dataset: + convs = sample.get("conversations", sample.get("messages", [])) + if not convs: + skipped += 1 + continue + try: + input_ids, loss_mask = parser.parse(convs, max_length=max_length) + processed["input_ids"].append(input_ids) + processed["loss_mask"].append(loss_mask) + except Exception: + skipped += 1 + + print_rank0(f"Processed {len(processed['input_ids'])} samples, skipped {skipped}") + return processed + + +class DFlashDataset(torch.utils.data.Dataset): + """Simple dataset wrapping tokenized input_ids and loss_mask.""" + + def __init__(self, data): + self.input_ids = data["input_ids"] + self.loss_mask = data["loss_mask"] + + def __len__(self): + return len(self.input_ids) + + def __getitem__(self, idx): + return { + "input_ids": self.input_ids[idx], + "loss_mask": self.loss_mask[idx], + } + + +def collate_fn(batch): + """Collate batch of samples.""" + input_ids = torch.stack([b["input_ids"] for b in batch]) + loss_mask = torch.stack([b["loss_mask"] for b in batch]) + return {"input_ids": input_ids, "loss_mask": loss_mask} + + +def train(args): + """Main training loop.""" + # Init distributed + dist.init_process_group("nccl") + local_rank = int(os.environ.get("LOCAL_RANK", 0)) + torch.cuda.set_device(local_rank) + device = torch.device("cuda", local_rank) + + torch.manual_seed(args.seed) + mto.enable_huggingface_checkpointing() + + # Load model + print_rank0(f"Loading model: {args.model}") + model = AutoModelForCausalLM.from_pretrained( + args.model, torch_dtype=torch.bfloat16, device_map={"": device}, trust_remote_code=True + ) + tokenizer = AutoTokenizer.from_pretrained(args.model, trust_remote_code=True) + + # Detect mask_token_id + mask_token_id = args.mask_token_id + if mask_token_id is None: + if hasattr(tokenizer, "mask_token_id") and tokenizer.mask_token_id is not None: + mask_token_id = tokenizer.mask_token_id + elif hasattr(tokenizer, "pad_token_id") and tokenizer.pad_token_id is not None: + mask_token_id = tokenizer.pad_token_id + else: + mask_token_id = tokenizer.eos_token_id + print_rank0(f"mask_token_id: {mask_token_id}") + + # Convert to DFlash + config = { + "dflash_block_size": args.block_size, + "dflash_use_torch_compile": False, + "dflash_architecture_config": { + "num_hidden_layers": args.num_draft_layers, + "mask_token_id": mask_token_id, + }, + } + mtsp.convert(model, [("dflash", config)]) + print_rank0( + f"DFlash module created: {sum(p.numel() for p in model.dflash_module.parameters()):,} params" + ) + + # Build dataset using SpecForge pipeline + print_rank0("Building dataset with SpecForge pipeline...") + data = build_dataset(tokenizer, args.data, args.chat_template, args.max_length) + + # Filter samples with too few loss tokens + min_loss_tokens = 2 * args.block_size + filtered_ids = [] + filtered_masks = [] + for i in range(len(data["input_ids"])): + if data["loss_mask"][i].sum() >= min_loss_tokens: + filtered_ids.append(data["input_ids"][i]) + filtered_masks.append(data["loss_mask"][i]) + print_rank0(f"After filtering: {len(filtered_ids)} samples (min {min_loss_tokens} loss tokens)") + data = {"input_ids": filtered_ids, "loss_mask": filtered_masks} + + dataset = DFlashDataset(data) + sampler = DistributedSampler(dataset, shuffle=True) + dataloader = DataLoader( + dataset, + batch_size=args.batch_size, + sampler=sampler, + collate_fn=collate_fn, + num_workers=2, + pin_memory=True, + drop_last=True, + ) + + # Wrap with DDP + model = torch.nn.parallel.DistributedDataParallel( + model, + device_ids=[local_rank], + find_unused_parameters=True, + ) + raw_model = model.module + + # Optimizer — only train dflash_module + optimizer = torch.optim.AdamW( + [p for p in raw_model.dflash_module.parameters() if p.requires_grad], + lr=args.lr, + weight_decay=0.0, + ) + + # LR scheduler + steps_per_epoch = len(dataloader) + total_steps = args.num_epochs * steps_per_epoch + warmup_steps = int(total_steps * args.warmup_ratio) + + def lr_lambda(step): + if step < warmup_steps: + return step / max(warmup_steps, 1) + progress = (step - warmup_steps) / max(total_steps - warmup_steps, 1) + return 0.5 * (1.0 + math.cos(math.pi * progress)) + + scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) + + print_rank0(f"Training: {total_steps} steps, {warmup_steps} warmup, {steps_per_epoch}/epoch") + + # Training loop + global_step = 0 + for epoch in range(args.num_epochs): + sampler.set_epoch(epoch) + model.train() + + for batch in dataloader: + input_ids = batch["input_ids"].to(device) + loss_mask = batch["loss_mask"].to(device) + + # Create labels from loss_mask: -100 for masked positions + labels = input_ids.clone() + labels[loss_mask == 0] = -100 + + output = model( + input_ids=input_ids, + attention_mask=torch.ones_like(input_ids), + labels=labels, + ) + + loss = output.loss + loss.backward() + optimizer.step() + scheduler.step() + optimizer.zero_grad() + + global_step += 1 + + if global_step % args.log_interval == 0: + acc = output.train_acc[0][0] if hasattr(output, "train_acc") else 0.0 + lr = scheduler.get_last_lr()[0] + print_rank0( + f"Step {global_step} | loss={loss.item():.4f} | acc={acc:.4f} | lr={lr:.2e}" + ) + + if args.save_interval > 0 and global_step % args.save_interval == 0: + if is_rank0(): + save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") + raw_model.save_pretrained(save_path) + print_rank0(f"Saved checkpoint: {save_path}") + + # Save final model + if is_rank0(): + os.makedirs(args.output_dir, exist_ok=True) + raw_model.save_pretrained(args.output_dir) + print_rank0(f"Saved final model: {args.output_dir}") + + dist.barrier() + + # AR validation on rank 0 + if is_rank0() and args.num_ar_samples > 0: + print_rank0("\n=== AR Validation ===") + model.eval() + from modelopt.torch.speculative.plugins.transformers import HFARValidation + + validator = HFARValidation(raw_model, tokenizer) + ds = load_dataset("/hf-local/HuggingFaceH4/mt_bench_prompts")["train"] + + ars = [] + for i in range(min(args.num_ar_samples, len(ds))): + prompt = ds[i]["prompt"][0] + chat = [{"role": "user", "content": prompt}] + text = tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True) + inp = tokenizer(text, return_tensors="pt").input_ids.to(device) + try: + _, ar = validator.validate(osl=32, input_ids=inp, steps=3) + ars.append(ar) + print_rank0(f" AR={ar:.2f} | {prompt[:60]}") + except Exception as e: + print_rank0(f" ERROR | {prompt[:60]}... | {e}") + + if ars: + avg = sum(ars) / len(ars) + print_rank0("\n==== DFlash AR Results ====") + print_rank0(f"Average AR: {avg:.4f}") + print_rank0(f"Min: {min(ars):.4f}, Max: {max(ars):.4f}") + + dist.destroy_process_group() + + +if __name__ == "__main__": + args = parse_args() + train(args) diff --git a/modelopt/torch/export/plugins/hf_spec_export.py b/modelopt/torch/export/plugins/hf_spec_export.py index aca19a1580..29b2cf633f 100644 --- a/modelopt/torch/export/plugins/hf_spec_export.py +++ b/modelopt/torch/export/plugins/hf_spec_export.py @@ -27,7 +27,7 @@ from .hf_spec_configs import kimik2_eagle_template_config, llama_eagle_template_config -ALL_SPEC_MODES = ["eagle"] +ALL_SPEC_MODES = ["eagle", "dflash"] LLAMA_EAGLE_SINGLE_LAYER = { "required": { @@ -243,3 +243,106 @@ def _extract_state_dict(self, full_state_dict: dict): export_sd.pop(f"parallel_draft_heads.medusa_heads.{i}.{j}.linear.bias") ) return export_sd + + +class DFlashExporter(SpeculativeDecodingExporter): + """Draft model exporter for DFlash. + + Exports in z-lab compatible format: + - model.safetensors: draft module weights (no prefix) + - config.json: Qwen3-style config with dflash_config field + """ + + def __init__(self, model: nn.Module): + """Initialize the DFlashExporter.""" + super().__init__(model) + + def _extract_state_dict(self, full_state_dict: dict): + """Extract DFlash module weights, stripping the dflash_module prefix.""" + export_sd = {} + for key, value in full_state_dict.items(): + if "dflash_module." in key: + export_key = key.split("dflash_module.", 1)[1] + # Skip rotary embedding buffers (not needed, recomputed) + if "rotary_emb" in export_key: + continue + export_sd[export_key] = value.clone() + return export_sd + + def _export_config(self): + """Build config.json matching z-lab DFlash format.""" + model = self.model + base_config = ( + getattr(model.config, "text_config", None) + or getattr(model.config, "llm_config", None) + or model.config + ) + draft_config = model.dflash_config + + config = { + "architectures": ["DFlashDraftModel"], + "model_type": getattr(base_config, "model_type", "qwen3"), + "block_size": model.dflash_block_size, + "dflash_config": { + "mask_token_id": model.mask_token_id, + "target_layer_ids": list(model.target_layer_ids), + }, + # Architecture dimensions + "hidden_size": getattr(draft_config, "hidden_size", base_config.hidden_size), + "num_hidden_layers": draft_config.num_hidden_layers, + "num_attention_heads": getattr( + draft_config, "num_attention_heads", base_config.num_attention_heads + ), + "num_key_value_heads": getattr( + draft_config, "num_key_value_heads", base_config.num_key_value_heads + ), + "head_dim": getattr( + draft_config, + "head_dim", + base_config.hidden_size // base_config.num_attention_heads, + ), + "intermediate_size": getattr( + draft_config, "intermediate_size", base_config.intermediate_size + ), + "hidden_act": getattr(draft_config, "hidden_act", "silu"), + "rms_norm_eps": getattr(draft_config, "rms_norm_eps", 1e-6), + "vocab_size": base_config.vocab_size, + "max_position_embeddings": getattr(base_config, "max_position_embeddings", 32768), + "initializer_range": getattr(base_config, "initializer_range", 0.02), + "attention_bias": getattr(draft_config, "attention_bias", False), + "attention_dropout": getattr(draft_config, "attention_dropout", 0.0), + "rope_theta": getattr(base_config, "rope_theta", 1000000.0), + "tie_word_embeddings": False, + "torch_dtype": "bfloat16", + "num_target_layers": getattr(base_config, "num_hidden_layers", 36), + } + + # Add layer_types if present (Qwen3-style) + if hasattr(draft_config, "layer_types"): + config["layer_types"] = draft_config.layer_types + else: + config["layer_types"] = ["full_attention"] * draft_config.num_hidden_layers + + return config + + def export(self, export_dir: Path | str, dtype: torch.dtype | None = None): + """Export the DFlash draft model to deployment format.""" + export_dir = Path(export_dir) + export_dir.mkdir(parents=True, exist_ok=True) + + # Export state dict + full_sd = self.model.state_dict() + drafter_sd = self._extract_state_dict(full_sd) + if dtype is not None: + drafter_sd = {k: v.to(dtype) for k, v in drafter_sd.items()} + save_file(drafter_sd, f"{export_dir}/model.safetensors") + + # Export config + drafter_config = self._export_config() + with open(f"{export_dir}/config.json", "w") as f: + json.dump(drafter_config, f, indent=2) + + print( + f"Exported DFlash draft model: {len(drafter_sd)} tensors, " + f"config keys: {list(drafter_config.keys())[:5]}..." + ) diff --git a/modelopt/torch/speculative/config.py b/modelopt/torch/speculative/config.py index 69491c6599..5202865efb 100644 --- a/modelopt/torch/speculative/config.py +++ b/modelopt/torch/speculative/config.py @@ -46,6 +46,61 @@ } +def _get_dflash_default_config(): + from .dflash.default_config import default_dflash_config + + return default_dflash_config + + +DFLASH_DEFAULT_CFG = { + "algorithm": "dflash", + "config": { + "dflash_architecture_config": {}, # merged with default at convert time + }, +} + + +class DFlashConfig(ModeloptBaseConfig): + """DFlash config for block-wise parallel speculative decoding.""" + + dflash_block_size: int = ModeloptField( + default=16, + description="Block size for parallel prediction. Draft predicts this many tokens per block.", + ) + + dflash_freeze_base_model: bool = ModeloptField( + default=True, description="Whether to freeze base model during DFlash module training." + ) + + dflash_self_logit_distillation: bool = ModeloptField( + default=True, description="Whether to use logit distillation from base model." + ) + + dflash_loss_decay_factor: float = ModeloptField( + default=0.0, + description="Gamma for exponential loss decay weighting (paper Eq.4). " + "Suggested: 7 for block_size=16, 5 for 10, 4 for 8. 0 disables.", + ) + + dflash_num_anchors: int = ModeloptField( + default=512, + description="Number of random anchor positions sampled per sequence during training.", + ) + + dflash_report_acc: bool = ModeloptField( + default=True, description="Whether to report eval accuracy." + ) + + dflash_architecture_config: dict = ModeloptField( + default={}, description="Config for the DFlash draft module architecture." + ) + + dflash_use_torch_compile: bool = ModeloptField( + default=True, + description="Whether to use torch.compile on DFlash forward/loss methods.", + ) + + class MedusaConfig(ModeloptBaseConfig): """Medusa config.""" diff --git a/modelopt/torch/speculative/dflash/__init__.py b/modelopt/torch/speculative/dflash/__init__.py new file mode 100644 index 0000000000..912b8d47a2 --- /dev/null +++ b/modelopt/torch/speculative/dflash/__init__.py @@ -0,0 +1,20 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""DFlash Optimization Method.""" + +from .conversion import * +from .default_config import * +from .dflash_model import * diff --git a/modelopt/torch/speculative/dflash/conversion.py b/modelopt/torch/speculative/dflash/conversion.py new file mode 100644 index 0000000000..943be90ca0 --- /dev/null +++ b/modelopt/torch/speculative/dflash/conversion.py @@ -0,0 +1,58 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""DFlash conversion/restore utilities.""" + +from torch import nn + +from modelopt.torch.opt.conversion import ModelLikeModule +from modelopt.torch.opt.dynamic import _DMRegistryCls +from modelopt.torch.opt.mode import ConvertReturnType, MetadataDict + +from ..config import DFlashConfig + +DFlashDMRegistry = _DMRegistryCls(prefix="DFlash") # global instance for the registry + + +def convert_to_dflash_model(model: nn.Module, config: DFlashConfig) -> ConvertReturnType: + """Convert the model to a DFlash model as per `config`.""" + model = model.init_modellike() if isinstance(model, ModelLikeModule) else model + + original_cls = type(model) + if original_cls not in DFlashDMRegistry: + for cls in DFlashDMRegistry._registry: + if issubclass(original_cls, cls): + DFlashDMRegistry.register({original_cls: "base_model_class"})(DFlashDMRegistry[cls]) + break + + # merge custom config with default config (lazy import to avoid circular) + from .default_config import default_dflash_config + + custom_config = config.dflash_architecture_config + config.dflash_architecture_config = {**default_dflash_config, **custom_config} + + dflash_model = DFlashDMRegistry.convert(model) + dflash_model.modify(config) + + metadata = {} + return dflash_model, metadata + + +def restore_dflash_model( + model: nn.Module, config: DFlashConfig, metadata: MetadataDict +) -> nn.Module: + """Function for restoring a previously converted model to a DFlash model.""" + assert not metadata, "No metadata expected!" + return convert_to_dflash_model(model, config)[0] diff --git a/modelopt/torch/speculative/dflash/default_config.py b/modelopt/torch/speculative/dflash/default_config.py new file mode 100644 index 0000000000..5536e0d4df --- /dev/null +++ b/modelopt/torch/speculative/dflash/default_config.py @@ -0,0 +1,28 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Default DFlash architecture config. + +Model-specific settings (hidden_size, num_attention_heads, rope_*, etc.) +are inherited from the base model in HFDFlashModel.modify(). Only +DFlash-specific defaults are set here. +""" + +default_dflash_config = { + "num_hidden_layers": 5, + "rms_norm_eps": 1e-06, + "attention_bias": False, + "attention_dropout": 0.0, +} diff --git a/modelopt/torch/speculative/dflash/dflash_model.py b/modelopt/torch/speculative/dflash/dflash_model.py new file mode 100644 index 0000000000..0a10f065eb --- /dev/null +++ b/modelopt/torch/speculative/dflash/dflash_model.py @@ -0,0 +1,36 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""DFlash model to support block-wise parallel speculative decoding.""" + +from modelopt.torch.opt.dynamic import DynamicModule + + +class DFlashModel(DynamicModule): + """Base DFlash Model.""" + + def _setup(self): + """Register temporary attributes for the DFlash module.""" + self._register_temp_attribute("dflash_module", None) + + def modify(self, config): + """Base DFlash Model modify function. Child class should implement the details.""" + self.dflash_block_size = config.dflash_block_size + self.dflash_freeze_base_model = config.dflash_freeze_base_model + self.dflash_loss_decay_factor = config.dflash_loss_decay_factor + self.dflash_self_logit_distillation = config.dflash_self_logit_distillation + self.dflash_num_anchors = config.dflash_num_anchors + self.dflash_report_acc = config.dflash_report_acc + self.dflash_use_torch_compile = config.dflash_use_torch_compile diff --git a/modelopt/torch/speculative/mode.py b/modelopt/torch/speculative/mode.py index 866449e155..ae965354a9 100644 --- a/modelopt/torch/speculative/mode.py +++ b/modelopt/torch/speculative/mode.py @@ -23,7 +23,8 @@ _ModeRegistryCls, ) -from .config import EagleConfig, MedusaConfig +from .config import DFlashConfig, EagleConfig, MedusaConfig +from .dflash.conversion import convert_to_dflash_model, restore_dflash_model from .eagle.conversion import convert_to_eagle_model, restore_eagle_model from .medusa.conversion import convert_to_medusa_model, restore_medusa_model @@ -58,6 +59,34 @@ def restore(self) -> RestoreEntrypoint: return restore_medusa_model +@SpeculativeDecodingModeRegistry.register_mode +class DFlashModeDescriptor(ModeDescriptor): + """Class to describe the ``"dflash"`` mode. + + The properties of this mode can be inspected via the source code. + """ + + @property + def name(self) -> str: + """Returns the value (str representation) of the mode.""" + return "dflash" + + @property + def config_class(self) -> type[ModeloptBaseConfig]: + """Specifies the config class for the mode.""" + return DFlashConfig + + @property + def convert(self) -> ConvertEntrypoint: + """The mode's entrypoint for converting a model.""" + return convert_to_dflash_model + + @property + def restore(self) -> RestoreEntrypoint: + """The mode's entrypoint for restoring a model.""" + return restore_dflash_model + + @SpeculativeDecodingModeRegistry.register_mode class EagleModeDescriptor(ModeDescriptor): """Class to describe the ``"eagle"`` mode. diff --git a/modelopt/torch/speculative/plugins/__init__.py b/modelopt/torch/speculative/plugins/__init__.py index 5e3f4bff2f..d59aed37d5 100644 --- a/modelopt/torch/speculative/plugins/__init__.py +++ b/modelopt/torch/speculative/plugins/__init__.py @@ -31,3 +31,6 @@ with import_plugin("transformers"): from .transformers import * + +with import_plugin("hf_dflash"): + from .hf_dflash import * diff --git a/modelopt/torch/speculative/plugins/hf_dflash.py b/modelopt/torch/speculative/plugins/hf_dflash.py new file mode 100644 index 0000000000..9257f50e0d --- /dev/null +++ b/modelopt/torch/speculative/plugins/hf_dflash.py @@ -0,0 +1,882 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""DFlash speculative decoding plugin for HuggingFace models. + +Matches the reference SpecForge implementation (github.com/sgl-project/SpecForge PR #415). + +Architecture: +- Feature Fusion: multi-layer target hidden states → FC + RMSNorm +- KV Injection: fused features as K/V in every draft layer with QK-norm +- Parallel Drafting: mask_token_id for unknown positions, causal within blocks +- Loss: hard CE on input_ids[i] (position i predicts token i) + +Reference: "DFlash: Block Diffusion for Flash Speculative Decoding" (arXiv:2602.06036) +""" + +import importlib + +import torch +import torch.nn.functional as F +from torch import nn +from transformers import PretrainedConfig, PreTrainedModel +from transformers.utils import ModelOutput + +from ..dflash.conversion import DFlashDMRegistry +from ..dflash.dflash_model import DFlashModel + + +def _resolve_model_components(model_type): + """Resolve MLP, RMSNorm, RotaryEmbedding from the base model's transformers module. + + Falls back to Llama components if the model type is unknown. + """ + fallback = "llama" + model_type = model_type or fallback + try: + mod = importlib.import_module(f"transformers.models.{model_type}.modeling_{model_type}") + except (ImportError, ModuleNotFoundError): + mod = importlib.import_module(f"transformers.models.{fallback}.modeling_{fallback}") + model_type = fallback + + prefix = model_type.capitalize() + # Handle multi-word model types (e.g., "qwen3" -> "Qwen3") + for attr in dir(mod): + if attr.lower() == f"{model_type}mlp": + prefix = attr.replace("MLP", "") + break + + mlp_cls = getattr(mod, f"{prefix}MLP", None) + norm_cls = getattr(mod, f"{prefix}RMSNorm", None) + rotary_cls = getattr(mod, f"{prefix}RotaryEmbedding", None) + rotate_half_fn = getattr(mod, "rotate_half", None) + + # Fallback to Llama if any component is missing + if not all([mlp_cls, norm_cls, rotary_cls, rotate_half_fn]): + from transformers.models.llama.modeling_llama import ( + LlamaMLP, + LlamaRMSNorm, + LlamaRotaryEmbedding, + ) + from transformers.models.llama.modeling_llama import rotate_half as _rotate_half + + mlp_cls = mlp_cls or LlamaMLP + norm_cls = norm_cls or LlamaRMSNorm + rotary_cls = rotary_cls or LlamaRotaryEmbedding + rotate_half_fn = rotate_half_fn or _rotate_half + + return mlp_cls, norm_cls, rotary_cls, rotate_half_fn + + +# Default to Llama components; overridden per-model during convert() +_MLP_CLS, _NORM_CLS, _ROTARY_CLS, _rotate_half = _resolve_model_components("llama") + +__all__ = ["HFDFlashModel"] + + +def build_target_layer_ids(num_target_layers, num_draft_layers): + """Select layers uniformly from the target model for feature extraction.""" + if num_draft_layers == 1: + return [num_target_layers // 2] + start = 1 + end = num_target_layers - 3 + span = end - start + return [round(start + (i * span) / (num_draft_layers - 1)) for i in range(num_draft_layers)] + + +def apply_rotary_pos_emb(q, k, cos, sin): + """Apply RoPE. Q uses last q_len positions, K uses all positions.""" + cos = cos.unsqueeze(1) # [B, 1, seq, dim] + sin = sin.unsqueeze(1) + q_len = q.size(2) + q_embed = (q * cos[:, :, -q_len:, :]) + (_rotate_half(q) * sin[:, :, -q_len:, :]) + k_embed = (k * cos) + (_rotate_half(k) * sin) + return q_embed, k_embed + + +class DFlashAttention(nn.Module): + """Attention with KV injection, using HF's attention dispatch for exact SpecForge parity.""" + + def __init__(self, config, layer_idx): + """Initialize DFlash attention with KV injection projections and QK-norm.""" + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.head_dim = getattr( + config, "head_dim", config.hidden_size // config.num_attention_heads + ) + self.num_heads = config.num_attention_heads + self.num_kv_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_kv_heads + self.scaling = self.head_dim**-0.5 + self.attention_dropout = getattr(config, "attention_dropout", 0.0) + self.is_causal = False + + attn_bias = getattr(config, "attention_bias", False) + self.q_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=attn_bias) + self.k_proj = nn.Linear( + config.hidden_size, self.num_kv_heads * self.head_dim, bias=attn_bias + ) + self.v_proj = nn.Linear( + config.hidden_size, self.num_kv_heads * self.head_dim, bias=attn_bias + ) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, config.hidden_size, bias=attn_bias) + + self.q_norm = _NORM_CLS(self.head_dim, eps=config.rms_norm_eps) + self.k_norm = _NORM_CLS(self.head_dim, eps=config.rms_norm_eps) + + # Resolve HF attention function matching SpecForge's dispatch + self._attn_fn = None + self.sliding_window = None + + def _get_attn_fn(self): + """Lazily resolve the HF attention function.""" + if self._attn_fn is not None: + return self._attn_fn + try: + from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS + + impl = getattr(self.config, "_attn_implementation", "eager") + if impl and impl != "eager" and impl in ALL_ATTENTION_FUNCTIONS: + self._attn_fn = ALL_ATTENTION_FUNCTIONS[impl] + else: + self._attn_fn = self._eager_attention + except (ImportError, AttributeError): + self._attn_fn = self._eager_attention + return self._attn_fn + + def _eager_attention(self, module, q, k, v, attention_mask, **kwargs): + """Eager attention matching HF's eager_attention_forward.""" + scaling = kwargs.get("scaling", self.scaling) + n_rep = self.num_key_value_groups + if n_rep > 1: + k = k.repeat_interleave(n_rep, dim=1) + v = v.repeat_interleave(n_rep, dim=1) + attn_weights = torch.matmul(q, k.transpose(2, 3)) * scaling + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to( + q.dtype + ) + attn_output = torch.matmul(attn_weights, v) + attn_output = attn_output.transpose(1, 2).contiguous() + return attn_output, None + + def forward(self, hidden_states, target_hidden, position_embeddings, attention_mask=None): + """Forward with KV injection: Q from noise, K/V from context+noise.""" + bsz, q_len, _ = hidden_states.shape + ctx_len = target_hidden.shape[1] + + # Q from noise only, with QK-norm + q = self.q_proj(hidden_states).view(bsz, q_len, -1, self.head_dim) + q = self.q_norm(q).transpose(1, 2) + + # K from context + noise, with QK-norm + k_ctx = self.k_proj(target_hidden) + k_noise = self.k_proj(hidden_states) + k = torch.cat([k_ctx, k_noise], dim=1).view(bsz, ctx_len + q_len, -1, self.head_dim) + k = self.k_norm(k).transpose(1, 2) + + # V from context + noise (no norm) + v_ctx = self.v_proj(target_hidden) + v_noise = self.v_proj(hidden_states) + v = ( + torch.cat([v_ctx, v_noise], dim=1) + .view(bsz, ctx_len + q_len, -1, self.head_dim) + .transpose(1, 2) + ) + + # RoPE + cos, sin = position_embeddings + q, k = apply_rotary_pos_emb(q, k, cos, sin) + + # Use HF's attention dispatch (handles GQA internally) + attn_fn = self._get_attn_fn() + attn_output, _ = attn_fn( + self, + q, + k, + v, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + sliding_window=self.sliding_window, + ) + attn_output = attn_output.reshape(bsz, q_len, -1) + return self.o_proj(attn_output) + + +class DFlashDecoderLayer(nn.Module): + """Draft decoder layer with KV injection.""" + + def __init__(self, config, layer_idx): + """Initialize decoder layer with attention, MLP, and layer norms.""" + super().__init__() + self.self_attn = DFlashAttention(config, layer_idx) + self.mlp = _MLP_CLS(config) + self.input_layernorm = _NORM_CLS(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = _NORM_CLS(config.hidden_size, eps=config.rms_norm_eps) + + def forward(self, hidden_states, target_hidden, position_embeddings, attention_mask=None): + """Forward pass with residual connections.""" + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + hidden_states = self.self_attn( + hidden_states, target_hidden, position_embeddings, attention_mask + ) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + return hidden_states + + +class DFlashModule(nn.Module): + """DFlash draft module matching SpecForge DFlashDraftModel.""" + + def __init__(self, config): + """Initialize DFlash module with feature fusion, decoder layers, and rotary embeddings.""" + super().__init__() + self.config = config + self.block_size = config.block_size + + # Feature fusion + num_fused_layers = len(config.target_layer_ids) + self.fc = nn.Linear(num_fused_layers * config.hidden_size, config.hidden_size, bias=False) + self.hidden_norm = _NORM_CLS(config.hidden_size, eps=config.rms_norm_eps) + + # Decoder layers + self.layers = nn.ModuleList( + [DFlashDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = _NORM_CLS(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = _ROTARY_CLS(config=config) + + # Initialize weights matching HF PreTrainedModel (normal_ with initializer_range) + # SpecForge's DFlashDraftModel uses Qwen3PreTrainedModel.post_init() which does this. + self._init_weights(config) + + def _init_weights(self, config): + """Initialize weights matching HF PreTrainedModel._init_weights.""" + std = getattr(config, "initializer_range", 0.02) + for module in self.modules(): + if isinstance(module, nn.Linear): + nn.init.normal_(module.weight, mean=0.0, std=std) + if module.bias is not None: + nn.init.zeros_(module.bias) + + def forward(self, noise_embedding, target_hidden, position_ids, attention_mask=None): + """Forward matching SpecForge DFlashDraftModel.forward.""" + hidden_states = noise_embedding + target_hidden = self.hidden_norm(self.fc(target_hidden)) + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + for layer in self.layers: + hidden_states = layer(hidden_states, target_hidden, position_embeddings, attention_mask) + + return self.norm(hidden_states) + + +def create_dflash_attention_mask( + seq_len, block_size, device, dtype +): # Legacy: used for inference only + """Create [L, 2L] attention mask matching SpecForge. + + Context (cols 0..L-1): Block B sees blocks 0..B-1 (strictly previous). + Noise (cols L..2L-1): causal within same block only. + """ + indices = torch.arange(seq_len, device=device) + block_ids = indices // block_size + + q_block_ids = block_ids.unsqueeze(1) # [L, 1] + k_block_ids = block_ids.unsqueeze(0) # [1, L] + + ctx_mask = k_block_ids < q_block_ids + same_block = q_block_ids == k_block_ids + causal = indices.unsqueeze(0) >= indices.unsqueeze(1) # matching SpecForge: j >= i + noise_mask = same_block & causal + + full_mask_bool = torch.cat([ctx_mask, noise_mask], dim=1) + + # Create in f32 then cast, matching SpecForge. This ensures masked + # positions get -inf in bf16 (f32 min overflows to -inf when cast), + # not the largest finite negative bf16 value. + full_mask = torch.zeros(seq_len, 2 * seq_len, device=device, dtype=torch.float32) + full_mask.masked_fill_(~full_mask_bool, torch.finfo(torch.float32).min) + full_mask = full_mask.to(dtype=dtype) + + return full_mask.unsqueeze(0).unsqueeze(0) # [1, 1, L, 2L] + + +def create_dflash_loss_mask(seq_len, block_size, device): # Legacy: used for inference only + """Create loss mask: exclude Block 0 and block starts.""" + positions = torch.arange(seq_len, device=device) + block_ids = positions // block_size + is_block_0 = block_ids == 0 + is_block_start = (positions % block_size) == 0 + return (~is_block_0 & ~is_block_start).float() + + +@DFlashDMRegistry.register({PreTrainedModel: "hf.PreTrainedModel"}) +class HFDFlashModel(DFlashModel): + """DFlash Model matching SpecForge OnlineDFlashModel.""" + + @property + def _base_model(self): + return self.get_submodule(self.base_model_path) + + @property + def _base_model_embeddings(self): + return self.get_submodule(self.base_model_embeddings_path) + + @property + def _base_model_lm_head(self): + return self.get_submodule(self.base_model_lm_head_path) + + @property + def _base_llm_config(self): + return ( + getattr(self.config, "text_config", None) + or getattr(self.config, "llm_config", None) + or self.config + ) + + @staticmethod + def _auto_detect_mask_token_id(base_config): + """Auto-detect an appropriate mask token ID for DFlash. + + Different model families use different strategies: + - Qwen3/3.5: built-in [MASK] token in vocabulary + - Llama3: reserved special tokens (128002 = reserved_special_token_0) + - Others: try tokenizer.mask_token_id, then fall back to pad/eos + """ + model_type = getattr(base_config, "model_type", "") + vocab_size = getattr(base_config, "vocab_size", 0) + + # Qwen3/3.5: known mask token positions + if "qwen3" in model_type.lower() or "qwen" in model_type.lower(): + # Qwen3 vocab has dedicated mask tokens + # Qwen3.5-4B: 248070, Qwen3-8B: similar range + # Heuristic: eos_token_id + some offset, or check known values + eos = getattr(base_config, "eos_token_id", None) + if isinstance(eos, list): + eos = eos[0] + if eos and vocab_size > 200000: + # Large Qwen vocab — mask token is typically near end of special tokens + # Known: Qwen3.5 eos=248044, mask=248070 (offset ~26) + # Try common offsets + for offset in [26, 25, 24]: + candidate = eos + offset + if candidate < vocab_size: + return candidate + # Fallback for smaller Qwen models + if vocab_size > 150000: + return vocab_size - 250 # heuristic for Qwen special token region + + # Llama3: use reserved_special_token_0 (128002) + if "llama" in model_type.lower(): + if vocab_size >= 128256: # Llama3 vocab size + return 128002 # <|reserved_special_token_0|> + + # Generic: try pad_token_id, then eos + pad_id = getattr(base_config, "pad_token_id", None) + eos_id = getattr(base_config, "eos_token_id", None) + if isinstance(eos_id, list): + eos_id = eos_id[0] + + # Prefer pad over eos (pad is less likely to interfere) + if pad_id is not None and pad_id != eos_id: + return pad_id + + # Last resort + return eos_id or 0 + + def _find_base_model_parts(self): + """Locate base model submodules (backbone, embeddings, lm_head) by probing known paths.""" + for name, paths in { + "base_model_path": ["model.language_model", "model", "backbone"], + "base_model_embeddings_path": [ + "model.embed_tokens", + "backbone.embeddings", + "model.language_model.embed_tokens", + ], + "base_model_lm_head_path": ["lm_head", "language_model.lm_head"], + }.items(): + for path in paths: + try: + submodule = self.get_submodule(path) + assert isinstance(submodule, torch.nn.Module) + setattr(self, name, path) + break + except Exception: + continue + else: + raise ValueError(f"Part {name} not found in model") + + def modify(self, config): + """Initialize DFlash draft module.""" + super().modify(config) + + base_config = self._base_llm_config + self.dflash_config = PretrainedConfig.from_dict(config.dflash_architecture_config) + + # Inherit settings from base model, but only those NOT already in the user config. + # hidden_size and vocab_size MUST match. Others (heads, intermediate_size) can differ. + # This allows the draft model to have a different architecture than the base model. + self.dflash_config.hidden_size = base_config.hidden_size + self.dflash_config.vocab_size = base_config.vocab_size + + # These use base model defaults if not specified in dflash_architecture_config + for attr, default_from_base in [ + ("max_position_embeddings", True), + ("intermediate_size", True), + ("num_attention_heads", True), + ("num_key_value_heads", True), + ("hidden_act", True), + ("rope_theta", True), + ("rope_scaling", True), + ("rope_type", False), + ("position_embedding_type", False), + ("rope_interleaved", False), + ("rms_norm_eps", True), + ("attention_bias", False), + ("tie_word_embeddings", False), + ]: + if not hasattr(self.dflash_config, attr) or getattr(self.dflash_config, attr) is None: + if default_from_base and hasattr(base_config, attr): + setattr(self.dflash_config, attr, getattr(base_config, attr)) + + # Ensure required attrs have defaults + if not hasattr(self.dflash_config, "mlp_bias") or self.dflash_config.mlp_bias is None: + self.dflash_config.mlp_bias = False + + self.dflash_config.head_dim = getattr( + self.dflash_config, + "head_dim", + self.dflash_config.hidden_size // self.dflash_config.num_attention_heads, + ) + self.dflash_config.block_size = self.dflash_block_size + # Default to sdpa, matching SpecForge's DFlashDraftModel(Qwen3PreTrainedModel) + # which resolves to sdpa via post_init() + if self.dflash_config._attn_implementation is None: + self.dflash_config._attn_implementation = "sdpa" + + # Target layer IDs + num_target_layers = base_config.num_hidden_layers + num_draft_layers = self.dflash_config.num_hidden_layers + self.target_layer_ids = build_target_layer_ids(num_target_layers, num_draft_layers) + self.dflash_config.target_layer_ids = self.target_layer_ids + + # mask_token_id resolution order: + # 1. Explicit in dflash_architecture_config (user override) + # 2. Auto-detect from model vocabulary: + # - Qwen3/3.5: built-in [MASK] token + # - Llama3: reserved_special_token_0 (128002) + # - Others: tokenizer.mask_token_id + # 3. Fallback to pad_token_id or eos_token_id (suboptimal) + mask_id = config.dflash_architecture_config.get("mask_token_id", None) + if mask_id is None: + mask_id = self._auto_detect_mask_token_id(base_config) + self.mask_token_id = mask_id[0] if isinstance(mask_id, list) else mask_id + print(f"DFlash mask_token_id: {self.mask_token_id}") + + # Freeze base model + if self.dflash_freeze_base_model: + for param in self.parameters(): + param.requires_grad = False + + self._find_base_model_parts() + + # Resolve model-specific components (MLP, RMSNorm, RotaryEmbedding) + # from the base model's architecture for weight compatibility + global _MLP_CLS, _NORM_CLS, _ROTARY_CLS, _rotate_half + _MLP_CLS, _NORM_CLS, _ROTARY_CLS, _rotate_half = _resolve_model_components( + getattr(base_config, "model_type", "llama") + ) + self.dflash_module = DFlashModule(self.dflash_config) + self.dflash_module.to(self._base_model.dtype).to( + next(self._base_model.layers[-1].parameters()).device + ) + + self.is_quantized = False + self._num_anchors = self.dflash_num_anchors + + # Store bound reference to the original model class's forward. + # DynamicModule changes type(self) but the original class is in _original_cls. + # Find the original HF model class (e.g., Qwen3_5ForConditionalGeneration) + # by walking MRO and skipping DFlash/DynamicModule classes + skip_names = { + "HFDFlashModel", + "DFlashModel", + "DynamicModule", + "DFlashPreTrainedModel", + "DFlashDraftModel", + } + original_cls = None + for cls in type(self).__mro__: + if ( + hasattr(cls, "forward") + and cls.__name__ not in skip_names + and cls is not type(self) + and issubclass(cls, PreTrainedModel) + and cls is not PreTrainedModel + ): + original_cls = cls + break + if original_cls is None: + # Last resort: use the class two levels up (skip DFlash wrapper + DynamicModule) + original_cls = type(self).__mro__[2] + self._original_forward_cls = original_cls + print(f"DFlash: using {original_cls.__name__}.forward as base forward") + + def get_exporter(self): + """Get the exporter for the DFlash draft model.""" + from modelopt.torch.export.plugins.hf_spec_export import DFlashExporter + + return DFlashExporter(self) + + def _base_forward(self, **kwargs): + """Call the original model's forward, bypassing DFlash wrapper.""" + return self._original_forward_cls.forward(self, **kwargs) + + def _sample_anchor_positions(self, seq_len, loss_mask, device): + """Randomly sample anchor positions per sample, matching SpecForge PR #473. + + Returns (anchor_positions [B, N], block_keep_mask [B, N]). + """ + bs = self.dflash_block_size + bsz = loss_mask.shape[0] + max_anchor = max(seq_len - bs, 0) + num_anchors = getattr(self, "_num_anchors", 512) + + valid = loss_mask[:, : max_anchor + 1] > 0.5 + valid_counts = valid.sum(dim=1) + max_n = min(num_anchors, int(valid_counts.max().item()) - 1) + + if max_n <= 0: + # No valid anchors — return empty + anchors = torch.zeros(bsz, 1, dtype=torch.long, device=device) + keep = torch.zeros(bsz, 1, dtype=torch.bool, device=device) + return anchors, keep + + indices = torch.arange(max_anchor + 1, device=device).unsqueeze(0).expand(bsz, -1) + masked_indices = torch.where(valid, indices, torch.tensor(seq_len + 1, device=device)) + + random_vals = torch.rand(bsz, max_anchor + 1, device=device) + random_vals = torch.where(valid, random_vals, torch.tensor(2.0, device=device)) + + _, sorted_idx = random_vals.sort(dim=1) + gathered = torch.gather(masked_indices, 1, sorted_idx) + anchors = gathered[:, :max_n].sort(dim=1).values + + keep = torch.arange(max_n, device=device).unsqueeze(0) < valid_counts.unsqueeze(1).clamp( + max=max_n + ) + anchors = torch.where(keep, anchors, torch.tensor(0, dtype=torch.long, device=device)) + return anchors, keep + + def forward( + self, + input_ids=None, + attention_mask=None, + position_ids=None, + past_key_values=None, + inputs_embeds=None, + labels=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + cache_position=None, + **kwargs, + ): + """Training forward matching SpecForge latest (post-PR #473). + + Key changes from original PR #415: + - Random anchor sampling instead of uniform block division + - Bidirectional intra-block attention (no causal constraint) + - Context sees strictly before anchor position + - Label alignment: position k predicts token at anchor+k + - Optional loss decay weighting + """ + if not self.training: + return super().forward( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + labels=labels, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + cache_position=cache_position, + **kwargs, + ) + + bsz, seq_len = input_ids.shape + block_size = self.dflash_block_size + device = input_ids.device + + # 1. Run base model → hidden states + with torch.no_grad(): + base_outputs = super().forward( + input_ids=input_ids, + attention_mask=attention_mask, + output_hidden_states=True, + ) + + offset = 1 + selected = [base_outputs.hidden_states[lid + offset] for lid in self.target_layer_ids] + target_hidden = torch.cat(selected, dim=-1) # [B, seq, num_layers * H] + + # 2. Build loss mask from labels or attention_mask + if labels is not None: + loss_mask = (labels != -100).float() + elif attention_mask is not None: + loss_mask = attention_mask.float() + else: + loss_mask = torch.ones(bsz, seq_len, device=device) + + # 3. Random anchor sampling (SpecForge PR #463/#473) + anchor_positions, block_keep_mask = self._sample_anchor_positions( + seq_len, loss_mask, device + ) + n_blocks = anchor_positions.shape[1] + + if n_blocks == 0 or not block_keep_mask.any(): + # Zero loss that still flows through dflash_module for DDP gradient sync + dummy = self.dflash_module.fc.weight.sum() * 0.0 + return ModelOutput(loss=dummy, logits=base_outputs.logits, train_acc=[[0.0]]) + + # 4. Create noise embeddings: anchor token at block start, mask_token elsewhere + noise_ids = torch.full( + (bsz, n_blocks * block_size), self.mask_token_id, dtype=torch.long, device=device + ) + block_starts = torch.arange(n_blocks, device=device) * block_size + block_starts_exp = block_starts.unsqueeze(0).expand(bsz, -1) + valid_anchors = anchor_positions.clamp(0, seq_len - 1) + anchor_tokens = torch.gather(input_ids, 1, valid_anchors) + batch_idx = torch.arange(bsz, device=device).unsqueeze(1).expand(bsz, n_blocks) + noise_ids[batch_idx, block_starts_exp] = torch.where( + block_keep_mask, + anchor_tokens, + torch.tensor(self.mask_token_id, dtype=torch.long, device=device), + ) + noise_embedding = self._base_model_embeddings(noise_ids) + + # 5. Position IDs: context [0..S-1], draft blocks [anchor+0..anchor+B-1] + ctx_pos = torch.arange(seq_len, device=device).unsqueeze(0).expand(bsz, -1) + offsets = torch.arange(block_size, device=device).view(1, 1, -1) + draft_pos = (anchor_positions.unsqueeze(-1) + offsets).view(bsz, -1) + full_pos = torch.cat([ctx_pos, draft_pos], dim=1) + + # 6. Attention mask: SDPA bool mask [B, 1, Q_LEN, KV_LEN] + q_len = n_blocks * block_size + kv_len = seq_len + q_len + + q_indices = torch.arange(q_len, device=device).view(1, 1, -1, 1) + kv_indices = torch.arange(kv_len, device=device).view(1, 1, 1, -1) + q_block_ids = q_indices // block_size + + anchor_exp = anchor_positions.view(bsz, 1, n_blocks, 1).repeat_interleave(block_size, dim=2) + + # Context: kv < S and kv < anchor + mask_ctx = (kv_indices < seq_len) & (kv_indices < anchor_exp) + # Draft: kv >= S and same block + is_draft = kv_indices >= seq_len + kv_block_ids = (kv_indices - seq_len) // block_size + mask_draft = is_draft & (q_block_ids == kv_block_ids) + # Valid block + valid_block = block_keep_mask.view(bsz, 1, n_blocks, 1).repeat_interleave(block_size, dim=2) + + final_mask = (mask_ctx | mask_draft) & valid_block # [B, 1, Q, KV] + + # Convert bool mask to float additive mask for SDPA + dtype = target_hidden.dtype + attn_mask = torch.zeros(bsz, 1, q_len, kv_len, device=device, dtype=torch.float32) + attn_mask.masked_fill_(~final_mask, torch.finfo(torch.float32).min) + attn_mask = attn_mask.to(dtype=dtype) + + # 7. Draft forward + hidden = self.dflash_module( + noise_embedding=noise_embedding, + target_hidden=target_hidden, + position_ids=full_pos, + attention_mask=attn_mask, + ) + + # 8. Loss: same-position prediction (position k predicts token at anchor+k) + logits = self._base_model_lm_head(hidden) + + label_offsets = torch.arange(0, block_size, device=device).view(1, 1, -1) + label_indices = anchor_positions.unsqueeze(-1) + label_offsets + valid_label = label_indices < seq_len + safe_label_indices = label_indices.clamp(max=seq_len - 1) + + target_ids = torch.gather( + input_ids.unsqueeze(1).expand(-1, n_blocks, -1), 2, safe_label_indices + ) + + # Weight mask: valid block * in bounds * exclude anchor (pos 0) * loss_mask + weight_mask = block_keep_mask.unsqueeze(-1).expand(-1, -1, block_size).float() + weight_mask = weight_mask * valid_label.float() + pos_in_block = torch.arange(block_size, device=device).view(1, 1, -1) + weight_mask = weight_mask * (pos_in_block > 0).float() + + orig_loss_mask = torch.gather( + loss_mask.unsqueeze(1).expand(-1, n_blocks, -1), 2, safe_label_indices + ) + weight_mask = weight_mask * orig_loss_mask + + binary_eval_mask = weight_mask.view(-1) + + # Optional loss decay + if self.dflash_loss_decay_factor > 0: + k = torch.arange(block_size, device=device).view(1, 1, -1) + decay = torch.exp(-(k - 1).clamp(min=0).float() / self.dflash_loss_decay_factor) + weight_mask = weight_mask * decay + + # Cross entropy or logit distillation + flat_logits = logits.view(-1, logits.size(-1)) + flat_targets = target_ids.view(-1) + flat_weights = weight_mask.view(-1) + + valid_count = flat_weights.sum() + 1e-6 + + if valid_count > 1.0: + if self.dflash_self_logit_distillation: + # Teacher logits at position p predict token p+1 (autoregressive). + # Draft position k predicts token at anchor+k (same position). + # So teacher logits for token anchor+k are at position anchor+k-1. + base_logits = base_outputs.logits # [B, seq, vocab] + teacher_indices = (safe_label_indices - 1).clamp(min=0) + teacher_logits = torch.gather( + base_logits.unsqueeze(1).expand(-1, n_blocks, -1, -1), + 2, + teacher_indices.unsqueeze(-1).expand(-1, -1, -1, base_logits.size(-1)), + ) # [B, N, block_size, vocab] + flat_teacher = teacher_logits.reshape(-1, base_logits.size(-1)).detach() + target_soft = torch.softmax(flat_teacher, dim=-1) + draft_logsoft = torch.log_softmax(flat_logits, dim=-1) + kd_loss = -(target_soft * draft_logsoft).sum(dim=-1) + loss = (kd_loss * flat_weights).sum() / valid_count + else: + loss_per_token = F.cross_entropy(flat_logits, flat_targets, reduction="none") + loss = (loss_per_token * flat_weights).sum() / valid_count + + with torch.no_grad(): + preds = flat_logits.argmax(dim=-1) + correct = (preds == flat_targets) & (binary_eval_mask > 0.5) + accuracy = correct.sum().float() / (binary_eval_mask.sum() + 1e-6) + accuracy = accuracy.item() + else: + loss = flat_logits.sum() * 0.0 + accuracy = 0.0 + + return ModelOutput( + loss=loss, + logits=base_outputs.logits, + train_acc=[[accuracy]], + ) + + @torch.no_grad() + def pseudo_speculative_generate(self, input_ids, steps=1): + """Generate draft tokens using one DFlash block. + + DFlash generates block_size-1 draft tokens in a single forward pass. + The `steps` parameter is used as the number of tokens to return + (capped at block_size-1). + + Returns: + base_token: Next token from base model [B, 1]. + draft_tokens: Draft tokens [B, min(steps, block_size-1)] or None. + """ + # Call the base model's inner model directly (avoids DynamicModule dispatch) + model_output = self._base_model( + input_ids=input_ids, + output_hidden_states=True, + ) + # Compute logits via lm_head + base_logits = self._base_model_lm_head(model_output.last_hidden_state) + # Build output with hidden_states + base_outputs = ModelOutput( + logits=base_logits, + hidden_states=model_output.hidden_states, + ) + base_logits = base_outputs.logits + base_token = base_logits[:, -1:, :].argmax(dim=-1).to(input_ids.device) + + if steps < 1: + return base_token, None + + # Extract target hidden states (raw, before FC projection) + hid_offset = 1 + if not hasattr(self, "_psg_debug"): + self._psg_debug = True + sel = [base_outputs.hidden_states[lid + hid_offset] for lid in self.target_layer_ids] + th_dbg = torch.cat(sel, dim=-1) + n_layers = len(base_outputs.hidden_states) + th_norm = th_dbg.norm().item() + print( + f"[psg] hidden layers: {n_layers}, target_hidden: {th_dbg.shape}, norm: {th_norm:.2f}" + ) + print(f"[psg] base_token: {base_token.item()}, mask_token_id: {self.mask_token_id}") + seq_len = input_ids.shape[1] + blk = self.dflash_block_size + print(f"[psg] pos: ctx=[0..{seq_len - 1}], blk=[{seq_len}..{seq_len + blk - 1}]") + selected = [base_outputs.hidden_states[lid + hid_offset] for lid in self.target_layer_ids] + target_hidden = torch.cat(selected, dim=-1) + + block_size = self.dflash_block_size + bsz = input_ids.shape[0] + seq_len = input_ids.shape[1] + device = input_ids.device + + # Block: first token is base_token (anchor), rest are mask + block_ids = torch.full( + (bsz, block_size), self.mask_token_id, dtype=torch.long, device=device + ) + block_ids[:, 0] = base_token.squeeze(-1) + noise_embedding = self._base_model_embeddings(block_ids) + + # Position IDs: training uses [0..L-1, 0..L-1] where noise positions + # mirror context positions. At inference, block predicts tokens at + # seq_len..seq_len+B-1, so noise positions continue from ctx_len. + ctx_len = target_hidden.shape[1] + ctx_positions = torch.arange(ctx_len, device=device) + block_positions = torch.arange(ctx_len, ctx_len + block_size, device=device) + pos_ids = torch.cat([ctx_positions, block_positions]).unsqueeze(0).expand(bsz, -1) + + # No attention mask at inference — matching SpecForge's spec_generate + # which uses KV cache with no mask. All positions attend freely to + # context and each other within the block. + + # Draft forward + draft_hidden = self.dflash_module( + noise_embedding=noise_embedding, + target_hidden=target_hidden, + position_ids=pos_ids, + attention_mask=None, + ) + + # Logits on positions 1..block_size-1 (skip anchor at position 0) + draft_logits = self._base_model_lm_head(draft_hidden[:, 1:, :]) + draft_tokens = draft_logits.argmax(dim=-1) # [B, block_size-1] + + # Return up to `steps` tokens + num_tokens = min(steps, block_size - 1) + return base_token, draft_tokens[:, :num_tokens] diff --git a/modelopt/torch/utils/plugins/transformers_dataset.py b/modelopt/torch/utils/plugins/transformers_dataset.py index e147ebf2c2..b9a5367cd9 100644 --- a/modelopt/torch/utils/plugins/transformers_dataset.py +++ b/modelopt/torch/utils/plugins/transformers_dataset.py @@ -153,6 +153,9 @@ def __init__( if self.tokenizer.chat_template is None: raise ValueError("No valid chat template!") + if self.answer_only_loss: + self._ensure_generation_tags() + def _post_process_tokenizer(self): if self.tokenizer.pad_token_id is None: print_rank_0("The tokenizer has no pad_token_id, using eos_token_id instead.") @@ -171,6 +174,166 @@ def _post_process_chat_template(self): REMOVE_THINK_CHAT_TEMPLATE, "" ) + # Simplified chat templates with {% generation %} tags for answer_only_loss. + # + # PURPOSE: + # HuggingFace's return_assistant_tokens_mask requires {% generation %} / + # {% endgeneration %} tags in the Jinja chat template to identify which tokens + # belong to assistant responses. Many models (Qwen3, Llama3) ship without these + # tags. These simplified templates add them so that answer_only_loss works + # reliably without regex fallbacks. + # + # HOW IT WORKS: + # When answer_only_loss=True, _ensure_generation_tags() detects the model's + # template style (ChatML, Llama3) and replaces the tokenizer's chat_template + # with one of these simplified versions. The {% generation %} tags tell HF + # exactly which tokens are assistant content for loss masking. + # + # WHAT IS PRESERVED: + # - System / user / assistant role formatting (exact token match) + # - Multi-turn conversation structure + # - block injection on last assistant turn (Qwen3-style, chatml_think) + # - Content is output as-is — training data with blocks is handled correctly + # + # WHAT IS DROPPED (vs original model templates): + # - Tool call formatting (tool_call XML tags, function signatures) + # - Multi-step tool response handling + # - reasoning_content vs content splitting logic + # - enable_thinking parameter support + # - VLM/multimodal content handling + # + # LIMITATIONS: + # - Training data with tool_call messages will not be formatted correctly. + # Use the original template with manually added {% generation %} tags for + # tool-use training data. + # - The chatml_think variant adds \n\n\n\n only to the last + # assistant turn (matching Qwen3 behavior). Non-last turns without + # in their content will differ from the original template which also + # conditionally adds think wrappers based on multi-step reasoning context. + # - Only ChatML (<|im_start|>/<|im_end|>) and Llama3 + # (<|start_header_id|>/<|eot_id|>) styles are supported. Other template + # styles fall back to regex-based assistant span detection. + # + # TO USE A CUSTOM TEMPLATE INSTEAD: + # Pass chat_template= to LanguageDataCollator with your own template that + # includes {% generation %}...{% endgeneration %} around assistant content. + _GENERATION_TEMPLATES = { + # Basic ChatML without injection (Phi, older Qwen, generic ChatML) + "chatml": ( + "{% for message in messages %}" + "{% if message['role'] == 'system' %}" + "{{ '<|im_start|>system\n' + message['content'] + '<|im_end|>\n' }}" + "{% elif message['role'] == 'user' %}" + "{{ '<|im_start|>user\n' + message['content'] + '<|im_end|>\n' }}" + "{% elif message['role'] == 'assistant' %}" + "{{ '<|im_start|>assistant\n' }}" + "{% generation %}" + "{{ message['content'] }}" + "{% endgeneration %}" + "{{ '<|im_end|>\n' }}" + "{% endif %}" + "{% endfor %}" + "{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}" + ), + # ChatML with wrapper on last assistant turn (Qwen3-style) + "chatml_think": ( + "{% for message in messages %}" + "{% if message['role'] == 'system' %}" + "{{ '<|im_start|>system\n' + message['content'] + '<|im_end|>\n' }}" + "{% elif message['role'] == 'user' %}" + "{{ '<|im_start|>user\n' + message['content'] + '<|im_end|>\n' }}" + "{% elif message['role'] == 'assistant' %}" + "{{ '<|im_start|>assistant\n' }}" + "{% generation %}" + "{% if loop.last and not message['content'].startswith('') %}" + "{{ '\n\n\n\n' }}" + "{% endif %}" + "{{ message['content'] }}" + "{% endgeneration %}" + "{{ '<|im_end|>\n' }}" + "{% endif %}" + "{% endfor %}" + "{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}" + ), + "llama3": ( + "{% for message in messages %}" + "{% if message['role'] == 'system' %}" + "{{ '<|start_header_id|>system<|end_header_id|>\n\n' + message['content'] + '<|eot_id|>' }}" + "{% elif message['role'] == 'user' %}" + "{{ '<|start_header_id|>user<|end_header_id|>\n\n' + message['content'] + '<|eot_id|>' }}" + "{% elif message['role'] == 'assistant' %}" + "{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% generation %}" + "{{ message['content'] }}{% endgeneration %}{{ '<|eot_id|>' }}" + "{% endif %}" + "{% endfor %}" + "{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% endif %}" + ), + } + + def _ensure_generation_tags(self): + """Ensure chat template has {% generation %} tags for answer_only_loss. + + If the template already has generation tags, no action taken. + Otherwise, detect the template style and replace with a simplified + version that includes proper generation tags. + """ + template = self.tokenizer.chat_template + if template is None: + return + + if "{% generation %}" in template or "{%generation%}" in template: + return + + # Detect template style and replace with generation-tagged version + old_template = template + if "<|im_start|>" in template and "<|im_end|>" in template: + # Check if original template injects (Qwen3-style) + style = "chatml_think" if "" in template else "chatml" + elif "<|start_header_id|>" in template and "<|eot_id|>" in template: + style = "llama3" + else: + print_rank_0( + "=== WARNING === Cannot auto-inject {% generation %} tags for this chat " + "template. answer_only_loss will not work correctly. Provide a template " + "with {% generation %} tags via the chat_template parameter." + ) + return + + new_template = self._GENERATION_TEMPLATES[style] + self.tokenizer.chat_template = new_template + + # Verify + try: + test_msgs = [ + [ + {"role": "user", "content": "hi"}, + {"role": "assistant", "content": "hello"}, + ] + ] + result = self.tokenizer.apply_chat_template( + test_msgs, + return_dict=True, + return_assistant_tokens_mask=True, + padding=True, + return_tensors="pt", + ) + mask = result.get("assistant_masks", None) + if mask is not None and mask.any(): + print_rank_0( + f"Replaced chat template with {style} generation-tagged version " + f"for answer_only_loss." + ) + return + except Exception: + pass + + # Revert on failure + self.tokenizer.chat_template = old_template + print_rank_0( + f"=== WARNING === Failed to apply {style} generation template. " + "answer_only_loss will not work correctly." + ) + def _process_chat_sample(self, examples: list): tokenized_examples = self.tokenizer.apply_chat_template( examples, @@ -186,6 +349,20 @@ def _process_chat_sample(self, examples: list): input_ids = tokenized_examples["input_ids"] labels = input_ids.new_full(input_ids.shape, IGNORE_TOKEN_ID) labels[..., :-1] = input_ids[..., 1:] + if self.answer_only_loss: + if "assistant_masks" in tokenized_examples: + assistant_mask = tokenized_examples["assistant_masks"] + if isinstance(assistant_mask, torch.Tensor) and assistant_mask.any(): + labels[assistant_mask == 0] = IGNORE_TOKEN_ID + else: + # All assistant content truncated or no assistant in batch — mask all + labels[:] = IGNORE_TOKEN_ID + else: + raise ValueError( + "answer_only_loss requires {% generation %} tags in the chat " + "template but assistant_masks was not returned by the tokenizer. " + "Ensure _ensure_generation_tags() ran successfully." + ) tokenized_examples["labels"] = labels return tokenized_examples @@ -211,15 +388,34 @@ def __call__(self, examples): batch.append(text) else: messages = example.get("messages", None) - if messages is None: - conversations = example.get("conversations", None) - if conversations is None: - raise ValueError( - "The sample must in either OpenAI messages format or ShareGPT conversations format." + conversations = example.get("conversations", None) + # Prefer whichever has an assistant turn for training + if messages and any(m.get("role") == "assistant" for m in messages): + batch.append(messages) + elif conversations: + converted = _sharegpt_to_openai_messages(conversations) + if not any(m.get("role") == "assistant" for m in converted): + print_rank_0( + "=== WARNING === Skipping sample with no assistant turn in conversations." ) - else: - messages = _sharegpt_to_openai_messages(conversations) - batch.append(messages) + continue + batch.append(converted) + elif messages: + if not any(m.get("role") == "assistant" for m in messages): + print_rank_0( + "=== WARNING === Skipping sample with no assistant turn in messages." + ) + continue + batch.append(messages) + else: + raise ValueError( + "The sample must in either OpenAI messages format or ShareGPT conversations format." + ) + + if not batch: + # All samples skipped — create a dummy batch with all-masked labels + # so the training step produces zero loss without crashing DDP + batch = [[{"role": "user", "content": ""}, {"role": "assistant", "content": ""}]] # type: ignore[list-item] return self._process_chat_sample(batch) diff --git a/tests/gpu/torch/speculative/plugins/test_hf_dflash.py b/tests/gpu/torch/speculative/plugins/test_hf_dflash.py new file mode 100644 index 0000000000..230b67c45d --- /dev/null +++ b/tests/gpu/torch/speculative/plugins/test_hf_dflash.py @@ -0,0 +1,196 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""GPU tests for DFlash speculative decoding plugin. + +These tests require a CUDA GPU. CPU-only tests are in tests/unit/. +""" + +from copy import deepcopy + +import pytest +import torch +from _test_utils.torch.transformers_models import get_tiny_llama + +import modelopt.torch.speculative as mtsp +from modelopt.torch.speculative.config import DFLASH_DEFAULT_CFG + +BLOCK_SIZE = 4 +NUM_DRAFT_LAYERS = 2 +SEQ_LEN = 16 # must be multiple of BLOCK_SIZE + + +def _get_dflash_config(block_size=BLOCK_SIZE, num_layers=NUM_DRAFT_LAYERS): + """Create a DFlash config for testing.""" + config = deepcopy(DFLASH_DEFAULT_CFG["config"]) + config["dflash_block_size"] = block_size + config["dflash_use_torch_compile"] = False + config["dflash_architecture_config"] = { + "num_hidden_layers": num_layers, + "mask_token_id": 0, + } + return config + + +@pytest.fixture +def dflash_model(): + """Create a tiny DFlash model on GPU.""" + model = get_tiny_llama(num_hidden_layers=4) + config = _get_dflash_config() + mtsp.convert(model, [("dflash", config)]) + model = model.cuda() + return model + + +class TestDFlashModuleGPU: + """Test DFlash draft module forward pass on GPU.""" + + def test_dflash_module_forward_shape(self, dflash_model): + """Test that draft module produces correct output shape.""" + model = dflash_model + bsz = 2 + hidden_size = model.config.hidden_size + num_layers = len(model.target_layer_ids) + + dtype = next(model.dflash_module.parameters()).dtype + target_hidden = torch.randn( + bsz, SEQ_LEN, num_layers * hidden_size, device="cuda", dtype=dtype + ) + noise_emb = torch.randn(bsz, SEQ_LEN, hidden_size, device="cuda", dtype=dtype) + pos_ids = ( + torch.cat([torch.arange(SEQ_LEN), torch.arange(SEQ_LEN)]) + .unsqueeze(0) + .expand(bsz, -1) + .cuda() + ) + + output = model.dflash_module( + noise_embedding=noise_emb, + target_hidden=target_hidden, + position_ids=pos_ids, + attention_mask=None, + ) + assert output.shape == (bsz, SEQ_LEN, hidden_size) + + def test_dflash_module_deterministic(self, dflash_model): + """Test that draft module produces identical outputs for same input.""" + model = dflash_model + model.eval() + bsz = 1 + hidden_size = model.config.hidden_size + num_layers = len(model.target_layer_ids) + + dtype = next(model.dflash_module.parameters()).dtype + target_hidden = torch.randn( + bsz, SEQ_LEN, num_layers * hidden_size, device="cuda", dtype=dtype + ) + noise_emb = torch.randn(bsz, SEQ_LEN, hidden_size, device="cuda", dtype=dtype) + pos_ids = torch.cat([torch.arange(SEQ_LEN), torch.arange(SEQ_LEN)]).unsqueeze(0).cuda() + + with torch.no_grad(): + out1 = model.dflash_module( + noise_embedding=noise_emb, + target_hidden=target_hidden, + position_ids=pos_ids, + ) + out2 = model.dflash_module( + noise_embedding=noise_emb, + target_hidden=target_hidden, + position_ids=pos_ids, + ) + assert torch.allclose(out1, out2) + + +class TestDFlashTrainingForwardGPU: + """Test DFlash training forward pass end-to-end on GPU.""" + + @pytest.fixture + def model(self): + """Create a tiny DFlash model in training mode on GPU.""" + model = get_tiny_llama(num_hidden_layers=4) + config = _get_dflash_config() + mtsp.convert(model, [("dflash", config)]) + model = model.cuda() + model.train() + return model + + def test_training_forward_returns_loss(self, model): + """Test that training forward returns a differentiable loss.""" + bsz = 2 + input_ids = torch.randint(0, model.config.vocab_size, (bsz, SEQ_LEN), device="cuda") + attention_mask = torch.ones(bsz, SEQ_LEN, dtype=torch.long, device="cuda") + + output = model(input_ids=input_ids, attention_mask=attention_mask) + assert hasattr(output, "loss") + assert output.loss.requires_grad + + def test_training_forward_returns_accuracy(self, model): + """Test that training forward returns train_acc.""" + bsz = 2 + input_ids = torch.randint(0, model.config.vocab_size, (bsz, SEQ_LEN), device="cuda") + attention_mask = torch.ones(bsz, SEQ_LEN, dtype=torch.long, device="cuda") + + output = model(input_ids=input_ids, attention_mask=attention_mask) + assert hasattr(output, "train_acc") + + def test_training_forward_with_labels(self, model): + """Test that labels are used for response-only loss masking.""" + bsz = 2 + input_ids = torch.randint(0, model.config.vocab_size, (bsz, SEQ_LEN), device="cuda") + attention_mask = torch.ones(bsz, SEQ_LEN, dtype=torch.long, device="cuda") + + # Labels with -100 for first half (masked), real labels for second half + labels = torch.full((bsz, SEQ_LEN), -100, dtype=torch.long, device="cuda") + labels[:, SEQ_LEN // 2 :] = input_ids[:, SEQ_LEN // 2 :] + + output = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels) + assert hasattr(output, "loss") + assert output.loss.requires_grad + + def test_training_forward_all_masked_labels(self, model): + """Test that all-masked labels produce zero loss without crashing.""" + bsz = 2 + input_ids = torch.randint(0, model.config.vocab_size, (bsz, SEQ_LEN), device="cuda") + attention_mask = torch.ones(bsz, SEQ_LEN, dtype=torch.long, device="cuda") + labels = torch.full((bsz, SEQ_LEN), -100, dtype=torch.long, device="cuda") + + output = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels) + assert output.loss.item() == 0.0 + + def test_training_backward(self, model): + """Test that gradients flow to dflash_module.""" + bsz = 2 + input_ids = torch.randint(0, model.config.vocab_size, (bsz, SEQ_LEN), device="cuda") + attention_mask = torch.ones(bsz, SEQ_LEN, dtype=torch.long, device="cuda") + + output = model(input_ids=input_ids, attention_mask=attention_mask) + output.loss.backward() + + has_grad = False + for name, param in model.dflash_module.named_parameters(): + if param.grad is not None and param.grad.abs().sum() > 0: + has_grad = True + break + assert has_grad, "DFlash module should receive gradients" + + def test_eval_forward_uses_base_model(self, model): + """In eval mode, forward should use base model (not DFlash training).""" + model.eval() + bsz = 1 + input_ids = torch.randint(0, model.config.vocab_size, (bsz, SEQ_LEN), device="cuda") + + with torch.no_grad(): + output = model(input_ids=input_ids) + assert output.logits.shape == (bsz, SEQ_LEN, model.config.vocab_size) diff --git a/tests/unit/torch/speculative/plugins/test_hf_dflash.py b/tests/unit/torch/speculative/plugins/test_hf_dflash.py new file mode 100644 index 0000000000..50d3c9768b --- /dev/null +++ b/tests/unit/torch/speculative/plugins/test_hf_dflash.py @@ -0,0 +1,243 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""CPU unit tests for DFlash speculative decoding plugin. + +GPU-dependent tests (training forward, module forward) are in tests/gpu/. +""" + +import os +from copy import deepcopy + +import torch +from _test_utils.torch.transformers_models import ( + get_tiny_llama, + tf_modelopt_state_and_output_tester, +) +from transformers import AutoModelForCausalLM + +import modelopt.torch.opt as mto +import modelopt.torch.speculative as mtsp +from modelopt.torch.speculative.config import DFLASH_DEFAULT_CFG +from modelopt.torch.speculative.plugins.hf_dflash import ( + DFlashModule, + HFDFlashModel, + create_dflash_attention_mask, + create_dflash_loss_mask, +) + +BLOCK_SIZE = 4 +NUM_DRAFT_LAYERS = 2 +SEQ_LEN = 16 # must be multiple of BLOCK_SIZE + + +def _get_dflash_config(block_size=BLOCK_SIZE, num_layers=NUM_DRAFT_LAYERS): + """Create a DFlash config for testing.""" + config = deepcopy(DFLASH_DEFAULT_CFG["config"]) + config["dflash_block_size"] = block_size + config["dflash_use_torch_compile"] = False + config["dflash_architecture_config"] = { + "num_hidden_layers": num_layers, + "mask_token_id": 0, # use token 0 as mask for tiny model + } + return config + + +class TestDFlashConvert: + """Test DFlash model conversion.""" + + def test_convert_creates_dflash_model(self): + """Test that convert produces an HFDFlashModel.""" + model = get_tiny_llama(num_hidden_layers=4) + config = _get_dflash_config() + mtsp.convert(model, [("dflash", config)]) + assert isinstance(model, HFDFlashModel) + + def test_convert_creates_dflash_module(self): + """Test that convert attaches a DFlashModule.""" + model = get_tiny_llama(num_hidden_layers=4) + config = _get_dflash_config() + mtsp.convert(model, [("dflash", config)]) + assert hasattr(model, "dflash_module") + assert isinstance(model.dflash_module, DFlashModule) + + def test_convert_freezes_base_model(self): + """Test that base model parameters are frozen after convert.""" + model = get_tiny_llama(num_hidden_layers=4) + config = _get_dflash_config() + mtsp.convert(model, [("dflash", config)]) + for name, param in model.named_parameters(): + if "dflash_module" not in name: + assert not param.requires_grad, f"Base param {name} should be frozen" + + def test_convert_dflash_module_trainable(self): + """Test that DFlash module parameters are trainable after convert.""" + model = get_tiny_llama(num_hidden_layers=4) + config = _get_dflash_config() + mtsp.convert(model, [("dflash", config)]) + dflash_params = [(n, p) for n, p in model.named_parameters() if "dflash_module" in n] + assert len(dflash_params) > 0 + for name, param in dflash_params: + assert param.requires_grad, f"DFlash param {name} should be trainable" + + def test_convert_sets_target_layer_ids(self): + """Test that target layer IDs are set correctly.""" + model = get_tiny_llama(num_hidden_layers=8) + config = _get_dflash_config(num_layers=3) + mtsp.convert(model, [("dflash", config)]) + assert hasattr(model, "target_layer_ids") + assert len(model.target_layer_ids) == 3 + for lid in model.target_layer_ids: + assert 0 <= lid < 8 + + def test_convert_sets_mask_token_id(self): + """Test that mask_token_id is set from config.""" + model = get_tiny_llama(num_hidden_layers=4) + config = _get_dflash_config() + mtsp.convert(model, [("dflash", config)]) + assert hasattr(model, "mask_token_id") + assert model.mask_token_id == 0 + + +class TestDFlashSaveRestore: + """Test DFlash model save and restore.""" + + def test_save_and_restore(self, tmp_path): + """Test round-trip save/load preserves modelopt state and outputs.""" + mto.enable_huggingface_checkpointing() + model_ref = get_tiny_llama(num_hidden_layers=4) + config = _get_dflash_config() + mtsp.convert(model_ref, [("dflash", config)]) + + model_ref.save_pretrained(tmp_path / "modelopt_model") + assert os.path.exists(tmp_path / "modelopt_model/modelopt_state.pth") + + model_test = AutoModelForCausalLM.from_pretrained(tmp_path / "modelopt_model") + assert isinstance(model_test, HFDFlashModel) + tf_modelopt_state_and_output_tester(model_ref, model_test) + + +class TestDFlashAttentionMask: + """Test DFlash attention mask construction.""" + + def test_mask_shape(self): + """Test mask has shape [1, 1, L, 2L].""" + mask = create_dflash_attention_mask(SEQ_LEN, BLOCK_SIZE, "cpu", torch.float32) + assert mask.shape == (1, 1, SEQ_LEN, 2 * SEQ_LEN) + + def test_mask_context_strictly_previous_blocks(self): + """Context (left half): block B can only see blocks 0..B-1.""" + mask = create_dflash_attention_mask(8, 4, "cpu", torch.float32) + mask_2d = mask[0, 0] # [8, 16] + ctx_mask = mask_2d[:, :8] # context part + + # Block 0 (rows 0-3) should NOT see any context + assert (ctx_mask[:4, :] < 0).all() + + # Block 1 (rows 4-7) should see block 0 context only + assert (ctx_mask[4:8, :4] == 0).all() # can see block 0 + assert (ctx_mask[4:8, 4:8] < 0).all() # cannot see own block + + def test_mask_noise_causal_within_block(self): + """Noise (right half): reverse-causal within same block, matching SpecForge. + + SpecForge uses j >= i: position 0 (anchor) sees all positions in block, + position B-1 sees only itself. Cross-block noise is fully masked. + """ + mask = create_dflash_attention_mask(8, 4, "cpu", torch.float32) + mask_2d = mask[0, 0] + noise_mask = mask_2d[:, 8:] # noise part + + # Block 0, position 0: can see all positions in block (0-3) + assert (noise_mask[0, :4] == 0).all() + + # Block 0, position 3: can only see position 3 + assert (noise_mask[3, :3] < 0).all() + assert noise_mask[3, 3] == 0 + + # Block 1 cannot see block 0 noise + assert (noise_mask[4:8, :4] < 0).all() + + def test_mask_values_are_zero_or_neg_inf(self): + """Test mask contains only 0 (attend) and -inf (mask).""" + mask = create_dflash_attention_mask(SEQ_LEN, BLOCK_SIZE, "cpu", torch.float32) + unique_vals = mask.unique() + assert len(unique_vals) == 2 + assert 0.0 in unique_vals + assert unique_vals.min() == torch.finfo(torch.float32).min + + +class TestDFlashLossMask: + """Test DFlash loss mask construction.""" + + def test_loss_mask_shape(self): + """Test loss mask has shape [L].""" + mask = create_dflash_loss_mask(SEQ_LEN, BLOCK_SIZE, "cpu") + assert mask.shape == (SEQ_LEN,) + + def test_loss_mask_excludes_block_zero(self): + """Test all positions in block 0 are masked out.""" + mask = create_dflash_loss_mask(SEQ_LEN, BLOCK_SIZE, "cpu") + assert (mask[:BLOCK_SIZE] == 0).all() + + def test_loss_mask_excludes_block_starts(self): + """Test block start positions are masked.""" + mask = create_dflash_loss_mask(SEQ_LEN, BLOCK_SIZE, "cpu") + for i in range(0, SEQ_LEN, BLOCK_SIZE): + assert mask[i] == 0, f"Block start position {i} should be masked" + + def test_loss_mask_includes_non_start_positions(self): + """Test non-start positions in non-zero blocks are included.""" + mask = create_dflash_loss_mask(SEQ_LEN, BLOCK_SIZE, "cpu") + for b in range(1, SEQ_LEN // BLOCK_SIZE): + for offset in range(1, BLOCK_SIZE): + pos = b * BLOCK_SIZE + offset + assert mask[pos] == 1, f"Position {pos} should be in loss" + + def test_loss_mask_count(self): + """Test total active positions matches expected count.""" + mask = create_dflash_loss_mask(SEQ_LEN, BLOCK_SIZE, "cpu") + num_blocks = SEQ_LEN // BLOCK_SIZE + expected = (num_blocks - 1) * (BLOCK_SIZE - 1) + assert mask.sum().item() == expected + + +class TestBuildTargetLayerIds: + """Test target layer selection.""" + + def test_single_draft_layer(self): + """Test single draft layer selects middle target layer.""" + from modelopt.torch.speculative.plugins.hf_dflash import build_target_layer_ids + + ids = build_target_layer_ids(32, 1) + assert len(ids) == 1 + assert ids[0] == 16 # middle layer + + def test_multiple_draft_layers(self): + """Test multiple draft layers are monotonically increasing and in bounds.""" + from modelopt.torch.speculative.plugins.hf_dflash import build_target_layer_ids + + ids = build_target_layer_ids(36, 5) + assert len(ids) == 5 + assert ids == sorted(ids) + assert all(1 <= lid <= 33 for lid in ids) + + def test_layer_ids_spread(self): + """Test layer IDs have no duplicates.""" + from modelopt.torch.speculative.plugins.hf_dflash import build_target_layer_ids + + ids = build_target_layer_ids(32, 5) + assert len(ids) == 5 + assert len(set(ids)) == 5 diff --git a/tools/launcher/common/dflash/ar_validate.sh b/tools/launcher/common/dflash/ar_validate.sh new file mode 100644 index 0000000000..b9df0b5c6f --- /dev/null +++ b/tools/launcher/common/dflash/ar_validate.sh @@ -0,0 +1,127 @@ +#!/bin/bash + +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# DFlash AR (Acceptance Rate) validation script. +# Loads a trained DFlash checkpoint and evaluates speculative decoding AR on MT-Bench. +# +# Required env vars: +# HF_MODEL_CKPT — path to the target HuggingFace model +# DFLASH_CKPT — path to the trained DFlash checkpoint +# DFLASH_BLOCK_SIZE — block size (default: 16) +# DFLASH_NUM_LAYERS — number of draft layers (default: 5) +# DFLASH_MASK_TOKEN_ID — mask token ID (default: auto-detect) +# NUM_SAMPLES — number of MT-Bench samples to evaluate (default: 20) + +SCRIPT_DIR="$(dirname "$(readlink -f "$0")")" +source ${SCRIPT_DIR}/../service_utils.sh +trap 'error_handler $0 $LINENO' ERR + +pip install --upgrade "transformers>=4.57" 2>&1 | tail -3 + +DFLASH_BLOCK_SIZE=${DFLASH_BLOCK_SIZE:-16} +DFLASH_NUM_LAYERS=${DFLASH_NUM_LAYERS:-5} +NUM_SAMPLES=${NUM_SAMPLES:-20} + +# Build mask_token_id arg +if [ -n "$DFLASH_MASK_TOKEN_ID" ]; then + MASK_ARG="'mask_token_id': ${DFLASH_MASK_TOKEN_ID}," +else + MASK_ARG="" +fi + +echo "=== DFlash AR Validation ===" +echo "Target model: ${HF_MODEL_CKPT}" +echo "DFlash checkpoint: ${DFLASH_CKPT}" +echo "Block size: ${DFLASH_BLOCK_SIZE}" +echo "Draft layers: ${DFLASH_NUM_LAYERS}" +echo "Samples: ${NUM_SAMPLES}" + +CUDA_VISIBLE_DEVICES=0 python3 -c " +import torch +from datasets import load_dataset +from transformers import AutoModelForCausalLM, AutoTokenizer +from modelopt.torch.speculative.plugins.transformers import HFARValidation +import modelopt.torch.opt as mto +import modelopt.torch.speculative as mtsp + +mto.enable_huggingface_checkpointing() + +model = AutoModelForCausalLM.from_pretrained( + '${HF_MODEL_CKPT}', torch_dtype=torch.bfloat16, device_map={'': 0}, trust_remote_code=True +) +tokenizer = AutoTokenizer.from_pretrained('${HF_MODEL_CKPT}', trust_remote_code=True) + +config = { + 'dflash_block_size': ${DFLASH_BLOCK_SIZE}, + 'dflash_architecture_config': { + 'num_hidden_layers': ${DFLASH_NUM_LAYERS}, + ${MASK_ARG} + }, + 'dflash_use_torch_compile': False, +} +mtsp.convert(model, [('dflash', config)]) + +# Load trained DFlash weights +import glob +from safetensors.torch import load_file +ckpt_files = sorted(glob.glob('${DFLASH_CKPT}/model*.safetensors')) +if ckpt_files: + state = {} + for f in ckpt_files: + state.update(load_file(f)) + # Try with dflash_module prefix first (ModelOpt format) + dflash_keys = {k: v for k, v in state.items() if 'dflash_module' in k} + if dflash_keys: + model.load_state_dict(dflash_keys, strict=False) + print(f'Loaded {len(dflash_keys)} DFlash weights (with prefix)') + else: + # No prefix — SpecForge format, load directly into dflash_module + result = model.dflash_module.load_state_dict(state, strict=False) + loaded = len(state) - len(result.unexpected_keys) + print(f'Loaded {loaded} DFlash weights (no prefix), missing={len(result.missing_keys)}') +else: + print('WARNING: No checkpoint files found, using random weights') + +model.eval() +validator = HFARValidation(model, tokenizer) + +ds = load_dataset('/hf-local/HuggingFaceH4/mt_bench_prompts')['train'] +num_samples = min(${NUM_SAMPLES}, len(ds)) + +ars = [] +for i in range(num_samples): + prompt = ds[i]['prompt'][0] + chat = [{'role': 'user', 'content': prompt}] + text = tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True) + input_ids = tokenizer(text, return_tensors='pt').input_ids.cuda() + try: + _, ar = validator.validate(osl=32, input_ids=input_ids, steps=3) + ars.append(ar) + print(f' AR={ar:.2f} | {prompt[:60]}') + except Exception as e: + print(f' ERROR | {prompt[:60]}... | {e}') + +if ars: + avg_ar = sum(ars) / len(ars) + print(f'\n==== DFlash AR Results ====') + print(f'Samples: {len(ars)}') + print(f'Average AR: {avg_ar:.4f}') + print(f'Min AR: {min(ars):.4f}') + print(f'Max AR: {max(ars):.4f}') +else: + print('No AR results collected') +" diff --git a/tools/launcher/common/dflash/online_training.sh b/tools/launcher/common/dflash/online_training.sh new file mode 100644 index 0000000000..62fe51fcc2 --- /dev/null +++ b/tools/launcher/common/dflash/online_training.sh @@ -0,0 +1,221 @@ +#!/bin/bash + +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# DFlash online training + AR validation script for the ModelOpt Launcher. +# Trains a DFlash draft model alongside the frozen target model, +# then evaluates acceptance rate on MT-Bench. +# +# Required env vars: +# HF_MODEL_CKPT — path to the target HuggingFace model +# +# Optional env vars: +# NUM_AR_SAMPLES — number of MT-Bench samples for AR validation (default: 20, 0 to skip) +# +# All other args are passed through to launch_train.sh. + +SCRIPT_DIR="$(dirname "$(readlink -f "$0")")" +source ${SCRIPT_DIR}/../service_utils.sh + +pip install -r modules/Model-Optimizer/examples/speculative_decoding/requirements.txt +pip install huggingface-hub>=1.2.1 +export PATH=$PATH:/workspace/.local/bin + +################################################################################################### + +trap 'error_handler $0 $LINENO' ERR + +# Auto-detect head node IP for multi-node training +if [ -z "$HEAD_NODE_IP" ]; then + # Method 1: scontrol (works outside container) + HEAD_NODE_IP=$(scontrol show hostnames "$SLURM_JOB_NODELIST" 2>/dev/null | head -1) + # Method 2: SLURM_LAUNCH_NODE_IPADDR (some Slurm versions) + HEAD_NODE_IP=${HEAD_NODE_IP:-$SLURM_LAUNCH_NODE_IPADDR} + # Method 3: Parse SLURM_NODELIST and resolve via Python + if [ -z "$HEAD_NODE_IP" ] && [ -n "$SLURM_JOB_NODELIST" ]; then + HEAD_NODE_IP=$(python3 -c " +import socket, re, os +nl = os.environ.get('SLURM_JOB_NODELIST', '') +# Extract first hostname: 'node[001-002]' -> 'node001', 'node001,node002' -> 'node001' +m = re.match(r'([a-zA-Z0-9-]+?)(?:\[(\d+))?', nl) +if m: + host = m.group(1) + (m.group(2) or '') + try: + print(socket.gethostbyname(host)) + except: + print(host) +" 2>/dev/null) + fi + # Method 4: Use rank 0's hostname + if [ -z "$HEAD_NODE_IP" ] && [ "${SLURM_PROCID:-0}" = "0" ]; then + HEAD_NODE_IP=$(hostname -I 2>/dev/null | awk '{print $1}') + fi + export HEAD_NODE_IP + echo "Auto-detected HEAD_NODE_IP: ${HEAD_NODE_IP}" +fi + +# Parse DFlash-specific args from the command line for AR validation +DFLASH_BLOCK_SIZE=16 +DFLASH_NUM_LAYERS=5 +DFLASH_MASK_TOKEN_ID="" +OUTPUT_DIR="" +for arg in "$@"; do + case "$arg" in + --dflash_block_size) next_is_block_size=1 ;; + --dflash_num_layers) next_is_num_layers=1 ;; + --dflash_mask_token_id) next_is_mask_id=1 ;; + --output_dir) next_is_output=1 ;; + *) + if [ "$next_is_block_size" = "1" ]; then DFLASH_BLOCK_SIZE="$arg"; next_is_block_size=0; fi + if [ "$next_is_num_layers" = "1" ]; then DFLASH_NUM_LAYERS="$arg"; next_is_num_layers=0; fi + if [ "$next_is_mask_id" = "1" ]; then DFLASH_MASK_TOKEN_ID="$arg"; next_is_mask_id=0; fi + if [ "$next_is_output" = "1" ]; then OUTPUT_DIR="$arg"; next_is_output=0; fi + ;; + esac +done + +# Step 1: Training +bash modules/Model-Optimizer/examples/speculative_decoding/launch_train.sh \ + --model ${HF_MODEL_CKPT} \ + --mode dflash \ + ${@} + +# Step 2: AR Validation +NUM_AR_SAMPLES=${NUM_AR_SAMPLES:-20} +if [ "${NUM_AR_SAMPLES}" = "0" ]; then + echo "Skipping AR validation (NUM_AR_SAMPLES=0)" + exit 0 +fi + +if [ -z "$OUTPUT_DIR" ]; then + echo "WARNING: --output_dir not found in args, skipping export and AR validation" + exit 0 +fi + +# Step 2: Export checkpoint to z-lab HF format +EXPORT_DIR=${OUTPUT_DIR}/export +echo "" +echo "=== Exporting DFlash checkpoint ===" +echo "Source: ${OUTPUT_DIR}" +echo "Export: ${EXPORT_DIR}" + +python3 modules/Model-Optimizer/examples/speculative_decoding/scripts/export_hf_checkpoint.py \ + --model_path ${OUTPUT_DIR} \ + --export_path ${EXPORT_DIR} \ + || echo "WARNING: Export failed, continuing with AR validation" + +echo "" +echo "Export contents:" +ls -la ${EXPORT_DIR}/ 2>/dev/null || echo "No export dir" + +# Step 3: AR Validation +# Build mask_token_id config +if [ -n "$DFLASH_MASK_TOKEN_ID" ]; then + MASK_ARG="'mask_token_id': ${DFLASH_MASK_TOKEN_ID}," +else + MASK_ARG="" +fi + +echo "" +echo "=== DFlash AR Validation ===" +echo "Target model: ${HF_MODEL_CKPT}" +# Prefer exported checkpoint (no prefix), fall back to training output (with prefix) +if [ -f "${EXPORT_DIR}/model.safetensors" ]; then + AR_CKPT=${EXPORT_DIR} + echo "Using exported checkpoint: ${AR_CKPT}" +else + AR_CKPT=${OUTPUT_DIR} + echo "Using training checkpoint: ${AR_CKPT}" +fi +echo "Block size: ${DFLASH_BLOCK_SIZE}" +echo "Draft layers: ${DFLASH_NUM_LAYERS}" +echo "Samples: ${NUM_AR_SAMPLES}" + +CUDA_VISIBLE_DEVICES=0 python3 -c " +import torch +from datasets import load_dataset +from transformers import AutoModelForCausalLM, AutoTokenizer +from modelopt.torch.speculative.plugins.transformers import HFARValidation +import modelopt.torch.opt as mto +import modelopt.torch.speculative as mtsp + +mto.enable_huggingface_checkpointing() + +model = AutoModelForCausalLM.from_pretrained( + '${HF_MODEL_CKPT}', torch_dtype=torch.bfloat16, device_map={'': 0}, trust_remote_code=True +) +tokenizer = AutoTokenizer.from_pretrained('${HF_MODEL_CKPT}', trust_remote_code=True) + +config = { + 'dflash_block_size': ${DFLASH_BLOCK_SIZE}, + 'dflash_architecture_config': { + 'num_hidden_layers': ${DFLASH_NUM_LAYERS}, + ${MASK_ARG} + }, + 'dflash_use_torch_compile': False, +} +mtsp.convert(model, [('dflash', config)]) + +# Load trained DFlash weights +import glob +from safetensors.torch import load_file +ckpt_files = sorted(glob.glob('${AR_CKPT}/model*.safetensors')) +if ckpt_files: + state = {} + for f in ckpt_files: + state.update(load_file(f)) + dflash_keys = {k: v for k, v in state.items() if 'dflash_module' in k} + if dflash_keys: + model.load_state_dict(dflash_keys, strict=False) + print(f'Loaded {len(dflash_keys)} DFlash weights (with prefix)') + else: + result = model.dflash_module.load_state_dict(state, strict=False) + loaded = len(state) - len(result.unexpected_keys) + print(f'Loaded {loaded} DFlash weights (no prefix), missing={len(result.missing_keys)}') +else: + print('WARNING: No checkpoint files found, using random weights') + +model.eval() +validator = HFARValidation(model, tokenizer) + +ds = load_dataset('/hf-local/HuggingFaceH4/mt_bench_prompts')['train'] +num_samples = min(${NUM_AR_SAMPLES}, len(ds)) + +ars = [] +for i in range(num_samples): + prompt = ds[i]['prompt'][0] + chat = [{'role': 'user', 'content': prompt}] + text = tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True) + input_ids = tokenizer(text, return_tensors='pt').input_ids.cuda() + try: + _, ar = validator.validate(osl=32, input_ids=input_ids, steps=3) + ars.append(ar) + print(f' AR={ar:.2f} | {prompt[:60]}') + except Exception as e: + print(f' ERROR | {prompt[:60]}... | {e}') + +if ars: + avg_ar = sum(ars) / len(ars) + print(f'\n==== DFlash AR Results ====') + print(f'Samples: {len(ars)}') + print(f'Average AR: {avg_ar:.4f}') + print(f'Min AR: {min(ars):.4f}') + print(f'Max AR: {max(ars):.4f}') +else: + print('No AR results collected') +" + +################################################################################################### diff --git a/tools/launcher/examples/Qwen/Qwen3-8B/hf_online_dflash.yaml b/tools/launcher/examples/Qwen/Qwen3-8B/hf_online_dflash.yaml new file mode 100644 index 0000000000..5f094f3a14 --- /dev/null +++ b/tools/launcher/examples/Qwen/Qwen3-8B/hf_online_dflash.yaml @@ -0,0 +1,63 @@ +# DFlash online speculative decoding training for Qwen3-8B. +# +# Trains a DFlash draft model (block diffusion) using the frozen target model +# to extract multi-layer hidden states on the fly, then evaluates AR on MT-Bench. +# +# 2-step pipeline: +# task_0: Online DFlash training + AR validation +# task_1: Benchmark speculative decoding speedup via VLLM +# +# Reference: "DFlash: Block Diffusion for Flash Speculative Decoding" (arXiv:2602.06036) +# +# Usage: +# uv run launch.py --yaml examples/Qwen/Qwen3-8B/hf_online_dflash.yaml --yes +# uv run slurm.py --yaml modules/Model-Optimizer/tools/launcher/examples/Qwen/Qwen3-8B/hf_online_dflash.yaml --yes + +job_name: Qwen3-8B_DFlash_online +pipeline: + # Step 1: Online DFlash training + AR validation + task_0: + script: common/dflash/online_training.sh + args: + - --data /hf-local/modelopt/Speculative-Decoding-Dataset-v1-Qwen3-8B/sample-100K.jsonl + - --output_dir /scratchspace/dflash + - --num_epochs 1 + - --lr 6e-4 + - --training_seq_len 512 + - --save_steps 500000 + - --log_steps 100 + - --disable_tqdm True + - --ar_validate_steps 0 + - --dflash_block_size 16 + - --dflash_num_layers 5 + - --dflash_mask_token_id 151643 + environment: + - HF_MODEL_CKPT: /hf-local/Qwen/Qwen3-8B + - NUM_AR_SAMPLES: "20" + slurm_config: + _factory_: "slurm_factory" + nodes: 1 + ntasks_per_node: 1 + gpus_per_node: 8 + + # Step 2: Benchmark speculative decoding (VLLM backend) + task_1: + script: common/specdec_bench/quick_check.sh + args: + - --draft_model_dir /scratchspace/dflash + - --draft_length 3 + - --output_length 4096 + - --engine VLLM + - --tp_size 4 + - --ep_size 1 + - --speculative_algorithm EAGLE3 + - --mtbench /hf-local/HuggingFaceH4/mt_bench_prompts/raw/question.jsonl + - --concurrency 1 + environment: + - HF_MODEL_CKPT: /hf-local/Qwen/Qwen3-8B + slurm_config: + _factory_: "slurm_factory" + nodes: 1 + ntasks_per_node: 1 + gpus_per_node: 4 + container: vllm/vllm-openai:latest diff --git a/uv.lock b/uv.lock index d890e361cb..3cd6db2083 100644 --- a/uv.lock +++ b/uv.lock @@ -20,9 +20,6 @@ resolution-markers = [ "python_full_version < '3.11' and platform_machine != 'aarch64' and sys_platform != 'darwin' and sys_platform != 'win32'", ] -[manifest] -overrides = [{ name = "torch", marker = "sys_platform == 'never'" }] - [[package]] name = "accelerate" version = "1.13.0" @@ -35,7 +32,7 @@ dependencies = [ { name = "psutil" }, { name = "pyyaml" }, { name = "safetensors" }, - { name = "torch", marker = "sys_platform == 'never'" }, + { name = "torch" }, ] sdist = { url = "https://files.pythonhosted.org/packages/ca/14/787e5498cd062640f0f3d92ef4ae4063174f76f9afd29d13fc52a319daae/accelerate-1.13.0.tar.gz", hash = "sha256:d631b4e0f5b3de4aff2d7e9e6857d164810dfc3237d54d017f075122d057b236", size = 402835, upload-time = "2026-03-04T19:34:12.359Z" } wheels = [ @@ -480,6 +477,21 @@ dependencies = [ ] sdist = { url = "https://files.pythonhosted.org/packages/54/27/01d9078a77b9e31b79b9716e66ca4db74f4744c5232bcb3e8769395c4280/cppimport-22.8.2.tar.gz", hash = "sha256:bbb4957102db41bc99ad72c233bce92f9d1fd91be352fc07878c4361033a401f", size = 26635, upload-time = "2022-08-02T16:50:36.872Z" } +[[package]] +name = "cuda-bindings" +version = "12.9.4" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "cuda-pathfinder", marker = "platform_machine != 'aarch64' and sys_platform != 'darwin' and sys_platform != 'win32'" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/7a/d8/b546104b8da3f562c1ff8ab36d130c8fe1dd6a045ced80b4f6ad74f7d4e1/cuda_bindings-12.9.4-cp310-cp310-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:4d3c842c2a4303b2a580fe955018e31aea30278be19795ae05226235268032e5", size = 12148218, upload-time = "2025-10-21T14:51:28.855Z" }, + { url = "https://files.pythonhosted.org/packages/45/e7/b47792cc2d01c7e1d37c32402182524774dadd2d26339bd224e0e913832e/cuda_bindings-12.9.4-cp311-cp311-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:c912a3d9e6b6651853eed8eed96d6800d69c08e94052c292fec3f282c5a817c9", size = 12210593, upload-time = "2025-10-21T14:51:36.574Z" }, + { url = "https://files.pythonhosted.org/packages/a9/c1/dabe88f52c3e3760d861401bb994df08f672ec893b8f7592dc91626adcf3/cuda_bindings-12.9.4-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:fda147a344e8eaeca0c6ff113d2851ffca8f7dfc0a6c932374ee5c47caa649c8", size = 12151019, upload-time = "2025-10-21T14:51:43.167Z" }, + { url = "https://files.pythonhosted.org/packages/63/56/e465c31dc9111be3441a9ba7df1941fe98f4aa6e71e8788a3fb4534ce24d/cuda_bindings-12.9.4-cp313-cp313-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:32bdc5a76906be4c61eb98f546a6786c5773a881f3b166486449b5d141e4a39f", size = 11906628, upload-time = "2025-10-21T14:51:49.905Z" }, + { url = "https://files.pythonhosted.org/packages/a3/84/1e6be415e37478070aeeee5884c2022713c1ecc735e6d82d744de0252eee/cuda_bindings-12.9.4-cp313-cp313t-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:56e0043c457a99ac473ddc926fe0dc4046694d99caef633e92601ab52cbe17eb", size = 11925991, upload-time = "2025-10-21T14:51:56.535Z" }, +] + [[package]] name = "cuda-pathfinder" version = "1.4.3" @@ -554,7 +566,7 @@ dependencies = [ { name = "psutil", marker = "sys_platform != 'win32'" }, { name = "py-cpuinfo", marker = "sys_platform != 'win32'" }, { name = "pydantic", marker = "sys_platform != 'win32'" }, - { name = "torch", marker = "sys_platform == 'never'" }, + { name = "torch", marker = "sys_platform != 'win32'" }, { name = "tqdm", marker = "sys_platform != 'win32'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/19/11/46b9eb3806ca7a5e9bdddb7e873855a2d59a9f87f0675ae8231678d98434/deepspeed-0.18.8.tar.gz", hash = "sha256:e4e051a144b0c74270c46e4970139f9a86a61ff26959c5e463000c4a93b99304", size = 1647226, upload-time = "2026-03-13T18:49:48.568Z" } @@ -1311,7 +1323,9 @@ name = "networkx" version = "3.4.2" source = { registry = "https://pypi.org/simple" } resolution-markers = [ + "python_full_version < '3.11' and platform_machine == 'aarch64' and sys_platform == 'win32'", "(python_full_version < '3.11' and platform_machine == 'aarch64' and sys_platform != 'win32') or (python_full_version < '3.11' and sys_platform == 'darwin')", + "python_full_version < '3.11' and platform_machine != 'aarch64' and sys_platform == 'win32'", "python_full_version < '3.11' and platform_machine != 'aarch64' and sys_platform != 'darwin' and sys_platform != 'win32'", ] sdist = { url = "https://files.pythonhosted.org/packages/fd/1d/06475e1cd5264c0b870ea2cc6fdb3e37177c1e565c43f56ff17a10e3937f/networkx-3.4.2.tar.gz", hash = "sha256:307c3669428c5362aab27c8a1260aa8f47c4e91d3891f48be0141738d8d053e1", size = 2151368, upload-time = "2024-10-21T12:39:38.695Z" } @@ -1324,12 +1338,18 @@ name = "networkx" version = "3.6.1" source = { registry = "https://pypi.org/simple" } resolution-markers = [ + "python_full_version >= '3.13' and platform_machine == 'aarch64' and sys_platform == 'win32'", + "python_full_version == '3.12.*' and platform_machine == 'aarch64' and sys_platform == 'win32'", + "python_full_version == '3.11.*' and platform_machine == 'aarch64' and sys_platform == 'win32'", "(python_full_version >= '3.13' and platform_machine == 'aarch64' and sys_platform != 'win32') or (python_full_version >= '3.13' and sys_platform == 'darwin')", "(python_full_version == '3.12.*' and platform_machine == 'aarch64' and sys_platform != 'win32') or (python_full_version == '3.12.*' and sys_platform == 'darwin')", "(python_full_version == '3.11.*' and platform_machine == 'aarch64' and sys_platform != 'win32') or (python_full_version == '3.11.*' and sys_platform == 'darwin')", "python_full_version >= '3.13' and platform_machine != 'aarch64' and sys_platform != 'darwin' and sys_platform != 'win32'", "python_full_version == '3.12.*' and platform_machine != 'aarch64' and sys_platform != 'darwin' and sys_platform != 'win32'", "python_full_version == '3.11.*' and platform_machine != 'aarch64' and sys_platform != 'darwin' and sys_platform != 'win32'", + "python_full_version >= '3.13' and platform_machine != 'aarch64' and sys_platform == 'win32'", + "python_full_version == '3.12.*' and platform_machine != 'aarch64' and sys_platform == 'win32'", + "python_full_version == '3.11.*' and platform_machine != 'aarch64' and sys_platform == 'win32'", ] sdist = { url = "https://files.pythonhosted.org/packages/6a/51/63fe664f3908c97be9d2e4f1158eb633317598cfa6e1fc14af5383f17512/networkx-3.6.1.tar.gz", hash = "sha256:26b7c357accc0c8cde558ad486283728b65b6a95d85ee1cd66bafab4c8168509", size = 2517025, upload-time = "2025-12-08T17:02:39.908Z" } wheels = [ @@ -1526,6 +1546,108 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/57/a7/b35835e278c18b85206834b3aa3abe68e77a98769c59233d1f6300284781/numpy-2.4.3-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:4b42639cdde6d24e732ff823a3fa5b701d8acad89c4142bc1d0bd6dc85200ba5", size = 12504685, upload-time = "2026-03-09T07:58:50.525Z" }, ] +[[package]] +name = "nvidia-cublas-cu12" +version = "12.8.4.1" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/dc/61/e24b560ab2e2eaeb3c839129175fb330dfcfc29e5203196e5541a4c44682/nvidia_cublas_cu12-12.8.4.1-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:8ac4e771d5a348c551b2a426eda6193c19aa630236b418086020df5ba9667142", size = 594346921, upload-time = "2025-03-07T01:44:31.254Z" }, +] + +[[package]] +name = "nvidia-cuda-cupti-cu12" +version = "12.8.90" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f8/02/2adcaa145158bf1a8295d83591d22e4103dbfd821bcaf6f3f53151ca4ffa/nvidia_cuda_cupti_cu12-12.8.90-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:ea0cb07ebda26bb9b29ba82cda34849e73c166c18162d3913575b0c9db9a6182", size = 10248621, upload-time = "2025-03-07T01:40:21.213Z" }, +] + +[[package]] +name = "nvidia-cuda-nvrtc-cu12" +version = "12.8.93" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/05/6b/32f747947df2da6994e999492ab306a903659555dddc0fbdeb9d71f75e52/nvidia_cuda_nvrtc_cu12-12.8.93-py3-none-manylinux2010_x86_64.manylinux_2_12_x86_64.whl", hash = "sha256:a7756528852ef889772a84c6cd89d41dfa74667e24cca16bb31f8f061e3e9994", size = 88040029, upload-time = "2025-03-07T01:42:13.562Z" }, +] + +[[package]] +name = "nvidia-cuda-runtime-cu12" +version = "12.8.90" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/0d/9b/a997b638fcd068ad6e4d53b8551a7d30fe8b404d6f1804abf1df69838932/nvidia_cuda_runtime_cu12-12.8.90-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:adade8dcbd0edf427b7204d480d6066d33902cab2a4707dcfc48a2d0fd44ab90", size = 954765, upload-time = "2025-03-07T01:40:01.615Z" }, +] + +[[package]] +name = "nvidia-cudnn-cu12" +version = "9.10.2.21" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "nvidia-cublas-cu12", marker = "platform_machine != 'aarch64' and sys_platform != 'darwin' and sys_platform != 'win32'" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/ba/51/e123d997aa098c61d029f76663dedbfb9bc8dcf8c60cbd6adbe42f76d049/nvidia_cudnn_cu12-9.10.2.21-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:949452be657fa16687d0930933f032835951ef0892b37d2d53824d1a84dc97a8", size = 706758467, upload-time = "2025-06-06T21:54:08.597Z" }, +] + +[[package]] +name = "nvidia-cufft-cu12" +version = "11.3.3.83" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "nvidia-nvjitlink-cu12", marker = "platform_machine != 'aarch64' and sys_platform != 'darwin' and sys_platform != 'win32'" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/1f/13/ee4e00f30e676b66ae65b4f08cb5bcbb8392c03f54f2d5413ea99a5d1c80/nvidia_cufft_cu12-11.3.3.83-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:4d2dd21ec0b88cf61b62e6b43564355e5222e4a3fb394cac0db101f2dd0d4f74", size = 193118695, upload-time = "2025-03-07T01:45:27.821Z" }, +] + +[[package]] +name = "nvidia-cufile-cu12" +version = "1.13.1.3" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/bb/fe/1bcba1dfbfb8d01be8d93f07bfc502c93fa23afa6fd5ab3fc7c1df71038a/nvidia_cufile_cu12-1.13.1.3-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:1d069003be650e131b21c932ec3d8969c1715379251f8d23a1860554b1cb24fc", size = 1197834, upload-time = "2025-03-07T01:45:50.723Z" }, +] + +[[package]] +name = "nvidia-curand-cu12" +version = "10.3.9.90" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/fb/aa/6584b56dc84ebe9cf93226a5cde4d99080c8e90ab40f0c27bda7a0f29aa1/nvidia_curand_cu12-10.3.9.90-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:b32331d4f4df5d6eefa0554c565b626c7216f87a06a4f56fab27c3b68a830ec9", size = 63619976, upload-time = "2025-03-07T01:46:23.323Z" }, +] + +[[package]] +name = "nvidia-cusolver-cu12" +version = "11.7.3.90" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "nvidia-cublas-cu12", marker = "platform_machine != 'aarch64' and sys_platform != 'darwin' and sys_platform != 'win32'" }, + { name = "nvidia-cusparse-cu12", marker = "platform_machine != 'aarch64' and sys_platform != 'darwin' and sys_platform != 'win32'" }, + { name = "nvidia-nvjitlink-cu12", marker = "platform_machine != 'aarch64' and sys_platform != 'darwin' and sys_platform != 'win32'" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/85/48/9a13d2975803e8cf2777d5ed57b87a0b6ca2cc795f9a4f59796a910bfb80/nvidia_cusolver_cu12-11.7.3.90-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:4376c11ad263152bd50ea295c05370360776f8c3427b30991df774f9fb26c450", size = 267506905, upload-time = "2025-03-07T01:47:16.273Z" }, +] + +[[package]] +name = "nvidia-cusparse-cu12" +version = "12.5.8.93" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "nvidia-nvjitlink-cu12", marker = "platform_machine != 'aarch64' and sys_platform != 'darwin' and sys_platform != 'win32'" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/c2/f5/e1854cb2f2bcd4280c44736c93550cc300ff4b8c95ebe370d0aa7d2b473d/nvidia_cusparse_cu12-12.5.8.93-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:1ec05d76bbbd8b61b06a80e1eaf8cf4959c3d4ce8e711b65ebd0443bb0ebb13b", size = 288216466, upload-time = "2025-03-07T01:48:13.779Z" }, +] + +[[package]] +name = "nvidia-cusparselt-cu12" +version = "0.7.1" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/56/79/12978b96bd44274fe38b5dde5cfb660b1d114f70a65ef962bcbbed99b549/nvidia_cusparselt_cu12-0.7.1-py3-none-manylinux2014_x86_64.whl", hash = "sha256:f1bb701d6b930d5a7cea44c19ceb973311500847f81b634d802b7b539dc55623", size = 287193691, upload-time = "2025-02-26T00:15:44.104Z" }, +] + [[package]] name = "nvidia-ml-py" version = "13.595.45" @@ -1554,7 +1676,7 @@ dependencies = [ { name = "scipy", version = "1.15.3", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, { name = "scipy", version = "1.17.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, { name = "setuptools" }, - { name = "torch", marker = "sys_platform == 'never'" }, + { name = "torch" }, { name = "tqdm" }, ] @@ -1756,11 +1878,43 @@ requires-dist = [ { name = "tox", marker = "extra == 'dev-test'", specifier = ">4.18" }, { name = "tox-current-env", marker = "extra == 'dev-test'", specifier = ">=0.0.12" }, { name = "tqdm" }, - { name = "transformers", marker = "extra == 'hf'", specifier = ">=4.53,<5.0" }, + { name = "transformers", marker = "extra == 'hf'", specifier = ">=4.56,<5.0" }, { name = "wonderwords", marker = "extra == 'hf'" }, ] provides-extras = ["onnx", "hf", "dev-lint", "dev-docs", "dev-test", "all", "dev"] +[[package]] +name = "nvidia-nccl-cu12" +version = "2.27.5" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/6e/89/f7a07dc961b60645dbbf42e80f2bc85ade7feb9a491b11a1e973aa00071f/nvidia_nccl_cu12-2.27.5-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:ad730cf15cb5d25fe849c6e6ca9eb5b76db16a80f13f425ac68d8e2e55624457", size = 322348229, upload-time = "2025-06-26T04:11:28.385Z" }, +] + +[[package]] +name = "nvidia-nvjitlink-cu12" +version = "12.8.93" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f6/74/86a07f1d0f42998ca31312f998bd3b9a7eff7f52378f4f270c8679c77fb9/nvidia_nvjitlink_cu12-12.8.93-py3-none-manylinux2010_x86_64.manylinux_2_12_x86_64.whl", hash = "sha256:81ff63371a7ebd6e6451970684f916be2eab07321b73c9d244dc2b4da7f73b88", size = 39254836, upload-time = "2025-03-07T01:49:55.661Z" }, +] + +[[package]] +name = "nvidia-nvshmem-cu12" +version = "3.4.5" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b5/09/6ea3ea725f82e1e76684f0708bbedd871fc96da89945adeba65c3835a64c/nvidia_nvshmem_cu12-3.4.5-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:042f2500f24c021db8a06c5eec2539027d57460e1c1a762055a6554f72c369bd", size = 139103095, upload-time = "2025-09-06T00:32:31.266Z" }, +] + +[[package]] +name = "nvidia-nvtx-cu12" +version = "12.8.90" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a2/eb/86626c1bbc2edb86323022371c39aa48df6fd8b0a1647bc274577f72e90b/nvidia_nvtx_cu12-12.8.90-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:5b17e2001cc0d751a5bc2c6ec6d26ad95913324a4adb86788c944f8ce9ba441f", size = 89954, upload-time = "2025-03-07T01:42:44.131Z" }, +] + [[package]] name = "omegaconf" version = "2.3.0" @@ -2157,7 +2311,7 @@ dependencies = [ { name = "psutil" }, { name = "pyyaml" }, { name = "safetensors" }, - { name = "torch", marker = "sys_platform == 'never'" }, + { name = "torch" }, { name = "tqdm" }, { name = "transformers" }, ] @@ -3403,7 +3557,7 @@ dependencies = [ { name = "huggingface-hub" }, { name = "pyyaml" }, { name = "safetensors" }, - { name = "torch", marker = "sys_platform == 'never'" }, + { name = "torch" }, { name = "torchvision" }, ] sdist = { url = "https://files.pythonhosted.org/packages/d7/2c/593109822fe735e637382aca6640c1102c19797f7791f1fd1dab2d6c3cb1/timm-1.0.25.tar.gz", hash = "sha256:47f59fc2754725735cc81bb83bcbfce5bec4ebd5d4bb9e69da57daa92fcfa768", size = 2414743, upload-time = "2026-02-23T16:49:00.137Z" } @@ -3491,15 +3645,63 @@ name = "torch" version = "2.10.0" source = { registry = "https://pypi.org/simple" } dependencies = [ + { name = "cuda-bindings", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "filelock" }, { name = "fsspec" }, { name = "jinja2" }, { name = "networkx", version = "3.4.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, { name = "networkx", version = "3.6.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, + { name = "nvidia-cublas-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cuda-cupti-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cuda-nvrtc-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cuda-runtime-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cudnn-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cufft-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cufile-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-curand-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cusolver-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cusparse-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cusparselt-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-nccl-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-nvjitlink-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-nvshmem-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-nvtx-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "setuptools", marker = "python_full_version >= '3.12'" }, { name = "sympy" }, + { name = "triton", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "typing-extensions" }, ] +wheels = [ + { url = "https://files.pythonhosted.org/packages/5b/30/bfebdd8ec77db9a79775121789992d6b3b75ee5494971294d7b4b7c999bc/torch-2.10.0-2-cp310-none-macosx_11_0_arm64.whl", hash = "sha256:2b980edd8d7c0a68c4e951ee1856334a43193f98730d97408fbd148c1a933313", size = 79411457, upload-time = "2026-02-10T21:44:59.189Z" }, + { url = "https://files.pythonhosted.org/packages/0f/8b/4b61d6e13f7108f36910df9ab4b58fd389cc2520d54d81b88660804aad99/torch-2.10.0-2-cp311-none-macosx_11_0_arm64.whl", hash = "sha256:418997cb02d0a0f1497cf6a09f63166f9f5df9f3e16c8a716ab76a72127c714f", size = 79423467, upload-time = "2026-02-10T21:44:48.711Z" }, + { url = "https://files.pythonhosted.org/packages/d3/54/a2ba279afcca44bbd320d4e73675b282fcee3d81400ea1b53934efca6462/torch-2.10.0-2-cp312-none-macosx_11_0_arm64.whl", hash = "sha256:13ec4add8c3faaed8d13e0574f5cd4a323c11655546f91fbe6afa77b57423574", size = 79498202, upload-time = "2026-02-10T21:44:52.603Z" }, + { url = "https://files.pythonhosted.org/packages/ec/23/2c9fe0c9c27f7f6cb865abcea8a4568f29f00acaeadfc6a37f6801f84cb4/torch-2.10.0-2-cp313-none-macosx_11_0_arm64.whl", hash = "sha256:e521c9f030a3774ed770a9c011751fb47c4d12029a3d6522116e48431f2ff89e", size = 79498254, upload-time = "2026-02-10T21:44:44.095Z" }, + { url = "https://files.pythonhosted.org/packages/16/ee/efbd56687be60ef9af0c9c0ebe106964c07400eade5b0af8902a1d8cd58c/torch-2.10.0-3-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:a1ff626b884f8c4e897c4c33782bdacdff842a165fee79817b1dd549fdda1321", size = 915510070, upload-time = "2026-03-11T14:16:39.386Z" }, + { url = "https://files.pythonhosted.org/packages/36/ab/7b562f1808d3f65414cd80a4f7d4bb00979d9355616c034c171249e1a303/torch-2.10.0-3-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:ac5bdcbb074384c66fa160c15b1ead77839e3fe7ed117d667249afce0acabfac", size = 915518691, upload-time = "2026-03-11T14:15:43.147Z" }, + { url = "https://files.pythonhosted.org/packages/b3/7a/abada41517ce0011775f0f4eacc79659bc9bc6c361e6bfe6f7052a6b9363/torch-2.10.0-3-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:98c01b8bb5e3240426dcde1446eed6f40c778091c8544767ef1168fc663a05a6", size = 915622781, upload-time = "2026-03-11T14:17:11.354Z" }, + { url = "https://files.pythonhosted.org/packages/ab/c6/4dfe238342ffdcec5aef1c96c457548762d33c40b45a1ab7033bb26d2ff2/torch-2.10.0-3-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:80b1b5bfe38eb0e9f5ff09f206dcac0a87aadd084230d4a36eea5ec5232c115b", size = 915627275, upload-time = "2026-03-11T14:16:11.325Z" }, + { url = "https://files.pythonhosted.org/packages/d8/f0/72bf18847f58f877a6a8acf60614b14935e2f156d942483af1ffc081aea0/torch-2.10.0-3-cp313-cp313t-manylinux_2_28_x86_64.whl", hash = "sha256:46b3574d93a2a8134b3f5475cfb98e2eb46771794c57015f6ad1fb795ec25e49", size = 915523474, upload-time = "2026-03-11T14:17:44.422Z" }, + { url = "https://files.pythonhosted.org/packages/0c/1a/c61f36cfd446170ec27b3a4984f072fd06dab6b5d7ce27e11adb35d6c838/torch-2.10.0-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:5276fa790a666ee8becaffff8acb711922252521b28fbce5db7db5cf9cb2026d", size = 145992962, upload-time = "2026-01-21T16:24:14.04Z" }, + { url = "https://files.pythonhosted.org/packages/b5/60/6662535354191e2d1555296045b63e4279e5a9dbad49acf55a5d38655a39/torch-2.10.0-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:aaf663927bcd490ae971469a624c322202a2a1e68936eb952535ca4cd3b90444", size = 915599237, upload-time = "2026-01-21T16:23:25.497Z" }, + { url = "https://files.pythonhosted.org/packages/40/b8/66bbe96f0d79be2b5c697b2e0b187ed792a15c6c4b8904613454651db848/torch-2.10.0-cp310-cp310-win_amd64.whl", hash = "sha256:a4be6a2a190b32ff5c8002a0977a25ea60e64f7ba46b1be37093c141d9c49aeb", size = 113720931, upload-time = "2026-01-21T16:24:23.743Z" }, + { url = "https://files.pythonhosted.org/packages/76/bb/d820f90e69cda6c8169b32a0c6a3ab7b17bf7990b8f2c680077c24a3c14c/torch-2.10.0-cp310-none-macosx_11_0_arm64.whl", hash = "sha256:35e407430795c8d3edb07a1d711c41cc1f9eaddc8b2f1cc0a165a6767a8fb73d", size = 79411450, upload-time = "2026-01-21T16:25:30.692Z" }, + { url = "https://files.pythonhosted.org/packages/78/89/f5554b13ebd71e05c0b002f95148033e730d3f7067f67423026cc9c69410/torch-2.10.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:3282d9febd1e4e476630a099692b44fdc214ee9bf8ee5377732d9d9dfe5712e4", size = 145992610, upload-time = "2026-01-21T16:25:26.327Z" }, + { url = "https://files.pythonhosted.org/packages/ae/30/a3a2120621bf9c17779b169fc17e3dc29b230c29d0f8222f499f5e159aa8/torch-2.10.0-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:a2f9edd8dbc99f62bc4dfb78af7bf89499bca3d753423ac1b4e06592e467b763", size = 915607863, upload-time = "2026-01-21T16:25:06.696Z" }, + { url = "https://files.pythonhosted.org/packages/6f/3d/c87b33c5f260a2a8ad68da7147e105f05868c281c63d65ed85aa4da98c66/torch-2.10.0-cp311-cp311-win_amd64.whl", hash = "sha256:29b7009dba4b7a1c960260fc8ac85022c784250af43af9fb0ebafc9883782ebd", size = 113723116, upload-time = "2026-01-21T16:25:21.916Z" }, + { url = "https://files.pythonhosted.org/packages/61/d8/15b9d9d3a6b0c01b883787bd056acbe5cc321090d4b216d3ea89a8fcfdf3/torch-2.10.0-cp311-none-macosx_11_0_arm64.whl", hash = "sha256:b7bd80f3477b830dd166c707c5b0b82a898e7b16f59a7d9d42778dd058272e8b", size = 79423461, upload-time = "2026-01-21T16:24:50.266Z" }, + { url = "https://files.pythonhosted.org/packages/cc/af/758e242e9102e9988969b5e621d41f36b8f258bb4a099109b7a4b4b50ea4/torch-2.10.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:5fd4117d89ffd47e3dcc71e71a22efac24828ad781c7e46aaaf56bf7f2796acf", size = 145996088, upload-time = "2026-01-21T16:24:44.171Z" }, + { url = "https://files.pythonhosted.org/packages/23/8e/3c74db5e53bff7ed9e34c8123e6a8bfef718b2450c35eefab85bb4a7e270/torch-2.10.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:787124e7db3b379d4f1ed54dd12ae7c741c16a4d29b49c0226a89bea50923ffb", size = 915711952, upload-time = "2026-01-21T16:23:53.503Z" }, + { url = "https://files.pythonhosted.org/packages/6e/01/624c4324ca01f66ae4c7cd1b74eb16fb52596dce66dbe51eff95ef9e7a4c/torch-2.10.0-cp312-cp312-win_amd64.whl", hash = "sha256:2c66c61f44c5f903046cc696d088e21062644cbe541c7f1c4eaae88b2ad23547", size = 113757972, upload-time = "2026-01-21T16:24:39.516Z" }, + { url = "https://files.pythonhosted.org/packages/c9/5c/dee910b87c4d5c0fcb41b50839ae04df87c1cfc663cf1b5fca7ea565eeaa/torch-2.10.0-cp312-none-macosx_11_0_arm64.whl", hash = "sha256:6d3707a61863d1c4d6ebba7be4ca320f42b869ee657e9b2c21c736bf17000294", size = 79498198, upload-time = "2026-01-21T16:24:34.704Z" }, + { url = "https://files.pythonhosted.org/packages/c9/6f/f2e91e34e3fcba2e3fc8d8f74e7d6c22e74e480bbd1db7bc8900fdf3e95c/torch-2.10.0-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:5c4d217b14741e40776dd7074d9006fd28b8a97ef5654db959d8635b2fe5f29b", size = 146004247, upload-time = "2026-01-21T16:24:29.335Z" }, + { url = "https://files.pythonhosted.org/packages/98/fb/5160261aeb5e1ee12ee95fe599d0541f7c976c3701d607d8fc29e623229f/torch-2.10.0-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:6b71486353fce0f9714ca0c9ef1c850a2ae766b409808acd58e9678a3edb7738", size = 915716445, upload-time = "2026-01-21T16:22:45.353Z" }, + { url = "https://files.pythonhosted.org/packages/6a/16/502fb1b41e6d868e8deb5b0e3ae926bbb36dab8ceb0d1b769b266ad7b0c3/torch-2.10.0-cp313-cp313-win_amd64.whl", hash = "sha256:c2ee399c644dc92ef7bc0d4f7e74b5360c37cdbe7c5ba11318dda49ffac2bc57", size = 113757050, upload-time = "2026-01-21T16:24:19.204Z" }, + { url = "https://files.pythonhosted.org/packages/1a/0b/39929b148f4824bc3ad6f9f72a29d4ad865bcf7ebfc2fa67584773e083d2/torch-2.10.0-cp313-cp313t-macosx_14_0_arm64.whl", hash = "sha256:3202429f58309b9fa96a614885eace4b7995729f44beb54d3e4a47773649d382", size = 79851305, upload-time = "2026-01-21T16:24:09.209Z" }, + { url = "https://files.pythonhosted.org/packages/d8/14/21fbce63bc452381ba5f74a2c0a959fdf5ad5803ccc0c654e752e0dbe91a/torch-2.10.0-cp313-cp313t-manylinux_2_28_aarch64.whl", hash = "sha256:aae1b29cd68e50a9397f5ee897b9c24742e9e306f88a807a27d617f07adb3bd8", size = 146005472, upload-time = "2026-01-21T16:22:29.022Z" }, + { url = "https://files.pythonhosted.org/packages/54/fd/b207d1c525cb570ef47f3e9f836b154685011fce11a2f444ba8a4084d042/torch-2.10.0-cp313-cp313t-manylinux_2_28_x86_64.whl", hash = "sha256:6021db85958db2f07ec94e1bc77212721ba4920c12a18dc552d2ae36a3eb163f", size = 915612644, upload-time = "2026-01-21T16:21:47.019Z" }, + { url = "https://files.pythonhosted.org/packages/36/53/0197f868c75f1050b199fe58f9bf3bf3aecac9b4e85cc9c964383d745403/torch-2.10.0-cp313-cp313t-win_amd64.whl", hash = "sha256:ff43db38af76fda183156153983c9a096fc4c78d0cd1e07b14a2314c7f01c2c8", size = 113997015, upload-time = "2026-01-21T16:23:00.767Z" }, + { url = "https://files.pythonhosted.org/packages/0e/13/e76b4d9c160e89fff48bf16b449ea324bda84745d2ab30294c37c2434c0d/torch-2.10.0-cp313-none-macosx_11_0_arm64.whl", hash = "sha256:cdf2a523d699b70d613243211ecaac14fe9c5df8a0b0a9c02add60fb2a413e0f", size = 79498248, upload-time = "2026-01-21T16:23:09.315Z" }, +] [[package]] name = "torch-geometric" @@ -3529,7 +3731,7 @@ source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, { name = "numpy", version = "2.4.3", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, - { name = "torch", marker = "sys_platform == 'never'" }, + { name = "torch" }, { name = "torchvision" }, ] sdist = { url = "https://files.pythonhosted.org/packages/6f/36/574c0c46e818533b78b3c09505211162918188325ab4165ef11a3f295755/torchprofile-0.0.4.tar.gz", hash = "sha256:96b6da17d752a06b02977e078aea95614893b31d4117dd5dcd081f30ce65611b", size = 4557, upload-time = "2021-06-22T04:58:03.592Z" } @@ -3545,7 +3747,7 @@ dependencies = [ { name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, { name = "numpy", version = "2.4.3", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, { name = "pillow" }, - { name = "torch", marker = "sys_platform == 'never'" }, + { name = "torch" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/50/ae/cbf727421eb73f1cf907fbe5788326a08f111b3f6b6ddca15426b53fec9a/torchvision-0.25.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:a95c47abb817d4e90ea1a8e57bd0d728e3e6b533b3495ae77d84d883c4d11f56", size = 1874919, upload-time = "2026-01-21T16:27:47.617Z" }, @@ -3638,6 +3840,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/03/b8/e484ef633af3887baeeb4b6ad12743363af7cce68ae51e938e00aaa0529d/transformers-4.57.6-py3-none-any.whl", hash = "sha256:4c9e9de11333ddfe5114bc872c9f370509198acf0b87a832a0ab9458e2bd0550", size = 11993498, upload-time = "2026-01-16T10:38:31.289Z" }, ] +[[package]] +name = "triton" +version = "3.6.0" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/8c/f7/f1c9d3424ab199ac53c2da567b859bcddbb9c9e7154805119f8bd95ec36f/triton-3.6.0-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:a6550fae429e0667e397e5de64b332d1e5695b73650ee75a6146e2e902770bea", size = 188105201, upload-time = "2026-01-20T16:00:29.272Z" }, + { url = "https://files.pythonhosted.org/packages/e0/12/b05ba554d2c623bffa59922b94b0775673de251f468a9609bc9e45de95e9/triton-3.6.0-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:e8e323d608e3a9bfcc2d9efcc90ceefb764a82b99dea12a86d643c72539ad5d3", size = 188214640, upload-time = "2026-01-20T16:00:35.869Z" }, + { url = "https://files.pythonhosted.org/packages/ab/a8/cdf8b3e4c98132f965f88c2313a4b493266832ad47fb52f23d14d4f86bb5/triton-3.6.0-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:74caf5e34b66d9f3a429af689c1c7128daba1d8208df60e81106b115c00d6fca", size = 188266850, upload-time = "2026-01-20T16:00:43.041Z" }, + { url = "https://files.pythonhosted.org/packages/f9/0b/37d991d8c130ce81a8728ae3c25b6e60935838e9be1b58791f5997b24a54/triton-3.6.0-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:10c7f76c6e72d2ef08df639e3d0d30729112f47a56b0c81672edc05ee5116ac9", size = 188289450, upload-time = "2026-01-20T16:00:49.136Z" }, + { url = "https://files.pythonhosted.org/packages/35/f8/9c66bfc55361ec6d0e4040a0337fb5924ceb23de4648b8a81ae9d33b2b38/triton-3.6.0-cp313-cp313t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:d002e07d7180fd65e622134fbd980c9a3d4211fb85224b56a0a0efbd422ab72f", size = 188400296, upload-time = "2026-01-20T16:00:56.042Z" }, +] + [[package]] name = "typing-extensions" version = "4.15.0"