streamk_gemm_kernel.hpp Source File

streamk_gemm_kernel.hpp Source File#

Composable Kernel: streamk_gemm_kernel.hpp Source File
streamk_gemm_kernel.hpp
Go to the documentation of this file.
1// Copyright © Advanced Micro Devices, Inc., or its affiliates.
2// SPDX-License-Identifier: MIT
3
4#pragma once
5
9
10namespace ck_tile {
11namespace reboot {
12
21{
22 CK_TILE_HOST explicit StreamKHostArgs(const void* a_ptr_,
23 const void* b_ptr_,
24 void* c_ptr_,
25 index_t M_,
26 index_t N_,
27 index_t K_,
28 index_t stride_A_,
29 index_t stride_B_,
30 index_t stride_C_,
31 StreamKReductionStrategy reduction_strategy_)
32 : UniversalGemmHostArgs<>({a_ptr_},
33 {b_ptr_},
34 {/*ds_ptr*/},
35 c_ptr_,
36 /*k_batch_ =*/1,
37 M_,
38 N_,
39 K_,
40 {stride_A_},
41 {stride_B_},
42 {/*stride_Ds_*/},
43 stride_C_),
44 reduction_strategy{reduction_strategy_}
45 {
46 }
47
49};
50
55// The main kernel functions are the operator() functions. There is one for Persistent
56// and one for Non-Persistent data parallel sections of the Stream-K algorithm.
57//
58// Both the Non-Persistent and Persistent kernels make use of `BaseGemm()` and
59// `StreamKGemm()`. `BaseGemm()` computes offsets into the A,B,C tensors, then calls
60// `RunGemm()` which runs the GEMM pipeline and epilogue. `StreamKGemm()` performs the
61// main Stream-K algorithm. Each iteration of the Stream-K loop calls `BaseGemm()`.
62template <typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_>
64{
69
72
73 using TilePartitioner = TilePartitioner_;
74 using GemmPipeline = GemmPipeline_;
75 using EpiloguePipeline = EpiloguePipeline_;
76
77 static_assert(
78 TilePartitioner::PERSISTENT == PersistentDP,
79 "Persistent flag from TilePartitioner must match Persistent flag from UniversalGemm.");
80
82 using ALayout = typename GemmPipeline::ALayout;
83 using BLayout = typename GemmPipeline::BLayout;
84 using CLayout = typename GemmPipeline::CLayout;
85
87 using ADataType = typename GemmPipeline::ADataType;
88 using BDataType = typename GemmPipeline::BDataType;
89 using CDataType = typename EpiloguePipeline::ODataType;
90 using AccDataType = typename EpiloguePipeline::AccDataType;
91
92 template <typename T>
94
97 "ALayout and ADataType must be scalars.");
98
101 "BLayout and BDataType must be scalars.");
102
105 "CLayout and CDataType must be scalars.");
106
108 {
111 host_args.bs_ptr,
112 host_args.ds_ptr,
113 host_args.e_ptr,
114 host_args.M,
115 host_args.N,
116 host_args.K,
117 host_args.stride_As,
118 host_args.stride_Bs,
119 host_args.stride_Ds,
120 host_args.stride_E,
121 host_args.k_batch},
123 // The workspace pointer is set to nullptr because we must first
124 // instantiate the TilePartitioner to get the necessary size
125 workspace_ptr{nullptr},
126 tile_partitioner{TilePartitioner{host_args.M, host_args.N, host_args.K, grid}}
127
128 {
129 }
130
139 };
140
143
144 [[nodiscard]] CK_TILE_HOST static const std::string GetName()
145 {
146 // clang-format off
147 using P_ = GemmPipeline;
148 using WarpTile = typename P_::BlockGemmShape::WarpTile;
149
150 return concat('_', "streamk", gemm_prec_str<ADataType, BDataType>(),
151 concat('x', P_::MPerBlock, P_::NPerBlock, P_::KPerBlock),
152 concat('x', WarpTile::at(number<0>{}), WarpTile::at(number<1>{}), WarpTile::at(number<2>{})),
153 concat('x', P_::GetVectorSizeA(), P_::GetVectorSizeB(), P_::GetVectorSizeC()),
154 concat('x', P_::kPadM, P_::kPadN, P_::kPadK));
155 // clang-format on
156 }
157
160 CK_TILE_HOST static auto GridSize(const TilePartitioner& tile_partitioner) -> dim3
161 {
162 return tile_partitioner.grid_size();
163 }
164
169 CK_TILE_HOST static auto MaxOccupancyGridSize(const stream_config& s) -> dim3
170 {
172 }
173
174 CK_TILE_HOST static constexpr auto BlockSize() -> dim3
175 {
177 }
178
186 CK_TILE_HOST static StreamKKernelArgs MakeKernelArgs(const StreamKHostArgs& host_args,
187 int num_cu = NumCU(),
188 int occupancy = Occupancy())
189 {
190 const index_t grid = num_cu * occupancy;
191
192 return StreamKKernelArgs{host_args, grid};
193 }
194
195 template <bool UseDefaultScheduler = true>
196 CK_TILE_DEVICE static void
197 RunGemm(const std::array<const ADataType*, UniversalGemmKernel::NumATensor>& as_ptr,
198 const std::array<const BDataType*, UniversalGemmKernel::NumBTensor>& bs_ptr,
199 const std::array<const void*, UniversalGemmKernel::NumDTensor>& ds_ptr,
200 CDataType* c_ptr,
201 void* smem_ptr_0,
202 const typename UniversalGemmKernel::KernelArgs& kargs,
203 const index_t num_loop,
204 const index_t block_idx_m,
205 const index_t block_idx_n,
206 const index_t k_size)
207 {
208 // Create Gemm tensor views, pad views and tile windows
209 const auto& gemm_tensor_views_tuple =
210 UniversalGemmKernel::template MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
211 as_ptr, bs_ptr, ds_ptr, c_ptr, kargs, k_size);
212
213 const auto& gemm_pad_views = UniversalGemmKernel::MakeGemmPadViews(gemm_tensor_views_tuple);
214 auto gemm_tile_windows =
215 UniversalGemmKernel::MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
216
217 // Run GEMM cooperatively by whole workgroup.
218 const auto& as_block_window = gemm_tile_windows.at(UniversalGemmKernel::I0);
219 const auto& bs_block_window = gemm_tile_windows.at(UniversalGemmKernel::I1);
220 const auto& ds_block_window = gemm_tile_windows.at(UniversalGemmKernel::I2);
221
222 // Since num_loop can vary per WG and per iteration of the Stream-K while loop, we compute
223 // has_hot_loop and tail_num here. This is a similar pattern used by grouped GEMM. In this
224 // case, we call the GemmPipeline's operator() function that takes both has_hot_loop and
225 // tail_num.
226 const bool has_hot_loop = GemmPipeline::BlockHasHotloop(num_loop);
227 const TailNumber tail_num = GemmPipeline::GetBlockLoopTailNum(num_loop);
228
229 const auto& c_block_tile = GemmPipeline{}(as_block_window[UniversalGemmKernel::I0],
230 bs_block_window[UniversalGemmKernel::I0],
231 num_loop,
232 has_hot_loop,
233 tail_num,
234 smem_ptr_0);
235
236 if(UseDefaultScheduler || (get_warp_id() == 0))
237 {
238 // Run Epilogue Pipeline
239 auto& c_block_window = gemm_tile_windows.at(UniversalGemmKernel::I3);
240
241 EpiloguePipeline{}(c_block_window, c_block_tile, ds_block_window, smem_ptr_0);
242 }
243 }
244
245 CK_TILE_HOST static bool IsSupportedArgument(const StreamKKernelArgs& kargs)
246 {
248 }
249
252 CK_TILE_HOST static uint32_t GetWorkSpaceSize(const StreamKKernelArgs& kargs)
253 {
254 return kargs.tile_partitioner.get_workspace_size(sizeof(AccDataType));
255 }
256
259 CK_TILE_HOST static void SetWorkSpacePointer(StreamKKernelArgs& kargs, void* workspace_ptr)
260 {
261 kargs.workspace_ptr = workspace_ptr;
262 }
263
274 CK_TILE_DEVICE void BaseGemm(StreamKKernelArgs& kargs,
275 index_t tile_idx,
276 index_t num_loop,
277 index_t i_k_a,
278 index_t i_k_b,
279 index_t k_size,
280 void* smem_ptr_0) const
281 {
282 const auto c_macro_tile_idx = kargs.tile_partitioner.get_output_tile_index(tile_idx);
283 index_t i_m = c_macro_tile_idx[UniversalGemmKernel::I0] * TilePartitioner::MPerBlock;
284 index_t i_n = c_macro_tile_idx[UniversalGemmKernel::I1] * TilePartitioner::NPerBlock;
285
286 const ADataType* a_ptr = static_cast<const ADataType*>(kargs.as_ptr[0]) + i_k_a;
287 const BDataType* b_ptr = static_cast<const BDataType*>(kargs.bs_ptr[0]) + i_k_b;
288 CDataType* c_ptr = static_cast<CDataType*>(kargs.e_ptr);
289
290 // Run the GEMM pipeline and Epilogue.
291 RunGemm(
292 {a_ptr}, {b_ptr}, {/*ds_ptr*/}, c_ptr, smem_ptr_0, kargs, num_loop, i_m, i_n, k_size);
293 }
294
301 CK_TILE_DEVICE void SignalStorePartialDone(const StreamKKernelArgs& kargs,
302 index_t cta_idx) const
303 {
304 auto sk_flags_ptr = static_cast<uint32_t*>(kargs.workspace_ptr);
305 workgroup_barrier sk_flags(sk_flags_ptr);
306 sk_flags.wait_set(0, 1, cta_idx);
307 }
308
314 CK_TILE_DEVICE void WaitStorePartialDone(const StreamKKernelArgs& kargs, index_t cta_idx) const
315 {
316 auto sk_flags_ptr = static_cast<uint32_t*>(kargs.workspace_ptr);
317 workgroup_barrier sk_flags(sk_flags_ptr);
318 sk_flags.wait_eq(1, cta_idx);
319 }
320
326 template <typename OAccTile>
327 CK_TILE_DEVICE void AddBlockTile(OAccTile& in_out_block_tile,
328 const OAccTile& in_block_tile) const
329 {
330 using BlockType = remove_cvref_t<decltype(in_out_block_tile)>;
331 constexpr auto o_spans = BlockType::get_distributed_spans();
332 sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) {
333 sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) {
334 constexpr auto idx = make_tuple(idx0, idx1);
335 in_out_block_tile(idx) = in_out_block_tile[idx] + in_block_tile[idx];
336 });
337 });
338 }
339
347 template <typename DataType, typename OAccTileDist>
348 CK_TILE_DEVICE auto LoadPartial(const StreamKKernelArgs& kargs,
349 index_t cta_idx,
350 const OAccTileDist& c_block_tile_dist) const
351 {
352 const auto c_block_tile_buffer_size =
353 TilePartitioner::MPerBlock * TilePartitioner::NPerBlock * sizeof(DataType);
354 void* partial_buffer_ptr = static_cast<char*>(kargs.workspace_ptr) +
355 kargs.tile_partitioner.get_flags_buffer_size() +
356 cta_idx * c_block_tile_buffer_size;
357
358 const auto& partial_tensor_view = make_naive_tensor_view<address_space_enum::global>(
359 static_cast<DataType*>(partial_buffer_ptr),
361 make_tuple(TilePartitioner::NPerBlock, 1),
362 number<GemmPipeline::GetVectorSizeC()>{},
363 number<1>{});
364
365 auto partial_tile_window = make_tile_window(
366 partial_tensor_view,
368 {0, 0},
369 c_block_tile_dist);
370
371 return load_tile(partial_tile_window);
372 }
373
380 template <typename OAccTile>
381 CK_TILE_DEVICE void StorePartial(const StreamKKernelArgs& kargs,
382 index_t cta_idx,
383 const OAccTile& c_block_tile) const
384 {
385 const auto c_block_tile_buffer_size = TilePartitioner::MPerBlock *
386 TilePartitioner::NPerBlock *
387 sizeof(typename OAccTile::DataType);
388 void* partial_buffer_ptr = static_cast<char*>(kargs.workspace_ptr) +
389 kargs.tile_partitioner.get_flags_buffer_size() +
390 cta_idx * c_block_tile_buffer_size;
391
392 const auto& partial_tensor_view = make_naive_tensor_view<address_space_enum::global>(
393 static_cast<typename OAccTile::DataType*>(partial_buffer_ptr),
395 make_tuple(TilePartitioner::NPerBlock, 1),
396 number<GemmPipeline::GetVectorSizeC()>{},
397 number<1>{});
398
399 auto partial_tile_window = make_tile_window(
400 partial_tensor_view,
402 {0, 0});
403
404 store_tile(partial_tile_window, c_block_tile);
405 }
406
414 CK_TILE_DEVICE void
415 StreamKGemm(StreamKKernelArgs& kargs, index_t cta_idx, void* smem_ptr_0) const
416 {
417 index_t iter_start, iter_end;
418 kargs.tile_partitioner.get_iter_boundaries(iter_start, iter_end, cta_idx);
419
420 while(iter_start < iter_end)
421 {
422 // Get the 1D tile index in the C tensor that this workgroup will work in for this
423 // iteration of the loop.
424 index_t tile_idx =
425 amd_wave_read_first_lane(kargs.tile_partitioner.get_tile_index(iter_start));
426
427 // Get the start and end boundaries for the current tile.
428 index_t tile_iter_start, tile_iter_end;
429 kargs.tile_partitioner.get_tile_boundaries(tile_iter_start, tile_iter_end, tile_idx);
430
431 // Get the start and end iteration within the current tile for the workgroup.
432 index_t local_iter_start = amd_wave_read_first_lane(
433 kargs.tile_partitioner.get_local_iter(iter_start, tile_iter_start));
434 index_t local_iter_end =
435 amd_wave_read_first_lane(kargs.tile_partitioner.get_local_iter_end(
436 tile_iter_start, iter_end, tile_iter_end));
437
438 // Get the iteration length.
439 index_t num_loop_sk = local_iter_end - local_iter_start;
440
441 // Determine the total size along the K dimension the workgroup is using in this
442 // iteration (used to construct tensor views).
443 index_t k_size = num_loop_sk * TilePartitioner::KPerBlock;
444
445 // Get the K offsets for the A and B tensors
446 auto [i_k_a, i_k_b] = GetKOffsets<ALayout, BLayout>(
447 local_iter_start, kargs.stride_As[0], kargs.stride_Bs[0]);
448
449 if constexpr(TilePartitioner::ReductionStrategy == StreamKReductionStrategy::Atomic)
450 {
451 BaseGemm(kargs, tile_idx, num_loop_sk, i_k_a, i_k_b, k_size, smem_ptr_0);
452 }
453 else
454 {
455 const auto c_macro_tile_idx =
456 kargs.tile_partitioner.get_output_tile_index(tile_idx);
457 index_t i_m =
458 c_macro_tile_idx[UniversalGemmKernel::I0] * TilePartitioner::MPerBlock;
459 index_t i_n =
460 c_macro_tile_idx[UniversalGemmKernel::I1] * TilePartitioner::NPerBlock;
461
462 const ADataType* a_ptr = static_cast<const ADataType*>(kargs.as_ptr[0]) + i_k_a;
463 const BDataType* b_ptr = static_cast<const BDataType*>(kargs.bs_ptr[0]) + i_k_b;
464 CDataType* c_ptr = static_cast<CDataType*>(kargs.e_ptr);
465
466 // Create Gemm tensor views, pad views and tile windows
467 const auto& gemm_tensor_views_tuple =
468 UniversalGemmKernel::template MakeGemmTensorViews<
469 EpiloguePipeline::MemoryOperation>(
470 {a_ptr}, {b_ptr}, {/*ds_ptr*/}, c_ptr, kargs, k_size);
471
472 const auto& gemm_pad_views =
473 UniversalGemmKernel::MakeGemmPadViews(gemm_tensor_views_tuple);
474 auto gemm_tile_windows =
475 UniversalGemmKernel::MakeGemmTileWindows(gemm_pad_views, i_m, i_n);
476
477 // Run GEMM cooperatively by whole workgroup.
478 const auto& as_block_window = gemm_tile_windows.at(UniversalGemmKernel::I0);
479 const auto& bs_block_window = gemm_tile_windows.at(UniversalGemmKernel::I1);
480 const auto& ds_block_window = gemm_tile_windows.at(UniversalGemmKernel::I2);
481
482 // Since num_loop can vary per WG and per iteration of the Stream-K while loop,
483 // we compute has_hot_loop and tail_num here. This is a similar pattern used by
484 // grouped GEMM. In this case, we call the GemmPipeline's operator() function
485 // that takes both has_hot_loop and tail_num.
486 const bool has_hot_loop = GemmPipeline::BlockHasHotloop(num_loop_sk);
487 const TailNumber tail_num = GemmPipeline::GetBlockLoopTailNum(num_loop_sk);
488
489 const auto& c_block_tile = GemmPipeline{}(as_block_window[UniversalGemmKernel::I0],
490 bs_block_window[UniversalGemmKernel::I0],
491 num_loop_sk,
492 has_hot_loop,
493 tail_num,
494 smem_ptr_0);
495
496 auto tile_started = iter_start == tile_iter_start;
497 auto tile_ended = iter_end >= tile_iter_end;
498 if(!tile_started)
499 {
500 StorePartial(kargs, cta_idx, c_block_tile);
501 // Ensure device-wide visibility of partial results stored in global memory
502 // before signaling completion. __threadfence() guarantees that all global
503 // memory writes by this thread are visible to other threads on the device.
504 __threadfence(); // send signal when the store is done
505 SignalStorePartialDone(kargs, cta_idx);
506 }
507 else
508 {
509 auto accum_block_tile = c_block_tile;
510 if(!tile_ended)
511 {
512 const index_t iter_per_tile = kargs.tile_partitioner.get_iters_per_tile();
513 const index_t iter_per_cta = kargs.tile_partitioner.get_iters_per_sk_cta();
514 const index_t extra_iters = kargs.tile_partitioner.get_extra_iters();
515 int accum_iters = local_iter_end - local_iter_start;
516 int next_cta = cta_idx + 1;
517
518 while(accum_iters < iter_per_tile)
519 {
520 WaitStorePartialDone(kargs, next_cta);
521
522 using BlockType = remove_cvref_t<decltype(c_block_tile)>;
524 accum_block_tile,
526 kargs, next_cta, c_block_tile.get_tile_distribution()));
527
528 accum_iters += iter_per_cta + (next_cta < extra_iters);
529 ++next_cta;
530 }
531 }
532
533 auto& c_block_window = gemm_tile_windows.at(UniversalGemmKernel::I3);
535 c_block_window, accum_block_tile, ds_block_window, smem_ptr_0);
536 }
537 }
538
539 // Prepare for next Stream-K loop iteration.
540 iter_start = tile_iter_end;
542 }
543 }
544
552 template <bool U = PersistentDP>
553 CK_TILE_DEVICE typename std::enable_if_t<!U> operator()(StreamKKernelArgs kargs) const
554 {
555 // Allocate LDS
556 __shared__ char smem_ptr_0[UniversalGemmKernel::GetSmemSize()];
557
558 index_t block_idx = ck_tile::get_block_1d_id();
559 index_t dp_num_loop = kargs.tile_partitioner.get_iters_per_tile();
560 index_t dp_ctas = kargs.tile_partitioner.get_dp_ctas();
561 bool is_dp_ctas = block_idx < kargs.tile_partitioner.get_dp_ctas();
562
563 // Check if at the data parallel section
564 if(is_dp_ctas)
565 {
566 BaseGemm(kargs, block_idx, dp_num_loop, 0, 0, kargs.K, smem_ptr_0);
567 }
568 else
569 {
570 // Stream-K
571 StreamKGemm(kargs, block_idx - dp_ctas, smem_ptr_0);
572 }
573 }
574
583 template <bool U = PersistentDP>
584 CK_TILE_DEVICE typename std::enable_if_t<U> operator()(StreamKKernelArgs kargs) const
585 {
586 // Allocate LDS
587 __shared__ char smem_ptr_0[UniversalGemmKernel::GetSmemSize()];
588
589 index_t block_idx = ck_tile::get_block_1d_id();
590 index_t dp_num_loop = kargs.tile_partitioner.get_iters_per_tile();
591
592 // Data-parallel section
593 for(index_t tile_idx = block_idx; tile_idx < kargs.tile_partitioner.get_dp_tiles();
594 tile_idx += kargs.tile_partitioner.get_grid())
595 {
596 BaseGemm(kargs, tile_idx, dp_num_loop, 0, 0, kargs.K, smem_ptr_0);
597 }
598
599 // Stream-K section
600 StreamKGemm(kargs, block_idx, smem_ptr_0);
601 }
602
603 private:
610 template <typename ALayout, typename BLayout>
612 GetKOffsets(index_t iter_offset, index_t stride_a, index_t stride_b)
613 {
614 index_t stride_offset_a;
615 index_t stride_offset_b;
616 if constexpr(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::ColumnMajor>)
617 {
618 stride_offset_a = stride_a;
619 }
620 else
621 {
622 stride_offset_a = 1;
623 }
624
625 if constexpr(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>)
626 {
627 stride_offset_b = stride_b;
628 }
629 else
630 {
631 stride_offset_b = 1;
632 }
633
634 index_t base_offset = iter_offset * TilePartitioner::KPerBlock;
635
636 return make_tuple(base_offset * stride_offset_a, base_offset * stride_offset_b);
637 }
638
639 CK_TILE_HOST static int NumCU()
640 {
641 hipDeviceProp_t dev_prop;
642 hipDevice_t dev;
643 hip_check_error(hipGetDevice(&dev));
644 hip_check_error(hipGetDeviceProperties(&dev_prop, dev));
645 int num_cu = dev_prop.multiProcessorCount;
646
647 return num_cu;
648 }
649
654 CK_TILE_HOST static int Occupancy()
655 {
656 int occupancy;
657
658 // Since occupancy of 1 is valid for stream k, we set min_num_block_per_cu to 1
659 constexpr int min_block_per_cu = 1;
661
663 hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel, kBlockSize, 0));
664
665 return occupancy;
666 }
667};
668} // namespace reboot
669
678{
679 CK_TILE_HOST explicit StreamKHostArgs(const void* a_ptr_,
680 const void* b_ptr_,
681 void* c_ptr_,
682 index_t M_,
683 index_t N_,
684 index_t K_,
685 index_t stride_A_,
686 index_t stride_B_,
687 index_t stride_C_,
688 StreamKReductionStrategy reduction_strategy_,
689 uint32_t num_sk_blocks_ = 0xffffffff)
690 : UniversalGemmHostArgs<>({a_ptr_},
691 {b_ptr_},
692 {/*ds_ptr*/},
693 c_ptr_,
694 /*k_batch_ =*/1,
695 M_,
696 N_,
697 K_,
698 {stride_A_},
699 {stride_B_},
700 {/*stride_Ds_*/},
701 stride_C_),
702 reduction_strategy{reduction_strategy_},
703 num_sk_blocks{num_sk_blocks_}
704 {
705 }
706
709};
710
711template <typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_>
713{
718
720
724
729
734
738 "ALayout and ADataType must be scalars.");
739
743 "BLayout and BDataType must be scalars.");
744
748 "CLayout and CDataType must be scalars.");
749
763
766
767 [[nodiscard]] CK_TILE_HOST static const std::string GetName()
768 {
769 // clang-format off
770 using P_ = GemmPipeline;
771 using WarpTile = typename P_::BlockGemmShape::WarpTile;
772
773 return concat('_', "streamk", gemm_prec_str<ADataType, BDataType>(),
774 concat('x', P_::MPerBlock, P_::NPerBlock, P_::KPerBlock),
775 concat('x', WarpTile::at(number<0>{}), WarpTile::at(number<1>{}), WarpTile::at(number<2>{})),
776 concat('x', P_::GetVectorSizeA(), P_::GetVectorSizeB(), P_::GetVectorSizeC()),
777 concat('x', P_::kPadM, P_::kPadN, P_::kPadK));
778 // clang-format on
779 }
780
783 CK_TILE_HOST static auto GridSize(const TilePartitioner& tile_partitioner) -> dim3
784 {
785 return tile_partitioner.GridSize();
786 }
787
792 CK_TILE_HOST static auto MaxOccupancyGridSize(const stream_config& s) -> dim3
793 {
795 }
796
797 CK_TILE_HOST static constexpr auto BlockSize() -> dim3
798 {
800 }
801
809 CK_TILE_HOST static StreamKKernelArgs MakeKernelArgs(const StreamKHostArgs& host_args,
810 int num_cu = NumCU(),
811 int occupancy = Occupancy())
812 {
813 return StreamKKernelArgs{{host_args.as_ptr,
814 host_args.bs_ptr,
815 host_args.ds_ptr,
816 host_args.e_ptr,
817 host_args.M,
818 host_args.N,
819 host_args.K,
820 host_args.stride_As,
821 host_args.stride_Bs,
822 host_args.stride_Ds,
823 host_args.stride_E,
824 host_args.k_batch},
825 host_args.reduction_strategy,
826 host_args.num_sk_blocks,
827 // The workspace pointer is set to nullptr because we must first
828 // instantiate the TilePartitioner to get the necessary size
829 /*workspace_ptr =*/nullptr,
830 TilePartitioner{static_cast<uint32_t>(host_args.M),
831 static_cast<uint32_t>(host_args.N),
832 static_cast<uint32_t>(host_args.K),
833 static_cast<uint32_t>(num_cu),
834 static_cast<uint32_t>(occupancy),
835 host_args.num_sk_blocks}};
836 }
837
838 template <bool UseDefaultScheduler = true>
839 CK_TILE_DEVICE static void
840 RunGemm(const std::array<const ADataType*, UniversalGemmKernel::NumATensor>& as_ptr,
841 const std::array<const BDataType*, UniversalGemmKernel::NumBTensor>& bs_ptr,
842 const std::array<const void*, UniversalGemmKernel::NumDTensor>& ds_ptr,
843 CDataType* c_ptr,
844 void* smem_ptr_0,
845 const typename UniversalGemmKernel::KernelArgs& kargs,
846 const index_t num_loop,
847 const index_t block_idx_m,
848 const index_t block_idx_n,
849 const index_t k_size)
850 {
851 // Create Gemm tensor views, pad views and tile windows
852 const auto& gemm_tensor_views_tuple =
853 UniversalGemmKernel::template MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
854 as_ptr, bs_ptr, ds_ptr, c_ptr, kargs, k_size);
855
856 const auto& gemm_pad_views = UniversalGemmKernel::MakeGemmPadViews(gemm_tensor_views_tuple);
857 auto gemm_tile_windows =
858 UniversalGemmKernel::MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
859
860 // Run GEMM cooperatively by whole workgroup.
861 const auto& as_block_window = gemm_tile_windows.at(UniversalGemmKernel::I0);
862 const auto& bs_block_window = gemm_tile_windows.at(UniversalGemmKernel::I1);
863 const auto& ds_block_window = gemm_tile_windows.at(UniversalGemmKernel::I2);
864
865 // Since num_loop can vary per WG and per iteration of the Stream-K while loop, we compute
866 // has_hot_loop and tail_num here. This is a similar pattern used by grouped GEMM. In this
867 // case, we call the GemmPipeline's operator() function that takes both has_hot_loop and
868 // tail_num.
869 const bool has_hot_loop = GemmPipeline::BlockHasHotloop(num_loop);
870 const TailNumber tail_num = GemmPipeline::GetBlockLoopTailNum(num_loop);
871
872 const auto& c_block_tile = GemmPipeline{}(as_block_window[UniversalGemmKernel::I0],
873 bs_block_window[UniversalGemmKernel::I0],
874 num_loop,
875 has_hot_loop,
876 tail_num,
877 smem_ptr_0);
878
879 if(UseDefaultScheduler || (get_warp_id() == 0))
880 {
881 // Run Epilogue Pipeline
882 auto& c_block_window = gemm_tile_windows.at(UniversalGemmKernel::I3);
883
884 EpiloguePipeline{}(c_block_window, c_block_tile, ds_block_window, smem_ptr_0);
885 }
886 }
887
888 CK_TILE_HOST static bool IsSupportedArgument(const StreamKKernelArgs& kargs)
889 {
890 if(kargs.reduction_strategy == StreamKReductionStrategy::Reduction)
891 {
892 if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
893 {
894 CK_TILE_ERROR("CK Tile Stream-K only supports the atomic reduction strategy.");
895 }
896 return false;
897 }
899 }
900
903 CK_TILE_HOST static uint32_t GetWorkSpaceSize(const StreamKKernelArgs& kargs)
904 {
905 // For reduction, we need to determine the amount of device space for acculumation
906 // results and semaphores.
907 if(kargs.reduction_strategy == ck_tile::StreamKReductionStrategy::Reduction)
908 {
909 return kargs.tile_partitioner.GetWorkSpaceSize(sizeof(CDataType));
910 }
911
912 // Otherwise, no additional space is needed since blocks atomically store their results.
913 return 0;
914 }
915
918 CK_TILE_HOST static void SetWorkSpacePointer(StreamKKernelArgs& kargs, void* workspace_ptr)
919 {
920 kargs.workspace_ptr = workspace_ptr;
921 }
922
924 CK_TILE_DEVICE void operator()(StreamKKernelArgs kargs) const
925 {
926 // Allocate LDS
927 __shared__ char smem_ptr_0[UniversalGemmKernel::GetSmemSize()];
928
930
931 bool is_padding_block =
932 amd_wave_read_first_lane(block_idx >= kargs.tile_partitioner.sk_num_blocks &&
933 block_idx < kargs.tile_partitioner.dp_start_block_idx);
934
935 // Padding blocks make it such that the DP blocks are aligned with the number of CUs; they
936 // should not partake in the GEMM
937 if(is_padding_block)
938 return;
939
940 // Determine the K offset of the first and final macro tile in the A and B tensors along the
941 // K dimension.
942 uint32_t iter_start, iter_end;
943 kargs.tile_partitioner.GetBlockItr(block_idx, iter_start, iter_end);
944
945 // Main Stream-K loop
946 while(true)
947 {
948 // Determine the number of macro tiles in A and B this WG is resposible for in the
949 // current C macro tile.
950 uint32_t current_iter_length = amd_wave_read_first_lane(
951 kargs.tile_partitioner.GetCurrentIterLength(iter_start, iter_end));
952
953 // Determine the 1D tile_idx and the iter_offset for this WG.
954 // The tile_idx is the 1D macro tile index in the C tensor.
955 // The iter_offset is the starting macro tile index in the K dimension for the WG in the
956 // current iteration of the while loop.
957 uint32_t tile_idx, iter_offset;
958 kargs.tile_partitioner.GetTileIdxWithOffset(iter_start, tile_idx, iter_offset);
959
960 // Get the 2D tile index in the C tensor for this WG using the 1D index (i.e. tile_idx)
961 auto spatial_idx = kargs.tile_partitioner.GetOutputTileIndex(tile_idx);
962
963 // Get the offsets in A, B, C tensors.
964 index_t i_m = static_cast<index_t>(spatial_idx[UniversalGemmKernel::I0] *
965 TilePartitioner::MPerBlock);
966 index_t i_n = static_cast<index_t>(spatial_idx[UniversalGemmKernel::I1] *
967 TilePartitioner::NPerBlock);
968 auto [i_k_a, i_k_b] = GetKOffsets<ALayout, BLayout>(
969 static_cast<index_t>(iter_offset), kargs.stride_As[0], kargs.stride_Bs[0]);
970
971 // Determine the total size along the K dimension the WG is using in this iteration
972 // (used to construct tensor views).
973 index_t k_size = static_cast<index_t>(current_iter_length * TilePartitioner::KPerBlock);
974
975 // Update pointer offsets for A, B, and C.
976 const ADataType* a_ptr = static_cast<const ADataType*>(kargs.as_ptr[0]) + i_k_a;
977 const BDataType* b_ptr = static_cast<const BDataType*>(kargs.bs_ptr[0]) + i_k_b;
978 CDataType* c_ptr = static_cast<CDataType*>(kargs.e_ptr);
979
980 // Run the GEMM pipeline and Epilogue.
981 RunGemm({a_ptr},
982 {b_ptr},
983 {/*ds_ptr*/},
984 c_ptr,
985 smem_ptr_0,
986 kargs,
987 current_iter_length,
988 i_m,
989 i_n,
990 k_size);
991
992 // Prepare for next Stream-K loop iteration.
993 iter_start += current_iter_length;
994 if(iter_end <= iter_start)
995 break;
997 }
998 }
999
1000 private:
1007 template <typename ALayout, typename BLayout>
1009 GetKOffsets(index_t iter_offset, index_t stride_a, index_t stride_b)
1010 {
1011 index_t stride_offset_a;
1012 index_t stride_offset_b;
1013 if constexpr(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::ColumnMajor>)
1014 {
1015 stride_offset_a = stride_a;
1016 }
1017 else
1018 {
1019 stride_offset_a = 1;
1020 }
1021
1022 if constexpr(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>)
1023 {
1024 stride_offset_b = stride_b;
1025 }
1026 else
1027 {
1028 stride_offset_b = 1;
1029 }
1030
1031 index_t base_offset = iter_offset * TilePartitioner::KPerBlock;
1032
1033 return make_tuple(base_offset * stride_offset_a, base_offset * stride_offset_b);
1034 }
1035
1036 CK_TILE_HOST static int NumCU()
1037 {
1038 hipDeviceProp_t dev_prop;
1039 hipDevice_t dev;
1040 hip_check_error(hipGetDevice(&dev));
1041 hip_check_error(hipGetDeviceProperties(&dev_prop, dev));
1042 int num_cu = dev_prop.multiProcessorCount;
1043
1044 return num_cu;
1045 }
1046
1051 CK_TILE_HOST static int Occupancy()
1052 {
1053 int occupancy;
1054
1055 // Since occupancy of 1 is valid for stream k, we set min_num_block_per_cu to 1
1056 constexpr int min_block_per_cu = 1;
1058
1060 hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel, kBlockSize, 0));
1061
1062 return occupancy;
1063 }
1064};
1065
1066} // namespace ck_tile
#define CK_TILE_DEVICE
Definition config.hpp:41
#define CK_TILE_HOST
Definition config.hpp:40
Definition streamk_gemm_kernel.hpp:11
Definition tile/core/algorithm/cluster_descriptor.hpp:13
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
CK_TILE_HOST_DEVICE constexpr auto make_naive_tensor_view(DataType *__restrict__ p, const tuple< Lengths... > &lengths, const tuple< Strides... > &strides, number< GuaranteedLastDimensionVectorLength >=number<-1 >{}, number< GuaranteedLastDimensionVectorStride >=number<-1 >{})
Definition tensor_view.hpp:471
bool EnvIsEnabled(EnvVar)
Definition tile/core/utility/env.hpp:156
void CK_TILE_ERROR(Args &&... args) noexcept
Definition tile/core/utility/env.hpp:12
__device__ uint32_t amd_wave_read_first_lane(uint16_t v)
Definition tile/core/arch/amd_buffer_addressing.hpp:35
TailNumber
Definition gemm_pipeline_ag_bg_cr_scheduler.hpp:21
__global__ void kentry(Args... args)
Definition tile/host/kernel_launch.hpp:22
typename detail::detector< nonesuch, void, Op, Args... >::value_t is_detected
Definition type_traits.hpp:67
CK_TILE_DEVICE index_t get_warp_id(bool_constant< ReturnSgpr >={})
Definition arch.hpp:104
CK_TILE_DEVICE index_t get_block_1d_id()
Definition arch.hpp:98
StreamKReductionStrategy
Definition streamk_common.hpp:10
@ Atomic
Definition streamk_common.hpp:11
@ Reduction
Definition streamk_common.hpp:12
CK_TILE_HOST void hip_check_error(hipError_t x)
Definition tile/host/hip_check_error.hpp:13
CK_TILE_DEVICE void block_sync_lds()
Definition arch.hpp:282
std::string gemm_prec_str()
Definition utils.hpp:31
auto concat(const Ts &... xs) -> std::enable_if_t<!AllConvertibleToStringView< Ts... >, std::string >
Definition concat.hpp:43
constant< v > number
Definition tile/core/numeric/integral_constant.hpp:37
CK_TILE_DEVICE constexpr auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition null_tile_window.hpp:75
CK_TILE_DEVICE void sweep_tile_span(TileDistributedSpan_, const F &f)
Definition sweep_tile.hpp:20
CK_TILE_DEVICE void store_tile(tile_window_with_static_lengths< BottomTensorView_, WindowLengths_ > &tile_window_tmp, const static_distributed_tensor< DataType_, TileDistribution_ > &dstr_tensor)
Definition store_tile.hpp:23
int32_t index_t
Definition integer.hpp:9
CK_TILE_DEVICE auto load_tile(const TileWindow_ &tile_window, number< i_access >={}, bool_constant< oob_conditional_check >={})
Definition load_tile.hpp:22
CK_TILE_HOST_DEVICE constexpr auto make_tuple(Xs &&... xs)
Definition tile/core/container/tuple.hpp:360
unsigned int uint32_t
Definition stdint.h:126
The Stream K GEMM kernel host arguments.
Definition streamk_gemm_kernel.hpp:678
uint32_t num_sk_blocks
Definition streamk_gemm_kernel.hpp:708
ck_tile::StreamKReductionStrategy reduction_strategy
Definition streamk_gemm_kernel.hpp:707
CK_TILE_HOST StreamKHostArgs(const void *a_ptr_, const void *b_ptr_, void *c_ptr_, index_t M_, index_t N_, index_t K_, index_t stride_A_, index_t stride_B_, index_t stride_C_, StreamKReductionStrategy reduction_strategy_, uint32_t num_sk_blocks_=0xffffffff)
Definition streamk_gemm_kernel.hpp:679
ALayout and ADataType are expected to be scalars, not a tuple.
Definition streamk_gemm_kernel.hpp:751
StreamKReductionStrategy reduction_strategy
The strategy used by work groups to compute final results in C tensor.
Definition streamk_gemm_kernel.hpp:753
uint32_t num_sk_blocks
The number of stream k blocks.
Definition streamk_gemm_kernel.hpp:755
void * workspace_ptr
A pointer to a buffer in device memory for accumulating partial via reduction strategy.
Definition streamk_gemm_kernel.hpp:758
TilePartitioner tile_partitioner
An instance of the TilePartioner class for assisting with mapping workgroups to the C tensor.
Definition streamk_gemm_kernel.hpp:761
Definition streamk_gemm_kernel.hpp:713
UniversalGemmKernel< TilePartitioner_, GemmPipeline_, EpiloguePipeline_ > UniversalGemmKernel
Inject the UniversalGemmKernel base class to support execution of all necessary functions.
Definition streamk_gemm_kernel.hpp:716
remove_cvref_t< typename GemmPipeline::ALayout > ALayout
Specify the layout configurations for A, B, and C.
Definition streamk_gemm_kernel.hpp:726
remove_cvref_t< typename GemmPipeline::ADataType > ADataType
Specify the data type configurations for A, B, and C.
Definition streamk_gemm_kernel.hpp:731
static CK_TILE_HOST auto GridSize(const TilePartitioner &tile_partitioner) -> dim3
Compute the grid size for the Stream K kernel using the tile_partitioner.
Definition streamk_gemm_kernel.hpp:783
remove_cvref_t< typename GemmPipeline::BLayout > BLayout
Definition streamk_gemm_kernel.hpp:727
static CK_TILE_HOST StreamKKernelArgs MakeKernelArgs(const StreamKHostArgs &host_args, int num_cu=NumCU(), int occupancy=Occupancy())
Constructs kernel arguments for the Stream-K kernel.
Definition streamk_gemm_kernel.hpp:809
remove_cvref_t< typename EpiloguePipeline::ODataType > CDataType
Definition streamk_gemm_kernel.hpp:733
remove_cvref_t< TilePartitioner_ > TilePartitioner
Definition streamk_gemm_kernel.hpp:721
remove_cvref_t< EpiloguePipeline_ > EpiloguePipeline
Definition streamk_gemm_kernel.hpp:723
static CK_TILE_HOST constexpr auto BlockSize() -> dim3
Definition streamk_gemm_kernel.hpp:797
StreamKKernel< TilePartitioner, GemmPipeline, EpiloguePipeline > Kernel
Definition streamk_gemm_kernel.hpp:765
static CK_TILE_DEVICE void RunGemm(const std::array< const ADataType *, UniversalGemmKernel::NumATensor > &as_ptr, const std::array< const BDataType *, UniversalGemmKernel::NumBTensor > &bs_ptr, const std::array< const void *, UniversalGemmKernel::NumDTensor > &ds_ptr, CDataType *c_ptr, void *smem_ptr_0, const typename UniversalGemmKernel::KernelArgs &kargs, const index_t num_loop, const index_t block_idx_m, const index_t block_idx_n, const index_t k_size)
Definition streamk_gemm_kernel.hpp:840
static CK_TILE_HOST auto MaxOccupancyGridSize(const stream_config &s) -> dim3
Get the maximum occupancy grid size for the persistent kernel on the current device.
Definition streamk_gemm_kernel.hpp:792
static CK_TILE_HOST void SetWorkSpacePointer(StreamKKernelArgs &kargs, void *workspace_ptr)
Sets the kargs' current workspace_ptr to the given workspace_ptr.
Definition streamk_gemm_kernel.hpp:918
remove_cvref_t< typename GemmPipeline::BDataType > BDataType
Definition streamk_gemm_kernel.hpp:732
static constexpr index_t kBlockSize
Definition streamk_gemm_kernel.hpp:719
remove_cvref_t< GemmPipeline_ > GemmPipeline
Definition streamk_gemm_kernel.hpp:722
static CK_TILE_HOST const std::string GetName()
Definition streamk_gemm_kernel.hpp:767
static CK_TILE_HOST bool IsSupportedArgument(const StreamKKernelArgs &kargs)
Definition streamk_gemm_kernel.hpp:888
CK_TILE_DEVICE void operator()(StreamKKernelArgs kargs) const
Entry point for the Stream-K Kernel, performing the main Stream-K loop.
Definition streamk_gemm_kernel.hpp:924
static CK_TILE_HOST uint32_t GetWorkSpaceSize(const StreamKKernelArgs &kargs)
Computes the buffer size needed to store accumulation results for Stream K.
Definition streamk_gemm_kernel.hpp:903
StreamKKernelArgs KernelArgs
Definition streamk_gemm_kernel.hpp:764
remove_cvref_t< typename GemmPipeline::CLayout > CLayout
Definition streamk_gemm_kernel.hpp:728
The Universal GEMM kernel host arguments.
Definition universal_gemm_kernel.hpp:32
const std::array< index_t, NumDTensor > stride_Ds
Definition universal_gemm_kernel.hpp:73
const std::array< index_t, NumBTensor > stride_Bs
Definition universal_gemm_kernel.hpp:72
CK_TILE_HOST UniversalGemmHostArgs(const std::array< const void *, NumATensor > &as_ptr_, const std::array< const void *, NumBTensor > &bs_ptr_, const std::array< const void *, NumDTensor > &ds_ptr_, void *e_ptr_, index_t k_batch_, index_t M_, index_t N_, index_t K_, const std::array< index_t, NumATensor > &stride_As_, const std::array< index_t, NumBTensor > &stride_Bs_, const std::array< index_t, NumDTensor > &stride_Ds_, index_t stride_E_)
Definition universal_gemm_kernel.hpp:33
index_t K
Definition universal_gemm_kernel.hpp:70
void * e_ptr
Definition universal_gemm_kernel.hpp:65
index_t M
Definition universal_gemm_kernel.hpp:68
const std::array< const void *, NumDTensor > ds_ptr
Definition universal_gemm_kernel.hpp:62
const std::array< const void *, NumATensor > as_ptr
Definition universal_gemm_kernel.hpp:60
const std::array< index_t, NumATensor > stride_As
Definition universal_gemm_kernel.hpp:71
index_t N
Definition universal_gemm_kernel.hpp:69
index_t stride_E
Definition universal_gemm_kernel.hpp:76
const std::array< const void *, NumBTensor > bs_ptr
Definition universal_gemm_kernel.hpp:61
index_t k_batch
Definition universal_gemm_kernel.hpp:80
The GEMM kernel device arguments.
Definition universal_gemm_kernel.hpp:86
void * e_ptr
The E output tensor's pointer to device memory.
Definition universal_gemm_kernel.hpp:94
std::array< index_t, NumBTensor > stride_Bs
The distance between consecutive elements of non-contiguous dimension (in memory) of Bs tensor.
Definition universal_gemm_kernel.hpp:106
const std::array< const void *, NumDTensor > ds_ptr
The Ds input tensor's pointer to device memory.
Definition universal_gemm_kernel.hpp:92
std::array< index_t, NumATensor > stride_As
The distance between consecutive elements of non-contiguous dimension (in memory) of As tensor.
Definition universal_gemm_kernel.hpp:103
const std::array< const void *, NumATensor > as_ptr
The As input tensor's pointer to device memory.
Definition universal_gemm_kernel.hpp:88
index_t k_batch
Definition universal_gemm_kernel.hpp:113
index_t N
GEMM's N dimension size.
Definition universal_gemm_kernel.hpp:98
index_t stride_E
The distance between consecutive elements of non-contiguous dimension (in memory) of E tensor.
Definition universal_gemm_kernel.hpp:112
index_t K
GEMM's K dimension size.
Definition universal_gemm_kernel.hpp:100
const std::array< const void *, NumBTensor > bs_ptr
The Bs input tensor's pointer to device memory.
Definition universal_gemm_kernel.hpp:90
std::array< index_t, NumDTensor > stride_Ds
The distance between consecutive elements of non-contiguous dimension (in memory) of Ds tensor.
Definition universal_gemm_kernel.hpp:109
index_t M
GEMM's M dimension size.
Definition universal_gemm_kernel.hpp:96
static constexpr auto I2
Definition universal_gemm_kernel.hpp:238
static CK_TILE_DEVICE auto MakeGemmTileWindows(const PadView &views, const index_t i_m, const index_t i_n)
Definition universal_gemm_kernel.hpp:853
static constexpr auto I3
Definition universal_gemm_kernel.hpp:239
static CK_TILE_DEVICE auto MakeGemmPadViews(const TensorView &views)
Definition universal_gemm_kernel.hpp:754
static constexpr bool PersistentKernel
Definition universal_gemm_kernel.hpp:217
static CK_TILE_HOST_DEVICE constexpr index_t GetSmemSize()
Definition universal_gemm_kernel.hpp:319
static constexpr auto I1
Definition universal_gemm_kernel.hpp:237
static CK_TILE_HOST auto BlockSize()
Definition universal_gemm_kernel.hpp:290
static CK_TILE_HOST auto MaxOccupancyGridSize(const stream_config &s) -> dim3
Get the maximum occupancy grid size for the persistent kernel on the current device.
Definition universal_gemm_kernel.hpp:278
static constexpr auto I0
Definition universal_gemm_kernel.hpp:236
static CK_TILE_HOST bool IsSupportedArgument(const KernelArgs &kargs)
Definition universal_gemm_kernel.hpp:373
static constexpr index_t kBlockSize
Definition universal_gemm_kernel.hpp:202
UniversalGemmKernelArgs< AsLayout::size(), BsLayout::size(), DsLayout::size()> KernelArgs
Definition universal_gemm_kernel.hpp:257
The Stream K GEMM kernel host arguments.
Definition streamk_gemm_kernel.hpp:21
ck_tile::StreamKReductionStrategy reduction_strategy
Definition streamk_gemm_kernel.hpp:48
CK_TILE_HOST StreamKHostArgs(const void *a_ptr_, const void *b_ptr_, void *c_ptr_, index_t M_, index_t N_, index_t K_, index_t stride_A_, index_t stride_B_, index_t stride_C_, StreamKReductionStrategy reduction_strategy_)
Definition streamk_gemm_kernel.hpp:22
ALayout and ADataType are expected to be scalars, not a tuple.
Definition streamk_gemm_kernel.hpp:108
TilePartitioner tile_partitioner
An instance of the TilePartioner class for assisting with mapping workgroups to the C tensor.
Definition streamk_gemm_kernel.hpp:138
StreamKReductionStrategy reduction_strategy
The strategy used by work groups to compute final results in C tensor.
Definition streamk_gemm_kernel.hpp:132
void * workspace_ptr
A pointer to a buffer in device memory for accumulating partial via reduction strategy.
Definition streamk_gemm_kernel.hpp:135
StreamKKernelArgs(const StreamKHostArgs &host_args, index_t grid)
Definition streamk_gemm_kernel.hpp:109
The Stream K GEMM kernel class.
Definition streamk_gemm_kernel.hpp:64
typename GemmPipeline::ALayout ALayout
Specify the layout configurations for A, B, and C.
Definition streamk_gemm_kernel.hpp:82
CK_TILE_DEVICE void WaitStorePartialDone(const StreamKKernelArgs &kargs, index_t cta_idx) const
Waits for the thread block (cta_idx) to complete storing its partial results.
Definition streamk_gemm_kernel.hpp:314
static CK_TILE_HOST auto GridSize(const TilePartitioner &tile_partitioner) -> dim3
Compute the grid size for the Stream K kernel using the tile_partitioner.
Definition streamk_gemm_kernel.hpp:160
static constexpr bool is_tuple_v
Definition streamk_gemm_kernel.hpp:93
CK_TILE_DEVICE void AddBlockTile(OAccTile &in_out_block_tile, const OAccTile &in_block_tile) const
Adds the values of a block tile to an output block tile.
Definition streamk_gemm_kernel.hpp:327
static constexpr bool PersistentDP
Definition streamk_gemm_kernel.hpp:71
EpiloguePipeline_ EpiloguePipeline
Definition streamk_gemm_kernel.hpp:75
static CK_TILE_HOST auto MaxOccupancyGridSize(const stream_config &s) -> dim3
Get the maximum occupancy grid size for the persistent kernel on the current device.
Definition streamk_gemm_kernel.hpp:169
typename GemmPipeline::BDataType BDataType
Definition streamk_gemm_kernel.hpp:88
CK_TILE_DEVICE std::enable_if_t< U > operator()(StreamKKernelArgs kargs) const
Entry point for the Stream-K Kernel with persistent DP.
Definition streamk_gemm_kernel.hpp:584
CK_TILE_DEVICE auto LoadPartial(const StreamKKernelArgs &kargs, index_t cta_idx, const OAccTileDist &c_block_tile_dist) const
Loads a partial block tile from the workspace buffer.
Definition streamk_gemm_kernel.hpp:348
StreamKKernelArgs KernelArgs
Definition streamk_gemm_kernel.hpp:141
CK_TILE_DEVICE void BaseGemm(StreamKKernelArgs &kargs, index_t tile_idx, index_t num_loop, index_t i_k_a, index_t i_k_b, index_t k_size, void *smem_ptr_0) const
Computes offsets into A, B, and C tensors then runs the GEMM pipeline and epilogue.
Definition streamk_gemm_kernel.hpp:274
static CK_TILE_HOST const std::string GetName()
Definition streamk_gemm_kernel.hpp:144
typename EpiloguePipeline::AccDataType AccDataType
Definition streamk_gemm_kernel.hpp:90
static CK_TILE_HOST bool IsSupportedArgument(const StreamKKernelArgs &kargs)
Definition streamk_gemm_kernel.hpp:245
UniversalGemmKernel< TilePartitioner_, GemmPipeline_, EpiloguePipeline_ > UniversalGemmKernel
Inject the UniversalGemmKernel base class to support execution of all necessary functions.
Definition streamk_gemm_kernel.hpp:67
static CK_TILE_HOST constexpr auto BlockSize() -> dim3
Definition streamk_gemm_kernel.hpp:174
CK_TILE_DEVICE void StorePartial(const StreamKKernelArgs &kargs, index_t cta_idx, const OAccTile &c_block_tile) const
Stores a partial block tile to the workspace buffer.
Definition streamk_gemm_kernel.hpp:381
static constexpr index_t kBlockSize
Definition streamk_gemm_kernel.hpp:70
static CK_TILE_HOST uint32_t GetWorkSpaceSize(const StreamKKernelArgs &kargs)
Computes the buffer size needed to store accumulation results for Stream K.
Definition streamk_gemm_kernel.hpp:252
static CK_TILE_HOST StreamKKernelArgs MakeKernelArgs(const StreamKHostArgs &host_args, int num_cu=NumCU(), int occupancy=Occupancy())
Constructs kernel arguments for the Stream-K kernel.
Definition streamk_gemm_kernel.hpp:186
CK_TILE_DEVICE void StreamKGemm(StreamKKernelArgs &kargs, index_t cta_idx, void *smem_ptr_0) const
Runs the main Stream-K algorithm.
Definition streamk_gemm_kernel.hpp:415
static CK_TILE_DEVICE void RunGemm(const std::array< const ADataType *, UniversalGemmKernel::NumATensor > &as_ptr, const std::array< const BDataType *, UniversalGemmKernel::NumBTensor > &bs_ptr, const std::array< const void *, UniversalGemmKernel::NumDTensor > &ds_ptr, CDataType *c_ptr, void *smem_ptr_0, const typename UniversalGemmKernel::KernelArgs &kargs, const index_t num_loop, const index_t block_idx_m, const index_t block_idx_n, const index_t k_size)
Definition streamk_gemm_kernel.hpp:197
static CK_TILE_HOST void SetWorkSpacePointer(StreamKKernelArgs &kargs, void *workspace_ptr)
Sets the kargs' current workspace_ptr to the given workspace_ptr.
Definition streamk_gemm_kernel.hpp:259
typename GemmPipeline::ADataType ADataType
Specify the data type configurations for A, B, and C.
Definition streamk_gemm_kernel.hpp:87
typename GemmPipeline::BLayout BLayout
Definition streamk_gemm_kernel.hpp:83
CK_TILE_DEVICE void SignalStorePartialDone(const StreamKKernelArgs &kargs, index_t cta_idx) const
Signals that the current thread block (CTA) has completed storing its partial results.
Definition streamk_gemm_kernel.hpp:301
CK_TILE_DEVICE std::enable_if_t<!U > operator()(StreamKKernelArgs kargs) const
Entry point for the Stream-K Kernel with non-persistent DP.
Definition streamk_gemm_kernel.hpp:553
typename EpiloguePipeline::ODataType CDataType
Definition streamk_gemm_kernel.hpp:89
GemmPipeline_ GemmPipeline
Definition streamk_gemm_kernel.hpp:74
TilePartitioner_ TilePartitioner
Definition streamk_gemm_kernel.hpp:73
StreamKKernel< TilePartitioner, GemmPipeline, EpiloguePipeline > Kernel
Definition streamk_gemm_kernel.hpp:142
typename GemmPipeline::CLayout CLayout
Definition streamk_gemm_kernel.hpp:84
Definition ck_tile/host/stream_config.hpp:30
Definition tile/core/container/tuple.hpp:192
Definition tile/core/arch/workgroup_barrier.hpp:12
CK_TILE_DEVICE void wait_eq(uint32_t value, uint32_t offset=0)
Definition tile/core/arch/workgroup_barrier.hpp:20
CK_TILE_DEVICE void wait_set(uint32_t compare, uint32_t value, uint32_t offset=0)
Definition tile/core/arch/workgroup_barrier.hpp:38
#define CK_TILE_ENV(name)
Definition tile/core/utility/env.hpp:145