HPC-AI-Optimization-Lab 1.0.0
High-Performance CUDA Kernels for AI/ML Workloads
Loading...
Searching...
No Matches
tma.cu
Go to the documentation of this file.
1#include "tma.cuh"
2#include "../common/cuda_check.cuh"
3#include <stdexcept>
4#include <cooperative_groups.h>
5
6namespace hpc::cuda13 {
7
8bool is_hopper_architecture() {
9 int device = 0;
10 cudaDeviceProp prop;
11 CUDA_CHECK(cudaGetDeviceProperties(&prop, device));
12 return prop.major >= 9;
13}
14
15namespace cg = cooperative_groups;
16
17template <typename T, int NUM_CHANNELS>
18__global__ void tma_copy_kernel(const T* __restrict__ src,
19 T* __restrict__ dst,
20 int rows, int cols) {
21 extern __shared__ char smem[];
22 using TMA_LOAD = cuda::pipeline::async_load_factory<NUM_CHANNELS, T>;
23 using TMA_STORE = cuda::pipeline::async_store_factory<NUM_CHANNELS, T>;
24
25 __shared__ TMA_LOAD tma_load;
26 __shared__ TMA_STORE tma_store;
27
28 cg::thread_block tile = cg::this_thread_block();
29 cuda::pipeline::prologue consume = cuda::pipeline::make_prologue();
30
31 int row = blockIdx.y;
32 int col = blockIdx.x * blockDim.x + threadIdx.x;
33
34 if (row < rows && col < cols) {
35 cuda::memcpy_async(dst + row * cols + col,
36 src + row * cols + col,
37 sizeof(T) * NUM_CHANNELS, consume);
38 }
39
40 cuda::pipeline::commit consume;
41 consume.wait();
42}
43
44template <typename T>
45__global__ void async_copy_kernel(const T* __restrict__ src,
46 T* __restrict__ dst,
47 int rows, int cols) {
48 int row = blockIdx.y;
49 int col_start = blockIdx.x * blockDim.x + threadIdx.x;
50
51 if (row < rows && col_start < cols) {
52 dst[row * cols + col_start] = src[row * cols + col_start];
53 }
54}
55
56template <>
57void tma_copy_2d<float, 8>(const float* src, float* dst,
58 int rows, int cols,
59 const TMAConfig& config,
60 cudaStream_t stream) {
61 if (src == nullptr || dst == nullptr) {
62 throw std::invalid_argument("tma_copy_2d expects non-null src and dst pointers");
63 }
64 if (rows <= 0 || cols <= 0) {
65 throw std::invalid_argument("tma_copy_2d expects positive rows and cols");
66 }
67
68 if (config.use_tma && is_hopper_architecture()) {
69 constexpr int NUM_CHANNELS = 8;
70 dim3 block(128);
71 dim3 grid((cols + NUM_CHANNELS - 1) / NUM_CHANNELS, rows);
72 size_t smem_size = sizeof(float) * NUM_CHANNELS * 2;
73
74 tma_copy_kernel<float, NUM_CHANNELS><<<grid, block, smem_size, stream>>>(
75 src, dst, rows, cols);
76 } else {
77 tma_copy_2d_fallback(src, dst, rows, cols, stream);
78 }
79 CUDA_CHECK_LAST();
80}
81
82template <>
83void tma_copy_2d<float>(const float* src, float* dst,
84 int rows, int cols,
85 const TMAConfig& config,
86 cudaStream_t stream) {
87 tma_copy_2d<float, 8>(src, dst, rows, cols, config, stream);
88}
89
90template <>
91void tma_copy_2d_fallback<float>(const float* src, float* dst,
92 int rows, int cols,
93 cudaStream_t stream) {
94 dim3 block(256);
95 dim3 grid((cols + block.x - 1) / block.x, rows);
96 async_copy_kernel<float><<<grid, block, 0, stream>>>(src, dst, rows, cols);
97 CUDA_CHECK_LAST();
98}
99
100} // namespace hpc::cuda13