算子设计

本页解释仓库 Triton kernel 的主要实现思路。

fused_rmsnorm_rope

核心思路是:在寄存器里尽量保留归一化结果,紧接着完成 RoPE,再把最终输出写回。

设计目标:

  • 逐行计算 RMS 统计量,
  • 乘上 RMSNorm 权重,
  • 立刻按 head pair 做旋转,
  • 只写最终输出,不写中间张量。

为什么重要:

  • 未融合路径通常会先把归一化中间结果写回全局显存,
  • 融合路径可以避免这一步额外的 HBM 往返。

fused_gated_mlp

这个 kernel 会对同一块输入同时计算两条投影:

  • gate projection,
  • up projection。

随后对 gate projection 施加激活,并与 up projection 相乘:

output = activation(gate_proj(x)) * up_proj(x)

这样就把投影与激活的工作收敛到一次 launch 中,而不是拆成多个操作。

fp8_gemm

GEMM kernel 使用的是仓库自定义的 FP8 兼容表示:

  • 数据以 uint8 存储,
  • scale 来自显式标量张量,
  • 用 FP32 做累加,
  • 输出走半精度路径。

代码里还采用了 grouped output tile 排布,以改善 cache locality。

分块启发式

当前 Python launcher 主要根据问题规模做启发式 block 选择,而不是每次调用时在线自动调优。

例如:

  • 大矩阵选择更大的 tile,
  • 当 reduction 维度较小时选择更小的 BLOCK_K
  • Gated MLP 当前实现使用较固定的 tile 参数。

这样做的好处是运行路径更小、更稳定,而更复杂的配置搜索则交给通用 autotuner 层。

为什么 reference 实现很重要

每个 kernel 模块都同时保留了 PyTorch reference 实现,这很关键,因为它提供了:

  • 正确性对照基线,
  • 更容易阅读的数学实现,
  • benchmark 验证所需的参考输出。

仓库强调的不只是“快”,而是“可验证地快”。