HPC-AI-Optimization-Lab 1.0.0
High-Performance CUDA Kernels for AI/ML Workloads
Loading...
Searching...
No Matches
conv_winograd.cu
Go to the documentation of this file.
1#include "conv_winograd.cuh"
2#include "conv_implicit_gemm.cuh"
3#include "../common/cuda_check.cuh"
4#include <stdexcept>
5#include <cmath>
6
7namespace hpc::convolution {
8
9__device__ constexpr float winograd_BT[16] = {
10 1.0f, 0.0f, -1.0f, 0.0f,
11 0.0f, 1.0f, 1.0f, 0.0f,
12 0.0f, -1.0f, 1.0f, 0.0f,
13 0.0f, 1.0f, 0.0f, -1.0f
14};
15
16__device__ constexpr float winograd_G[16] = {
17 1.0f, 0.0f, 0.0f, 0.0f,
18 0.5f, 0.5f, 0.5f, 0.5f,
19 0.5f, -0.5f, 0.5f, -0.5f,
20 0.0f, 0.0f, 1.0f, 1.0f
21};
22
23__device__ constexpr float winograd_AT[16] = {
24 1.0f, 1.0f, 1.0f, 0.0f,
25 0.0f, 1.0f, -1.0f, 0.0f,
26 0.0f, 1.0f, 1.0f, 1.0f,
27 0.0f, 1.0f, 0.0f, -1.0f
28};
29
30__device__ __forceinline__ float winograd_transform_input(float d[4][4], int i, int j) {
31 float result = 0.0f;
32 for (int ri = 0; ri < 4; ++ri) {
33 for (int rj = 0; rj < 4; ++rj) {
34 result += winograd_AT[i * 4 + ri] * d[ri][rj] * winograd_AT[j * 4 + rj];
35 }
36 }
37 return result;
38}
39
40__device__ __forceinline__ float winograd_transform_weight(float g[3][3], int i, int j) {
41 float result = 0.0f;
42 for (int ri = 0; ri < 3; ++ri) {
43 for (int rj = 0; rj < 3; ++rj) {
44 result += winograd_G[i * 4 + ri] * g[ri][rj] * winograd_G[j * 4 + rj];
45 }
46 }
47 return result;
48}
49
50__global__ void winograd_conv_kernel(const float* __restrict__ input,
51 const float* __restrict__ weight,
52 float* __restrict__ output,
53 int batch, int in_ch, int out_ch,
54 int out_h, int out_w,
55 int in_h, int in_w) {
56 const int tile_h = 4;
57 const int tile_w = 4;
58
59 extern __shared__ float smem[];
60 float* s_input = smem;
61 float* s_weight = s_input + tile_h * tile_w * 16;
62 float* s_output = s_weight + 16;
63
64 int tile_idx = blockIdx.x;
65 int tile_h_idx = tile_idx / ((out_w + tile_w - 1) / tile_w);
66 int tile_w_idx = tile_idx % ((out_w + tile_w - 1) / tile_w);
67
68 int output_row = tile_h_idx * (tile_h - 2) + threadIdx.y;
69 int output_col = tile_w_idx * (tile_w - 2) + threadIdx.x;
70
71 if (output_row < out_h && output_col < out_w) {
72 float d[4][4] = {0};
73
74 for (int c = 0; c < in_ch; ++c) {
75 for (int dy = 0; dy < tile_h; ++dy) {
76 for (int dx = 0; dx < tile_w; ++dx) {
77 int in_row = output_row + dy - 1;
78 int in_col = output_col + dx - 1;
79
80 if (in_row >= 0 && in_row < in_h && in_col >= 0 && in_col < in_w) {
81 d[dy][dx] = input[(batch * in_ch + c) * in_h * in_w + in_row * in_w + in_col];
82 }
83 }
84 }
85
86 float d_win[4][4];
87 for (int i = 0; i < 4; ++i) {
88 for (int j = 0; j < 4; ++j) {
89 d_win[i][j] = 0;
90 for (int ri = 0; ri < 4; ++ri) {
91 d_win[i][j] += winograd_BT[i * 4 + ri] * d[ri][j];
92 }
93 }
94 }
95
96 for (int ox = 0; ox < 4; ++ox) {
97 for (int oy = 0; oy < 4; ++oy) {
98 d[oy][ox] = 0;
99 for (int ri = 0; ri < 4; ++ri) {
100 d[oy][ox] += d_win[oy][ri] * winograd_BT[ox * 4 + ri];
101 }
102 }
103 }
104
105 for (int oc = 0; oc < out_ch; ++oc) {
106 float g[3][3] = {0};
107 for (int ky = 0; ky < 3; ++ky) {
108 for (int kx = 0; kx < 3; ++kx) {
109 g[ky][kx] = weight[(oc * in_ch + c) * 9 + ky * 3 + kx];
110 }
111 }
112
113 float g_win[4][4];
114 for (int i = 0; i < 4; ++i) {
115 for (int j = 0; j < 4; ++j) {
116 g_win[i][j] = 0;
117 for (int ri = 0; ri < 3; ++ri) {
118 g_win[i][j] += winograd_G[i * 4 + ri] * g[ri][j % 3];
119 }
120 }
121 }
122
123 float m[4][4];
124 for (int i = 0; i < 4; ++i) {
125 for (int j = 0; j < 4; ++j) {
126 m[i][j] = d[i][j] * g_win[i][j];
127 }
128 }
129
130 if (output_row < out_h && output_col < out_w) {
131 float sum = 0;
132 for (int i = 0; i < 4; ++i) {
133 for (int j = 0; j < 4; ++j) {
134 sum += winograd_AT[i * 4 + j] * m[i][j];
135 }
136 }
137
138 int out_idx = (oc * out_h + output_row) * out_w + output_col;
139 if (threadIdx.y == 0 && threadIdx.x == 0) {
140 atomicAdd(&output[out_idx], sum);
141 }
142 }
143 }
144 }
145 }
146}
147
148void conv2d_winograd(const float* input, const float* weight, float* output,
149 const ConvParams& params,
150 const WinogradConfig& config,
151 cudaStream_t stream) {
152 if (input == nullptr || weight == nullptr || output == nullptr) {
153 throw std::invalid_argument("conv2d_winograd expects non-null input, weight, and output pointers");
154 }
155 if (params.batch <= 0 || params.in_channels <= 0 || params.out_channels <= 0) {
156 throw std::invalid_argument("conv2d_winograd expects positive batch/channel dimensions");
157 }
158 if (params.kernel_h != 3 || params.kernel_w != 3) {
159 conv2d_winograd_fallback(input, weight, output, params, stream);
160 return;
161 }
162
163 if (config.use_winograd) {
164 int out_h = (params.in_height + 2 * params.pad_h - params.dilation_h * (params.kernel_h - 1) - 1) / params.stride_h + 1;
165 int out_w = (params.in_width + 2 * params.pad_w - params.dilation_w * (params.kernel_w - 1) - 1) / params.stride_w + 1;
166
167 int tiles_h = (out_h + 1) / 2;
168 int tiles_w = (out_w + 1) / 2;
169 int num_tiles = tiles_h * tiles_w;
170
171 dim3 block(4, 4);
172 dim3 grid(num_tiles);
173 size_t smem_size = sizeof(float) * (16 + 16 + 16);
174
175 winograd_conv_kernel<<<grid, block, smem_size, stream>>>(
176 input, weight, output,
177 params.batch, params.in_channels, params.out_channels,
178 out_h, out_w,
179 params.in_height, params.in_width);
180 } else {
181 conv2d_winograd_fallback(input, weight, output, params, stream);
182 }
183 CUDA_CHECK_LAST();
184}
185
186void conv2d_winograd_fallback(const float* input, const float* weight, float* output,
187 const ConvParams& params,
188 cudaStream_t stream) {
189 conv2d_implicit_gemm<float>(input, weight, output, params, stream);
190}
191
192} // namespace hpc::convolution