Skip to content

MemoryWorld/llm-inference-bench

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

3 Commits
 
 
 
 
 
 
 
 
 
 

Repository files navigation

llm-inference-bench

Systematic benchmarking of LLM inference optimization techniques, implemented from scratch and compared against production frameworks.

Model: Qwen/Qwen2.5-7B-Instruct Hardware: NVIDIA RTX 5090 (32 GB VRAM) · CUDA 13.1 · PyTorch 2.9.1


Results Summary

1. KV Cache — from scratch vs naive recomputation

Implemented multi-head KV Cache in pure PyTorch. Each decode step caches K/V from previous tokens instead of recomputing the full sequence.

FLOPs per decode step: O(n²) → O(n). See benchmark/kv_cache.py.

KV Cache Speedup


2. vLLM (PagedAttention + Continuous Batching) vs HuggingFace baseline

Batch Size HuggingFace (tok/s) vLLM (tok/s) Speedup
1 59.74 103.81 1.74x
2 117.97 207.98 1.76x
4 236.70 414.22 1.75x
8 468.45 824.40 1.76x

At batch=1, speedup comes from FlashAttention kernels and vLLM's C++ engine eliminating Python-level CPU/GPU round-trips — not PagedAttention or Continuous Batching (which require concurrent requests to take effect). Consistent 1.75x across all batch sizes confirms this is a kernel/framework overhead difference.

Throughput Comparison


3. Quantization: FP16 / INT8 / NF4 with Roofline Analysis

LLM autoregressive inference at batch=1 is memory-bandwidth-bound: every token generation loads all model weights from HBM. Quantization reduces weight size → less data transferred → higher throughput ceiling.

RTX 5090 HBM bandwidth: ~1792 GB/s

dtype Throughput (tok/s) Roofline (tok/s) TTFT (ms) Peak VRAM Perplexity (WikiText-2)
FP16 64.82 117.9 21.81 14.22 GB 7.798
INT8 12.30 235.8 96.10 8.20 GB 7.787
NF4 46.54 471.6 31.34 5.35 GB 8.089

Key finding: bitsandbytes INT8 is slower than FP16 despite the smaller model. Runtime dequantization overhead (INT8 → FP16 before each matmul) outweighs the memory bandwidth savings at batch=1. Its real benefit is enabling higher concurrency within a VRAM budget, not single-request throughput.

NF4: 62% VRAM reduction (14.2 → 5.4 GB) with only 3.7% perplexity degradation (7.80 → 8.09), enabling ~2.6x higher theoretical concurrency on the same hardware.

Quantization Tradeoff


4. Flash Attention — Triton kernel from scratch

Implemented FlashAttention v1 (Dao et al., NeurIPS 2022) as a custom Triton kernel. FLOPs are identical to naive attention — every speedup comes from HBM memory I/O reduction.

Key idea: tile Q/K/V into SRAM blocks (BLOCK_M=128, BLOCK_N=64) and maintain a running (max, sum) per row for online softmax, so the N×N score matrix never materialises in HBM.

RTX 5090 · BATCH=1 · N_HEADS=32 · HEAD_DIM=128

seq_len Naive (ms) torch SDPA (ms) Triton (ours, ms) TFLOPS HBM naive HBM flash Mem reduction
512 0.17 0.04 0.05 89T 0.1 GB 0.017 GB 3x
1024 0.84 0.12 0.16 105T 0.2 GB 0.034 GB 5x
2048 3.34 0.47 0.70 98T 0.6 GB 0.067 GB 9x
4096 13.05 1.53 2.12 130T 2.3 GB 0.134 GB 17x
8192 51.40 5.45 7.80 141T 8.9 GB 0.268 GB 33x

Key findings:

  • Our Triton kernel reaches 141 TFLOPS at N=8192, within 1.43x of cuDNN FlashAttn v2
  • HBM traffic reduction grows with sequence length: 33x less memory I/O at N=8192
  • Naive attention at N=8192 loads 8.9 GB from HBM per forward pass; Flash loads only 268 MB
  • Correctness verified: max_err=0.0005 vs naive (fp16 rounding expected)

