FmhaFwdSplitKVKernel< FmhaPipeline_, EpiloguePipeline_ > Struct Template Reference

FmhaFwdSplitKVKernel&lt; FmhaPipeline_, EpiloguePipeline_ &gt; Struct Template Reference#

Composable Kernel: ck_tile::FmhaFwdSplitKVKernel< FmhaPipeline_, EpiloguePipeline_ > Struct Template Reference
ck_tile::FmhaFwdSplitKVKernel< FmhaPipeline_, EpiloguePipeline_ > Struct Template Reference

#include <fmha_fwd_splitkv_kernel.hpp>

Classes

struct  t2s
struct  t2s< float >
struct  t2s< ck_tile::fp16_t >
struct  t2s< ck_tile::bf16_t >
struct  t2s< ck_tile::fp8_t >
struct  t2s< ck_tile::bf8_t >
struct  EmptyKargs
struct  CommonKargs
struct  LogitsSoftCapKargs
struct  CommonBiasKargs
struct  BatchModeBiasKargs
struct  AlibiKargs
struct  MaskKargs
struct  Fp8StaticQuantKargs
struct  CommonPageBlockTableKargs
struct  GroupModePageBlockTableKargs
struct  CacheBatchIdxKargs
struct  BatchModeKargs
struct  GroupModeKargs
struct  BlockIndices

Public Types

using FmhaPipeline = ck_tile::remove_cvref_t<FmhaPipeline_>
using EpiloguePipeline = ck_tile::remove_cvref_t<EpiloguePipeline_>
using QDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::QDataType>
using KDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::KDataType>
using VDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::VDataType>
using BiasDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::BiasDataType>
using LSEDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::LSEDataType>
using SaccDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::SaccDataType>
using OaccDataType = remove_cvref_t<typename FmhaPipeline::OaccDataType>
using ODataType = remove_cvref_t<typename FmhaPipeline::ODataType>
using VLayout = ck_tile::remove_cvref_t<typename FmhaPipeline::VLayout>
using AttentionVariant = ck_tile::remove_cvref_t<typename FmhaPipeline::AttentionVariant>
using FmhaMask = ck_tile::remove_cvref_t<typename FmhaPipeline::FmhaMask>
using Kargs = std::conditional_t<kIsGroupMode, GroupModeKargs, BatchModeKargs>

Public Member Functions

CK_TILE_DEVICE void operator() (Kargs kargs) const

Static Public Member Functions

static CK_TILE_HOST std::string GetName ()
template<bool Cond = !kIsGroupMode>
static CK_TILE_HOST constexpr std::enable_if_t< Cond, KargsMakeKargs (const void *q_ptr, const void *k_ptr, const void *v_ptr, const void *bias_ptr, void *lse_acc_ptr, void *o_acc_ptr, ck_tile::index_t batch, ck_tile::index_t seqlen_q, ck_tile::index_t seqlen_k, const void *seqlen_k_ptr, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, ck_tile::index_t num_splits, const void *block_table_ptr, ck_tile::index_t batch_stride_block_table, ck_tile::index_t page_block_size, const void *cache_batch_idx, float scale_s, float scale_p, float logits_soft_cap, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_bias, ck_tile::index_t stride_o_acc, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_lse_acc, ck_tile::index_t nhead_stride_o_acc, ck_tile::index_t batch_stride_q, ck_tile::index_t batch_stride_k, ck_tile::index_t batch_stride_v, ck_tile::index_t batch_stride_bias, ck_tile::index_t batch_stride_lse_acc, ck_tile::index_t batch_stride_o_acc, ck_tile::index_t split_stride_lse_acc, ck_tile::index_t split_stride_o_acc, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t mask_type)
template<bool Cond = kIsGroupMode>
static CK_TILE_HOST constexpr std::enable_if_t< Cond, KargsMakeKargs (const void *q_ptr, const void *k_ptr, const void *v_ptr, const void *bias_ptr, void *lse_acc_ptr, void *o_acc_ptr, ck_tile::index_t batch, const void *seqstart_q_ptr, const void *seqstart_k_ptr, const void *seqlen_k_ptr, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, ck_tile::index_t num_splits, const void *block_table_ptr, ck_tile::index_t batch_stride_block_table, ck_tile::index_t page_block_size, bool is_gappy, float scale_s, float scale_p, float logits_soft_cap, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_bias, ck_tile::index_t stride_o_acc, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_lse_acc, ck_tile::index_t nhead_stride_o_acc, ck_tile::index_t batch_stride_k, ck_tile::index_t batch_stride_v, ck_tile::index_t split_stride_lse_acc, ck_tile::index_t split_stride_o_acc, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t mask_type)
static CK_TILE_HOST constexpr auto GridSize (ck_tile::index_t batch_size, ck_tile::index_t nhead_q, ck_tile::index_t nhead_kv, ck_tile::index_t max_seqlen_q, ck_tile::index_t hdim_v, ck_tile::index_t num_splits)
static CK_TILE_DEVICE constexpr auto GetTileIndex (const Kargs &kargs)
static CK_TILE_HOST dim3 BlockSize ()
static CK_TILE_HOST_DEVICE constexpr ck_tile::index_t GetSmemSize ()

