tuple_helper.hpp Source File

tuple_helper.hpp Source File#

Composable Kernel: tuple_helper.hpp Source File
tuple_helper.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include "functional4.hpp"
7#include "tuple.hpp"
8#ifndef CK_CODE_GEN_RTC
9#include "is_detected.hpp"
10#endif
11
12namespace ck {
13
14template <typename F, index_t... ids>
15__host__ __device__ constexpr auto generate_tuple_for(F&& f, Sequence<ids...>)
16{
17 return make_tuple(f(Number<ids>{})...);
18}
19
20template <typename F, index_t N>
21__host__ __device__ constexpr auto generate_tuple(F&& f, Number<N>)
22{
24}
25
26template <typename F, index_t N>
27__host__ __device__ constexpr auto generate_tuple(F&& f, LongNumber<N>)
28{
29 return unpack([&f](auto&&... xs) { return make_tuple(f(xs)...); },
31}
32
33template <typename F, index_t N>
34__host__ __device__ constexpr auto generate_tie(F&& f, Number<N>)
35{
36 return unpack([&f](auto&&... xs) { return tie(f(xs)...); },
38}
39
40// tx and ty are tuple of references, return type of will tuple of referennce (not rvalue)
41template <typename... X, typename... Y>
42__host__ __device__ constexpr auto concat_tuple_of_reference(const Tuple<X&...>& tx,
43 const Tuple<Y&...>& ty)
44{
45 return unpack2(
46 [&](auto&&... zs) { return Tuple<decltype(zs)...>{ck::forward<decltype(zs)>(zs)...}; },
47 tx,
48 ty);
49}
50
51template <typename... X, typename... Y>
52__host__ __device__ constexpr auto concat_tuple(const Tuple<X...>& tx, const Tuple<Y...>& ty)
53{
54 return unpack2(
55 [&](auto... zs) { return Tuple<decltype(zs)...>{ck::forward<decltype(zs)>(zs)...}; },
56 tx,
57 ty);
58}
59
60// Support any number of tuples to concat (also 1)
61template <typename... X>
62__host__ __device__ constexpr auto concat_tuple(const Tuple<X...>& tx)
63{
64 return tx;
65}
66
67template <typename... X, typename... Tuples>
68__host__ __device__ constexpr auto concat_tuple(const Tuple<X...>& tx, const Tuples&... tuples)
69{
70 return concat_tuple(tx, concat_tuple(tuples...));
71}
72
73namespace detail {
74
75template <typename F, typename X, index_t... Is>
76__host__ __device__ constexpr auto transform_tuples_impl(F f, const X& x, Sequence<Is...>)
77{
78 return make_tuple(f(x.At(Number<Is>{}))...);
79}
80
81template <typename F, typename X, typename Y, index_t... Is>
82__host__ __device__ constexpr auto
83transform_tuples_impl(F f, const X& x, const Y& y, Sequence<Is...>)
84{
85 return make_tuple(f(x.At(Number<Is>{}), y.At(Number<Is>{}))...);
86}
87
88template <typename F, typename X, typename Y, typename Z, index_t... Is>
89__host__ __device__ constexpr auto
90transform_tuples_impl(F f, const X& x, const Y& y, const Z& z, Sequence<Is...>)
91{
92 return make_tuple(f(x.At(Number<Is>{}), y.At(Number<Is>{}), z.At(Number<Is>{}))...);
93}
94
95} // namespace detail
96
97template <typename F, typename X>
98__host__ __device__ constexpr auto transform_tuples(F f, const X& x)
99{
101 f, x, typename arithmetic_sequence_gen<0, X::Size(), 1>::type{});
102}
103
104template <typename F, typename X, typename Y>
105__host__ __device__ constexpr auto transform_tuples(F f, const X& x, const Y& y)
106{
108 f, x, y, typename arithmetic_sequence_gen<0, X::Size(), 1>::type{});
109}
110
111template <typename F, typename X, typename Y, typename Z>
112__host__ __device__ constexpr auto transform_tuples(F f, const X& x, const Y& y, const Z& z)
113{
115 f, x, y, z, typename arithmetic_sequence_gen<0, X::Size(), 1>::type{});
116}
117
118// By default unroll to the flatten
119template <index_t Depth = 0, index_t MaxDepth = -1>
120__host__ __device__ constexpr auto UnrollNestedTuple(const Tuple<>& element)
121{
122 return element;
123}
124
125template <index_t Depth = 0, index_t MaxDepth = -1, typename T>
126__host__ __device__ constexpr auto UnrollNestedTuple(const T& element)
127{
128 return make_tuple(element);
129}
130
131template <index_t Depth = 0, index_t MaxDepth = -1, typename... Ts>
132__host__ __device__ constexpr auto UnrollNestedTuple(const Tuple<Ts...>& tuple)
133{
134 if constexpr(Depth == MaxDepth)
135 {
136 return tuple;
137 }
138 else
139 {
140 return unpack(
141 [&](auto&&... ts) {
143 },
144 tuple);
145 }
146}
147
148template <typename... Ts>
149__host__ __device__ constexpr auto TupleReverse(const Tuple<Ts...>& tuple)
150{
151 return generate_tuple(
152 [&](auto i) {
153 using Idx = Number<Tuple<Ts...>::Size() - i - 1>;
154 return tuple.At(Idx{});
155 },
157}
158
159// Reduce tuple values in specific range using Function
160template <index_t Idx, index_t End, typename F, typename... Ts>
161__host__ __device__ constexpr auto TupleReduce(F&& f, const Tuple<Ts...>& tuple)
162{
163 static_assert(Idx < End, "Wrong parameters for TupleReduce");
164 if constexpr(Idx + 1 == End)
165 {
166 return tuple.At(Number<Idx>{});
167 }
168 else
169 {
170 return f(tuple.At(Number<Idx>{}), TupleReduce<Idx + 1, End>(f, tuple));
171 }
172}
173
174#if !defined(__HIPCC_RTC__) || !defined(CK_CODE_GEN_RTC)
175template <typename T>
176using is_tuple = decltype(ck::declval<T&>().IsTuple());
177#endif
178
179template <typename... Ts>
180__host__ __device__ constexpr auto IsNestedTuple(const Tuple<Ts...>&)
181{
182#if !defined(__HIPCC_RTC__) || !defined(CK_CODE_GEN_RTC)
183 return (is_detected<is_tuple, Ts>::value || ...);
184#endif
185}
186
187template <index_t depth = 0, typename T>
188__host__ __device__ constexpr auto TupleDepth(const T&)
189{
190 return depth;
191}
192
193template <index_t depth = 0, typename... Ts>
194__host__ __device__ constexpr auto TupleDepth(const Tuple<Ts...>&)
195{
196 return math::max(TupleDepth<depth + 1>(Ts{})...);
197}
198
199template <index_t from, index_t to, typename... Ts>
200__host__ __device__ constexpr auto TupleSlice(const Tuple<Ts...>& tuple)
201{
202 return generate_tuple(
203 [&](auto i) {
204 using Idx = Number<from + i>;
205 return tuple.At(Idx{});
206 },
207 Number<to - from>{});
208}
209
210} // namespace ck
__host__ __device__ constexpr auto depth(const Layout< Shape, UnrolledDescriptorType > &layout)
Get depth of the layout shape (return 0 if scalar).
Definition layout_utils.hpp:371
__host__ __device__ constexpr auto transform_tuples_impl(F f, const X &x, Sequence< Is... >)
Definition tuple_helper.hpp:76
__host__ __device__ constexpr T max(T x)
Definition utility/math.hpp:84
Definition ck.hpp:268
__host__ __device__ constexpr auto generate_tuple_for(F &&f, Sequence< ids... >)
Definition tuple_helper.hpp:15
integral_constant< long_index_t, N > LongNumber
Definition number.hpp:15
__host__ __device__ constexpr auto concat_tuple(const Tuple< X... > &tx, const Tuple< Y... > &ty)
Definition tuple_helper.hpp:52
int32_t index_t
Definition ck.hpp:299
__host__ __device__ constexpr auto UnrollNestedTuple(const Tuple<> &element)
Definition tuple_helper.hpp:120
decltype(ck::declval< T & >().IsTuple()) is_tuple
Definition tuple_helper.hpp:176
constexpr Tuple< Args &... > tie(Args &... args) noexcept
Definition utility/tuple.hpp:218
typename detail::detector< nonesuch, void, Op, Args... >::value_t is_detected
Definition is_detected.hpp:34
integral_constant< index_t, N > Number
Definition number.hpp:12
__host__ __device__ constexpr auto TupleReverse(const Tuple< Ts... > &tuple)
Definition tuple_helper.hpp:149
__host__ __device__ constexpr auto unpack(F &&f, X &&x)
Definition functional4.hpp:46
__host__ __device__ constexpr auto transform_tuples(F f, const X &x)
Definition tuple_helper.hpp:98
typename __make_integer_seq< impl::__integer_sequence, index_t, N >::seq_type make_index_sequence
Definition utility/sequence.hpp:199
__host__ __device__ constexpr auto generate_tuple(F &&f, Number< N >)
Definition tuple_helper.hpp:21
__host__ __device__ constexpr auto TupleReduce(F &&f, const Tuple< Ts... > &tuple)
Definition tuple_helper.hpp:161
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__host__ __device__ constexpr auto TupleSlice(const Tuple< Ts... > &tuple)
Definition tuple_helper.hpp:200
__host__ __device__ constexpr auto TupleDepth(const T &)
Definition tuple_helper.hpp:188
__host__ __device__ constexpr auto unpack2(F &&f, X &&x, Y &&y)
Definition functional4.hpp:55
__host__ __device__ constexpr auto IsNestedTuple(const Tuple< Ts... > &)
Definition tuple_helper.hpp:180
__host__ __device__ constexpr auto generate_tie(F &&f, Number< N >)
Definition tuple_helper.hpp:34
__host__ __device__ constexpr auto concat_tuple_of_reference(const Tuple< X &... > &tx, const Tuple< Y &... > &ty)
Definition tuple_helper.hpp:42
Definition utility/sequence.hpp:43
Definition utility/tuple.hpp:186
Definition utility/tuple.hpp:117
__host__ __device__ constexpr const auto & At(Number< I >) const
Definition utility/tuple.hpp:141
Definition utility/sequence.hpp:256
typename conditional< kHasContent, type0, type1 >::type type
Definition utility/sequence.hpp:271