DeviceBatchedGemmMultipleDGemmMultipleD< A0Layout, B0Layout, D0sLayout, B1Layout, D1sLayout, E1Layout, A0DataType, B0DataType, D0sDataType, B1DataType, D1sDataType, E1DataType, A0ElementwiseOperation, B0ElementwiseOperation, CDE0ElementwiseOperation, B1ElementwiseOperation, CDE1ElementwiseOperation > Struct Template Reference

DeviceBatchedGemmMultipleDGemmMultipleD&lt; A0Layout, B0Layout, D0sLayout, B1Layout, D1sLayout, E1Layout, A0DataType, B0DataType, D0sDataType, B1DataType, D1sDataType, E1DataType, A0ElementwiseOperation, B0ElementwiseOperation, CDE0ElementwiseOperation, B1ElementwiseOperation, CDE1ElementwiseOperation &gt; Struct Template Reference#

Composable Kernel: ck::tensor_operation::device::DeviceBatchedGemmMultipleDGemmMultipleD< A0Layout, B0Layout, D0sLayout, B1Layout, D1sLayout, E1Layout, A0DataType, B0DataType, D0sDataType, B1DataType, D1sDataType, E1DataType, A0ElementwiseOperation, B0ElementwiseOperation, CDE0ElementwiseOperation, B1ElementwiseOperation, CDE1ElementwiseOperation > Struct Template Reference
ck::tensor_operation::device::DeviceBatchedGemmMultipleDGemmMultipleD< A0Layout, B0Layout, D0sLayout, B1Layout, D1sLayout, E1Layout, A0DataType, B0DataType, D0sDataType, B1DataType, D1sDataType, E1DataType, A0ElementwiseOperation, B0ElementwiseOperation, CDE0ElementwiseOperation, B1ElementwiseOperation, CDE1ElementwiseOperation > Struct Template Referenceabstract

#include <device_batched_gemm_multiple_d_gemm_multiple_d.hpp>

Inheritance diagram for ck::tensor_operation::device::DeviceBatchedGemmMultipleDGemmMultipleD< A0Layout, B0Layout, D0sLayout, B1Layout, D1sLayout, E1Layout, A0DataType, B0DataType, D0sDataType, B1DataType, D1sDataType, E1DataType, A0ElementwiseOperation, B0ElementwiseOperation, CDE0ElementwiseOperation, B1ElementwiseOperation, CDE1ElementwiseOperation >:
ck::tensor_operation::device::BaseOperator ck::tensor_operation::device::DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< A0Layout, B0Layout, D0sLayout, B1Layout, D1sLayout, E1Layout, A0DataType, B0DataType, Acc0DataType, D0sDataType, B1DataType, Acc1DataType, C1ShuffleDataType, D1sDataType, E1DataType, A0ElementwiseOperation, B0ElementwiseOperation, CDE0ElementwiseOperation, B1ElementwiseOperation, CDE1ElementwiseOperation, PadGemm0M, PadGemm0N, PadGemm0K, PadGemm1N, PadGemm1K, NumGemm0KPrefetchStage, BlockSize, Gemm0MPerBlock, Gemm0NPerBlock, Gemm0KPerBlock, Gemm1NPerBlock, Gemm1KPerBlock, A0K1, B0K1, B1K1, Gemm0MPerXdl, Gemm0NPerXdl, Gemm0MXdlPerWave, Gemm0NXdlPerWave, Gemm1NXdlPerWave, A0BlockTransferThreadClusterLengths_AK0_M_AK1, A0BlockTransferThreadClusterArrangeOrder, A0BlockTransferSrcAccessOrder, A0BlockTransferSrcVectorDim, A0BlockTransferSrcScalarPerVector, A0BlockTransferDstScalarPerVector_AK1, A0BlockLdsExtraM, B0BlockTransferThreadClusterLengths_BK0_N_BK1, B0BlockTransferThreadClusterArrangeOrder, B0BlockTransferSrcAccessOrder, B0BlockTransferSrcVectorDim, B0BlockTransferSrcScalarPerVector, B0BlockTransferDstScalarPerVector_BK1, B0BlockLdsExtraN, CDE0BlockTransferSrcVectorDim, CDE0BlockTransferSrcScalaerPerVector, B1BlockTransferThreadClusterLengths_BK0_N_BK1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_BK1, B1BlockLdsExtraN, C1ShuffleMXdlPerWavePerShuffle, C1ShuffleGemm0NXdlPerWavePerShuffle, CDE1ShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDE1ShuffleBlockTransferScalarPerVector_NPerBlock, LoopSched >

