Back to Blog
2 min read

Speculative Decoding: Accelerating LLM Generation

Speculative decoding uses a small model to draft tokens that a large model verifies, dramatically speeding up generation.

Speculative Decoding Implementation

# speculative_decoding.py - Implement speculative decoding

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from typing import Tuple, List

class SpeculativeDecoder:
    """Accelerate LLM generation with speculative decoding."""

    def __init__(
        self,
        draft_model_name: str,
        target_model_name: str,
        num_speculative_tokens: int = 4
    ):
        self.num_speculative = num_speculative_tokens

        # Load draft model (small, fast)
        self.draft_model = AutoModelForCausalLM.from_pretrained(
            draft_model_name,
            torch_dtype=torch.float16,
            device_map="cuda:0"
        )

        # Load target model (large, accurate)
        self.target_model = AutoModelForCausalLM.from_pretrained(
            target_model_name,
            torch_dtype=torch.float16,
            device_map="cuda:0"
        )

        self.tokenizer = AutoTokenizer.from_pretrained(target_model_name)

    @torch.no_grad()
    def generate(self, prompt: str, max_tokens: int = 100) -> str:
        """Generate with speculative decoding."""
        input_ids = self.tokenizer.encode(prompt, return_tensors="pt").cuda()
        generated = input_ids.clone()

        while generated.shape[1] - input_ids.shape[1] < max_tokens:
            # Step 1: Draft tokens with small model
            draft_tokens, draft_probs = self.draft_tokens(generated)

            # Step 2: Verify with target model
            accepted, target_token = self.verify_tokens(
                generated, draft_tokens, draft_probs
            )

            # Step 3: Accept verified tokens
            generated = torch.cat([generated, accepted], dim=1)

            # Add one token from target model
            if target_token is not None:
                generated = torch.cat([generated, target_token], dim=1)

            # Check for EOS
            if self.tokenizer.eos_token_id in generated[0, -self.num_speculative:]:
                break

        return self.tokenizer.decode(generated[0], skip_special_tokens=True)

    def draft_tokens(self, context: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """Generate draft tokens with small model."""
        draft_tokens = []
        draft_probs = []
        current = context.clone()

        for _ in range(self.num_speculative):
            outputs = self.draft_model(current)
            logits = outputs.logits[:, -1, :]
            probs = torch.softmax(logits, dim=-1)

            # Sample token
            token = torch.multinomial(probs, num_samples=1)
            draft_tokens.append(token)
            draft_probs.append(probs.gather(-1, token))

            current = torch.cat([current, token], dim=1)

        return torch.cat(draft_tokens, dim=1), torch.cat(draft_probs, dim=1)

    def verify_tokens(
        self,
        context: torch.Tensor,
        draft_tokens: torch.Tensor,
        draft_probs: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Verify draft tokens with target model."""
        # Run target model on full sequence
        full_sequence = torch.cat([context, draft_tokens], dim=1)
        outputs = self.target_model(full_sequence)
        target_logits = outputs.logits

        # Verify each draft token
        accepted_tokens = []
        for i in range(self.num_speculative):
            pos = context.shape[1] + i - 1
            target_probs = torch.softmax(target_logits[:, pos, :], dim=-1)

            draft_token = draft_tokens[:, i:i+1]
            target_prob = target_probs.gather(-1, draft_token)

            # Accept if target probability is high enough
            acceptance_ratio = target_prob / (draft_probs[:, i:i+1] + 1e-10)

            if torch.rand(1).cuda() < acceptance_ratio:
                accepted_tokens.append(draft_token)
            else:
                # Reject and sample from target
                target_token = torch.multinomial(target_probs, num_samples=1)
                return torch.cat(accepted_tokens, dim=1) if accepted_tokens else torch.tensor([]).cuda(), target_token

        # All accepted - sample one more from target
        last_logits = target_logits[:, -1, :]
        target_token = torch.multinomial(torch.softmax(last_logits, dim=-1), num_samples=1)

        return draft_tokens, target_token

Speculative decoding can achieve 2-3x speedup without any quality degradation.

Michael John Peña

Michael John Peña

Senior Data Engineer based in Sydney. Writing about data, cloud, and technology.