Static Public Attributes

static constexpr ck_tile::index_t kBlockSize = FmhaPipeline::kBlockSize
static constexpr ck_tile::index_t kBlockPerCu = FmhaPipeline::kBlockPerCu
static constexpr ck_tile::index_t kBlockPerCuInput = FmhaPipeline::Problem::kBlockPerCu
static constexpr bool kIsGroupMode = FmhaPipeline::kIsGroupMode
static constexpr bool kPadSeqLenQ = FmhaPipeline::kPadSeqLenQ
static constexpr bool kPadSeqLenK = FmhaPipeline::kPadSeqLenK
static constexpr bool kPadHeadDimQ = FmhaPipeline::kPadHeadDimQ
static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV
static constexpr bool kHasLogitsSoftCap = FmhaPipeline::kHasLogitsSoftCap
static constexpr auto BiasEnum = FmhaPipeline::BiasEnum
static constexpr bool kStoreLSE = FmhaPipeline::kStoreLSE
static constexpr bool kDoFp8StaticQuant = FmhaPipeline::Problem::kDoFp8StaticQuant
static constexpr bool kIsPagedKV = FmhaPipeline::Problem::kIsPagedKV
static constexpr bool kMergeNumHeadGroupsSeqLenQ
static constexpr bool kHasMask = FmhaMask::IsMasking

Member Typedef Documentation

◆ AttentionVariant

template<typename FmhaPipeline_, typename EpiloguePipeline_>
using ck_tile::FmhaFwdSplitKVKernel< FmhaPipeline_, EpiloguePipeline_ >::AttentionVariant = ck_tile::remove_cvref_t<typename FmhaPipeline::AttentionVariant>

◆ BiasDataType

template<typename FmhaPipeline_, typename EpiloguePipeline_>
using ck_tile::FmhaFwdSplitKVKernel< FmhaPipeline_, EpiloguePipeline_ >::BiasDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::BiasDataType>

◆ EpiloguePipeline

template<typename FmhaPipeline_, typename EpiloguePipeline_>
using ck_tile::FmhaFwdSplitKVKernel< FmhaPipeline_, EpiloguePipeline_ >::EpiloguePipeline = ck_tile::remove_cvref_t<EpiloguePipeline_>

◆ FmhaMask

template<typename FmhaPipeline_, typename EpiloguePipeline_>
using ck_tile::FmhaFwdSplitKVKernel< FmhaPipeline_, EpiloguePipeline_ >::FmhaMask = ck_tile::remove_cvref_t<typename FmhaPipeline::FmhaMask>

◆ FmhaPipeline

template<typename FmhaPipeline_, typename EpiloguePipeline_>
using ck_tile::FmhaFwdSplitKVKernel< FmhaPipeline_, EpiloguePipeline_ >::FmhaPipeline = ck_tile::remove_cvref_t<FmhaPipeline_>

◆ Kargs

template<typename FmhaPipeline_, typename EpiloguePipeline_>
using ck_tile::FmhaFwdSplitKVKernel< FmhaPipeline_, EpiloguePipeline_ >::Kargs = std::conditional_t<kIsGroupMode, GroupModeKargs, BatchModeKargs>

◆ KDataType

template<typename FmhaPipeline_, typename EpiloguePipeline_>
using ck_tile::FmhaFwdSplitKVKernel< FmhaPipeline_, EpiloguePipeline_ >::KDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::KDataType>

◆ LSEDataType