Public Member Functions

virtual std::unique_ptr< BaseArgumentMakeArgumentPointer (const void *p_a0, const void *p_b0, std::array< const void *, NumD0Tensor > p_d0s, const void *p_b1, std::array< const void *, NumD1Tensor > p_d1s, void *p_e1, ck::index_t M, ck::index_t N, ck::index_t K, ck::index_t O, ck::index_t Batch, ck::index_t StrideA0, ck::index_t StrideB0, std::array< ck::index_t, NumD0Tensor > StrideD0s, ck::index_t StrideB1, std::array< ck::index_t, NumD1Tensor > StrideD1s, ck::index_t StrideE1, ck::index_t BatchStrideA0, ck::index_t BatchStrideB0, std::array< ck::index_t, NumD0Tensor > BatchStrideD0s, ck::index_t BatchStrideB1, std::array< ck::index_t, NumD1Tensor > BatchStrideD1s, ck::index_t BatchStrideE1, A0ElementwiseOperation a0_element_op, B0ElementwiseOperation b0_element_op, CDE0ElementwiseOperation cde0_element_op, B1ElementwiseOperation b1_element_op, CDE1ElementwiseOperation cde1_element_op)=0
virtual std::unique_ptr< BaseInvokerMakeInvokerPointer ()=0
Public Member Functions inherited from ck::tensor_operation::device::BaseOperator
 BaseOperator ()=default
 BaseOperator (const BaseOperator &)=default
BaseOperatoroperator= (const BaseOperator &)=default
virtual bool IsSupportedArgument (const BaseArgument *)
virtual std::string GetTypeString () const
virtual std::string GetInstanceString () const
virtual std::string GetTypeIdName () const
virtual std::optional< std::string > GetObjectName () const
virtual std::optional< std::string > GetTemplateInfo () const
virtual std::string GetTypeIdHashCode () const
virtual size_t GetWorkSpaceSize (const BaseArgument *) const
virtual void SetWorkSpacePointer (BaseArgument *p_arg, void *p_workspace, const StreamConfig &=StreamConfig{}) const
virtual ~BaseOperator ()

Static Public Attributes

static constexpr index_t NumD0Tensor = D0sDataType::Size()
static constexpr index_t NumD1Tensor = D1sDataType::Size()

Member Function Documentation

◆ MakeArgumentPointer()

template<typename A0Layout, typename B0Layout, typename D0sLayout, typename B1Layout, typename D1sLayout, typename E1Layout, typename A0DataType, typename B0DataType, typename D0sDataType, typename B1DataType, typename D1sDataType, typename E1DataType, typename A0ElementwiseOperation, typename B0ElementwiseOperation, typename CDE0ElementwiseOperation, typename B1ElementwiseOperation, typename CDE1ElementwiseOperation>
virtual std::unique_ptr< BaseArgument > ck::tensor_operation::device::DeviceBatchedGemmMultipleDGemmMultipleD< A0Layout, B0Layout, D0sLayout, B1Layout, D1sLayout, E1Layout, A0DataType, B0DataType, D0sDataType, B1DataType, D1sDataType, E1DataType, A0ElementwiseOperation, B0ElementwiseOperation, CDE0ElementwiseOperation, B1ElementwiseOperation, CDE1ElementwiseOperation >::MakeArgumentPointer ( const void * p_a0,
const void * p_b0,
std::array< const void *, NumD0Tensor > p_d0s,
const void * p_b1,
std::array< const void *, NumD1Tensor > p_d1s,
void * p_e1,
ck::index_t M,
ck::index_t N,
ck::index_t K,
ck::index_t O,
ck::index_t Batch,
ck::index_t StrideA0,
ck::index_t StrideB0,
std::array< ck::index_t, NumD0Tensor > StrideD0s,
ck::index_t StrideB1,
std::array< ck::index_t, NumD1Tensor > StrideD1s,
ck::index_t StrideE1,
ck::index_t BatchStrideA0,
ck::index_t BatchStrideB0,
std::array< ck::index_t, NumD0Tensor > BatchStrideD0s,
ck::index_t BatchStrideB1,
std::array< ck::index_t, NumD1Tensor > BatchStrideD1s,
ck::index_t BatchStrideE1,
A0ElementwiseOperation a0_element_op,
B0ElementwiseOperation b0_element_op,
CDE0ElementwiseOperation cde0_element_op,
B1ElementwiseOperation b1_element_op,
CDE1ElementwiseOperation cde1_element_op )
pure virtual

