blockwise_gemm_pipeline_wmmaops_base.hpp Source File

blockwise_gemm_pipeline_wmmaops_base.hpp Source File#

Composable Kernel: blockwise_gemm_pipeline_wmmaops_base.hpp Source File
blockwise_gemm_pipeline_wmmaops_base.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
12
13namespace ck {
14
15template <index_t BlockSize,
16 typename ADataType,
17 typename BDataType,
18 typename ComputeTypeA,
19 typename ComputeTypeB,
20 typename AccDataType,
21 typename AWmmaTileDesc,
22 typename BWmmaTileDesc,
23 index_t ABlockTransferSrcScalarPerVector,
24 index_t BBlockTransferSrcScalarPerVector,
25 index_t MPerBlock,
26 index_t NPerBlock,
27 index_t KPerBlock,
28 index_t MPerWmma,
29 index_t NPerWmma,
30 index_t MRepeat,
31 index_t NRepeat,
32 index_t KPack,
33 bool TransposeC = false>
35{
36 static constexpr auto I0 = Number<0>{};
37 static constexpr auto I1 = Number<1>{};
38 static constexpr auto I2 = Number<2>{};
39 static constexpr auto I3 = Number<3>{};
40 static constexpr auto I5 = Number<5>{};
41
43
44 static constexpr index_t WaveSize = 32;
45
46 static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerWmma);
47 static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerWmma);
48
49#if defined(__gfx12__)
50 static constexpr index_t A_KRow = 2;
51 static constexpr index_t B_KRow = 2;
52#else
53 static constexpr index_t A_KRow = 1;
54 static constexpr index_t B_KRow = 1;
55#endif
56
57 static constexpr index_t A_K1 = AWmmaTileDesc{}.GetLength(I5);
58 static constexpr index_t B_K1 = BWmmaTileDesc{}.GetLength(I5);
59
60 static_assert(KPack % (A_K1 * A_KRow) == 0, "wrong!");
61 static_assert(KPack % (B_K1 * B_KRow) == 0, "wrong!");
62
65
66 static constexpr index_t KRepeat = KPerBlock / KPack;
67
68 static constexpr auto WmmaK = Number<wmma_gemm.wmma_instr.k_per_wmma>{};
69
72 MPerBlock,
73 NPerBlock,
74 KPerBlock,
75 ABlockTransferSrcScalarPerVector,
76 BBlockTransferSrcScalarPerVector,
77 A_K1,
78 B_K1,
79 A_K1,
80 B_K1,
81 MRepeat,
82 NRepeat,
83 MPerWmma,
84 NPerWmma,
85 wmma_gemm.wmma_instr.k_per_wmma>;
86
88 AccDataType,
89 MRepeat * NRepeat,
90 wmma_gemm.GetRegSizePerWmma(),
91 true>
93
94 struct Empty
95 {
96 __device__ Empty() {};
97 template <index_t NBuffer>
98 __device__ void GlobalLoad(bool cond)
99 {
100 ignore = NBuffer;
101 ignore = cond;
102 }
103 };
104
105 template <index_t ScaleSliceSizeN,
106 index_t ScaleSliceSizeK,
108 index_t ScaleBlockK,
109 index_t NumberOfBuffers,
110 typename GridDesc,
111 typename ThreadCopy,
112 typename GridBuffer,
113 typename ThreadStaticBuffer,
114 typename BScaleThreadDesc>
115 struct BScale
116 {
117 __device__ BScale(GridDesc b_scale_grid_desc_,
118 ThreadCopy b_scale_thread_copy_,
119 GridBuffer b_scale_grid_buf_)
120 : b_scale_thread_copy(b_scale_thread_copy_),
121 b_scale_grid_desc(b_scale_grid_desc_),
122 b_scale_grid_buf(b_scale_grid_buf_) {};
123
124 static constexpr index_t num_scale_k_block = BScaleThreadDesc{}.GetLength(Number<1>{});
126
127 static constexpr auto b_scale_thread_desc = BScaleThreadDesc{};
128
129 static constexpr auto b_scale_thread_copy_step =
130 make_tuple(make_multi_index(NWaves * NPerWmma, 0),
131 make_multi_index(-NPerBlock, 0),
132 make_multi_index(-NPerBlock, (KPerBlock + ScaleBlockK - 1) / ScaleBlockK));
133
134 template <index_t NBuffer>
135 __device__ void GlobalLoad(bool cond)
136 {
137 static_for<0, NRepeat, 1>{}([&](auto n0) {
141 make_tuple(n0, Number<0>{}),
143
144 b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
146 });
147
148 if(cond)
149 {
150 b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
152 }
153 else
154 {
155 b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
157 }
158 }
159
164 };
165
166 __host__ __device__ constexpr auto& GetCThreadBuffer() { return c_thread_buf_; }
167
168 __device__ static auto GetWaveIdx()
169 {
170 const index_t thread_id = ThisThreadBlock::GetThreadId();
171
172 constexpr auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor(
176
177 return threadid_to_wave_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id));
178 }
179
180 __device__ static auto CalculateAThreadOriginDataIndex()
181 {
182 const auto wave_idx = GetWaveIdx();
183
184 const auto waveId_m = wave_idx[I0];
185
186 const auto wmma_a_idx = wmma_gemm.CalculateAThreadOriginDataIndex();
187
188#if defined(__gfx12__)
189 const auto wmma_krow = wmma_gemm.GetSubGroupId();
190#else
191 const auto wmma_krow = 0;
192#endif
193
194 // |KRepeat |MRepeat|MWave |KRow |MLane |KPack
195 return make_tuple(0, 0, waveId_m, wmma_krow, wmma_a_idx, 0);
196 }
197
198 __device__ static auto CalculateBThreadOriginDataIndex()
199 {
200 const auto wave_idx = GetWaveIdx();
201
202 const auto waveId_n = wave_idx[I1];
203
204 const auto wmma_b_idx = wmma_gemm.CalculateBThreadOriginDataIndex();
205
206#if defined(__gfx12__)
207 const auto wmma_krow = wmma_gemm.GetSubGroupId();
208#else
209 const auto wmma_krow = 0;
210#endif
211
212 // |KRepeat |NRepeat|Nwave |KRow |NLane |KPack
213 return make_tuple(0, 0, waveId_n, wmma_krow, wmma_b_idx, 0);
214 }
215
216 template <index_t m0, index_t n0>
218 {
219 const auto wave_idx = GetWaveIdx();
220
221 const auto waveId_m = wave_idx[I0];
222 const auto waveId_n = wave_idx[I1];
223
224 const auto blk_idx = wmma_gemm.GetBeginOfThreadBlk();
225
226 constexpr auto mrepeat_mwave_mperwmma_to_m_adaptor = make_single_stage_tensor_adaptor(
230
231 constexpr auto nrepeat_nwave_nperwmma_to_n_adaptor = make_single_stage_tensor_adaptor(
235
236 const index_t c_thread_m = mrepeat_mwave_mperwmma_to_m_adaptor.CalculateBottomIndex(
237 make_tuple(m0, waveId_m, blk_idx[I0]))[I0];
238 const index_t c_thread_n = nrepeat_nwave_nperwmma_to_n_adaptor.CalculateBottomIndex(
239 make_tuple(n0, waveId_n, blk_idx[I1]))[I0];
240
241 return make_tuple(c_thread_m, c_thread_n);
242 }
243
245
263 __host__ __device__
266 : a_thread_copy_(a_origin), b_thread_copy_(b_origin)
267 {
268 static_assert(AWmmaTileDesc::IsKnownAtCompileTime() &&
269 BWmmaTileDesc::IsKnownAtCompileTime(),
270 "wrong! Desc should be known at compile-time");
271
273 "ThisThreadBlock::GetNumOfThread() != MWaves * NWaves * WaveSize");
274
275 static_assert(MPerBlock % (MPerWmma * MRepeat) == 0 &&
276 NPerBlock % (NPerWmma * NRepeat) == 0,
277 "wrong!");
278 }
279
280 // transposed WMMA output C' = B' * A'
281 __host__ __device__ static constexpr auto
283 {
284 constexpr auto c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens =
285 wmma_gemm.GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths();
286
287 constexpr auto NAccVgprs = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I2];
288
290 // |MRepeat |MWave |MSubGroup |NRepeat |NWave
291 // |NThreadPerSubGroup |MAccVgprs
292 make_tuple(Number<MRepeat>{}, I1, I1, Number<NRepeat>{}, I1, I1, NAccVgprs));
293 }
294
295 static constexpr auto MAccVgprs =
296 wmma_gemm.GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths()[I2];
297
298 __host__ __device__ static constexpr auto
300 {
301 constexpr auto c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens =
302 wmma_gemm.GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths();
303
304 constexpr auto AccStride = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I3];
306 // |MRepeat |MWave |MSubGroup |NRepeat |NWave
307 // |NThreadPerSubGroup |MAccVgprs
309 make_tuple(Number<NRepeat>{} * MAccVgprs * AccStride,
310 Number<NRepeat>{} * MAccVgprs * AccStride,
311 Number<NRepeat>{} * MAccVgprs * AccStride,
312 MAccVgprs * AccStride,
313 MAccVgprs * AccStride,
314 MAccVgprs * AccStride,
315 AccStride));
316 }
317
318 __host__ __device__ static constexpr auto
320 {
321 constexpr auto c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma =
328
329 return wmma_gemm
330 .MakeCDesc_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs(
331 c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma);
332 }
333
334 // Describe how data allocated in thread copy src buffer
335 // M0_M1_M2 = MRepeat_MWave_MPerWmma, N0_N1_N2 = NRepeat_NWave_NPerWmma
336 static constexpr AWmmaTileDesc a_block_desc_k0_m0_m1_m2_k1;
337 static constexpr BWmmaTileDesc b_block_desc_k0_n0_n1_n2_k1;
338
339 protected:
340 static constexpr auto a_thread_desc_ =
344 I1,
345 I1,
346 Number<A_K1>{}),
348 Number<KPack / A_KRow>{},
349 Number<KPack / A_KRow * MRepeat>{},
350 I0,
351 I0,
352 I1));
353
354 static constexpr auto b_thread_desc_ =
358 I1,
359 I1,
360 Number<B_K1>{}),
362 Number<KPack / B_KRow>{},
363 Number<KPack / B_KRow * NRepeat>{},
364 I0,
365 I0,
366 I1));
367
368 // C[M, N, NumRegWmma]
370 make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, wmma_gemm.GetRegSizePerWmma()));
371
374 ComputeTypeA,
376 decltype(a_thread_desc_),
377 Sequence<KPack / A_K1 / A_KRow, 1, 1, 1, 1, A_K1>,
379 5,
380 A_K1,
381 A_K1>;
382
385 ComputeTypeB,
387 decltype(b_thread_desc_),
388 Sequence<KPack / B_K1 / B_KRow, 1, 1, 1, 1, B_K1>,
390 5,
391 B_K1,
392 B_K1>;
393
396};
397
398} // namespace ck
Definition ck.hpp:268
__host__ __device__ constexpr auto make_multi_index(Xs &&... xs)
Definition array_multi_index.hpp:15
typename detail::StaticallyIndexedArrayImpl< T, N >::type StaticallyIndexedArray
Definition utility/statically_indexed_array.hpp:45
int32_t index_t
Definition ck.hpp:299
__host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition tensor_descriptor_helper.hpp:49
__host__ __device__ constexpr auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition tensor_description/tensor_adaptor.hpp:425
integral_constant< index_t, N > Number
Definition number.hpp:12
@ Vgpr
Definition amd_address_space.hpp:20
__host__ __device__ constexpr auto make_merge_transform(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:55
constexpr detail::ignore_t ignore
Definition utility/ignore.hpp:20
__host__ __device__ constexpr auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition tensor_descriptor_helper.hpp:101
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__host__ __device__ constexpr auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:90
static constexpr index_t num_scale_krepeat
Definition blockwise_gemm_pipeline_wmmaops_base.hpp:125
GridBuffer b_scale_grid_buf
Definition blockwise_gemm_pipeline_wmmaops_base.hpp:162
__device__ BScale(GridDesc b_scale_grid_desc_, ThreadCopy b_scale_thread_copy_, GridBuffer b_scale_grid_buf_)
Definition blockwise_gemm_pipeline_wmmaops_base.hpp:117
StaticallyIndexedArray< ThreadStaticBuffer, Number< NumberOfBuffers >{}> b_scale_thread_bufs
Definition blockwise_gemm_pipeline_wmmaops_base.hpp:163
static constexpr auto b_scale_thread_copy_step
Definition blockwise_gemm_pipeline_wmmaops_base.hpp:129
static constexpr index_t num_scale_k_block
Definition blockwise_gemm_pipeline_wmmaops_base.hpp:124
static constexpr auto b_scale_thread_desc
Definition blockwise_gemm_pipeline_wmmaops_base.hpp:127
__device__ void GlobalLoad(bool cond)
Definition blockwise_gemm_pipeline_wmmaops_base.hpp:135
ThreadCopy b_scale_thread_copy
Definition blockwise_gemm_pipeline_wmmaops_base.hpp:160
GridDesc b_scale_grid_desc
Definition blockwise_gemm_pipeline_wmmaops_base.hpp:161
__device__ Empty()
Definition blockwise_gemm_pipeline_wmmaops_base.hpp:96
__device__ void GlobalLoad(bool cond)
Definition blockwise_gemm_pipeline_wmmaops_base.hpp:98
ThreadwiseTensorSliceTransfer_v4< ADataType, ComputeTypeA, decltype(a_block_desc_k0_m0_m1_m2_k1), decltype(a_thread_desc_), Sequence< KPack/A_K1/A_KRow, 1, 1, 1, 1, A_K1 >, Sequence< 0, 1, 2, 3, 4, 5 >, 5, A_K1, A_K1 > AThreadCopy
Definition blockwise_gemm_pipeline_wmmaops_base.hpp:372
ck::BlockwiseGemmWmmaops_pipeline_hotloop_inst< BlockSize, MPerBlock, NPerBlock, KPerBlock, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, A_K1, B_K1, A_K1, B_K1, MRepeat, NRepeat, MPerWmma, NPerWmma, wmma_gemm.wmma_instr.k_per_wmma > HotLoopInstList
Definition blockwise_gemm_pipeline_wmmaops_base.hpp:70
__host__ __device__ constexpr auto & GetCThreadBuffer()
Definition blockwise_gemm_pipeline_wmmaops_base.hpp:166
__host__ static __device__ constexpr auto GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs()
Definition blockwise_gemm_pipeline_wmmaops_base.hpp:319
static __device__ auto CalculateCThreadOriginDataIndex(Number< m0 >, Number< n0 >)
Definition blockwise_gemm_pipeline_wmmaops_base.hpp:217
static __device__ auto CalculateBThreadOriginDataIndex()
Definition blockwise_gemm_pipeline_wmmaops_base.hpp:198
__host__ static __device__ constexpr auto GetCThreadDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs()
Definition blockwise_gemm_pipeline_wmmaops_base.hpp:282
__host__ __device__ BlockwiseGemmWmmaops_pipeline_base(Tuple6 a_origin=CalculateAThreadOriginDataIndex(), Tuple6 b_origin=CalculateBThreadOriginDataIndex())
Constructor for BlockwiseGemmWmmaops_pipeline_base.
Definition blockwise_gemm_pipeline_wmmaops_base.hpp:264
ThreadwiseTensorSliceTransfer_v4< BDataType, ComputeTypeB, decltype(b_block_desc_k0_n0_n1_n2_k1), decltype(b_thread_desc_), Sequence< KPack/B_K1/B_KRow, 1, 1, 1, 1, B_K1 >, Sequence< 0, 1, 2, 3, 4, 5 >, 5, B_K1, B_K1 > BThreadCopy
Definition blockwise_gemm_pipeline_wmmaops_base.hpp:383
static __device__ auto GetWaveIdx()
Definition blockwise_gemm_pipeline_wmmaops_base.hpp:168
__host__ static __device__ constexpr auto GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs()
Definition blockwise_gemm_pipeline_wmmaops_base.hpp:299
ThisThreadBlock< BlockSize > ThisThreadBlock
Definition blockwise_gemm_pipeline_wmmaops_base.hpp:42
decltype(CalculateAThreadOriginDataIndex()) Tuple6
Definition blockwise_gemm_pipeline_wmmaops_base.hpp:244
static __device__ auto CalculateAThreadOriginDataIndex()
Definition blockwise_gemm_pipeline_wmmaops_base.hpp:180
Definition blockwise_gemm_pipeline_wmmaops.hpp:26
Definition utility/sequence.hpp:43
Definition static_buffer.hpp:75
static __device__ constexpr index_t GetNumOfThread()
Definition thread_group.hpp:15
static __device__ index_t GetThreadId()
Definition thread_group.hpp:19
Definition threadwise_tensor_slice_transfer.hpp:1260
Definition wmma_gemm.hpp:663
Definition functional2.hpp:33