TorchSpec is a torch-native speculative decoding training framework. We introduce a disaggregated way of training speculative decoding draft models where inference and training are fully decoupled and stream hidden states directly from inference engine groups to distributed training workers via Mooncake store, allowing each side to scale independently.
TorchSpec currently includes training flows and examples for:
- Kimi-K2.5
- MiniMax-M2.5
- Qwen3-Coder-Next
- PyTorch blog: TorchSpec: Speculative Decoding Training at Scale
- Release blog: TorchSpec: Speculative Decoding Training at Scale
- Released draft model: lightseekorg/kimi-k2.5-eagle3
- Architecture Overview
- Quick Start
- Setup
- Examples
- Training Modes
- Checkpoint Conversion
- Metrics Reporting
- Troubleshooting
TorchSpec is built around a disaggregated training pipeline:
- Inference engines generate target-model hidden states with either vLLM or SGLang.
- Mooncake store transfers tensors between inference and training without materializing them on disk.
- Training workers consume streamed hidden states to train speculative decoding draft models.
This separation keeps the training side focused on optimization while letting the inference side scale for hidden-state generation throughput.
Train an Eagle3 draft model for Qwen3-8B on a single node with 4 GPUs (2 for training and 2 for inference):
./examples/qwen3-8b-single-node/run.shOverride config values directly from the CLI:
./examples/qwen3-8b-single-node/run.sh training.learning_rate=5e-5 training.num_train_steps=500# Install with vLLM
./tools/build_conda.sh 1 vllm
micromamba activate torchspec
# Or install with SGLang
./tools/build_conda.sh
micromamba activate torchspecTo install into your current environment instead:
./tools/build_conda.sh current sglang # or 'vllm' or 'both'Optional: install Flash Attention support:
pip install -e ".[fa]"vLLM
./examples/qwen3-8b-single-node/run.sh --config configs/vllm_qwen3_8b.yamlSGLang
./examples/qwen3-8b-single-node/run.shTorchSpec uses vLLM's Worker Extension mechanism to hook into the model forward pass and capture hidden states directly inside worker processes, which avoids RPC serialization overhead during extraction. For SGLang, TorchSpec applies a patch to the existing codebase to enable hidden-state extraction.
| Example | Backend | Model |
|---|---|---|
| hf-quickstart | HuggingFace | Qwen3-8B |
| qwen3-8b-single-node | Inference engine | Qwen3-8B |
| kimi-k25-2node-h200 | Inference engine | Kimi-K2.5 |
| kimi-k25-3node-h100 | Inference engine | Kimi-K2.5 |
| minimax-m25-5node-h200 | Inference engine | MiniMax-M2.5 |
See examples/README.md for more details about each example.
Both modes use training.load_path, but they restore different states:
| Goal | training.load_path |
training.continual_training |
What gets restored |
|---|---|---|---|
| Resume an interrupted run | Required | false (default) |
Model, optimizer, LR scheduler, RNG, and step metadata |
| Start a new run from existing weights | Required | true |
Model weights only |
Resume the same run:
training:
load_path: /path/to/old_run/checkpoints
output_dir: /path/to/old_runStart a new run from existing weights:
training:
load_path: /path/to/old_run/checkpoints
continual_training: true
learning_rate: 1e-5
warmup_ratio: 0.01
num_epochs: 1
output_dir: /path/to/new_runConvert an FSDP checkpoint to HuggingFace format:
python tools/convert_to_hf.py --input-dir ./outputs/my_experiment/iter_0010000/Vocabulary pruning, which reduces the draft model lm_head to a smaller token set and emits d2t and t2d mappings, can be applied either during training or at conversion time.
- Pre-pruning: set
draft_vocab_sizein your training config. The checkpoint already contains the prunedlm_headandd2t/t2dbuffers, so the basic conversion command is enough. - Post-pruning: train with the full vocabulary, then pass
--prune-vocabat conversion time together with a representative dataset to compute token frequencies.
python tools/convert_to_hf.py \
--input-dir ./outputs/my_experiment/iter_0010000/ \
--prune-vocab \
--dataset-path Aeala/ShareGPT_Vicuna_unfiltered \
--draft-vocab-size 32000 \
--tokenizer Qwen/Qwen3-8B \
--chat-template qwen \
--prompt-key conversationsPass --cache-dir ./cache to reuse the tokenized dataset cache from training.
W&B logging is disabled by default with report_to: none. To enable it, set report_to: wandb in your config and provide your API key.
Set TORCHSPEC_LOG_LEVEL=DEBUG for more verbose logging when diagnosing issues:
TORCHSPEC_LOG_LEVEL=DEBUG ./examples/qwen3-8b-single-node/run.shSet TORCHSPEC_LOG_DIR to an absolute path on a shared filesystem (NFS) to enable per-rank log files for every Ray actor on both training and inference:
export TORCHSPEC_LOG_DIR=/my_project/running_logsThis creates a structured directory with one file per actor, organized by role and node:
running_logs/
training/
10.0.0.1/
training_g0_rank0_20260301_080012.log
training_g0_rank1_20260301_080012.log
10.0.0.2/
training_g0_rank2_20260301_080013.log
inference/
10.0.0.1/
inference_g0_rank0_20260301_080014.log
10.0.0.2/
inference_g0_rank1_20260301_080015.log
The path must be absolute and writable from all nodes. If TORCHSPEC_LOG_DIR is unset or not writable, per-rank file logging stays disabled and Ray falls back to stdout/stderr capture.
| Issue | Reference |
|---|---|
| Stuck or failing distributed runs, Ray actor errors | docs/debugging_ray_jobs.md |
| Ray cluster setup, actor hierarchy, placement groups | docs/ray.md |
| Pipeline bottlenecks, slow steps, throughput analysis | docs/performance_metrics.md |
