streamk_gemm_tile_partitioner.hpp Source File

streamk_gemm_tile_partitioner.hpp Source File#

Composable Kernel: streamk_gemm_tile_partitioner.hpp Source File
streamk_gemm_tile_partitioner.hpp
Go to the documentation of this file.
1// Copyright © Advanced Micro Devices, Inc., or its affiliates.
2// SPDX-License-Identifier: MIT
3
4#pragma once
5
6#include "ck_tile/core.hpp"
8
9namespace ck_tile {
10
21template <typename BlockGemmShapeType,
24{
25 using BlockGemmShape = BlockGemmShapeType;
26
27 static constexpr index_t MPerBlock = BlockGemmShape::kM;
28 static constexpr index_t NPerBlock = BlockGemmShape::kN;
29 static constexpr index_t KPerBlock = BlockGemmShape::kK;
30 static constexpr StreamKReductionStrategy ReductionStrategy = ReductionStrategyType;
31
33
40 CK_TILE_HOST_DEVICE index_t get_partials_buffer_size(index_t acc_element_bytes) const noexcept;
41
48
49 public:
63 get_iter_boundaries(index_t& iter_start, index_t& iter_end, index_t cta_idx) const noexcept;
64
71 CK_TILE_DEVICE index_t get_tile_index(index_t iter_start) const noexcept;
72
82 CK_TILE_DEVICE void get_tile_boundaries(index_t& tile_iter_start,
83 index_t& tile_iter_end,
84 index_t tile_idx) const noexcept;
85
96 index_t tile_iter_start) noexcept;
97
108 get_local_iter_end(index_t tile_iter_start, index_t iter_end, index_t tile_iter_end) noexcept;
109
116 CK_TILE_DEVICE auto
117 get_output_tile_index(index_t tile_idx) const noexcept -> tuple<index_t, index_t>;
118
125 CK_TILE_HOST_DEVICE index_t get_workspace_size(index_t acc_element_bytes) const noexcept;
126
131
136 CK_TILE_HOST_DEVICE index_t get_grid() const noexcept;
137
143
148
152 CK_TILE_HOST_DEVICE index_t get_sk_ctas() const noexcept;
153
158
164
170
177
182
186 CK_TILE_HOST_DEVICE index_t get_n() const noexcept;
187
192
193 protected:
197
198 private:
203 index_t full_tiles_ = 1;
204 index_t sk_tiles_;
205 index_t sk_ctas_;
206 index_t total_sk_iters_;
207 index_t iters_per_tile_;
208 index_t iters_per_sk_cta_;
209 index_t extra_iters_;
210 index_t total_dp_iters_;
211 index_t n_;
212};
213
227template <typename BlockGemmShapeType,
228 StreamKReductionStrategy ReductionStrategyType,
229 bool Persistent>
231
243template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
244struct StreamKTilePartitioner_v2<BlockGemmShapeType, ReductionStrategyType, true>
245 : StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>
246{
250 ck_tile::index_t grid);
251
252 public:
253 static constexpr bool PERSISTENT = true;
261 CK_TILE_HOST auto grid_size() const noexcept -> dim3;
262
267
273
274 protected:
277};
278
290template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
291struct StreamKTilePartitioner_v2<BlockGemmShapeType, ReductionStrategyType, false>
292 : StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>
293{
297 ck_tile::index_t grid);
298
299 public:
300 static constexpr bool PERSISTENT = false;
308 CK_TILE_HOST auto grid_size() const noexcept -> dim3;
309
313 CK_TILE_HOST_DEVICE index_t get_dp_ctas() const noexcept;
314
319
324
325 protected:
329};
330
331} // namespace ck_tile
332
333#include "streamk_gemm_tile_partitioner_impl.hpp"
#define CK_TILE_DEVICE
Definition config.hpp:41
#define CK_TILE_HOST
Definition config.hpp:40
#define CK_TILE_HOST_DEVICE
Definition config.hpp:42
Definition tile/core/algorithm/cluster_descriptor.hpp:13
StreamKReductionStrategy
Definition streamk_common.hpp:10
@ Atomic
Definition streamk_common.hpp:11
int32_t index_t
Definition integer.hpp:9
index_t sk_start_block_idx_
Definition streamk_gemm_tile_partitioner.hpp:328
index_t dp_ctas_
Definition streamk_gemm_tile_partitioner.hpp:326
CK_TILE_HOST auto grid_size() const noexcept -> dim3
Calculates the launching grid size for the Stream-K kernel. In the Non-Persistent case,...
Definition streamk_gemm_tile_partitioner_impl.hpp:303
CK_TILE_HOST_DEVICE index_t get_dp_start_block_idx() const noexcept
Returns starting DP workgroup index. It is always zero.
Definition streamk_gemm_tile_partitioner_impl.hpp:320
static constexpr bool PERSISTENT
Definition streamk_gemm_tile_partitioner.hpp:300
StreamKTilePartitioner_v2(ck_tile::index_t m, ck_tile::index_t n, ck_tile::index_t k, ck_tile::index_t grid)
Definition streamk_gemm_tile_partitioner_impl.hpp:290
index_t dp_start_block_idx_
Definition streamk_gemm_tile_partitioner.hpp:327
CK_TILE_HOST_DEVICE index_t get_sk_start_block_idx() const noexcept
The index that starts the Stream-K workgroups. It is set to the number of dp_tiles_.
Definition streamk_gemm_tile_partitioner_impl.hpp:328
CK_TILE_HOST_DEVICE index_t get_dp_ctas() const noexcept
Returns the total number of DP workgroups.
Definition streamk_gemm_tile_partitioner_impl.hpp:311
CK_TILE_HOST auto grid_size() const noexcept -> dim3
Calculates the launching grid size for the Stream-K kernel. In the Persistent case,...
Definition streamk_gemm_tile_partitioner_impl.hpp:258
StreamKTilePartitioner_v2(ck_tile::index_t m, ck_tile::index_t n, ck_tile::index_t k, ck_tile::index_t grid)
Definition streamk_gemm_tile_partitioner_impl.hpp:246
static constexpr bool PERSISTENT
Definition streamk_gemm_tile_partitioner.hpp:253
index_t dp_tiles_per_cta_
Definition streamk_gemm_tile_partitioner.hpp:275
CK_TILE_HOST_DEVICE index_t get_extra_dp_tiles() const noexcept
Returns the total number of DP tiles left over when dp_tiles_ is not evenly divisible by grid_.
Definition streamk_gemm_tile_partitioner_impl.hpp:281
index_t extra_dp_tiles_
Definition streamk_gemm_tile_partitioner.hpp:276
CK_TILE_HOST_DEVICE index_t get_dp_tiles_per_cta() const noexcept
Returns the total number of DP tiles per workgroup.
Definition streamk_gemm_tile_partitioner_impl.hpp:273
Template for the Stream-K tile partitioner derived struct.
Definition streamk_gemm_tile_partitioner.hpp:230
CK_TILE_HOST_DEVICE index_t get_sk_tiles() const noexcept
Returns the number of tiles in the C tensor that will use the Stream-K approach.
Definition streamk_gemm_tile_partitioner_impl.hpp:158
CK_TILE_HOST_DEVICE index_t get_total_dp_iters() const noexcept
Returns the total number of DP iterations.
Definition streamk_gemm_tile_partitioner_impl.hpp:204
CK_TILE_HOST_DEVICE index_t get_flags_buffer_size() const noexcept
Calculates the total space needed for the flags buffer.
Definition streamk_gemm_tile_partitioner_impl.hpp:57
CK_TILE_HOST_DEVICE index_t get_num_tiles() const noexcept
Returns the number of macro tiles in the C tensor.
Definition streamk_gemm_tile_partitioner_impl.hpp:136
CK_TILE_HOST_DEVICE index_t get_partials_buffer_size(index_t acc_element_bytes) const noexcept
Calculates the total space needed for the partials buffer.
Definition streamk_gemm_tile_partitioner_impl.hpp:49
CK_TILE_DEVICE void get_iter_boundaries(index_t &iter_start, index_t &iter_end, index_t cta_idx) const noexcept
Calculates the start and end iteration given the cta_idx.
Definition streamk_gemm_tile_partitioner_impl.hpp:65
CK_TILE_DEVICE void get_tile_boundaries(index_t &tile_iter_start, index_t &tile_iter_end, index_t tile_idx) const noexcept
Calculates the starting and ending tile boundaries for the given 1D tile index.
Definition streamk_gemm_tile_partitioner_impl.hpp:83
CK_TILE_DEVICE auto get_output_tile_index(index_t tile_idx) const noexcept -> tuple< index_t, index_t >
Calculates the workgroups 2D tile index in the C tensor given the 1D tile index.
Definition streamk_gemm_tile_partitioner_impl.hpp:108
index_t grid_
Definition streamk_gemm_tile_partitioner.hpp:195
CK_TILE_HOST_DEVICE index_t get_extra_iters() const noexcept
Returns the remainder resulting from total_sk_iters_ divided by sk_ctas_. When this is non-zero,...
Definition streamk_gemm_tile_partitioner_impl.hpp:196
static constexpr index_t KPerBlock
Definition streamk_gemm_tile_partitioner.hpp:29
static CK_TILE_DEVICE index_t get_local_iter_end(index_t tile_iter_start, index_t iter_end, index_t tile_iter_end) noexcept
Calculates the workgroup's non-inclusive end iteration that is local to a tile.
Definition streamk_gemm_tile_partitioner_impl.hpp:100
CK_TILE_HOST_DEVICE index_t get_grid() const noexcept
Returns the maximum number of active workgroups; this is assumed to be number of CUs * occupancy.
Definition streamk_gemm_tile_partitioner_impl.hpp:144
static constexpr StreamKReductionStrategy ReductionStrategy
Definition streamk_gemm_tile_partitioner.hpp:30
CK_TILE_DEVICE index_t get_tile_index(index_t iter_start) const noexcept
Calculates the 1D tile index in the C tensor for a workgroup.
Definition streamk_gemm_tile_partitioner_impl.hpp:75
CK_TILE_HOST_DEVICE index_t get_total_sk_iters() const noexcept
Returns the total number of Stream-K iterations.
Definition streamk_gemm_tile_partitioner_impl.hpp:172
CK_TILE_HOST_DEVICE index_t get_dp_tiles() const noexcept
Returns the number of tiles in the C tensor that will use the data-parallel (DP) approach.
Definition streamk_gemm_tile_partitioner_impl.hpp:151
BlockGemmShapeType BlockGemmShape
Definition streamk_gemm_tile_partitioner.hpp:25
CK_TILE_HOST_DEVICE index_t get_iters_per_tile() const noexcept
Returns the total number of iterations per tile in the C tensor. In other words, this is the total nu...
Definition streamk_gemm_tile_partitioner_impl.hpp:180
CK_TILE_HOST_DEVICE index_t get_n() const noexcept
Returns the n dimension for the GEMM problem.
Definition streamk_gemm_tile_partitioner_impl.hpp:212
static constexpr index_t NPerBlock
Definition streamk_gemm_tile_partitioner.hpp:28
static constexpr index_t MPerBlock
Definition streamk_gemm_tile_partitioner.hpp:27
index_t num_tiles_
Definition streamk_gemm_tile_partitioner.hpp:194
CK_TILE_HOST_DEVICE index_t get_workspace_size(index_t acc_element_bytes) const noexcept
Calculates the total space needed for the partials and flags buffers.
Definition streamk_gemm_tile_partitioner_impl.hpp:120
StreamKTilePartitionerBase(index_t m, index_t n, index_t k, index_t grid)
Definition streamk_gemm_tile_partitioner_impl.hpp:8
CK_TILE_HOST_DEVICE index_t get_iters_per_sk_cta() const noexcept
Returns the total number of Stream-K iterations for each sk_cta. This is the lower bound (i....
Definition streamk_gemm_tile_partitioner_impl.hpp:188
CK_TILE_HOST_DEVICE index_t get_sk_ctas() const noexcept
Returns the number of workgroups that will participate in Stream-K in the sk_tiles_.
Definition streamk_gemm_tile_partitioner_impl.hpp:165
static CK_TILE_DEVICE index_t get_local_iter(index_t iter_start, index_t tile_iter_start) noexcept
Calculates the workgroup's starting iteration that is local to a tile.
Definition streamk_gemm_tile_partitioner_impl.hpp:92
CK_TILE_HOST index_t estimate_num_wgs_per_tile() const noexcept
Returns an estimate of the number of workgroups writing to the same macro tile in C.
Definition streamk_gemm_tile_partitioner_impl.hpp:219
index_t dp_tiles_
Definition streamk_gemm_tile_partitioner.hpp:196
Definition tile/core/container/tuple.hpp:192