fmha_fwd_v3_kernel.hpp Source File

fmha_fwd_v3_kernel.hpp Source File#

Composable Kernel: fmha_fwd_v3_kernel.hpp Source File
fmha_fwd_v3_kernel.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include "ck_tile/core.hpp"
9
10#include <type_traits>
11#include <utility>
12
13namespace ck_tile {
14
15template <typename FmhaPipeline_, typename EpiloguePipeline_>
17{
20 static constexpr ck_tile::index_t kBlockSize = FmhaPipeline::kBlockSize;
21 static constexpr ck_tile::index_t kBlockPerCu = FmhaPipeline::kBlockPerCu;
22 static_assert(kBlockPerCu > 0);
23
30
31 static constexpr bool kIsGroupMode = FmhaPipeline::kIsGroupMode;
32 static constexpr bool kPadSeqLenQ = FmhaPipeline::kPadSeqLenQ;
33 static constexpr bool kPadSeqLenK = FmhaPipeline::kPadSeqLenK;
34 static constexpr bool kPadHeadDimQ = FmhaPipeline::kPadHeadDimQ;
35 static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV;
36 static constexpr bool kStoreLSE = FmhaPipeline::kStoreLSE;
37
39 static constexpr bool kHasMask = FmhaMask::IsMasking;
40
41 template <ck_tile::index_t I> // to avoid duplicated base class prblem, introduce an template
42 // arg
44 {
45 };
46
47 // kargs use aggregate initializer, so no constructor will provided
48 // use inheritance to minimize karg size
49 // user need to use MakeKargs() function to create kargs.
78
86
93
96 std::conditional_t<kHasMask, FmhaFwdMaskKargs, FmhaFwdEmptyKargs<0>>,
97 std::conditional_t<kStoreLSE, FmhaFwdCommonLSEKargs, FmhaFwdEmptyKargs<1>>
98 {
103
104 // Optional cumulative sequence length pointers for batch mode
105 // If provided, they override seqlen_q / seqlen_k per-batch to skip tail padding.
106 const ck_tile::index_t* cu_seqlen_q_ptr = nullptr; // [batch+1]
107 const ck_tile::index_t* cu_seqlen_kv_ptr = nullptr; // [batch+1]
108 };
109
112 std::conditional_t<kHasMask, FmhaFwdMaskKargs, FmhaFwdEmptyKargs<0>>,
113 std::conditional_t<kStoreLSE, FmhaFwdCommonLSEKargs, FmhaFwdEmptyKargs<1>>
114 {
118
119 // Optional cumulative padded sequence starts (including PAD tokens)
120 // Used solely to compute memory offsets when sequences are physically padded.
121 const int32_t* seqstart_padded_q_ptr = nullptr; // [batch+1]
122 const int32_t* seqstart_padded_k_ptr = nullptr; // [batch+1]
123 };
124
125 using Kargs = std::conditional_t<kIsGroupMode, FmhaFwdGroupModeKargs, FmhaFwdBatchModeKargs>;
126
127 template <bool Cond = !kIsGroupMode>
128 CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
129 MakeKargs(const void* q_ptr,
130 const void* k_ptr,
131 const void* v_ptr,
132 void* lse_ptr,
133 void* o_ptr,
134 ck_tile::index_t seqlen_q,
135 ck_tile::index_t seqlen_k,
136 ck_tile::index_t hdim_q,
137 ck_tile::index_t hdim_v,
138 ck_tile::index_t num_head_q,
139 ck_tile::index_t nhead_ratio_qk,
140 float scale_s,
141 ck_tile::index_t stride_q,
142 ck_tile::index_t stride_k,
143 ck_tile::index_t stride_v,
144 ck_tile::index_t stride_o,
145 ck_tile::index_t nhead_stride_q,
146 ck_tile::index_t nhead_stride_k,
147 ck_tile::index_t nhead_stride_v,
148 ck_tile::index_t nhead_stride_lse,
149 ck_tile::index_t nhead_stride_o,
150 ck_tile::index_t batch_stride_q,
151 ck_tile::index_t batch_stride_k,
152 ck_tile::index_t batch_stride_v,
153 ck_tile::index_t batch_stride_lse,
154 ck_tile::index_t batch_stride_o,
155 ck_tile::index_t window_size_left,
156 ck_tile::index_t window_size_right,
157 ck_tile::index_t mask_type,
158 ck_tile::index_t remap_opt,
159 const ck_tile::index_t* cu_seqlen_q_ptr = nullptr,
160 const ck_tile::index_t* cu_seqlen_kv_ptr = nullptr)
161 {
162 Kargs kargs{{q_ptr,
163 k_ptr,
164 v_ptr,
165 o_ptr,
166 seqlen_q,
167 seqlen_k,
168 hdim_q,
169 hdim_v,
170 num_head_q,
171 nhead_ratio_qk,
172 static_cast<float>(scale_s * ck_tile::log2e_v<>),
173 stride_q,
174 stride_k,
175 stride_v,
176 stride_o,
177 nhead_stride_q,
178 nhead_stride_k,
179 nhead_stride_v,
180 nhead_stride_o}, // args for common karg
181 {}, // placeholder for mask
182 {}, // placeholder for lse
183 batch_stride_q,
184 batch_stride_k,
185 batch_stride_v,
186 batch_stride_o};
187
188 if constexpr(kHasMask)
189 {
190 kargs.window_size_left = window_size_left;
191 kargs.window_size_right = window_size_right;
192 kargs.mask_type = static_cast<ck_tile::GenericAttentionMaskEnum>(mask_type);
193 kargs.remap_opt = remap_opt;
194 }
195 if constexpr(kStoreLSE)
196 {
197 kargs.lse_ptr = lse_ptr;
198 kargs.nhead_stride_lse = nhead_stride_lse;
199 kargs.batch_stride_lse = batch_stride_lse;
200 }
201
202 kargs.cu_seqlen_q_ptr = cu_seqlen_q_ptr;
203 kargs.cu_seqlen_kv_ptr = cu_seqlen_kv_ptr;
204 return kargs;
205 }
206
207 template <bool Cond = kIsGroupMode>
208 CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
209 MakeKargs(const void* q_ptr,
210 const void* k_ptr,
211 const void* v_ptr,
212 void* lse_ptr,
213 void* o_ptr,
214 const void* seqstart_q_ptr,
215 const void* seqstart_k_ptr,
216 const void* seqlen_k_ptr,
217 ck_tile::index_t hdim_q,
218 ck_tile::index_t hdim_v,
219 ck_tile::index_t num_head_q,
220 ck_tile::index_t nhead_ratio_qk,
221 float scale_s,
222 ck_tile::index_t stride_q,
223 ck_tile::index_t stride_k,
224 ck_tile::index_t stride_v,
225 ck_tile::index_t stride_o,
226 ck_tile::index_t nhead_stride_q,
227 ck_tile::index_t nhead_stride_k,
228 ck_tile::index_t nhead_stride_v,
229 ck_tile::index_t nhead_stride_lse,
230 ck_tile::index_t nhead_stride_o,
231 ck_tile::index_t window_size_left,
232 ck_tile::index_t window_size_right,
233 ck_tile::index_t mask_type,
234 ck_tile::index_t remap_opt,
235 const void* seqstart_padded_q_ptr = nullptr,
236 const void* seqstart_padded_k_ptr = nullptr)
237 {
238 Kargs kargs{{q_ptr,
239 k_ptr,
240 v_ptr,
241 o_ptr,
242 -1, // seqlen will be updated by another pointer
243 -1, //
244 hdim_q,
245 hdim_v,
246 num_head_q,
247 nhead_ratio_qk,
248 static_cast<float>(scale_s * ck_tile::log2e_v<>),
249 stride_q,
250 stride_k,
251 stride_v,
252 stride_o,
253 nhead_stride_q,
254 nhead_stride_k,
255 nhead_stride_v,
256 nhead_stride_o}, // args for common karg
257 {}, // placeholder for mask
258 {}, // placeholder for lse
259 reinterpret_cast<const int32_t*>(seqstart_q_ptr),
260 reinterpret_cast<const int32_t*>(seqstart_k_ptr),
261 reinterpret_cast<const int32_t*>(seqlen_k_ptr)};
262
263 if constexpr(kHasMask)
264 {
265 kargs.window_size_left = window_size_left;
266 kargs.window_size_right = window_size_right;
267 kargs.mask_type = static_cast<ck_tile::GenericAttentionMaskEnum>(mask_type);
268 kargs.remap_opt = remap_opt;
269 }
270 if constexpr(kStoreLSE)
271 {
272 kargs.lse_ptr = lse_ptr;
273 kargs.nhead_stride_lse = nhead_stride_lse;
274 }
275
276 kargs.seqstart_padded_q_ptr = reinterpret_cast<const int32_t*>(seqstart_padded_q_ptr);
277 kargs.seqstart_padded_k_ptr = reinterpret_cast<const int32_t*>(seqstart_padded_k_ptr);
278 return kargs;
279 }
280
281 CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size_,
282 ck_tile::index_t nhead_,
283 ck_tile::index_t seqlen_q_,
284 ck_tile::index_t hdim_v_)
285 {
286 // TODO: this may need tuning
287 if constexpr(kHasMask)
288 {
289 return dim3(nhead_,
290 ck_tile::integer_divide_ceil(seqlen_q_, FmhaPipeline::kM0) *
291 ck_tile::integer_divide_ceil(hdim_v_, FmhaPipeline::kN1),
292 batch_size_);
293 }
294 else
295 {
296 return dim3(nhead_,
297 ck_tile::integer_divide_ceil(seqlen_q_, FmhaPipeline::kM0) *
298 ck_tile::integer_divide_ceil(hdim_v_, FmhaPipeline::kN1),
299 batch_size_);
300 }
301 }
302
303 CK_TILE_DEVICE static constexpr auto
304 RemapTileIndices(int32_t tg_idx, int32_t tg_idy, int32_t remap_option)
305 {
306 if(remap_option < 1)
307 {
308 return make_tuple(static_cast<int32_t>(gridDim.x - tg_idx - 1), tg_idy);
309 }
310
311 int32_t remapped_tg_idx = tg_idx;
312 int32_t remapped_tg_idy = tg_idy;
313
314 if(remap_option == 2)
315 { // special remapping
316 int32_t tmp0 = (remapped_tg_idy & 0x7) * gridDim.x + remapped_tg_idx;
317 int32_t tmp1 = tmp0 & 0x7;
318
319 remapped_tg_idx = tmp0 >> 3;
320 remapped_tg_idy = (remapped_tg_idy & 0xfffffff8) + tmp1;
321 }
322 else
323 { // normal remapping
324 int32_t cus_per_xdim_per_xcc = gridDim.x >> 3;
325 int32_t tgs_cu_id = remapped_tg_idx >> 3;
326
327 if(tgs_cu_id < cus_per_xdim_per_xcc)
328 {
329 int32_t tgs_xcc_id = remapped_tg_idx & 0x7;
330 int32_t new_tg_idx = tgs_xcc_id * cus_per_xdim_per_xcc + tgs_cu_id;
331
332 remapped_tg_idx = new_tg_idx;
333 }
334 }
335
336 return make_tuple(remapped_tg_idx, remapped_tg_idy);
337 }
338
339 CK_TILE_DEVICE static constexpr auto GetTileIndex(const Kargs&)
340 {
341 using namespace ck_tile;
342
343 // const index_t num_tile_n1 = ck_tile::integer_divide_ceil(kargs.hdim_v,
344 // FmhaPipeline::kN1);
345
346 // assume that num_tile_n1 is always 1
347 if constexpr(kHasMask)
348 {
349 const index_t i_nhead = blockIdx.x;
350 const index_t i_block = blockIdx.y;
351 const index_t i_batch = blockIdx.z;
352
353 return ck_tile::make_tuple(gridDim.y - 1 - i_block, 0, i_nhead, i_batch);
354 }
355 else
356 {
357 const index_t i_nhead = blockIdx.x;
358 const index_t i_block = blockIdx.y;
359 const index_t i_batch = blockIdx.z;
360
361 return ck_tile::make_tuple(i_block, 0, i_nhead, i_batch);
362 }
363 }
364
365 CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); }
366
368 {
369 return ck_tile::max(FmhaPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
370 }
371
373 {
374 using namespace ck_tile;
375
376 // allocate LDS
377 __shared__ char smem_ptr[GetSmemSize()];
378
379 // divide problem
380 const auto [i_tile_m, i_tile_n, i_nhead, i_batch] = GetTileIndex(kargs);
381
382 const index_t i_m0 = amd_wave_read_first_lane(i_tile_m * FmhaPipeline::kM0);
383 const index_t i_n1 = amd_wave_read_first_lane(i_tile_n * FmhaPipeline::kN1);
384
385 long_index_t batch_offset_q = 0;
386 long_index_t batch_offset_k = 0;
387 long_index_t batch_offset_v = 0;
388 long_index_t batch_offset_lse = 0;
389 long_index_t batch_offset_o = 0;
390
391 if constexpr(kIsGroupMode)
392 {
393 // get starting offset for each batch
394 const long_index_t query_start_unpadded = kargs.seqstart_q_ptr[i_batch];
395 const long_index_t key_start_unpadded = kargs.seqstart_k_ptr[i_batch];
396
397 const long_index_t query_start_padded = kargs.seqstart_padded_q_ptr
398 ? kargs.seqstart_padded_q_ptr[i_batch]
399 : query_start_unpadded;
400 const long_index_t key_start_padded = kargs.seqstart_padded_k_ptr
401 ? kargs.seqstart_padded_k_ptr[i_batch]
402 : key_start_unpadded;
403
404 batch_offset_q = query_start_padded * kargs.stride_q;
405 batch_offset_k = key_start_padded * kargs.stride_k;
406 batch_offset_v = key_start_padded * kargs.stride_v;
407
408 if constexpr(kStoreLSE)
409 {
410 // LSE layout is [nhead, total_seqlen], index by unpadded start
411 batch_offset_lse = query_start_unpadded;
412 }
413 batch_offset_o = query_start_padded * kargs.stride_o;
414
415 // get real # queries & # keys under group mode
416 const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch;
417 kargs.seqlen_q = adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0];
418
419 // # of required blocks is different in each groups, terminate unnecessary blocks
420 // earlier
421 if(kargs.seqlen_q <= i_m0)
422 {
423 return;
424 }
425
426 if(kargs.seqlen_k_ptr != nullptr)
427 {
428 kargs.seqlen_k = kargs.seqlen_k_ptr[i_batch];
429 }
430 else
431 {
432 const auto adjusted_seqstart_k_ptr = kargs.seqstart_k_ptr + i_batch;
433 kargs.seqlen_k = adjusted_seqstart_k_ptr[1] - adjusted_seqstart_k_ptr[0];
434 }
435 }
436 else
437 {
438 batch_offset_q = static_cast<long_index_t>(i_batch) * kargs.batch_stride_q;
439 batch_offset_k = static_cast<long_index_t>(i_batch) * kargs.batch_stride_k;
440 batch_offset_v = static_cast<long_index_t>(i_batch) * kargs.batch_stride_v;
441 if constexpr(kStoreLSE)
442 {
443 batch_offset_lse = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse;
444 }
445 batch_offset_o = static_cast<long_index_t>(i_batch) * kargs.batch_stride_o;
446
447 // If cumulative seqlen pointers are provided, override per-batch effective lengths
448 if(kargs.cu_seqlen_q_ptr != nullptr)
449 {
450 kargs.seqlen_q =
451 kargs.cu_seqlen_q_ptr[i_batch + 1] - kargs.cu_seqlen_q_ptr[i_batch];
452 }
453 if(kargs.cu_seqlen_kv_ptr != nullptr)
454 {
455 kargs.seqlen_k =
456 kargs.cu_seqlen_kv_ptr[i_batch + 1] - kargs.cu_seqlen_kv_ptr[i_batch];
457 }
458 }
459
460 // for simplicity, batch stride we just modify the pointer
461 const QDataType* q_ptr = reinterpret_cast<const QDataType*>(kargs.q_ptr) +
462 static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_q +
463 batch_offset_q;
464 const KDataType* k_ptr =
465 reinterpret_cast<const KDataType*>(kargs.k_ptr) +
466 static_cast<long_index_t>(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_k +
467 batch_offset_k;
468 const VDataType* v_ptr =
469 reinterpret_cast<const VDataType*>(kargs.v_ptr) +
470 static_cast<long_index_t>(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_v +
471 batch_offset_v;
472 ODataType* o_ptr = reinterpret_cast<ODataType*>(kargs.o_ptr) +
473 static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_o +
474 batch_offset_o;
475
476 // Q/K/V DRAM and DRAM window
477 const auto q_dram = [&]() {
479 q_ptr,
480 make_tuple(kargs.seqlen_q, kargs.hdim_q),
481 make_tuple(kargs.stride_q, 1),
483 number<1>{});
484
485 return pad_tensor_view(
486 q_dram_naive,
489 }();
490 const auto k_dram = [&]() {
492 k_ptr,
493 make_tuple(kargs.seqlen_k, kargs.hdim_q),
494 make_tuple(kargs.stride_k, 1),
496 number<1>{});
497
498 return pad_tensor_view(
499 k_dram_naive,
502 }();
503 const auto v_dram = [&]() {
505 v_ptr,
506 make_tuple(kargs.seqlen_k, kargs.hdim_v),
507 make_tuple(kargs.stride_v, 1),
509 number<1>{});
510
511 return pad_tensor_view(
512 v_dram_naive,
515 }();
516
517 auto q_dram_window = make_tile_window(
518 q_dram,
520 {i_m0, 0});
521
522 auto k_dram_window = make_tile_window(
524
525 auto v_dram_window =
526 make_tile_window(v_dram,
528 {0, i_n1});
529
530 // lse
531 auto lse_dram_window = [&, i_nhead_ = i_nhead]() {
532 constexpr auto lse_dram_window_lengths = make_tuple(number<FmhaPipeline::kM0>{});
533 if constexpr(kStoreLSE)
534 {
535 LSEDataType* lse_ptr =
536 reinterpret_cast<LSEDataType*>(kargs.lse_ptr) +
537 static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_lse + batch_offset_lse;
538
539 const auto lse_dram = [&]() {
540 const auto lse_dram_naive = make_naive_tensor_view<address_space_enum::global>(
541 lse_ptr,
542 make_tuple(kargs.seqlen_q),
543 make_tuple(1),
544 number<1>{},
545 number<1>{});
546
547 return pad_tensor_view(
548 lse_dram_naive, lse_dram_window_lengths, sequence<kPadSeqLenQ>{});
549 }();
550
551 return make_tile_window(lse_dram, lse_dram_window_lengths, {i_m0});
552 }
553 else
554 {
555 return make_null_tile_window(lse_dram_window_lengths);
556 }
557 }();
558
559 FmhaMask mask = [&]() {
560 if constexpr(kHasMask)
562 kargs.window_size_left,
563 kargs.window_size_right,
564 kargs.seqlen_q,
565 kargs.seqlen_k,
567 else
568 return FmhaMask{kargs.seqlen_q, kargs.seqlen_k};
569 }();
570
571 auto o_acc_tile = [&]() {
572 return FmhaPipeline{}(q_dram_window,
573 k_dram_window,
574 v_dram_window,
575 lse_dram_window,
576 mask,
577 kargs.scale_s,
578 smem_ptr);
579 }();
580
581 // O DRAM and O DRAM window
582 auto o_dram = [&]() {
584 o_ptr,
585 make_tuple(kargs.seqlen_q, kargs.hdim_v),
586 make_tuple(kargs.stride_o, 1),
588 number<1>{});
589
590 return pad_tensor_view(
591 o_dram_naive,
594 }();
595
596 auto o_dram_window =
597 make_tile_window(o_dram,
599 {i_m0, i_n1});
600
601 EpiloguePipeline{}(o_dram_window, o_acc_tile, nullptr);
602 }
603};
604} // namespace ck_tile
#define CK_TILE_DEVICE
Definition config.hpp:41
#define CK_TILE_HOST
Definition config.hpp:40
#define CK_TILE_HOST_DEVICE
Definition config.hpp:42
Definition tile/core/algorithm/cluster_descriptor.hpp:13
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
CK_TILE_HOST_DEVICE constexpr auto make_naive_tensor_view(DataType *__restrict__ p, const tuple< Lengths... > &lengths, const tuple< Strides... > &strides, number< GuaranteedLastDimensionVectorLength >=number<-1 >{}, number< GuaranteedLastDimensionVectorStride >=number<-1 >{})
Definition tensor_view.hpp:471
__device__ uint32_t amd_wave_read_first_lane(uint16_t v)
Definition tile/core/arch/amd_buffer_addressing.hpp:35
CK_TILE_HOST_DEVICE constexpr auto make_generic_attention_mask_from_lr_window(index_t left_size, index_t right_size, index_t y_total, index_t x_total, bool is_top_left=true)
Definition block_masking.hpp:632
constexpr T log2e_v
Definition tile/core/numeric/math.hpp:488
int64_t long_index_t
Definition integer.hpp:11
constant< v > number
Definition tile/core/numeric/integral_constant.hpp:37
CK_TILE_DEVICE constexpr auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition null_tile_window.hpp:75
CK_TILE_HOST_DEVICE constexpr auto integer_divide_ceil(X x, Y y)
Definition tile/core/numeric/math.hpp:149
CK_TILE_HOST_DEVICE constexpr auto pad_tensor_view(const TensorView &tensor_view, const TileLengths &tile_lengths, DoPads)
Definition tensor_view.hpp:530
int32_t int32_t
Definition integer.hpp:10
CK_TILE_DEVICE constexpr auto make_null_tile_window(const WindowLengths &window_lengths)
Definition null_tile_window.hpp:66
GenericAttentionMaskEnum
Definition block_masking.hpp:11
@ MASK_FROM_TOP_LEFT
Definition block_masking.hpp:15
CK_TILE_HOST_DEVICE constexpr T max(T x)
Definition tile/core/numeric/math.hpp:161
int32_t index_t
Definition integer.hpp:9
CK_TILE_HOST_DEVICE constexpr auto make_tuple(Xs &&... xs)
Definition tile/core/container/tuple.hpp:360
Definition fmha_fwd_v3_kernel.hpp:98
const ck_tile::index_t * cu_seqlen_kv_ptr
Definition fmha_fwd_v3_kernel.hpp:107
ck_tile::index_t batch_stride_o
Definition fmha_fwd_v3_kernel.hpp:102
ck_tile::index_t batch_stride_q
Definition fmha_fwd_v3_kernel.hpp:99
ck_tile::index_t batch_stride_k
Definition fmha_fwd_v3_kernel.hpp:100
ck_tile::index_t batch_stride_v
Definition fmha_fwd_v3_kernel.hpp:101
const ck_tile::index_t * cu_seqlen_q_ptr
Definition fmha_fwd_v3_kernel.hpp:106
Definition fmha_fwd_v3_kernel.hpp:51
ck_tile::index_t nhead_stride_v
Definition fmha_fwd_v3_kernel.hpp:75
ck_tile::index_t nhead_stride_q
Definition fmha_fwd_v3_kernel.hpp:73
ck_tile::index_t nhead_ratio_qk
Definition fmha_fwd_v3_kernel.hpp:65
ck_tile::index_t stride_o
Definition fmha_fwd_v3_kernel.hpp:71
ck_tile::index_t seqlen_k
Definition fmha_fwd_v3_kernel.hpp:58
ck_tile::index_t nhead_stride_o
Definition fmha_fwd_v3_kernel.hpp:76
ck_tile::index_t hdim_q
Definition fmha_fwd_v3_kernel.hpp:59
ck_tile::index_t seqlen_q
Definition fmha_fwd_v3_kernel.hpp:57
ck_tile::index_t stride_v
Definition fmha_fwd_v3_kernel.hpp:70
const void * q_ptr
Definition fmha_fwd_v3_kernel.hpp:52
float scale_s
Definition fmha_fwd_v3_kernel.hpp:66
const void * v_ptr
Definition fmha_fwd_v3_kernel.hpp:54
void * o_ptr
Definition fmha_fwd_v3_kernel.hpp:55
ck_tile::index_t stride_q
Definition fmha_fwd_v3_kernel.hpp:68
ck_tile::index_t num_head_q
Definition fmha_fwd_v3_kernel.hpp:62
const void * k_ptr
Definition fmha_fwd_v3_kernel.hpp:53
ck_tile::index_t hdim_v
Definition fmha_fwd_v3_kernel.hpp:60
ck_tile::index_t nhead_stride_k
Definition fmha_fwd_v3_kernel.hpp:74
ck_tile::index_t stride_k
Definition fmha_fwd_v3_kernel.hpp:69
Definition fmha_fwd_v3_kernel.hpp:88
void * lse_ptr
Definition fmha_fwd_v3_kernel.hpp:89
ck_tile::index_t nhead_stride_lse
Definition fmha_fwd_v3_kernel.hpp:90
ck_tile::index_t batch_stride_lse
Definition fmha_fwd_v3_kernel.hpp:91
Definition fmha_fwd_v3_kernel.hpp:44
Definition fmha_fwd_v3_kernel.hpp:114
const int32_t * seqstart_padded_q_ptr
Definition fmha_fwd_v3_kernel.hpp:121
const int32_t * seqlen_k_ptr
Definition fmha_fwd_v3_kernel.hpp:117
const int32_t * seqstart_padded_k_ptr
Definition fmha_fwd_v3_kernel.hpp:122
const int32_t * seqstart_k_ptr
Definition fmha_fwd_v3_kernel.hpp:116
const int32_t * seqstart_q_ptr
Definition fmha_fwd_v3_kernel.hpp:115
Definition fmha_fwd_v3_kernel.hpp:80
ck_tile::index_t window_size_left
Definition fmha_fwd_v3_kernel.hpp:82
ck_tile::index_t remap_opt
Definition fmha_fwd_v3_kernel.hpp:84
ck_tile::index_t window_size_right
Definition fmha_fwd_v3_kernel.hpp:82
ck_tile::GenericAttentionMaskEnum mask_type
Definition fmha_fwd_v3_kernel.hpp:83
Definition fmha_fwd_v3_kernel.hpp:17
static constexpr bool kPadSeqLenQ
Definition fmha_fwd_v3_kernel.hpp:32
ck_tile::remove_cvref_t< typename FmhaPipeline::QDataType > QDataType
Definition fmha_fwd_v3_kernel.hpp:24
CK_TILE_DEVICE void operator()(Kargs kargs) const
Definition fmha_fwd_v3_kernel.hpp:372
static constexpr bool kPadSeqLenK
Definition fmha_fwd_v3_kernel.hpp:33
static CK_TILE_DEVICE constexpr auto GetTileIndex(const Kargs &)
Definition fmha_fwd_v3_kernel.hpp:339
static constexpr ck_tile::index_t kBlockSize
Definition fmha_fwd_v3_kernel.hpp:20
ck_tile::remove_cvref_t< typename FmhaPipeline::ODataType > ODataType
Definition fmha_fwd_v3_kernel.hpp:28
ck_tile::remove_cvref_t< typename FmhaPipeline::KDataType > KDataType
Definition fmha_fwd_v3_kernel.hpp:25
static CK_TILE_DEVICE constexpr auto RemapTileIndices(int32_t tg_idx, int32_t tg_idy, int32_t remap_option)
Definition fmha_fwd_v3_kernel.hpp:304
static constexpr bool kPadHeadDimV
Definition fmha_fwd_v3_kernel.hpp:35
ck_tile::remove_cvref_t< FmhaPipeline_ > FmhaPipeline
Definition fmha_fwd_v3_kernel.hpp:18
static constexpr bool kHasMask
Definition fmha_fwd_v3_kernel.hpp:39
std::conditional_t< kIsGroupMode, FmhaFwdGroupModeKargs, FmhaFwdBatchModeKargs > Kargs
Definition fmha_fwd_v3_kernel.hpp:125
static CK_TILE_HOST constexpr std::enable_if_t< Cond, Kargs > MakeKargs(const void *q_ptr, const void *k_ptr, const void *v_ptr, void *lse_ptr, void *o_ptr, const void *seqstart_q_ptr, const void *seqstart_k_ptr, const void *seqlen_k_ptr, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, float scale_s, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t mask_type, ck_tile::index_t remap_opt, const void *seqstart_padded_q_ptr=nullptr, const void *seqstart_padded_k_ptr=nullptr)
Definition fmha_fwd_v3_kernel.hpp:209
static CK_TILE_HOST constexpr std::enable_if_t< Cond, Kargs > MakeKargs(const void *q_ptr, const void *k_ptr, const void *v_ptr, void *lse_ptr, void *o_ptr, ck_tile::index_t seqlen_q, ck_tile::index_t seqlen_k, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, float scale_s, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, ck_tile::index_t batch_stride_q, ck_tile::index_t batch_stride_k, ck_tile::index_t batch_stride_v, ck_tile::index_t batch_stride_lse, ck_tile::index_t batch_stride_o, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t mask_type, ck_tile::index_t remap_opt, const ck_tile::index_t *cu_seqlen_q_ptr=nullptr, const ck_tile::index_t *cu_seqlen_kv_ptr=nullptr)
Definition fmha_fwd_v3_kernel.hpp:129
ck_tile::remove_cvref_t< EpiloguePipeline_ > EpiloguePipeline
Definition fmha_fwd_v3_kernel.hpp:19
static constexpr bool kPadHeadDimQ
Definition fmha_fwd_v3_kernel.hpp:34
ck_tile::remove_cvref_t< typename FmhaPipeline::VDataType > VDataType
Definition fmha_fwd_v3_kernel.hpp:26
static CK_TILE_HOST constexpr auto BlockSize()
Definition fmha_fwd_v3_kernel.hpp:365
static constexpr ck_tile::index_t kBlockPerCu
Definition fmha_fwd_v3_kernel.hpp:21
static constexpr bool kStoreLSE
Definition fmha_fwd_v3_kernel.hpp:36
ck_tile::remove_cvref_t< typename FmhaPipeline::SaccDataType > SaccDataType
Definition fmha_fwd_v3_kernel.hpp:29
static CK_TILE_HOST_DEVICE constexpr ck_tile::index_t GetSmemSize()
Definition fmha_fwd_v3_kernel.hpp:367
static CK_TILE_HOST constexpr auto GridSize(ck_tile::index_t batch_size_, ck_tile::index_t nhead_, ck_tile::index_t seqlen_q_, ck_tile::index_t hdim_v_)
Definition fmha_fwd_v3_kernel.hpp:281
ck_tile::remove_cvref_t< typename FmhaPipeline::LSEDataType > LSEDataType
Definition fmha_fwd_v3_kernel.hpp:27
static constexpr bool kIsGroupMode
Definition fmha_fwd_v3_kernel.hpp:31
ck_tile::remove_cvref_t< typename FmhaPipeline::FmhaMask > FmhaMask
Definition fmha_fwd_v3_kernel.hpp:38
Definition tile/core/container/sequence.hpp:49