template<typename FmhaPipeline_, typename EpiloguePipeline_>
using ck_tile::FmhaFwdSplitKVKernel< FmhaPipeline_, EpiloguePipeline_ >::LSEDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::LSEDataType>

◆ OaccDataType

template<typename FmhaPipeline_, typename EpiloguePipeline_>
using ck_tile::FmhaFwdSplitKVKernel< FmhaPipeline_, EpiloguePipeline_ >::OaccDataType = remove_cvref_t<typename FmhaPipeline::OaccDataType>

◆ ODataType

template<typename FmhaPipeline_, typename EpiloguePipeline_>
using ck_tile::FmhaFwdSplitKVKernel< FmhaPipeline_, EpiloguePipeline_ >::ODataType = remove_cvref_t<typename FmhaPipeline::ODataType>

◆ QDataType

template<typename FmhaPipeline_, typename EpiloguePipeline_>
using ck_tile::FmhaFwdSplitKVKernel< FmhaPipeline_, EpiloguePipeline_ >::QDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::QDataType>

◆ SaccDataType

template<typename FmhaPipeline_, typename EpiloguePipeline_>
using ck_tile::FmhaFwdSplitKVKernel< FmhaPipeline_, EpiloguePipeline_ >::SaccDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::SaccDataType>

◆ VDataType

template<typename FmhaPipeline_, typename EpiloguePipeline_>
using ck_tile::FmhaFwdSplitKVKernel< FmhaPipeline_, EpiloguePipeline_ >::VDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::VDataType>

◆ VLayout

template<typename FmhaPipeline_, typename EpiloguePipeline_>
using ck_tile::FmhaFwdSplitKVKernel< FmhaPipeline_, EpiloguePipeline_ >::VLayout = ck_tile::remove_cvref_t<typename FmhaPipeline::VLayout>

Member Function Documentation

◆ BlockSize()

template<typename FmhaPipeline_, typename EpiloguePipeline_>
CK_TILE_HOST dim3 ck_tile::FmhaFwdSplitKVKernel< FmhaPipeline_, EpiloguePipeline_ >::BlockSize ( )
inlinestatic

◆ GetName()

template<typename FmhaPipeline_, typename EpiloguePipeline_>
CK_TILE_HOST std::string ck_tile::FmhaFwdSplitKVKernel< FmhaPipeline_, EpiloguePipeline_ >::GetName ( )
inlinestatic

◆ GetSmemSize()

template<typename FmhaPipeline_, typename EpiloguePipeline_>
CK_TILE_HOST_DEVICE constexpr ck_tile::index_t ck_tile::FmhaFwdSplitKVKernel< FmhaPipeline_, EpiloguePipeline_ >::GetSmemSize ( )
inlinestaticconstexpr

◆ GetTileIndex()

template<typename FmhaPipeline_, typename EpiloguePipeline_>
CK_TILE_DEVICE constexpr auto ck_tile::FmhaFwdSplitKVKernel< FmhaPipeline_, EpiloguePipeline_ >::GetTileIndex ( const Kargs & kargs)
inlinestaticconstexpr

◆ GridSize()

template<typename FmhaPipeline_, typename EpiloguePipeline_>
CK_TILE_HOST constexpr auto ck_tile::FmhaFwdSplitKVKernel< FmhaPipeline_, EpiloguePipeline_ >::GridSize ( ck_tile::index_t batch_size,
ck_tile::index_t nhead_q,
ck_tile::index_t nhead_kv,
ck_tile::index_t max_seqlen_q,
ck_tile::index_t hdim_v,
ck_tile::index_t num_splits )
inlinestaticconstexpr

◆ MakeKargs() [1/2]

