Core Kernels
This page documents the compute-heavy entry points exported from triton_ops.
fused_rmsnorm_rope
fused_rmsnorm_rope(
x: torch.Tensor,
weight: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
eps: float = 1e-6,
num_heads: int | None = None,
) -> torch.Tensor
Purpose:
- Apply RMSNorm and RoPE in one kernel launch.
- Avoid materializing the normalized intermediate back to HBM.
Input contract:
xmust be a contiguous CUDA tensor with shape[batch, seq_len, hidden_dim].weightmust be a contiguous CUDA tensor with shape[hidden_dim].cosandsinmust be contiguous CUDA tensors with matching shapes.- Supported RoPE cache shapes are:
[seq_len, head_dim][1, seq_len, 1, head_dim]
head_dimmust be even.- If
num_headsis omitted, the function infers it fromhidden_dim / head_dim.
Output:
- Same shape as
x. - Same dtype as
x.
Common failures:
DeviceErrorwhen tensors are not on CUDA.ShapeMismatchErrorwhen shapes are inconsistent.UnsupportedDtypeErrorwhen the tensors are not floating dtypes accepted by validation.
Example:
import torch
from triton_ops import fused_rmsnorm_rope
x = torch.randn(2, 128, 4096, device="cuda", dtype=torch.float16)
weight = torch.ones(4096, device="cuda", dtype=torch.float16)
cos = torch.randn(128, 64, device="cuda", dtype=torch.float16)
sin = torch.randn(128, 64, device="cuda", dtype=torch.float16)
y = fused_rmsnorm_rope(x, weight, cos, sin)
FusedRMSNormRoPE
FusedRMSNormRoPE(hidden_dim: int, head_dim: int, eps: float = 1e-6)
This wrapper owns the RMSNorm weight parameter and still expects external cos and sin tensors at call time:
module = FusedRMSNormRoPE(4096, 64).cuda()
out = module(x, cos, sin)
Integration note:
- This is not a drop-in replacement for a standalone LayerNorm or RMSNorm module because RoPE inputs are part of the forward contract.
fused_gated_mlp
fused_gated_mlp(
x: torch.Tensor,
gate_weight: torch.Tensor,
up_weight: torch.Tensor,
activation: Literal["silu", "gelu"] = "silu",
) -> torch.Tensor
Current formula in both the Triton kernel and the reference implementation:
output = activation(gate_proj(x)) * up_proj(x)
Input contract:
x: contiguous CUDA tensor with shape[batch, seq_len, hidden_dim]gate_weight: contiguous CUDA tensor with shape[intermediate_dim, hidden_dim]up_weight: same shape asgate_weightactivation:"silu"or"gelu"
Output:
- Shape
[batch, seq_len, intermediate_dim] - Same dtype as
x
Important boundary:
- This kernel implements the gated expansion stage only.
- A full transformer MLP block still needs the down projection and residual path outside this function.
FusedGatedMLP
FusedGatedMLP(
hidden_dim: int,
intermediate_dim: int,
activation: Literal["silu", "gelu"] = "silu",
)
The module owns gate_weight and up_weight and forwards to fused_gated_mlp.
module = FusedGatedMLP(4096, 11008, activation="silu").cuda().half()
y = module(x)
fp8_gemm
fp8_gemm(
a: torch.Tensor,
b: torch.Tensor,
a_scale: torch.Tensor | None = None,
b_scale: torch.Tensor | None = None,
output_dtype: torch.dtype = torch.float16,
) -> torch.Tensor
Behavior:
- If
aorbis floating-point, the function quantizes it internally withquantize_fp8. - If
aorbis already in the repository’s FP8 storage format, the matching scale tensor is required. - The current maintained runtime path uses the repository’s
uint8-based FP8 compatibility format.
Input contract:
aandbmust be contiguous CUDA tensors.- Matrix shapes must be
[M, K]and[K, N]. - Pre-quantized inputs require scalar scale tensors on the same device.
Output:
- Shape
[M, N] - In normal usage,
torch.float16ortorch.bfloat16
Practical note:
- The validation helper accepts
torch.float32as an output dtype, but the Triton implementation is written around half-precision output paths. Treatfloat16andbfloat16as the maintained choices.
FP8Linear
FP8Linear(in_features: int, out_features: int, bias: bool = False)
Behavior:
- Stores a trainable floating-point
weightparameter. - On the first forward pass, quantizes the weight to FP8 and caches:
weight_fp8weight_scaleweight_fp8_t(transposed, contiguous)
- Uses
fp8_gemmfor the forward path.
Important integration caveat:
- The cached quantized weight is not automatically refreshed after weight updates.
- That makes
FP8Lineara better fit for inference or for phases where weights are stable.
Example:
import torch
from triton_ops import FP8Linear
layer = FP8Linear(4096, 4096).cuda()
x = torch.randn(2, 128, 4096, device="cuda", dtype=torch.float16)
y = layer(x)