Trition程序编写:从“Hello CUDA“到“Hello Triton“:向量加法背后的编译黑魔法

写 CUDA Kernel 写了三年,最怕的是什么?不是算法难,是调<<<grid, block>>>那一行永远写不对。

线程索引算错一位,debug 一天。Shared Memory bank conflict 搞不明白,性能掉一半。等到好不容易跑通了,换个 GPU 架构又得重来一遍。

后来同事说:“试试 Triton 吧,一行@triton.jit搞定。”

我当时是不信的。

直到我用 Triton 写完第一个向量加法 Kernel,对比 CUDA 版本的代码量直接腰斩——而且性能居然不输手写 CUDA。这篇文章就来复盘 Triton 程序的完整编写流程,从 API 到实战,把刚上手时最容易踩的坑都给你趟一遍。


一、Triton 到底是什么来头?

Triton 是 OpenAI 搞的开源 GPU 编程语言,定位很明确:比 CUDA 好写,比 PyTorch 灵活

传统 CUDA 编程里,你得手动管线程(thread)、线程束(warp)、线程块(block),每个线程该算哪段数据,索引写错了就是灾难。Triton 直接把这个模型翻了个个——你写代码时假装只处理"一块数据"(tile),编译器负责把这块数据自动拆给几百个线程去并行执行。

这叫tile-based programming model(基于块的编程模型)。核心思路一句话:你关心数据块,Triton 关心线程

整个编译流程是这样走的:

  1. @triton.jit装饰器捕获你函数的AST(抽象语法树),不是直接跑 Python 代码
  2. AST 被转成 Triton IR(MLIR 的自定义方言),里面全是 tile 级别的操作
  3. Triton IR 进一步降到 TritonGPU IR,决定每个 warp 分多少数据、寄存器怎么布局
  4. 最后走 LLVM 生成 PTX,NVIDIA 驱动再转成 SASS 机器码

这一套流程最爽的地方:同一份 Triton 代码,换 GPU 不用改。因为布局优化、warp 分配这些脏活全在编译器里自动完成。


二、核心 API 速览:这五个东西必须先认全

下面这张表是 Triton 编程的"身份证",不记牢后面写代码会不停翻文档。

2.1@triton.jit— 一切的起点

@triton.jitdefmy_kernel(x_ptr,y_ptr,output_ptr,n_elements,BLOCK_SIZE:tl.constexpr):...

关键点:

  • 这不是普通的 Python 装饰器——它不会执行你写的代码,而是把函数体抓成 AST 丢给编译器。
  • tl.constexpr标记的参数是编译期常量。同一个 kernel 用不同BLOCK_SIZE调用,编译器会分别生成两份优化过的机器码——这叫"特化"(specialization),是 Triton 性能不输 CUDA 的核心原因之一。
  • Kernel 函数里不能随便写 Python,只能用tl.load/tl.store/tl.arange这些 Triton DSL 操作。

2.2triton.autotune— 让 GPU 自己挑参数

@triton.autotune(configs=[triton.Config(kwargs={'BLOCK_SIZE':128},num_warps=4),triton.Config(kwargs={'BLOCK_SIZE':1024},num_warps=8),],key=['x_size'])@triton.jitdefkernel(x_ptr,x_size,**META):BLOCK_SIZE=META['BLOCK_SIZE']

这是 Triton 最让我惊艳的功能。你不用猜BLOCK_SIZE设 128 还是 1024 性能好——把候选配置丢进去,Triton 会逐个编译并跑一遍,自动选出最优的那个

几个要注意的坑:

  1. key参数是用来分组缓存的。如果key=['x_size'],当x_size变化时才会重新评估所有配置。设计 key 的时候只放"会影响性能选择"的参数,别把什么都塞进去,否则 autotune 开销爆炸。
  2. autotune 会把 kernel 跑很多遍,如果你 kernel 里会修改全局状态(比如累加计数),必须用reset_to_zero参数指定哪些 tensor 每次跑前归零。
  3. 第一次调用时 autotune 有预热开销,后面命中缓存就快了。

