device_grouped_gemm_splitk.hpp Source File

device_grouped_gemm_splitk.hpp Source File#

Composable Kernel: device_grouped_gemm_splitk.hpp Source File
device_grouped_gemm_splitk.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
3#pragma once
4
6
7namespace ck {
8namespace tensor_operation {
9namespace device {
10
11template <typename ALayout,
12 typename BLayout,
13 typename DsLayout,
14 typename ELayout,
15 typename ADataType,
16 typename BDataType,
17 typename DsDataType,
18 typename EDataType,
19 typename AElementwiseOperation,
20 typename BElementwiseOperation,
21 typename CElementwiseOperation>
23 BLayout,
24 DsLayout,
25 ELayout,
26 ADataType,
27 BDataType,
28 DsDataType,
29 EDataType,
30 AElementwiseOperation,
31 BElementwiseOperation,
32 CElementwiseOperation>
33{
34 //----------------------------------------------------------------------------------------------
40 virtual void SetKBatchSize(BaseArgument* p_arg, index_t kbatch) const = 0;
41 //----------------------------------------------------------------------------------------------
47 virtual void SetKBatch(BaseArgument* p_arg, index_t kbatch) const
48 {
49 this->SetKBatchSize(p_arg, kbatch);
50 };
51};
52
53} // namespace device
54} // namespace tensor_operation
55} // namespace ck
Definition convolution_backward_data_specialization.hpp:8
Definition convolution_backward_data_specialization.hpp:7
Definition ck.hpp:268
int32_t index_t
Definition ck.hpp:299
Definition device_base.hpp:197
Definition device_grouped_gemm.hpp:99
Definition device_grouped_gemm_splitk.hpp:33
virtual void SetKBatchSize(BaseArgument *p_arg, index_t kbatch) const =0
Sets the k batch size.
virtual void SetKBatch(BaseArgument *p_arg, index_t kbatch) const
Sets the k batch size.
Definition device_grouped_gemm_splitk.hpp:47