24template <
typename GridwiseGemm,
27 typename ReducePtrsGlobal,
28 typename AElementwiseOperation,
29 typename BElementwiseOperation,
30 typename CElementwiseOperation,
31 typename ReduceInElementwiseOperations,
32 typename ReduceAccElementwiseOperations,
33 typename AGridDesc_AK0_M_AK1,
34 typename BGridDesc_BK0_N_BK1,
35 typename CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
36 typename ReduceGridDescriptor_MBlock_MPerBlock,
37 typename ComputeBasePrtOfBatch,
38 typename Block2CTileMap,
39 bool HasMainK0BlockLoop>
41#if CK_USE_LAUNCH_BOUNDS
45 const FloatAB* __restrict__ p_a_grid,
46 const FloatAB* __restrict__ p_b_grid,
47 FloatC* __restrict__ p_c_grid,
48 ReducePtrsGlobal p_reduces_grid,
50 const AElementwiseOperation a_element_op,
51 const BElementwiseOperation b_element_op,
52 const CElementwiseOperation c_element_op,
53 const ReduceInElementwiseOperations reduce_in_element_ops,
54 const ReduceAccElementwiseOperations reduce_out_element_ops,
55 const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
56 const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
57 const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
58 c_grid_desc_mblock_mperblock_nblock_nperblock,
59 const ReduceGridDescriptor_MBlock_MPerBlock reduce_grid_desc_mblock_mperblock,
60 const ComputeBasePrtOfBatch compute_base_ptr_of_batch_,
61 const Block2CTileMap block_2_ctile_map)
63#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__)
64 if constexpr(GridwiseGemm::template IsValidCompilationParameter<>())
66 const index_t num_blocks_per_batch =
67 __builtin_amdgcn_readfirstlane(
get_grid_size() / batch_count);
69 __builtin_amdgcn_readfirstlane(
get_block_1d_id() / num_blocks_per_batch);
71 const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane(
72 static_cast<long_index_t>(compute_base_ptr_of_batch_.GetABasePtr(g_idx)));
73 const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane(
74 static_cast<long_index_t>(compute_base_ptr_of_batch_.GetBBasePtr(g_idx)));
75 const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane(
76 static_cast<long_index_t>(compute_base_ptr_of_batch_.GetCBasePtr(g_idx)));
78 static_for<0, p_reduces_grid.Size(), 1>{}([&](
auto In) {
79 const long_index_t d_batch_offset = __builtin_amdgcn_readfirstlane(
80 static_cast<long_index_t>(compute_base_ptr_of_batch_.GetDBasePtr(g_idx, In)));
81 p_reduces_grid(In) = p_reduces_grid(In) + d_batch_offset;
84 __shared__
char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
86 GridwiseGemm::template Run<HasMainK0BlockLoop>(
87 p_a_grid + a_batch_offset,
88 p_b_grid + b_batch_offset,
89 p_c_grid + c_batch_offset,
95 reduce_in_element_ops,
96 reduce_out_element_ops,
97 a_grid_desc_ak0_m_ak1,
98 b_grid_desc_bk0_n_bk1,
99 c_grid_desc_mblock_mperblock_nblock_nperblock,
100 reduce_grid_desc_mblock_mperblock,
112 ignore = reduce_in_element_ops;
113 ignore = reduce_out_element_ops;
114 ignore = a_grid_desc_ak0_m_ak1;
115 ignore = b_grid_desc_bk0_n_bk1;
116 ignore = c_grid_desc_mblock_mperblock_nblock_nperblock;
117 ignore = reduce_grid_desc_mblock_mperblock;
118 ignore = compute_base_ptr_of_batch_;
119 ignore = block_2_ctile_map;
126template <
typename ALayout,
132 typename GemmAccDataType,
133 typename CShuffleDataType,
134 typename ReduceAccDataType,
135 typename ReducePtrsGlobal,
136 typename AElementwiseOperation,
137 typename BElementwiseOperation,
138 typename CElementwiseOperation,
139 typename ReduceOperations,
140 typename ReduceInElementwiseOperations,
141 typename ReduceAccElementwiseOperations,
142 typename ReduceGlobalMemoryDataOperation,
155 typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
156 typename ABlockTransferThreadClusterArrangeOrder,
157 typename ABlockTransferSrcAccessOrder,
158 index_t ABlockTransferSrcVectorDim,
159 index_t ABlockTransferSrcScalarPerVector,
160 index_t ABlockTransferDstScalarPerVector_AK1,
161 bool ABlockLdsExtraM,
162 typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
163 typename BBlockTransferThreadClusterArrangeOrder,
164 typename BBlockTransferSrcAccessOrder,
165 index_t BBlockTransferSrcVectorDim,
166 index_t BBlockTransferSrcScalarPerVector,
167 index_t BBlockTransferDstScalarPerVector_BK1,
168 bool BBlockLdsExtraN,
169 index_t CShuffleMXdlPerWavePerShuffle,
170 index_t CShuffleNXdlPerWavePerShuffle,
171 typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
172 index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
173 typename CReduceThreadClusterLengths_MPerBlock_NPerBlock,
174 index_t CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock,
175 index_t CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock,
190 const auto a_grid_desc_mraw_kraw = [&]() {
206 const auto MPad = M - MRaw;
207 const auto KPad = K - KRaw;
213 assert(K % AK1 == 0);
215 const auto AK0 = K / AK1;
217 const auto a_grid_desc_m_k =
224 const auto a_grid_desc_ak0_m_ak1 =
231 return a_grid_desc_ak0_m_ak1;
237 assert(KRaw % AK1 == 0);
239 const auto AK0 = KRaw / AK1;
241 const auto a_grid_desc_ak0_m_ak1 =
248 return a_grid_desc_ak0_m_ak1;
254 assert(K % AK1 == 0);
256 const auto AK0 = K / AK1;
259 a_grid_desc_mraw_kraw,
264 const auto a_grid_desc_ak0_m_ak1 =
271 return a_grid_desc_ak0_m_ak1;
276 assert(KRaw % AK1 == 0);
278 const auto AK0 = KRaw / AK1;
280 const auto a_grid_desc_ak0_m_ak1 =
287 return a_grid_desc_ak0_m_ak1;
293 const auto b_grid_desc_nraw_kraw = [&]() {
309 const auto NPad = N - NRaw;
310 const auto KPad = K - KRaw;
316 assert(K % BK1 == 0);
318 const auto BK0 = K / BK1;
320 const auto b_grid_desc_n_k =
327 const auto b_grid_desc_bk0_n_bk1 =
334 return b_grid_desc_bk0_n_bk1;
340 assert(KRaw % BK1 == 0);
342 const auto BK0 = KRaw / BK1;
344 const auto b_grid_desc_bk0_n_bk1 =
351 return b_grid_desc_bk0_n_bk1;
357 assert(K % BK1 == 0);
359 const auto BK0 = K / BK1;
362 b_grid_desc_nraw_kraw,
367 const auto b_grid_desc_bk0_n_bk1 =
374 return b_grid_desc_bk0_n_bk1;
379 assert(KRaw % BK1 == 0);
381 const auto BK0 = KRaw / BK1;
383 const auto b_grid_desc_bk0_n_bk1 =
390 return b_grid_desc_bk0_n_bk1;
396 const auto c_grid_desc_mraw_nraw = [&]() {
412 const auto MPad = M - MRaw;
413 const auto NPad = N - NRaw;
430 c_grid_desc_mraw_nraw,
440 c_grid_desc_mraw_nraw,
448 return c_grid_desc_mraw_nraw;
458 const auto MPad = M - MRaw;
474 return d_grid_desc_mraw;
489 : BatchStrideA_(BatchStrideA),
490 BatchStrideB_(BatchStrideB),
491 BatchStrideC_(BatchStrideC),
492 BatchStrideD_(BatchStrideD)
498 return g_idx *
static_cast<long_index_t>(BatchStrideA_);
503 return g_idx *
static_cast<long_index_t>(BatchStrideB_);
508 return g_idx *
static_cast<long_index_t>(BatchStrideC_);
517 return g_idx *
static_cast<long_index_t>(BatchStrideD_);
528 template <index_t NXdlPerWave_>
536 AElementwiseOperation,
537 BElementwiseOperation,
538 CElementwiseOperation,
540 ReduceInElementwiseOperations,
541 ReduceAccElementwiseOperations,
543 ReduceGlobalMemoryDataOperation,
548 NumGemmKPrefetchStage,
559 ABlockTransferThreadClusterLengths_AK0_M_AK1,
560 ABlockTransferThreadClusterArrangeOrder,
561 ABlockTransferSrcAccessOrder,
562 ABlockTransferSrcVectorDim,
563 ABlockTransferSrcScalarPerVector,
564 ABlockTransferDstScalarPerVector_AK1,
567 BBlockTransferThreadClusterLengths_BK0_N_BK1,
568 BBlockTransferThreadClusterArrangeOrder,
569 BBlockTransferSrcAccessOrder,
570 BBlockTransferSrcVectorDim,
571 BBlockTransferSrcScalarPerVector,
572 BBlockTransferDstScalarPerVector_BK1,
575 CShuffleMXdlPerWavePerShuffle,
576 CShuffleNXdlPerWavePerShuffle,
577 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
578 CShuffleBlockTransferScalarPerVector_NPerBlock,
579 CReduceThreadClusterLengths_MPerBlock_NPerBlock,
580 CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock,
581 CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock,
590 const BDataType* p_b_grid,
592 ReducePtrsGlobal p_reduces_grid,
599 AElementwiseOperation a_element_op,
600 BElementwiseOperation b_element_op,
601 CElementwiseOperation c_element_op,
602 ReduceInElementwiseOperations reduce_in_element_ops,
603 ReduceAccElementwiseOperations reduce_out_element_ops,
652 template <
typename Gr
idwiseGemm>
660 throw std::runtime_error(
"wrong! GridwiseGemm has invalid setting");
663 auto c_grid_desc_mblock_mperblock_nblock_nperblock =
664 GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
667 auto reduce_grid_desc_mblock_mperblock =
673 std::cout <<
"arg.Batch_ = " << arg.
Batch_ << std::endl;
675 std::cout <<
"arg.a_grid_desc_ak0_m_ak1_{"
680 std::cout <<
"arg.b_grid_desc_bk0_n_bk1_{"
688 std::cout <<
"arg.reduce_grid_desc_m_{ "
698 float elapsed_time = 0.0f;
699 if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
706 AElementwiseOperation,
707 BElementwiseOperation,
708 CElementwiseOperation,
709 ReduceInElementwiseOperations,
710 ReduceAccElementwiseOperations,
713 typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
714 typename GridwiseGemm::ReduceGridDescriptor_MBlock_MPerBlock,
715 ComputeBasePtrOfStridedBatch,
716 typename GridwiseGemm::DefaultBlock2CTileMap,
736 c_grid_desc_mblock_mperblock_nblock_nperblock,
737 reduce_grid_desc_mblock_mperblock,
748 AElementwiseOperation,
749 BElementwiseOperation,
750 CElementwiseOperation,
751 ReduceInElementwiseOperations,
752 ReduceAccElementwiseOperations,
755 typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
756 typename GridwiseGemm::ReduceGridDescriptor_MBlock_MPerBlock,
757 ComputeBasePtrOfStridedBatch,
758 typename GridwiseGemm::DefaultBlock2CTileMap,
778 c_grid_desc_mblock_mperblock_nblock_nperblock,
779 reduce_grid_desc_mblock_mperblock,
793 return Run(*
dynamic_cast<const Argument*
>(p_arg), stream_config);
835 auto casted_p_arg =
dynamic_cast<const Argument*
>(p_arg);
836 if(casted_p_arg ==
nullptr)
846 static constexpr int NumReduce = ReduceOperations::Size();
850 std::array<const void*, 0> p_ds,
852 std::array<void*, NumReduce> p_reduces,
859 std::array<ck::index_t, 0> StrideDs,
860 std::array<void*, 3> gemm_element_ops,
861 std::array<void*, 0> d_element_ops,
862 std::array<void*, NumReduce> reduce_in_element_op,
863 std::array<void*, NumReduce> reduce_out_element_op,
873 auto tmp = ReducePtrsGlobal{}[I];
875 return static_cast<T*
>(p_reduces[I]);
879 ReduceInElementwiseOperations reduce_in_element_ops =
generate_tuple(
881 auto tmp = ReduceInElementwiseOperations{}[I];
883 return *(
static_cast<T*
>(reduce_in_element_op[I]));
886 ReduceAccElementwiseOperations reduce_out_element_ops =
generate_tuple(
888 auto tmp = ReduceAccElementwiseOperations{}[I];
890 return *(
static_cast<T*
>(reduce_out_element_op[I]));
894 AElementwiseOperation a_element_op =
895 *(
static_cast<AElementwiseOperation*
>(gemm_element_ops[0]));
896 BElementwiseOperation b_element_op =
897 *(
static_cast<BElementwiseOperation*
>(gemm_element_ops[1]));
898 CElementwiseOperation c_element_op =
899 *(
static_cast<CElementwiseOperation*
>(gemm_element_ops[2]));
901 return Argument{
static_cast<const ADataType*
>(p_a),
902 static_cast<const BDataType*
>(p_b),
903 static_cast<CDataType*
>(p_c),
914 reduce_in_element_ops,
915 reduce_out_element_ops,
922 std::unique_ptr<BaseArgument>
926 std::array<const void*, 0> p_ds,
928 std::array<void*, NumReduce> p_reduces,
935 std::array<ck::index_t, 0> StrideDs,
936 std::array<void*, 3> gemm_element_ops,
937 std::array<void*, 0> d_element_ops,
938 std::array<void*, NumReduce> reduce_in_element_op,
939 std::array<void*, NumReduce> reduce_out_element_op,
949 auto tmp = ReducePtrsGlobal{}[I];
951 return static_cast<T*
>(p_reduces[I]);
955 ReduceInElementwiseOperations reduce_in_element_ops =
generate_tuple(
957 auto tmp = ReduceInElementwiseOperations{}[I];
959 return *(
static_cast<T*
>(reduce_in_element_op[I]));
962 ReduceAccElementwiseOperations reduce_out_element_ops =
generate_tuple(
964 auto tmp = ReduceAccElementwiseOperations{}[I];
966 return *(
static_cast<T*
>(reduce_out_element_op[I]));
970 AElementwiseOperation a_element_op =
971 *(
static_cast<AElementwiseOperation*
>(gemm_element_ops[0]));
972 BElementwiseOperation b_element_op =
973 *(
static_cast<BElementwiseOperation*
>(gemm_element_ops[1]));
974 CElementwiseOperation c_element_op =
975 *(
static_cast<CElementwiseOperation*
>(gemm_element_ops[2]));
977 return std::make_unique<Argument>(
static_cast<const ADataType*
>(p_a),
978 static_cast<const BDataType*
>(p_b),
979 static_cast<CDataType*
>(p_c),
990 reduce_in_element_ops,
991 reduce_out_element_ops,
998 return std::make_unique<Invoker>(
Invoker{});
1004 auto str = std::stringstream();
1007 str <<
"DeviceBatchedGemmReduce_Xdl_CShuffle"
1009 << BlockSize <<
", "
1010 << MPerBlock <<
", "
1011 << NPerBlock <<
", "
1012 << KPerBlock <<
", "
#define CK_MIN_BLOCK_PER_CU
Definition ck.hpp:31
#define CK_MAX_THREAD_PER_BLOCK
Definition ck.hpp:30
#define GET_NXDL_PER_WAVE_IMPL
Definition device_base.hpp:81
#define INVOKER_RUN_IMPL
Definition device_base.hpp:94
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:14
__host__ __device__ constexpr T max(T x)
Definition utility/math.hpp:84
__host__ __device__ constexpr auto integer_divide_ceil(X x, Y y)
Definition utility/math.hpp:72
Definition convolution_backward_data_specialization.hpp:8
GemmSpecialization
Definition gemm_specialization.hpp:11
@ MKPadding
Definition gemm_specialization.hpp:18
@ KPadding
Definition gemm_specialization.hpp:16
@ NPadding
Definition gemm_specialization.hpp:15
@ MPadding
Definition gemm_specialization.hpp:14
@ MNKPadding
Definition gemm_specialization.hpp:20
@ MNPadding
Definition gemm_specialization.hpp:17
@ NKPadding
Definition gemm_specialization.hpp:19
__global__ void kernel_batched_gemm_reduce_xdl_cshuffle_v1(const FloatAB *__restrict__ p_a_grid, const FloatAB *__restrict__ p_b_grid, FloatC *__restrict__ p_c_grid, ReducePtrsGlobal p_reduces_grid, const index_t batch_count, const AElementwiseOperation a_element_op, const BElementwiseOperation b_element_op, const CElementwiseOperation c_element_op, const ReduceInElementwiseOperations reduce_in_element_ops, const ReduceAccElementwiseOperations reduce_out_element_ops, const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock, const ReduceGridDescriptor_MBlock_MPerBlock reduce_grid_desc_mblock_mperblock, const ComputeBasePrtOfBatch compute_base_ptr_of_batch_, const Block2CTileMap block_2_ctile_map)
Definition device_batched_gemm_reduce_xdl_cshuffle.hpp:44
Definition convolution_backward_data_specialization.hpp:7
__host__ __device__ constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition multi_index_transform_helper.hpp:12
__device__ index_t get_grid_size()
Definition get_id.hpp:49
int32_t index_t
Definition ck.hpp:299
__host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition tensor_descriptor_helper.hpp:49
@ Set
Definition ck.hpp:278
__host__ __device__ constexpr auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:37
integral_constant< index_t, N > Number
Definition number.hpp:12
constexpr detail::ignore_t ignore
Definition utility/ignore.hpp:20
__device__ index_t get_block_1d_id()
Definition get_id.hpp:47
__host__ __device__ constexpr Y type_convert(X x)
Definition utility/type_convert.hpp:98
bool is_xdl_wmma_supported()
Definition host_utility/device_prop.hpp:76
__device__ constexpr index_t get_warp_size()
Definition get_id.hpp:10
bool EnvIsEnabled(EnvVar)
Definition utility/env.hpp:140
constexpr bool is_same_v
Definition type.hpp:283
__host__ __device__ constexpr auto generate_tuple(F &&f, Number< N >)
Definition tuple_helper.hpp:21
__host__ __device__ constexpr auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition tensor_descriptor_helper.hpp:101
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
LoopScheduler
Definition loop_scheduler.hpp:15
int64_t long_index_t
Definition ck.hpp:300
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
__host__ __device__ constexpr auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:90
typename remove_pointer< T >::type remove_pointer_t
Definition type.hpp:300
constexpr LoopScheduler make_default_loop_scheduler()
Definition loop_scheduler.hpp:20
Definition ck/stream_config.hpp:10
Definition gridwise_gemm_reduce_xdl_cshuffle_v1.hpp:152
ck::GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1< ADataType, GemmAccDataType, CShuffleDataType, CDataType, ReduceAccDataType, ReducePtrsGlobal, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, ReduceOperations, ReduceInElementwiseOperations, ReduceAccElementwiseOperations, InMemoryDataOperationEnum::Set, ReduceGlobalMemoryDataOperation, AGridDesc_AK0_M_AK1, BGridDesc_BK0_N_BK1, CGridDesc_M_N, ReduceGridDesc_M, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, CReduceThreadClusterLengths_MPerBlock_NPerBlock, CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock, CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock, LoopSched >::DefaultBlock2CTileMap remove_cvref_t< decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}))> DefaultBlock2CTileMap
Definition gridwise_gemm_reduce_xdl_cshuffle_v1.hpp:347
ck::GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1< ADataType, GemmAccDataType, CShuffleDataType, CDataType, ReduceAccDataType, ReducePtrsGlobal, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, ReduceOperations, ReduceInElementwiseOperations, ReduceAccElementwiseOperations, InMemoryDataOperationEnum::Set, ReduceGlobalMemoryDataOperation, AGridDesc_AK0_M_AK1, BGridDesc_BK0_N_BK1, CGridDesc_M_N, ReduceGridDesc_M, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, CReduceThreadClusterLengths_MPerBlock_NPerBlock, CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock, CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock, LoopSched >::CheckValidity __host__ static __device__ constexpr bool CheckValidity(const AGridDesc_AK0_M_AK1 &a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_BK1 &b_grid_desc_bk0_n_bk1, const CGridDesc_M_N &c_grid_desc_m_n, const Block2CTileMap &block_2_ctile_map)
Definition gridwise_gemm_reduce_xdl_cshuffle_v1.hpp:251
Definition utility/sequence.hpp:43
static constexpr value_type value
Definition utility/integral_constant.hpp:13
Definition functional2.hpp:33
Definition device_base.hpp:197
Definition device_batched_gemm_reduce_xdl_cshuffle.hpp:588
ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch_
Definition device_batched_gemm_reduce_xdl_cshuffle.hpp:638
ReduceGridDesc_M reduce_grid_desc_m_
Definition device_batched_gemm_reduce_xdl_cshuffle.hpp:637
BElementwiseOperation b_element_op_
Definition device_batched_gemm_reduce_xdl_cshuffle.hpp:641
ReducePtrsGlobal p_reduces_grid_
Definition device_batched_gemm_reduce_xdl_cshuffle.hpp:632
index_t Batch_
Definition device_batched_gemm_reduce_xdl_cshuffle.hpp:633
CGridDesc_M_N c_grid_desc_m_n_
Definition device_batched_gemm_reduce_xdl_cshuffle.hpp:636
CDataType * p_c_grid_
Definition device_batched_gemm_reduce_xdl_cshuffle.hpp:631
GridwiseGemm64::DefaultBlock2CTileMap block_2_ctile_map_
Definition device_batched_gemm_reduce_xdl_cshuffle.hpp:639
CElementwiseOperation c_element_op_
Definition device_batched_gemm_reduce_xdl_cshuffle.hpp:642
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_
Definition device_batched_gemm_reduce_xdl_cshuffle.hpp:635
Argument(const ADataType *p_a_grid, const BDataType *p_b_grid, CDataType *p_c_grid, ReducePtrsGlobal p_reduces_grid, index_t MRaw, index_t NRaw, index_t KRaw, index_t StrideA, index_t StrideB, index_t StrideC, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op, ReduceInElementwiseOperations reduce_in_element_ops, ReduceAccElementwiseOperations reduce_out_element_ops, index_t Batch)
Definition device_batched_gemm_reduce_xdl_cshuffle.hpp:589
const ADataType * p_a_grid_
Definition device_batched_gemm_reduce_xdl_cshuffle.hpp:629
ReduceAccElementwiseOperations reduce_out_element_ops_
Definition device_batched_gemm_reduce_xdl_cshuffle.hpp:644
ReduceInElementwiseOperations reduce_in_element_ops_
Definition device_batched_gemm_reduce_xdl_cshuffle.hpp:643
const BDataType * p_b_grid_
Definition device_batched_gemm_reduce_xdl_cshuffle.hpp:630
AElementwiseOperation a_element_op_
Definition device_batched_gemm_reduce_xdl_cshuffle.hpp:640
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_
Definition device_batched_gemm_reduce_xdl_cshuffle.hpp:634
Definition device_batched_gemm_reduce_xdl_cshuffle.hpp:484
__host__ __device__ constexpr long_index_t GetBBasePtr(index_t g_idx) const
Definition device_batched_gemm_reduce_xdl_cshuffle.hpp:501
ComputeBasePtrOfStridedBatch(index_t BatchStrideA, index_t BatchStrideB, index_t BatchStrideC, index_t BatchStrideD)
Definition device_batched_gemm_reduce_xdl_cshuffle.hpp:485
__host__ __device__ constexpr long_index_t GetABasePtr(index_t g_idx) const
Definition device_batched_gemm_reduce_xdl_cshuffle.hpp:496
__host__ __device__ constexpr long_index_t GetDBasePtr(index_t g_idx, Number< I > reduction_idx) const
Definition device_batched_gemm_reduce_xdl_cshuffle.hpp:512
__host__ __device__ constexpr long_index_t GetCBasePtr(index_t g_idx) const
Definition device_batched_gemm_reduce_xdl_cshuffle.hpp:506
Definition device_batched_gemm_reduce_xdl_cshuffle.hpp:649
DeviceOp::Argument Argument
Definition device_batched_gemm_reduce_xdl_cshuffle.hpp:650
INVOKER_RUN_IMPL float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition device_batched_gemm_reduce_xdl_cshuffle.hpp:790
float RunImp(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition device_batched_gemm_reduce_xdl_cshuffle.hpp:653
Definition device_batched_gemm_reduce_xdl_cshuffle.hpp:178
static auto MakeInvoker()
Definition device_batched_gemm_reduce_xdl_cshuffle.hpp:919
static auto MakeBGridDescriptor_BK0_N_BK1(index_t KRaw, index_t NRaw, index_t StrideB)
Definition device_batched_gemm_reduce_xdl_cshuffle.hpp:291
static constexpr bool IsValidCompilationParameter()
Definition device_batched_gemm_reduce_xdl_cshuffle.hpp:797
static auto MakeAGridDescriptor_AK0_M_AK1(index_t MRaw, index_t KRaw, index_t StrideA)
Definition device_batched_gemm_reduce_xdl_cshuffle.hpp:188
static constexpr auto I0
Definition device_batched_gemm_reduce_xdl_cshuffle.hpp:184
std::string GetTypeString() const override
Definition device_batched_gemm_reduce_xdl_cshuffle.hpp:1002
static constexpr int NumReduce
Definition device_batched_gemm_reduce_xdl_cshuffle.hpp:846
static auto MakeArgument(const void *p_a, const void *p_b, const void *p_bias, std::array< const void *, 0 > p_ds, void *p_c, std::array< void *, NumReduce > p_reduces, ck::index_t M, ck::index_t N, ck::index_t K, ck::index_t StrideA, ck::index_t StrideB, ck::index_t StrideC, std::array< ck::index_t, 0 > StrideDs, std::array< void *, 3 > gemm_element_ops, std::array< void *, 0 > d_element_ops, std::array< void *, NumReduce > reduce_in_element_op, std::array< void *, NumReduce > reduce_out_element_op, index_t Batch)
Definition device_batched_gemm_reduce_xdl_cshuffle.hpp:847
decltype(MakeReduceGridDescriptor_M(1)) ReduceGridDesc_M
Definition device_batched_gemm_reduce_xdl_cshuffle.hpp:481
GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1< ADataType, GemmAccDataType, CShuffleDataType, CDataType, ReduceAccDataType, ReducePtrsGlobal, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, ReduceOperations, ReduceInElementwiseOperations, ReduceAccElementwiseOperations, InMemoryDataOperationEnum::Set, ReduceGlobalMemoryDataOperation, AGridDesc_AK0_M_AK1, BGridDesc_BK0_N_BK1, CGridDesc_M_N, ReduceGridDesc_M, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, CReduceThreadClusterLengths_MPerBlock_NPerBlock, CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock, CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock, LoopSched > GridwiseGemmBase
Definition device_batched_gemm_reduce_xdl_cshuffle.hpp:529
static auto MakeCGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideC)
Definition device_batched_gemm_reduce_xdl_cshuffle.hpp:394
decltype(MakeAGridDescriptor_AK0_M_AK1(1, 1, 1)) AGridDesc_AK0_M_AK1
Definition device_batched_gemm_reduce_xdl_cshuffle.hpp:478
static bool IsSupportedArgument(const Argument &arg)
Definition device_batched_gemm_reduce_xdl_cshuffle.hpp:803
GridwiseGemmBase< NXdlPerWave32 > GridwiseGemm32
Definition device_batched_gemm_reduce_xdl_cshuffle.hpp:584
DeviceBatchedGemmReduce_Xdl_CShuffle DeviceOp
Definition device_batched_gemm_reduce_xdl_cshuffle.hpp:179
static constexpr auto I2
Definition device_batched_gemm_reduce_xdl_cshuffle.hpp:186
GridwiseGemmBase< math::max(NXdlPerWave64, 1)> GridwiseGemm64
Definition device_batched_gemm_reduce_xdl_cshuffle.hpp:583
decltype(MakeBGridDescriptor_BK0_N_BK1(1, 1, 1)) BGridDesc_BK0_N_BK1
Definition device_batched_gemm_reduce_xdl_cshuffle.hpp:479
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition device_batched_gemm_reduce_xdl_cshuffle.hpp:996
static auto MakeReduceGridDescriptor_M(index_t MRaw)
Definition device_batched_gemm_reduce_xdl_cshuffle.hpp:453
static GET_NXDL_PER_WAVE_IMPL constexpr auto NXdlPerWave64
Definition device_batched_gemm_reduce_xdl_cshuffle.hpp:181
static constexpr auto NXdlPerWave32
Definition device_batched_gemm_reduce_xdl_cshuffle.hpp:182
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, const void *p_bias, std::array< const void *, 0 > p_ds, void *p_c, std::array< void *, NumReduce > p_reduces, ck::index_t M, ck::index_t N, ck::index_t K, ck::index_t StrideA, ck::index_t StrideB, ck::index_t StrideC, std::array< ck::index_t, 0 > StrideDs, std::array< void *, 3 > gemm_element_ops, std::array< void *, 0 > d_element_ops, std::array< void *, NumReduce > reduce_in_element_op, std::array< void *, NumReduce > reduce_out_element_op, index_t Batch=1) override
Definition device_batched_gemm_reduce_xdl_cshuffle.hpp:923
decltype(MakeCGridDescriptor_M_N(1, 1, 1)) CGridDesc_M_N
Definition device_batched_gemm_reduce_xdl_cshuffle.hpp:480
static constexpr auto I1
Definition device_batched_gemm_reduce_xdl_cshuffle.hpp:185
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition device_batched_gemm_reduce_xdl_cshuffle.hpp:833
Definition device_gemm_reduce.hpp:17
#define CK_ENV(name)
Definition utility/env.hpp:129