HPC-AI-Optimization-Lab 1.0.0
High-Performance CUDA Kernels for AI/ML Workloads
Loading...
Searching...
No Matches
cluster.cu
Go to the documentation of this file.
1#include "cluster.cuh"
2#include "../common/cuda_check.cuh"
3#include <stdexcept>
4#include <cooperative_groups/memcpy_async.h>
5
6namespace hpc::cuda13 {
7
8bool is_hopper_architecture() {
9 int device = 0;
10 cudaDeviceProp prop;
11 CUDA_CHECK(cudaGetDeviceProperties(&prop, device));
12 return prop.major >= 9;
13}
14
15namespace cg = cooperative_groups;
16
17template <typename T>
18__global__ void cluster_reduce_kernel(const T* __restrict__ input,
19 T* __restrict__ output,
20 size_t n) {
21 extern __shared__ float smem[];
22
23 cg::cluster_group cluster = cg::this_cluster();
24 int cluster_rank = cluster.rank();
25 int cluster_size = cluster.size();
26
27 int tid = threadIdx.x;
28 int idx = blockIdx.x * blockDim.x + threadIdx.x;
29
30 float val = (idx < n) ? static_cast<float>(input[idx]) : 0.0f;
31 smem[tid] = val;
32
33 cluster.sync();
34
35 if (cluster.use_cluster()) {
36 for (int s = cluster_size / 2; s > 0; s >>= 1) {
37 int peer_rank = (cluster_rank ^ s);
38 if (cluster_rank < s) {
39 smem[tid] = smem[tid] + smem[tid + s * blockDim.x];
40 }
41 cluster.sync();
42 }
43
44 if (cluster_rank == 0) {
45 float block_sum = 0.0f;
46 for (int i = 0; i < cluster_size; ++i) {
47 block_sum += smem[i * blockDim.x];
48 }
49 atomicAdd(output, static_cast<T>(block_sum));
50 }
51 } else {
52 for (int s = blockDim.x / 2; s > 0; s >>= 1) {
53 if (tid < s) {
54 smem[tid] += smem[tid + s];
55 }
56 __syncthreads();
57 }
58
59 if (tid == 0) {
60 atomicAdd(output, static_cast<T>(smem[0]));
61 }
62 }
63}
64
65template <typename T>
66__global__ void cluster_reduce_fallback_kernel(const T* __restrict__ input,
67 T* __restrict__ output,
68 size_t n) {
69 extern __shared__ float smem[];
70
71 int tid = threadIdx.x;
72 int idx = blockIdx.x * blockDim.x + threadIdx.x;
73
74 smem[tid] = (idx < n) ? static_cast<float>(input[idx]) : 0.0f;
75 __syncthreads();
76
77 for (int s = blockDim.x / 2; s > 0; s >>= 1) {
78 if (tid < s) {
79 smem[tid] += smem[tid + s];
80 }
81 __syncthreads();
82 }
83
84 if (tid == 0) {
85 atomicAdd(output, static_cast<T>(smem[0]));
86 }
87}
88
89template <>
90void cluster_reduce<float>(const float* input, float* output, size_t n,
91 const ClusterConfig& config, cudaStream_t stream) {
92 if (input == nullptr || output == nullptr) {
93 throw std::invalid_argument("cluster_reduce expects non-null input and output pointers");
94 }
95 if (n == 0) {
96 throw std::invalid_argument("cluster_reduce expects n > 0");
97 }
98 if (config.block_dims.x == 0) {
99 throw std::invalid_argument("cluster_reduce expects config.block_dims.x > 0");
100 }
101
102 int block_size = config.block_dims.x;
103 int grid_size = (n + block_size - 1) / block_size;
104 size_t smem_size = block_size * sizeof(float);
105
106 CUDA_CHECK(cudaMemsetAsync(output, 0, sizeof(float), stream));
107
108 if (config.use_cluster && is_hopper_architecture()) {
109 cluster_reduce_kernel<float><<<grid_size, block_size, smem_size, stream>>>(
110 input, output, n);
111 } else {
112 cluster_reduce_fallback_kernel<float><<<grid_size, block_size, smem_size, stream>>>(
113 input, output, n);
114 }
115 CUDA_CHECK_LAST();
116}
117
118template <>
119void cluster_reduce_fallback<float>(const float* input, float* output, size_t n,
120 const ClusterConfig& config, cudaStream_t stream) {
121 if (input == nullptr || output == nullptr) {
122 throw std::invalid_argument("cluster_reduce expects non-null input and output pointers");
123 }
124 if (n == 0) {
125 throw std::invalid_argument("cluster_reduce expects n > 0");
126 }
127 if (config.block_dims.x == 0) {
128 throw std::invalid_argument("cluster_reduce expects config.block_dims.x > 0");
129 }
130
131 int block_size = config.block_dims.x;
132 int grid_size = (n + block_size - 1) / block_size;
133 size_t smem_size = block_size * sizeof(float);
134
135 cluster_reduce_fallback_kernel<float><<<grid_size, block_size, smem_size, stream>>>(
136 input, output, n);
137 CUDA_CHECK_LAST();
138}
139
140} // namespace hpc::cuda13