topk_softmax_warp_per_row_policy.hpp Source File

topk_softmax_warp_per_row_policy.hpp Source File#

Composable Kernel: topk_softmax_warp_per_row_policy.hpp Source File
topk_softmax_warp_per_row_policy.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include "ck_tile/core.hpp"
9
10namespace ck_tile {
11
13{
14 template <typename Problem>
16 {
17 // TODO: Y dim must have one dim that is not reduced
21 tuple<sequence<Problem::IssuesPerCol,
22 Problem::WarpsPerBlock,
23 Problem::RowsPerWarpPerColIssue>,
29 }
30
31 template <typename Problem>
33 {
36 tuple<sequence<Problem::IssuesPerCol,
37 Problem::WarpsPerBlock,
38 Problem::RowsPerWarpPerColIssue>,
39 sequence<1>>, // each row write out single element
44 }
45
46 template <typename Problem>
47 CK_TILE_HOST_DEVICE static constexpr auto GetSoftmax()
48 {
51 }
52
53 template <typename Problem>
54 CK_TILE_HOST_DEVICE static constexpr auto GetTopk()
55 {
56 using topk_problem = BlockTopkStream2DProblem<typename Problem::WeightType,
57 typename Problem::IndexType,
58 Problem::LanesPerRow>;
59 // Note: replicate is LanesPerRow
61 }
62};
63} // namespace ck_tile
#define CK_TILE_HOST_DEVICE
Definition config.hpp:42
Definition tile/core/algorithm/cluster_descriptor.hpp:13
CK_TILE_HOST_DEVICE constexpr auto make_static_tile_distribution(StaticTileDistributionEncoding_)
Definition tile_distribution.hpp:480
Definition block_softmax_2d.hpp:21
Definition block_softmax_2d_problem.hpp:12
Definition block_topk_stream_2d.hpp:17
Definition block_topk_stream_2d_problem.hpp:17
Definition topk_softmax_warp_per_row_policy.hpp:13
static CK_TILE_HOST_DEVICE constexpr auto MakeInputDistribution()
Definition topk_softmax_warp_per_row_policy.hpp:15
static CK_TILE_HOST_DEVICE constexpr auto GetSoftmax()
Definition topk_softmax_warp_per_row_policy.hpp:47
static CK_TILE_HOST_DEVICE constexpr auto GetTopk()
Definition topk_softmax_warp_per_row_policy.hpp:54
static CK_TILE_HOST_DEVICE constexpr auto MakeOutputDistribution()
Definition topk_softmax_warp_per_row_policy.hpp:32
Definition tile/core/container/sequence.hpp:49
Definition tile_distribution_encoding.hpp:26
Definition tile/core/container/tuple.hpp:192