smoothquant_kernel.hpp Source File

smoothquant_kernel.hpp Source File#

Composable Kernel: smoothquant_kernel.hpp Source File
smoothquant_kernel.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include "ck_tile/core.hpp"
8
9namespace ck_tile {
10
11// host side args
13{
14 const void* p_x; // [m ,n], input, fp16/bf16
15 const void* p_smscale; // [1, n], input, columnwise scale, fp32
16
17 void* p_yscale; // [m, 1], output, rowwise quant scale (amax / 127) of (p_x * p_smscale)
18 void* p_qy; // [m, n], output, p_x * p_smscale / p_yscale
19
22 index_t x_stride; // input row_stride
23 index_t y_stride; // output row_stride
24};
25
26// TODO: Extract some type to wrapper class
27template <typename Pipeline_>
29{
31 using Problem = typename Pipeline::Problem;
32
38
39 static constexpr index_t Block_M = Problem::BlockShape::Block_M;
40 static constexpr index_t Block_N = Problem::BlockShape::Block_N;
41 static constexpr bool kPadM = false; // always no need to pad along M
42 static constexpr bool kPadN = Problem::kPadN;
43 static constexpr bool kTwoPass = Problem::kTwoPass;
44
45 static constexpr index_t ThreadPerWarp_N = Problem::BlockShape::ThreadPerWarp_N;
46 static constexpr index_t Vector_N = Problem::BlockShape::Vector_N;
47 static constexpr index_t Repeat_N = Problem::BlockShape::Repeat_N;
48 static constexpr index_t kBlockSize = Problem::BlockShape::BlockSize;
49
50 static constexpr auto I0 = number<0>{};
51 static constexpr auto I1 = number<1>{};
52
53 struct Kargs
54 {
55 const void* p_x;
56 const void* p_smscale;
57
58 void* p_yscale;
59 void* p_qy;
60
63 index_t x_stride; // input row_stride
64 index_t y_stride; // out row_stride
65 };
67
68 CK_TILE_HOST static constexpr Kargs MakeKargs(const Hargs& hargs)
69 {
70 return Kargs{hargs.p_x,
71 hargs.p_smscale,
72 hargs.p_yscale,
73 hargs.p_qy,
74 hargs.m,
75 hargs.n,
76 hargs.x_stride,
77 hargs.y_stride};
78 }
79
80 CK_TILE_HOST static constexpr auto GridSize(const Hargs& hargs)
81 {
82 return dim3(integer_divide_ceil(hargs.m, Block_M));
83 }
84
85 CK_TILE_HOST static constexpr auto BlockSize()
86 {
87 return is_wave32() ? Problem::BlockShape::template GetBlockSize<true>()
88 : Problem::BlockShape::template GetBlockSize<false>();
89 }
90
91 // clang-format off
92 template <typename T> struct t2s;
93 template <> struct t2s<float> { static constexpr const char * name = "fp32"; };
94 template <> struct t2s<ck_tile::fp16_t> { static constexpr const char * name = "fp16"; };
95 template <> struct t2s<ck_tile::bf16_t> { static constexpr const char * name = "bf16"; };
96 template <> struct t2s<ck_tile::fp8_t> { static constexpr const char * name = "fp8"; };
97 template <> struct t2s<ck_tile::bf8_t> { static constexpr const char * name = "bf8"; };
98 // clang-format on
99
100 // in byte
101 CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return Pipeline::GetSmemSize(); }
102
103 CK_TILE_HOST static std::string GetName()
104 {
105 // clang-format off
106 using S_ = typename Problem::BlockShape;
107 auto surfix = [&] () {
108 std::string n;
109 if (kPadN) n += "_pn";
110 if (kTwoPass) n += "_2p";
111 return n; }();
112
113 #define _SS_ std::string
114 #define _TS_ std::to_string
115 return _SS_("smoothquant_fwd_") + _SS_(t2s<XDataType>::name) + "_" +
116 _TS_(S_::Block_M) + "x" + _TS_(S_::Block_N) + "_" + _TS_(S_::WarpPerBlock_M) + "x" + _TS_(S_::WarpPerBlock_N) + "_" +
117 _TS_(S_::Warp_M) + "x" + _TS_(S_::Warp_N) + "_" + _TS_(S_::Vector_M) + "x" + _TS_(S_::Vector_N) + "_" +
118 _SS_(Pipeline::name) + surfix;
119 #undef _SS_
120 #undef _TS_
121 // clang-format on
122 }
123
125 {
126 const auto iM = get_block_id() * Block_M;
127
128 const auto x_window = [&]() {
130 static_cast<const XDataType*>(kargs.p_x),
131 make_tuple(kargs.m, kargs.n),
132 make_tuple(kargs.x_stride, 1),
134 number<1>{});
135
136 const auto tmp2_ = pad_tensor_view(
138 return make_tile_window(
139 tmp2_, make_tuple(number<Block_M>{}, number<Block_N>{}), {iM, 0});
140 }();
141
142 const auto smscale_window = [&]() {
144 static_cast<const SmoothScaleDataType*>(kargs.p_smscale),
145 make_tuple(kargs.n),
146 make_tuple(1),
148 number<1>{});
149
150 const auto tmp2_ =
152
153 return make_tile_window(tmp2_, make_tuple(number<Block_N>{}), {0});
154 }();
155
156 auto yscale_window = [&]() {
158 static_cast<YScaleDataType*>(kargs.p_yscale),
159 make_tuple(kargs.m),
160 make_tuple(1),
161 number<1>{});
162
163 const auto tmp2_ =
165
166 return make_tile_window(tmp2_, make_tuple(number<Block_M>{}), {iM});
167 }();
168
169 auto qy_window = [&]() {
171 static_cast<QYDataType*>(kargs.p_qy),
172 make_tuple(kargs.m, kargs.n),
173 make_tuple(kargs.y_stride, 1),
175 number<1>{});
176
177 auto tmp2_ = pad_tensor_view(
179 return make_tile_window(
180 tmp2_, make_tuple(number<Block_M>{}, number<Block_N>{}), {iM, 0});
181 }();
182
183 __shared__ char smem[GetSmemSize()];
184
185 Pipeline{}(x_window, smscale_window, yscale_window, qy_window, kargs.n, smem);
186 }
187};
188
189} // namespace ck_tile
#define _TS_
#define _SS_
#define CK_TILE_DEVICE
Definition config.hpp:41
#define CK_TILE_HOST
Definition config.hpp:40
#define CK_TILE_HOST_DEVICE
Definition config.hpp:42
Definition tile/core/algorithm/cluster_descriptor.hpp:13
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
CK_TILE_HOST_DEVICE constexpr auto make_naive_tensor_view(DataType *__restrict__ p, const tuple< Lengths... > &lengths, const tuple< Strides... > &strides, number< GuaranteedLastDimensionVectorLength >=number<-1 >{}, number< GuaranteedLastDimensionVectorStride >=number<-1 >{})
Definition tensor_view.hpp:471
bfloat16_t bf16_t
Definition bfloat16.hpp:113
_Float16 fp16_t
Definition half.hpp:110
_BitInt(8) fp8_t
Definition float8.hpp:204
constant< v > number
Definition tile/core/numeric/integral_constant.hpp:37
CK_TILE_DEVICE constexpr auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition null_tile_window.hpp:75
CK_TILE_HOST_DEVICE constexpr auto integer_divide_ceil(X x, Y y)
Definition tile/core/numeric/math.hpp:149
CK_TILE_HOST_DEVICE constexpr auto pad_tensor_view(const TensorView &tensor_view, const TileLengths &tile_lengths, DoPads)
Definition tensor_view.hpp:530
unsigned _BitInt(8) bf8_t
Definition float8.hpp:206
CK_TILE_DEVICE index_t get_block_id()
Definition arch.hpp:119
int32_t index_t
Definition integer.hpp:9
CK_TILE_HOST_DEVICE constexpr auto make_tuple(Xs &&... xs)
Definition tile/core/container/tuple.hpp:360
CK_TILE_HOST bool is_wave32()
Definition arch.hpp:72
Definition smoothquant_kernel.hpp:54
index_t m
Definition smoothquant_kernel.hpp:61
index_t n
Definition smoothquant_kernel.hpp:62
index_t x_stride
Definition smoothquant_kernel.hpp:63
const void * p_x
Definition smoothquant_kernel.hpp:55
void * p_qy
Definition smoothquant_kernel.hpp:59
const void * p_smscale
Definition smoothquant_kernel.hpp:56
void * p_yscale
Definition smoothquant_kernel.hpp:58
index_t y_stride
Definition smoothquant_kernel.hpp:64
static constexpr const char * name
Definition smoothquant_kernel.hpp:95
static constexpr const char * name
Definition smoothquant_kernel.hpp:97
static constexpr const char * name
Definition smoothquant_kernel.hpp:94
static constexpr const char * name
Definition smoothquant_kernel.hpp:96
static constexpr const char * name
Definition smoothquant_kernel.hpp:93
Definition smoothquant_kernel.hpp:92
Definition smoothquant_kernel.hpp:13
index_t y_stride
Definition smoothquant_kernel.hpp:23
const void * p_smscale
Definition smoothquant_kernel.hpp:15
void * p_qy
Definition smoothquant_kernel.hpp:18
index_t x_stride
Definition smoothquant_kernel.hpp:22
void * p_yscale
Definition smoothquant_kernel.hpp:17
index_t m
Definition smoothquant_kernel.hpp:20
index_t n
Definition smoothquant_kernel.hpp:21
const void * p_x
Definition smoothquant_kernel.hpp:14
Definition smoothquant_kernel.hpp:29
static constexpr index_t Block_M
Definition smoothquant_kernel.hpp:39
static constexpr index_t kBlockSize
Definition smoothquant_kernel.hpp:48
remove_cvref_t< Pipeline_ > Pipeline
Definition smoothquant_kernel.hpp:30
static CK_TILE_HOST_DEVICE constexpr index_t GetSmemSize()
Definition smoothquant_kernel.hpp:101
static constexpr auto I1
Definition smoothquant_kernel.hpp:51
static constexpr index_t Repeat_N
Definition smoothquant_kernel.hpp:47
static CK_TILE_HOST constexpr auto BlockSize()
Definition smoothquant_kernel.hpp:85
static CK_TILE_HOST constexpr Kargs MakeKargs(const Hargs &hargs)
Definition smoothquant_kernel.hpp:68
static CK_TILE_HOST constexpr auto GridSize(const Hargs &hargs)
Definition smoothquant_kernel.hpp:80
remove_cvref_t< typename Problem::SmoothScaleDataType > SmoothScaleDataType
Definition smoothquant_kernel.hpp:34
static constexpr bool kPadM
Definition smoothquant_kernel.hpp:41
CK_TILE_DEVICE void operator()(Kargs kargs) const
Definition smoothquant_kernel.hpp:124
static CK_TILE_HOST std::string GetName()
Definition smoothquant_kernel.hpp:103
static constexpr index_t ThreadPerWarp_N
Definition smoothquant_kernel.hpp:45
static constexpr bool kTwoPass
Definition smoothquant_kernel.hpp:43
static constexpr auto I0
Definition smoothquant_kernel.hpp:50
remove_cvref_t< typename Problem::XDataType > XDataType
Definition smoothquant_kernel.hpp:33
static constexpr index_t Vector_N
Definition smoothquant_kernel.hpp:46
remove_cvref_t< typename Problem::QYDataType > QYDataType
Definition smoothquant_kernel.hpp:37
SmoothquantHostArgs Hargs
Definition smoothquant_kernel.hpp:66
static constexpr bool kPadN
Definition smoothquant_kernel.hpp:42
remove_cvref_t< typename Problem::ComputeDataType > ComputeDataType
Definition smoothquant_kernel.hpp:35
static constexpr index_t Block_N
Definition smoothquant_kernel.hpp:40
remove_cvref_t< typename Problem::YScaleDataType > YScaleDataType
Definition smoothquant_kernel.hpp:36
typename Pipeline::Problem Problem
Definition smoothquant_kernel.hpp:31
Definition tile/core/container/sequence.hpp:49