2#include "../common/cuda_check.cuh"
4#include <cooperative_groups.h>
8bool is_hopper_architecture() {
11 CUDA_CHECK(cudaGetDeviceProperties(&prop, device));
12 return prop.major >= 9;
15namespace cg = cooperative_groups;
17template <typename T, int NUM_CHANNELS>
18__global__ void tma_copy_kernel(const T* __restrict__ src,
21 extern __shared__ char smem[];
22 using TMA_LOAD = cuda::pipeline::async_load_factory<NUM_CHANNELS, T>;
23 using TMA_STORE = cuda::pipeline::async_store_factory<NUM_CHANNELS, T>;
25 __shared__ TMA_LOAD tma_load;
26 __shared__ TMA_STORE tma_store;
28 cg::thread_block tile = cg::this_thread_block();
29 cuda::pipeline::prologue consume = cuda::pipeline::make_prologue();
32 int col = blockIdx.x * blockDim.x + threadIdx.x;
34 if (row < rows && col < cols) {
35 cuda::memcpy_async(dst + row * cols + col,
36 src + row * cols + col,
37 sizeof(T) * NUM_CHANNELS, consume);
40 cuda::pipeline::commit consume;
45__global__ void async_copy_kernel(const T* __restrict__ src,
49 int col_start = blockIdx.x * blockDim.x + threadIdx.x;
51 if (row < rows && col_start < cols) {
52 dst[row * cols + col_start] = src[row * cols + col_start];
57void tma_copy_2d<float, 8>(const float* src, float* dst,
59 const TMAConfig& config,
60 cudaStream_t stream) {
61 if (src == nullptr || dst == nullptr) {
62 throw std::invalid_argument("tma_copy_2d expects non-null src and dst pointers");
64 if (rows <= 0 || cols <= 0) {
65 throw std::invalid_argument("tma_copy_2d expects positive rows and cols");
68 if (config.use_tma && is_hopper_architecture()) {
69 constexpr int NUM_CHANNELS = 8;
71 dim3 grid((cols + NUM_CHANNELS - 1) / NUM_CHANNELS, rows);
72 size_t smem_size = sizeof(float) * NUM_CHANNELS * 2;
74 tma_copy_kernel<float, NUM_CHANNELS><<<grid, block, smem_size, stream>>>(
75 src, dst, rows, cols);
77 tma_copy_2d_fallback(src, dst, rows, cols, stream);
83void tma_copy_2d<float>(const float* src, float* dst,
85 const TMAConfig& config,
86 cudaStream_t stream) {
87 tma_copy_2d<float, 8>(src, dst, rows, cols, config, stream);
91void tma_copy_2d_fallback<float>(const float* src, float* dst,
93 cudaStream_t stream) {
95 dim3 grid((cols + block.x - 1) / block.x, rows);
96 async_copy_kernel<float><<<grid, block, 0, stream>>>(src, dst, rows, cols);
100} // namespace hpc::cuda13