blockwise_gemm_pipeline_xdlops_v2.hpp Source File

blockwise_gemm_pipeline_xdlops_v2.hpp Source File#

Composable Kernel: blockwise_gemm_pipeline_xdlops_v2.hpp Source File
blockwise_gemm_pipeline_xdlops_v2.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
7
8namespace ck {
9
10// Maximum Global Memory throughput pipeline with >=32KB data in fly
11// GlobalPrefetchStages: >=2
12// LocalPreFillStages: 1
13// LocalPreFetchStages: 0
14// LocalSharedMemoryBuffer: 1
15
16template <BlockGemmPipelineScheduler BlkGemmPipelineVer,
17 index_t BlockSize,
18 typename ADataType,
19 typename BDataType,
20 typename ComputeDataType,
21 typename AccDataType,
22 typename ATileDesc,
23 typename BTileDesc,
24 typename AMmaTileDesc,
25 typename BMmaTileDesc,
26 index_t ABlockTransferSrcScalarPerVector,
27 index_t BBlockTransferSrcScalarPerVector,
28 index_t MPerBlock,
29 index_t NPerBlock,
30 index_t KPerBlock,
31 index_t MPerXDL,
32 index_t NPerXDL,
33 index_t MRepeat,
34 index_t NRepeat,
35 index_t KPacks>
39
40template <index_t BlockSize,
41 typename ADataType,
42 typename BDataType,
43 typename ComputeDataType,
44 typename AccDataType,
45 typename ATileDesc,
46 typename BTileDesc,
47 typename AMmaTileDesc,
48 typename BMmaTileDesc,
49 index_t ABlockTransferSrcScalarPerVector,
50 index_t BBlockTransferSrcScalarPerVector,
51 index_t MPerBlock,
52 index_t NPerBlock,
53 index_t KPerBlock,
54 index_t MPerXDL,
55 index_t NPerXDL,
56 index_t MRepeat,
57 index_t NRepeat,
58 index_t KPack
59 // ,bool TransposeC //disable transposec right now...
60 >
62 BlockSize,
63 ADataType,
64 BDataType,
65 ComputeDataType,
66 AccDataType,
67 ATileDesc,
68 BTileDesc,
69 AMmaTileDesc,
70 BMmaTileDesc,
71 ABlockTransferSrcScalarPerVector,
72 BBlockTransferSrcScalarPerVector,
73 MPerBlock,
74 NPerBlock,
75 KPerBlock,
76 MPerXDL,
77 NPerXDL,
78 MRepeat,
79 NRepeat,
80 KPack>
82 ADataType,
83 BDataType,
84 ComputeDataType,
85 AccDataType,
86 ATileDesc,
87 BTileDesc,
88 AMmaTileDesc,
89 BMmaTileDesc,
90 ABlockTransferSrcScalarPerVector,
91 BBlockTransferSrcScalarPerVector,
92 MPerBlock,
93 NPerBlock,
94 KPerBlock,
95 MPerXDL,
96 NPerXDL,
97 MRepeat,
98 NRepeat,
99 KPack>
100
101{
103 ADataType,
104 BDataType,
105 ComputeDataType,
106 AccDataType,
107 ATileDesc,
108 BTileDesc,
109 AMmaTileDesc,
110 BMmaTileDesc,
111 ABlockTransferSrcScalarPerVector,
112 BBlockTransferSrcScalarPerVector,
113 MPerBlock,
114 NPerBlock,
115 KPerBlock,
116 MPerXDL,
117 NPerXDL,
118 MRepeat,
119 NRepeat,
120 KPack>;
121 using Base::I0;
122 using Base::KRepeat;
123 using Base::xdlops_gemm;
124
136
139
140 using Base::AMmaKStride;
141 using Base::BMmaKStride;
142 using Base::WaveSize;
143
145
146 static constexpr index_t WgpPerCU =
147 (4 * WaveSize / BlockSize) >= 1 ? 4 * WaveSize / BlockSize : 1;
149 32768 / WgpPerCU,
150 (MPerBlock * sizeof(ADataType) + NPerBlock * sizeof(BDataType)) * KPerBlock);
151 static constexpr index_t PrefetchStages =
154 : 2;
155
156 static constexpr index_t PrefillStages = 1;
158
159 __host__ __device__ static constexpr bool BlockHasHotloop(index_t num_loop)
160 {
161 return num_loop > PrefetchStages;
162 }
163
164 __host__ __device__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop)
165 {
166 if(num_loop % PrefetchStages == 1)
167 {
168 return TailNumber::One;
169 }
170 else if(num_loop % PrefetchStages == 2)
171 {
172 return TailNumber::Two;
173 }
174 else if(num_loop % PrefetchStages == 3)
175 {
176 return TailNumber::Three;
177 }
178 else if(num_loop % PrefetchStages == 4)
179 {
180 return TailNumber::Four;
181 }
182 else if(num_loop % PrefetchStages == 5)
183 {
184 return TailNumber::Five;
185 }
186 else if(num_loop % PrefetchStages == 6)
187 {
188 return TailNumber::Six;
189 }
190 else if(num_loop % PrefetchStages == 7)
191 {
192 return TailNumber::Seven;
193 }
194 else
195 {
196 return TailNumber::Full;
197 }
198 }
199
200 template <bool HasMainLoop,
201 TailNumber TailNum,
202 typename AGridDesc,
203 typename ABlockDesc,
204 typename ABlockTransfer,
205 typename AGridBuffer,
206 typename ABlockBuffer,
207 typename ABlockTransferStep,
208 typename BGridDesc,
209 typename BBlockDesc,
210 typename BBlockTransfer,
211 typename BGridBuffer,
212 typename BBlockBuffer,
213 typename BBlockTransferStep,
214 typename CThreadBuffer>
215 __device__ void Run(const AGridDesc& a_grid_desc,
216 const ABlockDesc& a_block_desc,
217 ABlockTransfer& a_blockwise_copy,
218 const AGridBuffer& a_grid_buf,
219 ABlockBuffer& a_block_buf,
220 const ABlockTransferStep& a_block_copy_step,
221 const BGridDesc& b_grid_desc,
222 const BBlockDesc& b_block_desc,
223 BBlockTransfer& b_blockwise_copy,
224 const BGridBuffer& b_grid_buf,
225 BBlockBuffer& b_block_buf,
226 const BBlockTransferStep& b_block_copy_step,
227 CThreadBuffer& c_thread_buf,
228 index_t num_loop) const
229 {
231 a_thread_desc_.GetElementSpaceSize());
233 b_thread_desc_.GetElementSpaceSize());
234
235 // Global prefetch 1
236 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0);
237 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I0);
238
239 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
240 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
241
242 // Initialize C
243 c_thread_buf.Clear();
244
245 // Local prefill 1
246 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I0);
247 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, I0);
248
249 // Global prefetch [2, PrefetchStages]
250 static_for<1, PrefetchStages, 1>{}([&](auto iprefetch) {
251 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, iprefetch);
252 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, iprefetch);
253
254 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
255 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
256 });
257
258 // main body
259 if constexpr(HasMainLoop)
260 {
261 index_t i = 0;
262 do
263 {
264 static_for<0, PrefetchStages, 1>{}([&](auto iprefetch) {
265 // -------------------------------------------------------------------------------------------
267 static_for<0, KRepeat, 1>{}([&](auto k) {
268 static_for<0, MRepeat, 1>{}([&](auto m0) {
271 a_block_buf,
273 make_tuple(m0, I0, k, I0),
274 a_thread_buf);
275 });
276 static_for<0, NRepeat, 1>{}([&](auto n0) {
279 b_block_buf,
281 make_tuple(n0, I0, k, I0),
282 b_thread_buf);
283 });
284 });
285
286 static_for<0, KRepeat, 1>{}([&](auto k0) {
287 static_for<0, MRepeat, 1>{}([&](auto m0) {
288 static_for<0, NRepeat, 1>{}([&](auto n0) {
291
292 static_for<0, KPack, 1>{}([&](auto ik) {
293 a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
294 a_thread_buf[Number<a_thread_desc_.CalculateOffset(
295 make_tuple(m0, I0, k0, ik))>{}];
296 b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
297 b_thread_buf[Number<b_thread_desc_.CalculateOffset(
298 make_tuple(n0, I0, k0, ik))>{}];
299 });
300
301 using mfma_input_type =
303 xdlops_gemm.K1PerXdlops>::type;
304
305 constexpr index_t c_offset =
306 c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
307
308 xdlops_gemm.Run(
309 a_thread_vec.template AsType<mfma_input_type>(),
310 b_thread_vec.template AsType<mfma_input_type>(),
311 c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
312 });
313 });
314 });
315
317 a_blockwise_copy.RunWrite(
318 a_block_desc, a_block_buf, Number<(iprefetch + 1) % PrefetchStages>{});
319 b_blockwise_copy.RunWrite(
320 b_block_desc, b_block_buf, Number<(iprefetch + 1) % PrefetchStages>{});
321
322 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, iprefetch);
323 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, iprefetch);
324
325 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
326 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
327 });
328
329 i += PrefetchStages;
330 } while(i < (num_loop - PrefetchStages));
331 }
332
333 // tail
334
335 auto LoopTailFunc = [&](auto tail_num) {
336 static_for<1, tail_num, 1>{}([&](auto iprefetch) {
338 static_for<0, KRepeat, 1>{}([&](auto k) {
339 static_for<0, MRepeat, 1>{}([&](auto m0) {
342 a_block_buf,
344 make_tuple(m0, I0, k, I0),
345 a_thread_buf);
346 });
347 static_for<0, NRepeat, 1>{}([&](auto n0) {
350 b_block_buf,
352 make_tuple(n0, I0, k, I0),
353 b_thread_buf);
354 });
355 });
356
357 static_for<0, KRepeat, 1>{}([&](auto k0) {
358 static_for<0, MRepeat, 1>{}([&](auto m0) {
359 static_for<0, NRepeat, 1>{}([&](auto n0) {
362
363 static_for<0, KPack, 1>{}([&](auto ik) {
364 a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
365 a_thread_buf[Number<a_thread_desc_.CalculateOffset(
366 make_tuple(m0, I0, k0, ik))>{}];
367 b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
368 b_thread_buf[Number<b_thread_desc_.CalculateOffset(
369 make_tuple(n0, I0, k0, ik))>{}];
370 });
371
372 using mfma_input_type =
374 xdlops_gemm.K1PerXdlops>::type;
375
376 constexpr index_t c_offset =
377 c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
378
379 xdlops_gemm.Run(
380 a_thread_vec.template AsType<mfma_input_type>(),
381 b_thread_vec.template AsType<mfma_input_type>(),
382 c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
383 });
384 });
385 });
386
388 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, iprefetch);
389 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, iprefetch);
390 });
391
393 static_for<0, KRepeat, 1>{}([&](auto k) {
394 static_for<0, MRepeat, 1>{}([&](auto m0) {
397 a_block_buf,
399 make_tuple(m0, I0, k, I0),
400 a_thread_buf);
401 });
402 static_for<0, NRepeat, 1>{}([&](auto n0) {
405 b_block_buf,
407 make_tuple(n0, I0, k, I0),
408 b_thread_buf);
409 });
410 });
411
412 static_for<0, KRepeat, 1>{}([&](auto k0) {
413 static_for<0, MRepeat, 1>{}([&](auto m0) {
414 static_for<0, NRepeat, 1>{}([&](auto n0) {
417
418 static_for<0, KPack, 1>{}([&](auto ik) {
419 a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
420 a_thread_buf[Number<a_thread_desc_.CalculateOffset(
421 make_tuple(m0, I0, k0, ik))>{}];
422 b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
423 b_thread_buf[Number<b_thread_desc_.CalculateOffset(
424 make_tuple(n0, I0, k0, ik))>{}];
425 });
426
427 using mfma_input_type =
428 typename vector_type<ComputeDataTypeBuf, xdlops_gemm.K1PerXdlops>::type;
429
430 constexpr index_t c_offset =
431 c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
432
433 xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
434 b_thread_vec.template AsType<mfma_input_type>(),
435 c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
436 });
437 });
438 });
439 };
440
441 if constexpr(TailNum == TailNumber::One)
442 {
444 static_for<0, KRepeat, 1>{}([&](auto k) {
445 static_for<0, MRepeat, 1>{}([&](auto m0) {
448 a_block_buf,
450 make_tuple(m0, I0, k, I0),
451 a_thread_buf);
452 });
453 static_for<0, NRepeat, 1>{}([&](auto n0) {
456 b_block_buf,
458 make_tuple(n0, I0, k, I0),
459 b_thread_buf);
460 });
461 });
462
463 static_for<0, KRepeat, 1>{}([&](auto k0) {
464 static_for<0, MRepeat, 1>{}([&](auto m0) {
465 static_for<0, NRepeat, 1>{}([&](auto n0) {
468
469 static_for<0, KPack, 1>{}([&](auto ik) {
470 a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
471 a_thread_buf[Number<a_thread_desc_.CalculateOffset(
472 make_tuple(m0, I0, k0, ik))>{}];
473 b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
474 b_thread_buf[Number<b_thread_desc_.CalculateOffset(
475 make_tuple(n0, I0, k0, ik))>{}];
476 });
477
478 using mfma_input_type =
479 typename vector_type<ComputeDataTypeBuf, xdlops_gemm.K1PerXdlops>::type;
480
481 constexpr index_t c_offset =
482 c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
483
484 xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
485 b_thread_vec.template AsType<mfma_input_type>(),
486 c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
487 });
488 });
489 });
490 }
491 else if constexpr(TailNum == TailNumber::Two)
492 {
493 LoopTailFunc(Number<2>{});
494 }
495 else if constexpr(TailNum == TailNumber::Three)
496 {
497 LoopTailFunc(Number<3>{});
498 }
499 else if constexpr(TailNum == TailNumber::Four)
500 {
501 LoopTailFunc(Number<4>{});
502 }
503 else if constexpr(TailNum == TailNumber::Five)
504 {
505 LoopTailFunc(Number<5>{});
506 }
507 else if constexpr(TailNum == TailNumber::Six)
508 {
509 LoopTailFunc(Number<6>{});
510 }
511 else if constexpr(TailNum == TailNumber::Seven)
512 {
513 LoopTailFunc(Number<7>{});
514 }
515 else if constexpr(TailNum == TailNumber::Full)
516 {
517 LoopTailFunc(Number<PrefetchStages>{});
518 }
519 }
520
521 protected:
522 using Base::a_thread_copy_;
523 using Base::a_thread_desc_;
524 using Base::b_thread_copy_;
525 using Base::b_thread_desc_;
526 using Base::c_thread_desc_;
527};
528
529template <index_t BlockSize,
530 typename ADataType,
531 typename BDataType,
532 typename ComputeDataType,
533 typename AccDataType,
534 typename ATileDesc,
535 typename BTileDesc,
536 typename AMmaTileDesc,
537 typename BMmaTileDesc,
538 index_t ABlockTransferSrcScalarPerVector,
539 index_t BBlockTransferSrcScalarPerVector,
540 index_t MPerBlock,
541 index_t NPerBlock,
542 index_t KPerBlock,
543 index_t MPerXDL,
544 index_t NPerXDL,
545 index_t MRepeat,
546 index_t NRepeat,
547 index_t KPack
548 // ,bool TransposeC //disable transposec right now...
549 >
551 BlockSize,
552 ADataType,
553 BDataType,
554 ComputeDataType,
555 AccDataType,
556 ATileDesc,
557 BTileDesc,
558 AMmaTileDesc,
559 BMmaTileDesc,
560 ABlockTransferSrcScalarPerVector,
561 BBlockTransferSrcScalarPerVector,
562 MPerBlock,
563 NPerBlock,
564 KPerBlock,
565 MPerXDL,
566 NPerXDL,
567 MRepeat,
568 NRepeat,
569 KPack>
571 ADataType,
572 BDataType,
573 ComputeDataType,
574 AccDataType,
575 ATileDesc,
576 BTileDesc,
577 AMmaTileDesc,
578 BMmaTileDesc,
579 ABlockTransferSrcScalarPerVector,
580 BBlockTransferSrcScalarPerVector,
581 MPerBlock,
582 NPerBlock,
583 KPerBlock,
584 MPerXDL,
585 NPerXDL,
586 MRepeat,
587 NRepeat,
588 KPack>
589
590{
592 ADataType,
593 BDataType,
594 ComputeDataType,
595 AccDataType,
596 ATileDesc,
597 BTileDesc,
598 AMmaTileDesc,
599 BMmaTileDesc,
600 ABlockTransferSrcScalarPerVector,
601 BBlockTransferSrcScalarPerVector,
602 MPerBlock,
603 NPerBlock,
604 KPerBlock,
605 MPerXDL,
606 NPerXDL,
607 MRepeat,
608 NRepeat,
609 KPack>;
610 using Base::A_K1;
611 using Base::B_K1;
612 using Base::I0;
613 using Base::I1;
614 using Base::KPerThread;
615 using Base::xdlops_gemm;
616
628
631 using Base::WaveSize;
632
634
638
639 static constexpr index_t WgpPerCU =
640 (4 * WaveSize / BlockSize) >= 1 ? 4 * WaveSize / BlockSize : 1;
642 32768 / WgpPerCU,
643 (MPerBlock * sizeof(ADataType) + NPerBlock * sizeof(BDataType)) * KPerBlock);
644 static constexpr index_t PrefetchStages =
647 : 2;
648
649 static constexpr index_t PrefillStages = 1;
651
652 __host__ __device__ static constexpr bool BlockHasHotloop(index_t num_loop)
653 {
654 return num_loop > PrefetchStages;
655 }
656
657 __host__ __device__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop)
658 {
659 if(num_loop % PrefetchStages == 1)
660 {
661 return TailNumber::One;
662 }
663 else if(num_loop % PrefetchStages == 2)
664 {
665 return TailNumber::Two;
666 }
667 else if(num_loop % PrefetchStages == 3)
668 {
669 return TailNumber::Three;
670 }
671 else if(num_loop % PrefetchStages == 4)
672 {
673 return TailNumber::Four;
674 }
675 else if(num_loop % PrefetchStages == 5)
676 {
677 return TailNumber::Five;
678 }
679 else if(num_loop % PrefetchStages == 6)
680 {
681 return TailNumber::Six;
682 }
683 else if(num_loop % PrefetchStages == 7)
684 {
685 return TailNumber::Seven;
686 }
687 else
688 {
689 return TailNumber::Full;
690 }
691 }
692
693 template <bool HasMainLoop,
694 TailNumber TailNum,
695 typename AGridDesc,
696 typename ABlockDesc,
697 typename ABlockTransfer,
698 typename AGridBuffer,
699 typename ABlockBuffer,
700 typename ABlockTransferStep,
701 typename BGridDesc,
702 typename BBlockDesc,
703 typename BBlockTransfer,
704 typename BGridBuffer,
705 typename BBlockBuffer,
706 typename BBlockTransferStep,
707 typename CThreadBuffer>
708 __device__ void Run(const AGridDesc& a_grid_desc,
709 const ABlockDesc& a_block_desc,
710 ABlockTransfer& a_blockwise_copy,
711 const AGridBuffer& a_grid_buf,
712 ABlockBuffer& a_block_buf,
713 const ABlockTransferStep& a_block_copy_step,
714 const BGridDesc& b_grid_desc,
715 const BBlockDesc& b_block_desc,
716 BBlockTransfer& b_blockwise_copy,
717 const BGridBuffer& b_grid_buf,
718 BBlockBuffer& b_block_buf,
719 const BBlockTransferStep& b_block_copy_step,
720 CThreadBuffer& c_thread_buf,
721 index_t num_loop) const
722 {
724 a_thread_desc_.GetElementSpaceSize());
726 b_thread_desc_.GetElementSpaceSize());
727
728 // Global prefetch 1
729 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0);
730 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I0);
731
732 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
733 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
734
735 // Initialize C
736 c_thread_buf.Clear();
737
738 // Local prefill 1
739 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I0);
740 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, I0);
741
742 // Global prefetch [2, PrefetchStages]
743 static_for<1, PrefetchStages, 1>{}([&](auto iprefetch) {
744 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, iprefetch);
745 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, iprefetch);
746
747 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
748 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
749 });
750
751 // main body
752 if constexpr(HasMainLoop)
753 {
754 index_t i = 0;
755 do
756 {
757 static_for<0, PrefetchStages, 1>{}([&](auto iprefetch) {
758 // -------------------------------------------------------------------------------------------
760 static_for<0, KRepeat, 1>{}([&](auto k0) {
761 static_for<0, MRepeat, 1>{}([&](auto m0) {
764 a_block_buf,
766 make_tuple(m0, I0, k0, I0),
767 a_thread_buf);
768 });
769 static_for<0, NRepeat, 1>{}([&](auto n0) {
772 b_block_buf,
774 make_tuple(n0, I0, k0, I0),
775 b_thread_buf);
776 });
777 __builtin_amdgcn_sched_barrier(0);
778 // NOTE: Synchronize threads in a workgroup at the start of each MAC
779 // cluster, but except the first, as we can shorten non-MAC cluster a bit
780 // and there's no observable negative impact. The desired effect is waves in
781 // a workgroup executing MAC in sync. This avoids some out-of-sync waves
782 // hijacking MAC resource from other workgroups and reducing the chance of
783 // latency hiding by waiting for the rest of the workgroup at the eventual
784 // sync point.
785 if constexpr(k0.value != 0 || KRepeat == 1)
786 {
787 __builtin_amdgcn_s_barrier();
788 __builtin_amdgcn_sched_barrier(0);
789 }
791 static_for<0, MRepeat, 1>{}([&](auto m0) {
792 static_for<0, NRepeat, 1>{}([&](auto n0) {
795
796 static_for<0, KPack, 1>{}([&](auto ik) {
797 a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
798 a_thread_buf[Number<a_thread_desc_.CalculateOffset(
799 make_tuple(m0, I0, k0, k_ + ik))>{}];
800 b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
801 b_thread_buf[Number<b_thread_desc_.CalculateOffset(
802 make_tuple(n0, I0, k0, k_ + ik))>{}];
803 });
804
805 using mfma_input_type =
807 xdlops_gemm.K1PerXdlops>::type;
808
809 constexpr index_t c_offset =
810 c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
811
812 // The block_sync_lds() here performs double duty:
813 // A) safeguard against data hazard because barrier from
814 // blockwise_gemm is moved here B) reduce VMEM FIFO congestion
815 // by applying small delays to different wavefronts It is
816 // performed near the end of MAC cluster to minimize lgkmcnt
817 // penalty
818 if constexpr(k0.value == KRepeat - 1 &&
819 k_.value == KPerInnerLoop - KPack &&
820 m0.value == MRepeat - 1 && n0.value == NRepeat - 1)
821 {
822 __builtin_amdgcn_sched_barrier(0);
824 __builtin_amdgcn_sched_barrier(0);
825 }
826 xdlops_gemm.Run(
827 a_thread_vec.template AsType<mfma_input_type>(),
828 b_thread_vec.template AsType<mfma_input_type>(),
829 c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
830 if constexpr(k_.value == 0 && m0.value == 0 && n0.value == 0)
831 {
832 __builtin_amdgcn_sched_barrier(0);
833 __builtin_amdgcn_s_setprio(1);
834 __builtin_amdgcn_sched_barrier(0);
835 }
836 });
837 });
838 });
839 __builtin_amdgcn_sched_barrier(0);
840 __builtin_amdgcn_s_setprio(0);
841 __builtin_amdgcn_sched_barrier(0);
842 });
843
844 // block_sync_lds();
845 a_blockwise_copy.RunWrite(
846 a_block_desc, a_block_buf, Number<(iprefetch + 1) % PrefetchStages>{});
847 b_blockwise_copy.RunWrite(
848 b_block_desc, b_block_buf, Number<(iprefetch + 1) % PrefetchStages>{});
849
850 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, iprefetch);
851 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, iprefetch);
852
853 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
854 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
855 });
856 i += PrefetchStages;
857 } while(i < (num_loop - PrefetchStages));
858 }
859
860 // tail
861
862 auto LoopTailFunc = [&](auto tail_num) {
863 static_for<1, tail_num, 1>{}([&](auto iprefetch) {
865 static_for<0, KRepeat, 1>{}([&](auto k0) {
866 static_for<0, MRepeat, 1>{}([&](auto m0) {
869 a_block_buf,
871 make_tuple(m0, I0, k0, I0),
872 a_thread_buf);
873 });
874 static_for<0, NRepeat, 1>{}([&](auto n0) {
877 b_block_buf,
879 make_tuple(n0, I0, k0, I0),
880 b_thread_buf);
881 });
882
883 __builtin_amdgcn_sched_barrier(0);
884 if constexpr(k0.value != 0 || KRepeat == 1)
885 {
886 __builtin_amdgcn_s_barrier();
887 __builtin_amdgcn_sched_barrier(0);
888 }
890 static_for<0, MRepeat, 1>{}([&](auto m0) {
891 static_for<0, NRepeat, 1>{}([&](auto n0) {
894
895 static_for<0, KPack, 1>{}([&](auto ik) {
896 a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
897 a_thread_buf[Number<a_thread_desc_.CalculateOffset(
898 make_tuple(m0, I0, k0, k_ + ik))>{}];
899 b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
900 b_thread_buf[Number<b_thread_desc_.CalculateOffset(
901 make_tuple(n0, I0, k0, k_ + ik))>{}];
902 });
903
904 using mfma_input_type =
906 xdlops_gemm.K1PerXdlops>::type;
907
908 constexpr index_t c_offset =
909 c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
910
911 if constexpr(k0.value == KRepeat - 1 &&
912 k_.value == KPerInnerLoop - KPack &&
913 m0.value == MRepeat - 1 && n0.value == NRepeat - 1)
914 {
915 __builtin_amdgcn_sched_barrier(0);
917 __builtin_amdgcn_sched_barrier(0);
918 }
919 xdlops_gemm.Run(
920 a_thread_vec.template AsType<mfma_input_type>(),
921 b_thread_vec.template AsType<mfma_input_type>(),
922 c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
923 if constexpr(k_.value == 0 && m0.value == 0 && n0.value == 0)
924 {
925 __builtin_amdgcn_sched_barrier(0);
926 __builtin_amdgcn_s_setprio(1);
927 __builtin_amdgcn_sched_barrier(0);
928 }
929 });
930 });
931 });
932 __builtin_amdgcn_sched_barrier(0);
933 __builtin_amdgcn_s_setprio(0);
934 __builtin_amdgcn_sched_barrier(0);
935 });
936
937 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, iprefetch);
938 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, iprefetch);
939 });
941 static_for<0, KRepeat, 1>{}([&](auto k0) {
942 static_for<0, MRepeat, 1>{}([&](auto m0) {
945 a_block_buf,
947 make_tuple(m0, I0, k0, I0),
948 a_thread_buf);
949 });
950 static_for<0, NRepeat, 1>{}([&](auto n0) {
953 b_block_buf,
955 make_tuple(n0, I0, k0, I0),
956 b_thread_buf);
957 });
958
959 __builtin_amdgcn_sched_barrier(0);
960 if constexpr(k0.value != 0 || KRepeat == 1)
961 {
962 __builtin_amdgcn_s_barrier();
963 __builtin_amdgcn_sched_barrier(0);
964 }
966 static_for<0, MRepeat, 1>{}([&](auto m0) {
967 static_for<0, NRepeat, 1>{}([&](auto n0) {
970
971 static_for<0, KPack, 1>{}([&](auto ik) {
972 a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
973 a_thread_buf[Number<a_thread_desc_.CalculateOffset(
974 make_tuple(m0, I0, k0, k_ + ik))>{}];
975 b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
976 b_thread_buf[Number<b_thread_desc_.CalculateOffset(
977 make_tuple(n0, I0, k0, k_ + ik))>{}];
978 });
979
980 using mfma_input_type =
982 xdlops_gemm.K1PerXdlops>::type;
983
984 constexpr index_t c_offset =
985 c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
986
987 if constexpr(k0.value == KRepeat - 1 &&
988 k_.value == KPerInnerLoop - KPack &&
989 m0.value == MRepeat - 1 && n0.value == NRepeat - 1)
990 {
991 __builtin_amdgcn_sched_barrier(0);
993 __builtin_amdgcn_sched_barrier(0);
994 }
995 xdlops_gemm.Run(
996 a_thread_vec.template AsType<mfma_input_type>(),
997 b_thread_vec.template AsType<mfma_input_type>(),
998 c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
999 if constexpr(k_.value == 0 && m0.value == 0 && n0.value == 0)
1000 {
1001 __builtin_amdgcn_sched_barrier(0);
1002 __builtin_amdgcn_s_setprio(1);
1003 __builtin_amdgcn_sched_barrier(0);
1004 }
1005 });
1006 });
1007 });
1008 __builtin_amdgcn_sched_barrier(0);
1009 __builtin_amdgcn_s_setprio(0);
1010 __builtin_amdgcn_sched_barrier(0);
1011 });
1012 };
1013
1014 if constexpr(TailNum == TailNumber::One)
1015 {
1017 static_for<0, KRepeat, 1>{}([&](auto k0) {
1018 static_for<0, MRepeat, 1>{}([&](auto m0) {
1021 a_block_buf,
1023 make_tuple(m0, I0, k0, I0),
1024 a_thread_buf);
1025 });
1026 static_for<0, NRepeat, 1>{}([&](auto n0) {
1029 b_block_buf,
1031 make_tuple(n0, I0, k0, I0),
1032 b_thread_buf);
1033 });
1034
1035 __builtin_amdgcn_sched_barrier(0);
1036 if constexpr(k0.value != 0 || KRepeat == 1)
1037 {
1038 __builtin_amdgcn_s_barrier();
1039 __builtin_amdgcn_sched_barrier(0);
1040 }
1042 static_for<0, MRepeat, 1>{}([&](auto m0) {
1043 static_for<0, NRepeat, 1>{}([&](auto n0) {
1046
1047 static_for<0, KPack, 1>{}([&](auto ik) {
1048 a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
1049 a_thread_buf[Number<a_thread_desc_.CalculateOffset(
1050 make_tuple(m0, I0, k0, k_ + ik))>{}];
1051 b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
1052 b_thread_buf[Number<b_thread_desc_.CalculateOffset(
1053 make_tuple(n0, I0, k0, k_ + ik))>{}];
1054 });
1055
1056 using mfma_input_type =
1058 xdlops_gemm.K1PerXdlops>::type;
1059
1060 constexpr index_t c_offset =
1061 c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
1062
1063 if constexpr(k0.value == KRepeat - 1 &&
1064 k_.value == KPerInnerLoop - KPack &&
1065 m0.value == MRepeat - 1 && n0.value == NRepeat - 1)
1066 {
1067 __builtin_amdgcn_sched_barrier(0);
1069 __builtin_amdgcn_sched_barrier(0);
1070 }
1071 xdlops_gemm.Run(
1072 a_thread_vec.template AsType<mfma_input_type>(),
1073 b_thread_vec.template AsType<mfma_input_type>(),
1074 c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
1075 if constexpr(k_.value == 0 && m0.value == 0 && n0.value == 0)
1076 {
1077 __builtin_amdgcn_sched_barrier(0);
1078 __builtin_amdgcn_s_setprio(1);
1079 __builtin_amdgcn_sched_barrier(0);
1080 }
1081 });
1082 });
1083 });
1084 __builtin_amdgcn_sched_barrier(0);
1085 __builtin_amdgcn_s_setprio(0);
1086 __builtin_amdgcn_sched_barrier(0);
1087 });
1088 }
1089 else if constexpr(TailNum == TailNumber::Two)
1090 {
1091 LoopTailFunc(Number<2>{});
1092 }
1093 else if constexpr(TailNum == TailNumber::Three)
1094 {
1095 LoopTailFunc(Number<3>{});
1096 }
1097 else if constexpr(TailNum == TailNumber::Four)
1098 {
1099 LoopTailFunc(Number<4>{});
1100 }
1101 else if constexpr(TailNum == TailNumber::Five)
1102 {
1103 LoopTailFunc(Number<5>{});
1104 }
1105 else if constexpr(TailNum == TailNumber::Six)
1106 {
1107 LoopTailFunc(Number<6>{});
1108 }
1109 else if constexpr(TailNum == TailNumber::Seven)
1110 {
1111 LoopTailFunc(Number<7>{});
1112 }
1113 else if constexpr(TailNum == TailNumber::Full)
1114 {
1115 LoopTailFunc(Number<PrefetchStages>{});
1116 }
1117 }
1118
1119 protected:
1120 // K->M loopover
1126 I1));
1127
1133 I1));
1134
1137 decltype(a_block_desc_m0_m1_m2_k),
1138 decltype(a_thread_desc_),
1141 3,
1142 A_K1,
1143 A_K1>;
1144
1147 decltype(b_block_desc_n0_n1_n2_k),
1148 decltype(b_thread_desc_),
1151 3,
1152 B_K1,
1153 B_K1>;
1154
1157 using Base::c_thread_desc_;
1158};
1159
1160} // namespace ck
#define CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING_MAC_CLUSTERS
Definition ck.hpp:209
__host__ __device__ constexpr T max(T x)
Definition utility/math.hpp:84
__host__ __device__ constexpr auto integer_divide_ceil(X x, Y y)
Definition utility/math.hpp:72
Definition ck.hpp:268
__host__ __device__ constexpr auto make_static_buffer(Number< N >)
Definition static_buffer.hpp:186
int32_t index_t
Definition ck.hpp:299
__host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition tensor_descriptor_helper.hpp:49
integral_constant< index_t, N > Number
Definition number.hpp:12
TailNumber
Definition blkgemmpipe_scheduler.hpp:31
@ One
Definition blkgemmpipe_scheduler.hpp:37
@ Seven
Definition blkgemmpipe_scheduler.hpp:43
@ Four
Definition blkgemmpipe_scheduler.hpp:40
@ Two
Definition blkgemmpipe_scheduler.hpp:38
@ Full
Definition blkgemmpipe_scheduler.hpp:49
@ Three
Definition blkgemmpipe_scheduler.hpp:39
@ Five
Definition blkgemmpipe_scheduler.hpp:41
@ Six
Definition blkgemmpipe_scheduler.hpp:42
BlockGemmPipelineScheduler
Definition blkgemmpipe_scheduler.hpp:25
@ Intrawave
Definition blkgemmpipe_scheduler.hpp:26
@ Interwave
Definition blkgemmpipe_scheduler.hpp:27
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__device__ void block_sync_lds()
Definition synchronization.hpp:16
__host__ __device__ BlockwiseGemmXdlops_pipeline_base(Tuple4 a_origin=CalculateAThreadOriginDataIndex(), Tuple4 b_origin=CalculateBThreadOriginDataIndex())
Constructor for BlockwiseGemmXdlops_pipeline_base.
Definition blockwise_gemm_pipeline_xdlops_base.hpp:222
__host__ static __device__ constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4()
Definition blockwise_gemm_pipeline_xdlops_base.hpp:280
__host__ static __device__ constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4()
Definition blockwise_gemm_pipeline_xdlops_base.hpp:239
static constexpr auto xdlops_gemm
Definition blockwise_gemm_pipeline_xdlops_base.hpp:54
conditional_t< std::is_same< ComputeDataType, ck::tf32_t >::value, float, ComputeDataType > ComputeDataTypeBuf
Definition blockwise_gemm_pipeline_xdlops_base.hpp:57
static __device__ auto CalculateBThreadOriginDataIndex()
Definition blockwise_gemm_pipeline_xdlops_base.hpp:147
static constexpr BMmaTileDesc b_block_desc_n0_n1_n2_k
Definition blockwise_gemm_pipeline_xdlops_base.hpp:360
static constexpr auto I1
Definition blockwise_gemm_pipeline_xdlops_base.hpp:37
__host__ static __device__ constexpr auto GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2()
Definition blockwise_gemm_pipeline_xdlops_base.hpp:266
__host__ static __device__ constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()
Definition blockwise_gemm_pipeline_xdlops_base.hpp:294
static constexpr index_t AMmaKStride
Definition blockwise_gemm_pipeline_xdlops_base.hpp:60
static __device__ auto CalculateAThreadOriginDataIndex()
Definition blockwise_gemm_pipeline_xdlops_base.hpp:125
static constexpr index_t WaveSize
Definition blockwise_gemm_pipeline_xdlops_base.hpp:46
__host__ static __device__ constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()
Definition blockwise_gemm_pipeline_xdlops_base.hpp:253
static constexpr index_t B_K1
Definition blockwise_gemm_pipeline_xdlops_base.hpp:51
__host__ __device__ constexpr auto & GetCThreadBuffer()
Definition blockwise_gemm_pipeline_xdlops_base.hpp:111
static constexpr auto I0
Definition blockwise_gemm_pipeline_xdlops_base.hpp:36
static __device__ auto CalculateCThreadOriginDataIndex(Number< m0 >, Number< n0 >, Number< xdlops_i >, Number< blk_i >)
Definition blockwise_gemm_pipeline_xdlops_base.hpp:160
static __device__ auto CalculateCThreadOriginDataIndex8D(Number< m0 >, Number< n0 >, Number< xdlops_i >, Number< blk_i >)
Definition blockwise_gemm_pipeline_xdlops_base.hpp:189
static constexpr index_t KRepeat
Definition blockwise_gemm_pipeline_xdlops_base.hpp:64
static constexpr AMmaTileDesc a_block_desc_m0_m1_m2_k
Definition blockwise_gemm_pipeline_xdlops_base.hpp:359
static constexpr index_t A_K1
Definition blockwise_gemm_pipeline_xdlops_base.hpp:50
static constexpr index_t BMmaKStride
Definition blockwise_gemm_pipeline_xdlops_base.hpp:61
__host__ static __device__ constexpr auto MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_G_M_N &c_grid_desc_g_m_n)
Definition blockwise_gemm_pipeline_xdlops_base.hpp:341
static constexpr index_t KPerThread
Definition blockwise_gemm_pipeline_xdlops_base.hpp:63
__host__ static __device__ constexpr auto GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2()
Definition blockwise_gemm_pipeline_xdlops_base.hpp:307
__host__ static __device__ constexpr auto MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_M_N &c_grid_desc_m_n)
Definition blockwise_gemm_pipeline_xdlops_base.hpp:324
ThreadwiseTensorSliceTransfer_v4< ADataType, ComputeDataTypeBuf, decltype(a_block_desc_m0_m1_m2_k), decltype(a_thread_desc_), Sequence< 1, 1, 1, KPerInnerLoop >, Sequence< 0, 1, 2, 3 >, 3, A_K1, A_K1 > AThreadCopy
Definition blockwise_gemm_pipeline_xdlops_v2.hpp:1135
BlockwiseGemmXdlops_pipeline_base< BlockSize, ADataType, BDataType, ComputeDataType, AccDataType, ATileDesc, BTileDesc, AMmaTileDesc, BMmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXDL, NPerXDL, MRepeat, NRepeat, KPack > Base
Definition blockwise_gemm_pipeline_xdlops_v2.hpp:591
__device__ void Run(const AGridDesc &a_grid_desc, const ABlockDesc &a_block_desc, ABlockTransfer &a_blockwise_copy, const AGridBuffer &a_grid_buf, ABlockBuffer &a_block_buf, const ABlockTransferStep &a_block_copy_step, const BGridDesc &b_grid_desc, const BBlockDesc &b_block_desc, BBlockTransfer &b_blockwise_copy, const BGridBuffer &b_grid_buf, BBlockBuffer &b_block_buf, const BBlockTransferStep &b_block_copy_step, CThreadBuffer &c_thread_buf, index_t num_loop) const
Definition blockwise_gemm_pipeline_xdlops_v2.hpp:708
ThreadwiseTensorSliceTransfer_v4< BDataType, ComputeDataTypeBuf, decltype(b_block_desc_n0_n1_n2_k), decltype(b_thread_desc_), Sequence< 1, 1, 1, KPerInnerLoop >, Sequence< 0, 1, 2, 3 >, 3, B_K1, B_K1 > BThreadCopy
Definition blockwise_gemm_pipeline_xdlops_v2.hpp:1145
__device__ void Run(const AGridDesc &a_grid_desc, const ABlockDesc &a_block_desc, ABlockTransfer &a_blockwise_copy, const AGridBuffer &a_grid_buf, ABlockBuffer &a_block_buf, const ABlockTransferStep &a_block_copy_step, const BGridDesc &b_grid_desc, const BBlockDesc &b_block_desc, BBlockTransfer &b_blockwise_copy, const BGridBuffer &b_grid_buf, BBlockBuffer &b_block_buf, const BBlockTransferStep &b_block_copy_step, CThreadBuffer &c_thread_buf, index_t num_loop) const
Definition blockwise_gemm_pipeline_xdlops_v2.hpp:215
BlockwiseGemmXdlops_pipeline_base< BlockSize, ADataType, BDataType, ComputeDataType, AccDataType, ATileDesc, BTileDesc, AMmaTileDesc, BMmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXDL, NPerXDL, MRepeat, NRepeat, KPack > Base
Definition blockwise_gemm_pipeline_xdlops_v2.hpp:102
Definition blockwise_gemm_pipeline_xdlops_v2.hpp:37
Definition utility/sequence.hpp:43
Definition threadwise_tensor_slice_transfer.hpp:1260
Definition functional2.hpp:33
Definition dtype_vector.hpp:10