HPC-AI-Optimization-Lab 1.0.0
High-Performance CUDA Kernels for AI/ML Workloads
Loading...
Searching...
No Matches
conv_implicit_gemm.cu
Go to the documentation of this file.
1#include "conv_implicit_gemm.cuh"
2#include "../common/cuda_check.cuh"
3
4namespace hpc::convolution {
5
6template <typename T>
7__global__ void conv2d_implicit_gemm_kernel(const T* __restrict__ input,
8 const T* __restrict__ weight,
9 T* __restrict__ output,
10 int batch, int in_c, int out_c,
11 int in_h, int in_w,
12 int out_h, int out_w,
13 int k_h, int k_w,
14 int stride_h, int stride_w,
15 int pad_h, int pad_w) {
16 int out_idx = blockIdx.x * blockDim.x + threadIdx.x;
17 int total_out = batch * out_c * out_h * out_w;
18
19 if (out_idx >= total_out) return;
20
21 int ow = out_idx % out_w;
22 int oh = (out_idx / out_w) % out_h;
23 int oc = (out_idx / (out_w * out_h)) % out_c;
24 int b = out_idx / (out_w * out_h * out_c);
25
26 float sum = 0.0f;
27
28 for (int ic = 0; ic < in_c; ++ic) {
29 for (int kh = 0; kh < k_h; ++kh) {
30 for (int kw = 0; kw < k_w; ++kw) {
31 int ih = oh * stride_h - pad_h + kh;
32 int iw = ow * stride_w - pad_w + kw;
33
34 if (ih >= 0 && ih < in_h && iw >= 0 && iw < in_w) {
35 int in_idx = b * (in_c * in_h * in_w) + ic * (in_h * in_w) + ih * in_w + iw;
36 int w_idx = oc * (in_c * k_h * k_w) + ic * (k_h * k_w) + kh * k_w + kw;
37 sum += static_cast<float>(input[in_idx]) * static_cast<float>(weight[w_idx]);
38 }
39 }
40 }
41 }
42
43 output[out_idx] = static_cast<T>(sum);
44}
45
46template <>
47void conv2d_implicit_gemm<float>(const float* input, const float* weight, float* output,
48 const ConvParams& p, cudaStream_t stream) {
49 int out_h = (p.in_height + 2 * p.pad_h - p.dilation_h * (p.kernel_h - 1) - 1) / p.stride_h + 1;
50 int out_w = (p.in_width + 2 * p.pad_w - p.dilation_w * (p.kernel_w - 1) - 1) / p.stride_w + 1;
51 int total = p.batch * p.out_channels * out_h * out_w;
52
53 int block_size = 256;
54 int grid_size = (total + block_size - 1) / block_size;
55
56 conv2d_implicit_gemm_kernel<float><<<grid_size, block_size, 0, stream>>>(
57 input, weight, output,
58 p.batch, p.in_channels, p.out_channels,
59 p.in_height, p.in_width, out_h, out_w,
60 p.kernel_h, p.kernel_w,
61 p.stride_h, p.stride_w,
62 p.pad_h, p.pad_w);
63 CUDA_CHECK_LAST();
64}
65
66} // namespace hpc::convolution