speculative decoding
2026.03.23my favorite trick for making large language models faster without making them dumber.
autoregressive decoding is slow because it’s serial. to generate a 500-token response from a 70B-parameter model, you call the model 500 times, one token at a time. each call is a full forward pass over 70B weights. the frustrating part is that on modern gpus, that pass is mostly memory-bound rather than compute-bound. the gpu spends most of its life waiting for weights to arrive in cache, and the tensor cores sit idle. the implication: if you could verifymore tokens per forward pass, you’d barely pay extra compute for the privilege.
speculative decoding is how you get exactly that, for free, at the same output distribution. the idea, from leviathan et al. 2023, is easier to see as a picture than a paragraph.
the algorithm in four steps:
- draft the next k tokens with a small, cheap model (say a 1B). typical k is 4 to 8.
- run the big target model once, in parallel, over all kdraft tokens. you get the target’s probability at each position in a single forward pass.
- walk down the kpositions. at each one, accept the draft’s token with probability min(1, p_target / p_draft).
- on the first rejection, throw out the rest of the draft and resample that position from (p_target − p_draft) clipped to zero and renormalized.
worked through on a made-up four-token example:
| pos | draft tok | p_draft | p_target | accept prob | emitted |
|---|---|---|---|---|---|
| 1 | “the” | 0.72 | 0.85 | min(1, 0.85/0.72) = 1.00 | “the” |
| 2 | “cat” | 0.54 | 0.60 | min(1, 0.60/0.54) = 1.00 | “cat” |
| 3 | “sat” | 0.81 | 0.31 | min(1, 0.31/0.81) = 0.38 | resample from (p_T − p_D)₊ |
| 4 | “on” | 0.62 | 0.70 | — (never reached) | — |
in this round the draft was right about “the” and “cat,” disagreed with the target on “sat,” and the rejection at position 3 means we never even look at position 4. we emit three tokens from a single target forward pass: two draft accepts plus one resampled token. a regular decoder would have done three target passes to get here.
the property that makes this not-a-hack is that the output distribution is exactly the target model’s. it’s not an approximation. it’s not a smaller model’s output dressed up with extra steps. the rejection-sampling math above (the min(1, p_target / p_draft) accept and the (p_target − p_draft)₊resample) is the exact construction needed so that every emitted token is drawn from the target’s real distribution. you pay nothing in quality, and you pay compute only for the draft, which is tiny by construction.
the thing that made it click for me is that autoregressive sampling is sequential in name only. what’s sequential is the commitment: you can’t emit token 5 before you know what token 4 was. but verification is parallel. a transformer with a causal mask can score k candidate tokens in a single forward pass, because each of those positions already has the preceding context baked in. spec decoding is the observation that if you have a plausible guess for tokens 1 through k, you can pay one forward pass to find out which of those guesses the big model would have picked anyway.
in practice you see 2 to 3x speedups on reasonable draft/target pairs, and more when the draft is well-aligned with the target (same tokenizer, similar training distribution). distilled drafts help a lot here. medusa-style heads, where the target model itself grows extra heads that propose the next few tokens, cut out the separate draft model entirely. eagle pushes the same idea further by speculating in the hidden-state space rather than the token space.
the core idea has been generalized in every direction since. tree speculation samples a tree of draft paths and accepts the longest path that verifies. self-speculative decoding uses the target’s own earlier layers as the draft. staged spec cascades small then medium then target. they all live in the same box though: decouple the sequential part of generation from the expensive part, and push as much of the expensive part into parallel as you can.
if you’re running llm inference at scale and not doing this, you’re leaving a solid third of your throughput on the floor.