gridwise_gemm_dpp.hpp Source File

gridwise_gemm_dpp.hpp Source File#

Composable Kernel: gridwise_gemm_dpp.hpp Source File
gridwise_gemm_dpp.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
18
19namespace ck {
20
21template <typename GridwiseGemm, bool HasMainKBlockLoop>
22__global__ void
23#if CK_USE_LAUNCH_BOUNDS
25#endif
26#if CK_USE_WAVES_PER_EU
27 __attribute__((amdgpu_waves_per_eu(CK_MIN_WAVES_PER_EU, CK_MAX_WAVES_PER_EU)))
28#endif
29 kernel_gemm_dpp(const typename GridwiseGemm::Argument karg)
30{
31#if(defined(__gfx103__) || defined(__gfx11__))
32 __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
33
34 const auto a_grid_desc_ak0_m_ak1 = amd_wave_read_first_lane(
35 GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1(karg.M, karg.K, karg.AK0, karg.StrideA));
36 const auto b_grid_desc_bk0_n_bk1 = amd_wave_read_first_lane(
37 GridwiseGemm::MakeBGridDescriptor_BK0_N_BK1(karg.K, karg.N, karg.BK0, karg.StrideB));
38 const auto c_grid_desc_m_n = amd_wave_read_first_lane(
39 GridwiseGemm::MakeCGridDescriptor_M_N(karg.M, karg.N, karg.StrideC));
40
41 GridwiseGemm::template Run<HasMainKBlockLoop>(karg.p_a_grid,
42 karg.p_b_grid,
43 karg.p_c_grid,
44 p_shared,
45 a_grid_desc_ak0_m_ak1,
46 b_grid_desc_bk0_n_bk1,
47 c_grid_desc_m_n);
48#else
49 ignore = karg;
50#endif
51}
52
53template <index_t BlockSize,
54 typename ABDataType,
55 typename AccDataType,
56 typename CDataType,
57 InMemoryDataOperationEnum CGlobalMemoryDataOperation,
58 typename ALayout,
59 typename BLayout,
60 typename CLayout,
61 typename AElementwiseOperation,
62 typename BElementwiseOperation,
63 typename CElementwiseOperation,
65 index_t MPerBlock,
66 index_t NPerBlock,
67 index_t KPerBlock,
68 index_t MPerDpp,
69 index_t NPerDpp,
70 index_t AK1Value,
71 index_t BK1Value,
72 index_t MDppPerWave,
73 index_t NDppPerWave,
74 typename ABlockTransferThreadClusterLengths_K0_M_K1,
75 typename ABlockTransferThreadClusterArrangeOrder,
76 typename ABlockTransferSrcAccessOrder,
77 index_t ABlockTransferSrcVectorDim,
78 index_t ABlockTransferSrcScalarPerVector,
79 index_t ABlockTransferDstScalarPerVector_K1,
80 bool AThreadTransferSrcResetCoordinateAfterRun,
81 bool ABlockLdsExtraM,
82 typename BBlockTransferThreadClusterLengths_K0_N_K1,
83 typename BBlockTransferThreadClusterArrangeOrder,
84 typename BBlockTransferSrcAccessOrder,
85 index_t BBlockTransferSrcVectorDim,
86 index_t BBlockTransferSrcScalarPerVector,
87 index_t BBlockTransferDstScalarPerVector_K1,
88 bool BThreadTransferSrcResetCoordinateAfterRun,
89 bool BBlockLdsExtraN,
90 typename CThreadTransferSrcDstAccessOrder,
91 index_t CThreadTransferSrcDstVectorDim,
92 index_t CThreadTransferDstScalarPerVector,
93 index_t NumGemmKPrefetchStage = 1,
96{
97 static constexpr auto I0 = Number<0>{};
98 static constexpr auto I1 = Number<1>{};
99 static constexpr auto I2 = Number<2>{};
100 static constexpr auto I3 = Number<3>{};
101 static constexpr auto I4 = Number<4>{};
102 static constexpr auto I5 = Number<5>{};
103
104 static constexpr auto AK1 = Number<AK1Value>{};
105 static constexpr auto BK1 = Number<BK1Value>{};
106 static constexpr auto AK0PerBlock = Number<KPerBlock / AK1Value>{};
107 static constexpr auto BK0PerBlock = Number<KPerBlock / BK1Value>{};
108
109 static constexpr auto max_lds_align = math::lcm(AK1, BK1);
110
112 // return block_id to C matrix tile idx (m0, n0) mapping
114
115 __host__ static auto CalculateGridSize(index_t M, index_t N)
116 {
117 return std::make_tuple(Block2CTileMap::CalculateGridSize(M, N), 1, 1);
118 }
119
120 __host__ static auto CalculateMPadded(index_t M)
121 {
122 return math::integer_divide_ceil(M, MPerBlock) * MPerBlock;
123 }
124
125 __host__ static auto CalculateNPadded(index_t N)
126 {
127 return math::integer_divide_ceil(N, NPerBlock) * NPerBlock;
128 }
129
130 __host__ static auto CalculateAK0(index_t K) { return math::integer_divide_floor(K, AK1Value); }
131 __host__ static auto CalculateBK0(index_t K) { return math::integer_divide_floor(K, BK1Value); }
132
133 // Argument
134 struct Problem
135 {
136 __host__ Problem(index_t M_,
137 index_t N_,
138 index_t K_,
139 index_t StrideA_,
140 index_t StrideB_,
141 index_t StrideC_)
142 : M{M_},
143 N{N_},
144 K{K_},
145 StrideA{StrideA_},
146 StrideB{StrideB_},
147 StrideC{StrideC_},
152 {
153 }
154
155 __host__ void Print() const
156 {
157 std::cout << "problem {" << "M:" << M << ", " << "N:" << N << ", " << "K:" << K << ", "
158 << "SA:" << StrideA << ", " << "SB:" << StrideB << ", " << "SC:" << StrideC
159 << ", " << "MP:" << MPadded << ", " << "NP:" << NPadded << ", "
160 << "AK0:" << AK0 << ", " << "BK0:" << BK0 << "}" << std::endl;
161 }
162
173 };
174
175 // Argument
177 {
178 __host__ Argument(const ABDataType* p_a_grid_,
179 const ABDataType* p_b_grid_,
180 CDataType* p_c_grid_,
181 index_t M_,
182 index_t N_,
183 index_t K_,
184 index_t StrideA_,
185 index_t StrideB_,
186 index_t StrideC_)
187 : Problem{M_, N_, K_, StrideA_, StrideB_, StrideC_},
188 p_a_grid{p_a_grid_},
189 p_b_grid{p_b_grid_},
190 p_c_grid{p_c_grid_}
191 {
192 }
193
194 const ABDataType* p_a_grid;
195 const ABDataType* p_b_grid;
196 CDataType* p_c_grid;
197 };
198
201
202 __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
203 {
204 // A matrix in LDS memory, dst of blockwise copy
205 constexpr auto a_block_desc_ak0_m_ak1 = [&]() {
206 if constexpr(ABlockLdsExtraM)
207 {
211 }
212 else
213 {
216 }
217 }();
218
219 return a_block_desc_ak0_m_ak1;
220 }
221
222 __host__ __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
223 {
224 // B matrix in LDS memory, dst of blockwise copy
225 constexpr auto b_block_desc_bk0_n_bk1 = [&]() {
226 if constexpr(BBlockLdsExtraN)
227 {
231 }
232 else
233 {
236 }
237 }();
238
239 return b_block_desc_bk0_n_bk1;
240 }
241
242 __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
243 {
244 // LDS allocation for A and B: be careful of alignment
245 constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
246 constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
247
248 constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
249 a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
250 constexpr auto b_block_space_size_aligned = math::integer_least_multiple(
251 b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
252
253 return (a_block_space_size_aligned + b_block_space_size_aligned) * sizeof(ABDataType);
254 }
255
256 __host__ static constexpr bool CheckValidity(const Problem& problem)
257 {
258 static_assert(is_known_at_compile_time<remove_cv_t<decltype(AK1)>>::value,
259 "Wrong! AK1 must be known at the time of compilation.");
260 static_assert(is_known_at_compile_time<remove_cv_t<decltype(BK1)>>::value,
261 "Wrong! BK1 must be known at the time of compilation.");
262
263 static_assert(
264 MPerBlock % (MPerDpp * MDppPerWave) == 0,
265 "Invalid tuning parameters! MPerBlock must be divisible by MPerDpp * MDppPerWave.");
266 static_assert(
267 NPerBlock % (NPerDpp * NDppPerWave) == 0,
268 "Invalid tuning parameters! NPerBlock must be divisible by NPerDpp * NDppPerWave.");
269
270 static_assert(
271 KPerBlock % AK1Value == 0 && KPerBlock % BK1Value == 0,
272 "Invalid tuning parameters! KPerBlock must be divisible by both AK1 and BK1.");
273
274 static_assert(AK1Value % ABlockTransferDstScalarPerVector_K1 == 0,
275 "Invalid tuning parameters! AK1Value must be divisible by "
276 "ABlockTransferDstScalarPerVector_K1");
277
278 static_assert(BK1Value % BBlockTransferDstScalarPerVector_K1 == 0,
279 "Invalid tuning parameters! BK1Value must be divisible by "
280 "BBlockTransferDstScalarPerVector_K1");
281
286 {
287 if(!(problem.M % MPerBlock == 0))
288 {
289 return false;
290 }
291 }
292
297 {
298 if(!(problem.N % NPerBlock == 0))
299 {
300 return false;
301 }
302 }
303
305 {
306 if(problem.K % ABlockTransferSrcScalarPerVector != 0)
307 {
308 return false;
309 }
310 }
311 else
312 {
313 if(problem.M % ABlockTransferSrcScalarPerVector != 0)
314 {
315 return false;
316 }
317 }
318
320 {
321 if(problem.N % BBlockTransferSrcScalarPerVector != 0)
322 {
323 return false;
324 }
325 }
326 else
327 {
328 if(problem.K % BBlockTransferSrcScalarPerVector != 0)
329 {
330 return false;
331 }
332 }
333
334 if(problem.K % KPerBlock != 0)
335 {
336 return false;
337 }
338
339 // check gridwise gemm pipeline
340 const auto num_k_loop = problem.K / KPerBlock;
341 if(!GridwiseGemmPipe::IsSupported(num_k_loop))
342 {
343 return false;
344 }
345
346 return true;
347 }
348
349 __host__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
350 {
351 const auto num_loop = K / KPerBlock;
352
353 return GridwiseGemmPipe::CalculateHasMainLoop(num_loop);
354 }
355
356 template <typename CGridDesc>
357 __host__ __device__ static constexpr auto
358 MakeCGridDescriptor_M0_N0_M1_N1_M2_N2(const CGridDesc& c_grid_desc_m_n)
359 {
360 constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
361 constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
362
363 constexpr index_t KPack = math::max(
365
366 using BlockwiseGemm =
368 ABDataType,
369 AccDataType,
370 decltype(a_block_desc_ak0_m_ak1),
371 decltype(b_block_desc_bk0_n_bk1),
372 MPerDpp,
373 NPerDpp,
374 MDppPerWave,
375 NDppPerWave,
376 KPack>;
377
378 return BlockwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_N2(c_grid_desc_m_n);
379 }
380
381 static constexpr auto matrix_padder =
383 MPerBlock, NPerBlock, KPerBlock};
384
385 __device__ static auto
387 {
388 const auto a_grid_desc_mraw_kraw = [&]() {
390 {
391 return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1));
392 }
394 {
395 return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA));
396 }
397 }();
398
399 const auto a_grid_desc_m_k = matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw);
401 a_grid_desc_m_k,
406 }
407
408 __device__ static auto
410 {
411 const auto b_grid_desc_nraw_kraw = [&]() {
413 {
414 return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(I1, StrideB));
415 }
417 {
418 return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(StrideB, I1));
419 }
420 }();
421
422 const auto b_grid_desc_n_k = matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw);
424 b_grid_desc_n_k,
426 make_unmerge_transform(make_tuple(BK0, BK1Value))),
429 }
430
431 __device__ static auto MakeCGridDescriptor_M_N(index_t M, index_t N, index_t StrideC)
432 {
433 const auto c_grid_desc_mraw_nraw = [&]() {
435 {
436 return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1));
437 }
439 {
440 return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC));
441 }
442 }();
443
444 return matrix_padder.PadCDescriptor_M_N(c_grid_desc_mraw_nraw);
445 }
446
447 template <bool HasMainKBlockLoop,
448 typename AGridDesc_AK0_M_AK1,
449 typename BGridDesc_BK0_N_BK1,
450 typename CGridDesc_M_N>
451 __device__ static void Run(const ABDataType* __restrict__ p_a_grid,
452 const ABDataType* __restrict__ p_b_grid,
453 CDataType* __restrict__ p_c_grid,
454 void* __restrict__ p_shared,
455 const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1,
456 const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1,
457 const CGridDesc_M_N& c_grid_desc_m_n)
458 {
459 const auto c_grid_desc_m0_n0_m1_n1_m2_n2 =
461
462 const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
463 p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
464 const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
465 p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
467 p_c_grid, c_grid_desc_m0_n0_m1_n1_m2_n2.GetElementSpaceSize());
468
469 const AElementwiseOperation a_element_op{};
470 const BElementwiseOperation b_element_op{};
471 const CElementwiseOperation c_element_op{};
472
473 const auto block_2_ctile_map =
474 Block2CTileMap{c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1)};
475
476 // divide block work by [M, N]
477 const auto block_work_idx =
478 block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
479
480 if(!block_2_ctile_map.ValidCTileIndex(
481 block_work_idx,
482 make_tuple(c_grid_desc_m0_n0_m1_n1_m2_n2.GetLength(I0),
483 c_grid_desc_m0_n0_m1_n1_m2_n2.GetLength(I1))))
484 {
485 return;
486 }
487
488 // HACK: this force m/n_block_data_idx_on_grid into SGPR
489 const index_t m_block_data_idx_on_grid =
490 __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock);
491 const index_t n_block_data_idx_on_grid =
492 __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock);
493
494 // A matrix in LDS memory, dst of blockwise copy
495 constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
496 // B matrix in LDS memory, dst of blockwise copy
497 constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
498
499 auto a_blockwise_copy =
501 AElementwiseOperation,
505 ABlockTransferThreadClusterLengths_K0_M_K1,
506 ABlockTransferThreadClusterArrangeOrder,
507 ABDataType,
508 ABDataType,
509 decltype(a_grid_desc_ak0_m_ak1),
510 decltype(a_block_desc_ak0_m_ak1),
511 ABlockTransferSrcAccessOrder,
513 ABlockTransferSrcVectorDim,
514 2,
515 ABlockTransferSrcScalarPerVector,
516 ABlockTransferDstScalarPerVector_K1,
517 1,
518 1,
519 AThreadTransferSrcResetCoordinateAfterRun,
520 true,
521 NumGemmKPrefetchStage>(
522 a_grid_desc_ak0_m_ak1,
523 make_multi_index(0, m_block_data_idx_on_grid, 0),
524 a_element_op,
525 a_block_desc_ak0_m_ak1,
526 make_multi_index(0, 0, 0),
528
529 auto b_blockwise_copy =
531 BElementwiseOperation,
535 BBlockTransferThreadClusterLengths_K0_N_K1,
536 BBlockTransferThreadClusterArrangeOrder,
537 ABDataType,
538 ABDataType,
539 decltype(b_grid_desc_bk0_n_bk1),
540 decltype(b_block_desc_bk0_n_bk1),
541 BBlockTransferSrcAccessOrder,
543 BBlockTransferSrcVectorDim,
544 2,
545 BBlockTransferSrcScalarPerVector,
546 BBlockTransferDstScalarPerVector_K1,
547 1,
548 1,
549 BThreadTransferSrcResetCoordinateAfterRun,
550 true,
551 NumGemmKPrefetchStage>(
552 b_grid_desc_bk0_n_bk1,
553 make_multi_index(0, n_block_data_idx_on_grid, 0),
554 b_element_op,
555 b_block_desc_bk0_n_bk1,
556 make_multi_index(0, 0, 0),
558
559 // GEMM definition
560 // c_mtx += transpose(a_mtx) * b_mtx
561 // a_mtx[AK0PerBlock, MPerBlock] is in LDS
562 // b_mtx[BK0PerBlock, NPerBlock] is in LDS
563 // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
564 // register
565 constexpr index_t KPack = math::max(
567 auto blockwise_gemm =
569 ABDataType,
570 AccDataType,
571 decltype(a_block_desc_ak0_m_ak1),
572 decltype(b_block_desc_bk0_n_bk1),
573 MPerDpp,
574 NPerDpp,
575 MDppPerWave,
576 NDppPerWave,
577 KPack>();
578
579 auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();
580
581 // LDS allocation for A and B: be careful of alignment
582 constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
583 a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
584
586 static_cast<ABDataType*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
587
589 static_cast<ABDataType*>(p_shared) + a_block_space_size_aligned,
590 b_block_desc_bk0_n_bk1.GetElementSpaceSize());
591
592 constexpr auto a_block_slice_copy_step = make_multi_index(AK0PerBlock, 0, 0);
593 constexpr auto b_block_slice_copy_step = make_multi_index(BK0PerBlock, 0, 0);
594
595 // gridwise GEMM pipeline
596 const auto AK0 = a_grid_desc_ak0_m_ak1.GetLength(I0);
597 // (AK0 / AK0PerBlock) is always equal to (BK0 / BK0PerBlock)
598 const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(AK0 / AK0PerBlock);
599
600 GridwiseGemmPipe::template Run<HasMainKBlockLoop>(a_grid_desc_ak0_m_ak1,
601 a_block_desc_ak0_m_ak1,
602 a_blockwise_copy,
603 a_grid_buf,
604 a_block_buf,
605 a_block_slice_copy_step,
606 b_grid_desc_bk0_n_bk1,
607 b_block_desc_bk0_n_bk1,
608 b_blockwise_copy,
609 b_grid_buf,
610 b_block_buf,
611 b_block_slice_copy_step,
612 blockwise_gemm,
613 c_thread_buf,
614 num_k_block_main_loop);
615
616 // output: register to global memory
617 {
618 constexpr auto c_thread_desc_m0_n0_m1_n1_m2_n2 =
619 blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2();
620
621 constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2 =
622 blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2();
623
624 constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_n2.GetLength(I0);
625 constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_n2.GetLength(I1);
626 constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_n2.GetLength(I2);
627 constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_n2.GetLength(I3);
628 constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_n2.GetLength(I4);
629 constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_n2.GetLength(I5);
630
631 constexpr auto MPerThread = c_thread_desc_m0_n0_m1_n1_m2_n2.GetLength(I4);
632 constexpr auto NPerThread = c_thread_desc_m0_n0_m1_n1_m2_n2.GetLength(I5);
633
634 // calculate origin of thread output tensor on global memory
635 // blockwise GEMM c matrix starting index
636 const auto c_thread_mtx_on_block =
637 blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0);
638
639 const index_t m_thread_data_on_grid =
640 m_block_data_idx_on_grid + c_thread_mtx_on_block[I0];
641
642 const index_t n_thread_data_on_grid =
643 n_block_data_idx_on_grid + c_thread_mtx_on_block[I1];
644
645 const auto m_thread_data_on_grid_to_m0_m1_m2_adaptor = make_single_stage_tensor_adaptor(
649
650 const auto m_thread_data_on_grid_idx =
651 m_thread_data_on_grid_to_m0_m1_m2_adaptor.CalculateBottomIndex(
652 make_multi_index(m_thread_data_on_grid));
653
654 const auto n_thread_data_on_grid_to_n0_n1_n2_adaptor = make_single_stage_tensor_adaptor(
658
659 const auto n_thread_data_on_grid_idx =
660 n_thread_data_on_grid_to_n0_n1_n2_adaptor.CalculateBottomIndex(
661 make_multi_index(n_thread_data_on_grid));
662
663 auto c_thread_copy =
665 CDataType,
666 decltype(c_thread_desc_m0_n0_m1_n1_m2_n2),
667 decltype(c_grid_desc_m0_n0_m1_n1_m2_n2),
668 CElementwiseOperation,
670 CThreadTransferSrcDstAccessOrder,
671 CThreadTransferSrcDstVectorDim,
672 CThreadTransferDstScalarPerVector,
673 CGlobalMemoryDataOperation,
674 1,
675 true>{
676 c_grid_desc_m0_n0_m1_n1_m2_n2,
677 make_multi_index(m_thread_data_on_grid_idx[I0],
678 n_thread_data_on_grid_idx[I0],
679 m_thread_data_on_grid_idx[I1],
680 n_thread_data_on_grid_idx[I1],
681 m_thread_data_on_grid_idx[I2],
682 n_thread_data_on_grid_idx[I2]),
683 c_element_op};
684
685 c_thread_copy.Run(c_thread_desc_m0_n0_m1_n1_m2_n2,
686 make_tuple(I0, I0, I0, I0, I0, I0),
687 c_thread_buf,
688 c_grid_desc_m0_n0_m1_n1_m2_n2,
689 c_grid_buf);
690 }
691 }
692};
693
694} // namespace ck
#define CK_MIN_BLOCK_PER_CU
Definition ck.hpp:31
#define CK_MAX_THREAD_PER_BLOCK
Definition ck.hpp:30
__host__ __device__ constexpr auto integer_least_multiple(X x, Y y)
Definition utility/math.hpp:78
__host__ __device__ constexpr auto integer_divide_floor(X x, Y y)
Definition utility/math.hpp:66
__host__ __device__ constexpr T max(T x)
Definition utility/math.hpp:84
__host__ __device__ constexpr auto integer_divide_ceil(X x, Y y)
Definition utility/math.hpp:72
__host__ __device__ constexpr auto lcm(X x, Y y)
Definition utility/math.hpp:198
GemmSpecialization
Definition gemm_specialization.hpp:11
@ MKPadding
Definition gemm_specialization.hpp:18
@ NPadding
Definition gemm_specialization.hpp:15
@ MPadding
Definition gemm_specialization.hpp:14
@ MNKPadding
Definition gemm_specialization.hpp:20
@ MNPadding
Definition gemm_specialization.hpp:17
@ NKPadding
Definition gemm_specialization.hpp:19
Definition ck.hpp:268
__host__ __device__ constexpr auto make_multi_index(Xs &&... xs)
Definition array_multi_index.hpp:15
__host__ __device__ constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition multi_index_transform_helper.hpp:12
constexpr auto GridwiseGemmPipeline_Selector()
Definition gridwise_gemm_pipeline_selector.hpp:31
int32_t index_t
Definition ck.hpp:299
__host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition tensor_descriptor_helper.hpp:49
InMemoryDataOperationEnum
Definition ck.hpp:277
@ Set
Definition ck.hpp:278
__host__ __device__ constexpr auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition tensor_description/tensor_adaptor.hpp:425
__global__ void kernel_gemm_dpp(const typename GridwiseGemm::Argument karg)
Definition gridwise_gemm_dpp.hpp:29
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
integral_constant< index_t, N > Number
Definition number.hpp:12
__host__ __device__ constexpr auto make_merge_transform(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:55
constexpr detail::ignore_t ignore
Definition utility/ignore.hpp:20
__device__ uint32_t amd_wave_read_first_lane(uint32_t value)
Definition amd_wave_read_first_lane.hpp:100
__device__ index_t get_block_1d_id()
Definition get_id.hpp:47
__host__ __device__ constexpr auto make_naive_tensor_descriptor_aligned(const Tuple< Lengths... > &lengths, Align align)
Definition tensor_descriptor_helper.hpp:132
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
PipelineVersion
Definition gridwise_gemm_pipeline_selector.hpp:18
@ v1
Definition gridwise_gemm_pipeline_selector.hpp:19
__host__ __device__ constexpr auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:90
typename remove_cv< T >::type remove_cv_t
Definition type.hpp:295
__host__ __device__ constexpr auto make_dynamic_buffer(T *p, ElementSpaceSize element_space_size)
Definition dynamic_buffer.hpp:472
const GenericPointer< typename T::ValueType > T2 value
Definition pointer.h:1697
Definition block_to_ctile_map.hpp:261
Definition blockwise_gemm_dpp.hpp:33
static constexpr auto selected_dpp
Definition dpp_gemm.hpp:380
const ABDataType * p_a_grid
Definition gridwise_gemm_dpp.hpp:194
const ABDataType * p_b_grid
Definition gridwise_gemm_dpp.hpp:195
CDataType * p_c_grid
Definition gridwise_gemm_dpp.hpp:196
__host__ Argument(const ABDataType *p_a_grid_, const ABDataType *p_b_grid_, CDataType *p_c_grid_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, index_t StrideC_)
Definition gridwise_gemm_dpp.hpp:178
index_t NPadded
Definition gridwise_gemm_dpp.hpp:170
index_t BK0
Definition gridwise_gemm_dpp.hpp:172
index_t StrideB
Definition gridwise_gemm_dpp.hpp:167
index_t N
Definition gridwise_gemm_dpp.hpp:164
index_t K
Definition gridwise_gemm_dpp.hpp:165
index_t StrideC
Definition gridwise_gemm_dpp.hpp:168
index_t M
Definition gridwise_gemm_dpp.hpp:163
index_t AK0
Definition gridwise_gemm_dpp.hpp:171
index_t MPadded
Definition gridwise_gemm_dpp.hpp:169
__host__ Problem(index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, index_t StrideC_)
Definition gridwise_gemm_dpp.hpp:136
__host__ void Print() const
Definition gridwise_gemm_dpp.hpp:155
index_t StrideA
Definition gridwise_gemm_dpp.hpp:166
Definition gridwise_gemm_dpp.hpp:96
static __host__ constexpr bool CalculateHasMainKBlockLoop(index_t K)
Definition gridwise_gemm_dpp.hpp:349
static __host__ auto CalculateAK0(index_t K)
Definition gridwise_gemm_dpp.hpp:130
__host__ static __device__ constexpr auto MakeCGridDescriptor_M0_N0_M1_N1_M2_N2(const CGridDesc &c_grid_desc_m_n)
Definition gridwise_gemm_dpp.hpp:358
static __device__ void Run(const ABDataType *__restrict__ p_a_grid, const ABDataType *__restrict__ p_b_grid, CDataType *__restrict__ p_c_grid, void *__restrict__ p_shared, const AGridDesc_AK0_M_AK1 &a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_BK1 &b_grid_desc_bk0_n_bk1, const CGridDesc_M_N &c_grid_desc_m_n)
Definition gridwise_gemm_dpp.hpp:451
static __device__ auto MakeCGridDescriptor_M_N(index_t M, index_t N, index_t StrideC)
Definition gridwise_gemm_dpp.hpp:431
static __host__ auto CalculateBK0(index_t K)
Definition gridwise_gemm_dpp.hpp:131
__host__ static __device__ constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
Definition gridwise_gemm_dpp.hpp:202
static __device__ auto MakeBGridDescriptor_BK0_N_BK1(index_t K, index_t N, index_t BK0, index_t StrideB)
Definition gridwise_gemm_dpp.hpp:409
__host__ static __device__ constexpr index_t GetSharedMemoryNumberOfByte()
Definition gridwise_gemm_dpp.hpp:242
static __host__ auto CalculateGridSize(index_t M, index_t N)
Definition gridwise_gemm_dpp.hpp:115
static __host__ auto CalculateMPadded(index_t M)
Definition gridwise_gemm_dpp.hpp:120
static __host__ constexpr bool CheckValidity(const Problem &problem)
Definition gridwise_gemm_dpp.hpp:256
__host__ static __device__ constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
Definition gridwise_gemm_dpp.hpp:222
static __host__ auto CalculateNPadded(index_t N)
Definition gridwise_gemm_dpp.hpp:125
static __device__ auto MakeAGridDescriptor_AK0_M_AK1(index_t M, index_t K, index_t AK0, index_t StrideA)
Definition gridwise_gemm_dpp.hpp:386
Definition utility/sequence.hpp:43
Blockwise data transfer.
Definition thread_group_tensor_slice_transfer_v4r1.hpp:46
Definition threadwise_tensor_slice_transfer.hpp:39
static constexpr value_type value
Definition utility/integral_constant.hpp:13
Definition is_known_at_compile_time.hpp:14
Definition device_base.hpp:197
Definition matrix_padder.hpp:180
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:340