32 typename AElementwiseOperation,
33 typename BElementwiseOperation,
34 typename CElementwiseOperation,
44 typename M1N1ThreadClusterM1Xs,
45 typename M1N1ThreadClusterN1Xs,
46 typename ABlockTransferThreadSliceLengths_K0_M0_M1_K1,
47 typename ABlockTransferThreadClusterLengths_K0_M0_M1_K1,
48 typename ABlockTransferThreadClusterArrangeOrder,
49 typename ABlockTransferSrcAccessOrder,
50 typename ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1,
51 typename ABlockTransferSrcVectorTensorContiguousDimOrder,
52 typename ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1,
53 typename BBlockTransferThreadSliceLengths_K0_N0_N1_K1,
54 typename BBlockTransferThreadClusterLengths_K0_N0_N1_K1,
55 typename BBlockTransferThreadClusterArrangeOrder,
56 typename BBlockTransferSrcAccessOrder,
57 typename BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1,
58 typename BBlockTransferSrcVectorTensorContiguousDimOrder,
59 typename BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1,
60 typename CThreadTransferSrcDstAccessOrder,
61 index_t CThreadTransferSrcDstVectorDim,
62 index_t CThreadTransferDstScalarPerVector,
74 AElementwiseOperation,
75 BElementwiseOperation,
76 CElementwiseOperation>
94 const auto a_grid_desc_m_k = [&]() {
107 const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock;
133 const auto b_grid_desc_k_n = [&]() {
146 const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock;
168 const auto c_grid_desc_m_n = [&]() {
181 const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock;
182 const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock;
222 M1N1ThreadClusterM1Xs,
223 M1N1ThreadClusterN1Xs,
224 ABlockTransferThreadSliceLengths_K0_M0_M1_K1,
225 ABlockTransferThreadClusterLengths_K0_M0_M1_K1,
226 ABlockTransferThreadClusterArrangeOrder,
227 ABlockTransferSrcAccessOrder,
228 ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1,
229 ABlockTransferSrcVectorTensorContiguousDimOrder,
230 ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1,
231 BBlockTransferThreadSliceLengths_K0_N0_N1_K1,
232 BBlockTransferThreadClusterLengths_K0_N0_N1_K1,
233 BBlockTransferThreadClusterArrangeOrder,
234 BBlockTransferSrcAccessOrder,
235 BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1,
236 BBlockTransferSrcVectorTensorContiguousDimOrder,
237 BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1,
238 CThreadTransferSrcDstAccessOrder,
239 CThreadTransferSrcDstVectorDim,
240 CThreadTransferDstScalarPerVector>;
255 const BDataType* p_b_grid,
265 AElementwiseOperation a_element_op,
266 BElementwiseOperation b_element_op,
267 CElementwiseOperation c_element_op)
340 std::cout <<
"arg.a_grid_desc_k0_m0_m1_k1_{"
345 std::cout <<
"arg.b_grid_desc_k0_n0_n1_k1_{"
357 throw std::runtime_error(
358 "wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdl_v2r3 has invalid setting");
366 const bool has_double_tail_k_block_loop =
371 if(has_main_k_block_loop && has_double_tail_k_block_loop)
397 else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
423 else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
483 return Run(*
dynamic_cast<const Argument*
>(p_arg), stream_config);
499 constexpr auto A_K_vec_length =
500 ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1::At(
I0) *
501 ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1::At(
I3);
502 if(arg.
K_raw_ % A_K_vec_length != 0)
509 constexpr auto A_M_vec_lenght =
510 ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1::At(
I1) *
511 ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1::At(
I2);
512 if(arg.
M_raw_ % A_M_vec_lenght != 0)
520 constexpr auto B_N_vec_lenght =
521 BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1::At(
I1) *
522 BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1::At(
I2);
523 if(arg.
N_raw_ % B_N_vec_lenght != 0)
530 constexpr auto B_K_vec_length =
531 BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1::At(
I0) *
532 BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1::At(
I3);
533 if(arg.
K_raw_ % B_K_vec_length != 0)
555 const BDataType* p_b,
563 AElementwiseOperation a_element_op,
564 BElementwiseOperation b_element_op,
565 CElementwiseOperation c_element_op)
595 AElementwiseOperation a_element_op,
596 BElementwiseOperation b_element_op,
597 CElementwiseOperation c_element_op)
override
599 return std::make_unique<Argument>(
static_cast<const ADataType*
>(p_a),
600 static_cast<const BDataType*
>(p_b),
601 static_cast<CDataType*
>(p_c),
618 return std::make_unique<Invoker>(
Invoker{});
624 auto str = std::stringstream();
627 str <<
"DeviceGemmDl"
632 << K0PerBlock <<
", "
634 << M1PerThread <<
", "
635 << N1PerThread <<
", "
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:14
Definition convolution_backward_data_specialization.hpp:8
GemmSpecialization
Definition gemm_specialization.hpp:11
@ MNPadding
Definition gemm_specialization.hpp:17
Definition convolution_backward_data_specialization.hpp:7
__host__ __device__ constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition multi_index_transform_helper.hpp:12
int32_t index_t
Definition ck.hpp:299
__host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition tensor_descriptor_helper.hpp:49
@ Set
Definition ck.hpp:278
__host__ __device__ constexpr auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:37
integral_constant< index_t, N > Number
Definition number.hpp:12
std::string get_device_name()
Definition host_utility/device_prop.hpp:19
bool is_gfx12_supported()
Definition host_utility/device_prop.hpp:55
__global__ void kernel_gemm_dl_v1r3(const FloatAB *__restrict__ p_a_grid, const FloatAB *__restrict__ p_b_grid, FloatC *__restrict__ p_c_grid, const AGridDesc_K0_M0_M1_K1 a_grid_desc_k0_m0_m1_k1, const BGridDesc_K0_N0_N1_K1 b_grid_desc_k0_n0_n1_k1, const CGridDesc_M0_M10_M11_N0_N10_N11 c_grid_desc_m0_m10_m11_n0_n10_n11, const Block2CTileMap block_2_ctile_map)
Definition gridwise_gemm_dl_v1r3.hpp:33
bool is_gfx103_supported()
Definition host_utility/device_prop.hpp:120
bool EnvIsEnabled(EnvVar)
Definition utility/env.hpp:140
constexpr bool is_same_v
Definition type.hpp:283
typename remove_reference< T >::type remove_reference_t
Definition type.hpp:292
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
__host__ __device__ constexpr auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:90
bool is_gfx11_supported()
Definition host_utility/device_prop.hpp:60
typename std::enable_if< B, T >::type enable_if_t
Definition enable_if.hpp:27
Definition ck/stream_config.hpp:10
Definition gridwise_gemm_dl_v1r3.hpp:93
ck::GridwiseGemmDl_km_kn_mn_v1r3< BlockSize, ADataType, AccDataType, CDataType, InMemoryDataOperationEnum::Set, AGridDesc_K0_M_K1, BGridDesc_K0_N_K1, CGridDesc_M_N, MPerBlock, NPerBlock, K0PerBlock, K1, M1PerThread, N1PerThread, KPerThread, M1N1ThreadClusterM1Xs, M1N1ThreadClusterN1Xs, ABlockTransferThreadSliceLengths_K0_M0_M1_K1, ABlockTransferThreadClusterLengths_K0_M0_M1_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1, ABlockTransferSrcVectorTensorContiguousDimOrder, ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1, BBlockTransferThreadSliceLengths_K0_N0_N1_K1, BBlockTransferThreadClusterLengths_K0_N0_N1_K1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1, BBlockTransferSrcVectorTensorContiguousDimOrder, BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1, CThreadTransferSrcDstAccessOrder, CThreadTransferSrcDstVectorDim, CThreadTransferDstScalarPerVector >::MakeCGridDescriptor_M0_M10_M11_N0_N10_N11 __host__ static __device__ constexpr auto MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(const CGridDesc_M_N &c_grid_desc_m_n)
Definition gridwise_gemm_dl_v1r3.hpp:208
ck::GridwiseGemmDl_km_kn_mn_v1r3< BlockSize, ADataType, AccDataType, CDataType, InMemoryDataOperationEnum::Set, AGridDesc_K0_M_K1, BGridDesc_K0_N_K1, CGridDesc_M_N, MPerBlock, NPerBlock, K0PerBlock, K1, M1PerThread, N1PerThread, KPerThread, M1N1ThreadClusterM1Xs, M1N1ThreadClusterN1Xs, ABlockTransferThreadSliceLengths_K0_M0_M1_K1, ABlockTransferThreadClusterLengths_K0_M0_M1_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1, ABlockTransferSrcVectorTensorContiguousDimOrder, ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1, BBlockTransferThreadSliceLengths_K0_N0_N1_K1, BBlockTransferThreadClusterLengths_K0_N0_N1_K1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1, BBlockTransferSrcVectorTensorContiguousDimOrder, BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1, CThreadTransferSrcDstAccessOrder, CThreadTransferSrcDstVectorDim, CThreadTransferDstScalarPerVector >::CalculateHasMainKBlockLoop __host__ static __device__ constexpr bool CalculateHasMainKBlockLoop(index_t K0)
Definition gridwise_gemm_dl_v1r3.hpp:153
ck::GridwiseGemmDl_km_kn_mn_v1r3< BlockSize, ADataType, AccDataType, CDataType, InMemoryDataOperationEnum::Set, AGridDesc_K0_M_K1, BGridDesc_K0_N_K1, CGridDesc_M_N, MPerBlock, NPerBlock, K0PerBlock, K1, M1PerThread, N1PerThread, KPerThread, M1N1ThreadClusterM1Xs, M1N1ThreadClusterN1Xs, ABlockTransferThreadSliceLengths_K0_M0_M1_K1, ABlockTransferThreadClusterLengths_K0_M0_M1_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1, ABlockTransferSrcVectorTensorContiguousDimOrder, ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1, BBlockTransferThreadSliceLengths_K0_N0_N1_K1, BBlockTransferThreadClusterLengths_K0_N0_N1_K1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1, BBlockTransferSrcVectorTensorContiguousDimOrder, BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1, CThreadTransferSrcDstAccessOrder, CThreadTransferSrcDstVectorDim, CThreadTransferDstScalarPerVector >::CalculateGridSize __host__ static __device__ constexpr index_t CalculateGridSize(index_t M, index_t N)
Definition gridwise_gemm_dl_v1r3.hpp:146
ck::GridwiseGemmDl_km_kn_mn_v1r3< BlockSize, ADataType, AccDataType, CDataType, InMemoryDataOperationEnum::Set, AGridDesc_K0_M_K1, BGridDesc_K0_N_K1, CGridDesc_M_N, MPerBlock, NPerBlock, K0PerBlock, K1, M1PerThread, N1PerThread, KPerThread, M1N1ThreadClusterM1Xs, M1N1ThreadClusterN1Xs, ABlockTransferThreadSliceLengths_K0_M0_M1_K1, ABlockTransferThreadClusterLengths_K0_M0_M1_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1, ABlockTransferSrcVectorTensorContiguousDimOrder, ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1, BBlockTransferThreadSliceLengths_K0_N0_N1_K1, BBlockTransferThreadClusterLengths_K0_N0_N1_K1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1, BBlockTransferSrcVectorTensorContiguousDimOrder, BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1, CThreadTransferSrcDstAccessOrder, CThreadTransferSrcDstVectorDim, CThreadTransferDstScalarPerVector >::CheckValidity __host__ static __device__ constexpr bool CheckValidity(const AGridDesc_K0_M_K1 &a_grid_desc_k0_m_k1, const BGridDesc_K0_N_K1 &b_grid_desc_k0_n_k1, const CGridDesc_M_N &c_grid_desc_m_n)
Definition gridwise_gemm_dl_v1r3.hpp:129
ck::GridwiseGemmDl_km_kn_mn_v1r3< BlockSize, ADataType, AccDataType, CDataType, InMemoryDataOperationEnum::Set, AGridDesc_K0_M_K1, BGridDesc_K0_N_K1, CGridDesc_M_N, MPerBlock, NPerBlock, K0PerBlock, K1, M1PerThread, N1PerThread, KPerThread, M1N1ThreadClusterM1Xs, M1N1ThreadClusterN1Xs, ABlockTransferThreadSliceLengths_K0_M0_M1_K1, ABlockTransferThreadClusterLengths_K0_M0_M1_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1, ABlockTransferSrcVectorTensorContiguousDimOrder, ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1, BBlockTransferThreadSliceLengths_K0_N0_N1_K1, BBlockTransferThreadClusterLengths_K0_N0_N1_K1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1, BBlockTransferSrcVectorTensorContiguousDimOrder, BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1, CThreadTransferSrcDstAccessOrder, CThreadTransferSrcDstVectorDim, CThreadTransferDstScalarPerVector >::CalculateHasDoubleTailKBlockLoop __host__ static __device__ constexpr bool CalculateHasDoubleTailKBlockLoop(index_t K0)
Definition gridwise_gemm_dl_v1r3.hpp:160
ck::GridwiseGemmDl_km_kn_mn_v1r3< BlockSize, ADataType, AccDataType, CDataType, InMemoryDataOperationEnum::Set, AGridDesc_K0_M_K1, BGridDesc_K0_N_K1, CGridDesc_M_N, MPerBlock, NPerBlock, K0PerBlock, K1, M1PerThread, N1PerThread, KPerThread, M1N1ThreadClusterM1Xs, M1N1ThreadClusterN1Xs, ABlockTransferThreadSliceLengths_K0_M0_M1_K1, ABlockTransferThreadClusterLengths_K0_M0_M1_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1, ABlockTransferSrcVectorTensorContiguousDimOrder, ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1, BBlockTransferThreadSliceLengths_K0_N0_N1_K1, BBlockTransferThreadClusterLengths_K0_N0_N1_K1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1, BBlockTransferSrcVectorTensorContiguousDimOrder, BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1, CThreadTransferSrcDstAccessOrder, CThreadTransferSrcDstVectorDim, CThreadTransferDstScalarPerVector >::MakeDefaultBlock2CTileMap __host__ static __device__ constexpr auto MakeDefaultBlock2CTileMap(const CGridDesc_M_N &c_grid_desc_m_n)
Definition gridwise_gemm_dl_v1r3.hpp:241
ck::GridwiseGemmDl_km_kn_mn_v1r3< BlockSize, ADataType, AccDataType, CDataType, InMemoryDataOperationEnum::Set, AGridDesc_K0_M_K1, BGridDesc_K0_N_K1, CGridDesc_M_N, MPerBlock, NPerBlock, K0PerBlock, K1, M1PerThread, N1PerThread, KPerThread, M1N1ThreadClusterM1Xs, M1N1ThreadClusterN1Xs, ABlockTransferThreadSliceLengths_K0_M0_M1_K1, ABlockTransferThreadClusterLengths_K0_M0_M1_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1, ABlockTransferSrcVectorTensorContiguousDimOrder, ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1, BBlockTransferThreadSliceLengths_K0_N0_N1_K1, BBlockTransferThreadClusterLengths_K0_N0_N1_K1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1, BBlockTransferSrcVectorTensorContiguousDimOrder, BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1, CThreadTransferSrcDstAccessOrder, CThreadTransferSrcDstVectorDim, CThreadTransferDstScalarPerVector >::MakeAGridDescriptor_K0_M0_M1_K1 __host__ static __device__ constexpr auto MakeAGridDescriptor_K0_M0_M1_K1(const AGridDesc_K0_M_K1 &a_grid_desc_k0_m_k1)
Definition gridwise_gemm_dl_v1r3.hpp:168
ck::GridwiseGemmDl_km_kn_mn_v1r3< BlockSize, ADataType, AccDataType, CDataType, InMemoryDataOperationEnum::Set, AGridDesc_K0_M_K1, BGridDesc_K0_N_K1, CGridDesc_M_N, MPerBlock, NPerBlock, K0PerBlock, K1, M1PerThread, N1PerThread, KPerThread, M1N1ThreadClusterM1Xs, M1N1ThreadClusterN1Xs, ABlockTransferThreadSliceLengths_K0_M0_M1_K1, ABlockTransferThreadClusterLengths_K0_M0_M1_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1, ABlockTransferSrcVectorTensorContiguousDimOrder, ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1, BBlockTransferThreadSliceLengths_K0_N0_N1_K1, BBlockTransferThreadClusterLengths_K0_N0_N1_K1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1, BBlockTransferSrcVectorTensorContiguousDimOrder, BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1, CThreadTransferSrcDstAccessOrder, CThreadTransferSrcDstVectorDim, CThreadTransferDstScalarPerVector >::MakeBGridDescriptor_K0_N0_N1_K1 __host__ static __device__ constexpr auto MakeBGridDescriptor_K0_N0_N1_K1(const BGridDesc_K0_N_K1 &b_grid_desc_k0_n_k1)
Definition gridwise_gemm_dl_v1r3.hpp:188
Definition utility/sequence.hpp:43
static constexpr value_type value
Definition utility/integral_constant.hpp:13
Definition device_base.hpp:197
Definition device_gemm_dl.hpp:253
index_t M_raw_
Definition device_gemm_dl.hpp:321
AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_
Definition device_gemm_dl.hpp:307
CGridDesc_M_N c_grid_desc_m_n_
Definition device_gemm_dl.hpp:309
BGridDesc_K0_N0_N1_K1 b_grid_desc_k0_n0_n1_k1_
Definition device_gemm_dl.hpp:312
index_t M01_
Definition device_gemm_dl.hpp:318
index_t N01_
Definition device_gemm_dl.hpp:319
CGridDesc_M0_M10_M11_N0_N10_N11 c_grid_desc_m0_m10_m11_n0_n10_n11_
Definition device_gemm_dl.hpp:313
index_t K_raw_
Definition device_gemm_dl.hpp:323
CDataType * p_c_grid_
Definition device_gemm_dl.hpp:305
index_t N_raw_
Definition device_gemm_dl.hpp:322
BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_
Definition device_gemm_dl.hpp:308
Argument(const ADataType *p_a_grid, const BDataType *p_b_grid, CDataType *p_c_grid, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, index_t StrideC, index_t M01, index_t N01, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op)
Definition device_gemm_dl.hpp:254
const BDataType * p_b_grid_
Definition device_gemm_dl.hpp:304
AElementwiseOperation a_element_op_
Definition device_gemm_dl.hpp:326
BElementwiseOperation b_element_op_
Definition device_gemm_dl.hpp:327
DefaultBlock2CTileMap block_2_ctile_map_
Definition device_gemm_dl.hpp:315
AGridDesc_K0_M0_M1_K1 a_grid_desc_k0_m0_m1_k1_
Definition device_gemm_dl.hpp:311
CElementwiseOperation c_element_op_
Definition device_gemm_dl.hpp:328
const ADataType * p_a_grid_
Definition device_gemm_dl.hpp:303
Definition device_gemm_dl.hpp:333
float Run(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition device_gemm_dl.hpp:336
float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition device_gemm_dl.hpp:480
DeviceGemmDl::Argument Argument
Definition device_gemm_dl.hpp:334
Definition device_gemm_dl.hpp:78
static constexpr auto I0
Definition device_gemm_dl.hpp:79
static constexpr auto I2
Definition device_gemm_dl.hpp:81
decltype(GridwiseGemm::MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(CGridDesc_M_N{})) CGridDesc_M0_M10_M11_N0_N10_N11
Definition device_gemm_dl.hpp:246
virtual std::string GetTypeString() const override
Definition device_gemm_dl.hpp:622
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition device_gemm_dl.hpp:549
decltype(MakeBGridDescriptor_K0_N_K1(1, 1, 1)) BGridDesc_K0_N_K1
Definition device_gemm_dl.hpp:202
static bool IsSupportedArgument(const Argument &arg)
Definition device_gemm_dl.hpp:493
decltype(GridwiseGemm::MakeAGridDescriptor_K0_M0_M1_K1(AGridDesc_K0_M_K1{})) AGridDesc_K0_M0_M1_K1
Definition device_gemm_dl.hpp:242
static constexpr auto I3
Definition device_gemm_dl.hpp:82
static auto MakeInvoker()
Definition device_gemm_dl.hpp:583
decltype(MakeAGridDescriptor_K0_M_K1(1, 1, 1)) AGridDesc_K0_M_K1
Definition device_gemm_dl.hpp:201
static auto MakeBGridDescriptor_K0_N_K1(index_t K, index_t N, index_t StrideB)
Definition device_gemm_dl.hpp:127
GridwiseGemmDl_km_kn_mn_v1r3< BlockSize, ADataType, AccDataType, CDataType, InMemoryDataOperationEnum::Set, AGridDesc_K0_M_K1, BGridDesc_K0_N_K1, CGridDesc_M_N, MPerBlock, NPerBlock, K0PerBlock, K1, M1PerThread, N1PerThread, KPerThread, M1N1ThreadClusterM1Xs, M1N1ThreadClusterN1Xs, ABlockTransferThreadSliceLengths_K0_M0_M1_K1, ABlockTransferThreadClusterLengths_K0_M0_M1_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1, ABlockTransferSrcVectorTensorContiguousDimOrder, ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1, BBlockTransferThreadSliceLengths_K0_N0_N1_K1, BBlockTransferThreadClusterLengths_K0_N0_N1_K1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1, BBlockTransferSrcVectorTensorContiguousDimOrder, BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1, CThreadTransferSrcDstAccessOrder, CThreadTransferSrcDstVectorDim, CThreadTransferDstScalarPerVector > GridwiseGemm
Definition device_gemm_dl.hpp:206
static constexpr auto I5
Definition device_gemm_dl.hpp:84
decltype(MakeCGridDescriptor_M_N(1, 1, 1)) CGridDesc_M_N
Definition device_gemm_dl.hpp:203
static constexpr auto I1
Definition device_gemm_dl.hpp:80
static constexpr bool IsValidCompilationParameter()
Definition device_gemm_dl.hpp:487
static auto MakeArgument(const ADataType *p_a, const BDataType *p_b, CDataType *p_c, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, index_t StrideC, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op)
Definition device_gemm_dl.hpp:554
static constexpr auto I4
Definition device_gemm_dl.hpp:83
decltype(GridwiseGemm::MakeDefaultBlock2CTileMap(CGridDesc_M_N{})) DefaultBlock2CTileMap
Definition device_gemm_dl.hpp:248
decltype(GridwiseGemm::MakeBGridDescriptor_K0_N0_N1_K1(BGridDesc_K0_N_K1{})) BGridDesc_K0_N0_N1_K1
Definition device_gemm_dl.hpp:244
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition device_gemm_dl.hpp:616
static auto MakeCGridDescriptor_M_N(index_t M, index_t N, index_t StrideC)
Definition device_gemm_dl.hpp:166
static auto MakeAGridDescriptor_K0_M_K1(index_t M, index_t K, index_t StrideA)
Definition device_gemm_dl.hpp:88
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, void *p_c, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, index_t StrideC, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op) override
Definition device_gemm_dl.hpp:586
static constexpr auto K1Number
Definition device_gemm_dl.hpp:86
Definition device_gemm.hpp:22
#define CK_ENV(name)
Definition utility/env.hpp:129