HPC-AI-Optimization-Lab 1.0.0
High-Performance CUDA Kernels for AI/ML Workloads
Loading...
Searching...
No Matches
fp8_gemm.cuh
Go to the documentation of this file.
1#pragma once
2
3#include <cuda_runtime.h>
4#include <cuda_fp8.h>
5
6namespace hpc::cuda13 {
7
8enum class FP8Format {
9 e4m3,
10 e5m2
11};
12
13struct FP8GEMMConfig {
14 int tile_m = 16;
15 int tile_n = 16;
16 int tile_k = 16;
17 FP8Format format_a = FP8Format::e4m3;
18 FP8Format format_b = FP8Format::e4m3;
19 float scale_a = 1.0f;
20 float scale_b = 1.0f;
21 bool use_fp8 = true;
22};
23
24bool is_hopper_architecture();
25
26void fp8_gemm(const __half* A, const __half* B, __half* C,
27 int M, int N, int K,
28 const FP8GEMMConfig& config,
29 cudaStream_t stream = nullptr);
30
31void fp8_gemm_fallback(const __half* A, const __half* B, __half* C,
32 int M, int N, int K,
33 const FP8GEMMConfig& config,
34 cudaStream_t stream = nullptr);
35
36} // namespace hpc::cuda13