buffer_view.hpp Source File

buffer_view.hpp Source File#

Composable Kernel: buffer_view.hpp Source File
buffer_view.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
19
20namespace ck_tile {
21
22// T may be scalar or vector
23// X may be scalar or vector
24// T and X have same scalar type
25// X contains multiple T
26// FIXME: InvalidElementUseNumericalZeroValue and invalid_element_value_ should be a property of
27// transforms of tensor_view/Tensor
28// FIXME: amd_buffer_coherence_enum is only meaningful for buffer addressing. Need to split
29// buffer_view definition for different memory address space (Global/GenericLds/Vgpr)
30template <address_space_enum BufferAddressSpace,
31 typename T,
32 typename BufferSizeType,
33 bool InvalidElementUseNumericalZeroValue,
36
37// Address Space: generic
38// T may be scalar or vector
39// X may be scalar or vector
40// T and X have same scalar type
41// X contains multiple T
42// FIXME: InvalidElementUseNumericalZeroValue and invalid_element_value_ should be a property of
43// transforms of tensor_view/Tensor
44template <typename T, typename BufferSizeType, bool InvalidElementUseNumericalZeroValue>
46 T,
47 BufferSizeType,
48 InvalidElementUseNumericalZeroValue,
50{
51 using type = T;
52
53 T* p_data_ = nullptr;
54 BufferSizeType buffer_size_;
56
61
62 CK_TILE_HOST_DEVICE constexpr buffer_view(T* __restrict__ p_data, BufferSizeType buffer_size)
63 : p_data_{p_data}, buffer_size_{buffer_size}, invalid_element_value_{0}
64 {
65 }
66
67 CK_TILE_HOST_DEVICE constexpr buffer_view(T* __restrict__ p_data,
68 BufferSizeType buffer_size,
69 T invalid_element_value)
70 : p_data_{p_data}, buffer_size_{buffer_size}, invalid_element_value_{invalid_element_value}
71 {
72 }
73
75
80
81 // i is offset of T
82 // FIXME: doesn't do is_valid check
83 CK_TILE_DEVICE constexpr const T& operator[](index_t i) const { return p_data_[i]; }
84
85 // i is offset of T
86 // FIXME: doesn't do is_valid check
87 CK_TILE_DEVICE constexpr T& operator()(index_t i) { return p_data_[i]; }
88
89 // i is offset of T, not X. i should be aligned to X
90 template <typename X,
91 bool oob_conditional_check = true,
92 typename std::enable_if<
93 std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
94 typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
95 bool>::type = false>
96 CK_TILE_DEVICE constexpr auto get(index_t i,
97 index_t linear_offset,
98 bool is_valid_element,
100 {
101 // X contains multiple T
102 constexpr index_t scalar_per_t_vector = vector_traits<remove_cvref_t<T>>::vector_size;
103
104 constexpr index_t scalar_per_x_vector = vector_traits<remove_cvref_t<X>>::vector_size;
105
106 static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
107 "wrong! X should contain multiple T");
108
109 if(is_valid_element)
110 {
111#if CK_TILE_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS
112 X tmp;
113
114 __builtin_memcpy(&tmp, &(p_data_[i + linear_offset]), sizeof(X));
115
116 return tmp;
117#else
118 return *c_style_pointer_cast<const X*>(&p_data_[i + linear_offset]);
119#endif
120 }
121 else
122 {
123 if constexpr(InvalidElementUseNumericalZeroValue)
124 {
125 return X{numeric<remove_cvref_t<T>>::zero()};
126 }
127 else
128 {
129 return X{invalid_element_value_};
130 }
131 }
132 }
133
134 /*
135 In the generic address space, we do not support the transpose instruction in the buffer view.
136 Will report compilation error when developer wants to use it.
137 */
138 template <typename X,
139 bool oob_conditional_check = true,
140 typename std::enable_if<
141 std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
142 typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
143 bool>::type = false>
145 index_t linear_offset,
146 bool is_valid_element,
148 {
149 static_assert(false, "Error: transpose load not supported in global memory space.");
150 ignore = i;
151 ignore = linear_offset;
152 ignore = is_valid_element;
153 return;
154 }
155
156 // i is offset of T, not X. i should be aligned to X
157 template <memory_operation_enum Op,
158 typename X,
159 typename std::enable_if<
160 std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
161 typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
162 bool>::type = false>
163 CK_TILE_DEVICE void update(index_t i, index_t linear_offset, bool is_valid_element, const X& x)
164 {
165 if constexpr(Op == memory_operation_enum::set)
166 {
167 this->template set<X>(i, linear_offset, is_valid_element, x);
168 }
169 // FIXME: remove memory_operation_enum::add
170 else if constexpr(Op == memory_operation_enum::add)
171 {
172 auto tmp = this->template get<X>(i, linear_offset, is_valid_element);
173 this->template set<X>(i, linear_offset, is_valid_element, x + tmp);
174 }
175 }
176
177 // i is offset of T, not X. i should be aligned to X
178 template <typename X,
179 typename std::enable_if<
180 std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
181 typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
182 bool>::type = false>
183 CK_TILE_DEVICE void set(index_t i, index_t linear_offset, bool is_valid_element, const X& x)
184 {
185 // X contains multiple T
186 constexpr index_t scalar_per_t_vector = vector_traits<remove_cvref_t<T>>::vector_size;
187
188 constexpr index_t scalar_per_x_vector = vector_traits<remove_cvref_t<X>>::vector_size;
189
190 static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
191 "wrong! X should contain multiple T");
192
193 if(is_valid_element)
194 {
195#if CK_TILE_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS
196 X tmp = x;
197
198 __builtin_memcpy(&(p_data_[i + linear_offset]), &tmp, sizeof(X));
199#else
200 *c_style_pointer_cast<X*>(&p_data_[i + linear_offset]) = x;
201#endif
202 }
203 }
204
205 // FIXME: remove
206 CK_TILE_DEVICE static constexpr bool is_static_buffer() { return false; }
207
208 // FIXME: remove
209 CK_TILE_DEVICE static constexpr bool is_dynamic_buffer() { return true; }
210};
211
212// Address Space: Global
213// T may be scalar or vector
214// X may be scalar or vector
215// T and X have same scalar type
216// X contains multiple T
217// FIXME: InvalidElementUseNumericalZeroValue and invalid_element_value_ should be a property of
218// transforms of tensor_view/Tensor
219template <typename T,
220 typename BufferSizeType,
221 bool InvalidElementUseNumericalZeroValue,
224 T,
225 BufferSizeType,
226 InvalidElementUseNumericalZeroValue,
227 Coherence>
228{
229 using type = T;
230
231 T* p_data_ = nullptr;
232 BufferSizeType buffer_size_;
235
237
242
243 CK_TILE_HOST_DEVICE constexpr buffer_view(T* __restrict__ p_data, BufferSizeType buffer_size)
244 : p_data_{p_data},
245 buffer_size_{buffer_size / PackedSize},
248 {
249 }
250
251 CK_TILE_HOST_DEVICE constexpr buffer_view(T* __restrict__ p_data,
252 BufferSizeType buffer_size,
253 T invalid_element_value)
254 : p_data_{p_data},
255 buffer_size_{buffer_size / PackedSize},
257 invalid_element_value_{invalid_element_value}
258 {
259 }
260
261 // this is non constexpr intentially (will call some intrinsic internally)
262 // Must call for buffers that need *_raw load/store
267
272
273 // i is offset of T
274 // FIXME: doesn't do is_valid check
275 CK_TILE_DEVICE constexpr const T& operator[](index_t i) const { return p_data_[i]; }
276
277 // i is offset of T
278 // FIXME: doesn't do is_valid check
279 CK_TILE_DEVICE constexpr T& operator()(index_t i) { return p_data_[i]; }
280
281 // i is offset of T, not X. i should be aligned to X
282 template <typename X,
283 bool oob_conditional_check = true,
284 typename std::enable_if<
285 std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
286 typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
287 bool>::type = false>
288 CK_TILE_DEVICE constexpr auto get(index_t i,
289 index_t linear_offset,
290 bool is_valid_element,
292 {
293 // X contains multiple T
294 constexpr index_t scalar_per_t_vector = vector_traits<remove_cvref_t<T>>::vector_size;
295
296 constexpr index_t scalar_per_x_vector = vector_traits<remove_cvref_t<X>>::vector_size;
297
298 static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
299 "wrong! X should contain multiple T");
300
301#if CK_TILE_USE_AMD_BUFFER_LOAD
302 bool constexpr use_amd_buffer_addressing = true;
303#else
304 bool constexpr use_amd_buffer_addressing = false;
305#endif
306
307 if constexpr(use_amd_buffer_addressing)
308 {
309 constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
310
311 if constexpr(InvalidElementUseNumericalZeroValue)
312 {
314 t_per_x,
315 Coherence,
316 oob_conditional_check>(
317 p_data_, i + linear_offset, is_valid_element, buffer_size_);
318 }
319 else
320 {
323 t_per_x,
324 Coherence,
325 oob_conditional_check>(p_data_,
326 i + linear_offset,
327 is_valid_element,
328 buffer_size_,
329 invalid_element_value_);
330 }
331 }
332 else
333 {
334 if(is_valid_element)
335 {
336#if CK_TILE_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS
337 X tmp;
338
339 __builtin_memcpy(&tmp, &(p_data_[i + linear_offset]), sizeof(X));
340
341 return tmp;
342#else
343 return *c_style_pointer_cast<const X*>(&p_data_[i + linear_offset]);
344#endif
345 }
346 else
347 {
348 if constexpr(InvalidElementUseNumericalZeroValue)
349 {
350 return X{numeric<remove_cvref_t<T>>::zero()};
351 }
352 else
353 {
354 return X{invalid_element_value_};
355 }
356 }
357 }
358 }
359
360 /*
361 In the global memory address space, we do not support the transpose instruction in the buffer
362 view. Will report compilation error when developer wants to use it.
363 */
364 template <typename X,
365 bool oob_conditional_check = true,
366 typename std::enable_if<
367 std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
368 typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
369 bool>::type = false>
371 index_t linear_offset,
372 bool is_valid_element,
374 {
375 static_assert(false, "Error: transpose load not supported in global memory space.");
376 ignore = i;
377 ignore = linear_offset;
378 ignore = is_valid_element;
379 return;
380 }
381
382 // i is offset of T, not X. i should be aligned to X
383 template <typename X,
384 bool oob_conditional_check = true,
385 bool pre_nop = false,
386 typename std::enable_if<
387 std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
388 typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
389 bool>::type = false>
391 index_t v_offset,
392 index_t i_offset,
393 bool is_valid_element,
394 bool_constant<pre_nop> = {}) const
395 {
396 constexpr index_t scalar_per_t_vector = vector_traits<remove_cvref_t<T>>::vector_size;
397
398 constexpr index_t scalar_per_x_vector = vector_traits<remove_cvref_t<X>>::vector_size;
399
400 static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
401 "wrong! X should contain multiple T");
402
403 constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
404
405 amd_buffer_load_raw<remove_cvref_t<T>, t_per_x, Coherence, oob_conditional_check, pre_nop>(
406 dst, cached_buf_res_, v_offset, i_offset, is_valid_element, bool_constant<pre_nop>{});
407 }
408
409 // i is offset of T, not X. i should be aligned to X
410 template <typename X,
411 bool oob_conditional_check = true,
412 typename std::enable_if<
413 std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
414 typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
415 bool>::type = false>
417 index_t i,
418 index_t linear_offset,
419 bool is_valid_element,
421 {
422 // X is vector of T
423 constexpr index_t scalar_per_t_vector = vector_traits<remove_cvref_t<T>>::vector_size;
424 constexpr index_t scalar_per_x_vector = vector_traits<remove_cvref_t<X>>::vector_size;
425
426 static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
427 "wrong! X should contain multiple T");
428
429 constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
430 const int32x4_t src_wave_buffer_resource =
431 make_wave_buffer_resource(p_data_, (buffer_size_) * sizeof(type));
432
434 smem,
435 src_wave_buffer_resource,
436 i,
437 linear_offset,
438 is_valid_element,
440 }
441
442 // i is offset of T, not X. i should be aligned to X
443 template <typename X,
444 bool pre_nop = false,
445 typename std::enable_if<
446 std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
447 typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
448 bool>::type = false>
450 index_t i,
451 index_t linear_offset,
452 bool /*is_valid_element*/,
453 bool_constant<pre_nop> = {}) const
454 {
455 // X is vector of T
456 constexpr index_t scalar_per_t_vector = vector_traits<remove_cvref_t<T>>::vector_size;
457 constexpr index_t scalar_per_x_vector = vector_traits<remove_cvref_t<X>>::vector_size;
458
459 static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
460 "wrong! X should contain multiple T");
461
462 constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
463
465 smem, cached_buf_res_, i, linear_offset, bool_constant<pre_nop>{});
466 }
467
468 // i is offset of T, not X. i should be aligned to X
469 template <memory_operation_enum Op,
470 typename X,
471 bool oob_conditional_check = true,
472 typename std::enable_if<
473 std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
474 typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
475 bool>::type = false>
477 index_t linear_offset,
478 bool is_valid_element,
479 const X& x,
481 {
482 if constexpr(Op == memory_operation_enum::set)
483 {
484 this->template set<X, oob_conditional_check>(i, linear_offset, is_valid_element, x);
485 }
486 else if constexpr(Op == memory_operation_enum::atomic_add)
487 {
489 i, linear_offset, is_valid_element, x);
490 }
491 else if constexpr(Op == memory_operation_enum::atomic_max)
492 {
494 i, linear_offset, is_valid_element, x);
495 }
496 // FIXME: remove memory_operation_enum::add
497 else if constexpr(Op == memory_operation_enum::add)
498 {
499 auto tmp =
500 this->template get<X, oob_conditional_check>(i, linear_offset, is_valid_element);
501 this->template set<X, oob_conditional_check>(
502 i, linear_offset, is_valid_element, x + tmp);
503 // tmp += x;
504 // this->template set<X>(i, is_valid_element, tmp);
505 }
506 }
507
508 // i is offset of T, not X. i should be aligned to X
509 template <memory_operation_enum Op,
510 typename X,
511 bool oob_conditional_check = true,
512 bool pre_nop = false,
513 typename std::enable_if<
514 std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
515 typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
516 bool>::type = false>
518 index_t linear_offset,
519 bool is_valid_element,
520 const X& x,
523 {
524 if constexpr(Op == memory_operation_enum::set)
525 {
526 this->template set_raw<X, oob_conditional_check>(i, linear_offset, is_valid_element, x);
527 }
528 else if constexpr(Op == memory_operation_enum::atomic_add)
529 {
530 this->template atomic_add_raw<X, oob_conditional_check, pre_nop>(
531 i, linear_offset, is_valid_element, x);
532 }
533 else if constexpr(Op == memory_operation_enum::atomic_max)
534 {
535 // this->template atomic_max_raw<X>(i, linear_offset, is_valid_element, x);
536 }
537 }
538
539 // i is offset of T, not X. i should be aligned to X
540 template <typename X,
541 bool oob_conditional_check = true,
542 typename std::enable_if<
543 std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
544 typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
545 bool>::type = false>
546 CK_TILE_DEVICE void set(index_t i, index_t linear_offset, bool is_valid_element, const X& x)
547 {
548 // X contains multiple T
549 constexpr index_t scalar_per_t_vector = vector_traits<remove_cvref_t<T>>::vector_size;
550
551 constexpr index_t scalar_per_x_vector = vector_traits<remove_cvref_t<X>>::vector_size;
552
553 static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
554 "wrong! X should contain multiple T");
555
556#if CK_TILE_USE_AMD_BUFFER_STORE
557 bool constexpr use_amd_buffer_addressing = true;
558#else
559 bool constexpr use_amd_buffer_addressing = false;
560#endif
561
562 if constexpr(use_amd_buffer_addressing)
563 {
564 constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
565
566 amd_buffer_store<remove_cvref_t<T>, t_per_x, Coherence>(
567 x, p_data_, i + linear_offset, is_valid_element, buffer_size_);
568 }
569 else
570 {
571 if(is_valid_element)
572 {
573#if CK_TILE_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS
574 X tmp = x;
575
576 __builtin_memcpy(&(p_data_[i + linear_offset]), &tmp, sizeof(X));
577#else
578 *c_style_pointer_cast<X*>(&p_data_[i + linear_offset]) = x;
579#endif
580 }
581 }
582 }
583
584 // i is offset of T, not X. i should be aligned to X
585 template <typename X,
586 bool oob_conditional_check = true,
587 typename std::enable_if<
588 std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
589 typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
590 bool>::type = false>
591 CK_TILE_DEVICE void set_raw(index_t i, index_t linear_offset, bool is_valid_element, const X& x)
592 {
593 // X contains multiple T
594 constexpr index_t scalar_per_t_vector = vector_traits<remove_cvref_t<T>>::vector_size;
595
596 constexpr index_t scalar_per_x_vector = vector_traits<remove_cvref_t<X>>::vector_size;
597
598 static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
599 "wrong! X should contain multiple T");
600
601 constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
602 amd_buffer_store_raw<remove_cvref_t<T>, t_per_x, Coherence, oob_conditional_check>(
603 x, p_data_, i, linear_offset, is_valid_element, buffer_size_);
604 }
605
606 template <typename X,
607 bool oob_conditional_check = true,
608 typename std::enable_if<
609 std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
610 typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
611 bool>::type = false>
612 CK_TILE_DEVICE void
613 atomic_add(index_t i, index_t linear_offset, bool is_valid_element, const X& x)
614 {
615 using scalar_t = typename vector_traits<remove_cvref_t<T>>::scalar_type;
616
617 // X contains multiple T
618 constexpr index_t scalar_per_t_vector = vector_traits<remove_cvref_t<T>>::vector_size;
619
620 constexpr index_t scalar_per_x_vector = vector_traits<remove_cvref_t<X>>::vector_size;
621
622 static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
623 "wrong! X should contain multiple T");
624
625 static_assert(get_address_space() == address_space_enum::global, "only support global mem");
626
627#if CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER && CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT
628 bool constexpr use_amd_buffer_addressing =
629 std::is_same_v<remove_cvref_t<scalar_t>, int32_t> ||
630 std::is_same_v<remove_cvref_t<scalar_t>, float> ||
631 (std::is_same_v<remove_cvref_t<scalar_t>, half_t> && scalar_per_x_vector % 2 == 0)
632#if defined(__gfx950__) // only gfx950 support atomic_pk_add_bf16
633 ||
634 (std::is_same_v<remove_cvref_t<scalar_t>, bfloat16_t> && scalar_per_x_vector % 2 == 0)
635#endif
636 ;
637#elif CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER && (!CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT)
638 bool constexpr use_amd_buffer_addressing =
639 std::is_same_v<remove_cvref_t<scalar_t>, int32_t>;
640#elif(!CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER) && CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT
641 bool constexpr use_amd_buffer_addressing =
642 std::is_same_v<remove_cvref_t<scalar_t>, float> ||
643 (std::is_same_v<remove_cvref_t<scalar_t>, half_t> && scalar_per_x_vector % 2 == 0)
644#if defined(__gfx950__) // only gfx950 support atomic_pk_add_bf16
645 ||
646 (std::is_same_v<remove_cvref_t<scalar_t>, bfloat16_t> && scalar_per_x_vector % 2 == 0)
647#endif
648 ;
649#else
650 bool constexpr use_amd_buffer_addressing = false;
651#endif
652
653 constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
654
655 if constexpr(use_amd_buffer_addressing)
656 {
658 x, p_data_, i + linear_offset, is_valid_element, buffer_size_);
659 }
660 else
661 {
662 if(is_valid_element)
663 {
664 atomic_add_g<remove_cvref_t<T>, t_per_x>(&p_data_[i + linear_offset], x);
665 }
666 }
667 }
668
669 template <typename X,
670 bool oob_conditional_check = true,
671 bool pre_nop = true,
672 typename std::enable_if<
673 std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
674 typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
675 bool>::type = false>
676 CK_TILE_DEVICE void
677 atomic_add_raw(index_t i, index_t linear_offset, bool is_valid_element, const X& x)
678 {
679 // using scalar_t = typename vector_traits<remove_cvref_t<T>>::scalar_type;
680
681 // X contains multiple T
682 constexpr index_t scalar_per_t_vector = vector_traits<remove_cvref_t<T>>::vector_size;
683
684 constexpr index_t scalar_per_x_vector = vector_traits<remove_cvref_t<X>>::vector_size;
685
686 static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
687 "wrong! X should contain multiple T");
688
689 static_assert(get_address_space() == address_space_enum::global, "only support global mem");
690
691 constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
692
694 t_per_x,
695 Coherence,
696 oob_conditional_check,
697 pre_nop>(
698 x, p_data_, i, linear_offset, is_valid_element, buffer_size_);
699 }
700
701 template <typename X,
702 bool oob_conditional_check = true,
703 typename std::enable_if<
704 std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
705 typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
706 bool>::type = false>
707 CK_TILE_DEVICE void
708 atomic_max(index_t i, index_t linear_offset, bool is_valid_element, const X& x)
709 {
710 // X contains multiple T
711 constexpr index_t scalar_per_t_vector = vector_traits<remove_cvref_t<T>>::vector_size;
712
713 constexpr index_t scalar_per_x_vector = vector_traits<remove_cvref_t<X>>::vector_size;
714
715 static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
716 "wrong! X should contain multiple T");
717
718 static_assert(get_address_space() == address_space_enum::global, "only support global mem");
719
720#if CK_TILE_USE_AMD_BUFFER_ATOMIC_MAX_FLOAT64
721 using scalar_t = typename vector_traits<remove_cvref_t<T>>::scalar_type;
722 bool constexpr use_amd_buffer_addressing = std::is_same_v<remove_cvref_t<scalar_t>, double>;
723#else
724 bool constexpr use_amd_buffer_addressing = false;
725#endif
726
727 constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
728
729 if constexpr(use_amd_buffer_addressing)
730 {
732 x, p_data_, i + linear_offset, is_valid_element, buffer_size_);
733 }
734 else if(is_valid_element)
735 {
736 atomic_max_g<remove_cvref_t<T>, t_per_x>(&p_data_[i + linear_offset], x);
737 }
738 }
739
740 // FIXME: remove
741 CK_TILE_DEVICE static constexpr bool is_static_buffer() { return false; }
742
743 // FIXME: remove
744 CK_TILE_DEVICE static constexpr bool is_dynamic_buffer() { return true; }
745};
746
747// Address Space: LDS
748// T may be scalar or vector
749// X may be scalar or vector
750// T and X have same scalar type
751// X contains multiple T
752// FIXME: InvalidElementUseNumericalZeroValue and invalid_element_value_ should be a property of
753// transforms of tensor_view/Tensor
754template <typename T, typename BufferSizeType, bool InvalidElementUseNumericalZeroValue>
756 T,
757 BufferSizeType,
758 InvalidElementUseNumericalZeroValue,
760{
761 using type = T;
762
763 T* p_data_ = nullptr;
764 BufferSizeType buffer_size_;
766
771
772 CK_TILE_HOST_DEVICE constexpr buffer_view(T* __restrict__ p_data, BufferSizeType buffer_size)
773 : p_data_{p_data}, buffer_size_{buffer_size}, invalid_element_value_{0}
774 {
775 }
776
777 CK_TILE_HOST_DEVICE constexpr buffer_view(T* __restrict__ p_data,
778 BufferSizeType buffer_size,
779 T invalid_element_value)
780 : p_data_{p_data}, buffer_size_{buffer_size}, invalid_element_value_{invalid_element_value}
781 {
782 }
783
785
790
791 // i is offset of T
792 // FIXME: doesn't do is_valid check
793 CK_TILE_DEVICE constexpr const T& operator[](index_t i) const { return p_data_[i]; }
794
795 // i is offset of T
796 // FIXME: doesn't do is_valid check
797 CK_TILE_DEVICE constexpr T& operator()(index_t i) { return p_data_[i]; }
798
799 // i is offset of T, not X. i should be aligned to X
800 template <typename X,
801 bool oob_conditional_check = true,
802 typename std::enable_if<
803 std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
804 typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
805 bool>::type = false>
806 CK_TILE_DEVICE constexpr auto get(index_t i,
807 index_t linear_offset,
808 bool is_valid_element,
810 {
811 // X contains multiple T
812 constexpr index_t scalar_per_t_vector = vector_traits<remove_cvref_t<T>>::vector_size;
813
814 constexpr index_t scalar_per_x_vector = vector_traits<remove_cvref_t<X>>::vector_size;
815
816 static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
817 "wrong! X should contain multiple T");
818
819 if(is_valid_element)
820 {
821#if CK_TILE_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS
822 X tmp;
823
824 __builtin_memcpy(&tmp, &(p_data_[i + linear_offset]), sizeof(X));
825
826 return tmp;
827#else
829 scalar_per_t_vector * scalar_per_x_vector>;
830 // using buf_t = ushort __attribute__((ext_vector_type(8)));
831 auto rtn = *c_style_pointer_cast<const buf_t*>(&p_data_[i + linear_offset]);
832 return bit_cast<X>(rtn);
833#endif
834 }
835 else
836 {
837 if constexpr(InvalidElementUseNumericalZeroValue)
838 {
839 return X{numeric<remove_cvref_t<T>>::zero()};
840 }
841 else
842 {
843 return X{invalid_element_value_};
844 }
845 }
846 }
847
848 // i is offset of T, not X. i should be aligned to X
849 template <typename X,
850 bool oob_conditional_check = true,
851 bool pre_nop = false,
852 typename std::enable_if<
853 std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
854 typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
855 bool>::type = false>
857 index_t v_offset,
858 index_t i_offset,
859 bool /*is_valid_element*/,
860 bool_constant<pre_nop> = {}) const
861 {
862 smem_load<sizeof(X)>{}(dst, v_offset * sizeof(T), i_offset * sizeof(T));
863 }
864
865 template <typename X,
866 typename std::enable_if<
867 std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
868 typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
869 bool>::type = false>
870 CK_TILE_DEVICE constexpr auto transpose_get([[maybe_unused]] index_t i,
871 [[maybe_unused]] index_t linear_offset,
872 bool is_valid_element) const
873 {
874 // X contains multiple T
875 constexpr index_t scalar_per_t_vector = vector_traits<remove_cvref_t<T>>::vector_size;
876
877 constexpr index_t scalar_per_x_vector = vector_traits<remove_cvref_t<X>>::vector_size;
878
879 static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
880 "wrong! X should contain multiple T");
881
882 if(is_valid_element)
883 {
884#if defined(__gfx950__)
885 constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
886 return amd_transpose_load_to_vgpr<remove_cvref_t<T>, t_per_x>(p_data_ + i +
887 linear_offset);
888#else
889 return X{numeric<remove_cvref_t<T>>::zero()};
890#endif
891 }
892 else
893 {
894 if constexpr(InvalidElementUseNumericalZeroValue)
895 {
896 return X{numeric<remove_cvref_t<T>>::zero()};
897 }
898 else
899 {
900 return X{invalid_element_value_};
901 }
902 }
903 }
904
905 // i is offset of T, not X. i should be aligned to X
906 template <memory_operation_enum Op,
907 typename X,
908 typename std::enable_if<
909 std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
910 typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
911 bool>::type = false>
912 CK_TILE_DEVICE void update(index_t i, index_t linear_offset, bool is_valid_element, const X& x)
913 {
914 if constexpr(Op == memory_operation_enum::set)
915 {
916 this->template set<X>(i, linear_offset, is_valid_element, x);
917 }
918 // FIXME: remove memory_operation_enum::add
919 else if constexpr(Op == memory_operation_enum::add)
920 {
921 auto tmp = this->template get<X>(i, linear_offset, is_valid_element);
922 this->template set<X>(i, linear_offset, is_valid_element, x + tmp);
923 }
924 }
925
926 // i is offset of T, not X. i should be aligned to X
927 template <typename X,
928 typename std::enable_if<
929 std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
930 typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
931 bool>::type = false>
932 CK_TILE_DEVICE void set(index_t i, index_t linear_offset, bool is_valid_element, const X& x)
933 {
934 // X contains multiple T
935 constexpr index_t scalar_per_t_vector = vector_traits<remove_cvref_t<T>>::vector_size;
936
937 constexpr index_t scalar_per_x_vector = vector_traits<remove_cvref_t<X>>::vector_size;
938
939 static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
940 "wrong! X should contain multiple T");
941
942#if CK_TILE_WORKAROUND_SWDEV_XXXXXX_INT8_DS_WRITE_ISSUE
943 bool constexpr workaround_int8_ds_write_issue = true;
944#else
945 bool constexpr workaround_int8_ds_write_issue = false;
946#endif
947
948 i += linear_offset; // simplicity
949 if constexpr(std::is_same_v<typename vector_traits<remove_cvref_t<T>>::scalar_type,
950 int8_t> &&
951 workaround_int8_ds_write_issue)
952 {
953 if(is_valid_element)
954 {
955 // HACK: compiler would lower IR "store<i8, 16> address_space(3)" into inefficient
956 // ISA, so I try to let compiler emit IR "store<i32, 4>" which would be lower to
957 // ds_write_b128
958 // TODO: remove this after compiler fix
959 // clang-format off
960 static_assert(
961 (std::is_same_v<remove_cvref_t<T>, int8_t> && std::is_same_v<remove_cvref_t<X>, int8_t>) ||
962 (std::is_same_v<remove_cvref_t<T>, int8_t> && std::is_same_v<remove_cvref_t<X>, int8x2_t>) ||
963 (std::is_same_v<remove_cvref_t<T>, int8_t> && std::is_same_v<remove_cvref_t<X>, int8x4_t>) ||
964 (std::is_same_v<remove_cvref_t<T>, int8_t> && std::is_same_v<remove_cvref_t<X>, int8x8_t>) ||
965 (std::is_same_v<remove_cvref_t<T>, int8_t> && std::is_same_v<remove_cvref_t<X>, int8x16_t>) ||
966 (std::is_same_v<remove_cvref_t<T>, int8x4_t> && std::is_same_v<remove_cvref_t<X>, int8x4_t>) ||
967 (std::is_same_v<remove_cvref_t<T>, int8x8_t> && std::is_same_v<remove_cvref_t<X>, int8x8_t>) ||
968 (std::is_same_v<remove_cvref_t<T>, int8x16_t> && std::is_same_v<remove_cvref_t<X>, int8x16_t>) ||
969 // int8 on thread buffer
970 (std::is_same_v<remove_cvref_t<T>, int8_t> && std::is_same_v<remove_cvref_t<X>, thread_buffer<int8_t, 16>>) ||
971 (std::is_same_v<remove_cvref_t<T>, int8_t> && std::is_same_v<remove_cvref_t<X>, thread_buffer<int8_t, 8>>) ||
972 (std::is_same_v<remove_cvref_t<T>, int8_t> && std::is_same_v<remove_cvref_t<X>, thread_buffer<int8_t, 4>>) ||
973 (std::is_same_v<remove_cvref_t<T>, int8_t> && std::is_same_v<remove_cvref_t<X>, thread_buffer<int8_t, 2>>) ||
974 (std::is_same_v<remove_cvref_t<T>, int8_t> && std::is_same_v<remove_cvref_t<X>, thread_buffer<int8_t, 1>>) ||
975 // ext_vector_type for pk_int4 must use int8_t as type
976 (std::is_same_v<remove_cvref_t<T>, pk_int4_t> && std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 1>>) ||
977 (std::is_same_v<remove_cvref_t<T>, pk_int4_t> && std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 2>>) ||
978 (std::is_same_v<remove_cvref_t<T>, pk_int4_t> && std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 4>>) ||
979 (std::is_same_v<remove_cvref_t<T>, pk_int4_t> && std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 8>>) ||
980 (std::is_same_v<remove_cvref_t<T>, pk_int4_t> && std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 16>>) ||
981 (std::is_same_v<remove_cvref_t<T>, pk_int4x4_t> && std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 4>>) ||
982 (std::is_same_v<remove_cvref_t<T>, pk_int4x8_t> && std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 8>>) ||
983 (std::is_same_v<remove_cvref_t<T>, pk_int4x16_t> && std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 16>>),
984 "wrong! not implemented for this combination, please add "
985 "implementation");
986 // clang-format on
987
988 if constexpr((std::is_same_v<remove_cvref_t<T>, int8_t> &&
989 std::is_same_v<remove_cvref_t<X>, int8_t>) ||
990 (std::is_same_v<remove_cvref_t<T>, int8_t> &&
991 std::is_same_v<remove_cvref_t<X>, thread_buffer<int8_t, 1>>) ||
992 (std::is_same_v<remove_cvref_t<T>, pk_int4_t> &&
993 std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 1>>))
994 {
995 // HACK: cast pointer of x is bad
996 // TODO: remove this after compiler fix
999 }
1000 else if constexpr((std::is_same_v<remove_cvref_t<T>, int8_t> &&
1001 std::is_same_v<remove_cvref_t<X>, int8x2_t>) ||
1002 (std::is_same_v<remove_cvref_t<T>, int8_t> &&
1003 std::is_same_v<remove_cvref_t<X>, thread_buffer<int8_t, 2>>) ||
1004 (std::is_same_v<remove_cvref_t<T>, pk_int4_t> &&
1005 std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 2>>))
1006 {
1007 // HACK: cast pointer of x is bad
1008 // TODO: remove this after compiler fix
1011 }
1012 else if constexpr((std::is_same_v<remove_cvref_t<T>, int8_t> &&
1013 std::is_same_v<remove_cvref_t<X>, int8x4_t>) ||
1014 (std::is_same_v<remove_cvref_t<T>, int8_t> &&
1015 std::is_same_v<remove_cvref_t<X>, thread_buffer<int8_t, 4>>) ||
1016 (std::is_same_v<remove_cvref_t<T>, pk_int4_t> &&
1017 std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 4>>))
1018 {
1019 // HACK: cast pointer of x is bad
1020 // TODO: remove this after compiler fix
1023 }
1024 else if constexpr((std::is_same_v<remove_cvref_t<T>, int8_t> &&
1025 std::is_same_v<remove_cvref_t<X>, int8x8_t>) ||
1026 (std::is_same_v<remove_cvref_t<T>, int8_t> &&
1027 std::is_same_v<remove_cvref_t<X>, thread_buffer<int8_t, 8>>) ||
1028 (std::is_same_v<remove_cvref_t<T>, pk_int4_t> &&
1029 std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 8>>))
1030 {
1031 // HACK: cast pointer of x is bad
1032 // TODO: remove this after compiler fix
1035 }
1036 else if constexpr((std::is_same_v<remove_cvref_t<T>, int8_t> &&
1037 std::is_same_v<remove_cvref_t<X>, int8x16_t>) ||
1038 (std::is_same_v<remove_cvref_t<T>, int8_t> &&
1039 std::is_same_v<remove_cvref_t<X>, thread_buffer<int8_t, 16>>) ||
1040 (std::is_same_v<remove_cvref_t<T>, pk_int4_t> &&
1041 std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 16>>))
1042 {
1043 // HACK: cast pointer of x is bad
1044 // TODO: remove this after compiler fix
1047 }
1048 else if constexpr((std::is_same_v<remove_cvref_t<T>, int8x4_t> &&
1049 std::is_same_v<remove_cvref_t<X>, int8x4_t>) ||
1050 (std::is_same_v<remove_cvref_t<T>, pk_int4x4_t> &&
1051 std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 4>>))
1052 {
1053 // HACK: cast pointer of x is bad
1054 // TODO: remove this after compiler fix
1057 }
1058 else if constexpr((std::is_same_v<remove_cvref_t<T>, int8x8_t> &&
1059 std::is_same_v<remove_cvref_t<X>, int8x8_t>) ||
1060 (std::is_same_v<remove_cvref_t<T>, pk_int4x8_t> &&
1061 std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 8>>))
1062 {
1063 // HACK: cast pointer of x is bad
1064 // TODO: remove this after compiler fix
1067 }
1068 else if constexpr((std::is_same_v<remove_cvref_t<T>, int8x16_t> &&
1069 std::is_same_v<remove_cvref_t<X>, int8x16_t>) ||
1070 (std::is_same_v<remove_cvref_t<T>, pk_int4x16_t> &&
1071 std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 16>>))
1072 {
1073 // HACK: cast pointer of x is bad
1074 // TODO: remove this after compiler fix
1077 }
1078 }
1079 }
1080 else
1081 {
1082 if(is_valid_element)
1083 {
1084#if CK_TILE_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS
1085 X tmp = x;
1086
1087 __builtin_memcpy(&(p_data_[i]), &tmp, sizeof(X));
1088#else
1090 scalar_per_t_vector * scalar_per_x_vector>;
1091
1092 *c_style_pointer_cast<buf_t*>(&p_data_[i]) = reinterpret_cast<const buf_t&>(x);
1093#endif
1094 }
1095 }
1096 }
1097
1098 // FIXME: remove
1099 CK_TILE_DEVICE static constexpr bool is_static_buffer() { return false; }
1100
1101 // FIXME: remove
1102 CK_TILE_DEVICE static constexpr bool is_dynamic_buffer() { return true; }
1103};
1104
1105// Address Space: Vgpr
1106// T may be scalar or vector
1107// X may be scalar or vector
1108// T and X have same scalar type
1109// X contains multiple T
1110// FIXME: InvalidElementUseNumericalZeroValue and invalid_element_value_ should be a property of
1111// transforms of tensor_view/Tensor
1112template <typename T, typename BufferSizeType, bool InvalidElementUseNumericalZeroValue>
1114 T,
1115 BufferSizeType,
1116 InvalidElementUseNumericalZeroValue,
1118{
1119 using type = T;
1120
1121 T* p_data_ = nullptr;
1122 BufferSizeType buffer_size_;
1124
1129
1130 CK_TILE_HOST_DEVICE constexpr buffer_view(T* __restrict__ p_data, BufferSizeType buffer_size)
1131 : p_data_{p_data}, buffer_size_{buffer_size}, invalid_element_value_{0}
1132 {
1133 }
1134
1135 CK_TILE_HOST_DEVICE constexpr buffer_view(T* __restrict__ p_data,
1136 BufferSizeType buffer_size,
1137 T invalid_element_value)
1138 : p_data_{p_data}, buffer_size_{buffer_size}, invalid_element_value_{invalid_element_value}
1139 {
1140 }
1141
1143
1148
1149 // i is offset of T
1150 // FIXME: doesn't do is_valid check
1151 CK_TILE_DEVICE constexpr const T& operator[](index_t i) const { return p_data_[i]; }
1152
1153 // i is offset of T
1154 // FIXME: doesn't do is_valid check
1155 CK_TILE_DEVICE constexpr T& operator()(index_t i) { return p_data_[i]; }
1156
1157 // i is offset of T, not X. i should be aligned to X
1158 template <typename X,
1159 bool oob_conditional_check = true,
1160 typename std::enable_if<
1161 std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
1162 typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
1163 bool>::type = false>
1164 CK_TILE_DEVICE constexpr auto get(index_t i,
1165 index_t /*linear_offset*/,
1166 bool is_valid_element,
1168 {
1169 // X contains multiple T
1170 constexpr index_t scalar_per_t_vector = vector_traits<remove_cvref_t<T>>::vector_size;
1171
1172 constexpr index_t scalar_per_x_vector = vector_traits<remove_cvref_t<X>>::vector_size;
1173
1174 static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
1175 "wrong! X should contain multiple T");
1176
1177 if(is_valid_element)
1178 {
1179#if CK_TILE_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS
1180 X tmp;
1181
1182 __builtin_memcpy(&tmp, &(p_data_[i]), sizeof(X));
1183
1184 return tmp;
1185#else
1186 return *c_style_pointer_cast<const X*>(&p_data_[i]);
1187#endif
1188 }
1189 else
1190 {
1191 if constexpr(InvalidElementUseNumericalZeroValue)
1192 {
1193 return X{numeric<remove_cvref_t<T>>::zero()};
1194 }
1195 else
1196 {
1197 return X{invalid_element_value_};
1198 }
1199 }
1200 }
1201
1202 // i is offset of T, not X. i should be aligned to X
1203 template <memory_operation_enum Op,
1204 typename X,
1205 typename std::enable_if<
1206 std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
1207 typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
1208 bool>::type = false>
1209 CK_TILE_DEVICE void update(index_t i, index_t linear_offset, bool is_valid_element, const X& x)
1210 {
1211 if constexpr(Op == memory_operation_enum::set)
1212 {
1213 this->template set<X>(i, linear_offset, is_valid_element, x);
1214 }
1215 // FIXME: remove memory_operation_enum::add
1216 else if constexpr(Op == memory_operation_enum::add)
1217 {
1218 auto tmp = this->template get<X>(i, linear_offset, is_valid_element);
1219 this->template set<X>(i, linear_offset, is_valid_element, x + tmp);
1220 }
1221 }
1222
1223 // i is offset of T, not X. i should be aligned to X
1224 template <typename X,
1225 typename std::enable_if<
1226 std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
1227 typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
1228 bool>::type = false>
1229 CK_TILE_DEVICE void set(index_t i, index_t linear_offset, bool is_valid_element, const X& x)
1230 {
1231 // X contains multiple T
1232 constexpr index_t scalar_per_t_vector = vector_traits<remove_cvref_t<T>>::vector_size;
1233
1234 constexpr index_t scalar_per_x_vector = vector_traits<remove_cvref_t<X>>::vector_size;
1235
1236 static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
1237 "wrong! X should contain multiple T");
1238
1239 if(is_valid_element)
1240 {
1241#if CK_TILE_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS
1242 X tmp = x;
1243
1244 __builtin_memcpy(&(p_data_[i + linear_offset]), &tmp, sizeof(X));
1245#else
1246 *c_style_pointer_cast<X*>(&p_data_[i + linear_offset]) = x;
1247#endif
1248 }
1249 }
1250
1251 // FIXME: remove
1252 CK_TILE_DEVICE static constexpr bool is_static_buffer() { return false; }
1253
1254 // FIXME: remove
1255 CK_TILE_DEVICE static constexpr bool is_dynamic_buffer() { return true; }
1256};
1257
1258template <address_space_enum BufferAddressSpace,
1260 typename T,
1261 typename BufferSizeType>
1262CK_TILE_HOST_DEVICE constexpr auto make_buffer_view(T* __restrict__ p, BufferSizeType buffer_size)
1263{
1265}
1266
1267template <address_space_enum BufferAddressSpace,
1269 typename T,
1270 typename BufferSizeType,
1271 typename X,
1272 typename std::enable_if<std::is_same<remove_cvref_t<T>, remove_cvref_t<X>>::value,
1273 bool>::type = false>
1274CK_TILE_HOST_DEVICE constexpr auto
1275make_buffer_view(T* __restrict__ p, BufferSizeType buffer_size, X invalid_element_value)
1276{
1278 p, buffer_size, invalid_element_value};
1279}
1280
1281// Generalized print function for all buffer_view variants
1282template <address_space_enum BufferAddressSpace,
1283 typename T,
1284 typename BufferSizeType,
1285 bool InvalidElementUseNumericalZeroValue,
1286 amd_buffer_coherence_enum Coherence>
1287CK_TILE_HOST_DEVICE void print(const buffer_view<BufferAddressSpace,
1288 T,
1289 BufferSizeType,
1290 InvalidElementUseNumericalZeroValue,
1291 Coherence>& bv)
1292{
1293 printf("buffer_view{AddressSpace: %s, p_data_: %p, buffer_size_: ",
1294 address_space_to_string(BufferAddressSpace),
1295 static_cast<void*>(const_cast<remove_cvref_t<T>*>(bv.p_data_)));
1296 print(bv.buffer_size_);
1297 printf(", invalid_element_value_: ");
1298 print(bv.invalid_element_value_);
1299 printf("}");
1300}
1301
1302} // namespace ck_tile
#define CK_TILE_DEVICE
Definition config.hpp:41
#define CK_TILE_LDS_ADDR
Definition config.hpp:58
#define CK_TILE_HOST_DEVICE
Definition config.hpp:42
Definition tile/core/algorithm/cluster_descriptor.hpp:13
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
int8_t int8x8_t
Definition vector_type.hpp:192
CK_TILE_DEVICE void atomic_add(X *p_dst, const X &x)
int8_t int8x2_t
Definition pk_int4.hpp:103
CK_TILE_DEVICE void amd_buffer_atomic_max(const thread_buffer< T, N > &src_thread_data, T *p_dst_wave, const index_t dst_thread_element_offset, const bool dst_thread_element_valid, const index_t dst_element_space_size)
Definition tile/core/arch/amd_buffer_addressing.hpp:2874
int8_t pk_int4x16_t
Definition vector_type.hpp:249
int8_t pk_int4x4_t
Definition vector_type.hpp:247
_Float16 half_t
Definition half.hpp:111
CK_TILE_DEVICE void atomic_add_g(T *p_dst, const thread_buffer< T, N > &x)
Definition tile/core/arch/generic_memory_space_atomic.hpp:358
memory_operation_enum
Definition arch.hpp:56
@ add
Definition arch.hpp:60
@ atomic_add
Definition arch.hpp:58
@ atomic_max
Definition arch.hpp:59
@ set
Definition arch.hpp:57
constant< b > bool_constant
Definition tile/core/numeric/integral_constant.hpp:43
int8_t int8_t
Definition int8.hpp:20
ushort bfloat16_t
Definition bfloat16.hpp:111
CK_TILE_DEVICE void amd_buffer_store(const thread_buffer< T, N > &src_thread_data, T *p_dst_wave, const index_t dst_thread_element_offset, const bool dst_thread_element_valid, const index_t dst_element_space_size)
Definition tile/core/arch/amd_buffer_addressing.hpp:2738
amd_buffer_coherence_enum
Definition tile/core/arch/amd_buffer_addressing.hpp:1404
@ coherence_default
Definition tile/core/arch/amd_buffer_addressing.hpp:1405
CK_TILE_DEVICE void amd_async_buffer_load_with_oob(CK_TILE_LDS_ADDR T *smem, const int32x4_t src_wave_buffer_resource, index_t src_thread_element_offset, index_t src_linear_element_offset, bool is_valid_element, bool_constant< oob_conditional_check >={})
Definition tile/core/arch/amd_buffer_addressing.hpp:2711
int32_t int32x4_t
Definition vector_type.hpp:155
CK_TILE_DEVICE void amd_async_buffer_load_with_oob_raw(T *smem, const T *p_src_wave, index_t src_thread_element_offset, index_t src_linear_element_offset, index_t src_element_space_size, bool_constant< pre_nop >={})
Definition tile/core/arch/amd_buffer_addressing.hpp:2663
CK_TILE_HOST_DEVICE constexpr Y bit_cast(const X &x)
Definition bit_cast.hpp:11
CK_TILE_DEVICE void amd_buffer_atomic_add_raw(const thread_buffer< T, N > &src_thread_data, T *p_dst_wave, const index_t dst_thread_element_offset, const index_t dst_linear_element_offset, const bool dst_thread_element_valid, const index_t dst_element_space_size, bool_constant< pre_nop >={})
Definition tile/core/arch/amd_buffer_addressing.hpp:2835
constexpr detail::ignore_t ignore
Definition tile/core/utility/ignore.hpp:20
int8_t int8x16_t
Definition vector_type.hpp:193
CK_TILE_HOST_DEVICE PY c_style_pointer_cast(PX p_x)
Definition type_traits.hpp:104
int32_t int32_t
Definition integer.hpp:10
CK_TILE_DEVICE void amd_buffer_load_raw(thread_buffer< T, N > &dst, const T *p_src_wave, index_t src_thread_element_offset, index_t src_linear_element_offset, index_t src_element_space_size, index_t is_valid_element=0, bool_constant< pre_nop >={})
Definition tile/core/arch/amd_buffer_addressing.hpp:2605
CK_TILE_DEVICE thread_buffer< T, N > amd_buffer_load_invalid_element_return_zero(const T *p_src_wave, index_t src_thread_element_offset, bool src_thread_element_valid, index_t src_element_space_size)
Definition tile/core/arch/amd_buffer_addressing.hpp:2542
CK_TILE_DEVICE void amd_buffer_store_raw(const thread_buffer< T, N > &src_thread_data, T *p_dst_wave, const index_t dst_thread_element_offset, const index_t dst_linear_element_offset, const bool dst_thread_element_valid, const index_t dst_element_space_size)
Definition tile/core/arch/amd_buffer_addressing.hpp:2779
typename impl::ext_vector< T, N >::type ext_vector_t
Definition vector_type.hpp:84
CK_TILE_DEVICE int32x4_t make_wave_buffer_resource(const void *ptr, uint32_t size=0xffffffff, ForceSGPR={})
Definition tile/core/arch/amd_buffer_addressing.hpp:97
address_space_enum
Definition arch.hpp:46
@ generic
Definition arch.hpp:47
@ global
Definition arch.hpp:48
@ lds
Definition arch.hpp:49
@ vgpr
Definition arch.hpp:52
CK_TILE_DEVICE void amd_buffer_atomic_add(const thread_buffer< T, N > &src_thread_data, T *p_dst_wave, const index_t dst_thread_element_offset, const bool dst_thread_element_valid, const index_t dst_element_space_size)
Definition tile/core/arch/amd_buffer_addressing.hpp:2805
int8_t int8x4_t
Definition vector_type.hpp:191
int32_t index_t
Definition integer.hpp:9
CK_TILE_DEVICE thread_buffer< T, N > amd_buffer_load_invalid_element_return_customized_value(const T *p_src_wave, index_t src_thread_element_offset, bool src_thread_element_valid, index_t src_element_space_size, T customized_value)
Definition tile/core/arch/amd_buffer_addressing.hpp:2580
int8_t pk_int4x8_t
Definition vector_type.hpp:248
CK_TILE_HOST_DEVICE constexpr auto make_buffer_view(T *__restrict__ p, BufferSizeType buffer_size)
Definition buffer_view.hpp:1262
CK_TILE_DEVICE void atomic_max_g(T *p_dst, const thread_buffer< T, N > &x)
Definition tile/core/arch/generic_memory_space_atomic.hpp:479
CK_TILE_HOST_DEVICE constexpr const char * address_space_to_string(address_space_enum addr_space)
Helper function to convert address space enum to string.
Definition arch.hpp:338
const GenericPointer< typename T::ValueType > T2 value
Definition pointer.h:1697
CK_TILE_DEVICE void set(index_t i, index_t linear_offset, bool is_valid_element, const X &x)
Definition buffer_view.hpp:183
CK_TILE_DEVICE constexpr auto transpose_get(index_t i, index_t linear_offset, bool is_valid_element, bool_constant< oob_conditional_check >={}) const
Definition buffer_view.hpp:144
CK_TILE_DEVICE constexpr auto get(index_t i, index_t linear_offset, bool is_valid_element, bool_constant< oob_conditional_check >={}) const
Definition buffer_view.hpp:96
CK_TILE_HOST_DEVICE constexpr buffer_view(T *__restrict__ p_data, BufferSizeType buffer_size)
Definition buffer_view.hpp:62
CK_TILE_DEVICE void update(index_t i, index_t linear_offset, bool is_valid_element, const X &x)
Definition buffer_view.hpp:163
CK_TILE_HOST_DEVICE constexpr buffer_view(T *__restrict__ p_data, BufferSizeType buffer_size, T invalid_element_value)
Definition buffer_view.hpp:67
CK_TILE_DEVICE constexpr auto transpose_get(index_t i, index_t linear_offset, bool is_valid_element, bool_constant< oob_conditional_check >={}) const
Definition buffer_view.hpp:370
CK_TILE_HOST_DEVICE constexpr buffer_view(T *__restrict__ p_data, BufferSizeType buffer_size, T invalid_element_value)
Definition buffer_view.hpp:251
CK_TILE_DEVICE void update(index_t i, index_t linear_offset, bool is_valid_element, const X &x, bool_constant< oob_conditional_check >={})
Definition buffer_view.hpp:476
CK_TILE_DEVICE void update_raw(index_t i, index_t linear_offset, bool is_valid_element, const X &x, bool_constant< oob_conditional_check >={}, bool_constant< pre_nop >={})
Definition buffer_view.hpp:517
static CK_TILE_DEVICE constexpr address_space_enum get_address_space()
Definition buffer_view.hpp:268
CK_TILE_HOST_DEVICE constexpr buffer_view(T *__restrict__ p_data, BufferSizeType buffer_size)
Definition buffer_view.hpp:243
CK_TILE_DEVICE constexpr auto async_get_raw(remove_cvref_t< T > *smem, index_t i, index_t linear_offset, bool, bool_constant< pre_nop >={}) const
Definition buffer_view.hpp:449
CK_TILE_DEVICE void atomic_add(index_t i, index_t linear_offset, bool is_valid_element, const X &x)
Definition buffer_view.hpp:613
CK_TILE_DEVICE constexpr auto get(index_t i, index_t linear_offset, bool is_valid_element, bool_constant< oob_conditional_check >={}) const
Definition buffer_view.hpp:288
CK_TILE_DEVICE constexpr auto async_get(CK_TILE_LDS_ADDR remove_cvref_t< T > *smem, index_t i, index_t linear_offset, bool is_valid_element, bool_constant< oob_conditional_check >={}) const
Definition buffer_view.hpp:416
CK_TILE_DEVICE constexpr auto get_raw(remove_cvref_t< X > &dst, index_t v_offset, index_t i_offset, bool is_valid_element, bool_constant< pre_nop >={}) const
Definition buffer_view.hpp:390
CK_TILE_DEVICE void set_raw(index_t i, index_t linear_offset, bool is_valid_element, const X &x)
Definition buffer_view.hpp:591
CK_TILE_DEVICE constexpr const T & operator[](index_t i) const
Definition buffer_view.hpp:275
CK_TILE_DEVICE void set(index_t i, index_t linear_offset, bool is_valid_element, const X &x)
Definition buffer_view.hpp:546
CK_TILE_DEVICE void atomic_max(index_t i, index_t linear_offset, bool is_valid_element, const X &x)
Definition buffer_view.hpp:708
CK_TILE_DEVICE void atomic_add_raw(index_t i, index_t linear_offset, bool is_valid_element, const X &x)
Definition buffer_view.hpp:677
CK_TILE_DEVICE constexpr auto get_raw(remove_cvref_t< X > &dst, index_t v_offset, index_t i_offset, bool, bool_constant< pre_nop >={}) const
Definition buffer_view.hpp:856
CK_TILE_DEVICE void set(index_t i, index_t linear_offset, bool is_valid_element, const X &x)
Definition buffer_view.hpp:932
CK_TILE_DEVICE constexpr auto transpose_get(index_t i, index_t linear_offset, bool is_valid_element) const
Definition buffer_view.hpp:870
CK_TILE_DEVICE constexpr auto get(index_t i, index_t linear_offset, bool is_valid_element, bool_constant< oob_conditional_check >={}) const
Definition buffer_view.hpp:806
CK_TILE_HOST_DEVICE constexpr buffer_view(T *__restrict__ p_data, BufferSizeType buffer_size, T invalid_element_value)
Definition buffer_view.hpp:777
CK_TILE_HOST_DEVICE constexpr buffer_view(T *__restrict__ p_data, BufferSizeType buffer_size)
Definition buffer_view.hpp:772
CK_TILE_DEVICE void update(index_t i, index_t linear_offset, bool is_valid_element, const X &x)
Definition buffer_view.hpp:912
CK_TILE_DEVICE constexpr auto get(index_t i, index_t, bool is_valid_element, bool_constant< oob_conditional_check >={}) const
Definition buffer_view.hpp:1164
CK_TILE_HOST_DEVICE constexpr buffer_view(T *__restrict__ p_data, BufferSizeType buffer_size)
Definition buffer_view.hpp:1130
CK_TILE_DEVICE void set(index_t i, index_t linear_offset, bool is_valid_element, const X &x)
Definition buffer_view.hpp:1229
CK_TILE_HOST_DEVICE constexpr buffer_view(T *__restrict__ p_data, BufferSizeType buffer_size, T invalid_element_value)
Definition buffer_view.hpp:1135
CK_TILE_DEVICE void update(index_t i, index_t linear_offset, bool is_valid_element, const X &x)
Definition buffer_view.hpp:1209
Definition buffer_view.hpp:35
Definition tile/core/numeric/numeric.hpp:81
Definition tile/core/numeric/numeric.hpp:18
Definition pk_int4.hpp:21
Definition tile/core/arch/amd_buffer_addressing.hpp:895
Definition tile/core/utility/debug.hpp:67
Definition vector_type.hpp:90