device_multiple_reduce_threadwise.hpp Source File

device_multiple_reduce_threadwise.hpp Source File#

Composable Kernel: device_multiple_reduce_threadwise.hpp Source File
device_multiple_reduce_threadwise.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include <iostream>
7#include <sstream>
8
11
16
18
19namespace ck {
20namespace tensor_operation {
21namespace device {
22
23template <index_t NumReduction,
24 typename InDataType,
25 typename AccDataType,
26 typename OutDataTypeTuple,
27 index_t Rank,
28 index_t NumReduceDim,
29 typename ReduceOperation,
30 typename InElementwiseOperationTuple,
31 typename AccElementwiseOperationTuple,
32 bool PropagateNan,
33 index_t BlockSize,
34 index_t MThreadSliceSize,
35 index_t KThreadSliceSize,
36 index_t InSrcVectorDim,
37 index_t InSrcVectorSize,
38 typename OutDstVectorSizeSeq>
40 NumReduceDim,
41 NumReduction,
42 InElementwiseOperationTuple,
43 AccElementwiseOperationTuple>
44{
45 static_assert(Rank <= 6, "Bigger Rank size is not supported!");
46
47 static_assert((InSrcVectorDim == 0 && MThreadSliceSize % InSrcVectorSize == 0) ||
48 (InSrcVectorDim == 1 && KThreadSliceSize % InSrcVectorSize == 0),
49 "Invalid thread slice sizes and/or vector sizes configuration, please check!");
50
51 static_assert(NumReduction == OutDataTypeTuple::Size() &&
52 NumReduction == InElementwiseOperationTuple::Size() &&
53 NumReduction == AccElementwiseOperationTuple::Size() &&
54 NumReduction == OutDstVectorSizeSeq::Size(),
55 "All tuple should have the same size as the number of Reductions!");
56
57 static_assert(sequence_all_of(OutDstVectorSizeSeq{},
58 [](auto vectorSize) {
59 return (MThreadSliceSize % vectorSize == 0);
60 }),
61 "The OutDstVectorSize should completely divide the MThreadSliceSize!");
62
63 static constexpr index_t NumInvariantDim = Rank - NumReduceDim;
64
65 static constexpr index_t NumInputDim = Rank;
66 static constexpr index_t NumOutputDim = (NumInvariantDim == 0) ? 1 : NumInvariantDim;
67 static constexpr bool reduceAllDim = (NumInvariantDim == 0);
68
69 static constexpr index_t M_BlockTileSize = BlockSize * MThreadSliceSize;
70 static constexpr index_t K_BlockTileSize = 1 * KThreadSliceSize;
71
73 {
74 return generate_tuple(
75 [&](auto I) {
76 using DataType = remove_cvref_t<decltype(OutDataTypeTuple{}[I])>;
77
78 return static_cast<DataType*>(nullptr);
79 },
81 };
82
84
85 static auto MakeSrc2dDescriptor(const std::array<index_t, NumInputDim>& inLengths,
86 const std::array<index_t, NumInputDim>& inStrides)
87 {
88 const auto tupleSrcLengths =
89 generate_tuple([&](auto I) { return inLengths[I]; }, Number<NumInputDim>{});
90 const auto tupleSrcStrides =
91 generate_tuple([&](auto I) { return inStrides[I]; }, Number<NumInputDim>{});
92
93 const auto inDesc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides);
94
95 const auto in_grid_desc_m_k = [&]() {
96 if constexpr(reduceAllDim)
97 {
98 const auto one_dim_inDesc = transform_tensor_descriptor(
99 inDesc,
100 make_tuple(make_merge_transform(tupleSrcLengths)),
103
104 return transform_tensor_descriptor(one_dim_inDesc,
106 1, one_dim_inDesc.GetLength(Number<0>{})))),
109 }
110 else
111 {
112 using InvariantDims = typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type;
114
115 const auto reduceDimLengths = generate_tuple(
116 [&](auto I) { return inLengths[NumInvariantDim + I]; }, Number<NumReduceDim>{});
117 const auto invariantDimLengths =
118 generate_tuple([&](auto I) { return inLengths[I]; }, Number<NumInvariantDim>{});
119
121 inDesc,
122 make_tuple(make_merge_transform(invariantDimLengths),
123 make_merge_transform(reduceDimLengths)),
124 make_tuple(InvariantDims{}, ReduceDims{}),
126 }
127 }();
128
129 const auto invariantLength = in_grid_desc_m_k.GetLength(Number<0>{});
130 const auto reduceLength = in_grid_desc_m_k.GetLength(Number<1>{});
131
132 const auto inPad_M =
133 math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
134 const auto inPad_K =
135 math::integer_least_multiple(reduceLength, K_BlockTileSize) - reduceLength;
136
137 auto in_grid_desc_m_k_padded = transform_tensor_descriptor(
138 in_grid_desc_m_k,
139 make_tuple(make_right_pad_transform(invariantLength, inPad_M),
140 make_right_pad_transform(reduceLength, inPad_K)),
143
144 return (in_grid_desc_m_k_padded);
145 };
146
147 static auto MakeDst1dDescriptor(const std::array<index_t, NumOutputDim>& outLengths,
148 const std::array<index_t, NumOutputDim>& outStrides)
149 {
150 const auto tupleDstLengths =
151 generate_tuple([&](auto I) { return outLengths[I]; }, Number<NumOutputDim>{});
152 const auto tupleDstStrides =
153 generate_tuple([&](auto I) { return outStrides[I]; }, Number<NumOutputDim>{});
154
155 auto outDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides);
156
157 auto out_grid_desc_m = transform_tensor_descriptor(
158 outDesc,
159 make_tuple(make_merge_transform(tupleDstLengths)),
162
163 const auto invariantLength = out_grid_desc_m.GetLength(Number<0>{});
164
165 const auto outPad =
166 math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
167
168 auto out_grid_desc_m_padded = transform_tensor_descriptor(
169 out_grid_desc_m,
170 make_tuple(make_right_pad_transform(invariantLength, outPad)),
173 return (out_grid_desc_m_padded);
174 };
175
177 {
178 return generate_tuple(
179 [&](auto I) {
180 (void)I;
181 return MakeDst1dDescriptor(std::array<index_t, NumOutputDim>{},
182 std::array<index_t, NumOutputDim>{});
183 },
185 };
186
187 using InGridDesc_M_K = decltype(MakeSrc2dDescriptor(std::array<index_t, NumInputDim>{},
188 std::array<index_t, NumInputDim>{}));
190
191 struct Argument : public BaseArgument
192 {
193 Argument(const std::array<index_t, NumInputDim>& inLengths,
194 const std::array<index_t, NumInputDim>& inStrides,
195 const std::array<index_t, NumOutputDim>& outLengths,
196 const std::array<std::array<index_t, NumOutputDim>, NumReduction>& outStridesArray,
197 const std::array<int, NumReduceDim>& reduceDims,
198 const std::array<double, NumReduction>& alphas,
199 const std::array<double, NumReduction>& betas,
200 const void* in_dev,
201 const std::array<void*, NumReduction>& out_dev_buffers,
202 const InElementwiseOperationTuple in_elementwise_op_tuple,
203 const AccElementwiseOperationTuple acc_elementwise_op_tuple)
204 : outLengths_{outLengths},
205 outStridesArray_{outStridesArray},
206 in_elementwise_op_tuple_{in_elementwise_op_tuple},
207 acc_elementwise_op_tuple_{acc_elementwise_op_tuple}
208 {
211
212 for(size_t i = 0; i < NumReduction; i++)
213 {
214 alpha_values_(i) = static_cast<AccDataType>(alphas[i]);
215 beta_values_(i) = static_cast<AccDataType>(betas[i]);
216 };
217
218 in_dev_ = static_cast<const InDataType*>(in_dev);
219
221 [&](auto iR) {
222 using OutDataTypePointer =
225 return static_cast<OutDataType*>(out_dev_buffers[iR]);
226 },
228
231
233
235 [&](auto I) { return MakeDst1dDescriptor(outLengths, outStridesArray[I]); },
237
240 }
241
242 std::array<index_t, NumInputDim> inLengths_;
243 std::array<index_t, NumInputDim> inStrides_;
244
245 std::array<index_t, NumOutputDim> outLengths_;
246 std::array<std::array<index_t, NumOutputDim>, NumReduction> outStridesArray_;
247
250
251 const InDataType* in_dev_;
253
256
257 InElementwiseOperationTuple in_elementwise_op_tuple_;
258 AccElementwiseOperationTuple acc_elementwise_op_tuple_;
259
262
263 size_t gridSize;
264 };
265
266 struct Invoker : public BaseInvoker
267 {
268 float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
269 {
270 using GridwiseMultipleReduce =
272 InDataType,
274 AccDataType,
277 ReduceOperation,
278 InElementwiseOperationTuple,
279 AccElementwiseOperationTuple,
281 PropagateNan,
282 BlockSize,
283 MThreadSliceSize,
284 KThreadSliceSize,
285 InSrcVectorDim,
286 InSrcVectorSize,
287 OutDstVectorSizeSeq>;
288
289 const auto kernel_main =
290 kernel_multiple_reduce_threadwise<GridwiseMultipleReduce,
291 NumReduction,
292 InDataType,
294 AccDataType,
297 InElementwiseOperationTuple,
298 AccElementwiseOperationTuple>;
299
300 float avg_time = 0;
301
302 avg_time += launch_and_time_kernel(stream_config,
303 kernel_main,
304 dim3(arg.gridSize),
305 dim3(BlockSize),
306 0,
311 arg.alpha_values_,
312 arg.in_dev_,
313 arg.beta_values_,
314 arg.out_dev_buffers_);
315
316 return (avg_time);
317 };
318
319 float Run(const BaseArgument* p_arg,
320 const StreamConfig& stream_config = StreamConfig{}) override
321 {
322 return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
323 };
324 };
325
326 bool IsSupportedArgument(const BaseArgument* p_arg) override
327 {
328 const Argument* pArg = dynamic_cast<const Argument*>(p_arg);
329
330 if constexpr(InSrcVectorDim == 0)
331 {
332 if constexpr(NumInvariantDim == 0)
333 {
334 return (false);
335 }
336 else
337 {
338 if(pArg->inStrides_[NumInvariantDim - 1] != 1 && InSrcVectorSize != 1)
339 return (false);
340
341 if(pArg->inLengths_[NumInvariantDim - 1] % InSrcVectorSize != 0)
342 return (false);
343 };
344 }
345 else
346 {
347 if(pArg->inStrides_[Rank - 1] != 1 && InSrcVectorSize != 1)
348 return (false);
349
350 if(pArg->inLengths_[Rank - 1] % InSrcVectorSize != 0)
351 return (false);
352 };
353
354 // To improve
355 bool valid = true;
356 static_for<0, NumReduction, 1>{}([&](auto I) {
357 if(pArg->outStridesArray_[I.value][NumOutputDim - 1] != 1 &&
358 OutDstVectorSizeSeq::At(I) != 1)
359 valid = false;
360
361 if(pArg->outLengths_[NumOutputDim - 1] % OutDstVectorSizeSeq::At(I) != 0)
362 valid = false;
363 });
364
365 if(!valid)
366 return (false);
367
368 return (true);
369 };
370
371 std::unique_ptr<BaseArgument> MakeArgumentPointer(
372 const std::array<index_t, NumInputDim> inLengths,
373 const std::array<index_t, NumInputDim> inStrides,
374 const std::array<index_t, NumOutputDim> outLengths,
375 const std::array<std::array<index_t, NumOutputDim>, NumReduction> outStridesArray,
376 const std::array<int, NumReduceDim> reduceDims,
377 const std::array<double, NumReduction> alphas,
378 const std::array<double, NumReduction> betas,
379 const void* in_dev,
380 const std::array<void*, NumReduction> out_dev_buffers,
381 const InElementwiseOperationTuple in_elementwise_op_tuple,
382 const AccElementwiseOperationTuple acc_elementwise_op_tuple) override
383 {
384 return std::make_unique<Argument>(inLengths,
385 inStrides,
386 outLengths,
387 outStridesArray,
388 reduceDims,
389 alphas,
390 betas,
391 in_dev,
392 out_dev_buffers,
393 in_elementwise_op_tuple,
394 acc_elementwise_op_tuple);
395 };
396
397 std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
398 {
399 return std::make_unique<Invoker>();
400 };
401
402 std::string GetTypeString() const override
403 {
404 auto str = std::stringstream();
405
406 // clang-format off
407 str << "DeviceMultipleReduceThreadwise<" << BlockSize << ",";
408 str << "M_C" << BlockSize << "_S" << MThreadSliceSize << ",";
409 str << "K_C" << 1 << "_S" << KThreadSliceSize << ",";
410 str << "InSrcVectorDim_" << InSrcVectorDim << "_InSrcVectorSize_" << InSrcVectorSize << ",";
411 str << "OutDstVectorSize";
412 static_for<0, OutDstVectorSizeSeq::Size(), 1>{}([&](auto I) {str << "_" << OutDstVectorSizeSeq::At(I); });
413 str << ">";
414 // clang-format on
415
416 return str.str();
417 }
418};
419
420} // namespace device
421} // namespace tensor_operation
422} // namespace ck
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:14
__host__ __device__ constexpr auto integer_least_multiple(X x, Y y)
Definition utility/math.hpp:78
Definition convolution_backward_data_specialization.hpp:8
std::pair< long_index_t, long_index_t > get_2d_lengths(const std::vector< index_t > &inLengths)
Definition device_reduce_common.hpp:20
std::vector< index_t > shuffle_tensor_dimensions(const std::vector< index_t > &origLengthsStrides, const std::vector< int > &reduceDims)
Definition device_reduce_common.hpp:75
Definition convolution_backward_data_specialization.hpp:7
Definition ck.hpp:268
int32_t index_t
Definition ck.hpp:299
__host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition tensor_descriptor_helper.hpp:49
@ Set
Definition ck.hpp:278
__global__ void kernel_multiple_reduce_threadwise(const InGridDesc_M_K in_grid_desc_m_k, const OutGridDesc_M_Tuple out_grid_desc_m_tuple, const InElementwiseOperationTuple in_elementwise_op_tuple, const AccElementwiseOperationTuple acc_elementwise_op_tuple, Array< AccDataType, NumReduction > alpha_values, const InDataType *const __restrict__ p_in_value_global, Array< AccDataType, NumReduction > beta_values, OutDataTypePointerTuple p_out_value_global_tuple)
Definition gridwise_2d_multiple_reduction_threadwise.hpp:26
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
__host__ __device__ constexpr auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:37
integral_constant< index_t, N > Number
Definition number.hpp:12
__host__ __device__ constexpr auto make_merge_transform(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:55
__host__ __device__ constexpr auto generate_tuple(F &&f, Number< N >)
Definition tuple_helper.hpp:21
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
int64_t long_index_t
Definition ck.hpp:300
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
__host__ __device__ constexpr auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:90
__host__ __device__ constexpr bool sequence_all_of(Seq, F f)
Definition utility/sequence.hpp:912
Definition ck/stream_config.hpp:10
Definition utility/array.hpp:14
Definition gridwise_2d_multiple_reduction_threadwise.hpp:63
Definition utility/sequence.hpp:43
typename conditional< kHasContent, type0, type1 >::type type
Definition utility/sequence.hpp:271
Definition functional2.hpp:33
Definition device_base.hpp:197
Definition device_multiple_reduce.hpp:25
Definition device_multiple_reduce_threadwise.hpp:192
InGridDesc_M_K in_grid_desc_m_k
Definition device_multiple_reduce_threadwise.hpp:254
long_index_t invariant_total_length
Definition device_multiple_reduce_threadwise.hpp:260
long_index_t reduce_total_length
Definition device_multiple_reduce_threadwise.hpp:261
Array< AccDataType, NumReduction > beta_values_
Definition device_multiple_reduce_threadwise.hpp:249
Array< AccDataType, NumReduction > alpha_values_
Definition device_multiple_reduce_threadwise.hpp:248
const InDataType * in_dev_
Definition device_multiple_reduce_threadwise.hpp:251
std::array< index_t, NumInputDim > inLengths_
Definition device_multiple_reduce_threadwise.hpp:242
std::array< index_t, NumInputDim > inStrides_
Definition device_multiple_reduce_threadwise.hpp:243
size_t gridSize
Definition device_multiple_reduce_threadwise.hpp:263
InElementwiseOperationTuple in_elementwise_op_tuple_
Definition device_multiple_reduce_threadwise.hpp:257
std::array< index_t, NumOutputDim > outLengths_
Definition device_multiple_reduce_threadwise.hpp:245
Argument(const std::array< index_t, NumInputDim > &inLengths, const std::array< index_t, NumInputDim > &inStrides, const std::array< index_t, NumOutputDim > &outLengths, const std::array< std::array< index_t, NumOutputDim >, NumReduction > &outStridesArray, const std::array< int, NumReduceDim > &reduceDims, const std::array< double, NumReduction > &alphas, const std::array< double, NumReduction > &betas, const void *in_dev, const std::array< void *, NumReduction > &out_dev_buffers, const InElementwiseOperationTuple in_elementwise_op_tuple, const AccElementwiseOperationTuple acc_elementwise_op_tuple)
Definition device_multiple_reduce_threadwise.hpp:193
AccElementwiseOperationTuple acc_elementwise_op_tuple_
Definition device_multiple_reduce_threadwise.hpp:258
OutGridDesc_M_Tuple out_grid_desc_m_tuple
Definition device_multiple_reduce_threadwise.hpp:255
std::array< std::array< index_t, NumOutputDim >, NumReduction > outStridesArray_
Definition device_multiple_reduce_threadwise.hpp:246
OutDataTypePointerTuple out_dev_buffers_
Definition device_multiple_reduce_threadwise.hpp:252
Definition device_multiple_reduce_threadwise.hpp:267
float Run(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition device_multiple_reduce_threadwise.hpp:268
float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition device_multiple_reduce_threadwise.hpp:319
Definition device_multiple_reduce_threadwise.hpp:44
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition device_multiple_reduce_threadwise.hpp:326
static constexpr index_t NumInputDim
Definition device_multiple_reduce_threadwise.hpp:65
static constexpr index_t NumInvariantDim
Definition device_multiple_reduce_threadwise.hpp:63
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition device_multiple_reduce_threadwise.hpp:397
static auto MakeSrc2dDescriptor(const std::array< index_t, NumInputDim > &inLengths, const std::array< index_t, NumInputDim > &inStrides)
Definition device_multiple_reduce_threadwise.hpp:85
static auto GenerateOutGrid1dDescTuple()
Definition device_multiple_reduce_threadwise.hpp:176
static auto MakeDst1dDescriptor(const std::array< index_t, NumOutputDim > &outLengths, const std::array< index_t, NumOutputDim > &outStrides)
Definition device_multiple_reduce_threadwise.hpp:147
std::string GetTypeString() const override
Definition device_multiple_reduce_threadwise.hpp:402
decltype(GenerateOutDataTypePointerTuple()) OutDataTypePointerTuple
Definition device_multiple_reduce_threadwise.hpp:83
decltype(MakeSrc2dDescriptor(std::array< index_t, NumInputDim >{}, std::array< index_t, NumInputDim >{})) InGridDesc_M_K
Definition device_multiple_reduce_threadwise.hpp:187
static constexpr index_t K_BlockTileSize
Definition device_multiple_reduce_threadwise.hpp:70
static auto GenerateOutDataTypePointerTuple()
Definition device_multiple_reduce_threadwise.hpp:72
std::unique_ptr< BaseArgument > MakeArgumentPointer(const std::array< index_t, NumInputDim > inLengths, const std::array< index_t, NumInputDim > inStrides, const std::array< index_t, NumOutputDim > outLengths, const std::array< std::array< index_t, NumOutputDim >, NumReduction > outStridesArray, const std::array< int, NumReduceDim > reduceDims, const std::array< double, NumReduction > alphas, const std::array< double, NumReduction > betas, const void *in_dev, const std::array< void *, NumReduction > out_dev_buffers, const InElementwiseOperationTuple in_elementwise_op_tuple, const AccElementwiseOperationTuple acc_elementwise_op_tuple) override
Definition device_multiple_reduce_threadwise.hpp:371
decltype(GenerateOutGrid1dDescTuple()) OutGridDesc_M_Tuple
Definition device_multiple_reduce_threadwise.hpp:189
static constexpr bool reduceAllDim
Definition device_multiple_reduce_threadwise.hpp:67
static constexpr index_t NumOutputDim
Definition device_multiple_reduce_threadwise.hpp:66
static constexpr index_t M_BlockTileSize
Definition device_multiple_reduce_threadwise.hpp:69