Flash Attention


5. Speculative Decoding — from scratch

Implemented speculative decoding (Leviathan et al., NeurIPS 2023) from scratch.

  • Draft model: Qwen2.5-0.5B-Instruct (~1 GB fp16)
  • Target model: Qwen2.5-7B-Instruct (~15 GB fp16)
  • Algorithm: draft k tokens serially with 0.5B → verify all k in one parallel 7B forward → accept/reject each token with min(1, p_target/p_draft) → sample corrected token on first rejection

RTX 5090 · n_new=200 · temperature=1.0

k (draft tokens) tok/s Speedup vs naive Acceptance rate
0 (naive) 43.41 1.00x
1 32.27 0.74x 71.6%
2 33.97 0.78x 53.9%
4 33.85 0.78x 42.8%
8 23.18 0.53x 22.6%
12 19.38 0.45x 18.0%

Why no speedup here? Both implementations run without KV cache — every decode step recomputes the full O(n²) attention. Speculative decoding's theoretical gain is:

speedup ≈ (k·α + 1) / (k·T_draft/T_target + 1)

This ratio exceeds 1 only when T_draft << T_target and α is high. Without KV cache, the O(n²) full-sequence forward dominates both model costs, collapsing the T_draft/T_target ratio. The production speedup (2–3x reported in the paper) requires KV-cached inference where each step is O(1) in attention compute.

Acceptance rate analysis: α decays rapidly with k — from 71.6% at k=1 to 18.0% at k=12. The Qwen2.5 0.5B and 7B share the same tokenizer and architectural family, but the 0.5B is not a distilled version of the 7B, so the draft distribution diverges quickly over multi-step chains.

Implementation notes:

  • Shared vocab handling: 7B has 152,064 tokens, 0.5B has 151,936 — all logits truncated to min(both) = 151,936
  • Draft tokens with index ≥ shared_vocab auto-rejected; corrected token sampled from target
  • Rejection path reuses precomputed draft distributions to avoid extra forward passes

Speculative Decoding


Roadmap

Module Description Status
KV Cache From-scratch PyTorch implementation + benchmark ✅ Done
vLLM PagedAttention + Continuous Batching comparison ✅ Done
Quantization FP16 / INT8 / NF4 with roofline + perplexity ✅ Done
Flash Attention Custom Triton kernel (FlashAttn v1) ✅ Done
Speculative Decoding From scratch, 0.5B draft → 7B target, acceptance rate profiling ✅ Done
Tensor Parallelism Multi-GPU inference ⏳ Planned

Structure

llm-inference-bench/
├── benchmark/
│   ├── baseline.py          # HuggingFace naive inference
│   ├── kv_cache.py          # KV Cache from scratch
│   ├── vllm_bench.py        # vLLM comparison
│   ├── quantization.py      # FP16 / INT8 / NF4 benchmark
│   ├── flash_attention.py   # Flash Attention Triton kernel from scratch
│   └── speculative_decoding.py  # Speculative Decoding from scratch
├── scripts/
│   └── compare_all.py       # Generate comparison charts from JSON results
└── results/                 # JSON data + PNG charts

Setup

pip install torch transformers vllm bitsandbytes datasets matplotlib

Each script is self-contained. Run from the benchmark/ directory:

cd benchmark
python baseline.py        # ~5 min
python kv_cache.py        # ~3 min
python vllm_bench.py      # ~5 min
python quantization.py    # ~15 min
python flash_attention.py      # ~5 min (first run compiles Triton kernel)
python speculative_decoding.py # ~30 min (loads two models, sweeps k=1,2,4,8,12)

cd ../scripts
python compare_all.py     # regenerate charts from saved JSON

About

Benchmarking LLM inference optimization: KV Cache, vLLM, Quantization on RTX 5090

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages