reference_grouped_conv_bwd_data.hpp Source File

reference_grouped_conv_bwd_data.hpp Source File#

Composable Kernel: reference_grouped_conv_bwd_data.hpp Source File
reference_grouped_conv_bwd_data.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include <cstdlib>
7#include <thread>
8
9#include "ck_tile/core.hpp"
11
12namespace ck_tile {
13
14template <ck_tile::index_t NDimSpatial,
15 typename InDataType,
16 typename WeiDataType,
17 typename OutDataType>
19 const HostTensor<WeiDataType>& weight,
20 const HostTensor<OutDataType>& output,
21 std::vector<ck_tile::long_index_t> conv_strides,
22 std::vector<ck_tile::long_index_t> conv_dilations,
23 std::vector<ck_tile::long_index_t> in_left_pads,
24 std::vector<ck_tile::long_index_t>)
25{
26 if(!(input.get_num_of_dimension() == NDimSpatial + 3 &&
27 weight.get_num_of_dimension() == NDimSpatial + 3 &&
28 output.get_num_of_dimension() == NDimSpatial + 3))
29 {
30
31 printf("%lu %lu %lu",
33 weight.get_num_of_dimension(),
34 output.get_num_of_dimension());
35
36 throw std::runtime_error("wrong! inconsistent dimension");
37 }
38
39 if constexpr(NDimSpatial == 1)
40 {
41 auto func = [&](auto g, auto n, auto c, auto wi) {
42 std::size_t K = weight.get_lengths()[1];
43 std::size_t X = weight.get_lengths()[3];
44
45 std::size_t Wo = output.get_lengths()[3];
46 float v_acc = 0;
47
48 for(std::size_t x = 0; x < X; ++x)
49 {
50 auto w_tmp = static_cast<ck_tile::long_index_t>(wi) +
51 static_cast<ck_tile::long_index_t>(in_left_pads[0]) -
52 static_cast<ck_tile::long_index_t>(x * conv_dilations[0]);
53
54 if(w_tmp % conv_strides[0] == 0)
55 {
56 auto wo = static_cast<ck_tile::long_index_t>(w_tmp) /
57 static_cast<ck_tile::long_index_t>(conv_strides[0]);
58
59 if(wo >= 0 && ck_tile::type_convert<std::size_t>(wo) < Wo)
60 {
61 for(std::size_t k = 0; k < K; ++k)
62 {
63 OutDataType v_out = output(g, n, k, wo);
64 WeiDataType v_wei = weight(g, k, c, x);
65 v_acc += ck_tile::type_convert<float>(v_out) *
67 }
68 }
69 }
70 }
71 InDataType v_acc_converted = ck_tile::type_convert<InDataType>(v_acc);
72 input(g, n, c, wi) = v_acc_converted;
73 };
74
76 input.get_lengths()[0],
77 input.get_lengths()[1],
78 input.get_lengths()[2],
79 input.get_lengths()[3])(std::thread::hardware_concurrency());
80 }
81 else if constexpr(NDimSpatial == 2)
82 {
83 auto func = [&](auto g, auto n, auto c, auto hi, auto wi) {
84 std::size_t K = weight.get_lengths()[1];
85 std::size_t Y = weight.get_lengths()[3];
86 std::size_t X = weight.get_lengths()[4];
87
88 std::size_t Ho = output.get_lengths()[3];
89 std::size_t Wo = output.get_lengths()[4];
90
91 float v_acc = 0;
92
93 for(std::size_t y = 0; y < Y; ++y)
94 {
95 auto h_tmp = static_cast<ck_tile::long_index_t>(hi) +
96 static_cast<ck_tile::long_index_t>(in_left_pads[0]) -
97 static_cast<ck_tile::long_index_t>(y * conv_dilations[0]);
98 if(h_tmp % conv_strides[0] == 0)
99 {
100 auto ho = static_cast<ck_tile::long_index_t>(h_tmp) /
101 static_cast<ck_tile::long_index_t>(conv_strides[0]);
102 if(ho >= 0 && ck_tile::type_convert<std::size_t>(ho) < Ho)
103 {
104 for(std::size_t x = 0; x < X; ++x)
105 {
106 auto w_tmp = static_cast<ck_tile::long_index_t>(wi) +
107 static_cast<ck_tile::long_index_t>(in_left_pads[1]) -
108 static_cast<ck_tile::long_index_t>(x * conv_dilations[1]);
109 if(w_tmp % conv_strides[1] == 0)
110 {
111 auto wo = static_cast<ck_tile::long_index_t>(w_tmp) /
112 static_cast<ck_tile::long_index_t>(conv_strides[1]);
113
114 if(wo >= 0 && ck_tile::type_convert<std::size_t>(wo) < Wo)
115 {
116 for(std::size_t k = 0; k < K; ++k)
117 {
118 OutDataType v_out = output(g, n, k, ho, wo);
119 WeiDataType v_wei = weight(g, k, c, y, x);
120 v_acc += ck_tile::type_convert<float>(v_out) *
122 }
123 }
124 }
125 }
126 }
127 }
128 }
129 InDataType v_acc_converted = ck_tile::type_convert<InDataType>(v_acc);
130 input(g, n, c, hi, wi) = v_acc_converted;
131 };
132
134 input.get_lengths()[0],
135 input.get_lengths()[1],
136 input.get_lengths()[2],
137 input.get_lengths()[3],
138 input.get_lengths()[4])(std::thread::hardware_concurrency());
139 }
140 else if constexpr(NDimSpatial == 3)
141 {
142 auto func = [&](auto g, auto n, auto c, auto di, auto hi, auto wi) {
143 std::size_t K = weight.get_lengths()[1];
144 std::size_t Z = weight.get_lengths()[3];
145 std::size_t Y = weight.get_lengths()[4];
146 std::size_t X = weight.get_lengths()[5];
147
148 std::size_t Do = output.get_lengths()[3];
149 std::size_t Ho = output.get_lengths()[4];
150 std::size_t Wo = output.get_lengths()[5];
151
152 float v_acc = 0;
153
154 for(std::size_t z = 0; z < Z; ++z)
155 {
156 auto d_tmp = static_cast<ck_tile::long_index_t>(di) +
157 static_cast<ck_tile::long_index_t>(in_left_pads[0]) -
158 static_cast<ck_tile::long_index_t>(z * conv_dilations[0]);
159 if(d_tmp % conv_strides[0] == 0)
160 {
161 auto do_ = static_cast<ck_tile::long_index_t>(d_tmp) /
162 static_cast<ck_tile::long_index_t>(conv_strides[0]);
163 if(do_ >= 0 && ck_tile::type_convert<std::size_t>(do_) < Do)
164 {
165 for(std::size_t y = 0; y < Y; ++y)
166 {
167 auto h_tmp = static_cast<ck_tile::long_index_t>(hi) +
168 static_cast<ck_tile::long_index_t>(in_left_pads[1]) -
169 static_cast<ck_tile::long_index_t>(y * conv_dilations[1]);
170 if(h_tmp % conv_strides[1] == 0)
171 {
172 auto ho = static_cast<ck_tile::long_index_t>(h_tmp) /
173 static_cast<ck_tile::long_index_t>(conv_strides[1]);
174 if(ho >= 0 && ck_tile::type_convert<std::size_t>(ho) < Ho)
175 {
176 for(std::size_t x = 0; x < X; ++x)
177 {
178 auto w_tmp =
179 static_cast<ck_tile::long_index_t>(wi) +
180 static_cast<ck_tile::long_index_t>(in_left_pads[2]) -
181 static_cast<ck_tile::long_index_t>(x *
182 conv_dilations[2]);
183
184 if(w_tmp % conv_strides[2] == 0)
185 {
186 auto wo =
187 static_cast<ck_tile::long_index_t>(w_tmp) /
188 static_cast<ck_tile::long_index_t>(conv_strides[2]);
189 if(wo >= 0 &&
191 {
192 for(std::size_t k = 0; k < K; ++k)
193 {
194 OutDataType v_out =
195 output(g, n, k, do_, ho, wo);
196 WeiDataType v_wei = weight(g, k, c, z, y, x);
197 v_acc += ck_tile::type_convert<float>(v_out) *
199 }
200 }
201 }
202 }
203 }
204 }
205 }
206 }
207 }
208 }
209 InDataType v_acc_converted = ck_tile::type_convert<InDataType>(v_acc);
210 input(g, n, c, di, hi, wi) = v_acc_converted;
211 };
212
214 input.get_lengths()[0],
215 input.get_lengths()[1],
216 input.get_lengths()[2],
217 input.get_lengths()[3],
218 input.get_lengths()[4],
219 input.get_lengths()[5])(std::thread::hardware_concurrency());
220 }
221 else
222 {
223 throw std::runtime_error(
224 "Ref_conv_bwd_data: number of dimensions must be between 1 and 3.");
225 }
226}
227} // namespace ck_tile
#define CK_TILE_HOST
Definition config.hpp:40
Definition tile/core/algorithm/cluster_descriptor.hpp:13
CK_TILE_HOST auto make_ParallelTensorFunctor(F f, Xs... xs)
Definition tile/host/host_tensor.hpp:329
int64_t long_index_t
Definition integer.hpp:11
CK_TILE_HOST void reference_grouped_conv_bwd_data(HostTensor< InDataType > &input, const HostTensor< WeiDataType > &weight, const HostTensor< OutDataType > &output, std::vector< ck_tile::long_index_t > conv_strides, std::vector< ck_tile::long_index_t > conv_dilations, std::vector< ck_tile::long_index_t > in_left_pads, std::vector< ck_tile::long_index_t >)
Definition reference_grouped_conv_bwd_data.hpp:18
int32_t index_t
Definition integer.hpp:9
CK_TILE_HOST_DEVICE constexpr Y type_convert(X x)
Definition tile/core/numeric/type_convert.hpp:29
Definition tile/host/host_tensor.hpp:336
decltype(auto) get_lengths() const
Definition tile/host/host_tensor.hpp:390
std::size_t get_num_of_dimension() const
Definition tile/host/host_tensor.hpp:396