Examples

These examples stay close to what the repository actually exports today.

Decoder-block skeleton

import torch
from triton_ops import FusedRMSNormRoPE, FusedGatedMLP, FP8Linear

class DecoderBlock(torch.nn.Module):
    def __init__(self, hidden_dim=4096, num_heads=32, intermediate_dim=11008):
        super().__init__()
        head_dim = hidden_dim // num_heads
        self.norm = FusedRMSNormRoPE(hidden_dim, head_dim)
        self.q_proj = FP8Linear(hidden_dim, hidden_dim, bias=False)
        self.k_proj = FP8Linear(hidden_dim, hidden_dim, bias=False)
        self.v_proj = FP8Linear(hidden_dim, hidden_dim, bias=False)
        self.mlp = FusedGatedMLP(hidden_dim, intermediate_dim, activation="silu")

    def forward(self, x, cos, sin):
        normed = self.norm(x, cos, sin)
        q = self.q_proj(normed)
        k = self.k_proj(normed)
        v = self.v_proj(normed)
        mlp_out = self.mlp(normed)
        return q, k, v, mlp_out

This repository does not ship a full attention implementation. Use your existing attention path around the fused normalization, projection, and MLP pieces.

Data models for test input generation

import torch
from triton_ops import RMSNormRoPEInput

spec = RMSNormRoPEInput.from_shapes(
    batch_size=2,
    seq_len=128,
    hidden_dim=4096,
    head_dim=64,
    dtype=torch.float16,
    device="cuda",
)

x = spec.x.create_tensor()
weight = spec.weight.create_tensor(fill_value=1.0)
cos = spec.cos.create_tensor()
sin = spec.sin.create_tensor()

The models.py dataclasses are useful for tests, examples, and benchmark scaffolding.

Benchmarking with BenchmarkSuite

import torch
from triton_ops import BenchmarkSuite
from triton_ops.kernels.rmsnorm_rope import fused_rmsnorm_rope_reference
from triton_ops import fused_rmsnorm_rope

suite = BenchmarkSuite(warmup_runs=5, benchmark_runs=20)

x = torch.randn(2, 128, 4096, device="cuda", dtype=torch.float16)
weight = torch.ones(4096, device="cuda", dtype=torch.float16)
cos = torch.randn(128, 64, device="cuda", dtype=torch.float16)
sin = torch.randn(128, 64, device="cuda", dtype=torch.float16)

result = suite.benchmark_kernel(
    fused_rmsnorm_rope,
    fused_rmsnorm_rope_reference,
    "fused_rmsnorm_rope",
    (2, 128, 4096),
    x,
    weight,
    cos,
    sin,
)

print(result.metrics.latency_ms)
print(suite.generate_report())

Auto-tuning a custom kernel wrapper

import torch
from triton_ops import TritonAutoTuner

def dummy_kernel(x, BLOCK_SIZE=64, num_warps=4):
    return x * 2

tuner = TritonAutoTuner(
    kernel_fn=dummy_kernel,
    config_space={
        "BLOCK_SIZE": [64, 128],
        "num_warps": [4, 8],
    },
    warmup_runs=2,
    benchmark_runs=10,
)

x = torch.randn(1024, device="cuda")
result = tuner.tune(x, problem_size=(1024,), device="cuda:0", kernel_type="dummy")
print(result.best_config)

TritonAutoTuner expects a callable that accepts the searched configuration values as keyword arguments.

FP8 overflow helper

The overflow-handling helper lives in the kernel module rather than the root package export list:

import torch
from triton_ops.kernels.fp8_quantize import quantize_fp8_with_overflow_handling

x = torch.full((1024,), 1000.0, device="cuda", dtype=torch.float16)
q, scale = quantize_fp8_with_overflow_handling(x, max_attempts=3)
print(q.dtype, scale.item())

Next