Skip to content

Kernel 5: Tensor Core (WMMA)

Leveraging dedicated matrix multiply-accumulate hardware

Tensor Core Architecture

Hardware Capabilities

GenerationArchitectureOperations/CyclePrecision
Volta (V100)sm_7064 FMAFP16/FP32
Turing (RTX 20)sm_7564 FMAFP16/INT8/INT32
Ampere (A100/RTX 30)sm_80/sm_86256 FMAFP16/BF16/TF32
Ada (RTX 40)sm_89512 FMAFP16/BF16/TF32
Hopper (H100)sm_901024 FMAFP8/FP16/BF16

WMMA Fragment Size

Warp Matrix Multiply Accumulate (WMMA):

Fragment A: 16×16 FP16 matrix (row-major)
Fragment B: 16×16 FP16 matrix (row-major)
Fragment C: 16×16 FP32 matrix (row-major)

          D = A × B + C

Fragment D: 16×16 FP32 matrix

One warp (32 threads) collaborates on one 16×16×16 operation.

Implementation

cpp
// File: src/kernels/tensor_core_sgemm.cuh

#include <mma.h>
using namespace nvcuda::wmma;

// WMMA uses 16×16 tiles
const int WMMA_M = 16;
const int WMMA_N = 16;
const int WMMA_K = 16;

__global__ void sgemm_tensor_core_kernel(
    const half* __restrict__ A,
    const half* __restrict__ B,
    float* __restrict__ C,
    int M, int N, int K)
{
    // Each warp computes one 16×16 output tile
    int warpM = (blockIdx.y * blockDim.y + threadIdx.y) / warpSize;
    int warpN = (blockIdx.x * blockDim.x + threadIdx.x) / warpSize;
    
    // Check if this warp has work
    if (warpM * WMMA_M >= M || warpN * WMMA_N >= N)
        return;

    // Declare fragments
    fragment<matrix_a, WMMA_M, WMMA_N, WMMA_K, half, row_major> a_frag;
    fragment<matrix_b, WMMA_M, WMMA_N, WMMA_K, half, row_major> b_frag;
    fragment<accumulator, WMMA_M, WMMA_N, WMMA_K, float> acc_frag;
    
    // Initialize accumulator to zero
    fill_fragment(acc_frag, 0.0f);
    
    // Iterate over K dimension in 16-element chunks
    for (int k = 0; k < K; k += WMMA_K) {
        // Calculate pointers for this tile
        const half* a_ptr = A + warpM * WMMA_M * K + k;
        const half* b_ptr = B + k * N + warpN * WMMA_N;
        
        // Load fragments
        load_matrix_sync(a_frag, a_ptr, K);
        load_matrix_sync(b_frag, b_ptr, N);
        
        // Perform MMA
        mma_sync(acc_frag, a_frag, b_frag, acc_frag);
    }
    
    // Store result
    float* c_ptr = C + warpM * WMMA_M * N + warpN * WMMA_N;
    store_matrix_sync(c_ptr, acc_frag, N, mem_row_major);
}

Mixed Precision Considerations

Accuracy Trade-off

FP32 precision: ~7 decimal digits
FP16 precision: ~3 decimal digits

Conversion FP32 → FP16 introduces quantization error.
But FP32 accumulation maintains precision for the sum.

Verification Tolerances

cpp
// Standard FP32 kernels
const float rtol_fp32 = 1e-3f;
const float atol_fp32 = 1e-4f;

// Tensor Core mixed precision
const float rtol_tc = 5e-2f;   // 50× looser
const float atol_tc = 1e-2f;   // 100× looser

When FP16 is Acceptable

Use CaseFP16 OK?
Deep Learning Training✓ Yes
Deep Learning Inference✓ Yes
Scientific Computing⚠️ Check
Financial Calculations✗ No

Fragment Layout Details

Each warp's 32 threads hold the 16×16 matrix collaboratively:

16×16 matrix distributed across 32 threads:

Thread Layout (8 rows × 4 columns of threads):
┌───┬───┬───┬───┐
│ 0 │ 1 │ 2 │ 3 │  Row 0 holds elements at column 0-3
├───┼───┼───┼───┤
│ 4 │ 5 │ 6 │ 7 │  Row 1 holds elements at column 0-3
├───┼───┼───┼───┤
│...│   │   │   │
└───┴───┴───┴───┘

Each thread holds 8 FP16 values (4 rows × 2 columns).

The exact mapping is hardware-defined and managed by WMMA APIs.

Benchmark Scope Note

The commonly quoted 40% of cuBLAS figure in this repository is the WMMA compute-only measurement on Tensor Core-friendly shapes. It is useful as an upper-bound reference for the raw WMMA path, but it is not the same thing as the end-to-end FP32-facing wrapper that includes conversion and fallback behavior.

Read this page together with Benchmark Results and Benchmark Scope before comparing Tensor Core numbers against FP32 kernels.

Optimization Opportunities

Why the compute-only WMMA path still trails cuBLAS:

  1. Multi-level tiling: Warp-level + thread-level tiles
  2. Instruction pipelining: Issue multiple MMAs concurrently
  3. Shared memory staging: Better FP16 data layout
  4. Epilogue fusion: Combine output processing with MMA

For production use, always use cuBLAS or CUTLASS.


MIT Licensed