gemm_pipeline_ag_bg_cr_comp_v5.hpp Source File

gemm_pipeline_ag_bg_cr_comp_v5.hpp Source File#

Composable Kernel: gemm_pipeline_ag_bg_cr_comp_v5.hpp Source File
gemm_pipeline_ag_bg_cr_comp_v5.hpp
Go to the documentation of this file.
1// Copyright © Advanced Micro Devices, Inc., or its affiliates.
2// SPDX-License-Identifier: MIT
3
4#include "ck_tile/core.hpp"
9
10namespace ck_tile {
11// A Tile Window: global memory
12// B Tile Window: global memory
13// C Distributed Tensor: register
14
15template <typename Problem>
17{
18 static constexpr index_t PrefetchStages = 1;
19 static constexpr index_t PrefillStages = 1;
20 static constexpr index_t GlobalBufferNum = 1;
21
22 CK_TILE_HOST_DEVICE static constexpr auto TransposeC() { return Problem::TransposeC; }
23
24 CK_TILE_HOST_DEVICE static constexpr bool BlockHasHotloop(index_t) { return true; }
25
30
31 template <typename RunFunction>
32 CK_TILE_HOST_DEVICE static auto TailHandler(const RunFunction& run_func, bool, TailNumber)
33 {
35 }
36};
37
38template <typename Problem, typename Policy = GemmPipelineAgBgCrCompV5DefaultPolicy>
40{
43
49
52
56
59
62
63 static constexpr index_t NumWaveGroups = Problem::NumWaveGroups;
64
66 using I0 = number<0>;
67 using I1 = number<1>;
68 using I2 = number<2>;
69
70 static constexpr index_t BlockSize = Problem::kBlockSize;
71
72 static constexpr index_t MPerBlock = BlockGemmShape::kM;
73 static constexpr index_t NPerBlock = BlockGemmShape::kN;
74 static constexpr index_t KPerBlock = BlockGemmShape::kK;
75
76 template <bool IsWave32Host = false>
77 static constexpr index_t GetVectorSizeA()
78 {
79 return Policy::template GetVectorSizeA<Problem, IsWave32Host>();
80 }
81 template <bool IsWave32Host = false>
82 static constexpr index_t GetVectorSizeB()
83 {
84 return Policy::template GetVectorSizeB<Problem, IsWave32Host>();
85 }
86 static constexpr index_t GetVectorSizeC() { return Policy::template GetVectorSizeC<Problem>(); }
87
88 static constexpr bool kPadM = Problem::kPadM;
89 static constexpr bool kPadN = Problem::kPadN;
90 static constexpr bool kPadK = Problem::kPadK;
91
92 static constexpr bool DoubleSmemBuffer = Problem::DoubleSmemBuffer;
93 static constexpr index_t Preshuffle = Problem::Preshuffle;
94
95 static constexpr bool HasHotLoop = Problem::HasHotLoop;
96 static constexpr auto TailNum = Problem::TailNum;
97 static constexpr auto Scheduler = Problem::Scheduler;
98
99 static constexpr index_t NumWarps = BlockGemmShape::NumWarps;
100 static constexpr index_t KTileSize = BlockGemmShape::WarpTile::at(I2{});
101
102 [[nodiscard]] CK_TILE_HOST static const std::string GetName()
103 {
104 // clang-format off
105 return concat('_', "pipeline_AgBgCrCompV5", BlockSize,
107 concat('x', kPadM, kPadN, kPadK));
108 // clang-format on
109 }
110
112 {
113 return Policy::template GetSmemSize<Problem>();
114 }
115
116 CK_TILE_HOST_DEVICE static constexpr auto IsTransposeC()
117 {
118 return Policy::template IsTransposeC<Problem>();
119 }
120
121 template <GemmPipelineScheduler Scheduler>
123 {
124 };
125
126 template <>
128 {
130
131 template <bool HasHotLoop,
133 typename AsDramBlockWindowTmp,
134 typename AElementFunction,
135 typename BsDramBlockWindowTmp,
136 typename BElementFunction,
137 typename std::enable_if_t<is_detected<is_tuple, AsDramBlockWindowTmp>::value &&
139 bool>* = nullptr>
140 CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp,
141 const AElementFunction& a_element_func,
142 const BsDramBlockWindowTmp& b_dram_block_window_tmp,
143 const BElementFunction& b_element_func,
144 index_t num_loop,
145 void* __restrict__ p_smem_0) const
146 {
147 using ADramBlockWindowTmp =
148 remove_cvref_t<std::tuple_element_t<number<0>{}, AsDramBlockWindowTmp>>;
149 using BDramBlockWindowTmp =
150 remove_cvref_t<std::tuple_element_t<number<0>{}, BsDramBlockWindowTmp>>;
151
152 static_assert(
153 std::is_same_v<ADataType, remove_cvref_t<typename ADramBlockWindowTmp::DataType>> &&
154 std::is_same_v<BDataType,
156 "Data Type conflict on A and B matrix input data type.");
157
158 static_assert(
159 KPerBlock % ((NumWarps / 2) * KTileSize) == 0,
160 "Ping Pong Warps, TileSize and Block Size for K dimensions does not match.");
161
162 constexpr bool is_a_col_major =
163 std::is_same_v<ALayout, tensor_layout::gemm::ColumnMajor>;
164 constexpr bool is_b_row_major = std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>;
165
166 static_assert(is_a_col_major
167 ? (KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
168 MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}])
169 : (MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
170 KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}]),
171 "A block window has incorrect lengths for defined ALayout!");
172 static_assert(is_b_row_major
173 ? (KPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
174 NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}])
175 : (NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
176 KPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}]),
177 "B block window has incorrect lengths for defined BLayout!");
178
179 index_t warp_id = get_warp_id();
180 index_t operation_id =
181 amd_wave_read_first_lane(get_warp_id()); // 0 - Memory read, 1 - block-gemm
182
183 auto a_offset = (warp_id == 0) ? make_array(0, 0) : make_array(0, KPerBlock);
184 auto b_offset = (warp_id == 0) ? make_array(0, 0) : make_array(0, KPerBlock);
185
186 auto tensor_views =
187 Base::GetABLdsTensorViews(static_cast<void*>(static_cast<char*>(p_smem_0)));
188 auto& a_lds_block = tensor_views.get(number<0>{});
189 auto& b_lds_block = tensor_views.get(number<1>{});
190
191 constexpr auto a_lds_load_tile_distr =
192 make_static_tile_distribution(BlockGemm::MakeABlockDistributionEncode());
193 constexpr auto b_lds_load_tile_distr =
194 make_static_tile_distribution(BlockGemm::MakeBBlockDistributionEncode());
195
196 auto a_windows = Base::GetAWindows(
197 a_dram_block_window_tmp, a_lds_block, a_lds_load_tile_distr, a_offset);
198 auto& a_copy_dram_window = a_windows.get(number<0>{});
199 auto& a_copy_lds_window = a_windows.get(number<1>{});
200 auto& a_lds_window = a_windows.get(number<2>{});
201
202 auto b_windows = Base::GetBWindows(
203 b_dram_block_window_tmp, b_lds_block, b_lds_load_tile_distr, b_offset);
204 auto& b_copy_dram_window = b_windows.get(number<0>{});
205 auto& b_copy_lds_window = b_windows.get(number<1>{});
206 auto& b_lds_window = b_windows.get(number<2>{});
207
208 // DRAM window steps.
209 using ADramTileWindowStep = typename ADramBlockWindowTmp::BottomTensorIndex;
210 using BDramTileWindowStep = typename BDramBlockWindowTmp::BottomTensorIndex;
211 constexpr ADramTileWindowStep a_dram_tile_window_step =
212 is_a_col_major ? make_array(KPerBlock * NumWarps, 0)
214 constexpr BDramTileWindowStep b_dram_tile_window_step =
215 is_b_row_major ? make_array(KPerBlock * NumWarps, 0)
217
218 constexpr auto AGemmTileDistr = decltype(make_static_tile_distribution(
219 BlockGemm::MakeABlockDistributionEncode())){};
220 constexpr auto BGemmTileDistr = decltype(make_static_tile_distribution(
221 BlockGemm::MakeBBlockDistributionEncode())){};
222
223 using AGemmTile = decltype(make_static_distributed_tensor<ADataType>(AGemmTileDistr));
224 using BGemmTile = decltype(make_static_distributed_tensor<BDataType>(BGemmTileDistr));
225 AGemmTile a_tile_0, a_tile_1;
226 BGemmTile b_tile_0, b_tile_1;
227
228 // Register tile for A and B.
229 using ABlockTileDistr =
230 decltype(a_copy_dram_window[number<0>{}].get_tile_distribution());
231 using BBlockTileDistr =
232 decltype(b_copy_dram_window[number<0>{}].get_tile_distribution());
233 using ABlockTile =
234 decltype(make_static_distributed_tensor<ADataType>(ABlockTileDistr{}));
235 using BBlockTile =
236 decltype(make_static_distributed_tensor<BDataType>(BBlockTileDistr{}));
237 ABlockTile elementwise_As_res;
238 BBlockTile elementwise_Bs_res;
239
240 // Block GEMM
241 auto block_gemm = BlockGemm();
242 auto c_block_tile_0 = block_gemm.MakeCBlockTile();
243 auto c_block_tile_1 = block_gemm.MakeCBlockTile();
244
245 CDataType* __restrict__ p_c_lds = static_cast<CDataType*>(p_smem_0);
246 auto c_lds_block_0 =
251 number<1>{});
252 auto c_window_0 = make_tile_window(c_lds_block_0,
254 {0, 0},
255 c_block_tile_1.get_tile_distribution());
256
257 // initialize C
258 if(warp_id == 0)
259 {
260 tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile_0);
261 }
262 else
263 {
264 tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile_1);
265 }
266
267 // define ping, pong steps here as lambda functions.
268 auto MemoryOpsStep = [&](auto idx) {
269 // Memory read half here.
270
271 // Load tile — during value loading, an elementwise function is executed for each
272 // A0, A1, … AN. The values A0, A1, … AN are read by the same thread.
273 elementwise_As_res = load_tile_with_elementwise(a_copy_dram_window, a_element_func);
274
275 // Move each A — the enhanced function move_tile_window is executed, which takes a
276 // tuple as input.
277 move_tile_window(a_copy_dram_window, a_dram_tile_window_step);
278
279 // Load tile — during value loading, an elementwise function is executed for each
280 // B0, B1, … BN. The values B0, B1, … BN are read by the same thread.
281 elementwise_Bs_res = load_tile_with_elementwise(b_copy_dram_window, b_element_func);
282
283 // Move each B — the enhanced function move_tile_window is executed, which takes a
284 // tuple as input.
285 move_tile_window(b_copy_dram_window, b_dram_tile_window_step);
286
287 if constexpr(is_a_col_major)
288 {
290 Policy::template MakeShuffledARegTileDistribution<Problem>());
291 transpose_tile2d(a_shuffle_tmp, elementwise_As_res);
292 Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp);
293 }
294 else
295 {
296 Base::LocalPrefill(a_copy_lds_window, elementwise_As_res);
297 }
298
299 if constexpr(is_b_row_major)
300 {
302 Policy::template MakeShuffledBRegTileDistribution<Problem>());
303 transpose_tile2d(b_shuffle_tmp, elementwise_Bs_res);
304 Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp);
305 }
306 else
307 {
308 Base::LocalPrefill(b_copy_lds_window, elementwise_Bs_res);
309 }
310
311 if(idx == 0)
312 {
313 Base::LocalPrefetch(a_tile_0, a_lds_window);
314 Base::LocalPrefetch(b_tile_0, b_lds_window);
315 }
316 else
317 {
318 Base::LocalPrefetch(a_tile_1, a_lds_window);
319 Base::LocalPrefetch(b_tile_1, b_lds_window);
320 }
321 };
322
323 auto ComputeStep = [&](auto idx) {
324 if(idx == 0)
325 {
326 block_gemm(c_block_tile_0, a_tile_0, b_tile_0);
327 }
328 else
329 {
330 block_gemm(c_block_tile_1, a_tile_1, b_tile_1);
331 }
332 };
333
334 if(operation_id == 0)
335 {
336 MemoryOpsStep(warp_id);
337 }
338
339 index_t num_compute_steps = amd_wave_read_first_lane(num_loop);
340 while(num_compute_steps > 1)
341 {
343 operation_id = (operation_id + 1) % NumWaveGroups;
344
345 if(operation_id == 0)
346 {
347 MemoryOpsStep(warp_id);
348 }
349 else
350 {
351 ComputeStep(warp_id);
352 }
353 num_compute_steps -= 1;
354 }
356
357 if(operation_id == 0)
358 {
359 ComputeStep(warp_id);
360 }
362
363 if(warp_id == 1)
364 {
365 store_tile(c_window_0, c_block_tile_1);
366 }
368
369 if(warp_id == 0)
370 {
371 load_tile(c_block_tile_1, c_window_0);
372
373 constexpr auto s_spans = decltype(c_block_tile_0)::get_distributed_spans();
374 sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) {
375 sweep_tile_span(s_spans[number<1>{}], [&](auto idx1) {
376 auto idx2 = make_tuple(idx0, idx1);
377 c_block_tile_0(idx2) += c_block_tile_1(idx2);
378 });
379 });
380 }
381 return c_block_tile_0;
382 }
383 };
384
385 public:
386 template <typename AsDramBlockWindowTmp,
387 typename BsDramBlockWindowTmp,
388 typename AElementFunction,
389 typename BElementFunction,
390 typename std::enable_if_t<is_detected<is_tuple, AsDramBlockWindowTmp>::value &&
392 bool>* = nullptr>
393 CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp,
394 const AElementFunction& a_element_func,
395 const BsDramBlockWindowTmp& b_dram_block_window_tmp,
396 const BElementFunction& b_element_func,
397 index_t num_loop,
398 void* p_smem_0) const
399 {
400 return PipelineImpl<Scheduler>{}.template operator()<HasHotLoop, TailNum>(
401 a_dram_block_window_tmp,
402 a_element_func,
403 b_dram_block_window_tmp,
404 b_element_func,
405 num_loop,
406 p_smem_0);
407 }
408
409 template <typename AsDramBlockWindowTmp,
410 typename BsDramBlockWindowTmp,
411 typename std::enable_if_t<is_detected<is_tuple, AsDramBlockWindowTmp>::value &&
413 bool>* = nullptr>
414 CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp,
415 const BsDramBlockWindowTmp& b_dram_block_window_tmp,
416 const index_t num_loop,
417 void* __restrict__ p_smem_0) const
418 {
419 return PipelineImpl<Scheduler>{}.template operator()<HasHotLoop, TailNum>(
420 a_dram_block_window_tmp,
421 [](auto& e, const ADataType& a) { e = a; },
422 b_dram_block_window_tmp,
423 [](auto& e, const BDataType& b) { e = b; },
424 num_loop,
425 p_smem_0);
426 }
427
428 template <typename ADramBlockWindowTmp,
429 typename BDramBlockWindowTmp,
430 typename AElementFunction,
431 typename BElementFunction,
432 typename std::enable_if_t<!is_detected<is_tuple, ADramBlockWindowTmp>::value &&
434 bool>* = nullptr>
435 CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
436 const AElementFunction& a_element_func,
437 const BDramBlockWindowTmp& b_dram_block_window_tmp,
438 const BElementFunction& b_element_func,
439 index_t num_loop,
440 void* p_smem_0) const
441 {
442 return operator()(ck_tile::make_tuple(a_dram_block_window_tmp),
443 a_element_func,
444 ck_tile::make_tuple(b_dram_block_window_tmp),
445 b_element_func,
446 num_loop,
447 p_smem_0);
448 }
449
450 template <typename ADramBlockWindowTmp,
451 typename BDramBlockWindowTmp,
452 typename std::enable_if_t<!is_detected<is_tuple, ADramBlockWindowTmp>::value &&
454 bool>* = nullptr>
455 CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
456 const BDramBlockWindowTmp& b_dram_block_window_tmp,
457 const index_t num_loop,
458 void* __restrict__ p_smem_0) const
459 {
460 return operator()(ck_tile::make_tuple(a_dram_block_window_tmp),
461 ck_tile::make_tuple(b_dram_block_window_tmp),
462 num_loop,
463 p_smem_0);
464 }
465};
466
467} // namespace ck_tile
#define CK_TILE_DEVICE
Definition config.hpp:41
#define CK_TILE_HOST
Definition config.hpp:40
#define CK_TILE_HOST_DEVICE
Definition config.hpp:42
Definition tile/core/algorithm/cluster_descriptor.hpp:13
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
CK_TILE_HOST_DEVICE constexpr auto make_naive_tensor_view(DataType *__restrict__ p, const tuple< Lengths... > &lengths, const tuple< Strides... > &strides, number< GuaranteedLastDimensionVectorLength >=number<-1 >{}, number< GuaranteedLastDimensionVectorStride >=number<-1 >{})
Definition tensor_view.hpp:471
__device__ uint32_t amd_wave_read_first_lane(uint16_t v)
Definition tile/core/arch/amd_buffer_addressing.hpp:35
TailNumber
Definition gemm_pipeline_ag_bg_cr_scheduler.hpp:21
@ Empty
Definition gemm_pipeline_ag_bg_cr_scheduler.hpp:36
CK_TILE_DEVICE auto load_tile_with_elementwise(const TileWindow_ &tile_window, ElementWise_ elementwise, number< i_access >={}, bool_constant< oob_conditional_check >={})
Load tile with elementwise function.
Definition load_tile.hpp:41
constant< b > bool_constant
Definition tile/core/numeric/integral_constant.hpp:43
typename detail::detector< nonesuch, void, Op, Args... >::value_t is_detected
Definition type_traits.hpp:67
CK_TILE_DEVICE index_t get_warp_id(bool_constant< ReturnSgpr >={})
Definition arch.hpp:104
CK_TILE_DEVICE void tile_elementwise_inout(const InOutElementFunc &inout_element_func, InOutDstrTensors &... inout_dstr_tensors)
Definition tile_elementwise.hpp:23
CK_TILE_DEVICE void block_sync_lds()
Definition arch.hpp:282
auto concat(const Ts &... xs) -> std::enable_if_t<!AllConvertibleToStringView< Ts... >, std::string >
Definition concat.hpp:43
CK_TILE_DEVICE void transpose_tile2d(OutTensor &out, const InTensor &in)
Definition transpose_tile.hpp:195
CK_TILE_HOST_DEVICE constexpr auto make_static_distributed_tensor(const StaticTileDistribution &)
Definition static_distributed_tensor.hpp:142
constant< v > number
Definition tile/core/numeric/integral_constant.hpp:37
CK_TILE_DEVICE constexpr auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition null_tile_window.hpp:75
CK_TILE_DEVICE void sweep_tile_span(TileDistributedSpan_, const F &f)
Definition sweep_tile.hpp:20
CK_TILE_DEVICE void move_tile_window(null_tile_window< WindowLengths > &, const typename null_tile_window< WindowLengths >::BottomTensorIndex &)
Definition null_tile_window.hpp:95
CK_TILE_DEVICE void store_tile(tile_window_with_static_lengths< BottomTensorView_, WindowLengths_ > &tile_window_tmp, const static_distributed_tensor< DataType_, TileDistribution_ > &dstr_tensor)
Definition store_tile.hpp:23
int32_t index_t
Definition integer.hpp:9
CK_TILE_HOST_DEVICE constexpr auto make_static_tile_distribution(StaticTileDistributionEncoding_)
Definition tile_distribution.hpp:480
GemmPipelineScheduler
Definition gemm_pipeline_ag_bg_cr_scheduler.hpp:14
@ Intrawave
Definition gemm_pipeline_ag_bg_cr_scheduler.hpp:16
CK_TILE_DEVICE auto load_tile(const TileWindow_ &tile_window, number< i_access >={}, bool_constant< oob_conditional_check >={})
Definition load_tile.hpp:22
CK_TILE_HOST_DEVICE constexpr details::return_type< D, Ts... > make_array(Ts &&... ts)
Definition tile/core/container/array.hpp:242
CK_TILE_HOST_DEVICE constexpr auto make_tuple(Xs &&... xs)
Definition tile/core/container/tuple.hpp:360
const GenericPointer< typename T::ValueType > T2 T::AllocatorType & a
Definition pointer.h:1517
Definition gemm_pipeline_ag_bg_cr_comp_v5.hpp:17
static CK_TILE_HOST_DEVICE constexpr bool BlockHasHotloop(index_t)
Definition gemm_pipeline_ag_bg_cr_comp_v5.hpp:24
static CK_TILE_HOST_DEVICE constexpr TailNumber GetBlockLoopTailNum(index_t)
Definition gemm_pipeline_ag_bg_cr_comp_v5.hpp:26
static CK_TILE_HOST_DEVICE auto TailHandler(const RunFunction &run_func, bool, TailNumber)
Definition gemm_pipeline_ag_bg_cr_comp_v5.hpp:32
static constexpr index_t PrefillStages
Definition gemm_pipeline_ag_bg_cr_comp_v5.hpp:19
static CK_TILE_HOST_DEVICE constexpr auto TransposeC()
Definition gemm_pipeline_ag_bg_cr_comp_v5.hpp:22
static constexpr index_t GlobalBufferNum
Definition gemm_pipeline_ag_bg_cr_comp_v5.hpp:20
static constexpr index_t PrefetchStages
Definition gemm_pipeline_ag_bg_cr_comp_v5.hpp:18
CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp &a_dram_block_window_tmp, const AElementFunction &a_element_func, const BsDramBlockWindowTmp &b_dram_block_window_tmp, const BElementFunction &b_element_func, index_t num_loop, void *__restrict__ p_smem_0) const
Definition gemm_pipeline_ag_bg_cr_comp_v5.hpp:140
PipelineImplBase Base
Definition gemm_pipeline_ag_bg_cr_comp_v5.hpp:129
Definition gemm_pipeline_ag_bg_cr_comp_v5.hpp:123
Definition gemm_pipeline_ag_bg_cr_comp_v5.hpp:40
static constexpr bool DoubleSmemBuffer
Definition gemm_pipeline_ag_bg_cr_comp_v5.hpp:92
static constexpr index_t NumWarps
Definition gemm_pipeline_ag_bg_cr_comp_v5.hpp:99
CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp &a_dram_block_window_tmp, const AElementFunction &a_element_func, const BsDramBlockWindowTmp &b_dram_block_window_tmp, const BElementFunction &b_element_func, index_t num_loop, void *p_smem_0) const
Definition gemm_pipeline_ag_bg_cr_comp_v5.hpp:393
static constexpr index_t MPerBlock
Definition gemm_pipeline_ag_bg_cr_comp_v5.hpp:72
number< 2 > I2
Definition gemm_pipeline_ag_bg_cr_comp_v5.hpp:68
static CK_TILE_HOST const std::string GetName()
Definition gemm_pipeline_ag_bg_cr_comp_v5.hpp:102
remove_cvref_t< typename Problem::CDataType > CDataType
Definition gemm_pipeline_ag_bg_cr_comp_v5.hpp:46
remove_cvref_t< typename Problem::ComputeDataType > ComputeDataType
Definition gemm_pipeline_ag_bg_cr_comp_v5.hpp:47
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp &a_dram_block_window_tmp, const BDramBlockWindowTmp &b_dram_block_window_tmp, const index_t num_loop, void *__restrict__ p_smem_0) const
Definition gemm_pipeline_ag_bg_cr_comp_v5.hpp:455
static constexpr index_t NumWaveGroups
Definition gemm_pipeline_ag_bg_cr_comp_v5.hpp:63
remove_cvref_t< std::tuple_element_t< 0, BsDataType > > BDataType
Definition gemm_pipeline_ag_bg_cr_comp_v5.hpp:61
remove_cvref_t< typename Problem::BsLayoutTuple > BsLayout
Definition gemm_pipeline_ag_bg_cr_comp_v5.hpp:54
static constexpr bool kPadM
Definition gemm_pipeline_ag_bg_cr_comp_v5.hpp:88
static constexpr index_t GetVectorSizeB()
Definition gemm_pipeline_ag_bg_cr_comp_v5.hpp:82
remove_cvref_t< std::tuple_element_t< 0, BsLayout > > BLayout
Definition gemm_pipeline_ag_bg_cr_comp_v5.hpp:58
static constexpr index_t NPerBlock
Definition gemm_pipeline_ag_bg_cr_comp_v5.hpp:73
remove_cvref_t< typename Problem::AsLayoutTuple > AsLayout
Definition gemm_pipeline_ag_bg_cr_comp_v5.hpp:53
remove_cvref_t< typename Problem::AsDataTypeTuple > AsDataType
Definition gemm_pipeline_ag_bg_cr_comp_v5.hpp:44
static constexpr auto Scheduler
Definition gemm_pipeline_ag_bg_cr_comp_v5.hpp:97
number< 1 > I1
Definition gemm_pipeline_ag_bg_cr_comp_v5.hpp:67
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp &a_dram_block_window_tmp, const AElementFunction &a_element_func, const BDramBlockWindowTmp &b_dram_block_window_tmp, const BElementFunction &b_element_func, index_t num_loop, void *p_smem_0) const
Definition gemm_pipeline_ag_bg_cr_comp_v5.hpp:435
static CK_TILE_HOST_DEVICE constexpr auto IsTransposeC()
Definition gemm_pipeline_ag_bg_cr_comp_v5.hpp:116
remove_cvref_t< typename Problem::AElementWise > AElementWise
Definition gemm_pipeline_ag_bg_cr_comp_v5.hpp:50
static constexpr index_t GetVectorSizeA()
Definition gemm_pipeline_ag_bg_cr_comp_v5.hpp:77
static CK_TILE_HOST_DEVICE constexpr index_t GetSmemSize()
Definition gemm_pipeline_ag_bg_cr_comp_v5.hpp:111
static constexpr index_t KTileSize
Definition gemm_pipeline_ag_bg_cr_comp_v5.hpp:100
static constexpr bool HasHotLoop
Definition gemm_pipeline_ag_bg_cr_comp_v5.hpp:95
static constexpr index_t GetVectorSizeC()
Definition gemm_pipeline_ag_bg_cr_comp_v5.hpp:86
static constexpr auto TailNum
Definition gemm_pipeline_ag_bg_cr_comp_v5.hpp:96
static constexpr index_t Preshuffle
Definition gemm_pipeline_ag_bg_cr_comp_v5.hpp:93
remove_cvref_t< typename Problem::BsDataTypeTuple > BsDataType
Definition gemm_pipeline_ag_bg_cr_comp_v5.hpp:45
remove_cvref_t< std::tuple_element_t< 0, AsDataType > > ADataType
Definition gemm_pipeline_ag_bg_cr_comp_v5.hpp:60
CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp &a_dram_block_window_tmp, const BsDramBlockWindowTmp &b_dram_block_window_tmp, const index_t num_loop, void *__restrict__ p_smem_0) const
Definition gemm_pipeline_ag_bg_cr_comp_v5.hpp:414
remove_cvref_t< typename Problem::CLayout > CLayout
Definition gemm_pipeline_ag_bg_cr_comp_v5.hpp:55
static constexpr index_t BlockSize
Definition gemm_pipeline_ag_bg_cr_comp_v5.hpp:70
static constexpr bool kPadK
Definition gemm_pipeline_ag_bg_cr_comp_v5.hpp:90
static constexpr bool kPadN
Definition gemm_pipeline_ag_bg_cr_comp_v5.hpp:89
number< 0 > I0
Definition gemm_pipeline_ag_bg_cr_comp_v5.hpp:66
remove_cvref_t< decltype(Policy::template GetBlockGemm< Problem >())> BlockGemm
Definition gemm_pipeline_ag_bg_cr_comp_v5.hpp:65
BaseGemmPipelineAgBgCrCompV5< Problem > Base
Definition gemm_pipeline_ag_bg_cr_comp_v5.hpp:41
GemmPipelineAgBgCrImplBase< Problem, Policy > PipelineImplBase
Definition gemm_pipeline_ag_bg_cr_comp_v5.hpp:42
remove_cvref_t< typename Problem::BElementWise > BElementWise
Definition gemm_pipeline_ag_bg_cr_comp_v5.hpp:51
static constexpr index_t KPerBlock
Definition gemm_pipeline_ag_bg_cr_comp_v5.hpp:74
remove_cvref_t< std::tuple_element_t< 0, AsLayout > > ALayout
Definition gemm_pipeline_ag_bg_cr_comp_v5.hpp:57
remove_cvref_t< typename Problem::BlockGemmShape > BlockGemmShape
Definition gemm_pipeline_ag_bg_cr_comp_v5.hpp:48
Definition gemm_pipeline_ag_bg_cr_base.hpp:13
CK_TILE_DEVICE constexpr auto GetBWindows(const BDramBlockWindowTmp &b_dram_block_window_tmp, const BLdsTensorView &b_lds_block_view, const BLdsLoadTileDistr &, const array< index_t, 2 > &offset={0, 0}) const
Definition gemm_pipeline_ag_bg_cr_base.hpp:225
remove_cvref_t< std::tuple_element_t< number< 0 >{}, BsDataType > > BDataType
Definition gemm_pipeline_ag_bg_cr_base.hpp:22
CK_TILE_DEVICE auto GetABLdsTensorViews(void *p_smem) const
Definition gemm_pipeline_ag_bg_cr_base.hpp:83
static constexpr index_t NPerBlock
Definition gemm_pipeline_ag_bg_cr_base.hpp:26
CK_TILE_DEVICE void LocalPrefetch(DstBlockTile &dst_block_tile, const SrcTileWindow &lds_tile_window, bool_constant< LoadTranspose >={}) const
Definition gemm_pipeline_ag_bg_cr_base.hpp:73
static constexpr index_t MPerBlock
Definition gemm_pipeline_ag_bg_cr_base.hpp:25
CK_TILE_DEVICE void LocalPrefill(DstTileWindow &lds_tile_window, const SrcBlockTile &src_block_tile, const ElementFunction &element_func) const
Definition gemm_pipeline_ag_bg_cr_base.hpp:57
CK_TILE_DEVICE constexpr auto GetAWindows(const ADramBlockWindowTmp &a_dram_block_window_tmp, const ALdsTensorView &a_lds_block_view, const ALdsLoadTileDistr &, const array< index_t, 2 > &offset={0, 0}) const
Definition gemm_pipeline_ag_bg_cr_base.hpp:190
static constexpr index_t KPerBlock
Definition gemm_pipeline_ag_bg_cr_base.hpp:27
Definition tile/core/numeric/integral_constant.hpp:30