gemm_kernel.hpp Source File

gemm_kernel.hpp Source File#

Composable Kernel: gemm_kernel.hpp Source File
gemm_kernel.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include <iostream>
7#include <string>
8
9#include "ck_tile/core.hpp"
17
18namespace ck_tile {
19
29{
31 CK_TILE_HOST GemmHostArgs(const void* a_ptr_,
32 const void* b_ptr_,
33 void* e_ptr_,
34 index_t k_batch_,
35 index_t M_,
36 index_t N_,
37 index_t K_,
38 index_t stride_A_,
39 index_t stride_B_,
40 index_t stride_E_)
41 : a_ptr(a_ptr_),
42 b_ptr(b_ptr_),
43 e_ptr(e_ptr_),
44 M(M_),
45 N(N_),
46 K(K_),
47 stride_A(stride_A_),
48 stride_B(stride_B_),
49 stride_E(stride_E_),
50 k_batch(k_batch_)
51 {
52 }
53
54 const void* a_ptr;
55 const void* b_ptr;
56 union
57 {
58 void* e_ptr;
59 void* c_ptr;
60 };
61
67
68 union
69 {
72 };
73
75};
76
77template <typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_>
79{
84
88
93
98
100 static_assert(
102 "ALayout and ADataType must be scalars. Multiple parameters are not currently supported.");
103
105 static_assert(
107 "BLayout and BDataType must be scalars. Multiple parameters are not currently supported.");
108
112 "C/CLayout and C/EDataType must be scalars.");
113
114 static constexpr index_t NumATensor = 1;
115 static constexpr index_t NumBTensor = 1;
117
118 CK_TILE_HOST static auto GetName() -> const std::string
119 {
121 }
122
123 CK_TILE_HOST static constexpr auto GridSize(index_t M, index_t N, index_t KBatch) -> dim3
124 {
125 return UniversalGemmKernel::GridSize(M, N, KBatch);
126 }
127
128 CK_TILE_HOST static auto MaxOccupancyGridSize(const stream_config& s) -> dim3
129 {
131 }
132
133 CK_TILE_HOST static constexpr auto BlockSize() -> dim3
134 {
136 }
137
138 CK_TILE_HOST static constexpr auto MakeKernelArgs(const GemmHostArgs& hostArgs) ->
140 {
145 {hostArgs.a_ptr},
146 {hostArgs.b_ptr},
147 {/*hostArgs.ds_ptr*/},
148 hostArgs.e_ptr,
149 hostArgs.k_batch,
150 hostArgs.M,
151 hostArgs.N,
152 hostArgs.K,
153 {hostArgs.stride_A},
154 {hostArgs.stride_B},
155 {/*hostArgs.stride_Ds*/},
156 hostArgs.stride_E));
157 }
158
159 CK_TILE_HOST static auto
161 {
163 }
164
166 {
167 UniversalGemmKernel{}.template operator()(kargs);
168 }
169};
170} // namespace ck_tile
#define CK_TILE_DEVICE
Definition config.hpp:41
#define CK_TILE_HOST
Definition config.hpp:40
Definition tile/core/algorithm/cluster_descriptor.hpp:13
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
typename detail::detector< nonesuch, void, Op, Args... >::value_t is_detected
Definition type_traits.hpp:67
int32_t index_t
Definition integer.hpp:9
The GEMM kernel host arguments.
Definition gemm_kernel.hpp:29
CK_TILE_HOST GemmHostArgs()=default
void * c_ptr
Definition gemm_kernel.hpp:59
index_t stride_E
Definition gemm_kernel.hpp:70
index_t stride_B
Definition gemm_kernel.hpp:66
index_t stride_C
Definition gemm_kernel.hpp:71
void * e_ptr
Definition gemm_kernel.hpp:58
index_t K
Definition gemm_kernel.hpp:64
index_t M
Definition gemm_kernel.hpp:62
CK_TILE_HOST GemmHostArgs(const void *a_ptr_, const void *b_ptr_, void *e_ptr_, index_t k_batch_, index_t M_, index_t N_, index_t K_, index_t stride_A_, index_t stride_B_, index_t stride_E_)
Definition gemm_kernel.hpp:31
index_t stride_A
Definition gemm_kernel.hpp:65
const void * a_ptr
Definition gemm_kernel.hpp:54
const void * b_ptr
Definition gemm_kernel.hpp:55
index_t N
Definition gemm_kernel.hpp:63
index_t k_batch
Definition gemm_kernel.hpp:74
Definition gemm_kernel.hpp:79
remove_cvref_t< typename EpiloguePipeline::ODataType > EDataType
Definition gemm_kernel.hpp:97
remove_cvref_t< typename GemmPipeline::ADataType > ADataType
Specify the data type configurations for A, B, E and D.
Definition gemm_kernel.hpp:95
static CK_TILE_HOST constexpr auto BlockSize() -> dim3
Definition gemm_kernel.hpp:133
remove_cvref_t< typename GemmPipeline::ALayout > ALayout
Specify the layout configurations for A, B, E and D.
Definition gemm_kernel.hpp:90
static constexpr index_t NumBTensor
Definition gemm_kernel.hpp:115
UniversalGemmKernel< TilePartitioner_, GemmPipeline_, EpiloguePipeline_ > UniversalGemmKernel
Inject the UniversalGemmKernel base class to support execution of all necessary functions.
Definition gemm_kernel.hpp:82
remove_cvref_t< typename GemmPipeline::BDataType > BDataType
Definition gemm_kernel.hpp:96
static constexpr index_t kBlockSize
Definition gemm_kernel.hpp:116
CK_TILE_DEVICE auto operator()(typename UniversalGemmKernel::KernelArgs kargs) const -> void
Definition gemm_kernel.hpp:165
static constexpr index_t NumATensor
ALayout and ADataType are expected to be scalars, not a tuple.
Definition gemm_kernel.hpp:114
static CK_TILE_HOST auto MaxOccupancyGridSize(const stream_config &s) -> dim3
Definition gemm_kernel.hpp:128
static CK_TILE_HOST auto IsSupportedArgument(const typename UniversalGemmKernel::KernelArgs &kargs) -> bool
Definition gemm_kernel.hpp:160
remove_cvref_t< GemmPipeline_ > GemmPipeline
Definition gemm_kernel.hpp:86
remove_cvref_t< typename GemmPipeline::CLayout > CLayout
Definition gemm_kernel.hpp:92
remove_cvref_t< typename GemmPipeline::BLayout > BLayout
Definition gemm_kernel.hpp:91
static CK_TILE_HOST constexpr auto MakeKernelArgs(const GemmHostArgs &hostArgs) -> typename UniversalGemmKernel::KernelArgs
Definition gemm_kernel.hpp:138
static CK_TILE_HOST constexpr auto GridSize(index_t M, index_t N, index_t KBatch) -> dim3
Definition gemm_kernel.hpp:123
static CK_TILE_HOST auto GetName() -> const std::string
Definition gemm_kernel.hpp:118
remove_cvref_t< TilePartitioner_ > TilePartitioner
Definition gemm_kernel.hpp:85
remove_cvref_t< EpiloguePipeline_ > EpiloguePipeline
Definition gemm_kernel.hpp:87
The Universal GEMM kernel host arguments.
Definition universal_gemm_kernel.hpp:32
static CK_TILE_HOST const std::string GetName()
Definition universal_gemm_kernel.hpp:260
static CK_TILE_HOST constexpr auto GridSize(index_t M, index_t N, index_t KBatch)
Definition universal_gemm_kernel.hpp:267
static CK_TILE_HOST auto BlockSize()
Definition universal_gemm_kernel.hpp:290
static CK_TILE_HOST auto MaxOccupancyGridSize(const stream_config &s) -> dim3
Get the maximum occupancy grid size for the persistent kernel on the current device.
Definition universal_gemm_kernel.hpp:278
static CK_TILE_HOST bool IsSupportedArgument(const KernelArgs &kargs)
Definition universal_gemm_kernel.hpp:373
static CK_TILE_HOST constexpr KernelArgs MakeKernelArgs(const UniversalGemmHostArgs< NumATensor, NumBTensor, NumDTensor > &hostArgs)
Definition universal_gemm_kernel.hpp:303
static constexpr index_t kBlockSize
Definition universal_gemm_kernel.hpp:202
UniversalGemmKernelArgs< AsLayout::size(), BsLayout::size(), DsLayout::size()> KernelArgs
Definition universal_gemm_kernel.hpp:257
Definition ck_tile/host/stream_config.hpp:30