rmsnorm2d_fwd_pipeline_two_pass.hpp Source File

rmsnorm2d_fwd_pipeline_two_pass.hpp Source File#

Composable Kernel: rmsnorm2d_fwd_pipeline_two_pass.hpp Source File
rmsnorm2d_fwd_pipeline_two_pass.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include "ck_tile/core.hpp"
8#include <string>
9#include <type_traits>
10
11namespace ck_tile {
12
13template <typename Problem_, typename Policy_ = Rmsnorm2dFwdPipelineDefaultPolicy>
15{
18
24
27
28 static constexpr bool kHasGamma = !std::is_same_v<GammaDataType, ck_tile::null_type>;
29 static constexpr bool kSaveInvRms = Problem::Traits::kSaveInvRms;
30
31 static constexpr bool kNeedCrossWarpSync = Problem::kNeedCrossWarpSync;
32 static constexpr bool kPadM = false; // TODO - BlockRmsnorm2dFwdProblem::kPadM
33 static constexpr bool kPadN = Problem::Traits::kPadN;
34 static constexpr auto kFusedAdd = Problem::Traits::kFusedAdd;
35 static constexpr auto kFusedQuant = Problem::Traits::kFusedQuant;
36
37 static constexpr const char* name = []() {
38 if constexpr(kNeedCrossWarpSync)
39 return "bpr_tp"; // block per row
40 else
41 return "wpr_tp"; // warp per row
42 }();
43
45 {
46 return Policy::template GetSmemSize<Problem>();
47 }
48
49 template <typename XWindow,
50 typename XResidualWindow,
51 typename GammaWindow,
52 typename YWindow,
53 typename YResidualWindow,
54 typename InvRmsWindow,
55 typename SmoothScaleWindow,
56 typename YScaleWindow,
57 typename UnquantYWindow,
58 typename Epilogue>
59 CK_TILE_DEVICE auto operator()(const XWindow& x_window_,
60 const XResidualWindow& x_residual_window_,
61 const GammaWindow& gamma_window_,
62 YWindow& y_window,
63 const YResidualWindow& y_residual_window_,
64 InvRmsWindow& inv_rms_window,
65 const SmoothScaleWindow& /*sm_scale_window_*/,
66 YScaleWindow& /*y_scale_window*/,
67 UnquantYWindow& /*unquant_y_window*/,
68 ComputeDataType epsilon,
69 ck_tile::index_t row_size,
70 void* smem,
71 Epilogue) const
72 {
73 auto x_window =
74 make_tile_window(x_window_, Policy::template MakeXBlockTileDistribution<Problem>());
75 auto gamma_window = make_tile_window(
76 gamma_window_, Policy::template MakeGammaBlockTileDistribution<Problem>());
77 auto x_residual_window = make_tile_window(
78 x_residual_window_, Policy::template MakeXBlockTileDistribution<Problem>());
79 auto y_residual_window = make_tile_window(
80 y_residual_window_, Policy::template MakeXBlockTileDistribution<Problem>());
81
82 // Problem::BlockShape
83 static constexpr index_t Block_N = Problem::BlockShape::Block_N;
84 index_t num_n_tile_iteration =
86
87 auto reduce_square_sum_func = ReduceOp::SquareAdd{};
88 auto reduce_sum_func = ReduceOp::Add{};
89 auto block_reduce2d = Policy::template GetBlockReduce2d<Problem>();
90 auto block_reduce2d_sync = Policy::template GetBlockReduce2dSync<Problem>();
91 auto block_reduce2d_cross_warp_sync =
92 Policy::template GetBlockReduce2dCrossWarpSync<Problem>();
93
94 using ComputeTensorType = decltype(cast_tile<ComputeDataType>(load_tile(x_window)));
95 auto square_sum = block_reduce2d.template MakeYBlockTile<ComputeTensorType>();
96 set_tile(square_sum, reduce_square_sum_func.GetIdentityValue<ComputeDataType>());
97
98 for(int iN = amd_wave_read_first_lane(0); iN < num_n_tile_iteration; ++iN)
99 {
100 auto x = load_tile(x_window);
101 auto x_resi = load_tile(x_residual_window);
102
103 move_tile_window(x_window, {0, Block_N});
104 move_tile_window(x_residual_window, {0, Block_N});
105
106 auto acc = cast_tile<ComputeDataType>(x);
109 {
110 sweep_tile(x_resi, [&](auto idx) {
111 // compute x = x_resi + x
112 acc(idx) = type_convert<ComputeDataType>(x_resi(idx)) + acc(idx);
113 });
115 {
116 store_tile(y_residual_window, cast_tile<YResidualDataType>(acc));
117 move_tile_window(y_residual_window, {0, Block_N});
118 }
119 }
120
121 block_reduce2d(acc, square_sum, reduce_square_sum_func);
122 }
123
124 block_reduce2d_sync(square_sum, reduce_sum_func);
125 block_reduce2d_cross_warp_sync(square_sum, smem, reduce_sum_func);
126
127 // compute inv-rms
128 auto inv_rms = tile_elementwise_in(
129 [&](const auto& v_) {
130 return type_convert<ComputeDataType>(1.0f) / (sqrt(v_ / row_size + epsilon));
131 },
132 square_sum);
133
134 if constexpr(kSaveInvRms)
135 store_tile(inv_rms_window, cast_tile<InvRmsDataType>(inv_rms));
136
137 // reverse read x to reuse cache
138 ck_tile::index_t stride_to_right_most_window =
139 row_size % Block_N == 0 ? row_size - Block_N : row_size - row_size % Block_N;
140
142 {
143 move_tile_window(y_residual_window, {0, -Block_N});
144 }
145 else
146 {
147 move_tile_window(x_window, {0, -Block_N});
148 move_tile_window(x_residual_window, {0, -Block_N});
149 }
150 move_tile_window(gamma_window, {stride_to_right_most_window});
151 move_tile_window(y_window, {0, stride_to_right_most_window});
152
153 // rmsnorm computation
154 for(int iN = amd_wave_read_first_lane(0); iN < num_n_tile_iteration; ++iN)
155 {
157 decltype(load_tile(x_window))::get_tile_distribution());
158
160 {
161 acc = cast_tile<ComputeDataType>(load_tile(y_residual_window));
162 move_tile_window(y_residual_window, {0, -Block_N});
163 }
164 else
165 {
166 acc = cast_tile<ComputeDataType>(load_tile(x_window));
167 move_tile_window(x_window, {0, -Block_N});
168
170 {
171 auto x_resi = load_tile(x_residual_window);
172 sweep_tile(x_resi, [&](auto idx) {
173 // compute x = x_resi + x
174 acc(idx) = type_convert<ComputeDataType>(x_resi(idx)) + acc(idx);
175 });
176 move_tile_window(x_residual_window, {0, -Block_N});
177 }
178 }
179
180 // load gamma (TODO: support no gamma?)
181 const auto gamma = load_tile(gamma_window);
182
183 // rmsnorm computation
185 decltype(load_tile(x_window))::get_tile_distribution());
186 sweep_tile(rmsn, [&, inv_rms_ = inv_rms](auto idx) {
187 constexpr auto i_idx = make_tuple(idx[number<0>{}]);
188 constexpr auto j_idx = make_tuple(idx[number<1>{}]);
189
190 const auto gamma_ = type_convert<ComputeDataType>(gamma[j_idx]);
191
192 auto rmsn_ = acc(idx) * inv_rms_[i_idx] * gamma_;
193
194 rmsn(idx) = rmsn_;
195 });
196
198 Epilogue{}(y_window, rmsn, nullptr);
199
200 move_tile_window(gamma_window, {-Block_N});
201 move_tile_window(y_window, {0, -Block_N});
202 }
203 }
204};
205} // namespace ck_tile
#define CK_TILE_DEVICE
Definition config.hpp:41
#define CK_TILE_HOST_DEVICE
Definition config.hpp:42
Definition tile/core/algorithm/cluster_descriptor.hpp:13
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
CK_TILE_DEVICE auto tile_elementwise_in(const InElementFunc &in_element_func, const InTensor &... in_dstr_tensors)
Definition tile_elementwise.hpp:40
CK_TILE_DEVICE void set_tile(DstrTensors &dstr_tensor, const T &value)
Definition tile_elementwise.hpp:95
__device__ uint32_t amd_wave_read_first_lane(uint16_t v)
Definition tile/core/arch/amd_buffer_addressing.hpp:35
@ NO_SWEEP
Definition rmsnorm2d_fwd_traits.hpp:28
@ PRE_ADD_STORE
Definition rmsnorm2d_fwd_traits.hpp:14
@ PRE_ADD
Definition rmsnorm2d_fwd_traits.hpp:16
CK_TILE_HOST_DEVICE constexpr void sweep_tile(const F &f, UnpacksPerXDim={})
Definition sweep_tile.hpp:231
CK_TILE_DEVICE bfloat16_t sqrt(bfloat16_t x)
Definition bfloat16.hpp:413
CK_TILE_HOST_DEVICE constexpr auto make_static_distributed_tensor(const StaticTileDistribution &)
Definition static_distributed_tensor.hpp:142
constant< v > number
Definition tile/core/numeric/integral_constant.hpp:37
CK_TILE_DEVICE constexpr auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition null_tile_window.hpp:75
CK_TILE_DEVICE auto cast_tile(const SrcTensor &src_tensor)
Definition tile_elementwise.hpp:327
CK_TILE_HOST_DEVICE constexpr auto integer_divide_ceil(X x, Y y)
Definition tile/core/numeric/math.hpp:149
CK_TILE_DEVICE void move_tile_window(null_tile_window< WindowLengths > &, const typename null_tile_window< WindowLengths >::BottomTensorIndex &)
Definition null_tile_window.hpp:95
CK_TILE_DEVICE void store_tile(tile_window_with_static_lengths< BottomTensorView_, WindowLengths_ > &tile_window_tmp, const static_distributed_tensor< DataType_, TileDistribution_ > &dstr_tensor)
Definition store_tile.hpp:23
int32_t index_t
Definition integer.hpp:9
CK_TILE_HOST_DEVICE constexpr Y type_convert(X x)
Definition tile/core/numeric/type_convert.hpp:29
CK_TILE_DEVICE auto load_tile(const TileWindow_ &tile_window, number< i_access >={}, bool_constant< oob_conditional_check >={})
Definition load_tile.hpp:22
CK_TILE_HOST_DEVICE constexpr auto make_tuple(Xs &&... xs)
Definition tile/core/container/tuple.hpp:360
Definition reduce_operator.hpp:14
Definition reduce_operator.hpp:40
Definition rmsnorm2d_fwd_pipeline_two_pass.hpp:15
static constexpr auto kFusedAdd
Definition rmsnorm2d_fwd_pipeline_two_pass.hpp:34
static constexpr bool kPadM
Definition rmsnorm2d_fwd_pipeline_two_pass.hpp:32
static constexpr bool kPadN
Definition rmsnorm2d_fwd_pipeline_two_pass.hpp:33
ck_tile::remove_cvref_t< typename Problem::XDataType > XDataType
Definition rmsnorm2d_fwd_pipeline_two_pass.hpp:19
ck_tile::remove_cvref_t< typename Problem::ComputeDataType > ComputeDataType
Definition rmsnorm2d_fwd_pipeline_two_pass.hpp:21
CK_TILE_DEVICE auto operator()(const XWindow &x_window_, const XResidualWindow &x_residual_window_, const GammaWindow &gamma_window_, YWindow &y_window, const YResidualWindow &y_residual_window_, InvRmsWindow &inv_rms_window, const SmoothScaleWindow &, YScaleWindow &, UnquantYWindow &, ComputeDataType epsilon, ck_tile::index_t row_size, void *smem, Epilogue) const
Definition rmsnorm2d_fwd_pipeline_two_pass.hpp:59
ck_tile::remove_cvref_t< Policy_ > Policy
Definition rmsnorm2d_fwd_pipeline_two_pass.hpp:17
static constexpr const char * name
Definition rmsnorm2d_fwd_pipeline_two_pass.hpp:37
static constexpr bool kNeedCrossWarpSync
Definition rmsnorm2d_fwd_pipeline_two_pass.hpp:31
ck_tile::remove_cvref_t< Problem_ > Problem
Definition rmsnorm2d_fwd_pipeline_two_pass.hpp:16
ck_tile::remove_cvref_t< typename Problem::YDataType > YDataType
Definition rmsnorm2d_fwd_pipeline_two_pass.hpp:22
static constexpr auto kFusedQuant
Definition rmsnorm2d_fwd_pipeline_two_pass.hpp:35
XDataType XResidualDataType
Definition rmsnorm2d_fwd_pipeline_two_pass.hpp:25
static constexpr bool kSaveInvRms
Definition rmsnorm2d_fwd_pipeline_two_pass.hpp:29
ck_tile::remove_cvref_t< typename Problem::GammaDataType > GammaDataType
Definition rmsnorm2d_fwd_pipeline_two_pass.hpp:20
XDataType YResidualDataType
Definition rmsnorm2d_fwd_pipeline_two_pass.hpp:26
static constexpr bool kHasGamma
Definition rmsnorm2d_fwd_pipeline_two_pass.hpp:28
static CK_TILE_HOST_DEVICE constexpr index_t GetSmemSize()
Definition rmsnorm2d_fwd_pipeline_two_pass.hpp:44
ck_tile::remove_cvref_t< typename Problem::InvRmsDataType > InvRmsDataType
Definition rmsnorm2d_fwd_pipeline_two_pass.hpp:23