blockwise_gemm_dpp.hpp Source File

blockwise_gemm_dpp.hpp Source File#

Composable Kernel: blockwise_gemm_dpp.hpp Source File
blockwise_gemm_dpp.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
10
11namespace ck {
12
22template <index_t BlockSize,
23 typename ABDataType,
24 typename AccDataType,
25 typename AK0MK1BlockDesc,
26 typename BK0NK1BlockDesc,
27 index_t MPerDpp,
28 index_t NPerDpp,
29 index_t MRepeat,
30 index_t NRepeat,
31 index_t KPack>
33{
34 static constexpr auto I0 = Number<0>{};
35 static constexpr auto I1 = Number<1>{};
36 static constexpr auto I2 = Number<2>{};
37 static constexpr auto I3 = Number<3>{};
38
40
41 static constexpr index_t MPerBlock = AK0MK1BlockDesc{}.GetLength(I1);
42 static constexpr index_t NPerBlock = BK0NK1BlockDesc{}.GetLength(I1);
43 static constexpr index_t KPerBlock =
44 BK0NK1BlockDesc{}.GetLength(I0) * BK0NK1BlockDesc{}.GetLength(I2);
45
46 static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerDpp);
47 static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerDpp);
48 static constexpr index_t WaveSize = BlockSize / MWaves / NWaves;
49
50 static constexpr index_t A_K0 = AK0MK1BlockDesc{}.GetLength(I0);
51 static constexpr index_t B_K0 = BK0NK1BlockDesc{}.GetLength(I0);
52 static constexpr index_t A_K1 = AK0MK1BlockDesc{}.GetLength(I2);
53 static constexpr index_t B_K1 = BK0NK1BlockDesc{}.GetLength(I2);
54
56
57 static constexpr index_t KPerThread = KPerBlock / dpp_gemm.K0PerDpp;
58
60 AccDataType,
61 MRepeat * NRepeat,
62 dpp_gemm.GetRegSizePerDpp(),
63 true>
65
66 __host__ __device__ constexpr auto& GetCThreadBuffer() { return c_thread_buf_; }
67
68 __device__ static auto GetWaveIdx()
69 {
70 const index_t thread_id = ThisThreadBlock::GetThreadId();
71
72 constexpr auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor(
76
77 return threadid_to_wave_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id));
78 }
79
81 {
82 const auto wave_idx = GetWaveIdx();
83 const auto waveId_m = wave_idx[I0];
84 const auto dpp_a_idx = dpp_gemm.CalculateAThreadOriginDataIndex_K_M();
85 const auto dpp_a_idx_k = dpp_a_idx[I0];
86 const auto dpp_a_idx_m = dpp_a_idx[I1];
87 return make_tuple(0, waveId_m, dpp_a_idx_m, KPerThread * dpp_a_idx_k);
88 }
89
91 {
92 const auto wave_idx = GetWaveIdx();
93 const auto waveId_n = wave_idx[I1];
94 const auto dpp_b_idx = dpp_gemm.CalculateBThreadOriginDataIndex_K_N();
95 const auto dpp_b_idx_k = dpp_b_idx[I0];
96 const auto dpp_b_idx_n = dpp_b_idx[I1];
97 return make_tuple(0, waveId_n, dpp_b_idx_n, KPerThread * dpp_b_idx_k);
98 }
99
100 template <index_t m0, index_t n0>
102 {
103 const auto wave_idx = GetWaveIdx();
104 const auto waveId_m = wave_idx[I0];
105 const auto waveId_n = wave_idx[I1];
106
107 const auto blk_idx = dpp_gemm.GetBeginOfThreadBlk();
108 const auto blk_m_offset = blk_idx[I0];
109 const auto blk_n_offset = blk_idx[I1];
110
111 constexpr auto mrepeat_mwave_MPerDpp_to_m_adaptor = make_single_stage_tensor_adaptor(
115
116 constexpr auto nrepeat_nwave_NPerDpp_to_n_adaptor = make_single_stage_tensor_adaptor(
120
121 const index_t c_thread_m = mrepeat_mwave_MPerDpp_to_m_adaptor.CalculateBottomIndex(
122 make_tuple(m0, waveId_m, blk_m_offset))[I0];
123 const index_t c_thread_n = nrepeat_nwave_NPerDpp_to_n_adaptor.CalculateBottomIndex(
124 make_tuple(n0, waveId_n, blk_n_offset))[I0];
125
126 return make_tuple(c_thread_m, c_thread_n);
127 }
128
130 {
131 static_assert(AK0MK1BlockDesc::IsKnownAtCompileTime() &&
132 BK0NK1BlockDesc::IsKnownAtCompileTime(),
133 "Wrong! Block descriptors should be known at the time of compilation.");
134
135#if defined(__HIP_DEVICE_COMPILE__)
136 // Host wave size can be different than the device one and this assert could fail for host,
137 // but it does matter only for device.
139 "ThisThreadBlock::GetNumOfThread() != MWaves * NWaves * WaveSize\n");
140#endif
141
142 static_assert(MPerBlock % (MPerDpp * MRepeat) == 0,
143 "Invalid parameters. MPerBlock must be divisible by MPerDpp * MRepeat.");
144 static_assert(NPerBlock % (NPerDpp * NRepeat) == 0,
145 "Invalid parameters. NPerBlock must be divisible by NPerDpp * NRepeat.");
146 }
147
148 __host__ __device__ static constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_N2()
149 {
150 constexpr auto c_m_n_tblk_lens = dpp_gemm.GetCMNThreadBlkLengths();
151 constexpr auto M = c_m_n_tblk_lens[I0];
152 constexpr auto N = c_m_n_tblk_lens[I1];
153
156 }
157
158 __host__ __device__ static constexpr auto GetCThreadDescriptor_G_M0_N0_M1_N1_M2_N2()
159 {
160 constexpr auto c_m_n_tblk_lens = dpp_gemm.GetCMNThreadBlkLengths();
161 constexpr auto M = c_m_n_tblk_lens[I0];
162 constexpr auto N = c_m_n_tblk_lens[I1];
163
166 }
167
168 __host__ __device__ static constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_N2()
169 {
170 constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2 =
176 Number<NPerDpp>{}));
177
178 return c_block_desc_m0_n0_m1_n1_m2_n2;
179 }
180
181 __host__ __device__ static constexpr auto GetCBlockDescriptor_G_M0_N0_M1_N1_M2_N2()
182 {
183 constexpr auto c_block_desc_g_m0_n0_m1_n1_m2_n2 =
190 Number<NPerDpp>{}));
191 return c_block_desc_g_m0_n0_m1_n1_m2_n2;
192 }
193
194 template <typename CGridDesc_M_N>
195 __host__ __device__ static constexpr auto
196 MakeCGridDescriptor_M0_N0_M1_N1_M2_N2(const CGridDesc_M_N& c_grid_desc_m_n)
197 {
198 const auto M = c_grid_desc_m_n.GetLength(I0);
199 const auto N = c_grid_desc_m_n.GetLength(I1);
200
201 const auto c_grid_desc_m0_n0_m1_n1_m2_n2 = transform_tensor_descriptor(
202 c_grid_desc_m_n,
203 make_tuple(make_unmerge_transform(make_tuple(M / (MWaves * MPerDpp), MWaves, MPerDpp)),
204 make_unmerge_transform(make_tuple(N / (NWaves * NPerDpp), NWaves, NPerDpp))),
207
208 return c_grid_desc_m0_n0_m1_n1_m2_n2;
209 }
210
211 template <typename CGridDesc_G_M_N>
212 __host__ __device__ static constexpr auto
213 MakeCGridDescriptor_G_M0_N0_M1_N1_M2_N2(const CGridDesc_G_M_N& c_grid_desc_g_m_n)
214 {
215 const auto G = c_grid_desc_g_m_n.GetLength(I0);
216 const auto M = c_grid_desc_g_m_n.GetLength(I1);
217 const auto N = c_grid_desc_g_m_n.GetLength(I2);
218
219 const auto c_grid_desc_g_m0_n0_m1_n1_m2_n2 = transform_tensor_descriptor(
220 c_grid_desc_g_m_n,
222 make_unmerge_transform(make_tuple(M / (MWaves * MPerDpp), MWaves, MPerDpp)),
223 make_unmerge_transform(make_tuple(N / (NWaves * NPerDpp), NWaves, NPerDpp))),
226
227 return c_grid_desc_g_m0_n0_m1_n1_m2_n2;
228 }
229
241
253
256
257 template <typename ABlockBuffer, typename BBlockBuffer, typename CThreadBuffer>
258 __device__ void Run(const ABlockBuffer& a_block_buf,
259 const BBlockBuffer& b_block_buf,
260 CThreadBuffer& c_thread_buf) const
261 {
263 a_thread_desc_.GetElementSpaceSize());
265 b_thread_desc_.GetElementSpaceSize());
266
267 static_for<0, MRepeat, 1>{}([&](auto m0) {
268 // read A
270 make_tuple(m0, I0, I0, I0),
271 a_block_buf,
273 make_tuple(I0, I0, I0, I0),
274 a_thread_buf);
275
276 static_for<0, NRepeat, 1>{}([&](auto n0) {
277 // read B
279 make_tuple(n0, I0, I0, I0),
280 b_block_buf,
282 make_tuple(I0, I0, I0, I0),
283 b_thread_buf);
284
288
289 static_for<0, KPack, 1>{}([&](auto i) {
290 a_thread_vec.template AsType<ABDataType>()(i) = a_thread_buf
291 [Number<a_thread_desc_.CalculateOffset(make_tuple(0, 0, 0, k + i))>{}];
292 b_thread_vec.template AsType<ABDataType>()(i) = b_thread_buf
293 [Number<b_thread_desc_.CalculateOffset(make_tuple(0, 0, 0, k + i))>{}];
294 });
295
296 using dpp_input_type =
297 typename vector_type<ABDataType, dpp_gemm.K1PerDpp>::type;
298
299 constexpr index_t c_offset =
300 c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
301
302 dpp_gemm.Run(a_thread_vec.template AsType<dpp_input_type>(),
303 b_thread_vec.template AsType<dpp_input_type>(),
304 c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
305 });
306 });
307 });
308 }
309
310 protected:
311 // A[M0, M1, M2, KPerThread]
312 static constexpr auto a_thread_desc_ =
314
315 // B[N0, N1, N2, KPerThread]
316 static constexpr auto b_thread_desc_ =
318
319 // C[M, N, NumRegDpp]
321 make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, dpp_gemm.GetRegSizePerDpp()));
322
324 ABDataType,
325 decltype(a_block_desc_m0_m1_m2_k),
326 decltype(a_thread_desc_),
329 3,
330 A_K1,
331 A_K1>;
332
334 ABDataType,
335 decltype(b_block_desc_n0_n1_n2_k),
336 decltype(b_thread_desc_),
339 3,
340 B_K1,
341 B_K1>;
342
345};
346
347} // namespace ck
Definition ck.hpp:268
__host__ __device__ constexpr auto make_multi_index(Xs &&... xs)
Definition array_multi_index.hpp:15
__host__ __device__ constexpr auto make_static_buffer(Number< N >)
Definition static_buffer.hpp:186
__host__ __device__ constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition multi_index_transform_helper.hpp:12
int32_t index_t
Definition ck.hpp:299
__host__ __device__ constexpr auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition tensor_description/tensor_adaptor.hpp:425
integral_constant< index_t, N > Number
Definition number.hpp:12
@ Vgpr
Definition amd_address_space.hpp:20
__host__ __device__ constexpr auto make_merge_transform(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:55
__host__ __device__ constexpr auto make_merge_transform_v3_division_mod(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:84
__host__ __device__ constexpr auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition tensor_descriptor_helper.hpp:101
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
__host__ __device__ constexpr auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:90
static constexpr index_t KPerBlock
Definition blockwise_gemm_dpp.hpp:43
static constexpr index_t NWaves
Definition blockwise_gemm_dpp.hpp:47
__host__ static __device__ constexpr auto MakeBBlockDescriptor_N0_N1_N2_K()
Definition blockwise_gemm_dpp.hpp:242
static constexpr index_t B_K1
Definition blockwise_gemm_dpp.hpp:53
BThreadCopy b_thread_copy_
Definition blockwise_gemm_dpp.hpp:344
static constexpr index_t A_K0
Definition blockwise_gemm_dpp.hpp:50
__host__ static __device__ constexpr auto MakeCGridDescriptor_G_M0_N0_M1_N1_M2_N2(const CGridDesc_G_M_N &c_grid_desc_g_m_n)
Definition blockwise_gemm_dpp.hpp:213
static __device__ auto CalculateAThreadOriginDataIndex_M0_M1_M2_K()
Definition blockwise_gemm_dpp.hpp:80
ThisThreadBlock< BlockSize > ThisThreadBlock
Definition blockwise_gemm_dpp.hpp:39
static constexpr auto c_thread_desc_
Definition blockwise_gemm_dpp.hpp:320
static constexpr auto I2
Definition blockwise_gemm_dpp.hpp:36
__host__ static __device__ constexpr auto GetCBlockDescriptor_G_M0_N0_M1_N1_M2_N2()
Definition blockwise_gemm_dpp.hpp:181
static constexpr auto I3
Definition blockwise_gemm_dpp.hpp:37
static __device__ auto GetWaveIdx()
Definition blockwise_gemm_dpp.hpp:68
static constexpr auto I1
Definition blockwise_gemm_dpp.hpp:35
__host__ static __device__ constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_N2()
Definition blockwise_gemm_dpp.hpp:168
static constexpr index_t MPerBlock
Definition blockwise_gemm_dpp.hpp:41
static __device__ auto CalculateCThreadOriginDataIndex(Number< m0 >, Number< n0 >)
Definition blockwise_gemm_dpp.hpp:101
static constexpr index_t WaveSize
Definition blockwise_gemm_dpp.hpp:48
__host__ static __device__ constexpr auto MakeABlockDescriptor_M0_M1_M2_K()
Definition blockwise_gemm_dpp.hpp:230
static constexpr auto I0
Definition blockwise_gemm_dpp.hpp:34
static constexpr auto b_thread_desc_
Definition blockwise_gemm_dpp.hpp:316
static constexpr auto b_block_desc_n0_n1_n2_k
Definition blockwise_gemm_dpp.hpp:255
static constexpr index_t NPerBlock
Definition blockwise_gemm_dpp.hpp:42
ThreadwiseTensorSliceTransfer_v4< ABDataType, ABDataType, decltype(a_block_desc_m0_m1_m2_k), decltype(a_thread_desc_), Sequence< 1, 1, 1, KPerThread >, Sequence< 0, 1, 2, 3 >, 3, A_K1, A_K1 > AThreadCopy
Definition blockwise_gemm_dpp.hpp:323
StaticBufferTupleOfVector< AddressSpaceEnum::Vgpr, AccDataType, MRepeat *NRepeat, dpp_gemm.GetRegSizePerDpp(), true > c_thread_buf_
Definition blockwise_gemm_dpp.hpp:64
static constexpr index_t MWaves
Definition blockwise_gemm_dpp.hpp:46
static constexpr index_t KPerThread
Definition blockwise_gemm_dpp.hpp:57
__host__ static __device__ constexpr auto MakeCGridDescriptor_M0_N0_M1_N1_M2_N2(const CGridDesc_M_N &c_grid_desc_m_n)
Definition blockwise_gemm_dpp.hpp:196
static __device__ auto CalculateBThreadOriginDataIndex_N0_N1_N2_K()
Definition blockwise_gemm_dpp.hpp:90
__device__ void Run(const ABlockBuffer &a_block_buf, const BBlockBuffer &b_block_buf, CThreadBuffer &c_thread_buf) const
Definition blockwise_gemm_dpp.hpp:258
__host__ static __device__ constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_N2()
Definition blockwise_gemm_dpp.hpp:148
static constexpr auto a_block_desc_m0_m1_m2_k
Definition blockwise_gemm_dpp.hpp:254
ThreadwiseTensorSliceTransfer_v4< ABDataType, ABDataType, decltype(b_block_desc_n0_n1_n2_k), decltype(b_thread_desc_), Sequence< 1, 1, 1, KPerThread >, Sequence< 0, 1, 2, 3 >, 3, B_K1, B_K1 > BThreadCopy
Definition blockwise_gemm_dpp.hpp:333
__host__ __device__ constexpr auto & GetCThreadBuffer()
Definition blockwise_gemm_dpp.hpp:66
static constexpr auto dpp_gemm
Definition blockwise_gemm_dpp.hpp:55
static constexpr auto a_thread_desc_
Definition blockwise_gemm_dpp.hpp:312
static constexpr index_t B_K0
Definition blockwise_gemm_dpp.hpp:51
__host__ __device__ BlockwiseGemmDpp_ak0mak1_bk0nbk1_m0n0m1n1m2n2()
Definition blockwise_gemm_dpp.hpp:129
static constexpr index_t A_K1
Definition blockwise_gemm_dpp.hpp:52
AThreadCopy a_thread_copy_
Definition blockwise_gemm_dpp.hpp:343
__host__ static __device__ constexpr auto GetCThreadDescriptor_G_M0_N0_M1_N1_M2_N2()
Definition blockwise_gemm_dpp.hpp:158
Definition dpp_gemm.hpp:426
Definition utility/sequence.hpp:43
Definition static_buffer.hpp:75
static __device__ constexpr index_t GetNumOfThread()
Definition thread_group.hpp:15
static __device__ index_t GetThreadId()
Definition thread_group.hpp:19
Definition threadwise_tensor_slice_transfer.hpp:1260
Definition functional2.hpp:33
Definition dtype_vector.hpp:10