2#include "../common/cuda_check.cuh"
8bool is_hopper_architecture() {
11 CUDA_CHECK(cudaGetDeviceProperties(&prop, device));
12 return prop.major >= 9;
15using namespace nvcuda;
17__global__ void fp8_gemm_kernel(const __half* __restrict__ A,
18 const __half* __restrict__ B,
19 __half* __restrict__ C,
21 float scale_a, float scale_b) {
28 const int global_warp_idx = (blockIdx.z * blockDim.y + threadIdx.y) * blockDim.x / 32 + threadIdx.x / 32;
30 extern __shared__ __half smem[];
32 __half* s_b = smem + BK * BM;
34 const int warp_idx = threadIdx.x / 32;
35 const int lane_idx = threadIdx.x % 32;
40 wmma::fragment<wmma::matrix_a, 16, 16, 16, __half, wmma::row_major> a_frag[4];
41 wmma::fragment<wmma::matrix_b, 16, 16, 16, __half, wmma::col_major> b_frag[4];
42 wmma::fragment<wmma::accumulator, 16, 16, 16, __half, wmma::row_major> c_frag[4];
44 for (int i = 0; i < 4; ++i) {
45 wmma::fill_fragment(c_frag[i], (__half)0.0f);
48 for (int by = 0; by < K; by += BK) {
49 for (int i = 0; i < 4; ++i) {
50 int row = by + (blockIdx.x * BM / 16) * 4 + i;
51 int col = blockIdx.y * BN + (warp_idx < 2 ? 0 : 32) + ((lane_idx % 8) * 4 + i % 4);
53 if (row < M && col < K) {
54 s_a[i] = A[row * K + col];
56 s_a[i] = (__half)0.0f;
59 row = by + (i < 2 ? 0 : 32) + (lane_idx / 8);
60 col = blockIdx.y * BN + (blockIdx.x * BN / 16) * 4 + i;
62 if (row < K && col < N) {
63 s_b[i] = B[row * N + col];
65 s_b[i] = (__half)0.0f;
70 for (int i = 0; i < 4; ++i) {
71 a_frag[i].fill((__half)0.0f);
72 b_frag[i].fill((__half)0.0f);
75 wmma::load_matrix_sync(a_frag[0], s_a, 64);
76 wmma::load_matrix_sync(b_frag[0], s_b, 64);
78 for (int i = 0; i < 4; ++i) {
79 wmma::mma_sync(c_frag[i], a_frag[i], b_frag[i], c_frag[i]);
85 for (int i = 0; i < 4; ++i) {
86 wmma::store_matrix_sync(s_a + i * 256, c_frag[i], 64, wmma::row_major);
91 for (int i = 0; i < 4; ++i) {
92 int row = blockIdx.x * BM + (warp_idx < 2 ? 0 : 32) + ((lane_idx / 8) * 4 + i % 4);
93 int col = blockIdx.y * BN + (i < 2 ? 0 : 32) + (lane_idx % 8) * 4 + i % 4;
95 if (row < M && col < N) {
96 C[row * N + col] = s_a[i * 256 + ((lane_idx / 8) * 4 + i % 4) * 16 + (lane_idx % 8)];
101void fp8_gemm(const __half* A, const __half* B, __half* C,
103 const FP8GEMMConfig& config,
104 cudaStream_t stream) {
105 if (A == nullptr || B == nullptr || C == nullptr) {
106 throw std::invalid_argument("fp8_gemm expects non-null A, B, C pointers");
108 if (M <= 0 || N <= 0 || K <= 0) {
109 throw std::invalid_argument("fp8_gemm expects positive M, N, K");
112 if (config.use_fp8 && is_hopper_architecture()) {
113 constexpr int BM = 128;
114 constexpr int BN = 128;
116 dim3 block(256, 1, 1);
117 dim3 grid((M + BM - 1) / BM, (N + BN - 1) / BN, 1);
119 fp8_gemm_kernel<<<grid, block, 0, stream>>>(
120 A, B, C, M, N, K, config.scale_a, config.scale_b);
122 fp8_gemm_fallback(A, B, C, M, N, K, config, stream);
127void fp8_gemm_fallback(const __half* A, const __half* B, __half* C,
129 const FP8GEMMConfig& config,
130 cudaStream_t stream) {
132 dim3 grid((N + block.x - 1) / block.x, (M + block.y - 1) / block.y);
134 auto fp8_gemm_naive = [] __device__ (
135 const __half* __restrict__ A,
136 const __half* __restrict__ B,
137 __half* __restrict__ C,
139 float scale_a, float scale_b
141 int row = blockIdx.y * blockDim.y + threadIdx.y;
142 int col = blockIdx.x * blockDim.x + threadIdx.x;
144 if (row < M && col < N) {
146 for (int k = 0; k < K; ++k) {
147 float a_val = __half2float(A[row * K + k]) * scale_a;
148 float b_val = __half2float(B[k * N + col]) * scale_b;
149 sum += a_val * b_val;
151 C[row * N + col] = __float2half(sum);
155 fp8_gemm_naive<<<grid, block, 0, stream>>>(
156 A, B, C, M, N, K, config.scale_a, config.scale_b);
160} // namespace hpc::cuda13