reduce_operator.hpp Source File

reduce_operator.hpp Source File#

Composable Kernel: reduce_operator.hpp Source File
reduce_operator.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
8
9namespace ck_tile {
10
11namespace ReduceOp {
12// y = ReduceOp(y, x);
13struct Add
14{
15 template <typename T>
17 {
18 return type_convert<T>(0.0f);
19 };
20
21 template <typename T,
22 typename = std::enable_if_t<is_any_of<T, float, double, int32_t, int8_t>::value>>
23 CK_TILE_HOST_DEVICE constexpr T operator()(const T& y, const T x) const
24 {
25 return y + x;
26 }
27
28 template <typename T,
29 typename = std::enable_if_t<is_any_of<T, half_t, bf16_t, fp8_t, bf8_t>::value>>
30 CK_TILE_HOST_DEVICE constexpr T operator()(T& y, T x) const
31 {
32 float y_ = type_convert<float>(y);
33 float x_ = type_convert<float>(x);
34
35 return type_convert<T>(y_ + x_);
36 }
37};
38
40{
41 template <typename T>
43 {
44 return type_convert<T>(0.0f);
45 };
46
47 template <typename T,
48 typename = std::enable_if_t<is_any_of<T, float, double, int32_t, int8_t>::value>>
49 CK_TILE_HOST_DEVICE constexpr T operator()(const T& y, const T x) const
50 {
51 return y + (x * x);
52 }
53
54 template <typename T,
55 typename = std::enable_if_t<is_any_of<T, half_t, bf16_t, fp8_t, bf8_t>::value>>
56 CK_TILE_HOST_DEVICE constexpr T operator()(T& y, T x) const
57 {
58 float y_ = type_convert<float>(y);
59 float x_ = type_convert<float>(x);
60 return type_convert<T>(y_ + (x_ * x_));
61 }
62};
63
64struct Max
65{
66 template <
67 typename T,
68 typename = std::enable_if_t<
71 {
72 return numeric<T>::lowest();
73 };
74
75 template <
76 typename T,
77 typename = std::enable_if_t<
79 CK_TILE_HOST_DEVICE constexpr T operator()(const T& y, const T x) const
80 {
81 return max(y, x);
82 }
83
84 // Overload with changed flag for index tracking
85 template <
86 typename T,
87 typename = std::enable_if_t<
89 CK_TILE_HOST_DEVICE constexpr T operator()(const T& y, const T x, bool& changed) const
90 {
91 T new_max = max(y, x);
92 if(x > y)
93 {
94 changed = true;
95 }
96 return new_max;
97 }
98};
99
100struct AbsMax
101{
102 template <
103 typename T,
104 typename = std::enable_if_t<
107 {
108 return numeric<T>::zero();
109 };
110
111 template <
112 typename T,
113 typename = std::enable_if_t<
115 CK_TILE_HOST_DEVICE constexpr T operator()(const T& y, const T x) const
116 {
117 return max(y, abs(x));
118 }
119
120 // Overload with changed flag for index tracking
121 template <
122 typename T,
123 typename = std::enable_if_t<
125 CK_TILE_HOST_DEVICE constexpr T operator()(const T& y, const T x, bool& changed) const
126 {
127 T new_max = max(y, abs(x));
128 if(abs(x) > y)
129 {
130 changed = true;
131 }
132 return new_max;
133 }
134};
135
136} // namespace ReduceOp
137} // namespace ck_tile
#define CK_TILE_HOST_DEVICE
Definition config.hpp:42
Definition reduce_operator.hpp:11
Definition tile/core/algorithm/cluster_descriptor.hpp:13
CK_TILE_HOST_DEVICE bfloat16_t abs(const bfloat16_t &x)
Definition bfloat16.hpp:400
CK_TILE_HOST_DEVICE constexpr T max(T x)
Definition tile/core/numeric/math.hpp:161
CK_TILE_HOST_DEVICE constexpr Y type_convert(X x)
Definition tile/core/numeric/type_convert.hpp:29
Definition reduce_operator.hpp:101
static CK_TILE_HOST_DEVICE constexpr T GetIdentityValue()
Definition reduce_operator.hpp:106
CK_TILE_HOST_DEVICE constexpr T operator()(const T &y, const T x) const
Definition reduce_operator.hpp:115
CK_TILE_HOST_DEVICE constexpr T operator()(const T &y, const T x, bool &changed) const
Definition reduce_operator.hpp:125
Definition reduce_operator.hpp:14
CK_TILE_HOST_DEVICE constexpr T operator()(const T &y, const T x) const
Definition reduce_operator.hpp:23
CK_TILE_HOST_DEVICE constexpr T operator()(T &y, T x) const
Definition reduce_operator.hpp:30
static CK_TILE_HOST_DEVICE constexpr T GetIdentityValue()
Definition reduce_operator.hpp:16
Definition reduce_operator.hpp:65
CK_TILE_HOST_DEVICE constexpr T operator()(const T &y, const T x) const
Definition reduce_operator.hpp:79
static CK_TILE_HOST_DEVICE constexpr T GetIdentityValue()
Definition reduce_operator.hpp:70
CK_TILE_HOST_DEVICE constexpr T operator()(const T &y, const T x, bool &changed) const
Definition reduce_operator.hpp:89
Definition reduce_operator.hpp:40
CK_TILE_HOST_DEVICE constexpr T operator()(const T &y, const T x) const
Definition reduce_operator.hpp:49
CK_TILE_HOST_DEVICE constexpr T operator()(T &y, T x) const
Definition reduce_operator.hpp:56
static CK_TILE_HOST_DEVICE constexpr T GetIdentityValue()
Definition reduce_operator.hpp:42
Definition type_traits.hpp:115
static CK_TILE_HOST_DEVICE constexpr T lowest()
Definition tile/core/numeric/numeric.hpp:23
static CK_TILE_HOST_DEVICE constexpr T zero()
Definition tile/core/numeric/numeric.hpp:58