template<typename FmhaPipeline_, typename EpiloguePipeline_>
template<bool Cond = !kIsGroupMode>
CK_TILE_HOST constexpr std::enable_if_t< Cond, Kargs > ck_tile::FmhaFwdSplitKVKernel< FmhaPipeline_, EpiloguePipeline_ >::MakeKargs ( const void * q_ptr,
const void * k_ptr,
const void * v_ptr,
const void * bias_ptr,
void * lse_acc_ptr,
void * o_acc_ptr,
ck_tile::index_t batch,
ck_tile::index_t seqlen_q,
ck_tile::index_t seqlen_k,
const void * seqlen_k_ptr,
ck_tile::index_t hdim_q,
ck_tile::index_t hdim_v,
ck_tile::index_t num_head_q,
ck_tile::index_t nhead_ratio_qk,
ck_tile::index_t num_splits,
const void * block_table_ptr,
ck_tile::index_t batch_stride_block_table,
ck_tile::index_t page_block_size,
const void * cache_batch_idx,
float scale_s,
float scale_p,
float logits_soft_cap,
ck_tile::index_t stride_q,
ck_tile::index_t stride_k,
ck_tile::index_t stride_v,
ck_tile::index_t stride_bias,
ck_tile::index_t stride_o_acc,
ck_tile::index_t nhead_stride_q,
ck_tile::index_t nhead_stride_k,
ck_tile::index_t nhead_stride_v,
ck_tile::index_t nhead_stride_bias,
ck_tile::index_t nhead_stride_lse_acc,
ck_tile::index_t nhead_stride_o_acc,
ck_tile::index_t batch_stride_q,
ck_tile::index_t batch_stride_k,
ck_tile::index_t batch_stride_v,
ck_tile::index_t batch_stride_bias,
ck_tile::index_t batch_stride_lse_acc,
ck_tile::index_t batch_stride_o_acc,
ck_tile::index_t split_stride_lse_acc,
ck_tile::index_t split_stride_o_acc,
ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right,
ck_tile::index_t mask_type )
inlinestaticconstexpr

◆ MakeKargs() [2/2]

template<typename FmhaPipeline_, typename EpiloguePipeline_>
template<bool Cond = kIsGroupMode>
CK_TILE_HOST constexpr std::enable_if_t< Cond, Kargs > ck_tile::FmhaFwdSplitKVKernel< FmhaPipeline_, EpiloguePipeline_ >::MakeKargs ( const void * q_ptr,
const void * k_ptr,
const void * v_ptr,
const void * bias_ptr,
void * lse_acc_ptr,
void * o_acc_ptr,
ck_tile::index_t batch,
const void * seqstart_q_ptr,
const void * seqstart_k_ptr,
const void * seqlen_k_ptr,
ck_tile::index_t hdim_q,
ck_tile::index_t hdim_v,
ck_tile::index_t num_head_q,
ck_tile::index_t nhead_ratio_qk,
ck_tile::index_t num_splits,
const void * block_table_ptr,
ck_tile::index_t batch_stride_block_table,
ck_tile::index_t page_block_size,
bool is_gappy,
float scale_s,
float scale_p,
float logits_soft_cap,
ck_tile::index_t stride_q,
ck_tile::index_t stride_k,
ck_tile::index_t stride_v,
ck_tile::index_t stride_bias,
ck_tile::index_t stride_o_acc,
ck_tile::index_t nhead_stride_q,
ck_tile::index_t nhead_stride_k,
ck_tile::index_t nhead_stride_v,
ck_tile::index_t nhead_stride_bias,
ck_tile::index_t nhead_stride_lse_acc,
ck_tile::index_t nhead_stride_o_acc,
ck_tile::index_t batch_stride_k,
ck_tile::index_t batch_stride_v,
ck_tile::index_t split_stride_lse_acc,
ck_tile::index_t split_stride_o_acc,
ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right,
ck_tile::index_t mask_type )
inlinestaticconstexpr

◆ operator()()

template<typename FmhaPipeline_, typename EpiloguePipeline_>
CK_TILE_DEVICE void ck_tile::FmhaFwdSplitKVKernel< FmhaPipeline_, EpiloguePipeline_ >::operator() ( Kargs kargs) const
inline

FIXME: Before C++20, capturing structured binding variables are not supported. Remove following copy capture of the 'i_nhead' if in C++20

Member Data Documentation

◆ BiasEnum

template<typename FmhaPipeline_, typename EpiloguePipeline_>
auto ck_tile::FmhaFwdSplitKVKernel< FmhaPipeline_, EpiloguePipeline_ >::BiasEnum = FmhaPipeline::BiasEnum
staticconstexpr

◆ kBlockPerCu

template<typename FmhaPipeline_, typename EpiloguePipeline_>
ck_tile::index_t ck_tile::FmhaFwdSplitKVKernel< FmhaPipeline_, EpiloguePipeline_ >::kBlockPerCu = FmhaPipeline::kBlockPerCu
staticconstexpr

