block_fmha_pipeline_qr_ks_vs_async_trload.hpp Source File

block_fmha_pipeline_qr_ks_vs_async_trload.hpp Source File#

Composable Kernel: block_fmha_pipeline_qr_ks_vs_async_trload.hpp Source File
block_fmha_pipeline_qr_ks_vs_async_trload.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include "ck_tile/core.hpp"
10
11namespace ck_tile {
12
13// This pipeline is qkv all located in LDS
14template <typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSAsyncTrloadDefaultPolicy>
16{
17 static constexpr auto I0 = number<0>{};
18 static constexpr auto I1 = number<1>{};
19
35
38 static constexpr bool kQLoadOnce = true; // if q_tile load whole block length (hdim) at once
39 static_assert(kQLoadOnce == Policy::QLoadOnce);
40 static constexpr bool kKLoadOnce = BlockFmhaShape::kM0 >= 64;
41
42 static constexpr index_t kBlockSize = Problem::kBlockSize;
43
44 static constexpr index_t kM0 = BlockFmhaShape::kM0;
45 static constexpr index_t kN0 = BlockFmhaShape::kN0;
46 static constexpr index_t kK0 = BlockFmhaShape::kK0;
47 static constexpr index_t kN1 = BlockFmhaShape::kN1;
48 static constexpr index_t kK1 = BlockFmhaShape::kK1;
49 static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim;
50 static constexpr index_t kSubQKHeaddim = BlockFmhaShape::kSubQKHeaddim;
51 static constexpr index_t kNWarp = BlockFmhaShape::Gemm0BlockWarps::at(I1);
52 static constexpr index_t kNXdl = BlockFmhaShape::Gemm0WarpTile::at(I1);
53
54 static_assert(kSubQKHeaddim <= 256, "hdim bigger than 256 is not suitable for this pipeline!");
55
56 // static_assert(Problem::kPadSeqLenQ == true && Problem::kPadHeadDimQ == true &&
57 // Problem::kPadHeadDimV == true);
58
59 static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
60 static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ;
61 static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK;
62 static constexpr bool kPadHeadDimQ =
63 Problem::kPadHeadDimQ; // support multiple of vector(like 8x)
64 static constexpr bool kPadHeadDimV =
65 Problem::kPadHeadDimV; // support multiple of vector(like 8x)
66
67 static constexpr bool kHasLogitsSoftCap = Problem::kHasLogitsSoftCap;
68 static constexpr bool kHasDropout = Problem::kHasDropout;
69 static constexpr auto BiasEnum = Problem::BiasEnum;
70 static constexpr bool kStoreLSE = Problem::kStoreLSE;
71 static constexpr bool kHasUnevenSplits = true;
72
73 static_assert((CK_TILE_FMHA_FWD_FAST_EXP2 &&
74 (kHasLogitsSoftCap && Problem::BiasEnum == BlockAttentionBiasEnum::NO_BIAS ||
77
78 // last dimension vector length used to create tensor view(and decide buffer_load vector length)
79 // ... together with tensor distribution. tensor dist should able to overwrite this
80 static constexpr index_t kAlignmentQ = Policy::template GetAlignmentQ<Problem>();
81 static constexpr index_t kAlignmentK = Policy::template GetAlignmentK<Problem>();
82 static constexpr index_t kAlignmentV = []() {
83 if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
84 return Policy::template GetAlignmentV<Problem>();
85 else
86 return kPadSeqLenK ? 1 : Policy::template GetAlignmentV<Problem>();
87 }();
88
89 static constexpr index_t kAlignmentOacc = Policy::template GetAlignmentO<Problem>();
90
91 static constexpr index_t kAlignmentBias =
92 kPadSeqLenK ? 1 : Policy::template GetAlignmentBias<Problem>();
93
94 static constexpr index_t kBlockPerCu = []() {
95 if constexpr(Problem::kBlockPerCu != -1)
96 return Problem::kBlockPerCu;
97 else
98 {
99 if constexpr(kQKHeaddim <= 32)
100 {
101 return 2;
102 }
103 else if constexpr(kQKHeaddim <= 64)
104 {
105 return 3;
106 }
107 else if constexpr(kQKHeaddim <= 128)
108 {
109 if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || kM0 >= 256)
110 return 1;
111 else
112 return 2;
113 }
114 else if constexpr(kQKHeaddim <= 256)
115 {
116 return 1;
117 }
118 else
119 {
120 return 1;
121 }
122 }
123 }();
124
125 static constexpr const char* name = "qr_async_trload";
126
128 {
129 return Policy::template GetSmemSize<Problem>();
130 }
131
132 // Decode
133 template <typename QDramBlockWindowTmp,
134 typename KDramBlockWindowTmp,
135 typename VDramBlockWindowTmp,
136 typename BiasDramBlockWindowTmp,
137 typename LSEaccDramBlockWindowTmp,
138 typename PositionEncoding>
140 operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
141 const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
142 const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
143 const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
144 LSEaccDramBlockWindowTmp& lse_acc_dram_window_tmp, // M0*1 tile
145 FmhaMask mask,
146 PositionEncoding position_encoding,
147 float scale_s,
148 void* smem_ptr) const
149 {
150 static_assert(
151 std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
152 std::is_same_v<KDataType, remove_cvref_t<typename KDramBlockWindowTmp::DataType>> &&
153 std::is_same_v<VDataType, remove_cvref_t<typename VDramBlockWindowTmp::DataType>>,
154 "wrong!");
155
156 static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[I0] &&
157 kSubQKHeaddim == QDramBlockWindowTmp{}.get_window_lengths()[I1] &&
158 kN0 == KDramBlockWindowTmp{}.get_window_lengths()[I0] &&
159 kK0 == KDramBlockWindowTmp{}.get_window_lengths()[I1] &&
160 kN1 == VDramBlockWindowTmp{}.get_window_lengths()[I0] &&
161 kK1 == VDramBlockWindowTmp{}.get_window_lengths()[I1] &&
162 kM0 == BiasDramBlockWindowTmp{}.get_window_lengths()[I0] &&
163 kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[I1],
164 "wrong!");
165 ignore = bias_dram_block_window_tmp;
166 ignore = position_encoding;
167 // Block GEMM
168 constexpr auto gemm_0 = Policy::template GetQKBlockGemm<Problem>();
169 constexpr auto gemm_1 = Policy::template GetPVBlockGemm<Problem>();
170
171 using SaccBlockTileType = decltype(gemm_0.MakeCBlockTile());
172 auto s_acc = SaccBlockTileType{};
173
174 // reduction function for softmax
175 const auto f_max = [](auto e0, auto e1) { return max(e0, e1); };
176 const auto f_sum = [](auto e0, auto e1) { return e0 + e1; };
177
178 using OaccBlockTileType = decltype(gemm_1.MakeCBlockTile());
179
180 auto o_acc = OaccBlockTileType{};
181
182 // infer Sacc, S, P, M, L, Oacc type
183 using SBlockTileType = decltype(cast_tile<SMPLComputeDataType>(o_acc));
184
185 using MLBlockTileType = decltype(block_tile_reduce<SMPLComputeDataType>(
186 SBlockTileType{}, sequence<1>{}, f_max, SMPLComputeDataType{0}));
187
188 // init M, L
189 auto m = MLBlockTileType{};
190 auto l = MLBlockTileType{};
191
192 clear_tile(o_acc);
194 clear_tile(l);
195
196 const auto q_origin = q_dram_block_window_tmp.get_window_origin();
197 const auto [logical_seqlen_k_start, logical_seqlen_k_end] =
198 mask.GetTileRangeAlongX(q_origin.at(I0), number<kM0>{}, number<kN0>{});
199
200 // check early exit if no work to do
201 if constexpr(FmhaMask::IsMasking || kPadSeqLenK || kHasUnevenSplits)
202 {
203 const index_t logical_num_total_loop =
204 integer_divide_ceil(logical_seqlen_k_end - logical_seqlen_k_start, kN0);
205 if(logical_num_total_loop <= 0)
206 {
207 if constexpr(kStoreLSE)
208 {
209 auto lse_acc =
210 make_static_distributed_tensor<LSEDataType>(m.get_tile_distribution());
211
213
214 store_tile(lse_acc_dram_window_tmp, lse_acc);
215 }
216
217 // Note: here occ are all cleard, return it
218 // Note: q loaded but no fence, ignore it.
219 return o_acc;
220 }
221 }
222
223 // Q tile in LDS
224 auto q_dram_window = make_tile_window(
225 q_dram_block_window_tmp, Policy::template MakeQDramTileDistribution<Problem>());
226
227 auto q_lds_write_view = make_tensor_view<address_space_enum::lds>(
228 static_cast<QDataType*>(smem_ptr), Policy::template MakeQLdsBlockDescriptor<Problem>());
229
230 auto q_lds_read_view = make_tensor_view<address_space_enum::lds>(
231 static_cast<QDataType*>(smem_ptr),
232 Policy::template MakeQLdsBlockDescriptor<Problem, true>());
233
234 auto q_lds_store_window =
235 make_tile_window(q_lds_write_view,
236 Policy::template MakeQLdsBlockDescriptor<Problem>().get_lengths(),
237 {0, 0});
238
239 auto q_lds_read_window =
240 make_tile_window(q_lds_read_view,
241 Policy::template MakeQLdsBlockDescriptor<Problem>().get_lengths(),
242 {0, 0},
243 Policy::template MakeQRegTileDistribution<Problem>());
244
245 async_load_tile(q_lds_store_window, q_dram_window);
246
247 // K tile in LDS
248 const index_t physical_seqlen_k_start = logical_seqlen_k_start;
249 const index_t physical_seqlen_k_end = logical_seqlen_k_end;
250 // make sure the first tile is completely located in page-block (page-block size should be
251 // divisible by kN0)
252 // relationship between each *_start variables: aligned_physical_seqlen_k_start <=
253 // physical_seqlen_k_start, logical_seqlen_k_start <= physical_seqlen_k_start
254 const index_t aligned_physical_seqlen_k_start = physical_seqlen_k_start;
255
256 auto k_dram_window =
257 make_tile_window(k_dram_block_window_tmp,
258 {physical_seqlen_k_start, 0},
259 Policy::template MakeKDramTileDistribution<Problem>());
260
261 auto k_lds_write_view = make_tensor_view<address_space_enum::lds>(
262 static_cast<KDataType*>(smem_ptr), Policy::template MakeKLdsBlockDescriptor<Problem>());
263 auto k_lds_read_view = make_tensor_view<address_space_enum::lds>(
264 static_cast<KDataType*>(smem_ptr),
265 Policy::template MakeKLdsBlockDescriptor<Problem, false, true>());
266
267 auto k_lds_write_window =
268 make_tile_window(k_lds_write_view,
269 Policy::template MakeKLdsBlockDescriptor<Problem>().get_lengths(),
270 {0, 0});
271 auto k_lds_read_window =
272 make_tile_window(k_lds_read_view,
274 {0, 0},
275 Policy::template MakeKRegTileDistribution<Problem>());
276
277 // S tile in LDS
279 reinterpret_cast<SaccDataType*>(reinterpret_cast<char*>(smem_ptr) +
280 Policy::template GetSmemSizeK<Problem>()),
281 Policy::template MakeSLdsBlockDescriptor<Problem>());
282 auto s_write_lds_window = make_tile_window(
283 s_lds, Policy::template MakeSLdsBlockDescriptor<Problem>().get_lengths(), {0, 0});
284 auto s_read_lds_window =
285 make_tile_window(s_lds,
286 Policy::template MakeSLdsBlockDescriptor<Problem>().get_lengths(),
287 {0, 0},
288 Policy::template MakeSRegTileDistribution<Problem>());
289
290 // V tile in LDS
291 auto v_dram_window =
292 make_tile_window(v_dram_block_window_tmp,
293 {physical_seqlen_k_start, 0},
294 Policy::template MakeVDramTileDistribution<Problem>());
295
296 auto v_lds_write_view = make_tensor_view<address_space_enum::lds>(
297 reinterpret_cast<VDataType*>(static_cast<char*>(smem_ptr) +
298 Policy::template GetSmemSizeK<Problem>() +
299 Policy::template GetSmemSizeS<Problem>()),
300 Policy::template MakeVLdsBlockDescriptor<Problem>());
301 auto v_lds_read_view = make_tensor_view<address_space_enum::lds>(
302 reinterpret_cast<VDataType*>(static_cast<char*>(smem_ptr) +
303 Policy::template GetSmemSizeK<Problem>() +
304 Policy::template GetSmemSizeS<Problem>()),
305 Policy::template MakeVLdsBlockDescriptor<Problem, true>());
306 auto v_lds_write_window =
307 make_tile_window(v_lds_write_view,
308 Policy::template MakeVLdsBlockDescriptor<Problem>().get_lengths(),
309 {0, 0});
310
311 auto v_lds_read_window =
312 make_tile_window(v_lds_read_view,
314 {0, 0},
315 Policy::template MakeVRegTileDistribution<Problem>());
316
318 auto q_tile = load_tile(q_lds_read_window);
319
320 const index_t num_total_loop =
321 integer_divide_ceil(physical_seqlen_k_end - aligned_physical_seqlen_k_start, kN0);
322
323 index_t i_total_loops = 0;
324 constexpr index_t k0_loops = kQKHeaddim / kK0;
325 constexpr index_t k1_loops = kN0 / kK1;
326
327 static_assert(1 <= k0_loops);
328 static_assert(1 <= k1_loops);
329
331 async_load_tile(k_lds_write_window, k_dram_window);
332
333 constexpr index_t k_vmem_insts = k_dram_window.get_num_of_access();
334 constexpr index_t v_vmem_insts = v_dram_window.get_num_of_access();
335
336 do
337 {
339 async_load_tile(v_lds_write_window, v_dram_window); // prefetch load v tile
340
341 // move V tile windows
342 move_tile_window(v_dram_window, {kN0, 0});
343
344 // STAGE 1, QK gemm
345 clear_tile(s_acc); // initialize C
346
347 if constexpr(1 < k0_loops)
348 {
349 static_for<0, k0_loops - 1, 1>{}([&](auto i_k0) {
350 if constexpr(i_k0 == 0)
351 {
353 }
354 else
355 {
357 }
358
359 auto k_tile = load_tile(k_lds_read_window);
360
361 gemm_0(s_acc,
362 get_slice_tile(q_tile,
364 sequence<kM0, (i_k0 + 1) * kK0>{}),
365 k_tile);
366
367 // loop over along the [K]ey head dimension
368 move_tile_window(k_dram_window, {0, kK0});
370 async_load_tile(k_lds_write_window, k_dram_window);
371 });
372 // move back to the origin
373 move_tile_window(k_dram_window, {0, -kK0 * (k0_loops - 1)});
374 }
375
376 if constexpr(k0_loops == 1)
377 {
379 }
380 else
381 {
383 }
384
385 auto k_tile = load_tile(k_lds_read_window);
386
387 gemm_0(s_acc,
388 get_slice_tile(q_tile,
389 sequence<0, (k0_loops - 1) * kK0>{},
391 k_tile);
392
393 if constexpr(kHasUnevenSplits)
394 {
395 if(i_total_loops == (num_total_loop - 1))
396 {
397 const auto k_origin =
398 make_tuple(kN0 * i_total_loops + physical_seqlen_k_start, 0);
399 set_tile_if(s_acc,
401 [&,
402 physical_seqlen_k_start_ = physical_seqlen_k_start,
403 physical_seqlen_k_end_ = physical_seqlen_k_end](auto tile_idx) {
404 const auto col = k_origin.at(I0) + tile_idx.at(I1);
405
406 {
407 return physical_seqlen_k_end_ <= col;
408 }
409 });
410 }
411 }
412
413 if constexpr(kPadSeqLenK || FmhaMask::IsMasking)
414 {
415 const auto k_origin = make_tuple(kN0 * i_total_loops + physical_seqlen_k_start, 0);
416
417 bool need_perpixel_check =
418 mask.IsEdgeTile(q_origin.at(I0), k_origin.at(I0), number<kM0>{}, number<kN0>{});
419 if(need_perpixel_check)
420 {
422 s_acc, -numeric<SMPLComputeDataType>::infinity(), [&](auto tile_idx) {
423 const auto row = q_origin.at(I0) + tile_idx.at(I0);
424 const auto col = k_origin.at(I0) + tile_idx.at(I1);
425 return mask.IsOutOfBound(row, col);
426 });
427 }
428 }
429
430 // move K tile windows after current status checked
431 // prefetch next-tile along [K]ey sequence length dimension
432 move_tile_window(k_dram_window, {kN0, 0});
433
435 async_load_tile(k_lds_write_window, k_dram_window);
436
437 // Gemm1
438 auto s_new = [&]() {
439 if constexpr(kNWarp > 1)
440 {
441 auto s = cast_tile<SMPLComputeDataType>(s_acc); // S{j}
442
443 store_tile(s_write_lds_window, s);
445 return load_tile(s_read_lds_window);
446 }
447 else
448 {
449 return cast_tile<SMPLComputeDataType>(s_acc); // S{j}
450 }
451 }();
452
454 s_new,
455 sequence<1>{},
456 f_max,
457 -numeric<SMPLComputeDataType>::infinity()); // m_local = rowmax(S{j})
458 // Set CrossWarp to false will trigger better strategy on gfx950, but will cause
459 // performance regression because of un-coexecutable packed math, silent it for now
461 m_local, f_max, bool_constant<false>{} /*, bool_constant<false>{}*/);
462
463 const auto m_old = m; // m{j-1}
465 [](auto& e0, auto e1, auto e2) { e0 = max(e1, e2); }, m, m_old, m_local); // m{j}
466
468 s_new.get_tile_distribution()); // Pcompute{j}
469
470 static const auto get_validated_m = [](SMPLComputeDataType raw_m) {
474 FmhaMask::IsMasking)
475 {
478 : raw_m;
479 }
480 else
481 {
482 return raw_m;
483 }
484 };
485
486 constexpr auto p_spans = decltype(p_compute)::get_distributed_spans();
487 sweep_tile_span(p_spans[I0], [&](auto idx0) {
488 constexpr auto i_idx = make_tuple(idx0);
489 auto row_max = scale_s * get_validated_m(m[i_idx]);
490 sweep_tile_span(p_spans[I1], [&](auto idx1) {
491 constexpr auto i_j_idx = make_tuple(idx0, idx1);
494 {
495 p_compute(i_j_idx) = exp2(s_new[i_j_idx] - get_validated_m(m[i_idx]));
496 }
497 else
498 {
499 if constexpr(kHasLogitsSoftCap)
500 {
501 p_compute(i_j_idx) = exp2(s_new[i_j_idx] - get_validated_m(m[i_idx]));
502 }
503 else
504 {
505 p_compute(i_j_idx) = exp2(scale_s * s_new[i_j_idx] - row_max);
506 }
507 }
508 });
509 });
510
512 p_compute, sequence<1>{}, f_sum, SMPLComputeDataType{0}); // rowsum(Pcompute{j})
513
515 rowsum_p, f_sum, bool_constant<false>{} /*, bool_constant<false>{}*/);
516
518 Policy::template MakePRegTileDistribution<Problem>());
519 p_tile.get_thread_buffer() = cast_tile<PDataType>(p_compute).get_thread_buffer();
520
521 // l{j}, Oacc{j}
522 constexpr auto o_spans = decltype(o_acc)::get_distributed_spans();
523 sweep_tile_span(o_spans[I0], [&](auto idx0) {
524 constexpr auto i_idx = make_tuple(idx0);
525 const auto tmp = [&]() {
528 {
529 return exp2(m_old[i_idx] - get_validated_m(m[i_idx]));
530 }
531 else
532 {
533 if constexpr(kHasLogitsSoftCap)
534 {
535 return exp2(m_old[i_idx] - get_validated_m(m[i_idx]));
536 }
537 else
538 {
539 auto row_max = scale_s * get_validated_m(m[i_idx]);
540 return exp2(scale_s * m_old[i_idx] - row_max);
541 }
542 }
543 }();
544 l(i_idx) = tmp * l[i_idx] + rowsum_p[i_idx];
545 sweep_tile_span(o_spans[I1], [&](auto idx1) {
546 constexpr auto i_j_idx = make_tuple(idx0, idx1);
547
548 o_acc(i_j_idx) *= tmp;
549 });
550 });
551
553
554 auto v_tile = load_tile_transpose(v_lds_read_window);
555
556 if constexpr(1 < k1_loops)
557 {
558 static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) {
559 gemm_1(o_acc,
560 get_slice_tile(p_tile,
562 sequence<kM0, (i_k1 + 1) * kK1>{}),
563 v_tile);
564
565 // loop over along the [V]alue Sequence length
566 move_tile_window(v_lds_read_window, {kK1, 0});
567 v_tile = load_tile_transpose(v_lds_read_window);
568 });
569 // move back to the origin
570 move_tile_window(v_lds_read_window, {-kK1 * (k1_loops - 1), 0});
571 }
572
573 gemm_1(o_acc,
574 get_slice_tile(p_tile,
575 sequence<0, (k1_loops - 1) * kK1>{},
577 v_tile);
578
579 } while(++i_total_loops < num_total_loop);
580
581 if constexpr(kStoreLSE)
582 {
583 // store lse acc
584 auto lse_acc = make_static_distributed_tensor<LSEDataType>(m.get_tile_distribution());
585
586 constexpr auto lse_acc_spans = decltype(lse_acc)::get_distributed_spans();
587 sweep_tile_span(lse_acc_spans[I0], [&, m_ = m, l_ = l](auto idx0) {
588 constexpr auto i_idx = make_tuple(idx0);
591 {
592 lse_acc(i_idx) = m_[i_idx] / C_LOG2E + log(l_[i_idx]);
593 }
594 else
595 {
596 if constexpr(kHasLogitsSoftCap)
597 {
598 lse_acc(i_idx) = m_[i_idx] / C_LOG2E + log(l_[i_idx]);
599 }
600 else
601 {
602 lse_acc(i_idx) = m_[i_idx] * scale_s / C_LOG2E + log(l_[i_idx]);
603 }
604 }
605 });
606
607 store_tile(lse_acc_dram_window_tmp, lse_acc);
608 }
609
610 // finally, O
611 constexpr auto o_spans = decltype(o_acc)::get_distributed_spans();
612
613 sweep_tile_span(o_spans[I0], [&](auto idx0) {
614 constexpr auto i_idx = make_tuple(idx0);
615 const auto tmp = [&]() {
617 FmhaMask::IsMasking)
618 {
619 return l[i_idx] == 0.f ? 0.f : 1 / l[i_idx];
620 }
621 else
622 return 1 / l[i_idx];
623 }();
624 sweep_tile_span(o_spans[I1], [&](auto idx1) {
625 constexpr auto i_j_idx = make_tuple(idx0, idx1);
626 o_acc(i_j_idx) *= tmp;
627 });
628 });
629
630 return o_acc;
631 }
632
633 // Prefill, double lds
634 template <typename QDramBlockWindowTmp,
635 typename KDramBlockWindowTmp,
636 typename VDramBlockWindowTmp,
637 typename BiasDramBlockWindowTmp,
638 typename LSEaccDramBlockWindowTmp,
639 typename PositionEncoding>
641 operator()(const QDramBlockWindowTmp& __restrict__ q_dram_block_window_tmp, // M0*K0 tile
642 const KDramBlockWindowTmp& __restrict__ k_dram_block_window_tmp, // N0*K0 tile
643 const VDramBlockWindowTmp& __restrict__ v_dram_block_window_tmp, // N1*K1 tile
644 const BiasDramBlockWindowTmp& __restrict__ bias_dram_block_window_tmp, // M0*N0 tile
645 LSEaccDramBlockWindowTmp& __restrict__ lse_acc_dram_window_tmp, // M0*1 tile
646 FmhaMask mask,
647 PositionEncoding position_encoding,
648 float scale_s,
649 void* __restrict__ smem_ptrk0,
650 void* __restrict__ smem_ptrk1,
651 void* __restrict__ smem_ptrv0,
652 void* __restrict__ smem_ptrv1) const
653 {
654 static_assert(
655 std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
656 std::is_same_v<KDataType, remove_cvref_t<typename KDramBlockWindowTmp::DataType>> &&
657 std::is_same_v<VDataType, remove_cvref_t<typename VDramBlockWindowTmp::DataType>>,
658 "wrong!");
659
660 static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[I0] &&
661 kSubQKHeaddim == QDramBlockWindowTmp{}.get_window_lengths()[I1] &&
662 kN0 == KDramBlockWindowTmp{}.get_window_lengths()[I0] &&
663 kK0 == KDramBlockWindowTmp{}.get_window_lengths()[I1] &&
664 kN1 == VDramBlockWindowTmp{}.get_window_lengths()[I0] &&
665 kK1 == VDramBlockWindowTmp{}.get_window_lengths()[I1] &&
666 kM0 == BiasDramBlockWindowTmp{}.get_window_lengths()[I0] &&
667 kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[I1],
668 "wrong!");
669 ignore = bias_dram_block_window_tmp;
670 ignore = position_encoding;
671
672 // Block GEMM
673 constexpr auto gemm_0 = Policy::template GetQKBlockGemm<Problem>();
674 constexpr auto gemm_1 = Policy::template GetPVBlockGemm<Problem>();
675
676 using SaccBlockTileType = decltype(gemm_0.MakeCBlockTile());
677 auto s_acc = SaccBlockTileType{};
678
679 // reduction function for softmax
680 const auto f_max = [](auto e0, auto e1) { return max(e0, e1); };
681 const auto f_sum = [](auto e0, auto e1) { return e0 + e1; };
682
683 using OaccBlockTileType = decltype(gemm_1.MakeCBlockTile());
684
685 auto o_acc = OaccBlockTileType{};
686
687 // infer Sacc, S, P, M, L, Oacc type
688 using SBlockTileType = decltype(cast_tile<SMPLComputeDataType>(o_acc));
689
690 using MLBlockTileType = decltype(block_tile_reduce<SMPLComputeDataType>(
691 SBlockTileType{}, sequence<1>{}, f_max, SMPLComputeDataType{0}));
692
693 // init M, L
694 auto m = MLBlockTileType{};
695 auto l = MLBlockTileType{};
696
697 clear_tile(o_acc);
699 clear_tile(l);
700
701 const auto q_origin = q_dram_block_window_tmp.get_window_origin();
702 const auto [logical_seqlen_k_start, logical_seqlen_k_end] =
703 mask.GetTileRangeAlongX(q_origin.at(I0), number<kM0>{}, number<kN0>{});
704
705 // check early exit if no work to do
706 if constexpr(FmhaMask::IsMasking || kPadSeqLenK || kHasUnevenSplits)
707 {
708 const index_t logical_num_total_loop =
709 integer_divide_ceil(logical_seqlen_k_end - logical_seqlen_k_start, kN0);
710 if(logical_num_total_loop <= 0)
711 {
712 if constexpr(kStoreLSE)
713 {
714 auto lse_acc =
715 make_static_distributed_tensor<LSEDataType>(m.get_tile_distribution());
716
718
719 store_tile(lse_acc_dram_window_tmp, lse_acc);
720 }
721
722 // Note: here occ are all cleard, return it
723 // Note: q loaded but no fence, ignore it.
724 return o_acc;
725 }
726 }
727
728 // Q tile in LDS
729 auto q_dram_window = make_tile_window(
730 q_dram_block_window_tmp, Policy::template MakeQDramTileDistribution<Problem>());
731
732 auto q_lds_write_view = make_tensor_view<address_space_enum::lds>(
733 static_cast<QDataType*>(smem_ptrk0),
734 Policy::template MakeQLdsBlockDescriptor<Problem>());
735
736 auto q_lds_read_view = make_tensor_view<address_space_enum::lds>(
737 static_cast<QDataType*>(smem_ptrk0),
738 Policy::template MakeQLdsBlockDescriptor<Problem, true>());
739
740 auto q_lds_store_window =
741 make_tile_window(q_lds_write_view,
742 Policy::template MakeQLdsBlockDescriptor<Problem>().get_lengths(),
743 {0, 0});
744
745 auto q_lds_read_window =
746 make_tile_window(q_lds_read_view,
747 Policy::template MakeQLdsBlockDescriptor<Problem>().get_lengths(),
748 {0, 0},
749 Policy::template MakeQRegTileDistribution<Problem>());
750
751 async_load_tile(q_lds_store_window, q_dram_window);
753 auto q_tile = load_tile(q_lds_read_window);
754
755 // K tile in LDS
756 const index_t physical_seqlen_k_start = logical_seqlen_k_start;
757 const index_t physical_seqlen_k_end = logical_seqlen_k_end;
758 // make sure the first tile is completely located in page-block (page-block size should be
759 // divisible by kN0)
760 // relationship between each *_start variables: aligned_physical_seqlen_k_start <=
761 // physical_seqlen_k_start, logical_seqlen_k_start <= physical_seqlen_k_start
762 const index_t aligned_physical_seqlen_k_start = physical_seqlen_k_start;
763
764 auto k_dram_window =
765 make_tile_window(k_dram_block_window_tmp,
766 {physical_seqlen_k_start, 0},
767 Policy::template MakeKDramTileDistribution<Problem, true>());
768
769 auto k_lds_write_view = make_tensor_view<address_space_enum::lds>(
770 static_cast<KDataType* __restrict__>(smem_ptrk0),
771 Policy::template MakeKLdsBlockDescriptor<Problem, true>());
772
773 auto k_lds_read_view = make_tensor_view<address_space_enum::lds>(
774 static_cast<KDataType* __restrict__>(smem_ptrk0),
775 Policy::template MakeKLdsBlockDescriptor<Problem, true, true>());
776
777 auto k_lds_write_window =
778 make_tile_window(k_lds_write_view,
779 Policy::template MakeKLdsBlockDescriptor<Problem>().get_lengths(),
780 {0, 0});
781
782 auto k_lds_read_window =
783 make_tile_window(k_lds_read_view,
785 {0, 0},
786 Policy::template MakeKRegTileDistribution<Problem>());
787
788 // S tile in LDS
790 reinterpret_cast<SaccDataType*>(reinterpret_cast<char*>(smem_ptrk0) +
791 Policy::template GetSmemSizeK<Problem>()),
792 Policy::template MakeSLdsBlockDescriptor<Problem>());
793 auto s_write_lds_window = make_tile_window(
794 s_lds, Policy::template MakeSLdsBlockDescriptor<Problem>().get_lengths(), {0, 0});
795 auto s_read_lds_window =
796 make_tile_window(s_lds,
797 Policy::template MakeSLdsBlockDescriptor<Problem>().get_lengths(),
798 {0, 0},
799 Policy::template MakeSRegTileDistribution<Problem>());
800
801 // V tile in LDS
802 auto v_dram_window =
803 make_tile_window(v_dram_block_window_tmp,
804 {physical_seqlen_k_start, 0},
805 Policy::template MakeVDramTileDistribution<Problem>());
806
807 auto v_lds_write_view = make_tensor_view<address_space_enum::lds>(
808 reinterpret_cast<VDataType* __restrict__>(static_cast<char*>(smem_ptrv0)),
809 Policy::template MakeVLdsBlockDescriptor<Problem>());
810
811 auto v_lds_read_view = make_tensor_view<address_space_enum::lds>(
812 reinterpret_cast<VDataType* __restrict__>(static_cast<char*>(smem_ptrv0)),
813 Policy::template MakeVLdsBlockDescriptor<Problem, true>());
814
815 auto v_lds_write_window =
816 make_tile_window(v_lds_write_view,
817 Policy::template MakeVLdsBlockDescriptor<Problem>().get_lengths(),
818 {0, 0});
819
820 auto v_lds_read_window =
821 make_tile_window(v_lds_read_view,
823 {0, 0},
824 Policy::template MakeVRegTileDistribution<Problem>());
825
826 // block_sync_lds_direct_load<0>();
827 // auto q_tile = load_tile(q_lds_read_window);
828
829 const index_t num_total_loop =
830 integer_divide_ceil(physical_seqlen_k_end - aligned_physical_seqlen_k_start, kN0);
831
832 index_t i_total_loops = 0;
833 constexpr index_t k0_loops = kQKHeaddim / kK0;
834 constexpr index_t k1_loops = kN0 / kK1;
835
836 static_assert(1 <= k0_loops);
837 static_assert(1 <= k1_loops);
839 async_load_tile(k_lds_write_window, k_dram_window);
840 async_load_tile(v_lds_write_window, v_dram_window);
841
842 move_tile_window(k_dram_window, {kN0, 0});
843 k_lds_write_window.set_bottom_tensor_view_data_ptr(
844 static_cast<KDataType* __restrict__>(smem_ptrk1));
845 async_load_tile(k_lds_write_window, k_dram_window);
846
847 constexpr index_t k_vmem_insts = k_dram_window.get_num_of_access();
848 constexpr index_t v_vmem_insts = v_dram_window.get_num_of_access();
849
850 constexpr index_t k_lds_insts = k_lds_read_window.get_num_of_access();
851 constexpr index_t v_lds_insts = v_lds_read_window.get_num_of_access();
852
854 auto k_tile = load_tile(k_lds_read_window);
855
856 __builtin_amdgcn_sched_barrier(0);
857
858 auto mainloop = [&](KDataType* __restrict__ k_lds_write_ptr,
859 KDataType* __restrict__ k_lds_read_ptr,
860 KDataType* __restrict__ v_lds_write_ptr,
861 KDataType* __restrict__ v_lds_read_ptr) {
862 // move V tile windows
864 move_tile_window(v_dram_window, {kN0, 0});
865 v_lds_write_window.set_bottom_tensor_view_data_ptr(v_lds_write_ptr);
866 async_load_tile(v_lds_write_window, v_dram_window);
867
868 // STAGE 1, QK gemm
869 clear_tile(s_acc); // initialize C
870
871 if constexpr(1 < k0_loops)
872 {
873 static_for<0, k0_loops - 1, 1>{}([&](auto i_k0) {
874 // loop over along the [K]ey head dimension
875 move_tile_window(k_lds_read_window, {0, kK0});
876 auto k_tile_switch = load_tile(k_lds_read_window);
877
878 gemm_0(s_acc,
879 get_slice_tile(q_tile,
881 sequence<kM0, (i_k0 + 1) * kK0>{}),
882 k_tile);
883
884 k_tile = k_tile_switch;
885 });
886 // move back to the origin
887 move_tile_window(k_lds_read_window, {0, -kK0 * (k0_loops - 1)});
888 }
889
890 gemm_0(s_acc,
891 get_slice_tile(q_tile,
892 sequence<0, (k0_loops - 1) * kK0>{},
894 k_tile);
895
897 v_lds_read_window.set_bottom_tensor_view_data_ptr(v_lds_read_ptr);
898 auto v_tile = load_tile_transpose(v_lds_read_window);
899
900 if constexpr(kHasUnevenSplits)
901 {
902 if(i_total_loops == (num_total_loop - 1))
903 {
904 const auto k_origin =
905 make_tuple(kN0 * i_total_loops + physical_seqlen_k_start, 0);
906 set_tile_if(s_acc,
908 [&,
909 physical_seqlen_k_start_ = physical_seqlen_k_start,
910 physical_seqlen_k_end_ = physical_seqlen_k_end](auto tile_idx) {
911 const auto col = k_origin.at(I0) + tile_idx.at(I1);
912
913 {
914 return physical_seqlen_k_end_ <= col;
915 }
916 });
917 }
918 }
919
920 if constexpr(kPadSeqLenK || FmhaMask::IsMasking)
921 {
922 const auto k_origin = make_tuple(kN0 * i_total_loops + physical_seqlen_k_start, 0);
923
924 bool need_perpixel_check =
925 mask.IsEdgeTile(q_origin.at(I0), k_origin.at(I0), number<kM0>{}, number<kN0>{});
926 if(need_perpixel_check)
927 {
929 s_acc, -numeric<SMPLComputeDataType>::infinity(), [&](auto tile_idx) {
930 const auto row = q_origin.at(I0) + tile_idx.at(I0);
931 const auto col = k_origin.at(I0) + tile_idx.at(I1);
932 return mask.IsOutOfBound(row, col);
933 });
934 }
935 }
936
937 // Gemm1
938 auto s_new = [&]() {
939 if constexpr(kNWarp > 1)
940 {
941 auto s = cast_tile<SMPLComputeDataType>(s_acc); // S{j}
942
943 store_tile(s_write_lds_window, s);
945 return load_tile(s_read_lds_window);
946 }
947 else
948 {
949 return cast_tile<SMPLComputeDataType>(s_acc); // S{j}
950 }
951 }();
952
954 s_new,
955 sequence<1>{},
956 f_max,
957 -numeric<SMPLComputeDataType>::infinity()); // m_local = rowmax(S{j})
959 m_local, f_max, bool_constant<false>{} /*, bool_constant<false>{}*/);
960
961 static_for<0, 12, 1>{}([&](auto i) {
962 ignore = i;
963 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
964 __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS_READ
965 });
966
967 static_for<0, 4, 1>{}([&](auto i) {
968 ignore = i;
969 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
970 __builtin_amdgcn_sched_group_barrier(0x100, 2, 0); // DS_READ
971 });
972
973 const auto m_old = m; // m{j-1}
975 [](auto& e0, auto e1, auto e2) { e0 = max(e1, e2); }, m, m_old, m_local); // m{j}
976
978 s_new.get_tile_distribution()); // Pcompute{j}
979
980 static const auto get_validated_m = [](SMPLComputeDataType raw_m) {
984 FmhaMask::IsMasking)
985 {
988 : raw_m;
989 }
990 else
991 {
992 return raw_m;
993 }
994 };
995
996 constexpr auto p_spans = decltype(p_compute)::get_distributed_spans();
997 sweep_tile_span(p_spans[I0], [&](auto idx0) {
998 constexpr auto i_idx = make_tuple(idx0);
999 auto row_max = scale_s * get_validated_m(m[i_idx]);
1000 sweep_tile_span(p_spans[I1], [&](auto idx1) {
1001 constexpr auto i_j_idx = make_tuple(idx0, idx1);
1004 {
1005 p_compute(i_j_idx) = exp2(s_new[i_j_idx] - get_validated_m(m[i_idx]));
1006 }
1007 else
1008 {
1009 if constexpr(kHasLogitsSoftCap)
1010 {
1011 p_compute(i_j_idx) = exp2(s_new[i_j_idx] - get_validated_m(m[i_idx]));
1012 }
1013 else
1014 {
1015 p_compute(i_j_idx) = exp2(scale_s * s_new[i_j_idx] - row_max);
1016 }
1017 }
1018 });
1019 });
1020
1022 p_compute, sequence<1>{}, f_sum, SMPLComputeDataType{0}); // rowsum(Pcompute{j})
1023
1025 rowsum_p, f_sum, bool_constant<false>{} /*, bool_constant<false>{}*/);
1026
1028 Policy::template MakePRegTileDistribution<Problem>());
1029 p_tile.get_thread_buffer() = cast_tile<PDataType>(p_compute).get_thread_buffer();
1030
1031 // l{j}, Oacc{j}
1032 constexpr auto o_spans = decltype(o_acc)::get_distributed_spans();
1033 sweep_tile_span(o_spans[I0], [&](auto idx0) {
1034 constexpr auto i_idx = make_tuple(idx0);
1035 const auto tmp = [&]() {
1038 {
1039 return exp2(m_old[i_idx] - get_validated_m(m[i_idx]));
1040 }
1041 else
1042 {
1043 if constexpr(kHasLogitsSoftCap)
1044 {
1045 return exp2(m_old[i_idx] - get_validated_m(m[i_idx]));
1046 }
1047 else
1048 {
1049 auto row_max = scale_s * get_validated_m(m[i_idx]);
1050 return exp2(scale_s * m_old[i_idx] - row_max);
1051 }
1052 }
1053 }();
1054 l(i_idx) = tmp * l[i_idx] + rowsum_p[i_idx];
1055 sweep_tile_span(o_spans[I1], [&](auto idx1) {
1056 constexpr auto i_j_idx = make_tuple(idx0, idx1);
1057
1058 o_acc(i_j_idx) *= tmp;
1059 });
1060 });
1061
1063 move_tile_window(k_dram_window, {kN0, 0});
1064 k_lds_write_window.set_bottom_tensor_view_data_ptr(k_lds_write_ptr);
1065 async_load_tile(k_lds_write_window, k_dram_window);
1066
1067 if constexpr(1 < k1_loops)
1068 {
1069 static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) {
1070 // loop over along the [V]alue Sequence length
1071 move_tile_window(v_lds_read_window, {kK1, 0});
1072 auto v_tile_switch = load_tile_transpose(v_lds_read_window);
1073
1074 gemm_1(o_acc,
1075 get_slice_tile(p_tile,
1077 sequence<kM0, (i_k1 + 1) * kK1>{}),
1078 v_tile);
1079
1080 v_tile = v_tile_switch;
1081 });
1082 // move back to the origin
1083 move_tile_window(v_lds_read_window, {-kK1 * (k1_loops - 1), 0});
1084 }
1085
1086 gemm_1(o_acc,
1087 get_slice_tile(p_tile,
1088 sequence<0, (k1_loops - 1) * kK1>{},
1090 v_tile);
1091
1093 k_lds_read_window.set_bottom_tensor_view_data_ptr(k_lds_read_ptr);
1094 k_tile = load_tile(k_lds_read_window);
1095
1096 static_for<0, 12, 1>{}([&](auto i) {
1097 ignore = i;
1098 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
1099 __builtin_amdgcn_sched_group_barrier(0x100, 2, 0); // DS_READ
1100 });
1101
1102 static_for<0, 4, 1>{}([&](auto i) {
1103 ignore = i;
1104 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
1105 __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS_READ
1106 });
1107 }; // mainloop
1108
1109 do
1110 {
1111 bool is_even_loop = i_total_loops % 2 == 0;
1112 auto k_lds_write_ptr = is_even_loop ? static_cast<KDataType* __restrict__>(smem_ptrk0)
1113 : static_cast<KDataType* __restrict__>(smem_ptrk1);
1114 auto k_lds_read_ptr = is_even_loop ? static_cast<KDataType* __restrict__>(smem_ptrk1)
1115 : static_cast<KDataType* __restrict__>(smem_ptrk0);
1116 auto v_lds_write_ptr = is_even_loop ? static_cast<VDataType* __restrict__>(smem_ptrv1)
1117 : static_cast<VDataType* __restrict__>(smem_ptrv0);
1118 auto v_lds_read_ptr = is_even_loop ? static_cast<VDataType* __restrict__>(smem_ptrv0)
1119 : static_cast<VDataType* __restrict__>(smem_ptrv1);
1120 mainloop(k_lds_write_ptr, k_lds_read_ptr, v_lds_write_ptr, v_lds_read_ptr);
1121 i_total_loops++;
1122 } while(i_total_loops < num_total_loop);
1123
1124 if constexpr(kStoreLSE)
1125 {
1126 // store lse acc
1127 auto lse_acc = make_static_distributed_tensor<LSEDataType>(m.get_tile_distribution());
1128
1129 constexpr auto lse_acc_spans = decltype(lse_acc)::get_distributed_spans();
1130 sweep_tile_span(lse_acc_spans[I0], [&, m_ = m, l_ = l](auto idx0) {
1131 constexpr auto i_idx = make_tuple(idx0);
1134 {
1135 lse_acc(i_idx) = m_[i_idx] / C_LOG2E + log(l_[i_idx]);
1136 }
1137 else
1138 {
1139 if constexpr(kHasLogitsSoftCap)
1140 {
1141 lse_acc(i_idx) = m_[i_idx] / C_LOG2E + log(l_[i_idx]);
1142 }
1143 else
1144 {
1145 lse_acc(i_idx) = m_[i_idx] * scale_s / C_LOG2E + log(l_[i_idx]);
1146 }
1147 }
1148 });
1149
1150 store_tile(lse_acc_dram_window_tmp, lse_acc);
1151 }
1152
1153 // finally, O
1154 constexpr auto o_spans = decltype(o_acc)::get_distributed_spans();
1155
1156 sweep_tile_span(o_spans[I0], [&](auto idx0) {
1157 constexpr auto i_idx = make_tuple(idx0);
1158 const auto tmp = [&]() {
1160 FmhaMask::IsMasking)
1161 {
1162 return l[i_idx] == 0.f ? 0.f : 1 / l[i_idx];
1163 }
1164 else
1165 return 1 / l[i_idx];
1166 }();
1167 sweep_tile_span(o_spans[I1], [&](auto idx1) {
1168 constexpr auto i_j_idx = make_tuple(idx0, idx1);
1169 o_acc(i_j_idx) *= tmp;
1170 });
1171 });
1172
1173 return o_acc;
1174 }
1175};
1176
1177} // namespace ck_tile
#define CK_TILE_FMHA_FWD_FAST_EXP2
Definition config.hpp:234
#define CK_TILE_HOST_DEVICE
Definition config.hpp:42
Definition tile/core/algorithm/cluster_descriptor.hpp:13
CK_TILE_DEVICE bfloat16_t log(bfloat16_t x)
Definition bfloat16.hpp:428
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
CK_TILE_DEVICE auto async_load_tile(LdsTileWindow_ &&lds_tile, const TileWindow_ &tile_window, number< i_access >={}, bool_constant< oob_conditional_check >={})
Definition load_tile.hpp:119
CK_TILE_DEVICE void set_tile(DstrTensors &dstr_tensor, const T &value)
Definition tile_elementwise.hpp:95
CK_TILE_HOST_DEVICE constexpr auto make_tensor_view(DataType *__restrict__ p, const tensor_descriptor< Ts... > &desc)
Definition tensor_view.hpp:452
CK_TILE_DEVICE constexpr auto get_slice_tile(const tile_window_with_static_lengths< BottomTensorView_, WindowLengths_ > &tile, sequence< SliceBegins... > slice_begins, sequence< SliceEnds... > slice_ends)
Definition slice_tile.hpp:23
@ ALIBI
Definition block_attention_bias_enum.hpp:15
@ NO_BIAS
Definition block_attention_bias_enum.hpp:13
@ ELEMENTWISE_BIAS
Definition block_attention_bias_enum.hpp:14
constant< b > bool_constant
Definition tile/core/numeric/integral_constant.hpp:43
CK_TILE_DEVICE void block_tile_reduce_sync(AccDistributedTensor_ &acc_tensor, const ReduceFunc &reduce_func, bool_constant< WithBroadcast >={}, bool_constant< CrossWarp >={})
Definition block_reduce.hpp:21
CK_TILE_DEVICE void tile_elementwise_inout(const InOutElementFunc &inout_element_func, InOutDstrTensors &... inout_dstr_tensors)
Definition tile_elementwise.hpp:23
CK_TILE_DEVICE void block_sync_lds_direct_load()
Definition arch.hpp:288
CK_TILE_DEVICE void block_sync_lds()
Definition arch.hpp:282
CK_TILE_HOST_DEVICE constexpr auto make_static_distributed_tensor(const StaticTileDistribution &)
Definition static_distributed_tensor.hpp:142
CK_TILE_DEVICE auto load_tile_transpose(const tile_window_with_static_distribution< BottomTensorView_, WindowLengths_, TileDistribution_, NumCoord > &tile_window)
transpose loads tile from a tensor and returns the resulting tensor with a new (transposed) tile dist...
Definition load_tile_transpose.hpp:403
constant< v > number
Definition tile/core/numeric/integral_constant.hpp:37
CK_TILE_DEVICE constexpr auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition null_tile_window.hpp:75
CK_TILE_DEVICE auto cast_tile(const SrcTensor &src_tensor)
Definition tile_elementwise.hpp:327
constexpr detail::ignore_t ignore
Definition tile/core/utility/ignore.hpp:20
CK_TILE_DEVICE void block_tile_reduce(AccDistributedTensor_ &acc_tensor, const InDistributedTensor_ &in_tensor, sequence< InReduceDims... >, const ReduceFunc &reduce_func)
Definition block_reduce.hpp:191
CK_TILE_HOST_DEVICE constexpr auto integer_divide_ceil(X x, Y y)
Definition tile/core/numeric/math.hpp:149
CK_TILE_DEVICE void sweep_tile_span(TileDistributedSpan_, const F &f)
Definition sweep_tile.hpp:20
CK_TILE_DEVICE void move_tile_window(null_tile_window< WindowLengths > &, const typename null_tile_window< WindowLengths >::BottomTensorIndex &)
Definition null_tile_window.hpp:95
CK_TILE_HOST_DEVICE void set_tile_if(static_distributed_tensor< DataType, StaticTileDistribution > &out_tensor, DataType value, XIndicesPredicate predicate)
Definition static_distributed_tensor.hpp:175
CK_TILE_HOST_DEVICE constexpr T max(T x)
Definition tile/core/numeric/math.hpp:161
CK_TILE_DEVICE void store_tile(tile_window_with_static_lengths< BottomTensorView_, WindowLengths_ > &tile_window_tmp, const static_distributed_tensor< DataType_, TileDistribution_ > &dstr_tensor)
Definition store_tile.hpp:23
int32_t index_t
Definition integer.hpp:9
CK_TILE_DEVICE void clear_tile(DstrTensors &dstr_tensor)
Definition tile_elementwise.hpp:177
CK_TILE_HOST_DEVICE constexpr Y type_convert(X x)
Definition tile/core/numeric/type_convert.hpp:29
CK_TILE_DEVICE auto load_tile(const TileWindow_ &tile_window, number< i_access >={}, bool_constant< oob_conditional_check >={})
Definition load_tile.hpp:22
CK_TILE_DEVICE bfloat16_t exp2(bfloat16_t x)
Definition bfloat16.hpp:425
CK_TILE_HOST_DEVICE constexpr auto make_tuple(Xs &&... xs)
Definition tile/core/container/tuple.hpp:360
Definition block_fmha_pipeline_qr_ks_vs_async_trload.hpp:16
remove_cvref_t< typename Problem::BiasDataType > BiasDataType
Definition block_fmha_pipeline_qr_ks_vs_async_trload.hpp:27
static constexpr bool kPadHeadDimV
Definition block_fmha_pipeline_qr_ks_vs_async_trload.hpp:64
static constexpr index_t kN0
Definition block_fmha_pipeline_qr_ks_vs_async_trload.hpp:45
remove_cvref_t< typename Problem::VDataType > VDataType
Definition block_fmha_pipeline_qr_ks_vs_async_trload.hpp:24
remove_cvref_t< typename Problem::QDataType > QDataType
Definition block_fmha_pipeline_qr_ks_vs_async_trload.hpp:22
static constexpr bool kHasDropout
Definition block_fmha_pipeline_qr_ks_vs_async_trload.hpp:68
static constexpr bool kPadHeadDimQ
Definition block_fmha_pipeline_qr_ks_vs_async_trload.hpp:62
remove_cvref_t< typename Problem::SMPLComputeDataType > SMPLComputeDataType
Definition block_fmha_pipeline_qr_ks_vs_async_trload.hpp:26
static constexpr bool kPadSeqLenK
Definition block_fmha_pipeline_qr_ks_vs_async_trload.hpp:61
remove_cvref_t< typename Problem::KDataType > KDataType
Definition block_fmha_pipeline_qr_ks_vs_async_trload.hpp:23
static constexpr index_t kN1
Definition block_fmha_pipeline_qr_ks_vs_async_trload.hpp:47
static constexpr index_t kAlignmentQ
Definition block_fmha_pipeline_qr_ks_vs_async_trload.hpp:80
static constexpr index_t kNXdl
Definition block_fmha_pipeline_qr_ks_vs_async_trload.hpp:52
static constexpr index_t kAlignmentK
Definition block_fmha_pipeline_qr_ks_vs_async_trload.hpp:81
static constexpr index_t kBlockPerCu
Definition block_fmha_pipeline_qr_ks_vs_async_trload.hpp:94
static constexpr index_t kQKHeaddim
Definition block_fmha_pipeline_qr_ks_vs_async_trload.hpp:49
static constexpr index_t kNWarp
Definition block_fmha_pipeline_qr_ks_vs_async_trload.hpp:51
remove_cvref_t< typename Problem::SaccDataType > SaccDataType
Definition block_fmha_pipeline_qr_ks_vs_async_trload.hpp:25
static constexpr bool kStoreLSE
Definition block_fmha_pipeline_qr_ks_vs_async_trload.hpp:70
static constexpr bool kQLoadOnce
Definition block_fmha_pipeline_qr_ks_vs_async_trload.hpp:38
static constexpr index_t kAlignmentV
Definition block_fmha_pipeline_qr_ks_vs_async_trload.hpp:82
static constexpr bool kHasUnevenSplits
Definition block_fmha_pipeline_qr_ks_vs_async_trload.hpp:71
static constexpr index_t kSubQKHeaddim
Definition block_fmha_pipeline_qr_ks_vs_async_trload.hpp:50
remove_cvref_t< typename Problem::RandValOutputDataType > RandValOutputDataType
Definition block_fmha_pipeline_qr_ks_vs_async_trload.hpp:28
CK_TILE_HOST_DEVICE auto operator()(const QDramBlockWindowTmp &__restrict__ q_dram_block_window_tmp, const KDramBlockWindowTmp &__restrict__ k_dram_block_window_tmp, const VDramBlockWindowTmp &__restrict__ v_dram_block_window_tmp, const BiasDramBlockWindowTmp &__restrict__ bias_dram_block_window_tmp, LSEaccDramBlockWindowTmp &__restrict__ lse_acc_dram_window_tmp, FmhaMask mask, PositionEncoding position_encoding, float scale_s, void *__restrict__ smem_ptrk0, void *__restrict__ smem_ptrk1, void *__restrict__ smem_ptrv0, void *__restrict__ smem_ptrv1) const
Definition block_fmha_pipeline_qr_ks_vs_async_trload.hpp:641
static constexpr index_t kM0
Definition block_fmha_pipeline_qr_ks_vs_async_trload.hpp:44
static constexpr index_t kAlignmentOacc
Definition block_fmha_pipeline_qr_ks_vs_async_trload.hpp:89
remove_cvref_t< typename Problem::FmhaMask > FmhaMask
Definition block_fmha_pipeline_qr_ks_vs_async_trload.hpp:34
remove_cvref_t< typename Problem::LSEDataType > LSEDataType
Definition block_fmha_pipeline_qr_ks_vs_async_trload.hpp:29
remove_cvref_t< Problem_ > Problem
Definition block_fmha_pipeline_qr_ks_vs_async_trload.hpp:20
remove_cvref_t< typename BlockFmhaShape::VLayout > VLayout
Definition block_fmha_pipeline_qr_ks_vs_async_trload.hpp:37
remove_cvref_t< typename Problem::PDataType > PDataType
Definition block_fmha_pipeline_qr_ks_vs_async_trload.hpp:30
static constexpr index_t kK1
Definition block_fmha_pipeline_qr_ks_vs_async_trload.hpp:48
static constexpr bool kKLoadOnce
Definition block_fmha_pipeline_qr_ks_vs_async_trload.hpp:40
static constexpr bool kHasLogitsSoftCap
Definition block_fmha_pipeline_qr_ks_vs_async_trload.hpp:67
static constexpr index_t kK0
Definition block_fmha_pipeline_qr_ks_vs_async_trload.hpp:46
static constexpr auto I1
Definition block_fmha_pipeline_qr_ks_vs_async_trload.hpp:18
static CK_TILE_HOST_DEVICE constexpr ck_tile::index_t GetSmemSize()
Definition block_fmha_pipeline_qr_ks_vs_async_trload.hpp:127
remove_cvref_t< typename Problem::OaccDataType > OaccDataType
Definition block_fmha_pipeline_qr_ks_vs_async_trload.hpp:31
static constexpr bool kPadSeqLenQ
Definition block_fmha_pipeline_qr_ks_vs_async_trload.hpp:60
remove_cvref_t< typename Problem::AttentionVariant > AttentionVariant
Definition block_fmha_pipeline_qr_ks_vs_async_trload.hpp:33
remove_cvref_t< Policy_ > Policy
Definition block_fmha_pipeline_qr_ks_vs_async_trload.hpp:21
static constexpr const char * name
Definition block_fmha_pipeline_qr_ks_vs_async_trload.hpp:125
static constexpr index_t kBlockSize
Definition block_fmha_pipeline_qr_ks_vs_async_trload.hpp:42
remove_cvref_t< typename Problem::ODataType > ODataType
Definition block_fmha_pipeline_qr_ks_vs_async_trload.hpp:32
CK_TILE_HOST_DEVICE auto operator()(const QDramBlockWindowTmp &q_dram_block_window_tmp, const KDramBlockWindowTmp &k_dram_block_window_tmp, const VDramBlockWindowTmp &v_dram_block_window_tmp, const BiasDramBlockWindowTmp &bias_dram_block_window_tmp, LSEaccDramBlockWindowTmp &lse_acc_dram_window_tmp, FmhaMask mask, PositionEncoding position_encoding, float scale_s, void *smem_ptr) const
Definition block_fmha_pipeline_qr_ks_vs_async_trload.hpp:140
static constexpr index_t kAlignmentBias
Definition block_fmha_pipeline_qr_ks_vs_async_trload.hpp:91
remove_cvref_t< typename Problem::BlockFmhaShape > BlockFmhaShape
Definition block_fmha_pipeline_qr_ks_vs_async_trload.hpp:36
static constexpr auto BiasEnum
Definition block_fmha_pipeline_qr_ks_vs_async_trload.hpp:69
static constexpr auto I0
Definition block_fmha_pipeline_qr_ks_vs_async_trload.hpp:17
static constexpr bool kIsGroupMode
Definition block_fmha_pipeline_qr_ks_vs_async_trload.hpp:59
static CK_TILE_HOST_DEVICE constexpr T infinity()
Definition tile/core/numeric/numeric.hpp:38
Definition tile/core/container/sequence.hpp:49
Definition tile/core/utility/functional.hpp:43
#define C_LOG2E
Definition tile/core/numeric/math.hpp:469