Implemented in ck::tensor_operation::device::DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< A0Layout, B0Layout, D0sLayout, B1Layout, D1sLayout, E1Layout, A0DataType, B0DataType, Acc0DataType, D0sDataType, B1DataType, Acc1DataType, C1ShuffleDataType, D1sDataType, E1DataType, A0ElementwiseOperation, B0ElementwiseOperation, CDE0ElementwiseOperation, B1ElementwiseOperation, CDE1ElementwiseOperation, PadGemm0M, PadGemm0N, PadGemm0K, PadGemm1N, PadGemm1K, NumGemm0KPrefetchStage, BlockSize, Gemm0MPerBlock, Gemm0NPerBlock, Gemm0KPerBlock, Gemm1NPerBlock, Gemm1KPerBlock, A0K1, B0K1, B1K1, Gemm0MPerXdl, Gemm0NPerXdl, Gemm0MXdlPerWave, Gemm0NXdlPerWave, Gemm1NXdlPerWave, A0BlockTransferThreadClusterLengths_AK0_M_AK1, A0BlockTransferThreadClusterArrangeOrder, A0BlockTransferSrcAccessOrder, A0BlockTransferSrcVectorDim, A0BlockTransferSrcScalarPerVector, A0BlockTransferDstScalarPerVector_AK1, A0BlockLdsExtraM, B0BlockTransferThreadClusterLengths_BK0_N_BK1, B0BlockTransferThreadClusterArrangeOrder, B0BlockTransferSrcAccessOrder, B0BlockTransferSrcVectorDim, B0BlockTransferSrcScalarPerVector, B0BlockTransferDstScalarPerVector_BK1, B0BlockLdsExtraN, CDE0BlockTransferSrcVectorDim, CDE0BlockTransferSrcScalaerPerVector, B1BlockTransferThreadClusterLengths_BK0_N_BK1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_BK1, B1BlockLdsExtraN, C1ShuffleMXdlPerWavePerShuffle, C1ShuffleGemm0NXdlPerWavePerShuffle, CDE1ShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDE1ShuffleBlockTransferScalarPerVector_NPerBlock, LoopSched >.

◆ MakeInvokerPointer()

template<typename A0Layout, typename B0Layout, typename D0sLayout, typename B1Layout, typename D1sLayout, typename E1Layout, typename A0DataType, typename B0DataType, typename D0sDataType, typename B1DataType, typename D1sDataType, typename E1DataType, typename A0ElementwiseOperation, typename B0ElementwiseOperation, typename CDE0ElementwiseOperation, typename B1ElementwiseOperation, typename CDE1ElementwiseOperation>
virtual std::unique_ptr< BaseInvoker > ck::tensor_operation::device::DeviceBatchedGemmMultipleDGemmMultipleD< A0Layout, B0Layout, D0sLayout, B1Layout, D1sLayout, E1Layout, A0DataType, B0DataType, D0sDataType, B1DataType, D1sDataType, E1DataType, A0ElementwiseOperation, B0ElementwiseOperation, CDE0ElementwiseOperation, B1ElementwiseOperation, CDE1ElementwiseOperation >::MakeInvokerPointer ( )
pure virtual

