gridwise_multiblock_welford_first_half.hpp Source File

gridwise_multiblock_welford_first_half.hpp Source File#

Composable Kernel: gridwise_multiblock_welford_first_half.hpp Source File
gridwise_multiblock_welford_first_half.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
7#include "ck/utility/math.hpp"
12
13namespace ck {
14
15template <typename GridwiseMultiblockWelfordFirstHalf_,
16 typename XDataType,
17 typename MeanVarDataType,
18 typename XGridDesc_M_K,
19 typename MeanVarCountGridDesc_M_G,
20 typename GetReduceCountPerThreadFunctor>
22 const XGridDesc_M_K x_grid_desc_m_k,
23 const MeanVarCountGridDesc_M_G mean_var_count_grid_desc_m_g,
24 const GetReduceCountPerThreadFunctor get_reduce_count_per_thread,
25 index_t num_k_block_tile_iteration,
26 const XDataType* const __restrict__ p_x,
27 MeanVarDataType* const p_welford_mean,
28 MeanVarDataType* const p_welford_variance,
29 int32_t* const p_welford_count)
30{
31 GridwiseMultiblockWelfordFirstHalf_::Run(x_grid_desc_m_k,
32 mean_var_count_grid_desc_m_g,
33 get_reduce_count_per_thread,
34 num_k_block_tile_iteration,
35 p_x,
36 p_welford_mean,
37 p_welford_variance,
38 p_welford_count);
39};
40
41template <typename XDataType,
42 typename AccDataType,
43 typename MeanVarDataType,
44 typename XGridDesc_M_K,
45 typename MeanVarCountGridDesc_M_G,
46 typename GetReduceCountPerThreadFunctor,
47 index_t BlockSize,
48 index_t MThreadClusterSize,
49 index_t KThreadClusterSize,
50 index_t MThreadSliceSize,
51 index_t KThreadSliceSize,
52 index_t XSrcCountSrcVectorDim,
53 index_t XSrcCountSrcVectorSize>
55{
56 static_assert((XSrcCountSrcVectorDim == 0 && MThreadSliceSize % XSrcCountSrcVectorSize == 0) ||
57 (XSrcCountSrcVectorDim == 1 &&
58 KThreadSliceSize % XSrcCountSrcVectorSize == 0),
59 "Invalid thread slice sizes and/or vector sizes configuration, please check!");
60
61 static constexpr bool reorder_thread_cluster = (XSrcCountSrcVectorDim == 0);
62
64
67
70
71 static constexpr auto thread_cluster_desc =
73
78
81
83 BlockSize,
86 false>;
87
89
90 static constexpr auto I0 = Number<0>{};
91 static constexpr auto I1 = Number<1>{};
92
93 static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
94 static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
95
96 // clang-format off
97 // First half of the Multiblock Welford method to calculate mean and variance, used by both batchnorm-forward and batchnorm-backward.
98 // clang-format on
99 __device__ static void Run(const XGridDesc_M_K& x_grid_desc_m_k,
100 const MeanVarCountGridDesc_M_G& mean_var_count_grid_desc_m_g,
101 const GetReduceCountPerThreadFunctor& get_reduce_count_per_thread,
102 index_t num_k_block_tile_iteration,
103 const XDataType* const __restrict__ p_x,
104 MeanVarDataType* const p_welford_mean,
105 MeanVarDataType* const p_welford_variance,
106 int32_t* const p_welford_count)
107 {
109 x_thread_buf;
110
112 welford_mean_thread_buf;
114 welford_var_thread_buf;
116 welford_count_thread_buf;
117
118 const index_t blkgroup_size = mean_var_count_grid_desc_m_g.GetLength(I1);
119
120 const index_t thread_local_id = get_thread_local_1d_id();
121 const index_t block_global_id = get_block_1d_id();
122 const index_t blkgroup_id = block_global_id / blkgroup_size;
123 const index_t block_local_id = block_global_id % blkgroup_size;
124
125 const auto thread_cluster_idx =
126 thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id));
127
128 const auto thread_m_cluster_id = thread_cluster_idx[I0];
129 const auto thread_k_cluster_id = thread_cluster_idx[I1];
130
131 using ThreadBufferLengths_M_K = Sequence<MThreadSliceSize, KThreadSliceSize>;
132 using ThreadBufferLengths_M_1 = Sequence<MThreadSliceSize, 1>;
133
134 constexpr auto thread_buffer_desc_m_k = make_naive_tensor_descriptor_packed(
136 constexpr auto thread_buffer_desc_m_1 = make_naive_tensor_descriptor_packed(
138
139 const index_t reduceSizePerBlock = K_BlockTileSize * num_k_block_tile_iteration;
140
141 auto threadwise_x_load = ThreadwiseTensorSliceTransfer_v2<XDataType,
142 AccDataType,
143 XGridDesc_M_K,
144 decltype(thread_buffer_desc_m_k),
145 ThreadBufferLengths_M_K,
147 XSrcCountSrcVectorDim,
148 XSrcCountSrcVectorSize,
149 1,
150 true>(
151 x_grid_desc_m_k,
152 make_multi_index(blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize,
153 block_local_id * reduceSizePerBlock +
154 thread_k_cluster_id * KThreadSliceSize));
155
156 auto threadwise_welford_mean_var_store =
158 MeanVarDataType,
159 decltype(thread_buffer_desc_m_1),
160 MeanVarCountGridDesc_M_G,
162 ThreadBufferLengths_M_1,
164 0,
165 1,
167 1,
168 true>(
169 mean_var_count_grid_desc_m_g,
170 make_multi_index(blkgroup_id * M_BlockTileSize +
171 thread_m_cluster_id * MThreadSliceSize,
172 block_local_id),
173 PassThroughOp{});
174
175 auto threadwise_welford_count_store =
177 int32_t,
178 decltype(thread_buffer_desc_m_1),
179 MeanVarCountGridDesc_M_G,
181 ThreadBufferLengths_M_1,
183 0,
184 1,
186 1,
187 true>(
188 mean_var_count_grid_desc_m_g,
189 make_multi_index(blkgroup_id * M_BlockTileSize +
190 thread_m_cluster_id * MThreadSliceSize,
191 block_local_id),
192 PassThroughOp{});
193
194 constexpr auto thread_copy_fwd_step_m_k = make_multi_index(0, K_BlockTileSize);
195
196 const auto x_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
197 p_x, x_grid_desc_m_k.GetElementSpaceSize());
198
199 auto welford_mean_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
200 p_welford_mean, mean_var_count_grid_desc_m_g.GetElementSpaceSize());
201
202 auto welford_var_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
203 p_welford_variance, mean_var_count_grid_desc_m_g.GetElementSpaceSize());
204
205 auto welford_count_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
206 p_welford_count, mean_var_count_grid_desc_m_g.GetElementSpaceSize());
207
208 auto threadwise_welford = ThreadwiseWelford();
209 threadwise_welford.max_count_ =
210 get_reduce_count_per_thread(block_local_id, thread_k_cluster_id);
211
213 welford_mean_thread_buf(I) = type_convert<AccDataType>(0.0f);
214 welford_var_thread_buf(I) = type_convert<AccDataType>(0.0f);
215 });
216
217 for(index_t reducedTiles = 0; reducedTiles < num_k_block_tile_iteration; ++reducedTiles)
218 {
219 threadwise_x_load.Run(x_grid_desc_m_k,
220 x_global_val_buf,
221 thread_buffer_desc_m_k,
222 make_tuple(I0, I0),
223 x_thread_buf);
224
225 threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_fwd_step_m_k);
226 threadwise_welford.Run(x_thread_buf, welford_mean_thread_buf, welford_var_thread_buf);
227 }
228
230 if constexpr(I > 0)
232
233 welford_count_thread_buf(I) = threadwise_welford.cur_count_;
235 welford_mean_thread_buf(I), welford_var_thread_buf(I), welford_count_thread_buf(I));
236 });
237
238 if(thread_k_cluster_id == 0)
239 {
240 threadwise_welford_mean_var_store.Run(thread_buffer_desc_m_1,
241 make_tuple(I0, I0),
242 welford_mean_thread_buf,
243 mean_var_count_grid_desc_m_g,
244 welford_mean_global_val_buf);
245
246 threadwise_welford_mean_var_store.Run(thread_buffer_desc_m_1,
247 make_tuple(I0, I0),
248 welford_var_thread_buf,
249 mean_var_count_grid_desc_m_g,
250 welford_var_global_val_buf);
251
252 threadwise_welford_count_store.Run(thread_buffer_desc_m_1,
253 make_tuple(I0, I0),
254 welford_count_thread_buf,
255 mean_var_count_grid_desc_m_g,
256 welford_count_global_val_buf);
257 };
258 }
259};
260
261} // namespace ck
Definition ck.hpp:268
__host__ __device__ constexpr auto make_multi_index(Xs &&... xs)
Definition array_multi_index.hpp:15
__global__ void kernel_multiblock_welford_first_half(const XGridDesc_M_K x_grid_desc_m_k, const MeanVarCountGridDesc_M_G mean_var_count_grid_desc_m_g, const GetReduceCountPerThreadFunctor get_reduce_count_per_thread, index_t num_k_block_tile_iteration, const XDataType *const __restrict__ p_x, MeanVarDataType *const p_welford_mean, MeanVarDataType *const p_welford_variance, int32_t *const p_welford_count)
Definition gridwise_multiblock_welford_first_half.hpp:21
int32_t index_t
Definition ck.hpp:299
@ Set
Definition ck.hpp:278
__host__ __device__ constexpr auto make_cluster_descriptor(const Lengths &lengths, ArrangeOrder order=typename arithmetic_sequence_gen< 0, Lengths::Size(), 1 >::type{})
Definition tensor_description/cluster_descriptor.hpp:13
integral_constant< index_t, N > Number
Definition number.hpp:12
__device__ index_t get_block_1d_id()
Definition get_id.hpp:47
__host__ __device__ constexpr Y type_convert(X x)
Definition utility/type_convert.hpp:98
__host__ __device__ constexpr auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition tensor_descriptor_helper.hpp:101
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__device__ index_t get_thread_local_1d_id()
Definition get_id.hpp:41
__device__ void block_sync_lds()
Definition synchronization.hpp:16
__host__ __device__ constexpr auto make_dynamic_buffer(T *p, ElementSpaceSize element_space_size)
Definition dynamic_buffer.hpp:472
signed int int32_t
Definition stdint.h:123
static __device__ void Run(AccDataType &mean_value, AccDataType &var_value, CountDataType &count)
Definition blockwise_welford.hpp:51
Definition gridwise_multiblock_welford_first_half.hpp:55
tensor_operation::element_wise::PassThrough PassThroughOp
Definition gridwise_multiblock_welford_first_half.hpp:88
static constexpr auto I1
Definition gridwise_multiblock_welford_first_half.hpp:91
static constexpr bool reorder_thread_cluster
Definition gridwise_multiblock_welford_first_half.hpp:61
static constexpr auto I0
Definition gridwise_multiblock_welford_first_half.hpp:90
static constexpr index_t M_BlockTileSize
Definition gridwise_multiblock_welford_first_half.hpp:93
static constexpr index_t K_BlockTileSize
Definition gridwise_multiblock_welford_first_half.hpp:94
typename conditional< reorder_thread_cluster, Sequence< 1, 0 >, Sequence< 0, 1 > >::type ThreadBufferDimAccessOrder
Definition gridwise_multiblock_welford_first_half.hpp:65
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number< MThreadSliceSize >{}))) ThreadReduceDstDesc_M
Definition gridwise_multiblock_welford_first_half.hpp:76
BlockwiseWelford< AccDataType, BlockSize, ThreadClusterLengths_M_K, ThreadClusterArrangeOrder, false > BlockwiseWelford
Definition gridwise_multiblock_welford_first_half.hpp:82
static __device__ void Run(const XGridDesc_M_K &x_grid_desc_m_k, const MeanVarCountGridDesc_M_G &mean_var_count_grid_desc_m_g, const GetReduceCountPerThreadFunctor &get_reduce_count_per_thread, index_t num_k_block_tile_iteration, const XDataType *const __restrict__ p_x, MeanVarDataType *const p_welford_mean, MeanVarDataType *const p_welford_variance, int32_t *const p_welford_count)
Definition gridwise_multiblock_welford_first_half.hpp:99
static constexpr auto thread_cluster_desc
Definition gridwise_multiblock_welford_first_half.hpp:71
typename conditional< reorder_thread_cluster, Sequence< 1, 0 >, Sequence< 0, 1 > >::type ThreadClusterArrangeOrder
Definition gridwise_multiblock_welford_first_half.hpp:68
Sequence< MThreadClusterSize, KThreadClusterSize > ThreadClusterLengths_M_K
Definition gridwise_multiblock_welford_first_half.hpp:63
ThreadwiseWelford< AccDataType, ThreadReduceSrcDesc_M_K, ThreadReduceDstDesc_M > ThreadwiseWelford
Definition gridwise_multiblock_welford_first_half.hpp:79
decltype(make_naive_tensor_descriptor_packed( make_tuple(Number< MThreadSliceSize >{}, Number< KThreadSliceSize >{}))) ThreadReduceSrcDesc_M_K
Definition gridwise_multiblock_welford_first_half.hpp:74
Definition utility/sequence.hpp:43
Definition static_buffer.hpp:16
Definition threadwise_tensor_slice_transfer.hpp:39
__device__ void Run(const SrcDesc &, const SrcSliceOriginIdx &, const SrcBuffer &src_buf, const DstDesc &dst_desc, DstBuffer &dst_buf)
Definition threadwise_tensor_slice_transfer.hpp:66
Helper structure that facilitates transfer of source (grid) data to destination threads.
Definition threadwise_tensor_slice_transfer.hpp:234
Definition utility/functional.hpp:100
Definition functional2.hpp:33
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:340