Skip to content

insop/kernel-coder

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

9 Commits
 
 
 
 
 
 

Repository files navigation

πŸš€ Kernel Coder: GRPO-Trained Triton Kernel Model

An AI code assistant that helps write Triton kernel code.

Summary:

We introduce Kernel Coder, an AI assistant for writing Triton GPU kernels, trained using a novel GRPO-based reinforcement learning pipeline. By combining format validation and similarity rewards, we optimize the compact Qwen2.5-Coder-3B model on the KernelBook dataset. Our approach aims to match or surpass KerneLllm (Llama 3.1 8B) in generating correct and efficient Triton code.

πŸ’‘ Introduction

We design an RL training pipeline to train a base model for generating Triton Kernel code. Triton is a Python-based DSL for GPU programming. Inspired by DeepSeek-R1-Zero, we implement a GRPO-based RL pipeline to train a base model (Qwen2.5-Coder-3B).

πŸ“‘ Table of Contents

Introduction

🎯 Goal: RL-train the Qwen2.5-Coder-3B base model on a Triton kernel dataset (KernelBook), aiming for competitive performance compared to the SFT-trained KerneLllm (based on Llama-3.1-8B).

πŸ—οΈ Reward Design

We design the reward function with two components:

  1. βœ… Format Checking: Validate correct usage of <thinking> and <answer> tags.
  2. πŸ” Similarity Score: Measure string similarity between generated and ground-truth Triton kernels using Python’s difflib.SequenceMatcher. This idea is inspired by SWE-RL.

πŸ§ͺ Evaluation

We evaluate the generated Triton kernels using KernelBench (triton_backend_v2 branch) on:

  • The base model (Qwen2.5-Coder-3B)
  • The SFT model (KernelLLM)
  • kernel-coder (our model): we will evaluate once the training is complete

🌟 Our Contributions

  • πŸŽ“ DeepSeek R1-Zero style RL pipeline for Triton kernel generation
  • πŸ“Š Reward model design: Combining format and similarity-based rewards

πŸ”­ Next Steps (Post June 1st)

  • πŸ§ͺ Add verifiable rewards: Use KernelBench to check compilation, correctness, and speedup.
  • πŸ”„ Explore knowledge distillation: Distill KerneLllm into a smaller model before applying RL training, then compare with our RL-trained model.

🧠 Model

We apply GRPO training to Qwen2.5-Coder-3B, a compact yet strong code model from the Qwen 2.5 family, balancing performance and compute cost.

πŸ”„ RL Training

Group Relative Policy Optimization (GRPO), proposed by DeepSeek, uses rule-based rewards for math and code tasks. GRPO avoids using a value model, instead estimating the advantage from relative reward rankings across multiple rollouts:

$$ \begin{aligned} \hat{A}_{i,t} = \tilde{r}_i = \frac{r_i - \text{mean}(\mathbf{r})}{\text{std}(\mathbf{r})} \end{aligned} $$

This approach improves efficiency by comparing rollout quality relative to the batch.

πŸ§ͺ Task

Generate Triton kernels equivalent or superior to provided PyTorch kernels.

πŸ“Š Evaluation Results

Using KernelBench (Cuda and Triton kernel benchmark), (triton_backend_v2 branch) to evaluate:

Model Compilation Rate (%) Correctness Rate (%)
KernelLLM 77.0% 12.0%
Qwen2.5-Coder-3B (untrained) 29.0% 3.0%
kernel-coder (ours, GRPO-trained) 🚧 TBD 🚧 TBD

We test on label 1 (100 test cases) with temperature 1.0 and top_p 0.97. Preliminary results show the importance of Triton-specific training for compilation and correctness.

πŸ“‚ Code Structure

The codebase consists of two main components:

  1. nano_r1_script.py - Modified for our project and originally from nano-aha-moment
  2. KernelBench - Forked and modified from ScalingIntelligence/KernelBench

Project structure:

kernel-coder/
β”œβ”€β”€ README.md
β”œβ”€β”€ kernel-coder
β”‚   β”œβ”€β”€ nano_r1_script.py # main code
β”‚   └── utils.py
└── scripts
    └── kernelllm.py # helper script from KernelLLM model, https://huggingface.co/facebook/KernelLLM

πŸƒ How To Run

cd kernel-coder # cd to the project root
python kernel-coder/nano_r1_script.py --nproc 8  --max_response_tokens 2048

πŸ™ Acknowledgements

We build on the following resources:

  1. πŸ“Š KernelBench (Cuda and Triton kernel benchmark)
  2. πŸ”₯ KerneLllm (SFT model with KernelBook dataset)
  3. πŸ“š KernelBook (Triton Kernel Dataset)
  4. πŸ§ͺ nano-aha-moment (simple GRPO pipeline)

About

AI code assistant helps to write Kernel code.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages