From 3f89ea978476ddbb5a912a06b8f615f5394a1c47 Mon Sep 17 00:00:00 2001 From: Chenhan Yu Date: Fri, 27 Mar 2026 16:26:18 -0700 Subject: [PATCH 01/72] add: DFlash block diffusion speculative decoding Implement DFlash (Block Diffusion for Flash Speculative Decoding) as a new mode in ModelOpt's speculative decoding framework. Key architecture: - Feature Fusion: extract hidden states from uniformly sampled target model layers, project via FC layer - KV Injection: fused target features injected as K/V entries in every draft decoder layer's attention (not just first layer input) - Parallel Drafting: all tokens in a block predicted simultaneously using learnable mask embeddings and bidirectional within-block attention Files: - dflash/ module: DFlashModel, DFlashConfig, conversion, default config - plugins/hf_dflash.py: HFDFlashModel with DFlashAttention (KV injection), DFlashModule (feature fusion + decoder), training forward pass with random anchor sampling and exponential position decay loss - main.py: --mode dflash support in training script Reference: "DFlash: Block Diffusion for Flash Speculative Decoding" (arXiv:2602.06036) Signed-off-by: Chenhan Yu --- examples/speculative_decoding/main.py | 35 +- modelopt/torch/speculative/config.py | 48 ++ modelopt/torch/speculative/dflash/__init__.py | 20 + .../torch/speculative/dflash/conversion.py | 58 ++ .../speculative/dflash/default_config.py | 43 ++ .../torch/speculative/dflash/dflash_model.py | 34 + modelopt/torch/speculative/mode.py | 31 +- .../torch/speculative/plugins/__init__.py | 3 + .../torch/speculative/plugins/hf_dflash.py | 622 ++++++++++++++++++ 9 files changed, 890 insertions(+), 4 deletions(-) create mode 100644 modelopt/torch/speculative/dflash/__init__.py create mode 100644 modelopt/torch/speculative/dflash/conversion.py create mode 100644 modelopt/torch/speculative/dflash/default_config.py create mode 100644 modelopt/torch/speculative/dflash/dflash_model.py create mode 100644 modelopt/torch/speculative/plugins/hf_dflash.py diff --git a/examples/speculative_decoding/main.py b/examples/speculative_decoding/main.py index 3369d399c2..d08b648c36 100644 --- a/examples/speculative_decoding/main.py +++ b/examples/speculative_decoding/main.py @@ -104,7 +104,7 @@ class TrainingArguments(transformers.TrainingArguments): ) dataloader_drop_last: bool = field(default=True) bf16: bool = field(default=True) - mode: Literal["eagle3", "medusa"] = "eagle3" + mode: Literal["eagle3", "medusa", "dflash"] = "eagle3" estimate_ar: bool = field( default=False, metadata={"help": "Whether to estimate AR during training for logging."} ) @@ -144,6 +144,21 @@ class EagleArguments: ) +@dataclass +class DFlashArguments: + dflash_block_size: int = field( + default=16, metadata={"help": "Block size for DFlash parallel prediction."} + ) + dflash_num_layers: int = field( + default=5, metadata={"help": "Number of decoder layers in the DFlash draft module."} + ) + dflash_config: str = field(default=None, metadata={"help": "Path to dflash_config.json"}) + dflash_disable_torch_compile: bool = field( + default=False, + metadata={"help": "Disable torch.compile on DFlash forward/loss methods."}, + ) + + def train(): parser = transformers.HfArgumentParser( ( @@ -152,9 +167,10 @@ def train(): TrainingArguments, MedusaArguments, EagleArguments, + DFlashArguments, ) ) - model_args, data_args, training_args, medusa_args, eagle_args = ( + model_args, data_args, training_args, medusa_args, eagle_args, dflash_args = ( parser.parse_args_into_dataclasses() ) if not data_args.data_path and not data_args.offline_data_path: @@ -236,11 +252,24 @@ def train(): ) model.eagle_module.d2t = torch.load(data_args.draft_vocab_cache) print_rank_0(f"Loaded draft vocab cache from {data_args.draft_vocab_cache}.") + elif training_args.mode == "dflash": + custom_config = ( + json.load(open(dflash_args.dflash_config)) if dflash_args.dflash_config else {} + ) + custom_config.setdefault("num_hidden_layers", dflash_args.dflash_num_layers) + + config = { + "dflash_block_size": dflash_args.dflash_block_size, + "dflash_use_torch_compile": not dflash_args.dflash_disable_torch_compile, + "dflash_architecture_config": custom_config, + } + + mtsp.convert(model, [("dflash", config)]) else: raise Exception(f"{training_args.mode} is not supported!") print_rank_0("Loading dataset...") - if training_args.mode == "eagle3": + if training_args.mode in ("eagle3", "dflash"): data_module = make_eagle_supervised_data_module( tokenizer, data_args, train_len=training_args.training_seq_len ) diff --git a/modelopt/torch/speculative/config.py b/modelopt/torch/speculative/config.py index 69491c6599..59aa98db4b 100644 --- a/modelopt/torch/speculative/config.py +++ b/modelopt/torch/speculative/config.py @@ -46,6 +46,54 @@ } +def _get_dflash_default_config(): + from .dflash.default_config import default_dflash_config + + return default_dflash_config + + +DFLASH_DEFAULT_CFG = { + "algorithm": "dflash", + "config": { + "dflash_architecture_config": {}, # merged with default at convert time + }, +} + + +class DFlashConfig(ModeloptBaseConfig): + """DFlash config for block-wise parallel speculative decoding.""" + + dflash_block_size: int = ModeloptField( + default=16, + description="Block size for parallel prediction. Draft predicts this many tokens per block.", + ) + + dflash_freeze_base_model: bool = ModeloptField( + default=True, description="Whether to freeze base model during DFlash module training." + ) + + dflash_self_logit_distillation: bool = ModeloptField( + default=True, description="Whether to use logit distillation from base model." + ) + + dflash_loss_decay_factor: float = ModeloptField( + default=0.9, description="Decay factor for per-block loss weighting." + ) + + dflash_report_acc: bool = ModeloptField( + default=True, description="Whether to report eval accuracy." + ) + + dflash_architecture_config: dict = ModeloptField( + default={}, description="Config for the DFlash draft module architecture." + ) + + dflash_use_torch_compile: bool = ModeloptField( + default=True, + description="Whether to use torch.compile on DFlash forward/loss methods.", + ) + + class MedusaConfig(ModeloptBaseConfig): """Medusa config.""" diff --git a/modelopt/torch/speculative/dflash/__init__.py b/modelopt/torch/speculative/dflash/__init__.py new file mode 100644 index 0000000000..912b8d47a2 --- /dev/null +++ b/modelopt/torch/speculative/dflash/__init__.py @@ -0,0 +1,20 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""DFlash Optimization Method.""" + +from .conversion import * +from .default_config import * +from .dflash_model import * diff --git a/modelopt/torch/speculative/dflash/conversion.py b/modelopt/torch/speculative/dflash/conversion.py new file mode 100644 index 0000000000..943be90ca0 --- /dev/null +++ b/modelopt/torch/speculative/dflash/conversion.py @@ -0,0 +1,58 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""DFlash conversion/restore utilities.""" + +from torch import nn + +from modelopt.torch.opt.conversion import ModelLikeModule +from modelopt.torch.opt.dynamic import _DMRegistryCls +from modelopt.torch.opt.mode import ConvertReturnType, MetadataDict + +from ..config import DFlashConfig + +DFlashDMRegistry = _DMRegistryCls(prefix="DFlash") # global instance for the registry + + +def convert_to_dflash_model(model: nn.Module, config: DFlashConfig) -> ConvertReturnType: + """Convert the model to a DFlash model as per `config`.""" + model = model.init_modellike() if isinstance(model, ModelLikeModule) else model + + original_cls = type(model) + if original_cls not in DFlashDMRegistry: + for cls in DFlashDMRegistry._registry: + if issubclass(original_cls, cls): + DFlashDMRegistry.register({original_cls: "base_model_class"})(DFlashDMRegistry[cls]) + break + + # merge custom config with default config (lazy import to avoid circular) + from .default_config import default_dflash_config + + custom_config = config.dflash_architecture_config + config.dflash_architecture_config = {**default_dflash_config, **custom_config} + + dflash_model = DFlashDMRegistry.convert(model) + dflash_model.modify(config) + + metadata = {} + return dflash_model, metadata + + +def restore_dflash_model( + model: nn.Module, config: DFlashConfig, metadata: MetadataDict +) -> nn.Module: + """Function for restoring a previously converted model to a DFlash model.""" + assert not metadata, "No metadata expected!" + return convert_to_dflash_model(model, config)[0] diff --git a/modelopt/torch/speculative/dflash/default_config.py b/modelopt/torch/speculative/dflash/default_config.py new file mode 100644 index 0000000000..b552d4e4ad --- /dev/null +++ b/modelopt/torch/speculative/dflash/default_config.py @@ -0,0 +1,43 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Default DFlash architecture config.""" + +default_dflash_config = { + "hidden_act": "silu", + "torch_dtype": "bfloat16", + "position_embedding_type": "rope", + "rope_scaling": { + "factor": 8.0, + "low_freq_factor": 1.0, + "high_freq_factor": 4.0, + "original_max_position_embeddings": 8192, + "rope_type": "llama3", + }, + "rope_theta": 500000.0, + "num_hidden_layers": 5, + "intermediate_size": 14336, + "num_attention_heads": 32, + "num_key_value_heads": 8, + "initializer_range": 0.01, + "rms_norm_eps": 1e-05, + "mlp_bias": False, + "attention_bias": False, + "attention_dropout": 0.0, + "use_input_layernorm_in_first_layer": True, + "use_last_layernorm": True, + "has_lm_head": False, + "head_dim": 128, +} diff --git a/modelopt/torch/speculative/dflash/dflash_model.py b/modelopt/torch/speculative/dflash/dflash_model.py new file mode 100644 index 0000000000..0e81689a57 --- /dev/null +++ b/modelopt/torch/speculative/dflash/dflash_model.py @@ -0,0 +1,34 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""DFlash model to support block-wise parallel speculative decoding.""" + +from modelopt.torch.opt.dynamic import DynamicModule + + +class DFlashModel(DynamicModule): + """Base DFlash Model.""" + + def _setup(self): + self._register_temp_attribute("dflash_module", None) + + def modify(self, config): + """Base DFlash Model modify function. Child class should implement the details.""" + self.dflash_block_size = config.dflash_block_size + self.dflash_freeze_base_model = config.dflash_freeze_base_model + self.dflash_loss_decay_factor = config.dflash_loss_decay_factor + self.dflash_self_logit_distillation = config.dflash_self_logit_distillation + self.dflash_report_acc = config.dflash_report_acc + self.dflash_use_torch_compile = config.dflash_use_torch_compile diff --git a/modelopt/torch/speculative/mode.py b/modelopt/torch/speculative/mode.py index 866449e155..ae965354a9 100644 --- a/modelopt/torch/speculative/mode.py +++ b/modelopt/torch/speculative/mode.py @@ -23,7 +23,8 @@ _ModeRegistryCls, ) -from .config import EagleConfig, MedusaConfig +from .config import DFlashConfig, EagleConfig, MedusaConfig +from .dflash.conversion import convert_to_dflash_model, restore_dflash_model from .eagle.conversion import convert_to_eagle_model, restore_eagle_model from .medusa.conversion import convert_to_medusa_model, restore_medusa_model @@ -58,6 +59,34 @@ def restore(self) -> RestoreEntrypoint: return restore_medusa_model +@SpeculativeDecodingModeRegistry.register_mode +class DFlashModeDescriptor(ModeDescriptor): + """Class to describe the ``"dflash"`` mode. + + The properties of this mode can be inspected via the source code. + """ + + @property + def name(self) -> str: + """Returns the value (str representation) of the mode.""" + return "dflash" + + @property + def config_class(self) -> type[ModeloptBaseConfig]: + """Specifies the config class for the mode.""" + return DFlashConfig + + @property + def convert(self) -> ConvertEntrypoint: + """The mode's entrypoint for converting a model.""" + return convert_to_dflash_model + + @property + def restore(self) -> RestoreEntrypoint: + """The mode's entrypoint for restoring a model.""" + return restore_dflash_model + + @SpeculativeDecodingModeRegistry.register_mode class EagleModeDescriptor(ModeDescriptor): """Class to describe the ``"eagle"`` mode. diff --git a/modelopt/torch/speculative/plugins/__init__.py b/modelopt/torch/speculative/plugins/__init__.py index 5e3f4bff2f..d59aed37d5 100644 --- a/modelopt/torch/speculative/plugins/__init__.py +++ b/modelopt/torch/speculative/plugins/__init__.py @@ -31,3 +31,6 @@ with import_plugin("transformers"): from .transformers import * + +with import_plugin("hf_dflash"): + from .hf_dflash import * diff --git a/modelopt/torch/speculative/plugins/hf_dflash.py b/modelopt/torch/speculative/plugins/hf_dflash.py new file mode 100644 index 0000000000..28a9c421eb --- /dev/null +++ b/modelopt/torch/speculative/plugins/hf_dflash.py @@ -0,0 +1,622 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""DFlash speculative decoding plugin for HuggingFace models. + +DFlash (Block Diffusion for Flash Speculative Decoding) uses three key mechanisms: + +1. Feature Fusion: Extract hidden states from uniformly sampled target model layers, + concatenate and project via a lightweight FC layer. + +2. KV Injection: The fused features are injected as Key/Value entries into EVERY + draft model layer's attention. Unlike EAGLE-3 which only feeds features to the + first layer, DFlash ensures every layer has full target model context. + +3. Parallel Drafting: All tokens in a block are predicted in a single forward pass. + The draft model uses mask tokens for unknown positions and predicts them all + simultaneously via cross-entropy against target model logits. + +Reference: "DFlash: Block Diffusion for Flash Speculative Decoding" (arXiv:2602.06036) +""" + +import contextlib +import math +import random + +import torch +from torch import nn +from transformers import PretrainedConfig, PreTrainedModel +from transformers.models.llama.modeling_llama import ( + LlamaMLP, + LlamaRMSNorm, + LlamaRotaryEmbedding, + rotate_half, +) +from transformers.utils import ModelOutput + +from ..dflash.conversion import DFlashDMRegistry +from ..dflash.dflash_model import DFlashModel + +__all__ = ["HFDFlashModel"] + + +def build_target_layer_ids(num_target_layers, num_sample_layers): + """Select layers uniformly from the target model for feature extraction.""" + if num_sample_layers == 1: + return [num_target_layers // 2] + start = 1 + end = num_target_layers - 3 + span = end - start + return [round(start + (i * span) / (num_sample_layers - 1)) for i in range(num_sample_layers)] + + +class DFlashAttention(nn.Module): + """Attention with KV injection from target model features. + + Key difference from standard attention: K and V are computed from BOTH + the target model's fused features (context) AND the draft tokens (noise). + Q is computed only from draft tokens. + + Attention pattern: [k_ctx | k_noise] where draft queries attend to + both context KV and draft KV with appropriate masking. + """ + + def __init__(self, config, layer_idx): + super().__init__() + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.num_kv_heads = config.num_key_value_heads + self.head_dim = self.hidden_size // self.num_heads + self.layer_idx = layer_idx + + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) + self.k_proj = nn.Linear(self.hidden_size, self.num_kv_heads * self.head_dim, bias=False) + self.v_proj = nn.Linear(self.hidden_size, self.num_kv_heads * self.head_dim, bias=False) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) + + def forward( + self, + hidden_states, + target_hidden, + position_embeddings, + attention_mask=None, + ): + bsz, q_len, _ = hidden_states.shape + ctx_len = target_hidden.shape[1] + + # Q from draft tokens only + q = ( + self.q_proj(hidden_states) + .view(bsz, q_len, self.num_heads, self.head_dim) + .transpose(1, 2) + ) + + # K, V from both context (target features) and noise (draft tokens) + k_ctx = ( + self.k_proj(target_hidden) + .view(bsz, ctx_len, self.num_kv_heads, self.head_dim) + .transpose(1, 2) + ) + v_ctx = ( + self.v_proj(target_hidden) + .view(bsz, ctx_len, self.num_kv_heads, self.head_dim) + .transpose(1, 2) + ) + k_noise = ( + self.k_proj(hidden_states) + .view(bsz, q_len, self.num_kv_heads, self.head_dim) + .transpose(1, 2) + ) + v_noise = ( + self.v_proj(hidden_states) + .view(bsz, q_len, self.num_kv_heads, self.head_dim) + .transpose(1, 2) + ) + + # Apply rotary embeddings + if position_embeddings is not None: + cos, sin = position_embeddings + # Rotary for noise Q and K + q = self._apply_rotary( + q, cos[:, ctx_len : ctx_len + q_len], sin[:, ctx_len : ctx_len + q_len] + ) + k_noise = self._apply_rotary( + k_noise, cos[:, ctx_len : ctx_len + q_len], sin[:, ctx_len : ctx_len + q_len] + ) + k_ctx = self._apply_rotary(k_ctx, cos[:, :ctx_len], sin[:, :ctx_len]) + + # Concatenate context and noise KV + k = torch.cat([k_ctx, k_noise], dim=2) # [B, num_kv_heads, ctx+q, head_dim] + v = torch.cat([v_ctx, v_noise], dim=2) + + # GQA: expand KV heads to match Q heads + if self.num_kv_heads != self.num_heads: + n_rep = self.num_heads // self.num_kv_heads + k = k.repeat_interleave(n_rep, dim=1) + v = v.repeat_interleave(n_rep, dim=1) + + # Scaled dot product attention + attn_output = torch.nn.functional.scaled_dot_product_attention( + q, k, v, attn_mask=attention_mask, is_causal=False + ) + + attn_output = attn_output.transpose(1, 2).contiguous().view(bsz, q_len, self.hidden_size) + return self.o_proj(attn_output) + + @staticmethod + def _apply_rotary(x, cos, sin): + """Apply rotary positional embeddings (HF Llama convention).""" + cos = cos.unsqueeze(1) # [B, 1, seq, dim] + sin = sin.unsqueeze(1) + return (x * cos) + (rotate_half(x) * sin) + + +class DFlashDecoderLayer(nn.Module): + """Draft decoder layer with KV injection.""" + + def __init__(self, config, layer_idx): + super().__init__() + self.self_attn = DFlashAttention(config, layer_idx) + self.mlp = LlamaMLP(config) + self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward(self, hidden_states, target_hidden, position_embeddings, attention_mask=None): + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + hidden_states = self.self_attn( + hidden_states, target_hidden, position_embeddings, attention_mask + ) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states + + +class DFlashModule(nn.Module): + """DFlash draft module with feature fusion + KV injection. + + Architecture: + - FC layer fuses multi-layer target hidden states → hidden_size + - N decoder layers, each with KV injection from fused target features + - Shares embeddings and lm_head with the target model + """ + + def __init__(self, config): + super().__init__() + self.config = config + + # Feature fusion: project concatenated multi-layer hidden states + num_fused_layers = len(config.target_layer_ids) + self.fc = nn.Linear(num_fused_layers * config.hidden_size, config.hidden_size, bias=False) + self.hidden_norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + # Learnable mask embedding for unknown positions in blocks + self.mask_embedding = nn.Parameter(torch.randn(config.hidden_size) * 0.02) + + # Draft decoder layers with KV injection + self.layers = nn.ModuleList( + [DFlashDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = LlamaRotaryEmbedding(config=config) + + def fuse_target_features(self, target_hidden_states, target_layer_ids): + """Extract and fuse hidden states from sampled target layers.""" + selected = [target_hidden_states[lid + 1] for lid in target_layer_ids] + concatenated = torch.cat(selected, dim=-1) # [B, seq, num_layers * H] + fused = self.hidden_norm(self.fc(concatenated)) # [B, seq, H] + return fused + + def forward(self, hidden_states, target_hidden, attention_mask=None): + """Forward pass with KV injection. + + Args: + hidden_states: Draft token embeddings [B, noise_len, H]. + target_hidden: Fused target features [B, ctx_len, H]. + attention_mask: Attention mask for [ctx + noise] positions. + """ + total_len = target_hidden.shape[1] + hidden_states.shape[1] + position_ids = torch.arange(total_len, device=hidden_states.device).unsqueeze(0) + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + for layer in self.layers: + hidden_states = layer(hidden_states, target_hidden, position_embeddings, attention_mask) + + return self.norm(hidden_states) + + +def build_dflash_attention_mask(ctx_len, block_anchors, block_size, seq_len, device, dtype): + """Build DFlash attention mask for training. + + Each draft query can attend to: + 1. Context positions strictly before its block's anchor position + 2. All positions within the same block (bidirectional) + 3. Nothing from other blocks + + Args: + ctx_len: Number of context tokens (= seq_len of the input). + block_anchors: List of anchor positions (indices into the original sequence). + block_size: Number of tokens per block. + seq_len: Original sequence length. + device: Target device. + dtype: Target dtype. + + Returns: + Attention mask [1, 1, noise_len, ctx_len + noise_len]. + """ + num_blocks = len(block_anchors) + noise_len = num_blocks * block_size + + # Mask shape: [noise_len, ctx_len + noise_len] + # Q dimension = noise tokens, KV dimension = context + noise tokens + mask = torch.full((noise_len, ctx_len + noise_len), float("-inf"), device=device, dtype=dtype) + + for block_idx, anchor in enumerate(block_anchors): + block_start = block_idx * block_size + block_end = block_start + block_size + + # 1. Context: each block sees all context up to its anchor position + mask[block_start:block_end, : min(anchor, ctx_len)] = 0.0 + + # 2. Within-block: BIDIRECTIONAL attention (all positions see each other) + # This is key to DFlash's parallel drafting — all positions in a block + # have the same information and predict independently. + mask[block_start:block_end, ctx_len + block_start : ctx_len + block_end] = 0.0 + + return mask.unsqueeze(0).unsqueeze(0) # [1, 1, noise, ctx+noise] + + +@DFlashDMRegistry.register({PreTrainedModel: "hf.PreTrainedModel"}) +class HFDFlashModel(DFlashModel): + """DFlash Model for HuggingFace models with KV injection + parallel drafting.""" + + @property + def _base_model(self): + return self.get_submodule(self.base_model_path) + + @property + def _base_model_embeddings(self): + return self.get_submodule(self.base_model_embeddings_path) + + @property + def _base_model_lm_head(self): + return self.get_submodule(self.base_model_lm_head_path) + + @property + def _base_llm_config(self): + return ( + getattr(self.config, "text_config", None) + or getattr(self.config, "llm_config", None) + or self.config + ) + + def _find_base_model_parts(self): + for name, paths in { + "base_model_path": ["model.language_model", "model", "backbone"], + "base_model_embeddings_path": [ + "model.embed_tokens", + "backbone.embeddings", + "model.language_model.embed_tokens", + ], + "base_model_lm_head_path": ["lm_head", "language_model.lm_head"], + }.items(): + for path in paths: + try: + submodule = self.get_submodule(path) + assert isinstance(submodule, torch.nn.Module) + setattr(self, name, path) + break + except Exception: + continue + else: + raise ValueError(f"Part {name} not found in model") + + def modify(self, config): + """Initialize DFlash with feature fusion + KV injection.""" + super().modify(config) + + base_config = self._base_llm_config + self.dflash_config = PretrainedConfig.from_dict(config.dflash_architecture_config) + self.dflash_config.hidden_size = base_config.hidden_size + self.dflash_config.vocab_size = base_config.vocab_size + self.dflash_config.max_position_embeddings = base_config.max_position_embeddings + self.dflash_config.intermediate_size = getattr( + self.dflash_config, "intermediate_size", base_config.intermediate_size + ) + # head_dim for rotary embeddings: match base model + actual_head_dim = base_config.hidden_size // base_config.num_attention_heads + self.dflash_config.head_dim = actual_head_dim + if self.dflash_config._attn_implementation is None: + self.dflash_config._attn_implementation = "sdpa" + + # Determine target layer IDs for feature extraction + num_target_layers = base_config.num_hidden_layers + num_sample_layers = self.dflash_config.num_hidden_layers # sample as many as draft layers + self.target_layer_ids = build_target_layer_ids(num_target_layers, num_sample_layers) + self.dflash_config.target_layer_ids = self.target_layer_ids + + # Freeze base model + if self.dflash_freeze_base_model: + for param in self.parameters(): + param.requires_grad = False + + self._find_base_model_parts() + + # Build DFlash module + self.dflash_module = DFlashModule(self.dflash_config) + self.dflash_module.to(self._base_model.dtype).to( + next(self._base_model.layers[-1].parameters()).device + ) + + # Register hooks to collect hidden states from target layers + self._target_hidden_states = [] + for layer_idx, layer in enumerate(self._base_model.layers): + if layer_idx in self.target_layer_ids: + layer.register_forward_hook(self._collect_hidden_hook) + + self._cached_masks = {} + self.is_quantized = False + + def _collect_hidden_hook(self, module, input, output): + hidden = ( + output.clone().detach() + if isinstance(output, torch.Tensor) + else output[0].clone().detach() + ) + self._target_hidden_states.append(hidden) + + def _run_base_model(self, input_ids, attention_mask, labels=None, **kwargs): + """Run base model, collect hidden states from target layers.""" + self._target_hidden_states = [] + with torch.no_grad() if self.dflash_freeze_base_model else contextlib.nullcontext(): + outputs = super().forward( + input_ids=input_ids, + attention_mask=attention_mask, + output_hidden_states=True, + **kwargs, + ) + logits = outputs.logits + + # Fuse collected hidden states + target_hidden = self.dflash_module.fuse_target_features( + outputs.hidden_states, self.target_layer_ids + ) + self._target_hidden_states = [] + + base_loss = None + if labels is not None and not self.dflash_freeze_base_model: + loss_fct = nn.CrossEntropyLoss() + base_loss = loss_fct(logits.view(-1, logits.shape[-1]), labels.view(-1)) + + return target_hidden, logits, base_loss + + def forward( + self, + input_ids=None, + attention_mask=None, + position_ids=None, + past_key_values=None, + inputs_embeds=None, + labels=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + cache_position=None, + **kwargs, + ): + """DFlash training forward pass. + + 1. Run base model → get hidden states from sampled layers + logits + 2. Fuse multi-layer hidden states via FC projection + 3. Sample random anchors from the sequence → form blocks + 4. Create noise input (mask tokens + anchor tokens at block starts) + 5. Run draft model with KV injection (fused features as K/V in every layer) + 6. Compute CE loss with exponential position decay + """ + if not self.training: + return super().forward( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + labels=labels, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + cache_position=cache_position, + **kwargs, + ) + + batch_size, seq_len = input_ids.shape + block_size = self.dflash_block_size + device = input_ids.device + + # 1. Run base model → fused target features + logits + target_hidden, base_logits, base_loss = self._run_base_model( + input_ids=input_ids, + attention_mask=attention_mask, + labels=labels, + **kwargs, + ) + + # 2. Sample anchor positions (start of each block) + # Anchors are random positions in the valid (non-padding) region + if attention_mask is not None: + actual_len = attention_mask.sum(dim=1).min().int().item() + else: + actual_len = seq_len + + # Number of blocks we can fit: leave room for block_size predictions after each anchor + max_anchor = actual_len - block_size + max_anchor = max(max_anchor, 1) + num_blocks = max(1, max_anchor // block_size) + # Sample anchor positions uniformly + anchors = sorted( + random.sample(range(1, max(2, max_anchor)), min(num_blocks, max(1, max_anchor - 1))) + ) + + noise_len = len(anchors) * block_size + + # 3. Create noise embeddings: anchor token at position 0, MASK for positions 1..B-1 + # This matches inference where only the anchor token is known. + # The draft model must predict all other positions from context KV alone. + noise_embeds = ( + self.dflash_module.mask_embedding.unsqueeze(0) + .unsqueeze(0) + .expand(batch_size, noise_len, -1) + .clone() + ) + for b, anchor in enumerate(anchors): + # Only position 0 of each block gets the real anchor token embedding + anchor_embed = self._base_model_embeddings(input_ids[:, anchor : anchor + 1]) + noise_embeds[:, b * block_size : b * block_size + 1] = anchor_embed + + # 4. Build attention mask + dtype = target_hidden.dtype + attn_mask = build_dflash_attention_mask( + seq_len, anchors, block_size, seq_len, device, dtype + ) + + # 5. Run DFlash draft model with KV injection + draft_hidden = self.dflash_module( + hidden_states=noise_embeds, + target_hidden=target_hidden, + attention_mask=attn_mask, + ) + draft_logits = self._base_model_lm_head(draft_hidden) # [B, noise_len, V] + + # 6. Compute loss with exponential position decay + # For block b, position k: target is token at anchors[b] + k + total_loss = torch.tensor(0.0, device=device, dtype=dtype) + total_correct = 0 + total_valid = 0 + decay_gamma = block_size # decay rate + + for b, anchor in enumerate(anchors): + for k in range(1, block_size): # skip position 0 (anchor itself) + target_pos = anchor + k + if target_pos >= seq_len: + break + + draft_idx = b * block_size + k + logit = draft_logits[:, draft_idx, :] # [B, V] + + # Target: base_logits[anchor + k - 1] predicts token at position anchor + k + # This is the base model's autoregressive prediction for this position + target = base_logits[:, target_pos - 1, :].detach() + + # Logit distillation loss + target_soft = torch.softmax(target, dim=-1) + draft_logsoft = torch.log_softmax(logit, dim=-1) + kd_loss = -torch.sum(target_soft * draft_logsoft, dim=-1).mean() + + # Position decay weight + weight = math.exp(-(k - 1) / decay_gamma) + total_loss = total_loss + weight * kd_loss + + # Accuracy: does draft predict the same token as the base model? + target_tok = input_ids[:, target_pos] + draft_tok = logit.detach().argmax(dim=-1) + total_correct += (target_tok == draft_tok).sum().item() + total_valid += batch_size + + # Normalize by number of predictions + num_predictions = sum(min(block_size - 1, seq_len - a - 1) for a in anchors) + if num_predictions > 0: + total_loss = total_loss / num_predictions + + accuracy = total_correct / max(total_valid, 1) + final_loss = (base_loss or 0) + total_loss + + return ModelOutput( + loss=final_loss, + logits=base_logits, + hidden_states=target_hidden, + train_acc=[[accuracy]], + ) + + @torch.no_grad() + def pseudo_speculative_generate(self, input_ids, steps=1): + """Generate draft tokens using DFlash parallel block prediction. + + Args: + input_ids: Prompt token IDs [B, seq_len]. + steps: Number of blocks to generate. + + Returns: + base_token: Next token from base model [B, 1]. + draft_tokens: Draft tokens [B, steps * block_size] or None. + """ + # Run base model + self._target_hidden_states = [] + base_outputs = super().forward(input_ids=input_ids, output_hidden_states=True) + base_logits = base_outputs.logits + base_token = base_logits[:, -1:, :].argmax(dim=-1).to(input_ids.device) + + if steps < 1: + return base_token, None + + # Fuse target features + target_hidden = self.dflash_module.fuse_target_features( + base_outputs.hidden_states, self.target_layer_ids + ) + + block_size = self.dflash_block_size + batch_size = input_ids.shape[0] + seq_len = input_ids.shape[1] + device = input_ids.device + dtype = target_hidden.dtype + + all_draft_tokens = [] + current_token = base_token # [B, 1] + + for step in range(steps): + # Build noise: anchor token at position 0, mask embedding for rest + noise_embeds = ( + self.dflash_module.mask_embedding.unsqueeze(0) + .unsqueeze(0) + .expand(batch_size, block_size, -1) + .clone() + ) + anchor_embed = self._base_model_embeddings(current_token) # [B, 1, H] + noise_embeds[:, :1] = anchor_embed + + # Attention mask: block sees all context, bidirectional within block + anchor = seq_len + step * block_size + attn_mask = build_dflash_attention_mask( + seq_len, [anchor], block_size, seq_len + (step + 1) * block_size, device, dtype + ) + + # Run draft with KV injection + draft_hidden = self.dflash_module( + hidden_states=noise_embeds, + target_hidden=target_hidden, + attention_mask=attn_mask, + ) + draft_logits = self._base_model_lm_head(draft_hidden) + block_tokens = draft_logits.argmax(dim=-1) # [B, block_size] + all_draft_tokens.append(block_tokens[:, 1:]) # skip anchor position + + # Next block starts with last predicted token + current_token = block_tokens[:, -1:] + + draft_tokens = torch.cat(all_draft_tokens, dim=-1) + return base_token, draft_tokens From 190cb3a4f430e3ccde0123f91988980869abd23e Mon Sep 17 00:00:00 2001 From: Chenhan Yu Date: Fri, 27 Mar 2026 20:29:21 -0700 Subject: [PATCH 02/72] fix: rewrite DFlash to match SpecForge reference Signed-off-by: Chenhan Yu --- .../torch/speculative/plugins/hf_dflash.py | 540 +++++++----------- 1 file changed, 202 insertions(+), 338 deletions(-) diff --git a/modelopt/torch/speculative/plugins/hf_dflash.py b/modelopt/torch/speculative/plugins/hf_dflash.py index 28a9c421eb..d8c469f36c 100644 --- a/modelopt/torch/speculative/plugins/hf_dflash.py +++ b/modelopt/torch/speculative/plugins/hf_dflash.py @@ -15,27 +15,19 @@ """DFlash speculative decoding plugin for HuggingFace models. -DFlash (Block Diffusion for Flash Speculative Decoding) uses three key mechanisms: +Matches the reference SpecForge implementation (github.com/sgl-project/SpecForge PR #415). -1. Feature Fusion: Extract hidden states from uniformly sampled target model layers, - concatenate and project via a lightweight FC layer. - -2. KV Injection: The fused features are injected as Key/Value entries into EVERY - draft model layer's attention. Unlike EAGLE-3 which only feeds features to the - first layer, DFlash ensures every layer has full target model context. - -3. Parallel Drafting: All tokens in a block are predicted in a single forward pass. - The draft model uses mask tokens for unknown positions and predicts them all - simultaneously via cross-entropy against target model logits. +Architecture: +- Feature Fusion: multi-layer target hidden states → FC + RMSNorm +- KV Injection: fused features as K/V in every draft layer with QK-norm +- Parallel Drafting: mask_token_id for unknown positions, causal within blocks +- Loss: hard CE on input_ids[i] (position i predicts token i) Reference: "DFlash: Block Diffusion for Flash Speculative Decoding" (arXiv:2602.06036) """ -import contextlib -import math -import random - import torch +import torch.nn.functional as F from torch import nn from transformers import PretrainedConfig, PreTrainedModel from transformers.models.llama.modeling_llama import ( @@ -52,116 +44,93 @@ __all__ = ["HFDFlashModel"] -def build_target_layer_ids(num_target_layers, num_sample_layers): +def build_target_layer_ids(num_target_layers, num_draft_layers): """Select layers uniformly from the target model for feature extraction.""" - if num_sample_layers == 1: + if num_draft_layers == 1: return [num_target_layers // 2] start = 1 end = num_target_layers - 3 span = end - start - return [round(start + (i * span) / (num_sample_layers - 1)) for i in range(num_sample_layers)] + return [ + int(round(start + (i * span) / (num_draft_layers - 1))) for i in range(num_draft_layers) + ] -class DFlashAttention(nn.Module): - """Attention with KV injection from target model features. +def apply_rotary_pos_emb(q, k, cos, sin): + """Apply RoPE. Q uses last q_len positions, K uses all positions.""" + cos = cos.unsqueeze(1) # [B, 1, seq, dim] + sin = sin.unsqueeze(1) + q_len = q.size(2) + q_embed = (q * cos[:, :, -q_len:, :]) + (rotate_half(q) * sin[:, :, -q_len:, :]) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed - Key difference from standard attention: K and V are computed from BOTH - the target model's fused features (context) AND the draft tokens (noise). - Q is computed only from draft tokens. - Attention pattern: [k_ctx | k_noise] where draft queries attend to - both context KV and draft KV with appropriate masking. - """ +class DFlashAttention(nn.Module): + """Attention with KV injection, matching SpecForge Qwen3DFlashAttention.""" def __init__(self, config, layer_idx): super().__init__() - self.hidden_size = config.hidden_size + self.config = config + self.layer_idx = layer_idx + self.head_dim = getattr( + config, "head_dim", config.hidden_size // config.num_attention_heads + ) self.num_heads = config.num_attention_heads self.num_kv_heads = config.num_key_value_heads - self.head_dim = self.hidden_size // self.num_heads - self.layer_idx = layer_idx + self.scaling = self.head_dim**-0.5 + self.is_causal = False - self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) - self.k_proj = nn.Linear(self.hidden_size, self.num_kv_heads * self.head_dim, bias=False) - self.v_proj = nn.Linear(self.hidden_size, self.num_kv_heads * self.head_dim, bias=False) - self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) + self.q_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=False) + self.k_proj = nn.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False) + self.v_proj = nn.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, config.hidden_size, bias=False) - def forward( - self, - hidden_states, - target_hidden, - position_embeddings, - attention_mask=None, - ): + # QK norm (matches Qwen3DFlashAttention) + self.q_norm = LlamaRMSNorm(self.head_dim, eps=config.rms_norm_eps) + self.k_norm = LlamaRMSNorm(self.head_dim, eps=config.rms_norm_eps) + + def forward(self, hidden_states, target_hidden, position_embeddings, attention_mask=None): bsz, q_len, _ = hidden_states.shape ctx_len = target_hidden.shape[1] - # Q from draft tokens only - q = ( - self.q_proj(hidden_states) - .view(bsz, q_len, self.num_heads, self.head_dim) - .transpose(1, 2) - ) + # Q from noise only, with QK-norm + q = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim) + q = self.q_norm(q).transpose(1, 2) - # K, V from both context (target features) and noise (draft tokens) - k_ctx = ( - self.k_proj(target_hidden) - .view(bsz, ctx_len, self.num_kv_heads, self.head_dim) - .transpose(1, 2) + # K from context + noise, with QK-norm + k_ctx = self.k_proj(target_hidden) + k_noise = self.k_proj(hidden_states) + k = torch.cat([k_ctx, k_noise], dim=1).view( + bsz, ctx_len + q_len, self.num_kv_heads, self.head_dim ) - v_ctx = ( - self.v_proj(target_hidden) - .view(bsz, ctx_len, self.num_kv_heads, self.head_dim) + k = self.k_norm(k).transpose(1, 2) + + # V from context + noise (no norm) + v_ctx = self.v_proj(target_hidden) + v_noise = self.v_proj(hidden_states) + v = ( + torch.cat([v_ctx, v_noise], dim=1) + .view(bsz, ctx_len + q_len, self.num_kv_heads, self.head_dim) .transpose(1, 2) ) - k_noise = ( - self.k_proj(hidden_states) - .view(bsz, q_len, self.num_kv_heads, self.head_dim) - .transpose(1, 2) - ) - v_noise = ( - self.v_proj(hidden_states) - .view(bsz, q_len, self.num_kv_heads, self.head_dim) - .transpose(1, 2) - ) - - # Apply rotary embeddings - if position_embeddings is not None: - cos, sin = position_embeddings - # Rotary for noise Q and K - q = self._apply_rotary( - q, cos[:, ctx_len : ctx_len + q_len], sin[:, ctx_len : ctx_len + q_len] - ) - k_noise = self._apply_rotary( - k_noise, cos[:, ctx_len : ctx_len + q_len], sin[:, ctx_len : ctx_len + q_len] - ) - k_ctx = self._apply_rotary(k_ctx, cos[:, :ctx_len], sin[:, :ctx_len]) - # Concatenate context and noise KV - k = torch.cat([k_ctx, k_noise], dim=2) # [B, num_kv_heads, ctx+q, head_dim] - v = torch.cat([v_ctx, v_noise], dim=2) + # RoPE: applied to full 2L positions, Q gets last q_len, K gets all + cos, sin = position_embeddings + q, k = apply_rotary_pos_emb(q, k, cos, sin) - # GQA: expand KV heads to match Q heads + # GQA expand if self.num_kv_heads != self.num_heads: n_rep = self.num_heads // self.num_kv_heads k = k.repeat_interleave(n_rep, dim=1) v = v.repeat_interleave(n_rep, dim=1) - # Scaled dot product attention - attn_output = torch.nn.functional.scaled_dot_product_attention( - q, k, v, attn_mask=attention_mask, is_causal=False + attn_output = F.scaled_dot_product_attention( + q, k, v, attn_mask=attention_mask, is_causal=False, scale=self.scaling ) - - attn_output = attn_output.transpose(1, 2).contiguous().view(bsz, q_len, self.hidden_size) + attn_output = attn_output.transpose(1, 2).contiguous().view(bsz, q_len, -1) return self.o_proj(attn_output) - @staticmethod - def _apply_rotary(x, cos, sin): - """Apply rotary positional embeddings (HF Llama convention).""" - cos = cos.unsqueeze(1) # [B, 1, seq, dim] - sin = sin.unsqueeze(1) - return (x * cos) + (rotate_half(x) * sin) - class DFlashDecoderLayer(nn.Module): """Draft decoder layer with KV injection.""" @@ -185,55 +154,33 @@ def forward(self, hidden_states, target_hidden, position_embeddings, attention_m hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states - return hidden_states class DFlashModule(nn.Module): - """DFlash draft module with feature fusion + KV injection. - - Architecture: - - FC layer fuses multi-layer target hidden states → hidden_size - - N decoder layers, each with KV injection from fused target features - - Shares embeddings and lm_head with the target model - """ + """DFlash draft module matching SpecForge DFlashDraftModel.""" def __init__(self, config): super().__init__() self.config = config + self.block_size = config.block_size - # Feature fusion: project concatenated multi-layer hidden states + # Feature fusion num_fused_layers = len(config.target_layer_ids) self.fc = nn.Linear(num_fused_layers * config.hidden_size, config.hidden_size, bias=False) self.hidden_norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - # Learnable mask embedding for unknown positions in blocks - self.mask_embedding = nn.Parameter(torch.randn(config.hidden_size) * 0.02) - - # Draft decoder layers with KV injection + # Decoder layers self.layers = nn.ModuleList( [DFlashDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] ) self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.rotary_emb = LlamaRotaryEmbedding(config=config) - def fuse_target_features(self, target_hidden_states, target_layer_ids): - """Extract and fuse hidden states from sampled target layers.""" - selected = [target_hidden_states[lid + 1] for lid in target_layer_ids] - concatenated = torch.cat(selected, dim=-1) # [B, seq, num_layers * H] - fused = self.hidden_norm(self.fc(concatenated)) # [B, seq, H] - return fused - - def forward(self, hidden_states, target_hidden, attention_mask=None): - """Forward pass with KV injection. - - Args: - hidden_states: Draft token embeddings [B, noise_len, H]. - target_hidden: Fused target features [B, ctx_len, H]. - attention_mask: Attention mask for [ctx + noise] positions. - """ - total_len = target_hidden.shape[1] + hidden_states.shape[1] - position_ids = torch.arange(total_len, device=hidden_states.device).unsqueeze(0) + def forward(self, noise_embedding, target_hidden, position_ids, attention_mask=None): + """Forward matching SpecForge DFlashDraftModel.forward.""" + hidden_states = noise_embedding + target_hidden = self.hidden_norm(self.fc(target_hidden)) position_embeddings = self.rotary_emb(hidden_states, position_ids) for layer in self.layers: @@ -242,50 +189,43 @@ def forward(self, hidden_states, target_hidden, attention_mask=None): return self.norm(hidden_states) -def build_dflash_attention_mask(ctx_len, block_anchors, block_size, seq_len, device, dtype): - """Build DFlash attention mask for training. +def create_dflash_attention_mask(seq_len, block_size, device, dtype): + """Create [L, 2L] attention mask matching SpecForge. - Each draft query can attend to: - 1. Context positions strictly before its block's anchor position - 2. All positions within the same block (bidirectional) - 3. Nothing from other blocks + Context (cols 0..L-1): Block B sees blocks 0..B-1 (strictly previous). + Noise (cols L..2L-1): causal within same block only. + """ + indices = torch.arange(seq_len, device=device) + block_ids = indices // block_size - Args: - ctx_len: Number of context tokens (= seq_len of the input). - block_anchors: List of anchor positions (indices into the original sequence). - block_size: Number of tokens per block. - seq_len: Original sequence length. - device: Target device. - dtype: Target dtype. + q_block_ids = block_ids.unsqueeze(1) # [L, 1] + k_block_ids = block_ids.unsqueeze(0) # [1, L] - Returns: - Attention mask [1, 1, noise_len, ctx_len + noise_len]. - """ - num_blocks = len(block_anchors) - noise_len = num_blocks * block_size + ctx_mask = k_block_ids < q_block_ids + same_block = q_block_ids == k_block_ids + causal = indices.unsqueeze(0) >= indices.unsqueeze(1) + noise_mask = same_block & causal - # Mask shape: [noise_len, ctx_len + noise_len] - # Q dimension = noise tokens, KV dimension = context + noise tokens - mask = torch.full((noise_len, ctx_len + noise_len), float("-inf"), device=device, dtype=dtype) + full_mask_bool = torch.cat([ctx_mask, noise_mask], dim=1) - for block_idx, anchor in enumerate(block_anchors): - block_start = block_idx * block_size - block_end = block_start + block_size + full_mask = torch.zeros(seq_len, 2 * seq_len, device=device, dtype=dtype) + full_mask.masked_fill_(~full_mask_bool, torch.finfo(dtype).min) - # 1. Context: each block sees all context up to its anchor position - mask[block_start:block_end, : min(anchor, ctx_len)] = 0.0 + return full_mask.unsqueeze(0).unsqueeze(0) # [1, 1, L, 2L] - # 2. Within-block: BIDIRECTIONAL attention (all positions see each other) - # This is key to DFlash's parallel drafting — all positions in a block - # have the same information and predict independently. - mask[block_start:block_end, ctx_len + block_start : ctx_len + block_end] = 0.0 - return mask.unsqueeze(0).unsqueeze(0) # [1, 1, noise, ctx+noise] +def create_dflash_loss_mask(seq_len, block_size, device): + """Create loss mask: exclude Block 0 and block starts.""" + positions = torch.arange(seq_len, device=device) + block_ids = positions // block_size + is_block_0 = block_ids == 0 + is_block_start = (positions % block_size) == 0 + return (~is_block_0 & ~is_block_start).float() @DFlashDMRegistry.register({PreTrainedModel: "hf.PreTrainedModel"}) class HFDFlashModel(DFlashModel): - """DFlash Model for HuggingFace models with KV injection + parallel drafting.""" + """DFlash Model matching SpecForge OnlineDFlashModel.""" @property def _base_model(self): @@ -329,7 +269,7 @@ def _find_base_model_parts(self): raise ValueError(f"Part {name} not found in model") def modify(self, config): - """Initialize DFlash with feature fusion + KV injection.""" + """Initialize DFlash draft module.""" super().modify(config) base_config = self._base_llm_config @@ -340,18 +280,22 @@ def modify(self, config): self.dflash_config.intermediate_size = getattr( self.dflash_config, "intermediate_size", base_config.intermediate_size ) - # head_dim for rotary embeddings: match base model - actual_head_dim = base_config.hidden_size // base_config.num_attention_heads - self.dflash_config.head_dim = actual_head_dim + self.dflash_config.head_dim = base_config.hidden_size // base_config.num_attention_heads + self.dflash_config.block_size = self.dflash_block_size if self.dflash_config._attn_implementation is None: - self.dflash_config._attn_implementation = "sdpa" + self.dflash_config._attn_implementation = "eager" - # Determine target layer IDs for feature extraction + # Target layer IDs num_target_layers = base_config.num_hidden_layers - num_sample_layers = self.dflash_config.num_hidden_layers # sample as many as draft layers - self.target_layer_ids = build_target_layer_ids(num_target_layers, num_sample_layers) + num_draft_layers = self.dflash_config.num_hidden_layers + self.target_layer_ids = build_target_layer_ids(num_target_layers, num_draft_layers) self.dflash_config.target_layer_ids = self.target_layer_ids + # mask_token_id: prefer pad_token_id, fallback to eos + self.mask_token_id = getattr(base_config, "pad_token_id", None) or getattr( + base_config, "eos_token_id", 0 + ) + # Freeze base model if self.dflash_freeze_base_model: for param in self.parameters(): @@ -359,54 +303,13 @@ def modify(self, config): self._find_base_model_parts() - # Build DFlash module self.dflash_module = DFlashModule(self.dflash_config) self.dflash_module.to(self._base_model.dtype).to( next(self._base_model.layers[-1].parameters()).device ) - # Register hooks to collect hidden states from target layers - self._target_hidden_states = [] - for layer_idx, layer in enumerate(self._base_model.layers): - if layer_idx in self.target_layer_ids: - layer.register_forward_hook(self._collect_hidden_hook) - - self._cached_masks = {} self.is_quantized = False - def _collect_hidden_hook(self, module, input, output): - hidden = ( - output.clone().detach() - if isinstance(output, torch.Tensor) - else output[0].clone().detach() - ) - self._target_hidden_states.append(hidden) - - def _run_base_model(self, input_ids, attention_mask, labels=None, **kwargs): - """Run base model, collect hidden states from target layers.""" - self._target_hidden_states = [] - with torch.no_grad() if self.dflash_freeze_base_model else contextlib.nullcontext(): - outputs = super().forward( - input_ids=input_ids, - attention_mask=attention_mask, - output_hidden_states=True, - **kwargs, - ) - logits = outputs.logits - - # Fuse collected hidden states - target_hidden = self.dflash_module.fuse_target_features( - outputs.hidden_states, self.target_layer_ids - ) - self._target_hidden_states = [] - - base_loss = None - if labels is not None and not self.dflash_freeze_base_model: - loss_fct = nn.CrossEntropyLoss() - base_loss = loss_fct(logits.view(-1, logits.shape[-1]), labels.view(-1)) - - return target_hidden, logits, base_loss - def forward( self, input_ids=None, @@ -421,15 +324,7 @@ def forward( cache_position=None, **kwargs, ): - """DFlash training forward pass. - - 1. Run base model → get hidden states from sampled layers + logits - 2. Fuse multi-layer hidden states via FC projection - 3. Sample random anchors from the sequence → form blocks - 4. Create noise input (mask tokens + anchor tokens at block starts) - 5. Run draft model with KV injection (fused features as K/V in every layer) - 6. Compute CE loss with exponential position decay - """ + """Training forward matching SpecForge OnlineDFlashModel.forward.""" if not self.training: return super().forward( input_ids=input_ids, @@ -445,128 +340,88 @@ def forward( **kwargs, ) - batch_size, seq_len = input_ids.shape + bsz, seq_len = input_ids.shape block_size = self.dflash_block_size device = input_ids.device - # 1. Run base model → fused target features + logits - target_hidden, base_logits, base_loss = self._run_base_model( - input_ids=input_ids, - attention_mask=attention_mask, - labels=labels, - **kwargs, - ) + # 1. Run base model → raw multi-layer hidden states + with torch.no_grad(): + base_outputs = super().forward( + input_ids=input_ids, + attention_mask=attention_mask, + output_hidden_states=True, + **kwargs, + ) + + # Extract and concatenate target layer hidden states + offset = 1 + selected = [base_outputs.hidden_states[lid + offset] for lid in self.target_layer_ids] + target_hidden = torch.cat(selected, dim=-1) # [B, seq, num_layers * H] - # 2. Sample anchor positions (start of each block) - # Anchors are random positions in the valid (non-padding) region + # 2. Truncate to multiple of block_size + n_blocks = seq_len // block_size + effective_len = n_blocks * block_size + input_ids_trunc = input_ids[:, :effective_len] + target_hidden = target_hidden[:, :effective_len, :] if attention_mask is not None: - actual_len = attention_mask.sum(dim=1).min().int().item() + loss_mask_input = attention_mask[:, :effective_len].float() else: - actual_len = seq_len - - # Number of blocks we can fit: leave room for block_size predictions after each anchor - max_anchor = actual_len - block_size - max_anchor = max(max_anchor, 1) - num_blocks = max(1, max_anchor // block_size) - # Sample anchor positions uniformly - anchors = sorted( - random.sample(range(1, max(2, max_anchor)), min(num_blocks, max(1, max_anchor - 1))) - ) + loss_mask_input = torch.ones(bsz, effective_len, device=device) - noise_len = len(anchors) * block_size + # 3. Prepare noise: mask_token_id everywhere, real token at block starts + positions = torch.arange(effective_len, device=device) + is_block_start = (positions % block_size) == 0 + noise_input_ids = torch.full_like(input_ids_trunc, self.mask_token_id) + noise_input_ids[:, is_block_start] = input_ids_trunc[:, is_block_start] + noise_embedding = self._base_model_embeddings(noise_input_ids) - # 3. Create noise embeddings: anchor token at position 0, MASK for positions 1..B-1 - # This matches inference where only the anchor token is known. - # The draft model must predict all other positions from context KV alone. - noise_embeds = ( - self.dflash_module.mask_embedding.unsqueeze(0) - .unsqueeze(0) - .expand(batch_size, noise_len, -1) - .clone() - ) - for b, anchor in enumerate(anchors): - # Only position 0 of each block gets the real anchor token embedding - anchor_embed = self._base_model_embeddings(input_ids[:, anchor : anchor + 1]) - noise_embeds[:, b * block_size : b * block_size + 1] = anchor_embed + # 4. Position IDs: [0..L-1, 0..L-1] + pos_seq = torch.arange(effective_len, device=device) + position_ids_2l = torch.cat([pos_seq, pos_seq]).unsqueeze(0).expand(bsz, -1) - # 4. Build attention mask + # 5. Attention mask: [1, 1, L, 2L] dtype = target_hidden.dtype - attn_mask = build_dflash_attention_mask( - seq_len, anchors, block_size, seq_len, device, dtype - ) + dflash_attn_mask = create_dflash_attention_mask(effective_len, block_size, device, dtype) - # 5. Run DFlash draft model with KV injection - draft_hidden = self.dflash_module( - hidden_states=noise_embeds, + # 6. Draft forward + hidden = self.dflash_module( + noise_embedding=noise_embedding, target_hidden=target_hidden, - attention_mask=attn_mask, + position_ids=position_ids_2l, + attention_mask=dflash_attn_mask, ) - draft_logits = self._base_model_lm_head(draft_hidden) # [B, noise_len, V] - - # 6. Compute loss with exponential position decay - # For block b, position k: target is token at anchors[b] + k - total_loss = torch.tensor(0.0, device=device, dtype=dtype) - total_correct = 0 - total_valid = 0 - decay_gamma = block_size # decay rate - - for b, anchor in enumerate(anchors): - for k in range(1, block_size): # skip position 0 (anchor itself) - target_pos = anchor + k - if target_pos >= seq_len: - break - draft_idx = b * block_size + k - logit = draft_logits[:, draft_idx, :] # [B, V] + # 7. Loss: hard CE, position i predicts token i + logits = self._base_model_lm_head(hidden) + dflash_loss_mask = create_dflash_loss_mask(effective_len, block_size, device) + combined_mask = loss_mask_input * dflash_loss_mask.unsqueeze(0) - # Target: base_logits[anchor + k - 1] predicts token at position anchor + k - # This is the base model's autoregressive prediction for this position - target = base_logits[:, target_pos - 1, :].detach() + logits_flat = logits.reshape(-1, logits.size(-1)) + labels_flat = input_ids_trunc.reshape(-1) + mask_flat = combined_mask.reshape(-1) - # Logit distillation loss - target_soft = torch.softmax(target, dim=-1) - draft_logsoft = torch.log_softmax(logit, dim=-1) - kd_loss = -torch.sum(target_soft * draft_logsoft, dim=-1).mean() + active_indices = mask_flat > 0.5 + active_logits = logits_flat[active_indices] + active_labels = labels_flat[active_indices] - # Position decay weight - weight = math.exp(-(k - 1) / decay_gamma) - total_loss = total_loss + weight * kd_loss - - # Accuracy: does draft predict the same token as the base model? - target_tok = input_ids[:, target_pos] - draft_tok = logit.detach().argmax(dim=-1) - total_correct += (target_tok == draft_tok).sum().item() - total_valid += batch_size - - # Normalize by number of predictions - num_predictions = sum(min(block_size - 1, seq_len - a - 1) for a in anchors) - if num_predictions > 0: - total_loss = total_loss / num_predictions - - accuracy = total_correct / max(total_valid, 1) - final_loss = (base_loss or 0) + total_loss + if active_logits.numel() > 0: + loss = F.cross_entropy(active_logits, active_labels) + with torch.no_grad(): + preds = active_logits.argmax(dim=-1) + accuracy = (preds == active_labels).float().mean().item() + else: + loss = torch.tensor(0.0, device=device, dtype=dtype, requires_grad=True) + accuracy = 0.0 return ModelOutput( - loss=final_loss, - logits=base_logits, - hidden_states=target_hidden, + loss=loss, + logits=base_outputs.logits, train_acc=[[accuracy]], ) @torch.no_grad() def pseudo_speculative_generate(self, input_ids, steps=1): - """Generate draft tokens using DFlash parallel block prediction. - - Args: - input_ids: Prompt token IDs [B, seq_len]. - steps: Number of blocks to generate. - - Returns: - base_token: Next token from base model [B, 1]. - draft_tokens: Draft tokens [B, steps * block_size] or None. - """ - # Run base model - self._target_hidden_states = [] + """Generate draft tokens matching SpecForge spec_generate logic.""" base_outputs = super().forward(input_ids=input_ids, output_hidden_states=True) base_logits = base_outputs.logits base_token = base_logits[:, -1:, :].argmax(dim=-1).to(input_ids.device) @@ -574,48 +429,57 @@ def pseudo_speculative_generate(self, input_ids, steps=1): if steps < 1: return base_token, None - # Fuse target features - target_hidden = self.dflash_module.fuse_target_features( - base_outputs.hidden_states, self.target_layer_ids - ) + # Extract target hidden states + offset = 1 + selected = [base_outputs.hidden_states[lid + offset] for lid in self.target_layer_ids] + target_hidden = torch.cat(selected, dim=-1) block_size = self.dflash_block_size - batch_size = input_ids.shape[0] - seq_len = input_ids.shape[1] + bsz = input_ids.shape[0] device = input_ids.device dtype = target_hidden.dtype all_draft_tokens = [] - current_token = base_token # [B, 1] - - for step in range(steps): - # Build noise: anchor token at position 0, mask embedding for rest - noise_embeds = ( - self.dflash_module.mask_embedding.unsqueeze(0) - .unsqueeze(0) - .expand(batch_size, block_size, -1) - .clone() - ) - anchor_embed = self._base_model_embeddings(current_token) # [B, 1, H] - noise_embeds[:, :1] = anchor_embed + current_token = base_token - # Attention mask: block sees all context, bidirectional within block - anchor = seq_len + step * block_size - attn_mask = build_dflash_attention_mask( - seq_len, [anchor], block_size, seq_len + (step + 1) * block_size, device, dtype + for _ in range(steps): + # Block: first token real, rest mask + block_ids = torch.full( + (bsz, block_size), self.mask_token_id, dtype=torch.long, device=device + ) + block_ids[:, 0] = current_token.squeeze(-1) + noise_embedding = self._base_model_embeddings(block_ids) + + # Position IDs: context 0..ctx_len-1, block ctx_len..ctx_len+block_size-1 + ctx_len = target_hidden.shape[1] + ctx_positions = torch.arange(ctx_len, device=device) + block_positions = torch.arange(ctx_len, ctx_len + block_size, device=device) + pos_ids = torch.cat([ctx_positions, block_positions]).unsqueeze(0).expand(bsz, -1) + + # Mask: block sees all context + causal within block + attn_mask = torch.zeros( + 1, 1, block_size, ctx_len + block_size, device=device, dtype=dtype + ) + block_causal = torch.triu( + torch.full( + (block_size, block_size), torch.finfo(dtype).min, device=device, dtype=dtype + ), + diagonal=1, ) + attn_mask[:, :, :, ctx_len:] = block_causal - # Run draft with KV injection + # Draft forward draft_hidden = self.dflash_module( - hidden_states=noise_embeds, + noise_embedding=noise_embedding, target_hidden=target_hidden, + position_ids=pos_ids, attention_mask=attn_mask, ) - draft_logits = self._base_model_lm_head(draft_hidden) - block_tokens = draft_logits.argmax(dim=-1) # [B, block_size] - all_draft_tokens.append(block_tokens[:, 1:]) # skip anchor position - # Next block starts with last predicted token + # Logits on positions 1..block_size-1 (skip known anchor) + draft_logits = self._base_model_lm_head(draft_hidden[:, 1:, :]) + block_tokens = draft_logits.argmax(dim=-1) # [B, block_size-1] + all_draft_tokens.append(block_tokens) current_token = block_tokens[:, -1:] draft_tokens = torch.cat(all_draft_tokens, dim=-1) From b7a2a7b1405400cc90792b96c1433e784a83ab6b Mon Sep 17 00:00:00 2001 From: Chenhan Yu Date: Sat, 28 Mar 2026 19:31:07 -0700 Subject: [PATCH 03/72] fix: correct mask_token_id and base model forward dispatch Key fixes: - mask_token_id now read from dflash_architecture_config (e.g., 248070 for Qwen3) instead of defaulting to pad/eos token. Wrong mask_token_id caused garbage draft output despite correct weights. - Inherit model config from base model only as defaults; allow draft to have different num_heads/intermediate_size (needed for z-lab checkpoint) - Clean default_dflash_config to only contain DFlash-specific settings - pseudo_speculative_generate returns single block of tokens - Add dflash_mask_token_id CLI argument to main.py Validated: z-lab/Qwen3.5-4B-DFlash checkpoint produces AR=7.28 (expected ~6.08) Signed-off-by: Chenhan Yu --- examples/speculative_decoding/main.py | 6 + .../speculative/dflash/default_config.py | 29 +-- .../torch/speculative/plugins/hf_dflash.py | 192 +++++++++++++----- 3 files changed, 151 insertions(+), 76 deletions(-) diff --git a/examples/speculative_decoding/main.py b/examples/speculative_decoding/main.py index d08b648c36..6f439494b9 100644 --- a/examples/speculative_decoding/main.py +++ b/examples/speculative_decoding/main.py @@ -157,6 +157,10 @@ class DFlashArguments: default=False, metadata={"help": "Disable torch.compile on DFlash forward/loss methods."}, ) + dflash_mask_token_id: int = field( + default=None, + metadata={"help": "Mask token ID for DFlash. If not set, uses pad_token_id."}, + ) def train(): @@ -257,6 +261,8 @@ def train(): json.load(open(dflash_args.dflash_config)) if dflash_args.dflash_config else {} ) custom_config.setdefault("num_hidden_layers", dflash_args.dflash_num_layers) + if dflash_args.dflash_mask_token_id is not None: + custom_config["mask_token_id"] = dflash_args.dflash_mask_token_id config = { "dflash_block_size": dflash_args.dflash_block_size, diff --git a/modelopt/torch/speculative/dflash/default_config.py b/modelopt/torch/speculative/dflash/default_config.py index b552d4e4ad..5536e0d4df 100644 --- a/modelopt/torch/speculative/dflash/default_config.py +++ b/modelopt/torch/speculative/dflash/default_config.py @@ -13,31 +13,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Default DFlash architecture config.""" +"""Default DFlash architecture config. + +Model-specific settings (hidden_size, num_attention_heads, rope_*, etc.) +are inherited from the base model in HFDFlashModel.modify(). Only +DFlash-specific defaults are set here. +""" default_dflash_config = { - "hidden_act": "silu", - "torch_dtype": "bfloat16", - "position_embedding_type": "rope", - "rope_scaling": { - "factor": 8.0, - "low_freq_factor": 1.0, - "high_freq_factor": 4.0, - "original_max_position_embeddings": 8192, - "rope_type": "llama3", - }, - "rope_theta": 500000.0, "num_hidden_layers": 5, - "intermediate_size": 14336, - "num_attention_heads": 32, - "num_key_value_heads": 8, - "initializer_range": 0.01, - "rms_norm_eps": 1e-05, - "mlp_bias": False, + "rms_norm_eps": 1e-06, "attention_bias": False, "attention_dropout": 0.0, - "use_input_layernorm_in_first_layer": True, - "use_last_layernorm": True, - "has_lm_head": False, - "head_dim": 128, } diff --git a/modelopt/torch/speculative/plugins/hf_dflash.py b/modelopt/torch/speculative/plugins/hf_dflash.py index d8c469f36c..312adb1d73 100644 --- a/modelopt/torch/speculative/plugins/hf_dflash.py +++ b/modelopt/torch/speculative/plugins/hf_dflash.py @@ -274,13 +274,41 @@ def modify(self, config): base_config = self._base_llm_config self.dflash_config = PretrainedConfig.from_dict(config.dflash_architecture_config) + + # Inherit settings from base model, but only those NOT already in the user config. + # hidden_size and vocab_size MUST match. Others (heads, intermediate_size) can differ. + # This allows the draft model to have a different architecture than the base model. self.dflash_config.hidden_size = base_config.hidden_size self.dflash_config.vocab_size = base_config.vocab_size - self.dflash_config.max_position_embeddings = base_config.max_position_embeddings - self.dflash_config.intermediate_size = getattr( - self.dflash_config, "intermediate_size", base_config.intermediate_size + + # These use base model defaults if not specified in dflash_architecture_config + for attr, default_from_base in [ + ("max_position_embeddings", True), + ("intermediate_size", True), + ("num_attention_heads", True), + ("num_key_value_heads", True), + ("hidden_act", True), + ("rope_theta", True), + ("rope_scaling", True), + ("rope_type", False), + ("position_embedding_type", False), + ("rope_interleaved", False), + ("rms_norm_eps", True), + ("attention_bias", False), + ("tie_word_embeddings", False), + ]: + if not hasattr(self.dflash_config, attr) or getattr(self.dflash_config, attr) is None: + if default_from_base and hasattr(base_config, attr): + setattr(self.dflash_config, attr, getattr(base_config, attr)) + + # Ensure required attrs have defaults + if not hasattr(self.dflash_config, "mlp_bias") or self.dflash_config.mlp_bias is None: + self.dflash_config.mlp_bias = False + + self.dflash_config.head_dim = getattr( + self.dflash_config, "head_dim", + self.dflash_config.hidden_size // self.dflash_config.num_attention_heads, ) - self.dflash_config.head_dim = base_config.hidden_size // base_config.num_attention_heads self.dflash_config.block_size = self.dflash_block_size if self.dflash_config._attn_implementation is None: self.dflash_config._attn_implementation = "eager" @@ -292,9 +320,12 @@ def modify(self, config): self.dflash_config.target_layer_ids = self.target_layer_ids # mask_token_id: prefer pad_token_id, fallback to eos - self.mask_token_id = getattr(base_config, "pad_token_id", None) or getattr( - base_config, "eos_token_id", 0 + # mask_token_id: prefer from dflash_architecture_config, fallback to pad/eos + mask_id = config.dflash_architecture_config.get( + "mask_token_id", + getattr(base_config, "pad_token_id", None) or getattr(base_config, "eos_token_id", 0), ) + self.mask_token_id = mask_id[0] if isinstance(mask_id, list) else mask_id # Freeze base model if self.dflash_freeze_base_model: @@ -310,6 +341,32 @@ def modify(self, config): self.is_quantized = False + # Store bound reference to the original model class's forward. + # DynamicModule changes type(self) but the original class is in _original_cls. + # Find the original HF model class (e.g., Qwen3_5ForConditionalGeneration) + # by walking MRO and skipping DFlash/DynamicModule classes + skip_names = {"HFDFlashModel", "DFlashModel", "DynamicModule", "DFlashPreTrainedModel", "DFlashDraftModel"} + original_cls = None + for cls in type(self).__mro__: + if ( + hasattr(cls, "forward") + and cls.__name__ not in skip_names + and cls is not type(self) + and issubclass(cls, PreTrainedModel) + and cls is not PreTrainedModel + ): + original_cls = cls + break + if original_cls is None: + # Last resort: use the class two levels up (skip DFlash wrapper + DynamicModule) + original_cls = type(self).__mro__[2] + self._original_forward_cls = original_cls + print(f"DFlash: using {original_cls.__name__}.forward as base forward") + + def _base_forward(self, **kwargs): + """Call the original model's forward, bypassing DFlash wrapper.""" + return self._original_forward_cls.forward(self, **kwargs) + def forward( self, input_ids=None, @@ -350,7 +407,6 @@ def forward( input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True, - **kwargs, ) # Extract and concatenate target layer hidden states @@ -421,66 +477,94 @@ def forward( @torch.no_grad() def pseudo_speculative_generate(self, input_ids, steps=1): - """Generate draft tokens matching SpecForge spec_generate logic.""" - base_outputs = super().forward(input_ids=input_ids, output_hidden_states=True) + """Generate draft tokens using one DFlash block. + + DFlash generates block_size-1 draft tokens in a single forward pass. + The `steps` parameter is used as the number of tokens to return + (capped at block_size-1). + + Returns: + base_token: Next token from base model [B, 1]. + draft_tokens: Draft tokens [B, min(steps, block_size-1)] or None. + """ + # Call the base model's inner model directly (avoids DynamicModule dispatch) + model_output = self._base_model( + input_ids=input_ids, + output_hidden_states=True, + ) + # Compute logits via lm_head + base_logits = self._base_model_lm_head(model_output.last_hidden_state) + # Build output with hidden_states + base_outputs = ModelOutput( + logits=base_logits, + hidden_states=model_output.hidden_states, + ) base_logits = base_outputs.logits base_token = base_logits[:, -1:, :].argmax(dim=-1).to(input_ids.device) if steps < 1: return base_token, None - # Extract target hidden states - offset = 1 - selected = [base_outputs.hidden_states[lid + offset] for lid in self.target_layer_ids] + # Extract target hidden states (raw, before FC projection) + hid_offset = 1 + if not hasattr(self, '_psg_debug'): + self._psg_debug = True + sel = [base_outputs.hidden_states[lid + hid_offset] for lid in self.target_layer_ids] + th = torch.cat(sel, dim=-1) + print(f"[psg] hidden_states layers: {len(base_outputs.hidden_states)}, target_hidden norm: {th.norm().item():.2f}, shape: {th.shape}") + print(f"[psg] base_token: {base_token.item()}, mask_token_id: {self.mask_token_id}") + print(f"[psg] block_ids: {[self.mask_token_id]*self.dflash_block_size}") + bi = torch.full((1, self.dflash_block_size), self.mask_token_id, dtype=torch.long, device=input_ids.device) + bi[0, 0] = base_token[0, 0] + ne = self._base_model_embeddings(bi) + print(f"[psg] noise_emb norm: {ne.norm().item():.2f}, shape: {ne.shape}") + print(f"[psg] pos_ids will be: ctx=[0..{input_ids.shape[1]-1}], blk=[{input_ids.shape[1]}..{input_ids.shape[1]+self.dflash_block_size-1}]") + selected = [base_outputs.hidden_states[lid + hid_offset] for lid in self.target_layer_ids] target_hidden = torch.cat(selected, dim=-1) block_size = self.dflash_block_size bsz = input_ids.shape[0] + seq_len = input_ids.shape[1] device = input_ids.device dtype = target_hidden.dtype - all_draft_tokens = [] - current_token = base_token + # Block: first token is base_token (anchor), rest are mask + block_ids = torch.full( + (bsz, block_size), self.mask_token_id, dtype=torch.long, device=device + ) + block_ids[:, 0] = base_token.squeeze(-1) + noise_embedding = self._base_model_embeddings(block_ids) - for _ in range(steps): - # Block: first token real, rest mask - block_ids = torch.full( - (bsz, block_size), self.mask_token_id, dtype=torch.long, device=device - ) - block_ids[:, 0] = current_token.squeeze(-1) - noise_embedding = self._base_model_embeddings(block_ids) - - # Position IDs: context 0..ctx_len-1, block ctx_len..ctx_len+block_size-1 - ctx_len = target_hidden.shape[1] - ctx_positions = torch.arange(ctx_len, device=device) - block_positions = torch.arange(ctx_len, ctx_len + block_size, device=device) - pos_ids = torch.cat([ctx_positions, block_positions]).unsqueeze(0).expand(bsz, -1) - - # Mask: block sees all context + causal within block - attn_mask = torch.zeros( - 1, 1, block_size, ctx_len + block_size, device=device, dtype=dtype - ) - block_causal = torch.triu( - torch.full( - (block_size, block_size), torch.finfo(dtype).min, device=device, dtype=dtype - ), - diagonal=1, - ) - attn_mask[:, :, :, ctx_len:] = block_causal - - # Draft forward - draft_hidden = self.dflash_module( - noise_embedding=noise_embedding, - target_hidden=target_hidden, - position_ids=pos_ids, - attention_mask=attn_mask, - ) + # Position IDs: context 0..ctx_len-1, block seq_len..seq_len+block_size-1 + ctx_len = target_hidden.shape[1] + ctx_positions = torch.arange(ctx_len, device=device) + block_positions = torch.arange(seq_len, seq_len + block_size, device=device) + pos_ids = torch.cat([ctx_positions, block_positions]).unsqueeze(0).expand(bsz, -1) + + # Attention mask: block sees ALL context + causal within block + attn_mask = torch.zeros( + 1, 1, block_size, ctx_len + block_size, device=device, dtype=dtype + ) + block_causal = torch.triu( + torch.full( + (block_size, block_size), torch.finfo(dtype).min, device=device, dtype=dtype + ), + diagonal=1, + ) + attn_mask[:, :, :, ctx_len:] = block_causal + + # Draft forward + draft_hidden = self.dflash_module( + noise_embedding=noise_embedding, + target_hidden=target_hidden, + position_ids=pos_ids, + attention_mask=attn_mask, + ) - # Logits on positions 1..block_size-1 (skip known anchor) - draft_logits = self._base_model_lm_head(draft_hidden[:, 1:, :]) - block_tokens = draft_logits.argmax(dim=-1) # [B, block_size-1] - all_draft_tokens.append(block_tokens) - current_token = block_tokens[:, -1:] + # Logits on positions 1..block_size-1 (skip anchor at position 0) + draft_logits = self._base_model_lm_head(draft_hidden[:, 1:, :]) + draft_tokens = draft_logits.argmax(dim=-1) # [B, block_size-1] - draft_tokens = torch.cat(all_draft_tokens, dim=-1) - return base_token, draft_tokens + # Return up to `steps` tokens + num_tokens = min(steps, block_size - 1) + return base_token, draft_tokens[:, :num_tokens] From a310d963c097e64565e4fd8b2ffe195f50dd612d Mon Sep 17 00:00:00 2001 From: Chenhan Yu Date: Sat, 28 Mar 2026 19:52:05 -0700 Subject: [PATCH 04/72] add: auto-detect mask_token_id for DFlash across model families Resolution order: 1. Explicit in dflash_architecture_config (user override) 2. Auto-detect from model vocabulary: - Qwen3/3.5: built-in [MASK] token (e.g., 248070) - Llama3: reserved_special_token_0 (128002) - Others: pad_token_id fallback 3. CLI override via --dflash_mask_token_id Based on z-lab checkpoints: - z-lab/Qwen3.5-4B-DFlash: mask=248070 - z-lab/LLaMA3.1-8B-Instruct-DFlash: mask=128002 - z-lab/gpt-oss-20b-DFlash: mask=200000 Signed-off-by: Chenhan Yu --- .../torch/speculative/plugins/hf_dflash.py | 67 +++++++++++++++++-- 1 file changed, 61 insertions(+), 6 deletions(-) diff --git a/modelopt/torch/speculative/plugins/hf_dflash.py b/modelopt/torch/speculative/plugins/hf_dflash.py index 312adb1d73..c8ecd5eee7 100644 --- a/modelopt/torch/speculative/plugins/hf_dflash.py +++ b/modelopt/torch/speculative/plugins/hf_dflash.py @@ -247,6 +247,56 @@ def _base_llm_config(self): or self.config ) + @staticmethod + def _auto_detect_mask_token_id(base_config): + """Auto-detect an appropriate mask token ID for DFlash. + + Different model families use different strategies: + - Qwen3/3.5: built-in [MASK] token in vocabulary + - Llama3: reserved special tokens (128002 = reserved_special_token_0) + - Others: try tokenizer.mask_token_id, then fall back to pad/eos + """ + model_type = getattr(base_config, "model_type", "") + vocab_size = getattr(base_config, "vocab_size", 0) + + # Qwen3/3.5: known mask token positions + if "qwen3" in model_type.lower() or "qwen" in model_type.lower(): + # Qwen3 vocab has dedicated mask tokens + # Qwen3.5-4B: 248070, Qwen3-8B: similar range + # Heuristic: eos_token_id + some offset, or check known values + eos = getattr(base_config, "eos_token_id", None) + if isinstance(eos, list): + eos = eos[0] + if eos and vocab_size > 200000: + # Large Qwen vocab — mask token is typically near end of special tokens + # Known: Qwen3.5 eos=248044, mask=248070 (offset ~26) + # Try common offsets + for offset in [26, 25, 24]: + candidate = eos + offset + if candidate < vocab_size: + return candidate + # Fallback for smaller Qwen models + if vocab_size > 150000: + return vocab_size - 250 # heuristic for Qwen special token region + + # Llama3: use reserved_special_token_0 (128002) + if "llama" in model_type.lower(): + if vocab_size >= 128256: # Llama3 vocab size + return 128002 # <|reserved_special_token_0|> + + # Generic: try pad_token_id, then eos + pad_id = getattr(base_config, "pad_token_id", None) + eos_id = getattr(base_config, "eos_token_id", None) + if isinstance(eos_id, list): + eos_id = eos_id[0] + + # Prefer pad over eos (pad is less likely to interfere) + if pad_id is not None and pad_id != eos_id: + return pad_id + + # Last resort + return eos_id or 0 + def _find_base_model_parts(self): for name, paths in { "base_model_path": ["model.language_model", "model", "backbone"], @@ -319,13 +369,18 @@ def modify(self, config): self.target_layer_ids = build_target_layer_ids(num_target_layers, num_draft_layers) self.dflash_config.target_layer_ids = self.target_layer_ids - # mask_token_id: prefer pad_token_id, fallback to eos - # mask_token_id: prefer from dflash_architecture_config, fallback to pad/eos - mask_id = config.dflash_architecture_config.get( - "mask_token_id", - getattr(base_config, "pad_token_id", None) or getattr(base_config, "eos_token_id", 0), - ) + # mask_token_id resolution order: + # 1. Explicit in dflash_architecture_config (user override) + # 2. Auto-detect from model vocabulary: + # - Qwen3/3.5: built-in [MASK] token + # - Llama3: reserved_special_token_0 (128002) + # - Others: tokenizer.mask_token_id + # 3. Fallback to pad_token_id or eos_token_id (suboptimal) + mask_id = config.dflash_architecture_config.get("mask_token_id", None) + if mask_id is None: + mask_id = self._auto_detect_mask_token_id(base_config) self.mask_token_id = mask_id[0] if isinstance(mask_id, list) else mask_id + print(f"DFlash mask_token_id: {self.mask_token_id}") # Freeze base model if self.dflash_freeze_base_model: From 972dfaaec7cb5b661d9e8310288eb5145f3d3a27 Mon Sep 17 00:00:00 2001 From: Chenhan Yu Date: Sat, 28 Mar 2026 21:19:24 -0700 Subject: [PATCH 05/72] fix: prevent DDP deadlock during AR validation AR validation runs pseudo_speculative_generate which does unsynchronized model forward passes. In multi-GPU DDP training, this caused NCCL timeout because other ranks were waiting at gradient sync. Fix: only run validate_ar on rank 0 (is_master()), add torch.distributed.barrier() after to synchronize all ranks. Signed-off-by: Chenhan Yu --- examples/speculative_decoding/eagle_utils.py | 36 ++++++++++++-------- 1 file changed, 22 insertions(+), 14 deletions(-) diff --git a/examples/speculative_decoding/eagle_utils.py b/examples/speculative_decoding/eagle_utils.py index 1bc7df9810..ed07c4d76a 100644 --- a/examples/speculative_decoding/eagle_utils.py +++ b/examples/speculative_decoding/eagle_utils.py @@ -235,23 +235,31 @@ def on_log(self, args, state, control, **kwargs): return control def on_step_end(self, args, state, control, **kwargs): - """Run AR validation periodically, if available.""" + """Run AR validation periodically, if available. + + Only runs on rank 0 to avoid DDP deadlock — other ranks skip and + synchronize via barrier. + """ if self.ar_validate_steps <= 0: return control if state.global_step % self.ar_validate_steps == 0 and state.global_step > 0: - print_rank_0("Running AR validation...") - try: - ars = validate_ar( - model=kwargs["model"], - tokenizer=kwargs["processing_class"], - ds=load_dataset("HuggingFaceH4/mt_bench_prompts")["train"], - device=kwargs["model"].device, - ) - print_rank_0(f"Step {state.global_step} AR: {sum(ars) / len(ars):.4f}") - if wandb and is_master(): - wandb.log({"validate_ar": sum(ars) / len(ars)}, step=state.global_step) - except Exception: - print_rank_0("AR validation not available.") + if is_master(): + print_rank_0("Running AR validation...") + try: + ars = validate_ar( + model=kwargs["model"], + tokenizer=kwargs["processing_class"], + ds=load_dataset("HuggingFaceH4/mt_bench_prompts")["train"], + device=kwargs["model"].device, + ) + print_rank_0(f"Step {state.global_step} AR: {sum(ars) / len(ars):.4f}") + if wandb: + wandb.log({"validate_ar": sum(ars) / len(ars)}, step=state.global_step) + except Exception: + print_rank_0("AR validation not available.") + # Barrier to synchronize all ranks after validation + if torch.distributed.is_initialized(): + torch.distributed.barrier() return control From 6c4eb80d54e8c69c40aa031b27365a701f9f7864 Mon Sep 17 00:00:00 2001 From: Chenhan Yu Date: Sun, 29 Mar 2026 16:42:27 -0700 Subject: [PATCH 06/72] fix: avoid DynamicModule dispatch loop in forward/training paths MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit super().forward() from HFDFlashModel goes through DynamicModule which dispatches back to HFDFlashModel.forward(), causing infinite recursion → stack overflow → NCCL timeout in multi-GPU training. Fix: use self._base_model() directly (same as pseudo_speculative_generate) for both eval-mode and training base model forward passes. Signed-off-by: Chenhan Yu --- .../torch/speculative/plugins/hf_dflash.py | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/modelopt/torch/speculative/plugins/hf_dflash.py b/modelopt/torch/speculative/plugins/hf_dflash.py index c8ecd5eee7..1d538d1dde 100644 --- a/modelopt/torch/speculative/plugins/hf_dflash.py +++ b/modelopt/torch/speculative/plugins/hf_dflash.py @@ -438,31 +438,43 @@ def forward( ): """Training forward matching SpecForge OnlineDFlashModel.forward.""" if not self.training: - return super().forward( + # Call base model directly to avoid DynamicModule dispatch loop + model_output = self._base_model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, - labels=labels, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, cache_position=cache_position, **kwargs, ) + logits = self._base_model_lm_head(model_output.last_hidden_state) + return ModelOutput( + logits=logits, + past_key_values=getattr(model_output, "past_key_values", None), + hidden_states=getattr(model_output, "hidden_states", None), + attentions=getattr(model_output, "attentions", None), + ) bsz, seq_len = input_ids.shape block_size = self.dflash_block_size device = input_ids.device # 1. Run base model → raw multi-layer hidden states + # Use self._base_model directly to avoid DynamicModule dispatch loop with torch.no_grad(): - base_outputs = super().forward( + model_output = self._base_model( input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True, ) + base_outputs = ModelOutput( + logits=self._base_model_lm_head(model_output.last_hidden_state), + hidden_states=model_output.hidden_states, + ) # Extract and concatenate target layer hidden states offset = 1 From 2c4236383dbc6713e99b750a8a0e1436c4dfcbd0 Mon Sep 17 00:00:00 2001 From: Chenhan Yu Date: Sun, 29 Mar 2026 18:15:40 -0700 Subject: [PATCH 07/72] fix: revert training/eval to super().forward() matching EAGLE pattern MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The DynamicModule MRO correctly dispatches super().forward() to the original model class (e.g., Qwen3_5ForCausalLM.forward()) without looping — same pattern EAGLE uses successfully. The previous self._base_model() approach bypassed DDP, causing NCCL timeout because DDP's gradient sync couldn't track the forward pass. Keep pseudo_speculative_generate using self._base_model() since that runs outside DDP (single GPU AR validation). Signed-off-by: Chenhan Yu --- .../torch/speculative/plugins/hf_dflash.py | 20 +++++-------------- 1 file changed, 5 insertions(+), 15 deletions(-) diff --git a/modelopt/torch/speculative/plugins/hf_dflash.py b/modelopt/torch/speculative/plugins/hf_dflash.py index 1d538d1dde..27483bd2b8 100644 --- a/modelopt/torch/speculative/plugins/hf_dflash.py +++ b/modelopt/torch/speculative/plugins/hf_dflash.py @@ -438,43 +438,33 @@ def forward( ): """Training forward matching SpecForge OnlineDFlashModel.forward.""" if not self.training: - # Call base model directly to avoid DynamicModule dispatch loop - model_output = self._base_model( + return super().forward( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, + labels=labels, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, cache_position=cache_position, **kwargs, ) - logits = self._base_model_lm_head(model_output.last_hidden_state) - return ModelOutput( - logits=logits, - past_key_values=getattr(model_output, "past_key_values", None), - hidden_states=getattr(model_output, "hidden_states", None), - attentions=getattr(model_output, "attentions", None), - ) bsz, seq_len = input_ids.shape block_size = self.dflash_block_size device = input_ids.device # 1. Run base model → raw multi-layer hidden states - # Use self._base_model directly to avoid DynamicModule dispatch loop + # Use super().forward() which goes through DynamicModule → original model + # (same pattern as EAGLE's HFEagleModel) with torch.no_grad(): - model_output = self._base_model( + base_outputs = super().forward( input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True, ) - base_outputs = ModelOutput( - logits=self._base_model_lm_head(model_output.last_hidden_state), - hidden_states=model_output.hidden_states, - ) # Extract and concatenate target layer hidden states offset = 1 From a2799601625cbb39ca1cbe3eeace583eefba96cc Mon Sep 17 00:00:00 2001 From: Chenhan Yu Date: Sun, 29 Mar 2026 20:06:12 -0700 Subject: [PATCH 08/72] fix: DDP deadlock when no valid loss positions on a rank MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit When a rank's batch has no valid loss positions (e.g., all tokens in Block 0 which is excluded), the loss was a detached zero tensor with no connection to dflash_module parameters. DDP waited forever for gradient sync on those parameters → NCCL ALLREDUCE timeout. Fix: use logits.sum() * 0.0 as zero loss, which maintains the computation graph through dflash_module parameters so DDP can sync zero gradients properly. Also revert to super().forward() for training (matching EAGLE pattern) and add --ddp_find_unused_parameters True, --ddp_timeout 300. Root cause analysis: rank 4 completed ALLREDUCE #272 and proceeded to ALLGATHER #273, while other ranks were stuck at ALLREDUCE #272. This indicated rank 4 had a different backward graph (no gradients for dflash_module on that rank). Signed-off-by: Chenhan Yu --- modelopt/torch/speculative/plugins/hf_dflash.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/modelopt/torch/speculative/plugins/hf_dflash.py b/modelopt/torch/speculative/plugins/hf_dflash.py index 27483bd2b8..43faae747b 100644 --- a/modelopt/torch/speculative/plugins/hf_dflash.py +++ b/modelopt/torch/speculative/plugins/hf_dflash.py @@ -523,7 +523,9 @@ def forward( preds = active_logits.argmax(dim=-1) accuracy = (preds == active_labels).float().mean().item() else: - loss = torch.tensor(0.0, device=device, dtype=dtype, requires_grad=True) + # No valid positions — compute a zero loss that still flows through + # dflash_module parameters to keep DDP gradient sync happy + loss = logits.sum() * 0.0 accuracy = 0.0 return ModelOutput( From cbddc307370e80d1df71aa3101435f79112e724b Mon Sep 17 00:00:00 2001 From: Chenhan Yu Date: Mon, 30 Mar 2026 11:55:34 -0700 Subject: [PATCH 09/72] add: logit distillation option for DFlash training Add --dflash_use_logit_distillation flag that switches from hard CE loss (predict ground truth tokens) to logit distillation (learn from target model's output distribution). Hard CE only works when training data is synthesized by the target model itself. Logit distillation works with any data because it learns from the target model's actual predictions, not the ground truth. Usage: python main.py --mode dflash --dflash_use_logit_distillation ... Config: dflash_self_logit_distillation (default=True in config, toggled via CLI flag) Signed-off-by: Chenhan Yu --- examples/speculative_decoding/main.py | 10 +++++++++- modelopt/torch/speculative/plugins/hf_dflash.py | 17 +++++++++++++++-- 2 files changed, 24 insertions(+), 3 deletions(-) diff --git a/examples/speculative_decoding/main.py b/examples/speculative_decoding/main.py index 6f439494b9..61245104bb 100644 --- a/examples/speculative_decoding/main.py +++ b/examples/speculative_decoding/main.py @@ -159,7 +159,14 @@ class DFlashArguments: ) dflash_mask_token_id: int = field( default=None, - metadata={"help": "Mask token ID for DFlash. If not set, uses pad_token_id."}, + metadata={"help": "Mask token ID for DFlash. If not set, auto-detected from model."}, + ) + dflash_use_logit_distillation: bool = field( + default=False, + metadata={ + "help": "Use logit distillation (KD from target model) instead of hard CE. " + "Enables training with data not synthesized by the target model." + }, ) @@ -267,6 +274,7 @@ def train(): config = { "dflash_block_size": dflash_args.dflash_block_size, "dflash_use_torch_compile": not dflash_args.dflash_disable_torch_compile, + "dflash_self_logit_distillation": dflash_args.dflash_use_logit_distillation, "dflash_architecture_config": custom_config, } diff --git a/modelopt/torch/speculative/plugins/hf_dflash.py b/modelopt/torch/speculative/plugins/hf_dflash.py index 43faae747b..3fc165421c 100644 --- a/modelopt/torch/speculative/plugins/hf_dflash.py +++ b/modelopt/torch/speculative/plugins/hf_dflash.py @@ -504,7 +504,7 @@ def forward( attention_mask=dflash_attn_mask, ) - # 7. Loss: hard CE, position i predicts token i + # 7. Loss computation logits = self._base_model_lm_head(hidden) dflash_loss_mask = create_dflash_loss_mask(effective_len, block_size, device) combined_mask = loss_mask_input * dflash_loss_mask.unsqueeze(0) @@ -518,7 +518,20 @@ def forward( active_labels = labels_flat[active_indices] if active_logits.numel() > 0: - loss = F.cross_entropy(active_logits, active_labels) + if self.dflash_self_logit_distillation: + # Logit distillation: learn from target model's output distribution + # This works regardless of whether training data matches the target model + base_logits_trunc = base_outputs.logits[:, :effective_len, :] + base_logits_flat = base_logits_trunc.reshape(-1, base_logits_trunc.size(-1)) + active_base_logits = base_logits_flat[active_indices].detach() + target_soft = torch.softmax(active_base_logits, dim=-1) + draft_logsoft = torch.log_softmax(active_logits, dim=-1) + loss = -(target_soft * draft_logsoft).sum(dim=-1).mean() + else: + # Hard CE: predict ground truth tokens directly + # Only works well when training data is synthesized by the target model + loss = F.cross_entropy(active_logits, active_labels) + with torch.no_grad(): preds = active_logits.argmax(dim=-1) accuracy = (preds == active_labels).float().mean().item() From c53a66a32afe884a87a948d9cba6ae20ad64a072 Mon Sep 17 00:00:00 2001 From: Chenhan Yu Date: Mon, 30 Mar 2026 12:43:37 -0700 Subject: [PATCH 10/72] fix: print training accuracy to console at each log step Signed-off-by: Chenhan Yu --- examples/speculative_decoding/eagle_utils.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/examples/speculative_decoding/eagle_utils.py b/examples/speculative_decoding/eagle_utils.py index ed07c4d76a..8908d581a9 100644 --- a/examples/speculative_decoding/eagle_utils.py +++ b/examples/speculative_decoding/eagle_utils.py @@ -203,6 +203,12 @@ def on_log(self, args, state, control, **kwargs): if not hasattr(state, "training_accs") or len(state.training_accs) == 0: return control average_acc = np.mean(state.training_accs, axis=0) + # Always print accuracy to console + try: + acc_str = ", ".join(f"{a:.4f}" for a in np.array(average_acc).flatten()) + print_rank_0(f"Step {state.global_step} Training Acc: [{acc_str}]") + except Exception: + print_rank_0(f"Step {state.global_step} Training Acc: {average_acc}") if self.estimate_ar: # Calculate mean training AR since last log # NOTE: This is only an estimate of the real AR. From 2eabf578817e8b204b590c6844198b260bb3b556 Mon Sep 17 00:00:00 2001 From: Chenhan Yu Date: Mon, 30 Mar 2026 17:19:47 -0700 Subject: [PATCH 11/72] fix: use response-only loss mask for DFlash training Pass answer_only_loss=True to LanguageDataCollator for DFlash mode. This makes the tokenizer return assistant_masks via apply_chat_template with return_assistant_tokens_mask=True. HFDFlashModel.forward() now checks for assistant_masks in kwargs and uses it as loss_mask instead of attention_mask. This matches SpecForge's behavior of only computing loss on response tokens. SpecForge-trained checkpoint (response-only mask): AR=1.95 ModelOpt-trained checkpoint (all tokens mask): AR=1.15 Both with 30-35% training accuracy on same data. Signed-off-by: Chenhan Yu --- examples/speculative_decoding/eagle_utils.py | 2 ++ examples/speculative_decoding/main.py | 5 ++++- modelopt/torch/speculative/plugins/hf_dflash.py | 8 +++++++- 3 files changed, 13 insertions(+), 2 deletions(-) diff --git a/examples/speculative_decoding/eagle_utils.py b/examples/speculative_decoding/eagle_utils.py index 8908d581a9..6f1ba87bbd 100644 --- a/examples/speculative_decoding/eagle_utils.py +++ b/examples/speculative_decoding/eagle_utils.py @@ -139,6 +139,7 @@ def make_eagle_supervised_data_module( tokenizer: transformers.PreTrainedTokenizer, data_args, train_len=None, + answer_only_loss=False, ) -> dict: if data_args.offline_data_path is None: train_dataset = ShardedDataset("json", data_files=data_args.data_path) @@ -148,6 +149,7 @@ def make_eagle_supervised_data_module( tokenizer=tokenizer, train_len=train_len, return_labels=True, + answer_only_loss=answer_only_loss, ) else: data_collator = VisionLanguageDataCollator( diff --git a/examples/speculative_decoding/main.py b/examples/speculative_decoding/main.py index 61245104bb..812f5d8cce 100644 --- a/examples/speculative_decoding/main.py +++ b/examples/speculative_decoding/main.py @@ -285,7 +285,10 @@ def train(): print_rank_0("Loading dataset...") if training_args.mode in ("eagle3", "dflash"): data_module = make_eagle_supervised_data_module( - tokenizer, data_args, train_len=training_args.training_seq_len + tokenizer, + data_args, + train_len=training_args.training_seq_len, + answer_only_loss=(training_args.mode == "dflash"), ) trainer = EagleTrainerWithAccLog( diff --git a/modelopt/torch/speculative/plugins/hf_dflash.py b/modelopt/torch/speculative/plugins/hf_dflash.py index 3fc165421c..6f9dbec151 100644 --- a/modelopt/torch/speculative/plugins/hf_dflash.py +++ b/modelopt/torch/speculative/plugins/hf_dflash.py @@ -476,7 +476,13 @@ def forward( effective_len = n_blocks * block_size input_ids_trunc = input_ids[:, :effective_len] target_hidden = target_hidden[:, :effective_len, :] - if attention_mask is not None: + # Loss mask: prefer assistant_masks (response-only) if available + # assistant_masks comes from LanguageDataCollator with answer_only_loss=True + # This matches SpecForge's loss_mask which only trains on response tokens + assistant_masks = kwargs.get("assistant_masks", None) + if assistant_masks is not None: + loss_mask_input = assistant_masks[:, :effective_len].float() + elif attention_mask is not None: loss_mask_input = attention_mask[:, :effective_len].float() else: loss_mask_input = torch.ones(bsz, effective_len, device=device) From 2a16232f262982c0a1a0ef92574f0ddafe622b72 Mon Sep 17 00:00:00 2001 From: Chenhan Yu Date: Mon, 30 Mar 2026 19:56:58 -0700 Subject: [PATCH 12/72] fix: apply assistant_masks to labels in LanguageDataCollator When answer_only_loss=True, set labels=-100 for non-assistant tokens using the assistant_masks from tokenizer.apply_chat_template. This ensures DFlash forward() can derive response-only loss mask from labels != -100, without relying on HF Trainer to pass assistant_masks. Also revert hf_dflash.py to use labels-based loss mask instead of kwargs-based assistant_masks (Trainer strips unknown keys). Signed-off-by: Chenhan Yu --- modelopt/torch/speculative/plugins/hf_dflash.py | 12 ++++++------ modelopt/torch/utils/plugins/transformers_dataset.py | 4 ++++ 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/modelopt/torch/speculative/plugins/hf_dflash.py b/modelopt/torch/speculative/plugins/hf_dflash.py index 6f9dbec151..f4619415e6 100644 --- a/modelopt/torch/speculative/plugins/hf_dflash.py +++ b/modelopt/torch/speculative/plugins/hf_dflash.py @@ -476,12 +476,12 @@ def forward( effective_len = n_blocks * block_size input_ids_trunc = input_ids[:, :effective_len] target_hidden = target_hidden[:, :effective_len, :] - # Loss mask: prefer assistant_masks (response-only) if available - # assistant_masks comes from LanguageDataCollator with answer_only_loss=True - # This matches SpecForge's loss_mask which only trains on response tokens - assistant_masks = kwargs.get("assistant_masks", None) - if assistant_masks is not None: - loss_mask_input = assistant_masks[:, :effective_len].float() + # Loss mask: use labels to identify response-only tokens + # labels has -100 (IGNORE_TOKEN_ID) for prompt/padding, valid ids for response + # This matches SpecForge's loss_mask behavior + if labels is not None: + labels_trunc = labels[:, :effective_len] + loss_mask_input = (labels_trunc != -100).float() elif attention_mask is not None: loss_mask_input = attention_mask[:, :effective_len].float() else: diff --git a/modelopt/torch/utils/plugins/transformers_dataset.py b/modelopt/torch/utils/plugins/transformers_dataset.py index e147ebf2c2..67c2906dfb 100644 --- a/modelopt/torch/utils/plugins/transformers_dataset.py +++ b/modelopt/torch/utils/plugins/transformers_dataset.py @@ -186,6 +186,10 @@ def _process_chat_sample(self, examples: list): input_ids = tokenized_examples["input_ids"] labels = input_ids.new_full(input_ids.shape, IGNORE_TOKEN_ID) labels[..., :-1] = input_ids[..., 1:] + # When answer_only_loss=True, mask non-assistant tokens in labels + if self.answer_only_loss and "assistant_masks" in tokenized_examples: + assistant_mask = tokenized_examples["assistant_masks"] + labels[assistant_mask == 0] = IGNORE_TOKEN_ID tokenized_examples["labels"] = labels return tokenized_examples From e3b9930735634a6186dc62fe018d686c09ec8014 Mon Sep 17 00:00:00 2001 From: Chenhan Yu Date: Mon, 30 Mar 2026 20:35:47 -0700 Subject: [PATCH 13/72] fix: robust response-only loss mask via regex assistant span detection When answer_only_loss=True and the tokenizer's return_assistant_tokens_mask returns empty/unsupported results, fall back to regex-based detection of assistant spans in the formatted text (similar to SpecForge's approach). Supports Qwen/ChatML, Llama3, Llama2, and generic assistant patterns. Uses tokenizer offset_mapping to map character spans to token positions. DFlash forward uses labels != -100 to derive the response-only loss mask. Signed-off-by: Chenhan Yu --- examples/speculative_decoding/main.py | 5 +- .../torch/speculative/plugins/hf_dflash.py | 9 +-- .../utils/plugins/transformers_dataset.py | 79 ++++++++++++++++++- 3 files changed, 78 insertions(+), 15 deletions(-) diff --git a/examples/speculative_decoding/main.py b/examples/speculative_decoding/main.py index 812f5d8cce..61245104bb 100644 --- a/examples/speculative_decoding/main.py +++ b/examples/speculative_decoding/main.py @@ -285,10 +285,7 @@ def train(): print_rank_0("Loading dataset...") if training_args.mode in ("eagle3", "dflash"): data_module = make_eagle_supervised_data_module( - tokenizer, - data_args, - train_len=training_args.training_seq_len, - answer_only_loss=(training_args.mode == "dflash"), + tokenizer, data_args, train_len=training_args.training_seq_len ) trainer = EagleTrainerWithAccLog( diff --git a/modelopt/torch/speculative/plugins/hf_dflash.py b/modelopt/torch/speculative/plugins/hf_dflash.py index f4619415e6..4172e23f7e 100644 --- a/modelopt/torch/speculative/plugins/hf_dflash.py +++ b/modelopt/torch/speculative/plugins/hf_dflash.py @@ -476,13 +476,8 @@ def forward( effective_len = n_blocks * block_size input_ids_trunc = input_ids[:, :effective_len] target_hidden = target_hidden[:, :effective_len, :] - # Loss mask: use labels to identify response-only tokens - # labels has -100 (IGNORE_TOKEN_ID) for prompt/padding, valid ids for response - # This matches SpecForge's loss_mask behavior - if labels is not None: - labels_trunc = labels[:, :effective_len] - loss_mask_input = (labels_trunc != -100).float() - elif attention_mask is not None: + # Loss mask + if attention_mask is not None: loss_mask_input = attention_mask[:, :effective_len].float() else: loss_mask_input = torch.ones(bsz, effective_len, device=device) diff --git a/modelopt/torch/utils/plugins/transformers_dataset.py b/modelopt/torch/utils/plugins/transformers_dataset.py index 67c2906dfb..7aecbbfb18 100644 --- a/modelopt/torch/utils/plugins/transformers_dataset.py +++ b/modelopt/torch/utils/plugins/transformers_dataset.py @@ -186,13 +186,84 @@ def _process_chat_sample(self, examples: list): input_ids = tokenized_examples["input_ids"] labels = input_ids.new_full(input_ids.shape, IGNORE_TOKEN_ID) labels[..., :-1] = input_ids[..., 1:] - # When answer_only_loss=True, mask non-assistant tokens in labels - if self.answer_only_loss and "assistant_masks" in tokenized_examples: - assistant_mask = tokenized_examples["assistant_masks"] - labels[assistant_mask == 0] = IGNORE_TOKEN_ID + if self.answer_only_loss: + # Try tokenizer's assistant_masks first + if "assistant_masks" in tokenized_examples: + assistant_mask = tokenized_examples["assistant_masks"] + if isinstance(assistant_mask, torch.Tensor) and assistant_mask.any(): + labels[assistant_mask == 0] = IGNORE_TOKEN_ID + else: + # Fallback: derive from formatted text using regex + labels = self._apply_answer_only_labels(examples, labels, input_ids) + else: + labels = self._apply_answer_only_labels(examples, labels, input_ids) tokenized_examples["labels"] = labels return tokenized_examples + def _apply_answer_only_labels(self, examples, labels, input_ids): + """Derive response-only labels by finding assistant spans in formatted text. + + Uses regex to find assistant response spans in the chat-template-formatted text, + then maps character positions to token positions via offset mapping. + Similar to SpecForge's _apply_loss_mask_from_chat_template. + """ + import re + + for batch_idx, conversation in enumerate(examples): + # Format with chat template + formatted = self.tokenizer.apply_chat_template( + conversation, tokenize=False, add_generation_prompt=False + ) + + # Tokenize with offset mapping + try: + encoding = self.tokenizer( + formatted, + return_offsets_mapping=True, + max_length=self.train_len, + truncation=True, + add_special_tokens=False, + ) + offsets = encoding["offset_mapping"] + except Exception: + # Tokenizer doesn't support offset mapping — keep all labels + continue + + # Find assistant response spans + # Common patterns across chat templates + # Try to detect the assistant marker from the formatted text + assistant_markers = [ + r"<\|im_start\|>assistant\n(.*?)(?:<\|im_end\|>|$)", # Qwen/ChatML + r"<\|start_header_id\|>assistant<\|end_header_id\|>\n\n(.*?)(?:<\|eot_id\|>|$)", # Llama3 + r"\[/INST\](.*?)(?:|$)", # Llama2 + r"assistant\n(.*?)(?:\n\n|$)", # Generic + ] + + found = False + for pattern in assistant_markers: + matches = list(re.finditer(pattern, formatted, re.DOTALL)) + if matches: + # Mask all tokens, then unmask assistant spans + labels[batch_idx, :] = IGNORE_TOKEN_ID + for match in matches: + start_char = match.start(1) + end_char = match.end(1) + for tok_idx, (tok_start, tok_end) in enumerate(offsets): + if tok_idx >= labels.shape[1]: + break + if tok_start >= start_char and tok_end <= end_char: + # Restore the shifted label for this position + if tok_idx < input_ids.shape[1] - 1: + labels[batch_idx, tok_idx] = input_ids[batch_idx, tok_idx + 1] + found = True + break + + if not found: + # No assistant pattern found — keep all labels (don't mask) + pass + + return labels + def _process_text_sample(self, examples: list): tokenized_examples = self.tokenizer( examples, From 07066c235a6cb48c6be749aaa40f1e606993461f Mon Sep 17 00:00:00 2001 From: Chenhan Yu Date: Tue, 31 Mar 2026 07:44:51 -0700 Subject: [PATCH 14/72] docs: add DFlash section to speculative decoding README Documents DFlash architecture, training usage, mask_token_id auto-detection, and current status including the known AR gap from data pipeline differences. Co-Authored-By: Claude Opus 4.6 (1M context) Signed-off-by: Chenhan Yu --- examples/speculative_decoding/README.md | 95 +++++++++++++++++++++++-- 1 file changed, 88 insertions(+), 7 deletions(-) diff --git a/examples/speculative_decoding/README.md b/examples/speculative_decoding/README.md index 2a29f644e6..96f75f8a81 100644 --- a/examples/speculative_decoding/README.md +++ b/examples/speculative_decoding/README.md @@ -319,15 +319,96 @@ trainer.save_state() trainer.save_model("") ``` +## DFlash: Block Diffusion for Flash Speculative Decoding + +DFlash ([arXiv:2602.06036](https://arxiv.org/abs/2602.06036)) is a parallel speculative decoding method that predicts multiple tokens simultaneously using block diffusion. Unlike autoregressive methods (EAGLE, Medusa) that draft one token at a time, DFlash predicts an entire block of tokens in parallel, then iteratively denoises them. + +### Architecture + +DFlash uses three key mechanisms: + +- **Feature Fusion**: Multi-layer hidden states from the target model are projected via a fully-connected layer and RMSNorm to create context features +- **KV Injection**: Context features are injected as K/V in every draft decoder layer, while Q comes from the noise embeddings. QK-Norm (RMSNorm on Q and K before RoPE) stabilizes attention +- **Parallel Drafting**: Within each block of size B, unknown positions use a `mask_token_id` token. Only block-start positions get the real token. The attention mask allows noise tokens to attend to all context tokens from previous blocks, plus causally within the same block + +### Training + +```bash +./launch_train.sh --model $BASE_MODEL \ + --output_dir $OUTPUT_DIR \ + --data input_conversations/train.jsonl \ + --num_epochs $NUM_EPOCH \ + --mode dflash \ + --dflash_block_size 16 \ + --dflash_num_layers 5 +``` + +Key arguments: + +| Flag | Default | Description | +|------|---------|-------------| +| `--mode dflash` | - | Enable DFlash mode | +| `--dflash_block_size` | 16 | Block size for parallel prediction | +| `--dflash_num_layers` | 5 | Number of decoder layers in draft module | +| `--dflash_config` | None | Path to JSON config for custom architecture | +| `--dflash_mask_token_id` | auto | Mask token ID (auto-detected from model) | +| `--dflash_disable_torch_compile` | False | Disable torch.compile | +| `--dflash_use_logit_distillation` | False | Use KD from target model logits instead of hard CE | + +### mask_token_id + +The `mask_token_id` is critical for DFlash training and inference. It must be consistent between training and deployment. Auto-detection logic: + +| Model Family | mask_token_id | Source | +|-------------|---------------|--------| +| Qwen3.5 | 248070 | Built-in `[MASK]` token | +| Qwen3 (8B) | 151643 | `eos_token_id` | +| Llama 3 | 128002 | `reserved_special_token_0` | +| Others | `pad_token_id` | Fallback | + +Override with `--dflash_mask_token_id ` if auto-detection is incorrect. + +### Configuring Draft Model + +Similar to EAGLE, provide a JSON config to customize the draft architecture: + +```json +{ + "num_hidden_layers": 5, + "rms_norm_eps": 1e-6 +} +``` + +Model dimensions (hidden_size, num_attention_heads, etc.) are automatically inherited from the base model. + +### Current Status (WIP) + +| Feature | Status | +|---------|--------| +| Architecture (Feature Fusion, KV Injection, Parallel Drafting) | Working | +| Online training with HF Trainer | Working | +| Inference / AR validation (`pseudo_speculative_generate`) | Working | +| z-lab checkpoint loading and inference (AR 7-9) | Working | +| Logit distillation option | Working | +| Response-only loss masking | Working | +| DDP training | Working (with `find_unused_parameters=True`) | + +**Known gap**: Training with ModelOpt achieves ~35% per-token accuracy (matching SpecForge's ~30%), but acceptance rate (AR) is lower than SpecForge-trained checkpoints (1.15 vs 1.95). Investigation shows the **data pipeline** differs significantly: + +- SpecForge uses its own tokenizer template with system prompt and response-only loss mask +- ModelOpt's `LanguageDataCollator` uses `apply_chat_template` with different formatting + +Aligning the data pipeline is the next step to close the AR gap. + ## Support Matrix -| Model | Medusa | EAGLE1/2 | EAGLE3 | -| :---: | :---: | :---: | :---: | -| LLAMA 2 | ✅ | ✅ | ✅ | -| LLAMA 3, 3.1 | ✅ | ✅ | ✅ | -| Mistral | ✅ | ✅ | ✅ | -| Phi 3 | ✅ | ✅ | ✅ | -| QWen 1.5,2,2.5,3 | ✅ | ✅ | ✅ | +| Model | Medusa | EAGLE1/2 | EAGLE3 | DFlash | +| :---: | :---: | :---: | :---: | :---: | +| LLAMA 2 | ✅ | ✅ | ✅ | ✅ | +| LLAMA 3, 3.1 | ✅ | ✅ | ✅ | ✅ | +| Mistral | ✅ | ✅ | ✅ | ✅ | +| Phi 3 | ✅ | ✅ | ✅ | ✅ | +| QWen 1.5,2,2.5,3 | ✅ | ✅ | ✅ | ✅ | ## Speculation Module Checkpoints From a32de63c4f1b67168e917e0e0302eb12dc43f403 Mon Sep 17 00:00:00 2001 From: Chenhan Yu Date: Tue, 31 Mar 2026 08:56:45 -0700 Subject: [PATCH 15/72] fix: resolve DFlash components from base model architecture Instead of hardcoding Llama components (LlamaMLP, LlamaRMSNorm, LlamaRotaryEmbedding), dynamically resolve them from the base model's transformers module (e.g., Qwen3MLP for Qwen3 models). Falls back to Llama components for unknown model types. Co-Authored-By: Claude Opus 4.6 (1M context) Signed-off-by: Chenhan Yu --- .../torch/speculative/plugins/hf_dflash.py | 120 +++++++++++++----- 1 file changed, 87 insertions(+), 33 deletions(-) diff --git a/modelopt/torch/speculative/plugins/hf_dflash.py b/modelopt/torch/speculative/plugins/hf_dflash.py index 4172e23f7e..5affbb7a2d 100644 --- a/modelopt/torch/speculative/plugins/hf_dflash.py +++ b/modelopt/torch/speculative/plugins/hf_dflash.py @@ -26,21 +26,63 @@ Reference: "DFlash: Block Diffusion for Flash Speculative Decoding" (arXiv:2602.06036) """ +import importlib + import torch import torch.nn.functional as F from torch import nn from transformers import PretrainedConfig, PreTrainedModel -from transformers.models.llama.modeling_llama import ( - LlamaMLP, - LlamaRMSNorm, - LlamaRotaryEmbedding, - rotate_half, -) from transformers.utils import ModelOutput from ..dflash.conversion import DFlashDMRegistry from ..dflash.dflash_model import DFlashModel + +def _resolve_model_components(model_type): + """Resolve MLP, RMSNorm, RotaryEmbedding from the base model's transformers module. + + Falls back to Llama components if the model type is unknown. + """ + fallback = "llama" + model_type = model_type or fallback + try: + mod = importlib.import_module(f"transformers.models.{model_type}.modeling_{model_type}") + except (ImportError, ModuleNotFoundError): + mod = importlib.import_module(f"transformers.models.{fallback}.modeling_{fallback}") + model_type = fallback + + prefix = model_type.capitalize() + # Handle multi-word model types (e.g., "qwen3" -> "Qwen3") + for attr in dir(mod): + if attr.lower() == f"{model_type}mlp": + prefix = attr.replace("MLP", "") + break + + mlp_cls = getattr(mod, f"{prefix}MLP", None) + norm_cls = getattr(mod, f"{prefix}RMSNorm", None) + rotary_cls = getattr(mod, f"{prefix}RotaryEmbedding", None) + rotate_half_fn = getattr(mod, "rotate_half", None) + + # Fallback to Llama if any component is missing + if not all([mlp_cls, norm_cls, rotary_cls, rotate_half_fn]): + from transformers.models.llama.modeling_llama import ( + LlamaMLP, + LlamaRMSNorm, + LlamaRotaryEmbedding, + ) + from transformers.models.llama.modeling_llama import rotate_half as _rotate_half + + mlp_cls = mlp_cls or LlamaMLP + norm_cls = norm_cls or LlamaRMSNorm + rotary_cls = rotary_cls or LlamaRotaryEmbedding + rotate_half_fn = rotate_half_fn or _rotate_half + + return mlp_cls, norm_cls, rotary_cls, rotate_half_fn + + +# Default to Llama components; overridden per-model during convert() +_MLP_CLS, _NORM_CLS, _ROTARY_CLS, _rotate_half = _resolve_model_components("llama") + __all__ = ["HFDFlashModel"] @@ -51,9 +93,7 @@ def build_target_layer_ids(num_target_layers, num_draft_layers): start = 1 end = num_target_layers - 3 span = end - start - return [ - int(round(start + (i * span) / (num_draft_layers - 1))) for i in range(num_draft_layers) - ] + return [round(start + (i * span) / (num_draft_layers - 1)) for i in range(num_draft_layers)] def apply_rotary_pos_emb(q, k, cos, sin): @@ -61,8 +101,8 @@ def apply_rotary_pos_emb(q, k, cos, sin): cos = cos.unsqueeze(1) # [B, 1, seq, dim] sin = sin.unsqueeze(1) q_len = q.size(2) - q_embed = (q * cos[:, :, -q_len:, :]) + (rotate_half(q) * sin[:, :, -q_len:, :]) - k_embed = (k * cos) + (rotate_half(k) * sin) + q_embed = (q * cos[:, :, -q_len:, :]) + (_rotate_half(q) * sin[:, :, -q_len:, :]) + k_embed = (k * cos) + (_rotate_half(k) * sin) return q_embed, k_embed @@ -87,8 +127,8 @@ def __init__(self, config, layer_idx): self.o_proj = nn.Linear(self.num_heads * self.head_dim, config.hidden_size, bias=False) # QK norm (matches Qwen3DFlashAttention) - self.q_norm = LlamaRMSNorm(self.head_dim, eps=config.rms_norm_eps) - self.k_norm = LlamaRMSNorm(self.head_dim, eps=config.rms_norm_eps) + self.q_norm = _NORM_CLS(self.head_dim, eps=config.rms_norm_eps) + self.k_norm = _NORM_CLS(self.head_dim, eps=config.rms_norm_eps) def forward(self, hidden_states, target_hidden, position_embeddings, attention_mask=None): bsz, q_len, _ = hidden_states.shape @@ -138,9 +178,9 @@ class DFlashDecoderLayer(nn.Module): def __init__(self, config, layer_idx): super().__init__() self.self_attn = DFlashAttention(config, layer_idx) - self.mlp = LlamaMLP(config) - self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.mlp = _MLP_CLS(config) + self.input_layernorm = _NORM_CLS(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = _NORM_CLS(config.hidden_size, eps=config.rms_norm_eps) def forward(self, hidden_states, target_hidden, position_embeddings, attention_mask=None): residual = hidden_states @@ -168,14 +208,14 @@ def __init__(self, config): # Feature fusion num_fused_layers = len(config.target_layer_ids) self.fc = nn.Linear(num_fused_layers * config.hidden_size, config.hidden_size, bias=False) - self.hidden_norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.hidden_norm = _NORM_CLS(config.hidden_size, eps=config.rms_norm_eps) # Decoder layers self.layers = nn.ModuleList( [DFlashDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] ) - self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.rotary_emb = LlamaRotaryEmbedding(config=config) + self.norm = _NORM_CLS(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = _ROTARY_CLS(config=config) def forward(self, noise_embedding, target_hidden, position_ids, attention_mask=None): """Forward matching SpecForge DFlashDraftModel.forward.""" @@ -356,7 +396,8 @@ def modify(self, config): self.dflash_config.mlp_bias = False self.dflash_config.head_dim = getattr( - self.dflash_config, "head_dim", + self.dflash_config, + "head_dim", self.dflash_config.hidden_size // self.dflash_config.num_attention_heads, ) self.dflash_config.block_size = self.dflash_block_size @@ -389,6 +430,14 @@ def modify(self, config): self._find_base_model_parts() + # Resolve model-specific components (MLP, RMSNorm, RotaryEmbedding) + # from the base model's architecture for weight compatibility + global _MLP_CLS, _NORM_CLS, _ROTARY_CLS, _rotate_half + _MLP_CLS, _NORM_CLS, _ROTARY_CLS, _rotate_half = _resolve_model_components( + getattr(base_config, "model_type", "llama") + ) + print(f"DFlash: using {_MLP_CLS.__name__} from {base_config.model_type}") + self.dflash_module = DFlashModule(self.dflash_config) self.dflash_module.to(self._base_model.dtype).to( next(self._base_model.layers[-1].parameters()).device @@ -400,7 +449,13 @@ def modify(self, config): # DynamicModule changes type(self) but the original class is in _original_cls. # Find the original HF model class (e.g., Qwen3_5ForConditionalGeneration) # by walking MRO and skipping DFlash/DynamicModule classes - skip_names = {"HFDFlashModel", "DFlashModel", "DynamicModule", "DFlashPreTrainedModel", "DFlashDraftModel"} + skip_names = { + "HFDFlashModel", + "DFlashModel", + "DynamicModule", + "DFlashPreTrainedModel", + "DFlashDraftModel", + } original_cls = None for cls in type(self).__mro__: if ( @@ -580,18 +635,19 @@ def pseudo_speculative_generate(self, input_ids, steps=1): # Extract target hidden states (raw, before FC projection) hid_offset = 1 - if not hasattr(self, '_psg_debug'): + if not hasattr(self, "_psg_debug"): self._psg_debug = True sel = [base_outputs.hidden_states[lid + hid_offset] for lid in self.target_layer_ids] - th = torch.cat(sel, dim=-1) - print(f"[psg] hidden_states layers: {len(base_outputs.hidden_states)}, target_hidden norm: {th.norm().item():.2f}, shape: {th.shape}") + th_dbg = torch.cat(sel, dim=-1) + n_layers = len(base_outputs.hidden_states) + th_norm = th_dbg.norm().item() + print( + f"[psg] hidden layers: {n_layers}, target_hidden: {th_dbg.shape}, norm: {th_norm:.2f}" + ) print(f"[psg] base_token: {base_token.item()}, mask_token_id: {self.mask_token_id}") - print(f"[psg] block_ids: {[self.mask_token_id]*self.dflash_block_size}") - bi = torch.full((1, self.dflash_block_size), self.mask_token_id, dtype=torch.long, device=input_ids.device) - bi[0, 0] = base_token[0, 0] - ne = self._base_model_embeddings(bi) - print(f"[psg] noise_emb norm: {ne.norm().item():.2f}, shape: {ne.shape}") - print(f"[psg] pos_ids will be: ctx=[0..{input_ids.shape[1]-1}], blk=[{input_ids.shape[1]}..{input_ids.shape[1]+self.dflash_block_size-1}]") + seq_len = input_ids.shape[1] + blk = self.dflash_block_size + print(f"[psg] pos: ctx=[0..{seq_len - 1}], blk=[{seq_len}..{seq_len + blk - 1}]") selected = [base_outputs.hidden_states[lid + hid_offset] for lid in self.target_layer_ids] target_hidden = torch.cat(selected, dim=-1) @@ -615,9 +671,7 @@ def pseudo_speculative_generate(self, input_ids, steps=1): pos_ids = torch.cat([ctx_positions, block_positions]).unsqueeze(0).expand(bsz, -1) # Attention mask: block sees ALL context + causal within block - attn_mask = torch.zeros( - 1, 1, block_size, ctx_len + block_size, device=device, dtype=dtype - ) + attn_mask = torch.zeros(1, 1, block_size, ctx_len + block_size, device=device, dtype=dtype) block_causal = torch.triu( torch.full( (block_size, block_size), torch.finfo(dtype).min, device=device, dtype=dtype From 6a6a9cafbc9e62ee51a49b42d7bbb921d1e4699e Mon Sep 17 00:00:00 2001 From: Chenhan Yu Date: Tue, 31 Mar 2026 09:08:52 -0700 Subject: [PATCH 16/72] fix: enable response-only loss mask for DFlash training Two bugs prevented response-only masking from working: 1. main.py never passed answer_only_loss=True to the data collator for DFlash mode, so all tokens had labels (511/512 instead of response-only). 2. HFDFlashModel.forward() used attention_mask (padding mask) for loss masking instead of labels. When answer_only_loss is enabled, the response-only information is in labels (where -100 = ignore), but this was completely ignored. Now uses labels when available. Co-Authored-By: Claude Opus 4.6 (1M context) Signed-off-by: Chenhan Yu --- examples/speculative_decoding/main.py | 5 ++++- modelopt/torch/speculative/plugins/hf_dflash.py | 7 +++++-- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/examples/speculative_decoding/main.py b/examples/speculative_decoding/main.py index 61245104bb..812f5d8cce 100644 --- a/examples/speculative_decoding/main.py +++ b/examples/speculative_decoding/main.py @@ -285,7 +285,10 @@ def train(): print_rank_0("Loading dataset...") if training_args.mode in ("eagle3", "dflash"): data_module = make_eagle_supervised_data_module( - tokenizer, data_args, train_len=training_args.training_seq_len + tokenizer, + data_args, + train_len=training_args.training_seq_len, + answer_only_loss=(training_args.mode == "dflash"), ) trainer = EagleTrainerWithAccLog( diff --git a/modelopt/torch/speculative/plugins/hf_dflash.py b/modelopt/torch/speculative/plugins/hf_dflash.py index 5affbb7a2d..61e930137d 100644 --- a/modelopt/torch/speculative/plugins/hf_dflash.py +++ b/modelopt/torch/speculative/plugins/hf_dflash.py @@ -531,8 +531,11 @@ def forward( effective_len = n_blocks * block_size input_ids_trunc = input_ids[:, :effective_len] target_hidden = target_hidden[:, :effective_len, :] - # Loss mask - if attention_mask is not None: + # Loss mask: use labels (response-only) if available, else attention_mask (padding) + if labels is not None: + # labels == -100 means "ignore" (system/user tokens when answer_only_loss=True) + loss_mask_input = (labels[:, :effective_len] != -100).float() + elif attention_mask is not None: loss_mask_input = attention_mask[:, :effective_len].float() else: loss_mask_input = torch.ones(bsz, effective_len, device=device) From a7778494d81be8f00aa1d438f7793a1958b605c0 Mon Sep 17 00:00:00 2001 From: Chenhan Yu Date: Tue, 31 Mar 2026 17:08:32 -0700 Subject: [PATCH 17/72] add: DFlash launcher example for Qwen3-8B - Add common/dflash/online_training.sh for launcher - Add examples/Qwen/Qwen3-8B/hf_online_dflash.yaml - Add --mode dflash support to launch_train.sh with DFlash-specific args (block_size, num_layers, mask_token_id, config) - DFlash uses DDP instead of FSDP for training Co-Authored-By: Claude Opus 4.6 (1M context) Signed-off-by: Chenhan Yu --- examples/speculative_decoding/launch_train.sh | 44 +++++++++++-- .../launcher/common/dflash/online_training.sh | 42 ++++++++++++ .../Qwen/Qwen3-8B/hf_online_dflash.yaml | 66 +++++++++++++++++++ 3 files changed, 145 insertions(+), 7 deletions(-) create mode 100644 tools/launcher/common/dflash/online_training.sh create mode 100644 tools/launcher/examples/Qwen/Qwen3-8B/hf_online_dflash.yaml diff --git a/examples/speculative_decoding/launch_train.sh b/examples/speculative_decoding/launch_train.sh index c15b97bdaa..27e17e388d 100755 --- a/examples/speculative_decoding/launch_train.sh +++ b/examples/speculative_decoding/launch_train.sh @@ -134,6 +134,22 @@ while [ $# -gt 0 ]; do if [[ "$1" != *=* ]]; then shift; fi FSDP="${1#*=}" ;; + --dflash_block_size*) + if [[ "$1" != *=* ]]; then shift; fi + DFLASH_BLOCK_SIZE="${1#*=}" + ;; + --dflash_num_layers*) + if [[ "$1" != *=* ]]; then shift; fi + DFLASH_NUM_LAYERS="${1#*=}" + ;; + --dflash_config*) + if [[ "$1" != *=* ]]; then shift; fi + DFLASH_CONFIG="${1#*=}" + ;; + --dflash_mask_token_id*) + if [[ "$1" != *=* ]]; then shift; fi + DFLASH_MASK_TOKEN_ID="${1#*=}" + ;; *) >&2 printf "Error: Invalid argument ${1#*=}\n" exit 1 @@ -195,8 +211,20 @@ if [[ "$MODE" == "eagle3" ]]; then else SPECULATIVE_ARGS="" fi +elif [[ "$MODE" == "dflash" ]]; then + DFLASH_BLOCK_SIZE=${DFLASH_BLOCK_SIZE:-16} + DFLASH_NUM_LAYERS=${DFLASH_NUM_LAYERS:-5} + SPECULATIVE_ARGS="--dflash_block_size $DFLASH_BLOCK_SIZE --dflash_num_layers $DFLASH_NUM_LAYERS --dflash_disable_torch_compile" + if [[ -n "$DFLASH_CONFIG" ]]; then + SPECULATIVE_ARGS="$SPECULATIVE_ARGS --dflash_config $DFLASH_CONFIG" + fi + if [[ -n "$DFLASH_MASK_TOKEN_ID" ]]; then + SPECULATIVE_ARGS="$SPECULATIVE_ARGS --dflash_mask_token_id $DFLASH_MASK_TOKEN_ID" + fi + # DFlash uses DDP instead of FSDP + FSDP_ARGS="--ddp_find_unused_parameters True --ddp_timeout 300" else - echo "Only eagle3 supported for now!" + echo "Unsupported mode: $MODE. Supported: eagle3, dflash" exit 1 fi @@ -218,12 +246,14 @@ else VLM_ARGS="" fi -if [[ "$TOTAL_GPU" -gt 1 && "$FSDP" == "True" ]]; then - #Use FSDP2 when multi GPU available - FSDP_ARGS="--fsdp 'full_shard' --fsdp_config ${SCRIPT_DIR}/fsdp_config.json" -else - #Otherwise, single GPU training - FSDP_ARGS="" +if [[ "$MODE" != "dflash" ]]; then + if [[ "$TOTAL_GPU" -gt 1 && "$FSDP" == "True" ]]; then + #Use FSDP2 when multi GPU available + FSDP_ARGS="--fsdp 'full_shard' --fsdp_config ${SCRIPT_DIR}/fsdp_config.json" + else + #Otherwise, single GPU training + FSDP_ARGS="" + fi fi diff --git a/tools/launcher/common/dflash/online_training.sh b/tools/launcher/common/dflash/online_training.sh new file mode 100644 index 0000000000..114a15ba3c --- /dev/null +++ b/tools/launcher/common/dflash/online_training.sh @@ -0,0 +1,42 @@ +#!/bin/bash + +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# DFlash online training script for the ModelOpt Launcher. +# Trains a DFlash draft model alongside the frozen target model. +# +# Required env vars: +# HF_MODEL_CKPT — path to the target HuggingFace model +# +# All other args are passed through to launch_train.sh. + +SCRIPT_DIR="$(dirname "$(readlink -f "$0")")" +source ${SCRIPT_DIR}/../service_utils.sh + +pip install -r modules/Model-Optimizer/examples/speculative_decoding/requirements.txt +pip install huggingface-hub>=1.2.1 +export PATH=$PATH:/workspace/.local/bin + +################################################################################################### + +trap 'error_handler $0 $LINENO' ERR + +bash modules/Model-Optimizer/examples/speculative_decoding/launch_train.sh \ + --model ${HF_MODEL_CKPT} \ + --mode dflash \ + ${@} + +################################################################################################### diff --git a/tools/launcher/examples/Qwen/Qwen3-8B/hf_online_dflash.yaml b/tools/launcher/examples/Qwen/Qwen3-8B/hf_online_dflash.yaml new file mode 100644 index 0000000000..af0dd7d64e --- /dev/null +++ b/tools/launcher/examples/Qwen/Qwen3-8B/hf_online_dflash.yaml @@ -0,0 +1,66 @@ +# DFlash online speculative decoding training for Qwen3-8B. +# +# Trains a DFlash draft model (block diffusion) using the frozen target model +# to extract multi-layer hidden states on the fly. +# +# Reference: "DFlash: Block Diffusion for Flash Speculative Decoding" (arXiv:2602.06036) +# +# Usage: +# uv run launch.py --yaml examples/Qwen/Qwen3-8B/hf_online_dflash.yaml --yes +# uv run slurm.py --yaml modules/Model-Optimizer/tools/launcher/examples/Qwen/Qwen3-8B/hf_online_dflash.yaml --yes + +job_name: Qwen3-8B_DFlash_online +pipeline: + allow_to_fail: false + skip: false + note: + + global_vars: + hf_model: /hf-local/Qwen/Qwen3-8B + data: /hf-local/modelopt/Speculative-Decoding-Dataset-v1-Qwen3-8B/sample-100K.jsonl + + # Step 1: Online DFlash training + task_0: + script: common/dflash/online_training.sh + args: + - --data <> + - --output_dir /scratchspace/dflash + - --num_epochs 3 + - --lr 1e-4 + - --training_seq_len 512 + - --save_steps 500000 + - --log_steps 100 + - --disable_tqdm True + - --ar_validate_steps 0 + - --dflash_block_size 16 + - --dflash_num_layers 5 + - --dflash_mask_token_id 151643 + environment: + - HF_MODEL_CKPT: <> + slurm_config: + _factory_: "slurm_factory" + nodes: 1 + ntasks_per_node: 1 + gpus_per_node: 8 + + # Step 2: Benchmark speculative decoding (VLLM backend) + task_1: + script: common/specdec_bench/quick_check.sh + args: + - --draft_model_dir /scratchspace/dflash + - --draft_length 3 + - --output_length 4096 + - --engine VLLM + - --tp_size 4 + - --ep_size 1 + - --speculative_algorithm EAGLE3 + - --mtbench /hf-local/HuggingFaceH4/mt_bench_prompts/raw/question.jsonl + - --concurrency 1 + environment: + - HF_MODEL_CKPT: <> + slurm_config: + _factory_: "slurm_factory" + nodes: 1 + ntasks_per_node: 1 + gpus_per_node: 4 + container: vllm/vllm-openai:latest From 2c56acabd31656905029ff1363d136a4e81bdbac Mon Sep 17 00:00:00 2001 From: Chenhan Yu Date: Tue, 31 Mar 2026 17:26:53 -0700 Subject: [PATCH 18/72] fix: inline values in DFlash launcher YAML for --yaml compatibility global_vars keys conflict with nemo_run's CLI parser when using --yaml format. Inline the values directly instead. Co-Authored-By: Claude Opus 4.6 (1M context) Signed-off-by: Chenhan Yu --- .../examples/Qwen/Qwen3-8B/hf_online_dflash.yaml | 14 +++----------- 1 file changed, 3 insertions(+), 11 deletions(-) diff --git a/tools/launcher/examples/Qwen/Qwen3-8B/hf_online_dflash.yaml b/tools/launcher/examples/Qwen/Qwen3-8B/hf_online_dflash.yaml index af0dd7d64e..f340ed6c9a 100644 --- a/tools/launcher/examples/Qwen/Qwen3-8B/hf_online_dflash.yaml +++ b/tools/launcher/examples/Qwen/Qwen3-8B/hf_online_dflash.yaml @@ -11,19 +11,11 @@ job_name: Qwen3-8B_DFlash_online pipeline: - allow_to_fail: false - skip: false - note: - - global_vars: - hf_model: /hf-local/Qwen/Qwen3-8B - data: /hf-local/modelopt/Speculative-Decoding-Dataset-v1-Qwen3-8B/sample-100K.jsonl - # Step 1: Online DFlash training task_0: script: common/dflash/online_training.sh args: - - --data <> + - --data /hf-local/modelopt/Speculative-Decoding-Dataset-v1-Qwen3-8B/sample-100K.jsonl - --output_dir /scratchspace/dflash - --num_epochs 3 - --lr 1e-4 @@ -36,7 +28,7 @@ pipeline: - --dflash_num_layers 5 - --dflash_mask_token_id 151643 environment: - - HF_MODEL_CKPT: <> + - HF_MODEL_CKPT: /hf-local/Qwen/Qwen3-8B slurm_config: _factory_: "slurm_factory" nodes: 1 @@ -57,7 +49,7 @@ pipeline: - --mtbench /hf-local/HuggingFaceH4/mt_bench_prompts/raw/question.jsonl - --concurrency 1 environment: - - HF_MODEL_CKPT: <> + - HF_MODEL_CKPT: /hf-local/Qwen/Qwen3-8B slurm_config: _factory_: "slurm_factory" nodes: 1 From 306fc3e586566a7dfe4a7ace5082688cbd5319bc Mon Sep 17 00:00:00 2001 From: Chenhan Yu Date: Tue, 31 Mar 2026 17:36:36 -0700 Subject: [PATCH 19/72] add: unit tests for DFlash speculative decoding Tests cover: - Model conversion (creates HFDFlashModel, DFlashModule, freezes base, sets target_layer_ids and mask_token_id) - Save/restore via HuggingFace checkpointing - Attention mask (shape, strictly-previous-block context, causal noise) - Loss mask (excludes block 0 and block starts, correct count) - Draft module forward (output shape, determinism) - Training forward (loss, accuracy, labels masking, all-masked edge case, gradient flow, eval mode fallback) - Target layer ID selection (single/multiple layers, spread, bounds) Co-Authored-By: Claude Opus 4.6 (1M context) Signed-off-by: Chenhan Yu --- .../speculative/plugins/test_hf_dflash.py | 367 ++++++++++++++++++ 1 file changed, 367 insertions(+) create mode 100644 tests/unit/torch/speculative/plugins/test_hf_dflash.py diff --git a/tests/unit/torch/speculative/plugins/test_hf_dflash.py b/tests/unit/torch/speculative/plugins/test_hf_dflash.py new file mode 100644 index 0000000000..bb7731a8a3 --- /dev/null +++ b/tests/unit/torch/speculative/plugins/test_hf_dflash.py @@ -0,0 +1,367 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for DFlash speculative decoding plugin.""" + +import os +from copy import deepcopy + +import pytest +import torch +from _test_utils.torch.transformers_models import ( + get_tiny_llama, + tf_modelopt_state_and_output_tester, +) +from transformers import AutoModelForCausalLM + +import modelopt.torch.opt as mto +import modelopt.torch.speculative as mtsp +from modelopt.torch.speculative.config import DFLASH_DEFAULT_CFG +from modelopt.torch.speculative.plugins.hf_dflash import ( + DFlashModule, + HFDFlashModel, + create_dflash_attention_mask, + create_dflash_loss_mask, +) + +BLOCK_SIZE = 4 +NUM_DRAFT_LAYERS = 2 +SEQ_LEN = 16 # must be multiple of BLOCK_SIZE + + +def _get_dflash_config(block_size=BLOCK_SIZE, num_layers=NUM_DRAFT_LAYERS): + config = deepcopy(DFLASH_DEFAULT_CFG["config"]) + config["dflash_block_size"] = block_size + config["dflash_use_torch_compile"] = False + config["dflash_architecture_config"] = { + "num_hidden_layers": num_layers, + "mask_token_id": 0, # use token 0 as mask for tiny model + } + return config + + +class TestDFlashConvert: + """Test DFlash model conversion.""" + + def test_convert_creates_dflash_model(self): + model = get_tiny_llama(num_hidden_layers=4) + config = _get_dflash_config() + mtsp.convert(model, [("dflash", config)]) + assert isinstance(model, HFDFlashModel) + + def test_convert_creates_dflash_module(self): + model = get_tiny_llama(num_hidden_layers=4) + config = _get_dflash_config() + mtsp.convert(model, [("dflash", config)]) + assert hasattr(model, "dflash_module") + assert isinstance(model.dflash_module, DFlashModule) + + def test_convert_freezes_base_model(self): + model = get_tiny_llama(num_hidden_layers=4) + config = _get_dflash_config() + mtsp.convert(model, [("dflash", config)]) + # Base model params should be frozen + for name, param in model.named_parameters(): + if "dflash_module" not in name: + assert not param.requires_grad, f"Base param {name} should be frozen" + + def test_convert_dflash_module_trainable(self): + model = get_tiny_llama(num_hidden_layers=4) + config = _get_dflash_config() + mtsp.convert(model, [("dflash", config)]) + # DFlash module params should be trainable + dflash_params = [(n, p) for n, p in model.named_parameters() if "dflash_module" in n] + assert len(dflash_params) > 0 + for name, param in dflash_params: + assert param.requires_grad, f"DFlash param {name} should be trainable" + + def test_convert_sets_target_layer_ids(self): + model = get_tiny_llama(num_hidden_layers=8) + config = _get_dflash_config(num_layers=3) + mtsp.convert(model, [("dflash", config)]) + assert hasattr(model, "target_layer_ids") + assert len(model.target_layer_ids) == 3 + # Layer IDs should be within target model range + for lid in model.target_layer_ids: + assert 0 <= lid < 8 + + def test_convert_sets_mask_token_id(self): + model = get_tiny_llama(num_hidden_layers=4) + config = _get_dflash_config() + mtsp.convert(model, [("dflash", config)]) + assert hasattr(model, "mask_token_id") + assert model.mask_token_id == 0 + + +class TestDFlashSaveRestore: + """Test DFlash model save and restore.""" + + def test_save_and_restore(self, tmp_path): + mto.enable_huggingface_checkpointing() + model_ref = get_tiny_llama(num_hidden_layers=4) + config = _get_dflash_config() + mtsp.convert(model_ref, [("dflash", config)]) + + model_ref.save_pretrained(tmp_path / "modelopt_model") + assert os.path.exists(tmp_path / "modelopt_model/modelopt_state.pth") + + model_test = AutoModelForCausalLM.from_pretrained(tmp_path / "modelopt_model") + assert isinstance(model_test, HFDFlashModel) + tf_modelopt_state_and_output_tester(model_ref, model_test) + + +class TestDFlashAttentionMask: + """Test DFlash attention mask construction.""" + + def test_mask_shape(self): + mask = create_dflash_attention_mask(SEQ_LEN, BLOCK_SIZE, "cpu", torch.float32) + assert mask.shape == (1, 1, SEQ_LEN, 2 * SEQ_LEN) + + def test_mask_context_strictly_previous_blocks(self): + """Context (left half): block B can only see blocks 0..B-1.""" + mask = create_dflash_attention_mask(8, 4, "cpu", torch.float32) + mask_2d = mask[0, 0] # [8, 16] + ctx_mask = mask_2d[:, :8] # context part + + # Block 0 (rows 0-3) should NOT see any context + assert (ctx_mask[:4, :] < 0).all() + + # Block 1 (rows 4-7) should see block 0 context only + assert (ctx_mask[4:8, :4] == 0).all() # can see block 0 + assert (ctx_mask[4:8, 4:8] < 0).all() # cannot see own block + + def test_mask_noise_causal_within_block(self): + """Noise (right half): causal within same block, blocked across blocks.""" + mask = create_dflash_attention_mask(8, 4, "cpu", torch.float32) + mask_2d = mask[0, 0] + noise_mask = mask_2d[:, 8:] # noise part + + # Block 0, position 0: can only see position 0 + assert noise_mask[0, 0] == 0 + assert (noise_mask[0, 1:4] < 0).all() + + # Block 0, position 3: can see positions 0-3 + assert (noise_mask[3, :4] == 0).all() + + # Block 1 cannot see block 0 noise + assert (noise_mask[4:8, :4] < 0).all() + + def test_mask_values_are_zero_or_neg_inf(self): + mask = create_dflash_attention_mask(SEQ_LEN, BLOCK_SIZE, "cpu", torch.float32) + unique_vals = mask.unique() + assert len(unique_vals) == 2 + assert 0.0 in unique_vals + assert unique_vals.min() == torch.finfo(torch.float32).min + + +class TestDFlashLossMask: + """Test DFlash loss mask construction.""" + + def test_loss_mask_shape(self): + mask = create_dflash_loss_mask(SEQ_LEN, BLOCK_SIZE, "cpu") + assert mask.shape == (SEQ_LEN,) + + def test_loss_mask_excludes_block_zero(self): + mask = create_dflash_loss_mask(SEQ_LEN, BLOCK_SIZE, "cpu") + # All positions in block 0 should be masked out + assert (mask[:BLOCK_SIZE] == 0).all() + + def test_loss_mask_excludes_block_starts(self): + mask = create_dflash_loss_mask(SEQ_LEN, BLOCK_SIZE, "cpu") + # Block start positions (every BLOCK_SIZE) should be masked + for i in range(0, SEQ_LEN, BLOCK_SIZE): + assert mask[i] == 0, f"Block start position {i} should be masked" + + def test_loss_mask_includes_non_start_positions(self): + mask = create_dflash_loss_mask(SEQ_LEN, BLOCK_SIZE, "cpu") + # Non-start positions in non-zero blocks should be included + for b in range(1, SEQ_LEN // BLOCK_SIZE): + for offset in range(1, BLOCK_SIZE): + pos = b * BLOCK_SIZE + offset + assert mask[pos] == 1, f"Position {pos} should be in loss" + + def test_loss_mask_count(self): + mask = create_dflash_loss_mask(SEQ_LEN, BLOCK_SIZE, "cpu") + num_blocks = SEQ_LEN // BLOCK_SIZE + # Block 0 excluded entirely (BLOCK_SIZE positions) + # Each remaining block excludes 1 start position + expected = (num_blocks - 1) * (BLOCK_SIZE - 1) + assert mask.sum().item() == expected + + +class TestDFlashModule: + """Test DFlash draft module forward pass.""" + + @pytest.fixture + def model_and_config(self): + model = get_tiny_llama(num_hidden_layers=4) + config = _get_dflash_config() + mtsp.convert(model, [("dflash", config)]) + return model + + def test_dflash_module_forward_shape(self, model_and_config): + model = model_and_config + bsz = 2 + hidden_size = model.config.hidden_size + num_layers = len(model.target_layer_ids) + + # Create inputs matching training forward + target_hidden = torch.randn(bsz, SEQ_LEN, num_layers * hidden_size) + noise_emb = torch.randn(bsz, SEQ_LEN, hidden_size) + pos_ids = ( + torch.cat([torch.arange(SEQ_LEN), torch.arange(SEQ_LEN)]).unsqueeze(0).expand(bsz, -1) + ) + + output = model.dflash_module( + noise_embedding=noise_emb, + target_hidden=target_hidden, + position_ids=pos_ids, + attention_mask=None, + ) + assert output.shape == (bsz, SEQ_LEN, hidden_size) + + def test_dflash_module_deterministic(self, model_and_config): + model = model_and_config + model.eval() + bsz = 1 + hidden_size = model.config.hidden_size + num_layers = len(model.target_layer_ids) + + target_hidden = torch.randn(bsz, SEQ_LEN, num_layers * hidden_size) + noise_emb = torch.randn(bsz, SEQ_LEN, hidden_size) + pos_ids = torch.cat([torch.arange(SEQ_LEN), torch.arange(SEQ_LEN)]).unsqueeze(0) + + with torch.no_grad(): + out1 = model.dflash_module( + noise_embedding=noise_emb, + target_hidden=target_hidden, + position_ids=pos_ids, + ) + out2 = model.dflash_module( + noise_embedding=noise_emb, + target_hidden=target_hidden, + position_ids=pos_ids, + ) + assert torch.allclose(out1, out2) + + +class TestDFlashTrainingForward: + """Test DFlash training forward pass end-to-end.""" + + @pytest.fixture + def model(self): + model = get_tiny_llama(num_hidden_layers=4) + config = _get_dflash_config() + mtsp.convert(model, [("dflash", config)]) + model.train() + return model + + def test_training_forward_returns_loss(self, model): + bsz = 2 + input_ids = torch.randint(0, model.config.vocab_size, (bsz, SEQ_LEN)) + attention_mask = torch.ones(bsz, SEQ_LEN, dtype=torch.long) + + output = model(input_ids=input_ids, attention_mask=attention_mask) + assert hasattr(output, "loss") + assert output.loss.requires_grad + + def test_training_forward_returns_accuracy(self, model): + bsz = 2 + input_ids = torch.randint(0, model.config.vocab_size, (bsz, SEQ_LEN)) + attention_mask = torch.ones(bsz, SEQ_LEN, dtype=torch.long) + + output = model(input_ids=input_ids, attention_mask=attention_mask) + assert hasattr(output, "train_acc") + + def test_training_forward_with_labels(self, model): + """Test that labels are used for response-only loss masking.""" + bsz = 2 + input_ids = torch.randint(0, model.config.vocab_size, (bsz, SEQ_LEN)) + attention_mask = torch.ones(bsz, SEQ_LEN, dtype=torch.long) + + # Labels with -100 for first half (masked), real labels for second half + labels = torch.full((bsz, SEQ_LEN), -100, dtype=torch.long) + labels[:, SEQ_LEN // 2 :] = input_ids[:, SEQ_LEN // 2 :] + + output = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels) + assert hasattr(output, "loss") + assert output.loss.requires_grad + + def test_training_forward_all_masked_labels(self, model): + """Test that all-masked labels produce zero loss without crashing.""" + bsz = 2 + input_ids = torch.randint(0, model.config.vocab_size, (bsz, SEQ_LEN)) + attention_mask = torch.ones(bsz, SEQ_LEN, dtype=torch.long) + labels = torch.full((bsz, SEQ_LEN), -100, dtype=torch.long) + + output = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels) + assert output.loss.item() == 0.0 + + def test_training_backward(self, model): + """Test that gradients flow to dflash_module.""" + bsz = 2 + input_ids = torch.randint(0, model.config.vocab_size, (bsz, SEQ_LEN)) + attention_mask = torch.ones(bsz, SEQ_LEN, dtype=torch.long) + + output = model(input_ids=input_ids, attention_mask=attention_mask) + output.loss.backward() + + # Check dflash_module has gradients + has_grad = False + for name, param in model.dflash_module.named_parameters(): + if param.grad is not None and param.grad.abs().sum() > 0: + has_grad = True + break + assert has_grad, "DFlash module should receive gradients" + + def test_eval_forward_uses_base_model(self, model): + """In eval mode, forward should use base model (not DFlash training).""" + model.eval() + bsz = 1 + input_ids = torch.randint(0, model.config.vocab_size, (bsz, SEQ_LEN)) + + with torch.no_grad(): + output = model(input_ids=input_ids) + # Should return base model output (logits over vocab) + assert output.logits.shape == (bsz, SEQ_LEN, model.config.vocab_size) + + +class TestBuildTargetLayerIds: + """Test target layer selection.""" + + def test_single_draft_layer(self): + from modelopt.torch.speculative.plugins.hf_dflash import build_target_layer_ids + + ids = build_target_layer_ids(32, 1) + assert len(ids) == 1 + assert ids[0] == 16 # middle layer + + def test_multiple_draft_layers(self): + from modelopt.torch.speculative.plugins.hf_dflash import build_target_layer_ids + + ids = build_target_layer_ids(36, 5) + assert len(ids) == 5 + # Should be monotonically increasing + assert ids == sorted(ids) + # Should be within [1, 33] for 36-layer model + assert all(1 <= lid <= 33 for lid in ids) + + def test_layer_ids_spread(self): + from modelopt.torch.speculative.plugins.hf_dflash import build_target_layer_ids + + ids = build_target_layer_ids(32, 5) + assert len(ids) == 5 + # No duplicates + assert len(set(ids)) == 5 From c4a3ecbaedabda4f462b81f609b10c0f976d3b93 Mon Sep 17 00:00:00 2001 From: Chenhan Yu Date: Tue, 31 Mar 2026 18:38:47 -0700 Subject: [PATCH 20/72] fix: add docstrings to DFlash classes for coverage check Co-Authored-By: Claude Opus 4.6 (1M context) Signed-off-by: Chenhan Yu --- .../torch/speculative/dflash/dflash_model.py | 1 + .../torch/speculative/plugins/hf_dflash.py | 6 + uv.lock | 236 +++++++++++++++++- 3 files changed, 232 insertions(+), 11 deletions(-) diff --git a/modelopt/torch/speculative/dflash/dflash_model.py b/modelopt/torch/speculative/dflash/dflash_model.py index 0e81689a57..e44b17b505 100644 --- a/modelopt/torch/speculative/dflash/dflash_model.py +++ b/modelopt/torch/speculative/dflash/dflash_model.py @@ -22,6 +22,7 @@ class DFlashModel(DynamicModule): """Base DFlash Model.""" def _setup(self): + """Register temporary attributes for the DFlash module.""" self._register_temp_attribute("dflash_module", None) def modify(self, config): diff --git a/modelopt/torch/speculative/plugins/hf_dflash.py b/modelopt/torch/speculative/plugins/hf_dflash.py index 61e930137d..8e1b275f73 100644 --- a/modelopt/torch/speculative/plugins/hf_dflash.py +++ b/modelopt/torch/speculative/plugins/hf_dflash.py @@ -110,6 +110,7 @@ class DFlashAttention(nn.Module): """Attention with KV injection, matching SpecForge Qwen3DFlashAttention.""" def __init__(self, config, layer_idx): + """Initialize DFlash attention with KV injection projections and QK-norm.""" super().__init__() self.config = config self.layer_idx = layer_idx @@ -131,6 +132,7 @@ def __init__(self, config, layer_idx): self.k_norm = _NORM_CLS(self.head_dim, eps=config.rms_norm_eps) def forward(self, hidden_states, target_hidden, position_embeddings, attention_mask=None): + """Forward with KV injection: Q from noise, K/V from context+noise.""" bsz, q_len, _ = hidden_states.shape ctx_len = target_hidden.shape[1] @@ -176,6 +178,7 @@ class DFlashDecoderLayer(nn.Module): """Draft decoder layer with KV injection.""" def __init__(self, config, layer_idx): + """Initialize decoder layer with attention, MLP, and layer norms.""" super().__init__() self.self_attn = DFlashAttention(config, layer_idx) self.mlp = _MLP_CLS(config) @@ -183,6 +186,7 @@ def __init__(self, config, layer_idx): self.post_attention_layernorm = _NORM_CLS(config.hidden_size, eps=config.rms_norm_eps) def forward(self, hidden_states, target_hidden, position_embeddings, attention_mask=None): + """Forward pass with residual connections.""" residual = hidden_states hidden_states = self.input_layernorm(hidden_states) hidden_states = self.self_attn( @@ -201,6 +205,7 @@ class DFlashModule(nn.Module): """DFlash draft module matching SpecForge DFlashDraftModel.""" def __init__(self, config): + """Initialize DFlash module with feature fusion, decoder layers, and rotary embeddings.""" super().__init__() self.config = config self.block_size = config.block_size @@ -338,6 +343,7 @@ def _auto_detect_mask_token_id(base_config): return eos_id or 0 def _find_base_model_parts(self): + """Locate base model submodules (backbone, embeddings, lm_head) by probing known paths.""" for name, paths in { "base_model_path": ["model.language_model", "model", "backbone"], "base_model_embeddings_path": [ diff --git a/uv.lock b/uv.lock index d890e361cb..3cd6db2083 100644 --- a/uv.lock +++ b/uv.lock @@ -20,9 +20,6 @@ resolution-markers = [ "python_full_version < '3.11' and platform_machine != 'aarch64' and sys_platform != 'darwin' and sys_platform != 'win32'", ] -[manifest] -overrides = [{ name = "torch", marker = "sys_platform == 'never'" }] - [[package]] name = "accelerate" version = "1.13.0" @@ -35,7 +32,7 @@ dependencies = [ { name = "psutil" }, { name = "pyyaml" }, { name = "safetensors" }, - { name = "torch", marker = "sys_platform == 'never'" }, + { name = "torch" }, ] sdist = { url = "https://files.pythonhosted.org/packages/ca/14/787e5498cd062640f0f3d92ef4ae4063174f76f9afd29d13fc52a319daae/accelerate-1.13.0.tar.gz", hash = "sha256:d631b4e0f5b3de4aff2d7e9e6857d164810dfc3237d54d017f075122d057b236", size = 402835, upload-time = "2026-03-04T19:34:12.359Z" } wheels = [ @@ -480,6 +477,21 @@ dependencies = [ ] sdist = { url = "https://files.pythonhosted.org/packages/54/27/01d9078a77b9e31b79b9716e66ca4db74f4744c5232bcb3e8769395c4280/cppimport-22.8.2.tar.gz", hash = "sha256:bbb4957102db41bc99ad72c233bce92f9d1fd91be352fc07878c4361033a401f", size = 26635, upload-time = "2022-08-02T16:50:36.872Z" } +[[package]] +name = "cuda-bindings" +version = "12.9.4" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "cuda-pathfinder", marker = "platform_machine != 'aarch64' and sys_platform != 'darwin' and sys_platform != 'win32'" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/7a/d8/b546104b8da3f562c1ff8ab36d130c8fe1dd6a045ced80b4f6ad74f7d4e1/cuda_bindings-12.9.4-cp310-cp310-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:4d3c842c2a4303b2a580fe955018e31aea30278be19795ae05226235268032e5", size = 12148218, upload-time = "2025-10-21T14:51:28.855Z" }, + { url = "https://files.pythonhosted.org/packages/45/e7/b47792cc2d01c7e1d37c32402182524774dadd2d26339bd224e0e913832e/cuda_bindings-12.9.4-cp311-cp311-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:c912a3d9e6b6651853eed8eed96d6800d69c08e94052c292fec3f282c5a817c9", size = 12210593, upload-time = "2025-10-21T14:51:36.574Z" }, + { url = "https://files.pythonhosted.org/packages/a9/c1/dabe88f52c3e3760d861401bb994df08f672ec893b8f7592dc91626adcf3/cuda_bindings-12.9.4-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:fda147a344e8eaeca0c6ff113d2851ffca8f7dfc0a6c932374ee5c47caa649c8", size = 12151019, upload-time = "2025-10-21T14:51:43.167Z" }, + { url = "https://files.pythonhosted.org/packages/63/56/e465c31dc9111be3441a9ba7df1941fe98f4aa6e71e8788a3fb4534ce24d/cuda_bindings-12.9.4-cp313-cp313-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:32bdc5a76906be4c61eb98f546a6786c5773a881f3b166486449b5d141e4a39f", size = 11906628, upload-time = "2025-10-21T14:51:49.905Z" }, + { url = "https://files.pythonhosted.org/packages/a3/84/1e6be415e37478070aeeee5884c2022713c1ecc735e6d82d744de0252eee/cuda_bindings-12.9.4-cp313-cp313t-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:56e0043c457a99ac473ddc926fe0dc4046694d99caef633e92601ab52cbe17eb", size = 11925991, upload-time = "2025-10-21T14:51:56.535Z" }, +] + [[package]] name = "cuda-pathfinder" version = "1.4.3" @@ -554,7 +566,7 @@ dependencies = [ { name = "psutil", marker = "sys_platform != 'win32'" }, { name = "py-cpuinfo", marker = "sys_platform != 'win32'" }, { name = "pydantic", marker = "sys_platform != 'win32'" }, - { name = "torch", marker = "sys_platform == 'never'" }, + { name = "torch", marker = "sys_platform != 'win32'" }, { name = "tqdm", marker = "sys_platform != 'win32'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/19/11/46b9eb3806ca7a5e9bdddb7e873855a2d59a9f87f0675ae8231678d98434/deepspeed-0.18.8.tar.gz", hash = "sha256:e4e051a144b0c74270c46e4970139f9a86a61ff26959c5e463000c4a93b99304", size = 1647226, upload-time = "2026-03-13T18:49:48.568Z" } @@ -1311,7 +1323,9 @@ name = "networkx" version = "3.4.2" source = { registry = "https://pypi.org/simple" } resolution-markers = [ + "python_full_version < '3.11' and platform_machine == 'aarch64' and sys_platform == 'win32'", "(python_full_version < '3.11' and platform_machine == 'aarch64' and sys_platform != 'win32') or (python_full_version < '3.11' and sys_platform == 'darwin')", + "python_full_version < '3.11' and platform_machine != 'aarch64' and sys_platform == 'win32'", "python_full_version < '3.11' and platform_machine != 'aarch64' and sys_platform != 'darwin' and sys_platform != 'win32'", ] sdist = { url = "https://files.pythonhosted.org/packages/fd/1d/06475e1cd5264c0b870ea2cc6fdb3e37177c1e565c43f56ff17a10e3937f/networkx-3.4.2.tar.gz", hash = "sha256:307c3669428c5362aab27c8a1260aa8f47c4e91d3891f48be0141738d8d053e1", size = 2151368, upload-time = "2024-10-21T12:39:38.695Z" } @@ -1324,12 +1338,18 @@ name = "networkx" version = "3.6.1" source = { registry = "https://pypi.org/simple" } resolution-markers = [ + "python_full_version >= '3.13' and platform_machine == 'aarch64' and sys_platform == 'win32'", + "python_full_version == '3.12.*' and platform_machine == 'aarch64' and sys_platform == 'win32'", + "python_full_version == '3.11.*' and platform_machine == 'aarch64' and sys_platform == 'win32'", "(python_full_version >= '3.13' and platform_machine == 'aarch64' and sys_platform != 'win32') or (python_full_version >= '3.13' and sys_platform == 'darwin')", "(python_full_version == '3.12.*' and platform_machine == 'aarch64' and sys_platform != 'win32') or (python_full_version == '3.12.*' and sys_platform == 'darwin')", "(python_full_version == '3.11.*' and platform_machine == 'aarch64' and sys_platform != 'win32') or (python_full_version == '3.11.*' and sys_platform == 'darwin')", "python_full_version >= '3.13' and platform_machine != 'aarch64' and sys_platform != 'darwin' and sys_platform != 'win32'", "python_full_version == '3.12.*' and platform_machine != 'aarch64' and sys_platform != 'darwin' and sys_platform != 'win32'", "python_full_version == '3.11.*' and platform_machine != 'aarch64' and sys_platform != 'darwin' and sys_platform != 'win32'", + "python_full_version >= '3.13' and platform_machine != 'aarch64' and sys_platform == 'win32'", + "python_full_version == '3.12.*' and platform_machine != 'aarch64' and sys_platform == 'win32'", + "python_full_version == '3.11.*' and platform_machine != 'aarch64' and sys_platform == 'win32'", ] sdist = { url = "https://files.pythonhosted.org/packages/6a/51/63fe664f3908c97be9d2e4f1158eb633317598cfa6e1fc14af5383f17512/networkx-3.6.1.tar.gz", hash = "sha256:26b7c357accc0c8cde558ad486283728b65b6a95d85ee1cd66bafab4c8168509", size = 2517025, upload-time = "2025-12-08T17:02:39.908Z" } wheels = [ @@ -1526,6 +1546,108 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/57/a7/b35835e278c18b85206834b3aa3abe68e77a98769c59233d1f6300284781/numpy-2.4.3-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:4b42639cdde6d24e732ff823a3fa5b701d8acad89c4142bc1d0bd6dc85200ba5", size = 12504685, upload-time = "2026-03-09T07:58:50.525Z" }, ] +[[package]] +name = "nvidia-cublas-cu12" +version = "12.8.4.1" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/dc/61/e24b560ab2e2eaeb3c839129175fb330dfcfc29e5203196e5541a4c44682/nvidia_cublas_cu12-12.8.4.1-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:8ac4e771d5a348c551b2a426eda6193c19aa630236b418086020df5ba9667142", size = 594346921, upload-time = "2025-03-07T01:44:31.254Z" }, +] + +[[package]] +name = "nvidia-cuda-cupti-cu12" +version = "12.8.90" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f8/02/2adcaa145158bf1a8295d83591d22e4103dbfd821bcaf6f3f53151ca4ffa/nvidia_cuda_cupti_cu12-12.8.90-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:ea0cb07ebda26bb9b29ba82cda34849e73c166c18162d3913575b0c9db9a6182", size = 10248621, upload-time = "2025-03-07T01:40:21.213Z" }, +] + +[[package]] +name = "nvidia-cuda-nvrtc-cu12" +version = "12.8.93" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/05/6b/32f747947df2da6994e999492ab306a903659555dddc0fbdeb9d71f75e52/nvidia_cuda_nvrtc_cu12-12.8.93-py3-none-manylinux2010_x86_64.manylinux_2_12_x86_64.whl", hash = "sha256:a7756528852ef889772a84c6cd89d41dfa74667e24cca16bb31f8f061e3e9994", size = 88040029, upload-time = "2025-03-07T01:42:13.562Z" }, +] + +[[package]] +name = "nvidia-cuda-runtime-cu12" +version = "12.8.90" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/0d/9b/a997b638fcd068ad6e4d53b8551a7d30fe8b404d6f1804abf1df69838932/nvidia_cuda_runtime_cu12-12.8.90-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:adade8dcbd0edf427b7204d480d6066d33902cab2a4707dcfc48a2d0fd44ab90", size = 954765, upload-time = "2025-03-07T01:40:01.615Z" }, +] + +[[package]] +name = "nvidia-cudnn-cu12" +version = "9.10.2.21" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "nvidia-cublas-cu12", marker = "platform_machine != 'aarch64' and sys_platform != 'darwin' and sys_platform != 'win32'" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/ba/51/e123d997aa098c61d029f76663dedbfb9bc8dcf8c60cbd6adbe42f76d049/nvidia_cudnn_cu12-9.10.2.21-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:949452be657fa16687d0930933f032835951ef0892b37d2d53824d1a84dc97a8", size = 706758467, upload-time = "2025-06-06T21:54:08.597Z" }, +] + +[[package]] +name = "nvidia-cufft-cu12" +version = "11.3.3.83" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "nvidia-nvjitlink-cu12", marker = "platform_machine != 'aarch64' and sys_platform != 'darwin' and sys_platform != 'win32'" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/1f/13/ee4e00f30e676b66ae65b4f08cb5bcbb8392c03f54f2d5413ea99a5d1c80/nvidia_cufft_cu12-11.3.3.83-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:4d2dd21ec0b88cf61b62e6b43564355e5222e4a3fb394cac0db101f2dd0d4f74", size = 193118695, upload-time = "2025-03-07T01:45:27.821Z" }, +] + +[[package]] +name = "nvidia-cufile-cu12" +version = "1.13.1.3" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/bb/fe/1bcba1dfbfb8d01be8d93f07bfc502c93fa23afa6fd5ab3fc7c1df71038a/nvidia_cufile_cu12-1.13.1.3-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:1d069003be650e131b21c932ec3d8969c1715379251f8d23a1860554b1cb24fc", size = 1197834, upload-time = "2025-03-07T01:45:50.723Z" }, +] + +[[package]] +name = "nvidia-curand-cu12" +version = "10.3.9.90" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/fb/aa/6584b56dc84ebe9cf93226a5cde4d99080c8e90ab40f0c27bda7a0f29aa1/nvidia_curand_cu12-10.3.9.90-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:b32331d4f4df5d6eefa0554c565b626c7216f87a06a4f56fab27c3b68a830ec9", size = 63619976, upload-time = "2025-03-07T01:46:23.323Z" }, +] + +[[package]] +name = "nvidia-cusolver-cu12" +version = "11.7.3.90" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "nvidia-cublas-cu12", marker = "platform_machine != 'aarch64' and sys_platform != 'darwin' and sys_platform != 'win32'" }, + { name = "nvidia-cusparse-cu12", marker = "platform_machine != 'aarch64' and sys_platform != 'darwin' and sys_platform != 'win32'" }, + { name = "nvidia-nvjitlink-cu12", marker = "platform_machine != 'aarch64' and sys_platform != 'darwin' and sys_platform != 'win32'" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/85/48/9a13d2975803e8cf2777d5ed57b87a0b6ca2cc795f9a4f59796a910bfb80/nvidia_cusolver_cu12-11.7.3.90-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:4376c11ad263152bd50ea295c05370360776f8c3427b30991df774f9fb26c450", size = 267506905, upload-time = "2025-03-07T01:47:16.273Z" }, +] + +[[package]] +name = "nvidia-cusparse-cu12" +version = "12.5.8.93" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "nvidia-nvjitlink-cu12", marker = "platform_machine != 'aarch64' and sys_platform != 'darwin' and sys_platform != 'win32'" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/c2/f5/e1854cb2f2bcd4280c44736c93550cc300ff4b8c95ebe370d0aa7d2b473d/nvidia_cusparse_cu12-12.5.8.93-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:1ec05d76bbbd8b61b06a80e1eaf8cf4959c3d4ce8e711b65ebd0443bb0ebb13b", size = 288216466, upload-time = "2025-03-07T01:48:13.779Z" }, +] + +[[package]] +name = "nvidia-cusparselt-cu12" +version = "0.7.1" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/56/79/12978b96bd44274fe38b5dde5cfb660b1d114f70a65ef962bcbbed99b549/nvidia_cusparselt_cu12-0.7.1-py3-none-manylinux2014_x86_64.whl", hash = "sha256:f1bb701d6b930d5a7cea44c19ceb973311500847f81b634d802b7b539dc55623", size = 287193691, upload-time = "2025-02-26T00:15:44.104Z" }, +] + [[package]] name = "nvidia-ml-py" version = "13.595.45" @@ -1554,7 +1676,7 @@ dependencies = [ { name = "scipy", version = "1.15.3", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, { name = "scipy", version = "1.17.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, { name = "setuptools" }, - { name = "torch", marker = "sys_platform == 'never'" }, + { name = "torch" }, { name = "tqdm" }, ] @@ -1756,11 +1878,43 @@ requires-dist = [ { name = "tox", marker = "extra == 'dev-test'", specifier = ">4.18" }, { name = "tox-current-env", marker = "extra == 'dev-test'", specifier = ">=0.0.12" }, { name = "tqdm" }, - { name = "transformers", marker = "extra == 'hf'", specifier = ">=4.53,<5.0" }, + { name = "transformers", marker = "extra == 'hf'", specifier = ">=4.56,<5.0" }, { name = "wonderwords", marker = "extra == 'hf'" }, ] provides-extras = ["onnx", "hf", "dev-lint", "dev-docs", "dev-test", "all", "dev"] +[[package]] +name = "nvidia-nccl-cu12" +version = "2.27.5" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/6e/89/f7a07dc961b60645dbbf42e80f2bc85ade7feb9a491b11a1e973aa00071f/nvidia_nccl_cu12-2.27.5-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:ad730cf15cb5d25fe849c6e6ca9eb5b76db16a80f13f425ac68d8e2e55624457", size = 322348229, upload-time = "2025-06-26T04:11:28.385Z" }, +] + +[[package]] +name = "nvidia-nvjitlink-cu12" +version = "12.8.93" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f6/74/86a07f1d0f42998ca31312f998bd3b9a7eff7f52378f4f270c8679c77fb9/nvidia_nvjitlink_cu12-12.8.93-py3-none-manylinux2010_x86_64.manylinux_2_12_x86_64.whl", hash = "sha256:81ff63371a7ebd6e6451970684f916be2eab07321b73c9d244dc2b4da7f73b88", size = 39254836, upload-time = "2025-03-07T01:49:55.661Z" }, +] + +[[package]] +name = "nvidia-nvshmem-cu12" +version = "3.4.5" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b5/09/6ea3ea725f82e1e76684f0708bbedd871fc96da89945adeba65c3835a64c/nvidia_nvshmem_cu12-3.4.5-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:042f2500f24c021db8a06c5eec2539027d57460e1c1a762055a6554f72c369bd", size = 139103095, upload-time = "2025-09-06T00:32:31.266Z" }, +] + +[[package]] +name = "nvidia-nvtx-cu12" +version = "12.8.90" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a2/eb/86626c1bbc2edb86323022371c39aa48df6fd8b0a1647bc274577f72e90b/nvidia_nvtx_cu12-12.8.90-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:5b17e2001cc0d751a5bc2c6ec6d26ad95913324a4adb86788c944f8ce9ba441f", size = 89954, upload-time = "2025-03-07T01:42:44.131Z" }, +] + [[package]] name = "omegaconf" version = "2.3.0" @@ -2157,7 +2311,7 @@ dependencies = [ { name = "psutil" }, { name = "pyyaml" }, { name = "safetensors" }, - { name = "torch", marker = "sys_platform == 'never'" }, + { name = "torch" }, { name = "tqdm" }, { name = "transformers" }, ] @@ -3403,7 +3557,7 @@ dependencies = [ { name = "huggingface-hub" }, { name = "pyyaml" }, { name = "safetensors" }, - { name = "torch", marker = "sys_platform == 'never'" }, + { name = "torch" }, { name = "torchvision" }, ] sdist = { url = "https://files.pythonhosted.org/packages/d7/2c/593109822fe735e637382aca6640c1102c19797f7791f1fd1dab2d6c3cb1/timm-1.0.25.tar.gz", hash = "sha256:47f59fc2754725735cc81bb83bcbfce5bec4ebd5d4bb9e69da57daa92fcfa768", size = 2414743, upload-time = "2026-02-23T16:49:00.137Z" } @@ -3491,15 +3645,63 @@ name = "torch" version = "2.10.0" source = { registry = "https://pypi.org/simple" } dependencies = [ + { name = "cuda-bindings", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "filelock" }, { name = "fsspec" }, { name = "jinja2" }, { name = "networkx", version = "3.4.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, { name = "networkx", version = "3.6.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, + { name = "nvidia-cublas-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cuda-cupti-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cuda-nvrtc-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cuda-runtime-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cudnn-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cufft-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cufile-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-curand-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cusolver-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cusparse-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cusparselt-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-nccl-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-nvjitlink-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-nvshmem-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-nvtx-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "setuptools", marker = "python_full_version >= '3.12'" }, { name = "sympy" }, + { name = "triton", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "typing-extensions" }, ] +wheels = [ + { url = "https://files.pythonhosted.org/packages/5b/30/bfebdd8ec77db9a79775121789992d6b3b75ee5494971294d7b4b7c999bc/torch-2.10.0-2-cp310-none-macosx_11_0_arm64.whl", hash = "sha256:2b980edd8d7c0a68c4e951ee1856334a43193f98730d97408fbd148c1a933313", size = 79411457, upload-time = "2026-02-10T21:44:59.189Z" }, + { url = "https://files.pythonhosted.org/packages/0f/8b/4b61d6e13f7108f36910df9ab4b58fd389cc2520d54d81b88660804aad99/torch-2.10.0-2-cp311-none-macosx_11_0_arm64.whl", hash = "sha256:418997cb02d0a0f1497cf6a09f63166f9f5df9f3e16c8a716ab76a72127c714f", size = 79423467, upload-time = "2026-02-10T21:44:48.711Z" }, + { url = "https://files.pythonhosted.org/packages/d3/54/a2ba279afcca44bbd320d4e73675b282fcee3d81400ea1b53934efca6462/torch-2.10.0-2-cp312-none-macosx_11_0_arm64.whl", hash = "sha256:13ec4add8c3faaed8d13e0574f5cd4a323c11655546f91fbe6afa77b57423574", size = 79498202, upload-time = "2026-02-10T21:44:52.603Z" }, + { url = "https://files.pythonhosted.org/packages/ec/23/2c9fe0c9c27f7f6cb865abcea8a4568f29f00acaeadfc6a37f6801f84cb4/torch-2.10.0-2-cp313-none-macosx_11_0_arm64.whl", hash = "sha256:e521c9f030a3774ed770a9c011751fb47c4d12029a3d6522116e48431f2ff89e", size = 79498254, upload-time = "2026-02-10T21:44:44.095Z" }, + { url = "https://files.pythonhosted.org/packages/16/ee/efbd56687be60ef9af0c9c0ebe106964c07400eade5b0af8902a1d8cd58c/torch-2.10.0-3-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:a1ff626b884f8c4e897c4c33782bdacdff842a165fee79817b1dd549fdda1321", size = 915510070, upload-time = "2026-03-11T14:16:39.386Z" }, + { url = "https://files.pythonhosted.org/packages/36/ab/7b562f1808d3f65414cd80a4f7d4bb00979d9355616c034c171249e1a303/torch-2.10.0-3-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:ac5bdcbb074384c66fa160c15b1ead77839e3fe7ed117d667249afce0acabfac", size = 915518691, upload-time = "2026-03-11T14:15:43.147Z" }, + { url = "https://files.pythonhosted.org/packages/b3/7a/abada41517ce0011775f0f4eacc79659bc9bc6c361e6bfe6f7052a6b9363/torch-2.10.0-3-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:98c01b8bb5e3240426dcde1446eed6f40c778091c8544767ef1168fc663a05a6", size = 915622781, upload-time = "2026-03-11T14:17:11.354Z" }, + { url = "https://files.pythonhosted.org/packages/ab/c6/4dfe238342ffdcec5aef1c96c457548762d33c40b45a1ab7033bb26d2ff2/torch-2.10.0-3-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:80b1b5bfe38eb0e9f5ff09f206dcac0a87aadd084230d4a36eea5ec5232c115b", size = 915627275, upload-time = "2026-03-11T14:16:11.325Z" }, + { url = "https://files.pythonhosted.org/packages/d8/f0/72bf18847f58f877a6a8acf60614b14935e2f156d942483af1ffc081aea0/torch-2.10.0-3-cp313-cp313t-manylinux_2_28_x86_64.whl", hash = "sha256:46b3574d93a2a8134b3f5475cfb98e2eb46771794c57015f6ad1fb795ec25e49", size = 915523474, upload-time = "2026-03-11T14:17:44.422Z" }, + { url = "https://files.pythonhosted.org/packages/0c/1a/c61f36cfd446170ec27b3a4984f072fd06dab6b5d7ce27e11adb35d6c838/torch-2.10.0-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:5276fa790a666ee8becaffff8acb711922252521b28fbce5db7db5cf9cb2026d", size = 145992962, upload-time = "2026-01-21T16:24:14.04Z" }, + { url = "https://files.pythonhosted.org/packages/b5/60/6662535354191e2d1555296045b63e4279e5a9dbad49acf55a5d38655a39/torch-2.10.0-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:aaf663927bcd490ae971469a624c322202a2a1e68936eb952535ca4cd3b90444", size = 915599237, upload-time = "2026-01-21T16:23:25.497Z" }, + { url = "https://files.pythonhosted.org/packages/40/b8/66bbe96f0d79be2b5c697b2e0b187ed792a15c6c4b8904613454651db848/torch-2.10.0-cp310-cp310-win_amd64.whl", hash = "sha256:a4be6a2a190b32ff5c8002a0977a25ea60e64f7ba46b1be37093c141d9c49aeb", size = 113720931, upload-time = "2026-01-21T16:24:23.743Z" }, + { url = "https://files.pythonhosted.org/packages/76/bb/d820f90e69cda6c8169b32a0c6a3ab7b17bf7990b8f2c680077c24a3c14c/torch-2.10.0-cp310-none-macosx_11_0_arm64.whl", hash = "sha256:35e407430795c8d3edb07a1d711c41cc1f9eaddc8b2f1cc0a165a6767a8fb73d", size = 79411450, upload-time = "2026-01-21T16:25:30.692Z" }, + { url = "https://files.pythonhosted.org/packages/78/89/f5554b13ebd71e05c0b002f95148033e730d3f7067f67423026cc9c69410/torch-2.10.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:3282d9febd1e4e476630a099692b44fdc214ee9bf8ee5377732d9d9dfe5712e4", size = 145992610, upload-time = "2026-01-21T16:25:26.327Z" }, + { url = "https://files.pythonhosted.org/packages/ae/30/a3a2120621bf9c17779b169fc17e3dc29b230c29d0f8222f499f5e159aa8/torch-2.10.0-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:a2f9edd8dbc99f62bc4dfb78af7bf89499bca3d753423ac1b4e06592e467b763", size = 915607863, upload-time = "2026-01-21T16:25:06.696Z" }, + { url = "https://files.pythonhosted.org/packages/6f/3d/c87b33c5f260a2a8ad68da7147e105f05868c281c63d65ed85aa4da98c66/torch-2.10.0-cp311-cp311-win_amd64.whl", hash = "sha256:29b7009dba4b7a1c960260fc8ac85022c784250af43af9fb0ebafc9883782ebd", size = 113723116, upload-time = "2026-01-21T16:25:21.916Z" }, + { url = "https://files.pythonhosted.org/packages/61/d8/15b9d9d3a6b0c01b883787bd056acbe5cc321090d4b216d3ea89a8fcfdf3/torch-2.10.0-cp311-none-macosx_11_0_arm64.whl", hash = "sha256:b7bd80f3477b830dd166c707c5b0b82a898e7b16f59a7d9d42778dd058272e8b", size = 79423461, upload-time = "2026-01-21T16:24:50.266Z" }, + { url = "https://files.pythonhosted.org/packages/cc/af/758e242e9102e9988969b5e621d41f36b8f258bb4a099109b7a4b4b50ea4/torch-2.10.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:5fd4117d89ffd47e3dcc71e71a22efac24828ad781c7e46aaaf56bf7f2796acf", size = 145996088, upload-time = "2026-01-21T16:24:44.171Z" }, + { url = "https://files.pythonhosted.org/packages/23/8e/3c74db5e53bff7ed9e34c8123e6a8bfef718b2450c35eefab85bb4a7e270/torch-2.10.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:787124e7db3b379d4f1ed54dd12ae7c741c16a4d29b49c0226a89bea50923ffb", size = 915711952, upload-time = "2026-01-21T16:23:53.503Z" }, + { url = "https://files.pythonhosted.org/packages/6e/01/624c4324ca01f66ae4c7cd1b74eb16fb52596dce66dbe51eff95ef9e7a4c/torch-2.10.0-cp312-cp312-win_amd64.whl", hash = "sha256:2c66c61f44c5f903046cc696d088e21062644cbe541c7f1c4eaae88b2ad23547", size = 113757972, upload-time = "2026-01-21T16:24:39.516Z" }, + { url = "https://files.pythonhosted.org/packages/c9/5c/dee910b87c4d5c0fcb41b50839ae04df87c1cfc663cf1b5fca7ea565eeaa/torch-2.10.0-cp312-none-macosx_11_0_arm64.whl", hash = "sha256:6d3707a61863d1c4d6ebba7be4ca320f42b869ee657e9b2c21c736bf17000294", size = 79498198, upload-time = "2026-01-21T16:24:34.704Z" }, + { url = "https://files.pythonhosted.org/packages/c9/6f/f2e91e34e3fcba2e3fc8d8f74e7d6c22e74e480bbd1db7bc8900fdf3e95c/torch-2.10.0-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:5c4d217b14741e40776dd7074d9006fd28b8a97ef5654db959d8635b2fe5f29b", size = 146004247, upload-time = "2026-01-21T16:24:29.335Z" }, + { url = "https://files.pythonhosted.org/packages/98/fb/5160261aeb5e1ee12ee95fe599d0541f7c976c3701d607d8fc29e623229f/torch-2.10.0-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:6b71486353fce0f9714ca0c9ef1c850a2ae766b409808acd58e9678a3edb7738", size = 915716445, upload-time = "2026-01-21T16:22:45.353Z" }, + { url = "https://files.pythonhosted.org/packages/6a/16/502fb1b41e6d868e8deb5b0e3ae926bbb36dab8ceb0d1b769b266ad7b0c3/torch-2.10.0-cp313-cp313-win_amd64.whl", hash = "sha256:c2ee399c644dc92ef7bc0d4f7e74b5360c37cdbe7c5ba11318dda49ffac2bc57", size = 113757050, upload-time = "2026-01-21T16:24:19.204Z" }, + { url = "https://files.pythonhosted.org/packages/1a/0b/39929b148f4824bc3ad6f9f72a29d4ad865bcf7ebfc2fa67584773e083d2/torch-2.10.0-cp313-cp313t-macosx_14_0_arm64.whl", hash = "sha256:3202429f58309b9fa96a614885eace4b7995729f44beb54d3e4a47773649d382", size = 79851305, upload-time = "2026-01-21T16:24:09.209Z" }, + { url = "https://files.pythonhosted.org/packages/d8/14/21fbce63bc452381ba5f74a2c0a959fdf5ad5803ccc0c654e752e0dbe91a/torch-2.10.0-cp313-cp313t-manylinux_2_28_aarch64.whl", hash = "sha256:aae1b29cd68e50a9397f5ee897b9c24742e9e306f88a807a27d617f07adb3bd8", size = 146005472, upload-time = "2026-01-21T16:22:29.022Z" }, + { url = "https://files.pythonhosted.org/packages/54/fd/b207d1c525cb570ef47f3e9f836b154685011fce11a2f444ba8a4084d042/torch-2.10.0-cp313-cp313t-manylinux_2_28_x86_64.whl", hash = "sha256:6021db85958db2f07ec94e1bc77212721ba4920c12a18dc552d2ae36a3eb163f", size = 915612644, upload-time = "2026-01-21T16:21:47.019Z" }, + { url = "https://files.pythonhosted.org/packages/36/53/0197f868c75f1050b199fe58f9bf3bf3aecac9b4e85cc9c964383d745403/torch-2.10.0-cp313-cp313t-win_amd64.whl", hash = "sha256:ff43db38af76fda183156153983c9a096fc4c78d0cd1e07b14a2314c7f01c2c8", size = 113997015, upload-time = "2026-01-21T16:23:00.767Z" }, + { url = "https://files.pythonhosted.org/packages/0e/13/e76b4d9c160e89fff48bf16b449ea324bda84745d2ab30294c37c2434c0d/torch-2.10.0-cp313-none-macosx_11_0_arm64.whl", hash = "sha256:cdf2a523d699b70d613243211ecaac14fe9c5df8a0b0a9c02add60fb2a413e0f", size = 79498248, upload-time = "2026-01-21T16:23:09.315Z" }, +] [[package]] name = "torch-geometric" @@ -3529,7 +3731,7 @@ source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, { name = "numpy", version = "2.4.3", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, - { name = "torch", marker = "sys_platform == 'never'" }, + { name = "torch" }, { name = "torchvision" }, ] sdist = { url = "https://files.pythonhosted.org/packages/6f/36/574c0c46e818533b78b3c09505211162918188325ab4165ef11a3f295755/torchprofile-0.0.4.tar.gz", hash = "sha256:96b6da17d752a06b02977e078aea95614893b31d4117dd5dcd081f30ce65611b", size = 4557, upload-time = "2021-06-22T04:58:03.592Z" } @@ -3545,7 +3747,7 @@ dependencies = [ { name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, { name = "numpy", version = "2.4.3", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, { name = "pillow" }, - { name = "torch", marker = "sys_platform == 'never'" }, + { name = "torch" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/50/ae/cbf727421eb73f1cf907fbe5788326a08f111b3f6b6ddca15426b53fec9a/torchvision-0.25.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:a95c47abb817d4e90ea1a8e57bd0d728e3e6b533b3495ae77d84d883c4d11f56", size = 1874919, upload-time = "2026-01-21T16:27:47.617Z" }, @@ -3638,6 +3840,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/03/b8/e484ef633af3887baeeb4b6ad12743363af7cce68ae51e938e00aaa0529d/transformers-4.57.6-py3-none-any.whl", hash = "sha256:4c9e9de11333ddfe5114bc872c9f370509198acf0b87a832a0ab9458e2bd0550", size = 11993498, upload-time = "2026-01-16T10:38:31.289Z" }, ] +[[package]] +name = "triton" +version = "3.6.0" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/8c/f7/f1c9d3424ab199ac53c2da567b859bcddbb9c9e7154805119f8bd95ec36f/triton-3.6.0-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:a6550fae429e0667e397e5de64b332d1e5695b73650ee75a6146e2e902770bea", size = 188105201, upload-time = "2026-01-20T16:00:29.272Z" }, + { url = "https://files.pythonhosted.org/packages/e0/12/b05ba554d2c623bffa59922b94b0775673de251f468a9609bc9e45de95e9/triton-3.6.0-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:e8e323d608e3a9bfcc2d9efcc90ceefb764a82b99dea12a86d643c72539ad5d3", size = 188214640, upload-time = "2026-01-20T16:00:35.869Z" }, + { url = "https://files.pythonhosted.org/packages/ab/a8/cdf8b3e4c98132f965f88c2313a4b493266832ad47fb52f23d14d4f86bb5/triton-3.6.0-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:74caf5e34b66d9f3a429af689c1c7128daba1d8208df60e81106b115c00d6fca", size = 188266850, upload-time = "2026-01-20T16:00:43.041Z" }, + { url = "https://files.pythonhosted.org/packages/f9/0b/37d991d8c130ce81a8728ae3c25b6e60935838e9be1b58791f5997b24a54/triton-3.6.0-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:10c7f76c6e72d2ef08df639e3d0d30729112f47a56b0c81672edc05ee5116ac9", size = 188289450, upload-time = "2026-01-20T16:00:49.136Z" }, + { url = "https://files.pythonhosted.org/packages/35/f8/9c66bfc55361ec6d0e4040a0337fb5924ceb23de4648b8a81ae9d33b2b38/triton-3.6.0-cp313-cp313t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:d002e07d7180fd65e622134fbd980c9a3d4211fb85224b56a0a0efbd422ab72f", size = 188400296, upload-time = "2026-01-20T16:00:56.042Z" }, +] + [[package]] name = "typing-extensions" version = "4.15.0" From 1c23ceda24a998a993bd12d811b4e0957f2f65a7 Mon Sep 17 00:00:00 2001 From: Chenhan Yu Date: Tue, 31 Mar 2026 19:56:04 -0700 Subject: [PATCH 21/72] add: AR validation step to DFlash launcher pipeline MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add common/dflash/ar_validate.sh for acceptance rate evaluation - Insert as task_1 in hf_online_dflash.yaml (train → AR → benchmark) - Uses single GPU, loads trained checkpoint from /scratchspace/dflash - Evaluates on MT-Bench prompts with pseudo_speculative_generate Co-Authored-By: Claude Opus 4.6 (1M context) Signed-off-by: Chenhan Yu --- tools/launcher/common/dflash/ar_validate.sh | 118 ++++++++++++++++++ .../Qwen/Qwen3-8B/hf_online_dflash.yaml | 23 +++- 2 files changed, 140 insertions(+), 1 deletion(-) create mode 100644 tools/launcher/common/dflash/ar_validate.sh diff --git a/tools/launcher/common/dflash/ar_validate.sh b/tools/launcher/common/dflash/ar_validate.sh new file mode 100644 index 0000000000..01ad61ffd1 --- /dev/null +++ b/tools/launcher/common/dflash/ar_validate.sh @@ -0,0 +1,118 @@ +#!/bin/bash + +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# DFlash AR (Acceptance Rate) validation script. +# Loads a trained DFlash checkpoint and evaluates speculative decoding AR on MT-Bench. +# +# Required env vars: +# HF_MODEL_CKPT — path to the target HuggingFace model +# DFLASH_CKPT — path to the trained DFlash checkpoint +# DFLASH_BLOCK_SIZE — block size (default: 16) +# DFLASH_NUM_LAYERS — number of draft layers (default: 5) +# DFLASH_MASK_TOKEN_ID — mask token ID (default: auto-detect) +# NUM_SAMPLES — number of MT-Bench samples to evaluate (default: 20) + +SCRIPT_DIR="$(dirname "$(readlink -f "$0")")" +source ${SCRIPT_DIR}/../service_utils.sh +trap 'error_handler $0 $LINENO' ERR + +DFLASH_BLOCK_SIZE=${DFLASH_BLOCK_SIZE:-16} +DFLASH_NUM_LAYERS=${DFLASH_NUM_LAYERS:-5} +NUM_SAMPLES=${NUM_SAMPLES:-20} + +# Build mask_token_id arg +if [ -n "$DFLASH_MASK_TOKEN_ID" ]; then + MASK_ARG="'mask_token_id': ${DFLASH_MASK_TOKEN_ID}," +else + MASK_ARG="" +fi + +echo "=== DFlash AR Validation ===" +echo "Target model: ${HF_MODEL_CKPT}" +echo "DFlash checkpoint: ${DFLASH_CKPT}" +echo "Block size: ${DFLASH_BLOCK_SIZE}" +echo "Draft layers: ${DFLASH_NUM_LAYERS}" +echo "Samples: ${NUM_SAMPLES}" + +CUDA_VISIBLE_DEVICES=0 python3 -c " +import torch +from datasets import load_dataset +from transformers import AutoModelForCausalLM, AutoTokenizer +from modelopt.torch.speculative.plugins.transformers import HFARValidation +import modelopt.torch.opt as mto +import modelopt.torch.speculative as mtsp + +mto.enable_huggingface_checkpointing() + +model = AutoModelForCausalLM.from_pretrained( + '${HF_MODEL_CKPT}', torch_dtype=torch.bfloat16, device_map={'': 0}, trust_remote_code=True +) +tokenizer = AutoTokenizer.from_pretrained('${HF_MODEL_CKPT}', trust_remote_code=True) + +config = { + 'dflash_block_size': ${DFLASH_BLOCK_SIZE}, + 'dflash_architecture_config': { + 'num_hidden_layers': ${DFLASH_NUM_LAYERS}, + ${MASK_ARG} + }, + 'dflash_use_torch_compile': False, +} +mtsp.convert(model, [('dflash', config)]) + +# Load trained DFlash weights +import glob +from safetensors.torch import load_file +ckpt_files = sorted(glob.glob('${DFLASH_CKPT}/model*.safetensors')) +if ckpt_files: + state = {} + for f in ckpt_files: + state.update(load_file(f)) + dflash_keys = {k: v for k, v in state.items() if 'dflash_module' in k} + model.load_state_dict(dflash_keys, strict=False) + print(f'Loaded {len(dflash_keys)} DFlash weights from {len(ckpt_files)} file(s)') +else: + print('WARNING: No checkpoint files found, using random weights') + +model.eval() +validator = HFARValidation(model, tokenizer) + +ds = load_dataset('/hf-local/HuggingFaceH4/mt_bench_prompts')['train'] +num_samples = min(${NUM_SAMPLES}, len(ds)) + +ars = [] +for i in range(num_samples): + prompt = ds[i]['prompt'][0] + chat = [{'role': 'user', 'content': prompt}] + text = tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True) + input_ids = tokenizer(text, return_tensors='pt').input_ids.cuda() + try: + _, ar = validator.validate(osl=32, input_ids=input_ids, steps=3) + ars.append(ar) + print(f' AR={ar:.2f} | {prompt[:60]}') + except Exception as e: + print(f' ERROR | {prompt[:60]}... | {e}') + +if ars: + avg_ar = sum(ars) / len(ars) + print(f'\n==== DFlash AR Results ====') + print(f'Samples: {len(ars)}') + print(f'Average AR: {avg_ar:.4f}') + print(f'Min AR: {min(ars):.4f}') + print(f'Max AR: {max(ars):.4f}') +else: + print('No AR results collected') +" diff --git a/tools/launcher/examples/Qwen/Qwen3-8B/hf_online_dflash.yaml b/tools/launcher/examples/Qwen/Qwen3-8B/hf_online_dflash.yaml index f340ed6c9a..4c3d499b2f 100644 --- a/tools/launcher/examples/Qwen/Qwen3-8B/hf_online_dflash.yaml +++ b/tools/launcher/examples/Qwen/Qwen3-8B/hf_online_dflash.yaml @@ -3,6 +3,11 @@ # Trains a DFlash draft model (block diffusion) using the frozen target model # to extract multi-layer hidden states on the fly. # +# 3-step pipeline: +# task_0: Online DFlash training +# task_1: AR (Acceptance Rate) validation on MT-Bench +# task_2: Benchmark speculative decoding speedup via VLLM +# # Reference: "DFlash: Block Diffusion for Flash Speculative Decoding" (arXiv:2602.06036) # # Usage: @@ -35,8 +40,24 @@ pipeline: ntasks_per_node: 1 gpus_per_node: 8 - # Step 2: Benchmark speculative decoding (VLLM backend) + # Step 2: AR validation on MT-Bench task_1: + script: common/dflash/ar_validate.sh + environment: + - HF_MODEL_CKPT: /hf-local/Qwen/Qwen3-8B + - DFLASH_CKPT: /scratchspace/dflash + - DFLASH_BLOCK_SIZE: "16" + - DFLASH_NUM_LAYERS: "5" + - DFLASH_MASK_TOKEN_ID: "151643" + - NUM_SAMPLES: "20" + slurm_config: + _factory_: "slurm_factory" + nodes: 1 + ntasks_per_node: 1 + gpus_per_node: 1 + + # Step 3: Benchmark speculative decoding (VLLM backend) + task_2: script: common/specdec_bench/quick_check.sh args: - --draft_model_dir /scratchspace/dflash From 38450b063a1cd6e89d3be8cfe69e8e8eca623feb Mon Sep 17 00:00:00 2001 From: Chenhan Yu Date: Tue, 31 Mar 2026 20:35:24 -0700 Subject: [PATCH 22/72] fix: split DFlash tests into CPU (unit) and GPU tests Move GPU-dependent tests (training forward, module forward) from tests/unit/ to tests/gpu/torch/speculative/plugins/. CPU-only tests (masks, layer IDs, convert, save/restore) remain in tests/unit/. Co-Authored-By: Claude Opus 4.6 (1M context) Signed-off-by: Chenhan Yu --- tests/gpu/torch/speculative/__init__.py | 15 ++ .../gpu/torch/speculative/plugins/__init__.py | 15 ++ .../speculative/plugins/test_hf_dflash.py | 190 ++++++++++++++++++ .../speculative/plugins/test_hf_dflash.py | 172 ++-------------- 4 files changed, 242 insertions(+), 150 deletions(-) create mode 100644 tests/gpu/torch/speculative/__init__.py create mode 100644 tests/gpu/torch/speculative/plugins/__init__.py create mode 100644 tests/gpu/torch/speculative/plugins/test_hf_dflash.py diff --git a/tests/gpu/torch/speculative/__init__.py b/tests/gpu/torch/speculative/__init__.py new file mode 100644 index 0000000000..47f1c65a15 --- /dev/null +++ b/tests/gpu/torch/speculative/__init__.py @@ -0,0 +1,15 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + diff --git a/tests/gpu/torch/speculative/plugins/__init__.py b/tests/gpu/torch/speculative/plugins/__init__.py new file mode 100644 index 0000000000..47f1c65a15 --- /dev/null +++ b/tests/gpu/torch/speculative/plugins/__init__.py @@ -0,0 +1,15 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + diff --git a/tests/gpu/torch/speculative/plugins/test_hf_dflash.py b/tests/gpu/torch/speculative/plugins/test_hf_dflash.py new file mode 100644 index 0000000000..d27fffc1c8 --- /dev/null +++ b/tests/gpu/torch/speculative/plugins/test_hf_dflash.py @@ -0,0 +1,190 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""GPU tests for DFlash speculative decoding plugin. + +These tests require a CUDA GPU. CPU-only tests are in tests/unit/. +""" + +from copy import deepcopy + +import pytest +import torch +from _test_utils.torch.transformers_models import get_tiny_llama + +import modelopt.torch.speculative as mtsp +from modelopt.torch.speculative.config import DFLASH_DEFAULT_CFG + +BLOCK_SIZE = 4 +NUM_DRAFT_LAYERS = 2 +SEQ_LEN = 16 # must be multiple of BLOCK_SIZE + + +def _get_dflash_config(block_size=BLOCK_SIZE, num_layers=NUM_DRAFT_LAYERS): + """Create a DFlash config for testing.""" + config = deepcopy(DFLASH_DEFAULT_CFG["config"]) + config["dflash_block_size"] = block_size + config["dflash_use_torch_compile"] = False + config["dflash_architecture_config"] = { + "num_hidden_layers": num_layers, + "mask_token_id": 0, + } + return config + + +@pytest.fixture +def dflash_model(): + """Create a tiny DFlash model on GPU.""" + model = get_tiny_llama(num_hidden_layers=4) + config = _get_dflash_config() + mtsp.convert(model, [("dflash", config)]) + model = model.cuda() + return model + + +class TestDFlashModuleGPU: + """Test DFlash draft module forward pass on GPU.""" + + def test_dflash_module_forward_shape(self, dflash_model): + """Test that draft module produces correct output shape.""" + model = dflash_model + bsz = 2 + hidden_size = model.config.hidden_size + num_layers = len(model.target_layer_ids) + + target_hidden = torch.randn(bsz, SEQ_LEN, num_layers * hidden_size, device="cuda") + noise_emb = torch.randn(bsz, SEQ_LEN, hidden_size, device="cuda") + pos_ids = ( + torch.cat([torch.arange(SEQ_LEN), torch.arange(SEQ_LEN)]) + .unsqueeze(0) + .expand(bsz, -1) + .cuda() + ) + + output = model.dflash_module( + noise_embedding=noise_emb, + target_hidden=target_hidden, + position_ids=pos_ids, + attention_mask=None, + ) + assert output.shape == (bsz, SEQ_LEN, hidden_size) + + def test_dflash_module_deterministic(self, dflash_model): + """Test that draft module produces identical outputs for same input.""" + model = dflash_model + model.eval() + bsz = 1 + hidden_size = model.config.hidden_size + num_layers = len(model.target_layer_ids) + + target_hidden = torch.randn(bsz, SEQ_LEN, num_layers * hidden_size, device="cuda") + noise_emb = torch.randn(bsz, SEQ_LEN, hidden_size, device="cuda") + pos_ids = torch.cat([torch.arange(SEQ_LEN), torch.arange(SEQ_LEN)]).unsqueeze(0).cuda() + + with torch.no_grad(): + out1 = model.dflash_module( + noise_embedding=noise_emb, + target_hidden=target_hidden, + position_ids=pos_ids, + ) + out2 = model.dflash_module( + noise_embedding=noise_emb, + target_hidden=target_hidden, + position_ids=pos_ids, + ) + assert torch.allclose(out1, out2) + + +class TestDFlashTrainingForwardGPU: + """Test DFlash training forward pass end-to-end on GPU.""" + + @pytest.fixture + def model(self): + """Create a tiny DFlash model in training mode on GPU.""" + model = get_tiny_llama(num_hidden_layers=4) + config = _get_dflash_config() + mtsp.convert(model, [("dflash", config)]) + model = model.cuda() + model.train() + return model + + def test_training_forward_returns_loss(self, model): + """Test that training forward returns a differentiable loss.""" + bsz = 2 + input_ids = torch.randint(0, model.config.vocab_size, (bsz, SEQ_LEN), device="cuda") + attention_mask = torch.ones(bsz, SEQ_LEN, dtype=torch.long, device="cuda") + + output = model(input_ids=input_ids, attention_mask=attention_mask) + assert hasattr(output, "loss") + assert output.loss.requires_grad + + def test_training_forward_returns_accuracy(self, model): + """Test that training forward returns train_acc.""" + bsz = 2 + input_ids = torch.randint(0, model.config.vocab_size, (bsz, SEQ_LEN), device="cuda") + attention_mask = torch.ones(bsz, SEQ_LEN, dtype=torch.long, device="cuda") + + output = model(input_ids=input_ids, attention_mask=attention_mask) + assert hasattr(output, "train_acc") + + def test_training_forward_with_labels(self, model): + """Test that labels are used for response-only loss masking.""" + bsz = 2 + input_ids = torch.randint(0, model.config.vocab_size, (bsz, SEQ_LEN), device="cuda") + attention_mask = torch.ones(bsz, SEQ_LEN, dtype=torch.long, device="cuda") + + # Labels with -100 for first half (masked), real labels for second half + labels = torch.full((bsz, SEQ_LEN), -100, dtype=torch.long, device="cuda") + labels[:, SEQ_LEN // 2 :] = input_ids[:, SEQ_LEN // 2 :] + + output = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels) + assert hasattr(output, "loss") + assert output.loss.requires_grad + + def test_training_forward_all_masked_labels(self, model): + """Test that all-masked labels produce zero loss without crashing.""" + bsz = 2 + input_ids = torch.randint(0, model.config.vocab_size, (bsz, SEQ_LEN), device="cuda") + attention_mask = torch.ones(bsz, SEQ_LEN, dtype=torch.long, device="cuda") + labels = torch.full((bsz, SEQ_LEN), -100, dtype=torch.long, device="cuda") + + output = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels) + assert output.loss.item() == 0.0 + + def test_training_backward(self, model): + """Test that gradients flow to dflash_module.""" + bsz = 2 + input_ids = torch.randint(0, model.config.vocab_size, (bsz, SEQ_LEN), device="cuda") + attention_mask = torch.ones(bsz, SEQ_LEN, dtype=torch.long, device="cuda") + + output = model(input_ids=input_ids, attention_mask=attention_mask) + output.loss.backward() + + has_grad = False + for name, param in model.dflash_module.named_parameters(): + if param.grad is not None and param.grad.abs().sum() > 0: + has_grad = True + break + assert has_grad, "DFlash module should receive gradients" + + def test_eval_forward_uses_base_model(self, model): + """In eval mode, forward should use base model (not DFlash training).""" + model.eval() + bsz = 1 + input_ids = torch.randint(0, model.config.vocab_size, (bsz, SEQ_LEN), device="cuda") + + with torch.no_grad(): + output = model(input_ids=input_ids) + assert output.logits.shape == (bsz, SEQ_LEN, model.config.vocab_size) diff --git a/tests/unit/torch/speculative/plugins/test_hf_dflash.py b/tests/unit/torch/speculative/plugins/test_hf_dflash.py index bb7731a8a3..fa35bd5baf 100644 --- a/tests/unit/torch/speculative/plugins/test_hf_dflash.py +++ b/tests/unit/torch/speculative/plugins/test_hf_dflash.py @@ -13,12 +13,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Unit tests for DFlash speculative decoding plugin.""" +"""CPU unit tests for DFlash speculative decoding plugin. + +GPU-dependent tests (training forward, module forward) are in tests/gpu/. +""" import os from copy import deepcopy -import pytest import torch from _test_utils.torch.transformers_models import ( get_tiny_llama, @@ -42,6 +44,7 @@ def _get_dflash_config(block_size=BLOCK_SIZE, num_layers=NUM_DRAFT_LAYERS): + """Create a DFlash config for testing.""" config = deepcopy(DFLASH_DEFAULT_CFG["config"]) config["dflash_block_size"] = block_size config["dflash_use_torch_compile"] = False @@ -56,12 +59,14 @@ class TestDFlashConvert: """Test DFlash model conversion.""" def test_convert_creates_dflash_model(self): + """Test that convert produces an HFDFlashModel.""" model = get_tiny_llama(num_hidden_layers=4) config = _get_dflash_config() mtsp.convert(model, [("dflash", config)]) assert isinstance(model, HFDFlashModel) def test_convert_creates_dflash_module(self): + """Test that convert attaches a DFlashModule.""" model = get_tiny_llama(num_hidden_layers=4) config = _get_dflash_config() mtsp.convert(model, [("dflash", config)]) @@ -69,35 +74,36 @@ def test_convert_creates_dflash_module(self): assert isinstance(model.dflash_module, DFlashModule) def test_convert_freezes_base_model(self): + """Test that base model parameters are frozen after convert.""" model = get_tiny_llama(num_hidden_layers=4) config = _get_dflash_config() mtsp.convert(model, [("dflash", config)]) - # Base model params should be frozen for name, param in model.named_parameters(): if "dflash_module" not in name: assert not param.requires_grad, f"Base param {name} should be frozen" def test_convert_dflash_module_trainable(self): + """Test that DFlash module parameters are trainable after convert.""" model = get_tiny_llama(num_hidden_layers=4) config = _get_dflash_config() mtsp.convert(model, [("dflash", config)]) - # DFlash module params should be trainable dflash_params = [(n, p) for n, p in model.named_parameters() if "dflash_module" in n] assert len(dflash_params) > 0 for name, param in dflash_params: assert param.requires_grad, f"DFlash param {name} should be trainable" def test_convert_sets_target_layer_ids(self): + """Test that target layer IDs are set correctly.""" model = get_tiny_llama(num_hidden_layers=8) config = _get_dflash_config(num_layers=3) mtsp.convert(model, [("dflash", config)]) assert hasattr(model, "target_layer_ids") assert len(model.target_layer_ids) == 3 - # Layer IDs should be within target model range for lid in model.target_layer_ids: assert 0 <= lid < 8 def test_convert_sets_mask_token_id(self): + """Test that mask_token_id is set from config.""" model = get_tiny_llama(num_hidden_layers=4) config = _get_dflash_config() mtsp.convert(model, [("dflash", config)]) @@ -109,6 +115,7 @@ class TestDFlashSaveRestore: """Test DFlash model save and restore.""" def test_save_and_restore(self, tmp_path): + """Test round-trip save/load preserves modelopt state and outputs.""" mto.enable_huggingface_checkpointing() model_ref = get_tiny_llama(num_hidden_layers=4) config = _get_dflash_config() @@ -126,6 +133,7 @@ class TestDFlashAttentionMask: """Test DFlash attention mask construction.""" def test_mask_shape(self): + """Test mask has shape [1, 1, L, 2L].""" mask = create_dflash_attention_mask(SEQ_LEN, BLOCK_SIZE, "cpu", torch.float32) assert mask.shape == (1, 1, SEQ_LEN, 2 * SEQ_LEN) @@ -159,6 +167,7 @@ def test_mask_noise_causal_within_block(self): assert (noise_mask[4:8, :4] < 0).all() def test_mask_values_are_zero_or_neg_inf(self): + """Test mask contains only 0 (attend) and -inf (mask).""" mask = create_dflash_attention_mask(SEQ_LEN, BLOCK_SIZE, "cpu", torch.float32) unique_vals = mask.unique() assert len(unique_vals) == 2 @@ -170,178 +179,42 @@ class TestDFlashLossMask: """Test DFlash loss mask construction.""" def test_loss_mask_shape(self): + """Test loss mask has shape [L].""" mask = create_dflash_loss_mask(SEQ_LEN, BLOCK_SIZE, "cpu") assert mask.shape == (SEQ_LEN,) def test_loss_mask_excludes_block_zero(self): + """Test all positions in block 0 are masked out.""" mask = create_dflash_loss_mask(SEQ_LEN, BLOCK_SIZE, "cpu") - # All positions in block 0 should be masked out assert (mask[:BLOCK_SIZE] == 0).all() def test_loss_mask_excludes_block_starts(self): + """Test block start positions are masked.""" mask = create_dflash_loss_mask(SEQ_LEN, BLOCK_SIZE, "cpu") - # Block start positions (every BLOCK_SIZE) should be masked for i in range(0, SEQ_LEN, BLOCK_SIZE): assert mask[i] == 0, f"Block start position {i} should be masked" def test_loss_mask_includes_non_start_positions(self): + """Test non-start positions in non-zero blocks are included.""" mask = create_dflash_loss_mask(SEQ_LEN, BLOCK_SIZE, "cpu") - # Non-start positions in non-zero blocks should be included for b in range(1, SEQ_LEN // BLOCK_SIZE): for offset in range(1, BLOCK_SIZE): pos = b * BLOCK_SIZE + offset assert mask[pos] == 1, f"Position {pos} should be in loss" def test_loss_mask_count(self): + """Test total active positions matches expected count.""" mask = create_dflash_loss_mask(SEQ_LEN, BLOCK_SIZE, "cpu") num_blocks = SEQ_LEN // BLOCK_SIZE - # Block 0 excluded entirely (BLOCK_SIZE positions) - # Each remaining block excludes 1 start position expected = (num_blocks - 1) * (BLOCK_SIZE - 1) assert mask.sum().item() == expected -class TestDFlashModule: - """Test DFlash draft module forward pass.""" - - @pytest.fixture - def model_and_config(self): - model = get_tiny_llama(num_hidden_layers=4) - config = _get_dflash_config() - mtsp.convert(model, [("dflash", config)]) - return model - - def test_dflash_module_forward_shape(self, model_and_config): - model = model_and_config - bsz = 2 - hidden_size = model.config.hidden_size - num_layers = len(model.target_layer_ids) - - # Create inputs matching training forward - target_hidden = torch.randn(bsz, SEQ_LEN, num_layers * hidden_size) - noise_emb = torch.randn(bsz, SEQ_LEN, hidden_size) - pos_ids = ( - torch.cat([torch.arange(SEQ_LEN), torch.arange(SEQ_LEN)]).unsqueeze(0).expand(bsz, -1) - ) - - output = model.dflash_module( - noise_embedding=noise_emb, - target_hidden=target_hidden, - position_ids=pos_ids, - attention_mask=None, - ) - assert output.shape == (bsz, SEQ_LEN, hidden_size) - - def test_dflash_module_deterministic(self, model_and_config): - model = model_and_config - model.eval() - bsz = 1 - hidden_size = model.config.hidden_size - num_layers = len(model.target_layer_ids) - - target_hidden = torch.randn(bsz, SEQ_LEN, num_layers * hidden_size) - noise_emb = torch.randn(bsz, SEQ_LEN, hidden_size) - pos_ids = torch.cat([torch.arange(SEQ_LEN), torch.arange(SEQ_LEN)]).unsqueeze(0) - - with torch.no_grad(): - out1 = model.dflash_module( - noise_embedding=noise_emb, - target_hidden=target_hidden, - position_ids=pos_ids, - ) - out2 = model.dflash_module( - noise_embedding=noise_emb, - target_hidden=target_hidden, - position_ids=pos_ids, - ) - assert torch.allclose(out1, out2) - - -class TestDFlashTrainingForward: - """Test DFlash training forward pass end-to-end.""" - - @pytest.fixture - def model(self): - model = get_tiny_llama(num_hidden_layers=4) - config = _get_dflash_config() - mtsp.convert(model, [("dflash", config)]) - model.train() - return model - - def test_training_forward_returns_loss(self, model): - bsz = 2 - input_ids = torch.randint(0, model.config.vocab_size, (bsz, SEQ_LEN)) - attention_mask = torch.ones(bsz, SEQ_LEN, dtype=torch.long) - - output = model(input_ids=input_ids, attention_mask=attention_mask) - assert hasattr(output, "loss") - assert output.loss.requires_grad - - def test_training_forward_returns_accuracy(self, model): - bsz = 2 - input_ids = torch.randint(0, model.config.vocab_size, (bsz, SEQ_LEN)) - attention_mask = torch.ones(bsz, SEQ_LEN, dtype=torch.long) - - output = model(input_ids=input_ids, attention_mask=attention_mask) - assert hasattr(output, "train_acc") - - def test_training_forward_with_labels(self, model): - """Test that labels are used for response-only loss masking.""" - bsz = 2 - input_ids = torch.randint(0, model.config.vocab_size, (bsz, SEQ_LEN)) - attention_mask = torch.ones(bsz, SEQ_LEN, dtype=torch.long) - - # Labels with -100 for first half (masked), real labels for second half - labels = torch.full((bsz, SEQ_LEN), -100, dtype=torch.long) - labels[:, SEQ_LEN // 2 :] = input_ids[:, SEQ_LEN // 2 :] - - output = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels) - assert hasattr(output, "loss") - assert output.loss.requires_grad - - def test_training_forward_all_masked_labels(self, model): - """Test that all-masked labels produce zero loss without crashing.""" - bsz = 2 - input_ids = torch.randint(0, model.config.vocab_size, (bsz, SEQ_LEN)) - attention_mask = torch.ones(bsz, SEQ_LEN, dtype=torch.long) - labels = torch.full((bsz, SEQ_LEN), -100, dtype=torch.long) - - output = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels) - assert output.loss.item() == 0.0 - - def test_training_backward(self, model): - """Test that gradients flow to dflash_module.""" - bsz = 2 - input_ids = torch.randint(0, model.config.vocab_size, (bsz, SEQ_LEN)) - attention_mask = torch.ones(bsz, SEQ_LEN, dtype=torch.long) - - output = model(input_ids=input_ids, attention_mask=attention_mask) - output.loss.backward() - - # Check dflash_module has gradients - has_grad = False - for name, param in model.dflash_module.named_parameters(): - if param.grad is not None and param.grad.abs().sum() > 0: - has_grad = True - break - assert has_grad, "DFlash module should receive gradients" - - def test_eval_forward_uses_base_model(self, model): - """In eval mode, forward should use base model (not DFlash training).""" - model.eval() - bsz = 1 - input_ids = torch.randint(0, model.config.vocab_size, (bsz, SEQ_LEN)) - - with torch.no_grad(): - output = model(input_ids=input_ids) - # Should return base model output (logits over vocab) - assert output.logits.shape == (bsz, SEQ_LEN, model.config.vocab_size) - - class TestBuildTargetLayerIds: """Test target layer selection.""" def test_single_draft_layer(self): + """Test single draft layer selects middle target layer.""" from modelopt.torch.speculative.plugins.hf_dflash import build_target_layer_ids ids = build_target_layer_ids(32, 1) @@ -349,19 +222,18 @@ def test_single_draft_layer(self): assert ids[0] == 16 # middle layer def test_multiple_draft_layers(self): + """Test multiple draft layers are monotonically increasing and in bounds.""" from modelopt.torch.speculative.plugins.hf_dflash import build_target_layer_ids ids = build_target_layer_ids(36, 5) assert len(ids) == 5 - # Should be monotonically increasing assert ids == sorted(ids) - # Should be within [1, 33] for 36-layer model assert all(1 <= lid <= 33 for lid in ids) def test_layer_ids_spread(self): + """Test layer IDs have no duplicates.""" from modelopt.torch.speculative.plugins.hf_dflash import build_target_layer_ids ids = build_target_layer_ids(32, 5) assert len(ids) == 5 - # No duplicates assert len(set(ids)) == 5 From 4c2fc7731fa30c054b85fb601e690a9cea06fbb4 Mon Sep 17 00:00:00 2001 From: Chenhan Yu Date: Tue, 31 Mar 2026 21:59:18 -0700 Subject: [PATCH 23/72] fix: correct DFlash attention mask test for reverse-causal pattern DFlash uses reverse-causal within blocks (matching SpecForge): earlier positions see more noise keys, later positions see fewer. This is intentional for block diffusion denoising. Co-Authored-By: Claude Opus 4.6 (1M context) Signed-off-by: Chenhan Yu --- .../torch/speculative/plugins/test_hf_dflash.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/tests/unit/torch/speculative/plugins/test_hf_dflash.py b/tests/unit/torch/speculative/plugins/test_hf_dflash.py index fa35bd5baf..2ed117675c 100644 --- a/tests/unit/torch/speculative/plugins/test_hf_dflash.py +++ b/tests/unit/torch/speculative/plugins/test_hf_dflash.py @@ -151,17 +151,22 @@ def test_mask_context_strictly_previous_blocks(self): assert (ctx_mask[4:8, 4:8] < 0).all() # cannot see own block def test_mask_noise_causal_within_block(self): - """Noise (right half): causal within same block, blocked across blocks.""" + """Noise (right half): reverse-causal within same block, blocked across blocks. + + DFlash uses reverse-causal: earlier positions in the block see more noise keys. + Position 0 sees all positions in its block, position B-1 sees only itself. + This matches SpecForge's implementation. + """ mask = create_dflash_attention_mask(8, 4, "cpu", torch.float32) mask_2d = mask[0, 0] noise_mask = mask_2d[:, 8:] # noise part - # Block 0, position 0: can only see position 0 - assert noise_mask[0, 0] == 0 - assert (noise_mask[0, 1:4] < 0).all() + # Block 0, position 0: can see all positions in block (0-3) + assert (noise_mask[0, :4] == 0).all() - # Block 0, position 3: can see positions 0-3 - assert (noise_mask[3, :4] == 0).all() + # Block 0, position 3: can only see position 3 + assert (noise_mask[3, :3] < 0).all() + assert noise_mask[3, 3] == 0 # Block 1 cannot see block 0 noise assert (noise_mask[4:8, :4] < 0).all() From bce17cf4dd82b2405fb65da58c8d895ed4f8d373 Mon Sep 17 00:00:00 2001 From: Chenhan Yu Date: Wed, 1 Apr 2026 07:35:51 -0700 Subject: [PATCH 24/72] fix: remove __init__.py from GPU test dirs to avoid conftest conflict The tests/gpu/ tree does not use __init__.py files. Adding them turned the directory into a Python package, changing pytest's module resolution and breaking bare 'from conftest import' in existing sparsity tests. Co-Authored-By: Claude Opus 4.6 (1M context) Signed-off-by: Chenhan Yu --- tests/gpu/torch/speculative/__init__.py | 15 --------------- tests/gpu/torch/speculative/plugins/__init__.py | 15 --------------- 2 files changed, 30 deletions(-) delete mode 100644 tests/gpu/torch/speculative/__init__.py delete mode 100644 tests/gpu/torch/speculative/plugins/__init__.py diff --git a/tests/gpu/torch/speculative/__init__.py b/tests/gpu/torch/speculative/__init__.py deleted file mode 100644 index 47f1c65a15..0000000000 --- a/tests/gpu/torch/speculative/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - diff --git a/tests/gpu/torch/speculative/plugins/__init__.py b/tests/gpu/torch/speculative/plugins/__init__.py deleted file mode 100644 index 47f1c65a15..0000000000 --- a/tests/gpu/torch/speculative/plugins/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - From 116527240adbe520d9420f668b9db6a3d4eecd2a Mon Sep 17 00:00:00 2001 From: Chenhan Yu Date: Wed, 1 Apr 2026 08:52:19 -0700 Subject: [PATCH 25/72] fix: match dtype in DFlash GPU tests to model dtype get_tiny_llama() creates bfloat16 models but torch.randn defaults to float32, causing dtype mismatch in linear layers. Use the model's parameter dtype for test inputs. Co-Authored-By: Claude Opus 4.6 (1M context) Signed-off-by: Chenhan Yu --- .../torch/speculative/plugins/test_hf_dflash.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/tests/gpu/torch/speculative/plugins/test_hf_dflash.py b/tests/gpu/torch/speculative/plugins/test_hf_dflash.py index d27fffc1c8..230b67c45d 100644 --- a/tests/gpu/torch/speculative/plugins/test_hf_dflash.py +++ b/tests/gpu/torch/speculative/plugins/test_hf_dflash.py @@ -64,8 +64,11 @@ def test_dflash_module_forward_shape(self, dflash_model): hidden_size = model.config.hidden_size num_layers = len(model.target_layer_ids) - target_hidden = torch.randn(bsz, SEQ_LEN, num_layers * hidden_size, device="cuda") - noise_emb = torch.randn(bsz, SEQ_LEN, hidden_size, device="cuda") + dtype = next(model.dflash_module.parameters()).dtype + target_hidden = torch.randn( + bsz, SEQ_LEN, num_layers * hidden_size, device="cuda", dtype=dtype + ) + noise_emb = torch.randn(bsz, SEQ_LEN, hidden_size, device="cuda", dtype=dtype) pos_ids = ( torch.cat([torch.arange(SEQ_LEN), torch.arange(SEQ_LEN)]) .unsqueeze(0) @@ -89,8 +92,11 @@ def test_dflash_module_deterministic(self, dflash_model): hidden_size = model.config.hidden_size num_layers = len(model.target_layer_ids) - target_hidden = torch.randn(bsz, SEQ_LEN, num_layers * hidden_size, device="cuda") - noise_emb = torch.randn(bsz, SEQ_LEN, hidden_size, device="cuda") + dtype = next(model.dflash_module.parameters()).dtype + target_hidden = torch.randn( + bsz, SEQ_LEN, num_layers * hidden_size, device="cuda", dtype=dtype + ) + noise_emb = torch.randn(bsz, SEQ_LEN, hidden_size, device="cuda", dtype=dtype) pos_ids = torch.cat([torch.arange(SEQ_LEN), torch.arange(SEQ_LEN)]).unsqueeze(0).cuda() with torch.no_grad(): From 273ba32bb0a3a8ba7f3d1a93230110382b876e97 Mon Sep 17 00:00:00 2001 From: Chenhan Yu Date: Wed, 1 Apr 2026 10:00:50 -0700 Subject: [PATCH 26/72] fix: use Optional types for nullable DFlash arguments HfArgumentParser requires Optional[int] (not bare int) when the default is None. Fix dflash_config and dflash_mask_token_id types. Co-Authored-By: Claude Opus 4.6 (1M context) Signed-off-by: Chenhan Yu --- examples/speculative_decoding/main.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/speculative_decoding/main.py b/examples/speculative_decoding/main.py index 812f5d8cce..650ff44336 100644 --- a/examples/speculative_decoding/main.py +++ b/examples/speculative_decoding/main.py @@ -152,12 +152,12 @@ class DFlashArguments: dflash_num_layers: int = field( default=5, metadata={"help": "Number of decoder layers in the DFlash draft module."} ) - dflash_config: str = field(default=None, metadata={"help": "Path to dflash_config.json"}) + dflash_config: str | None = field(default=None, metadata={"help": "Path to dflash_config.json"}) dflash_disable_torch_compile: bool = field( default=False, metadata={"help": "Disable torch.compile on DFlash forward/loss methods."}, ) - dflash_mask_token_id: int = field( + dflash_mask_token_id: int | None = field( default=None, metadata={"help": "Mask token ID for DFlash. If not set, auto-detected from model."}, ) From 9bf9c3401745e6063e5c9770c4ab57d61d1881b5 Mon Sep 17 00:00:00 2001 From: Chenhan Yu Date: Wed, 1 Apr 2026 11:52:52 -0700 Subject: [PATCH 27/72] fix: merge AR validation into DFlash training script Instead of a separate task_1 for AR validation, run it at the end of online_training.sh after training completes. This avoids needing a second Slurm job and container launch. Set NUM_AR_SAMPLES=0 to skip. Co-Authored-By: Claude Opus 4.6 (1M context) Signed-off-by: Chenhan Yu --- .../launcher/common/dflash/online_training.sh | 125 +++++++++++++++++- .../Qwen/Qwen3-8B/hf_online_dflash.yaml | 30 +---- 2 files changed, 130 insertions(+), 25 deletions(-) diff --git a/tools/launcher/common/dflash/online_training.sh b/tools/launcher/common/dflash/online_training.sh index 114a15ba3c..0cdb6a906d 100644 --- a/tools/launcher/common/dflash/online_training.sh +++ b/tools/launcher/common/dflash/online_training.sh @@ -15,12 +15,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -# DFlash online training script for the ModelOpt Launcher. -# Trains a DFlash draft model alongside the frozen target model. +# DFlash online training + AR validation script for the ModelOpt Launcher. +# Trains a DFlash draft model alongside the frozen target model, +# then evaluates acceptance rate on MT-Bench. # # Required env vars: # HF_MODEL_CKPT — path to the target HuggingFace model # +# Optional env vars: +# NUM_AR_SAMPLES — number of MT-Bench samples for AR validation (default: 20, 0 to skip) +# # All other args are passed through to launch_train.sh. SCRIPT_DIR="$(dirname "$(readlink -f "$0")")" @@ -34,9 +38,126 @@ export PATH=$PATH:/workspace/.local/bin trap 'error_handler $0 $LINENO' ERR +# Parse DFlash-specific args from the command line for AR validation +DFLASH_BLOCK_SIZE=16 +DFLASH_NUM_LAYERS=5 +DFLASH_MASK_TOKEN_ID="" +OUTPUT_DIR="" +for arg in "$@"; do + case "$arg" in + --dflash_block_size) next_is_block_size=1 ;; + --dflash_num_layers) next_is_num_layers=1 ;; + --dflash_mask_token_id) next_is_mask_id=1 ;; + --output_dir) next_is_output=1 ;; + *) + if [ "$next_is_block_size" = "1" ]; then DFLASH_BLOCK_SIZE="$arg"; next_is_block_size=0; fi + if [ "$next_is_num_layers" = "1" ]; then DFLASH_NUM_LAYERS="$arg"; next_is_num_layers=0; fi + if [ "$next_is_mask_id" = "1" ]; then DFLASH_MASK_TOKEN_ID="$arg"; next_is_mask_id=0; fi + if [ "$next_is_output" = "1" ]; then OUTPUT_DIR="$arg"; next_is_output=0; fi + ;; + esac +done + +# Step 1: Training bash modules/Model-Optimizer/examples/speculative_decoding/launch_train.sh \ --model ${HF_MODEL_CKPT} \ --mode dflash \ ${@} +# Step 2: AR Validation +NUM_AR_SAMPLES=${NUM_AR_SAMPLES:-20} +if [ "${NUM_AR_SAMPLES}" = "0" ]; then + echo "Skipping AR validation (NUM_AR_SAMPLES=0)" + exit 0 +fi + +if [ -z "$OUTPUT_DIR" ]; then + echo "WARNING: --output_dir not found in args, skipping AR validation" + exit 0 +fi + +# Build mask_token_id config +if [ -n "$DFLASH_MASK_TOKEN_ID" ]; then + MASK_ARG="'mask_token_id': ${DFLASH_MASK_TOKEN_ID}," +else + MASK_ARG="" +fi + +echo "" +echo "=== DFlash AR Validation ===" +echo "Target model: ${HF_MODEL_CKPT}" +echo "DFlash checkpoint: ${OUTPUT_DIR}" +echo "Block size: ${DFLASH_BLOCK_SIZE}" +echo "Draft layers: ${DFLASH_NUM_LAYERS}" +echo "Samples: ${NUM_AR_SAMPLES}" + +CUDA_VISIBLE_DEVICES=0 python3 -c " +import torch +from datasets import load_dataset +from transformers import AutoModelForCausalLM, AutoTokenizer +from modelopt.torch.speculative.plugins.transformers import HFARValidation +import modelopt.torch.opt as mto +import modelopt.torch.speculative as mtsp + +mto.enable_huggingface_checkpointing() + +model = AutoModelForCausalLM.from_pretrained( + '${HF_MODEL_CKPT}', torch_dtype=torch.bfloat16, device_map={'': 0}, trust_remote_code=True +) +tokenizer = AutoTokenizer.from_pretrained('${HF_MODEL_CKPT}', trust_remote_code=True) + +config = { + 'dflash_block_size': ${DFLASH_BLOCK_SIZE}, + 'dflash_architecture_config': { + 'num_hidden_layers': ${DFLASH_NUM_LAYERS}, + ${MASK_ARG} + }, + 'dflash_use_torch_compile': False, +} +mtsp.convert(model, [('dflash', config)]) + +# Load trained DFlash weights +import glob +from safetensors.torch import load_file +ckpt_files = sorted(glob.glob('${OUTPUT_DIR}/model*.safetensors')) +if ckpt_files: + state = {} + for f in ckpt_files: + state.update(load_file(f)) + dflash_keys = {k: v for k, v in state.items() if 'dflash_module' in k} + model.load_state_dict(dflash_keys, strict=False) + print(f'Loaded {len(dflash_keys)} DFlash weights from {len(ckpt_files)} file(s)') +else: + print('WARNING: No checkpoint files found, using random weights') + +model.eval() +validator = HFARValidation(model, tokenizer) + +ds = load_dataset('/hf-local/HuggingFaceH4/mt_bench_prompts')['train'] +num_samples = min(${NUM_AR_SAMPLES}, len(ds)) + +ars = [] +for i in range(num_samples): + prompt = ds[i]['prompt'][0] + chat = [{'role': 'user', 'content': prompt}] + text = tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True) + input_ids = tokenizer(text, return_tensors='pt').input_ids.cuda() + try: + _, ar = validator.validate(osl=32, input_ids=input_ids, steps=3) + ars.append(ar) + print(f' AR={ar:.2f} | {prompt[:60]}') + except Exception as e: + print(f' ERROR | {prompt[:60]}... | {e}') + +if ars: + avg_ar = sum(ars) / len(ars) + print(f'\n==== DFlash AR Results ====') + print(f'Samples: {len(ars)}') + print(f'Average AR: {avg_ar:.4f}') + print(f'Min AR: {min(ars):.4f}') + print(f'Max AR: {max(ars):.4f}') +else: + print('No AR results collected') +" + ################################################################################################### diff --git a/tools/launcher/examples/Qwen/Qwen3-8B/hf_online_dflash.yaml b/tools/launcher/examples/Qwen/Qwen3-8B/hf_online_dflash.yaml index 4c3d499b2f..c72b5aec48 100644 --- a/tools/launcher/examples/Qwen/Qwen3-8B/hf_online_dflash.yaml +++ b/tools/launcher/examples/Qwen/Qwen3-8B/hf_online_dflash.yaml @@ -1,12 +1,11 @@ # DFlash online speculative decoding training for Qwen3-8B. # # Trains a DFlash draft model (block diffusion) using the frozen target model -# to extract multi-layer hidden states on the fly. +# to extract multi-layer hidden states on the fly, then evaluates AR on MT-Bench. # -# 3-step pipeline: -# task_0: Online DFlash training -# task_1: AR (Acceptance Rate) validation on MT-Bench -# task_2: Benchmark speculative decoding speedup via VLLM +# 2-step pipeline: +# task_0: Online DFlash training + AR validation +# task_1: Benchmark speculative decoding speedup via VLLM # # Reference: "DFlash: Block Diffusion for Flash Speculative Decoding" (arXiv:2602.06036) # @@ -16,7 +15,7 @@ job_name: Qwen3-8B_DFlash_online pipeline: - # Step 1: Online DFlash training + # Step 1: Online DFlash training + AR validation task_0: script: common/dflash/online_training.sh args: @@ -34,30 +33,15 @@ pipeline: - --dflash_mask_token_id 151643 environment: - HF_MODEL_CKPT: /hf-local/Qwen/Qwen3-8B + - NUM_AR_SAMPLES: "20" slurm_config: _factory_: "slurm_factory" nodes: 1 ntasks_per_node: 1 gpus_per_node: 8 - # Step 2: AR validation on MT-Bench + # Step 2: Benchmark speculative decoding (VLLM backend) task_1: - script: common/dflash/ar_validate.sh - environment: - - HF_MODEL_CKPT: /hf-local/Qwen/Qwen3-8B - - DFLASH_CKPT: /scratchspace/dflash - - DFLASH_BLOCK_SIZE: "16" - - DFLASH_NUM_LAYERS: "5" - - DFLASH_MASK_TOKEN_ID: "151643" - - NUM_SAMPLES: "20" - slurm_config: - _factory_: "slurm_factory" - nodes: 1 - ntasks_per_node: 1 - gpus_per_node: 1 - - # Step 3: Benchmark speculative decoding (VLLM backend) - task_2: script: common/specdec_bench/quick_check.sh args: - --draft_model_dir /scratchspace/dflash From d19cd3b963171bf3eff6abfd555a5a974682cf4f Mon Sep 17 00:00:00 2001 From: Chenhan Yu Date: Wed, 1 Apr 2026 17:10:12 -0700 Subject: [PATCH 28/72] fix: align pseudo_speculative_generate with training masks Two mismatches between training and inference: 1. Causal direction: training uses reverse-causal (pos 0 sees all, pos B-1 sees only itself), but inference used standard causal (opposite). Now uses matching reverse-causal. 2. Position IDs: training uses [0..L-1, 0..L-1] (shared positions), but inference used [0..ctx_len-1, seq_len..seq_len+B-1]. Now uses [0..ctx_len-1, 0..B-1] to match training's block positions. Co-Authored-By: Claude Opus 4.6 (1M context) Signed-off-by: Chenhan Yu --- .../torch/speculative/plugins/hf_dflash.py | 24 +++++++++++-------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/modelopt/torch/speculative/plugins/hf_dflash.py b/modelopt/torch/speculative/plugins/hf_dflash.py index 8e1b275f73..12a57d18b1 100644 --- a/modelopt/torch/speculative/plugins/hf_dflash.py +++ b/modelopt/torch/speculative/plugins/hf_dflash.py @@ -673,21 +673,25 @@ def pseudo_speculative_generate(self, input_ids, steps=1): block_ids[:, 0] = base_token.squeeze(-1) noise_embedding = self._base_model_embeddings(block_ids) - # Position IDs: context 0..ctx_len-1, block seq_len..seq_len+block_size-1 + # Position IDs must match training: [0..ctx_len-1, 0..block_size-1] + # Training uses [0..L-1, 0..L-1] — context and noise share positions. + # At inference, the block positions should start from 0 (same as training blocks). ctx_len = target_hidden.shape[1] ctx_positions = torch.arange(ctx_len, device=device) - block_positions = torch.arange(seq_len, seq_len + block_size, device=device) + block_positions = torch.arange(block_size, device=device) pos_ids = torch.cat([ctx_positions, block_positions]).unsqueeze(0).expand(bsz, -1) - # Attention mask: block sees ALL context + causal within block + # Attention mask must match training pattern: + # - Context part: block sees ALL previous context (at inference there's only 1 block) + # - Noise part: reverse-causal within block (pos 0 sees all, pos B-1 sees only itself) + # This matches training: indices.unsqueeze(0) >= indices.unsqueeze(1) attn_mask = torch.zeros(1, 1, block_size, ctx_len + block_size, device=device, dtype=dtype) - block_causal = torch.triu( - torch.full( - (block_size, block_size), torch.finfo(dtype).min, device=device, dtype=dtype - ), - diagonal=1, - ) - attn_mask[:, :, :, ctx_len:] = block_causal + # Noise part: reverse-causal (lower-triangular is masked) + block_indices = torch.arange(block_size, device=device) + reverse_causal = block_indices.unsqueeze(0) >= block_indices.unsqueeze(1) # [B, B] + noise_mask = torch.zeros(block_size, block_size, device=device, dtype=dtype) + noise_mask.masked_fill_(~reverse_causal, torch.finfo(dtype).min) + attn_mask[:, :, :, ctx_len:] = noise_mask # Draft forward draft_hidden = self.dflash_module( From 73bb0cc01a4c9b61795b3a870989bfa922c7c134 Mon Sep 17 00:00:00 2001 From: Chenhan Yu Date: Wed, 1 Apr 2026 17:17:35 -0700 Subject: [PATCH 29/72] fix: use standard causal mask within DFlash blocks SpecForge PR #415 line 188 has a bug: comment says "j <= i" (standard causal) but code implements "j >= i" (reverse-causal). Fix to match the intended standard causal: position 0 (anchor) sees only itself, subsequent positions see all prior positions for progressive denoising. Both training mask (create_dflash_attention_mask) and inference mask (pseudo_speculative_generate) now use standard causal consistently. Co-Authored-By: Claude Opus 4.6 (1M context) Signed-off-by: Chenhan Yu --- .../torch/speculative/plugins/hf_dflash.py | 21 +++++++++---------- .../speculative/plugins/test_hf_dflash.py | 17 +++++++-------- 2 files changed, 18 insertions(+), 20 deletions(-) diff --git a/modelopt/torch/speculative/plugins/hf_dflash.py b/modelopt/torch/speculative/plugins/hf_dflash.py index 12a57d18b1..e7889f897c 100644 --- a/modelopt/torch/speculative/plugins/hf_dflash.py +++ b/modelopt/torch/speculative/plugins/hf_dflash.py @@ -248,7 +248,7 @@ def create_dflash_attention_mask(seq_len, block_size, device, dtype): ctx_mask = k_block_ids < q_block_ids same_block = q_block_ids == k_block_ids - causal = indices.unsqueeze(0) >= indices.unsqueeze(1) + causal = indices.unsqueeze(0) <= indices.unsqueeze(1) # standard causal: j <= i noise_mask = same_block & causal full_mask_bool = torch.cat([ctx_mask, noise_mask], dim=1) @@ -681,17 +681,16 @@ def pseudo_speculative_generate(self, input_ids, steps=1): block_positions = torch.arange(block_size, device=device) pos_ids = torch.cat([ctx_positions, block_positions]).unsqueeze(0).expand(bsz, -1) - # Attention mask must match training pattern: - # - Context part: block sees ALL previous context (at inference there's only 1 block) - # - Noise part: reverse-causal within block (pos 0 sees all, pos B-1 sees only itself) - # This matches training: indices.unsqueeze(0) >= indices.unsqueeze(1) + # Attention mask: block sees ALL context + standard causal within block + # Standard causal: position i can attend to positions j <= i (see anchor + previous) attn_mask = torch.zeros(1, 1, block_size, ctx_len + block_size, device=device, dtype=dtype) - # Noise part: reverse-causal (lower-triangular is masked) - block_indices = torch.arange(block_size, device=device) - reverse_causal = block_indices.unsqueeze(0) >= block_indices.unsqueeze(1) # [B, B] - noise_mask = torch.zeros(block_size, block_size, device=device, dtype=dtype) - noise_mask.masked_fill_(~reverse_causal, torch.finfo(dtype).min) - attn_mask[:, :, :, ctx_len:] = noise_mask + block_causal = torch.triu( + torch.full( + (block_size, block_size), torch.finfo(dtype).min, device=device, dtype=dtype + ), + diagonal=1, + ) + attn_mask[:, :, :, ctx_len:] = block_causal # Draft forward draft_hidden = self.dflash_module( diff --git a/tests/unit/torch/speculative/plugins/test_hf_dflash.py b/tests/unit/torch/speculative/plugins/test_hf_dflash.py index 2ed117675c..34ff72b9b2 100644 --- a/tests/unit/torch/speculative/plugins/test_hf_dflash.py +++ b/tests/unit/torch/speculative/plugins/test_hf_dflash.py @@ -151,22 +151,21 @@ def test_mask_context_strictly_previous_blocks(self): assert (ctx_mask[4:8, 4:8] < 0).all() # cannot see own block def test_mask_noise_causal_within_block(self): - """Noise (right half): reverse-causal within same block, blocked across blocks. + """Noise (right half): standard causal within same block, blocked across blocks. - DFlash uses reverse-causal: earlier positions in the block see more noise keys. - Position 0 sees all positions in its block, position B-1 sees only itself. - This matches SpecForge's implementation. + Standard causal: position i can attend to positions j <= i. + Position 0 (anchor) sees only itself, position B-1 sees all positions in block. """ mask = create_dflash_attention_mask(8, 4, "cpu", torch.float32) mask_2d = mask[0, 0] noise_mask = mask_2d[:, 8:] # noise part - # Block 0, position 0: can see all positions in block (0-3) - assert (noise_mask[0, :4] == 0).all() + # Block 0, position 0: can only see position 0 + assert noise_mask[0, 0] == 0 + assert (noise_mask[0, 1:4] < 0).all() - # Block 0, position 3: can only see position 3 - assert (noise_mask[3, :3] < 0).all() - assert noise_mask[3, 3] == 0 + # Block 0, position 3: can see positions 0-3 + assert (noise_mask[3, :4] == 0).all() # Block 1 cannot see block 0 noise assert (noise_mask[4:8, :4] < 0).all() From 3fa0d64983811372901ac3a944339428944e0c91 Mon Sep 17 00:00:00 2001 From: Chenhan Yu Date: Wed, 1 Apr 2026 20:44:41 -0700 Subject: [PATCH 30/72] fix: increase DDP timeout to 1800s for DFlash training 300s timeout caused transient NCCL failures. Increase to 30 min to tolerate temporary network glitches on multi-GPU nodes. Co-Authored-By: Claude Opus 4.6 (1M context) Signed-off-by: Chenhan Yu --- examples/speculative_decoding/launch_train.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/speculative_decoding/launch_train.sh b/examples/speculative_decoding/launch_train.sh index 27e17e388d..7beb567425 100755 --- a/examples/speculative_decoding/launch_train.sh +++ b/examples/speculative_decoding/launch_train.sh @@ -222,7 +222,7 @@ elif [[ "$MODE" == "dflash" ]]; then SPECULATIVE_ARGS="$SPECULATIVE_ARGS --dflash_mask_token_id $DFLASH_MASK_TOKEN_ID" fi # DFlash uses DDP instead of FSDP - FSDP_ARGS="--ddp_find_unused_parameters True --ddp_timeout 300" + FSDP_ARGS="--ddp_find_unused_parameters True --ddp_timeout 1800" else echo "Unsupported mode: $MODE. Supported: eagle3, dflash" exit 1 From 80afde2a49dfb8b33f228104adf6a1bdf06595dd Mon Sep 17 00:00:00 2001 From: Chenhan Yu Date: Thu, 2 Apr 2026 07:32:24 -0700 Subject: [PATCH 31/72] fix: revert to SpecForge's reverse-causal mask (j >= i) The SpecForge checkpoint achieves AR=1.77 with our inference code, confirming that the reverse-causal pattern works. Revert both training and inference to match SpecForge exactly (j >= i), so training and inference are consistent and we can isolate other causes of the AR gap. Co-Authored-By: Claude Opus 4.6 (1M context) Signed-off-by: Chenhan Yu --- .../torch/speculative/plugins/hf_dflash.py | 18 ++++++++---------- .../speculative/plugins/test_hf_dflash.py | 16 ++++++++-------- 2 files changed, 16 insertions(+), 18 deletions(-) diff --git a/modelopt/torch/speculative/plugins/hf_dflash.py b/modelopt/torch/speculative/plugins/hf_dflash.py index e7889f897c..4946e17f56 100644 --- a/modelopt/torch/speculative/plugins/hf_dflash.py +++ b/modelopt/torch/speculative/plugins/hf_dflash.py @@ -248,7 +248,7 @@ def create_dflash_attention_mask(seq_len, block_size, device, dtype): ctx_mask = k_block_ids < q_block_ids same_block = q_block_ids == k_block_ids - causal = indices.unsqueeze(0) <= indices.unsqueeze(1) # standard causal: j <= i + causal = indices.unsqueeze(0) >= indices.unsqueeze(1) # matching SpecForge: j >= i noise_mask = same_block & causal full_mask_bool = torch.cat([ctx_mask, noise_mask], dim=1) @@ -681,16 +681,14 @@ def pseudo_speculative_generate(self, input_ids, steps=1): block_positions = torch.arange(block_size, device=device) pos_ids = torch.cat([ctx_positions, block_positions]).unsqueeze(0).expand(bsz, -1) - # Attention mask: block sees ALL context + standard causal within block - # Standard causal: position i can attend to positions j <= i (see anchor + previous) + # Attention mask: block sees ALL context + reverse-causal within block + # Matching SpecForge training: j >= i (pos 0 sees all, pos B-1 sees only itself) attn_mask = torch.zeros(1, 1, block_size, ctx_len + block_size, device=device, dtype=dtype) - block_causal = torch.triu( - torch.full( - (block_size, block_size), torch.finfo(dtype).min, device=device, dtype=dtype - ), - diagonal=1, - ) - attn_mask[:, :, :, ctx_len:] = block_causal + block_indices = torch.arange(block_size, device=device) + reverse_causal = block_indices.unsqueeze(0) >= block_indices.unsqueeze(1) + noise_mask = torch.zeros(block_size, block_size, device=device, dtype=dtype) + noise_mask.masked_fill_(~reverse_causal, torch.finfo(dtype).min) + attn_mask[:, :, :, ctx_len:] = noise_mask # Draft forward draft_hidden = self.dflash_module( diff --git a/tests/unit/torch/speculative/plugins/test_hf_dflash.py b/tests/unit/torch/speculative/plugins/test_hf_dflash.py index 34ff72b9b2..50d3c9768b 100644 --- a/tests/unit/torch/speculative/plugins/test_hf_dflash.py +++ b/tests/unit/torch/speculative/plugins/test_hf_dflash.py @@ -151,21 +151,21 @@ def test_mask_context_strictly_previous_blocks(self): assert (ctx_mask[4:8, 4:8] < 0).all() # cannot see own block def test_mask_noise_causal_within_block(self): - """Noise (right half): standard causal within same block, blocked across blocks. + """Noise (right half): reverse-causal within same block, matching SpecForge. - Standard causal: position i can attend to positions j <= i. - Position 0 (anchor) sees only itself, position B-1 sees all positions in block. + SpecForge uses j >= i: position 0 (anchor) sees all positions in block, + position B-1 sees only itself. Cross-block noise is fully masked. """ mask = create_dflash_attention_mask(8, 4, "cpu", torch.float32) mask_2d = mask[0, 0] noise_mask = mask_2d[:, 8:] # noise part - # Block 0, position 0: can only see position 0 - assert noise_mask[0, 0] == 0 - assert (noise_mask[0, 1:4] < 0).all() + # Block 0, position 0: can see all positions in block (0-3) + assert (noise_mask[0, :4] == 0).all() - # Block 0, position 3: can see positions 0-3 - assert (noise_mask[3, :4] == 0).all() + # Block 0, position 3: can only see position 3 + assert (noise_mask[3, :3] < 0).all() + assert noise_mask[3, 3] == 0 # Block 1 cannot see block 0 noise assert (noise_mask[4:8, :4] < 0).all() From bfdd582648f808ed213f9eeede13a6d668b031e8 Mon Sep 17 00:00:00 2001 From: Chenhan Yu Date: Thu, 2 Apr 2026 07:53:09 -0700 Subject: [PATCH 32/72] fix: use continuing position IDs for DFlash inference block Training uses [0..L-1, 0..L-1] where noise positions mirror context. At inference, the block predicts tokens at seq_len..seq_len+B-1, so noise positions should continue from ctx_len, matching SpecForge's spec_generate which uses sequential position IDs. Previously used [0..B-1] which gave wrong RoPE embeddings. Co-Authored-By: Claude Opus 4.6 (1M context) Signed-off-by: Chenhan Yu --- modelopt/torch/speculative/plugins/hf_dflash.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/modelopt/torch/speculative/plugins/hf_dflash.py b/modelopt/torch/speculative/plugins/hf_dflash.py index 4946e17f56..ffa2ed3541 100644 --- a/modelopt/torch/speculative/plugins/hf_dflash.py +++ b/modelopt/torch/speculative/plugins/hf_dflash.py @@ -673,12 +673,12 @@ def pseudo_speculative_generate(self, input_ids, steps=1): block_ids[:, 0] = base_token.squeeze(-1) noise_embedding = self._base_model_embeddings(block_ids) - # Position IDs must match training: [0..ctx_len-1, 0..block_size-1] - # Training uses [0..L-1, 0..L-1] — context and noise share positions. - # At inference, the block positions should start from 0 (same as training blocks). + # Position IDs: training uses [0..L-1, 0..L-1] where noise positions + # mirror context positions. At inference, block predicts tokens at + # seq_len..seq_len+B-1, so noise positions continue from ctx_len. ctx_len = target_hidden.shape[1] ctx_positions = torch.arange(ctx_len, device=device) - block_positions = torch.arange(block_size, device=device) + block_positions = torch.arange(ctx_len, ctx_len + block_size, device=device) pos_ids = torch.cat([ctx_positions, block_positions]).unsqueeze(0).expand(bsz, -1) # Attention mask: block sees ALL context + reverse-causal within block From fb7acab55609b7f980eb321cfaf8ff0331364c13 Mon Sep 17 00:00:00 2001 From: Chenhan Yu Date: Thu, 2 Apr 2026 08:48:49 -0700 Subject: [PATCH 33/72] fix: remove attention mask at DFlash inference, matching SpecForge MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit SpecForge's spec_generate passes no attention mask — with KV cache, all positions attend freely. Our pseudo_speculative_generate should match: no mask at inference, only at training. Co-Authored-By: Claude Opus 4.6 (1M context) Signed-off-by: Chenhan Yu --- modelopt/torch/speculative/plugins/hf_dflash.py | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) diff --git a/modelopt/torch/speculative/plugins/hf_dflash.py b/modelopt/torch/speculative/plugins/hf_dflash.py index ffa2ed3541..d98e42b0f1 100644 --- a/modelopt/torch/speculative/plugins/hf_dflash.py +++ b/modelopt/torch/speculative/plugins/hf_dflash.py @@ -664,7 +664,6 @@ def pseudo_speculative_generate(self, input_ids, steps=1): bsz = input_ids.shape[0] seq_len = input_ids.shape[1] device = input_ids.device - dtype = target_hidden.dtype # Block: first token is base_token (anchor), rest are mask block_ids = torch.full( @@ -681,21 +680,16 @@ def pseudo_speculative_generate(self, input_ids, steps=1): block_positions = torch.arange(ctx_len, ctx_len + block_size, device=device) pos_ids = torch.cat([ctx_positions, block_positions]).unsqueeze(0).expand(bsz, -1) - # Attention mask: block sees ALL context + reverse-causal within block - # Matching SpecForge training: j >= i (pos 0 sees all, pos B-1 sees only itself) - attn_mask = torch.zeros(1, 1, block_size, ctx_len + block_size, device=device, dtype=dtype) - block_indices = torch.arange(block_size, device=device) - reverse_causal = block_indices.unsqueeze(0) >= block_indices.unsqueeze(1) - noise_mask = torch.zeros(block_size, block_size, device=device, dtype=dtype) - noise_mask.masked_fill_(~reverse_causal, torch.finfo(dtype).min) - attn_mask[:, :, :, ctx_len:] = noise_mask + # No attention mask at inference — matching SpecForge's spec_generate + # which uses KV cache with no mask. All positions attend freely to + # context and each other within the block. # Draft forward draft_hidden = self.dflash_module( noise_embedding=noise_embedding, target_hidden=target_hidden, position_ids=pos_ids, - attention_mask=attn_mask, + attention_mask=None, ) # Logits on positions 1..block_size-1 (skip anchor at position 0) From 290670fceeaec00ace2e8ebabb54d5d9cc353370 Mon Sep 17 00:00:00 2001 From: Chenhan Yu Date: Thu, 2 Apr 2026 10:28:45 -0700 Subject: [PATCH 34/72] add: standalone DFlash training script with SpecForge data pipeline train_dflash.py uses SpecForge's GeneralParser + chat template for data preprocessing (system prompt, offset-mapping loss mask) with ModelOpt's DFlash module for the draft model. Co-Authored-By: Claude Opus 4.6 (1M context) Signed-off-by: Chenhan Yu --- examples/speculative_decoding/train_dflash.py | 319 ++++++++++++++++++ 1 file changed, 319 insertions(+) create mode 100644 examples/speculative_decoding/train_dflash.py diff --git a/examples/speculative_decoding/train_dflash.py b/examples/speculative_decoding/train_dflash.py new file mode 100644 index 0000000000..20be8a85f0 --- /dev/null +++ b/examples/speculative_decoding/train_dflash.py @@ -0,0 +1,319 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Standalone DFlash training script using SpecForge's data pipeline. + +Uses SpecForge's tokenizer template + offset-mapping loss mask for data +preprocessing, and ModelOpt's DFlash module for the draft model. This +isolates data pipeline differences from model architecture differences. + +Usage: + torchrun --nproc_per_node=8 train_dflash.py \ + --model /path/to/Qwen3-8B \ + --data /path/to/train.jsonl \ + --chat-template qwen \ + --block-size 16 \ + --num-draft-layers 5 \ + --num-epochs 3 \ + --lr 1e-4 \ + --output-dir /path/to/output +""" + +import argparse +import math +import os + +import torch +import torch.distributed as dist +from datasets import load_dataset +from torch.utils.data import DataLoader, DistributedSampler +from transformers import AutoModelForCausalLM, AutoTokenizer + +import modelopt.torch.opt as mto +import modelopt.torch.speculative as mtsp + + +def parse_args(): + """Parse command line arguments.""" + parser = argparse.ArgumentParser(description="DFlash training with SpecForge data pipeline") + parser.add_argument("--model", type=str, required=True, help="Target model path") + parser.add_argument("--data", type=str, required=True, help="Training data JSONL path") + parser.add_argument("--chat-template", type=str, default="qwen", help="Chat template name") + parser.add_argument("--block-size", type=int, default=16) + parser.add_argument("--num-draft-layers", type=int, default=5) + parser.add_argument("--mask-token-id", type=int, default=None) + parser.add_argument("--max-length", type=int, default=512) + parser.add_argument("--num-epochs", type=int, default=3) + parser.add_argument("--lr", type=float, default=1e-4) + parser.add_argument("--batch-size", type=int, default=1) + parser.add_argument("--warmup-ratio", type=float, default=0.01) + parser.add_argument("--log-interval", type=int, default=100) + parser.add_argument("--save-interval", type=int, default=0, help="0 = save at end only") + parser.add_argument("--output-dir", type=str, required=True) + parser.add_argument("--seed", type=int, default=42) + parser.add_argument("--num-ar-samples", type=int, default=20, help="AR validation samples") + return parser.parse_args() + + +def is_rank0(): + """Check if current process is rank 0.""" + return not dist.is_initialized() or dist.get_rank() == 0 + + +def print_rank0(msg): + """Print only on rank 0.""" + if is_rank0(): + print(msg, flush=True) + + +def build_dataset(tokenizer, data_path, chat_template_name, max_length): + """Build dataset using SpecForge's data pipeline. + + Uses SpecForge's GeneralParser to tokenize conversations with the + proper chat template and compute offset-mapping-based loss masks. + """ + from specforge.data.parse import GeneralParser + from specforge.data.template import TEMPLATE_REGISTRY + + template = TEMPLATE_REGISTRY.get(chat_template_name) + parser = GeneralParser(tokenizer, template) + + raw_dataset = load_dataset("json", data_files=data_path)["train"] + + processed = {"input_ids": [], "loss_mask": []} + skipped = 0 + for sample in raw_dataset: + convs = sample.get("conversations", sample.get("messages", [])) + if not convs: + skipped += 1 + continue + try: + input_ids, loss_mask = parser.parse(convs, max_length=max_length) + processed["input_ids"].append(input_ids) + processed["loss_mask"].append(loss_mask) + except Exception: + skipped += 1 + + print_rank0(f"Processed {len(processed['input_ids'])} samples, skipped {skipped}") + return processed + + +class DFlashDataset(torch.utils.data.Dataset): + """Simple dataset wrapping tokenized input_ids and loss_mask.""" + + def __init__(self, data): + self.input_ids = data["input_ids"] + self.loss_mask = data["loss_mask"] + + def __len__(self): + return len(self.input_ids) + + def __getitem__(self, idx): + return { + "input_ids": self.input_ids[idx], + "loss_mask": self.loss_mask[idx], + } + + +def collate_fn(batch): + """Collate batch of samples.""" + input_ids = torch.stack([b["input_ids"] for b in batch]) + loss_mask = torch.stack([b["loss_mask"] for b in batch]) + return {"input_ids": input_ids, "loss_mask": loss_mask} + + +def train(args): + """Main training loop.""" + # Init distributed + dist.init_process_group("nccl") + local_rank = int(os.environ.get("LOCAL_RANK", 0)) + torch.cuda.set_device(local_rank) + device = torch.device("cuda", local_rank) + + torch.manual_seed(args.seed) + mto.enable_huggingface_checkpointing() + + # Load model + print_rank0(f"Loading model: {args.model}") + model = AutoModelForCausalLM.from_pretrained( + args.model, torch_dtype=torch.bfloat16, device_map={"": device}, trust_remote_code=True + ) + tokenizer = AutoTokenizer.from_pretrained(args.model, trust_remote_code=True) + + # Detect mask_token_id + mask_token_id = args.mask_token_id + if mask_token_id is None: + if hasattr(tokenizer, "mask_token_id") and tokenizer.mask_token_id is not None: + mask_token_id = tokenizer.mask_token_id + elif hasattr(tokenizer, "pad_token_id") and tokenizer.pad_token_id is not None: + mask_token_id = tokenizer.pad_token_id + else: + mask_token_id = tokenizer.eos_token_id + print_rank0(f"mask_token_id: {mask_token_id}") + + # Convert to DFlash + config = { + "dflash_block_size": args.block_size, + "dflash_use_torch_compile": False, + "dflash_architecture_config": { + "num_hidden_layers": args.num_draft_layers, + "mask_token_id": mask_token_id, + }, + } + mtsp.convert(model, [("dflash", config)]) + print_rank0( + f"DFlash module created: {sum(p.numel() for p in model.dflash_module.parameters()):,} params" + ) + + # Build dataset using SpecForge pipeline + print_rank0("Building dataset with SpecForge pipeline...") + data = build_dataset(tokenizer, args.data, args.chat_template, args.max_length) + + # Filter samples with too few loss tokens + min_loss_tokens = 2 * args.block_size + filtered_ids = [] + filtered_masks = [] + for i in range(len(data["input_ids"])): + if data["loss_mask"][i].sum() >= min_loss_tokens: + filtered_ids.append(data["input_ids"][i]) + filtered_masks.append(data["loss_mask"][i]) + print_rank0(f"After filtering: {len(filtered_ids)} samples (min {min_loss_tokens} loss tokens)") + data = {"input_ids": filtered_ids, "loss_mask": filtered_masks} + + dataset = DFlashDataset(data) + sampler = DistributedSampler(dataset, shuffle=True) + dataloader = DataLoader( + dataset, + batch_size=args.batch_size, + sampler=sampler, + collate_fn=collate_fn, + num_workers=2, + pin_memory=True, + drop_last=True, + ) + + # Wrap with DDP + model = torch.nn.parallel.DistributedDataParallel( + model, + device_ids=[local_rank], + find_unused_parameters=True, + ) + raw_model = model.module + + # Optimizer — only train dflash_module + optimizer = torch.optim.AdamW( + [p for p in raw_model.dflash_module.parameters() if p.requires_grad], + lr=args.lr, + weight_decay=0.0, + ) + + # LR scheduler + steps_per_epoch = len(dataloader) + total_steps = args.num_epochs * steps_per_epoch + warmup_steps = int(total_steps * args.warmup_ratio) + + def lr_lambda(step): + if step < warmup_steps: + return step / max(warmup_steps, 1) + progress = (step - warmup_steps) / max(total_steps - warmup_steps, 1) + return 0.5 * (1.0 + math.cos(math.pi * progress)) + + scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) + + print_rank0(f"Training: {total_steps} steps, {warmup_steps} warmup, {steps_per_epoch}/epoch") + + # Training loop + global_step = 0 + for epoch in range(args.num_epochs): + sampler.set_epoch(epoch) + model.train() + + for batch in dataloader: + input_ids = batch["input_ids"].to(device) + loss_mask = batch["loss_mask"].to(device) + + # Create labels from loss_mask: -100 for masked positions + labels = input_ids.clone() + labels[loss_mask == 0] = -100 + + output = model( + input_ids=input_ids, + attention_mask=torch.ones_like(input_ids), + labels=labels, + ) + + loss = output.loss + loss.backward() + optimizer.step() + scheduler.step() + optimizer.zero_grad() + + global_step += 1 + + if global_step % args.log_interval == 0: + acc = output.train_acc[0][0] if hasattr(output, "train_acc") else 0.0 + lr = scheduler.get_last_lr()[0] + print_rank0( + f"Step {global_step} | loss={loss.item():.4f} | acc={acc:.4f} | lr={lr:.2e}" + ) + + if args.save_interval > 0 and global_step % args.save_interval == 0: + if is_rank0(): + save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") + raw_model.save_pretrained(save_path) + print_rank0(f"Saved checkpoint: {save_path}") + + # Save final model + if is_rank0(): + os.makedirs(args.output_dir, exist_ok=True) + raw_model.save_pretrained(args.output_dir) + print_rank0(f"Saved final model: {args.output_dir}") + + dist.barrier() + + # AR validation on rank 0 + if is_rank0() and args.num_ar_samples > 0: + print_rank0("\n=== AR Validation ===") + model.eval() + from modelopt.torch.speculative.plugins.transformers import HFARValidation + + validator = HFARValidation(raw_model, tokenizer) + ds = load_dataset("/hf-local/HuggingFaceH4/mt_bench_prompts")["train"] + + ars = [] + for i in range(min(args.num_ar_samples, len(ds))): + prompt = ds[i]["prompt"][0] + chat = [{"role": "user", "content": prompt}] + text = tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True) + inp = tokenizer(text, return_tensors="pt").input_ids.to(device) + try: + _, ar = validator.validate(osl=32, input_ids=inp, steps=3) + ars.append(ar) + print_rank0(f" AR={ar:.2f} | {prompt[:60]}") + except Exception as e: + print_rank0(f" ERROR | {prompt[:60]}... | {e}") + + if ars: + avg = sum(ars) / len(ars) + print_rank0("\n==== DFlash AR Results ====") + print_rank0(f"Average AR: {avg:.4f}") + print_rank0(f"Min: {min(ars):.4f}, Max: {max(ars):.4f}") + + dist.destroy_process_group() + + +if __name__ == "__main__": + args = parse_args() + train(args) From eb6a0c9cde794e28020e2a7309e6fcc220629f14 Mon Sep 17 00:00:00 2001 From: Chenhan Yu Date: Thu, 2 Apr 2026 12:56:12 -0700 Subject: [PATCH 35/72] fix: create attention mask in f32 then cast, matching SpecForge SpecForge creates the mask in f32 (using f32 min = -3.4e38) then casts to bf16, which overflows to -inf. Our code used bf16 min directly (-3.39e38), a large finite negative. This caused different softmax behavior and draft output divergence (max diff 5.3). Co-Authored-By: Claude Opus 4.6 (1M context) Signed-off-by: Chenhan Yu --- modelopt/torch/speculative/plugins/hf_dflash.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/modelopt/torch/speculative/plugins/hf_dflash.py b/modelopt/torch/speculative/plugins/hf_dflash.py index d98e42b0f1..6374ef8a59 100644 --- a/modelopt/torch/speculative/plugins/hf_dflash.py +++ b/modelopt/torch/speculative/plugins/hf_dflash.py @@ -253,8 +253,12 @@ def create_dflash_attention_mask(seq_len, block_size, device, dtype): full_mask_bool = torch.cat([ctx_mask, noise_mask], dim=1) - full_mask = torch.zeros(seq_len, 2 * seq_len, device=device, dtype=dtype) - full_mask.masked_fill_(~full_mask_bool, torch.finfo(dtype).min) + # Create in f32 then cast, matching SpecForge. This ensures masked + # positions get -inf in bf16 (f32 min overflows to -inf when cast), + # not the largest finite negative bf16 value. + full_mask = torch.zeros(seq_len, 2 * seq_len, device=device, dtype=torch.float32) + full_mask.masked_fill_(~full_mask_bool, torch.finfo(torch.float32).min) + full_mask = full_mask.to(dtype=dtype) return full_mask.unsqueeze(0).unsqueeze(0) # [1, 1, L, 2L] From 2c853c1cf697ccf75afa3f5f5aa0b97a7068cbef Mon Sep 17 00:00:00 2001 From: Chenhan Yu Date: Thu, 2 Apr 2026 13:46:11 -0700 Subject: [PATCH 36/72] fix: use HF attention dispatch in DFlashAttention for SpecForge parity Replace direct F.scaled_dot_product_attention with HF's attention function dispatch (eager/sdpa/flash). This matches SpecForge's Qwen3DFlashAttention which uses ALL_ATTENTION_FUNCTIONS[impl]. Co-Authored-By: Claude Opus 4.6 (1M context) Signed-off-by: Chenhan Yu --- .../torch/speculative/plugins/hf_dflash.py | 87 ++++++++++++++----- 1 file changed, 66 insertions(+), 21 deletions(-) diff --git a/modelopt/torch/speculative/plugins/hf_dflash.py b/modelopt/torch/speculative/plugins/hf_dflash.py index 6374ef8a59..bff07239cb 100644 --- a/modelopt/torch/speculative/plugins/hf_dflash.py +++ b/modelopt/torch/speculative/plugins/hf_dflash.py @@ -107,7 +107,7 @@ def apply_rotary_pos_emb(q, k, cos, sin): class DFlashAttention(nn.Module): - """Attention with KV injection, matching SpecForge Qwen3DFlashAttention.""" + """Attention with KV injection, using HF's attention dispatch for exact SpecForge parity.""" def __init__(self, config, layer_idx): """Initialize DFlash attention with KV injection projections and QK-norm.""" @@ -119,33 +119,75 @@ def __init__(self, config, layer_idx): ) self.num_heads = config.num_attention_heads self.num_kv_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_kv_heads self.scaling = self.head_dim**-0.5 + self.attention_dropout = getattr(config, "attention_dropout", 0.0) self.is_causal = False - self.q_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=False) - self.k_proj = nn.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False) - self.v_proj = nn.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False) - self.o_proj = nn.Linear(self.num_heads * self.head_dim, config.hidden_size, bias=False) + attn_bias = getattr(config, "attention_bias", False) + self.q_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=attn_bias) + self.k_proj = nn.Linear( + config.hidden_size, self.num_kv_heads * self.head_dim, bias=attn_bias + ) + self.v_proj = nn.Linear( + config.hidden_size, self.num_kv_heads * self.head_dim, bias=attn_bias + ) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, config.hidden_size, bias=attn_bias) - # QK norm (matches Qwen3DFlashAttention) self.q_norm = _NORM_CLS(self.head_dim, eps=config.rms_norm_eps) self.k_norm = _NORM_CLS(self.head_dim, eps=config.rms_norm_eps) + # Resolve HF attention function matching SpecForge's dispatch + self._attn_fn = None + self.sliding_window = None + + def _get_attn_fn(self): + """Lazily resolve the HF attention function.""" + if self._attn_fn is not None: + return self._attn_fn + try: + from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS + + impl = getattr(self.config, "_attn_implementation", "eager") + if impl and impl != "eager" and impl in ALL_ATTENTION_FUNCTIONS: + self._attn_fn = ALL_ATTENTION_FUNCTIONS[impl] + else: + # Fall back to eager (manual matmul + softmax) + self._attn_fn = self._eager_attention + except (ImportError, AttributeError): + self._attn_fn = self._eager_attention + return self._attn_fn + + def _eager_attention(self, module, q, k, v, attention_mask, **kwargs): + """Eager attention matching HF's eager_attention_forward.""" + scaling = kwargs.get("scaling", self.scaling) + n_rep = self.num_key_value_groups + if n_rep > 1: + k = k.repeat_interleave(n_rep, dim=1) + v = v.repeat_interleave(n_rep, dim=1) + attn_weights = torch.matmul(q, k.transpose(2, 3)) * scaling + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to( + q.dtype + ) + attn_output = torch.matmul(attn_weights, v) + attn_output = attn_output.transpose(1, 2).contiguous() + return attn_output, None + def forward(self, hidden_states, target_hidden, position_embeddings, attention_mask=None): """Forward with KV injection: Q from noise, K/V from context+noise.""" bsz, q_len, _ = hidden_states.shape ctx_len = target_hidden.shape[1] # Q from noise only, with QK-norm - q = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim) + q = self.q_proj(hidden_states).view(bsz, q_len, -1, self.head_dim) q = self.q_norm(q).transpose(1, 2) # K from context + noise, with QK-norm k_ctx = self.k_proj(target_hidden) k_noise = self.k_proj(hidden_states) - k = torch.cat([k_ctx, k_noise], dim=1).view( - bsz, ctx_len + q_len, self.num_kv_heads, self.head_dim - ) + k = torch.cat([k_ctx, k_noise], dim=1).view(bsz, ctx_len + q_len, -1, self.head_dim) k = self.k_norm(k).transpose(1, 2) # V from context + noise (no norm) @@ -153,24 +195,27 @@ def forward(self, hidden_states, target_hidden, position_embeddings, attention_m v_noise = self.v_proj(hidden_states) v = ( torch.cat([v_ctx, v_noise], dim=1) - .view(bsz, ctx_len + q_len, self.num_kv_heads, self.head_dim) + .view(bsz, ctx_len + q_len, -1, self.head_dim) .transpose(1, 2) ) - # RoPE: applied to full 2L positions, Q gets last q_len, K gets all + # RoPE cos, sin = position_embeddings q, k = apply_rotary_pos_emb(q, k, cos, sin) - # GQA expand - if self.num_kv_heads != self.num_heads: - n_rep = self.num_heads // self.num_kv_heads - k = k.repeat_interleave(n_rep, dim=1) - v = v.repeat_interleave(n_rep, dim=1) - - attn_output = F.scaled_dot_product_attention( - q, k, v, attn_mask=attention_mask, is_causal=False, scale=self.scaling + # Use HF's attention dispatch (handles GQA internally) + attn_fn = self._get_attn_fn() + attn_output, _ = attn_fn( + self, + q, + k, + v, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + sliding_window=self.sliding_window, ) - attn_output = attn_output.transpose(1, 2).contiguous().view(bsz, q_len, -1) + attn_output = attn_output.reshape(bsz, q_len, -1) return self.o_proj(attn_output) From d6adadb24b94fe89fcdafca6bc62b3bb58c1d43f Mon Sep 17 00:00:00 2001 From: Chenhan Yu Date: Thu, 2 Apr 2026 14:11:06 -0700 Subject: [PATCH 37/72] fix: default DFlash attention to sdpa matching SpecForge SpecForge's DFlashDraftModel extends Qwen3PreTrainedModel, which resolves _attn_implementation to 'sdpa' via post_init(). We were forcing 'eager', causing different attention computation paths (eager does fp32 softmax, sdpa uses fused kernels). This was the root cause of the 5.3 max diff in decoder outputs despite identical inputs. Co-Authored-By: Claude Opus 4.6 (1M context) Signed-off-by: Chenhan Yu --- modelopt/torch/speculative/plugins/hf_dflash.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/modelopt/torch/speculative/plugins/hf_dflash.py b/modelopt/torch/speculative/plugins/hf_dflash.py index bff07239cb..fde99e3fec 100644 --- a/modelopt/torch/speculative/plugins/hf_dflash.py +++ b/modelopt/torch/speculative/plugins/hf_dflash.py @@ -456,8 +456,10 @@ def modify(self, config): self.dflash_config.hidden_size // self.dflash_config.num_attention_heads, ) self.dflash_config.block_size = self.dflash_block_size + # Default to sdpa, matching SpecForge's DFlashDraftModel(Qwen3PreTrainedModel) + # which resolves to sdpa via post_init() if self.dflash_config._attn_implementation is None: - self.dflash_config._attn_implementation = "eager" + self.dflash_config._attn_implementation = "sdpa" # Target layer IDs num_target_layers = base_config.num_hidden_layers From 65df160077ca62a2321f2da0aefb0a72917f7283 Mon Sep 17 00:00:00 2001 From: Chenhan Yu Date: Thu, 2 Apr 2026 15:08:56 -0700 Subject: [PATCH 38/72] fix: initialize DFlash weights with normal_(std=0.02) matching SpecForge MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit SpecForge's DFlashDraftModel extends Qwen3PreTrainedModel which calls post_init() → _init_weights(), initializing all Linear layers with normal_(mean=0, std=initializer_range=0.02). Our DFlashModule used PyTorch's default kaiming_uniform_ which has a much larger range, leading to different training dynamics and convergence. Co-Authored-By: Claude Opus 4.6 (1M context) Signed-off-by: Chenhan Yu --- modelopt/torch/speculative/plugins/hf_dflash.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/modelopt/torch/speculative/plugins/hf_dflash.py b/modelopt/torch/speculative/plugins/hf_dflash.py index fde99e3fec..f8869f6de7 100644 --- a/modelopt/torch/speculative/plugins/hf_dflash.py +++ b/modelopt/torch/speculative/plugins/hf_dflash.py @@ -267,6 +267,19 @@ def __init__(self, config): self.norm = _NORM_CLS(config.hidden_size, eps=config.rms_norm_eps) self.rotary_emb = _ROTARY_CLS(config=config) + # Initialize weights matching HF PreTrainedModel (normal_ with initializer_range) + # SpecForge's DFlashDraftModel uses Qwen3PreTrainedModel.post_init() which does this. + self._init_weights(config) + + def _init_weights(self, config): + """Initialize weights matching HF PreTrainedModel._init_weights.""" + std = getattr(config, "initializer_range", 0.02) + for module in self.modules(): + if isinstance(module, nn.Linear): + nn.init.normal_(module.weight, mean=0.0, std=std) + if module.bias is not None: + nn.init.zeros_(module.bias) + def forward(self, noise_embedding, target_hidden, position_ids, attention_mask=None): """Forward matching SpecForge DFlashDraftModel.forward.""" hidden_states = noise_embedding From 4451101447d686f062f3ac4d6a5e5ee4a4703f37 Mon Sep 17 00:00:00 2001 From: Chenhan Yu Date: Thu, 2 Apr 2026 15:21:17 -0700 Subject: [PATCH 39/72] debug: add attn_fn resolution and per-layer comparison prints Co-Authored-By: Claude Opus 4.6 (1M context) Signed-off-by: Chenhan Yu --- modelopt/torch/speculative/plugins/hf_dflash.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/modelopt/torch/speculative/plugins/hf_dflash.py b/modelopt/torch/speculative/plugins/hf_dflash.py index f8869f6de7..d5e3e1a02a 100644 --- a/modelopt/torch/speculative/plugins/hf_dflash.py +++ b/modelopt/torch/speculative/plugins/hf_dflash.py @@ -151,11 +151,13 @@ def _get_attn_fn(self): impl = getattr(self.config, "_attn_implementation", "eager") if impl and impl != "eager" and impl in ALL_ATTENTION_FUNCTIONS: self._attn_fn = ALL_ATTENTION_FUNCTIONS[impl] + print(f"[DFlash] attn_fn resolved to: {impl} -> {self._attn_fn.__name__}") else: - # Fall back to eager (manual matmul + softmax) self._attn_fn = self._eager_attention - except (ImportError, AttributeError): + print(f"[DFlash] attn_fn fallback to eager (impl={impl})") + except (ImportError, AttributeError) as e: self._attn_fn = self._eager_attention + print(f"[DFlash] attn_fn fallback to eager (error: {e})") return self._attn_fn def _eager_attention(self, module, q, k, v, attention_mask, **kwargs): From 27260680d9df7bf21bc7743e84e1da9078c8e0fb Mon Sep 17 00:00:00 2001 From: Chenhan Yu Date: Thu, 2 Apr 2026 17:20:57 -0700 Subject: [PATCH 40/72] feat: update DFlash training to match SpecForge latest (post-PR #473) Major rewrite of the training forward to match SpecForge's latest: - Random anchor sampling (#463) - Bidirectional intra-block attention (#427) - Label alignment with anchor positions (#473) - Loss decay weighting (#463) - New config: dflash_num_anchors, dflash_loss_decay_gamma Co-Authored-By: Claude Opus 4.6 (1M context) Signed-off-by: Chenhan Yu --- examples/speculative_decoding/main.py | 13 ++ modelopt/torch/speculative/config.py | 9 +- .../torch/speculative/dflash/dflash_model.py | 1 + .../torch/speculative/plugins/hf_dflash.py | 217 +++++++++++++----- 4 files changed, 179 insertions(+), 61 deletions(-) diff --git a/examples/speculative_decoding/main.py b/examples/speculative_decoding/main.py index 650ff44336..8d3f5c1c0b 100644 --- a/examples/speculative_decoding/main.py +++ b/examples/speculative_decoding/main.py @@ -168,6 +168,17 @@ class DFlashArguments: "Enables training with data not synthesized by the target model." }, ) + dflash_num_anchors: int = field( + default=512, + metadata={"help": "Number of random anchor positions per sequence during training."}, + ) + dflash_loss_decay_gamma: float = field( + default=0.0, + metadata={ + "help": "Gamma for loss decay weighting (paper Eq.4). " + "Suggested: 7 for block_size=16. 0 disables." + }, + ) def train(): @@ -275,6 +286,8 @@ def train(): "dflash_block_size": dflash_args.dflash_block_size, "dflash_use_torch_compile": not dflash_args.dflash_disable_torch_compile, "dflash_self_logit_distillation": dflash_args.dflash_use_logit_distillation, + "dflash_num_anchors": dflash_args.dflash_num_anchors, + "dflash_loss_decay_factor": dflash_args.dflash_loss_decay_gamma, "dflash_architecture_config": custom_config, } diff --git a/modelopt/torch/speculative/config.py b/modelopt/torch/speculative/config.py index 59aa98db4b..5202865efb 100644 --- a/modelopt/torch/speculative/config.py +++ b/modelopt/torch/speculative/config.py @@ -77,7 +77,14 @@ class DFlashConfig(ModeloptBaseConfig): ) dflash_loss_decay_factor: float = ModeloptField( - default=0.9, description="Decay factor for per-block loss weighting." + default=0.0, + description="Gamma for exponential loss decay weighting (paper Eq.4). " + "Suggested: 7 for block_size=16, 5 for 10, 4 for 8. 0 disables.", + ) + + dflash_num_anchors: int = ModeloptField( + default=512, + description="Number of random anchor positions sampled per sequence during training.", ) dflash_report_acc: bool = ModeloptField( diff --git a/modelopt/torch/speculative/dflash/dflash_model.py b/modelopt/torch/speculative/dflash/dflash_model.py index e44b17b505..0a10f065eb 100644 --- a/modelopt/torch/speculative/dflash/dflash_model.py +++ b/modelopt/torch/speculative/dflash/dflash_model.py @@ -31,5 +31,6 @@ def modify(self, config): self.dflash_freeze_base_model = config.dflash_freeze_base_model self.dflash_loss_decay_factor = config.dflash_loss_decay_factor self.dflash_self_logit_distillation = config.dflash_self_logit_distillation + self.dflash_num_anchors = config.dflash_num_anchors self.dflash_report_acc = config.dflash_report_acc self.dflash_use_torch_compile = config.dflash_use_torch_compile diff --git a/modelopt/torch/speculative/plugins/hf_dflash.py b/modelopt/torch/speculative/plugins/hf_dflash.py index d5e3e1a02a..47efbd3eb9 100644 --- a/modelopt/torch/speculative/plugins/hf_dflash.py +++ b/modelopt/torch/speculative/plugins/hf_dflash.py @@ -516,6 +516,7 @@ def modify(self, config): ) self.is_quantized = False + self._num_anchors = self.dflash_num_anchors # Store bound reference to the original model class's forward. # DynamicModule changes type(self) but the original class is in _original_cls. @@ -549,6 +550,42 @@ def _base_forward(self, **kwargs): """Call the original model's forward, bypassing DFlash wrapper.""" return self._original_forward_cls.forward(self, **kwargs) + def _sample_anchor_positions(self, seq_len, loss_mask, device): + """Randomly sample anchor positions per sample, matching SpecForge PR #473. + + Returns (anchor_positions [B, N], block_keep_mask [B, N]). + """ + bs = self.dflash_block_size + bsz = loss_mask.shape[0] + max_anchor = max(seq_len - bs, 0) + num_anchors = getattr(self, "_num_anchors", 512) + + valid = loss_mask[:, : max_anchor + 1] > 0.5 + valid_counts = valid.sum(dim=1) + max_n = min(num_anchors, int(valid_counts.max().item()) - 1) + + if max_n <= 0: + # No valid anchors — return empty + anchors = torch.zeros(bsz, 1, dtype=torch.long, device=device) + keep = torch.zeros(bsz, 1, dtype=torch.bool, device=device) + return anchors, keep + + indices = torch.arange(max_anchor + 1, device=device).unsqueeze(0).expand(bsz, -1) + masked_indices = torch.where(valid, indices, torch.tensor(seq_len + 1, device=device)) + + random_vals = torch.rand(bsz, max_anchor + 1, device=device) + random_vals = torch.where(valid, random_vals, torch.tensor(2.0, device=device)) + + _, sorted_idx = random_vals.sort(dim=1) + gathered = torch.gather(masked_indices, 1, sorted_idx) + anchors = gathered[:, :max_n].sort(dim=1).values + + keep = torch.arange(max_n, device=device).unsqueeze(0) < valid_counts.unsqueeze(1).clamp( + max=max_n + ) + anchors = torch.where(keep, anchors, torch.tensor(0, dtype=torch.long, device=device)) + return anchors, keep + def forward( self, input_ids=None, @@ -563,7 +600,15 @@ def forward( cache_position=None, **kwargs, ): - """Training forward matching SpecForge OnlineDFlashModel.forward.""" + """Training forward matching SpecForge latest (post-PR #473). + + Key changes from original PR #415: + - Random anchor sampling instead of uniform block division + - Bidirectional intra-block attention (no causal constraint) + - Context sees strictly before anchor position + - Label alignment: position k predicts token at anchor+k + - Optional loss decay weighting + """ if not self.training: return super().forward( input_ids=input_ids, @@ -583,9 +628,7 @@ def forward( block_size = self.dflash_block_size device = input_ids.device - # 1. Run base model → raw multi-layer hidden states - # Use super().forward() which goes through DynamicModule → original model - # (same pattern as EAGLE's HFEagleModel) + # 1. Run base model → hidden states with torch.no_grad(): base_outputs = super().forward( input_ids=input_ids, @@ -593,83 +636,137 @@ def forward( output_hidden_states=True, ) - # Extract and concatenate target layer hidden states offset = 1 selected = [base_outputs.hidden_states[lid + offset] for lid in self.target_layer_ids] target_hidden = torch.cat(selected, dim=-1) # [B, seq, num_layers * H] - # 2. Truncate to multiple of block_size - n_blocks = seq_len // block_size - effective_len = n_blocks * block_size - input_ids_trunc = input_ids[:, :effective_len] - target_hidden = target_hidden[:, :effective_len, :] - # Loss mask: use labels (response-only) if available, else attention_mask (padding) + # 2. Build loss mask from labels or attention_mask if labels is not None: - # labels == -100 means "ignore" (system/user tokens when answer_only_loss=True) - loss_mask_input = (labels[:, :effective_len] != -100).float() + loss_mask = (labels != -100).float() elif attention_mask is not None: - loss_mask_input = attention_mask[:, :effective_len].float() + loss_mask = attention_mask.float() else: - loss_mask_input = torch.ones(bsz, effective_len, device=device) + loss_mask = torch.ones(bsz, seq_len, device=device) + + # 3. Random anchor sampling (SpecForge PR #463/#473) + anchor_positions, block_keep_mask = self._sample_anchor_positions( + seq_len, loss_mask, device + ) + n_blocks = anchor_positions.shape[1] + + if n_blocks == 0 or not block_keep_mask.any(): + loss = ( + self._base_model_lm_head(target_hidden[:, :1, : self.config.hidden_size]).sum() + * 0.0 + ) + return ModelOutput(loss=loss, logits=base_outputs.logits, train_acc=[[0.0]]) + + # 4. Create noise embeddings: anchor token at block start, mask_token elsewhere + noise_ids = torch.full( + (bsz, n_blocks * block_size), self.mask_token_id, dtype=torch.long, device=device + ) + block_starts = torch.arange(n_blocks, device=device) * block_size + block_starts_exp = block_starts.unsqueeze(0).expand(bsz, -1) + valid_anchors = anchor_positions.clamp(0, seq_len - 1) + anchor_tokens = torch.gather(input_ids, 1, valid_anchors) + batch_idx = torch.arange(bsz, device=device).unsqueeze(1).expand(bsz, n_blocks) + noise_ids[batch_idx, block_starts_exp] = torch.where( + block_keep_mask, + anchor_tokens, + torch.tensor(self.mask_token_id, dtype=torch.long, device=device), + ) + noise_embedding = self._base_model_embeddings(noise_ids) + + # 5. Position IDs: context [0..S-1], draft blocks [anchor+0..anchor+B-1] + ctx_pos = torch.arange(seq_len, device=device).unsqueeze(0).expand(bsz, -1) + offsets = torch.arange(block_size, device=device).view(1, 1, -1) + draft_pos = (anchor_positions.unsqueeze(-1) + offsets).view(bsz, -1) + full_pos = torch.cat([ctx_pos, draft_pos], dim=1) + + # 6. Attention mask: SDPA bool mask [B, 1, Q_LEN, KV_LEN] + q_len = n_blocks * block_size + kv_len = seq_len + q_len + + q_indices = torch.arange(q_len, device=device).view(1, 1, -1, 1) + kv_indices = torch.arange(kv_len, device=device).view(1, 1, 1, -1) + q_block_ids = q_indices // block_size + + anchor_exp = anchor_positions.view(bsz, 1, n_blocks, 1).repeat_interleave(block_size, dim=2) - # 3. Prepare noise: mask_token_id everywhere, real token at block starts - positions = torch.arange(effective_len, device=device) - is_block_start = (positions % block_size) == 0 - noise_input_ids = torch.full_like(input_ids_trunc, self.mask_token_id) - noise_input_ids[:, is_block_start] = input_ids_trunc[:, is_block_start] - noise_embedding = self._base_model_embeddings(noise_input_ids) + # Context: kv < S and kv < anchor + mask_ctx = (kv_indices < seq_len) & (kv_indices < anchor_exp) + # Draft: kv >= S and same block + is_draft = kv_indices >= seq_len + kv_block_ids = (kv_indices - seq_len) // block_size + mask_draft = is_draft & (q_block_ids == kv_block_ids) + # Valid block + valid_block = block_keep_mask.view(bsz, 1, n_blocks, 1).repeat_interleave(block_size, dim=2) - # 4. Position IDs: [0..L-1, 0..L-1] - pos_seq = torch.arange(effective_len, device=device) - position_ids_2l = torch.cat([pos_seq, pos_seq]).unsqueeze(0).expand(bsz, -1) + final_mask = ((mask_ctx | mask_draft) & valid_block).unsqueeze(1) # [B, 1, Q, KV] - # 5. Attention mask: [1, 1, L, 2L] + # Convert bool mask to float additive mask for SDPA dtype = target_hidden.dtype - dflash_attn_mask = create_dflash_attention_mask(effective_len, block_size, device, dtype) + attn_mask = torch.zeros(bsz, 1, q_len, kv_len, device=device, dtype=dtype) + attn_mask.masked_fill_(~final_mask, torch.finfo(torch.float32).min) + attn_mask = attn_mask.to(dtype) - # 6. Draft forward + # 7. Draft forward hidden = self.dflash_module( noise_embedding=noise_embedding, target_hidden=target_hidden, - position_ids=position_ids_2l, - attention_mask=dflash_attn_mask, + position_ids=full_pos, + attention_mask=attn_mask, ) - # 7. Loss computation + # 8. Loss: same-position prediction (position k predicts token at anchor+k) logits = self._base_model_lm_head(hidden) - dflash_loss_mask = create_dflash_loss_mask(effective_len, block_size, device) - combined_mask = loss_mask_input * dflash_loss_mask.unsqueeze(0) - - logits_flat = logits.reshape(-1, logits.size(-1)) - labels_flat = input_ids_trunc.reshape(-1) - mask_flat = combined_mask.reshape(-1) - - active_indices = mask_flat > 0.5 - active_logits = logits_flat[active_indices] - active_labels = labels_flat[active_indices] - - if active_logits.numel() > 0: - if self.dflash_self_logit_distillation: - # Logit distillation: learn from target model's output distribution - # This works regardless of whether training data matches the target model - base_logits_trunc = base_outputs.logits[:, :effective_len, :] - base_logits_flat = base_logits_trunc.reshape(-1, base_logits_trunc.size(-1)) - active_base_logits = base_logits_flat[active_indices].detach() - target_soft = torch.softmax(active_base_logits, dim=-1) - draft_logsoft = torch.log_softmax(active_logits, dim=-1) - loss = -(target_soft * draft_logsoft).sum(dim=-1).mean() - else: - # Hard CE: predict ground truth tokens directly - # Only works well when training data is synthesized by the target model - loss = F.cross_entropy(active_logits, active_labels) + + label_offsets = torch.arange(0, block_size, device=device).view(1, 1, -1) + label_indices = anchor_positions.unsqueeze(-1) + label_offsets + valid_label = label_indices < seq_len + safe_label_indices = label_indices.clamp(max=seq_len - 1) + + target_ids = torch.gather( + input_ids.unsqueeze(1).expand(-1, n_blocks, -1), 2, safe_label_indices + ) + + # Weight mask: valid block * in bounds * exclude anchor (pos 0) * loss_mask + weight_mask = block_keep_mask.unsqueeze(-1).expand(-1, -1, block_size).float() + weight_mask = weight_mask * valid_label.float() + pos_in_block = torch.arange(block_size, device=device).view(1, 1, -1) + weight_mask = weight_mask * (pos_in_block > 0).float() + + orig_loss_mask = torch.gather( + loss_mask.unsqueeze(1).expand(-1, n_blocks, -1), 2, safe_label_indices + ) + weight_mask = weight_mask * orig_loss_mask + + binary_eval_mask = weight_mask.view(-1) + + # Optional loss decay + if self.dflash_loss_decay_factor > 0: + k = torch.arange(block_size, device=device).view(1, 1, -1) + decay = torch.exp(-(k - 1).clamp(min=0).float() / self.dflash_loss_decay_factor) + weight_mask = weight_mask * decay + + # Cross entropy + flat_logits = logits.view(-1, logits.size(-1)) + flat_targets = target_ids.view(-1) + flat_weights = weight_mask.view(-1) + + valid_count = flat_weights.sum() + 1e-6 + + if valid_count > 1.0: + loss_per_token = F.cross_entropy(flat_logits, flat_targets, reduction="none") + loss = (loss_per_token * flat_weights).sum() / valid_count with torch.no_grad(): - preds = active_logits.argmax(dim=-1) - accuracy = (preds == active_labels).float().mean().item() + preds = flat_logits.argmax(dim=-1) + correct = (preds == flat_targets) & (binary_eval_mask > 0.5) + accuracy = correct.sum().float() / (binary_eval_mask.sum() + 1e-6) + accuracy = accuracy.item() else: - # No valid positions — compute a zero loss that still flows through - # dflash_module parameters to keep DDP gradient sync happy - loss = logits.sum() * 0.0 + loss = flat_logits.sum() * 0.0 accuracy = 0.0 return ModelOutput( From 3516c0b1f7763fd6d6b02fca3bbfa55fb2285400 Mon Sep 17 00:00:00 2001 From: Chenhan Yu Date: Thu, 2 Apr 2026 19:31:13 -0700 Subject: [PATCH 41/72] fix: remove extra unsqueeze in DFlash training attention mask The mask tensors are already 4D [B, 1, Q, KV] from the broadcast. The extra unsqueeze(1) created 5D [B, 1, 1, Q, KV] causing shape mismatch at runtime. Co-Authored-By: Claude Opus 4.6 (1M context) Signed-off-by: Chenhan Yu --- modelopt/torch/speculative/plugins/hf_dflash.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modelopt/torch/speculative/plugins/hf_dflash.py b/modelopt/torch/speculative/plugins/hf_dflash.py index 47efbd3eb9..ca88136d0b 100644 --- a/modelopt/torch/speculative/plugins/hf_dflash.py +++ b/modelopt/torch/speculative/plugins/hf_dflash.py @@ -702,7 +702,7 @@ def forward( # Valid block valid_block = block_keep_mask.view(bsz, 1, n_blocks, 1).repeat_interleave(block_size, dim=2) - final_mask = ((mask_ctx | mask_draft) & valid_block).unsqueeze(1) # [B, 1, Q, KV] + final_mask = (mask_ctx | mask_draft) & valid_block # [B, 1, Q, KV] # Convert bool mask to float additive mask for SDPA dtype = target_hidden.dtype From 606e31d567659601aa17a994485a7e28f15f7dde Mon Sep 17 00:00:00 2001 From: Chenhan Yu Date: Thu, 2 Apr 2026 19:47:47 -0700 Subject: [PATCH 42/72] fix: create training attention mask in f32 to avoid bf16 overflow torch.finfo(torch.float32).min overflows when masked_fill_ targets a bf16 tensor. Create in f32 first, then cast to bf16. Co-Authored-By: Claude Opus 4.6 (1M context) Signed-off-by: Chenhan Yu --- modelopt/torch/speculative/plugins/hf_dflash.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/modelopt/torch/speculative/plugins/hf_dflash.py b/modelopt/torch/speculative/plugins/hf_dflash.py index ca88136d0b..448929c1e4 100644 --- a/modelopt/torch/speculative/plugins/hf_dflash.py +++ b/modelopt/torch/speculative/plugins/hf_dflash.py @@ -706,9 +706,9 @@ def forward( # Convert bool mask to float additive mask for SDPA dtype = target_hidden.dtype - attn_mask = torch.zeros(bsz, 1, q_len, kv_len, device=device, dtype=dtype) + attn_mask = torch.zeros(bsz, 1, q_len, kv_len, device=device, dtype=torch.float32) attn_mask.masked_fill_(~final_mask, torch.finfo(torch.float32).min) - attn_mask = attn_mask.to(dtype) + attn_mask = attn_mask.to(dtype=dtype) # 7. Draft forward hidden = self.dflash_module( From e1237f7f023ddfb9c5ba71300a089c1e348ab414 Mon Sep 17 00:00:00 2001 From: Chenhan Yu Date: Fri, 3 Apr 2026 07:51:32 -0700 Subject: [PATCH 43/72] fix: add dflash_num_anchors/loss_decay_gamma to launch_train.sh Also fix AR validation to handle both ModelOpt (prefixed) and SpecForge (no prefix) weight formats, and upgrade transformers for Qwen3.5 support. Co-Authored-By: Claude Opus 4.6 (1M context) Signed-off-by: Chenhan Yu --- examples/speculative_decoding/launch_train.sh | 14 ++++++++++++++ tools/launcher/common/dflash/ar_validate.sh | 13 +++++++++++-- tools/launcher/common/dflash/online_training.sh | 9 +++++++-- 3 files changed, 32 insertions(+), 4 deletions(-) diff --git a/examples/speculative_decoding/launch_train.sh b/examples/speculative_decoding/launch_train.sh index 7beb567425..504f1b8b60 100755 --- a/examples/speculative_decoding/launch_train.sh +++ b/examples/speculative_decoding/launch_train.sh @@ -150,6 +150,14 @@ while [ $# -gt 0 ]; do if [[ "$1" != *=* ]]; then shift; fi DFLASH_MASK_TOKEN_ID="${1#*=}" ;; + --dflash_num_anchors*) + if [[ "$1" != *=* ]]; then shift; fi + DFLASH_NUM_ANCHORS="${1#*=}" + ;; + --dflash_loss_decay_gamma*) + if [[ "$1" != *=* ]]; then shift; fi + DFLASH_LOSS_DECAY_GAMMA="${1#*=}" + ;; *) >&2 printf "Error: Invalid argument ${1#*=}\n" exit 1 @@ -221,6 +229,12 @@ elif [[ "$MODE" == "dflash" ]]; then if [[ -n "$DFLASH_MASK_TOKEN_ID" ]]; then SPECULATIVE_ARGS="$SPECULATIVE_ARGS --dflash_mask_token_id $DFLASH_MASK_TOKEN_ID" fi + if [[ -n "$DFLASH_NUM_ANCHORS" ]]; then + SPECULATIVE_ARGS="$SPECULATIVE_ARGS --dflash_num_anchors $DFLASH_NUM_ANCHORS" + fi + if [[ -n "$DFLASH_LOSS_DECAY_GAMMA" ]]; then + SPECULATIVE_ARGS="$SPECULATIVE_ARGS --dflash_loss_decay_gamma $DFLASH_LOSS_DECAY_GAMMA" + fi # DFlash uses DDP instead of FSDP FSDP_ARGS="--ddp_find_unused_parameters True --ddp_timeout 1800" else diff --git a/tools/launcher/common/dflash/ar_validate.sh b/tools/launcher/common/dflash/ar_validate.sh index 01ad61ffd1..b9df0b5c6f 100644 --- a/tools/launcher/common/dflash/ar_validate.sh +++ b/tools/launcher/common/dflash/ar_validate.sh @@ -30,6 +30,8 @@ SCRIPT_DIR="$(dirname "$(readlink -f "$0")")" source ${SCRIPT_DIR}/../service_utils.sh trap 'error_handler $0 $LINENO' ERR +pip install --upgrade "transformers>=4.57" 2>&1 | tail -3 + DFLASH_BLOCK_SIZE=${DFLASH_BLOCK_SIZE:-16} DFLASH_NUM_LAYERS=${DFLASH_NUM_LAYERS:-5} NUM_SAMPLES=${NUM_SAMPLES:-20} @@ -81,9 +83,16 @@ if ckpt_files: state = {} for f in ckpt_files: state.update(load_file(f)) + # Try with dflash_module prefix first (ModelOpt format) dflash_keys = {k: v for k, v in state.items() if 'dflash_module' in k} - model.load_state_dict(dflash_keys, strict=False) - print(f'Loaded {len(dflash_keys)} DFlash weights from {len(ckpt_files)} file(s)') + if dflash_keys: + model.load_state_dict(dflash_keys, strict=False) + print(f'Loaded {len(dflash_keys)} DFlash weights (with prefix)') + else: + # No prefix — SpecForge format, load directly into dflash_module + result = model.dflash_module.load_state_dict(state, strict=False) + loaded = len(state) - len(result.unexpected_keys) + print(f'Loaded {loaded} DFlash weights (no prefix), missing={len(result.missing_keys)}') else: print('WARNING: No checkpoint files found, using random weights') diff --git a/tools/launcher/common/dflash/online_training.sh b/tools/launcher/common/dflash/online_training.sh index 0cdb6a906d..8b244e6e50 100644 --- a/tools/launcher/common/dflash/online_training.sh +++ b/tools/launcher/common/dflash/online_training.sh @@ -125,8 +125,13 @@ if ckpt_files: for f in ckpt_files: state.update(load_file(f)) dflash_keys = {k: v for k, v in state.items() if 'dflash_module' in k} - model.load_state_dict(dflash_keys, strict=False) - print(f'Loaded {len(dflash_keys)} DFlash weights from {len(ckpt_files)} file(s)') + if dflash_keys: + model.load_state_dict(dflash_keys, strict=False) + print(f'Loaded {len(dflash_keys)} DFlash weights (with prefix)') + else: + result = model.dflash_module.load_state_dict(state, strict=False) + loaded = len(state) - len(result.unexpected_keys) + print(f'Loaded {loaded} DFlash weights (no prefix), missing={len(result.missing_keys)}') else: print('WARNING: No checkpoint files found, using random weights') From b8e5eb75c94274f548a0dd1c9f05da3afcb8113b Mon Sep 17 00:00:00 2001 From: Chenhan Yu Date: Fri, 3 Apr 2026 08:57:43 -0700 Subject: [PATCH 44/72] feat: add logit distillation to new random-anchor DFlash training Gather teacher logits at anchor+offset positions from base model output and use KL divergence as loss. Enabled via --dflash_use_logit_distillation flag. Also fix launch_train.sh to pass dflash_num_anchors and dflash_loss_decay_gamma, and fix AR validation to handle SpecForge weight format. Co-Authored-By: Claude Opus 4.6 (1M context) Signed-off-by: Chenhan Yu --- .../torch/speculative/plugins/hf_dflash.py | 20 ++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/modelopt/torch/speculative/plugins/hf_dflash.py b/modelopt/torch/speculative/plugins/hf_dflash.py index 448929c1e4..20fa9bea05 100644 --- a/modelopt/torch/speculative/plugins/hf_dflash.py +++ b/modelopt/torch/speculative/plugins/hf_dflash.py @@ -749,7 +749,7 @@ def forward( decay = torch.exp(-(k - 1).clamp(min=0).float() / self.dflash_loss_decay_factor) weight_mask = weight_mask * decay - # Cross entropy + # Cross entropy or logit distillation flat_logits = logits.view(-1, logits.size(-1)) flat_targets = target_ids.view(-1) flat_weights = weight_mask.view(-1) @@ -757,8 +757,22 @@ def forward( valid_count = flat_weights.sum() + 1e-6 if valid_count > 1.0: - loss_per_token = F.cross_entropy(flat_logits, flat_targets, reduction="none") - loss = (loss_per_token * flat_weights).sum() / valid_count + if self.dflash_self_logit_distillation: + # Gather teacher logits at anchor+offset positions + base_logits = base_outputs.logits # [B, seq, vocab] + teacher_logits = torch.gather( + base_logits.unsqueeze(1).expand(-1, n_blocks, -1, -1), + 2, + safe_label_indices.unsqueeze(-1).expand(-1, -1, -1, base_logits.size(-1)), + ) # [B, N, block_size, vocab] + flat_teacher = teacher_logits.reshape(-1, base_logits.size(-1)).detach() + target_soft = torch.softmax(flat_teacher, dim=-1) + draft_logsoft = torch.log_softmax(flat_logits, dim=-1) + kd_loss = -(target_soft * draft_logsoft).sum(dim=-1) + loss = (kd_loss * flat_weights).sum() / valid_count + else: + loss_per_token = F.cross_entropy(flat_logits, flat_targets, reduction="none") + loss = (loss_per_token * flat_weights).sum() / valid_count with torch.no_grad(): preds = flat_logits.argmax(dim=-1) From b0df28cc608ea6689757f26dbf99b0e79afd1a39 Mon Sep 17 00:00:00 2001 From: Chenhan Yu Date: Fri, 3 Apr 2026 08:58:38 -0700 Subject: [PATCH 45/72] fix: add dflash_use_logit_distillation to launch_train.sh Co-Authored-By: Claude Opus 4.6 (1M context) Signed-off-by: Chenhan Yu --- examples/speculative_decoding/launch_train.sh | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/examples/speculative_decoding/launch_train.sh b/examples/speculative_decoding/launch_train.sh index 504f1b8b60..8dc9ab216e 100755 --- a/examples/speculative_decoding/launch_train.sh +++ b/examples/speculative_decoding/launch_train.sh @@ -158,6 +158,9 @@ while [ $# -gt 0 ]; do if [[ "$1" != *=* ]]; then shift; fi DFLASH_LOSS_DECAY_GAMMA="${1#*=}" ;; + --dflash_use_logit_distillation*) + DFLASH_USE_LOGIT_DISTILLATION="True" + ;; *) >&2 printf "Error: Invalid argument ${1#*=}\n" exit 1 @@ -235,6 +238,9 @@ elif [[ "$MODE" == "dflash" ]]; then if [[ -n "$DFLASH_LOSS_DECAY_GAMMA" ]]; then SPECULATIVE_ARGS="$SPECULATIVE_ARGS --dflash_loss_decay_gamma $DFLASH_LOSS_DECAY_GAMMA" fi + if [[ "$DFLASH_USE_LOGIT_DISTILLATION" == "True" ]]; then + SPECULATIVE_ARGS="$SPECULATIVE_ARGS --dflash_use_logit_distillation" + fi # DFlash uses DDP instead of FSDP FSDP_ARGS="--ddp_find_unused_parameters True --ddp_timeout 1800" else From 4226349671f9948de54d039edfd5c8bf54034f07 Mon Sep 17 00:00:00 2001 From: Chenhan Yu Date: Fri, 3 Apr 2026 09:14:32 -0700 Subject: [PATCH 46/72] fix: shift teacher logits by -1 for DFlash logit distillation Base model logits at position p predict token p+1 (autoregressive), but DFlash draft at position k predicts token anchor+k (same position). Teacher logits for KD should be gathered at anchor+k-1, not anchor+k. Co-Authored-By: Claude Opus 4.6 (1M context) Signed-off-by: Chenhan Yu --- modelopt/torch/speculative/plugins/hf_dflash.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/modelopt/torch/speculative/plugins/hf_dflash.py b/modelopt/torch/speculative/plugins/hf_dflash.py index 20fa9bea05..e8690266f6 100644 --- a/modelopt/torch/speculative/plugins/hf_dflash.py +++ b/modelopt/torch/speculative/plugins/hf_dflash.py @@ -758,12 +758,15 @@ def forward( if valid_count > 1.0: if self.dflash_self_logit_distillation: - # Gather teacher logits at anchor+offset positions + # Teacher logits at position p predict token p+1 (autoregressive). + # Draft position k predicts token at anchor+k (same position). + # So teacher logits for token anchor+k are at position anchor+k-1. base_logits = base_outputs.logits # [B, seq, vocab] + teacher_indices = (safe_label_indices - 1).clamp(min=0) teacher_logits = torch.gather( base_logits.unsqueeze(1).expand(-1, n_blocks, -1, -1), 2, - safe_label_indices.unsqueeze(-1).expand(-1, -1, -1, base_logits.size(-1)), + teacher_indices.unsqueeze(-1).expand(-1, -1, -1, base_logits.size(-1)), ) # [B, N, block_size, vocab] flat_teacher = teacher_logits.reshape(-1, base_logits.size(-1)).detach() target_soft = torch.softmax(flat_teacher, dim=-1) From 818eb74ee3d4685a07efe61887d20bcbc59b9430 Mon Sep 17 00:00:00 2001 From: Chenhan Yu Date: Fri, 3 Apr 2026 09:43:16 -0700 Subject: [PATCH 47/72] fix: mask all tokens when assistant pattern not found Previously kept all tokens as valid when regex couldn't find assistant spans, causing training on system/user tokens and inflated per-token accuracy (70% vs SpecForge's 15%). Now masks everything, so these samples contribute zero loss. Co-Authored-By: Claude Opus 4.6 (1M context) Signed-off-by: Chenhan Yu --- modelopt/torch/utils/plugins/transformers_dataset.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/modelopt/torch/utils/plugins/transformers_dataset.py b/modelopt/torch/utils/plugins/transformers_dataset.py index 7aecbbfb18..f9fbeea65a 100644 --- a/modelopt/torch/utils/plugins/transformers_dataset.py +++ b/modelopt/torch/utils/plugins/transformers_dataset.py @@ -259,8 +259,9 @@ def _apply_answer_only_labels(self, examples, labels, input_ids): break if not found: - # No assistant pattern found — keep all labels (don't mask) - pass + # No assistant pattern found — mask all labels to avoid + # training on system/user tokens which inflates accuracy + labels[batch_idx, :] = IGNORE_TOKEN_ID return labels From 49038c7da6d778a9a18d68a57fdfe921d3b3c129 Mon Sep 17 00:00:00 2001 From: Chenhan Yu Date: Fri, 3 Apr 2026 10:24:25 -0700 Subject: [PATCH 48/72] feat: auto-inject generation tags for reliable answer_only_loss Replace complex chat templates with simplified versions that include {% generation %} tags when answer_only_loss=True. Supports ChatML (Qwen, Phi) and Llama3 template styles. This fixes the inflated per-token accuracy (70% vs SpecForge's 15%) caused by the regex fallback silently training on system/user tokens when {% generation %} tags were missing. The simplified templates correctly: - Mark only assistant content for loss (including blocks) - Support multi-turn conversations - Mask system and user tokens Co-Authored-By: Claude Opus 4.6 (1M context) Signed-off-by: Chenhan Yu --- .../utils/plugins/transformers_dataset.py | 99 +++++++++++++++++++ 1 file changed, 99 insertions(+) diff --git a/modelopt/torch/utils/plugins/transformers_dataset.py b/modelopt/torch/utils/plugins/transformers_dataset.py index f9fbeea65a..acc02dee50 100644 --- a/modelopt/torch/utils/plugins/transformers_dataset.py +++ b/modelopt/torch/utils/plugins/transformers_dataset.py @@ -153,6 +153,9 @@ def __init__( if self.tokenizer.chat_template is None: raise ValueError("No valid chat template!") + if self.answer_only_loss: + self._ensure_generation_tags() + def _post_process_tokenizer(self): if self.tokenizer.pad_token_id is None: print_rank_0("The tokenizer has no pad_token_id, using eos_token_id instead.") @@ -171,6 +174,102 @@ def _post_process_chat_template(self): REMOVE_THINK_CHAT_TEMPLATE, "" ) + # Simplified chat templates with {% generation %} tags for answer_only_loss. + # These drop complex features (tool_calls, thinking parsing) but correctly mark + # assistant content for loss masking. The training data should already contain + # the full assistant response including blocks. + _GENERATION_TEMPLATES = { + "chatml": ( + "{% for message in messages %}" + "{% if message['role'] == 'system' %}" + "{{ '<|im_start|>system\n' + message['content'] + '<|im_end|>\n' }}" + "{% elif message['role'] == 'user' %}" + "{{ '<|im_start|>user\n' + message['content'] + '<|im_end|>\n' }}" + "{% elif message['role'] == 'assistant' %}" + "{{ '<|im_start|>assistant\n' }}{% generation %}{{ message['content'] }}" + "{% endgeneration %}{{ '<|im_end|>\n' }}" + "{% endif %}" + "{% endfor %}" + "{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}" + ), + "llama3": ( + "{% for message in messages %}" + "{% if message['role'] == 'system' %}" + "{{ '<|start_header_id|>system<|end_header_id|>\n\n' + message['content'] + '<|eot_id|>' }}" + "{% elif message['role'] == 'user' %}" + "{{ '<|start_header_id|>user<|end_header_id|>\n\n' + message['content'] + '<|eot_id|>' }}" + "{% elif message['role'] == 'assistant' %}" + "{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% generation %}" + "{{ message['content'] }}{% endgeneration %}{{ '<|eot_id|>' }}" + "{% endif %}" + "{% endfor %}" + "{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% endif %}" + ), + } + + def _ensure_generation_tags(self): + """Ensure chat template has {% generation %} tags for answer_only_loss. + + If the template already has generation tags, no action taken. + Otherwise, detect the template style and replace with a simplified + version that includes proper generation tags. + """ + template = self.tokenizer.chat_template + if template is None: + return + + if "{% generation %}" in template or "{%generation%}" in template: + return + + # Detect template style and replace with generation-tagged version + old_template = template + if "<|im_start|>" in template and "<|im_end|>" in template: + style = "chatml" + elif "<|start_header_id|>" in template and "<|eot_id|>" in template: + style = "llama3" + else: + print_rank_0( + "WARNING: Cannot auto-inject {% generation %} tags for this chat template. " + "answer_only_loss will use regex fallback. Consider providing a template " + "with {% generation %} tags via the chat_template parameter." + ) + return + + new_template = self._GENERATION_TEMPLATES[style] + self.tokenizer.chat_template = new_template + + # Verify + try: + test_msgs = [ + [ + {"role": "user", "content": "hi"}, + {"role": "assistant", "content": "hello"}, + ] + ] + result = self.tokenizer.apply_chat_template( + test_msgs, + return_dict=True, + return_assistant_tokens_mask=True, + padding=True, + return_tensors="pt", + ) + mask = result.get("assistant_masks", None) + if mask is not None and mask.any(): + print_rank_0( + f"Replaced chat template with {style} generation-tagged version " + f"for answer_only_loss." + ) + return + except Exception: + pass + + # Revert on failure + self.tokenizer.chat_template = old_template + print_rank_0( + f"WARNING: Failed to apply {style} generation template. " + "Using regex fallback for answer_only_loss." + ) + def _process_chat_sample(self, examples: list): tokenized_examples = self.tokenizer.apply_chat_template( examples, From c49f6d91ead1f4a7803a2005cbb612cb17892972 Mon Sep 17 00:00:00 2001 From: Chenhan Yu Date: Fri, 3 Apr 2026 10:37:54 -0700 Subject: [PATCH 49/72] fix: add wrapper to simplified ChatML template for Qwen3 Qwen3's original template auto-injects \n\n\n\n before assistant content. Match this in our simplified template by adding the think wrapper when content doesn't already start with . Minor difference from original: we add it to all assistant turns, while Qwen3 only adds to the last turn. This doesn't affect training. Co-Authored-By: Claude Opus 4.6 (1M context) Signed-off-by: Chenhan Yu --- modelopt/torch/utils/plugins/transformers_dataset.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/modelopt/torch/utils/plugins/transformers_dataset.py b/modelopt/torch/utils/plugins/transformers_dataset.py index acc02dee50..e4bbf730f9 100644 --- a/modelopt/torch/utils/plugins/transformers_dataset.py +++ b/modelopt/torch/utils/plugins/transformers_dataset.py @@ -186,8 +186,14 @@ def _post_process_chat_template(self): "{% elif message['role'] == 'user' %}" "{{ '<|im_start|>user\n' + message['content'] + '<|im_end|>\n' }}" "{% elif message['role'] == 'assistant' %}" - "{{ '<|im_start|>assistant\n' }}{% generation %}{{ message['content'] }}" - "{% endgeneration %}{{ '<|im_end|>\n' }}" + "{{ '<|im_start|>assistant\n' }}" + "{% generation %}" + "{% if not message['content'].startswith('') %}" + "{{ '\n\n\n\n' }}" + "{% endif %}" + "{{ message['content'] }}" + "{% endgeneration %}" + "{{ '<|im_end|>\n' }}" "{% endif %}" "{% endfor %}" "{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}" From ae2e7bd5e99ebbff835190f3b1bd77f360905401 Mon Sep 17 00:00:00 2001 From: Chenhan Yu Date: Fri, 3 Apr 2026 11:04:36 -0700 Subject: [PATCH 50/72] fix: remove think wrapper from simplified ChatML template MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The original Qwen3 template adds \n\n\n\n only to the last assistant turn, not all turns. Rather than replicating this complex logic, keep the simplified template clean — just output message content as-is. Training data already contains blocks when present. Llama3 template has no think logic at all. Co-Authored-By: Claude Opus 4.6 (1M context) Signed-off-by: Chenhan Yu --- modelopt/torch/utils/plugins/transformers_dataset.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/modelopt/torch/utils/plugins/transformers_dataset.py b/modelopt/torch/utils/plugins/transformers_dataset.py index e4bbf730f9..e806003c4b 100644 --- a/modelopt/torch/utils/plugins/transformers_dataset.py +++ b/modelopt/torch/utils/plugins/transformers_dataset.py @@ -188,9 +188,6 @@ def _post_process_chat_template(self): "{% elif message['role'] == 'assistant' %}" "{{ '<|im_start|>assistant\n' }}" "{% generation %}" - "{% if not message['content'].startswith('') %}" - "{{ '\n\n\n\n' }}" - "{% endif %}" "{{ message['content'] }}" "{% endgeneration %}" "{{ '<|im_end|>\n' }}" From 82eedb202f7868c19a56a3f24072719072e2faa2 Mon Sep 17 00:00:00 2001 From: Chenhan Yu Date: Fri, 3 Apr 2026 11:10:53 -0700 Subject: [PATCH 51/72] feat: add chatml_think template variant for Qwen3 think injection Qwen3's original template adds \n\n\n\n to the last assistant turn when content doesn't start with . Detect this by checking if '' appears in the original template and use the chatml_think variant which replicates this behavior exactly. Models without think logic (Llama3, basic ChatML) use the plain chatml template. All three samples now match the original tokenization. Co-Authored-By: Claude Opus 4.6 (1M context) Signed-off-by: Chenhan Yu --- .../utils/plugins/transformers_dataset.py | 24 ++++++++++++++++++- 1 file changed, 23 insertions(+), 1 deletion(-) diff --git a/modelopt/torch/utils/plugins/transformers_dataset.py b/modelopt/torch/utils/plugins/transformers_dataset.py index e806003c4b..95830d9f98 100644 --- a/modelopt/torch/utils/plugins/transformers_dataset.py +++ b/modelopt/torch/utils/plugins/transformers_dataset.py @@ -179,6 +179,7 @@ def _post_process_chat_template(self): # assistant content for loss masking. The training data should already contain # the full assistant response including blocks. _GENERATION_TEMPLATES = { + # Basic ChatML without injection "chatml": ( "{% for message in messages %}" "{% if message['role'] == 'system' %}" @@ -195,6 +196,26 @@ def _post_process_chat_template(self): "{% endfor %}" "{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}" ), + # ChatML with wrapper on last assistant turn (Qwen3-style) + "chatml_think": ( + "{% for message in messages %}" + "{% if message['role'] == 'system' %}" + "{{ '<|im_start|>system\n' + message['content'] + '<|im_end|>\n' }}" + "{% elif message['role'] == 'user' %}" + "{{ '<|im_start|>user\n' + message['content'] + '<|im_end|>\n' }}" + "{% elif message['role'] == 'assistant' %}" + "{{ '<|im_start|>assistant\n' }}" + "{% generation %}" + "{% if loop.last and not message['content'].startswith('') %}" + "{{ '\n\n\n\n' }}" + "{% endif %}" + "{{ message['content'] }}" + "{% endgeneration %}" + "{{ '<|im_end|>\n' }}" + "{% endif %}" + "{% endfor %}" + "{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}" + ), "llama3": ( "{% for message in messages %}" "{% if message['role'] == 'system' %}" @@ -227,7 +248,8 @@ def _ensure_generation_tags(self): # Detect template style and replace with generation-tagged version old_template = template if "<|im_start|>" in template and "<|im_end|>" in template: - style = "chatml" + # Check if original template injects (Qwen3-style) + style = "chatml_think" if "" in template else "chatml" elif "<|start_header_id|>" in template and "<|eot_id|>" in template: style = "llama3" else: From 4ebb9deba4717897c7e28c23eb9e045ef7291afb Mon Sep 17 00:00:00 2001 From: Chenhan Yu Date: Fri, 3 Apr 2026 11:11:35 -0700 Subject: [PATCH 52/72] docs: document simplified generation templates and limitations Clearly document what the simplified chat templates do, what is preserved vs dropped, and limitations for tool-use and multi-step reasoning data. Also document how to use a custom template. Co-Authored-By: Claude Opus 4.6 (1M context) Signed-off-by: Chenhan Yu --- .../utils/plugins/transformers_dataset.py | 47 +++++++++++++++++-- 1 file changed, 43 insertions(+), 4 deletions(-) diff --git a/modelopt/torch/utils/plugins/transformers_dataset.py b/modelopt/torch/utils/plugins/transformers_dataset.py index 95830d9f98..d93805708e 100644 --- a/modelopt/torch/utils/plugins/transformers_dataset.py +++ b/modelopt/torch/utils/plugins/transformers_dataset.py @@ -175,11 +175,50 @@ def _post_process_chat_template(self): ) # Simplified chat templates with {% generation %} tags for answer_only_loss. - # These drop complex features (tool_calls, thinking parsing) but correctly mark - # assistant content for loss masking. The training data should already contain - # the full assistant response including blocks. + # + # PURPOSE: + # HuggingFace's return_assistant_tokens_mask requires {% generation %} / + # {% endgeneration %} tags in the Jinja chat template to identify which tokens + # belong to assistant responses. Many models (Qwen3, Llama3) ship without these + # tags. These simplified templates add them so that answer_only_loss works + # reliably without regex fallbacks. + # + # HOW IT WORKS: + # When answer_only_loss=True, _ensure_generation_tags() detects the model's + # template style (ChatML, Llama3) and replaces the tokenizer's chat_template + # with one of these simplified versions. The {% generation %} tags tell HF + # exactly which tokens are assistant content for loss masking. + # + # WHAT IS PRESERVED: + # - System / user / assistant role formatting (exact token match) + # - Multi-turn conversation structure + # - block injection on last assistant turn (Qwen3-style, chatml_think) + # - Content is output as-is — training data with blocks is handled correctly + # + # WHAT IS DROPPED (vs original model templates): + # - Tool call formatting (tool_call XML tags, function signatures) + # - Multi-step tool response handling + # - reasoning_content vs content splitting logic + # - enable_thinking parameter support + # - VLM/multimodal content handling + # + # LIMITATIONS: + # - Training data with tool_call messages will not be formatted correctly. + # Use the original template with manually added {% generation %} tags for + # tool-use training data. + # - The chatml_think variant adds \n\n\n\n only to the last + # assistant turn (matching Qwen3 behavior). Non-last turns without + # in their content will differ from the original template which also + # conditionally adds think wrappers based on multi-step reasoning context. + # - Only ChatML (<|im_start|>/<|im_end|>) and Llama3 + # (<|start_header_id|>/<|eot_id|>) styles are supported. Other template + # styles fall back to regex-based assistant span detection. + # + # TO USE A CUSTOM TEMPLATE INSTEAD: + # Pass chat_template= to LanguageDataCollator with your own template that + # includes {% generation %}...{% endgeneration %} around assistant content. _GENERATION_TEMPLATES = { - # Basic ChatML without injection + # Basic ChatML without injection (Phi, older Qwen, generic ChatML) "chatml": ( "{% for message in messages %}" "{% if message['role'] == 'system' %}" From 45615154509150f27f393cd9d59cef8668c99d0e Mon Sep 17 00:00:00 2001 From: Chenhan Yu Date: Fri, 3 Apr 2026 11:29:25 -0700 Subject: [PATCH 53/72] fix: ensure zero-loss path has gradient for DDP sync The early-exit zero-loss path used target_hidden (computed under no_grad) which has no gradient graph, causing 'does not require grad' error. Use dflash_module.fc.weight instead to keep DDP happy. Co-Authored-By: Claude Opus 4.6 (1M context) Signed-off-by: Chenhan Yu --- modelopt/torch/speculative/plugins/hf_dflash.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/modelopt/torch/speculative/plugins/hf_dflash.py b/modelopt/torch/speculative/plugins/hf_dflash.py index e8690266f6..3d58dff4bd 100644 --- a/modelopt/torch/speculative/plugins/hf_dflash.py +++ b/modelopt/torch/speculative/plugins/hf_dflash.py @@ -655,11 +655,9 @@ def forward( n_blocks = anchor_positions.shape[1] if n_blocks == 0 or not block_keep_mask.any(): - loss = ( - self._base_model_lm_head(target_hidden[:, :1, : self.config.hidden_size]).sum() - * 0.0 - ) - return ModelOutput(loss=loss, logits=base_outputs.logits, train_acc=[[0.0]]) + # Zero loss that still flows through dflash_module for DDP gradient sync + dummy = self.dflash_module.fc.weight.sum() * 0.0 + return ModelOutput(loss=dummy, logits=base_outputs.logits, train_acc=[[0.0]]) # 4. Create noise embeddings: anchor token at block start, mask_token elsewhere noise_ids = torch.full( From 047ba1d5ff7ff4557f0c5e8fed6a2c07b8892e6d Mon Sep 17 00:00:00 2001 From: Chenhan Yu Date: Fri, 3 Apr 2026 12:32:44 -0700 Subject: [PATCH 54/72] fix: prefer conversations field when messages lacks assistant turn The Speculative-Decoding dataset has both 'messages' (prompt only) and 'conversations' (prompt + response) fields. The collator took 'messages' first, missing the assistant response entirely. Now checks if 'messages' has an assistant turn, otherwise falls back to 'conversations'. Co-Authored-By: Claude Opus 4.6 (1M context) Signed-off-by: Chenhan Yu --- .../utils/plugins/transformers_dataset.py | 21 +++++++++++-------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/modelopt/torch/utils/plugins/transformers_dataset.py b/modelopt/torch/utils/plugins/transformers_dataset.py index d93805708e..dd911a26a7 100644 --- a/modelopt/torch/utils/plugins/transformers_dataset.py +++ b/modelopt/torch/utils/plugins/transformers_dataset.py @@ -450,15 +450,18 @@ def __call__(self, examples): batch.append(text) else: messages = example.get("messages", None) - if messages is None: - conversations = example.get("conversations", None) - if conversations is None: - raise ValueError( - "The sample must in either OpenAI messages format or ShareGPT conversations format." - ) - else: - messages = _sharegpt_to_openai_messages(conversations) - batch.append(messages) + conversations = example.get("conversations", None) + # Prefer whichever has an assistant turn for training + if messages and any(m.get("role") == "assistant" for m in messages): + batch.append(messages) + elif conversations: + batch.append(_sharegpt_to_openai_messages(conversations)) + elif messages: + batch.append(messages) + else: + raise ValueError( + "The sample must in either OpenAI messages format or ShareGPT conversations format." + ) return self._process_chat_sample(batch) From bdcc0dead252339d2218ade1f452b3b86d274cca Mon Sep 17 00:00:00 2001 From: Chenhan Yu Date: Fri, 3 Apr 2026 14:02:59 -0700 Subject: [PATCH 55/72] cleanup: remove debug prints and regex fallback in DFlash and dataset collator - Remove debug prints in hf_dflash.py (_get_attn_fn and modify) while keeping informational prints (mask_token_id and base forward) - Add "Legacy: used for inference only" comment on create_dflash_attention_mask and create_dflash_loss_mask - Remove _apply_answer_only_labels regex fallback in transformers_dataset.py; raise ValueError when assistant_masks is missing/empty - Add validation for missing assistant turns in __call__ - Make _ensure_generation_tags warnings more prominent with === WARNING === prefix Co-Authored-By: Claude Opus 4.6 (1M context) Signed-off-by: Chenhan Yu --- .../torch/speculative/plugins/hf_dflash.py | 13 +-- .../utils/plugins/transformers_dataset.py | 101 +++++------------- 2 files changed, 33 insertions(+), 81 deletions(-) diff --git a/modelopt/torch/speculative/plugins/hf_dflash.py b/modelopt/torch/speculative/plugins/hf_dflash.py index 3d58dff4bd..9dc34a3cbf 100644 --- a/modelopt/torch/speculative/plugins/hf_dflash.py +++ b/modelopt/torch/speculative/plugins/hf_dflash.py @@ -151,13 +151,10 @@ def _get_attn_fn(self): impl = getattr(self.config, "_attn_implementation", "eager") if impl and impl != "eager" and impl in ALL_ATTENTION_FUNCTIONS: self._attn_fn = ALL_ATTENTION_FUNCTIONS[impl] - print(f"[DFlash] attn_fn resolved to: {impl} -> {self._attn_fn.__name__}") else: self._attn_fn = self._eager_attention - print(f"[DFlash] attn_fn fallback to eager (impl={impl})") - except (ImportError, AttributeError) as e: + except (ImportError, AttributeError): self._attn_fn = self._eager_attention - print(f"[DFlash] attn_fn fallback to eager (error: {e})") return self._attn_fn def _eager_attention(self, module, q, k, v, attention_mask, **kwargs): @@ -294,7 +291,9 @@ def forward(self, noise_embedding, target_hidden, position_ids, attention_mask=N return self.norm(hidden_states) -def create_dflash_attention_mask(seq_len, block_size, device, dtype): +def create_dflash_attention_mask( + seq_len, block_size, device, dtype +): # Legacy: used for inference only """Create [L, 2L] attention mask matching SpecForge. Context (cols 0..L-1): Block B sees blocks 0..B-1 (strictly previous). @@ -323,7 +322,7 @@ def create_dflash_attention_mask(seq_len, block_size, device, dtype): return full_mask.unsqueeze(0).unsqueeze(0) # [1, 1, L, 2L] -def create_dflash_loss_mask(seq_len, block_size, device): +def create_dflash_loss_mask(seq_len, block_size, device): # Legacy: used for inference only """Create loss mask: exclude Block 0 and block starts.""" positions = torch.arange(seq_len, device=device) block_ids = positions // block_size @@ -508,8 +507,6 @@ def modify(self, config): _MLP_CLS, _NORM_CLS, _ROTARY_CLS, _rotate_half = _resolve_model_components( getattr(base_config, "model_type", "llama") ) - print(f"DFlash: using {_MLP_CLS.__name__} from {base_config.model_type}") - self.dflash_module = DFlashModule(self.dflash_config) self.dflash_module.to(self._base_model.dtype).to( next(self._base_model.layers[-1].parameters()).device diff --git a/modelopt/torch/utils/plugins/transformers_dataset.py b/modelopt/torch/utils/plugins/transformers_dataset.py index dd911a26a7..274fac2efc 100644 --- a/modelopt/torch/utils/plugins/transformers_dataset.py +++ b/modelopt/torch/utils/plugins/transformers_dataset.py @@ -293,8 +293,8 @@ def _ensure_generation_tags(self): style = "llama3" else: print_rank_0( - "WARNING: Cannot auto-inject {% generation %} tags for this chat template. " - "answer_only_loss will use regex fallback. Consider providing a template " + "=== WARNING === Cannot auto-inject {% generation %} tags for this chat " + "template. answer_only_loss will not work correctly. Provide a template " "with {% generation %} tags via the chat_template parameter." ) return @@ -330,8 +330,8 @@ def _ensure_generation_tags(self): # Revert on failure self.tokenizer.chat_template = old_template print_rank_0( - f"WARNING: Failed to apply {style} generation template. " - "Using regex fallback for answer_only_loss." + f"=== WARNING === Failed to apply {style} generation template. " + "answer_only_loss will not work correctly." ) def _process_chat_sample(self, examples: list): @@ -356,78 +356,22 @@ def _process_chat_sample(self, examples: list): if isinstance(assistant_mask, torch.Tensor) and assistant_mask.any(): labels[assistant_mask == 0] = IGNORE_TOKEN_ID else: - # Fallback: derive from formatted text using regex - labels = self._apply_answer_only_labels(examples, labels, input_ids) + raise ValueError( + "answer_only_loss requires {% generation %} tags in the chat " + "template but assistant_masks is empty. Either add " + "{% generation %}...{% endgeneration %} tags to your chat " + "template or pass a chat_template with generation tags." + ) else: - labels = self._apply_answer_only_labels(examples, labels, input_ids) + raise ValueError( + "answer_only_loss requires {% generation %} tags in the chat " + "template but assistant_masks was not returned by the tokenizer. " + "Either add {% generation %}...{% endgeneration %} tags to your " + "chat template or pass a chat_template with generation tags." + ) tokenized_examples["labels"] = labels return tokenized_examples - def _apply_answer_only_labels(self, examples, labels, input_ids): - """Derive response-only labels by finding assistant spans in formatted text. - - Uses regex to find assistant response spans in the chat-template-formatted text, - then maps character positions to token positions via offset mapping. - Similar to SpecForge's _apply_loss_mask_from_chat_template. - """ - import re - - for batch_idx, conversation in enumerate(examples): - # Format with chat template - formatted = self.tokenizer.apply_chat_template( - conversation, tokenize=False, add_generation_prompt=False - ) - - # Tokenize with offset mapping - try: - encoding = self.tokenizer( - formatted, - return_offsets_mapping=True, - max_length=self.train_len, - truncation=True, - add_special_tokens=False, - ) - offsets = encoding["offset_mapping"] - except Exception: - # Tokenizer doesn't support offset mapping — keep all labels - continue - - # Find assistant response spans - # Common patterns across chat templates - # Try to detect the assistant marker from the formatted text - assistant_markers = [ - r"<\|im_start\|>assistant\n(.*?)(?:<\|im_end\|>|$)", # Qwen/ChatML - r"<\|start_header_id\|>assistant<\|end_header_id\|>\n\n(.*?)(?:<\|eot_id\|>|$)", # Llama3 - r"\[/INST\](.*?)(?:|$)", # Llama2 - r"assistant\n(.*?)(?:\n\n|$)", # Generic - ] - - found = False - for pattern in assistant_markers: - matches = list(re.finditer(pattern, formatted, re.DOTALL)) - if matches: - # Mask all tokens, then unmask assistant spans - labels[batch_idx, :] = IGNORE_TOKEN_ID - for match in matches: - start_char = match.start(1) - end_char = match.end(1) - for tok_idx, (tok_start, tok_end) in enumerate(offsets): - if tok_idx >= labels.shape[1]: - break - if tok_start >= start_char and tok_end <= end_char: - # Restore the shifted label for this position - if tok_idx < input_ids.shape[1] - 1: - labels[batch_idx, tok_idx] = input_ids[batch_idx, tok_idx + 1] - found = True - break - - if not found: - # No assistant pattern found — mask all labels to avoid - # training on system/user tokens which inflates accuracy - labels[batch_idx, :] = IGNORE_TOKEN_ID - - return labels - def _process_text_sample(self, examples: list): tokenized_examples = self.tokenizer( examples, @@ -455,8 +399,19 @@ def __call__(self, examples): if messages and any(m.get("role") == "assistant" for m in messages): batch.append(messages) elif conversations: - batch.append(_sharegpt_to_openai_messages(conversations)) + converted = _sharegpt_to_openai_messages(conversations) + if not any(m.get("role") == "assistant" for m in converted): + raise ValueError( + "Conversation has no assistant turn. Each sample must contain " + "at least one assistant message for training." + ) + batch.append(converted) elif messages: + if not any(m.get("role") == "assistant" for m in messages): + raise ValueError( + "Conversation has no assistant turn. Each sample must contain " + "at least one assistant message for training." + ) batch.append(messages) else: raise ValueError( From e5260164ad8e6fc447958d25bf31ef902c0d5116 Mon Sep 17 00:00:00 2001 From: Chenhan Yu Date: Fri, 3 Apr 2026 14:13:01 -0700 Subject: [PATCH 56/72] fix: unwrap DDP model for AR validation to avoid deadlock The AR validation callback ran model.forward() on rank 0 only, but DDP model forward triggers collective ops that require all ranks. Now unwraps the DDP model (model.module) before validation, so forward runs without collective hooks. Other ranks wait at barrier. Co-Authored-By: Claude Opus 4.6 (1M context) Signed-off-by: Chenhan Yu --- examples/speculative_decoding/eagle_utils.py | 22 +++++++++++++------- 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/examples/speculative_decoding/eagle_utils.py b/examples/speculative_decoding/eagle_utils.py index 6f1ba87bbd..62187df700 100644 --- a/examples/speculative_decoding/eagle_utils.py +++ b/examples/speculative_decoding/eagle_utils.py @@ -243,10 +243,10 @@ def on_log(self, args, state, control, **kwargs): return control def on_step_end(self, args, state, control, **kwargs): - """Run AR validation periodically, if available. + """Run AR validation periodically on rank 0. - Only runs on rank 0 to avoid DDP deadlock — other ranks skip and - synchronize via barrier. + Uses the unwrapped model (no DDP) to avoid collective op deadlocks. + Other ranks wait at a barrier while rank 0 validates. """ if self.ar_validate_steps <= 0: return control @@ -254,17 +254,25 @@ def on_step_end(self, args, state, control, **kwargs): if is_master(): print_rank_0("Running AR validation...") try: + # Unwrap DDP/FSDP to get the raw model — avoids triggering + # collective ops that would deadlock with other ranks at barrier + model = kwargs["model"] + raw_model = model.module if hasattr(model, "module") else model + was_training = raw_model.training + raw_model.eval() ars = validate_ar( - model=kwargs["model"], + model=raw_model, tokenizer=kwargs["processing_class"], ds=load_dataset("HuggingFaceH4/mt_bench_prompts")["train"], - device=kwargs["model"].device, + device=next(raw_model.parameters()).device, ) + if was_training: + raw_model.train() print_rank_0(f"Step {state.global_step} AR: {sum(ars) / len(ars):.4f}") if wandb: wandb.log({"validate_ar": sum(ars) / len(ars)}, step=state.global_step) - except Exception: - print_rank_0("AR validation not available.") + except Exception as e: + print_rank_0(f"AR validation failed: {e}") # Barrier to synchronize all ranks after validation if torch.distributed.is_initialized(): torch.distributed.barrier() From 633da5541f759c99cef77bde82cbe58dd828e6ec Mon Sep 17 00:00:00 2001 From: Chenhan Yu Date: Fri, 3 Apr 2026 15:10:03 -0700 Subject: [PATCH 57/72] fix: skip samples without assistant turns instead of crashing During distributed training, a single bad sample would crash all ranks. Now warns and skips samples without assistant turns. Also handle the case where all assistant content is truncated by masking all labels instead of raising ValueError. Co-Authored-By: Claude Opus 4.6 (1M context) Signed-off-by: Chenhan Yu --- .../utils/plugins/transformers_dataset.py | 30 +++++++++---------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/modelopt/torch/utils/plugins/transformers_dataset.py b/modelopt/torch/utils/plugins/transformers_dataset.py index 274fac2efc..00b4b07b91 100644 --- a/modelopt/torch/utils/plugins/transformers_dataset.py +++ b/modelopt/torch/utils/plugins/transformers_dataset.py @@ -350,24 +350,18 @@ def _process_chat_sample(self, examples: list): labels = input_ids.new_full(input_ids.shape, IGNORE_TOKEN_ID) labels[..., :-1] = input_ids[..., 1:] if self.answer_only_loss: - # Try tokenizer's assistant_masks first if "assistant_masks" in tokenized_examples: assistant_mask = tokenized_examples["assistant_masks"] if isinstance(assistant_mask, torch.Tensor) and assistant_mask.any(): labels[assistant_mask == 0] = IGNORE_TOKEN_ID else: - raise ValueError( - "answer_only_loss requires {% generation %} tags in the chat " - "template but assistant_masks is empty. Either add " - "{% generation %}...{% endgeneration %} tags to your chat " - "template or pass a chat_template with generation tags." - ) + # All assistant content truncated or no assistant in batch — mask all + labels[:] = IGNORE_TOKEN_ID else: raise ValueError( "answer_only_loss requires {% generation %} tags in the chat " "template but assistant_masks was not returned by the tokenizer. " - "Either add {% generation %}...{% endgeneration %} tags to your " - "chat template or pass a chat_template with generation tags." + "Ensure _ensure_generation_tags() ran successfully." ) tokenized_examples["labels"] = labels return tokenized_examples @@ -401,23 +395,29 @@ def __call__(self, examples): elif conversations: converted = _sharegpt_to_openai_messages(conversations) if not any(m.get("role") == "assistant" for m in converted): - raise ValueError( - "Conversation has no assistant turn. Each sample must contain " - "at least one assistant message for training." + print_rank_0( + "=== WARNING === Skipping sample with no assistant turn in conversations." ) + continue batch.append(converted) elif messages: if not any(m.get("role") == "assistant" for m in messages): - raise ValueError( - "Conversation has no assistant turn. Each sample must contain " - "at least one assistant message for training." + print_rank_0( + "=== WARNING === Skipping sample with no assistant turn in messages." ) + continue batch.append(messages) else: raise ValueError( "The sample must in either OpenAI messages format or ShareGPT conversations format." ) + if not batch: + raise ValueError( + "All samples in batch were skipped (no assistant turns). " + "Check that your training data contains assistant responses." + ) + return self._process_chat_sample(batch) From 43afb0632e1da210c8e4b5f274ff18e7e92328b2 Mon Sep 17 00:00:00 2001 From: Chenhan Yu Date: Fri, 3 Apr 2026 15:26:39 -0700 Subject: [PATCH 58/72] fix: handle empty batch with dummy assistant turn Co-Authored-By: Claude Opus 4.6 (1M context) Signed-off-by: Chenhan Yu --- modelopt/torch/utils/plugins/transformers_dataset.py | 7 +++---- .../launcher/examples/Qwen/Qwen3-8B/hf_online_dflash.yaml | 4 ++-- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/modelopt/torch/utils/plugins/transformers_dataset.py b/modelopt/torch/utils/plugins/transformers_dataset.py index 00b4b07b91..b9a5367cd9 100644 --- a/modelopt/torch/utils/plugins/transformers_dataset.py +++ b/modelopt/torch/utils/plugins/transformers_dataset.py @@ -413,10 +413,9 @@ def __call__(self, examples): ) if not batch: - raise ValueError( - "All samples in batch were skipped (no assistant turns). " - "Check that your training data contains assistant responses." - ) + # All samples skipped — create a dummy batch with all-masked labels + # so the training step produces zero loss without crashing DDP + batch = [[{"role": "user", "content": ""}, {"role": "assistant", "content": ""}]] # type: ignore[list-item] return self._process_chat_sample(batch) diff --git a/tools/launcher/examples/Qwen/Qwen3-8B/hf_online_dflash.yaml b/tools/launcher/examples/Qwen/Qwen3-8B/hf_online_dflash.yaml index c72b5aec48..5f094f3a14 100644 --- a/tools/launcher/examples/Qwen/Qwen3-8B/hf_online_dflash.yaml +++ b/tools/launcher/examples/Qwen/Qwen3-8B/hf_online_dflash.yaml @@ -21,8 +21,8 @@ pipeline: args: - --data /hf-local/modelopt/Speculative-Decoding-Dataset-v1-Qwen3-8B/sample-100K.jsonl - --output_dir /scratchspace/dflash - - --num_epochs 3 - - --lr 1e-4 + - --num_epochs 1 + - --lr 6e-4 - --training_seq_len 512 - --save_steps 500000 - --log_steps 100 From 90e9b4b495fc4e4f291257134c47032c499f31a8 Mon Sep 17 00:00:00 2001 From: Chenhan Yu Date: Fri, 3 Apr 2026 16:08:56 -0700 Subject: [PATCH 59/72] fix: AR validation deadlock - eval all ranks, validate on rank 0 Co-Authored-By: Claude Opus 4.6 (1M context) Signed-off-by: Chenhan Yu --- examples/speculative_decoding/eagle_utils.py | 41 +++++++++++--------- 1 file changed, 23 insertions(+), 18 deletions(-) diff --git a/examples/speculative_decoding/eagle_utils.py b/examples/speculative_decoding/eagle_utils.py index 62187df700..5549d6d248 100644 --- a/examples/speculative_decoding/eagle_utils.py +++ b/examples/speculative_decoding/eagle_utils.py @@ -243,37 +243,42 @@ def on_log(self, args, state, control, **kwargs): return control def on_step_end(self, args, state, control, **kwargs): - """Run AR validation periodically on rank 0. + """Run AR validation periodically. - Uses the unwrapped model (no DDP) to avoid collective op deadlocks. - Other ranks wait at a barrier while rank 0 validates. + Only rank 0 with CUDA device 0 runs validation and logs results. + All other ranks skip to avoid DDP deadlock — the validation uses + the unwrapped model with torch.no_grad() which doesn't trigger + collective ops. A barrier syncs all ranks afterward. """ if self.ar_validate_steps <= 0: return control if state.global_step % self.ar_validate_steps == 0 and state.global_step > 0: + # All ranks must participate to avoid DDP deadlock. + # Only rank 0 does real validation; others do a no-op. + model = kwargs["model"] + raw_model = model.module if hasattr(model, "module") else model + was_training = raw_model.training + raw_model.eval() + if is_master(): print_rank_0("Running AR validation...") try: - # Unwrap DDP/FSDP to get the raw model — avoids triggering - # collective ops that would deadlock with other ranks at barrier - model = kwargs["model"] - raw_model = model.module if hasattr(model, "module") else model - was_training = raw_model.training - raw_model.eval() - ars = validate_ar( - model=raw_model, - tokenizer=kwargs["processing_class"], - ds=load_dataset("HuggingFaceH4/mt_bench_prompts")["train"], - device=next(raw_model.parameters()).device, - ) - if was_training: - raw_model.train() + with torch.no_grad(): + ars = validate_ar( + model=raw_model, + tokenizer=kwargs["processing_class"], + ds=load_dataset("/hf-local/HuggingFaceH4/mt_bench_prompts")["train"], + device=torch.device("cuda", 0), + num_samples=8, + ) print_rank_0(f"Step {state.global_step} AR: {sum(ars) / len(ars):.4f}") if wandb: wandb.log({"validate_ar": sum(ars) / len(ars)}, step=state.global_step) except Exception as e: print_rank_0(f"AR validation failed: {e}") - # Barrier to synchronize all ranks after validation + + if was_training: + raw_model.train() if torch.distributed.is_initialized(): torch.distributed.barrier() return control From 870db23a836ca6817a97a204f3efc4267e7657ee Mon Sep 17 00:00:00 2001 From: Chenhan Yu Date: Fri, 3 Apr 2026 16:26:43 -0700 Subject: [PATCH 60/72] feat: add TensorBoard logging for DFlash training Co-Authored-By: Claude Opus 4.6 (1M context) Signed-off-by: Chenhan Yu --- examples/speculative_decoding/eagle_utils.py | 16 ++++++++-------- examples/speculative_decoding/launch_train.sh | 2 ++ 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/examples/speculative_decoding/eagle_utils.py b/examples/speculative_decoding/eagle_utils.py index 5549d6d248..eb5cd88ea5 100644 --- a/examples/speculative_decoding/eagle_utils.py +++ b/examples/speculative_decoding/eagle_utils.py @@ -225,18 +225,18 @@ def on_log(self, args, state, control, **kwargs): est_ar += acc_cumprod print_rank_0(f"Step {state.global_step} Estimated Training AR: {est_ar:.4f}") + # Log accuracy to HF Trainer's logs dict (picked up by TensorBoard) + logs = kwargs.get("logs") or {} + for i, draft_acc in enumerate(average_acc): + for j, step_acc in enumerate(draft_acc): + logs[f"train_acc/parallel_{i}_step_{j}"] = float(step_acc) + if self.estimate_ar: + logs["estimated_training_ar"] = est_ar + # log to wandb if wandb and is_master(): - logs = kwargs.get("logs") or {} if logs: wandb.log({k: v for k, v in logs.items() if v is not None}, step=state.global_step) - for i, draft_acc in enumerate(average_acc): - for j, step_acc in enumerate(draft_acc): - wandb.log( - {f"parallel_{i}_step_{j}_train_acc": step_acc}, step=state.global_step - ) - if self.estimate_ar: - wandb.log({"estimated_training_ar": est_ar}, step=state.global_step) # reset training_accs state.training_accs = [] diff --git a/examples/speculative_decoding/launch_train.sh b/examples/speculative_decoding/launch_train.sh index 8dc9ab216e..b3b99b4c09 100755 --- a/examples/speculative_decoding/launch_train.sh +++ b/examples/speculative_decoding/launch_train.sh @@ -317,6 +317,8 @@ CMD="accelerate launch $MULTI_NODE_ARGS --mixed_precision bf16 ${SCRIPT_DIR}/mai --warmup_steps 100 \ --lr_scheduler_type linear \ --logging_steps $LOG_STEPS \ + --report_to tensorboard \ + --logging_dir $OUTPUT_DIR/tensorboard \ --tf32 True \ $DATA_ARGS \ --disable_tqdm $DISABLE_TQDM \ From 0d3f1faaa8d4f01fd95cd32d270caaf793d25088 Mon Sep 17 00:00:00 2001 From: Chenhan Yu Date: Fri, 3 Apr 2026 16:30:30 -0700 Subject: [PATCH 61/72] fix: skip AR validation during DDP training to prevent deadlock MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit AR validation with DDP is fundamentally incompatible — pseudo_speculative_generate runs inference on rank 0 while other ranks deadlock on collective ops. Now detects world_size > 1 and skips with a one-time warning. AR validation still works for single-GPU and post-training (online_training.sh). Co-Authored-By: Claude Opus 4.6 (1M context) Signed-off-by: Chenhan Yu --- examples/speculative_decoding/eagle_utils.py | 56 ++++++++++---------- 1 file changed, 29 insertions(+), 27 deletions(-) diff --git a/examples/speculative_decoding/eagle_utils.py b/examples/speculative_decoding/eagle_utils.py index eb5cd88ea5..0267ac684c 100644 --- a/examples/speculative_decoding/eagle_utils.py +++ b/examples/speculative_decoding/eagle_utils.py @@ -243,44 +243,46 @@ def on_log(self, args, state, control, **kwargs): return control def on_step_end(self, args, state, control, **kwargs): - """Run AR validation periodically. + """Run AR validation periodically (single-GPU only). - Only rank 0 with CUDA device 0 runs validation and logs results. - All other ranks skip to avoid DDP deadlock — the validation uses - the unwrapped model with torch.no_grad() which doesn't trigger - collective ops. A barrier syncs all ranks afterward. + AR validation with DDP is not supported because pseudo_speculative_generate + runs only on rank 0 while other ranks deadlock waiting for collective ops. + When world_size > 1, AR validation is skipped with a one-time warning. + Use post-training AR validation instead (online_training.sh runs it after training). """ if self.ar_validate_steps <= 0: return control if state.global_step % self.ar_validate_steps == 0 and state.global_step > 0: - # All ranks must participate to avoid DDP deadlock. - # Only rank 0 does real validation; others do a no-op. + if torch.distributed.is_initialized() and torch.distributed.get_world_size() > 1: + if not hasattr(self, "_ar_ddp_warned"): + self._ar_ddp_warned = True + print_rank_0( + "=== WARNING === AR validation during training is not supported with " + "DDP (world_size > 1). Skipping. Use post-training AR validation." + ) + return control + model = kwargs["model"] raw_model = model.module if hasattr(model, "module") else model was_training = raw_model.training raw_model.eval() - - if is_master(): - print_rank_0("Running AR validation...") - try: - with torch.no_grad(): - ars = validate_ar( - model=raw_model, - tokenizer=kwargs["processing_class"], - ds=load_dataset("/hf-local/HuggingFaceH4/mt_bench_prompts")["train"], - device=torch.device("cuda", 0), - num_samples=8, - ) - print_rank_0(f"Step {state.global_step} AR: {sum(ars) / len(ars):.4f}") - if wandb: - wandb.log({"validate_ar": sum(ars) / len(ars)}, step=state.global_step) - except Exception as e: - print_rank_0(f"AR validation failed: {e}") - + print_rank_0("Running AR validation...") + try: + with torch.no_grad(): + ars = validate_ar( + model=raw_model, + tokenizer=kwargs["processing_class"], + ds=load_dataset("/hf-local/HuggingFaceH4/mt_bench_prompts")["train"], + device=next(raw_model.parameters()).device, + num_samples=8, + ) + print_rank_0(f"Step {state.global_step} AR: {sum(ars) / len(ars):.4f}") + if wandb: + wandb.log({"validate_ar": sum(ars) / len(ars)}, step=state.global_step) + except Exception as e: + print_rank_0(f"AR validation failed: {e}") if was_training: raw_model.train() - if torch.distributed.is_initialized(): - torch.distributed.barrier() return control From 9b76b7d67620c7363f325c4ad8006a1d772b82dd Mon Sep 17 00:00:00 2001 From: Chenhan Yu Date: Fri, 3 Apr 2026 20:14:30 -0700 Subject: [PATCH 62/72] feat: add DFlash export to z-lab compatible HF format Co-Authored-By: Claude Opus 4.6 (1M context) Signed-off-by: Chenhan Yu --- .../torch/export/plugins/hf_spec_export.py | 105 +++++++++++++++++- .../torch/speculative/plugins/hf_dflash.py | 6 + 2 files changed, 110 insertions(+), 1 deletion(-) diff --git a/modelopt/torch/export/plugins/hf_spec_export.py b/modelopt/torch/export/plugins/hf_spec_export.py index aca19a1580..29b2cf633f 100644 --- a/modelopt/torch/export/plugins/hf_spec_export.py +++ b/modelopt/torch/export/plugins/hf_spec_export.py @@ -27,7 +27,7 @@ from .hf_spec_configs import kimik2_eagle_template_config, llama_eagle_template_config -ALL_SPEC_MODES = ["eagle"] +ALL_SPEC_MODES = ["eagle", "dflash"] LLAMA_EAGLE_SINGLE_LAYER = { "required": { @@ -243,3 +243,106 @@ def _extract_state_dict(self, full_state_dict: dict): export_sd.pop(f"parallel_draft_heads.medusa_heads.{i}.{j}.linear.bias") ) return export_sd + + +class DFlashExporter(SpeculativeDecodingExporter): + """Draft model exporter for DFlash. + + Exports in z-lab compatible format: + - model.safetensors: draft module weights (no prefix) + - config.json: Qwen3-style config with dflash_config field + """ + + def __init__(self, model: nn.Module): + """Initialize the DFlashExporter.""" + super().__init__(model) + + def _extract_state_dict(self, full_state_dict: dict): + """Extract DFlash module weights, stripping the dflash_module prefix.""" + export_sd = {} + for key, value in full_state_dict.items(): + if "dflash_module." in key: + export_key = key.split("dflash_module.", 1)[1] + # Skip rotary embedding buffers (not needed, recomputed) + if "rotary_emb" in export_key: + continue + export_sd[export_key] = value.clone() + return export_sd + + def _export_config(self): + """Build config.json matching z-lab DFlash format.""" + model = self.model + base_config = ( + getattr(model.config, "text_config", None) + or getattr(model.config, "llm_config", None) + or model.config + ) + draft_config = model.dflash_config + + config = { + "architectures": ["DFlashDraftModel"], + "model_type": getattr(base_config, "model_type", "qwen3"), + "block_size": model.dflash_block_size, + "dflash_config": { + "mask_token_id": model.mask_token_id, + "target_layer_ids": list(model.target_layer_ids), + }, + # Architecture dimensions + "hidden_size": getattr(draft_config, "hidden_size", base_config.hidden_size), + "num_hidden_layers": draft_config.num_hidden_layers, + "num_attention_heads": getattr( + draft_config, "num_attention_heads", base_config.num_attention_heads + ), + "num_key_value_heads": getattr( + draft_config, "num_key_value_heads", base_config.num_key_value_heads + ), + "head_dim": getattr( + draft_config, + "head_dim", + base_config.hidden_size // base_config.num_attention_heads, + ), + "intermediate_size": getattr( + draft_config, "intermediate_size", base_config.intermediate_size + ), + "hidden_act": getattr(draft_config, "hidden_act", "silu"), + "rms_norm_eps": getattr(draft_config, "rms_norm_eps", 1e-6), + "vocab_size": base_config.vocab_size, + "max_position_embeddings": getattr(base_config, "max_position_embeddings", 32768), + "initializer_range": getattr(base_config, "initializer_range", 0.02), + "attention_bias": getattr(draft_config, "attention_bias", False), + "attention_dropout": getattr(draft_config, "attention_dropout", 0.0), + "rope_theta": getattr(base_config, "rope_theta", 1000000.0), + "tie_word_embeddings": False, + "torch_dtype": "bfloat16", + "num_target_layers": getattr(base_config, "num_hidden_layers", 36), + } + + # Add layer_types if present (Qwen3-style) + if hasattr(draft_config, "layer_types"): + config["layer_types"] = draft_config.layer_types + else: + config["layer_types"] = ["full_attention"] * draft_config.num_hidden_layers + + return config + + def export(self, export_dir: Path | str, dtype: torch.dtype | None = None): + """Export the DFlash draft model to deployment format.""" + export_dir = Path(export_dir) + export_dir.mkdir(parents=True, exist_ok=True) + + # Export state dict + full_sd = self.model.state_dict() + drafter_sd = self._extract_state_dict(full_sd) + if dtype is not None: + drafter_sd = {k: v.to(dtype) for k, v in drafter_sd.items()} + save_file(drafter_sd, f"{export_dir}/model.safetensors") + + # Export config + drafter_config = self._export_config() + with open(f"{export_dir}/config.json", "w") as f: + json.dump(drafter_config, f, indent=2) + + print( + f"Exported DFlash draft model: {len(drafter_sd)} tensors, " + f"config keys: {list(drafter_config.keys())[:5]}..." + ) diff --git a/modelopt/torch/speculative/plugins/hf_dflash.py b/modelopt/torch/speculative/plugins/hf_dflash.py index 9dc34a3cbf..9257f50e0d 100644 --- a/modelopt/torch/speculative/plugins/hf_dflash.py +++ b/modelopt/torch/speculative/plugins/hf_dflash.py @@ -543,6 +543,12 @@ def modify(self, config): self._original_forward_cls = original_cls print(f"DFlash: using {original_cls.__name__}.forward as base forward") + def get_exporter(self): + """Get the exporter for the DFlash draft model.""" + from modelopt.torch.export.plugins.hf_spec_export import DFlashExporter + + return DFlashExporter(self) + def _base_forward(self, **kwargs): """Call the original model's forward, bypassing DFlash wrapper.""" return self._original_forward_cls.forward(self, **kwargs) From 1cfd55865e1a9b3d3ebc9e6249ef688e657c1498 Mon Sep 17 00:00:00 2001 From: Chenhan Yu Date: Fri, 3 Apr 2026 20:18:58 -0700 Subject: [PATCH 63/72] fix: checkpoint resume + export-then-validate pipeline - Fix resume: add device_map='cpu' to checkpoint loading path to avoid meta tensor errors - Add export step to online_training.sh: after training, export DFlash checkpoint to z-lab HF format, then validate AR on the exported checkpoint - AR validation prefers exported checkpoint (no prefix) over training checkpoint (with prefix) Co-Authored-By: Claude Opus 4.6 (1M context) Signed-off-by: Chenhan Yu --- examples/speculative_decoding/main.py | 5 ++- .../launcher/common/dflash/online_training.sh | 40 +++++++++++++++++-- 2 files changed, 41 insertions(+), 4 deletions(-) diff --git a/examples/speculative_decoding/main.py b/examples/speculative_decoding/main.py index 8d3f5c1c0b..fd9cac0917 100644 --- a/examples/speculative_decoding/main.py +++ b/examples/speculative_decoding/main.py @@ -223,7 +223,10 @@ def train(): if checkpoint: with patch_transformers5_params_loading(): model = load_vlm_or_llm( - checkpoint, torch_dtype="auto", trust_remote_code=model_args.trust_remote_code + checkpoint, + torch_dtype="auto", + device_map="cpu", + trust_remote_code=model_args.trust_remote_code, ) tokenizer = transformers.AutoTokenizer.from_pretrained( checkpoint, trust_remote_code=model_args.trust_remote_code diff --git a/tools/launcher/common/dflash/online_training.sh b/tools/launcher/common/dflash/online_training.sh index 8b244e6e50..573f76325d 100644 --- a/tools/launcher/common/dflash/online_training.sh +++ b/tools/launcher/common/dflash/online_training.sh @@ -72,10 +72,37 @@ if [ "${NUM_AR_SAMPLES}" = "0" ]; then fi if [ -z "$OUTPUT_DIR" ]; then - echo "WARNING: --output_dir not found in args, skipping AR validation" + echo "WARNING: --output_dir not found in args, skipping export and AR validation" exit 0 fi +# Step 2: Export checkpoint to z-lab HF format +EXPORT_DIR=${OUTPUT_DIR}/export +echo "" +echo "=== Exporting DFlash checkpoint ===" +echo "Source: ${OUTPUT_DIR}" +echo "Export: ${EXPORT_DIR}" + +CUDA_VISIBLE_DEVICES=0 python3 -c " +import modelopt.torch.opt as mto +from modelopt.torch.export import export_speculative_decoding +from modelopt.torch.speculative.utils import load_vlm_or_llm, patch_transformers5_params_loading + +mto.enable_huggingface_checkpointing() +with patch_transformers5_params_loading(): + model = load_vlm_or_llm('${OUTPUT_DIR}', torch_dtype='auto', device_map='cpu', trust_remote_code=True) +model.eval() +import torch +with torch.inference_mode(): + export_speculative_decoding(model, export_dir='${EXPORT_DIR}') +print('Export complete') +" || echo "WARNING: Export failed, continuing with AR validation" + +echo "" +echo "Export contents:" +ls -la ${EXPORT_DIR}/ 2>/dev/null || echo "No export dir" + +# Step 3: AR Validation # Build mask_token_id config if [ -n "$DFLASH_MASK_TOKEN_ID" ]; then MASK_ARG="'mask_token_id': ${DFLASH_MASK_TOKEN_ID}," @@ -86,7 +113,14 @@ fi echo "" echo "=== DFlash AR Validation ===" echo "Target model: ${HF_MODEL_CKPT}" -echo "DFlash checkpoint: ${OUTPUT_DIR}" +# Prefer exported checkpoint (no prefix), fall back to training output (with prefix) +if [ -f "${EXPORT_DIR}/model.safetensors" ]; then + AR_CKPT=${EXPORT_DIR} + echo "Using exported checkpoint: ${AR_CKPT}" +else + AR_CKPT=${OUTPUT_DIR} + echo "Using training checkpoint: ${AR_CKPT}" +fi echo "Block size: ${DFLASH_BLOCK_SIZE}" echo "Draft layers: ${DFLASH_NUM_LAYERS}" echo "Samples: ${NUM_AR_SAMPLES}" @@ -119,7 +153,7 @@ mtsp.convert(model, [('dflash', config)]) # Load trained DFlash weights import glob from safetensors.torch import load_file -ckpt_files = sorted(glob.glob('${OUTPUT_DIR}/model*.safetensors')) +ckpt_files = sorted(glob.glob('${AR_CKPT}/model*.safetensors')) if ckpt_files: state = {} for f in ckpt_files: From 6a153bf9b6b36b391001a2f000ca266ac25f217d Mon Sep 17 00:00:00 2001 From: Chenhan Yu Date: Fri, 3 Apr 2026 20:57:10 -0700 Subject: [PATCH 64/72] feat: auto-detect HEAD_NODE_IP for multi-node DFlash training Co-Authored-By: Claude Opus 4.6 (1M context) Signed-off-by: Chenhan Yu --- tools/launcher/common/dflash/online_training.sh | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/tools/launcher/common/dflash/online_training.sh b/tools/launcher/common/dflash/online_training.sh index 573f76325d..1116955897 100644 --- a/tools/launcher/common/dflash/online_training.sh +++ b/tools/launcher/common/dflash/online_training.sh @@ -38,6 +38,13 @@ export PATH=$PATH:/workspace/.local/bin trap 'error_handler $0 $LINENO' ERR +# Auto-detect head node IP for multi-node training +if [ -n "$SLURM_NODELIST" ] && [ -z "$HEAD_NODE_IP" ]; then + HEAD_NODE_IP=$(scontrol show hostnames "$SLURM_NODELIST" | head -1) + export HEAD_NODE_IP + echo "Auto-detected HEAD_NODE_IP: ${HEAD_NODE_IP}" +fi + # Parse DFlash-specific args from the command line for AR validation DFLASH_BLOCK_SIZE=16 DFLASH_NUM_LAYERS=5 From c0c43305de5d77176c742e5b03d2f101bf7c7e4b Mon Sep 17 00:00:00 2001 From: Chenhan Yu Date: Fri, 3 Apr 2026 20:58:53 -0700 Subject: [PATCH 65/72] fix: use explicit bfloat16 and device_map for export loading Co-Authored-By: Claude Opus 4.6 (1M context) Signed-off-by: Chenhan Yu --- .../launcher/common/dflash/online_training.sh | 24 ++++++++++++------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/tools/launcher/common/dflash/online_training.sh b/tools/launcher/common/dflash/online_training.sh index 1116955897..1e617653fe 100644 --- a/tools/launcher/common/dflash/online_training.sh +++ b/tools/launcher/common/dflash/online_training.sh @@ -91,19 +91,27 @@ echo "Source: ${OUTPUT_DIR}" echo "Export: ${EXPORT_DIR}" CUDA_VISIBLE_DEVICES=0 python3 -c " +import torch import modelopt.torch.opt as mto from modelopt.torch.export import export_speculative_decoding from modelopt.torch.speculative.utils import load_vlm_or_llm, patch_transformers5_params_loading mto.enable_huggingface_checkpointing() -with patch_transformers5_params_loading(): - model = load_vlm_or_llm('${OUTPUT_DIR}', torch_dtype='auto', device_map='cpu', trust_remote_code=True) -model.eval() -import torch -with torch.inference_mode(): - export_speculative_decoding(model, export_dir='${EXPORT_DIR}') -print('Export complete') -" || echo "WARNING: Export failed, continuing with AR validation" +try: + with patch_transformers5_params_loading(): + model = load_vlm_or_llm( + '${OUTPUT_DIR}', + torch_dtype=torch.bfloat16, + device_map={'': 'cpu'}, + trust_remote_code=True, + ) + model.eval() + with torch.inference_mode(): + export_speculative_decoding(model, export_dir='${EXPORT_DIR}') + print('Export complete') +except Exception as e: + print(f'Export failed: {e}') +" || echo "WARNING: Export script failed, continuing with AR validation" echo "" echo "Export contents:" From 6684f474f7585c179d145c0f9cbe202edbf6e18c Mon Sep 17 00:00:00 2001 From: Chenhan Yu Date: Fri, 3 Apr 2026 21:14:08 -0700 Subject: [PATCH 66/72] fix: improve HEAD_NODE_IP auto-detection for multi-node Try scontrol first, then parse SLURM_JOB_NODELIST directly and resolve via getent hosts. Works both inside and outside containers. Co-Authored-By: Claude Opus 4.6 (1M context) Signed-off-by: Chenhan Yu --- tools/launcher/common/dflash/online_training.sh | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/tools/launcher/common/dflash/online_training.sh b/tools/launcher/common/dflash/online_training.sh index 1e617653fe..a9c780d95a 100644 --- a/tools/launcher/common/dflash/online_training.sh +++ b/tools/launcher/common/dflash/online_training.sh @@ -39,8 +39,15 @@ export PATH=$PATH:/workspace/.local/bin trap 'error_handler $0 $LINENO' ERR # Auto-detect head node IP for multi-node training -if [ -n "$SLURM_NODELIST" ] && [ -z "$HEAD_NODE_IP" ]; then - HEAD_NODE_IP=$(scontrol show hostnames "$SLURM_NODELIST" | head -1) +if [ -n "$SLURM_JOB_NODELIST" ] && [ -z "$HEAD_NODE_IP" ]; then + # Try scontrol first (works outside container), then parse SLURM env directly + HEAD_NODE_IP=$(scontrol show hostnames "$SLURM_JOB_NODELIST" 2>/dev/null | head -1) + if [ -z "$HEAD_NODE_IP" ]; then + # Parse nodelist directly: "node[001-002]" → "node001" + HEAD_NODE_IP=$(echo "$SLURM_JOB_NODELIST" | sed 's/\[.*//; s/,.*//; s/ .*//') + # Resolve to IP + HEAD_NODE_IP=$(getent hosts "$HEAD_NODE_IP" 2>/dev/null | awk '{print $1}' || echo "$HEAD_NODE_IP") + fi export HEAD_NODE_IP echo "Auto-detected HEAD_NODE_IP: ${HEAD_NODE_IP}" fi From dd6e2823e357684f524b391b08e8b8c8266f22c5 Mon Sep 17 00:00:00 2001 From: Chenhan Yu Date: Fri, 3 Apr 2026 21:25:57 -0700 Subject: [PATCH 67/72] fix: multi-method HEAD_NODE_IP detection for multi-node Try scontrol, SLURM_LAUNCH_NODE_IPADDR, Python socket resolution, and hostname -I as fallbacks. Should work inside containers where scontrol is not available. Co-Authored-By: Claude Opus 4.6 (1M context) Signed-off-by: Chenhan Yu --- .../launcher/common/dflash/online_training.sh | 29 ++++++++++++++----- 1 file changed, 22 insertions(+), 7 deletions(-) diff --git a/tools/launcher/common/dflash/online_training.sh b/tools/launcher/common/dflash/online_training.sh index a9c780d95a..48c66d9e01 100644 --- a/tools/launcher/common/dflash/online_training.sh +++ b/tools/launcher/common/dflash/online_training.sh @@ -39,14 +39,29 @@ export PATH=$PATH:/workspace/.local/bin trap 'error_handler $0 $LINENO' ERR # Auto-detect head node IP for multi-node training -if [ -n "$SLURM_JOB_NODELIST" ] && [ -z "$HEAD_NODE_IP" ]; then - # Try scontrol first (works outside container), then parse SLURM env directly +if [ -z "$HEAD_NODE_IP" ]; then + # Method 1: scontrol (works outside container) HEAD_NODE_IP=$(scontrol show hostnames "$SLURM_JOB_NODELIST" 2>/dev/null | head -1) - if [ -z "$HEAD_NODE_IP" ]; then - # Parse nodelist directly: "node[001-002]" → "node001" - HEAD_NODE_IP=$(echo "$SLURM_JOB_NODELIST" | sed 's/\[.*//; s/,.*//; s/ .*//') - # Resolve to IP - HEAD_NODE_IP=$(getent hosts "$HEAD_NODE_IP" 2>/dev/null | awk '{print $1}' || echo "$HEAD_NODE_IP") + # Method 2: SLURM_LAUNCH_NODE_IPADDR (some Slurm versions) + HEAD_NODE_IP=${HEAD_NODE_IP:-$SLURM_LAUNCH_NODE_IPADDR} + # Method 3: Parse SLURM_NODELIST and resolve via Python + if [ -z "$HEAD_NODE_IP" ] && [ -n "$SLURM_JOB_NODELIST" ]; then + HEAD_NODE_IP=$(python3 -c " +import socket, re, os +nl = os.environ.get('SLURM_JOB_NODELIST', '') +# Extract first hostname: 'node[001-002]' -> 'node001', 'node001,node002' -> 'node001' +m = re.match(r'([a-zA-Z0-9-]+?)(?:\[(\d+))?', nl) +if m: + host = m.group(1) + (m.group(2) or '') + try: + print(socket.gethostbyname(host)) + except: + print(host) +" 2>/dev/null) + fi + # Method 4: Use rank 0's hostname + if [ -z "$HEAD_NODE_IP" ] && [ "${SLURM_PROCID:-0}" = "0" ]; then + HEAD_NODE_IP=$(hostname -I 2>/dev/null | awk '{print $1}') fi export HEAD_NODE_IP echo "Auto-detected HEAD_NODE_IP: ${HEAD_NODE_IP}" From 3efd65956909458ebffc97d0fa580914eae445a6 Mon Sep 17 00:00:00 2001 From: Chenhan Yu Date: Fri, 3 Apr 2026 21:37:09 -0700 Subject: [PATCH 68/72] fix: force dp_shard_size=1 for DFlash DDP training DFlash uses DDP, not FSDP. The default dp_shard_size=TOTAL_GPU caused FSDP-style sharding. Force to 1 for pure DDP replication. Co-Authored-By: Claude Opus 4.6 (1M context) Signed-off-by: Chenhan Yu --- examples/speculative_decoding/launch_train.sh | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/speculative_decoding/launch_train.sh b/examples/speculative_decoding/launch_train.sh index b3b99b4c09..201f177fad 100755 --- a/examples/speculative_decoding/launch_train.sh +++ b/examples/speculative_decoding/launch_train.sh @@ -241,8 +241,9 @@ elif [[ "$MODE" == "dflash" ]]; then if [[ "$DFLASH_USE_LOGIT_DISTILLATION" == "True" ]]; then SPECULATIVE_ARGS="$SPECULATIVE_ARGS --dflash_use_logit_distillation" fi - # DFlash uses DDP instead of FSDP + # DFlash uses DDP instead of FSDP — force dp_shard_size=1 FSDP_ARGS="--ddp_find_unused_parameters True --ddp_timeout 1800" + DP_SHARD_SIZE=1 else echo "Unsupported mode: $MODE. Supported: eagle3, dflash" exit 1 From 7d0028e781c5fdf6c3520937bcd2155d79bb5abd Mon Sep 17 00:00:00 2001 From: Chenhan Yu Date: Fri, 3 Apr 2026 21:47:32 -0700 Subject: [PATCH 69/72] fix: support both DDP and FSDP for DFlash training DFlash defaults to DDP (dp_shard_size=1). Pass --fsdp True to use FSDP with full_shard instead. Co-Authored-By: Claude Opus 4.6 (1M context) Signed-off-by: Chenhan Yu --- examples/speculative_decoding/launch_train.sh | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/examples/speculative_decoding/launch_train.sh b/examples/speculative_decoding/launch_train.sh index 201f177fad..86fc400a35 100755 --- a/examples/speculative_decoding/launch_train.sh +++ b/examples/speculative_decoding/launch_train.sh @@ -241,9 +241,13 @@ elif [[ "$MODE" == "dflash" ]]; then if [[ "$DFLASH_USE_LOGIT_DISTILLATION" == "True" ]]; then SPECULATIVE_ARGS="$SPECULATIVE_ARGS --dflash_use_logit_distillation" fi - # DFlash uses DDP instead of FSDP — force dp_shard_size=1 - FSDP_ARGS="--ddp_find_unused_parameters True --ddp_timeout 1800" - DP_SHARD_SIZE=1 + # DFlash: DDP by default, FSDP if --fsdp True is passed + if [[ "$FSDP" == "True" ]]; then + FSDP_ARGS="--fsdp 'full_shard' --fsdp_config ${SCRIPT_DIR}/fsdp_config.json" + else + FSDP_ARGS="--ddp_find_unused_parameters True --ddp_timeout 1800" + DP_SHARD_SIZE=1 + fi else echo "Unsupported mode: $MODE. Supported: eagle3, dflash" exit 1 From 496830dcffbe83b1dca93d677f7be18d304ce510 Mon Sep 17 00:00:00 2001 From: Chenhan Yu Date: Fri, 3 Apr 2026 21:50:44 -0700 Subject: [PATCH 70/72] fix: use AutoModelForCausalLM directly for export loading load_vlm_or_llm uses meta tensors internally. Use AutoModelForCausalLM.from_pretrained with low_cpu_mem_usage=False to avoid meta tensor errors during export. Co-Authored-By: Claude Opus 4.6 (1M context) Signed-off-by: Chenhan Yu --- .../launcher/common/dflash/online_training.sh | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/tools/launcher/common/dflash/online_training.sh b/tools/launcher/common/dflash/online_training.sh index 48c66d9e01..ccdafe4e0c 100644 --- a/tools/launcher/common/dflash/online_training.sh +++ b/tools/launcher/common/dflash/online_training.sh @@ -116,22 +116,24 @@ CUDA_VISIBLE_DEVICES=0 python3 -c " import torch import modelopt.torch.opt as mto from modelopt.torch.export import export_speculative_decoding -from modelopt.torch.speculative.utils import load_vlm_or_llm, patch_transformers5_params_loading +from transformers import AutoModelForCausalLM mto.enable_huggingface_checkpointing() try: - with patch_transformers5_params_loading(): - model = load_vlm_or_llm( - '${OUTPUT_DIR}', - torch_dtype=torch.bfloat16, - device_map={'': 'cpu'}, - trust_remote_code=True, - ) + model = AutoModelForCausalLM.from_pretrained( + '${OUTPUT_DIR}', + torch_dtype=torch.bfloat16, + device_map='cpu', + low_cpu_mem_usage=False, + trust_remote_code=True, + ) model.eval() with torch.inference_mode(): export_speculative_decoding(model, export_dir='${EXPORT_DIR}') print('Export complete') except Exception as e: + import traceback + traceback.print_exc() print(f'Export failed: {e}') " || echo "WARNING: Export script failed, continuing with AR validation" From 725596930028d6a6c10c8a1ba984faf743079a82 Mon Sep 17 00:00:00 2001 From: Chenhan Yu Date: Fri, 3 Apr 2026 21:55:39 -0700 Subject: [PATCH 71/72] fix: use export_hf_checkpoint.py script for DFlash export Same script used by EAGLE3 export. Avoids custom loading logic that caused meta tensor errors. Co-Authored-By: Claude Opus 4.6 (1M context) Signed-off-by: Chenhan Yu --- .../launcher/common/dflash/online_training.sh | 28 +++---------------- 1 file changed, 4 insertions(+), 24 deletions(-) diff --git a/tools/launcher/common/dflash/online_training.sh b/tools/launcher/common/dflash/online_training.sh index ccdafe4e0c..62fe51fcc2 100644 --- a/tools/launcher/common/dflash/online_training.sh +++ b/tools/launcher/common/dflash/online_training.sh @@ -112,30 +112,10 @@ echo "=== Exporting DFlash checkpoint ===" echo "Source: ${OUTPUT_DIR}" echo "Export: ${EXPORT_DIR}" -CUDA_VISIBLE_DEVICES=0 python3 -c " -import torch -import modelopt.torch.opt as mto -from modelopt.torch.export import export_speculative_decoding -from transformers import AutoModelForCausalLM - -mto.enable_huggingface_checkpointing() -try: - model = AutoModelForCausalLM.from_pretrained( - '${OUTPUT_DIR}', - torch_dtype=torch.bfloat16, - device_map='cpu', - low_cpu_mem_usage=False, - trust_remote_code=True, - ) - model.eval() - with torch.inference_mode(): - export_speculative_decoding(model, export_dir='${EXPORT_DIR}') - print('Export complete') -except Exception as e: - import traceback - traceback.print_exc() - print(f'Export failed: {e}') -" || echo "WARNING: Export script failed, continuing with AR validation" +python3 modules/Model-Optimizer/examples/speculative_decoding/scripts/export_hf_checkpoint.py \ + --model_path ${OUTPUT_DIR} \ + --export_path ${EXPORT_DIR} \ + || echo "WARNING: Export failed, continuing with AR validation" echo "" echo "Export contents:" From 3a8ff9c52bc411f0660bb896afeb67e7ee3099ce Mon Sep 17 00:00:00 2001 From: Chenhan Yu Date: Sat, 4 Apr 2026 13:08:47 -0700 Subject: [PATCH 72/72] fix: load model from output_dir for checkpoint resume Loading from checkpoint subdirectory (e.g., checkpoint-12500/) causes meta tensor errors with transformers 5.x. Load from output_dir (top-level save) instead, which works. The checkpoint path is still passed to trainer.train(resume_from_checkpoint=...) for optimizer and step count resume. Co-Authored-By: Claude Opus 4.6 (1M context) Signed-off-by: Chenhan Yu --- examples/speculative_decoding/main.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/examples/speculative_decoding/main.py b/examples/speculative_decoding/main.py index fd9cac0917..b5f2b3e26f 100644 --- a/examples/speculative_decoding/main.py +++ b/examples/speculative_decoding/main.py @@ -221,15 +221,19 @@ def train(): use_offline_training = data_args.offline_data_path is not None if checkpoint: + # Load model from output_dir (top-level save) rather than checkpoint subdir + # to avoid meta tensor errors. The checkpoint path is passed to + # trainer.train(resume_from_checkpoint=...) for optimizer/step resume. + model_load_path = training_args.output_dir with patch_transformers5_params_loading(): model = load_vlm_or_llm( - checkpoint, + model_load_path, torch_dtype="auto", device_map="cpu", trust_remote_code=model_args.trust_remote_code, ) tokenizer = transformers.AutoTokenizer.from_pretrained( - checkpoint, trust_remote_code=model_args.trust_remote_code + model_load_path, trust_remote_code=model_args.trust_remote_code ) else: # To avoid OOM for large models, we load and convert model on CPU first.