2.3triton.Config— 四个参数决定生死

一个 Config 对象就是一份"内核配置方案",autotune 会逐个尝试。四个核心参数:

参数含义调优建议
num_warps每个 block 分配的 warp 数(1 warp = 32 线程)VI00 用 2-4,A100 用 4-8,H100 可上 8-16
num_stages异步数据预取的流水线深度计算密集型 2-3,访存密集型(如 MatMul)3-5
num_ctasblock cluster 中的 block 数(SM90+ 专属)H100 才需要关注
maxnreg单线程最大寄存器数寄存器溢出时调这个,不是所有平台都支持

最重要的交互num_warpsnum_stages会抢同一块共享内存(shared memory)。warps 越多 → 线程越多 → 每个线程分到的寄存器越少 → 可能触发寄存器溢出(register spilling)。stages 越多 → 预取缓存越大 → 占的 shared memory 越多。加一个就得考虑减另一个,别两个一起拉满。

2.4 Math Ops — 这些算子直接能用

算子说明
tl.abs(x)逐元素绝对值
tl.cdiv(x, div)向上取整除法(算 grid 大小必用)
tl.sqrt(x)快速平方根(硬件近似,比math.sqrt快但精度略低)
tl.softmax(x)Softmax(注意是整块计算,不要自己手写)
tl.cos(x)/tl.sin(x)三角函数

cdiv是最常用的——因为你要根据n_elementsBLOCK_SIZE算出需要多少个 block,公式就是triton.cdiv(n_elements, BLOCK_SIZE)

2.5 Debug Ops — GPU 上的 printf

CUDA 调试痛苦的原因之一:kernel 里打不了断点,只能靠printf。Triton 把 debug 分了两层:

算子阶段用途
tl.static_print(...)编译期打印编译时常量,如BLOCK_SIZE
tl.static_assert(cond)编译期编译时断言,如检查BLOCK_SIZE是 2 的幂
tl.device_print(...)运行期GPU 上实时打印变量值
tl.device_assert(cond)运行期运行时断言,如检查mask范围

static_printstatic_assert非常实用——它们不会产生任何 GPU 指令,只在 JIT 编译时执行,零性能开销。


三、实战:用 Triton 写向量加法

光看 API 没用,直接上代码。

3.1 Kernel 函数

@triton.jitdefadd_kernel(x_ptr,y_ptr,output_ptr,n_elements,BLOCK_SIZE:tl.constexpr):# Step 1: 我是第几个 block?pid=tl.program_id(axis=0)# Step 2: 这个 block 负责的数据起始位置block_start=pid*BLOCK_SIZE# Step 3: 生成这个 block 里的所有偏移量 [0, 1, 2, ..., BLOCK_SIZE-1]offsets=block_start+tl.arange(0,BLOCK_SIZE)# Step 4: 最后一个 block 可能越界,做 maskmask=offsets<n_elements# Step 5: 从全局内存加载x=tl.load(x_ptr+offsets,mask=mask)y=tl.load(y_ptr+offsets,mask=mask)# Step 6: 算!output=x+y# Step 7: 写回全局内存tl.store(output_ptr+offsets,output,mask=mask)

这里解释几个新人容易懵的点:

  • tl.program_id(axis=0):Triton 里没有blockIdx.x这种 CUDA 概念,直接用program_id获取"我这个 block 是第几个"。axis=0 就是一维 grid,axis=1 / axis=2 对应二维 / 三维。
  • tl.arange(0, BLOCK_SIZE):生成一个从 0 到 BLOCK_SIZE-1 的向量。注意这不是 Python 的 range,而是一个 GPU 上的向量,后续所有操作都是按这个向量并行展开的。
  • mask=offsets < n_elements:数据总长度不一定是 BLOCK_SIZE 的整数倍,最后一个 block 会多算一些位置。mask 确保这些"越界"的偏移量不会被真的读写——tl.loadtl.store遇到 mask=False 的位置会直接跳过。
  • 指针运算x_ptr + offsets:Triton 里指针是整型,直接加偏移量就行,不需要&x_ptr[offsets]这种语法。

