amd_wmma.hpp Source File

amd_wmma.hpp Source File#

Composable Kernel: amd_wmma.hpp Source File
amd_wmma.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#ifndef CK_AMD_WMMA_HPP
5#define CK_AMD_WMMA_HPP
6
8#include "data_type.hpp"
9// TODO: Add arch limitation
10namespace ck {
11
12#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || \
13 defined(__gfx1103__) || defined(__gfx11_generic__)
14#define __gfx11__
15#endif
16
17#if defined(__gfx1200__) || defined(__gfx1201__) || defined(__gfx12_generic__)
18#define __gfx12__
19#endif
20
21/********************************WAVE32 MODE***********************************************/
22
23// src: fp16, dst: fp32
24template <index_t MPerWave, index_t NPerWave>
26
27template <>
29{
30 template <class FloatC>
31 __device__ static void Run(const half16_t& reg_a, const half16_t& reg_b, FloatC& reg_c)
32 {
33 // * Inline assembly need to elimate the duplicated data load, compiler won't help you
34 // delete them.
35 // amd_assembly_wmma_f32_16x16x16_f16_w32(
36 // reg_a, reg_b, reg_c.template AsType<float8_t>()(Number<0>{}));
37#if defined(__gfx11__)
38 reg_c.template AsType<float8_t>()(Number<0>{}) = __builtin_amdgcn_wmma_f32_16x16x16_f16_w32(
39 reg_a, reg_b, reg_c.template AsType<float8_t>()[Number<0>{}]);
40#else
41 ignore = reg_a;
42 ignore = reg_b;
43 ignore = reg_c;
44#endif
45 }
46};
47
48// src: bf16, dst: fp32
49template <index_t MPerWave, index_t NPerWave>
51
52template <>
54{
55 template <class FloatC>
56 __device__ static void Run(const bhalf16_t& reg_a, const bhalf16_t& reg_b, FloatC& reg_c)
57 {
58#if defined(__gfx11__)
59 reg_c.template AsType<float8_t>()(Number<0>{}) =
60 __builtin_amdgcn_wmma_f32_16x16x16_bf16_w32(
61 reg_a, reg_b, reg_c.template AsType<float8_t>()[Number<0>{}]);
62#else
63 ignore = reg_a;
64 ignore = reg_b;
65 ignore = reg_c;
66#endif
67 }
68};
69
70// src: fp16, dst: fp16
71template <index_t MPerWave, index_t NPerWave, index_t Opsel>
73
74template <index_t Opsel>
76{
77 template <class FloatC>
78 __device__ static void Run(const half16_t& reg_a, const half16_t& reg_b, FloatC& reg_c)
79 {
80 // opsel usage
81 // false: D0.[0:15] = result
82 // true : D0.[16:31]= result
83#if defined(__gfx11__)
84 reg_c.template AsType<half16_t>()(Number<0>{}) = __builtin_amdgcn_wmma_f16_16x16x16_f16_w32(
85 reg_a, reg_b, reg_c.template AsType<half16_t>()[Number<0>{}], Opsel);
86#else
87 ignore = reg_a;
88 ignore = reg_b;
89 ignore = reg_c;
90#endif
91 }
92};
93
94// src: bf16, dst: bf16
95template <index_t MPerWave, index_t NPerWave, index_t Opsel>
97
98template <index_t Opsel>
100{
101 template <class FloatC>
102 __device__ static void Run(const bhalf16_t& reg_a, const bhalf16_t& reg_b, FloatC& reg_c)
103 {
104 // opsel usage
105 // false: D0.[0:15] = result
106 // true : D0.[16:31]= result
107#if defined(__gfx11__)
108 reg_c.template AsType<bhalf16_t>()(Number<0>{}) =
109 __builtin_amdgcn_wmma_bf16_16x16x16_bf16_w32(
110 reg_a, reg_b, reg_c.template AsType<bhalf16_t>()[Number<0>{}], Opsel);
111#else
112 ignore = reg_a;
113 ignore = reg_b;
114 ignore = reg_c;
115#endif
116 }
117};
118
119// src: iu8, dst: i32
120template <index_t MPerWave, index_t NPerWave, bool neg_a, bool neg_b, bool clamp>
122
123template <bool neg_a, bool neg_b, bool clamp>
124struct intrin_wmma_i32_16x16x16_iu8_w32<16, 16, neg_a, neg_b, clamp>
125{
126 template <class FloatC>
127 __device__ static void Run(const int8x16_t& reg_a, const int8x16_t& reg_b, FloatC& reg_c)
128 {
129#if defined(__gfx11__)
130 reg_c.template AsType<int32x8_t>()(Number<0>{}) =
131 __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32(
132 neg_a,
133 bit_cast<int32x4_t>(reg_a),
134 neg_b,
135 bit_cast<int32x4_t>(reg_b),
136 reg_c.template AsType<int32x8_t>()[Number<0>{}],
137 clamp);
138#else
139 ignore = reg_a;
140 ignore = reg_b;
141 ignore = reg_c;
142#endif
143 }
144};
145
146/********************************WAVE64 MODE***********************************************/
147
148template <index_t MPerWave, index_t NPerWave>
150
151template <>
153{
154 template <class FloatC>
155 __device__ static void Run(const half16_t& reg_a, const half16_t& reg_b, FloatC& reg_c)
156 {
157#if defined(__gfx11__)
158 reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_wmma_f32_16x16x16_f16_w64(
159 reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}]);
160#else
161 ignore = reg_a;
162 ignore = reg_b;
163 ignore = reg_c;
164#endif
165 }
166};
167
168// src: bf16, dst: fp32
169template <index_t MPerWave, index_t NPerWave>
171
172template <>
174{
175 template <class FloatC>
176 __device__ static void Run(const bhalf16_t& reg_a, const bhalf16_t& reg_b, FloatC& reg_c)
177 {
178#if defined(__gfx11__)
179 reg_c.template AsType<float4_t>()(Number<0>{}) =
180 __builtin_amdgcn_wmma_f32_16x16x16_bf16_w64(
181 reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}]);
182#else
183 ignore = reg_a;
184 ignore = reg_b;
185 ignore = reg_c;
186#endif
187 }
188};
189
190// src: fp16, dst: fp16
191template <index_t MPerWave, index_t NPerWave, index_t Opsel>
193
194template <index_t Opsel>
196{
197 template <class FloatC>
198 __device__ static void Run(const half16_t& reg_a, const half16_t& reg_b, FloatC& reg_c)
199 {
200 // opsel usage
201 // false: D0.[0:15] = result
202 // true : D0.[16:31]= result
203#if defined(__gfx11__)
204 reg_c.template AsType<half8_t>()(Number<0>{}) = __builtin_amdgcn_wmma_f16_16x16x16_f16_w64(
205 reg_a, reg_b, reg_c.template AsType<half8_t>()[Number<0>{}], Opsel);
206#else
207 ignore = reg_a;
208 ignore = reg_b;
209 ignore = reg_c;
210#endif
211 }
212};
213
214// src: bf16, dst: bf16
215template <index_t MPerWave, index_t NPerWave, index_t Opsel>
217
218template <index_t Opsel>
220{
221 template <class FloatC>
222 __device__ static void Run(const bhalf16_t& reg_a, const bhalf16_t& reg_b, FloatC& reg_c)
223 {
224 // opsel usage
225 // false: D0.[0:15] = result
226 // true : D0.[16:31]= result
227#if defined(__gfx11__)
228 reg_c.template AsType<bhalf8_t>()(Number<0>{}) =
229 __builtin_amdgcn_wmma_bf16_16x16x16_bf16_w64(
230 reg_a, reg_b, reg_c.template AsType<bhalf8_t>()[Number<0>{}], Opsel);
231#else
232 ignore = reg_a;
233 ignore = reg_b;
234 ignore = reg_c;
235#endif
236 }
237};
238
239// src: iu8, dst: i32
240template <index_t MPerWave, index_t NPerWave, bool neg_a, bool neg_b, bool clamp>
242
243template <bool neg_a, bool neg_b, bool clamp>
244struct intrin_wmma_i32_16x16x16_iu8_w64<16, 16, neg_a, neg_b, clamp>
245{
246 template <class FloatC>
247 __device__ static void Run(const int8x16_t& reg_a, const int8x16_t& reg_b, FloatC& reg_c)
248 {
249#if defined(__gfx11__)
250 reg_c.template AsType<int32x4_t>()(Number<0>{}) =
251 __builtin_amdgcn_wmma_i32_16x16x16_iu8_w64(
252 neg_a,
253 bit_cast<int32x4_t>(reg_a),
254 neg_b,
255 bit_cast<int32x4_t>(reg_b),
256 reg_c.template AsType<int32x4_t>()[Number<0>{}],
257 clamp);
258#else
259 ignore = reg_a;
260 ignore = reg_b;
261 ignore = reg_c;
262#endif
263 }
264};
265
266// gfx12
267/********************************WAVE32 MODE***********************************************/
268
269// src: fp16, dst: fp32
270template <index_t MPerWave, index_t NPerWave>
272
273template <>
275{
276 template <class FloatC>
277 __device__ static void Run(const half8_t& reg_a, const half8_t& reg_b, FloatC& reg_c)
278 {
279 // * Inline assembly need to elimate the duplicated data load, compiler won't help you
280 // delete them.
281 // amd_assembly_wmma_f32_16x16x16_f16_w32(
282 // reg_a, reg_b, reg_c.template AsType<float8_t>()(Number<0>{}));
283#if defined(__gfx12__)
284 reg_c.template AsType<float8_t>()(Number<0>{}) =
285 __builtin_amdgcn_wmma_f32_16x16x16_f16_w32_gfx12(
286 reg_a, reg_b, reg_c.template AsType<float8_t>()[Number<0>{}]);
287#else
288 ignore = reg_a;
289 ignore = reg_b;
290 ignore = reg_c;
291#endif
292 }
293};
294
295// src: bf16, dst: fp32
296template <index_t MPerWave, index_t NPerWave>
298
299template <>
301{
302 template <class FloatC>
303 __device__ static void Run(const bhalf8_t& reg_a, const bhalf8_t& reg_b, FloatC& reg_c)
304 {
305#if defined(__gfx12__)
306 reg_c.template AsType<float8_t>()(Number<0>{}) =
307 __builtin_amdgcn_wmma_f32_16x16x16_bf16_w32_gfx12(
308 reg_a, reg_b, reg_c.template AsType<float8_t>()[Number<0>{}]);
309#else
310 ignore = reg_a;
311 ignore = reg_b;
312 ignore = reg_c;
313#endif
314 }
315};
316
317// src: iu8, dst: i32
318template <index_t MPerWave, index_t NPerWave, bool neg_a, bool neg_b, bool clamp>
320
321template <bool neg_a, bool neg_b, bool clamp>
322struct intrin_wmma_i32_16x16x16_iu8_w32_gfx12<16, 16, neg_a, neg_b, clamp>
323{
324 template <class FloatC>
325 __device__ static void Run(const int8x8_t& reg_a, const int8x8_t& reg_b, FloatC& reg_c)
326 {
327#if defined(__gfx12__)
328 reg_c.template AsType<int32x8_t>()(Number<0>{}) =
329 __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12(
330 neg_a,
331 bit_cast<int32x2_t>(reg_a),
332 neg_b,
333 bit_cast<int32x2_t>(reg_b),
334 reg_c.template AsType<int32x8_t>()[Number<0>{}],
335 clamp);
336#else
337 ignore = reg_a;
338 ignore = reg_b;
339 ignore = reg_c;
340#endif
341 }
342};
343
344// src: f8, f8, dst: fp32
345template <index_t MPerWave, index_t NPerWave>
347
348template <>
350{
351 template <class FloatC>
352 __device__ static void Run(const f8x8_t& reg_a, const f8x8_t& reg_b, FloatC& reg_c)
353 {
354#if defined(__gfx12__)
355 reg_c.template AsType<float8_t>()(Number<0>{}) =
356 __builtin_amdgcn_wmma_f32_16x16x16_fp8_fp8_w32_gfx12(
357 bit_cast<int32x2_t>(reg_a),
358 bit_cast<int32x2_t>(reg_b),
359 reg_c.template AsType<float8_t>()[Number<0>{}]);
360#else
361 ignore = reg_a;
362 ignore = reg_b;
363 ignore = reg_c;
364#endif
365 }
366};
367
368// src: f8, bf8, dst: fp32
369template <index_t MPerWave, index_t NPerWave>
371
372template <>
374{
375 template <class FloatC>
376 __device__ static void Run(const f8x8_t& reg_a, const bf8x8_t& reg_b, FloatC& reg_c)
377 {
378#if defined(__gfx12__)
379 reg_c.template AsType<float8_t>()(Number<0>{}) =
380 __builtin_amdgcn_wmma_f32_16x16x16_fp8_bf8_w32_gfx12(
381 bit_cast<int32x2_t>(reg_a),
382 bit_cast<int32x2_t>(reg_b),
383 reg_c.template AsType<float8_t>()[Number<0>{}]);
384#else
385 ignore = reg_a;
386 ignore = reg_b;
387 ignore = reg_c;
388#endif
389 }
390};
391
392// src: bf8, f8, dst: fp32
393template <index_t MPerWave, index_t NPerWave>
395
396template <>
398{
399 template <class FloatC>
400 __device__ static void Run(const bf8x8_t& reg_a, const f8x8_t& reg_b, FloatC& reg_c)
401 {
402#if defined(__gfx12__)
403 reg_c.template AsType<float8_t>()(Number<0>{}) =
404 __builtin_amdgcn_wmma_f32_16x16x16_bf8_fp8_w32_gfx12(
405 bit_cast<int32x2_t>(reg_a),
406 bit_cast<int32x2_t>(reg_b),
407 reg_c.template AsType<float8_t>()[Number<0>{}]);
408#else
409 ignore = reg_a;
410 ignore = reg_b;
411 ignore = reg_c;
412#endif
413 }
414};
415
416// src: bf8, bf8, dst: fp32
417template <index_t MPerWave, index_t NPerWave>
419
420template <>
422{
423 template <class FloatC>
424 __device__ static void Run(const bf8x8_t& reg_a, const bf8x8_t& reg_b, FloatC& reg_c)
425 {
426#if defined(__gfx12__)
427 reg_c.template AsType<float8_t>()(Number<0>{}) =
428 __builtin_amdgcn_wmma_f32_16x16x16_bf8_bf8_w32_gfx12(
429 bit_cast<int32x2_t>(reg_a),
430 bit_cast<int32x2_t>(reg_b),
431 reg_c.template AsType<float8_t>()[Number<0>{}]);
432#else
433 ignore = reg_a;
434 ignore = reg_b;
435 ignore = reg_c;
436#endif
437 }
438};
439
440} // namespace ck
441#endif
Definition ck.hpp:268
typename vector_type< int8_t, 8 >::type int8x8_t
Definition dtype_vector.hpp:2178
typename vector_type< bhalf_t, 8 >::type bhalf8_t
Definition dtype_vector.hpp:2162
typename vector_type< half_t, 16 >::type half16_t
Definition dtype_vector.hpp:2156
integral_constant< index_t, N > Number
Definition number.hpp:12
typename vector_type< half_t, 8 >::type half8_t
Definition dtype_vector.hpp:2155
constexpr detail::ignore_t ignore
Definition utility/ignore.hpp:20
typename vector_type< int8_t, 16 >::type int8x16_t
Definition dtype_vector.hpp:2179
typename vector_type< bhalf_t, 16 >::type bhalf16_t
Definition dtype_vector.hpp:2163
__host__ __device__ constexpr Y bit_cast(const X &x)
Definition type.hpp:306
static __device__ void Run(const bhalf16_t &reg_a, const bhalf16_t &reg_b, FloatC &reg_c)
Definition amd_wmma.hpp:102
Definition amd_wmma.hpp:96
static __device__ void Run(const bhalf16_t &reg_a, const bhalf16_t &reg_b, FloatC &reg_c)
Definition amd_wmma.hpp:222
Definition amd_wmma.hpp:216
static __device__ void Run(const half16_t &reg_a, const half16_t &reg_b, FloatC &reg_c)
Definition amd_wmma.hpp:78
Definition amd_wmma.hpp:72
static __device__ void Run(const half16_t &reg_a, const half16_t &reg_b, FloatC &reg_c)
Definition amd_wmma.hpp:198
Definition amd_wmma.hpp:192
static __device__ void Run(const bhalf16_t &reg_a, const bhalf16_t &reg_b, FloatC &reg_c)
Definition amd_wmma.hpp:56
static __device__ void Run(const bhalf8_t &reg_a, const bhalf8_t &reg_b, FloatC &reg_c)
Definition amd_wmma.hpp:303
Definition amd_wmma.hpp:50
static __device__ void Run(const bhalf16_t &reg_a, const bhalf16_t &reg_b, FloatC &reg_c)
Definition amd_wmma.hpp:176
Definition amd_wmma.hpp:170
static __device__ void Run(const bf8x8_t &reg_a, const bf8x8_t &reg_b, FloatC &reg_c)
Definition amd_wmma.hpp:424
static __device__ void Run(const bf8x8_t &reg_a, const f8x8_t &reg_b, FloatC &reg_c)
Definition amd_wmma.hpp:400
static __device__ void Run(const half16_t &reg_a, const half16_t &reg_b, FloatC &reg_c)
Definition amd_wmma.hpp:31
static __device__ void Run(const half8_t &reg_a, const half8_t &reg_b, FloatC &reg_c)
Definition amd_wmma.hpp:277
Definition amd_wmma.hpp:25
static __device__ void Run(const half16_t &reg_a, const half16_t &reg_b, FloatC &reg_c)
Definition amd_wmma.hpp:155
Definition amd_wmma.hpp:149
static __device__ void Run(const f8x8_t &reg_a, const bf8x8_t &reg_b, FloatC &reg_c)
Definition amd_wmma.hpp:376
static __device__ void Run(const f8x8_t &reg_a, const f8x8_t &reg_b, FloatC &reg_c)
Definition amd_wmma.hpp:352
static __device__ void Run(const int8x16_t &reg_a, const int8x16_t &reg_b, FloatC &reg_c)
Definition amd_wmma.hpp:127
static __device__ void Run(const int8x8_t &reg_a, const int8x8_t &reg_b, FloatC &reg_c)
Definition amd_wmma.hpp:325
Definition amd_wmma.hpp:121
static __device__ void Run(const int8x16_t &reg_a, const int8x16_t &reg_b, FloatC &reg_c)
Definition amd_wmma.hpp:247
Definition amd_wmma.hpp:241