HPC-AI-Optimization-Lab 1.0.0
High-Performance CUDA Kernels for AI/ML Workloads
Loading...
Searching...
No Matches
fp8_gemm.cu
Go to the documentation of this file.
1#include "fp8_gemm.cuh"
2#include "../common/cuda_check.cuh"
3#include <mma.h>
4#include <iostream>
5
6namespace hpc::cuda13 {
7
8bool is_hopper_architecture() {
9 int device = 0;
10 cudaDeviceProp prop;
11 CUDA_CHECK(cudaGetDeviceProperties(&prop, device));
12 return prop.major >= 9;
13}
14
15using namespace nvcuda;
16
17__global__ void fp8_gemm_kernel(const __half* __restrict__ A,
18 const __half* __restrict__ B,
19 __half* __restrict__ C,
20 int M, int N, int K,
21 float scale_a, float scale_b) {
22 const int BM = 128;
23 const int BN = 128;
24 const int BK = 64;
25 const int WM = 64;
26 const int WN = 64;
27
28 const int global_warp_idx = (blockIdx.z * blockDim.y + threadIdx.y) * blockDim.x / 32 + threadIdx.x / 32;
29
30 extern __shared__ __half smem[];
31 __half* s_a = smem;
32 __half* s_b = smem + BK * BM;
33
34 const int warp_idx = threadIdx.x / 32;
35 const int lane_idx = threadIdx.x % 32;
36
37 __half s_a_reg[4];
38 __half s_b_reg[4];
39
40 wmma::fragment<wmma::matrix_a, 16, 16, 16, __half, wmma::row_major> a_frag[4];
41 wmma::fragment<wmma::matrix_b, 16, 16, 16, __half, wmma::col_major> b_frag[4];
42 wmma::fragment<wmma::accumulator, 16, 16, 16, __half, wmma::row_major> c_frag[4];
43
44 for (int i = 0; i < 4; ++i) {
45 wmma::fill_fragment(c_frag[i], (__half)0.0f);
46 }
47
48 for (int by = 0; by < K; by += BK) {
49 for (int i = 0; i < 4; ++i) {
50 int row = by + (blockIdx.x * BM / 16) * 4 + i;
51 int col = blockIdx.y * BN + (warp_idx < 2 ? 0 : 32) + ((lane_idx % 8) * 4 + i % 4);
52
53 if (row < M && col < K) {
54 s_a[i] = A[row * K + col];
55 } else {
56 s_a[i] = (__half)0.0f;
57 }
58
59 row = by + (i < 2 ? 0 : 32) + (lane_idx / 8);
60 col = blockIdx.y * BN + (blockIdx.x * BN / 16) * 4 + i;
61
62 if (row < K && col < N) {
63 s_b[i] = B[row * N + col];
64 } else {
65 s_b[i] = (__half)0.0f;
66 }
67 }
68 __syncthreads();
69
70 for (int i = 0; i < 4; ++i) {
71 a_frag[i].fill((__half)0.0f);
72 b_frag[i].fill((__half)0.0f);
73 }
74
75 wmma::load_matrix_sync(a_frag[0], s_a, 64);
76 wmma::load_matrix_sync(b_frag[0], s_b, 64);
77
78 for (int i = 0; i < 4; ++i) {
79 wmma::mma_sync(c_frag[i], a_frag[i], b_frag[i], c_frag[i]);
80 }
81
82 __syncthreads();
83 }
84
85 for (int i = 0; i < 4; ++i) {
86 wmma::store_matrix_sync(s_a + i * 256, c_frag[i], 64, wmma::row_major);
87 }
88
89 __syncthreads();
90
91 for (int i = 0; i < 4; ++i) {
92 int row = blockIdx.x * BM + (warp_idx < 2 ? 0 : 32) + ((lane_idx / 8) * 4 + i % 4);
93 int col = blockIdx.y * BN + (i < 2 ? 0 : 32) + (lane_idx % 8) * 4 + i % 4;
94
95 if (row < M && col < N) {
96 C[row * N + col] = s_a[i * 256 + ((lane_idx / 8) * 4 + i % 4) * 16 + (lane_idx % 8)];
97 }
98 }
99}
100
101void fp8_gemm(const __half* A, const __half* B, __half* C,
102 int M, int N, int K,
103 const FP8GEMMConfig& config,
104 cudaStream_t stream) {
105 if (A == nullptr || B == nullptr || C == nullptr) {
106 throw std::invalid_argument("fp8_gemm expects non-null A, B, C pointers");
107 }
108 if (M <= 0 || N <= 0 || K <= 0) {
109 throw std::invalid_argument("fp8_gemm expects positive M, N, K");
110 }
111
112 if (config.use_fp8 && is_hopper_architecture()) {
113 constexpr int BM = 128;
114 constexpr int BN = 128;
115
116 dim3 block(256, 1, 1);
117 dim3 grid((M + BM - 1) / BM, (N + BN - 1) / BN, 1);
118
119 fp8_gemm_kernel<<<grid, block, 0, stream>>>(
120 A, B, C, M, N, K, config.scale_a, config.scale_b);
121 } else {
122 fp8_gemm_fallback(A, B, C, M, N, K, config, stream);
123 }
124 CUDA_CHECK_LAST();
125}
126
127void fp8_gemm_fallback(const __half* A, const __half* B, __half* C,
128 int M, int N, int K,
129 const FP8GEMMConfig& config,
130 cudaStream_t stream) {
131 dim3 block(16, 16);
132 dim3 grid((N + block.x - 1) / block.x, (M + block.y - 1) / block.y);
133
134 auto fp8_gemm_naive = [] __device__ (
135 const __half* __restrict__ A,
136 const __half* __restrict__ B,
137 __half* __restrict__ C,
138 int M, int N, int K,
139 float scale_a, float scale_b
140 ) {
141 int row = blockIdx.y * blockDim.y + threadIdx.y;
142 int col = blockIdx.x * blockDim.x + threadIdx.x;
143
144 if (row < M && col < N) {
145 float sum = 0.0f;
146 for (int k = 0; k < K; ++k) {
147 float a_val = __half2float(A[row * K + k]) * scale_a;
148 float b_val = __half2float(B[k * N + col]) * scale_b;
149 sum += a_val * b_val;
150 }
151 C[row * N + col] = __float2half(sum);
152 }
153 };
154
155 fp8_gemm_naive<<<grid, block, 0, stream>>>(
156 A, B, C, M, N, K, config.scale_a, config.scale_b);
157 CUDA_CHECK_LAST();
158}
159
160} // namespace hpc::cuda13