Implemented in ck::tensor_operation::device::DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< A0Layout, B0Layout, D0sLayout, B1Layout, D1sLayout, E1Layout, A0DataType, B0DataType, Acc0DataType, D0sDataType, B1DataType, Acc1DataType, C1ShuffleDataType, D1sDataType, E1DataType, A0ElementwiseOperation, B0ElementwiseOperation, CDE0ElementwiseOperation, B1ElementwiseOperation, CDE1ElementwiseOperation, PadGemm0M, PadGemm0N, PadGemm0K, PadGemm1N, PadGemm1K, NumGemm0KPrefetchStage, BlockSize, Gemm0MPerBlock, Gemm0NPerBlock, Gemm0KPerBlock, Gemm1NPerBlock, Gemm1KPerBlock, A0K1, B0K1, B1K1, Gemm0MPerXdl, Gemm0NPerXdl, Gemm0MXdlPerWave, Gemm0NXdlPerWave, Gemm1NXdlPerWave, A0BlockTransferThreadClusterLengths_AK0_M_AK1, A0BlockTransferThreadClusterArrangeOrder, A0BlockTransferSrcAccessOrder, A0BlockTransferSrcVectorDim, A0BlockTransferSrcScalarPerVector, A0BlockTransferDstScalarPerVector_AK1, A0BlockLdsExtraM, B0BlockTransferThreadClusterLengths_BK0_N_BK1, B0BlockTransferThreadClusterArrangeOrder, B0BlockTransferSrcAccessOrder, B0BlockTransferSrcVectorDim, B0BlockTransferSrcScalarPerVector, B0BlockTransferDstScalarPerVector_BK1, B0BlockLdsExtraN, CDE0BlockTransferSrcVectorDim, CDE0BlockTransferSrcScalaerPerVector, B1BlockTransferThreadClusterLengths_BK0_N_BK1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_BK1, B1BlockLdsExtraN, C1ShuffleMXdlPerWavePerShuffle, C1ShuffleGemm0NXdlPerWavePerShuffle, CDE1ShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDE1ShuffleBlockTransferScalarPerVector_NPerBlock, LoopSched >.

Member Data Documentation

◆ NumD0Tensor

template<typename A0Layout, typename B0Layout, typename D0sLayout, typename B1Layout, typename D1sLayout, typename E1Layout, typename A0DataType, typename B0DataType, typename D0sDataType, typename B1DataType, typename D1sDataType, typename E1DataType, typename A0ElementwiseOperation, typename B0ElementwiseOperation, typename CDE0ElementwiseOperation, typename B1ElementwiseOperation, typename CDE1ElementwiseOperation>
index_t ck::tensor_operation::device::DeviceBatchedGemmMultipleDGemmMultipleD< A0Layout, B0Layout, D0sLayout, B1Layout, D1sLayout, E1Layout, A0DataType, B0DataType, D0sDataType, B1DataType, D1sDataType, E1DataType, A0ElementwiseOperation, B0ElementwiseOperation, CDE0ElementwiseOperation, B1ElementwiseOperation, CDE1ElementwiseOperation >::NumD0Tensor = D0sDataType::Size()
staticconstexpr

◆ NumD1Tensor

template<typename A0Layout, typename B0Layout, typename D0sLayout, typename B1Layout, typename D1sLayout, typename E1Layout, typename A0DataType, typename B0DataType, typename D0sDataType, typename B1DataType, typename D1sDataType, typename E1DataType, typename A0ElementwiseOperation, typename B0ElementwiseOperation, typename CDE0ElementwiseOperation, typename B1ElementwiseOperation, typename CDE1ElementwiseOperation>
index_t ck::tensor_operation::device::DeviceBatchedGemmMultipleDGemmMultipleD< A0Layout, B0Layout, D0sLayout, B1Layout, D1sLayout, E1Layout, A0DataType, B0DataType, D0sDataType, B1DataType, D1sDataType, E1DataType, A0ElementwiseOperation, B0ElementwiseOperation, CDE0ElementwiseOperation, B1ElementwiseOperation, CDE1ElementwiseOperation >::NumD1Tensor = D1sDataType::Size()
staticconstexpr

The documentation for this struct was generated from the following file: