blockwise_gemm_pipeline_xdlops_blockscale_b_preshuffle_v1.hpp Source File

blockwise_gemm_pipeline_xdlops_blockscale_b_preshuffle_v1.hpp Source File#

Composable Kernel: blockwise_gemm_pipeline_xdlops_blockscale_b_preshuffle_v1.hpp Source File
blockwise_gemm_pipeline_xdlops_blockscale_b_preshuffle_v1.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
7
8namespace ck {
9
10// Compute optimized pipeline
11// GlobalPrefetchStages: 2
12// LocalPreFillStages: 1
13// LocalPreFetchStages: 1
14// LocalSharedMemoryBuffer: 1
15
16template <BlockGemmPipelineScheduler BlkGemmPipelineVer,
17 index_t BlockSize,
18 typename ADataType,
19 typename BDataType,
20 typename ComputeDataType,
21 typename AccDataType,
22 typename ATileDesc,
23 typename BTileDesc,
24 typename AMmaTileDesc,
25 typename BMmaTileDesc,
26 index_t ABlockTransferSrcScalarPerVector,
27 index_t BBlockTransferSrcScalarPerVector,
28 index_t MPerBlock,
29 index_t NPerBlock,
30 index_t KPerBlock,
31 index_t MScaleBlock,
32 index_t NScaleBlock,
33 index_t KScaleBlock,
34 index_t MPerXDL,
35 index_t NPerXDL,
36 index_t MRepeat,
37 index_t NRepeat,
38 index_t KPacks>
42
43template <index_t BlockSize,
44 typename ADataType,
45 typename BDataType,
46 typename ComputeDataType,
47 typename AccDataType,
48 typename ATileDesc,
49 typename BTileDesc,
50 typename AMmaTileDesc,
51 typename BMmaTileDesc,
52 index_t ABlockTransferSrcScalarPerVector,
53 index_t BBlockTransferSrcScalarPerVector,
54 index_t MPerBlock,
55 index_t NPerBlock,
56 index_t KPerBlock,
57 index_t MScaleBlock,
58 index_t NScaleBlock,
59 index_t KScaleBlock,
60 index_t MPerXDL,
61 index_t NPerXDL,
62 index_t MRepeat,
63 index_t NRepeat,
64 index_t KPack
65 // ,bool TransposeC //disable transposec right now...
66 >
68 BlockSize,
69 ADataType,
70 BDataType,
71 ComputeDataType,
72 AccDataType,
73 ATileDesc,
74 BTileDesc,
75 AMmaTileDesc,
76 BMmaTileDesc,
77 ABlockTransferSrcScalarPerVector,
78 BBlockTransferSrcScalarPerVector,
79 MPerBlock,
80 NPerBlock,
81 KPerBlock,
82 MScaleBlock,
83 NScaleBlock,
84 KScaleBlock,
85 MPerXDL,
86 NPerXDL,
87 MRepeat,
88 NRepeat,
89 KPack>
91 ADataType,
92 BDataType,
93 ComputeDataType,
94 AccDataType,
95 ATileDesc,
96 BTileDesc,
97 AMmaTileDesc,
98 BMmaTileDesc,
99 ABlockTransferSrcScalarPerVector,
100 BBlockTransferSrcScalarPerVector,
101 MPerBlock,
102 NPerBlock,
103 KPerBlock,
104 MPerXDL,
105 NPerXDL,
106 MRepeat,
107 NRepeat,
108 KPack,
109 true>
110
111{
113 ADataType,
114 BDataType,
115 ComputeDataType,
116 AccDataType,
117 ATileDesc,
118 BTileDesc,
119 AMmaTileDesc,
120 BMmaTileDesc,
121 ABlockTransferSrcScalarPerVector,
122 BBlockTransferSrcScalarPerVector,
123 MPerBlock,
124 NPerBlock,
125 KPerBlock,
126 MPerXDL,
127 NPerXDL,
128 MRepeat,
129 NRepeat,
130 KPack,
131 true>;
132 using Base::A_K1;
133 using Base::B_K1;
134 using Base::I0;
135 using Base::I1;
136 using Base::KGroup;
137 using Base::KRepeat;
138 using Base::xdlops_gemm;
139 using typename Base::HotLoopInstList;
140
153
154 using Base::MWaves;
155 using Base::NWaves;
156 using Base::WaveSize;
157
158 static constexpr index_t PrefetchStages = 2;
159 static constexpr index_t PrefillStages = 1;
160 static constexpr index_t GlobalBufferNum = 2;
161
162 template <typename TileDesc_M0_M1_M2_K>
163 __host__ __device__ static constexpr auto MakeAGemmMmaTileDescriptor(const TileDesc_M0_M1_M2_K&)
164 {
165 constexpr index_t M0 = TileDesc_M0_M1_M2_K{}.GetLength(Number<0>{});
166 constexpr index_t M1 = TileDesc_M0_M1_M2_K{}.GetLength(Number<1>{});
167 constexpr index_t M2 = TileDesc_M0_M1_M2_K{}.GetLength(Number<2>{});
168 constexpr index_t K2 = KPack / KGroup;
169 constexpr index_t K1 = WaveSize / NPerXDL;
170 constexpr index_t K0 = KRepeat * KGroup;
171
173 TileDesc_M0_M1_M2_K{},
181 }
182
183 static constexpr auto a_block_desc_m0_m1_m2_k0_k1_k2 =
185
186 __host__ __device__ static constexpr bool BlockHasHotloop(index_t num_loop)
187 {
188 return num_loop > PrefetchStages;
189 }
190
191 __host__ __device__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop)
192 {
193 return num_loop % 2 == 0 ? TailNumber::Even : TailNumber::Odd;
194 }
195
196 __device__ static constexpr auto HotLoopScheduler()
197 {
198 constexpr auto num_ds_read_inst_a = HotLoopInstList::A_LDS_Read_Inst_Num;
199 constexpr auto num_buffer_load_inst_a = HotLoopInstList::A_Buffer_Load_Inst_Num;
200 constexpr auto num_buffer_load_inst_b = HotLoopInstList::B_Buffer_Load_Inst_Num * MWaves;
201
202 // B global
204 ignore = i;
205 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
206 __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
207 });
208
209 // A global
211 ignore = i;
212 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
213 __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
214 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
215 __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
216 });
217
218 // A local
219 static_for<0, num_ds_read_inst_a / 2, 1>{}([&](auto i) {
220 ignore = i;
221 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
222 __builtin_amdgcn_sched_group_barrier(0x100, 2, 0); // DS read
223 });
224 }
225
226 template <bool HasMainLoop,
227 int NumKBlockPerScale,
228 TailNumber TailNum,
229 typename AGridDesc,
230 typename ABlockDesc,
231 typename ABlockTransfer,
232 typename AGridBuffer,
233 typename ABlockBuffer,
234 typename ABlockTransferStep,
235 typename BGridDesc,
236 typename BBlockDesc,
237 typename BBlockTransfer,
238 typename BGridBuffer,
239 typename BBlockBuffer,
240 typename BBlockTransferStep,
241 typename CScaleThreadDesc,
242 typename CThreadBuffer,
243 typename AScaleGridBuffer,
244 typename AScaleGridDesc,
245 typename AScaleThreadDesc,
246 typename AScaleThreadTransfer,
247 typename AScaleThreadTransferStep,
248 typename BScaleGridBuffer,
249 typename BScaleGridDesc,
250 typename BScaleThreadDesc,
251 typename BScaleThreadTransfer,
252 typename BScaleThreadTransferStep>
253 __device__ void Run(
254 // ABlockCopy
255 const AGridDesc& a_grid_desc,
256 const ABlockDesc& a_block_desc,
257 ABlockTransfer& a_blockwise_copy,
258 const AGridBuffer& a_grid_buf,
259 ABlockBuffer& a_block_buf,
260 const ABlockTransferStep& a_block_copy_step,
261 // BBlockCopy
262 const BGridDesc& b_grid_desc,
263 const BBlockDesc& b_block_desc,
264 BBlockTransfer& b_blockwise_copy,
265 const BGridBuffer& b_grid_buf,
266 BBlockBuffer& b_block_buf,
267 const BBlockTransferStep& b_block_copy_step,
268 // CThread
269 const CScaleThreadDesc& c_scale_thread_desc,
270 CThreadBuffer& c_thread_buf,
271 // AScaleThreadCopy
272 const AScaleGridDesc& a_scale_grid_desc,
273 const AScaleThreadDesc& a_scale_thread_desc,
274 AScaleThreadTransfer& a_scale_thread_copy,
275 const AScaleGridBuffer& a_scale_grid_buf,
276 const AScaleThreadTransferStep& a_scale_thread_copy_step,
277 // BScaleThreadCopy
278 const BScaleGridDesc& b_scale_grid_desc,
279 const BScaleThreadDesc& b_scale_thread_desc,
280 BScaleThreadTransfer& b_scale_thread_copy,
281 const BScaleGridBuffer& b_scale_grid_buf,
282 const BScaleThreadTransferStep& b_scale_thread_copy_step,
283 // num_loop
284 index_t num_loop) const
285 {
286 ignore = b_block_desc;
287 ignore = b_block_buf;
288 // __builtin_amdgcn_sched_barrier(0);
290 a_thread_desc_.GetElementSpaceSize());
292 b_thread_desc_.GetElementSpaceSize());
293
294 StaticallyIndexedArray<decltype(b_thread_buf), Number<2>{}> b_thread_bufs;
295 constexpr auto b_block_origin_idx = make_tuple(I0, I0, I0, I0);
296
298 a_scale_thread_desc.GetElementSpaceSize());
300 b_scale_thread_desc.GetElementSpaceSize());
302 c_scale_thread_desc.GetElementSpaceSize());
303
304 // Global prefetch A1 B1
305 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0);
306 b_blockwise_copy.Run(b_grid_desc,
307 b_grid_buf,
309 b_block_origin_idx,
310 b_thread_bufs(I0));
311
312 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
313 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
314
315 static_for<0, MRepeat, 1>{}([&](auto m0) {
316 a_scale_thread_copy.Run(a_scale_grid_desc,
317 a_scale_grid_buf,
318 a_scale_thread_desc,
319 make_tuple(m0, I0),
320 a_scale_thread_buf);
321 a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
322 a_scale_thread_copy_step.At(Number<0>{}));
323 });
324
325 if constexpr(NumKBlockPerScale == 1)
326 {
327 a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
328 a_scale_thread_copy_step.At(Number<2>{}));
329 }
330 else
331 {
332 a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
333 a_scale_thread_copy_step.At(Number<1>{}));
334 }
335
336 b_scale_thread_copy.Run(b_scale_grid_desc,
337 b_scale_grid_buf,
338 b_scale_thread_desc,
339 make_tuple(I0, I0),
340 b_scale_thread_buf);
341
342 b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, b_scale_thread_copy_step);
343
344 __builtin_amdgcn_sched_barrier(0);
345
346 constexpr auto num_scale_k_block = CScaleThreadDesc{}.GetLength(Number<0>{});
347 constexpr auto num_scale_m_block = CScaleThreadDesc{}.GetLength(Number<1>{});
348 constexpr auto num_scale_n_block = CScaleThreadDesc{}.GetLength(Number<2>{});
349
353 constexpr index_t c_offset =
354 CScaleThreadDesc{}.CalculateOffset(make_tuple(k0, m0, n0));
355 constexpr index_t a_offset =
356 AScaleThreadDesc{}.CalculateOffset(make_tuple(m0, k0));
357 constexpr index_t b_offset =
358 BScaleThreadDesc{}.CalculateOffset(make_tuple(n0, k0));
359
360 c_scale_thread_buf(Number<c_offset>{}) =
361 a_scale_thread_buf[Number<a_offset>{}] *
362 b_scale_thread_buf[Number<b_offset>{}];
363 });
364 });
365 });
366
367 // Local prefill A1
368 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I0);
369
370 // Global prefetch A2
371 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0);
372 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
373
374 static_for<0, MRepeat, 1>{}([&](auto m0) {
375 a_scale_thread_copy.Run(a_scale_grid_desc,
376 a_scale_grid_buf,
377 a_scale_thread_desc,
378 make_tuple(m0, I0),
379 a_scale_thread_buf);
380 a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
381 a_scale_thread_copy_step.At(Number<0>{}));
382 });
383
384 if constexpr(NumKBlockPerScale == 1)
385 {
386 a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
387 a_scale_thread_copy_step.At(Number<2>{}));
388 }
389 else
390 {
391 a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
392 a_scale_thread_copy_step.At(Number<1>{}));
393 }
394
395 b_scale_thread_copy.Run(b_scale_grid_desc,
396 b_scale_grid_buf,
397 b_scale_thread_desc,
398 make_tuple(I0, I0),
399 b_scale_thread_buf);
400
401 b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, b_scale_thread_copy_step);
402
404 AccDataType,
405 1,
406 xdlops_gemm.GetRegSizePerXdlops(),
407 true>
408 c_thread_buf_per_scale;
409
410 // Local prefetch A1
412 static_for<0, MRepeat, 1>{}([&](auto m0) {
413 static_for<0, KRepeat, 1>{}([&](auto k0) {
414 static_for<0, KGroup, 1>{}([&](auto kg0) {
415 a_thread_copy_.Run(
418 a_block_buf,
421 a_thread_buf);
422 });
423 });
424 });
425
426 // Initialize C
427 c_thread_buf.Clear();
428
429 // __builtin_amdgcn_sched_barrier(0);
430
431 // main body
432 if constexpr(HasMainLoop)
433 {
434 index_t i = 0;
435 do
436 {
437 auto LoopFunc = [&](auto mfma_reg_buf, auto local_read_buf) {
438 b_blockwise_copy.Run(b_grid_desc,
439 b_grid_buf,
441 b_block_origin_idx,
442 b_thread_bufs(local_read_buf));
443 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
444
446 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, mfma_reg_buf);
447
448 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, local_read_buf);
449 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
450
451 static_for<0, MRepeat, 1>{}([&](auto m0) {
452 static_for<0, NRepeat, 1>{}([&](auto n0) {
453 static_for<0, num_scale_k_block, 1>{}([&](auto kscale0) {
454 static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
455 c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
456 .template AsType<AccDataType>()(Number<t>{}) = 0;
457 });
458 vector_type<AccDataType, 2> c_scale_thread_vec;
459 constexpr index_t cscale_offset =
460 CScaleThreadDesc{}.CalculateOffset(
461 make_tuple(kscale0, m0, n0 * num_scale_n_block / NRepeat));
462
463 c_scale_thread_vec.template AsType<AccDataType>()(Number<0>{}) =
464 c_scale_thread_buf[Number<cscale_offset>{}];
465 c_scale_thread_vec.template AsType<AccDataType>()(Number<1>{}) =
466 c_scale_thread_buf[Number<cscale_offset>{}];
467
468 static_for<0, KRepeat / num_scale_k_block, 1>{}([&](auto k0) {
471
472 static_for<0, KPack, 1>{}([&](auto ik) {
473 a_thread_vec.template AsType<ComputeDataType>()(ik) =
474 a_thread_buf[Number<a_thread_desc_.CalculateOffset(
475 make_tuple(m0,
476 I0,
477 I0,
478 kscale0 * KRepeat / num_scale_k_block +
479 k0,
480 I0,
481 ik))>{}];
482 b_thread_vec.template AsType<ComputeDataType>()(ik) =
483 b_thread_bufs[mfma_reg_buf][Number<
484 b_thread_desc_.CalculateOffset(make_tuple(
485 n0,
486 I0,
487 kscale0 * KRepeat / num_scale_k_block + k0,
488 ik))>{}];
489 });
490
491 using mfma_input_type =
492 typename vector_type<ComputeDataType,
493 xdlops_gemm.K1PerXdlops>::type;
494
495 xdlops_gemm.template Run<>(
496 a_thread_vec.template AsType<mfma_input_type>(),
497 b_thread_vec.template AsType<mfma_input_type>(),
498 c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}));
499 });
500 constexpr index_t c_offset =
501 c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
502
503 static_for<0, xdlops_gemm.GetRegSizePerXdlops() / 2, 1>{}(
504 [&](auto t) {
505 using pk_fma_type =
507
508 c_thread_buf.GetVectorTypeReference(Number<c_offset>{})
509 .template AsType<pk_fma_type>()(t) =
510 __builtin_elementwise_fma(
511 c_thread_buf_per_scale
512 .GetVectorTypeReference(Number<0>{})
513 .template AsType<pk_fma_type>()[t],
514 c_scale_thread_vec
515 .template AsType<pk_fma_type>()[Number<0>{}],
516 c_thread_buf
517 .GetVectorTypeReference(Number<c_offset>{})
518 .template AsType<pk_fma_type>()[t]);
519 });
520 });
521 });
522 });
523
525
526 static_for<0, MRepeat, 1>{}([&](auto m0) {
527 static_for<0, KRepeat, 1>{}([&](auto k0) {
528 static_for<0, KGroup, 1>{}([&](auto kg0) {
529 a_thread_copy_.Run(
532 a_block_buf,
535 a_thread_buf);
536 });
537 });
538 });
539
541 __builtin_amdgcn_sched_barrier(0);
542
543 static_for<0, MRepeat, 1>{}([&](auto m0) {
546 constexpr index_t c_offset =
547 CScaleThreadDesc{}.CalculateOffset(make_tuple(k0, m0, n0));
548 constexpr index_t a_offset =
549 AScaleThreadDesc{}.CalculateOffset(make_tuple(m0, k0));
550 constexpr index_t b_offset =
551 BScaleThreadDesc{}.CalculateOffset(make_tuple(n0, k0));
552
553 c_scale_thread_buf(Number<c_offset>{}) =
554 a_scale_thread_buf[Number<a_offset>{}] *
555 b_scale_thread_buf[Number<b_offset>{}];
556 });
557 });
558 });
559
560 static_for<0, MRepeat, 1>{}([&](auto m0) {
561 a_scale_thread_copy.Run(a_scale_grid_desc,
562 a_scale_grid_buf,
563 a_scale_thread_desc,
564 make_tuple(m0, I0),
565 a_scale_thread_buf);
566 a_scale_thread_copy.MoveSrcSliceWindow(
567 a_scale_grid_desc, a_scale_thread_copy_step.At(Number<0>{}));
568 });
569
570 if constexpr(NumKBlockPerScale == 1)
571 {
572 a_scale_thread_copy.MoveSrcSliceWindow(
573 a_scale_grid_desc, a_scale_thread_copy_step.At(Number<2>{}));
574 }
575 else
576 {
577 a_scale_thread_copy.MoveSrcSliceWindow(
578 a_scale_grid_desc, a_scale_thread_copy_step.At(Number<1>{}));
579 }
580
581 b_scale_thread_copy.Run(b_scale_grid_desc,
582 b_scale_grid_buf,
583 b_scale_thread_desc,
584 make_tuple(I0, I0),
585 b_scale_thread_buf);
586
587 b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
588 b_scale_thread_copy_step);
589 };
590
591 LoopFunc(I0, I1);
592 LoopFunc(I1, I0);
593
594 i += 2;
595 } while(i < (num_loop - 2));
596 }
597
598 // tail
599 if constexpr(TailNum == TailNumber::Even)
600 {
601 b_blockwise_copy.Run(b_grid_desc,
602 b_grid_buf,
604 b_block_origin_idx,
605 b_thread_bufs(I1));
607 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
608
609 static_for<0, MRepeat, 1>{}([&](auto m0) {
610 static_for<0, NRepeat, 1>{}([&](auto n0) {
611 static_for<0, num_scale_k_block, 1>{}([&](auto kscale0) {
612 static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
613 c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
614 .template AsType<AccDataType>()(Number<t>{}) = 0;
615 });
616 vector_type<AccDataType, 2> c_scale_thread_vec;
617 constexpr index_t cscale_offset = CScaleThreadDesc{}.CalculateOffset(
618 make_tuple(kscale0, m0, n0 * num_scale_n_block / NRepeat));
619
620 c_scale_thread_vec.template AsType<AccDataType>()(Number<0>{}) =
621 c_scale_thread_buf[Number<cscale_offset>{}];
622 c_scale_thread_vec.template AsType<AccDataType>()(Number<1>{}) =
623 c_scale_thread_buf[Number<cscale_offset>{}];
624
625 static_for<0, KRepeat / num_scale_k_block, 1>{}([&](auto k0) {
628
629 static_for<0, KPack, 1>{}([&](auto ik) {
630 a_thread_vec.template AsType<ComputeDataType>()(ik) =
631 a_thread_buf[Number<a_thread_desc_.CalculateOffset(
632 make_tuple(m0,
633 I0,
634 I0,
635 kscale0 * KRepeat / num_scale_k_block + k0,
636 I0,
637 ik))>{}];
638 b_thread_vec.template AsType<ComputeDataType>()(ik) =
639 b_thread_bufs[I0][Number<b_thread_desc_.CalculateOffset(
640 make_tuple(n0,
641 I0,
642 kscale0 * KRepeat / num_scale_k_block + k0,
643 ik))>{}];
644 });
645
646 using mfma_input_type =
647 typename vector_type<ComputeDataType,
648 xdlops_gemm.K1PerXdlops>::type;
649
650 xdlops_gemm.template Run<>(
651 a_thread_vec.template AsType<mfma_input_type>(),
652 b_thread_vec.template AsType<mfma_input_type>(),
653 c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}));
654 });
655 constexpr index_t c_offset =
656 c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
657
658 static_for<0, xdlops_gemm.GetRegSizePerXdlops() / 2, 1>{}([&](auto t) {
659 using pk_fma_type = typename vector_type<AccDataType, 2>::type;
660
661 c_thread_buf.GetVectorTypeReference(Number<c_offset>{})
662 .template AsType<pk_fma_type>()(t) = __builtin_elementwise_fma(
663 c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
664 .template AsType<pk_fma_type>()[t],
665 c_scale_thread_vec.template AsType<pk_fma_type>()[Number<0>{}],
666 c_thread_buf.GetVectorTypeReference(Number<c_offset>{})
667 .template AsType<pk_fma_type>()[t]);
668 });
669 });
670 });
671 });
672
673 static_for<0, MRepeat, 1>{}([&](auto m0) {
676 constexpr index_t c_offset =
677 CScaleThreadDesc{}.CalculateOffset(make_tuple(k0, m0, n0));
678 constexpr index_t a_offset =
679 AScaleThreadDesc{}.CalculateOffset(make_tuple(m0, k0));
680 constexpr index_t b_offset =
681 BScaleThreadDesc{}.CalculateOffset(make_tuple(n0, k0));
682
683 c_scale_thread_buf(Number<c_offset>{}) =
684 a_scale_thread_buf[Number<a_offset>{}] *
685 b_scale_thread_buf[Number<b_offset>{}];
686 });
687 });
688 });
689
691
692 static_for<0, MRepeat, 1>{}([&](auto m0) {
693 static_for<0, KRepeat, 1>{}([&](auto k0) {
694 static_for<0, KGroup, 1>{}([&](auto kg0) {
695 a_thread_copy_.Run(
698 a_block_buf,
701 a_thread_buf);
702 });
703 });
704 });
705
706 static_for<0, MRepeat, 1>{}([&](auto m0) {
707 static_for<0, NRepeat, 1>{}([&](auto n0) {
708 static_for<0, num_scale_k_block, 1>{}([&](auto kscale0) {
709 static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
710 c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
711 .template AsType<AccDataType>()(Number<t>{}) = 0;
712 });
713 vector_type<AccDataType, 2> c_scale_thread_vec;
714 constexpr index_t cscale_offset = CScaleThreadDesc{}.CalculateOffset(
715 make_tuple(kscale0, m0, n0 * num_scale_n_block / NRepeat));
716
717 c_scale_thread_vec.template AsType<AccDataType>()(Number<0>{}) =
718 c_scale_thread_buf[Number<cscale_offset>{}];
719 c_scale_thread_vec.template AsType<AccDataType>()(Number<1>{}) =
720 c_scale_thread_buf[Number<cscale_offset>{}];
721
722 static_for<0, KRepeat / num_scale_k_block, 1>{}([&](auto k0) {
725
726 static_for<0, KPack, 1>{}([&](auto ik) {
727 a_thread_vec.template AsType<ComputeDataType>()(ik) =
728 a_thread_buf[Number<a_thread_desc_.CalculateOffset(
729 make_tuple(m0,
730 I0,
731 I0,
732 kscale0 * KRepeat / num_scale_k_block + k0,
733 I0,
734 ik))>{}];
735 b_thread_vec.template AsType<ComputeDataType>()(ik) =
736 b_thread_bufs[I1][Number<b_thread_desc_.CalculateOffset(
737 make_tuple(n0,
738 I0,
739 kscale0 * KRepeat / num_scale_k_block + k0,
740 ik))>{}];
741 });
742
743 using mfma_input_type =
744 typename vector_type<ComputeDataType,
745 xdlops_gemm.K1PerXdlops>::type;
746
747 xdlops_gemm.template Run<>(
748 a_thread_vec.template AsType<mfma_input_type>(),
749 b_thread_vec.template AsType<mfma_input_type>(),
750 c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}));
751 });
752 constexpr index_t c_offset =
753 c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
754
755 static_for<0, xdlops_gemm.GetRegSizePerXdlops() / 2, 1>{}([&](auto t) {
756 using pk_fma_type = typename vector_type<AccDataType, 2>::type;
757
758 c_thread_buf.GetVectorTypeReference(Number<c_offset>{})
759 .template AsType<pk_fma_type>()(t) = __builtin_elementwise_fma(
760 c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
761 .template AsType<pk_fma_type>()[t],
762 c_scale_thread_vec.template AsType<pk_fma_type>()[Number<0>{}],
763 c_thread_buf.GetVectorTypeReference(Number<c_offset>{})
764 .template AsType<pk_fma_type>()[t]);
765 });
766 });
767 });
768 });
769 }
770 else if constexpr(TailNum == TailNumber::Odd)
771 {
772 static_for<0, MRepeat, 1>{}([&](auto m0) {
773 static_for<0, NRepeat, 1>{}([&](auto n0) {
774 static_for<0, num_scale_k_block, 1>{}([&](auto kscale0) {
775 static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
776 c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
777 .template AsType<AccDataType>()(Number<t>{}) = 0;
778 });
779 vector_type<AccDataType, 2> c_scale_thread_vec;
780 constexpr index_t cscale_offset = CScaleThreadDesc{}.CalculateOffset(
781 make_tuple(kscale0, m0, n0 * num_scale_n_block / NRepeat));
782
783 c_scale_thread_vec.template AsType<AccDataType>()(Number<0>{}) =
784 c_scale_thread_buf[Number<cscale_offset>{}];
785 c_scale_thread_vec.template AsType<AccDataType>()(Number<1>{}) =
786 c_scale_thread_buf[Number<cscale_offset>{}];
787
788 static_for<0, KRepeat / num_scale_k_block, 1>{}([&](auto k0) {
791
792 static_for<0, KPack, 1>{}([&](auto ik) {
793 a_thread_vec.template AsType<ComputeDataType>()(ik) =
794 a_thread_buf[Number<a_thread_desc_.CalculateOffset(
795 make_tuple(m0,
796 I0,
797 I0,
798 kscale0 * KRepeat / num_scale_k_block + k0,
799 I0,
800 ik))>{}];
801 b_thread_vec.template AsType<ComputeDataType>()(ik) =
802 b_thread_bufs[I0][Number<b_thread_desc_.CalculateOffset(
803 make_tuple(n0,
804 I0,
805 kscale0 * KRepeat / num_scale_k_block + k0,
806 ik))>{}];
807 });
808
809 using mfma_input_type =
810 typename vector_type<ComputeDataType,
811 xdlops_gemm.K1PerXdlops>::type;
812
813 xdlops_gemm.template Run<>(
814 a_thread_vec.template AsType<mfma_input_type>(),
815 b_thread_vec.template AsType<mfma_input_type>(),
816 c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}));
817 });
818
819 constexpr index_t c_offset =
820 c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
821
822 static_for<0, xdlops_gemm.GetRegSizePerXdlops() / 2, 1>{}([&](auto t) {
823 using pk_fma_type = typename vector_type<AccDataType, 2>::type;
824
825 c_thread_buf.GetVectorTypeReference(Number<c_offset>{})
826 .template AsType<pk_fma_type>()(t) = __builtin_elementwise_fma(
827 c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
828 .template AsType<pk_fma_type>()[t],
829 c_scale_thread_vec.template AsType<pk_fma_type>()[Number<0>{}],
830 c_thread_buf.GetVectorTypeReference(Number<c_offset>{})
831 .template AsType<pk_fma_type>()[t]);
832 });
833 });
834 });
835 });
836 }
837 }
838
839 protected:
840 // MRepeat MWave MLane KRepeat KLane KPack
841 // KRepeat -> MRepeat-> Mwave->KLane->MLane->KPack
844
846 ComputeDataType,
848 decltype(a_thread_desc_),
849 Sequence<1, 1, 1, 1, 1, KPack / KGroup>,
851 5,
852 A_K1,
853 A_K1>;
854
856
859
860 static constexpr BTileDesc b_block_desc_n0_n1_k0_k1;
861
863};
864
865} // namespace ck
Definition ck.hpp:268
__host__ __device__ constexpr auto make_static_buffer(Number< N >)
Definition static_buffer.hpp:186
__host__ __device__ constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition multi_index_transform_helper.hpp:12
typename detail::StaticallyIndexedArrayImpl< T, N >::type StaticallyIndexedArray
Definition utility/statically_indexed_array.hpp:45
int32_t index_t
Definition ck.hpp:299
integral_constant< index_t, N > Number
Definition number.hpp:12
TailNumber
Definition blkgemmpipe_scheduler.hpp:31
@ Even
Definition blkgemmpipe_scheduler.hpp:34
@ Odd
Definition blkgemmpipe_scheduler.hpp:33
@ Vgpr
Definition amd_address_space.hpp:20
constexpr detail::ignore_t ignore
Definition utility/ignore.hpp:20
BlockGemmPipelineScheduler
Definition blkgemmpipe_scheduler.hpp:25
@ Intrawave
Definition blkgemmpipe_scheduler.hpp:26
__host__ __device__ constexpr auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition tensor_descriptor_helper.hpp:101
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
__device__ void block_sync_lds()
Definition synchronization.hpp:16
__host__ __device__ constexpr auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:90
ck::BlockwiseGemmXdlops_pipeline_hotloop_inst< BlockSize, MPerBlock, NPerBlock, KPerBlock, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, A_K1, B_K1, A_K1, B_K1, MRepeat, NRepeat, MPerXDL, NPerXDL, xdlops_gemm.KPerXdlops > HotLoopInstList
Definition blockwise_gemm_pipeline_xdlops_base.hpp:82
BlockwiseGemmXdlops_pipeline_base< BlockSize, ADataType, BDataType, ComputeDataType, AccDataType, ATileDesc, BTileDesc, AMmaTileDesc, BMmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXDL, NPerXDL, MRepeat, NRepeat, KPack, true > Base
Definition blockwise_gemm_pipeline_xdlops_blockscale_b_preshuffle_v1.hpp:112
__device__ void Run(const AGridDesc &a_grid_desc, const ABlockDesc &a_block_desc, ABlockTransfer &a_blockwise_copy, const AGridBuffer &a_grid_buf, ABlockBuffer &a_block_buf, const ABlockTransferStep &a_block_copy_step, const BGridDesc &b_grid_desc, const BBlockDesc &b_block_desc, BBlockTransfer &b_blockwise_copy, const BGridBuffer &b_grid_buf, BBlockBuffer &b_block_buf, const BBlockTransferStep &b_block_copy_step, const CScaleThreadDesc &c_scale_thread_desc, CThreadBuffer &c_thread_buf, const AScaleGridDesc &a_scale_grid_desc, const AScaleThreadDesc &a_scale_thread_desc, AScaleThreadTransfer &a_scale_thread_copy, const AScaleGridBuffer &a_scale_grid_buf, const AScaleThreadTransferStep &a_scale_thread_copy_step, const BScaleGridDesc &b_scale_grid_desc, const BScaleThreadDesc &b_scale_thread_desc, BScaleThreadTransfer &b_scale_thread_copy, const BScaleGridBuffer &b_scale_grid_buf, const BScaleThreadTransferStep &b_scale_thread_copy_step, index_t num_loop) const
Definition blockwise_gemm_pipeline_xdlops_blockscale_b_preshuffle_v1.hpp:253
ThreadwiseTensorSliceTransfer_v4< ADataType, ComputeDataType, decltype(a_block_desc_m0_m1_m2_k0_k1_k2), decltype(a_thread_desc_), Sequence< 1, 1, 1, 1, 1, KPack/KGroup >, Sequence< 0, 1, 2, 3, 4, 5 >, 5, A_K1, A_K1 > AThreadCopy
Definition blockwise_gemm_pipeline_xdlops_blockscale_b_preshuffle_v1.hpp:845
Definition blockwise_gemm_pipeline_xdlops_blockscale_b_preshuffle_v1.hpp:40
Definition utility/sequence.hpp:43
Definition static_buffer.hpp:75
Definition threadwise_tensor_slice_transfer.hpp:1260
Definition functional2.hpp:33
Definition dtype_vector.hpp:10