3.2 封装调用函数

defadd(x:torch.Tensor,y:torch.Tensor):# 分配输出 tensoroutput=torch.empty_like(x)# 安全检查:数据必须在 GPU 上assertx.is_cudaandy.is_cudaandoutput.is_cuda n_elements=output.numel()# 计算 grid:需要多少个 block?grid=lambdameta:(triton.cdiv(n_elements,meta['BLOCK_SIZE']),)# 启动 kernel!add_kernel[grid](x,y,output,n_elements,BLOCK_SIZE=1024)returnoutput

最需要解释的是grid = lambda meta: ...这个写法:

  • meta是一个字典,包含BLOCK_SIZE等编译期常量。这里meta['BLOCK_SIZE']就是 1024。
  • 返回值是一个元组(grid_x, grid_y, grid_z),这里只有一维所以是单元素元组。
  • add_kernel[grid]这种调用语法类似 CUDA 的<<<grid, block>>>,只不过 Triton 的"block 大小"已经在BLOCK_SIZE: tl.constexpr里定义好了,这里只指定 grid。

3.3 运行结果

$ python 01-vector-add.py

输出显示 Triton 计算结果与 PyTorch 原生+算子的最大差异为0.0——完全一致。

性能对比那块更有意思:从 4096 个元素一路测到 1.34 亿个元素,Triton 版本和 PyTorch(底层也是 CUDA)的耗时几乎完全重叠,差距在 1% 以内。这说明用 Triton 写的向量加法,编译出来的机器码质量不输 PyTorch 高度优化的 CUDA kernel


四、踩坑记录:我在 Triton 上栽过的跟头

写几个自己实际遇到、PPT 里不会直接说的坑:

坑 1:BLOCK_SIZE不是越大越好

直觉上 block 越大并行度越高,但 block 太大会导致:① 寄存器不够用,触发 spilling,性能反而暴跌;② shared memory 不够用(如果你的 kernel 用了)。向量加法这种极简单 kernel,1024 是个不错的默认值;复杂 kernel 如矩阵乘法,每维 64-128 更常见。

坑 2:mask 没写对,静默出 bug

tl.loadmask参数如果不传,越界的地址会读到未定义值——GPU 上不会直接 crash,但算出来的结果可能完全对,也可能偶尔错,特别难排查。任何带offsetsload/store都要检查边界

坑 3:autotune 第一次跑很慢

autotune 会逐配置编译+运行,候选配置多的话第一次调用可能要等几十秒甚至几分钟。这正常,因为 Triton 在 JIT 编译。第二次调用命中缓存就秒开了。生产环境建议提前 warmup。

坑 4:num_stages不是越大越好

num_stages增加异步预取的流水线深度,能隐藏访存延迟,但每多一级 stage 就多占一块 shared memory。如果你的 kernel 本身 shared memory 用量就高(比如矩阵乘法里的大块 tile),再加 stages 会爆 shared memory 容量,编译直接失败。


五、小结

用 Triton 写 GPU 程序的体验,打个不恰当的比方:CUDA 像手动挡,每个换挡时机都得自己把握;Triton 像自动挡 + 运动模式,把最烦的线程调度交给编译器,但关键参数(BLOCK_SIZE、num_warps、num_stages)你仍然能调。

回到开头那个向量加法——从 CUDA 迁移到 Triton,代码量减半,性能持平,而且换个 GPU 不用改一行代码。对于大部分"我需要一个自定义 kernel,但不想为线程索引掉头发"的场景,Triton 是目前最好的选择。

本文基于杜玉博老师《Triton程序编写》PPT 整理,图片均为原 PPT 截图。代码示例可在 Triton 官方仓库 找到完整教程。