Model-Optimizer/examples/speculative_decoding at main · NVIDIA/Model-Optimizer

Speculative Decoding

Documentation

Speculative decoding accelerates auto-regressive generation in large language models (LLMs) by leveraging a lightweight draft model to predict the next γ tokens. The main LLM then verifies these candidate tokens in a single forward pass. If the draft model correctly predicts α tokens, the LLM can accept and generate α+1 tokens per verification step, significantly improving generation speed.

This folder contains an end-to-end runnable speculative decoding fine‑tuning pipeline in which Llama‑3.2‑1B (Hugging Face) is trained on the Daring‑Anteater dataset.

This example focuses on training with Hugging Face. To train with Megatron‑LM, see the Megatron‑LM example.

Contents

Section Description Jump To
Pre-Requisites Required & optional dependencies [Link]
Simplified Workflow Train, evaluate, and export EAGLE model with one-line command [Link]
Online Training Train draft model alongside base model in GPU memory [Link]
Offline Training Train draft model using pre-computed hidden states [Link]
After Training Evaluation, export and deployment [Link]
Advanced Usage Data synthesis, vocab compression, and configuration [Link]
Support Matrix Supported models for speculative decoding training [Link]
Speculation Module Checkpoints View pre-trained speculation modules ready to deploy! [Link]
Resources Extra links to relevant resources [Link]

Pre-Requisites

Docker

Please use the PyTorch docker image (e.g., nvcr.io/nvidia/pytorch:25.08-py3) or visit our installation docs for more information.

Also follow the installation steps below to upgrade to the latest version of Model Optimizer and install dataset and example-specific dependencies.

Local Installation

Install Modelopt with hf dependencies and other requirements for this example:

pip install -U nvidia-modelopt[hf]
pip install -r requirements.txt

Data Preparation

We use Daring-Anteater dataset in this example. Prepare data by:

python prepare_input_conversations/add_daring_anteater.py

See other-datasets section for other dataset options and instruction for user-provided data.

Getting Started: Simplified Workflow

bash train_eagle3_and_export.sh --base_model meta-llama/Llama-3.2-1B-Instruct

This one-line command runs a minimal example workflow of training and exporting an EAGLE draft model in Modelopt. Specifically, it

Training Draft Model with Online Base Model

For small base models that fit in GPU memory, we can collocate them with draft models and train with the following command:

./launch_train.sh --model $BASE_MODEL \
            --output_dir $OUTPUT_DIR \
            --data input_conversations/daring-anteater.jsonl  \
            --num_epochs $NUM_EPOCH \
            --eagle_config eagle_config.json

FSDP2 is used by default. To enable context parallelism for long-context training, specify --cp_size n. The saved modelopt checkpoint is similar in architecture to HF models. It can be further optimized through ModelOpt, e.g., PTQ and QAT.

Training Draft Model with Offline Base Model

For large models, you can export intermediate hidden states to disk and train only the draft model. This significantly reduces GPU memory requirements, but requires several to tens of terabytes of disk storage depending on dataset size.

Dumpping Hidden States to Disk

We support two backends for generating base model hidden states. For better effciency, it is recommended to use TRT-LLM:

python collect_hidden_states/compute_hidden_states_trtllm.py \
            --model $BASE_MODEL \
            --input-file input_conversations/daring-anteater.jsonl \
            --output-dir $HIDDEN_STATES_DIR

NOTE: TRT-LLM installation needed for the above command.

Alternatively, you can generate the same hidden states with HF:

python collect_hidden_states/compute_hidden_states_hf.py \
            --model $BASE_MODEL \
            --input-file input_conversations/daring-anteater.jsonl  \
            --output-dir $HIDDEN_STATES_DIR

NOTE: See run_hf_compute_hiddens_dp.sh and run_trtllm_compute_hiddens_dp.sh for a simple example using data parallelism (DP) to accelerate hidden state generation.

Train Draft Model with Dumped Hidden States

Once we finish dumping hidden states, launch offline training with an extra --offline-data argument:

./launch_train.sh --model $BASE_MODEL \
            --output_dir $OUTPUT_DIR \
            --data $DATA \
            --num_epochs $NUM_EPOCH \
            --eagle_config eagle_config.json \
            --offline-data $HIDDEN_STATES_DIR

Model Validation

For online training checkpoints, we can run in-framework evaluation on MT-bench:

python scripts/ar_validate.py --model_path $ONLINE_CKPT

Note: In-framework evaluation is supported only for online training. For offline training checkpoints, please export the model and evaluate it using serving frameworks.

Export

python scripts/export_hf_checkpoint.py --model_path $OUTPUT_DIR --export_path $EXPORT_PATH

This exports the model from a ModelOpt checkpoint to a deployment-compatible format.

Deployment

The exported checkpoint can be deployed on TRT-LLM or SGLang.

TRT-LLM

To serve the checkpoint with TRT-LLM, run trtllm-serve with:

trtllm-serve <base_model_checkpoint> --host 0.0.0.0 --port 8000 --backend pytorch --max_batch_size 32 --max_num_tokens 8192 --max_seq_len 8192 --extra_llm_api_options extra-llm-api-config.yml

, with extra-llm-api-config.yml being

enable_attention_dp: false
disable_overlap_scheduler: true
enable_autotuner: false

cuda_graph_config:
    max_batch_size: 1

