Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
71 commits
Select commit Hold shift + click to select a range
3f89ea9
add: DFlash block diffusion speculative decoding
ChenhanYu Mar 27, 2026
190cb3a
fix: rewrite DFlash to match SpecForge reference
ChenhanYu Mar 28, 2026
b7a2a7b
fix: correct mask_token_id and base model forward dispatch
ChenhanYu Mar 29, 2026
a310d96
add: auto-detect mask_token_id for DFlash across model families
ChenhanYu Mar 29, 2026
972dfaa
fix: prevent DDP deadlock during AR validation
ChenhanYu Mar 29, 2026
6c4eb80
fix: avoid DynamicModule dispatch loop in forward/training paths
ChenhanYu Mar 29, 2026
2c42363
fix: revert training/eval to super().forward() matching EAGLE pattern
ChenhanYu Mar 30, 2026
a279960
fix: DDP deadlock when no valid loss positions on a rank
ChenhanYu Mar 30, 2026
cbddc30
add: logit distillation option for DFlash training
ChenhanYu Mar 30, 2026
c53a66a
fix: print training accuracy to console at each log step
ChenhanYu Mar 30, 2026
2eabf57
fix: use response-only loss mask for DFlash training
ChenhanYu Mar 31, 2026
2a16232
fix: apply assistant_masks to labels in LanguageDataCollator
ChenhanYu Mar 31, 2026
e3b9930
fix: robust response-only loss mask via regex assistant span detection
ChenhanYu Mar 31, 2026
07066c2
docs: add DFlash section to speculative decoding README
ChenhanYu Mar 31, 2026
a32de63
fix: resolve DFlash components from base model architecture
ChenhanYu Mar 31, 2026
6a6a9ca
fix: enable response-only loss mask for DFlash training
ChenhanYu Mar 31, 2026
a777849
add: DFlash launcher example for Qwen3-8B
ChenhanYu Apr 1, 2026
2c56aca
fix: inline values in DFlash launcher YAML for --yaml compatibility
ChenhanYu Apr 1, 2026
306fc3e
add: unit tests for DFlash speculative decoding
ChenhanYu Apr 1, 2026
c4a3ecb
fix: add docstrings to DFlash classes for coverage check
ChenhanYu Apr 1, 2026
1c23ced
add: AR validation step to DFlash launcher pipeline
ChenhanYu Apr 1, 2026
38450b0
fix: split DFlash tests into CPU (unit) and GPU tests
ChenhanYu Apr 1, 2026
4c2fc77
fix: correct DFlash attention mask test for reverse-causal pattern
ChenhanYu Apr 1, 2026
bce17cf
fix: remove __init__.py from GPU test dirs to avoid conftest conflict
ChenhanYu Apr 1, 2026
1165272
fix: match dtype in DFlash GPU tests to model dtype
ChenhanYu Apr 1, 2026
273ba32
fix: use Optional types for nullable DFlash arguments
ChenhanYu Apr 1, 2026
9bf9c34
fix: merge AR validation into DFlash training script
ChenhanYu Apr 1, 2026
d19cd3b
fix: align pseudo_speculative_generate with training masks
ChenhanYu Apr 2, 2026
73bb0cc
fix: use standard causal mask within DFlash blocks
ChenhanYu Apr 2, 2026
3fa0d64
fix: increase DDP timeout to 1800s for DFlash training
ChenhanYu Apr 2, 2026
80afde2
fix: revert to SpecForge's reverse-causal mask (j >= i)
ChenhanYu Apr 2, 2026
bfdd582
fix: use continuing position IDs for DFlash inference block
ChenhanYu Apr 2, 2026
fb7acab
fix: remove attention mask at DFlash inference, matching SpecForge
ChenhanYu Apr 2, 2026
290670f
add: standalone DFlash training script with SpecForge data pipeline
ChenhanYu Apr 2, 2026
eb6a0c9
fix: create attention mask in f32 then cast, matching SpecForge
ChenhanYu Apr 2, 2026
2c853c1
fix: use HF attention dispatch in DFlashAttention for SpecForge parity
ChenhanYu Apr 2, 2026
d6adadb
fix: default DFlash attention to sdpa matching SpecForge
ChenhanYu Apr 2, 2026
65df160
fix: initialize DFlash weights with normal_(std=0.02) matching SpecForge
ChenhanYu Apr 2, 2026
4451101
debug: add attn_fn resolution and per-layer comparison prints
ChenhanYu Apr 2, 2026
2726068
feat: update DFlash training to match SpecForge latest (post-PR #473)
ChenhanYu Apr 3, 2026
3516c0b
fix: remove extra unsqueeze in DFlash training attention mask
ChenhanYu Apr 3, 2026
606e31d
fix: create training attention mask in f32 to avoid bf16 overflow
ChenhanYu Apr 3, 2026
e1237f7
fix: add dflash_num_anchors/loss_decay_gamma to launch_train.sh
ChenhanYu Apr 3, 2026
b8e5eb7
feat: add logit distillation to new random-anchor DFlash training
ChenhanYu Apr 3, 2026
b0df28c
fix: add dflash_use_logit_distillation to launch_train.sh
ChenhanYu Apr 3, 2026
4226349
fix: shift teacher logits by -1 for DFlash logit distillation
ChenhanYu Apr 3, 2026
818eb74
fix: mask all tokens when assistant pattern not found
ChenhanYu Apr 3, 2026
49038c7
feat: auto-inject generation tags for reliable answer_only_loss
ChenhanYu Apr 3, 2026
c49f6d9
fix: add <think> wrapper to simplified ChatML template for Qwen3
ChenhanYu Apr 3, 2026
ae2e7bd
fix: remove think wrapper from simplified ChatML template
ChenhanYu Apr 3, 2026
82eedb2
feat: add chatml_think template variant for Qwen3 think injection
ChenhanYu Apr 3, 2026
4ebb9de
docs: document simplified generation templates and limitations
ChenhanYu Apr 3, 2026
4561515
fix: ensure zero-loss path has gradient for DDP sync
ChenhanYu Apr 3, 2026
047ba1d
fix: prefer conversations field when messages lacks assistant turn
ChenhanYu Apr 3, 2026
bdcc0de
cleanup: remove debug prints and regex fallback in DFlash and dataset…
ChenhanYu Apr 3, 2026
e526016
fix: unwrap DDP model for AR validation to avoid deadlock
ChenhanYu Apr 3, 2026
633da55
fix: skip samples without assistant turns instead of crashing
ChenhanYu Apr 3, 2026
43afb06
fix: handle empty batch with dummy assistant turn
ChenhanYu Apr 3, 2026
90e9b4b
fix: AR validation deadlock - eval all ranks, validate on rank 0
ChenhanYu Apr 3, 2026
870db23
feat: add TensorBoard logging for DFlash training
ChenhanYu Apr 3, 2026
0d3f1fa
fix: skip AR validation during DDP training to prevent deadlock
ChenhanYu Apr 3, 2026
9b76b7d
feat: add DFlash export to z-lab compatible HF format
ChenhanYu Apr 4, 2026
1cfd558
fix: checkpoint resume + export-then-validate pipeline
ChenhanYu Apr 4, 2026
6a153bf
feat: auto-detect HEAD_NODE_IP for multi-node DFlash training
ChenhanYu Apr 4, 2026
c0c4330
fix: use explicit bfloat16 and device_map for export loading
ChenhanYu Apr 4, 2026
6684f47
fix: improve HEAD_NODE_IP auto-detection for multi-node
ChenhanYu Apr 4, 2026
dd6e282
fix: multi-method HEAD_NODE_IP detection for multi-node
ChenhanYu Apr 4, 2026
3efd659
fix: force dp_shard_size=1 for DFlash DDP training
ChenhanYu Apr 4, 2026
7d0028e
fix: support both DDP and FSDP for DFlash training
ChenhanYu Apr 4, 2026
496830d
fix: use AutoModelForCausalLM directly for export loading
ChenhanYu Apr 4, 2026
7255969
fix: use export_hf_checkpoint.py script for DFlash export
ChenhanYu Apr 4, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 88 additions & 7 deletions examples/speculative_decoding/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -319,15 +319,96 @@ trainer.save_state()
trainer.save_model("<path to the output directory>")
```

## DFlash: Block Diffusion for Flash Speculative Decoding

DFlash ([arXiv:2602.06036](https://arxiv.org/abs/2602.06036)) is a parallel speculative decoding method that predicts multiple tokens simultaneously using block diffusion. Unlike autoregressive methods (EAGLE, Medusa) that draft one token at a time, DFlash predicts an entire block of tokens in parallel, then iteratively denoises them.

### Architecture

DFlash uses three key mechanisms:

- **Feature Fusion**: Multi-layer hidden states from the target model are projected via a fully-connected layer and RMSNorm to create context features
- **KV Injection**: Context features are injected as K/V in every draft decoder layer, while Q comes from the noise embeddings. QK-Norm (RMSNorm on Q and K before RoPE) stabilizes attention
- **Parallel Drafting**: Within each block of size B, unknown positions use a `mask_token_id` token. Only block-start positions get the real token. The attention mask allows noise tokens to attend to all context tokens from previous blocks, plus causally within the same block

### Training

```bash
./launch_train.sh --model $BASE_MODEL \
--output_dir $OUTPUT_DIR \
--data input_conversations/train.jsonl \
--num_epochs $NUM_EPOCH \
--mode dflash \
--dflash_block_size 16 \
--dflash_num_layers 5
```

Key arguments:

| Flag | Default | Description |
|------|---------|-------------|
| `--mode dflash` | - | Enable DFlash mode |
| `--dflash_block_size` | 16 | Block size for parallel prediction |
| `--dflash_num_layers` | 5 | Number of decoder layers in draft module |
| `--dflash_config` | None | Path to JSON config for custom architecture |
| `--dflash_mask_token_id` | auto | Mask token ID (auto-detected from model) |
| `--dflash_disable_torch_compile` | False | Disable torch.compile |
| `--dflash_use_logit_distillation` | False | Use KD from target model logits instead of hard CE |

### mask_token_id

The `mask_token_id` is critical for DFlash training and inference. It must be consistent between training and deployment. Auto-detection logic:

| Model Family | mask_token_id | Source |
|-------------|---------------|--------|
| Qwen3.5 | 248070 | Built-in `[MASK]` token |
| Qwen3 (8B) | 151643 | `eos_token_id` |
| Llama 3 | 128002 | `reserved_special_token_0` |
| Others | `pad_token_id` | Fallback |

Override with `--dflash_mask_token_id <id>` if auto-detection is incorrect.

### Configuring Draft Model

Similar to EAGLE, provide a JSON config to customize the draft architecture:

```json
{
"num_hidden_layers": 5,
"rms_norm_eps": 1e-6
}
```

Model dimensions (hidden_size, num_attention_heads, etc.) are automatically inherited from the base model.

### Current Status (WIP)

| Feature | Status |
|---------|--------|
| Architecture (Feature Fusion, KV Injection, Parallel Drafting) | Working |
| Online training with HF Trainer | Working |
| Inference / AR validation (`pseudo_speculative_generate`) | Working |
| z-lab checkpoint loading and inference (AR 7-9) | Working |
| Logit distillation option | Working |
| Response-only loss masking | Working |
| DDP training | Working (with `find_unused_parameters=True`) |

**Known gap**: Training with ModelOpt achieves ~35% per-token accuracy (matching SpecForge's ~30%), but acceptance rate (AR) is lower than SpecForge-trained checkpoints (1.15 vs 1.95). Investigation shows the **data pipeline** differs significantly:

- SpecForge uses its own tokenizer template with system prompt and response-only loss mask
- ModelOpt's `LanguageDataCollator` uses `apply_chat_template` with different formatting

Aligning the data pipeline is the next step to close the AR gap.

## Support Matrix

| Model | Medusa | EAGLE1/2 | EAGLE3 |
| :---: | :---: | :---: | :---: |
| LLAMA 2 | ✅ | ✅ | ✅ |
| LLAMA 3, 3.1 | ✅ | ✅ | ✅ |
| Mistral | ✅ | ✅ | ✅ |
| Phi 3 | ✅ | ✅ | ✅ |
| QWen 1.5,2,2.5,3 | ✅ | ✅ | ✅ |
| Model | Medusa | EAGLE1/2 | EAGLE3 | DFlash |
| :---: | :---: | :---: | :---: | :---: |
| LLAMA 2 | ✅ | ✅ | ✅ | ✅ |
| LLAMA 3, 3.1 | ✅ | ✅ | ✅ | ✅ |
| Mistral | ✅ | ✅ | ✅ | ✅ |
| Phi 3 | ✅ | ✅ | ✅ | ✅ |
| QWen 1.5,2,2.5,3 | ✅ | ✅ | ✅ | ✅ |

## Speculation Module Checkpoints

Expand Down
67 changes: 49 additions & 18 deletions examples/speculative_decoding/eagle_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ def make_eagle_supervised_data_module(
tokenizer: transformers.PreTrainedTokenizer,
data_args,
train_len=None,
answer_only_loss=False,
) -> dict:
if data_args.offline_data_path is None:
train_dataset = ShardedDataset("json", data_files=data_args.data_path)
Expand All @@ -148,6 +149,7 @@ def make_eagle_supervised_data_module(
tokenizer=tokenizer,
train_len=train_len,
return_labels=True,
answer_only_loss=answer_only_loss,
)
else:
data_collator = VisionLanguageDataCollator(
Expand Down Expand Up @@ -203,6 +205,12 @@ def on_log(self, args, state, control, **kwargs):
if not hasattr(state, "training_accs") or len(state.training_accs) == 0:
return control
average_acc = np.mean(state.training_accs, axis=0)
# Always print accuracy to console
try:
acc_str = ", ".join(f"{a:.4f}" for a in np.array(average_acc).flatten())
print_rank_0(f"Step {state.global_step} Training Acc: [{acc_str}]")
except Exception:
print_rank_0(f"Step {state.global_step} Training Acc: {average_acc}")
if self.estimate_ar:
# Calculate mean training AR since last log
# NOTE: This is only an estimate of the real AR.
Expand All @@ -217,41 +225,64 @@ def on_log(self, args, state, control, **kwargs):
est_ar += acc_cumprod
print_rank_0(f"Step {state.global_step} Estimated Training AR: {est_ar:.4f}")

# Log accuracy to HF Trainer's logs dict (picked up by TensorBoard)
logs = kwargs.get("logs") or {}
for i, draft_acc in enumerate(average_acc):
for j, step_acc in enumerate(draft_acc):
logs[f"train_acc/parallel_{i}_step_{j}"] = float(step_acc)
if self.estimate_ar:
logs["estimated_training_ar"] = est_ar

# log to wandb
if wandb and is_master():
logs = kwargs.get("logs") or {}
if logs:
wandb.log({k: v for k, v in logs.items() if v is not None}, step=state.global_step)
for i, draft_acc in enumerate(average_acc):
for j, step_acc in enumerate(draft_acc):
wandb.log(
{f"parallel_{i}_step_{j}_train_acc": step_acc}, step=state.global_step
)
if self.estimate_ar:
wandb.log({"estimated_training_ar": est_ar}, step=state.global_step)

# reset training_accs
state.training_accs = []
return control

def on_step_end(self, args, state, control, **kwargs):
"""Run AR validation periodically, if available."""
"""Run AR validation periodically (single-GPU only).

AR validation with DDP is not supported because pseudo_speculative_generate
runs only on rank 0 while other ranks deadlock waiting for collective ops.
When world_size > 1, AR validation is skipped with a one-time warning.
Use post-training AR validation instead (online_training.sh runs it after training).
"""
if self.ar_validate_steps <= 0:
return control
if state.global_step % self.ar_validate_steps == 0 and state.global_step > 0:
if torch.distributed.is_initialized() and torch.distributed.get_world_size() > 1:
if not hasattr(self, "_ar_ddp_warned"):
self._ar_ddp_warned = True
print_rank_0(
"=== WARNING === AR validation during training is not supported with "
"DDP (world_size > 1). Skipping. Use post-training AR validation."
)
return control

model = kwargs["model"]
raw_model = model.module if hasattr(model, "module") else model
was_training = raw_model.training
raw_model.eval()
print_rank_0("Running AR validation...")
try:
ars = validate_ar(
model=kwargs["model"],
tokenizer=kwargs["processing_class"],
ds=load_dataset("HuggingFaceH4/mt_bench_prompts")["train"],
device=kwargs["model"].device,
)
with torch.no_grad():
ars = validate_ar(
model=raw_model,
tokenizer=kwargs["processing_class"],
ds=load_dataset("/hf-local/HuggingFaceH4/mt_bench_prompts")["train"],
device=next(raw_model.parameters()).device,
num_samples=8,
)
print_rank_0(f"Step {state.global_step} AR: {sum(ars) / len(ars):.4f}")
if wandb and is_master():
if wandb:
wandb.log({"validate_ar": sum(ars) / len(ars)}, step=state.global_step)
except Exception:
print_rank_0("AR validation not available.")
except Exception as e:
print_rank_0(f"AR validation failed: {e}")
if was_training:
raw_model.train()
return control


Expand Down
71 changes: 64 additions & 7 deletions examples/speculative_decoding/launch_train.sh
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,33 @@ while [ $# -gt 0 ]; do
if [[ "$1" != *=* ]]; then shift; fi
FSDP="${1#*=}"
;;
--dflash_block_size*)
if [[ "$1" != *=* ]]; then shift; fi
DFLASH_BLOCK_SIZE="${1#*=}"
;;
--dflash_num_layers*)
if [[ "$1" != *=* ]]; then shift; fi
DFLASH_NUM_LAYERS="${1#*=}"
;;
--dflash_config*)
if [[ "$1" != *=* ]]; then shift; fi
DFLASH_CONFIG="${1#*=}"
;;
--dflash_mask_token_id*)
if [[ "$1" != *=* ]]; then shift; fi
DFLASH_MASK_TOKEN_ID="${1#*=}"
;;
--dflash_num_anchors*)
if [[ "$1" != *=* ]]; then shift; fi
DFLASH_NUM_ANCHORS="${1#*=}"
;;
--dflash_loss_decay_gamma*)
if [[ "$1" != *=* ]]; then shift; fi
DFLASH_LOSS_DECAY_GAMMA="${1#*=}"
;;
--dflash_use_logit_distillation*)
DFLASH_USE_LOGIT_DISTILLATION="True"
;;
*)
>&2 printf "Error: Invalid argument ${1#*=}\n"
exit 1
Expand Down Expand Up @@ -195,8 +222,34 @@ if [[ "$MODE" == "eagle3" ]]; then
else
SPECULATIVE_ARGS=""
fi
elif [[ "$MODE" == "dflash" ]]; then
DFLASH_BLOCK_SIZE=${DFLASH_BLOCK_SIZE:-16}
DFLASH_NUM_LAYERS=${DFLASH_NUM_LAYERS:-5}
SPECULATIVE_ARGS="--dflash_block_size $DFLASH_BLOCK_SIZE --dflash_num_layers $DFLASH_NUM_LAYERS --dflash_disable_torch_compile"
if [[ -n "$DFLASH_CONFIG" ]]; then
SPECULATIVE_ARGS="$SPECULATIVE_ARGS --dflash_config $DFLASH_CONFIG"
fi
if [[ -n "$DFLASH_MASK_TOKEN_ID" ]]; then
SPECULATIVE_ARGS="$SPECULATIVE_ARGS --dflash_mask_token_id $DFLASH_MASK_TOKEN_ID"
fi
if [[ -n "$DFLASH_NUM_ANCHORS" ]]; then
SPECULATIVE_ARGS="$SPECULATIVE_ARGS --dflash_num_anchors $DFLASH_NUM_ANCHORS"
fi
if [[ -n "$DFLASH_LOSS_DECAY_GAMMA" ]]; then
SPECULATIVE_ARGS="$SPECULATIVE_ARGS --dflash_loss_decay_gamma $DFLASH_LOSS_DECAY_GAMMA"
fi
if [[ "$DFLASH_USE_LOGIT_DISTILLATION" == "True" ]]; then
SPECULATIVE_ARGS="$SPECULATIVE_ARGS --dflash_use_logit_distillation"
fi
# DFlash: DDP by default, FSDP if --fsdp True is passed
if [[ "$FSDP" == "True" ]]; then
FSDP_ARGS="--fsdp 'full_shard' --fsdp_config ${SCRIPT_DIR}/fsdp_config.json"
else
FSDP_ARGS="--ddp_find_unused_parameters True --ddp_timeout 1800"
DP_SHARD_SIZE=1
fi
else
echo "Only eagle3 supported for now!"
echo "Unsupported mode: $MODE. Supported: eagle3, dflash"
exit 1
fi

Expand All @@ -218,12 +271,14 @@ else
VLM_ARGS=""
fi

if [[ "$TOTAL_GPU" -gt 1 && "$FSDP" == "True" ]]; then
#Use FSDP2 when multi GPU available
FSDP_ARGS="--fsdp 'full_shard' --fsdp_config ${SCRIPT_DIR}/fsdp_config.json"
else
#Otherwise, single GPU training
FSDP_ARGS=""
if [[ "$MODE" != "dflash" ]]; then
if [[ "$TOTAL_GPU" -gt 1 && "$FSDP" == "True" ]]; then
#Use FSDP2 when multi GPU available
FSDP_ARGS="--fsdp 'full_shard' --fsdp_config ${SCRIPT_DIR}/fsdp_config.json"
else
#Otherwise, single GPU training
FSDP_ARGS=""
fi
fi


Expand Down Expand Up @@ -267,6 +322,8 @@ CMD="accelerate launch $MULTI_NODE_ARGS --mixed_precision bf16 ${SCRIPT_DIR}/mai
--warmup_steps 100 \
--lr_scheduler_type linear \
--logging_steps $LOG_STEPS \
--report_to tensorboard \
--logging_dir $OUTPUT_DIR/tensorboard \
--tf32 True \
$DATA_ARGS \
--disable_tqdm $DISABLE_TQDM \
Expand Down
Loading
Loading