gridwise_gemm_xdl_waveletmodel_cshuffle.hpp Source File

gridwise_gemm_xdl_waveletmodel_cshuffle.hpp Source File#

Composable Kernel: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp Source File
gridwise_gemm_xdl_waveletmodel_cshuffle.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
17
18namespace ck {
19
20template <typename ABDataType,
21 typename FloatGemmAcc,
22 typename EDataTypeShuffle,
23 typename EDataType,
24 typename AElementwiseOperation,
25 typename BElementwiseOperation,
26 typename EElementwiseOperation,
27 InMemoryDataOperationEnum CGlobalMemoryDataOperation,
28 typename AGridDesc_M_K,
29 typename BGridDesc_N_K,
30 typename EGridDesc_M_N,
31 index_t NumGemmKPrefetchStage,
32 index_t TileLoadThreadGroupSize,
33 index_t TileMathThreadGroupSize,
34 index_t MPerBlock,
35 index_t NPerBlock,
36 index_t KPerBlock,
37 index_t AK1Value,
38 index_t BK1Value,
39 index_t MPerXdl,
40 index_t NPerXdl,
41 index_t MXdlPerWave,
42 index_t NXdlPerWave,
43 typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
44 typename ABlockTransferThreadClusterArrangeOrder,
45 typename ABlockTransferSrcAccessOrder,
46 index_t ABlockTransferSrcVectorDim,
47 index_t ABlockTransferSrcScalarPerVector,
48 index_t ABlockTransferDstScalarPerVector_AK1,
49 bool AThreadTransferSrcResetCoordinateAfterRun,
50 index_t ABlockLdsExtraM,
51 typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
52 typename BBlockTransferThreadClusterArrangeOrder,
53 typename BBlockTransferSrcAccessOrder,
54 index_t BBlockTransferSrcVectorDim,
55 index_t BBlockTransferSrcScalarPerVector,
56 index_t BBlockTransferDstScalarPerVector_BK1,
57 bool BThreadTransferSrcResetCoordinateAfterRun,
58 index_t BBlockLdsExtraN,
59 index_t CShuffleMXdlPerWavePerShuffle,
60 index_t CShuffleNXdlPerWavePerShuffle,
61 typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
62 index_t CShuffleBlockTransferScalarPerVector_NPerBlock>
64{
65
66 static constexpr auto I0 = Number<0>{};
67 static constexpr auto I1 = Number<1>{};
68 static constexpr auto I2 = Number<2>{};
69 static constexpr auto I3 = Number<3>{};
70 static constexpr auto I4 = Number<4>{};
71 static constexpr auto I5 = Number<5>{};
72 static constexpr auto I6 = Number<6>{};
73 static constexpr auto I7 = Number<7>{};
74
75 // K1 should be Number<...>
76 static constexpr auto AK1 = Number<AK1Value>{};
77 static constexpr auto BK1 = Number<BK1Value>{};
78 static constexpr auto AK0PerBlock = Number<KPerBlock / AK1Value>{};
79 static constexpr auto BK0PerBlock = Number<KPerBlock / BK1Value>{};
80 static constexpr auto BlockSize = math::max(TileLoadThreadGroupSize, TileMathThreadGroupSize);
81
83 {
84 __device__ static constexpr index_t GetNumOfThread() { return TileLoadThreadGroupSize; }
85
86 __device__ static constexpr bool IsBelong()
87 {
88 return (get_thread_local_1d_id() >= TileLoadThreadGroupSize);
89 }
90
91 __device__ static index_t GetThreadId()
92 {
93 return get_thread_local_1d_id() - TileMathThreadGroupSize;
94 }
95 };
96
98 {
99 __device__ static constexpr index_t GetNumOfThread() { return TileMathThreadGroupSize; }
100
101 __device__ static constexpr bool IsBelong()
102 {
103 return get_thread_local_1d_id() < TileMathThreadGroupSize;
104 }
105
106 __device__ static index_t GetThreadId() { return get_thread_local_1d_id(); }
107 };
108
110
111 // load and math+store Wave pipelines.
112 // TODO: build pipelines blocks scheduling parallel tasks
115
116 __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
117 {
118 // A matrix in LDS memory, dst of blockwise copy
122 }
123
124 __host__ __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
125 {
126 // B matrix in LDS memory, dst of blockwise copy
130 }
131
132 __host__ __device__ static constexpr auto
134 {
135 constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
136 constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
137
138 constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
142 I1,
144
145 return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock;
146 }
147
148 __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
149 {
150 // LDS allocation for A and B: be careful of alignment
151 constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
152 constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
153
154 // lds max alignment
155 constexpr auto max_lds_align = math::lcm(AK1, BK1);
156
157 constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
158 a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
159
160 constexpr auto b_block_space_size_aligned = math::integer_least_multiple(
161 b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
162
163 // LDS allocation for C shuffle in LDS
164 constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
166
167 constexpr auto c_block_size =
168 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
169
170 return math::max((a_block_space_size_aligned + b_block_space_size_aligned) *
171 sizeof(ABDataType),
172 c_block_size * sizeof(EDataTypeShuffle));
173 }
174
175 template <
176 InMemoryDataOperationEnum CGlobalMemoryDataOperation_ = InMemoryDataOperationEnum::Set>
177 __device__ static bool constexpr IsValidCompilationParameter()
178 {
179 return ck::tensor_operation::device::IsValidGemmCompilationParameter<
180 BlockSize,
181 MPerBlock,
182 NPerBlock,
183 MPerXdl,
184 NPerXdl,
185 MXdlPerWave,
186 NXdlPerWave,
187 EDataType,
188 CGlobalMemoryDataOperation>();
189 }
190
191 // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
192 template <typename Block2ETileMap>
193 __host__ __device__ static constexpr bool
194 CheckValidity(const AGridDesc_M_K& a_grid_desc_m_k,
195 const BGridDesc_N_K& b_grid_desc_n_k,
196 const EGridDesc_M_N& e_grid_desc_m_n,
197 const Block2ETileMap& /*block_2_etile_map*/)
198 {
199 static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
200 (NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
201 "Invalid tuning param!");
202
203 const auto M = a_grid_desc_m_k.GetLength(I0);
204 const auto N = b_grid_desc_n_k.GetLength(I0);
205 const auto K = a_grid_desc_m_k.GetLength(I1);
206
207 // check consistency of desc
208 if(!(M == e_grid_desc_m_n.GetLength(I0) && N == e_grid_desc_m_n.GetLength(I1) &&
209 K == b_grid_desc_n_k.GetLength(I1)))
210 {
211 return false;
212 }
213
214 // check tile size
215 if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0))
216 {
217 return false;
218 }
219
220 // check gridwise gemm pipeline
221 const auto num_k_loop = K / KPerBlock;
222
223 if(!GridwiseGemmMath::IsSupported(num_k_loop))
224 {
225 return false;
226 }
227
228 // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
229
230 // check tensor size: cannot be larger than 2GB each
231 constexpr long_index_t TwoGB = (long_index_t{1} << 31);
232
233 if(!(a_grid_desc_m_k.GetElementSpaceSize() * sizeof(ABDataType) <= TwoGB &&
234 b_grid_desc_n_k.GetElementSpaceSize() * sizeof(ABDataType) <= TwoGB &&
235 e_grid_desc_m_n.GetElementSpaceSize() * sizeof(EDataType) <= TwoGB))
236 {
237 return false;
238 }
239
240 return true;
241 }
242
243 __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
244 {
245 const index_t num_loop = K / KPerBlock;
246
247 return GridwiseGemmMath::CalculateHasMainLoop(num_loop);
248 }
249
250 // return block_id to E matrix tile idx (m0, n0) mapping
251 __host__ __device__ static constexpr auto
252 MakeDefaultBlock2ETileMap(const EGridDesc_M_N& e_grid_desc_m_n)
253 {
254 const auto M = e_grid_desc_m_n.GetLength(I0);
255 const auto N = e_grid_desc_m_n.GetLength(I1);
256
257 constexpr auto M1 = Number<MPerBlock>{};
258 constexpr auto N1 = Number<NPerBlock>{};
259
260 const auto M0 = M / M1;
261 const auto N0 = N / N1;
262
263 constexpr auto M01 = I1;
264 constexpr auto N01 = I1;
265
266 const auto m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor =
272
273 const auto cblockid_to_m00_m01_n00_n01_block_cluster_adaptor =
275 make_tuple(make_merge_transform(make_tuple(M0, N0, M01, N01))),
278
279 const auto cblockid_to_m0_n0_block_cluster_adaptor =
280 chain_tensor_adaptors(m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor,
281 cblockid_to_m00_m01_n00_n01_block_cluster_adaptor);
282
283 return cblockid_to_m0_n0_block_cluster_adaptor;
284 }
285
286 __host__ __device__ static constexpr index_t
287 CalculateGridSize(const EGridDesc_M_N& e_grid_desc_m_n)
288 {
289 const auto M = e_grid_desc_m_n.GetLength(I0);
290 const auto N = e_grid_desc_m_n.GetLength(I1);
291
292 const index_t grid_size = (M / MPerBlock) * (N / NPerBlock);
293
294 return grid_size;
295 }
296
297 // A desc for source in blockwise copy
298 __host__ __device__ static constexpr auto
299 MakeDefaultAGridDescriptor_AK0_M_AK1(const AGridDesc_M_K& a_grid_desc_m_k)
300 {
301 const auto M = a_grid_desc_m_k.GetLength(I0);
302 const auto K = a_grid_desc_m_k.GetLength(I1);
303
304 const auto AK0 = K / AK1;
305
306 return transform_tensor_descriptor(a_grid_desc_m_k,
311 }
312
313 // B desc for source in blockwise copy
314 __host__ __device__ static constexpr auto
315 MakeDefaultBGridDescriptor_BK0_N_BK1(const BGridDesc_N_K& b_grid_desc_n_k)
316 {
317 const auto N = b_grid_desc_n_k.GetLength(I0);
318 const auto K = b_grid_desc_n_k.GetLength(I1);
319
320 const auto BK0 = K / BK1;
321
322 return transform_tensor_descriptor(b_grid_desc_n_k,
327 }
328
329 // E desc for destination in blockwise copy
330 template <typename EGridDescriptor_M_N>
331 __host__ __device__ static constexpr auto MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
332 const EGridDescriptor_M_N& e_grid_desc_m_n)
333 {
334 const auto M = e_grid_desc_m_n.GetLength(I0);
335 const auto N = e_grid_desc_m_n.GetLength(I1);
336
337 const auto MBlock = M / MPerBlock;
338 const auto NBlock = N / NPerBlock;
339
340 const auto e_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor(
341 e_grid_desc_m_n,
346
347 return e_grid_desc_mblock_mperblock_nblock_nperblock;
348 }
349
352 EGridDesc_M_N{}))>;
353
355 remove_cvref_t<decltype(MakeDefaultBlock2ETileMap(EGridDesc_M_N{}))>;
356
357 template <bool HasMainKBlockLoop,
358 typename AGridDesc_AK0_M_AK1,
359 typename BGridDesc_BK0_N_BK1,
360 typename Block2ETileMap>
361 __device__ static void Run(const ABDataType* __restrict__ p_a_grid,
362 const ABDataType* __restrict__ p_b_grid,
363 EDataType* __restrict__ p_e_grid,
364 void* __restrict__ p_shared,
365 const AElementwiseOperation& a_element_op,
366 const BElementwiseOperation& b_element_op,
367 const EElementwiseOperation& e_element_op,
368 const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1,
369 const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1,
371 e_grid_desc_mblock_mperblock_nblock_nperblock,
372 const Block2ETileMap& block_2_etile_map)
373 {
374 // build loadWave and MathWave pipelines
375 // loadWave and MathWave synchronized through LDS
376
377 // A matrix in LDS memory, dst of blockwise copy
378 constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
379
380 // B matrix in LDS memory, dst of blockwise copy
381 constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
382
383 // lds max alignment
384 constexpr auto max_lds_align = math::lcm(AK1, BK1);
385
386 // LDS allocation for A and B: be careful of alignment
387 constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
388 a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
389
391 static_cast<ABDataType*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
392
394 static_cast<ABDataType*>(p_shared) + a_block_space_size_aligned,
395 b_block_desc_bk0_n_bk1.GetElementSpaceSize());
396
397 constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1, 0, 0);
398 constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1, 0, 0);
399
400 const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
401 (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
402 KPerBlock);
403
404 // divide block work by [M, N]
405 const auto block_work_idx =
406 block_2_etile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
407
408 // HACK: this force m/n_block_data_idx_on_grid into SGPR
409 const index_t m_block_data_idx_on_grid =
410 __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock);
411
412 const index_t n_block_data_idx_on_grid =
413 __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock);
414
416 {
417
418 // LoadWave
419 const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
420 p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
421 const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
422 p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
423
424 // A matrix blockwise copy
425 auto a_blockwise_copy =
426 ThreadGroupTensorSliceTransfer_v4r1<TileLoadThreadGroup,
427 AElementwiseOperation,
431 ABlockTransferThreadClusterLengths_AK0_M_AK1,
432 ABlockTransferThreadClusterArrangeOrder,
433 ABDataType,
434 ABDataType,
435 decltype(a_grid_desc_ak0_m_ak1),
436 decltype(a_block_desc_ak0_m_ak1),
437 ABlockTransferSrcAccessOrder,
439 ABlockTransferSrcVectorDim,
440 2,
441 ABlockTransferSrcScalarPerVector,
442 ABlockTransferDstScalarPerVector_AK1,
443 1,
444 1,
445 AThreadTransferSrcResetCoordinateAfterRun,
446 true,
447 NumGemmKPrefetchStage>(
448 a_grid_desc_ak0_m_ak1,
449 make_multi_index(0, m_block_data_idx_on_grid, 0),
450 a_element_op,
451 a_block_desc_ak0_m_ak1,
452 make_multi_index(0, 0, 0),
454
455 // B matrix blockwise copy
456 auto b_blockwise_copy =
457 ThreadGroupTensorSliceTransfer_v4r1<TileLoadThreadGroup,
458 BElementwiseOperation,
462 BBlockTransferThreadClusterLengths_BK0_N_BK1,
463 BBlockTransferThreadClusterArrangeOrder,
464 ABDataType,
465 ABDataType,
466 decltype(b_grid_desc_bk0_n_bk1),
467 decltype(b_block_desc_bk0_n_bk1),
468 BBlockTransferSrcAccessOrder,
470 BBlockTransferSrcVectorDim,
471 2,
472 BBlockTransferSrcScalarPerVector,
473 BBlockTransferDstScalarPerVector_BK1,
474 1,
475 1,
476 BThreadTransferSrcResetCoordinateAfterRun,
477 true,
478 NumGemmKPrefetchStage>(
479 b_grid_desc_bk0_n_bk1,
480 make_multi_index(0, n_block_data_idx_on_grid, 0),
481 b_element_op,
482 b_block_desc_bk0_n_bk1,
483 make_multi_index(0, 0, 0),
485
486 GridwiseGemmLoad::template RunLoadWavePipeline<HasMainKBlockLoop>(
487 a_grid_desc_ak0_m_ak1,
488 a_block_desc_ak0_m_ak1,
489 a_blockwise_copy,
490 a_grid_buf,
491 a_block_buf,
492 a_block_slice_copy_step,
493 b_grid_desc_bk0_n_bk1,
494 b_block_desc_bk0_n_bk1,
495 b_blockwise_copy,
496 b_grid_buf,
497 b_block_buf,
498 b_block_slice_copy_step,
499 num_k_block_main_loop);
500
503 }
505 {
506 // branch early for math wave
507 constexpr auto lcm_AK1_BK1 = math::lcm(AK1, BK1);
508 constexpr bool is_single_rate_mfma =
510 lcm_AK1_BK1 <= 4) ||
511 (is_same<ABDataType, int8_t>::value && lcm_AK1_BK1 <= 8) ||
513 lcm_AK1_BK1 < 32))
514 ? true
515 : false;
516 constexpr auto is_scale_mfma = false;
517 constexpr index_t KPack =
518 math::max(lcm_AK1_BK1,
519 MfmaSelector<ABDataType,
520 MPerXdl,
521 NPerXdl,
522 ABDataType,
523 is_single_rate_mfma,
524 is_scale_mfma>::selected_mfma.k_per_blk);
525
527 TileMathThreadGroupSize,
528 ABDataType,
529 ABDataType,
530 FloatGemmAcc,
531 decltype(a_block_desc_ak0_m_ak1),
532 decltype(b_block_desc_bk0_n_bk1),
533 MPerXdl,
534 NPerXdl,
535 MXdlPerWave,
536 NXdlPerWave,
537 KPack>{};
538
539 auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();
541 p_e_grid, e_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
542
543 // TODO re-architect LDS+math stages
544 // Writing data to GMEM: only math wave is doing the work in cshuffle
545 GridwiseGemmMath::template RunMathWavePipeline<HasMainKBlockLoop>(
546 a_block_buf, b_block_buf, blockwise_gemm, c_thread_buf, num_k_block_main_loop);
547
548 // GEMM definition
549 // c_mtx += transpose(a_mtx) * b_mtx
550 // a_mtx[K0PerBlock, MPerBlock] is in LDS
551 // b_mtx[K0PerBlock, NPerBlock] is in LDS
552 // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
553 // register
554 // sanity check
555
556 // shuffle C and write out
557 {
558 static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
559 NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
560 "wrong!");
561
562 constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
563 constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
564
565 // TODO: hacky, fix it!
566 constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
567 blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
568
569 // TODO: hacky, fix it!
570 // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
571 constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
572 blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
573
574 constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0);
575 constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1);
576 constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2);
577 constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3);
578 constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4);
579 constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5);
580 constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6);
581 constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7);
582
583 constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
585
586 auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
587 static_cast<EDataTypeShuffle*>(p_shared),
588 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
589
590 constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
591 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
595 Number<CShuffleMXdlPerWavePerShuffle>{}, // M0 (MXdlPerWave) per shuffle
596 M1, // M1 = MWave
597 M2, // M2 * M3 * M4 = MPerXdl
598 M3,
599 M4)),
602 Number<CShuffleNXdlPerWavePerShuffle>{}, // N0 (NXdlPerWave) per shuffle
603 N1, // N1 = NWave
604 N2))), // N2 = NPerXdl
608 Sequence<>{},
610
611 // calculate origin of thread output tensor on global memory
612 // blockwise GEMM c matrix starting index
613 const auto c_thread_mtx_on_block =
614 blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
615
616 const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
617 const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
618
619 const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
621 make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
624
625 const auto m_thread_data_on_block_idx =
626 m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
627 make_multi_index(m_thread_data_on_block));
628
629 const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
634
635 const auto n_thread_data_on_block_idx =
636 n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
637 make_multi_index(n_thread_data_on_block));
638
639 // shuffle: threadwise copy C from VGPR to LDS
640 auto c_thread_copy_vgpr_to_lds = ThreadwiseTensorSliceTransfer_v1r3<
641 FloatGemmAcc,
642 EDataTypeShuffle,
643 decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
644 decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
646 Sequence<CShuffleMXdlPerWavePerShuffle,
647 CShuffleNXdlPerWavePerShuffle,
648 I1,
649 I1,
650 M2,
651 I1,
652 M4,
653 I1>,
655 7,
656 1,
658 1,
659 true>{c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
661 0,
662 m_thread_data_on_block_idx[I1],
663 n_thread_data_on_block_idx[I1],
664 m_thread_data_on_block_idx[I2],
665 m_thread_data_on_block_idx[I3],
666 m_thread_data_on_block_idx[I4],
667 n_thread_data_on_block_idx[I2]),
669
670 // shuffle: blockwise copy C from LDS to global
671 auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1<
673 EElementwiseOperation, // ElementwiseOperation,
674 CGlobalMemoryDataOperation, // DstInMemOp,
675 Sequence<1,
676 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
677 1,
678 CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
679 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
680 Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
681 EDataTypeShuffle, // typename SrcData,
682 EDataType, // typename DstData,
683 decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
684 decltype(e_grid_desc_mblock_mperblock_nblock_nperblock),
685 Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
686 3, // index_t VectorDim,
687 CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector,
688 true, // bool ThreadTransferSrcResetCoordinateAfterRun,
689 false> // bool ThreadTransferDstResetCoordinateAfterRun>
690 {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
691 make_multi_index(0, 0, 0, 0),
692 e_grid_desc_mblock_mperblock_nblock_nperblock,
693 make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0),
694 e_element_op};
695
696 // space filling curve for threadwise C in VGPR
697 constexpr auto sfc_c_vgpr =
700 Sequence<CShuffleMXdlPerWavePerShuffle,
701 CShuffleNXdlPerWavePerShuffle,
702 1,
703 1,
704 M2,
705 1,
706 M4,
707 1>>{};
708
709 // space filling curve for shuffled blockwise C in global mem
710 constexpr auto sfc_c_global =
713 Sequence<1,
714 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
715 1,
716 CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
717
718 constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
719
720 static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!");
721
722 // Different way of getting coalesced writes:
723 // We can get rid of doing cshuffle. Instead of reading A rows in contiguous manner
724 // do it interleaved, then mfma can have nice c-mat layout as below:
725 //
726 // TODO
727 // We do not need to do LDS swizzle to align global writes writing cache lines:
728 // v_mfma cmat, amat, bmat, cmat - c-mat register layout are 1xN
729 // elments (N is vertical or strided
730 // dimension)
731 // v_mfma cmat, bmat, amat, cmat - c-mat register layout are Mx1
732 // elments (M is coalescing
733 // dimension) by enumerating M index in
734 // amat, bmat you can align cmat
735 // register(s) to contiguous M elements
736 // for example
737 // 1st mfma instruction output space : 0 4 8 12 16 ....
738 // 2nd mfma instruction output space : 1 5 9 13 17 ....
739 // 3rd mfma instruction output space : 2 6 10 14 18 ....
740 // 4th mfma instruction output space : 3 7 11 15 19 ....
741 // you can pack 4 registers output space into 2WORD and do global write
742 // (no LDS swizzling required)
743
744 static_for<0, num_access, 1>{}([&](auto access_id) {
745 // make sure it's safe to write to LDS
747
748 // each thread write its data from VGPR to LDS
749 c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
750 sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
751 c_thread_buf,
752 c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
753 c_shuffle_block_buf);
754 // make sure it's safe to read from LDS
756
757 // each block copy its data from LDS to global
758 c_shuffle_block_copy_lds_to_global.Run(
759 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
760 c_shuffle_block_buf,
761 e_grid_desc_mblock_mperblock_nblock_nperblock,
762 c_grid_buf);
763
764 if constexpr(access_id < num_access - 1)
765 {
766 constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id);
767
768 // move on C
769 c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow(
770 e_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step);
771 }
772 });
773 }
774 }
775 }
776};
777
778} // namespace ck
__host__ __device__ constexpr auto integer_least_multiple(X x, Y y)
Definition utility/math.hpp:78
__host__ __device__ constexpr T max(T x)
Definition utility/math.hpp:84
__host__ __device__ constexpr auto lcm(X x, Y y)
Definition utility/math.hpp:198
Definition ck.hpp:268
__host__ __device__ constexpr auto make_multi_index(Xs &&... xs)
Definition array_multi_index.hpp:15
__host__ __device__ constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition multi_index_transform_helper.hpp:12
int32_t index_t
Definition ck.hpp:299
__host__ __device__ constexpr auto chain_tensor_adaptors(const TensorAdaptor0 &adaptor0, const TensorAdaptor1 &adaptor1)
Definition tensor_description/tensor_adaptor.hpp:245
__host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition tensor_descriptor_helper.hpp:49
InMemoryDataOperationEnum
Definition ck.hpp:277
@ Set
Definition ck.hpp:278
__host__ __device__ constexpr auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition tensor_description/tensor_adaptor.hpp:425
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
__host__ __device__ constexpr auto make_freeze_transform(const LowerIndex &low_idx)
Definition multi_index_transform_helper.hpp:151
integral_constant< index_t, N > Number
Definition number.hpp:12
__host__ __device__ constexpr auto make_merge_transform(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:55
__device__ index_t get_block_1d_id()
Definition get_id.hpp:47
__host__ __device__ constexpr auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition tensor_descriptor_helper.hpp:101
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__device__ index_t get_thread_local_1d_id()
Definition get_id.hpp:41
int64_t long_index_t
Definition ck.hpp:300
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
__device__ void block_sync_lds()
Definition synchronization.hpp:16
__host__ __device__ constexpr auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:90
__host__ __device__ constexpr auto make_dynamic_buffer(T *p, ElementSpaceSize element_space_size)
Definition dynamic_buffer.hpp:472
Definition blockwise_gemm_smfmac_xdlops.hpp:44
__host__ __device__ constexpr auto & GetCThreadBuffer()
Definition blockwise_gemm_smfmac_xdlops.hpp:78
Definition gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:83
static __device__ index_t GetThreadId()
Definition gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:91
static __device__ constexpr bool IsBelong()
Definition gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:86
static __device__ constexpr index_t GetNumOfThread()
Definition gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:84
Definition gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:98
static __device__ index_t GetThreadId()
Definition gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:106
static __device__ constexpr index_t GetNumOfThread()
Definition gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:99
static __device__ constexpr bool IsBelong()
Definition gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:101
Definition gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:64
static __device__ void Run(const ADataType *__restrict__ p_a_grid, const ADataType *__restrict__ p_b_grid, EDataType *__restrict__ p_e_grid, void *__restrict__ p_shared, const AElementwiseOperation &a_element_op, const BElementwiseOperation &b_element_op, const CDEElementwiseOperation &e_element_op, const AGridDesc_AK0_M_AK1 &a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_BK1 &b_grid_desc_bk0_n_bk1, const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock &e_grid_desc_mblock_mperblock_nblock_nperblock, const Block2ETileMap &block_2_etile_map)
Definition gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:361
Definition gridwise_gemm_waveletmodel.hpp:11
Definition gridwise_gemm_waveletmodel.hpp:103
Selects the appropriate MFMA instruction type and configuration for given data types and tile sizes o...
Definition xdlops_gemm.hpp:1208
Definition utility/sequence.hpp:43
Definition tensor_space_filling_curve.hpp:20
Definition thread_group.hpp:12
Blockwise data transfer.
Definition thread_group_tensor_slice_transfer_v4r1.hpp:46
Definition thread_group_tensor_slice_transfer_v6r1.hpp:34
Definition threadwise_tensor_slice_transfer.hpp:39
static constexpr value_type value
Definition utility/integral_constant.hpp:13
Definition functional2.hpp:33
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:340