Skip to content

feat: Eagle-3 speculator implementation#50

Merged
dsikka merged 4 commits intomainfrom
eagle-3
Jul 16, 2025
Merged

feat: Eagle-3 speculator implementation#50
dsikka merged 4 commits intomainfrom
eagle-3

Conversation

@rahul-tuli
Copy link
Copy Markdown
Collaborator

@rahul-tuli rahul-tuli commented Jul 11, 2025

Eagle-3 Speculator Implementation

This PR adds support for Eagle-3

Architecture

image

Key Features

  • Vocabulary mapping between draft (32K) and target (128K) vocabularies
  • Custom attention layer accepting 2×hidden_size input
  • Fusion layer processing 3 verifier layers (3×H → H)
  • Configurable layer normalization placement (before/after residual)
  • Full checkpoint compatibility with HuggingFace models

Implementation Details

  • Eagle3SpeculatorConfig: Configuration with vocabulary size settings
  • Eagle3Attention: Modified attention for 2×hidden_size input
  • Eagle3DecoderLayer: Processes concatenated embeddings and hidden states
  • Eagle3Speculator: Main model with vocabulary mapping support

Based on: https://arxiv.org/abs/2503.01840

Verification

To verify the implementation works with your checkpoint:

#!/usr/bin/env python3
"""
Simple verification script for Eagle-3 speculator implementation.

This script demonstrates how to:
1. Load an Eagle-3 checkpoint from HuggingFace
2. Create the appropriate configuration
3. Instantiate the Eagle3Speculator model with a verifier
4. Load checkpoint weights
5. Run a dummy forward pass

Usage:
    python simple_verification.py
    
    # Or with custom checkpoint and verifier:
    python simple_verification.py --checkpoint nm-testing/SpeculatorLlama3-1-8B-Eagle3 --verifier meta-llama/Llama-3.1-8B
"""

import argparse

# Add parent directory to path for imports
import sys
from pathlib import Path

import torch
from huggingface_hub import snapshot_download
from safetensors import safe_open
from transformers import AutoConfig, AutoModelForCausalLM


from speculators.models.eagle3 import Eagle3Speculator, Eagle3SpeculatorConfig


