quantization_operation.hpp Source File

quantization_operation.hpp Source File#

Composable Kernel: quantization_operation.hpp Source File
quantization_operation.hpp
Go to the documentation of this file.
1#pragma once
2
4// #include "ck/utility/get_id.hpp"
5
6namespace ck {
7namespace tensor_operation {
8namespace element_wise {
9
10// Y = Sy * Qy
11// W = Sw * Qw
12// X = Sx * Qx
13// B = Sb * Qb = Sw * Sx * Qb
14// Where X, W, Y are float32, Qx, Qw, Qy are int8
15// Sx, Sw, Sy are scale of x, w, y (float32), which is calculated from quantization range
16// Qb is int32, scale of B is Sw * Sx for convenient
17
18// Y = W @ X, where @ is convolution or matrix multiplication
19// Sy * Qy = Sw * Qw @ Sx * Qx
20// Qy = [(Sw*Sx)/Sy] * Qw @ Qx
21
22// For Activation function which is piecewise linear function, such as relu, leaky relu ...etc
23// Activation(Sy * Qy) = Sy * Activation(Qy)
24template <typename Activation>
26{
27 static constexpr const char* name = "Activation_Mul_Clamp";
28
29 // Convolution + Activation (piecewise linear function)
30 // If an activation is piecewise linear function, then Activation(Sy * Qy) = Sy * Activation(Qy)
31 // Z = Activation(Y) = Activation(W @ X)
32 // Sz * Qz = Activation(Sy * Qy)
33 // Qz = Sy / Sz * Activation(Qy) = (Sw * Sx / Sz) * Activation(Qw @ Qx)
34
35 // requantScale_ = Sw * Sx / Sz
36 Activation_Mul_Clamp(float requantScale, Activation activationOp)
37 : requantScale_(requantScale), activationOp_(activationOp)
38 {
39 }
40
41 __host__ __device__ constexpr void operator()(int8_t& y, const int32_t& x) const
42 {
43 float y_fp32 = ck::type_convert<float>(x);
44 activationOp_(y_fp32, y_fp32);
45 y_fp32 = math::clamp(requantScale_ * y_fp32, -128.f, 127.f);
46 y = ck::type_convert<int8_t>(y_fp32);
47 }
48
49 __device__ constexpr void operator()(int32_t& y, const int32_t& x) const
50 {
51 // CAUSION - We might type_convert to int8 in threadwise copy
52 // eg. GridwiseGemmDlMultipleD_km_kn_mn
53 float y_fp32 = ck::type_convert<float>(x);
54 activationOp_(y_fp32, y_fp32);
55 y_fp32 = math::clamp(requantScale_ * y_fp32, -128.f, 127.f);
56 y = ck::type_convert<int32_t>(y_fp32);
57 }
58
59 __host__ constexpr void operator()(float& y, const float& x) const
60 {
61 // CAUSION - We might float in & float out in reference code
62 activationOp_(y, x);
63 y = math::clamp(requantScale_ * y, -128.f, 127.f);
64 }
65
68};
69
70// For Activation function which is non piecewise linear function, such as TanH, Sigmoid ...etc
71// If an activation is not piecewise linear function
72// then Activation(Sy * Qy) != Sy * Activation(Qy)
73template <typename Activation>
75{
76 static constexpr const char* name = "Mul_Activation_Mul_Clamp";
77
78 // Convolution + Activation (non piecewise linear function)
79 // Z = Activation(Y) = Activation(W @ X)
80 // Sz * Qz = Activation(Sy * Qy)
81 // Qz = S1 * Activation[Sacc * (Qw @ Qx)]
82 // Where S1 = 1 / Sz, Sacc = Sw * Sx
83 Mul_Activation_Mul_Clamp(float scale_z_inv, float scaleAcc, Activation activationOp)
84 : scale_z_inv_(scale_z_inv), scaleAcc_(scaleAcc), activationOp_(activationOp)
85 {
86 }
87
88 __host__ __device__ constexpr void operator()(int8_t& y, const int32_t& x) const
89 {
90 float y_fp32 = ck::type_convert<float>(x);
91 y_fp32 = scaleAcc_ * y_fp32;
92 activationOp_(y_fp32, y_fp32);
93 y_fp32 = math::clamp(scale_z_inv_ * y_fp32, -128.f, 127.f);
94 y = ck::type_convert<int8_t>(y_fp32);
95 }
96
98 float scaleAcc_;
100};
101
102// Conv Perchannel quantization + Activation function which is piecewise linear function, such as
103// relu, leaky relu ...etc
104// Activation(Sy * Qy) = Sy * Activation(Qy)
105template <typename Activation>
107{
108 static constexpr const char* name = "Activation_Mul2_Clamp";
109
110 Activation_Mul2_Clamp(Activation activationOp) : activationOp_(activationOp) {}
111
112 __host__ __device__ constexpr void
113 operator()(int8_t& y, const int32_t& x, const float& requantScale) const
114 {
115 float y_fp32 = ck::type_convert<float>(x);
116 activationOp_(y_fp32, y_fp32);
117 y_fp32 = math::clamp(requantScale * y_fp32, -128.f, 127.f);
118 y = ck::type_convert<int8_t>(y_fp32);
119 }
120
121 __device__ constexpr void
122 operator()(int32_t& y, const int32_t& x, const float& requantScale) const
123 {
124 // CAUSION - We might type_convert to int8 in threadwise copy
125 // eg. GridwiseGemmDlMultipleD_km_kn_mn
126 float y_fp32 = ck::type_convert<float>(x);
127 activationOp_(y_fp32, y_fp32);
128 y_fp32 = math::clamp(requantScale * y_fp32, -128.f, 127.f);
129 y = ck::type_convert<int32_t>(y_fp32);
130 }
131
133};
134
135// For Activation function which is piecewise linear function, such as relu, leaky relu ...etc
136// Activation(Sy * Qy) = Sy * Activation(Qy)
137template <typename Activation>
139{
140 static constexpr const char* name = "Add_Activation_Mul_Clamp";
141
142 // Convolution + bias
143 // Let Bias = B = Sw * Sx * Qb
144 // Where Qb is int32
145 // Y = W @ X + B
146 // Sy * Qy = Sw * Qw @ Sx * Qx + Sw * Sx * Qb
147 // Qy = [(Sw*Sx)/Sy] * (Qw @ Qx + Qb)
148
149 // For activation, Z = Activaiton(Y)
150 // Sz * Qz = Activation(Sy * Qy)
151 // Qz = Sy / Sz * Activation(Qy) = [(Sw*Sx)/Sz] * Activation(Qw @ Qx + Qb)
152 Add_Activation_Mul_Clamp(float requantScale, Activation activationOp)
153 : requantScale_(requantScale), activationOp_(activationOp)
154 {
155 }
156
157 __host__ __device__ constexpr void
158 operator()(int8_t& y, const int32_t& x, const int32_t& bias) const
159 {
160 float y_fp32 = ck::type_convert<float>(x + bias);
161 activationOp_(y_fp32, y_fp32);
162 y_fp32 = math::clamp(requantScale_ * y_fp32, -128.f, 127.f);
163 y = ck::type_convert<int8_t>(y_fp32);
164 }
165
166 __host__ __device__ constexpr void
167 operator()(int32_t& y, const int32_t& x, const int32_t& bias) const
168 {
169 // CAUSION - We might type_convert to int8 in threadwise copy
170 // eg. GridwiseGemmDlMultipleD_km_kn_mn
171 float y_fp32 = ck::type_convert<float>(x + bias);
172 activationOp_(y_fp32, y_fp32);
173 y_fp32 = math::clamp(requantScale_ * y_fp32, -128.f, 127.f);
174 y = ck::type_convert<int32_t>(y_fp32);
175 }
176
179};
180
181// Conv Perchannel quantization + Activation function which is piecewise linear function, such as
182// relu, leaky relu ...etc
183template <typename Activation>
185{
186 static constexpr const char* name = "Add_Activation_Mul2_Clamp";
187
188 Add_Activation_Mul2_Clamp(Activation activationOp) : activationOp_(activationOp) {}
189
190 __host__ __device__ constexpr void
191 operator()(int8_t& y, const int32_t& x, const int32_t& bias, const float& requantScale) const
192 {
193 float y_fp32 = ck::type_convert<float>(x + bias);
194 activationOp_(y_fp32, y_fp32);
195 y_fp32 = math::clamp(requantScale * y_fp32, -128.f, 127.f);
196 y = ck::type_convert<int8_t>(y_fp32);
197 }
198
199 __host__ __device__ constexpr void
200 operator()(int32_t& y, const int32_t& x, const int32_t& bias, const float& requantScale) const
201 {
202 // CAUSION - We might type_convert to int8 in threadwise copy
203 // eg. GridwiseGemmDlMultipleD_km_kn_mn
204 float y_fp32 = ck::type_convert<float>(x + bias);
205 activationOp_(y_fp32, y_fp32);
206 y_fp32 = math::clamp(requantScale * y_fp32, -128.f, 127.f);
207 y = ck::type_convert<int32_t>(y_fp32);
208 }
209
211};
212
213// For Activation function which is non piecewise linear function, such as TanH, Sigmoid ...etc
214// If an activation is not piecewise linear function
215// then Activation(Sy * Qy) != Sy * Activation(Qy)
216template <typename Activation>
218{
219 static constexpr const char* name = "Add_Mul_Activation_Mul_Clamp";
220
221 // Convolution + Activation (non piecewise linear function)
222 // Z = Activation(Y) = Activation(W @ X + B)
223 // Sz * Qz = Activation(Sy * Qy)
224 // Qz = S1 * Activation[Sacc * (Qw @ Qx + Qb)]
225 // Where S1 = 1 / Sz, Sacc = Sw * Sx
226 Add_Mul_Activation_Mul_Clamp(float scale_z_inv, float scaleAcc, Activation activationOp)
227 : scale_z_inv_(scale_z_inv), scaleAcc_(scaleAcc), activationOp_(activationOp)
228 {
229 }
230
231 __host__ __device__ constexpr void
232 operator()(int8_t& y, const int32_t& x, const int32_t& bias) const
233 {
234 float y_fp32 = ck::type_convert<float>(x + bias);
235 y_fp32 = scaleAcc_ * y_fp32;
236 activationOp_(y_fp32, y_fp32);
237 y_fp32 = math::clamp(scale_z_inv_ * y_fp32, -128.f, 127.f);
238 y = ck::type_convert<int8_t>(y_fp32);
239 }
240
241 __host__ __device__ constexpr void
242 operator()(int32_t& y, const int32_t& x, const int32_t& bias) const
243 {
244 // CAUSION - We might type_convert to int8 in threadwise copy
245 // eg. GridwiseGemmDlMultipleD_km_kn_mn
246 float y_fp32 = ck::type_convert<float>(x + bias);
247 y_fp32 = scaleAcc_ * y_fp32;
248 activationOp_(y_fp32, y_fp32);
249 y_fp32 = math::clamp(scale_z_inv_ * y_fp32, -128.f, 127.f);
250 y = ck::type_convert<int32_t>(y_fp32);
251 }
252
256};
257
258// Conv Perchannel quantization + Activation function which is non piecewise linear function,
259// such as TanH, Sigmoid ...etc
260// If an activation is not piecewise linear function
261// then Activation(Sy *Qy) != Sy * Activation(Qy)
262template <typename Activation>
264{
265 static constexpr const char* name = "Add_Mul2_Activation_Mul_Clamp";
266
267 Add_Mul2_Activation_Mul_Clamp(float scale_z_inv, Activation activationOp)
268 : scale_z_inv_(scale_z_inv), activationOp_(activationOp)
269 {
270 }
271
272 __host__ __device__ constexpr void
273 operator()(int8_t& y, const int32_t& x, const int32_t& bias, const float& scaleAcc) const
274 {
275 float y_fp32 = ck::type_convert<float>(x + bias);
276 y_fp32 = scaleAcc * y_fp32;
277 activationOp_(y_fp32, y_fp32);
278 y_fp32 = math::clamp(scale_z_inv_ * y_fp32, -128.f, 127.f);
279 y = ck::type_convert<int8_t>(y_fp32);
280 }
281
282 __host__ __device__ constexpr void
283 operator()(int32_t& y, const int32_t& x, const int32_t& bias, const float& scaleAcc) const
284 {
285 // CAUSION - We might type_convert to int8 in threadwise copy
286 // eg. GridwiseGemmDlMultipleD_km_kn_mn
287 float y_fp32 = ck::type_convert<float>(x + bias);
288 y_fp32 = scaleAcc * y_fp32;
289 activationOp_(y_fp32, y_fp32);
290 y_fp32 = math::clamp(scale_z_inv_ * y_fp32, -128.f, 127.f);
291 y = ck::type_convert<int32_t>(y_fp32);
292 }
293
296};
297
298} // namespace element_wise
299} // namespace tensor_operation
300} // namespace ck
__host__ __device__ constexpr T clamp(const T &x, const T &lowerbound, const T &upperbound)
Definition utility/math.hpp:148
Definition binary_element_wise_operation.hpp:11
Definition convolution_backward_data_specialization.hpp:7
Definition ck.hpp:268
__host__ __device__ constexpr Y type_convert(X x)
Definition utility/type_convert.hpp:98
Activation
Definition gridwise_moe_gemm.hpp:31
signed int int32_t
Definition stdint.h:123
signed char int8_t
Definition stdint.h:121
Activation activationOp_
Definition quantization_operation.hpp:132
static constexpr const char * name
Definition quantization_operation.hpp:108
__host__ __device__ constexpr void operator()(int8_t &y, const int32_t &x, const float &requantScale) const
Definition quantization_operation.hpp:113
__device__ constexpr void operator()(int32_t &y, const int32_t &x, const float &requantScale) const
Definition quantization_operation.hpp:122
Activation_Mul2_Clamp(Activation activationOp)
Definition quantization_operation.hpp:110
Activation activationOp_
Definition quantization_operation.hpp:67
float requantScale_
Definition quantization_operation.hpp:66
__device__ constexpr void operator()(int32_t &y, const int32_t &x) const
Definition quantization_operation.hpp:49
__host__ constexpr void operator()(float &y, const float &x) const
Definition quantization_operation.hpp:59
__host__ __device__ constexpr void operator()(int8_t &y, const int32_t &x) const
Definition quantization_operation.hpp:41
Activation_Mul_Clamp(float requantScale, Activation activationOp)
Definition quantization_operation.hpp:36
static constexpr const char * name
Definition quantization_operation.hpp:27
__host__ __device__ constexpr void operator()(int32_t &y, const int32_t &x, const int32_t &bias, const float &requantScale) const
Definition quantization_operation.hpp:200
static constexpr const char * name
Definition quantization_operation.hpp:186
__host__ __device__ constexpr void operator()(int8_t &y, const int32_t &x, const int32_t &bias, const float &requantScale) const
Definition quantization_operation.hpp:191
Add_Activation_Mul2_Clamp(Activation activationOp)
Definition quantization_operation.hpp:188
Activation activationOp_
Definition quantization_operation.hpp:210
Activation activationOp_
Definition quantization_operation.hpp:178
float requantScale_
Definition quantization_operation.hpp:177
__host__ __device__ constexpr void operator()(int32_t &y, const int32_t &x, const int32_t &bias) const
Definition quantization_operation.hpp:167
Add_Activation_Mul_Clamp(float requantScale, Activation activationOp)
Definition quantization_operation.hpp:152
__host__ __device__ constexpr void operator()(int8_t &y, const int32_t &x, const int32_t &bias) const
Definition quantization_operation.hpp:158
static constexpr const char * name
Definition quantization_operation.hpp:140
Activation activationOp_
Definition quantization_operation.hpp:295
static constexpr const char * name
Definition quantization_operation.hpp:265
float scale_z_inv_
Definition quantization_operation.hpp:294
__host__ __device__ constexpr void operator()(int8_t &y, const int32_t &x, const int32_t &bias, const float &scaleAcc) const
Definition quantization_operation.hpp:273
Add_Mul2_Activation_Mul_Clamp(float scale_z_inv, Activation activationOp)
Definition quantization_operation.hpp:267
__host__ __device__ constexpr void operator()(int32_t &y, const int32_t &x, const int32_t &bias, const float &scaleAcc) const
Definition quantization_operation.hpp:283
float scaleAcc_
Definition quantization_operation.hpp:254
Add_Mul_Activation_Mul_Clamp(float scale_z_inv, float scaleAcc, Activation activationOp)
Definition quantization_operation.hpp:226
float scale_z_inv_
Definition quantization_operation.hpp:253
__host__ __device__ constexpr void operator()(int8_t &y, const int32_t &x, const int32_t &bias) const
Definition quantization_operation.hpp:232
Activation activationOp_
Definition quantization_operation.hpp:255
static constexpr const char * name
Definition quantization_operation.hpp:219
__host__ __device__ constexpr void operator()(int32_t &y, const int32_t &x, const int32_t &bias) const
Definition quantization_operation.hpp:242
__host__ __device__ constexpr void operator()(int8_t &y, const int32_t &x) const
Definition quantization_operation.hpp:88
float scaleAcc_
Definition quantization_operation.hpp:98
static constexpr const char * name
Definition quantization_operation.hpp:76
Mul_Activation_Mul_Clamp(float scale_z_inv, float scaleAcc, Activation activationOp)
Definition quantization_operation.hpp:83
float scale_z_inv_
Definition quantization_operation.hpp:97
Activation activationOp_
Definition quantization_operation.hpp:99