epilogue_cshuffle_v3_welford_wmma.hpp Source File

epilogue_cshuffle_v3_welford_wmma.hpp Source File#

Composable Kernel: epilogue_cshuffle_v3_welford_wmma.hpp Source File
epilogue_cshuffle_v3_welford_wmma.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
10
11namespace ck {
12
13template <typename DsDataType,
14 typename EDataType,
15 typename AccDataType,
16 typename CShuffleDataType,
17 index_t MPerBlock,
18 index_t NPerBlock,
19 index_t MPerWmma,
20 index_t NPerWmma,
21 index_t MRepeat,
22 index_t NRepeat,
23 index_t CShuffleMRepeatPerShuffle,
24 index_t CShuffleNRepeatPerShuffle,
25 typename CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
26 typename CDEShuffleBlockTransferScalarPerVectors,
27 typename CDEElementwiseOperation,
28 typename ThisThreadBlock,
29 typename BlockwiseGemmPipe,
30 index_t BlockSize>
32 : EpilogueCShuffleBase<DsDataType,
33 EDataType,
34 AccDataType,
35 CShuffleDataType,
36 MPerBlock,
37 NPerBlock,
38 MPerWmma,
39 NPerWmma,
40 MRepeat,
41 NRepeat,
42 CShuffleMRepeatPerShuffle,
43 CShuffleNRepeatPerShuffle,
44 CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
45 CDEShuffleBlockTransferScalarPerVectors,
46 CDEElementwiseOperation,
47 ThisThreadBlock,
48 BlockwiseGemmPipe>
49{
51 DsDataType,
52 EDataType,
53 AccDataType,
54 CShuffleDataType,
55 MPerBlock,
56 NPerBlock,
57 MPerWmma,
58 NPerWmma,
59 MRepeat,
60 NRepeat,
61 CShuffleMRepeatPerShuffle,
62 CShuffleNRepeatPerShuffle,
63 CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
64 CDEShuffleBlockTransferScalarPerVectors,
65 CDEElementwiseOperation,
67 BlockwiseGemmPipe>;
68
72 using Base::I0;
73 using Base::I1;
74 using Base::I2;
75 using Base::I3;
76 using Base::NumDTensor;
77
78 template <typename DoPads, index_t MPerTile, index_t NPerTile>
79 __host__ __device__ static auto MakeMeanVarDescriptor_M_N(index_t M, index_t N)
80 {
81 const auto grid_desc_m_n =
84 grid_desc_m_n, make_tuple(MPerTile, NPerTile), DoPads{});
85 }
86
87 template <typename DoPads, index_t MPerTile, index_t NPerTile>
88 __host__ __device__ static auto MakeCountDescriptor_M_N(index_t M, index_t N)
89 {
90 // We will broadcast [N] to [M, N] in this descriptor
91 // Hence, 1st stride is 0
92 const auto grid_desc_m_n =
95 grid_desc_m_n, make_tuple(MPerTile, NPerTile), DoPads{});
96 }
97
98 template <typename GridDescriptor_M_N>
99 __host__ __device__ static constexpr auto
100 MakeMeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock(const GridDescriptor_M_N& grid_desc_m_n)
101 {
102 const auto M = grid_desc_m_n.GetLength(I0);
103 const auto NBlock = grid_desc_m_n.GetLength(I1);
104 const auto MBlock = M / MPerBlock;
105
106 const auto grid_desc_mblock_mperblock_nblock = transform_tensor_descriptor(
107 grid_desc_m_n,
112
113 return grid_desc_mblock_mperblock_nblock;
114 }
115
117 decltype(MakeMeanVarDescriptor_M_N<Sequence<true, false>, MPerBlock, 1>(1, 1));
118
120 decltype(MakeCountDescriptor_M_N<Sequence<true, false>, MPerBlock, 1>(1, 1));
121
122 __device__ EpilogueWelfordCShuffle(EDataType* p_welford_mean_grid_,
123 EDataType* p_welford_var_grid_,
124 int32_t* p_welford_count_grid_,
125 index_t MRaw_,
126 index_t NRaw_)
127 : p_welford_mean_grid(p_welford_mean_grid_),
128 p_welford_var_grid(p_welford_var_grid_),
129 p_welford_count_grid(p_welford_count_grid_),
130 NRaw(NRaw_)
131 {
132 index_t gemm_nblock = math::integer_divide_ceil(NRaw_, NPerBlock);
133
135 MakeMeanVarDescriptor_M_N<Sequence<true, false>, MPerBlock, 1>(MRaw_, gemm_nblock);
136
138 MakeCountDescriptor_M_N<Sequence<true, false>, MPerBlock, 1>(MRaw_, gemm_nblock);
139 }
140
141 template <InMemoryDataOperationEnum EGlobalMemoryDataOperation,
142 typename CThreadBuf,
143 typename DsGridPointer,
144 typename DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
145 typename EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>
146 __device__ void Run(CThreadBuf& c_thread_buf,
147 DsGridPointer p_ds_grid,
148 EDataType* p_e_grid,
149 void* p_shared,
150 const DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
151 ds_grid_desc_mblock_mperblock_nblock_nperblock,
152 const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
153 e_grid_desc_mblock_mperblock_nblock_nperblock,
154 CDEElementwiseOperation& cde_element_op,
155 const index_t& block_m_id,
156 const index_t& block_n_id)
157 {
158 // Vmem buffers
159 const auto ds_grid_buf = generate_tuple(
160 [&](auto i) {
162 p_ds_grid[i],
163 ds_grid_desc_mblock_mperblock_nblock_nperblock[i].GetElementSpaceSize());
164 },
166
168 p_e_grid, e_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
169
170 auto mean_var_grid_desc_mblock_mperblock_nblock =
173
175 p_welford_mean_grid, mean_var_grid_desc_mblock_mperblock_nblock.GetElementSpaceSize());
176
178 p_welford_var_grid, mean_var_grid_desc_mblock_mperblock_nblock.GetElementSpaceSize());
179
180 auto count_grid_desc_mblock_mperblock_nblock =
182 auto welford_count_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
183 p_welford_count_grid, count_grid_desc_mblock_mperblock_nblock.GetElementSpaceSize());
184
185 // LDS buffer
186 constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat =
188
189 auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
190 static_cast<CShuffleDataType*>(p_shared),
191 c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat
192 .GetElementSpaceSize());
193
194 // tuple of reference to C/Ds tensor buffers (mix LDS and Vmem)
195 const auto c_ds_buf_refs = concat_tuple_of_reference(
196 tie(c_shuffle_block_buf),
197 generate_tie([&](auto i) -> const auto& // return type should be reference
198 { return ds_grid_buf[i]; },
200
201 // Thread transfer Vgpr to LDS
202 auto c_thread_copy_vgpr_to_lds = GetVgprToLDSEpilogueDescriptor();
203
204 // Space Filling Curve Vgpr
205 constexpr auto sfc_c_vgpr = typename Base::SpaceFillingCurveVgpr{};
206
207 // Space Filling Curve Vmem
208 constexpr auto sfc_cde_global = typename Base::SpaceFillingCurveVmem{};
209
210 // C thread descriptor
211 constexpr auto c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs =
212 BlockwiseGemmPipe::
213 GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs();
214
215 // tuple of reference to C/Ds tensor descriptors
216 const auto c_ds_desc_refs = concat_tuple_of_reference(
217 tie(c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat),
218 generate_tie([&](auto i) -> const auto& // return type should be reference
219 { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; },
221
222 // Thread transfer LDS to Vmem
223 auto cde_shuffle_block_copy_lds_and_global =
225 c_ds_desc_refs,
226 e_grid_desc_mblock_mperblock_nblock_nperblock,
227 cde_element_op,
228 block_m_id,
229 block_n_id);
230
231 // Block descriptor
232 constexpr auto
233 c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs =
235
236 // E Vgpr buffer
237 constexpr index_t PostShuffleThreadSliceSize_M =
238 (CShuffleMRepeatPerShuffle * BlockwiseGemmPipe::MWaves * MPerWmma) /
239 CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(I1);
240
241 constexpr index_t PostShuffleThreadSliceSize_N =
242 (CShuffleNRepeatPerShuffle * BlockwiseGemmPipe::NWaves * NPerWmma) /
243 CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(I3);
244
245 constexpr auto PostShuffleThreadSliceSize_M_N =
247
248 // Welford
249 constexpr auto post_shuffle_thread_desc_m_n =
252 Number<1>{},
254
256 post_shuffle_thread_desc_m_n.GetElementSpaceSize());
257
258 using PostShuffleThreadClusterSize_M_N = Sequence<
259 CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(I1),
260 CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(I3)>;
261
262 constexpr auto post_shuffle_thread_cluster_desc =
263 make_cluster_descriptor(PostShuffleThreadClusterSize_M_N{}, Sequence<0, 1>{});
264
265 const auto post_shuffle_thread_cluster_idx =
266 post_shuffle_thread_cluster_desc.CalculateBottomIndex(
268
269 const auto post_shuffle_thread_data_idx_begin =
270 post_shuffle_thread_cluster_idx * PostShuffleThreadSliceSize_M_N;
271
272 constexpr auto thread_welford_src_desc_m_k = make_naive_tensor_descriptor_packed(make_tuple(
274
275 constexpr auto thread_welford_dst_desc_m =
277
278 using ThreadwiseWelford = ThreadwiseWelford<AccDataType,
279 decltype(thread_welford_src_desc_m_k),
280 decltype(thread_welford_dst_desc_m)>;
281
282 using BlockwiseWelford = BlockwiseWelford<AccDataType,
283 BlockSize,
284 PostShuffleThreadClusterSize_M_N,
286 false>;
287
288 constexpr int num_shuffleM =
289 MPerBlock / (CShuffleMRepeatPerShuffle * BlockwiseGemmPipe::MWaves * MPerWmma);
290
291 constexpr int num_shuffleN =
292 NPerBlock / (CShuffleNRepeatPerShuffle * BlockwiseGemmPipe::NWaves * NPerWmma);
293
294 using mean_var_vgpr_type = decltype(make_static_buffer<AddressSpaceEnum::Vgpr, AccDataType>(
295 thread_welford_dst_desc_m.GetElementSpaceSize()));
296
297 using welford_count_vgpr_type =
299 thread_welford_dst_desc_m.GetElementSpaceSize()));
300
301 Array<ThreadwiseWelford, num_shuffleM> threadwise_welfords;
304 Array<welford_count_vgpr_type, num_shuffleM> welford_count_thread_bufs;
305
306 int max_count = PostShuffleThreadSliceSize_N * num_shuffleN;
307 const auto nblock = mean_var_grid_desc_mblock_mperblock_nblock.GetLength(I2);
308
309 // tail block
310 if(block_n_id % nblock == nblock - 1)
311 {
312 constexpr index_t NPerShuffleBlock =
313 CShuffleNRepeatPerShuffle * BlockwiseGemmPipe::NWaves * NPerWmma;
314
315 int NPerBlockTail = NRaw - NPerBlock * (nblock - 1);
316 int thread_max_len =
317 PostShuffleThreadSliceSize_N * (post_shuffle_thread_cluster_idx[I1] + 1);
318 int shuffle_step = 0;
319 while(thread_max_len <= NPerBlockTail && shuffle_step < num_shuffleN)
320 {
321 ++shuffle_step;
322 thread_max_len += NPerShuffleBlock;
323 }
324
325 int delta = 0;
326 if(thread_max_len - NPerBlockTail > PostShuffleThreadSliceSize_N)
327 delta = 0;
328 else if(NPerBlockTail > thread_max_len)
329 delta = PostShuffleThreadSliceSize_N;
330 else
331 delta = PostShuffleThreadSliceSize_N - thread_max_len + NPerBlockTail;
332
333 max_count = shuffle_step * PostShuffleThreadSliceSize_N + delta;
334 }
335
336 // Initialize Welford
337 static_for<0, num_shuffleM, 1>{}([&](auto i) {
338 threadwise_welfords(i).max_count_ = max_count;
340 thread_welford_dst_desc_m.GetElementSpaceSize());
341
343 thread_welford_dst_desc_m.GetElementSpaceSize());
344
345 welford_count_thread_bufs(i) = make_static_buffer<AddressSpaceEnum::Vgpr, int32_t>(
346 thread_welford_dst_desc_m.GetElementSpaceSize());
347
349 mean_thread_bufs(i)(j) = type_convert<AccDataType>(0.0f);
350 var_thread_bufs(i)(j) = type_convert<AccDataType>(0.0f);
351 welford_count_thread_bufs(i)(j) = 0;
352 });
353 });
354
355 constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
356
357 static_assert(num_access == sfc_cde_global.GetNumOfAccess(), "wrong!");
358
359 // Run CShuffle + Store E + Welford threadwise
360 int shuffleM_index = __builtin_amdgcn_readfirstlane(0);
361 static_for<0, num_access, 1>{}([&](auto access_id) {
362 // make sure it's safe to read from LDS
364
365 // each thread shuffle data from VGPR to LDS
366 c_thread_copy_vgpr_to_lds.Run(
367 c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs,
368 sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
369 c_thread_buf,
370 c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs,
371 c_shuffle_block_buf);
372
373 // make sure it's safe to write to LDS
375
376 // Read LDS / Vmem + CDE elementwise operation
377 cde_shuffle_block_copy_lds_and_global.RunRead(c_ds_desc_refs, c_ds_buf_refs);
378
379 // Store to Vmem, but keep data in Vgpr for Welford
380 cde_shuffle_block_copy_lds_and_global.RunWriteAndStoreVgpr(
381 tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
382 tie(e_grid_buf),
383 tie(post_shuffle_thread_desc_m_n),
384 tie(e_thread_buf));
385
386 if constexpr(access_id < num_access - 1)
387 {
388 constexpr auto cde_global_step = sfc_cde_global.GetForwardStep(access_id);
389 // move on Ds
390 static_for<0, NumDTensor, 1>{}([&](auto i) {
391 cde_shuffle_block_copy_lds_and_global.MoveSrcSliceWindow(
392 c_ds_desc_refs, i + I1, cde_global_step);
393 });
394
395 // move on E
396 cde_shuffle_block_copy_lds_and_global.MoveDstSliceWindow(
397 tie(e_grid_desc_mblock_mperblock_nblock_nperblock), cde_global_step);
398 }
399
400 // Threadwise welford
401 auto& threadwise_welford = threadwise_welfords(shuffleM_index);
402 auto& mean_thread_buf = mean_thread_bufs(shuffleM_index);
403 auto& var_thread_buf = var_thread_bufs(shuffleM_index);
404
405 threadwise_welford.Run(e_thread_buf, mean_thread_buf, var_thread_buf);
406
407 if constexpr(access_id < num_access - 1)
408 {
409 constexpr auto de_global_step = sfc_cde_global.GetForwardStep(access_id);
410 constexpr int shuffleMInc =
411 de_global_step[I1] /
412 c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat.GetLength(
413 I1);
414 shuffleM_index = __builtin_amdgcn_readfirstlane(shuffleM_index + shuffleMInc);
415 }
416 });
417
418 // Blockwise welford and write out
419 static_for<0, num_shuffleM, 1>{}([&](auto i) {
420 auto& mean_thread_buf = mean_thread_bufs(i);
421 auto& var_thread_buf = var_thread_bufs(i);
422 auto& count_thread_buf = welford_count_thread_bufs(i);
423
426 count_thread_buf(j) = threadwise_welfords(i).cur_count_;
427 BlockwiseWelford::Run(mean_thread_buf(j), var_thread_buf(j), count_thread_buf(j));
428 });
429
430 if(post_shuffle_thread_cluster_idx[I1] == 0)
431 {
432 constexpr auto thread_welford_desc_I_m_I = make_naive_tensor_descriptor_packed(
434
435 constexpr int shuffleMPerBlock =
436 c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat.GetLength(
437 I1);
438
439 auto mean_var_count_thread_copy_index = make_multi_index(
440 block_m_id, // mblock
441 shuffleMPerBlock * i + post_shuffle_thread_data_idx_begin[I0], // mperblock
442 block_n_id); // nblock
443
444 auto mean_var_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3<
445 AccDataType,
446 EDataType,
447 decltype(thread_welford_desc_I_m_I),
448 decltype(mean_var_grid_desc_mblock_mperblock_nblock),
452 1,
453 1,
455 1,
456 true>{mean_var_grid_desc_mblock_mperblock_nblock,
457 mean_var_count_thread_copy_index,
459
460 mean_var_thread_copy_vgpr_to_global.Run(thread_welford_desc_I_m_I,
461 make_tuple(I0, I0, I0),
462 mean_thread_buf,
463 mean_var_grid_desc_mblock_mperblock_nblock,
464 mean_grid_buf); // write mean
465
466 mean_var_thread_copy_vgpr_to_global.Run(thread_welford_desc_I_m_I,
467 make_tuple(I0, I0, I0),
468 var_thread_buf,
469 mean_var_grid_desc_mblock_mperblock_nblock,
470 var_grid_buf); // write variance
471
472 // Stride of count is [0, 1]. Only the first row in count[0, 0:nblock] need
473 // to be written.
474 if(i == 0 && block_m_id == 0 && post_shuffle_thread_cluster_idx[I0] == 0)
475 {
476 auto count_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3<
477 int32_t,
478 int32_t,
479 decltype(thread_welford_desc_I_m_I),
480 decltype(count_grid_desc_mblock_mperblock_nblock),
484 1,
485 1,
487 1,
488 false>{count_grid_desc_mblock_mperblock_nblock,
489 mean_var_count_thread_copy_index,
491
492 count_thread_copy_vgpr_to_global.Run(thread_welford_desc_I_m_I,
493 make_tuple(I0, I0, I0),
494 count_thread_buf,
495 count_grid_desc_mblock_mperblock_nblock,
496 welford_count_grid_buf); // write count
497 }
498 }
499 });
500 }
501
508};
509
510} // namespace ck
__host__ __device__ constexpr auto integer_divide_ceil(X x, Y y)
Definition utility/math.hpp:72
__host__ __device__ constexpr auto PadTensorDescriptor(const TensorDesc &desc, const TileLengths &tile_lengths, DoPads)
Definition matrix_padder.hpp:19
Definition ck.hpp:268
__host__ __device__ constexpr auto make_multi_index(Xs &&... xs)
Definition array_multi_index.hpp:15
__host__ __device__ constexpr auto make_static_buffer(Number< N >)
Definition static_buffer.hpp:186
__host__ __device__ constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition multi_index_transform_helper.hpp:12
int32_t index_t
Definition ck.hpp:299
__host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition tensor_descriptor_helper.hpp:49
InMemoryDataOperationEnum
Definition ck.hpp:277
@ Set
Definition ck.hpp:278
constexpr Tuple< Args &... > tie(Args &... args) noexcept
Definition utility/tuple.hpp:218
__host__ __device__ constexpr auto make_cluster_descriptor(const Lengths &lengths, ArrangeOrder order=typename arithmetic_sequence_gen< 0, Lengths::Size(), 1 >::type{})
Definition tensor_description/cluster_descriptor.hpp:13
integral_constant< index_t, N > Number
Definition number.hpp:12
__host__ __device__ constexpr Y type_convert(X x)
Definition utility/type_convert.hpp:98
__host__ __device__ constexpr auto generate_tuple(F &&f, Number< N >)
Definition tuple_helper.hpp:21
__host__ __device__ constexpr auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition tensor_descriptor_helper.hpp:101
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__device__ index_t get_thread_local_1d_id()
Definition get_id.hpp:41
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
__device__ void block_sync_lds()
Definition synchronization.hpp:16
__host__ __device__ constexpr auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:90
__host__ __device__ constexpr auto make_dynamic_buffer(T *p, ElementSpaceSize element_space_size)
Definition dynamic_buffer.hpp:472
__host__ __device__ constexpr auto generate_tie(F &&f, Number< N >)
Definition tuple_helper.hpp:34
__host__ __device__ constexpr auto concat_tuple_of_reference(const Tuple< X &... > &tx, const Tuple< Y &... > &ty)
Definition tuple_helper.hpp:42
signed int int32_t
Definition stdint.h:123
Definition utility/array.hpp:14
Definition blockwise_welford.hpp:25
static __device__ void Run(T &mean_value, T &var_value, CountDataType &count)
Definition blockwise_welford.hpp:51
Definition epilogue_cshuffle_v3_wmma_base.hpp:29
static constexpr auto I2
Definition epilogue_cshuffle_v3_wmma_base.hpp:32
static constexpr index_t NumDTensor
Definition epilogue_cshuffle_v3_wmma_base.hpp:38
static constexpr auto I0
Definition epilogue_cshuffle_v3_wmma_base.hpp:30
static constexpr auto I3
Definition epilogue_cshuffle_v3_wmma_base.hpp:33
static __device__ auto GetLDSToVmemEpilogueDescriptor(CDsDescRefs &c_ds_desc_refs, EGridDesc &e_grid_desc_mblock_mperblock_nblock_nperblock, CDEElementwiseOperation &cde_element_op, const index_t &block_m_id, const index_t &block_n_id)
Definition epilogue_cshuffle_v3_wmma_base.hpp:204
SpaceFillingCurve< Sequence< MRepeat, 1, 1, NRepeat, 1, 1, BlockwiseGemmPipe::MAccVgprs >, Sequence< 0, 1, 2, 3, 4, 5, 6 >, Sequence< CShuffleMRepeatPerShuffle, 1, 1, CShuffleNRepeatPerShuffle, 1, 1, BlockwiseGemmPipe::MAccVgprs > > SpaceFillingCurveVgpr
Definition epilogue_cshuffle_v3_wmma_base.hpp:42
static __device__ constexpr auto GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat()
Definition epilogue_cshuffle_v3_wmma_base.hpp:63
static constexpr auto I1
Definition epilogue_cshuffle_v3_wmma_base.hpp:31
static __device__ auto GetVgprToLDSEpilogueDescriptor()
Definition epilogue_cshuffle_v3_wmma_base.hpp:118
SpaceFillingCurve< Sequence< 1, MPerBlock, 1, NPerBlock >, Sequence< 0, 2, 1, 3 >, Sequence< 1, CShuffleMRepeatPerShuffle *BlockwiseGemmPipe::MWaves *MPerWmma, 1, CShuffleNRepeatPerShuffle *BlockwiseGemmPipe::NWaves *NPerWmma > > SpaceFillingCurveVmem
Definition epilogue_cshuffle_v3_wmma_base.hpp:53
static __device__ constexpr auto GetCShuffleLDSDescriptor()
Definition epilogue_cshuffle_v3_wmma_base.hpp:78
__host__ static __device__ auto MakeCountDescriptor_M_N(index_t M, index_t N)
Definition epilogue_cshuffle_v3_welford_wmma.hpp:88
static constexpr auto I2
Definition epilogue_cshuffle_v3_wmma_base.hpp:32
EDataType * p_welford_var_grid
Definition epilogue_cshuffle_v3_welford_wmma.hpp:503
GemmCountGridDesc_M_N gemm_count_grid_desc_m_nblock
Definition epilogue_cshuffle_v3_welford_wmma.hpp:507
EpilogueCShuffleBase< DsDataType, EDataType, AccDataType, CShuffleDataType, MPerBlock, NPerBlock, MPerWmma, NPerWmma, MRepeat, NRepeat, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEShuffleBlockTransferScalarPerVectors, CDEElementwiseOperation, ThisThreadBlock, BlockwiseGemmPipe > Base
Definition epilogue_cshuffle_v3_welford_wmma.hpp:50
EDataType * p_welford_mean_grid
Definition epilogue_cshuffle_v3_welford_wmma.hpp:502
index_t NRaw
Definition epilogue_cshuffle_v3_welford_wmma.hpp:505
static constexpr auto I0
Definition epilogue_cshuffle_v3_wmma_base.hpp:30
static constexpr auto I3
Definition epilogue_cshuffle_v3_wmma_base.hpp:33
decltype(MakeMeanVarDescriptor_M_N< Sequence< true, false >, MPerBlock, 1 >(1, 1)) GemmMeanVarGridDesc_M_N
Definition epilogue_cshuffle_v3_welford_wmma.hpp:116
__device__ void Run(CThreadBuf &c_thread_buf, DsGridPointer p_ds_grid, EDataType *p_e_grid, void *p_shared, const DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock &ds_grid_desc_mblock_mperblock_nblock_nperblock, const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock &e_grid_desc_mblock_mperblock_nblock_nperblock, CDEElementwiseOperation &cde_element_op, const index_t &block_m_id, const index_t &block_n_id)
Definition epilogue_cshuffle_v3_welford_wmma.hpp:146
int32_t * p_welford_count_grid
Definition epilogue_cshuffle_v3_welford_wmma.hpp:504
__host__ static __device__ auto MakeMeanVarDescriptor_M_N(index_t M, index_t N)
Definition epilogue_cshuffle_v3_welford_wmma.hpp:79
__device__ EpilogueWelfordCShuffle(EDataType *p_welford_mean_grid_, EDataType *p_welford_var_grid_, int32_t *p_welford_count_grid_, index_t MRaw_, index_t NRaw_)
Definition epilogue_cshuffle_v3_welford_wmma.hpp:122
static __device__ constexpr auto GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat()
Definition epilogue_cshuffle_v3_wmma_base.hpp:63
decltype(MakeCountDescriptor_M_N< Sequence< true, false >, MPerBlock, 1 >(1, 1)) GemmCountGridDesc_M_N
Definition epilogue_cshuffle_v3_welford_wmma.hpp:119
static constexpr auto I1
Definition epilogue_cshuffle_v3_wmma_base.hpp:31
static __device__ auto GetVgprToLDSEpilogueDescriptor()
Definition epilogue_cshuffle_v3_wmma_base.hpp:118
GemmMeanVarGridDesc_M_N gemm_mean_var_grid_desc_m_nblock
Definition epilogue_cshuffle_v3_welford_wmma.hpp:506
__host__ static __device__ constexpr auto MakeMeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock(const GridDescriptor_M_N &grid_desc_m_n)
Definition epilogue_cshuffle_v3_welford_wmma.hpp:100
static __device__ constexpr auto GetCShuffleLDSDescriptor()
Definition epilogue_cshuffle_v3_wmma_base.hpp:78
Definition utility/sequence.hpp:43
Definition thread_group.hpp:12
Definition threadwise_tensor_slice_transfer.hpp:39
Definition threadwise_welford.hpp:18
Definition functional2.hpp:33
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:340