API Reference
Complete API documentation for LLM-Speed.
Table of Contents
Installation
# Install from source
pip install -e .
# Verify installation
python -c "import cuda_llm_ops; print(cuda_llm_ops.__version__)"
Module Overview
import cuda_llm_ops
# List all available functions
dir(cuda_llm_ops)
# ['flash_attention', 'tiled_attention', 'naive_attention',
# 'gemm', 'tensor_core_gemm', 'tensor_core_gemm_int8', '__version__']
Attention Functions
flash_attention
FlashAttention with O(N) memory complexity using online softmax algorithm.
def flash_attention(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
scale: float = 0.0,
is_causal: bool = False
) -> torch.Tensor
Parameters
| Parameter |
Type |
Default |
Description |
q |
torch.Tensor |
Required |
Query tensor, shape [batch, heads, seq_len, head_dim] |
k |
torch.Tensor |
Required |
Key tensor, shape [batch, heads, seq_len, head_dim] |
v |
torch.Tensor |
Required |
Value tensor, shape [batch, heads, seq_len, head_dim] |
scale |
float |
0.0 |
Attention scale factor. If 0.0, uses 1/√head_dim |
is_causal |
bool |
False |
Enable causal mask for autoregressive models |
Returns
| Type |
Description |
torch.Tensor |
Output tensor, shape [batch, heads, seq_len, head_dim] |
Raises
| Exception |
Condition |
RuntimeError |
Input tensors not 4D |
RuntimeError |
Input tensors not on CUDA device |
RuntimeError |
Input tensors not contiguous |
RuntimeError |
Shape mismatch between Q, K, V |
RuntimeError |
Unsupported dtype (not float32/float16) |
Examples
import torch
from cuda_llm_ops import flash_attention
# Standard attention
batch, heads, seq_len, head_dim = 2, 8, 512, 64
q = torch.randn(batch, heads, seq_len, head_dim, device='cuda', dtype=torch.float16)
k = torch.randn_like(q)
v = torch.randn_like(q)
output = flash_attention(q, k, v)
print(output.shape) # torch.Size([2, 8, 512, 64])
# Causal attention (for autoregressive models like GPT)
output_causal = flash_attention(q, k, v, is_causal=True)
# Custom scale factor
output_scaled = flash_attention(q, k, v, scale=0.125) # 1/8 instead of 1/√64
Memory Usage Comparison
| Sequence Length |
Standard Attention |
FlashAttention |
Reduction |
| 1024 |
4 MB |
0.25 MB |
94% |
| 2048 |
16 MB |
0.5 MB |
97% |
| 4096 |
64 MB |
1 MB |
98% |
tiled_attention
Tiled attention with shared memory optimization. Suitable for medium-length sequences.
def tiled_attention(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
scale: float = 0.0,
is_causal: bool = False
) -> torch.Tensor
Parameters
| Parameter |
Type |
Default |
Description |
q |
torch.Tensor |
Required |
Query tensor, shape [batch, heads, seq_len, head_dim] |
k |
torch.Tensor |
Required |
Key tensor, shape [batch, heads, seq_len, head_dim] |
v |
torch.Tensor |
Required |
Value tensor, shape [batch, heads, seq_len, head_dim] |
scale |
float |
0.0 |
Attention scale factor. If 0.0, uses 1/√head_dim |
is_causal |
bool |
False |
Enable causal mask for autoregressive models |
Returns
Same as flash_attention.
Examples
from cuda_llm_ops import tiled_attention
output = tiled_attention(q, k, v)
# With custom scale
output_scaled = tiled_attention(q, k, v, scale=0.1)
Notes
- More efficient than
naive_attention for sequences ≥128
- Still stores attention matrix internally (O(N²) memory)
- Not recommended for sequences >2048
naive_attention
Baseline attention implementation with O(N²) memory complexity. Used primarily for correctness verification.
def naive_attention(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
scale: float = 0.0,
is_causal: bool = False
) -> torch.Tensor
Parameters
| Parameter |
Type |
Default |
Description |
q |
torch.Tensor |
Required |
Query tensor, shape [batch, heads, seq_len, head_dim] |
k |
torch.Tensor |
Required |
Key tensor, shape [batch, heads, seq_len, head_dim] |
v |
torch.Tensor |
Required |
Value tensor, shape [batch, heads, seq_len, head_dim] |
scale |
float |
0.0 |
Attention scale factor. If 0.0, uses 1/√head_dim |
is_causal |
bool |
False |
Enable causal mask for autoregressive models |
Returns
Same as flash_attention.
Warning
Memory Alert: This implementation stores the full N×N attention matrix. For long sequences (N > 1024), this may cause out-of-memory errors. Use flash_attention for production workloads.
Examples
from cuda_llm_ops import naive_attention
# Only recommended for short sequences or testing
q = torch.randn(2, 4, 64, 32, device='cuda', dtype=torch.float16)
k = torch.randn_like(q)
v = torch.randn_like(q)
output = naive_attention(q, k, v)
# Verify correctness against PyTorch reference
reference = torch.nn.functional.scaled_dot_product_attention(q, k, v)
assert torch.allclose(output, reference, rtol=1e-3, atol=1e-3)
GEMM Functions
gemm
High-performance general matrix multiplication with register tiling.
def gemm(
a: torch.Tensor,
b: torch.Tensor,
alpha: float = 1.0,
beta: float = 0.0,
trans_a: bool = False,
trans_b: bool = False
) -> torch.Tensor
Parameters
| Parameter |
Type |
Default |
Description |
a |
torch.Tensor |
Required |
Matrix A, shape [M, K] (or [K, M] if trans_a=True) |
b |
torch.Tensor |
Required |
Matrix B, shape [K, N] (or [N, K] if trans_b=True) |
alpha |
float |
1.0 |
Scaling factor for A @ B |
beta |
float |
0.0 |
Scaling factor for existing C (currently unused) |
trans_a |
bool |
False |
Transpose matrix A |
trans_b |
bool |
False |
Transpose matrix B |
Returns
| Type |
Description |
torch.Tensor |
Output matrix C, shape [M, N] |
Raises
| Exception |
Condition |
RuntimeError |
Input tensors not 2D |
RuntimeError |
Inner dimensions don’t match |
RuntimeError |
Tensors not on CUDA or not contiguous |
RuntimeError |
Unsupported dtype |
Examples
import torch
from cuda_llm_ops import gemm
# Standard C = A @ B
M, K, N = 1024, 512, 1024
a = torch.randn(M, K, device='cuda', dtype=torch.float16)
b = torch.randn(K, N, device='cuda', dtype=torch.float16)
c = gemm(a, b)
# With scaling: C = 2.0 * A @ B
c = gemm(a, b, alpha=2.0)
# Handle transposed matrices
a_t = torch.randn(K, M, device='cuda', dtype=torch.float16) # Actually A^T
b = torch.randn(K, N, device='cuda', dtype=torch.float16)
# Compute A^T @ B
c = gemm(a_t, b, trans_a=True)
Layout Equivalence Table
trans_a |
trans_b |
Operation |
| False |
False |
C = A @ B |
| False |
True |
C = A @ B^T |
| True |
False |
C = A^T @ B |
| True |
True |
C = A^T @ B^T |
tensor_core_gemm
Tensor Core accelerated matrix multiplication using WMMA API.
def tensor_core_gemm(
a: torch.Tensor,
b: torch.Tensor,
alpha: float = 1.0,
beta: float = 0.0
) -> torch.Tensor
Parameters
| Parameter |
Type |
Description |
a |
torch.Tensor |
Matrix A, shape [M, K], must be float16 |
b |
torch.Tensor |
Matrix B, shape [K, N], must be float16 |
alpha |
float |
Scaling factor |
beta |
float |
Scaling factor for existing C |
Returns
| Type |
Description |
torch.Tensor |
Output matrix C, shape [M, N], dtype float32 |
Hardware Requirements
- GPU: Volta architecture or newer (SM 7.0+)
- PyTorch: 2.0+
Examples
import torch
from cuda_llm_ops import tensor_core_gemm
a = torch.randn(1024, 512, device='cuda', dtype=torch.float16)
b = torch.randn(512, 1024, device='cuda', dtype=torch.float16)
# Returns FP32 output for higher precision accumulation
c = tensor_core_gemm(a, b)
print(c.dtype) # torch.float32
# Scaling
c = tensor_core_gemm(a, b, alpha=0.5)
- Align M, K, N to multiples of 16 for optimal performance
- Use
float16 input for maximum throughput
- Output precision is
float32 for numerical stability
tensor_core_gemm_int8
INT8 quantized matrix multiplication using Tensor Cores.
def tensor_core_gemm_int8(
a: torch.Tensor,
b: torch.Tensor
) -> torch.Tensor
Parameters
| Parameter |
Type |
Description |
a |
torch.Tensor |
Matrix A, shape [M, K], must be int8 |
b |
torch.Tensor |
Matrix B, shape [K, N], must be int8 |
Returns
| Type |
Description |
torch.Tensor |
Output matrix C, shape [M, N], dtype int32 |
Hardware Requirements
- GPU: Turing architecture or newer (SM 7.2+)
- PyTorch: 2.0+
Examples
import torch
from cuda_llm_ops import tensor_core_gemm_int8
# Create INT8 tensors
a = torch.randint(-128, 127, (1024, 512), device='cuda', dtype=torch.int8)
b = torch.randint(-128, 127, (512, 1024), device='cuda', dtype=torch.int8)
# INT32 accumulation for precision
c = tensor_core_gemm_int8(a, b)
print(c.dtype) # torch.int32
# Verify against reference
reference = torch.matmul(a.to(torch.int32), b.to(torch.int32))
assert torch.equal(c, reference)
Checking GPU Compatibility
import torch
capability = torch.cuda.get_device_capability()
if capability[0] > 7 or (capability[0] == 7 and capability[1] >= 2):
print("INT8 Tensor Core is supported")
else:
print("Requires Turing+ GPU (SM 7.2+)")
Tensor Requirements
Attention Functions
| Requirement |
Specification |
| Dimensions |
4D: [batch, heads, seq_len, head_dim] |
| Device |
Must be CUDA (tensor.is_cuda == True) |
| Layout |
Contiguous (tensor.is_contiguous() == True) |
| Dtype |
float32 or float16 |
| Shape Consistency |
Q, K, V shapes must match exactly |
GEMM Functions
| Requirement |
Specification |
| Dimensions |
2D: [M, K] and [K, N] |
| Device |
Must be CUDA |
| Layout |
Contiguous |
| Dtype |
float32, float16, or int8 |
| Dimension Alignment |
Inner dimensions must match |
Error Handling
Common Error Messages
# Example: Wrong dimensions
q = torch.randn(64, 32, device='cuda') # 2D instead of 4D
flash_attention(q, k, v)
# RuntimeError: Q must be 4D tensor [batch, heads, seq_len, head_dim]
# Example: CPU tensor
q = torch.randn(2, 4, 64, 32) # CPU tensor
flash_attention(q, k, v)
# RuntimeError: Q must be on CUDA device
# Example: Non-contiguous tensor
q = torch.randn(2, 4, 64, 32, device='cuda').transpose(1, 2)
flash_attention(q, k, v)
# RuntimeError: Q must be contiguous
# Example: Shape mismatch
q = torch.randn(2, 4, 64, 32, device='cuda')
v = torch.randn(2, 4, 128, 32, device='cuda') # Different seq_len
flash_attention(q, k, v)
# RuntimeError: K and V must have same shape
# Example: Unsupported dtype
q = torch.randn(2, 4, 64, 32, device='cuda', dtype=torch.int32)
flash_attention(q, k, v)
# RuntimeError: Only float32 and float16 are supported
Error Handling Pattern
import torch
from cuda_llm_ops import flash_attention
def safe_flash_attention(q, k, v, **kwargs):
try:
return flash_attention(q, k, v, **kwargs)
except RuntimeError as e:
error_msg = str(e)
if "must be on CUDA device" in error_msg:
print("Error: Please move tensors to CUDA using .cuda()")
elif "must be 4D tensor" in error_msg:
print("Error: Input shapes must be [batch, heads, seq_len, head_dim]")
elif "must have same shape" in error_msg:
print("Error: Q, K, V tensors must have identical shapes")
else:
print(f"Error: {error_msg}")
raise
Memory Optimization
# Use FlashAttention for long sequences
seq_len = 1024
if seq_len >= 512:
output = flash_attention(q, k, v) # O(N) memory
else:
output = naive_attention(q, k, v) # May be faster for short sequences
Precision Selection
# FP16 for inference (recommended)
q_fp16 = q.half()
output = flash_attention(q_fp16, k_fp16, v_fp16)
# FP32 for training or when precision is critical
output = flash_attention(q.float(), k.float(), v.float())
# Tensor Core GEMM for FP16 inputs with FP32 accumulation
c = tensor_core_gemm(a.half(), b.half()) # Returns FP32
Optimal Dimensions
# For best Tensor Core performance, use multiples of 16
def round_up_to_16(x):
return ((x + 15) // 16) * 16
M = round_up_to_16(1000) # 1008
N = round_up_to_16(500) # 512
Batch Processing
# Process multiple sequences together for better GPU utilization
batch_size = 8 # Adjust based on available memory
q = torch.randn(batch_size, heads, seq_len, head_dim, device='cuda', dtype=torch.float16)
import cuda_llm_ops
print(cuda_llm_ops.__version__)
# e.g., "0.3.0"
Support
- Issues: https://github.com/LessUp/llm-speed/issues
- Documentation: https://lessup.github.io/llm-speed/
← Back to Documentation