type_convert.hpp Source File

type_convert.hpp Source File#

Composable Kernel: type_convert.hpp Source File
tile/core/numeric/type_convert.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include <stdint.h>
7#include <tuple>
8#include <type_traits>
15
16namespace ck_tile {
17
18#if CK_TILE_USE_CUSTOM_DATA_TYPE
19template <typename Y, typename X>
21{
22 return static_cast<Y>(x);
23}
24#else
25// Convert X to Y, both X and Y are non-const data types.
26template <typename Y,
27 typename X,
28 std::enable_if_t<!(std::is_const_v<Y> || std::is_const_v<X>), bool> = false>
30{
31 static_assert(!std::is_reference_v<Y> && !std::is_reference_v<X>);
32 return static_cast<Y>(x);
33}
34
35// Convert X to Y, either X or Y is a const data type.
36template <typename Y,
37 typename X,
38 std::enable_if_t<std::is_const_v<Y> || std::is_const_v<X>, bool> = false>
40{
41 static_assert(!std::is_reference_v<Y> && !std::is_reference_v<X>);
42
43 using non_const_y = std::remove_const_t<Y>;
44 using non_const_x = std::remove_const_t<X>;
45 return static_cast<Y>(type_convert<non_const_y, non_const_x>(x));
46}
47
48#define CK_TILE_TYPE_CONVERT(dtype_, dname_, stype_, sname_) \
49 template <> \
50 CK_TILE_HOST_DEVICE constexpr dtype_ type_convert<dtype_, stype_>(stype_ x) \
51 { \
52 return sname_##_to_##dname_(x); \
53 }
54
55CK_TILE_TYPE_CONVERT(float, float, fp16_t, fp16)
56CK_TILE_TYPE_CONVERT(float, float, bf16_t, bf16)
57CK_TILE_TYPE_CONVERT(float, float, fp8_t, fp8)
58CK_TILE_TYPE_CONVERT(float, float, bf8_t, bf8)
59
60CK_TILE_TYPE_CONVERT(fp16_t, fp16, float, float)
61CK_TILE_TYPE_CONVERT(bf16_t, bf16, float, float)
62CK_TILE_TYPE_CONVERT(fp8_t, fp8, float, float)
63CK_TILE_TYPE_CONVERT(bf8_t, bf8, float, float)
64
65CK_TILE_TYPE_CONVERT(float, float, int8_t, int8)
66CK_TILE_TYPE_CONVERT(int8_t, int8, float, float)
67#undef CK_TILE_TYPE_CONVERT
68
69} // namespace ck_tile
70
72
73namespace ck_tile {
74
75template <typename Y, typename X>
76CK_TILE_HOST_DEVICE constexpr Y scaled_type_convert(X x, float scale);
77
78#define CK_TILE_SCALED_TYPE_CONVERT(dtype_, dname_, stype_, sname_) \
79 template <> \
80 CK_TILE_HOST_DEVICE constexpr dtype_ scaled_type_convert<dtype_, stype_>(stype_ x, \
81 float scale) \
82 { \
83 return sname_##_to_##dname_(x, scale); \
84 } \
85 template <> \
86 CK_TILE_HOST_DEVICE constexpr dtype_ type_convert<dtype_, stype_>(stype_ x) \
87 { \
88 return sname_##_to_##dname_(x, 1.f); \
89 }
90
97CK_TILE_SCALED_TYPE_CONVERT(pk_fp4_t, pk_fp4, float, float)
98CK_TILE_SCALED_TYPE_CONVERT(float, float, pk_fp4_t, pk_fp4)
103#undef CK_TILE_SCALED_TYPE_CONVERT
104
105#endif
106
107} // namespace ck_tile
#define CK_TILE_HOST_DEVICE
Definition config.hpp:42
Definition tile/core/algorithm/cluster_descriptor.hpp:13
_Float16 fp16x2_t
Definition half.hpp:385
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
bfloat16_t bf16_t
Definition bfloat16.hpp:113
_Float16 fp16_t
Definition half.hpp:110
_BitInt(8) fp8_t
Definition float8.hpp:204
bfloat16_t bf16x2_t
Definition pk_fp4.hpp:24
pk_float4_e2m1_t pk_fp4_t
Definition pk_fp4.hpp:151
CK_TILE_HOST_DEVICE constexpr Y scaled_type_convert(X x, float scale)
unsigned _BitInt(8) bf8_t
Definition float8.hpp:206
float fp32x2_t
Definition pk_fp4.hpp:22
CK_TILE_HOST_DEVICE constexpr Y type_convert(X x)
Definition tile/core/numeric/type_convert.hpp:29
signed char int8_t
Definition stdint.h:121
#define CK_TILE_TYPE_CONVERT(dtype_, dname_, stype_, sname_)
Definition tile/core/numeric/type_convert.hpp:48
#define CK_TILE_SCALED_TYPE_CONVERT(dtype_, dname_, stype_, sname_)
Definition tile/core/numeric/type_convert.hpp:78