gridwise_ab_transfer_thread_tiles.hpp Source File

gridwise_ab_transfer_thread_tiles.hpp Source File#

Composable Kernel: gridwise_ab_transfer_thread_tiles.hpp Source File
gridwise_ab_transfer_thread_tiles.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
9
10namespace ck {
11
12template <typename ABLayout,
13 typename ABMajorLayout,
14 typename LDSTypeAB,
15 index_t BlockSize,
16 index_t MNPerBlock,
17 index_t KPerBlock,
18 index_t MNPerWmma,
19 index_t ABK1Value,
20 bool UseBlockPaddingAB,
21 bool PermuteAB,
22 typename ABBlockTransferThreadClusterLengths_ABK0_MN_ABK1,
23 typename ABBlockTransferThreadClusterArrangeOrder,
24 typename ABBlockTransferSrcAccessOrder,
25 index_t ABBlockTransferSrcVectorDim,
26 index_t ABBlockTransferSrcScalarPerVector,
27 index_t ABBlockTransferDstScalarPerVector_ABK1,
28 bool ABThreadTransferSrcResetCoordinateAfterRun>
30{
31 static constexpr auto ABK0Number = Number<KPerBlock / ABK1Value>{};
32 static constexpr auto ABK1Number = Number<ABK1Value>{};
33
34 static constexpr auto I0 = Number<0>{};
35 static constexpr auto I1 = Number<1>{};
36 static constexpr auto I2 = Number<2>{};
37
38 static constexpr index_t ABPackedSize = []() {
40 return 2;
41 else
42 return 1;
43 }();
44
46
47 template <bool PadMN, bool PadK, typename GridDescriptorBase>
48 __host__ __device__ static auto MakeGridDescriptor(const GridDescriptorBase& ab_grid_desc,
49 index_t MN,
50 index_t MNPad,
51 index_t K,
52 index_t KPad,
53 index_t StrideAB,
54 index_t ABK0)
55 {
56
57 if constexpr(PadMN && PadK)
58 {
59 // pad both MN and K
60 const auto ab_grid_desc_n_k =
61 transform_tensor_descriptor(ab_grid_desc,
63 make_right_pad_transform(K, KPad - K)),
66
67 const auto ab_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
68 ab_grid_desc_n_k,
73
74 return ab_grid_desc_bk0_n_bk1;
75 }
76 else if constexpr(PadMN && !PadK)
77 {
78 // pad MN, but not K
79 const auto ab_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
80 ab_grid_desc,
82 make_right_pad_transform(MN, MNPad - MN)),
85
86 return ab_grid_desc_bk0_n_bk1;
87 }
88 else if constexpr(!PadMN && PadK)
89 {
90 // pad K, but not MN
91 const auto ab_grid_desc_n_k = transform_tensor_descriptor(
92 ab_grid_desc,
96
97 const auto ab_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
98 ab_grid_desc_n_k,
103
104 return ab_grid_desc_bk0_n_bk1;
105 }
106 else
107 {
108 if constexpr(!PermuteAB)
109 {
110 // not pad MN or K
111 const auto ab_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
112 ab_grid_desc,
117
118 return ab_grid_desc_bk0_n_bk1;
119 }
120 else
121 {
122 // Pre-shuffled Weight
123 // BGlobal[K / KPerBlock, MN, KPerBlock / K1, K1] -> BTile[K / K1, MN, K1]
124 constexpr index_t ABK01 = KPerBlock / ABK1Value;
125 const index_t ABK0_ = StrideAB / ABK1Value;
126 const index_t ABK00 = ABK0_ / ABK01;
127
128 const auto ab_grid_desc_abk00_mn_abk01_abk1_permute =
129 make_naive_tensor_descriptor_packed(make_tuple(ABK00, MN, ABK01, ABK1Value));
130
131 const auto ab_grid_desc_abk0_mn_abk1_permute = transform_tensor_descriptor(
132 ab_grid_desc_abk00_mn_abk01_abk1_permute,
135 make_pass_through_transform(ABK1Value)),
138
139 return ab_grid_desc_abk0_mn_abk1_permute;
140 }
141 }
142 }
143
144 __device__ static constexpr auto GetBlockDescriptor()
145 {
146 // A matrix in LDS memory, dst of blockwise copy
147 if constexpr(UseBlockPaddingAB)
148 {
149 // bank conflict when writting the data into LDS, but don't worry, we have whole entire
150 // loop to hide it in v4. it may give you some benefit from less valu in compute address
154 }
155 // xor tensor transformation request more unnecessary vgpr usage, would cause register spill
156 // in some cases.
158 {
159 constexpr index_t LdsSize = 32 * 4 / KPerBlock / sizeof(LDSTypeAB) / ABPackedSize;
160 constexpr auto MNLdsLayer = LdsSize < 1 ? 1 : LdsSize;
161 constexpr auto ab_lds_block_desc = make_naive_tensor_descriptor(
163 Number<MNPerBlock / MNLdsLayer>{},
164 ABK1Number),
166
167 constexpr auto ab_lds_block_desc_permuted = transform_tensor_descriptor(
168 ab_lds_block_desc,
175
176 constexpr auto ab_lds_block_desc_abk0_mnldslayer_mn_abk1 = transform_tensor_descriptor(
177 ab_lds_block_desc_permuted,
183
184 constexpr auto ab_lds_block_desc_abk0_mn_abk1 = transform_tensor_descriptor(
185 ab_lds_block_desc_abk0_mnldslayer_mn_abk1,
192
193 return ab_lds_block_desc_abk0_mn_abk1;
194 }
195 else
196 {
197 // kfold and mpair dimension is not always required.
198 // more dimension in merge_transform increase the difficulty of generating immarg offset
199 // for compiler.
200 constexpr auto MN0 = ABBlockTransferThreadClusterLengths_ABK0_MN_ABK1{}.At(I1);
201 constexpr auto MN1 = MNPerBlock / MN0;
202
203 constexpr auto KThreadWrite = ABBlockTransferThreadClusterLengths_ABK0_MN_ABK1{}.At(I0);
204 constexpr auto K0PerThreadWrite = ABK0Number / KThreadWrite;
205 constexpr auto KThreadRead = 64 / MNPerWmma;
206 constexpr auto K0PerThreadRead = ABK0Number / KThreadRead;
207
208 constexpr auto kfold = (ABK1Number * MN0 * sizeof(LDSTypeAB) > 128)
209 ? 1
210 : 128 / (ABK1Number * MN0 * sizeof(LDSTypeAB));
211 constexpr auto KThreadReadPerm =
212 (kfold * K0PerThreadWrite / K0PerThreadRead) > 1
213 ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead)
214 : KThreadRead;
215
216 // 1<=mpair<=n0
217 constexpr auto mpair = (ABK1Number * MNPerWmma * sizeof(LDSTypeAB) > 128)
218 ? 1
219 : ((128 / (ABK1Number * MNPerWmma * sizeof(LDSTypeAB))) > MN0
220 ? MN0
221 : 128 / (ABK1Number * MNPerWmma * sizeof(LDSTypeAB)));
222
223 constexpr auto ab_lds_block_desc = make_naive_tensor_descriptor_packed(
227 Number<kfold * MN0 / mpair>{},
229 ABK1Number));
230
231 constexpr auto ab_lds_block_desc_permuted = transform_tensor_descriptor(
232 ab_lds_block_desc,
237 make_tuple(Number<KThreadReadPerm * MN1>{}, Number<kfold * MN0 / mpair>{})),
244
245 constexpr auto ab_lds_block_desc_unmerged = transform_tensor_descriptor(
246 ab_lds_block_desc_permuted,
255 Sequence<1>{},
256 Sequence<2>{},
257 Sequence<3>{},
258 Sequence<4>{},
259 Sequence<5>{}),
261 Sequence<2>{},
264 Sequence<6>{},
265 Sequence<7>{}));
266
267 constexpr auto ab_lds_block_desc_abk0_mn_abk1 = transform_tensor_descriptor(
268 ab_lds_block_desc_unmerged,
271 Number<KThreadWrite / kfold / KThreadReadPerm>{},
279
280 return ab_lds_block_desc_abk0_mn_abk1;
281 }
282 }
283
284 template <typename GridDescriptor,
285 typename BlockDescriptor,
286 typename ABsDataType,
287 typename ABElementwiseOperation,
288 index_t GlobalBufferNum>
289 __device__ static auto GetBlockTransfer(GridDescriptor& grid_descriptor,
290 BlockDescriptor& block_descriptor,
291 ABElementwiseOperation& ab_element_op,
292 const index_t block_mn_id)
293 {
294 constexpr index_t NumABTensor = ABsDataType::Size();
295 const index_t mn_block_data_idx_on_grid =
296 __builtin_amdgcn_readfirstlane(block_mn_id * MNPerBlock);
297 // workaround because v7r2 is not as general as v4r1
298 if constexpr(NumABTensor > 1)
299 {
300 const auto idx_as_block_begin = generate_tuple(
301 [&](auto) { return make_multi_index(0, mn_block_data_idx_on_grid, 0); },
303
306 ABsDataType,
308 GridDescriptor,
309 decltype(tie(block_descriptor)),
310 ABElementwiseOperation,
313 ABBlockTransferThreadClusterLengths_ABK0_MN_ABK1,
314 ABBlockTransferThreadClusterArrangeOrder,
315 ABBlockTransferSrcAccessOrder,
317 ABBlockTransferSrcVectorDim,
318 2,
319 ABBlockTransferSrcScalarPerVector,
320 ABBlockTransferDstScalarPerVector_ABK1,
323 GlobalBufferNum>{grid_descriptor,
324 idx_as_block_begin,
325 tie(block_descriptor),
327 ab_element_op};
328 }
329 else
330 {
333 ABElementwiseOperation,
337 ABBlockTransferThreadClusterLengths_ABK0_MN_ABK1,
338 ABBlockTransferThreadClusterArrangeOrder,
341 decltype(grid_descriptor[I0]),
342 decltype(block_descriptor),
343 ABBlockTransferSrcAccessOrder,
345 ABBlockTransferSrcVectorDim,
346 2,
347 ABBlockTransferSrcScalarPerVector,
348 ABBlockTransferDstScalarPerVector_ABK1,
349 1,
350 1,
351 ABThreadTransferSrcResetCoordinateAfterRun,
352 true,
353 GlobalBufferNum>(grid_descriptor[I0],
354 make_multi_index(0, mn_block_data_idx_on_grid, 0),
355 ab_element_op,
356 block_descriptor,
357 make_multi_index(0, 0, 0),
359 }
360 }
361
362 template <index_t MNRepeat, index_t MNWaves>
363 __host__ __device__ static constexpr auto MakeWmmaTileDescriptor()
364 {
365 // This is a block descriptor used to read LDS memory into register
366 // It's defined in a way consistent with the existing implementation to
367 // avoid changes in the pipelines
368 using BlockDesc = decltype(GetBlockDescriptor());
369 // ABK0_MN_ABK1 -> ABK0_MNRepeat_MNWaves_KRow_MNPerWmma_ABK1
370 constexpr auto ABK0 = BlockDesc{}.GetLength(I0);
371 constexpr auto ABK1 = BlockDesc{}.GetLength(I2);
372#ifdef __gfx12__
373 constexpr auto KRow = I2;
374#else
375 constexpr auto KRow = I1;
376#endif
378 BlockDesc{},
385 }
386
387 __device__ static constexpr auto GetBlockStep()
388 {
389 // Grid descriptor step (MoveSrcSliceWindow)
390 return make_multi_index(KPerBlock / ABK1Number, 0, 0);
391 }
392
393 template <typename GridDescriptor>
394 __device__ static constexpr index_t GetKDimension(const GridDescriptor& grid_desc)
395 {
396 // K dimension size. This should always be called with the A matrix grid descriptor
397 // because it doesn't work for B matrix when packed int4 is used
398 return grid_desc.GetLength(I0) * grid_desc.GetLength(I2);
399 }
400};
401
402} // namespace ck
Definition ck.hpp:268
__host__ __device__ constexpr auto make_multi_index(Xs &&... xs)
Definition array_multi_index.hpp:15
typename uniform_sequence_gen< NSize, I >::type uniform_sequence_gen_t
Definition utility/sequence.hpp:928
__host__ __device__ constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition multi_index_transform_helper.hpp:12
int32_t index_t
Definition ck.hpp:299
__host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition tensor_descriptor_helper.hpp:49
@ Set
Definition ck.hpp:278
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
constexpr Tuple< Args &... > tie(Args &... args) noexcept
Definition utility/tuple.hpp:218
__host__ __device__ constexpr auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:37
__host__ __device__ constexpr auto make_xor_with_modulo_transform(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:185
integral_constant< index_t, N > Number
Definition number.hpp:12
__host__ __device__ constexpr auto make_merge_transform(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:55
constexpr bool is_same_v
Definition type.hpp:283
__host__ __device__ constexpr auto make_merge_transform_v3_division_mod(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:84
__host__ __device__ constexpr auto generate_tuple(F &&f, Number< N >)
Definition tuple_helper.hpp:21
__host__ __device__ constexpr auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition tensor_descriptor_helper.hpp:101
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
__host__ __device__ constexpr auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:90
Definition gridwise_ab_transfer_thread_tiles.hpp:30
__host__ static __device__ constexpr auto MakeWmmaTileDescriptor()
Definition gridwise_ab_transfer_thread_tiles.hpp:363
static __device__ constexpr index_t GetKDimension(const GridDescriptor &grid_desc)
Definition gridwise_ab_transfer_thread_tiles.hpp:394
static constexpr auto I1
Definition gridwise_ab_transfer_thread_tiles.hpp:35
static constexpr auto I0
Definition gridwise_ab_transfer_thread_tiles.hpp:34
static __device__ constexpr auto GetBlockStep()
Definition gridwise_ab_transfer_thread_tiles.hpp:387
static __device__ auto GetBlockTransfer(GridDescriptor &grid_descriptor, BlockDescriptor &block_descriptor, ABElementwiseOperation &ab_element_op, const index_t block_mn_id)
Definition gridwise_ab_transfer_thread_tiles.hpp:289
static constexpr auto ABK1Number
Definition gridwise_ab_transfer_thread_tiles.hpp:32
ThisThreadBlock< BlockSize > ThisThreadBlock
Definition gridwise_ab_transfer_thread_tiles.hpp:45
static __device__ constexpr auto GetBlockDescriptor()
Definition gridwise_ab_transfer_thread_tiles.hpp:144
static constexpr index_t ABPackedSize
Definition gridwise_ab_transfer_thread_tiles.hpp:38
__host__ static __device__ auto MakeGridDescriptor(const GridDescriptorBase &ab_grid_desc, index_t MN, index_t MNPad, index_t K, index_t KPad, index_t StrideAB, index_t ABK0)
Definition gridwise_ab_transfer_thread_tiles.hpp:48
static constexpr auto I2
Definition gridwise_ab_transfer_thread_tiles.hpp:36
static constexpr auto ABK0Number
Definition gridwise_ab_transfer_thread_tiles.hpp:31
Definition utility/sequence.hpp:43
Blockwise data transfer.
Definition thread_group_tensor_slice_transfer_v4r1.hpp:46
Definition thread_group_tensor_slice_transfer_v7r2.hpp:47
Definition utility/tuple.hpp:117
static constexpr value_type value
Definition utility/integral_constant.hpp:13
Definition data_type.hpp:187
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:340