mxf8_utils.hpp Source File

mxf8_utils.hpp Source File#

Composable Kernel: mxf8_utils.hpp Source File
mxf8_utils.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
3
6
7#if defined(__gfx950__) && __HIP_DEVICE_COMPILE__
8#define CK_MX_FP8_CVT_FAST_PATH 1
9#else
10#define CK_MX_FP8_CVT_FAST_PATH 0
11#endif
12
13namespace ck {
14
15namespace fp8_impl {
16#if CK_MX_FP8_CVT_FAST_PATH
17template <ck_fp8_interpretation_t interpret>
18static __device__ float cast_to_f32_from_f8_scaled(float scale, fp8_storage_t v)
19{
20 union
21 {
22 unsigned int i32val;
23 unsigned char i8val[4];
24 } val;
25 val.i8val[0] = v;
26
27 static_assert(interpret == ck_fp8_interpretation_t::CK_E4M3_OCP ||
29 "Only OCP interpretations are supported");
30
31 if constexpr(interpret == ck_fp8_interpretation_t::CK_E4M3_OCP)
32 {
33 return __builtin_amdgcn_cvt_scalef32_f32_fp8(val.i32val, scale, 0);
34 }
35 else
36 {
37 return __builtin_amdgcn_cvt_scalef32_f32_bf8(val.i32val, scale, 0);
38 }
39}
40
41template <ck_fp8_interpretation_t interpret>
42static __device__ float2_t cast_to_f32_from_f8_scaled(float scale, fp8x2_storage_t v)
43{
44 const auto i16val = bit_cast<uint16_t>(v);
45
46 static_assert(interpret == ck_fp8_interpretation_t::CK_E4M3_OCP ||
48 "Only OCP interpretations are supported");
49
50 if constexpr(interpret == ck_fp8_interpretation_t::CK_E4M3_OCP)
51 {
52 return __builtin_amdgcn_cvt_scalef32_pk_f32_fp8(i16val, scale, 0);
53 }
54 else
55 {
56 return __builtin_amdgcn_cvt_scalef32_pk_f32_bf8(i16val, scale, 0);
57 }
58}
59
60template <ck_fp8_interpretation_t interpret, bool stochastic_rounding = false>
61static __device__ fp8_storage_t cast_to_f8_from_f32_scaled(float v,
62 unsigned int rng = 0,
63 float scale = 1.0f)
64{
65 fp8_storage_t i8data;
66 union
67 {
68 float fval;
69 unsigned int i32val;
70 } val;
71
72 union
73 {
74 uint32_t ival;
75 vector_type<int16_t, 2>::type v2i16;
76 fp8_storage_t v4i8[4];
77 } ret{};
78
79 // unsigned int ival = 0;
80 val.fval = v;
81
82 if constexpr(stochastic_rounding)
83 {
84 ret.ival =
86 ? __builtin_amdgcn_cvt_scalef32_sr_fp8_f32(ret.ival, val.fval, rng, scale, 0)
87 : __builtin_amdgcn_cvt_scalef32_sr_bf8_f32(ret.ival, val.fval, rng, scale, 0);
88
89 i8data = ret.v4i8[0];
90 }
91 else
92 {
93 // RNE CVT
94 // llvm.amdgcn.cvt.scalef32.pk.fp8.f32
95 // v2i16 old_vdst, float srcA, float srcB, float scale, bool dst_lo_hi_sel
96 if constexpr(interpret == ck_fp8_interpretation_t::CK_E4M3_OCP)
97 {
98 // If fval / scale > max fp8, returns Nan
99 ret.v2i16 = __builtin_amdgcn_cvt_scalef32_pk_fp8_f32(/*old_vdst*/ ret.v2i16,
100 val.fval,
101 val.fval,
102 scale,
103 /*dst_lo_hi_sel*/ false);
104 }
105 else
106 {
107 // If fval / scale > max bf8, returns Inf
108 ret.v2i16 = __builtin_amdgcn_cvt_scalef32_pk_bf8_f32(/*old_vdst*/ ret.v2i16,
109 val.fval,
110 val.fval,
111 scale,
112 /*dst_lo_hi_sel*/ false);
113 }
114
115 i8data = ret.v4i8[0];
116 }
117 return i8data;
118}
119
120template <ck_fp8_interpretation_t interpret, bool stochastic_rounding = false>
121static __device__ fp8x2_storage_t cast_to_f8_from_f32_scaled(float2_t v,
122 unsigned int rng = 0,
123 float scale = 1.0f)
124{
125
126 union
127 {
128 uint32_t ival;
129 vector_type<int16_t, 2>::type v2i16;
130 StaticallyIndexedArray<fp8x2_storage_t, 2> v2f8x2;
131 } ret{};
132
133 if constexpr(stochastic_rounding)
134 {
135 fp8x2_storage_t f8x2;
136 if constexpr(interpret == ck_fp8_interpretation_t::CK_E4M3_OCP)
137 {
138 ret.ival = __builtin_amdgcn_cvt_scalef32_sr_fp8_f32(ret.ival, v[0], rng, scale, 0);
139 f8x2[0] = ret.v2f8x2(Number<0>{})[0];
140 ret.ival = __builtin_amdgcn_cvt_scalef32_sr_fp8_f32(ret.ival, v[1], rng, scale, 0);
141 f8x2[1] = ret.v2f8x2(Number<0>{})[0];
142 }
143 else
144 {
145 ret.ival = __builtin_amdgcn_cvt_scalef32_sr_bf8_f32(ret.ival, v[0], rng, scale, 0);
146 f8x2[0] = ret.v2f8x2(Number<0>{})[0];
147 ret.ival = __builtin_amdgcn_cvt_scalef32_sr_bf8_f32(ret.ival, v[1], rng, scale, 0);
148 f8x2[1] = ret.v2f8x2(Number<0>{})[0];
149 }
150 return f8x2;
151 }
152 else
153 {
154 // RNE CVT
155 // llvm.amdgcn.cvt.scalef32.pk.fp8.f32
156 // v2i16 old_vdst, float srcA, float srcB, float scale, bool dst_lo_hi_sel
157 if constexpr(interpret == ck_fp8_interpretation_t::CK_E4M3_OCP)
158 {
159 // If fval / scale > max fp8, returns Nan
160 ret.v2i16 = __builtin_amdgcn_cvt_scalef32_pk_fp8_f32(/*old_vdst*/ ret.v2i16,
161 v[0],
162 v[1],
163 scale,
164 /*dst_lo_hi_sel*/ false);
165 }
166 else
167 {
168 // If fval / scale > max bf8, returns Inf
169 ret.v2i16 = __builtin_amdgcn_cvt_scalef32_pk_bf8_f32(/*old_vdst*/ ret.v2i16,
170 v[0],
171 v[1],
172 scale,
173 /*dst_lo_hi_sel*/ false);
174 }
175
176 return ret.v2f8x2(Number<0>{});
177 }
178}
179
180#endif // CK_MX_FP8_CVT_FAST_PATH
181
182#if CK_MX_FP8_CVT_FAST_PATH
193template <ck_fp8_interpretation_t interp, bool stochastic_rounding = false>
194__host__ __device__ static inline fp8_storage_t cvt_float_to_fp8_scaled(const float f, float scale)
195{
196 __is_interpret_supported(interp);
197 uint32_t rng = 0;
198 if constexpr(stochastic_rounding)
199 {
200 // use HW clock for stochastic input multiply by incremented thread id
201 rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() *
203 }
204 return cast_to_f8_from_f32_scaled<interp, stochastic_rounding>(f, rng, scale);
205}
206
217template <ck_fp8_interpretation_t interp, bool stochastic_rounding = false>
218__host__ __device__ static inline fp8x2_storage_t cvt_float_to_fp8_scaled(const float2_t f,
219 float scale)
220{
221 __is_interpret_supported(interp);
222 uint32_t rng = 0;
223 if constexpr(stochastic_rounding)
224 {
225 // use HW clock for stochastic input multiply by incremented thread id
226 rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() *
228 }
229 return cast_to_f8_from_f32_scaled<interp, stochastic_rounding>(f, rng, scale);
230}
231
232#else
233
244template <ck_fp8_interpretation_t interp, bool stochastic_rounding = false>
245__host__ __device__ static inline fp8_storage_t cvt_float_to_fp8_scaled(const float f, float scale)
246{
247
248 static_assert(interp == ck_fp8_interpretation_t::CK_E4M3_OCP ||
250 "Only OCP interpretations are supported");
251
252 uint32_t rng = 0;
253 if constexpr(stochastic_rounding)
254 {
255 constexpr int seed = 1254739;
256 rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&f), f);
257 }
258
259 if constexpr(interp == ck_fp8_interpretation_t::CK_E4M3_OCP)
260 {
261 return cast_to_f8<float, 3, 4, false, true, stochastic_rounding>(f / scale, rng);
262 }
263 else if constexpr(interp == ck_fp8_interpretation_t::CK_E5M2_OCP)
264 {
265 return cast_to_f8<float, 2, 5, false, true, stochastic_rounding>(f / scale, rng);
266 }
267 else
268 {
269 __hip_assert(false && "FP8 type is not supported by current target device");
270 return 0;
271 }
272}
273
284template <ck_fp8_interpretation_t interp, bool stochastic_rounding = false>
285__host__ __device__ static inline fp8x2_storage_t cvt_float_to_fp8_scaled(const float2_t f,
286 float scale)
287{
288
289 static_assert(interp == ck_fp8_interpretation_t::CK_E4M3_OCP ||
291 "Only OCP interpretations are supported");
292
293 uint32_t rng = 0;
294 if constexpr(stochastic_rounding)
295 {
296 constexpr int seed = 1254739;
297 rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&f), f[0]);
298 }
299
300 if constexpr(interp == ck_fp8_interpretation_t::CK_E4M3_OCP)
301 {
302 return {cast_to_f8<float, 3, 4, false, true, stochastic_rounding>(f[0] / scale, rng),
303 cast_to_f8<float, 3, 4, false, true, stochastic_rounding>(f[1] / scale, rng)};
304 }
305 else if constexpr(interp == ck_fp8_interpretation_t::CK_E5M2_OCP)
306 {
307 return {cast_to_f8<float, 2, 5, false, true, stochastic_rounding>(f[0] / scale, rng),
308 cast_to_f8<float, 2, 5, false, true, stochastic_rounding>(f[1] / scale, rng)};
309 }
310 else
311 {
312 __hip_assert(false && "FP8 type is not supported by current target device");
313 return 0;
314 }
315}
316
317#endif // CK_MX_FP8_CVT_FAST_PATH
318
319} // namespace fp8_impl
320
321// Declare a template function for fp8 conversion using SR
322template <typename Y, typename X>
323__host__ __device__ constexpr Y mxf8_convert_sr(X x, float scale);
324
325// Declare a template function for fp8 conversion using RNE
326template <typename Y, typename X>
327__host__ __device__ constexpr Y mxf8_convert_rne(X x, float scale);
328
329// convert fp32 to fp8 with rounding to nearest even
330template <>
331inline __host__ __device__ f8_ocp_t mxf8_convert_rne<f8_ocp_t, float>(float x, float scale)
332{
333 return f8_ocp_t{fp8_impl::cvt_float_to_fp8_scaled<f8_ocp_t::default_interpret>(x, scale)};
334}
335
336// convert fp32 to bf8 with rounding to nearest even
337template <>
338inline __host__ __device__ bf8_ocp_t mxf8_convert_rne<bf8_ocp_t, float>(float x, float scale)
339{
340 return bf8_ocp_t{fp8_impl::cvt_float_to_fp8_scaled<bf8_ocp_t::default_interpret>(x, scale)};
341}
342
343// convert fp32x2 to fp8x2 with rounding to nearest even
344template <>
346 float scale)
347{
348 return f8x2_ocp_t{fp8_impl::cvt_float_to_fp8_scaled<f8_ocp_t::default_interpret>(x, scale)};
349}
350
351// convert fp32x2 to bf8x2 with rounding to nearest even
352template <>
354 float scale)
355{
356 return bf8x2_ocp_t{fp8_impl::cvt_float_to_fp8_scaled<bf8_ocp_t::default_interpret>(x, scale)};
357}
358
359// convert fp32x16 to fp8x16 with rounding to nearest even
360template <>
362 float scale)
363{
364 union
365 {
366 float16_t float_1x16;
367 float2_t float_2x8[8];
368 } in{x};
369
370 union
371 {
372 f8x16_ocp_t fp8_1x16;
373 f8x2_ocp_t fp8_2x8[8];
374 } out{};
375
377 [&](auto i) { out.fp8_2x8[i] = mxf8_convert_rne<f8x2_ocp_t>(in.float_2x8[i], scale); });
378
379 return out.fp8_1x16;
380}
381
382// convert fp32x16 to bf8x16 with rounding to nearest even
383template <>
385 float scale)
386{
387 union
388 {
389 float16_t float_1x16;
390 float2_t float_2x8[8];
391 } in{x};
392
393 union
394 {
395 bf8x16_ocp_t bf8_1x16;
396 bf8x2_ocp_t bf8_2x8[8];
397 } out{};
398
400 [&](auto i) { out.bf8_2x8[i] = mxf8_convert_rne<bf8x2_ocp_t>(in.float_2x8[i], scale); });
401
402 return out.bf8_1x16;
403}
404
405// convert fp32x32 to fp8x32 with rounding to nearest even
406template <>
408 float scale)
409{
410 union
411 {
412 float32_t float_1x32;
413 float16_t float_16x2[2];
414 } in{x};
415
416 union
417 {
418 f8x32_ocp_t fp8_1x32;
419 f8x16_ocp_t fp8_16x2[2];
420 } out{};
421
423 [&](auto i) { out.fp8_16x2[i] = mxf8_convert_rne<f8x16_ocp_t>(in.float_16x2[i], scale); });
424
425 return out.fp8_1x32;
426}
427
428// convert fp32x32 to bf8x32 with rounding to nearest even
429template <>
431 float scale)
432{
433 union
434 {
435 float32_t float_1x32;
436 float16_t float_16x2[2];
437 } in{x};
438
439 union
440 {
441 bf8x32_ocp_t bf8_1x32;
442 bf8x16_ocp_t bf8_16x2[2];
443 } out{};
444
446 [&](auto i) { out.bf8_16x2[i] = mxf8_convert_rne<bf8x16_ocp_t>(in.float_16x2[i], scale); });
447
448 return out.bf8_1x32;
449}
450
451// convert fp32 to fp8 with stochastic rounding
452template <>
453inline __host__ __device__ f8_ocp_t mxf8_convert_sr<f8_ocp_t, float>(float x, float scale)
454{
455 return f8_ocp_t{fp8_impl::cvt_float_to_fp8_scaled<f8_ocp_t::default_interpret, true>(x, scale)};
456}
457
458// convert fp32 to bf8 with stochastic rounding
459template <>
460inline __host__ __device__ bf8_ocp_t mxf8_convert_sr<bf8_ocp_t, float>(float x, float scale)
461{
462 return bf8_ocp_t{
463 fp8_impl::cvt_float_to_fp8_scaled<bf8_ocp_t::default_interpret, true>(x, scale)};
464}
465
466// convert fp32x2 to fp8x2 with stochastic rounding
467template <>
468inline __host__ __device__ f8x2_ocp_t mxf8_convert_sr<f8x2_ocp_t, float2_t>(float2_t x, float scale)
469{
470 return f8x2_ocp_t{
471 fp8_impl::cvt_float_to_fp8_scaled<f8_ocp_t::default_interpret, true>(x, scale)};
472}
473
474// convert fp32x2 to bf8x2 with stochastic rounding
475template <>
477 float scale)
478{
479 return bf8x2_ocp_t{
480 fp8_impl::cvt_float_to_fp8_scaled<bf8_ocp_t::default_interpret, true>(x, scale)};
481}
482
483// convert fp32x16 to fp8x16 with stochastic rounding
484template <>
486 float scale)
487{
488 union
489 {
490 float16_t float_1x16;
491 float2_t float_2x8[8];
492 } in{x};
493
494 union
495 {
496 f8x16_ocp_t fp8_1x16;
497 f8x2_ocp_t fp8_2x8[8];
498 } out{};
499
501 [&](auto i) { out.fp8_2x8[i] = mxf8_convert_sr<f8x2_ocp_t>(in.float_2x8[i], scale); });
502
503 return out.fp8_1x16;
504}
505
506// convert fp32x16 to bf8x16 with stochastic rounding
507template <>
509 float scale)
510{
511 union
512 {
513 float16_t float_1x16;
514 float2_t float_2x8[8];
515 } in{x};
516
517 union
518 {
519 bf8x16_ocp_t bf8_1x16;
520 bf8x2_ocp_t bf8_2x8[8];
521 } out{};
522
524 [&](auto i) { out.bf8_2x8[i] = mxf8_convert_sr<bf8x2_ocp_t>(in.float_2x8[i], scale); });
525
526 return out.bf8_1x16;
527}
528
529// convert fp32x32 to fp8x32 with stochastic rounding
530template <>
532 float scale)
533{
534 union
535 {
536 float32_t float_1x32;
537 float16_t float_16x2[2];
538 } in{x};
539
540 union
541 {
542 f8x32_ocp_t fp8_1x32;
543 f8x16_ocp_t fp8_16x2[2];
544 } out{};
545
547 [&](auto i) { out.fp8_16x2[i] = mxf8_convert_sr<f8x16_ocp_t>(in.float_16x2[i], scale); });
548
549 return out.fp8_1x32;
550}
551
552// convert fp32x32 to bf8x32 with stochastic rounding
553template <>
555 float scale)
556{
557 union
558 {
559 float32_t float_1x32;
560 float16_t float_16x2[2];
561 } in{x};
562
563 union
564 {
565 bf8x32_ocp_t bf8_1x32;
566 bf8x16_ocp_t bf8_16x2[2];
567 } out{};
568
570 [&](auto i) { out.bf8_16x2[i] = mxf8_convert_sr<bf8x16_ocp_t>(in.float_16x2[i], scale); });
571
572 return out.bf8_1x32;
573}
574
575} // namespace ck
float float2_t
Definition amd_ck_fp8.hpp:92
fp8_storage_t fp8x2_storage_t
Definition amd_ck_fp8.hpp:88
Definition ck.hpp:268
typename vector_type< float, 16 >::type float16_t
Definition dtype_vector.hpp:2148
__host__ __device__ f8x16_ocp_t mxf8_convert_sr< f8x16_ocp_t, float16_t >(float16_t x, float scale)
Definition mxf8_utils.hpp:485
__host__ __device__ f8x2_ocp_t mxf8_convert_rne< f8x2_ocp_t, float2_t >(float2_t x, float scale)
Definition mxf8_utils.hpp:345
__host__ __device__ constexpr Y mxf8_convert_rne(X x, float scale)
__host__ __device__ f8_ocp_t mxf8_convert_rne< f8_ocp_t, float >(float x, float scale)
Definition mxf8_utils.hpp:331
__host__ __device__ bf8_ocp_t mxf8_convert_sr< bf8_ocp_t, float >(float x, float scale)
Definition mxf8_utils.hpp:460
@ CK_E4M3_OCP
Definition amd_ck_fp8.hpp:71
@ CK_E5M2_OCP
Definition amd_ck_fp8.hpp:72
typename vector_type< f8_ocp_t, 32 >::type f8x32_ocp_t
Definition dtype_vector.hpp:2204
__host__ __device__ f8_ocp_t mxf8_convert_sr< f8_ocp_t, float >(float x, float scale)
Definition mxf8_utils.hpp:453
integral_constant< index_t, N > Number
Definition number.hpp:12
typename vector_type< bf8_ocp_t, 32 >::type bf8x32_ocp_t
Definition dtype_vector.hpp:2212
__host__ __device__ bf8x32_ocp_t mxf8_convert_rne< bf8x32_ocp_t, float32_t >(float32_t x, float scale)
Definition mxf8_utils.hpp:430
typename vector_type< bf8_ocp_t, 2 >::type bf8x2_ocp_t
Definition dtype_vector.hpp:2208
__host__ __device__ bf8_ocp_t mxf8_convert_rne< bf8_ocp_t, float >(float x, float scale)
Definition mxf8_utils.hpp:338
typename vector_type< float, 2 >::type float2_t
Definition dtype_vector.hpp:2145
__host__ __device__ bf8x16_ocp_t mxf8_convert_sr< bf8x16_ocp_t, float16_t >(float16_t x, float scale)
Definition mxf8_utils.hpp:508
__device__ index_t get_thread_global_1d_id()
Definition get_id.hpp:43
__host__ __device__ uint32_t prand_generator(index_t id, T val, uint32_t seed=seed_t)
Definition random_gen.hpp:19
typename vector_type< f8_ocp_t, 2 >::type f8x2_ocp_t
Definition dtype_vector.hpp:2200
typename vector_type< float, 32 >::type float32_t
Definition dtype_vector.hpp:2149
typename vector_type< f8_ocp_t, 16 >::type f8x16_ocp_t
Definition dtype_vector.hpp:2203
__host__ __device__ bf8x16_ocp_t mxf8_convert_rne< bf8x16_ocp_t, float16_t >(float16_t x, float scale)
Definition mxf8_utils.hpp:384
__host__ __device__ f8x16_ocp_t mxf8_convert_rne< f8x16_ocp_t, float16_t >(float16_t x, float scale)
Definition mxf8_utils.hpp:361
__host__ __device__ bf8x32_ocp_t mxf8_convert_sr< bf8x32_ocp_t, float32_t >(float32_t x, float scale)
Definition mxf8_utils.hpp:554
__host__ __device__ f8x32_ocp_t mxf8_convert_sr< f8x32_ocp_t, float32_t >(float32_t x, float scale)
Definition mxf8_utils.hpp:531
__host__ __device__ constexpr Y mxf8_convert_sr(X x, float scale)
typename vector_type< bf8_ocp_t, 16 >::type bf8x16_ocp_t
Definition dtype_vector.hpp:2211
__host__ __device__ bf8x2_ocp_t mxf8_convert_sr< bf8x2_ocp_t, float2_t >(float2_t x, float scale)
Definition mxf8_utils.hpp:476
__host__ __device__ constexpr Y bit_cast(const X &x)
Definition type.hpp:306
__host__ __device__ f8x32_ocp_t mxf8_convert_rne< f8x32_ocp_t, float32_t >(float32_t x, float scale)
Definition mxf8_utils.hpp:407
__host__ __device__ bf8x2_ocp_t mxf8_convert_rne< bf8x2_ocp_t, float2_t >(float2_t x, float scale)
Definition mxf8_utils.hpp:353
__host__ __device__ f8x2_ocp_t mxf8_convert_sr< f8x2_ocp_t, float2_t >(float2_t x, float scale)
Definition mxf8_utils.hpp:468
unsigned char fp8_storage_t
Definition amd_ck_fp8.hpp:64
_W64 unsigned int uintptr_t
Definition stdint.h:164
unsigned int uint32_t
Definition stdint.h:126
Definition amd_ck_fp8.hpp:369
Definition amd_ck_fp8.hpp:323
Definition functional2.hpp:33