◆ kBlockPerCuInput

template<typename FmhaPipeline_, typename EpiloguePipeline_>
ck_tile::index_t ck_tile::FmhaFwdSplitKVKernel< FmhaPipeline_, EpiloguePipeline_ >::kBlockPerCuInput = FmhaPipeline::Problem::kBlockPerCu
staticconstexpr

◆ kBlockSize

template<typename FmhaPipeline_, typename EpiloguePipeline_>
ck_tile::index_t ck_tile::FmhaFwdSplitKVKernel< FmhaPipeline_, EpiloguePipeline_ >::kBlockSize = FmhaPipeline::kBlockSize
staticconstexpr

◆ kDoFp8StaticQuant

template<typename FmhaPipeline_, typename EpiloguePipeline_>
bool ck_tile::FmhaFwdSplitKVKernel< FmhaPipeline_, EpiloguePipeline_ >::kDoFp8StaticQuant = FmhaPipeline::Problem::kDoFp8StaticQuant
staticconstexpr

◆ kHasLogitsSoftCap

template<typename FmhaPipeline_, typename EpiloguePipeline_>
bool ck_tile::FmhaFwdSplitKVKernel< FmhaPipeline_, EpiloguePipeline_ >::kHasLogitsSoftCap = FmhaPipeline::kHasLogitsSoftCap
staticconstexpr

◆ kHasMask

template<typename FmhaPipeline_, typename EpiloguePipeline_>
bool ck_tile::FmhaFwdSplitKVKernel< FmhaPipeline_, EpiloguePipeline_ >::kHasMask = FmhaMask::IsMasking
staticconstexpr

◆ kIsGroupMode

template<typename FmhaPipeline_, typename EpiloguePipeline_>
bool ck_tile::FmhaFwdSplitKVKernel< FmhaPipeline_, EpiloguePipeline_ >::kIsGroupMode = FmhaPipeline::kIsGroupMode
staticconstexpr

◆ kIsPagedKV

template<typename FmhaPipeline_, typename EpiloguePipeline_>
bool ck_tile::FmhaFwdSplitKVKernel< FmhaPipeline_, EpiloguePipeline_ >::kIsPagedKV = FmhaPipeline::Problem::kIsPagedKV
staticconstexpr

◆ kMergeNumHeadGroupsSeqLenQ

template<typename FmhaPipeline_, typename EpiloguePipeline_>
bool ck_tile::FmhaFwdSplitKVKernel< FmhaPipeline_, EpiloguePipeline_ >::kMergeNumHeadGroupsSeqLenQ
staticconstexpr
Initial value:
=
FmhaPipeline::Problem::kMergeNumHeadGroupsSeqLenQ

◆ kPadHeadDimQ

template<typename FmhaPipeline_, typename EpiloguePipeline_>
bool ck_tile::FmhaFwdSplitKVKernel< FmhaPipeline_, EpiloguePipeline_ >::kPadHeadDimQ = FmhaPipeline::kPadHeadDimQ
staticconstexpr

◆ kPadHeadDimV

template<typename FmhaPipeline_, typename EpiloguePipeline_>
bool ck_tile::FmhaFwdSplitKVKernel< FmhaPipeline_, EpiloguePipeline_ >::kPadHeadDimV = FmhaPipeline::kPadHeadDimV
staticconstexpr

◆ kPadSeqLenK

template<typename FmhaPipeline_, typename EpiloguePipeline_>
bool ck_tile::FmhaFwdSplitKVKernel< FmhaPipeline_, EpiloguePipeline_ >::kPadSeqLenK = FmhaPipeline::kPadSeqLenK
staticconstexpr

◆ kPadSeqLenQ

template<typename FmhaPipeline_, typename EpiloguePipeline_>
bool ck_tile::FmhaFwdSplitKVKernel< FmhaPipeline_, EpiloguePipeline_ >::kPadSeqLenQ = FmhaPipeline::kPadSeqLenQ
staticconstexpr

◆ kStoreLSE

template<typename FmhaPipeline_, typename EpiloguePipeline_>
bool ck_tile::FmhaFwdSplitKVKernel< FmhaPipeline_, EpiloguePipeline_ >::kStoreLSE = FmhaPipeline::kStoreLSE
staticconstexpr

The documentation for this struct was generated from the following file: