block_fmha_bwd_pipeline_trload_default_policy.hpp Source File

block_fmha_bwd_pipeline_trload_default_policy.hpp Source File#

Composable Kernel: block_fmha_bwd_pipeline_trload_default_policy.hpp Source File
block_fmha_bwd_pipeline_trload_default_policy.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
6
8
9namespace ck_tile {
10
12{
13 template <typename Problem>
14 CK_TILE_HOST_DEVICE static constexpr auto GetQKBlockGemm()
15 {
16 using GemmProblem =
17 BlockGemmProblem<typename Problem::QDataType,
18 typename Problem::KDataType,
19 typename Problem::AccDataType,
20 Problem::kBlockSize,
21 TileGemmShape<sequence<Problem::BlockFmhaShape::kM0,
22 Problem::BlockFmhaShape::kN0,
23 Problem::BlockFmhaShape::kK0>,
24 typename Problem::BlockFmhaShape::Gemm0BlockWarps,
25 typename Problem::BlockFmhaShape::Gemm0WarpTile>>;
26
27 constexpr auto SwizzleA = false;
28 using WarpGemm = WarpGemmDispatcher< //
29 typename Problem::QDataType,
30 typename Problem::KDataType,
31 typename Problem::AccDataType,
32 Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{}),
33 Problem::BlockFmhaShape::Gemm0WarpTile::at(number<1>{}),
34 Problem::BlockFmhaShape::Gemm0WarpTile::at(number<2>{}),
35 false,
36 SwizzleA>;
37
38 using BlockGemmPolicy =
39 BlockGemmARegBRegCRegV1CustomPolicy<typename Problem::QDataType,
40 typename Problem::KDataType,
41 typename Problem::AccDataType,
42 typename Problem::BlockFmhaShape::Gemm0BlockWarps,
43 WarpGemm>;
44
45 return BlockGemmARegBRegCRegV1<GemmProblem, BlockGemmPolicy, /* TransposeC */ true>{};
46 }
47
48 template <typename Problem>
53
54 template <typename Problem>
56 {
57 using GemmProblem =
58 BlockGemmProblem<typename Problem::OGradDataType,
59 typename Problem::VDataType,
60 typename Problem::AccDataType,
61 Problem::kBlockSize,
62 TileGemmShape<sequence<Problem::BlockFmhaShape::kM0,
63 Problem::BlockFmhaShape::kN0,
64 Problem::BlockFmhaShape::kK2>,
65 typename Problem::BlockFmhaShape::Gemm2BlockWarps,
66 typename Problem::BlockFmhaShape::Gemm2WarpTile>>;
67
68 constexpr auto SwizzleA = false;
69 using WarpGemm = WarpGemmDispatcher< //
70 typename Problem::OGradDataType,
71 typename Problem::VDataType,
72 typename Problem::AccDataType,
73 Problem::BlockFmhaShape::Gemm2WarpTile::at(number<0>{}),
74 Problem::BlockFmhaShape::Gemm2WarpTile::at(number<1>{}),
75 Problem::BlockFmhaShape::Gemm2WarpTile::at(number<2>{}),
76 false,
77 SwizzleA>;
78
79 using BlockGemmPolicy =
80 BlockGemmARegBRegCRegV1CustomPolicy<typename Problem::OGradDataType,
81 typename Problem::VDataType,
82 typename Problem::AccDataType,
83 typename Problem::BlockFmhaShape::Gemm2BlockWarps,
84 WarpGemm>;
85
86 return BlockGemmARegBRegCRegV1<GemmProblem, BlockGemmPolicy, /* TransposeC */ true>{};
87 }
88
89 template <typename Problem>
94
95 template <typename Problem>
97 {
98 using BlockFmhaShape = typename Problem::BlockFmhaShape;
99 using GemmProblem = BlockGemmProblem<
100 typename Problem::GemmDataType,
101 typename Problem::KDataType,
102 typename Problem::AccDataType,
103 Problem::kBlockSize,
106 typename BlockFmhaShape::Gemm4BlockWarps,
107 typename BlockFmhaShape::Gemm4WarpTile>>;
108
109 using WarpGemm = WarpGemmDispatcher< //
110 typename Problem::GemmDataType,
111 typename Problem::KDataType,
112 typename Problem::AccDataType,
113 BlockFmhaShape::Gemm4WarpTile::at(number<0>{}),
114 BlockFmhaShape::Gemm4WarpTile::at(number<1>{}),
115 BlockFmhaShape::Gemm4WarpTile::at(number<2>{}),
116 false,
117 false,
118 false,
119 (Problem::BlockFmhaShape::Gemm4WarpTile::at(number<2>{}) == 32)
120 ? WGAttrNumAccessEnum ::Double
121 : WGAttrNumAccessEnum ::Single>;
122
123 using BlockGemmPolicy =
124 BlockGemmARegBRegCRegV1CustomPolicy<typename Problem::GemmDataType,
125 typename Problem::KDataType,
126 typename Problem::AccDataType,
127 typename BlockFmhaShape::Gemm4BlockWarps,
128 WarpGemm>;
129
131 }
132
133 // these are for global load
134 template <typename Problem, typename T>
135 CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentX() noexcept
136 {
137 return 16 / sizeof(T);
138 }
139 template <typename Problem>
144 template <typename Problem>
149 template <typename Problem>
154 template <typename Problem>
159 template <typename Problem>
164 template <typename Problem>
169
170 template <typename Problem>
175
176 template <typename Problem>
181
182 // these are for load_tr_b64
183 template <typename T>
184 CK_TILE_HOST_DEVICE static constexpr auto GetTransposedAlignmentX() noexcept
185 {
186 return 8 / sizeof(T);
187 }
188 template <typename Problem>
193
194 template <typename Problem>
199
200 template <typename Problem>
202 {
203 constexpr index_t kBlockSize = Problem::kBlockSize;
204 constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
205 constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
206
207 constexpr index_t total_pixels = kMPerBlock * kNPerBlock / kBlockSize;
208
209 return total_pixels / GetAlignmentBias<Problem>();
210 }
211
212 template <typename Problem>
214 {
216 return 16 / sizeof(AccDataType);
217 }
218
219 template <typename Problem>
221 {
223 }
224
225 // It is found that alignment of 8x dwordx4 can avoid bank conflicts for both transposed and
226 // non-transposed load
227 static constexpr index_t WarpAlignmentBytes = 128;
228
229 // As load_lds requires contiguous LDS write, we need to transform the distribution of DRAM for
230 // reading
231 template <typename T, typename TensorView>
232 CK_TILE_HOST_DEVICE static constexpr auto TransformXDramTensorView(const TensorView& naive_view)
233 {
234 if constexpr(std::is_same_v<TensorView, ck_tile::null_tensor_view>)
235 {
236 return naive_view;
237 }
238 else
239 {
240 const auto transformed_desc =
241 TransformXDramDescriptor<T>(naive_view.get_tensor_descriptor());
242 return tensor_view<typename TensorView::buffer_view,
243 remove_cvref_t<decltype(transformed_desc)>,
244 TensorView::DstInMemOp>{naive_view.buf_, transformed_desc};
245 }
246 }
247 template <typename T, typename... TD_TS>
248 CK_TILE_HOST_DEVICE static constexpr auto
250 {
251 using from_desc_t = tensor_descriptor<TD_TS...>;
252
253 constexpr auto ndims = from_desc_t::get_num_of_dimension();
254 static_assert(ndims == 2, "XDram descriptor must have 2 dimensions");
255 const auto Rows = from_desc.get_length(number<0>{});
256 // constexpr auto Cols = 128;
257 // assert(from_desc.get_length(number<1>{}) == 128);
258 const auto Cols = from_desc.get_length(number<1>{});
259
260 constexpr index_t Dwordx4Bytes = 16;
261 constexpr index_t K2 = Dwordx4Bytes / sizeof(T);
262 constexpr index_t K1 = WarpAlignmentBytes / Dwordx4Bytes;
263 const index_t K0 = Cols / K1;
264 const auto ColLens = make_tuple(K0, number<K1>{}, number<K2>{});
265
266 const auto desc_tmp1 = transform_tensor_descriptor(
267 from_desc,
271
272 const auto desc_tmp2 = transform_tensor_descriptor(
273 desc_tmp1,
279
281 desc_tmp2,
286 }
287
288 template <typename Problem, typename T, index_t RowsPerBlock, index_t ColsPerBlock>
290 {
291 constexpr index_t kBlockSize = Problem::kBlockSize;
292 constexpr index_t kWarps = kBlockSize / get_warp_size();
293
294 constexpr index_t K3 = GetAlignmentK<Problem>(); // 8
295 constexpr index_t K2 = WarpAlignmentBytes / sizeof(T) / K3; // 8
296 constexpr index_t K_remain = ColsPerBlock / K2 / K3;
297 constexpr index_t K1 = min(kWarps, K_remain);
298 constexpr index_t K0 = K_remain / K1;
299 static_assert((K0 * K1 * K2 * K3 == ColsPerBlock) &&
300 K2 * K3 * sizeof(T) == WarpAlignmentBytes,
301 "ColsPerBlock notdivisible");
302
303 constexpr index_t N2 = get_warp_size() / K2; // 8
304 constexpr index_t N1 = max(1, kWarps / K1);
305 constexpr index_t N0 = RowsPerBlock / N1 / N2;
306 static_assert((N0 * N1 * N2 == RowsPerBlock) && (K1 * N1 == kWarps) &&
307 (K2 * N2 == get_warp_size()),
308 "RowsPerBlock not divisible");
309
313 tuple<sequence<2, 1>, sequence<1, 2>>, // K1 N1, N2 K2
315 sequence<1, 2, 2>, // N0 K0 K3
317 }
318
319 template <typename Problem>
321 {
322 return MakeXDramTileDistribution<Problem,
323 typename Problem::KDataType,
324 Problem::BlockFmhaShape::kN0,
325 Problem::BlockFmhaShape::kQKHeaddim>();
326 }
327
328 template <typename Problem>
330 {
331 return MakeXDramTileDistribution<Problem,
332 typename Problem::VDataType,
333 Problem::BlockFmhaShape::kN0,
334 Problem::BlockFmhaShape::kVHeaddim>();
335 }
336
337 template <typename Problem>
339 {
340 return MakeXDramTileDistribution<Problem,
341 typename Problem::QDataType,
342 Problem::BlockFmhaShape::kM0,
343 Problem::BlockFmhaShape::kQKHeaddim>();
344 }
345
346 template <typename Problem>
348 {
349 return MakeXDramTileDistribution<Problem,
350 typename Problem::OGradDataType,
351 Problem::BlockFmhaShape::kM0,
352 Problem::BlockFmhaShape::kVHeaddim>();
353 }
354
355 template <typename Problem>
357 {
359 constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
360 constexpr index_t MWarp = config.template at<1>();
361 constexpr index_t NWarp = config.template at<2>();
362
363 constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
364
365 constexpr index_t N0 = MWarp * NWarp;
366
367 constexpr index_t M1 = kMPerBlock;
368 constexpr index_t M0 = get_warp_size() / M1;
369 static_assert(M1 <= get_warp_size() && get_warp_size() % M1 == 0,
370 "M1 must be a factor of warp size");
371
378 sequence<1>>{});
379 }
380
381 template <typename Problem>
386
387 template <typename DataType, index_t MPerBlock, index_t KPerBlock>
389 {
390 constexpr index_t K1 = 16 / sizeof(DataType);
391 constexpr index_t K0 = KPerBlock / K1;
392 constexpr index_t M2 = 1;
393 constexpr index_t M1 = get_warp_size();
394 constexpr index_t M0 = MPerBlock / M1;
395
403 }
404
405 template <typename Problem>
407 {
409
410 constexpr index_t kBlockSize = Problem::kBlockSize;
411 constexpr index_t kKPerBlock = Problem::kVHeaddim;
412
414 }
415
416 template <typename Problem>
418 {
420
421 constexpr index_t kBlockSize = Problem::kBlockSize;
422 constexpr index_t kKPerBlock = Problem::kVHeaddim;
423
425 }
426
427 template <typename Problem>
429 {
431
432 constexpr index_t kBlockSize = Problem::kBlockSize;
433 constexpr index_t kMPerBlock = Problem::kM0;
434 constexpr index_t kKPerBlock = Problem::kQKHeaddim;
435
436 constexpr index_t K1 = 16 / sizeof(AccDataType);
437 constexpr index_t K0 = kKPerBlock / K1;
438
439 constexpr index_t M2 = get_warp_size() / K0;
440 constexpr index_t M1 = kBlockSize / get_warp_size();
441 constexpr index_t M0 = kMPerBlock / (M1 * M2);
442
450 }
451
452 template <typename Problem>
454 {
456
457 constexpr index_t kBlockSize = Problem::kBlockSize;
458 constexpr index_t kMPerBlock = Problem::kM0;
459 constexpr index_t kKPerBlock = Problem::kQKHeaddim;
460
461 constexpr index_t K1 = 16 / sizeof(AccDataType);
462 constexpr index_t K0 = kKPerBlock / K1;
463
464 constexpr index_t M2 = get_warp_size() / K0;
465 constexpr index_t M1 = kBlockSize / get_warp_size();
466 constexpr index_t M0 = kMPerBlock / (M1 * M2);
467
474 sequence<0, 1>>{});
475 }
476
477 template <typename Problem>
482
483 template <typename Problem>
488
489 template <typename Problem>
491 {
493 using WarpGemm = typename BlockGemm::WarpGemm;
494
495 constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm4BlockWarps::at(number<0>{});
496 constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm4BlockWarps::at(number<1>{});
497
498 constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
499 constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kN0;
500
501 constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WarpGemm::kN);
502 constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK;
503
504 constexpr auto kt_block_outer_dstr_encoding = tile_distribution_encoding<
511
512 constexpr auto kt_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
513 kt_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{});
514
515 auto output =
517 decltype(kt_block_dstr_encode),
518 typename Problem::KDataType>::TransposedDstrEncode{});
519 return output;
520 }
521
522 // lds write descriptor used together with block_sync_lds (transformed dram descriptor)
523 template <typename T, index_t MNPerBlock, index_t KPerBlock>
538 template <typename Problem>
540 {
541 return MakeXLdsWriteBlockDescriptor<typename Problem::KDataType,
542 Problem::BlockFmhaShape::kN0,
543 Problem::BlockFmhaShape::kQKHeaddim>();
544 }
545 template <typename Problem>
547 {
548 return MakeXLdsWriteBlockDescriptor<typename Problem::VDataType,
549 Problem::BlockFmhaShape::kN0,
550 Problem::BlockFmhaShape::kVHeaddim>();
551 }
552 template <typename Problem>
554 {
555 return MakeXLdsWriteBlockDescriptor<typename Problem::QDataType,
556 Problem::BlockFmhaShape::kM0,
557 Problem::BlockFmhaShape::kQKHeaddim>();
558 }
559 template <typename Problem>
561 {
562 return MakeXLdsWriteBlockDescriptor<typename Problem::OGradDataType,
563 Problem::BlockFmhaShape::kM0,
564 Problem::BlockFmhaShape::kQKHeaddim>();
565 }
566 template <typename Problem>
571
572 template <typename Problem, bool Transposed = false>
574 {
575 // SGrad should be of the same distr as Gemm2 OGradV's output (i.e. PGrad)
577 using WarpGemm = typename BlockGemm::WarpGemm;
578
579 constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
580 constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
581
582 constexpr index_t M2 = WarpGemm::WarpGemmAttribute::Impl::kCM1PerLane;
583 constexpr index_t M1 = WarpGemm::WarpGemmAttribute::Impl::kCMLane;
584 static_assert(WarpGemm::WarpGemmAttribute::Impl::kCM0PerLane == 1, "kCM0PerLane must be 1");
585 constexpr index_t M0 = kMPerBlock / (M1 * M2);
586
587 constexpr index_t N1 = WarpGemm::WarpGemmAttribute::Impl::kCNLane;
588 constexpr index_t N0 = kNPerBlock / N1;
589
590 constexpr auto desc_0 = make_naive_tensor_descriptor_packed(
592
593 constexpr index_t M1_0 = 2, M1_1 = 2;
594 constexpr index_t N1_0 = 2, N1_1 = 8;
595 static_assert(M1_0 * M1_1 == M1, "M1_0 * M1_1 must equal M1");
596 static_assert(N1_0 * N1_1 == N1, "N1_0 * N1_1 must equal N1");
597
598 constexpr auto desc_1 = transform_tensor_descriptor(
599 desc_0,
608 constexpr auto desc_2 = transform_tensor_descriptor(
609 desc_1,
617 sequence<1>{},
619 sequence<3>{},
620 sequence<5>{},
621 sequence<6>{}),
623 sequence<1>{},
625 sequence<3>{},
626 sequence<5>{},
627 sequence<6>{}));
628
629 constexpr auto top_dims = []() {
630 if constexpr(Transposed)
632 else
634 }();
636 desc_2,
642 top_dims);
643 }
644
645 template <typename T, index_t MNPerBlock, index_t KPerBlock>
647 {
648 const auto Dwordx4Bytes = 16;
649 const auto K2 = Dwordx4Bytes / sizeof(T);
650 const auto K1 = WarpAlignmentBytes / Dwordx4Bytes;
651 const auto K0 = KPerBlock / (K1 * K2);
652
653 constexpr auto desc_0 = make_naive_tensor_descriptor_packed(
655 constexpr auto desc_1 = transform_tensor_descriptor(
656 desc_0,
663 desc_1,
669 }
670 template <typename Problem>
672 {
673 return MakeXLdsReadBlockDescriptor<typename Problem::KDataType,
674 Problem::BlockFmhaShape::kN0,
675 Problem::BlockFmhaShape::kQKHeaddim>();
676 }
677 template <typename Problem>
679 {
680 return MakeXLdsReadBlockDescriptor<typename Problem::VDataType,
681 Problem::BlockFmhaShape::kN0,
682 Problem::BlockFmhaShape::kVHeaddim>();
683 }
684 template <typename Problem>
686 {
687 return MakeXLdsReadBlockDescriptor<typename Problem::QDataType,
688 Problem::BlockFmhaShape::kM0,
689 Problem::BlockFmhaShape::kQKHeaddim>();
690 }
691 template <typename Problem>
693 {
694 return MakeXLdsReadBlockDescriptor<typename Problem::OGradDataType,
695 Problem::BlockFmhaShape::kM0,
696 Problem::BlockFmhaShape::kQKHeaddim>();
697 }
698
699 template <typename Problem>
701 {
703 constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
704 using WarpGemm = remove_cvref_t<decltype(config.template at<0>())>;
705
706 constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm0BlockWarps::at(number<0>{});
707 constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm0BlockWarps::at(number<1>{});
708
709 constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
710 constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
711
712 constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WarpGemm::kM);
713 constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK;
714
715 constexpr auto q_block_outer_dstr_encoding =
722
723 constexpr auto q_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
724 q_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{});
725
726 constexpr auto q_block_dstr = make_static_tile_distribution(q_block_dstr_encode);
727
728 return q_block_dstr;
729 }
730
731 template <typename Problem>
733 {
735 using WarpGemm = typename BlockGemm::WarpGemm;
736
737 constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm3BlockWarps::at(number<0>{});
738 constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm3BlockWarps::at(number<1>{});
739
740 constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
741 constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK3;
742
743 constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WarpGemm::kN);
744 constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK;
745
746 constexpr auto qt_block_outer_dstr_encoding =
753
754 constexpr auto qt_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
755 qt_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{});
756
758 decltype(qt_block_dstr_encode),
759 typename Problem::QDataType>::TransposedDstrEncode{});
760 }
761
762 template <typename Problem>
764 {
766 using WarpGemm = typename BlockGemm::WarpGemm;
767
768 constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm3BlockWarps::at(number<0>{});
769 constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm3BlockWarps::at(number<1>{});
770
771 constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kN0;
772 constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK3;
773
774 constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WarpGemm::kM);
775 constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK;
776
777 constexpr auto dst_block_outer_dstr_encoding =
784
785 constexpr auto dst_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
786 dst_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{});
787
788 constexpr auto dst_block_dstr = make_static_tile_distribution(dst_block_dstr_encode);
789
790 return dst_block_dstr;
791 }
792
793 template <typename Problem>
795 {
796 constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
798 constexpr index_t kMPack = 16 / sizeof(LSEDType);
799
800 constexpr auto lsed_lds_block_desc =
804 number<1>{});
805
806 return lsed_lds_block_desc;
807 }
808
809 template <typename Problem>
811 {
813 constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
814 using WG = remove_cvref_t<decltype(config.template at<0>())>;
815 constexpr index_t MWarp = config.template at<1>();
816 constexpr index_t NWarp = config.template at<2>();
817
818 constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
819
820 constexpr index_t N1 = WG::WarpGemmAttribute::Impl::kCNLane;
821 constexpr index_t N0 = NWarp;
822
823 // M4 *2 and M2 /2 when swizzle mode enabled
824 constexpr index_t SwizzleConfig = WG::kM == 16 ? 1 : 2;
825 // constexpr index_t SwizzleConfig = 1;
826 constexpr index_t M4 = WG::WarpGemmAttribute::Impl::kCM1PerLane * SwizzleConfig;
827 constexpr index_t M3 = WG::WarpGemmAttribute::Impl::kCMLane;
828 constexpr index_t M2 = WG::WarpGemmAttribute::Impl::kCM0PerLane / SwizzleConfig;
829 constexpr index_t M1 = MWarp;
830 constexpr index_t M0 = kMPerBlock / (M1 * WG::WarpGemmAttribute::Impl::kM);
831
839 }
840
841 template <typename Problem>
843 {
845 constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
846 using WarpGemm = remove_cvref_t<decltype(config.template at<0>())>;
847
848 constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm2BlockWarps::at(number<0>{});
849 constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm2BlockWarps::at(number<1>{});
850
851 constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
852 constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK2;
853
854 constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WarpGemm::kM);
855 constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK;
856
857 constexpr auto do_block_outer_dstr_encoding =
864
865 constexpr auto do_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
866 do_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{});
867
868 constexpr auto do_block_dstr = make_static_tile_distribution(do_block_dstr_encode);
869
870 return do_block_dstr;
871 }
872
873 template <typename Problem>
875 {
877 using WarpGemm = typename BlockGemm::WarpGemm;
878
879 constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm1BlockWarps::at(number<0>{});
880 constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm1BlockWarps::at(number<1>{});
881
882 constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kVHeaddim;
883 // constexpr index_t kNPerBlock = 32;
884 constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
885
886 constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WarpGemm::kN);
887 constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK;
888
889 constexpr auto dot_block_outer_dstr_encoding =
896
897 constexpr auto dot_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
898 dot_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{});
899 // CK_PRINT<typename WarpGemm::BWarpDstrEncoding>();
900 // CK_PRINT<decltype(dot_block_dstr_encode)>();
901
904 decltype(dot_block_dstr_encode),
905 typename Problem::OGradDataType>::TransposedDstrEncode{});
906 }
907
908 template <typename Problem>
910 {
912 using WarpGemm = typename BlockGemm::WarpGemm;
913
914 constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm1BlockWarps::at(number<0>{});
915 constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm1BlockWarps::at(number<1>{});
916
917 constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kN0;
918 constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
919
920 constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WarpGemm::kM);
921 constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK;
922
923 constexpr auto pt_block_outer_dstr_encoding =
930
931 constexpr auto pt_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
932 pt_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{});
933
934 constexpr auto pt_block_dstr = make_static_tile_distribution(pt_block_dstr_encode);
935
936 return pt_block_dstr;
937 }
938
939 template <typename Problem>
941 {
943 using WarpGemm = typename BlockGemm::WarpGemm;
944
945 constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm4BlockWarps::at(number<0>{});
946 constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm4BlockWarps::at(number<1>{});
947
948 constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
949 constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK4;
950
951 constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WarpGemm::kM);
952 constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK;
953
954 constexpr auto ds_block_outer_dstr_encoding =
961
962 constexpr auto ds_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
963 ds_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{});
964
967 decltype(ds_block_dstr_encode),
968 typename Problem::GemmDataType>::TransposedDstrEncode{});
969 }
970
971 template <typename Problem>
976
977 template <typename BlockGemm>
979 {
980 using c_block_tensor_type = decltype(BlockGemm{}.MakeCBlockTile());
981 return c_block_tensor_type::get_tile_distribution();
982 }
983
984 template <typename Problem>
986 {
987 return sizeof(typename Problem::QDataType) *
988 MakeQLdsWriteBlockDescriptor<Problem>().get_element_space_size();
989 }
990
991 template <typename Problem>
993 {
994 return sizeof(typename Problem::KDataType) *
995 MakeKLdsWriteBlockDescriptor<Problem>().get_element_space_size();
996 }
997
998 template <typename Problem>
1000 {
1001 return static_cast<index_t>(max( //
1002 sizeof(int) * get_warp_size(),
1003 sizeof(typename Problem::LSEDataType) *
1004 MakeLSEDLdsWriteBlockDescriptor<Problem>().get_element_space_size()));
1005 }
1006
1007 template <typename Problem>
1009 {
1010 return GetSmemSizeLSE<Problem>();
1011 }
1012
1013 template <typename Problem>
1015 {
1016 return sizeof(typename Problem::VDataType) *
1017 MakeVLdsWriteBlockDescriptor<Problem>().get_element_space_size();
1018 }
1019
1020 template <typename Problem>
1022 {
1023 return sizeof(typename Problem::OGradDataType) *
1024 MakeOGradLdsWriteBlockDescriptor<Problem>().get_element_space_size();
1025 }
1026
1027 template <typename Problem>
1029 {
1030 return sizeof(typename Problem::GemmDataType) *
1031 MakeSGradLdsBlockDescriptor<Problem>().get_element_space_size();
1032 }
1033
1034 template <typename Problem>
1036 {
1037 if constexpr(Problem::BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
1038 return sizeof(typename Problem::BiasDataType) *
1039 MakeBiasLdsBlockDescriptor<Problem>().get_element_space_size();
1040 else
1041 return 0;
1042 }
1043
1044 template <typename Problem>
1046 {
1047 constexpr index_t smem_size_q = GetSmemSizeQ<Problem>();
1048 constexpr index_t smem_size_lse = GetSmemSizeLSE<Problem>();
1049 constexpr index_t smem_size_k = GetSmemSizeK<Problem>();
1050 constexpr index_t smem_size_v = GetSmemSizeV<Problem>();
1051 constexpr index_t smem_size_do = GetSmemSizeOGrad<Problem>();
1052 constexpr index_t smem_size_d = GetSmemSizeD<Problem>();
1053 constexpr index_t smem_size_ds = GetSmemSizeSGrad<Problem>();
1054 constexpr index_t smem_size_bias = GetSmemSizeBias<Problem>();
1055
1056 constexpr index_t smem_size_stage0 = smem_size_k + smem_size_v;
1057 constexpr index_t smem_size_stage1 = smem_size_q * 2 + smem_size_do * 2 +
1058 smem_size_lse * 2 + smem_size_d * 2 +
1059 max(smem_size_bias, smem_size_ds);
1060 return max(smem_size_stage0, smem_size_stage1);
1061 }
1062
1063 template <typename Problem>
1065 {
1066 static constexpr index_t kBlockSize = Problem::kBlockSize;
1067 static constexpr index_t kM0 = Problem::BlockFmhaShape::kM0;
1068 static constexpr index_t kN0 = Problem::BlockFmhaShape::kN0;
1069 static constexpr index_t kQKHeaddim = Problem::BlockFmhaShape::kQKHeaddim;
1070 static constexpr index_t kVHeaddim = Problem::BlockFmhaShape::kVHeaddim;
1071 static constexpr index_t kK0 = Problem::BlockFmhaShape::kK0;
1072 static constexpr index_t kK2 = Problem::BlockFmhaShape::kK2;
1073 static constexpr index_t kK4 = Problem::BlockFmhaShape::kK4;
1074
1075 static constexpr index_t WarpGemmM =
1076 Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{});
1077 static constexpr index_t WarpGemmN =
1078 Problem::BlockFmhaShape::Gemm0WarpTile::at(number<1>{});
1079 static constexpr index_t WarpGemmK =
1080 Problem::BlockFmhaShape::Gemm0WarpTile::at(number<2>{});
1081 static constexpr index_t Gemm4MWarp =
1082 Problem::BlockFmhaShape::Gemm4BlockWarps::at(number<0>{});
1083 static constexpr index_t Gemm4NWarp =
1084 Problem::BlockFmhaShape::Gemm4BlockWarps::at(number<1>{});
1085
1086 static constexpr index_t blockWarps = kBlockSize / get_warp_size();
1087 using GemmDataType = typename Problem::GemmDataType;
1088
1089 // Compute
1090 static constexpr index_t Gemm0MFMA =
1091 kM0 * kN0 * kK0 / (blockWarps * WarpGemmM * WarpGemmN * WarpGemmK);
1092 static constexpr index_t Gemm1MFMA =
1093 kN0 * kVHeaddim * kM0 / (blockWarps * WarpGemmM * WarpGemmN * WarpGemmK);
1094 static constexpr index_t Gemm2MFMA =
1095 kM0 * kN0 * kK2 / (blockWarps * WarpGemmM * WarpGemmN * WarpGemmK);
1096 static constexpr index_t Gemm3MFMA =
1097 kN0 * kQKHeaddim * kM0 / (blockWarps * WarpGemmM * WarpGemmN * WarpGemmK);
1098 static constexpr index_t Gemm4MFMA =
1099 kM0 * kQKHeaddim * kN0 / (blockWarps * WarpGemmM * WarpGemmN * WarpGemmK);
1100
1101 // VMEM
1102 static constexpr index_t Q_VMEM_READ =
1103 kM0 * kQKHeaddim / kBlockSize / GetAlignmentQ<Problem>();
1104 static constexpr index_t OGrad_VMEM_READ =
1105 kM0 * kVHeaddim / kBlockSize / GetAlignmentOGrad<Problem>();
1106 static constexpr index_t LSE_VMEM_READ = 1;
1107 static constexpr index_t D_VMEM_READ = 1;
1108
1109 static constexpr index_t DQ_VMEM_WRITE = kM0 * kQKHeaddim / kBlockSize; // atomic add
1110
1111 // LDS Read
1112 static constexpr index_t OGradT_LDS_READ =
1114 static constexpr index_t QT_LDS_READ =
1115 kM0 * kQKHeaddim / get_warp_size() / GetTransposedAlignmentQ<Problem>();
1116 static constexpr index_t SGradT_LDS_READ_P1 =
1117 kM0 * kK4 / (get_warp_size() * Gemm4MWarp) / GetTransposedAlignmentX<GemmDataType>();
1118 static constexpr index_t SGradT_LDS_READ_P2 =
1119 kM0 * kN0 / (get_warp_size() * Gemm4MWarp) / GetTransposedAlignmentX<GemmDataType>() -
1120 SGradT_LDS_READ_P1;
1121 static constexpr index_t Q_LDS_READ =
1122 kM0 * kK0 / get_warp_size() / GetAlignmentQ<Problem>();
1123 static constexpr index_t LSE_LDS_READ = kM0 / (4 * 4);
1124 static constexpr index_t D_LDS_READ = LSE_LDS_READ;
1125 static constexpr index_t OGrad_LDS_READ =
1126 kM0 * kK2 / kBlockSize / GetAlignmentOGrad<Problem>();
1127
1128 // LDS Write
1129 static constexpr index_t Q_LDS_WRITE =
1130 kM0 * kQKHeaddim / Problem::kBlockSize / GetAlignmentQ<Problem>();
1131 static constexpr index_t QT_LDS_WRITE =
1132 kM0 * kQKHeaddim / kBlockSize / GetTransposedAlignmentQ<Problem>();
1133 static constexpr index_t OGrad_LDS_WRITE =
1134 kM0 * kVHeaddim / kBlockSize / GetAlignmentOGrad<Problem>();
1135 static constexpr index_t OGradT_LDS_WRITE =
1136 kM0 * kVHeaddim / kBlockSize / GetTransposedAlignmentOGrad<Problem>();
1137 static constexpr index_t SGradT_LDS_WRITE = kM0 * kN0 / kBlockSize;
1138
1139 public:
1140 static constexpr index_t TOTAL_VMEM_READ =
1141 Q_VMEM_READ + OGrad_VMEM_READ + LSE_VMEM_READ + D_VMEM_READ + DQ_VMEM_WRITE;
1142
1143 CK_TILE_DEVICE static constexpr void SchedulerGemm0()
1144 {
1145 // Mem: Q, LSE, OGrad, D global load, OGrad^T LDS load
1146 // Comp: Q x K
1147 constexpr index_t VMEM_READ_INST =
1148 Q_VMEM_READ + OGrad_VMEM_READ + LSE_VMEM_READ + D_VMEM_READ;
1149 constexpr index_t MFMA_INST = Gemm0MFMA;
1150 constexpr index_t LDS_READ_INST = OGradT_LDS_READ + LSE_LDS_READ + D_LDS_READ;
1151
1152 constexpr index_t lcm_inst = lcm(VMEM_READ_INST, MFMA_INST, LDS_READ_INST);
1153 static_for<0, lcm_inst, 1>{}([&](auto i) {
1154 if constexpr(i % (lcm_inst / VMEM_READ_INST) == 0)
1155 __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
1156 if constexpr(i % (lcm_inst / MFMA_INST) == 0)
1157 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
1158 if constexpr(i % (lcm_inst / LDS_READ_INST) == 0)
1159 __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
1160 });
1161 }
1162
1163 CK_TILE_DEVICE static constexpr void SchedulerGemm12()
1164 {
1165 // Mem: Q^T LDS load
1166 // Comp: PT x OGrad
1167 constexpr index_t LDS_READ_INST = QT_LDS_READ;
1168 constexpr index_t MFMA_INST = Gemm1MFMA + Gemm2MFMA;
1169
1170 constexpr index_t lcm_inst = lcm(MFMA_INST, LDS_READ_INST);
1171 static_for<0, lcm_inst, 1>{}([&](auto i) {
1172 if constexpr(i % (lcm_inst / MFMA_INST) == 0)
1173 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
1174 if constexpr(i % (lcm_inst / LDS_READ_INST) == 0)
1175 __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // VMEM read
1176 });
1177 }
1178
1179 CK_TILE_DEVICE static constexpr void SchedulerGemm3()
1180 {
1181 // Mem: LSE/D LDS store, SGradT LDS store, SGrad, Q, LSE LDS load.
1182 // Comp: SGradT x QT
1183 constexpr index_t LDS_WRITE_INST = SGradT_LDS_WRITE;
1184 constexpr index_t LDS_READ_INST = SGradT_LDS_READ_P1 + Q_LDS_READ;
1185 constexpr index_t MFMA_INST = Gemm3MFMA;
1186
1187 constexpr index_t lds_rw_inst = LDS_WRITE_INST + LDS_READ_INST;
1188 constexpr index_t lcm_inst = lcm(MFMA_INST, lds_rw_inst);
1189
1190 static_for<0, lcm_inst, 1>{}([&](auto i) {
1191 if constexpr(i % (lcm_inst / MFMA_INST) == 0)
1192 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
1193 if constexpr(i % (lcm_inst / lds_rw_inst) == 0)
1194 {
1195 if constexpr(i / (lcm_inst / lds_rw_inst) < LDS_WRITE_INST)
1196 __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS Write
1197 else
1198 __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS Read
1199 }
1200 });
1201 }
1202
1203 CK_TILE_DEVICE static constexpr void SchedulerGemm4()
1204 {
1205 // Mem: SGrad, OGrad, D LDS load.
1206 // Comp: SGrad x KT
1207 constexpr index_t LDS_READ_INST = SGradT_LDS_READ_P2 + OGrad_LDS_READ;
1208 constexpr index_t MFMA_INST = Gemm4MFMA;
1209
1210 constexpr index_t lcm_inst = lcm(MFMA_INST, LDS_READ_INST);
1211 static_for<0, lcm_inst, 1>{}([&](auto i) {
1212 if constexpr(i % (lcm_inst / MFMA_INST) == 0)
1213 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
1214 if constexpr(i % (lcm_inst / LDS_READ_INST) == 0)
1215 __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
1216 });
1217 }
1218 };
1219};
1220
1221} // namespace ck_tile
Definition block_fmha_bwd_pipeline_trload_default_policy.hpp:1065
static CK_TILE_DEVICE constexpr void SchedulerGemm12()
Definition block_fmha_bwd_pipeline_trload_default_policy.hpp:1163
static CK_TILE_DEVICE constexpr void SchedulerGemm3()
Definition block_fmha_bwd_pipeline_trload_default_policy.hpp:1179
static constexpr index_t TOTAL_VMEM_READ
Definition block_fmha_bwd_pipeline_trload_default_policy.hpp:1140
static CK_TILE_DEVICE constexpr void SchedulerGemm4()
Definition block_fmha_bwd_pipeline_trload_default_policy.hpp:1203
static CK_TILE_DEVICE constexpr void SchedulerGemm0()
Definition block_fmha_bwd_pipeline_trload_default_policy.hpp:1143
#define CK_TILE_DEVICE
Definition config.hpp:41
#define CK_TILE_HOST_DEVICE
Definition config.hpp:42
CK_TILE_HOST_DEVICE constexpr auto make_embed_tile_distribution_encoding(OuterDstr, InnerDstr)
Definition tile_distribution_encoding.hpp:457
Definition tile/core/algorithm/cluster_descriptor.hpp:13
CK_TILE_HOST_DEVICE constexpr auto make_naive_tensor_descriptor_packed(const tuple< Lengths... > &lengths, number< GuaranteedLastDimensionVectorLength >=number<-1 >{})
Definition tile/core/tensor/tensor_descriptor.hpp:371
typename impl::WarpGemmDispatcher< AType, BType, AccType, MPerWave, NPerWave, KPerWave, TransposeC, SwizzleA, UseStructuredSparsity, AttrNumAccess >::Type WarpGemmDispatcher
Definition warp_gemm_dispatcher.hpp:182
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
CK_TILE_HOST_DEVICE constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition coordinate_transform.hpp:1558
CK_TILE_HOST_DEVICE constexpr index_t get_warp_size()
Definition arch.hpp:63
TransposeTileDistributionTraits< TileDistributionEncoding_, DataType_, Policy, true > InputTileDistributionTraits
Definition load_tile_transpose.hpp:343
@ ELEMENTWISE_BIAS
Definition block_attention_bias_enum.hpp:14
CK_TILE_HOST_DEVICE constexpr auto make_naive_tensor_descriptor(const tuple< Lengths... > &lengths, const tuple< Strides... > &strides, number< GuaranteedLastDimensionVectorLength >=number<-1 >{}, number< GuaranteedLastDimensionVectorStride >=number<-1 >{})
Definition tile/core/tensor/tensor_descriptor.hpp:274
CK_TILE_HOST_DEVICE constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldTopIdss, NewUpperDimensionNewTopIdss)
Definition tile/core/tensor/tensor_descriptor.hpp:203
CK_TILE_HOST_DEVICE constexpr auto make_unmerge_transform(const UpLengths &up_lengths, bool_constant< Use24BitIntegerCalculation >=bool_constant< false >{})
Definition coordinate_transform.hpp:1622
constant< v > number
Definition tile/core/numeric/integral_constant.hpp:37
CK_TILE_HOST_DEVICE constexpr auto make_xor_transform(const LowLengths &low_lengths)
Definition coordinate_transform.hpp:1662
CK_TILE_HOST_DEVICE constexpr auto make_merge_transform_v3_division_mod(const LowLengths &low_lengths)
Definition coordinate_transform.hpp:1609
CK_TILE_HOST_DEVICE constexpr T max(T x)
Definition tile/core/numeric/math.hpp:161
CK_TILE_HOST_DEVICE constexpr T min(T x)
Definition tile/core/numeric/math.hpp:210
int32_t index_t
Definition integer.hpp:9
CK_TILE_HOST_DEVICE constexpr auto make_static_tile_distribution(StaticTileDistributionEncoding_)
Definition tile_distribution.hpp:480
CK_TILE_HOST_DEVICE constexpr auto lcm(X x, Y y)
Definition tile/core/numeric/math.hpp:314
CK_TILE_HOST_DEVICE constexpr auto make_tuple(Xs &&... xs)
Definition tile/core/container/tuple.hpp:360
static CK_TILE_HOST_DEVICE constexpr auto MakeVRegBlockDescriptor()
Definition block_fmha_bwd_pipeline_default_policy.hpp:1071
static CK_TILE_HOST_DEVICE constexpr auto MakeBiasLdsBlockDescriptor()
Definition block_fmha_bwd_pipeline_default_policy.hpp:1798
static CK_TILE_HOST_DEVICE constexpr auto GetSGradTQTBlockGemm()
Definition block_fmha_bwd_pipeline_default_policy.hpp:138
static CK_TILE_DEVICE constexpr auto GetPTOGradTBlockGemm()
Definition block_fmha_bwd_pipeline_default_policy.hpp:66
static CK_TILE_HOST_DEVICE constexpr auto MakeBiasTileDistribution()
Definition block_fmha_bwd_pipeline_default_policy.hpp:614
static CK_TILE_HOST_DEVICE constexpr auto MakeKRegBlockDescriptor()
Definition block_fmha_bwd_pipeline_default_policy.hpp:1013
static CK_TILE_HOST_DEVICE constexpr auto MakeShuffledBiasTileDistribution()
Definition block_fmha_bwd_pipeline_default_policy.hpp:1776
Definition block_fmha_bwd_pipeline_trload_default_policy.hpp:12
static CK_TILE_HOST_DEVICE constexpr index_t GetSmemSize()
Definition block_fmha_bwd_pipeline_trload_default_policy.hpp:1045
static CK_TILE_HOST_DEVICE constexpr auto MakePreOGradDramTileDistribution()
Definition block_fmha_bwd_pipeline_trload_default_policy.hpp:417
static constexpr index_t WarpAlignmentBytes
Definition block_fmha_bwd_pipeline_trload_default_policy.hpp:227
static CK_TILE_HOST_DEVICE constexpr auto MakeOGradLdsWriteBlockDescriptor()
Definition block_fmha_bwd_pipeline_trload_default_policy.hpp:560
static CK_TILE_HOST_DEVICE constexpr auto MakeLSEDDramTileDistribution()
Definition block_fmha_bwd_pipeline_trload_default_policy.hpp:356
static CK_TILE_HOST_DEVICE constexpr auto GetAlignmentBias()
Definition block_fmha_bwd_pipeline_trload_default_policy.hpp:165
static CK_TILE_HOST_DEVICE constexpr auto MakeQDramTileDistribution()
Definition block_fmha_bwd_pipeline_trload_default_policy.hpp:338
static CK_TILE_HOST_DEVICE constexpr auto MakeXDramTileDistribution()
Definition block_fmha_bwd_pipeline_trload_default_policy.hpp:289
static CK_TILE_HOST_DEVICE constexpr auto TransformXDramTensorView(const TensorView &naive_view)
Definition block_fmha_bwd_pipeline_trload_default_policy.hpp:232
static CK_TILE_HOST_DEVICE constexpr auto MakePreODramTileDistribution()
Definition block_fmha_bwd_pipeline_trload_default_policy.hpp:406
static CK_TILE_HOST_DEVICE constexpr index_t GetSmemSizeBias()
Definition block_fmha_bwd_pipeline_trload_default_policy.hpp:1035
static CK_TILE_HOST_DEVICE constexpr auto MakeSGradTRegSliceBlockDescriptor()
Definition block_fmha_bwd_pipeline_trload_default_policy.hpp:763
static CK_TILE_HOST_DEVICE constexpr auto GetAlignmentK()
Definition block_fmha_bwd_pipeline_trload_default_policy.hpp:145
static CK_TILE_HOST_DEVICE constexpr auto GetAlignmentOGrad()
Definition block_fmha_bwd_pipeline_trload_default_policy.hpp:160
static CK_TILE_HOST_DEVICE constexpr auto MakeBiasLdsBlockDescriptor()
Definition block_fmha_bwd_pipeline_trload_default_policy.hpp:567
static CK_TILE_HOST_DEVICE constexpr auto MakeBiasSTileDistribution()
Definition block_fmha_bwd_pipeline_trload_default_policy.hpp:978
static CK_TILE_HOST_DEVICE constexpr auto MakePostQGradDramTileDistribution()
Definition block_fmha_bwd_pipeline_trload_default_policy.hpp:453
static CK_TILE_HOST_DEVICE constexpr auto GetAlignmentO()
Definition block_fmha_bwd_pipeline_trload_default_policy.hpp:155
static CK_TILE_HOST_DEVICE constexpr auto GetPTOGradTBlockGemm()
Definition block_fmha_bwd_pipeline_trload_default_policy.hpp:49
static CK_TILE_HOST_DEVICE constexpr auto MakeKRegBlockDescriptor()
Definition block_fmha_bwd_pipeline_trload_default_policy.hpp:478
static CK_TILE_HOST_DEVICE constexpr auto MakeKLdsWriteBlockDescriptor()
Definition block_fmha_bwd_pipeline_trload_default_policy.hpp:539
static CK_TILE_HOST_DEVICE constexpr auto MakeVDramTileDistribution()
Definition block_fmha_bwd_pipeline_trload_default_policy.hpp:329
static CK_TILE_HOST_DEVICE constexpr auto MakeKLdsReadBlockDescriptor()
Definition block_fmha_bwd_pipeline_trload_default_policy.hpp:671
static CK_TILE_HOST_DEVICE constexpr auto GetQKBlockGemm()
Definition block_fmha_bwd_pipeline_trload_default_policy.hpp:14
static CK_TILE_HOST_DEVICE constexpr auto GetTransposedAlignmentOGrad()
Definition block_fmha_bwd_pipeline_trload_default_policy.hpp:195
static CK_TILE_HOST_DEVICE constexpr auto GetTransposedAlignmentBias()
Definition block_fmha_bwd_pipeline_trload_default_policy.hpp:201
static CK_TILE_HOST_DEVICE constexpr auto MakePostQGradAccDramTileDistribution()
Definition block_fmha_bwd_pipeline_trload_default_policy.hpp:428
static CK_TILE_HOST_DEVICE constexpr auto MakeVLdsReadBlockDescriptor()
Definition block_fmha_bwd_pipeline_trload_default_policy.hpp:678
static CK_TILE_HOST_DEVICE constexpr auto GetAlignmentX() noexcept
Definition block_fmha_bwd_pipeline_trload_default_policy.hpp:135
static CK_TILE_HOST_DEVICE constexpr auto MakeOGradDramTileDistribution()
Definition block_fmha_bwd_pipeline_trload_default_policy.hpp:347
static CK_TILE_HOST_DEVICE constexpr auto GetAlignmentPostQGrad()
Definition block_fmha_bwd_pipeline_trload_default_policy.hpp:220
static CK_TILE_HOST_DEVICE constexpr auto GetSGradTQTBlockGemm()
Definition block_fmha_bwd_pipeline_trload_default_policy.hpp:90
static CK_TILE_HOST_DEVICE constexpr index_t GetSmemSizeSGrad()
Definition block_fmha_bwd_pipeline_trload_default_policy.hpp:1028
static CK_TILE_HOST_DEVICE constexpr auto MakeBiasTileDistribution()
Definition block_fmha_bwd_pipeline_trload_default_policy.hpp:382
static CK_TILE_HOST_DEVICE constexpr auto GetTransposedAlignmentQ() noexcept
Definition block_fmha_bwd_pipeline_trload_default_policy.hpp:189
static CK_TILE_HOST_DEVICE constexpr auto GetTransposedAlignmentX() noexcept
Definition block_fmha_bwd_pipeline_trload_default_policy.hpp:184
static CK_TILE_HOST_DEVICE constexpr auto MakeKTRegBlockDescriptor()
Definition block_fmha_bwd_pipeline_trload_default_policy.hpp:490
static CK_TILE_HOST_DEVICE constexpr auto GetAlignmentPostQGradAcc()
Definition block_fmha_bwd_pipeline_trload_default_policy.hpp:213
static CK_TILE_HOST_DEVICE constexpr auto GetSGradKTBlockGemm()
Definition block_fmha_bwd_pipeline_trload_default_policy.hpp:96
static CK_TILE_HOST_DEVICE constexpr auto MakeShuffledBiasTileDistribution()
Definition block_fmha_bwd_pipeline_trload_default_policy.hpp:972
static CK_TILE_HOST_DEVICE constexpr auto MakeLSEDLdsReadBlockDescriptor()
Definition block_fmha_bwd_pipeline_trload_default_policy.hpp:810
static CK_TILE_HOST_DEVICE constexpr auto MakeQTRegSliceBlockDescriptor()
Definition block_fmha_bwd_pipeline_trload_default_policy.hpp:732
static CK_TILE_HOST_DEVICE constexpr auto GetAlignmentQ()
Definition block_fmha_bwd_pipeline_trload_default_policy.hpp:140
static CK_TILE_HOST_DEVICE constexpr auto MakeXLdsReadBlockDescriptor()
Definition block_fmha_bwd_pipeline_trload_default_policy.hpp:646
static CK_TILE_HOST_DEVICE constexpr index_t GetSmemSizeLSE()
Definition block_fmha_bwd_pipeline_trload_default_policy.hpp:999
static CK_TILE_HOST_DEVICE constexpr index_t GetSmemSizeK()
Definition block_fmha_bwd_pipeline_trload_default_policy.hpp:992
static CK_TILE_DEVICE constexpr auto MakePTRegSliceBlockDescriptor()
Definition block_fmha_bwd_pipeline_trload_default_policy.hpp:909
static CK_TILE_HOST_DEVICE constexpr index_t GetSmemSizeQ()
Definition block_fmha_bwd_pipeline_trload_default_policy.hpp:985
static CK_TILE_HOST_DEVICE constexpr auto MakeVLdsWriteBlockDescriptor()
Definition block_fmha_bwd_pipeline_trload_default_policy.hpp:546
static CK_TILE_HOST_DEVICE constexpr auto MakeQRegSliceBlockDescriptor()
Definition block_fmha_bwd_pipeline_trload_default_policy.hpp:700
static CK_TILE_DEVICE constexpr auto MakeOGradTRegSliceBlockDescriptor()
Definition block_fmha_bwd_pipeline_trload_default_policy.hpp:874
static CK_TILE_HOST_DEVICE constexpr auto MakeLSEDLdsWriteBlockDescriptor()
Definition block_fmha_bwd_pipeline_trload_default_policy.hpp:794
static CK_TILE_HOST_DEVICE constexpr auto MakeKDramTileDistribution()
Definition block_fmha_bwd_pipeline_trload_default_policy.hpp:320
static CK_TILE_HOST_DEVICE constexpr index_t GetSmemSizeV()
Definition block_fmha_bwd_pipeline_trload_default_policy.hpp:1014
static CK_TILE_HOST_DEVICE constexpr auto MakeQLdsWriteBlockDescriptor()
Definition block_fmha_bwd_pipeline_trload_default_policy.hpp:553
static CK_TILE_HOST_DEVICE constexpr auto GetAlignmentVGrad()
Definition block_fmha_bwd_pipeline_trload_default_policy.hpp:177
static CK_TILE_HOST_DEVICE constexpr auto MakeSGradRegSliceBlockDescriptor()
Definition block_fmha_bwd_pipeline_trload_default_policy.hpp:940
static CK_TILE_HOST_DEVICE constexpr auto GetAlignmentKGrad()
Definition block_fmha_bwd_pipeline_trload_default_policy.hpp:171
static CK_TILE_HOST_DEVICE constexpr auto MakeVRegBlockDescriptor()
Definition block_fmha_bwd_pipeline_trload_default_policy.hpp:484
static CK_TILE_HOST_DEVICE constexpr index_t GetSmemSizeOGrad()
Definition block_fmha_bwd_pipeline_trload_default_policy.hpp:1021
static CK_TILE_HOST_DEVICE constexpr auto MakeOGradLdsReadBlockDescriptor()
Definition block_fmha_bwd_pipeline_trload_default_policy.hpp:692
static CK_TILE_HOST_DEVICE constexpr auto GetAlignmentV()
Definition block_fmha_bwd_pipeline_trload_default_policy.hpp:150
static CK_TILE_HOST_DEVICE constexpr auto MakeSGradLdsBlockDescriptor()
Definition block_fmha_bwd_pipeline_trload_default_policy.hpp:573
static CK_TILE_HOST_DEVICE constexpr auto MakeOGradRegSliceBlockDescriptor()
Definition block_fmha_bwd_pipeline_trload_default_policy.hpp:842
static CK_TILE_HOST_DEVICE constexpr auto MakeXLdsWriteBlockDescriptor()
Definition block_fmha_bwd_pipeline_trload_default_policy.hpp:524
static CK_TILE_HOST_DEVICE constexpr index_t GetSmemSizeD()
Definition block_fmha_bwd_pipeline_trload_default_policy.hpp:1008
static CK_TILE_HOST_DEVICE constexpr auto TransformXDramDescriptor(const tensor_descriptor< TD_TS... > &from_desc)
Definition block_fmha_bwd_pipeline_trload_default_policy.hpp:249
static CK_TILE_HOST_DEVICE constexpr auto GetOGradVBlockGemm()
Definition block_fmha_bwd_pipeline_trload_default_policy.hpp:55
static CK_TILE_HOST_DEVICE constexpr auto MakeQLdsReadBlockDescriptor()
Definition block_fmha_bwd_pipeline_trload_default_policy.hpp:685
static CK_TILE_HOST_DEVICE constexpr auto MakePreXDramTileDistribution()
Definition block_fmha_bwd_pipeline_trload_default_policy.hpp:388
Definition block_gemm_areg_breg_creg_v1_custom_policy.hpp:16
Definition block_gemm_areg_breg_creg_v1.hpp:18
Definition block_gemm_problem.hpp:18
Definition tile_gemm_shape.hpp:17
Definition tile/core/container/sequence.hpp:49
Definition tile/core/utility/functional.hpp:43
Definition tile/core/tensor/tensor_descriptor.hpp:34
CK_TILE_HOST_DEVICE constexpr auto get_length(number< IDim > idim) const
Definition tile/core/tensor/tensor_descriptor.hpp:86
Definition tensor_view.hpp:41
Definition tile_distribution_encoding.hpp:26
Definition tile/core/container/tuple.hpp:192