blockwise_gemm_dl_v2r3.hpp Source File

blockwise_gemm_dl_v2r3.hpp Source File#

Composable Kernel: blockwise_gemm_dl_v2r3.hpp Source File
blockwise_gemm_dl_v2r3.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
10
11namespace ck {
12
13// C[BM0, BM1, BN0, BN1] += transpose(A[K, BM0, BM1]) * B[K, BN0, BN1]
14// A and B are visible to the whole block, C is distributed among each thread
15// Assume:
16// 1. A:
17// 1. ABlockDesc_BK0_BM_BK1 is known at compile-time
18// 2. ABlockBuffer is DynamicBuffer
19// 2. B:
20// 1. BBlockDesc_BK0_BN_BK1 is known at compile-time
21// 2. BBlockBuffer is DynamicBuffer
22// 3. C:
23// 1. CThreadDesc_BM0_BM11_BN0_BN11 is known at compile-time
24// 2. CThreadBuffer is StaticBuffer
25// Also assume:
26// BM10BN10ThreadClusterBM10Xs::Size() = BM10BN10ThreadClusterBN10Xs::Size() == 2
27// BM0 = BN0 = 2. It will do 2x2 pipelined read and fma (ABBA optimization)
28template <index_t BlockSize,
29 typename FloatA,
30 typename FloatB,
31 typename FloatC,
32 typename ABlockDesc_BK0_BM_BK1,
33 typename BBlockDesc_BK0_BN_BK1,
34 index_t BM1PerThreadBM11,
35 index_t BN1PerThreadBN11,
36 index_t BK0PerThread,
37 typename BM10BN10ThreadClusterBM10Xs, // Sequence<BM10BN10ThreadClusterBM100,
38 // BM10BN10ThreadClusterBM101, ...>
39 typename BM10BN10ThreadClusterBN10Xs, // Sequence<BM10BN10ThreadClusterBN100,
40 // BM10BN10ThreadClusterBN101, ...>
41 index_t AThreadCopyScalarPerVector_BM11,
42 index_t BThreadCopyScalarPerVector_BN11,
43 typename enable_if<ABlockDesc_BK0_BM_BK1::IsKnownAtCompileTime() &&
44 BBlockDesc_BK0_BN_BK1::IsKnownAtCompileTime(),
45 bool>::type = false>
47{
51
52 static constexpr auto I0 = Number<0>{};
53 static constexpr auto I1 = Number<1>{};
54 static constexpr auto I2 = Number<2>{};
55 static constexpr auto I3 = Number<3>{};
56
57 static constexpr index_t BK0 = ABlockDesc_BK0_BM_BK1{}.GetLength(I0);
58 static constexpr index_t BK1 = ABlockDesc_BK0_BM_BK1{}.GetLength(I2);
59 static constexpr index_t BM = ABlockDesc_BK0_BM_BK1{}.GetLength(I1);
60 static constexpr index_t BN = BBlockDesc_BK0_BN_BK1{}.GetLength(I1);
61
62 static constexpr index_t BM100 = BM10BN10ThreadClusterBM10Xs{}[I0];
63 static constexpr index_t BN100 = BM10BN10ThreadClusterBN10Xs{}[I0];
64
65 static constexpr index_t BM101 = BM10BN10ThreadClusterBM10Xs{}[I1];
66 static constexpr index_t BN101 = BM10BN10ThreadClusterBN10Xs{}[I1];
67
68 static constexpr index_t BM11 = BM1PerThreadBM11;
69 static constexpr index_t BN11 = BN1PerThreadBN11;
70
71 static constexpr index_t BM1 = BM100 * BM101 * BM11;
72 static constexpr index_t BN1 = BN100 * BN101 * BN11;
73
74 static constexpr index_t BM0 = BM / BM1;
75 static constexpr index_t BN0 = BN / BN1;
76
77 __host__ __device__ static constexpr auto
78 MakeABlockDescriptor_BK0_BM0_BM1_BK1(const ABlockDesc_BK0_BM_BK1& a_block_desc_bk0_bm_bk1)
79 {
80 const auto a_block_bk0_bm0_bm1_bk1 = transform_tensor_descriptor(
81 a_block_desc_bk0_bm_bk1,
87
88 return a_block_bk0_bm0_bm1_bk1;
89 }
90
91 __host__ __device__ static constexpr auto
92 MakeBBlockDescriptor_BK0_BN0_BN1_BK1(const BBlockDesc_BK0_BN_BK1& b_block_desc_bk0_bn_bk1)
93 {
94 const auto b_block_desc_bk0_bn0_bn1_bk1 = transform_tensor_descriptor(
95 b_block_desc_bk0_bn_bk1,
101
102 return b_block_desc_bk0_bn0_bn1_bk1;
103 }
104
105 __host__ __device__ static constexpr auto
107 {
108 // upper: [BM0, BM100, BM101, BM11, BN0, BN100, BN101, BN11]
109 // lower: [BM, BN]
110 constexpr auto c_block_adaptor_m0_m100_m101_m11_n0_n100_n101_n11_to_m_n =
118
119 return c_block_adaptor_m0_m100_m101_m11_n0_n100_n101_n11_to_m_n;
120 }
121
122 __host__ __device__ static constexpr auto
124 {
125 // upper: [BM0, BM100, BM101, BM11, BN0, BN100, BN101, BN11]
126 // lower: [BM0, BM1, BN0, BN1]
127 constexpr auto c_block_adaptor_m0_m100_m101_m11_n0_n100_n101_n11_to_m0_m1_n0_n1 =
137
138 return c_block_adaptor_m0_m100_m101_m11_n0_n100_n101_n11_to_m0_m1_n0_n1;
139 }
140
141 __host__ __device__ static constexpr auto GetCThreadTensorLengths_BM0_BM1_BN0_BN1()
142 {
144 }
145
146 static constexpr auto a_block_desc_bk0_bm0_bm1_bk1_ =
147 MakeABlockDescriptor_BK0_BM0_BM1_BK1(ABlockDesc_BK0_BM_BK1{});
148
149 static constexpr auto b_block_desc_bk0_bn0_bn1_bk1_ =
150 MakeBBlockDescriptor_BK0_BN0_BN1_BK1(BBlockDesc_BK0_BN_BK1{});
151
152 public:
154 : c_thread_origin_data_idx_{CalculateCThreadOriginOnBlock_BM0_BM1_BN0_BN1(
156 a_thread_copy_{
157 make_tuple(0, c_thread_origin_data_idx_[I0], c_thread_origin_data_idx_[I1], 0)},
158 b_thread_copy_{
159 make_tuple(0, c_thread_origin_data_idx_[I2], c_thread_origin_data_idx_[I3], 0)}
160 {
161 static_assert(ABlockDesc_BK0_BM_BK1::IsKnownAtCompileTime() &&
162 BBlockDesc_BK0_BN_BK1::IsKnownAtCompileTime(),
163 "wrong! Desc should be known at compile-time");
164
165 static_assert(BlockSize == BM101 * BM100 * BN101 * BN100,
166 "wrong! blocksize and cluster size not consistent");
167
168 static_assert(BM % BM1 == 0 && BN % BN1 == 0, "wrong!");
169
170 static_assert(ABlockDesc_BK0_BM_BK1{}.GetLength(I0) ==
171 BBlockDesc_BK0_BN_BK1{}.GetLength(I0),
172 "wrong! K dimension not consistent");
173
174 // TODO remove this restriction
175 static_assert(BM10BN10ThreadClusterBM10Xs::Size() == 2 &&
176 BM10BN10ThreadClusterBN10Xs::Size() == 2,
177 "wrong!");
178
179 // TODO: remove this restriction
180 static_assert(BM0 == 2, "wrong");
181 static_assert(BM0 == 2 && BN0 == 2, "wrong");
182 }
183
185 {
186 // lower: [BM0, BM1, BN0, BN1]
187 // upper: [BM0, BM100, BM101, BM11, BN0, BN100, BN101, BN11]
188 constexpr auto adaptor0 =
190
191 // lower: [BM0, BM100, BM101, BM11, BN0, BN100, BN101, BN11]
192 // upper: [Tid, BM0, BM11, BN0, BN11]
193 constexpr auto adaptor1 = make_single_stage_tensor_adaptor(
202
203 constexpr auto adaptor = chain_tensor_adaptors(adaptor0, adaptor1);
204
205 return adaptor.CalculateBottomIndex(make_multi_index(thread_id, 0, 0, 0, 0));
206 }
207
208 template <typename CThreadDesc_BM0_BM11_BN0_BN11,
209 typename ABlockBuffer,
210 typename BBlockBuffer,
211 typename CThreadBuffer>
212 __device__ void Run(const CThreadDesc_BM0_BM11_BN0_BN11&,
213 const ABlockBuffer& a_block_buf,
214 const BBlockBuffer& b_block_buf,
215 CThreadBuffer& c_thread_buf) const
216 {
217 static_assert(CThreadDesc_BM0_BM11_BN0_BN11::IsKnownAtCompileTime(),
218 "wrong! Desc should be known at compile-time");
219
220 // TODO: remove this restriction
221 static_assert(BM0 == 2 && BN0 == 2 &&
222 CThreadDesc_BM0_BM11_BN0_BN11{}.GetLength(I0) == BM0 &&
223 CThreadDesc_BM0_BM11_BN0_BN11{}.GetLength(I2) == BN0,
224 "wrong");
225
227 a_thread_desc_bk0_bm0_bm1_bk1_.GetElementSpaceSize());
229 b_thread_desc_bk0_bn0_bn1_bk1_.GetElementSpaceSize());
230
231 constexpr auto threadwise_contraction =
233 FloatA,
234 FloatB,
235 FloatC,
236 decltype(a_thread_desc_bk0_bm0_bm1_bk1_),
237 decltype(b_thread_desc_bk0_bn0_bn1_bk1_),
238 CThreadDesc_BM0_BM11_BN0_BN11,
242
243 // read A_sub_0
244 a_thread_copy_.Run(a_block_desc_bk0_bm0_bm1_bk1_,
245 make_tuple(I0, I0, I0, I0),
246 a_block_buf,
247 a_thread_desc_bk0_bm0_bm1_bk1_,
248 make_tuple(I0, I0, I0, I0),
249 a_thread_buf);
250
251 // read B_sub_0
252 b_thread_copy_.Run(b_block_desc_bk0_bn0_bn1_bk1_,
253 make_tuple(I0, I0, I0, I0),
254 b_block_buf,
255 b_thread_desc_bk0_bn0_bn1_bk1_,
256 make_tuple(I0, I0, I0, I0),
257 b_thread_buf);
258
259 // read B_sub_1
260 b_thread_copy_.Run(b_block_desc_bk0_bn0_bn1_bk1_,
261 make_tuple(I0, I1, I0, I0),
262 b_block_buf,
263 b_thread_desc_bk0_bn0_bn1_bk1_,
264 make_tuple(I0, I1, I0, I0),
265 b_thread_buf);
266
267 // read A_sub_1
268 a_thread_copy_.Run(a_block_desc_bk0_bm0_bm1_bk1_,
269 make_tuple(I0, I1, I0, I0),
270 a_block_buf,
271 a_thread_desc_bk0_bm0_bm1_bk1_,
272 make_tuple(I0, I1, I0, I0),
273 a_thread_buf);
274
275 // C_sub_00 += transpose(A_sub_0) * B_sub_0
276 threadwise_contraction.Run(a_thread_buf,
277 make_tuple(I0, I0, I0, I0),
278 b_thread_buf,
279 make_tuple(I0, I0, I0, I0),
280 c_thread_buf,
281 make_tuple(I0, I0, I0, I0));
282
283 // C_sub_01 += transpose(A_sub_0) * B_sub_1
284 threadwise_contraction.Run(a_thread_buf,
285 make_tuple(I0, I0, I0, I0),
286 b_thread_buf,
287 make_tuple(I0, I1, I0, I0),
288 c_thread_buf,
289 make_tuple(I0, I0, I1, I0));
290
291 // loop over rest of bk0
293 // read A_sub_0
294 a_thread_copy_.Run(a_block_desc_bk0_bm0_bm1_bk1_,
295 make_tuple(bk0, I0, I0, I0),
296 a_block_buf,
297 a_thread_desc_bk0_bm0_bm1_bk1_,
298 make_tuple(I0, I0, I0, I0),
299 a_thread_buf);
300
301 // C_sub_10 += transpose(A_sub_1) * B_sub_0
302 threadwise_contraction.Run(a_thread_buf,
303 make_tuple(I0, I1, I0, I0),
304 b_thread_buf,
305 make_tuple(I0, I0, I0, I0),
306 c_thread_buf,
307 make_tuple(I1, I0, I0, I0));
308
309 // read B_sub_0
310 b_thread_copy_.Run(b_block_desc_bk0_bn0_bn1_bk1_,
311 make_tuple(bk0, I0, I0, I0),
312 b_block_buf,
313 b_thread_desc_bk0_bn0_bn1_bk1_,
314 make_tuple(I0, I0, I0, I0),
315 b_thread_buf);
316
317 // C_sub_11 += transpose(A_sub_1) * B_sub_1
318 threadwise_contraction.Run(a_thread_buf,
319 make_tuple(I0, I1, I0, I0),
320 b_thread_buf,
321 make_tuple(I0, I1, I0, I0),
322 c_thread_buf,
323 make_tuple(I1, I0, I1, I0));
324
325 // read B_sub_1
326 b_thread_copy_.Run(b_block_desc_bk0_bn0_bn1_bk1_,
327 make_tuple(bk0, I1, I0, I0),
328 b_block_buf,
329 b_thread_desc_bk0_bn0_bn1_bk1_,
330 make_tuple(I0, I1, I0, I0),
331 b_thread_buf);
332
333 // read A_sub_1
334 a_thread_copy_.Run(a_block_desc_bk0_bm0_bm1_bk1_,
335 make_tuple(bk0, I1, I0, I0),
336 a_block_buf,
337 a_thread_desc_bk0_bm0_bm1_bk1_,
338 make_tuple(I0, I1, I0, I0),
339 a_thread_buf);
340
341 // C_sub_00 += transpose(A_sub_0) * B_sub_0
342 threadwise_contraction.Run(a_thread_buf,
343 make_tuple(I0, I0, I0, I0),
344 b_thread_buf,
345 make_tuple(I0, I0, I0, I0),
346 c_thread_buf,
347 make_tuple(I0, I0, I0, I0));
348
349 // C_sub_01 += transpose(A_sub_0) * B_sub_1
350 threadwise_contraction.Run(a_thread_buf,
351 make_tuple(I0, I0, I0, I0),
352 b_thread_buf,
353 make_tuple(I0, I1, I0, I0),
354 c_thread_buf,
355 make_tuple(I0, I0, I1, I0));
356 });
357
358 // C_sub_10 += transpose(A_sub_1) * B_sub_0
359 threadwise_contraction.Run(a_thread_buf,
360 make_tuple(I0, I1, I0, I0),
361 b_thread_buf,
362 make_tuple(I0, I0, I0, I0),
363 c_thread_buf,
364 make_tuple(I1, I0, I0, I0));
365
366 // C_sub_11 += transpose(A_sub_1) * B_sub_1
367 threadwise_contraction.Run(a_thread_buf,
368 make_tuple(I0, I1, I0, I0),
369 b_thread_buf,
370 make_tuple(I0, I1, I0, I0),
371 c_thread_buf,
372 make_tuple(I1, I0, I1, I0));
373 }
374
375 private:
376 // A[BK0, BM0, BM1, BK1]
377 static constexpr auto a_thread_desc_bk0_bm0_bm1_bk1_ =
380
381 // B[BK0, BN0, BN1, BK1]
382 static constexpr auto b_thread_desc_bk0_bn0_bn1_bk1_ =
385
386 using AThreadCopy = ThreadwiseTensorSliceTransfer_v4r1<
387 FloatA,
388 FloatA,
390 decltype(a_thread_desc_bk0_bm0_bm1_bk1_),
391 Sequence<BK0PerThread, 1, BM1PerThreadBM11, BK1>, // SliceLengths
392 Sequence<0, 1, 2, 3>, // DimAccessOrder
393 Sequence<1, 1, BM1PerThreadBM11, BK1>, // SrcVectorTensorLengths
394 Sequence<0, 1, 2, 3>>; // SrcVectorTensorContiguousDimOrder
395
396 using BThreadCopy = ThreadwiseTensorSliceTransfer_v4r1<
397 FloatB,
398 FloatB,
400 decltype(b_thread_desc_bk0_bn0_bn1_bk1_),
401 Sequence<BK0PerThread, 1, BN1PerThreadBN11, BK1>, // SliceLengths
402 Sequence<0, 1, 2, 3>, // DimAccessOrder
403 Sequence<1, 1, BN1PerThreadBN11, BK1>, // SrcVectorTensorLengths
404 Sequence<0, 1, 2, 3>>; // SrcVectorTensorContiguousDimOrder
405
406 CIndex c_thread_origin_data_idx_;
407
408 AThreadCopy a_thread_copy_;
409 BThreadCopy b_thread_copy_;
410};
411
412} // namespace ck
Definition ck.hpp:268
__host__ __device__ constexpr auto make_multi_index(Xs &&... xs)
Definition array_multi_index.hpp:15
__host__ __device__ constexpr auto make_static_buffer(Number< N >)
Definition static_buffer.hpp:186
__host__ __device__ constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition multi_index_transform_helper.hpp:12
int32_t index_t
Definition ck.hpp:299
__host__ __device__ constexpr auto chain_tensor_adaptors(const TensorAdaptor0 &adaptor0, const TensorAdaptor1 &adaptor1)
Definition tensor_description/tensor_adaptor.hpp:245
__host__ __device__ constexpr auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition tensor_description/tensor_adaptor.hpp:425
integral_constant< index_t, N > Number
Definition number.hpp:12
__host__ __device__ constexpr auto make_merge_transform(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:55
std::enable_if< B, T > enable_if
Definition enable_if.hpp:24
__host__ __device__ constexpr auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition tensor_descriptor_helper.hpp:101
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__device__ index_t get_thread_local_1d_id()
Definition get_id.hpp:41
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
__host__ __device__ constexpr auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:90
Array< index_t, N > MultiIndex
Definition array_multi_index.hpp:12
static constexpr auto I0
Definition blockwise_gemm_dl_v2r3.hpp:52
static constexpr index_t BK0
Definition blockwise_gemm_dl_v2r3.hpp:57
static constexpr auto I2
Definition blockwise_gemm_dl_v2r3.hpp:54
MultiIndex< 4 > CIndex
Definition blockwise_gemm_dl_v2r3.hpp:50
static constexpr index_t BM0
Definition blockwise_gemm_dl_v2r3.hpp:74
__device__ BlockwiseGemmDl_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2()
Definition blockwise_gemm_dl_v2r3.hpp:153
static constexpr index_t BM101
Definition blockwise_gemm_dl_v2r3.hpp:65
static constexpr index_t BM
Definition blockwise_gemm_dl_v2r3.hpp:59
static constexpr index_t BK1
Definition blockwise_gemm_dl_v2r3.hpp:58
static constexpr index_t BN101
Definition blockwise_gemm_dl_v2r3.hpp:66
static __device__ CIndex CalculateCThreadOriginOnBlock_BM0_BM1_BN0_BN1(index_t thread_id)
Definition blockwise_gemm_dl_v2r3.hpp:184
static constexpr index_t BM1
Definition blockwise_gemm_dl_v2r3.hpp:71
__host__ static __device__ constexpr auto MakeBBlockDescriptor_BK0_BN0_BN1_BK1(const BBlockDesc_BK0_BN_BK1 &b_block_desc_bk0_bn_bk1)
Definition blockwise_gemm_dl_v2r3.hpp:92
__device__ void Run(const CThreadDesc_BM0_BM11_BN0_BN11 &, const ABlockBuffer &a_block_buf, const BBlockBuffer &b_block_buf, CThreadBuffer &c_thread_buf) const
Definition blockwise_gemm_dl_v2r3.hpp:212
static constexpr index_t BN1
Definition blockwise_gemm_dl_v2r3.hpp:72
static constexpr index_t BN0
Definition blockwise_gemm_dl_v2r3.hpp:75
static constexpr auto I1
Definition blockwise_gemm_dl_v2r3.hpp:53
static constexpr auto a_block_desc_bk0_bm0_bm1_bk1_
Definition blockwise_gemm_dl_v2r3.hpp:146
__host__ static __device__ constexpr auto MakeCBlockAdaptor_BM0_BM100_BM101_BM11_BN0_BN100_BN101_BN11_To_BM_BN()
Definition blockwise_gemm_dl_v2r3.hpp:106
MultiIndex< 3 > AIndex
Definition blockwise_gemm_dl_v2r3.hpp:48
__host__ static __device__ constexpr auto GetCThreadTensorLengths_BM0_BM1_BN0_BN1()
Definition blockwise_gemm_dl_v2r3.hpp:141
__host__ static __device__ constexpr auto MakeCBlockAdaptor_BM0_BM100_BM101_BM11_BN0_BN100_BN101_BN11_To_BM0_BM1_BN0_BN1()
Definition blockwise_gemm_dl_v2r3.hpp:123
static constexpr index_t BN11
Definition blockwise_gemm_dl_v2r3.hpp:69
static constexpr index_t BN100
Definition blockwise_gemm_dl_v2r3.hpp:63
static constexpr index_t BM11
Definition blockwise_gemm_dl_v2r3.hpp:68
static constexpr index_t BN
Definition blockwise_gemm_dl_v2r3.hpp:60
__host__ static __device__ constexpr auto MakeABlockDescriptor_BK0_BM0_BM1_BK1(const ABlockDesc_BK0_BM_BK1 &a_block_desc_bk0_bm_bk1)
Definition blockwise_gemm_dl_v2r3.hpp:78
MultiIndex< 3 > BIndex
Definition blockwise_gemm_dl_v2r3.hpp:49
static constexpr auto b_block_desc_bk0_bn0_bn1_bk1_
Definition blockwise_gemm_dl_v2r3.hpp:149
static constexpr auto I3
Definition blockwise_gemm_dl_v2r3.hpp:55
static constexpr index_t BM100
Definition blockwise_gemm_dl_v2r3.hpp:62
Definition utility/sequence.hpp:43
Definition functional2.hpp:33