layout_utils.hpp Source File

layout_utils.hpp Source File#

Composable Kernel: layout_utils.hpp Source File
layout_utils.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include "ck/ck.hpp"
7
14
19
20// Disable from doxygen docs generation
22namespace ck {
23namespace wrapper {
25
26// Disable from doxygen docs generation
28// forward declaration
29template <typename Shape, typename UnrolledDescriptorType>
30struct Layout;
31
32template <typename T>
33using is_tuple = decltype(std::declval<T&>().IsTuple());
34
35namespace {
36namespace detail {
43template <typename... Ts>
44__host__ __device__ constexpr static auto
45GenerateColumnMajorPackedStrides(const Tuple<Ts...>& shape)
46{
47 const auto unrolled_shape = UnrollNestedTuple(shape);
48 return generate_tuple(
49 [&](auto i) {
50 if constexpr(i.value == 0)
51 {
52 return Number<1>{};
53 }
54 else
55 {
56 return TupleReduce<Number<0>{}.value, i.value>([](auto x, auto y) { return x * y; },
57 unrolled_shape);
58 }
59 },
60 Number<decltype(unrolled_shape)::Size()>{});
61}
62
70template <typename LayoutShape, typename LayoutStrides>
71__host__ __device__ constexpr auto MakeUnrolledDescriptor(const LayoutShape& shape,
72 const LayoutStrides& strides)
73{
74 const auto unrolled_shape = UnrollNestedTuple(shape);
75 if constexpr(is_same_v<LayoutStrides, Tuple<>>)
76 {
77 // if not passed, then generate
78 const auto unrolled_strides = GenerateColumnMajorPackedStrides(unrolled_shape);
79 static_assert(unrolled_shape.Size() == unrolled_strides.Size(),
80 "Size of strides and shape are not consistent.");
81 return make_naive_tensor_descriptor(unrolled_shape, unrolled_strides);
82 }
83 else
84 {
85 const auto unrolled_strides = UnrollNestedTuple(strides);
86 static_assert(unrolled_shape.Size() == unrolled_strides.Size(),
87 "Size of strides and shape are not consistent.");
88 return make_naive_tensor_descriptor(unrolled_shape, unrolled_strides);
89 }
90}
91} // namespace detail
92} // namespace
93
95
96// make_*
104template <typename Shape, typename Strides>
105__host__ __device__ constexpr auto make_layout(const Shape& shape, const Strides& strides)
106{
107 using UnrolledDescriptorType = decltype(detail::MakeUnrolledDescriptor(Shape{}, Strides{}));
109 detail::MakeUnrolledDescriptor(shape, strides));
110}
111
119template <typename Shape>
120__host__ __device__ constexpr auto make_layout(const Shape& shape)
121{
122 using UnrolledDescriptorType = decltype(detail::MakeUnrolledDescriptor(Shape{}, Tuple<>{}));
124 detail::MakeUnrolledDescriptor(shape, Tuple<>{}));
125}
126// Layout helpers
127// get
135template <typename T>
136__host__ __device__ T constexpr get(const T& dim)
137{
138 return dim;
139}
140
148template <index_t idx, typename... Dims>
149__host__ __device__ constexpr auto get(const Tuple<Dims...>& tuple)
150{
151 return tuple.At(Number<idx>{});
152}
153
161template <index_t idx, typename Shape, typename UnrolledDesc>
162__host__ __device__ constexpr auto get(const Layout<Shape, UnrolledDesc>& layout)
163{
164 const auto& shape = layout.GetShape();
165 const auto new_shape = get<idx>(shape);
166 static_assert(is_detected<is_tuple, decltype(new_shape)>::value,
167 "Shape of sub layout must be tuple");
168
169 constexpr auto old_shape_dims = decltype(UnrollNestedTuple(shape))::Size();
170 constexpr auto new_shape_dims = decltype(UnrollNestedTuple(new_shape))::Size();
171 constexpr auto shape_offset = decltype(UnrollNestedTuple(TupleSlice<0, idx>(shape)))::Size();
172
173 const auto unrolled_shape = UnrollNestedTuple(shape);
174 const auto transforms = generate_tuple(
175 [&](auto i) {
176 // Compare Idx with shape
177 if constexpr(i < shape_offset || i >= shape_offset + new_shape_dims)
178 {
179 // Remove dimension
180 return make_freeze_transform(Number<0>{});
181 }
182 else
183 {
184 return make_pass_through_transform(unrolled_shape.At(i));
185 }
186 },
187 Number<old_shape_dims>{});
188
189 const auto lower_dims =
190 generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<old_shape_dims>{});
191 const auto upper_dims = generate_tuple(
192 [&](auto i) {
193 if constexpr(i < shape_offset || i >= shape_offset + new_shape_dims)
194 return Sequence<>{};
195
196 else
197 {
198 return Sequence<i.value - shape_offset>{};
199 }
200 },
201 Number<old_shape_dims>{});
202
203 const auto& flatten_desc = layout.GetUnrolledDescriptor();
204 auto new_desc = transform_tensor_descriptor(flatten_desc, transforms, lower_dims, upper_dims);
205 return Layout<decltype(new_shape), decltype(new_desc)>(new_shape, new_desc);
206}
207
215template <index_t Idx, index_t... Idxs, typename T>
216__host__ __device__ constexpr auto get(const T& elem)
217{
218 return get<Idxs...>(get<Idx>(elem));
219}
220
221// size
229template <typename T>
230__host__ __device__ T constexpr size(const T& dim)
231{
232 return dim;
233}
234
242template <index_t idx, typename Shape, typename UnrolledDescriptorType>
243__host__ __device__ constexpr auto size(const Layout<Shape, UnrolledDescriptorType>& layout)
244{
245 return layout.template GetLength<idx>();
246}
247
254template <typename... ShapeDims>
255__host__ __device__ constexpr auto size(const Tuple<ShapeDims...>& shape)
256{
257 const auto unrolled_shape = UnrollNestedTuple(shape);
258 return TupleReduce<0, unrolled_shape.Size()>([](auto x, auto y) { return x * y; },
259 unrolled_shape);
260}
261
268template <typename Shape, typename UnrolledDescriptorType>
269__host__ __device__ constexpr auto size(const Layout<Shape, UnrolledDescriptorType>& layout)
270{
271 return layout.GetLengths();
272}
273
281template <index_t idx, typename... Ts>
282__host__ __device__ constexpr auto size(const Tuple<Ts...>& tuple)
283{
284 return size(tuple.At(Number<idx>{}));
285}
286
295template <index_t Idx, index_t... Idxs, typename T>
296__host__ __device__ constexpr auto size(const T& elem)
297{
298 return size(get<Idx, Idxs...>(elem));
299}
300
301// rank
308template <typename Shape, typename UnrolledDescriptorType>
309__host__ __device__ constexpr auto
311{
312 return Shape::Size();
313}
314
322template <typename... Dims>
323__host__ __device__ constexpr auto rank([[maybe_unused]] const Tuple<Dims...>& tuple)
324{
325 return Tuple<Dims...>::Size();
326}
327
335template <index_t IDim>
336__host__ __device__ constexpr index_t rank([[maybe_unused]] const Number<IDim>& dim)
337{
338 return 1;
339}
340
348__host__ __device__ constexpr index_t rank([[maybe_unused]] const index_t& dim) { return 1; }
349
357template <index_t... Idxs, typename T>
358__host__ __device__ constexpr auto rank(const T& elem)
359{
360 return rank(get<Idxs...>(elem));
361}
362
363// depth
370template <typename Shape, typename UnrolledDescriptorType>
371__host__ __device__ constexpr auto depth(const Layout<Shape, UnrolledDescriptorType>& layout)
372{
373 const auto& shape = layout.GetShape();
374 return TupleDepth(shape);
375}
376
383template <typename... Dims>
384__host__ __device__ constexpr auto depth(const Tuple<Dims...>& tuple)
385{
386 return TupleDepth(tuple);
387}
388
396template <index_t IDim>
397__host__ __device__ constexpr index_t depth([[maybe_unused]] const Number<IDim>& dim)
398{
399 return 0;
400}
401
409__host__ __device__ constexpr index_t depth([[maybe_unused]] const index_t& dim) { return 0; }
410
418template <index_t... Idxs, typename T>
419__host__ __device__ constexpr auto depth(const T& elem)
420{
421 return depth(get<Idxs...>(elem));
422}
423
430template <typename LayoutType>
431__host__ __device__ constexpr const auto& shape(const LayoutType& layout)
432{
433 return layout.GetShape();
434}
435
436// pad
445template <typename Shape, typename UnrolledDesc, typename TileLengths>
446__host__ __device__ constexpr auto pad(const Layout<Shape, UnrolledDesc>& layout,
447 const TileLengths& tile_lengths)
448{
449 auto& unrolled_desc = layout.GetUnrolledDescriptor();
450 // Generate sequence with ones to mark that all dims will be padded
451 constexpr auto do_pads_seq =
452 generate_sequence_v2([](auto) { return Number<1>{}; }, Number<Shape::Size()>{});
453 // Create descriptor with padding
454 auto padded_desc =
455 tensor_operation::device::PadTensorDescriptor(unrolled_desc, tile_lengths, do_pads_seq);
456 // Generate padded shape
457 const auto padded_shape = generate_tuple(
458 [&](auto i) { return padded_desc.GetLength(Number<i>{}); }, Number<TileLengths::Size()>{});
459 // Create layout
460 return Layout<decltype(padded_shape), decltype(padded_desc)>(padded_shape, padded_desc);
461}
462
463// unmerge
473template <index_t Idx, typename Shape, typename UnrolledDesc, typename NewLengths, typename NewIdxs>
474__host__ __device__ constexpr auto unmerge(const Layout<Shape, UnrolledDesc>& layout,
475 const NewLengths& new_lengths,
476 [[maybe_unused]] const NewIdxs& new_indexes)
477{
478 const auto& layout_shape = shape(layout);
479 auto& unrolled_desc = layout.GetUnrolledDescriptor();
480 constexpr auto dims = Shape::Size();
481 // Generate transforms
482 const auto transforms = generate_tuple(
483 [&](auto i) {
484 if constexpr(i == Idx)
485 {
486 return make_unmerge_transform(new_lengths);
487 }
488 else
489 {
490 return make_pass_through_transform(layout_shape.At(i));
491 }
492 },
493 Number<dims>{});
494
495 constexpr auto lower_dims =
496 generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<dims>{});
497 constexpr auto upper_dims = generate_tuple(
498 [&](auto i) {
499 if constexpr(is_detected<is_tuple, tuple_element_t<i.value, NewIdxs>>::value)
500 {
501 constexpr auto idxs_tuple = tuple_element_t<i.value, NewIdxs>{};
502 return to_sequence(idxs_tuple);
503 }
504 else
505 {
506 constexpr index_t index = tuple_element_t<i.value, NewIdxs>{};
507 return Sequence<index>{};
508 }
509 },
510 Number<dims>{});
511
512 const auto unmerged_desc =
513 transform_tensor_descriptor(unrolled_desc, transforms, lower_dims, upper_dims);
514 const auto unmerged_shape =
515 generate_tuple([&](auto i) { return unmerged_desc.GetLength(Number<i>{}); },
516 Number<decltype(unmerged_desc)::GetNumOfVisibleDimension()>{});
517
518 // Create layout
519 return Layout<decltype(unmerged_shape), decltype(unmerged_desc)>(unmerged_shape, unmerged_desc);
520}
521
522} // namespace wrapper
523} // namespace ck
auto pad(ck::index_t mpb, ck::index_t npb, ck::index_t kpb, ck::tensor_operation::device::GemmSpecialization gemm, CDesc_MRaw_NRaw conv)
Definition helper.hpp:70
__host__ __device__ constexpr const auto & shape(const LayoutType &layout)
Get Layout shape.
Definition layout_utils.hpp:431
__host__ __device__ constexpr auto depth(const Layout< Shape, UnrolledDescriptorType > &layout)
Get depth of the layout shape (return 0 if scalar).
Definition layout_utils.hpp:371
__host__ __device__ constexpr auto unmerge(const Layout< Shape, UnrolledDesc > &layout, const NewLengths &new_lengths, const NewIdxs &new_indexes)
Unmerge selected dim in layout.
Definition layout_utils.hpp:474
__host__ __device__ constexpr auto make_layout(const Shape &shape, const Strides &strides)
Make layout function.
Definition layout_utils.hpp:105
__host__ __device__ constexpr auto rank(const Layout< Shape, UnrolledDescriptorType > &layout)
Get layout rank (num elements in shape).
Definition layout_utils.hpp:310
decltype(std::declval< T & >().IsTuple()) is_tuple
Definition device_grouped_conv_fwd_multiple_abd.hpp:23
Definition ck.hpp:268
int32_t index_t
Definition ck.hpp:299
__host__ __device__ constexpr auto UnrollNestedTuple(const Tuple<> &element)
Definition tuple_helper.hpp:120
__host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition tensor_descriptor_helper.hpp:49
integral_constant< index_t, N > Number
Definition number.hpp:12
__host__ __device__ constexpr auto generate_tuple(F &&f, Number< N >)
Definition tuple_helper.hpp:21
const GenericPointer< typename T::ValueType > T2 value
Definition pointer.h:1697
Layout wrapper that performs the tensor descriptor logic.
Definition layout.hpp:24
__host__ __device__ constexpr const auto & layout(const Tensor< BufferAddressSpace, ElementType, Shape, UnrolledDescriptorType > &tensor)
Get Tensor Layout.
Definition tensor_utils.hpp:162