type_convert.hpp Source File

type_convert.hpp Source File#

Composable Kernel: type_convert.hpp Source File
utility/type_convert.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
12#include "ck/utility/array.hpp"
14#include "ck/utility/type.hpp"
15
16namespace ck {
17// Define the common macro for MI300 models
18#if defined(__gfx942__) || defined(__gfx950__)
19#define __gfx94__
20#endif
21
22namespace {
23namespace details {
24
25[[maybe_unused]] __host__ half2_t pk_add_f16(const half2_t& x, const half2_t& y)
26{
27 half2_t vector_res;
28
29 vector_res.x = x.x + y.x;
30 vector_res.y = x.y + y.y;
31
32 return vector_res;
33}
34
35[[maybe_unused]] __device__ half2_t pk_add_f16(const half2_t& x, const half2_t& y)
36{
37 return amd_assembly_pk_add_f16(x, y);
38}
39} // namespace details
40} // namespace
41
42#if defined(__gfx950__)
43inline __device__ bhalf_t static_cast_float_to_bf16(float x)
44{
45 union
46 {
47 uint16_t uint16;
48 __bf16 bf16;
49 } out;
50 out.bf16 = static_cast<__bf16>(x);
51 return out.uint16;
52}
53#endif
54
55// Declare a template function for bf16 conversion using RTN
56template <typename Y, typename X>
57__host__ __device__ constexpr Y bf16_convert_rtn(X x);
58
59// Convert fp32 to bf16 with RTN if higher precision is needed
60template <>
61inline __host__ __device__ constexpr bhalf_t bf16_convert_rtn<bhalf_t, float>(float x)
62{
63#if defined(__gfx950__)
64 return static_cast_float_to_bf16(x);
65#else
66 // Nan check
67 if(x != x)
68 {
69 return uint16_t(0x7FC0);
70 }
71
72 union
73 {
74 float fp32;
75 uint32_t int32;
76 } u = {x};
77
78 const uint32_t first_bf16_mantisa_bit = ((u.int32 >> 16) & 1);
79 constexpr uint32_t rounding_bias = uint32_t((1 << 15) - 1);
80
81 return uint16_t((u.int32 + first_bf16_mantisa_bit + rounding_bias) >> 16);
82#endif
83}
84
85// convert fp16 to bfp16 via fp32 with RTN if higher precision is needed
86template <>
87inline __host__ __device__ constexpr bhalf_t bf16_convert_rtn<bhalf_t, half_t>(half_t x)
88{
89 float x_fp32 = static_cast<float>(x);
90
91 return bf16_convert_rtn<bhalf_t>(x_fp32);
92}
93
94// Convert X to Y, both X and Y are non-const data types.
95template <typename Y,
96 typename X,
97 ck::enable_if_t<!(ck::is_const_v<Y> || ck::is_const_v<X>), bool> = false>
98__host__ __device__ constexpr Y type_convert(X x)
99{
100 static_assert(!ck::is_reference_v<Y> && !ck::is_reference_v<X>);
101
102 return static_cast<Y>(x);
103}
104
105// Convert X to Y, either X or Y is a const data type.
106template <typename Y,
107 typename X,
108 ck::enable_if_t<ck::is_const_v<Y> || ck::is_const_v<X>, bool> = false>
109__host__ __device__ constexpr Y type_convert(X x)
110{
111 static_assert(!ck::is_reference_v<Y> && !ck::is_reference_v<X>);
112
113 using NonConstY = ck::remove_const_t<Y>;
114 using NonConstX = ck::remove_const_t<X>;
115 return static_cast<Y>(type_convert<NonConstY, NonConstX>(x));
116}
117
118// convert bfp16 to fp32
119template <>
120inline __host__ __device__ constexpr float type_convert<float, bhalf_t>(bhalf_t x)
121{
122 union
123 {
124 uint32_t int32;
125 float fp32;
126 } u = {uint32_t(x) << 16};
127
128 return u.fp32;
129}
130
131// convert fp32 to bfp16, round to nearest even
132template <>
133inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, float>(float x)
134{
135#if CK_USE_RNE_BF16_CONVERSION
137#else
138 return uint16_t(static_cast<uint32_t>(x) >> 16);
139#endif
140}
141
142// convert bfp16 to fp16 via fp32
143template <>
144inline __host__ __device__ constexpr half_t type_convert<half_t, bhalf_t>(bhalf_t x)
145{
146 float x_fp32 = type_convert<float>(x);
147
148 return static_cast<half_t>(x_fp32);
149}
150
151// convert fp16 to bfp16 via fp32
152template <>
153inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, half_t>(half_t x)
154{
155 float x_fp32 = static_cast<float>(x);
156
157 return type_convert<bhalf_t>(x_fp32);
158}
159
160// convert bfp16 to int8 via fp32
161template <>
162inline __host__ __device__ constexpr int8_t type_convert<int8_t, bhalf_t>(bhalf_t x)
163{
164 float x_fp32 = type_convert<float>(x);
165
166 return static_cast<int8_t>(x_fp32);
167}
168
169// convert int8 to bfp16 via fp32
170template <>
171inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, int8_t>(int8_t x)
172{
173 float x_fp32 = static_cast<float>(x);
174
175 return type_convert<bhalf_t>(x_fp32);
176}
177
178template <>
179inline __host__ __device__ constexpr f8_ocp_t type_convert<f8_ocp_t, int>(int x)
180{
182}
183
184template <>
185inline __host__ __device__ constexpr bf8_ocp_t type_convert<bf8_ocp_t, int>(int x)
186{
188}
189
190template <typename Y, enable_if_t<is_same_v<Y, ck::tf32_t>, bool> = false>
191inline __host__ __device__ constexpr float type_convert(float x)
192{
193 union
194 {
195 float fp32;
196 uint32_t int32;
197 } u = {x};
198
199 u.int32 = u.int32 & 0xffffe000;
200 return u.fp32;
201}
202
203// Convert X to Y
204template <typename Y, typename X>
205__host__ __device__ constexpr Y type_convert_sp(X x)
206{
207 static_assert(!ck::is_reference_v<Y> && !ck::is_reference_v<X>);
208 return static_cast<Y>(x);
209}
210
211template <>
212inline __host__ __device__ constexpr int type_convert_sp<int, float>(float x)
213{
214 union
215 {
216 float fp32;
217 int int32;
218 } u = {x};
219
220 return u.int32;
221}
222
223template <>
224inline __host__ __device__ constexpr float type_convert_sp<float, int>(int x)
225{
226 union
227 {
228 int int32;
229 float fp32;
230 } u = {x};
231
232 return u.fp32;
233}
234
235template <>
236inline __host__ __device__ constexpr int type_convert_sp<int, half_t>(half_t x)
237{
238 union
239 {
240 half_t fp16;
241 int int32;
242 } u = {x};
243
244 return u.int32;
245}
246
247template <>
248inline __host__ __device__ constexpr half_t type_convert_sp<half_t, int>(int x)
249{
250 union
251 {
252 int int32;
253 half_t fp16;
254 } u = {x};
255
256 return u.fp16;
257}
258
259template <>
260inline __host__ __device__ constexpr int type_convert_sp<int, f8_t>(f8_t x)
261{
262 union
263 {
264 f8_t fp8;
265 int int32;
266 } u = {x};
267
268 return u.int32;
269}
270
271template <>
272inline __host__ __device__ constexpr f8_t type_convert_sp<f8_t, int>(int x)
273{
274 union
275 {
276 int int32;
277 f8_t fp8;
278 } u = {x};
279
280 return u.fp8;
281}
282
283template <>
284inline __host__ __device__ constexpr int type_convert_sp<int, bhalf_t>(bhalf_t x)
285{
286 union
287 {
288 bhalf_t fp16;
289 int int32;
290 } u = {x};
291
292 return u.int32;
293}
294
295template <>
296inline __host__ __device__ constexpr bhalf_t type_convert_sp<bhalf_t, int>(int x)
297{
298 union
299 {
300 int int32;
301 bhalf_t fp16;
302 } u = {x};
303
304 return u.fp16;
305}
306
307template <>
308inline __host__ __device__ constexpr bhalf_t type_convert_sp<bhalf_t, float>(float x)
309{
310 return type_convert<bhalf_t>(x);
311}
312
313template <>
314inline __host__ __device__ constexpr half_t type_convert_sp<half_t, float>(float x)
315{
316 return type_convert<half_t>(x);
317}
318// Declare a template function for fp8 conversion using SR
319template <typename Y, typename X>
320__host__ __device__ constexpr Y f8_convert_sr(X x);
321
322// convert fp32 to fp8 with stochastic rounding
323template <>
324inline __host__ __device__ f8_fnuz_t f8_convert_sr<f8_fnuz_t, float>(float x)
325{
326#if defined(__gfx950__)
327 // use HW clock for stochastic input multiply by incremented thread id
328 uint32_t rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() *
330#else
331 constexpr int seed = 1254739;
332#ifndef CK_CODE_GEN_RTC
333 uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x);
334#else
335 uint32_t rng = prand_generator<float, seed>(reinterpret_cast<size_t>(&x), x);
336#endif // #ifndef CK_CODE_GEN_RTC
337#endif // #if defined(__gfx950__)
338#if defined(__gfx94__)
339 union
340 {
341 float fval;
342 uint32_t i32val;
343 uint8_t i8val[4]; // not endian independent
344 } val;
345 val.fval = x;
346 uint32_t ival = 0;
347 const float max_fp8 = 240.0f;
348 // if x is not +/- infinity or nan
350 // clip float value
351 val.fval = __builtin_amdgcn_fmed3f(val.fval, max_fp8, -max_fp8);
352 ival = __builtin_amdgcn_cvt_sr_fp8_f32(val.fval, rng, ival, 0); // 0 pos
353 val.i32val = ival;
354 return f8_fnuz_t{val.i8val[0]}; // little endian
355#else
356 constexpr bool negative_zero_nan = true;
357 constexpr bool clip = true;
359 return utils::
361 x, rng);
362#endif
363}
364
365// convert fp16 to fp8 with stochastic rounding
366template <>
368{
369#if defined(__gfx94__)
370 // convert to float and use native converion
372#else
373 constexpr bool negative_zero_nan = true;
374 constexpr bool clip = true;
376 constexpr int seed = 1254739;
377#ifndef CK_CODE_GEN_RTC
378 uint32_t rng = prand_generator<half_t, seed>(reinterpret_cast<uintptr_t>(&x), x);
379#else
380 uint32_t rng = prand_generator<half_t, seed>(reinterpret_cast<size_t>(&x), x);
381#endif
383 f8_fnuz_t,
384 negative_zero_nan,
385 clip,
386 (rm == f8_rounding_mode::stochastic)>(x, rng);
387#endif
388}
389
390// convert fp32 to bf8 with stochastic rounding
391template <>
392inline __host__ __device__ bf8_fnuz_t f8_convert_sr<bf8_fnuz_t, float>(float x)
393{
394#if defined(__gfx950__)
395 // use HW clock for stochastic input multiply by incremented thread id
396 uint32_t rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() *
398#else
399 constexpr int seed = 1254739;
400#ifndef CK_CODE_GEN_RTC
401 uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x);
402#else
403 uint32_t rng = prand_generator<float, seed>(reinterpret_cast<size_t>(&x), x);
404#endif // #ifndef CK_CODE_GEN_RTC
405#endif // #if defined(__gfx950__)
406#if defined(__gfx94__)
407 union
408 {
409 float fval;
410 uint32_t i32val;
411 uint8_t i8val[4]; // not endian independent
412 } val;
413 val.fval = x;
414 uint32_t ival = 0;
415 const float max_bf8 = 57344.0f;
416 // if x is not +/- infinity or nan
418 // clip float value
419 val.fval = __builtin_amdgcn_fmed3f(val.fval, max_bf8, -max_bf8);
420 ival = __builtin_amdgcn_cvt_sr_bf8_f32(val.fval, rng, ival, 0); // 0 pos
421 val.i32val = ival;
422 return bf8_fnuz_t{val.i8val[0]}; // little endian
423#else
424 constexpr bool negative_zero_nan = true;
425 constexpr bool clip = true;
427 return utils::cast_to_f8<float,
429 negative_zero_nan,
430 clip,
431 (rm == f8_rounding_mode::stochastic)>(x, rng);
432#endif
433}
434
435// convert fp16 to bf8 with stochastic rounding
436template <>
438{
439#if defined(__gfx94__)
440 // convert to float and use native converion
442#else
443 constexpr bool negative_zero_nan = true;
444 constexpr bool clip = true;
446 constexpr int seed = 1254739;
447#ifndef CK_CODE_GEN_RTC
448 uint32_t rng = prand_generator<half_t, seed>(reinterpret_cast<uintptr_t>(&x), x);
449#else
450 uint32_t rng = prand_generator<half_t, seed>(reinterpret_cast<size_t>(&x), x);
451#endif
454 negative_zero_nan,
455 clip,
456 (rm == f8_rounding_mode::stochastic)>(x, rng);
457#endif
458}
459
466template <>
467inline __host__ __device__ f8_ocp_t f8_convert_sr<f8_ocp_t, float>(float x)
468{
469 return f8_ocp_t{
470 fp8_impl::cvt_float_to_fp8<f8_ocp_t::default_interpret, f8_ocp_t::default_saturation, true>(
471 x)};
472}
473
481template <>
483{
484 return f8x2_ocp_t{
485 fp8_impl::cvt_float_to_fp8<f8_ocp_t::default_interpret, f8_ocp_t::default_saturation, true>(
486 x)};
487}
488
495template <>
496inline __host__ __device__ bf8_ocp_t f8_convert_sr<bf8_ocp_t, float>(float x)
497{
498 return bf8_ocp_t{fp8_impl::cvt_float_to_fp8<bf8_ocp_t::default_interpret,
500 true>(x)};
501}
502
510template <>
512{
513 return bf8x2_ocp_t{fp8_impl::cvt_float_to_fp8<bf8_ocp_t::default_interpret,
515 true>(x)};
516}
517
524template <>
526{
527 return f8_ocp_t{fp8_impl::cvt_half_t_to_fp8<f8_ocp_t::default_interpret,
529 true>(x)};
530}
531
539template <>
541{
542 return f8x2_ocp_t{fp8_impl::cvt_half_t_to_fp8<f8_ocp_t::default_interpret,
544 true>(x)};
545}
546
553template <>
555{
556 return bf8_ocp_t{fp8_impl::cvt_half_t_to_fp8<bf8_ocp_t::default_interpret,
558 true>(x)};
559}
560
568template <>
570{
571 return bf8x2_ocp_t{fp8_impl::cvt_half_t_to_fp8<bf8_ocp_t::default_interpret,
573 true>(x)};
574}
575
582template <>
584{
585 return f8_ocp_t{fp8_impl::cvt_bhalf_t_to_fp8<f8_ocp_t::default_interpret,
587 true>(x)};
588}
589
597template <>
599{
600 return f8x2_ocp_t{fp8_impl::cvt_bhalf_t_to_fp8<f8_ocp_t::default_interpret,
602 true>(x)};
603}
604
611template <>
613{
614 return bf8_ocp_t{fp8_impl::cvt_bhalf_t_to_fp8<bf8_ocp_t::default_interpret,
616 true>(x)};
617}
618
626template <>
628{
629 return bf8x2_ocp_t{fp8_impl::cvt_bhalf_t_to_fp8<bf8_ocp_t::default_interpret,
631 true>(x)};
632}
633
634// Declare a template function for fp8 conversion using RNE
635template <typename Y, typename X>
636__host__ __device__ constexpr Y f8_convert_rne(X x);
637
638// convert fp32 to fp8 with rounding to nearest even
639template <>
640inline __host__ __device__ f8_fnuz_t f8_convert_rne<f8_fnuz_t, float>(float x)
641{
642#if defined(__gfx94__)
643 union
644 {
645 float fval;
646 uint32_t i32val;
647 uint8_t i8val[4]; // not endian independent
648 } val;
649 val.fval = x;
650 uint32_t ival = 0;
651 const float max_fp8 = 240.0f;
652 // if x is not +/- infinity or nan
654 // clip float value
655 val.fval = __builtin_amdgcn_fmed3f(val.fval, max_fp8, -max_fp8);
656 ival = __builtin_amdgcn_cvt_pk_fp8_f32(val.fval, val.fval, ival, false); // false -> WORD0
657 val.i32val = ival;
658 return f8_fnuz_t{val.i8val[0]};
659#else
660 constexpr bool negative_zero_nan = true;
661 constexpr bool clip = true;
663 constexpr uint32_t rng = 0;
664 return utils::
666 x, rng);
667#endif
668}
669
670// convert fp16 to fp8 with rounding to nearest even
671template <>
673{
674#if defined(__gfx94__)
675 // convert to float and use native converion
677#else
678 constexpr bool negative_zero_nan = true;
679 constexpr bool clip = true;
681 constexpr uint32_t rng = 0;
683 f8_fnuz_t,
684 negative_zero_nan,
685 clip,
686 (rm == f8_rounding_mode::stochastic)>(x, rng);
687#endif
688}
689
690// convert fp32 to bf8 with rounding to nearest even
691template <>
692inline __host__ __device__ bf8_fnuz_t f8_convert_rne<bf8_fnuz_t, float>(float x)
693{
694#if defined(__gfx94__)
695 union
696 {
697 float fval;
698 uint32_t i32val;
699 uint8_t i8val[4]; // not endian independent
700 } val;
701 val.fval = x;
702 uint32_t ival = 0;
703 const float max_bf8 = 57344.0f;
704 // if x is not +/- infinity or nan
706 // clip float value
707 val.fval = __builtin_amdgcn_fmed3f(val.fval, max_bf8, -max_bf8);
708 ival = __builtin_amdgcn_cvt_pk_bf8_f32(val.fval, val.fval, ival, false); // false -> WORD0
709 val.i32val = ival;
710 return bf8_fnuz_t{val.i8val[0]};
711#else
712 constexpr bool negative_zero_nan = true;
713 constexpr bool clip = true;
715 constexpr uint32_t rng = 0;
716 return utils::cast_to_f8<float,
718 negative_zero_nan,
719 clip,
720 (rm == f8_rounding_mode::stochastic)>(x, rng);
721#endif
722}
723
724// convert fp16 to bf8 with rounding to nearest even
725template <>
727{
728#if defined(__gfx94__)
729 // convert to float and use native converion
731#else
732 constexpr bool negative_zero_nan = true;
733 constexpr bool clip = true;
735 constexpr uint32_t rng = 0;
738 negative_zero_nan,
739 clip,
740 (rm == f8_rounding_mode::stochastic)>(x, rng);
741#endif
742}
743
750template <>
751inline __host__ __device__ f8_ocp_t f8_convert_rne<f8_ocp_t, float>(float x)
752{
753 return f8_ocp_t{
754 fp8_impl::cvt_float_to_fp8<f8_ocp_t::default_interpret, f8_ocp_t::default_saturation>(x)};
755}
756
764template <>
766{
767 return f8x2_ocp_t{
768 fp8_impl::cvt_float_to_fp8<f8_ocp_t::default_interpret, f8_ocp_t::default_saturation>(x)};
769}
770
777template <>
778inline __host__ __device__ bf8_ocp_t f8_convert_rne<bf8_ocp_t, float>(float x)
779{
780 return bf8_ocp_t{
781 fp8_impl::cvt_float_to_fp8<bf8_ocp_t::default_interpret, bf8_ocp_t::default_saturation>(x)};
782}
783
791template <>
793{
794 return bf8x2_ocp_t{
795 fp8_impl::cvt_float_to_fp8<bf8_ocp_t::default_interpret, bf8_ocp_t::default_saturation>(x)};
796}
797
804template <>
806{
807 return f8_ocp_t{
808 fp8_impl::cvt_half_t_to_fp8<f8_ocp_t::default_interpret, f8_ocp_t::default_saturation>(x)};
809}
810
818template <>
820{
821 return f8x2_ocp_t{
822 fp8_impl::cvt_half_t_to_fp8<f8_ocp_t::default_interpret, f8_ocp_t::default_saturation>(x)};
823}
824
831template <>
833{
834 return bf8_ocp_t{
835 fp8_impl::cvt_half_t_to_fp8<bf8_ocp_t::default_interpret, bf8_ocp_t::default_saturation>(
836 x)};
837}
838
846template <>
848{
849 return bf8x2_ocp_t{
850 fp8_impl::cvt_half_t_to_fp8<bf8_ocp_t::default_interpret, bf8_ocp_t::default_saturation>(
851 x)};
852}
853
860template <>
862{
863 return f8_ocp_t{
864 fp8_impl::cvt_bhalf_t_to_fp8<f8_ocp_t::default_interpret, f8_ocp_t::default_saturation>(x)};
865}
866
874template <>
876{
877 return f8x2_ocp_t{
878 fp8_impl::cvt_bhalf_t_to_fp8<f8_ocp_t::default_interpret, f8_ocp_t::default_saturation>(x)};
879}
880
887template <>
889{
890 return bf8_ocp_t{
891 fp8_impl::cvt_bhalf_t_to_fp8<bf8_ocp_t::default_interpret, bf8_ocp_t::default_saturation>(
892 x)};
893}
894
902template <>
904{
905 return bf8x2_ocp_t{
906 fp8_impl::cvt_bhalf_t_to_fp8<bf8_ocp_t::default_interpret, bf8_ocp_t::default_saturation>(
907 x)};
908}
909
910// convert fp32 to fp8
911template <>
912inline __host__ __device__ f8_fnuz_t type_convert<f8_fnuz_t, float>(float x)
913{
914#if CK_USE_SR_F8_CONVERSION
915 return f8_convert_sr<f8_fnuz_t>(x);
916#else
918#endif
919}
920
921// convert fp8 to fp32
922template <>
923inline __host__ __device__ float type_convert<float, f8_fnuz_t>(f8_fnuz_t x)
924{
925#if defined(__gfx94__)
926 float fval;
927 uint32_t i32val = static_cast<uint32_t>(static_cast<uint8_t>(x));
928 fval = __builtin_amdgcn_cvt_f32_fp8(i32val, 0);
929 // asm volatile("v_cvt_f32_fp8 %0, %1 src0_sel:BYTE_0" : "=v"(fval) : "v"(i32val));
930 return fval;
931#else
932 constexpr bool negative_zero_nan = true;
934#endif
935}
936
937template <>
939{
940#if defined(__gfx94__)
941 const auto i16val = bit_cast<uint16_t>(x);
942 return __builtin_amdgcn_cvt_pk_f32_fp8(i16val, 0);
943#else
944 constexpr bool negative_zero_nan = true;
945 const auto f8x2_v = vector_type<f8_fnuz_t, 2>(x);
946 vector_type<float, 2> f32x2_v;
947 f32x2_v.template AsType<float>()(Number<0>{}) =
949 f8x2_v.template AsType<f8_fnuz_t>()[Number<0>{}]);
950 f32x2_v.template AsType<float>()(Number<1>{}) =
952 f8x2_v.template AsType<f8_fnuz_t>()[Number<1>{}]);
953 return f32x2_v.template AsType<float2_t>()[Number<0>{}];
954#endif
955}
956
963template <>
964inline __host__ __device__ float type_convert<float, f8_ocp_t>(f8_ocp_t x)
965{
966#if CK_OCP_FP8_CVT_FAST_PATH
967 union
968 {
969 unsigned int i32val;
970 fp8_storage_t i8val[4];
971 } val;
972 val.i8val[0] = x.data;
973 return __builtin_amdgcn_cvt_f32_fp8(val.i32val, 0);
974#else
975 return fp8_impl::cast_from_f8<float, f8_ocp_t::wm, f8_ocp_t::we, false>(x.data);
976#endif
977}
978
985template <>
987{
988#if CK_OCP_FP8_CVT_FAST_PATH
989// __builtin_amdgcn_cvt_pk_f32_fp8 can produce incorrect results due to a compiler issue.
990// TODO: Enable when SWDEV-532959 is fixed.
991#if defined(__gfx12__)
992 return float2_t{__builtin_amdgcn_cvt_f32_fp8(bit_cast<uint16_t>(x), 0),
993 __builtin_amdgcn_cvt_f32_fp8(bit_cast<uint16_t>(x), 1)};
994#else
995 return __builtin_amdgcn_cvt_pk_f32_fp8(bit_cast<uint16_t>(x), false);
996#endif
997#else
998 return float2_t{fp8_impl::cast_from_f8<float, f8_ocp_t::wm, f8_ocp_t::we, false>(
999 x.AsType<fp8_storage_t>()[Number<0>{}]),
1000 fp8_impl::cast_from_f8<float, f8_ocp_t::wm, f8_ocp_t::we, false>(
1001 x.AsType<fp8_storage_t>()[Number<1>{}])};
1002#endif
1003}
1004
1011template <>
1013{
1014#if defined(__gfx950__)
1015 union
1016 {
1017 uint16_t i16val;
1018 fp8_storage_t i8val[2];
1019 } input;
1020 input.i8val[0] = x.data;
1021
1022 union
1023 {
1024 half2_t half_vec;
1025 half_t half_arr[2];
1026 } output;
1027 output.half_vec = __builtin_amdgcn_cvt_scalef32_pk_f16_fp8(input.i16val, /*scale*/ 1.f, 0);
1028
1029 return output.half_arr[0];
1030#else
1031 return fp8_impl::cast_from_f8<half_t, f8_ocp_t::wm, f8_ocp_t::we, false>(x.data);
1032#endif
1033}
1034
1041template <>
1043{
1044#if defined(__gfx950__)
1045 return __builtin_amdgcn_cvt_scalef32_pk_f16_fp8(bit_cast<uint16_t>(x), /*scale*/ 1.f, 0);
1046#else
1047 return half2_t{type_convert<half_t>(float(x.AsType<f8_ocp_t>()[Number<0>{}])),
1048 type_convert<half_t>(float(x.AsType<f8_ocp_t>()[Number<1>{}]))};
1049#endif
1050}
1051
1058template <>
1060{
1061#if defined(__gfx950__)
1062 union
1063 {
1064 uint16_t i16val;
1065 fp8_storage_t i8val[2];
1066 } input;
1067 input.i8val[0] = x.data;
1068
1069 union
1070 {
1071 bhalf2_t bhalf_vec;
1072 bhalf_t bhalf_arr[2];
1073 } output;
1074 output.bhalf_vec = __builtin_amdgcn_cvt_scalef32_pk_bf16_fp8(input.i16val, /*scale*/ 1.f, 0);
1075
1076 return output.bhalf_arr[0];
1077#else
1078 return type_convert<bhalf_t>(
1079 fp8_impl::cast_from_f8<float, f8_ocp_t::wm, f8_ocp_t::we, false>(x.data));
1080#endif
1081}
1082
1089template <>
1091{
1092#if defined(__gfx950__)
1093 return __builtin_amdgcn_cvt_scalef32_pk_bf16_fp8(bit_cast<uint16_t>(x), /*scale*/ 1.f, 0);
1094#else
1095 return bhalf2_t{type_convert<bhalf_t>(float(x.AsType<f8_ocp_t>()[Number<0>{}])),
1096 type_convert<bhalf_t>(float(x.AsType<f8_ocp_t>()[Number<1>{}]))};
1097#endif
1098}
1099
1106template <>
1107inline __host__ __device__ float type_convert<float, bf8_ocp_t>(bf8_ocp_t x)
1108{
1109#if CK_OCP_FP8_CVT_FAST_PATH
1110 union
1111 {
1112 unsigned int i32val;
1113 fp8_storage_t i8val[4];
1114 } val;
1115 val.i8val[0] = x.data;
1116 return __builtin_amdgcn_cvt_f32_bf8(val.i32val, 0);
1117#else
1118 return fp8_impl::cast_from_f8<float, bf8_ocp_t::wm, bf8_ocp_t::we, false>(x.data);
1119#endif
1120}
1121
1128template <>
1130{
1131#if CK_OCP_FP8_CVT_FAST_PATH
1132// __builtin_amdgcn_cvt_pk_f32_bf8 can produce incorrect results due to a compiler issue.
1133// TODO: Enable when SWDEV-532959 is fixed.
1134#if defined(__gfx12__)
1135 return float2_t{__builtin_amdgcn_cvt_f32_bf8(bit_cast<uint16_t>(x), 0),
1136 __builtin_amdgcn_cvt_f32_bf8(bit_cast<uint16_t>(x), 1)};
1137#else
1138 return __builtin_amdgcn_cvt_pk_f32_bf8(bit_cast<uint16_t>(x), false);
1139#endif
1140#else
1141 return float2_t{fp8_impl::cast_from_f8<float, bf8_ocp_t::wm, bf8_ocp_t::we, false>(
1142 x.AsType<fp8_storage_t>()[Number<0>{}]),
1143 fp8_impl::cast_from_f8<float, bf8_ocp_t::wm, bf8_ocp_t::we, false>(
1144 x.AsType<fp8_storage_t>()[Number<1>{}])};
1145#endif
1146}
1147
1154template <>
1156{
1157#if defined(__gfx950__)
1158 union
1159 {
1160 uint16_t i16val;
1161 fp8_storage_t i8val[2];
1162 } val;
1163 val.i8val[0] = x.data;
1164 return __builtin_amdgcn_cvt_scalef32_pk_f16_bf8(val.i16val, /*scale*/ 1.f, 0)[0];
1165#else
1166 return fp8_impl::cast_from_f8<half_t, bf8_ocp_t::wm, bf8_ocp_t::we, false>(x.data);
1167#endif
1168}
1169
1176template <>
1178{
1179#if defined(__gfx950__)
1180 return __builtin_amdgcn_cvt_scalef32_pk_f16_bf8(bit_cast<uint16_t>(x), /*scale*/ 1.f, 0);
1181#else
1182 return half2_t{type_convert<half_t>(float(x.AsType<bf8_ocp_t>()[Number<0>{}])),
1183 type_convert<half_t>(float(x.AsType<bf8_ocp_t>()[Number<1>{}]))};
1184#endif
1185}
1186
1193template <>
1195{
1196#if defined(__gfx950__)
1197 union
1198 {
1199 uint16_t i16val;
1200 fp8_storage_t i8val[2];
1201 } input;
1202 input.i8val[0] = x.data;
1203
1204 union
1205 {
1206 bhalf2_t bhalf_vec;
1207 bhalf_t bhalf_arr[2];
1208 } output;
1209 output.bhalf_vec = __builtin_amdgcn_cvt_scalef32_pk_bf16_bf8(input.i16val, /*scale*/ 1.f, 0);
1210
1211 return output.bhalf_arr[0];
1212#else
1213 return type_convert<bhalf_t>(
1214 fp8_impl::cast_from_f8<float, bf8_ocp_t::wm, bf8_ocp_t::we, false>(x.data));
1215#endif
1216}
1217
1224template <>
1226{
1227#if defined(__gfx950__)
1228 return __builtin_amdgcn_cvt_scalef32_pk_bf16_bf8(bit_cast<uint16_t>(x), /*scale*/ 1.f, 0);
1229#else
1230 return bhalf2_t{type_convert<bhalf_t>(float(x.AsType<bf8_ocp_t>()[Number<0>{}])),
1231 type_convert<bhalf_t>(float(x.AsType<bf8_ocp_t>()[Number<1>{}]))};
1232#endif
1233}
1234
1235template <>
1237{
1239
1240 float x_l = ((x_u8 & 0x0f) >> 0) - 8.f;
1241 float x_h = ((x_u8 & 0xf0) >> 4) - 8.f;
1242
1243#ifdef CK_USE_PK4_LAYOUT_SHUFFLE
1244 float2_t res = {x_h, x_l};
1245#elif
1246 float2_t res = {x_l, x_h};
1247#endif
1248 return res;
1249}
1250
1251template <>
1253{
1255#ifdef CK_USE_PK4_LAYOUT_SHUFFLE
1256 uint32_t i4s = ((x_u8 & 0x0f) << 16) | ((x_u8 & 0xf0) >> 4);
1257#else
1258 uint32_t i4s = ((x_u8 & 0xf0) << 12) | (x_u8 & 0xf);
1259#endif
1260
1261 const int EX = 0x64006400;
1262 const int SUB = 0xE408E408; //-8
1263
1264 int lo = i4s | EX;
1265
1266 return details::pk_add_f16(bit_cast<half2_t>(lo), bit_cast<half2_t>(SUB));
1267}
1268
1269template <>
1271{
1273
1274 float x_l = ((x_u8 & 0x0f) >> 0) - 8.f;
1275 float x_h = ((x_u8 & 0xf0) >> 4) - 8.f;
1276
1277#ifdef CK_USE_PK4_LAYOUT_SHUFFLE
1279#else
1281#endif
1282
1283 return res;
1284}
1285
1286template <>
1288{
1289
1290 const vector_type<float, 2> f32x2_v(x);
1291 const auto y = __builtin_amdgcn_cvt_pkrtz(f32x2_v.template AsType<float>()[Number<0>{}],
1292 f32x2_v.template AsType<float>()[Number<1>{}]);
1293 return bit_cast<half2_t>(y);
1294}
1295
1296// convert fp16 to fp8
1297template <>
1299{
1300#if CK_USE_SR_F8_CONVERSION
1301 return f8_convert_sr<f8_fnuz_t>(x);
1302#else
1303 return f8_convert_rne<f8_fnuz_t>(x);
1304#endif
1305}
1306
1313template <>
1315{
1316#if CK_USE_SR_F8_CONVERSION
1317 return f8_convert_sr<f8_ocp_t>(x);
1318#else
1319 return f8_convert_rne<f8_ocp_t>(x);
1320#endif
1321}
1322
1329template <>
1331{
1332#if CK_USE_SR_F8_CONVERSION
1333 return f8_convert_sr<bf8_ocp_t>(x);
1334#else
1335 return f8_convert_rne<bf8_ocp_t>(x);
1336#endif
1337}
1338
1339// convert fp8 to fp16
1340template <>
1342{
1343#if defined(__gfx94__)
1344 // use native conversion to float and convert to fp16
1346#else
1347 constexpr bool negative_zero_nan = true;
1349#endif
1350}
1351
1352// convert fp32 to bf8
1353template <>
1354inline __host__ __device__ bf8_fnuz_t type_convert<bf8_fnuz_t, float>(float x)
1355{
1356#if CK_USE_SR_F8_CONVERSION
1357 return f8_convert_sr<bf8_fnuz_t>(x);
1358#else
1360#endif
1361}
1362
1369template <>
1370inline __host__ __device__ f8_ocp_t type_convert<f8_ocp_t, float>(float x)
1371{
1372#if CK_USE_SR_F8_CONVERSION
1373 return f8_convert_sr<f8_ocp_t>(x);
1374#else
1375 return f8_convert_rne<f8_ocp_t>(x);
1376#endif
1377}
1378
1385template <>
1386inline __host__ __device__ bf8_ocp_t type_convert<bf8_ocp_t, float>(float x)
1387{
1388#if CK_USE_SR_F8_CONVERSION
1389 return f8_convert_sr<bf8_ocp_t>(x);
1390#else
1391 return f8_convert_rne<bf8_ocp_t>(x);
1392#endif
1393}
1394
1401template <>
1403{
1404#if CK_USE_SR_F8_CONVERSION
1405 return f8_convert_sr<f8_ocp_t>(x);
1406#else
1407 return f8_convert_rne<f8_ocp_t>(x);
1408#endif
1409}
1410
1417template <>
1419{
1420#if CK_USE_SR_F8_CONVERSION
1421 return f8_convert_sr<bf8_ocp_t>(x);
1422#else
1423 return f8_convert_rne<bf8_ocp_t>(x);
1424#endif
1425}
1426
1427// convert bf8 to fp32
1428template <>
1429inline __host__ __device__ float type_convert<float, bf8_fnuz_t>(bf8_fnuz_t x)
1430{
1431#if defined(__gfx94__)
1432 float fval;
1433 uint32_t i32val = static_cast<uint32_t>(static_cast<uint8_t>(x));
1434 fval = __builtin_amdgcn_cvt_f32_bf8(i32val, 0);
1435 // asm volatile("v_cvt_f32_bf8 %0, %1 src0_sel:BYTE_0" : "=v"(fval) : "v"(i32val));
1436 return fval;
1437#else
1438 constexpr bool negative_zero_nan = true;
1440#endif
1441}
1442
1443// convert fp16 to bf8
1444template <>
1446{
1447#if CK_USE_SR_F8_CONVERSION
1448 return f8_convert_sr<bf8_fnuz_t>(x);
1449#else
1451#endif
1452}
1453
1454// convert bf8 to fp16
1455template <>
1457{
1458#if defined(__gfx94__)
1459 // use native conversion to float and convert to fp16
1461#else
1462 constexpr bool negative_zero_nan = true;
1464#endif
1465}
1466#ifndef CK_CODE_GEN_RTC
1467// convert fp32 to fp4 with rounding to nearest even
1468inline __host__ __device__ f4_t f4_convert_rne(float x, float scale = 1.0f)
1469{
1470#if defined(__gfx950__)
1471 union
1472 {
1473 uint32_t bitwise;
1474 f4_t f4_array[4];
1475 } value{0};
1476 value.bitwise = __builtin_amdgcn_cvt_scalef32_pk_fp4_f32(value.bitwise, x, x, scale, 0);
1477 return value.f4_array[0];
1478#else
1479 return utils::sat_convert_to_type<f4_t>(x / scale);
1480#endif
1481}
1482
1483// convert vector of 2 fp32 to vector of 2 fp4 with rne
1484inline __host__ __device__ f4x2_t f4_convert_rne(float2_t x, float scale = 1.0f)
1485{
1486#if defined(__gfx950__)
1487 union
1488 {
1489 uint32_t bitwise;
1490 f4x2_t f4x2_array[4];
1491 } value{0};
1492 value.bitwise = __builtin_amdgcn_cvt_scalef32_pk_fp4_f32(value.bitwise, x[0], x[1], scale, 0);
1493 return value.f4x2_array[0];
1494#else
1495 union
1496 {
1497 uint32_t bitwise;
1498 f4x2_t f4x2_array[4];
1499 } value{0};
1500 uint8_t l = utils::sat_convert_to_type<f4_t>(x[0] / scale);
1501 uint8_t h = utils::sat_convert_to_type<f4_t>(x[1] / scale);
1502 value.bitwise = (h << 4) | l;
1503 return value.f4x2_array[0];
1504#endif
1505}
1506
1507// convert vector of 32 fp32 to vector of 32 fp4 with rne
1508inline __host__ __device__ f4x32_t f4_convert_rne(float32_t x, float scale = 1.0f)
1509{
1510#if defined(__gfx950__)
1511 union
1512 {
1513 __uint128_t bitwise;
1514 f4x2_t f4x2_array[16];
1515 f4x32_t f4x32_array;
1516 } f4_values{}, tmp_values{};
1517
1518 ck::static_for<0, 32 / 2, 1>{}([&](auto idx) {
1519 tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_pk_fp4_f32(
1520 tmp_values.bitwise, x[2 * idx], x[2 * idx + 1], scale, 0);
1521 f4_values.f4x2_array[idx] = tmp_values.f4x2_array[0];
1522 });
1523
1524 return f4_values.f4x32_array;
1525#else
1526 union
1527 {
1528 __uint128_t bitwise;
1529 f4x2_t f4x2_array[16];
1530 f4x32_t f4x32_array;
1531 } f4_values{};
1532
1533 f4_t tmp;
1534
1535 ck::static_for<0, 32, 1>{}([&](auto idx) {
1536 tmp = utils::sat_convert_to_type<f4_t>(x[static_cast<int>(idx)] / scale);
1537 f4_values.bitwise <<= 4;
1538 f4_values.bitwise |= tmp;
1539 });
1540
1541 return f4_values.f4x32_array;
1542#endif
1543}
1544
1545// convert fp32 to fp4 with stochastic rounding
1546inline __host__ __device__ f4_t f4_convert_sr(float x, float scale = 1.0f)
1547{
1548#if defined(__gfx950__)
1549 // use HW clock for stochastic input multiply by incremented thread id
1550 uint32_t rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() *
1551 (get_thread_global_1d_id() + 1));
1552 union
1553 {
1554 uint32_t bitwise;
1555 f4_t f4_array[4];
1556 } value{0};
1557 union
1558 {
1559 float float_array[2];
1560 float2_t float2_array;
1561 } float_values{{x}};
1562
1563 value.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32(
1564 value.bitwise, float_values.float2_array, rng, scale, 0);
1565 return value.f4_array[0];
1566#else
1567 constexpr int seed = 1254739;
1568#ifndef CK_CODE_GEN_RTC
1569 uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x);
1570#else
1571 uint32_t rng = prand_generator<float, seed>(reinterpret_cast<size_t>(&x), x);
1572#endif
1573 return utils::sat_convert_to_type_sr<f4_t>(x / scale, rng);
1574#endif
1575}
1576
1577// convert vector of 2 fp32 to vector of 2 fp4 with sr
1578inline __host__ __device__ f4x2_t f4_convert_sr(float2_t x, float scale = 1.0f)
1579{
1580#if defined(__gfx950__)
1581 // use HW clock for stochastic input multiply by incremented thread id
1582 uint32_t rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() *
1583 (get_thread_global_1d_id() + 1));
1584 union
1585 {
1586 uint32_t bitwise;
1587 f4x2_t f4x2_array[4];
1588 } value{0};
1589 value.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32(value.bitwise, x, rng, scale, 0);
1590 return value.f4x2_array[0];
1591#else
1592 constexpr int seed = 1254739;
1593#ifndef CK_CODE_GEN_RTC
1594 uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x[0]);
1595#else
1596 uint32_t rng = prand_generator<float, seed>(reinterpret_cast<size_t>(&x), x[0]);
1597#endif
1598 union
1599 {
1600 uint32_t bitwise;
1601 f4x2_t f4x2_array[4];
1602 } value{0};
1603 uint8_t l = utils::sat_convert_to_type_sr<f4_t>(x[0] / scale, rng);
1604 uint8_t h = utils::sat_convert_to_type_sr<f4_t>(x[1] / scale, rng);
1605 value.bitwise = (h << 4) | l;
1606 return value.f4x2_array[0];
1607#endif
1608}
1609
1610// convert vector of 32 fp32 to vector of 32 fp4 with sr
1611inline __host__ __device__ f4x32_t f4_convert_sr(float32_t x, float scale = 1.0f)
1612{
1613#if defined(__gfx950__)
1614 // use HW clock for stochastic input multiply by incremented thread id
1615 uint32_t rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() *
1616 (get_thread_global_1d_id() + 1));
1617 union
1618 {
1619 __uint128_t bitwise;
1620 f4x2_t f4x2_array[16];
1621 f4x32_t f4x32_array;
1622 } f4_values{0};
1623 union
1624 {
1625 float2_t floatx2_array[16];
1626 float32_t floatx32_array;
1627 } float_values{{0}};
1628 float_values.floatx32_array = x;
1629
1630 ck::static_for<0, 32 / 2, 1>{}([&](auto idx) {
1631 f4_values.f4x2_array[idx] = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32(
1632 f4_values.bitwise, float_values.floatx2_array[idx], rng, scale, 0);
1633 });
1634
1635 return f4_values.f4x32_array;
1636#else
1637 constexpr int seed = 1254739;
1638#ifndef CK_CODE_GEN_RTC
1639 uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x[0]);
1640#else
1641 uint32_t rng = prand_generator<float, seed>(reinterpret_cast<size_t>(&x), x[0]);
1642#endif
1643 union
1644 {
1645 __uint128_t bitwise;
1646 f4x2_t f4x2_array[16];
1647 f4x32_t f4x32_array;
1648 } f4_values{0};
1649
1650 f4_t tmp;
1651
1652 ck::static_for<0, 32, 1>{}([&](auto idx) {
1653 tmp = utils::sat_convert_to_type_sr<f4_t>(x[static_cast<int>(idx)] / scale, rng);
1654 f4_values.bitwise <<= 4;
1655 f4_values.bitwise |= tmp;
1656 });
1657
1658 return f4_values.f4x32_array;
1659#endif
1660}
1661
1662// convert fp32 to fp4
1663template <>
1664inline __host__ __device__ f4_t type_convert<f4_t, float>(float x)
1665{
1666#if CK_USE_SR_F4_CONVERSION
1667 return f4_convert_sr(x);
1668#else
1669 return f4_convert_rne(x);
1670#endif
1671}
1672
1673// convert vector of 2 fp32 to vector of 2 fp4
1674template <>
1676{
1677#if CK_USE_SR_F4_CONVERSION
1678 return f4_convert_sr(x);
1679#else
1680 return f4_convert_rne(x);
1681#endif
1682}
1683template <>
1685{
1686 return static_cast<f4x2_pk_t>(type_convert<f4x2_t>(x));
1687}
1688
1689// convert vector of 32 fp32 to vector of 32 fp4
1690template <>
1692{
1693#if CK_USE_SR_F4_CONVERSION
1694 return f4_convert_sr(x);
1695#else
1696 return f4_convert_rne(x);
1697#endif
1698}
1699
1700// convert fp4 to fp32
1701template <>
1702inline __host__ __device__ float type_convert<float, f4_t>(f4_t x)
1703{
1704#if defined(__gfx950__)
1705 union
1706 {
1707 float float_array[2];
1708 float2_t float2_array;
1709 } float_values{};
1710 float scale = 1.0f;
1711 float_values.float2_array = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(x, scale, 0);
1712 return float_values.float_array[0];
1713#else
1715#endif
1716}
1717
1718// convert vector of 2 fp4 to vector of 2 fp32
1719template <>
1721{
1722#if defined(__gfx950__)
1723 union
1724 {
1725 uint32_t bitwise;
1726 f4x2_t f4x2_array[4];
1727 } value{};
1728 value.f4x2_array[0] = x;
1729 float scale = 1.0f;
1730 return __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(value.bitwise, scale, 0);
1731#else
1732 float2_t ret{
1734 x.template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{})),
1736 x.template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}))};
1737 return ret;
1738#endif
1739}
1740
1741// convert vector of 32 fp4 to vector of 32 fp32
1742template <>
1744{
1745#if defined(__gfx950__)
1746 union
1747 {
1748 f4x32_t f4x32_array;
1749 f4x2_t fp4x2[16];
1750 } value{x};
1751 float2_t op;
1752 float32_t ret;
1753 float scale = 1.0f;
1754
1755 ck::static_for<0, 32 / 2, 1>{}([&](auto idx) {
1756 op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(value.fp4x2[idx], scale, 0);
1757 ret[2 * idx] = op[0];
1758 ret[2 * idx + 1] = op[1];
1759 });
1760
1761 return ret;
1762#else
1763 union
1764 {
1765 float32_t float32_array;
1766 float float_array[32];
1767 } float_values{};
1768 union
1769 {
1770 __uint128_t bitwise;
1771 f4x2_t f4x2_array[16];
1772 f4x32_t f4x32_array;
1773 } f4_values{bit_cast<__uint128_t>(x)};
1774
1775 ck::static_for<0, 32 / 2, 1>{}([&](auto idx) {
1776 float_values.float_array[2 * idx] = utils::to_float<f4_t>(
1778 f4_values.f4x2_array[idx].template AsType<f4x2_pk_t>()[Number<0>{}].template unpack<>(
1779 Number<0>{}));
1780
1781 float_values.float_array[2 * idx + 1] = utils::to_float<f4_t>(
1783 f4_values.f4x2_array[idx].template AsType<f4x2_pk_t>()[Number<0>{}].template unpack<>(
1784 Number<1>{}));
1785 });
1786
1787 return float_values.float32_array;
1788#endif
1789}
1790
1801inline __host__ __device__ f6_t f6_convert_rne(float x, float scale = 1.0f)
1802{
1803#if defined(__gfx950__)
1804 float16_t in1{x};
1805 float16_t in2{};
1806
1807 union
1808 {
1809 f6x32_t f6_vector;
1810 f6_t f6_array[32];
1811 } out{};
1812
1813 out.f6_vector = f6x32_t{__builtin_amdgcn_cvt_scalef32_2xpk16_fp6_f32(in1, in2, scale)};
1814
1815 return out.f6_array[0];
1816#else
1817 return utils::sat_convert_to_type<f6_t>(x / scale);
1818#endif
1819}
1820
1831inline __host__ __device__ f6x32_t f6_convert_rne(float32_t x, float scale = 1.0f)
1832{
1833#if defined(__gfx950__)
1834 float16_t* in1 = reinterpret_cast<float16_t*>(&x);
1835 float16_t* in2 = reinterpret_cast<float16_t*>(&x + 16);
1836 return f6x32_t{__builtin_amdgcn_cvt_scalef32_2xpk16_fp6_f32(*in1, *in2, scale)};
1837#else
1838 union
1839 {
1840 float32_t float_vector;
1841 float float_array[32];
1842 } in{x};
1843
1844 using array_type = uint8_t __attribute__((ext_vector_type(32)));
1845 array_type uint8_array;
1846
1847 // collect the 6-bit values into an array
1848 ck::static_for<0, 32, 1>{}([&](auto i) {
1849 uint8_array[static_cast<index_t>(i)] =
1850 utils::sat_convert_to_type<f6_t>(in.float_array[i] / scale);
1851 });
1852 return f6x32_t{f6x32_pk_t{uint8_array}};
1853#endif
1854}
1855
1866inline __host__ __device__ f6_t f6_convert_sr(float x, float scale = 1.0f)
1867{
1868#if defined(__gfx950__)
1869 // use HW clock for stochastic input multiply by incremented thread id
1870 uint32_t rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() *
1871 (get_thread_global_1d_id() + 1));
1872 union
1873 {
1874 float32_t float_vector;
1875 float float_array[32];
1876 } in{x};
1877
1878 union
1879 {
1880 f6x32_t f6_vector;
1881 f6_t f6_array[32];
1882 } out{};
1883
1884 out.f6_vector =
1885 f6x32_t{__builtin_amdgcn_cvt_scalef32_sr_pk32_fp6_f32(in.float_vector, rng, scale)};
1886
1887 return out.f6_array[0];
1888#else
1889 constexpr int seed = 1254739;
1890#ifndef CK_CODE_GEN_RTC
1891 uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x);
1892#else
1893 uint32_t rng = prand_generator<float, seed>(reinterpret_cast<size_t>(&x), x);
1894#endif
1895 return utils::sat_convert_to_type_sr<f6_t>(x / scale, rng);
1896#endif
1897}
1898
1909inline __host__ __device__ f6x32_t f6_convert_sr(float32_t x, float scale = 1.0f)
1910{
1911#if defined(__gfx950__)
1912 // use HW clock for stochastic input multiply by incremented thread id
1913 uint32_t rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() *
1914 (get_thread_global_1d_id() + 1));
1915 return f6x32_t{__builtin_amdgcn_cvt_scalef32_sr_pk32_fp6_f32(x, rng, scale)};
1916#else
1917 constexpr int seed = 1254739;
1918 union
1919 {
1920 float32_t float_vector;
1921 float float_array[32];
1922 } float_values{x};
1923#ifndef CK_CODE_GEN_RTC
1924 uint32_t rng =
1925 prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), float_values.float_array[0]);
1926#else
1927 uint32_t rng =
1928 prand_generator<float, seed>(reinterpret_cast<size_t>(&x), float_values.float_array[0]);
1929#endif
1930
1931 union
1932 {
1933 float32_t float_vector;
1934 float float_array[32];
1935 } in{x};
1936
1937 union
1938 {
1939 f6x32_t f6_vector;
1940 f6_t f6_array[32];
1941 } out{};
1942
1943 ck::static_for<0, 32, 1>{}([&](auto i) {
1944 out.f6_array[i] = utils::sat_convert_to_type_sr<f6_t>(in.float_array[i] / scale, rng);
1945 });
1946
1947 return out.f6_vector;
1948#endif
1949}
1950
1962template <>
1963inline __host__ __device__ f6_t type_convert<f6_t, float>(float x)
1964{
1965#if CK_USE_SR_F6_CONVERSION
1966 return f6_convert_sr(x);
1967#else
1968 return f6_convert_rne(x);
1969#endif
1970}
1971
1983template <>
1985{
1986#if CK_USE_SR_F6_CONVERSION
1987 return f6_convert_sr(x);
1988#else
1989 return f6_convert_rne(x);
1990#endif
1991}
1992
1993template <>
1995{
1996 return static_cast<f6x32_pk_t>(type_convert<f6x32_t>(x));
1997}
1998
1999template <>
2001{
2002
2003 union
2004 {
2005 float16_t v16x2[2];
2006 float32_t v32;
2007 } in{{x, x}};
2008
2009 union
2010 {
2011 f6x32_t v32;
2012 f6x16_t v16x2[2];
2013 } out{};
2014
2015#if CK_USE_SR_F6_CONVERSION
2016 out.v32 = f6_convert_sr(in.v32);
2017#else
2018 out.v32 = f6_convert_rne(in.v32);
2019#endif
2020
2021 return out.v16x2[0];
2022}
2023
2024template <>
2026{
2027 return static_cast<f6x16_pk_t>(type_convert<f6x16_t>(x));
2028}
2029
2039template <>
2040inline __host__ __device__ float type_convert<float, f6_t>(f6_t x)
2041{
2042#if defined(__gfx950__)
2043 union
2044 {
2045 f6_t f6_array[32];
2046 f6x32_t f6_vector;
2047 } in{{x}};
2048
2049 union
2050 {
2051 float32_t float_vector;
2052 float float_array[32];
2053 } out{};
2054
2055 out.float_vector = __builtin_amdgcn_cvt_scalef32_pk32_f32_fp6(
2056 in.f6_vector.template AsType<f6x32_t::data_t>()[Number<0>{}],
2058 return out.float_array[0];
2059#else
2061#endif
2062}
2063
2073template <>
2075{
2076#if defined(__gfx950__)
2077 return __builtin_amdgcn_cvt_scalef32_pk32_f32_fp6(
2078 x.template AsType<f6x32_t::data_t>()[Number<0>{}],
2080#else
2081 union
2082 {
2083 f6x32_t f6_vector;
2084 f6_t f6_array[32];
2085 } in{x};
2086
2087 union
2088 {
2089 float32_t float_vector;
2090 float float_array[32];
2091 } out{};
2092
2093 ck::static_for<0, 32, 1>{}([&](auto i) {
2094 out.float_array[i] =
2096 });
2097
2098 return out.float_vector;
2099#endif
2100}
2101
2102template <>
2104{
2105 union
2106 {
2107 f6x16_t v16x2[2];
2108 f6x32_t v32;
2109 } in{{x, x}};
2110
2111 union
2112 {
2113 float16_t v16x2[2];
2114 float32_t v32;
2115 } out{};
2116
2117 out.v32 = type_convert<float32_t>(in.v32);
2118 return out.v16x2[0];
2119}
2120
2121template <>
2123{
2124 return type_convert<float16_t>(static_cast<f6x16_t>(x));
2125}
2126
2137inline __host__ __device__ bf6_t bf6_convert_rne(float x, float scale = 1.0f)
2138{
2139#if defined(__gfx950__)
2140 float16_t in1{x};
2141 float16_t in2{};
2142
2143 union
2144 {
2145 bf6x32_t bf6_vector;
2146 bf6_t bf6_array[32];
2147 } out{};
2148
2149 out.bf6_vector = bf6x32_t{__builtin_amdgcn_cvt_scalef32_2xpk16_bf6_f32(in1, in2, scale)};
2150
2151 return out.bf6_array[0];
2152#else
2153 return utils::sat_convert_to_type<bf6_t>(x / scale);
2154#endif
2155}
2156
2168inline __host__ __device__ bf6x32_t bf6_convert_rne(float32_t x, float scale = 1.0f)
2169{
2170#if defined(__gfx950__)
2171 float16_t* in1 = reinterpret_cast<float16_t*>(&x);
2172 float16_t* in2 = reinterpret_cast<float16_t*>(&x + 16);
2173 return bf6x32_t{__builtin_amdgcn_cvt_scalef32_2xpk16_bf6_f32(*in1, *in2, scale)};
2174#else
2175 union
2176 {
2177 float32_t float_vector;
2178 float float_array[32];
2179 } in{x};
2180
2181 using array_type = uint8_t __attribute__((ext_vector_type(32)));
2182 array_type uint8_array;
2183
2184 // collect the 6-bit values into an array
2185 ck::static_for<0, 32, 1>{}([&](auto i) {
2186 uint8_array[static_cast<index_t>(i)] =
2187 utils::sat_convert_to_type<bf6_t>(in.float_array[i] / scale);
2188 });
2189 return bf6x32_t{bf6x32_pk_t{uint8_array}};
2190#endif
2191}
2192
2204inline __host__ __device__ bf6_t bf6_convert_sr(float x, float scale = 1.0f)
2205{
2206#if defined(__gfx950__)
2207 // use HW clock for stochastic input multiply by incremented thread id
2208 uint32_t rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() *
2209 (get_thread_global_1d_id() + 1));
2210 union
2211 {
2212 float32_t float_vector;
2213 float float_array[32];
2214 } in{x};
2215
2216 union
2217 {
2218 bf6x32_t bf6_vector;
2219 bf6_t bf6_array[32];
2220 } out{};
2221
2222 out.bf6_vector =
2223 bf6x32_t{__builtin_amdgcn_cvt_scalef32_sr_pk32_bf6_f32(in.float_vector, rng, scale)};
2224
2225 return out.bf6_array[0];
2226#else
2227 constexpr int seed = 1254739;
2228#ifndef CK_CODE_GEN_RTC
2229 uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x);
2230#else
2231 uint32_t rng = prand_generator<float, seed>(reinterpret_cast<size_t>(&x), x);
2232#endif
2233 return utils::sat_convert_to_type_sr<bf6_t>(x / scale, rng);
2234#endif
2235}
2236
2249inline __host__ __device__ bf6x32_t bf6_convert_sr(float32_t x, float scale = 1.0f)
2250{
2251#if defined(__gfx950__)
2252 // use HW clock for stochastic input multiply by incremented thread id
2253 uint32_t rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() *
2254 (get_thread_global_1d_id() + 1));
2255 return bf6x32_t{__builtin_amdgcn_cvt_scalef32_sr_pk32_bf6_f32(x, rng, scale)};
2256#else
2257 constexpr int seed = 1254739;
2258 union
2259 {
2260 float32_t float_vector;
2261 float float_array[32];
2262 } float_values{x};
2263#ifndef CK_CODE_GEN_RTC
2264 uint32_t rng =
2265 prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), float_values.float_array[0]);
2266#else
2267 uint32_t rng =
2268 prand_generator<float, seed>(reinterpret_cast<size_t>(&x), float_values.float_array[0]);
2269#endif
2270 union
2271 {
2272 float32_t float_vector;
2273 float float_array[32];
2274 } in{x};
2275
2276 union
2277 {
2278 bf6x32_t bf6_vector;
2279 bf6_t bf6_array[32];
2280 } out{};
2281
2282 ck::static_for<0, 32, 1>{}([&](auto i) {
2283 out.bf6_array[i] = utils::sat_convert_to_type_sr<bf6_t>(in.float_array[i] / scale, rng);
2284 });
2285
2286 return out.bf6_vector;
2287#endif
2288}
2289
2299template <>
2300inline __host__ __device__ bf6_t type_convert<bf6_t, float>(float x)
2301{
2302#if CK_USE_SR_F6_CONVERSION
2303 return bf6_convert_sr(x);
2304#else
2305 return bf6_convert_rne(x);
2306#endif
2307}
2308
2318template <>
2320{
2321#if CK_USE_SR_F6_CONVERSION
2322 return bf6_convert_sr(x);
2323#else
2324 return bf6_convert_rne(x);
2325#endif
2326}
2327
2328template <>
2330{
2331 return static_cast<bf6x32_pk_t>(type_convert<bf6x32_t>(x));
2332}
2333
2334template <>
2336{
2337
2338 union
2339 {
2340 float16_t v16x2[2];
2341 float32_t v32;
2342 } in{{x, x}};
2343
2344 union
2345 {
2346 bf6x32_t v32;
2347 bf6x16_t v16x2[2];
2348 } out{};
2349
2350#if CK_USE_SR_F6_CONVERSION
2351 out.v32 = bf6_convert_sr(in.v32);
2352#else
2353 out.v32 = bf6_convert_rne(in.v32);
2354#endif
2355
2356 return out.v16x2[0];
2357}
2358
2359template <>
2361{
2362 return static_cast<bf6x16_pk_t>(type_convert<bf6x16_t>(x));
2363}
2364
2374template <>
2375inline __host__ __device__ float type_convert<float, bf6_t>(bf6_t x)
2376{
2377#if defined(__gfx950__)
2378 union
2379 {
2380 bf6_t bf6_array[32];
2381 bf6x32_t bf6_vector;
2382 } in{{x}};
2383
2384 union
2385 {
2386 float32_t float_vector;
2387 float float_array[32];
2388 } out{};
2389
2390 out.float_vector = __builtin_amdgcn_cvt_scalef32_pk32_f32_bf6(
2391 in.bf6_vector.template AsType<bf6x32_t::data_t>()[Number<0>{}],
2393 return out.float_array[0];
2394#else
2396#endif
2397}
2398
2409template <>
2411{
2412#if defined(__gfx950__)
2413 return __builtin_amdgcn_cvt_scalef32_pk32_f32_bf6(
2414 x.template AsType<bf6x32_t::data_t>()[Number<0>{}],
2416#else
2417 union
2418 {
2419 bf6x32_t bf6_vector;
2420 bf6_t bf6_array[32];
2421 } in{x};
2422
2423 union
2424 {
2425 float32_t float_vector;
2426 float float_array[32];
2427 } out{};
2428
2429 ck::static_for<0, 32, 1>{}([&](auto i) {
2430 out.float_array[i] =
2432 });
2433
2434 return out.float_vector;
2435#endif
2436}
2437
2438template <>
2440{
2441 union
2442 {
2443 bf6x16_t v16x2[2];
2444 bf6x32_t v32;
2445 } in{{x, x}};
2446
2447 union
2448 {
2449 float16_t v16x2[2];
2450 float32_t v32;
2451 } out{};
2452
2453 out.v32 = type_convert<float32_t>(in.v32);
2454 return out.v16x2[0];
2455}
2456
2457template <>
2459{
2460 return type_convert<float16_t>(static_cast<bf6x16_t>(x));
2461}
2462
2463#endif
2464#if !defined(__HIPCC_RTC__) || !defined(CK_CODE_GEN_RTC)
2465template <typename Y, typename X, size_t NumElems>
2466inline __host__ __device__ void array_convert(std::array<Y, NumElems>& y,
2467 const std::array<X, NumElems>& x)
2468{
2469 for(size_t i = 0; i < NumElems; i++)
2470 {
2471 y[i] = type_convert<Y>(x[i]);
2472 }
2473}
2474#endif
2475
2476template <typename Y, typename X, index_t NumElems>
2477inline __host__ __device__ void array_convert(Array<Y, NumElems>& y, const Array<X, NumElems>& x)
2478{
2479 for(size_t i = 0; i < NumElems; i++)
2480 {
2481 y[i] = type_convert<Y>(x[i]);
2482 }
2483}
2484
2485} // namespace ck
Definition utility/type_convert.hpp:23
__host__ __device__ T sat_convert_to_type(float value)
__host__ __device__ f6_t sat_convert_to_type_sr< f6_t >(float value, uint32_t seed)
Converts a float to f6_t with saturation and stochastic rounding.
Definition mxf6_utils.hpp:267
__host__ __device__ bf6_t sat_convert_to_type_sr< bf6_t >(float value, uint32_t seed)
Converts a float to f6_t with saturation and stochastic rounding.
Definition mxf6_utils.hpp:302
__host__ __device__ f6_t sat_convert_to_type< f6_t >(float value)
Converts a float to f6_t with saturation.
Definition mxf6_utils.hpp:191
__host__ __device__ f4_t sat_convert_to_type_sr< f4_t >(float value, uint32_t seed)
Definition mxf4_utils.hpp:84
__host__ __device__ bf6_t sat_convert_to_type< bf6_t >(float value)
Converts a float to bf6_t with saturation.
Definition mxf6_utils.hpp:229
__host__ __device__ Y cast_from_f8(X x)
Definition f8_utils.hpp:284
__host__ __device__ float to_float< bf6_t >(e8m0_bexp_t const scale, bf6_t const data)
Converts an bf6_t value to a float based on an e8m0_bexp_t scale factor.
Definition mxf6_utils.hpp:165
__host__ __device__ Y cast_to_f8(X x, uint32_t rng)
Definition f8_utils.hpp:273
__host__ __device__ float to_float< f4_t >(e8m0_bexp_t const scale, f4_t const data)
Definition mxf4_utils.hpp:40
__host__ __device__ float to_float< f6_t >(e8m0_bexp_t const scale, f6_t const data)
Converts an f6_t value to a float based on an e8m0_bexp_t scale factor.
Definition mxf6_utils.hpp:139
__host__ __device__ f4_t sat_convert_to_type< f4_t >(float value)
Definition mxf4_utils.hpp:56
Definition ck.hpp:268
__host__ __device__ bhalf2_t type_convert< bhalf2_t, f8x2_ocp_t >(f8x2_ocp_t x)
Converts a vector of 2 f8_ocp_t values to a vector of 2 bhalf_t values.
Definition utility/type_convert.hpp:1090
__host__ __device__ float16_t type_convert< float16_t, bf6x16_t >(bf6x16_t x)
Definition utility/type_convert.hpp:2439
typename vector_type< float, 16 >::type float16_t
Definition dtype_vector.hpp:2148
__host__ __device__ bf8x2_ocp_t f8_convert_rne< bf8x2_ocp_t, bhalf2_t >(bhalf2_t x)
Converts a vector of 2 bhalf_t to a vector of 2 8-bit float types (bf8_ocp_t) using rounding to neare...
Definition utility/type_convert.hpp:903
ushort bhalf_t
Definition data_type.hpp:30
__host__ __device__ constexpr int type_convert_sp< int, f8_t >(f8_t x)
Definition utility/type_convert.hpp:260
__host__ __device__ float32_t type_convert< float32_t, f4x32_t >(f4x32_t x)
Definition utility/type_convert.hpp:1743
f6_pk_t< f6_t, 16 > f6x16_pk_t
Definition data_type.hpp:180
__host__ __device__ f8x2_ocp_t f8_convert_rne< f8x2_ocp_t, float2_t >(float2_t x)
Converts a vector of 2 floats to a vector of 2 8-bit float types (f8_ocp_t) using rounding to nearest...
Definition utility/type_convert.hpp:765
__host__ __device__ constexpr int type_convert_sp< int, half_t >(half_t x)
Definition utility/type_convert.hpp:236
f8_fnuz_t f8_t
Definition amd_ck_fp8.hpp:1762
__host__ __device__ bf6x32_t type_convert< bf6x32_t, float32_t >(float32_t x)
Specializes vector of 32 float-to-bf6_t conversion.
Definition utility/type_convert.hpp:2319
__host__ __device__ bf8x2_ocp_t f8_convert_rne< bf8x2_ocp_t, half2_t >(half2_t x)
Converts a vector of 2 half_t to a vector of 2 8-bit float types (bf8_ocp_t) using rounding to neares...
Definition utility/type_convert.hpp:847
__device__ half2_t amd_assembly_pk_add_f16(half2_t a, half2_t b)
Definition amd_inline_asm.hpp:35
__host__ __device__ constexpr Y type_convert_sp(X x)
Definition utility/type_convert.hpp:205
__host__ __device__ bf8_fnuz_t f8_convert_rne< bf8_fnuz_t, half_t >(half_t x)
Definition utility/type_convert.hpp:726
__host__ __device__ constexpr bf8_ocp_t type_convert< bf8_ocp_t, int >(int x)
Definition utility/type_convert.hpp:185
__host__ __device__ constexpr int8_t type_convert< int8_t, bhalf_t >(bhalf_t x)
Definition utility/type_convert.hpp:162
__host__ __device__ f8_ocp_t type_convert< f8_ocp_t, half_t >(half_t x)
Converts a half_t value to a f8_ocp_t value with rounding determined by a flag.
Definition utility/type_convert.hpp:1314
int32_t index_t
Definition ck.hpp:299
__host__ __device__ half_t type_convert< half_t, bf8_ocp_t >(bf8_ocp_t x)
Converts a bf8_ocp_t value to a half_t value.
Definition utility/type_convert.hpp:1155
__host__ __device__ constexpr f8_t type_convert_sp< f8_t, int >(int x)
Definition utility/type_convert.hpp:272
__host__ __device__ f6_t f6_convert_rne(float x, float scale=1.0f)
Converts a float to a 6-bit float type (f6_t) using round-to-nearest-even.
Definition utility/type_convert.hpp:1801
__host__ __device__ f8_ocp_t f8_convert_rne< f8_ocp_t, half_t >(half_t x)
Converts a half_t to a 8-bit float type (f8_ocp_t) using rounding to nearest/even.
Definition utility/type_convert.hpp:805
__host__ __device__ bf8_fnuz_t type_convert< bf8_fnuz_t, half_t >(half_t x)
Definition utility/type_convert.hpp:1445
__host__ __device__ f8x2_ocp_t f8_convert_rne< f8x2_ocp_t, half2_t >(half2_t x)
Converts a vector of 2 half_t to a vector of 2 8-bit float types (f8_ocp_t) using rounding to nearest...
Definition utility/type_convert.hpp:819
__host__ __device__ half_t type_convert< half_t, f8_fnuz_t >(f8_fnuz_t x)
Definition utility/type_convert.hpp:1341
__host__ __device__ f8_ocp_t f8_convert_rne< f8_ocp_t, float >(float x)
Converts a float to a 8-bit float type (f8_ocp_t) using rounding to nearest/even.
Definition utility/type_convert.hpp:751
__host__ __device__ constexpr bhalf_t type_convert< bhalf_t, float >(float x)
Definition utility/type_convert.hpp:133
__host__ __device__ f4_t type_convert< f4_t, float >(float x)
Definition utility/type_convert.hpp:1664
__host__ __device__ bf6_t bf6_convert_sr(float x, float scale=1.0f)
Converts a float to the 6-bit BF6 type using stochastic rounding.
Definition utility/type_convert.hpp:2204
__host__ __device__ bf8_fnuz_t f8_convert_sr< bf8_fnuz_t, float >(float x)
Definition utility/type_convert.hpp:392
__host__ __device__ bf8_ocp_t type_convert< bf8_ocp_t, bhalf_t >(bhalf_t x)
Converts a bhalf_t value to a bf8_ocp_t value with rounding determined by a flag.
Definition utility/type_convert.hpp:1418
__host__ __device__ bf6_t type_convert< bf6_t, float >(float x)
Specializes float-to-bf6_t conversion.
Definition utility/type_convert.hpp:2300
__host__ __device__ constexpr bhalf_t type_convert_sp< bhalf_t, int >(int x)
Definition utility/type_convert.hpp:296
typename vector_type< f8_fnuz_t, 2 >::type f8x2_fnuz_t
Definition dtype_vector.hpp:2184
__host__ __device__ bf6_t bf6_convert_rne(float x, float scale=1.0f)
Converts a float to the 6-bit BF6 type using round-to-nearest-even.
Definition utility/type_convert.hpp:2137
__host__ __device__ float16_t type_convert< float16_t, f6x16_t >(f6x16_t x)
Definition utility/type_convert.hpp:2103
__host__ __device__ constexpr Y bf16_convert_rtn(X x)
__host__ __device__ bhalf2_t type_convert< bhalf2_t, bf8x2_ocp_t >(bf8x2_ocp_t x)
Converts a vector of 2 bf8_ocp_t values to a vector of 2 bhalf_t values.
Definition utility/type_convert.hpp:1225
__host__ __device__ f4x32_t type_convert< f4x32_t, float32_t >(float32_t x)
Definition utility/type_convert.hpp:1691
__host__ __device__ float type_convert< float, bf8_fnuz_t >(bf8_fnuz_t x)
Definition utility/type_convert.hpp:1429
__host__ __device__ f4_t f4_convert_rne(float x, float scale=1.0f)
Definition utility/type_convert.hpp:1468
__host__ __device__ f8_fnuz_t type_convert< f8_fnuz_t, float >(float x)
Definition utility/type_convert.hpp:912
__host__ __device__ constexpr Y f8_convert_rne(X x)
_Float16 half_t
Definition data_type.hpp:31
__host__ __device__ f6_t type_convert< f6_t, float >(float x)
Specializes the type conversion template for converting a float into the 6-bit float type (f6_t).
Definition utility/type_convert.hpp:1963
__host__ __device__ bf6x16_t type_convert< bf6x16_t, float16_t >(float16_t x)
Definition utility/type_convert.hpp:2335
integral_constant< index_t, N > Number
Definition number.hpp:12
__host__ __device__ half2_t type_convert< half2_t, f8x2_ocp_t >(f8x2_ocp_t x)
Converts a vector of 2 f8_ocp_t values to a vector of 2 half_t values.
Definition utility/type_convert.hpp:1042
__host__ __device__ bf8_ocp_t f8_convert_sr< bf8_ocp_t, float >(float x)
Converts a float to a 8-bit float type (bf8_ocp_t) using stochastic rounding.
Definition utility/type_convert.hpp:496
__host__ __device__ f8x2_ocp_t f8_convert_sr< f8x2_ocp_t, float2_t >(float2_t x)
Converts a vector of 2 floats to a vector of 2 8-bit float types (f8_ocp_t) using stochastic rounding...
Definition utility/type_convert.hpp:482
__host__ __device__ f8_ocp_t f8_convert_sr< f8_ocp_t, bhalf_t >(bhalf_t x)
Converts a bhalf_t to a 8-bit float type (f8_ocp_t) using stochastic rounding.
Definition utility/type_convert.hpp:583
f8_rounding_mode
Definition f8_utils.hpp:14
@ stochastic
Definition f8_utils.hpp:16
@ standard
Definition f8_utils.hpp:15
__host__ __device__ f8x2_ocp_t f8_convert_sr< f8x2_ocp_t, bhalf2_t >(bhalf2_t x)
Converts a vector of 2 bhalf_t to a vector of 2 8-bit float types (f8_ocp_t) using stochastic roundin...
Definition utility/type_convert.hpp:598
typename vector_type< bf8_ocp_t, 2 >::type bf8x2_ocp_t
Definition dtype_vector.hpp:2208
__host__ __device__ bf8_ocp_t f8_convert_sr< bf8_ocp_t, half_t >(half_t x)
Converts a half_t to a 8-bit half_t type (bf8_ocp_t) using stochastic rounding.
Definition utility/type_convert.hpp:554
__host__ __device__ float16_t type_convert< float16_t, f6x16_pk_t >(f6x16_pk_t x)
Definition utility/type_convert.hpp:2122
__host__ __device__ bf6x32_pk_t type_convert< bf6x32_pk_t, float32_t >(float32_t x)
Definition utility/type_convert.hpp:2329
__host__ __device__ bf8x2_ocp_t f8_convert_sr< bf8x2_ocp_t, half2_t >(half2_t x)
Converts a vector of 2 half_t to a vector of 2 8-bit float types (bf8_ocp_t) using stochastic roundin...
Definition utility/type_convert.hpp:569
__host__ __device__ constexpr Y f8_convert_sr(X x)
__host__ __device__ f8x2_ocp_t f8_convert_sr< f8x2_ocp_t, half2_t >(half2_t x)
Converts a vector of 2 half_t to a vector of 2 8-bit float types (f8_ocp_t) using stochastic rounding...
Definition utility/type_convert.hpp:540
__host__ __device__ bf8_fnuz_t type_convert< bf8_fnuz_t, float >(float x)
Definition utility/type_convert.hpp:1354
__host__ __device__ bhalf_t type_convert< bhalf_t, bf8_ocp_t >(bf8_ocp_t x)
Converts a bf8_ocp_t value to a bhalf_t value.
Definition utility/type_convert.hpp:1194
__host__ __device__ constexpr half_t type_convert_sp< half_t, float >(float x)
Definition utility/type_convert.hpp:314
__host__ __device__ float2_t type_convert< float2_t, f4x2_t >(f4x2_t x)
Definition utility/type_convert.hpp:1720
__host__ __device__ constexpr auto unpack(F &&f, X &&x)
Definition functional4.hpp:46
__host__ __device__ float32_t type_convert< float32_t, f6x32_t >(f6x32_t x)
Specializes the type conversion template for converting the vector of 32 6-bit float types (f6x32_t) ...
Definition utility/type_convert.hpp:2074
__host__ __device__ f8_ocp_t type_convert< f8_ocp_t, bhalf_t >(bhalf_t x)
Converts a bhalf_t value to a f8_ocp_t value with rounding determined by a flag.
Definition utility/type_convert.hpp:1402
__host__ __device__ f8_fnuz_t f8_convert_rne< f8_fnuz_t, half_t >(half_t x)
Definition utility/type_convert.hpp:672
f6_pk_t< bf6_t, 32 > bf6x32_pk_t
Definition data_type.hpp:183
__host__ __device__ f6x32_pk_t type_convert< f6x32_pk_t, float32_t >(float32_t x)
Definition utility/type_convert.hpp:1994
typename vector_type< float, 2 >::type float2_t
Definition dtype_vector.hpp:2145
typename vector_type< f4x2_pk_t, 1 >::type f4x2_t
Definition dtype_vector.hpp:2258
__host__ __device__ f6x16_t type_convert< f6x16_t, float16_t >(float16_t x)
Definition utility/type_convert.hpp:2000
__host__ __device__ float type_convert< float, f4_t >(f4_t x)
Definition utility/type_convert.hpp:1702
__host__ __device__ f8_fnuz_t f8_convert_sr< f8_fnuz_t, half_t >(half_t x)
Definition utility/type_convert.hpp:367
unsigned _BitInt(4) f4_t
Definition data_type.hpp:33
__host__ __device__ bf8_ocp_t type_convert< bf8_ocp_t, half_t >(half_t x)
Converts a half_t value to a bf8_ocp_t value with rounding determined by a flag.
Definition utility/type_convert.hpp:1330
__host__ __device__ constexpr int type_convert_sp< int, bhalf_t >(bhalf_t x)
Definition utility/type_convert.hpp:284
__host__ __device__ f4_t f4_convert_sr(float x, float scale=1.0f)
Definition utility/type_convert.hpp:1546
_BitInt(6) f6_t
Definition data_type.hpp:34
__device__ index_t get_thread_global_1d_id()
Definition get_id.hpp:43
__host__ __device__ bf8_ocp_t type_convert< bf8_ocp_t, float >(float x)
Converts a float value to a bf8_ocp_t value with rounding determined by a flag.
Definition utility/type_convert.hpp:1386
__host__ __device__ constexpr Y type_convert(X x)
Definition utility/type_convert.hpp:98
__host__ __device__ constexpr float type_convert< float, bhalf_t >(bhalf_t x)
Definition utility/type_convert.hpp:120
typename vector_type< bf6x32_pk_t, 1 >::type bf6x32_t
Definition dtype_vector.hpp:2273
f6_pk_t< f6_t, 32 > f6x32_pk_t
Definition data_type.hpp:181
__host__ __device__ uint32_t prand_generator(index_t id, T val, uint32_t seed=seed_t)
Definition random_gen.hpp:19
__host__ __device__ f8x2_ocp_t f8_convert_rne< f8x2_ocp_t, bhalf2_t >(bhalf2_t x)
Converts a vector of 2 bhalf_t to a vector of 2 8-bit float types (f8_ocp_t) using rounding to neares...
Definition utility/type_convert.hpp:875
__host__ __device__ bf8_ocp_t f8_convert_rne< bf8_ocp_t, bhalf_t >(bhalf_t x)
Converts a bhalf_t to a 8-bit half_t type (bf8_ocp_t) using rounding to nearest/even.
Definition utility/type_convert.hpp:888
__host__ __device__ constexpr bhalf_t type_convert< bhalf_t, half_t >(half_t x)
Definition utility/type_convert.hpp:153
__host__ __device__ f4x2_t type_convert< f4x2_t, float2_t >(float2_t x)
Definition utility/type_convert.hpp:1675
__host__ __device__ constexpr bhalf_t bf16_convert_rtn< bhalf_t, float >(float x)
Definition utility/type_convert.hpp:61
__host__ __device__ half_t type_convert< half_t, bf8_fnuz_t >(bf8_fnuz_t x)
Definition utility/type_convert.hpp:1456
typename vector_type< f8_ocp_t, 2 >::type f8x2_ocp_t
Definition dtype_vector.hpp:2200
__host__ __device__ bf8_ocp_t f8_convert_rne< bf8_ocp_t, half_t >(half_t x)
Converts a half_t to a 8-bit half_t type (bf8_ocp_t) using rounding to nearest/even.
Definition utility/type_convert.hpp:832
__host__ __device__ bhalf2_t type_convert< bhalf2_t, pk_i4_t >(pk_i4_t x)
Definition utility/type_convert.hpp:1270
__host__ __device__ f4x2_pk_t type_convert< f4x2_pk_t, float2_t >(float2_t x)
Definition utility/type_convert.hpp:1684
__host__ __device__ float type_convert< float, bf6_t >(bf6_t x)
Specializes the type conversion template for converting a bf6_t value to float.
Definition utility/type_convert.hpp:2375
__host__ __device__ f6x16_pk_t type_convert< f6x16_pk_t, float16_t >(float16_t x)
Definition utility/type_convert.hpp:2025
__host__ __device__ constexpr float type_convert_sp< float, int >(int x)
Definition utility/type_convert.hpp:224
typename vector_type< half_t, 2 >::type half2_t
Definition dtype_vector.hpp:2153
__host__ __device__ bf8_fnuz_t f8_convert_rne< bf8_fnuz_t, float >(float x)
Definition utility/type_convert.hpp:692
__host__ __device__ float type_convert< float, f8_fnuz_t >(f8_fnuz_t x)
Definition utility/type_convert.hpp:923
__host__ __device__ constexpr bhalf_t type_convert_sp< bhalf_t, float >(float x)
Definition utility/type_convert.hpp:308
__host__ __device__ void array_convert(std::array< Y, NumElems > &y, const std::array< X, NumElems > &x)
Definition utility/type_convert.hpp:2466
typename vector_type< float, 32 >::type float32_t
Definition dtype_vector.hpp:2149
unsigned _BitInt(6) bf6_t
Definition data_type.hpp:35
__host__ __device__ half_t type_convert< half_t, f8_ocp_t >(f8_ocp_t x)
Converts a f8_ocp_t value to a half_t value.
Definition utility/type_convert.hpp:1012
typename vector_type< f6x32_pk_t, 1 >::type f6x32_t
Definition dtype_vector.hpp:2268
typename vector_type< bhalf_t, 2 >::type bhalf2_t
Definition dtype_vector.hpp:2160
__host__ __device__ float type_convert< float, bf8_ocp_t >(bf8_ocp_t x)
Converts a bf8_ocp_t value to a float value.
Definition utility/type_convert.hpp:1107
__host__ __device__ half2_t type_convert< half2_t, pk_i4_t >(pk_i4_t x)
Definition utility/type_convert.hpp:1252
__host__ __device__ half2_t type_convert< half2_t, bf8x2_ocp_t >(bf8x2_ocp_t x)
Converts a vector of 2 bf8_ocp_t values to a vector of 2 half_t values.
Definition utility/type_convert.hpp:1177
__host__ __device__ float2_t type_convert< float2_t, bf8x2_ocp_t >(bf8x2_ocp_t x)
Converts a vector of 2 bf8_ocp_t values to a vector of 2 float values.
Definition utility/type_convert.hpp:1129
__host__ __device__ constexpr half_t type_convert_sp< half_t, int >(int x)
Definition utility/type_convert.hpp:248
__host__ __device__ float2_t type_convert< float2_t, f8x2_ocp_t >(f8x2_ocp_t x)
Converts a vector of 2 f8_ocp_t values to a vector of 2 float values.
Definition utility/type_convert.hpp:986
__host__ __device__ float2_t type_convert< float2_t, pk_i4_t >(pk_i4_t x)
Definition utility/type_convert.hpp:1236
__host__ __device__ constexpr f8_ocp_t type_convert< f8_ocp_t, int >(int x)
Definition utility/type_convert.hpp:179
__host__ __device__ constexpr int type_convert_sp< int, float >(float x)
Definition utility/type_convert.hpp:212
__host__ __device__ float32_t type_convert< float32_t, bf6x32_t >(bf6x32_t x)
Specializes the type conversion template for converting a vector of 32 bf6_t values to vector of 32 f...
Definition utility/type_convert.hpp:2410
__host__ __device__ bf8_ocp_t f8_convert_sr< bf8_ocp_t, bhalf_t >(bhalf_t x)
Converts a bhalf_t to a 8-bit half_t type (bf8_ocp_t) using stochastic rounding.
Definition utility/type_convert.hpp:612
__host__ __device__ f8_fnuz_t f8_convert_rne< f8_fnuz_t, float >(float x)
Definition utility/type_convert.hpp:640
__host__ __device__ constexpr half_t type_convert< half_t, bhalf_t >(bhalf_t x)
Definition utility/type_convert.hpp:144
__host__ __device__ f8_fnuz_t type_convert< f8_fnuz_t, half_t >(half_t x)
Definition utility/type_convert.hpp:1298
__host__ __device__ float type_convert< float, f8_ocp_t >(f8_ocp_t x)
Converts a f8_ocp_t value to a float value.
Definition utility/type_convert.hpp:964
__host__ __device__ bf8x2_ocp_t f8_convert_rne< bf8x2_ocp_t, float2_t >(float2_t x)
Converts a vector of 2 floats to a vector of 2 8-bit float types (bf8_ocp_t) using rounding to neares...
Definition utility/type_convert.hpp:792
__host__ __device__ f8_ocp_t f8_convert_rne< f8_ocp_t, bhalf_t >(bhalf_t x)
Converts a bhalf_t to a 8-bit float type (f8_ocp_t) using rounding to nearest/even.
Definition utility/type_convert.hpp:861
typename vector_type< f4x2_pk_t, 16 >::type f4x32_t
Definition dtype_vector.hpp:2262
__host__ __device__ float type_convert< float, f6_t >(f6_t x)
Specializes the type conversion template for converting the 6-bit float type (f6_t) to float.
Definition utility/type_convert.hpp:2040
__host__ __device__ f8_ocp_t type_convert< f8_ocp_t, float >(float x)
Converts a float value to a f8_ocp_t value with rounding determined by a flag.
Definition utility/type_convert.hpp:1370
__host__ __device__ constexpr Y bit_cast(const X &x)
Definition type.hpp:306
__host__ __device__ bf8x2_ocp_t f8_convert_sr< bf8x2_ocp_t, bhalf2_t >(bhalf2_t x)
Converts a vector of 2 bhalf_t to a vector of 2 8-bit float types (bf8_ocp_t) using stochastic roundi...
Definition utility/type_convert.hpp:627
__host__ __device__ bf6x16_pk_t type_convert< bf6x16_pk_t, float16_t >(float16_t x)
Definition utility/type_convert.hpp:2360
__host__ __device__ f6x32_t type_convert< f6x32_t, float32_t >(float32_t x)
Specializes the type conversion template for converting a vector of 32 floats into the vector of 32 6...
Definition utility/type_convert.hpp:1984
__host__ __device__ half2_t type_convert< half2_t, float2_t >(float2_t x)
Definition utility/type_convert.hpp:1287
typename vector_type< bf6x16_pk_t, 1 >::type bf6x16_t
Definition dtype_vector.hpp:2271
__host__ __device__ f6_t f6_convert_sr(float x, float scale=1.0f)
Converts a float to the 6-bit floating-point type (f6_t) using stochastic rounding.
Definition utility/type_convert.hpp:1866
__host__ __device__ bf8_ocp_t f8_convert_rne< bf8_ocp_t, float >(float x)
Converts a float to a 8-bit float type (bf8_ocp_t) using rounding to nearest/even.
Definition utility/type_convert.hpp:778
__host__ __device__ bhalf_t type_convert< bhalf_t, f8_ocp_t >(f8_ocp_t x)
Converts a f8_ocp_t value to a bhalf_t value.
Definition utility/type_convert.hpp:1059
__host__ __device__ f8_ocp_t f8_convert_sr< f8_ocp_t, half_t >(half_t x)
Converts a half_t to a 8-bit float type (f8_ocp_t) using stochastic rounding.
Definition utility/type_convert.hpp:525
__host__ __device__ constexpr bhalf_t bf16_convert_rtn< bhalf_t, half_t >(half_t x)
Definition utility/type_convert.hpp:87
__host__ __device__ bf8_fnuz_t f8_convert_sr< bf8_fnuz_t, half_t >(half_t x)
Definition utility/type_convert.hpp:437
typename std::enable_if< B, T >::type enable_if_t
Definition enable_if.hpp:27
__host__ __device__ f8_fnuz_t f8_convert_sr< f8_fnuz_t, float >(float x)
Definition utility/type_convert.hpp:324
f6_pk_t< bf6_t, 16 > bf6x16_pk_t
Definition data_type.hpp:182
__host__ __device__ float16_t type_convert< float16_t, bf6x16_pk_t >(bf6x16_pk_t x)
Definition utility/type_convert.hpp:2458
__host__ __device__ bf8x2_ocp_t f8_convert_sr< bf8x2_ocp_t, float2_t >(float2_t x)
Converts a vector of 2 floats to a vector of 2 8-bit float types (bf8_ocp_t) using stochastic roundin...
Definition utility/type_convert.hpp:511
__host__ __device__ constexpr bhalf_t type_convert< bhalf_t, int8_t >(int8_t x)
Definition utility/type_convert.hpp:171
typename vector_type< f6x16_pk_t, 1 >::type f6x16_t
Definition dtype_vector.hpp:2266
__host__ __device__ f8_ocp_t f8_convert_sr< f8_ocp_t, float >(float x)
Converts a float to a 8-bit float type (f8_ocp_t) using stochastic rounding.
Definition utility/type_convert.hpp:467
unsigned char fp8_storage_t
Definition amd_ck_fp8.hpp:64
__host__ __device__ float2_t type_convert< float2_t, f8x2_fnuz_t >(f8x2_fnuz_t x)
Definition utility/type_convert.hpp:938
const GenericPointer< typename T::ValueType > T2 value
Definition pointer.h:1697
unsigned short uint16_t
Definition stdint.h:125
_W64 unsigned int uintptr_t
Definition stdint.h:164
unsigned int uint32_t
Definition stdint.h:126
unsigned char uint8_t
Definition stdint.h:124
signed char int8_t
Definition stdint.h:121
Definition utility/array.hpp:14
Definition numeric_limits.hpp:309
Definition numeric_utils.hpp:10
Definition amd_ck_fp8.hpp:49
Definition amd_ck_fp8.hpp:369
data_type data
Definition amd_ck_fp8.hpp:371
static constexpr ck_fp8_interpretation_t default_interpret
Definition amd_ck_fp8.hpp:374
static constexpr ck_saturation_t default_saturation
Definition amd_ck_fp8.hpp:373
Definition data_type.hpp:42
Definition amd_ck_fp8.hpp:36
Definition amd_ck_fp8.hpp:323
data_type data
Definition amd_ck_fp8.hpp:325
static constexpr ck_fp8_interpretation_t default_interpret
Definition amd_ck_fp8.hpp:328
static constexpr ck_saturation_t default_saturation
Definition amd_ck_fp8.hpp:327
Definition data_type.hpp:187
Definition functional2.hpp:33
Definition dtype_vector.hpp:10