1#include "conv_implicit_gemm.cuh"
2#include "../common/cuda_check.cuh"
4namespace hpc::convolution {
7__global__ void conv2d_implicit_gemm_kernel(const T* __restrict__ input,
8 const T* __restrict__ weight,
9 T* __restrict__ output,
10 int batch, int in_c, int out_c,
14 int stride_h, int stride_w,
15 int pad_h, int pad_w) {
16 int out_idx = blockIdx.x * blockDim.x + threadIdx.x;
17 int total_out = batch * out_c * out_h * out_w;
19 if (out_idx >= total_out) return;
21 int ow = out_idx % out_w;
22 int oh = (out_idx / out_w) % out_h;
23 int oc = (out_idx / (out_w * out_h)) % out_c;
24 int b = out_idx / (out_w * out_h * out_c);
28 for (int ic = 0; ic < in_c; ++ic) {
29 for (int kh = 0; kh < k_h; ++kh) {
30 for (int kw = 0; kw < k_w; ++kw) {
31 int ih = oh * stride_h - pad_h + kh;
32 int iw = ow * stride_w - pad_w + kw;
34 if (ih >= 0 && ih < in_h && iw >= 0 && iw < in_w) {
35 int in_idx = b * (in_c * in_h * in_w) + ic * (in_h * in_w) + ih * in_w + iw;
36 int w_idx = oc * (in_c * k_h * k_w) + ic * (k_h * k_w) + kh * k_w + kw;
37 sum += static_cast<float>(input[in_idx]) * static_cast<float>(weight[w_idx]);
43 output[out_idx] = static_cast<T>(sum);
47void conv2d_implicit_gemm<float>(const float* input, const float* weight, float* output,
48 const ConvParams& p, cudaStream_t stream) {
49 int out_h = (p.in_height + 2 * p.pad_h - p.dilation_h * (p.kernel_h - 1) - 1) / p.stride_h + 1;
50 int out_w = (p.in_width + 2 * p.pad_w - p.dilation_w * (p.kernel_w - 1) - 1) / p.stride_w + 1;
51 int total = p.batch * p.out_channels * out_h * out_w;
54 int grid_size = (total + block_size - 1) / block_size;
56 conv2d_implicit_gemm_kernel<float><<<grid_size, block_size, 0, stream>>>(
57 input, weight, output,
58 p.batch, p.in_channels, p.out_channels,
59 p.in_height, p.in_width, out_h, out_w,
60 p.kernel_h, p.kernel_w,
61 p.stride_h, p.stride_w,
66} // namespace hpc::convolution