StreamKTilePartitioner< BlockGemmShapeType, ReductionStrategy, TileSwizzleSubM > Struct Template Reference#
Stream-K tile partitioner that dynamically balances work across workgroups. More...
#include <gemm_tile_partitioner.hpp>
Public Types | |
| using | BlockGemmShape = BlockGemmShapeType |
Public Member Functions | |
| CK_TILE_HOST_DEVICE | StreamKTilePartitioner () noexcept=delete |
| CK_TILE_HOST_DEVICE | StreamKTilePartitioner (uint32_t M, uint32_t N, uint32_t K, uint32_t num_cu, uint32_t occupancy, uint32_t sk_blocks=0xffffffff) noexcept |
| Construct Stream-K tile partitioner with problem dimensions. | |
| CK_TILE_HOST auto | GridSize () const noexcept -> dim3 |
| Calculate optimal grid size for Stream-K. | |
| CK_TILE_DEVICE auto | GetOutputTileIndex (uint32_t tile_idx) const noexcept -> tuple< uint32_t, uint32_t > |
| Get output tile index for standard 2D mapping (compatibility). | |
| CK_TILE_DEVICE void | GetBlockItr (uint32_t block_idx, uint32_t &iter_start, uint32_t &iter_end) const noexcept |
| Get work range for a given block ID. | |
| CK_TILE_HOST_DEVICE uint32_t | GetSkTotalIters () const noexcept |
| Get total number of iterations for sk tiles. | |
| CK_TILE_HOST_DEVICE uint32_t | GetSkTiles () const noexcept |
| Get total number of sk tiles. | |
| CK_TILE_DEVICE uint32_t | GetCurrentIterLength (uint32_t iter_start, uint32_t iter_end) const noexcept |
| Get length of loop iterations for stream-k loop. | |
| CK_TILE_DEVICE uint32_t | GetTileIdx (uint32_t iter) const noexcept |
| Get index of tile during a specified iteration. | |
| CK_TILE_DEVICE void | GetTileIdxWithOffset (uint32_t iter, uint32_t &tile_idx, uint32_t &iter_offset) const noexcept |
| Get index of tile during a specified iteration. | |
| CK_TILE_HOST_DEVICE uint32_t | GetWorkSpaceSizeForAcc (uint32_t acc_element_bytes) const noexcept |
| Calculates the buffer space needed for accumulation. | |
| CK_TILE_HOST_DEVICE uint32_t | GetWorkSpaceSizeForSemaphore () const noexcept |
| Calculates the buffer space needed for the semaphore. | |
| CK_TILE_HOST_DEVICE uint32_t | GetWorkSpaceSize (uint32_t acc_element_bytes) const noexcept |
| Calculates the total buffer space needed for accumulation and the semaphore. | |
| CK_TILE_HOST_DEVICE uint32_t | GetTileIntersections (uint32_t tiles_, const mdiv &equiv_tiles_) const noexcept |
| Get location of intersection of tiles for reduction. | |
| CK_TILE_HOST_DEVICE uint32_t | GetTilesCoverSkBlock (uint32_t num_sk_blocks_, uint32_t iters_per_sk_block_) const noexcept |
| Calculate the number of tiles needed for the number of sk blocks. | |
| CK_TILE_HOST_DEVICE uint32_t | GetTotalAccBuffers () const noexcept |
| Calculate the amount of total accumulation buffers required for stream-k. | |
| CK_TILE_DEVICE uint32_t | GetAccBufferOffsetFromTile (uint32_t tile_idx_) const noexcept |
| Calculate offset based on tile index for big/little tiles. | |
| CK_TILE_DEVICE uint32_t | GetAccBufferOffsetFromBlock (uint32_t block_idx_) const noexcept |
| Calculate offset based on block_idx index for big/little streamk blocks. | |
| CK_TILE_HOST_DEVICE uint32_t | GetNumTileM () const noexcept |
| CK_TILE_HOST_DEVICE uint32_t | GetNumTileN () const noexcept |
| CK_TILE_HOST_DEVICE uint32_t | GetNumTileK () const noexcept |
Static Public Member Functions | |
| static CK_TILE_HOST_DEVICE auto | GetLoopNum (uint32_t K) noexcept -> uint32_t |
| Calculate number of loop iterations over K dimension for given work unit. | |
Static Public Attributes | |
| static constexpr uint32_t | MPerBlock = BlockGemmShape::kM |
| static constexpr uint32_t | NPerBlock = BlockGemmShape::kN |
| static constexpr uint32_t | KPerBlock = BlockGemmShape::kK |
Detailed Description
struct ck_tile::StreamKTilePartitioner< BlockGemmShapeType, ReductionStrategy, TileSwizzleSubM >
Stream-K tile partitioner that dynamically balances work across workgroups.
This partitioner is responsible for mapping workgroups to tiles in the C tensor for the Stream-K algorithm which decomposes the GEMM problem into smaller work units and distributes them more evenly across available blocks, improving load balancing especially for cases where the K dimension is large.
- Template Parameters
-
BlockGemmShapeType A class providing basic GEMM parameters. ReductionStrategy A class that defines the reduction strategy for the results in the C Tensor. TileSwizzleSubM A value that defines the size of the swizzle group along the m dimension, where the swizzle group denotes consecutive tiles down a column. For instance a swizzle group of 8 denotes tiles 0, 1, ..., 7, map to tiles [0,0], [1,0], ..., [7,0] in the C tensor.
Member Typedef Documentation
◆ BlockGemmShape
| using ck_tile::StreamKTilePartitioner< BlockGemmShapeType, ReductionStrategy, TileSwizzleSubM >::BlockGemmShape = BlockGemmShapeType |
Constructor & Destructor Documentation
◆ StreamKTilePartitioner() [1/2]
|
deletenoexcept |
◆ StreamKTilePartitioner() [2/2]
|
inlinenoexcept |
Construct Stream-K tile partitioner with problem dimensions.
Member Function Documentation
◆ GetAccBufferOffsetFromBlock()
|
inlinenoexcept |
Calculate offset based on block_idx index for big/little streamk blocks.
◆ GetAccBufferOffsetFromTile()
|
inlinenoexcept |
Calculate offset based on tile index for big/little tiles.
◆ GetBlockItr()
|
inlinenoexcept |
Get work range for a given block ID.
◆ GetCurrentIterLength()
|
inlinenoexcept |
Get length of loop iterations for stream-k loop.
◆ GetLoopNum()
|
inlinestaticnoexcept |
Calculate number of loop iterations over K dimension for given work unit.
◆ GetNumTileK()
|
inlinenoexcept |
◆ GetNumTileM()
|
inlinenoexcept |
◆ GetNumTileN()
|
inlinenoexcept |
◆ GetOutputTileIndex()
|
inlinenoexcept |
Get output tile index for standard 2D mapping (compatibility).
◆ GetSkTiles()
|
inlinenoexcept |
Get total number of sk tiles.
◆ GetSkTotalIters()
|
inlinenoexcept |
Get total number of iterations for sk tiles.
◆ GetTileIdx()
|
inlinenoexcept |
Get index of tile during a specified iteration.
◆ GetTileIdxWithOffset()
|
inlinenoexcept |
Get index of tile during a specified iteration.
◆ GetTileIntersections()
|
inlinenoexcept |
Get location of intersection of tiles for reduction.
◆ GetTilesCoverSkBlock()
|
inlinenoexcept |
Calculate the number of tiles needed for the number of sk blocks.
◆ GetTotalAccBuffers()
|
inlinenoexcept |
Calculate the amount of total accumulation buffers required for stream-k.
◆ GetWorkSpaceSize()
|
inlinenoexcept |
Calculates the total buffer space needed for accumulation and the semaphore.
◆ GetWorkSpaceSizeForAcc()
|
inlinenoexcept |
Calculates the buffer space needed for accumulation.
◆ GetWorkSpaceSizeForSemaphore()
|
inlinenoexcept |
Calculates the buffer space needed for the semaphore.
◆ GridSize()
|
inlinenoexcept |
Calculate optimal grid size for Stream-K.
Member Data Documentation
◆ dp_start_block_idx
| uint32_t ck_tile::StreamKTilePartitioner< BlockGemmShapeType, ReductionStrategy, TileSwizzleSubM >::dp_start_block_idx |
◆ equiv_tiles_big
| mdiv ck_tile::StreamKTilePartitioner< BlockGemmShapeType, ReductionStrategy, TileSwizzleSubM >::equiv_tiles_big |
◆ equiv_tiles_little
| mdiv ck_tile::StreamKTilePartitioner< BlockGemmShapeType, ReductionStrategy, TileSwizzleSubM >::equiv_tiles_little |
◆ k_iters_per_big_block
| uint32_t ck_tile::StreamKTilePartitioner< BlockGemmShapeType, ReductionStrategy, TileSwizzleSubM >::k_iters_per_big_block |
◆ k_iters_per_tile
| mdiv ck_tile::StreamKTilePartitioner< BlockGemmShapeType, ReductionStrategy, TileSwizzleSubM >::k_iters_per_tile |
◆ KPerBlock
|
staticconstexpr |
◆ MPerBlock
|
staticconstexpr |
◆ n_tiles
| mdiv2 ck_tile::StreamKTilePartitioner< BlockGemmShapeType, ReductionStrategy, TileSwizzleSubM >::n_tiles |
◆ NPerBlock
|
staticconstexpr |
◆ reduction_start_block_idx
| uint32_t ck_tile::StreamKTilePartitioner< BlockGemmShapeType, ReductionStrategy, TileSwizzleSubM >::reduction_start_block_idx |
◆ sk_num_big_blocks
| uint32_t ck_tile::StreamKTilePartitioner< BlockGemmShapeType, ReductionStrategy, TileSwizzleSubM >::sk_num_big_blocks |
◆ sk_num_blocks
| uint32_t ck_tile::StreamKTilePartitioner< BlockGemmShapeType, ReductionStrategy, TileSwizzleSubM >::sk_num_blocks |
The documentation for this struct was generated from the following file: