示例教程
这里的示例尽量贴合仓库当前真实导出的接口与行为。
Decoder block 骨架
import torch
from triton_ops import FusedRMSNormRoPE, FusedGatedMLP, FP8Linear
class DecoderBlock(torch.nn.Module):
def __init__(self, hidden_dim=4096, num_heads=32, intermediate_dim=11008):
super().__init__()
head_dim = hidden_dim // num_heads
self.norm = FusedRMSNormRoPE(hidden_dim, head_dim)
self.q_proj = FP8Linear(hidden_dim, hidden_dim, bias=False)
self.k_proj = FP8Linear(hidden_dim, hidden_dim, bias=False)
self.v_proj = FP8Linear(hidden_dim, hidden_dim, bias=False)
self.mlp = FusedGatedMLP(hidden_dim, intermediate_dim, activation="silu")
def forward(self, x, cos, sin):
normed = self.norm(x, cos, sin)
q = self.q_proj(normed)
k = self.k_proj(normed)
v = self.v_proj(normed)
mlp_out = self.mlp(normed)
return q, k, v, mlp_out
这说明本仓库更适合被当成“优化原语集合”使用,而不是完整的 Attention / Transformer 实现。
用数据模型生成测试输入
import torch
from triton_ops import RMSNormRoPEInput
spec = RMSNormRoPEInput.from_shapes(
batch_size=2,
seq_len=128,
hidden_dim=4096,
head_dim=64,
dtype=torch.float16,
device="cuda",
)
x = spec.x.create_tensor()
weight = spec.weight.create_tensor(fill_value=1.0)
cos = spec.cos.create_tensor()
sin = spec.sin.create_tensor()
models.py 中的 dataclass 很适合用于测试、示例和 benchmark 脚手架。
使用 BenchmarkSuite 做基准测试
import torch
from triton_ops import BenchmarkSuite
from triton_ops.kernels.rmsnorm_rope import fused_rmsnorm_rope_reference
from triton_ops import fused_rmsnorm_rope
suite = BenchmarkSuite(warmup_runs=5, benchmark_runs=20)
x = torch.randn(2, 128, 4096, device="cuda", dtype=torch.float16)
weight = torch.ones(4096, device="cuda", dtype=torch.float16)
cos = torch.randn(128, 64, device="cuda", dtype=torch.float16)
sin = torch.randn(128, 64, device="cuda", dtype=torch.float16)
result = suite.benchmark_kernel(
fused_rmsnorm_rope,
fused_rmsnorm_rope_reference,
"fused_rmsnorm_rope",
(2, 128, 4096),
x,
weight,
cos,
sin,
)
print(result.metrics.latency_ms)
print(suite.generate_report())
对自定义 kernel wrapper 做自动调优
import torch
from triton_ops import TritonAutoTuner
def dummy_kernel(x, BLOCK_SIZE=64, num_warps=4):
return x * 2
tuner = TritonAutoTuner(
kernel_fn=dummy_kernel,
config_space={
"BLOCK_SIZE": [64, 128],
"num_warps": [4, 8],
},
warmup_runs=2,
benchmark_runs=10,
)
x = torch.randn(1024, device="cuda")
result = tuner.tune(x, problem_size=(1024,), device="cuda:0", kernel_type="dummy")
print(result.best_config)
TritonAutoTuner 期望被调优的函数能够接受搜索出来的配置参数作为关键字参数。
FP8 溢出辅助函数
带溢出处理的 helper 位于 kernel 模块中,而不是根包导出列表里:
import torch
from triton_ops.kernels.fp8_quantize import quantize_fp8_with_overflow_handling
x = torch.full((1024,), 1000.0, device="cuda", dtype=torch.float16)
q, scale = quantize_fp8_with_overflow_handling(x, max_attempts=3)
print(q.dtype, scale.item())