vector_type.hpp Source File

vector_type.hpp Source File#

Composable Kernel: vector_type.hpp Source File
vector_type.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
16
17namespace ck_tile {
18
19// this structure is used to pick up the <base> type inside
20// using xxx = <base> __attribute__((ext_vector_type(N)));
21// because clang only allow native type + bool in this term (custom type will fail)
22// overload this structure to let proper <base> type
23
24template <typename T>
26{
28};
29
30// we name this as ext_vector purposely, because clang ext_vector_type extention only accept literay
31// basic type to construct a ext_vector_type you must be very careful using this, or will have lot
32// of compiler errors e.g. struct A; using Ax2_t = A __attribute__((ext_vector_type(2))); -> will
33// have compiler error
34namespace impl {
35
36template <typename T_, index_t N_, typename = void>
38
39template <typename T_, index_t N_>
40struct ext_vector<T_, N_, std::enable_if_t<!std::is_class_v<typename native_t<T_>::type>>>
41{
42 static constexpr index_t N = N_;
43 // struct type is not supported for ext_vector
45 static_assert(!std::is_class_v<value_type>);
46 using type = value_type __attribute__((ext_vector_type(N))); // this is danguous
47};
48
49template <typename T_, index_t N_>
50struct ext_vector<T_, N_, std::enable_if_t<std::is_class_v<typename native_t<T_>::type>>>
51{
52 static constexpr index_t N = N_;
53 // struct type is not supported for ext_vector
55 static_assert(!std::is_class_v<value_type>);
56 using type = value_type __attribute__((ext_vector_type(N))); // this is danguous
57};
58
59template <typename V_, index_t Vs_, index_t N_>
60struct ext_vector<V_ __attribute__((ext_vector_type(Vs_))),
61 N_,
62 std::enable_if_t<!std::is_class_v<typename native_t<V_>::type>>>
63{
64 static constexpr index_t N = Vs_ * N_;
66 static_assert(!std::is_class_v<value_type>);
67 using type = value_type __attribute__((ext_vector_type(N))); // this is danguous
68};
69
70template <typename V_, index_t Vs_, index_t N_>
71struct ext_vector<V_ __attribute__((ext_vector_type(Vs_))),
72 N_,
73 std::enable_if_t<std::is_class_v<typename native_t<V_>::type>>>
74{
75 static constexpr index_t N = Vs_ * N_;
76 using value_type = typename native_t<remove_cvref_t<V_>>::type::type;
77 static_assert(!std::is_class_v<value_type>);
78 using type = value_type __attribute__((ext_vector_type(N))); // this is danguous
79};
80
81} // namespace impl
82
83template <typename T, index_t N>
85
86// by default, any type will result in a vector_size=1 with scalar_type=T traits.
87// ... unless we have other vector_traits specialization
88template <typename T, typename = void>
90{
92 std::conditional_t<std::is_same_v<remove_cvref_t<T>, pk_int4_t>,
93 int8_t,
94 std::conditional_t<std::is_same_v<remove_cvref_t<T>, pk_fp4_t> ||
95 std::is_same_v<remove_cvref_t<T>, e8m0_t>,
96 uint8_t,
98 static constexpr index_t vector_size = 1;
99};
100
101// specialization for ext_vector_type()
102template <typename T, index_t N>
103struct vector_traits<T __attribute__((ext_vector_type(N))), void>
104{
105 using scalar_type = std::conditional_t<
106 std::is_same_v<T, pk_int4_t>,
107 int8_t,
108 std::conditional_t<std::is_same_v<T, pk_fp4_t> || std::is_same_v<remove_cvref_t<T>, e8m0_t>,
109 uint8_t,
110 T>>;
111 static constexpr index_t vector_size = N;
112};
113
114template <typename X, typename Y>
115using has_same_scalar_type = std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
116 typename vector_traits<remove_cvref_t<Y>>::scalar_type>;
117
118// below are some pre-defines of ext_vector_type
119// attention! 2 vector type could be just the same type
120// fp64
121using fp64_t = double;
122using fp64x2_t = double __attribute__((ext_vector_type(2)));
123using fp64x4_t = double __attribute__((ext_vector_type(4)));
124
125// fp32
126using fp32_t = float;
127using fp32x2_t = float __attribute__((ext_vector_type(2)));
128using fp32x4_t = float __attribute__((ext_vector_type(4)));
129using fp32x8_t = float __attribute__((ext_vector_type(8)));
130using fp32x16_t = float __attribute__((ext_vector_type(16)));
131using fp32x32_t = float __attribute__((ext_vector_type(32)));
132using fp32x64_t = float __attribute__((ext_vector_type(64)));
133
134// fp16
135// using fp16_t = ...
136using fp16x2_t = _Float16 __attribute__((ext_vector_type(2)));
137using fp16x4_t = _Float16 __attribute__((ext_vector_type(4)));
138using fp16x8_t = _Float16 __attribute__((ext_vector_type(8)));
139using fp16x16_t = _Float16 __attribute__((ext_vector_type(16)));
140using fp16x32_t = _Float16 __attribute__((ext_vector_type(32)));
141using fp16x64_t = _Float16 __attribute__((ext_vector_type(64)));
142
143// bf16
144// using bf16_t = ...
145using bf16x2_t = bfloat16_t __attribute__((ext_vector_type(2)));
146using bf16x4_t = bfloat16_t __attribute__((ext_vector_type(4)));
147using bf16x8_t = bfloat16_t __attribute__((ext_vector_type(8)));
148using bf16x16_t = bfloat16_t __attribute__((ext_vector_type(16)));
149using bf16x32_t = bfloat16_t __attribute__((ext_vector_type(32)));
150using bf16x64_t = bfloat16_t __attribute__((ext_vector_type(64)));
151
152// i32
153// using int32_t = ...
154using int32x2_t = int32_t __attribute__((ext_vector_type(2)));
155using int32x4_t = int32_t __attribute__((ext_vector_type(4)));
156using int32x8_t = int32_t __attribute__((ext_vector_type(8)));
157using int32x16_t = int32_t __attribute__((ext_vector_type(16)));
158using int32x32_t = int32_t __attribute__((ext_vector_type(32)));
159using int32x64_t = int32_t __attribute__((ext_vector_type(64)));
160
161// u32
162// using uint32_t = ...
163using uint32x2_t = uint32_t __attribute__((ext_vector_type(2)));
164using uint32x4_t = uint32_t __attribute__((ext_vector_type(4)));
165using uint32x8_t = uint32_t __attribute__((ext_vector_type(8)));
166using uint32x16_t = uint32_t __attribute__((ext_vector_type(16)));
167using uint32x32_t = uint32_t __attribute__((ext_vector_type(32)));
168using uint32x64_t = uint32_t __attribute__((ext_vector_type(64)));
169
170// i16
171// using int16_t = ...
172using int16x2_t = int16_t __attribute__((ext_vector_type(2)));
173using int16x4_t = int16_t __attribute__((ext_vector_type(4)));
174using int16x8_t = int16_t __attribute__((ext_vector_type(8)));
175using int16x16_t = int16_t __attribute__((ext_vector_type(16)));
176using int16x32_t = int16_t __attribute__((ext_vector_type(32)));
177using int16x64_t = int16_t __attribute__((ext_vector_type(64)));
178
179// u16
180// using uint16_t
181using uint16x2_t = uint16_t __attribute__((ext_vector_type(2)));
182using uint16x4_t = uint16_t __attribute__((ext_vector_type(4)));
183using uint16x8_t = uint16_t __attribute__((ext_vector_type(8)));
184using uint16x16_t = uint16_t __attribute__((ext_vector_type(16)));
185using uint16x32_t = uint16_t __attribute__((ext_vector_type(32)));
186using uint16x64_t = uint16_t __attribute__((ext_vector_type(64)));
187
188// i8
189// using int8_t
190using int8x2_t = int8_t __attribute__((ext_vector_type(2)));
191using int8x4_t = int8_t __attribute__((ext_vector_type(4)));
192using int8x8_t = int8_t __attribute__((ext_vector_type(8)));
193using int8x16_t = int8_t __attribute__((ext_vector_type(16)));
194using int8x32_t = int8_t __attribute__((ext_vector_type(32)));
195using int8x64_t = int8_t __attribute__((ext_vector_type(64)));
196
197// ui8
198// using uint8_t
199using uint8x2_t = uint8_t __attribute__((ext_vector_type(2)));
200using uint8x4_t = uint8_t __attribute__((ext_vector_type(4)));
201using uint8x8_t = uint8_t __attribute__((ext_vector_type(8)));
202using uint8x16_t = uint8_t __attribute__((ext_vector_type(16)));
203using uint8x32_t = uint8_t __attribute__((ext_vector_type(32)));
204using uint8x64_t = uint8_t __attribute__((ext_vector_type(64)));
205
206#if CK_TILE_USE_CUSTOM_DATA_TYPE
207// f8
208// using fp8_t
209using fp8x2_t = fp8_raw_t __attribute__((ext_vector_type(2)));
210using fp8x4_t = fp8_raw_t __attribute__((ext_vector_type(4)));
211using fp8x8_t = fp8_raw_t __attribute__((ext_vector_type(8)));
212using fp8x16_t = fp8_raw_t __attribute__((ext_vector_type(16)));
213using fp8x32_t = fp8_raw_t __attribute__((ext_vector_type(32)));
214using fp8x64_t = fp8_raw_t __attribute__((ext_vector_type(64)));
215
216// bf8
217// using bf8_t
218using bf8x2_t = bf8_raw_t __attribute__((ext_vector_type(2)));
219using bf8x4_t = bf8_raw_t __attribute__((ext_vector_type(4)));
220using bf8x8_t = bf8_raw_t __attribute__((ext_vector_type(8)));
221using bf8x16_t = bf8_raw_t __attribute__((ext_vector_type(16)));
222using bf8x32_t = bf8_raw_t __attribute__((ext_vector_type(32)));
223using bf8x64_t = bf8_raw_t __attribute__((ext_vector_type(64)));
224#else
225// f8
226// using fp8_t
227using fp8x2_t = fp8_t __attribute__((ext_vector_type(2)));
228using fp8x4_t = fp8_t __attribute__((ext_vector_type(4)));
229using fp8x8_t = fp8_t __attribute__((ext_vector_type(8)));
230using fp8x16_t = fp8_t __attribute__((ext_vector_type(16)));
231using fp8x32_t = fp8_t __attribute__((ext_vector_type(32)));
232using fp8x64_t = fp8_t __attribute__((ext_vector_type(64)));
233
234// bf8
235// using bf8_t
236using bf8x2_t = bf8_t __attribute__((ext_vector_type(2)));
237using bf8x4_t = bf8_t __attribute__((ext_vector_type(4)));
238using bf8x8_t = bf8_t __attribute__((ext_vector_type(8)));
239using bf8x16_t = bf8_t __attribute__((ext_vector_type(16)));
240using bf8x32_t = bf8_t __attribute__((ext_vector_type(32)));
241using bf8x64_t = bf8_t __attribute__((ext_vector_type(64)));
242#endif
243
244// pk_int4_t
245// using pk_int4_t
246using pk_int4x2_t = int8_t __attribute__((ext_vector_type(2)));
247using pk_int4x4_t = int8_t __attribute__((ext_vector_type(4)));
248using pk_int4x8_t = int8_t __attribute__((ext_vector_type(8)));
249using pk_int4x16_t = int8_t __attribute__((ext_vector_type(16)));
250using pk_int4x32_t = int8_t __attribute__((ext_vector_type(32)));
251
252using pk_fp4x2_t = uint8_t __attribute((ext_vector_type(2)));
253using pk_fp4x4_t = uint8_t __attribute((ext_vector_type(4)));
254using pk_fp4x8_t = uint8_t __attribute((ext_vector_type(8)));
255using pk_fp4x16_t = uint8_t __attribute((ext_vector_type(16)));
256using pk_fp4x32_t = uint8_t __attribute((ext_vector_type(32)));
257} // namespace ck_tile
Definition tile/core/arch/amd_buffer_addressing.hpp:110
Definition tile/core/algorithm/cluster_descriptor.hpp:13
_Float16 fp16x2_t
Definition half.hpp:385
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
uint8_t uint8x64_t
Definition vector_type.hpp:204
int8_t int8x8_t
Definition vector_type.hpp:192
fp8_t fp8x4_t
Definition vector_type.hpp:228
int8_t int8x2_t
Definition pk_int4.hpp:103
uint8_t __attribute((ext_vector_type(2))) pk_fp4x2_t
Definition vector_type.hpp:252
int8_t pk_int4x16_t
Definition vector_type.hpp:249
bfloat16_t bf16x64_t
Definition vector_type.hpp:150
int8_t int8x64_t
Definition vector_type.hpp:195
int8_t pk_int4x4_t
Definition vector_type.hpp:247
int8_t int8x32_t
Definition vector_type.hpp:194
uint8_t __attribute((ext_vector_type(4))) pk_fp4x4_t
Definition vector_type.hpp:253
uint16_t uint16x2_t
Definition vector_type.hpp:181
float fp32x32_t
Definition vector_type.hpp:131
int16_t int16x4_t
Definition vector_type.hpp:173
uint32_t uint32x8_t
Definition vector_type.hpp:165
uint8_t uint8x4_t
Definition vector_type.hpp:200
_Float16 fp16x64_t
Definition vector_type.hpp:141
int8_t int8_t
Definition int8.hpp:20
int32_t int32x8_t
Definition vector_type.hpp:156
ushort bfloat16_t
Definition bfloat16.hpp:111
fp8_t fp8x32_t
Definition vector_type.hpp:231
_BitInt(8) fp8_t
Definition float8.hpp:204
bfloat16_t bf16x16_t
Definition vector_type.hpp:148
int32_t int32x4_t
Definition vector_type.hpp:155
int16_t int16x16_t
Definition vector_type.hpp:175
bfloat16_t bf16x2_t
Definition pk_fp4.hpp:24
double fp64x2_t
Definition vector_type.hpp:122
float fp32x64_t
Definition vector_type.hpp:132
int16_t int16x8_t
Definition vector_type.hpp:174
double fp64_t
Definition vector_type.hpp:121
uint32_t uint32x4_t
Definition vector_type.hpp:164
bf8_t bf8x2_t
Definition vector_type.hpp:236
double fp64x4_t
Definition vector_type.hpp:123
_Float16 fp16x4_t
Definition vector_type.hpp:137
uint8_t __attribute((ext_vector_type(16))) pk_fp4x16_t
Definition vector_type.hpp:255
uint16_t uint16x16_t
Definition vector_type.hpp:184
uint16_t uint16x64_t
Definition vector_type.hpp:186
uint8_t fp8_raw_t
Definition float8.hpp:205
pk_float4_e2m1_t pk_fp4_t
Definition pk_fp4.hpp:151
bf8_t bf8x16_t
Definition vector_type.hpp:239
bf8_t bf8x32_t
Definition vector_type.hpp:240
int8_t int8x16_t
Definition vector_type.hpp:193
uint8_t uint8x16_t
Definition vector_type.hpp:202
int16_t int16x64_t
Definition vector_type.hpp:177
bfloat16_t bf16x32_t
Definition vector_type.hpp:149
bf8_t bf8x8_t
Definition vector_type.hpp:238
uint8_t __attribute((ext_vector_type(32))) pk_fp4x32_t
Definition vector_type.hpp:256
bfloat16_t bf16x4_t
Definition vector_type.hpp:146
int32_t int32_t
Definition integer.hpp:10
int32_t int32x16_t
Definition vector_type.hpp:157
_Float16 fp16x8_t
Definition vector_type.hpp:138
e8m0_bexp_t e8m0_t
Definition tile/core/numeric/e8m0.hpp:49
fp8_t fp8x2_t
Definition vector_type.hpp:227
bfloat16_t bf16x8_t
Definition vector_type.hpp:147
int8_t pk_int4x2_t
Definition vector_type.hpp:246
uint8_t uint8x2_t
Definition vector_type.hpp:199
typename impl::ext_vector< T, N >::type ext_vector_t
Definition vector_type.hpp:84
unsigned _BitInt(8) bf8_t
Definition float8.hpp:206
fp8_t fp8x16_t
Definition vector_type.hpp:230
uint8_t __attribute((ext_vector_type(8))) pk_fp4x8_t
Definition vector_type.hpp:254
uint8_t bf8_raw_t
Definition float8.hpp:207
uint32_t uint32x64_t
Definition vector_type.hpp:168
uint32_t uint32x32_t
Definition vector_type.hpp:167
uint32_t uint32x16_t
Definition vector_type.hpp:166
float fp32x4_t
Definition vector_type.hpp:128
float fp32x8_t
Definition vector_type.hpp:129
int16_t int16x32_t
Definition vector_type.hpp:176
uint16_t uint16x8_t
Definition vector_type.hpp:183
uint16_t uint16x4_t
Definition vector_type.hpp:182
float fp32x2_t
Definition pk_fp4.hpp:22
bf8_t bf8x4_t
Definition vector_type.hpp:237
int8_t int8x4_t
Definition vector_type.hpp:191
std::is_same< typename vector_traits< remove_cvref_t< X > >::scalar_type, typename vector_traits< remove_cvref_t< Y > >::scalar_type > has_same_scalar_type
Definition vector_type.hpp:115
_Float16 fp16x16_t
Definition vector_type.hpp:139
_Float16 fp16x32_t
Definition vector_type.hpp:140
fp8_t fp8x8_t
Definition vector_type.hpp:229
int32_t index_t
Definition integer.hpp:9
uint16_t uint16x32_t
Definition vector_type.hpp:185
int32_t int32x64_t
Definition vector_type.hpp:159
int32_t int32x2_t
Definition vector_type.hpp:154
int8_t pk_int4x8_t
Definition vector_type.hpp:248
bf8_t bf8x64_t
Definition vector_type.hpp:241
float fp32_t
Definition pk_fp4.hpp:21
float fp32x16_t
Definition vector_type.hpp:130
int16_t int16x2_t
Definition vector_type.hpp:172
uint8_t uint8x8_t
Definition vector_type.hpp:201
uint8_t uint8x32_t
Definition vector_type.hpp:203
uint32_t uint32x2_t
Definition vector_type.hpp:163
fp8_t fp8x64_t
Definition vector_type.hpp:232
int8_t pk_int4x32_t
Definition vector_type.hpp:250
int32_t int32x32_t
Definition vector_type.hpp:158
STL namespace.
signed short int16_t
Definition stdint.h:122
unsigned short uint16_t
Definition stdint.h:125
unsigned int uint32_t
Definition stdint.h:126
unsigned char uint8_t
Definition stdint.h:124
typename native_t< remove_cvref_t< V_ > >::type::type value_type
Definition vector_type.hpp:76
typename native_t< remove_cvref_t< V_ > >::type value_type
Definition vector_type.hpp:65
Definition vector_type.hpp:37
Definition vector_type.hpp:26
remove_cvref_t< T > type
Definition vector_type.hpp:27
Definition pk_int4.hpp:21
static constexpr index_t vector_size
Definition vector_type.hpp:111
std::conditional_t< std::is_same_v< T, pk_int4_t >, int8_t, std::conditional_t< std::is_same_v< T, pk_fp4_t >||std::is_same_v< remove_cvref_t< T >, e8m0_t >, uint8_t, T > > scalar_type
Definition vector_type.hpp:105
Definition vector_type.hpp:90
static constexpr index_t vector_size
Definition vector_type.hpp:98
std::conditional_t< std::is_same_v< remove_cvref_t< T >, pk_int4_t >, int8_t, std::conditional_t< std::is_same_v< remove_cvref_t< T >, pk_fp4_t >|| std::is_same_v< remove_cvref_t< T >, e8m0_t >, uint8_t, remove_cvref_t< T > > > scalar_type
Definition vector_type.hpp:91