layout.hpp Source File

layout.hpp Source File#

Composable Kernel: layout.hpp Source File
layout.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
7
8// Disable from doxygen docs generation
10namespace ck {
11namespace wrapper {
13
22template <typename Shape, typename UnrolledDescriptorType>
23struct Layout
24{
25 // Disable from doxygen docs generation
27 private:
28 static constexpr auto I0 = Number<0>{};
29 static constexpr auto I1 = Number<1>{};
30
37 template <typename... Ts>
38 __host__ __device__ constexpr static auto
39 GenerateDefaultIdxsTuple([[maybe_unused]] const Tuple<Ts...>& shape)
40 {
41 return generate_tuple(
42 [&](auto) {
43 if constexpr(!remove_cvref_t<UnrolledDescriptorType>::IsKnownAtCompileTime())
44 {
45 // runtime layout
46 return index_t(0);
47 }
48 else
49 {
50 // compiletime layout
51 return I0;
52 }
53 },
54 Number<Tuple<Ts...>::Size()>{});
55 }
56
66 template <typename Idx, typename... Ts>
67 __host__ __device__ constexpr static auto
68 GenerateLowerDim([[maybe_unused]] const Tuple<Ts...>& shape)
69 {
70 if constexpr(Idx::value == 0)
71 {
72 if constexpr(is_detected<is_tuple, tuple_element_t<Idx::value, Tuple<Ts...>>>::value)
73 {
74 // Return Sequence for the first tuple
75 constexpr index_t merge_nelems = decltype(UnrollNestedTuple(
76 tuple_element_t<Idx::value, Tuple<Ts...>>{}))::Size();
77 using LowerDimsSequence =
78 typename arithmetic_sequence_gen<0, merge_nelems, 1>::type;
79 return LowerDimsSequence::Reverse();
80 }
81 else
82 {
83 // Return first element
84 return Sequence<0>{};
85 }
86 }
87 else
88 {
89 // Get previous element using recurence (in compile-time)
90 using PreviousSeqT = decltype(GenerateLowerDim<Number<Idx::value - 1>>(Tuple<Ts...>{}));
91 const auto next_seq_val = PreviousSeqT::At(I0) + 1;
92 if constexpr(is_detected<is_tuple, tuple_element_t<Idx::value, Tuple<Ts...>>>::value)
93 {
94 constexpr index_t merge_nelems = decltype(UnrollNestedTuple(
95 tuple_element_t<Idx::value, Tuple<Ts...>>{}))::Size();
96 using LowerDimsSequence =
97 typename arithmetic_sequence_gen<next_seq_val, next_seq_val + merge_nelems, 1>::
98 type;
99 return LowerDimsSequence::Reverse();
100 }
101 else
102 {
103 return Sequence<next_seq_val>{};
104 }
105 }
106 }
107
119 template <typename... ShapeDims, typename... IdxDims>
120 __host__ __device__ constexpr static auto AlignShapeToIdx(const Tuple<ShapeDims...>& shape,
121 const Tuple<IdxDims...>& idx)
122 {
123 if constexpr(!IsNestedTuple(Tuple<IdxDims...>{}))
124 {
125 // Index unrolled to flatten, return shape
126 return shape;
127 }
128 else
129 {
130 // Iterate over shape tuple elements:
131 // 1. If corresponding idx element is tuple then return (will be unrolled)
132 // 2. If no, pack in tuple. It will be restored during unroll.
133 auto aligned_shape = generate_tuple(
134 [&](auto i) {
135 if constexpr(is_detected<is_tuple,
136 tuple_element_t<i, Tuple<IdxDims...>>>::value)
137 {
138 return shape.At(i);
139 }
140 else
141 {
142 return make_tuple(shape.At(i));
143 }
144 },
145 Number<Tuple<IdxDims...>::Size()>{});
146
147 // Unroll and process next step
148 return AlignShapeToIdx(UnrollNestedTuple<0, 1>(aligned_shape),
149 UnrollNestedTuple<0, 1>(idx));
150 }
151 }
152
160 template <typename... ShapeDims, typename DescriptorToMerge>
161 __host__ __device__ constexpr static auto MakeMerge1d(const Tuple<ShapeDims...>& shape,
162 const DescriptorToMerge& desc)
163 {
164 // Reverse each element in tuple
165 const auto merge_elems = TupleReverse(UnrollNestedTuple(shape));
166 // Generate reverted indexes (column major traverse)
167 using MergeElemsSequence = typename arithmetic_sequence_gen<0, merge_elems.Size(), 1>::type;
168 const auto lower_dims = make_tuple(MergeElemsSequence::Reverse());
169 const auto upper_dims = make_tuple(Sequence<0>{});
170 // Merge to 1d
171 if constexpr(!remove_cvref_t<UnrolledDescriptorType>::IsKnownAtCompileTime())
172 {
173 return transform_tensor_descriptor(
174 desc, make_tuple(make_merge_transform(merge_elems)), lower_dims, upper_dims);
175 }
176 else
177 {
178 // If the descriptor is known at the compilation time,
179 // use `make_merge_transform_v1_carry_check` because it doesn't use
180 // memcpy.
181 return transform_tensor_descriptor(
182 desc,
183 make_tuple(make_merge_transform_v1_carry_check(merge_elems)),
184 lower_dims,
185 upper_dims);
186 }
187 }
188
201 template <typename... ShapeDims, typename... IdxDims, typename DescriptorToMerge>
202 __host__ __device__ constexpr static auto
203 CreateMergedDescriptor(const Tuple<ShapeDims...>& shape,
204 [[maybe_unused]] const Tuple<IdxDims...>& idxs,
205 DescriptorToMerge& desc)
206 {
207 const auto transforms = generate_tuple(
208 [&](auto i) {
209 // Compare Idx with shape
210 if constexpr(is_detected<is_tuple,
211 tuple_element_t<i, Tuple<ShapeDims...>>>::value &&
212 !is_detected<is_tuple, tuple_element_t<i, Tuple<IdxDims...>>>::value)
213 {
214 // If shape element is tuple and idx element is Number, then merge
215 // Unroll and reverse tuple to traverse column-major
216 const auto merge_elems = TupleReverse(UnrollNestedTuple(shape.At(i)));
217 if constexpr(!remove_cvref_t<UnrolledDescriptorType>::IsKnownAtCompileTime())
218 {
219 return make_merge_transform(merge_elems);
220 }
221 else
222 {
223 // If the descriptor is known at the compilation time,
224 // use `make_merge_transform_v1_carry_check` because
225 // it doesn't use memcpy.
226 return make_merge_transform_v1_carry_check(merge_elems);
227 }
228 }
229 else
230 {
231 // If shape element is integer and idx element is tuple, passed idx is wrong
232 static_assert(
233 !(!is_detected<is_tuple, tuple_element_t<i, Tuple<ShapeDims...>>>::value &&
234 is_detected<is_tuple, tuple_element_t<i, Tuple<IdxDims...>>>::value),
235 "Wrong Idx for layout()");
236 // If shape element has the same type as idx element, then pass through
237 return make_pass_through_transform(shape.At(i));
238 }
239 },
240 Number<Tuple<ShapeDims...>::Size()>{});
241
242 const auto lower_dims =
243 generate_tuple([&](auto i) { return GenerateLowerDim<Number<i>>(shape); },
244 Number<Tuple<ShapeDims...>::Size()>{});
245 const auto upper_dims = generate_tuple([&](auto i) { return Sequence<i.value>{}; },
246 Number<Tuple<ShapeDims...>::Size()>{});
247
248 return transform_tensor_descriptor(desc, transforms, lower_dims, upper_dims);
249 }
250
251 using Descriptor1dType =
252 remove_cvref_t<decltype(MakeMerge1d(Shape{}, UnrolledDescriptorType{}))>;
253 using DefaultIdxsTupleType = remove_cvref_t<decltype(GenerateDefaultIdxsTuple(Shape{}))>;
255
256 public:
257 using LayoutShape = Shape;
258 using LayoutUnrolledDescriptorType = UnrolledDescriptorType;
259
268 template <typename... ShapeDims, typename... IdxDims>
269 __host__ __device__ constexpr static auto
270 TransformDesc(const Tuple<ShapeDims...>& shape,
271 const Tuple<IdxDims...>& idxs,
272 const UnrolledDescriptorType& naive_descriptor)
273 {
274 if constexpr(Tuple<IdxDims...>::Size() == I1)
275 {
276 // 1d idx path
277 return MakeMerge1d(shape, naive_descriptor);
278 }
279 else
280 {
281 // Merge nested shape dims
282 // Example idx: (1, 1), 1, 1
283 // Example shape: (2, (2, 2)), 2, (2, 2)
284 // Merged shape: (2, 4), 2, 4
285 static_assert(Tuple<ShapeDims...>::Size() == Tuple<IdxDims...>::Size(),
286 "Idx rank and Shape rank must be the same (except 1d).");
287 // Unroll while IdxDims is nested
288 const auto aligned_shape = AlignShapeToIdx(shape, idxs);
289 // Transform correct form of shape
290 return CreateMergedDescriptor(aligned_shape, UnrollNestedTuple(idxs), naive_descriptor);
291 }
292 }
293
294 using MergedNestsDescriptorType = remove_cvref_t<decltype(TransformDesc(
295 Shape{}, DefaultIdxsTupleType{}, UnrolledDescriptorType{}))>;
296
297 __host__ __device__ constexpr auto GetElementSpaceSize() const
298 {
299 return unrolled_descriptor_.GetElementSpaceSize();
300 }
301
302 __host__ __device__ Layout() = delete;
303
310 __host__ __device__ constexpr Layout(const Shape& shape,
311 const UnrolledDescriptorType& unnested_descriptor)
312 : unrolled_descriptor_(unnested_descriptor), shape_(shape)
313 {
314 // Construct if runtime mode
315 if constexpr(!remove_cvref_t<UnrolledDescriptorType>::IsKnownAtCompileTime())
316 {
317 descriptor_1d_ = MakeMerge1d(shape_, unrolled_descriptor_);
318 merged_nests_descriptor_ =
319 TransformDesc(shape_, DefaultIdxsTupleType{}, unrolled_descriptor_);
320 }
321 }
322
329 template <typename Idxs>
330 __host__ __device__ constexpr index_t operator()() const
331 {
332 static_assert(remove_cvref_t<UnrolledDescriptorType>::IsKnownAtCompileTime(),
333 "Compiletime operator used on runtime layout.");
334 using TransformedDesc = decltype(TransformDesc(Shape{}, Idxs{}, UnrolledDescriptorType{}));
335 using UnrolledIdx = decltype(UnrollNestedTuple(Idxs{}));
336 return TransformedDesc{}.CalculateOffset(UnrolledIdx{});
337 }
338
345 template <typename... Ts>
346 __host__ __device__ index_t operator()(const Tuple<Ts...>& Idx) const
347 {
348 if constexpr(!IsNestedTuple(Tuple<Ts...>{}) && Tuple<Ts...>::Size() == 1)
349 {
350 // if 1d access
351 return descriptor_1d_.CalculateOffset(Idx);
352 }
353 else if constexpr(!IsNestedTuple(Tuple<Ts...>{}) && Tuple<Ts...>::Size() == Shape::Size())
354 {
355 // if Shape::Size() access (merged nested shapes)
356 return merged_nests_descriptor_.CalculateOffset(UnrollNestedTuple(Idx));
357 }
358 else
359 {
360 // Custom index, need to transform descriptor
361 const auto transformed_desc = TransformDesc(shape_, Idx, unrolled_descriptor_);
362 return transformed_desc.CalculateOffset(UnrollNestedTuple(Idx));
363 }
364 }
365
372 template <index_t IDim>
373 __host__ __device__ constexpr auto GetLength() const
374 {
375 const auto elem = shape_.At(Number<IDim>{});
376 if constexpr(is_detected<is_tuple, tuple_element_t<IDim, Shape>>::value)
377 {
378 const auto unrolled_element = UnrollNestedTuple(elem);
379 return TupleReduce<I0.value, unrolled_element.Size()>(
380 [](auto x, auto y) { return x * y; }, unrolled_element);
381 }
382 else
383 {
384 return elem;
385 }
386 }
387
393 __host__ __device__ constexpr auto GetLengths() const
394 {
395 const auto unrolled_shape = UnrollNestedTuple(shape_);
396 return TupleReduce<I0.value, unrolled_shape.Size()>([](auto x, auto y) { return x * y; },
397 unrolled_shape);
398 }
399
405 __host__ __device__ constexpr const Shape& GetShape() const { return shape_; }
406
412 __host__ __device__ constexpr auto GetDefaultLengthsTuple() const
413 {
414 return generate_tuple([&](auto i) { return GetLength<i>(); }, Number<Shape::Size()>{});
415 }
416
422 __host__ __device__ constexpr auto GetDefaultStartIdxs() const
423 {
424 return GenerateDefaultIdxsTuple(shape_);
425 }
426
436 __host__ __device__ constexpr const MergedNestsDescriptorType&
438 {
439 return merged_nests_descriptor_;
440 }
441
449 __host__ __device__ constexpr const Descriptor1dType& Get1DDescriptor() const
450 {
451 return descriptor_1d_;
452 }
453
461 __host__ __device__ constexpr const UnrolledDescriptorType& GetUnrolledDescriptor() const
462 {
463 return unrolled_descriptor_;
464 }
465
466 // Disable from doxygen docs generation
468 private:
469 // All dimensions are unrolled
470 UnrolledDescriptorType unrolled_descriptor_;
471 // 1D descriptor
472 Descriptor1dType descriptor_1d_;
473 // All nesting are merged
474 MergedNestsDescriptorType merged_nests_descriptor_;
475 // Example, shape: ((2, 2), 2)
476 // UnrolledDescriptorType lengths: (2, 2, 2)
477 // Descriptor1dType lengths: (8)
478 // MergedNestsDescriptorType lengths: (4, 2)
479 const Shape shape_;
481};
482
483} // namespace wrapper
484} // namespace ck
__host__ __device__ constexpr const auto & shape(const LayoutType &layout)
Get Layout shape.
Definition layout_utils.hpp:431
Definition ck.hpp:268
const GenericPointer< typename T::ValueType > T2 value
Definition pointer.h:1697
__host__ __device__ Layout()=delete
__host__ __device__ constexpr auto GetLength() const
Length getter (product if tuple).
Definition layout.hpp:373
Shape LayoutShape
Definition layout.hpp:257
__host__ __device__ constexpr auto GetDefaultLengthsTuple() const
Get default lengths (tuple filled with Shape length elements).
Definition layout.hpp:412
__host__ __device__ constexpr const Descriptor1dType & Get1DDescriptor() const
Get descriptor with all dimensions are merged (1D). Example, shape: ((2, 2), 2) Descriptor lengths: (...
Definition layout.hpp:449
__host__ __device__ constexpr const UnrolledDescriptorType & GetUnrolledDescriptor() const
Get unnested descriptor (with unrolled dims) Example, shape: ((2, 2), 2) Descriptor lengths: (2,...
Definition layout.hpp:461
remove_cvref_t< decltype(TransformDesc( Shape{}, DefaultIdxsTupleType{}, UnrolledDescriptorType{}))> MergedNestsDescriptorType
Definition layout.hpp:294
UnrolledDescriptorType LayoutUnrolledDescriptorType
Definition layout.hpp:258
__host__ __device__ constexpr index_t operator()() const
Returns real offset to element in runtime.
Definition layout.hpp:330
__host__ __device__ static constexpr auto TransformDesc(const Tuple< ShapeDims... > &shape, const Tuple< IdxDims... > &idxs, const UnrolledDescriptorType &naive_descriptor)
Transform descriptor to align to passed indexes.
Definition layout.hpp:270
__host__ __device__ constexpr auto GetDefaultStartIdxs() const
Get default start idx (tuple filled with 0s of the same size as Shape).
Definition layout.hpp:422
__host__ __device__ constexpr const MergedNestsDescriptorType & GetMergedNestingDescriptor() const
Get descriptor with all nested dimensions merged. Example, shape: ((2, 2), 2) Descriptor lengths: (4,...
Definition layout.hpp:437
__host__ __device__ index_t operator()(const Tuple< Ts... > &Idx) const
Returns real offset to element in compile time.
Definition layout.hpp:346
__host__ __device__ constexpr auto GetLengths() const
Layout size getter (product of shape).
Definition layout.hpp:393
__host__ __device__ constexpr Layout(const Shape &shape, const UnrolledDescriptorType &unnested_descriptor)
Layout constructor.
Definition layout.hpp:310
__host__ __device__ constexpr auto GetElementSpaceSize() const
Definition layout.hpp:297
__host__ __device__ constexpr const Shape & GetShape() const
Shape getter.
Definition layout.hpp:405