speculative_config:
    decoding_type: Eagle
    max_draft_len: 3
    speculative_model_dir: <draft_model_checkpoint>

kv_cache_config:
    enable_block_reuse: false

Please refer to TRT-LLM Doc: Speculative Decoding for detailed usage.

vLLM

Please refer to VLLM Doc: Speculative Decoding for detailed usage.

Optionally, you can convert the exported checkpoint to contain target model information, which is accepted by vLLM to simplify depployment:

python scripts/convert_to_vllm_ckpt.py --input <exported_ckpt> --verifier <target_model> --output <output_dir>

SGLang

Please refer to SGLang Doc: Speculative Decoding for detailed usage.

SpecDec Bench

One can also use examples/specdec_bench to validate the trained Eagle3 checkpoints in a variety of frameworks (vLLM, SGLang, TRT-LLM) on a set of datasets.

Deploying Quantized model

See more details on deployment of quantized model to TRTLLM here.

Advanced Usage

Other Datasets

In addition to daring-anteater, we provide scripts for adding several other commonly used datasets in prepare_input_conversations:

prepare_input_conversations/
    ├── add_daring_anteater.py
    ├── add_mtbench.py
    ├── add_sharegpt.py
    ├── add_ultrachat.py
    └── example_make_prompt_dataset.sh

To use your own datasets, please preprocess your data into a .jsonl file with each line in the format:

{
    "conversation_id": <unique id>,
    "conversations": [{"role":<user or assistant>, "content":<content>}]
}

Data Synthesis

To achieve higher acceptance rates during speculative decoding, it is beneficial to use conversations generated by the base model as training data. This ensures that the draft model's output distribution closely aligns with that of the base model.

To prepare such data, we launch an inference server with the base model:

pip install vllm
vllm serve meta-llama/Llama-3.2-1B-Instruct --api-key token-abc123 --port 8000  --tensor-parallel-size 1

Note: Add --quantization=modelopt flag for quantized models.

Then, we generate conversations with the base model using prompts from Daring-Anteater:

python scripts/server_generate.py --data_path input_conversations/daring-anteater.jsonl --output_path synthetic/train.jsonl

To add a system prompt, use the --system_prompt <system_prompt_text> argument.

For large scale data generation, please see SLURM prepare data for SLURM support.

Configuring Draft Model

For EAGLE‑1 and EAGLE‑3 we provide a default model architecture config in ModelOpt. You can override default settings by providing an additional JSON dict. E.g. To use 2-layer eagle with 8192 intermediate size for MLP, set eagle_config.json to:

{
    "num_hidden_layers": 2,
    "intermediate_size":8192
}

Draft Vocabulary Compression

We can optionally use smaller vocab size for the draft model for faster training and inference. E.g. Llama3.2-1B has a vocab size of 128256. In this example, we construct a draft vocab mapping of size 32k by finding the most commonly appeared vocabs in our training set:

python scripts/calibrate_draft_vocab.py --model meta-llama/Llama-3.2-1B-Instruct --data input_conversations/daring-anteater.jsonl --draft_vocab_size 32000 --save_dir draft_vocab_cache

This will produce a d2t.pt file in save_dir, which is the mapping from draft token to target token. During inference, draft tokens can be mapped back to target tokens by target_token = draft_token + d2t[draft_token].

Then, simply set {"draft_vocab_size":32000} in eagle_config.json and include --draft_vocab_cache <path_to_d2t.pt> when running ./launch_train.sh. The draft model will use this provided vocab table during training and export.

Interact with modelopt.torch.speculative

main.py provides an example for converting a HF base model for speculative decoding and training it. It consists of a few simple steps: First, load the base model and tokenizer from Hugging Face:

model = transformers.AutoModelForCausalLM.from_pretrained(
    "<path to your pretrained model>"
)

Then, load default eagle config and make necessary overwrites:

# Load default config
config = {
    "eagle1": EAGLE1_DEFAULT_CFG,
    "eagle3": EAGLE3_DEFAULT_CFG,
}[training_args.mode]["config"]

# overwrite config with custom config
config["eagle_architecture_config"].update({"<overwrite_keys>": "<overwrite_values>"})

# Mandatory: hidden size, vocab size and max position embeddings must match base model
config["eagle_architecture_config"].update(
    {
        "hidden_size": model.config.hidden_size,
        "vocab_size": model.config.vocab_size,
        "max_position_embeddings": model.config.max_position_embeddings,
    }
)

Then, we convert model to a speculative decoding model:

mtsp.convert(model, [("eagle", config)])

This will modify the model in-place with eagle training forward, making it compatible with HF trainer:

# Create a trainer
trainer = transformers.Trainer(model=model, tokenizer=tokenizer, args=training_args, **data_module)
trainer._move_model_to_device(model, trainer.args.device)

# Enable HF checkpointing so that the saved model will contain the speculative decoding module
mto.enable_huggingface_checkpointing()

trainer.train(resume_from_checkpoint=checkpoint)
trainer.save_state()
trainer.save_model("<path to the output directory>")

Support Matrix

Model Medusa EAGLE1/2 EAGLE3
LLAMA 2
LLAMA 3, 3.1
Mistral
Phi 3
QWen 1.5,2,2.5,3

Speculation Module Checkpoints

Ready-to-deploy speculation module checkpoints [🤗 Hugging Face - NVIDIA Speculative Decoding Modules Collection] Deployable on TensorRT-LLM and SGLang!
More models coming soon!

Resources