smoothquant_pipeline_one_pass.hpp Source File

smoothquant_pipeline_one_pass.hpp Source File#

Composable Kernel: smoothquant_pipeline_one_pass.hpp Source File
smoothquant_pipeline_one_pass.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include "ck_tile/core.hpp"
8#include <string>
9#include <type_traits>
10
11namespace ck_tile {
12
13template <typename Problem_, typename Policy_ = SmoothquantPipelineDefaultPolicy>
15{
18
24
25 static constexpr bool kNeedCrossWarpSync = Problem::kNeedCrossWarpSync;
26 static constexpr bool kPadM = false; // TODO - BlockSmoothquantProblem::kPadM
27 static constexpr bool kPadN = Problem::kPadN;
28 static constexpr bool UseMax3 = true; // TODO - Move to trait
29
30 static constexpr const char* name = []() {
31 if constexpr(kNeedCrossWarpSync)
32 return "bpr_op"; // block per row
33 else
34 return "wpr_op"; // warp per row
35 }();
36
38 {
39 return Policy::template GetSmemSize<Problem>();
40 }
41
42 template <typename XWindow,
43 typename SmoothScaleWindow,
44 typename QYWindow,
45 typename YScaleWindow>
46 CK_TILE_DEVICE auto operator()(const XWindow& x_window_,
47 const SmoothScaleWindow& smscale_window_,
48 YScaleWindow& yscale_window,
49 QYWindow& qy_window,
51 void* smem) const
52 {
53 auto x_window =
54 make_tile_window(x_window_, Policy::template MakeXBlockTileDistribution<Problem>());
55 auto smscale_window = make_tile_window(
56 smscale_window_, Policy::template MakeSmoothScaleBlockTileDistribution<Problem>());
57
58 auto reduce_absmax_func = ReduceOp::AbsMax{};
59 auto reduce_absmax3_func = [](auto acc_, auto v_0_, auto v_1_) {
60 float rtn;
61 asm volatile("v_max3_f32 %0, %1, abs(%2), abs(%3)"
62 : "=v"(rtn)
63 : "v"(acc_), "v"(v_0_), "v"(v_1_));
64 return rtn;
65 };
66 auto reduce_max_func = ReduceOp::Max{};
67
68 auto block_reduce2d = Policy::template GetBlockReduce2d<Problem>();
69 auto block_reduce2d_sync = Policy::template GetBlockReduce2dSync<Problem>();
70 auto block_reduce2d_cross_warp_sync =
71 Policy::template GetBlockReduce2dCrossWarpSync<Problem>();
72
73 const auto x = load_tile(x_window);
74 const auto smscale = load_tile(smscale_window);
75 auto y = tile_elementwise_in(
76 [&](const auto& a, const auto& b) {
78 },
79 x,
80 smscale);
81
82 // compute absmax, cross-lane->cross-warp
83 auto absmax = [&]() {
84 constexpr auto x_size_per_row =
85 x.get_tile_distribution().get_ys_to_d_descriptor().get_lengths().at(number<1>{});
86 if constexpr(UseMax3 && std::is_same_v<ComputeDataType, float> &&
87 x_size_per_row % 2 == 0)
88 {
89 return block_reduce2d(y,
90 reduce_absmax_func.GetIdentityValue<ComputeDataType>(),
91 reduce_absmax3_func,
93 }
94 else
95 {
96 return block_reduce2d(
97 y, reduce_absmax_func.GetIdentityValue<ComputeDataType>(), reduce_absmax_func);
98 }
99 }();
100 block_reduce2d_sync(absmax, reduce_max_func);
101 block_reduce2d_cross_warp_sync(absmax, smem, reduce_max_func);
102
103 // ex: yscale = absmax / 127 if int8
104 auto yscale = tile_elementwise_in(
105 [&](const auto& v_) {
107 },
108 absmax);
109 store_tile(yscale_window, cast_tile<YScaleDataType>(yscale));
110
111 // quantize y to qy
112 auto qy = make_static_distributed_tensor<QYDataType>(y.get_tile_distribution());
113 sweep_tile(qy, [&](auto idx) {
114 constexpr auto i_idx = make_tuple(idx[number<0>{}]);
115 auto qy_ = y[idx] / yscale[i_idx];
117 });
118 store_tile(qy_window, qy);
119 }
120};
121} // namespace ck_tile
#define CK_TILE_DEVICE
Definition config.hpp:41
#define CK_TILE_HOST_DEVICE
Definition config.hpp:42
Definition tile/core/algorithm/cluster_descriptor.hpp:13
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
CK_TILE_DEVICE auto tile_elementwise_in(const InElementFunc &in_element_func, const InTensor &... in_dstr_tensors)
Definition tile_elementwise.hpp:40
CK_TILE_HOST_DEVICE constexpr void sweep_tile(const F &f, UnpacksPerXDim={})
Definition sweep_tile.hpp:231
CK_TILE_HOST_DEVICE constexpr auto make_static_distributed_tensor(const StaticTileDistribution &)
Definition static_distributed_tensor.hpp:142
constant< v > number
Definition tile/core/numeric/integral_constant.hpp:37
CK_TILE_DEVICE constexpr auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition null_tile_window.hpp:75
CK_TILE_DEVICE auto cast_tile(const SrcTensor &src_tensor)
Definition tile_elementwise.hpp:327
CK_TILE_DEVICE void store_tile(tile_window_with_static_lengths< BottomTensorView_, WindowLengths_ > &tile_window_tmp, const static_distributed_tensor< DataType_, TileDistribution_ > &dstr_tensor)
Definition store_tile.hpp:23
int32_t index_t
Definition integer.hpp:9
CK_TILE_HOST_DEVICE constexpr Y type_convert(X x)
Definition tile/core/numeric/type_convert.hpp:29
CK_TILE_DEVICE auto load_tile(const TileWindow_ &tile_window, number< i_access >={}, bool_constant< oob_conditional_check >={})
Definition load_tile.hpp:22
CK_TILE_HOST_DEVICE constexpr auto make_tuple(Xs &&... xs)
Definition tile/core/container/tuple.hpp:360
const GenericPointer< typename T::ValueType > T2 T::AllocatorType & a
Definition pointer.h:1517
Definition reduce_operator.hpp:101
Definition reduce_operator.hpp:65
Definition smoothquant_pipeline_one_pass.hpp:15
static constexpr const char * name
Definition smoothquant_pipeline_one_pass.hpp:30
static constexpr bool UseMax3
Definition smoothquant_pipeline_one_pass.hpp:28
CK_TILE_DEVICE auto operator()(const XWindow &x_window_, const SmoothScaleWindow &smscale_window_, YScaleWindow &yscale_window, QYWindow &qy_window, ck_tile::index_t, void *smem) const
Definition smoothquant_pipeline_one_pass.hpp:46
ck_tile::remove_cvref_t< typename Problem::XDataType > XDataType
Definition smoothquant_pipeline_one_pass.hpp:19
ck_tile::remove_cvref_t< Problem_ > Problem
Definition smoothquant_pipeline_one_pass.hpp:16
static CK_TILE_HOST_DEVICE constexpr index_t GetSmemSize()
Definition smoothquant_pipeline_one_pass.hpp:37
static constexpr bool kPadN
Definition smoothquant_pipeline_one_pass.hpp:27
ck_tile::remove_cvref_t< typename Problem::YScaleDataType > YScaleDataType
Definition smoothquant_pipeline_one_pass.hpp:23
ck_tile::remove_cvref_t< typename Problem::ComputeDataType > ComputeDataType
Definition smoothquant_pipeline_one_pass.hpp:21
ck_tile::remove_cvref_t< Policy_ > Policy
Definition smoothquant_pipeline_one_pass.hpp:17
static constexpr bool kNeedCrossWarpSync
Definition smoothquant_pipeline_one_pass.hpp:25
ck_tile::remove_cvref_t< typename Problem::SmoothScaleDataType > SmoothScaleDataType
Definition smoothquant_pipeline_one_pass.hpp:20
static constexpr bool kPadM
Definition smoothquant_pipeline_one_pass.hpp:26
ck_tile::remove_cvref_t< typename Problem::QYDataType > QYDataType
Definition smoothquant_pipeline_one_pass.hpp:22
static CK_TILE_HOST_DEVICE constexpr T max()
Definition tile/core/numeric/numeric.hpp:26
Definition unary_element_function.hpp:56
Definition tile/core/container/sequence.hpp:49