CuFlash-Attn

CI Docs

English | 简体中文

A FlashAttention implementation in CUDA C++ from scratch. This is primarily a reference/educational implementation demonstrating the FlashAttention algorithm; for production workloads requiring maximum performance, consider using established libraries like FlashAttention-2.

Features

  • Forward Pass: Efficient attention computation with O(N) memory complexity (FP32 and FP16)
  • Backward Pass: Gradient computation using recomputation strategy (FP32 only; FP16 backward not yet implemented)
  • Causal Masking: Support for autoregressive models
  • Online Softmax: Numerically stable softmax without storing O(N²) attention matrix

Known Limitations

  • FP16 backward pass is not implemented - calling with half pointers will return UNSUPPORTED_DTYPE
  • head_dim support: Only 32, 64, and 128 are supported
  • High shared memory usage: May require GPUs with extended shared memory support for head_dim=128
  • DIMENSION_MISMATCH error: Currently not actively checked (API does not receive per-tensor shape metadata)

Requirements

  • CUDA Toolkit 11.0+
  • CMake 3.18+
  • C++17 compatible compiler
  • (Optional) PyTorch for comparison tests

Building

cmake --preset default      # Debug build with tests
cmake --build --preset default
ctest --preset default

cmake --preset release      # Optimized build
cmake --build --preset release

Manual build

mkdir build && cd build
cmake .. -DCMAKE_BUILD_TYPE=Release
cmake --build . -j$(nproc)

Build Options

  • BUILD_TESTS=ON/OFF: Build test suite (default: ON)
  • ENABLE_RAPIDCHECK=ON/OFF: Enable RapidCheck property tests (default: OFF)
  • BUILD_SHARED_LIBS=ON/OFF: Build the shared library (default: ON)
  • BUILD_EXAMPLES=ON/OFF: Build the example binary (default: ON)
  • ENABLE_FAST_MATH=ON/OFF: Enable --use_fast_math for CUDA (faster but less precise, default: OFF)

If CMake cannot find CUDA, configure it explicitly:

cmake .. -DCUDAToolkit_ROOT=/usr/local/cuda -DCMAKE_CUDA_COMPILER=/usr/local/cuda/bin/nvcc

Usage

C++ API

#include "flash_attention.h"

// Forward pass
cuflash::FlashAttentionError err = cuflash::flash_attention_forward(
    Q, K, V,           // Input tensors [batch, heads, seq_len, head_dim]
    O, L,              // Output tensor and logsumexp
    batch_size, num_heads, seq_len, head_dim,
    scale,             // Usually 1/sqrt(head_dim)
    causal,            // Enable causal masking
    stream             // CUDA stream (optional)
);

// Backward pass
err = cuflash::flash_attention_backward(
    Q, K, V, O, L, dO, // Inputs and upstream gradient
    dQ, dK, dV,        // Output gradients
    batch_size, num_heads, seq_len, head_dim,
    scale, causal, stream
);

Supported Configurations

  • head_dim: 32, 64, 128
  • Data types: float32; float16 is currently forward-only
  • Causal masking: Optional

Running Tests

ctest --preset default --output-on-failure

GoogleTest is automatically fetched via CMake FetchContent — no manual installation required.

PyTorch Comparison Tests

python tests/test_pytorch_comparison.py

Build the shared library first. Preset builds place artifacts under build/<preset>/, for example build/default/ or build/release/. You can also override the library path with CUFLASH_LIB=/absolute/path/to/libcuflash_attn.so.

Algorithm

This implementation follows the FlashAttention algorithm:

  1. Tiling: Divide Q, K, V into blocks that fit in SRAM
  2. Online Softmax: Compute softmax incrementally without storing full attention matrix
  3. Recomputation: In backward pass, recompute attention weights instead of storing

Memory Complexity

  • Standard Attention: O(N²) for attention matrix
  • FlashAttention: O(N) - only stores output and logsumexp

Project Structure

├── include/
│   └── flash_attention.h      # Public API header
├── src/
│   ├── flash_attention_api.cu     # API implementation
│   ├── flash_attention_forward.cu # Forward kernel
│   ├── flash_attention_backward.cu# Backward kernel
│   ├── flash_attention_fp16.cu    # FP16 support
│   ├── online_softmax.cuh         # Online softmax utilities
│   └── matmul.cuh                 # Matrix multiplication helpers
├── tests/
│   ├── test_forward.cu            # Forward pass tests
│   ├── test_backward.cu           # Backward pass tests
│   ├── test_causal_mask.cu        # Causal masking tests
│   ├── test_online_softmax.cu     # Online softmax tests
│   ├── test_error_handling.cu     # Error handling tests
│   ├── test_dtype.cu              # Data type tests
│   ├── test_numerical_stability.cu# Numerical stability tests
│   └── test_pytorch_comparison.py # PyTorch comparison
├── examples/
│   └── basic_usage.cu             # Usage example
├── CMakeLists.txt
├── CMakePresets.json           # Build presets
└── README.md

Error Handling

cuflash::FlashAttentionError err = cuflash::flash_attention_forward(...);
if (err != cuflash::FlashAttentionError::SUCCESS) {
    std::cerr << cuflash::get_error_string(err) << std::endl;
}

Error Codes

  • SUCCESS: Operation completed successfully
  • INVALID_DIMENSION: Dimension parameters are invalid
  • DIMENSION_MISMATCH: Reserved for richer shape validation APIs; not emitted by the current raw-pointer interface
  • NULL_POINTER: Input or output pointer is null
  • CUDA_ERROR: CUDA runtime error occurred
  • OUT_OF_MEMORY: Insufficient GPU memory
  • UNSUPPORTED_HEAD_DIM: head_dim must be 32, 64, or 128
  • UNSUPPORTED_DTYPE: Data type not supported for this operation

License

MIT License

References

results matching ""

    No results matching ""