blockwise_gemm_mx_pipeline_xdlops_base.hpp Source File

blockwise_gemm_mx_pipeline_xdlops_base.hpp Source File#

Composable Kernel: blockwise_gemm_mx_pipeline_xdlops_base.hpp Source File
blockwise_gemm_mx_pipeline_xdlops_base.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
11
12namespace ck {
13
14template <index_t BlockSize,
15 typename ADataType,
16 typename BDataType,
17 typename ATileDesc,
18 typename BTileDesc,
19 typename AMmaTileDesc,
20 typename BMmaTileDesc,
21 index_t ABlockTransferSrcScalarPerVector,
22 index_t BBlockTransferSrcScalarPerVector,
23 index_t MPerBlock,
24 index_t NPerBlock,
25 index_t KPerBlock,
26 index_t MPerXDL,
27 index_t NPerXDL,
28 index_t MRepeat,
29 index_t NRepeat,
30 index_t KPack,
31 bool TransposeC = false>
33{
34 using ComputeTypeA = ADataType;
35 using ComputeTypeB = BDataType;
36 using AccType = float; // for now only support V_MFMA_SCALE_F32
37
40
41 static constexpr auto I0 = Number<0>{};
42 static constexpr auto I1 = Number<1>{};
43 static constexpr auto I2 = Number<2>{};
44 static constexpr auto I3 = Number<3>{};
45
47
48 // Hardcode to 64, as HIP-provided "WarpSize" would return 32 on RDNA GPUs.
49 static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL);
50 static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL);
51 static constexpr index_t WaveSize = BlockSize / MWaves / NWaves;
52
53 static constexpr index_t A_K0 = ATileDesc{}.GetLength(I0);
54 static constexpr index_t B_K0 = BTileDesc{}.GetLength(I0);
55 static constexpr index_t A_K1 = ATileDesc{}.GetLength(I2);
56 // static constexpr index_t B_K1 = BTileDesc{}.GetLength(I2);
57 static constexpr index_t B_K1 =
58 BTileDesc{}.GetLength(Number < BTileDesc{}.GetNumOfDimension() == 4 ? 3 : 2 > {});
59
60 static constexpr auto xdlops_gemm = XdlopsGemm<ComputeTypeA,
61 MPerXDL,
62 NPerXDL,
63 KPack * APackedSize,
65 TransposeC,
66 true>{};
67
68 static constexpr index_t AMmaKStride = KPack;
69 static constexpr index_t BMmaKStride = KPack;
70
71 // store rows/cols into thread registers in chunks of 16 for FP8
72 // e.g. [k0,...,k15,k64,...,k79] or [k0,...,k15,k32,...,k47]
73 // or in chunks of 32 / APackedSize for FP6/FP4
74 static constexpr index_t KThreadChunk = (APackedSize == 1) ? 16 : 32 / APackedSize;
75
76 static_assert(APackedSize == BPackedSize, "APackedSize must be equal to BPackedSize for now");
77
78 static constexpr index_t KPerThread = KPerBlock / xdlops_gemm.K0PerXdlops;
79 static constexpr index_t KRepeat = KPerThread / KPack;
80 static constexpr index_t KPerInnerLoop = KPack;
81
82 // Hardcode to 2, for better 8-bit access pattern
83
84 static constexpr index_t MXdlPack = 2;
85 static constexpr index_t NXdlPack = 2;
86 static constexpr index_t KXdlPack = 2;
87
89 BlockSize,
90 MPerBlock,
91 NPerBlock,
92 KPerBlock,
93 ABlockTransferSrcScalarPerVector,
94 BBlockTransferSrcScalarPerVector,
95 A_K1,
96 B_K1,
97 A_K1,
98 B_K1,
99 MRepeat,
100 NRepeat,
101 MPerXDL,
102 NPerXDL,
103 xdlops_gemm.KPerXdlops,
105
106 static_assert(KPerThread % KPack == 0,
107 "Wrong KPack setting; try increasing KPerThread or decreasing KPack");
108
110 AccType,
111 MRepeat * NRepeat,
112 xdlops_gemm.GetRegSizePerXdlops(),
113 true>
115
116 __host__ __device__ constexpr auto& GetCThreadBuffer() { return c_thread_buf_; }
117
118 __device__ static auto GetWaveIdx()
119 {
120 const index_t thread_id = ThisThreadBlock::GetThreadId();
121
122 constexpr auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor(
126
127 return threadid_to_wave_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id));
128 }
129
130 __device__ static auto CalculateAThreadOriginDataIndex()
131 {
132 const auto wave_idx = GetWaveIdx();
133
134 const auto waveId_m = wave_idx[I0];
135
136 const auto xdlops_a_idx = xdlops_gemm.CalculateAThreadOriginDataIndex();
137
138 return make_tuple(0, waveId_m, 0, xdlops_a_idx[I1], KThreadChunk * xdlops_a_idx[I0]);
139 }
140
141 __device__ static auto CalculateBThreadOriginDataIndex()
142 {
143 const auto wave_idx = GetWaveIdx();
144
145 const auto waveId_n = wave_idx[I1];
146
147 const auto xdlops_b_idx = xdlops_gemm.CalculateBThreadOriginDataIndex();
148
149 return make_tuple(0, waveId_n, 0, xdlops_b_idx[I1], KThreadChunk * xdlops_b_idx[I0]);
150 }
151
152 template <index_t m0, index_t n0, index_t xdlops_i, index_t blk_i>
153 __device__ static auto
155 {
156 const auto wave_idx = GetWaveIdx();
157
158 const auto waveId_m = wave_idx[I0];
159 const auto waveId_n = wave_idx[I1];
160
161 const auto blk_idx = xdlops_gemm.GetBeginOfThreadBlk(xdlops_i, blk_i);
162
163 constexpr auto mrepeat_mwave_mperxdl_to_m_adaptor = make_single_stage_tensor_adaptor(
168
169 constexpr auto nrepeat_nwave_nperxdl_to_n_adaptor = make_single_stage_tensor_adaptor(
174
175 // We pack 2 mfma in M/N direction, so we need to divide by 2
176 const index_t c_thread_m = mrepeat_mwave_mperxdl_to_m_adaptor.CalculateBottomIndex(
177 make_tuple(m0 / MXdlPack, waveId_m, m0 % MXdlPack, blk_idx[I0]))[I0];
178 const index_t c_thread_n = nrepeat_nwave_nperxdl_to_n_adaptor.CalculateBottomIndex(
179 make_tuple(n0 / NXdlPack, waveId_n, n0 % NXdlPack, blk_idx[I1]))[I0];
180
181 return make_tuple(c_thread_m, c_thread_n);
182 }
183
185
203 __host__ __device__
206 : a_thread_copy_(a_origin), b_thread_copy_(b_origin)
207 {
208#if defined(__HIP_DEVICE_COMPILE__)
209 static_assert(AMmaTileDesc::IsKnownAtCompileTime() && BMmaTileDesc::IsKnownAtCompileTime(),
210 "wrong! Desc should be known at compile-time");
212 "ThisThreadBlock::GetNumOfThread() != MWaves * NWaves * WaveSize\n");
213
214 static_assert(MPerBlock % (MPerXDL * MRepeat) == 0 && NPerBlock % (NPerXDL * NRepeat) == 0,
215 "wrong!");
216#endif
217 }
218
219 // transposed XDL output supporting C_xdl' = B_xdl' * A_xdl'
220 __host__ __device__ static constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4()
221 {
222 constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths();
223
224 constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0];
225 constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1];
226 constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2];
227 constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3];
228
230 make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, I1, I1, N, M0, M1, M2));
231 }
232
233 // XDL output supporting C_xdl = A_xdl * B_xdl
234 __host__ __device__ static constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()
235 {
236 constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths();
237
238 constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0];
239 constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1];
240 constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2];
241 constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3];
242
244 make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, I1, I1, M0, M1, M2, N));
245 }
246
247 // XDL output supporting C_xdl = A_xdl * B_xdl, packed mfma
248 __host__ __device__ static constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3()
249 {
250 constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths();
251
252 constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0];
253 constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1];
254 constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2];
255 constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3];
256
258 Number<NRepeat / NXdlPack>{},
259 I1,
260 I1,
263 M0,
264 M1,
265 M2,
266 N));
267 }
268
269 __host__ __device__ static constexpr auto GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2()
270 {
271 constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths();
272
273 constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0];
274 constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1];
275 constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2];
276 constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3];
277
279 make_tuple(I1, Number<MRepeat>{}, Number<NRepeat>{}, I1, I1, M0, M1, M2, N));
280 }
281
282 // transposed XDL output supporting C_xdl' = B_xdl' * A_xdl'
283 __host__ __device__ static constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4()
284 {
285 constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2 =
291 Number<NPerXDL>{}));
292
293 return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_N2_N3_N4(c_block_desc_m0_n0_m1_n1_m2_n2);
294 }
295
296 // XDL output supporting C_xdl = A_xdl * B_xdl
297 __host__ __device__ static constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()
298 {
299 constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2 =
305 Number<NPerXDL>{}));
306
307 return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_block_desc_m0_n0_m1_n1_m2_n2);
308 }
309
310 // XDL output supporting C_xdl = A_xdl * B_xdl_packed mfma
311 __host__ __device__ static constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3()
312 {
313 constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2 =
315 Number<NRepeat / NXdlPack>{},
321 Number<NPerXDL>{}));
322
323 return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3(
324 c_block_desc_m0_n0_m1_n1_m2_n2);
325 }
326
327 __host__ __device__ static constexpr auto GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2()
328 {
329 constexpr auto c_block_desc_g_m0_n0_m1_n1_m2_n2 =
336 Number<NPerXDL>{}));
337
338 return xdlops_gemm.MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(
339 c_block_desc_g_m0_n0_m1_n1_m2_n2);
340 }
341
342 template <typename CGridDesc_M_N>
343 __host__ __device__ static constexpr auto
344 MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_M_N& c_grid_desc_m_n)
345 {
346 const auto M = c_grid_desc_m_n.GetLength(I0);
347 const auto N = c_grid_desc_m_n.GetLength(I1);
348
349 const auto c_grid_desc_m0_n0_m1_n1_m2_n2 = transform_tensor_descriptor(
350 c_grid_desc_m_n,
351 make_tuple(make_unmerge_transform(make_tuple(M / (MWaves * MPerXDL), MWaves, MPerXDL)),
352 make_unmerge_transform(make_tuple(N / (NWaves * NPerXDL), NWaves, NPerXDL))),
355
356 return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m0_n0_m1_n1_m2_n2);
357 }
358
359 template <typename CGridDesc_G_M_N>
360 __host__ __device__ static constexpr auto
361 MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_G_M_N& c_grid_desc_g_m_n)
362 {
363 const auto G = c_grid_desc_g_m_n.GetLength(I0);
364 const auto M = c_grid_desc_g_m_n.GetLength(I1);
365 const auto N = c_grid_desc_g_m_n.GetLength(I2);
366
367 const auto c_grid_desc_g_m0_n0_m1_n1_m2_n2 = transform_tensor_descriptor(
368 c_grid_desc_g_m_n,
370 make_unmerge_transform(make_tuple(M / (MWaves * MPerXDL), MWaves, MPerXDL)),
371 make_unmerge_transform(make_tuple(N / (NWaves * NPerXDL), NWaves, NPerXDL))),
374
375 return xdlops_gemm.MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(
376 c_grid_desc_g_m0_n0_m1_n1_m2_n2);
377 }
378
379 __host__ __device__ static constexpr auto GetCThreadDesc() { return c_thread_desc_; }
380
381 static constexpr AMmaTileDesc a_block_desc_m0_m1_m2_m3_k;
382 static constexpr BMmaTileDesc b_block_desc_n0_n1_n2_n3_k;
383
384 protected:
385 // M1, N1 as double buffer index
386 // Read buffer + Compute buffer
387 // A[M0, M1, M2, KPack]
390
391 // B[N0, N1, N2, KPack]
394
395 // C[M, N, NumRegXdlops]
396 static constexpr auto c_thread_desc_ =
398 Number<NRepeat / NXdlPack>{},
401 xdlops_gemm.GetRegSizePerXdlops()));
402
406 decltype(a_thread_desc_),
409 4,
410 A_K1,
411 A_K1>;
412
416 decltype(b_thread_desc_),
419 4,
420 B_K1,
421 B_K1>;
422
425};
426
427} // namespace ck
Definition ck.hpp:268
__host__ __device__ constexpr auto make_multi_index(Xs &&... xs)
Definition array_multi_index.hpp:15
__host__ __device__ constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition multi_index_transform_helper.hpp:12
int32_t index_t
Definition ck.hpp:299
__host__ __device__ constexpr auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition tensor_description/tensor_adaptor.hpp:425
integral_constant< index_t, N > Number
Definition number.hpp:12
@ Vgpr
Definition amd_address_space.hpp:20
__host__ __device__ constexpr auto make_merge_transform(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:55
__host__ __device__ constexpr auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition tensor_descriptor_helper.hpp:101
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
__host__ __device__ constexpr auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:90
constexpr index_t packed_size_v
Definition data_type.hpp:411
static __device__ auto GetWaveIdx()
Definition blockwise_gemm_mx_pipeline_xdlops_base.hpp:118
float AccType
Definition blockwise_gemm_mx_pipeline_xdlops_base.hpp:36
ck::BlockwiseGemmXdlops_pipeline_hotloop_inst< BlockSize, MPerBlock, NPerBlock, KPerBlock, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, A_K1, B_K1, A_K1, B_K1, MRepeat, NRepeat, MPerXDL, NPerXDL, xdlops_gemm.KPerXdlops,(packed_size_v< ComputeTypeA > > 1||packed_size_v< ComputeTypeB > > 1)> HotLoopInstList
Definition blockwise_gemm_mx_pipeline_xdlops_base.hpp:88
ADataType ComputeTypeA
Definition blockwise_gemm_mx_pipeline_xdlops_base.hpp:34
__host__ static __device__ constexpr auto MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_M_N &c_grid_desc_m_n)
Definition blockwise_gemm_mx_pipeline_xdlops_base.hpp:344
__host__ static __device__ constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4()
Definition blockwise_gemm_mx_pipeline_xdlops_base.hpp:220
__host__ static __device__ constexpr auto GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2()
Definition blockwise_gemm_mx_pipeline_xdlops_base.hpp:269
__host__ static __device__ constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()
Definition blockwise_gemm_mx_pipeline_xdlops_base.hpp:297
decltype(CalculateAThreadOriginDataIndex()) Tuple5
Definition blockwise_gemm_mx_pipeline_xdlops_base.hpp:184
ThreadwiseTensorSliceTransfer_v4< BDataType, ComputeTypeB, decltype(b_block_desc_n0_n1_n2_n3_k), decltype(b_thread_desc_), Sequence< 1, 1, 1, 1, KThreadChunk >, Sequence< 0, 1, 2, 3, 4 >, 4, B_K1, B_K1 > BThreadCopy
Definition blockwise_gemm_mx_pipeline_xdlops_base.hpp:413
__host__ static __device__ constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3()
Definition blockwise_gemm_mx_pipeline_xdlops_base.hpp:248
ThreadwiseTensorSliceTransfer_v4< ADataType, ComputeTypeA, decltype(a_block_desc_m0_m1_m2_m3_k), decltype(a_thread_desc_), Sequence< 1, 1, 1, 1, KThreadChunk >, Sequence< 0, 1, 2, 3, 4 >, 4, A_K1, A_K1 > AThreadCopy
Definition blockwise_gemm_mx_pipeline_xdlops_base.hpp:403
static __device__ auto CalculateCThreadOriginDataIndex(Number< m0 >, Number< n0 >, Number< xdlops_i >, Number< blk_i >)
Definition blockwise_gemm_mx_pipeline_xdlops_base.hpp:154
__host__ static __device__ constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4()
Definition blockwise_gemm_mx_pipeline_xdlops_base.hpp:283
__host__ static __device__ constexpr auto MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_G_M_N &c_grid_desc_g_m_n)
Definition blockwise_gemm_mx_pipeline_xdlops_base.hpp:361
static __device__ auto CalculateBThreadOriginDataIndex()
Definition blockwise_gemm_mx_pipeline_xdlops_base.hpp:141
ThisThreadBlock< BlockSize > ThisThreadBlock
Definition blockwise_gemm_mx_pipeline_xdlops_base.hpp:46
__host__ static __device__ constexpr auto GetCThreadDesc()
Definition blockwise_gemm_mx_pipeline_xdlops_base.hpp:379
__host__ static __device__ constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3()
Definition blockwise_gemm_mx_pipeline_xdlops_base.hpp:311
__host__ static __device__ constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()
Definition blockwise_gemm_mx_pipeline_xdlops_base.hpp:234
__host__ __device__ constexpr auto & GetCThreadBuffer()
Definition blockwise_gemm_mx_pipeline_xdlops_base.hpp:116
BDataType ComputeTypeB
Definition blockwise_gemm_mx_pipeline_xdlops_base.hpp:35
__host__ __device__ BlockwiseGemmXdlops_mx_pipeline_base(Tuple5 a_origin=CalculateAThreadOriginDataIndex(), Tuple5 b_origin=CalculateBThreadOriginDataIndex())
Constructor for BlockwiseGemmXdlops_mx_pipeline_base.
Definition blockwise_gemm_mx_pipeline_xdlops_base.hpp:204
__host__ static __device__ constexpr auto GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2()
Definition blockwise_gemm_mx_pipeline_xdlops_base.hpp:327
static __device__ auto CalculateAThreadOriginDataIndex()
Definition blockwise_gemm_mx_pipeline_xdlops_base.hpp:130
Definition blockwise_gemm_pipeline_xdlops.hpp:34
Definition utility/sequence.hpp:43
Definition static_buffer.hpp:75
static __device__ constexpr index_t GetNumOfThread()
Definition thread_group.hpp:15
static __device__ index_t GetThreadId()
Definition thread_group.hpp:19
Definition threadwise_tensor_slice_transfer.hpp:1260
Definition xdlops_gemm.hpp:1821