13template <
typename GridwiseSparseEmbedding,
16 typename GammaDataType,
17 typename BetaDataType,
21 typename EmbElementwiseOperation,
23#if CK_USE_LAUNCH_BOUNDS
28 const ck::Array<EmbType*, NumEmbeddings> p_embs,
29 const ck::Array<IndexType*, NumEmbeddings> p_indexes,
30 const GammaDataType* p_gamma,
31 const BetaDataType* p_beta,
32 const OutGridDesc out_grid_desc,
33 const AccDataType epsilon,
34 const EmbElementwiseOperation emb_elementwise_op)
36 GridwiseSparseEmbedding::Run(
37 p_out, p_embs, p_indexes, p_gamma, p_beta, out_grid_desc, epsilon, emb_elementwise_op);
40template <
typename EmbType,
42 typename GammaDataType,
43 typename BetaDataType,
47 typename EmbElementwiseOperation,
64 static_assert(BlockSize == RowClusterSize * DimClusterSize,
65 "Invalid cluster distribution within block");
66 static_assert(RowClusterSize %
WaveSize == 0,
"need to be wavewise");
68 static_assert(DimPerBlock % (DimClusterSize * DimThreadSize) == 0,
"");
69 static_assert(RowPerBlock % (RowClusterSize * RowVectorSize) == 0,
"");
71 static constexpr auto DimSubBlocks = DimPerBlock / (DimClusterSize * DimThreadSize);
72 static constexpr auto RowSubBlocks = RowPerBlock / (RowClusterSize * RowVectorSize);
92 __device__
static void Run(OutType* p_out,
95 const GammaDataType* p_gamma,
96 const BetaDataType* p_beta,
98 const AccDataType epsilon,
99 const EmbElementwiseOperation emb_elementwise_op)
104 constexpr auto thread_cluster_desc =
107 const auto thread_cluster_idx =
108 thread_cluster_desc.CalculateBottomIndex(
make_multi_index(thread_local_id));
110 const auto thread_dim_cluster_id = thread_cluster_idx[
I0];
111 const auto thread_row_cluster_id = thread_cluster_idx[
I1];
113 const auto wave_dim_id = __builtin_amdgcn_readfirstlane(thread_dim_cluster_id /
WaveSize);
115 const auto index_start = block_global_id * DimPerBlock + wave_dim_id * DimThreadSize;
118 threadwise_welford.max_count_ =
RowSubBlocks * RowVectorSize;
120 constexpr auto thread_buf_size =
124 constexpr auto mean_var_buf_size =
DimSubBlocks * DimThreadSize;
125 constexpr auto mean_var_buf_desc =
127 constexpr auto gamma_beta_buf_size =
RowSubBlocks * RowVectorSize;
128 constexpr auto gamma_beta_buf_desc =
147 auto load_current_sub_row = [&](
auto i_dim_sub_,
auto i_row_sub_) {
149 auto emb_a = emb_vectors[0];
150 using src_vector_t =
typename decltype(emb_a)::type;
152 constexpr auto current_dim = i_dim_sub_ *
DimPerSubBlock + i_dim_vec_;
154 auto thread_offset = (thread_row_cluster_id + i_row_sub_ * RowClusterSize) *
155 sizeof(EmbType) * RowVectorSize;
159 __amdgpu_buffer_rsrc_t emb_res =
161 index * RowPerBlock);
162 emb_vectors(i_embedding_).template AsType<src_vector_t>()(
I0) =
167 constexpr auto register_offset = thread_buf_desc.CalculateOffset(
168 make_tuple(i_dim_sub_, i_dim_vec_, i_row_sub_, i_row_vec_));
172 emb_vectors[i_embedding_].
template AsType<EmbType>()[i_row_vec_]);
178 auto accumulate_current_sub_row = [&](
auto i_dim_sub_,
auto i_row_sub_) {
181 constexpr auto register_offset = thread_buf_desc.CalculateOffset(
182 make_tuple(i_dim_sub_, i_dim_vec_, i_row_sub_, i_row_vec_));
184 [&](
auto i_embedding_) ->
const auto& {
191 unpack2(emb_elementwise_op, out_data_refs, in_data_refs);
196 auto threadwise_welford_sub_row = [&](
auto i_dim_sub_,
auto i_row_sub_) {
199 constexpr auto register_offset = thread_buf_desc.CalculateOffset(
200 make_tuple(i_dim_sub_, i_dim_vec_, i_row_sub_, i_row_vec_));
201 constexpr auto mean_var_offset =
202 mean_var_buf_desc.CalculateOffset(
make_tuple(i_dim_sub_, i_dim_vec_));
204 threadwise_welford.cur_count_++;
212 auto threadwise_normalize_store_out = [&](
auto i_dim_sub_,
auto i_row_sub_) {
213 __amdgpu_buffer_rsrc_t out_res =
217 using dst_vector_t =
typename decltype(out_vector)::type;
219 constexpr auto mean_var_offset =
220 mean_var_buf_desc.CalculateOffset(
make_tuple(i_dim_sub_, i_dim_vec_));
224 constexpr auto register_offset = thread_buf_desc.CalculateOffset(
225 make_tuple(i_dim_sub_, i_dim_vec_, i_row_sub_, i_row_vec_));
226 constexpr auto gamma_beta_offset =
227 gamma_beta_buf_desc.CalculateOffset(
make_tuple(i_row_sub_, i_row_vec_));
238 index_t thread_offset = (thread_row_cluster_id + i_row_sub_ * RowClusterSize) *
239 sizeof(OutType) * RowVectorSize;
242 out_vector.template AsType<dst_vector_t>()[
Number<0>{}],
253 index_bufs(i_embedding_)(i_idx_) =
254 p_indexes[i_embedding_][index_start + i_idx_.value];
263 index_t thread_offset_gamma = (thread_row_cluster_id + i_row_sub_ * RowClusterSize) *
264 sizeof(GammaDataType) * RowVectorSize;
265 index_t thread_offset_beta = (thread_row_cluster_id + i_row_sub_ * RowClusterSize) *
266 sizeof(BetaDataType) * RowVectorSize;
268 __amdgpu_buffer_rsrc_t gamma_res =
270 __amdgpu_buffer_rsrc_t beta_res =
273 gamma_vector.template AsType<typename decltype(gamma_vector)::type>()(
I0) =
275 gamma_res, thread_offset_gamma, 0);
276 beta_vector.template AsType<typename decltype(beta_vector)::type>()(
I0) =
280 constexpr auto offset =
281 gamma_beta_buf_desc.CalculateOffset(
make_tuple(i_row_sub_, i_row_vec_));
298 load_current_sub_row(i_dim_sub,
Number<0>{});
300 load_current_sub_row(i_dim_sub,
Number<1>{} + i_row);
301 accumulate_current_sub_row(i_dim_sub, i_row);
302 threadwise_welford_sub_row(i_dim_sub, i_row);
312 mean_thread_buf(I), var_thread_buf(I), threadwise_welford.cur_count_);
317 [&](
auto i_row) { threadwise_normalize_store_out(i_dim_sub, i_row); });
#define CK_MIN_BLOCK_PER_CU
Definition ck.hpp:31
#define CK_MAX_THREAD_PER_BLOCK
Definition ck.hpp:30
__host__ __device__ constexpr auto make_multi_index(Xs &&... xs)
Definition array_multi_index.hpp:15
int32_t index_t
Definition ck.hpp:299
__device__ void amd_buffer_store_impl(const typename vector_type< T, N >::type src_thread_data, int32x4_t dst_wave_buffer_resource, index_t dst_thread_addr_offset, index_t dst_wave_addr_offset)
Definition utility/amd_buffer_addressing.hpp:544
__host__ __device__ constexpr auto make_cluster_descriptor(const Lengths &lengths, ArrangeOrder order=typename arithmetic_sequence_gen< 0, Lengths::Size(), 1 >::type{})
Definition tensor_description/cluster_descriptor.hpp:13
integral_constant< index_t, N > Number
Definition number.hpp:12
__global__ void kernel_sparse_embeddings_forward_layernorm(OutType *p_out, const ck::Array< EmbType *, NumEmbeddings > p_embs, const ck::Array< IndexType *, NumEmbeddings > p_indexes, const GammaDataType *p_gamma, const BetaDataType *p_beta, const OutGridDesc out_grid_desc, const AccDataType epsilon, const EmbElementwiseOperation emb_elementwise_op)
Definition gridwise_sparse_embeddings_forward_layernorm.hpp:26
__device__ index_t get_block_1d_id()
Definition get_id.hpp:47
__host__ __device__ constexpr Y type_convert(X x)
Definition utility/type_convert.hpp:98
__host__ __device__ constexpr auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition tensor_descriptor_helper.hpp:101
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__device__ vector_type< T, N >::type amd_buffer_load_impl(int32x4_t src_wave_buffer_resource, index_t src_thread_addr_offset, index_t src_wave_addr_offset)
Definition utility/amd_buffer_addressing.hpp:419
__device__ index_t get_thread_local_1d_id()
Definition get_id.hpp:41
__device__ void block_sync_lds()
Definition synchronization.hpp:16
__host__ __device__ constexpr auto unpack2(F &&f, X &&x, Y &&y)
Definition functional4.hpp:55
__device__ __amdgpu_buffer_rsrc_t make_wave_buffer_resource_with_default_range_new(T *p_wave)
Definition utility/amd_buffer_addressing_builtins.hpp:66
__host__ __device__ constexpr auto generate_tie(F &&f, Number< N >)
Definition tuple_helper.hpp:34
typename vector_type_maker< T, N >::type vector_type_maker_t
Definition dtype_vector.hpp:54
Definition utility/array.hpp:14
static __device__ void Run(T &mean_value, T &var_value, CountDataType &count)
Definition blockwise_welford.hpp:51
Definition gridwise_sparse_embeddings_forward_layernorm.hpp:57
static constexpr auto I0
Definition gridwise_sparse_embeddings_forward_layernorm.hpp:58
static constexpr auto RowPerSubBlock
Definition gridwise_sparse_embeddings_forward_layernorm.hpp:76
static constexpr auto RowSubBlocks
Definition gridwise_sparse_embeddings_forward_layernorm.hpp:72
static constexpr auto DimSubBlocks
Definition gridwise_sparse_embeddings_forward_layernorm.hpp:71
ck::GridwiseSparseEmbeddingsForwardLayernorm< EmbType, IndexType, GammaDataType, BetaDataType, AccDataType, OutType, decltype(MakeOutputDescriptor(1, 1)), EmbElementwiseOperation, BlockSize, DimClusterSize, RowClusterSize, DimPerBlock, RowPerBlock, DimThreadSize, RowVectorSize, NumEmbeddings >::BlockwiseWelford BlockwiseWelford< AccDataType, BlockSize, ThreadClusterLength, Sequence< 0, 1 > > BlockwiseWelford
Definition gridwise_sparse_embeddings_forward_layernorm.hpp:89
static constexpr auto DimPerSubBlock
Definition gridwise_sparse_embeddings_forward_layernorm.hpp:75
ck::GridwiseSparseEmbeddingsForwardLayernorm< EmbType, IndexType, GammaDataType, BetaDataType, AccDataType, OutType, decltype(MakeOutputDescriptor(1, 1)), EmbElementwiseOperation, BlockSize, DimClusterSize, RowClusterSize, DimPerBlock, RowPerBlock, DimThreadSize, RowVectorSize, NumEmbeddings >::ThreadwiseWolfordDesc2D decltype(make_naive_tensor_descriptor_packed(make_tuple(Number< DimSubBlocks *DimThreadSize >{}, Number< RowSubBlocks *RowVectorSize >{}))) ThreadwiseWolfordDesc2D
Definition gridwise_sparse_embeddings_forward_layernorm.hpp:78
static constexpr auto I1
Definition gridwise_sparse_embeddings_forward_layernorm.hpp:59
ck::GridwiseSparseEmbeddingsForwardLayernorm< EmbType, IndexType, GammaDataType, BetaDataType, AccDataType, OutType, decltype(MakeOutputDescriptor(1, 1)), EmbElementwiseOperation, BlockSize, DimClusterSize, RowClusterSize, DimPerBlock, RowPerBlock, DimThreadSize, RowVectorSize, NumEmbeddings >::ThreadwiseWolfordDescReduce decltype(make_naive_tensor_descriptor_packed(make_tuple(Number< DimSubBlocks *DimThreadSize >{}))) ThreadwiseWolfordDescReduce
Definition gridwise_sparse_embeddings_forward_layernorm.hpp:81
static constexpr auto I2
Definition gridwise_sparse_embeddings_forward_layernorm.hpp:60
static __device__ void Run(OutType *p_out, const ck::Array< EmbType *, NumEmbeddings > p_embs, const ck::Array< IndexType *, NumEmbeddings > p_indexes, const GammaDataType *p_gamma, const BetaDataType *p_beta, const OutGridDesc, const AccDataType epsilon, const EmbElementwiseOperation emb_elementwise_op)
Definition gridwise_sparse_embeddings_forward_layernorm_builtins.hpp:92
ck::GridwiseSparseEmbeddingsForwardLayernorm< EmbType, IndexType, GammaDataType, BetaDataType, AccDataType, OutType, decltype(MakeOutputDescriptor(1, 1)), EmbElementwiseOperation, BlockSize, DimClusterSize, RowClusterSize, DimPerBlock, RowPerBlock, DimThreadSize, RowVectorSize, NumEmbeddings >::ThreadClusterLength Sequence< DimClusterSize, RowClusterSize > ThreadClusterLength
Definition gridwise_sparse_embeddings_forward_layernorm.hpp:87
static constexpr index_t WaveSize
Definition gridwise_sparse_embeddings_forward_layernorm.hpp:62
ck::GridwiseSparseEmbeddingsForwardLayernorm< EmbType, IndexType, GammaDataType, BetaDataType, AccDataType, OutType, decltype(MakeOutputDescriptor(1, 1)), EmbElementwiseOperation, BlockSize, DimClusterSize, RowClusterSize, DimPerBlock, RowPerBlock, DimThreadSize, RowVectorSize, NumEmbeddings >::ThreadwiseWelford ThreadwiseWelford< AccDataType, ThreadwiseWolfordDesc2D, ThreadwiseWolfordDescReduce > ThreadwiseWelford
Definition gridwise_sparse_embeddings_forward_layernorm.hpp:84
static constexpr auto I3
Definition gridwise_sparse_embeddings_forward_layernorm.hpp:61
Definition utility/sequence.hpp:43
Definition static_buffer.hpp:16
Definition functional2.hpp:33