def main():
    parser = argparse.ArgumentParser(description="Verify Eagle-3 implementation")
    parser.add_argument(
        "--checkpoint",
        default="nm-testing/SpeculatorLlama3-1-8B-Eagle3",
        help="Eagle-3 checkpoint to load from HuggingFace",
    )
    parser.add_argument(
        "--verifier",
        default="meta-llama/Llama-3.1-8B",
        help="Verifier model ID (e.g., meta-llama/Llama-3.1-8B)",
    )
    args = parser.parse_args()

    print("Eagle-3 Verification Script")
    print("=" * 60)
    print(f"Checkpoint: {args.checkpoint}")
    print(f"Verifier: {args.verifier}")
    print()

    # Step 1: Get checkpoint path
    print("1. Getting checkpoint path...")
    checkpoint_path = Path(args.checkpoint)

    if checkpoint_path.exists():
        # Local path
        print(f"   ✓ Using local checkpoint: {checkpoint_path}")
    else:
        # Download from HuggingFace
        print(f"   → Downloading from HuggingFace: {args.checkpoint}")
        checkpoint_path = Path(snapshot_download(args.checkpoint))
        print(f"   ✓ Downloaded to: {checkpoint_path}")

    # Step 2: Load checkpoint config to understand model architecture
    print("\n2. Loading checkpoint configuration...")
    checkpoint_config = AutoConfig.from_pretrained(checkpoint_path)

    # Extract key information from checkpoint
    hidden_size = checkpoint_config.hidden_size
    num_attention_heads = checkpoint_config.num_attention_heads
    num_key_value_heads = checkpoint_config.num_key_value_heads

    print(f"   ✓ Hidden size: {hidden_size}")
    print(f"   ✓ Attention heads: {num_attention_heads}")
    print(f"   ✓ KV heads: {num_key_value_heads}")

    # Step 3: Create Eagle3SpeculatorConfig
    print("\n3. Creating Eagle3SpeculatorConfig...")

    # Create transformer config matching the checkpoint
    transformer_config = type(checkpoint_config)(
        hidden_size=hidden_size,
        intermediate_size=checkpoint_config.intermediate_size,
        num_attention_heads=num_attention_heads,
        num_key_value_heads=num_key_value_heads,
        num_hidden_layers=1,  # Eagle-3 uses single layer
        vocab_size=128256,  # Target vocabulary size
        max_position_embeddings=checkpoint_config.max_position_embeddings,
        rope_theta=checkpoint_config.rope_theta,
        rms_norm_eps=checkpoint_config.rms_norm_eps,
        attention_bias=False,
    )

    eagle3_config = Eagle3SpeculatorConfig(
        transformer_layer_config=transformer_config,
        draft_vocab_size=32000,  # Draft vocabulary size
        norm_before_residual=True,  # HF checkpoint style
    )

    print(f"   ✓ Draft vocabulary: {eagle3_config.draft_vocab_size}")
    print(f"   ✓ Target vocabulary: {eagle3_config.target_vocab_size}")

    # Step 4: Load verifier model
    print("\n4. Loading verifier model...")
    verifier = AutoModelForCausalLM.from_pretrained(args.verifier)
    print(f"   ✓ Loaded verifier: {args.verifier}")

    # Step 5: Instantiate Eagle3Speculator
    print("\n5. Creating Eagle3Speculator...")
    model = Eagle3Speculator(eagle3_config, verifier=verifier)
    print("   ✓ Model created successfully")

    # Step 6: Load checkpoint weights
    print("\n6. Loading checkpoint weights...")
    weight_file = checkpoint_path / "model.safetensors"

    state_dict = {}
    with safe_open(weight_file, framework="pt") as f:
        for key in f.keys():
            state_dict[key] = f.get_tensor(key)

    missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)

    # Filter out expected missing keys (verifier weights)
    non_verifier_missing = [k for k in missing_keys if not k.startswith("verifier")]

    if non_verifier_missing:
        print(f"   ⚠ Missing keys: {non_verifier_missing}")
    else:
        print("   ✓ All Eagle-3 weights loaded successfully")

    if unexpected_keys:
        print(f"   ⚠ Unexpected keys: {unexpected_keys}")

    # Step 7: Verify vocabulary mapping
    print("\n7. Checking vocabulary mapping tensors...")
    print(f"   ✓ d2t shape: {model.d2t.shape} (draft to target mapping)")
    print(f"   ✓ t2d shape: {model.t2d.shape} (target to draft availability)")

    # Step 8: Run dummy forward pass
    print("\n8. Running forward pass test...")

    batch_size = 1
    seq_len = 5

    # Create dummy inputs
    input_ids = torch.randint(0, eagle3_config.target_vocab_size, (batch_size, seq_len))
    hidden_states = torch.randn(
        batch_size, seq_len, 3 * hidden_size
    )  # 3 verifier layers

    # Run forward pass
    with torch.no_grad():
        output = model(
            input_ids=input_ids, hidden_states=hidden_states, return_dict=True
        )

    # Verify output shape
    expected_shape = (batch_size, seq_len, eagle3_config.target_vocab_size)
    actual_shape = tuple(output.logits.shape)

    print(f"   ✓ Output shape: {actual_shape} (expected: {expected_shape})")

    # Check vocabulary mapping works
    non_inf_mask = ~torch.isinf(output.logits[0, 0])
    num_valid = non_inf_mask.sum().item()

    print(
        f"   ✓ Valid logit positions: {num_valid} (expected: {eagle3_config.draft_vocab_size})"
    )

    print(f"\n{'=' * 60}")
    print("✅ Verification completed successfully!")
    print("\nThe Eagle-3 model is working correctly with:")
    print(f"- Fusion layer: 3×{hidden_size}{hidden_size}")
    print(f"- Custom attention: accepts 2×{hidden_size} input")
    print(
        f"- Vocabulary mapping: {eagle3_config.draft_vocab_size}{eagle3_config.target_vocab_size}"
    )


if __name__ == "__main__":
    main()

Example Output

Eagle-3 Verification Script
============================================================
Checkpoint: nm-testing/SpeculatorLlama3-1-8B-Eagle3
Verifier: HuggingFaceTB/SmolLM2-135M

1. Getting checkpoint path...
   → Downloading from HuggingFace: nm-testing/SpeculatorLlama3-1-8B-Eagle3
   ✓ Downloaded to: /home/user/.cache/huggingface/hub/models--nm-testing--SpeculatorLlama3-1-8B-Eagle3/snapshots/2670830690477b928f17b1237b05e9ff3b9d0255

2. Loading checkpoint configuration...
   ✓ Hidden size: 4096
   ✓ Attention heads: 32
   ✓ KV heads: 8

3. Creating Eagle3SpeculatorConfig...
   ✓ Draft vocabulary: 32000
   ✓ Target vocabulary: 128256

4. Loading verifier model...
   ✓ Loaded verifier: HuggingFaceTB/SmolLM2-135M

5. Creating Eagle3Speculator...
   ✓ Model created successfully

6. Loading checkpoint weights...
   ✓ All Eagle-3 weights loaded successfully

7. Checking vocabulary mapping tensors...
   ✓ d2t shape: torch.Size([32000]) (draft to target mapping)
   ✓ t2d shape: torch.Size([128256]) (target to draft availability)

8. Running forward pass test...
   ✓ Output shape: (1, 5, 128256) (expected: (1, 5, 128256))
   ✓ Valid logit positions: 32000 (expected: 32000)

============================================================
✅ Verification completed successfully!

The Eagle-3 model is working correctly with:
- Fusion layer: 3×4096 → 4096
- Custom attention: accepts 2×4096 input
- Vocabulary mapping: 32000 → 128256
@github-actions
Copy link
Copy Markdown

github-actions bot commented Jul 11, 2025

📦 Build Artifacts Available
The build artifacts (`.whl` and `.tar.gz`) have been successfully generated and are available for download: https://github.com/neuralmagic/speculators/actions/runs/16275592566/artifacts/3529661618.
They will be retained for up to 30 days.
Commit: 1de8bc2

Eagle-3 extends EAGLE with vocabulary mapping for cross-tokenizer speculation,
enabling draft models with different tokenizers than the target model.

Key features:
- Vocabulary mapping between draft (32K) and target (128K) vocabularies
- Custom attention layer accepting 2x hidden_size input
- Fusion layer processing 3 verifier layers (3×H → H)
- Configurable layer normalization placement (before/after residual)
- Full checkpoint compatibility with HuggingFace models

Implementation includes:
- Eagle3SpeculatorConfig with vocabulary size configuration
- Eagle3Attention module for modified attention computation
- Eagle3DecoderLayer processing concatenated embeddings and hidden states
- Eagle3Speculator main model with vocabulary mapping support
- Proper type annotations and mypy compatibility

Based on: https://arxiv.org/abs/2503.01840

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>
Co-Authored-By: Rahul Tuli <rtuli@redhat.com>
@rahul-tuli rahul-tuli self-assigned this Jul 11, 2025
@rahul-tuli rahul-tuli marked this pull request as ready for review July 11, 2025 17:57
Copy link
Copy Markdown
Collaborator

@shanjiaz shanjiaz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great! Ran the verification and it works very well.

rahul-tuli and others added 3 commits July 14, 2025 09:11
- Add target_hidden_size field to Eagle3SpeculatorConfig
- Update fusion layer to use target model's hidden size (3 × target_hidden_size)
- Properly handle 70B models where target hidden size (8192) differs from draft model (6144)
- Update docstrings to clarify hidden states dimensions

This fixes compatibility with Eagle3-LLaMA3.3-Instruct-70B checkpoint.

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>
Co-Authored-By: Rahul Tuli <rtuli@redhat.com>
Copy link
Copy Markdown
Collaborator

@markurtz markurtz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll go through in detail a bit later around the implementation, but quick thing is we need to add in unit tests at a minimum

Copy link
Copy Markdown
Collaborator

@MeganEFlynn MeganEFlynn left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me, only question is if we need to handle the case where the embeddings are not saved as part of the state dict from the saved draft model, but need to be pulled from the verifier?

@dsikka
Copy link
Copy Markdown
Collaborator

dsikka commented Jul 16, 2025

Landing - @markurtz to follow-up with tests

@dsikka dsikka merged commit 5d4ff90 into main Jul 16, 2025
10 checks passed
@rahul-tuli rahul-tuli deleted the eagle-3 branch February 24, 2026 12:37
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

5 participants