核心算子
本页记录 triton_ops 当前导出的主要计算接口。
fused_rmsnorm_rope
fused_rmsnorm_rope(
x: torch.Tensor,
weight: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
eps: float = 1e-6,
num_heads: int | None = None,
) -> torch.Tensor
用途:
- 在一次 kernel 启动中完成 RMSNorm 与 RoPE。
- 避免把归一化后的中间结果单独写回 HBM。
输入契约:
x必须是 contiguous 的 CUDA 张量,形状为[batch, seq_len, hidden_dim]。weight必须是 contiguous 的 CUDA 张量,形状为[hidden_dim]。cos与sin必须是形状一致的 contiguous CUDA 张量。- 当前支持的 RoPE cache 形状:
[seq_len, head_dim][1, seq_len, 1, head_dim]
head_dim必须是偶数。- 如果不传
num_heads,会按hidden_dim / head_dim自动推断。
输出:
- 形状与
x相同。 - dtype 与
x相同。
常见错误:
DeviceError:输入不在 CUDA 上。ShapeMismatchError:形状不匹配。UnsupportedDtypeError:dtype 不在支持范围内。
FusedRMSNormRoPE
FusedRMSNormRoPE(hidden_dim: int, head_dim: int, eps: float = 1e-6)
该模块内部持有 RMSNorm 的权重参数,但前向仍然要求显式传入 cos 与 sin:
module = FusedRMSNormRoPE(4096, 64).cuda()
out = module(x, cos, sin)
集成提醒:
- 它不是普通 LayerNorm/RMSNorm 的直接替代品,因为前向契约包含 RoPE 输入。
fused_gated_mlp
fused_gated_mlp(
x: torch.Tensor,
gate_weight: torch.Tensor,
up_weight: torch.Tensor,
activation: Literal["silu", "gelu"] = "silu",
) -> torch.Tensor
当前 Triton kernel 与参考实现都遵循:
output = activation(gate_proj(x)) * up_proj(x)
输入契约:
x:contiguous CUDA 张量,形状[batch, seq_len, hidden_dim]gate_weight:contiguous CUDA 张量,形状[intermediate_dim, hidden_dim]up_weight:形状必须与gate_weight相同activation:只能是"silu"或"gelu"
输出:
- 形状
[batch, seq_len, intermediate_dim] - dtype 与
x相同
重要边界:
- 该 kernel 只覆盖 gated expansion 这一步。
- 完整 Transformer MLP 仍然需要在外部补上下投影与 residual 逻辑。
FusedGatedMLP
FusedGatedMLP(
hidden_dim: int,
intermediate_dim: int,
activation: Literal["silu", "gelu"] = "silu",
)
这个模块内部持有 gate_weight 与 up_weight,前向时会调用 fused_gated_mlp。
fp8_gemm
fp8_gemm(
a: torch.Tensor,
b: torch.Tensor,
a_scale: torch.Tensor | None = None,
b_scale: torch.Tensor | None = None,
output_dtype: torch.dtype = torch.float16,
) -> torch.Tensor
行为说明:
- 如果
a或b还是浮点张量,函数会先内部调用quantize_fp8。 - 如果输入已经是仓库定义的 FP8 存储格式,则必须同时提供对应 scale。
- 当前维护的运行路径使用基于
uint8的 FP8 兼容表示。
输入契约:
a与b必须是 contiguous 的 CUDA 张量。- 矩阵形状必须分别为
[M, K]与[K, N]。 - 预量化输入需要在同一设备上提供标量 scale 张量。
输出:
- 形状
[M, N] - 实际使用时请优先视为
torch.float16或torch.bfloat16输出路径
实践提醒:
- 校验层允许
torch.float32作为output_dtype,但 Triton 实现的维护重点仍是半精度输出路径。实践中应优先使用float16/bfloat16。
FP8Linear
FP8Linear(in_features: int, out_features: int, bias: bool = False)
行为:
- 内部保留一个可训练的浮点
weight参数。 - 第一次前向时,量化并缓存:
weight_fp8weight_scaleweight_fp8_t(转置后且 contiguous)
- 前向调用时使用
fp8_gemm完成计算。
重要集成提醒:
- 缓存后的 FP8 权重不会在权重更新后自动刷新。
- 因此
FP8Linear更适合推理场景,或者权重稳定的阶段。