gridwise_gemm_multiple_d_wmma_cshuffle.hpp Source File

gridwise_gemm_multiple_d_wmma_cshuffle.hpp Source File#

Composable Kernel: gridwise_gemm_multiple_d_wmma_cshuffle.hpp Source File
gridwise_gemm_multiple_d_wmma_cshuffle.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include "ck/utility/env.hpp"
18
19namespace ck {
20
21template <typename GridwiseOp,
22 typename ADataType,
23 typename BDataType,
24 typename DsPointer,
25 typename EDataType,
26 typename AElementwiseOperation,
27 typename BElementwiseOperation,
28 typename CDEElementwiseOperation,
29 typename AGridDesc_AK0_M_AK1,
30 typename BGridDesc_BK0_N_BK1,
31 typename DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
32 typename EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
33 typename Block2CTileMap,
34 typename ComputePtrOffsetOfBatch,
35 bool HasMainKBlockLoop>
36__global__ void
37#if CK_USE_LAUNCH_BOUNDS
39#endif
41 const ADataType* __restrict__ p_a_grid,
42 const BDataType* __restrict__ p_b_grid,
43 DsPointer p_ds_grid,
44 EDataType* __restrict__ p_e_grid,
45 const AElementwiseOperation a_element_op,
46 const BElementwiseOperation b_element_op,
47 const CDEElementwiseOperation cde_element_op,
48 const index_t batch_count,
49 const AGridDesc_AK0_M_AK1 a_grid_desc,
50 const BGridDesc_BK0_N_BK1 b_grid_desc,
51 const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
52 ds_grid_desc_mblock_mperblock_nblock_nperblock,
53 const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
54 e_grid_desc_mblock_mperblock_nblock_nperblock_,
55 const Block2CTileMap block_2_ctile_map,
56 const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch)
57{
58#if(defined(__gfx11__) || defined(__gfx12__))
59 // offset base pointer for each work-group
60 const index_t num_blocks_per_batch =
61 __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
62 const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
63
64 const long_index_t a_batch_offset = amd_wave_read_first_lane(
65 static_cast<long_index_t>(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)));
66 const long_index_t b_batch_offset = amd_wave_read_first_lane(
67 static_cast<long_index_t>(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)));
68 const long_index_t e_batch_offset = amd_wave_read_first_lane(
69 static_cast<long_index_t>(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx)));
70
71 const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx);
72
73 __shared__ char p_shared[GridwiseOp::SharedMemTrait::lds_size];
74
75 DsPointer p_ds_grid_grp;
76
77 static constexpr index_t NumDTensor =
78 DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock::Size();
79
81 [&](auto i) { p_ds_grid_grp(i) = p_ds_grid[i] + ds_batch_offset[i]; });
82
83 GridwiseOp::template Run<HasMainKBlockLoop>(p_a_grid + a_batch_offset,
84 p_b_grid + b_batch_offset,
85 p_ds_grid_grp,
86 p_e_grid + e_batch_offset,
87 p_shared,
88 a_grid_desc,
89 b_grid_desc,
90 ds_grid_desc_mblock_mperblock_nblock_nperblock,
91 e_grid_desc_mblock_mperblock_nblock_nperblock_,
92 a_element_op,
93 b_element_op,
94 cde_element_op,
95 block_2_ctile_map);
96#else
97 ignore = p_a_grid;
98 ignore = p_b_grid;
99 ignore = p_ds_grid;
100 ignore = p_e_grid;
101 ignore = batch_count;
102 ignore = a_grid_desc;
103 ignore = b_grid_desc;
104 ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock;
105 ignore = e_grid_desc_mblock_mperblock_nblock_nperblock_;
106 ignore = a_element_op;
107 ignore = b_element_op;
108 ignore = cde_element_op;
109 ignore = compute_ptr_offset_of_batch;
110 ignore = block_2_ctile_map;
111#endif
112}
113
114template <typename GridwiseOp,
115 typename ADataType,
116 typename BDataType,
117 typename DsPointer,
118 typename EDataType,
119 typename AGridDesc,
120 typename BGridDesc,
121 typename DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
122 typename EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
123 typename AElementwiseOperation,
124 typename BElementwiseOperation,
125 typename CDEElementwiseOperation,
126 typename ComputePtrOffsetOfBatch,
127 typename Block2CTileMap,
128 bool HasMainKBlockLoop>
129__global__ void
130#if CK_USE_LAUNCH_BOUNDS
132#endif
134 const ADataType* __restrict__ p_a_grid,
135 const BDataType* __restrict__ p_b_grid,
136 DsPointer p_ds_grid,
137 EDataType* __restrict__ p_e_grid,
138 const index_t batch_count,
139 const AGridDesc a_grid_desc,
140 const BGridDesc b_grid_desc,
141 const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
142 ds_grid_desc_mblock_mperblock_nblock_nperblock,
143 const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
144 e_grid_desc_mblock_mperblock_nblock_nperblock,
145 const AElementwiseOperation a_element_op,
146 const BElementwiseOperation b_element_op,
147 const CDEElementwiseOperation cde_element_op,
148 const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
149 const Block2CTileMap block_2_etile_map)
150{
151#if(defined(__gfx11__) || defined(__gfx12__))
152 // printf("entry kernel launch");
153 __shared__ char p_shared[GridwiseOp::SharedMemTrait::lds_size];
154
155 const index_t num_blocks_per_batch =
156 __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
157 const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
158
159 const long_index_t a_batch_offset = amd_wave_read_first_lane(
160 static_cast<long_index_t>(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)));
161 const long_index_t b_batch_offset = amd_wave_read_first_lane(
162 static_cast<long_index_t>(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)));
163 const long_index_t e_batch_offset = amd_wave_read_first_lane(
164 static_cast<long_index_t>(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx)));
165
166 const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx);
167
168 static constexpr index_t NumDTensor =
169 DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock::Size();
170
171 DsPointer p_ds_grid_grp;
172
174 [&](auto i) { p_ds_grid_grp(i) = p_ds_grid[i] + ds_batch_offset[i]; });
175
176 GridwiseOp::template Run<HasMainKBlockLoop>(p_a_grid + a_batch_offset,
177 p_b_grid + b_batch_offset,
178 p_ds_grid_grp,
179 p_e_grid + e_batch_offset,
180 p_shared,
181 a_grid_desc,
182 b_grid_desc,
183 ds_grid_desc_mblock_mperblock_nblock_nperblock,
184 e_grid_desc_mblock_mperblock_nblock_nperblock,
185 a_element_op,
186 b_element_op,
187 cde_element_op,
188 block_2_etile_map);
189#else
190 ignore = p_a_grid;
191 ignore = p_b_grid;
192 ignore = p_ds_grid;
193 ignore = p_e_grid;
194 ignore = batch_count;
195 ignore = a_element_op;
196 ignore = b_element_op;
197 ignore = cde_element_op;
198 ignore = a_grid_desc;
199 ignore = b_grid_desc;
200 ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock;
201 ignore = e_grid_desc_mblock_mperblock_nblock_nperblock;
202 ignore = block_2_etile_map;
203 ignore = compute_ptr_offset_of_batch;
204#endif
205}
206
207template <typename GridwiseOp,
208 typename ADataType,
209 typename BDataType,
210 typename DsPointer,
211 typename EDataType,
212 typename AGridDesc,
213 typename BGridDesc,
214 typename DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
215 typename EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
216 typename AElementwiseOperation,
217 typename BElementwiseOperation,
218 typename CDEElementwiseOperation,
219 typename Block2CTileMap,
220 bool HasMainKBlockLoop>
221__global__ void
222#if CK_USE_LAUNCH_BOUNDS
224#endif
225 kernel_gemm_mupltipe_d_wmma_cshuffle(const ADataType* __restrict__ p_a_grid,
226 const BDataType* __restrict__ p_b_grid,
227 DsPointer p_ds_grid,
228 EDataType* __restrict__ p_e_grid,
229 const AGridDesc a_grid_desc,
230 const BGridDesc b_grid_desc,
231 const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
232 ds_grid_desc_mblock_mperblock_nblock_nperblock,
233 const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
234 e_grid_desc_mblock_mperblock_nblock_nperblock,
235 const AElementwiseOperation a_element_op,
236 const BElementwiseOperation b_element_op,
237 const CDEElementwiseOperation cde_element_op,
238 const Block2CTileMap block_2_ctile_map)
239{
240#if(defined(__gfx11__) || defined(__gfx12__))
241 __shared__ char p_shared[GridwiseOp::SharedMemTrait::lds_size];
242
243 GridwiseOp::template Run<HasMainKBlockLoop>(p_a_grid,
244 p_b_grid,
245 p_ds_grid,
246 p_e_grid,
247 p_shared,
248 a_grid_desc,
249 b_grid_desc,
250 ds_grid_desc_mblock_mperblock_nblock_nperblock,
251 e_grid_desc_mblock_mperblock_nblock_nperblock,
252 a_element_op,
253 b_element_op,
254 cde_element_op,
255 block_2_ctile_map);
256#else
257 ignore = p_a_grid;
258 ignore = p_b_grid;
259 ignore = p_ds_grid;
260 ignore = p_e_grid;
261 ignore = a_grid_desc;
262 ignore = b_grid_desc;
263 ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock;
264 ignore = e_grid_desc_mblock_mperblock_nblock_nperblock;
265 ignore = a_element_op;
266 ignore = b_element_op;
267 ignore = cde_element_op;
268 ignore = block_2_ctile_map;
269#endif // end of if (defined(__gfx11__ ))
270}
271
272template < // DataType Family
273 typename ADataType,
274 typename BDataType,
275 typename AccDataType,
276 typename CShuffleDataType,
277 typename DsDataType,
278 typename EDataType,
279 // InMemory Data Descriptor
280 typename AGridDesc,
281 typename BGridDesc,
282 typename DsGridDesc_M_N,
283 typename EGridDesc_M_N,
284 // ElementwiseOp Family
285 typename AElementwiseOperation,
286 typename BElementwiseOperation,
287 typename CDEElementwiseOperation,
288 InMemoryDataOperationEnum EGlobalMemoryDataOperation,
289 // Tiling Family
290 index_t MPerBlock,
291 index_t NPerBlock,
292 index_t KPerBlock,
293 index_t MPerWmma,
294 index_t NPerWmma,
295 index_t K1Value,
296 index_t MRepeat,
297 index_t NRepeat,
298 // ThreadCluster Family
299 index_t BlockSize,
300 typename ABlockTransferThreadClusterLengths_K0_M_K1,
301 typename ABlockTransferThreadClusterArrangeOrder,
302 typename ABlockTransferSrcAccessOrder,
303 index_t ABlockTransferSrcVectorDim,
304 index_t ABlockTransferSrcScalarPerVector,
305 index_t ABlockTransferDstScalarPerVector_K1,
306 bool AThreadTransferSrcResetCoordinateAfterRun,
307 bool AEnableLds,
308 bool ABlockLdsExtraM,
309 typename BBlockTransferThreadClusterLengths_K0_N_K1,
310 typename BBlockTransferThreadClusterArrangeOrder,
311 typename BBlockTransferSrcAccessOrder,
312 index_t BBlockTransferSrcVectorDim,
313 index_t BBlockTransferSrcScalarPerVector,
314 index_t BBlockTransferDstScalarPerVector_K1,
315 bool BThreadTransferSrcResetCoordinateAfterRun,
316 bool BEnableLds,
317 bool BBlockLdsExtraN,
318 index_t CShuffleMRepeatPerShuffle,
319 index_t CShuffleNRepeatPerShuffle,
320 typename CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
321 index_t CDEShuffleBlockTransferScalarPerVector_NPerBlock,
322 index_t NumGemmKPrefetchStage = 1,
326{
327 static constexpr index_t NumDTensor = DsDataType::Size();
328
329 static constexpr auto I0 = Number<0>{};
330 static constexpr auto I1 = Number<1>{};
331 static constexpr auto I2 = Number<2>{};
332 static constexpr auto I3 = Number<3>{};
333 static constexpr auto I4 = Number<4>{};
334 static constexpr auto I5 = Number<5>{};
335 static constexpr auto I6 = Number<6>{};
336 static constexpr auto I7 = Number<7>{};
337
338 // K1 should be Number<...>
339 static constexpr auto K1 = Number<K1Value>{};
340
341 static constexpr auto MWaves = MPerBlock / (MRepeat * MPerWmma);
342 static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma);
343 static constexpr auto WmmaK = K1 == 16 ? 32 : 16;
344
346
349 NumGemmKPrefetchStage,
350 LoopSched,
351 AEnableLds,
352 BEnableLds>())>;
353
354 // Describe how data store to (LDS/VGPR) buffer from Global memory
355 __host__ __device__ static constexpr auto MakeABlockDescriptor()
356 {
357 constexpr auto a_block_desc = [&]() {
358 if constexpr(AEnableLds)
359 {
360 // K0->M->K1 Per Block
361 constexpr auto K0PerBlock = KPerBlock / K1;
362 constexpr auto max_lds_align = K1;
363
364 if constexpr(ABlockLdsExtraM)
365 {
369 }
370 else
371 {
373 make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
374 }
375 }
376 else
377 {
378 constexpr auto A_KRow = I2;
379 constexpr auto KWmmaPerblock = KPerBlock / WmmaK;
380 constexpr auto K0PerWmma = WmmaK / A_KRow / K1;
381 // KWmma->MRepeat->MWave->K0PerWmma->KRow->MPerWmma->K1 Per Thread
385 I1,
387 I1,
388 I1,
389 K1),
393 K1,
394 K1,
395 K1,
396 I1));
397 }
398 }();
399
400 return a_block_desc;
401 }
402
403 __host__ __device__ static constexpr auto MakeBBlockDescriptor()
404 {
405 constexpr auto b_block_desc = [&]() {
406 if constexpr(BEnableLds)
407 {
408 // K0->N->K1 Per Block
409 constexpr auto K0PerBlock = KPerBlock / K1;
410 constexpr auto max_lds_align = K1;
411
412 if constexpr(BBlockLdsExtraN)
413 {
417 }
418 else
419 {
421 make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
422 }
423 }
424 else
425 {
426 constexpr auto B_KRow = I2;
427 constexpr auto KWmmaPerblock = KPerBlock / WmmaK;
428 constexpr auto K0PerWmma = WmmaK / B_KRow / K1;
429 // KWmma->NRepeat->MWave->K0PerWmma->KRow->MPerWmma->K1 Per Thread
433 I1,
435 I1,
436 I1,
437 K1),
441 K1,
442 K1,
443 K1,
444 I1));
445 }
446 }();
447
448 return b_block_desc;
449 }
450
451 __host__ __device__ static constexpr auto MakeABlockSliceCopyStep()
452 {
453 constexpr auto a_block_copy_step = [&]() {
454 if constexpr(AEnableLds)
455 {
456 constexpr auto K0PerBlock = KPerBlock / K1;
457
458 return make_multi_index(K0PerBlock, 0, 0);
459 }
460 else
461 {
462 constexpr auto KWmmaPerBlock = KPerBlock / WmmaK;
463
464 return make_multi_index(KWmmaPerBlock, 0, 0, 0, 0, 0, 0);
465 }
466 }();
467
468 return a_block_copy_step;
469 }
470
471 __host__ __device__ static constexpr auto MakeBBlockSliceCopyStep()
472 {
473 constexpr auto b_block_copy_step = [&]() {
474 if constexpr(BEnableLds)
475 {
476 constexpr auto K0PerBlock = KPerBlock / K1;
477
478 return make_multi_index(K0PerBlock, 0, 0);
479 }
480 else
481 {
482 constexpr auto KWmmaPerBlock = KPerBlock / WmmaK;
483
484 return make_multi_index(KWmmaPerBlock, 0, 0, 0, 0, 0, 0);
485 }
486 }();
487
488 return b_block_copy_step;
489 }
490
491 // Describe how data read from (LDS/VGPR) buffer
492 template <typename ABlockDesc_>
493 __host__ __device__ static constexpr auto MakeAWaveDescriptor(const ABlockDesc_&)
494 {
495
496 constexpr auto a_wave_desc = [&]() {
497 if constexpr(AEnableLds)
498 {
499 // AK0_M_AK1 -> AK0_MRepeat_Mwaves_AKRow_MPerWmma_AK1
500 constexpr auto A_K0 = ABlockDesc_{}.GetLength(I0);
501 constexpr auto A_K1 = ABlockDesc_{}.GetLength(I2);
502#ifdef __gfx12__
503 constexpr auto A_KRow = I2;
504#else
505 constexpr auto A_KRow = I1;
506#endif
508 ABlockDesc_{},
515 }
516 else
517 {
518 // KWmma_MRepeat_MWave_K0PerWmma_KRow_MPerWmma_K1 -> K0_MRepeat_Mwaves_MPerWmma_K1
519 constexpr auto KWmma = ABlockDesc_{}.GetLength(I0);
520 constexpr auto K0PerWmma = ABlockDesc_{}.GetLength(I3);
521 constexpr auto A_KRow = ABlockDesc_{}.GetLength(I4);
522 constexpr auto A_K1 = ABlockDesc_{}.GetLength(I6);
523
526 I1,
528 I1,
529 Number<A_K1>{}));
530 }
531 }();
532
533 return a_wave_desc;
534 }
535
536 template <typename BBlockDesc_>
537 __host__ __device__ static constexpr auto MakeBWaveDescriptor(const BBlockDesc_&)
538 {
539 constexpr auto b_wave_desc = [&]() {
540 if constexpr(BEnableLds)
541 {
542 // BK0_N_BK1 -> BK0_NRepeat_Nwaves_NPerWmma_BK1
543 constexpr auto B_K0 = BBlockDesc_{}.GetLength(I0);
544 constexpr auto B_K1 = BBlockDesc_{}.GetLength(I2);
545#ifdef __gfx12__
546 constexpr auto B_KRow = I2;
547#else
548 constexpr auto B_KRow = I1;
549#endif
551 BBlockDesc_{},
558 }
559 else
560 {
561 // KWmma_MRepeat_MWave_K0PerWmma_KRow_MPerWmma_K1 -> K0_MRepeat_Mwaves_MPerWmma_K1
562 constexpr auto KWmma = BBlockDesc_{}.GetLength(I0);
563 constexpr auto K0PerWmma = BBlockDesc_{}.GetLength(I3);
564 constexpr auto B_KRow = BBlockDesc_{}.GetLength(I4);
565 constexpr auto B_K1 = BBlockDesc_{}.GetLength(I6);
566
567 // Workaround, Freeze transform
570 I1,
572 I1,
573 Number<B_K1>{}));
574 }
575 }();
576
577 return b_wave_desc;
578 }
579
580 __host__ __device__ static constexpr auto
581 // *Caution Here repeat is shuffle repeat
583 {
584 constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat =
588 I1,
590
591 return c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat;
592 }
593
594 // ck::Tuple<const D0DataType*, const D1DataType*, ...>
595 static constexpr auto MakeDsGridPointer()
596 {
597 return generate_tuple(
598 [&](auto i) {
599 using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
600
601 return static_cast<const DDataType*>(nullptr);
602 },
604 }
605
606 // CheckValidity for kernels without multi D
607 template <typename Block2CTileMap>
608 __host__ __device__ static constexpr bool CheckValidity(const AGridDesc& a_grid_desc,
609 const BGridDesc& b_grid_desc,
610 const EGridDesc_M_N& e_grid_desc_m_n,
611 const Block2CTileMap& block_2_ctile_map)
612 {
613 static_assert(is_known_at_compile_time<remove_cv_t<decltype(K1)>>::value,
614 "wrong! K1 need to be known at compile-time");
615
616 static_assert((MPerBlock % (MPerWmma * MRepeat) == 0) &&
617 (NPerBlock % (NRepeat * NPerWmma)) == 0,
618 "Invalid tuning param!");
619
620 const auto GetAProblemsizeMK = [&]() {
621 if constexpr(AEnableLds)
622 {
623 return make_tuple(a_grid_desc.GetLength(I1),
624 a_grid_desc.GetLength(I0) * a_grid_desc.GetLength(I2));
625 }
626 else
627 {
628 return make_tuple(a_grid_desc.GetLength(I1) * a_grid_desc.GetLength(I2) *
629 a_grid_desc.GetLength(I5),
630 a_grid_desc.GetLength(I0) * a_grid_desc.GetLength(I3) *
631 a_grid_desc.GetLength(I4) * a_grid_desc.GetLength(I6));
632 }
633 };
634
635 const auto GetBProblemsizeNK = [&]() {
636 if constexpr(BEnableLds)
637 {
638 return make_tuple(b_grid_desc.GetLength(I1),
639 b_grid_desc.GetLength(I0) * b_grid_desc.GetLength(I2));
640 }
641 else
642 {
643 return make_tuple(b_grid_desc.GetLength(I1) * b_grid_desc.GetLength(I2) *
644 b_grid_desc.GetLength(I5),
645 b_grid_desc.GetLength(I0) * b_grid_desc.GetLength(I3) *
646 b_grid_desc.GetLength(I4) * b_grid_desc.GetLength(I6));
647 }
648 };
649
650 const auto M = GetAProblemsizeMK()[I0];
651 const auto N = GetBProblemsizeNK()[I0];
652 const auto K = GetAProblemsizeMK()[I1];
653
654 if(!(M == e_grid_desc_m_n.GetLength(I0) && N == e_grid_desc_m_n.GetLength(I1) &&
655 K == GetBProblemsizeNK()[I1]))
656 {
657 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
658 {
659 printf("GridwiseOp: ABE descriptor dimension cross check failure\n");
660 }
661 return false;
662 }
663
664 if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0))
665 {
666 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
667 {
668 printf("GridwiseOp: Problemsize descriptor dimension check failure\n");
669 }
670 return false;
671 }
672
673 // check gridwise gemm pipeline
674 const auto num_k_loop = K / KPerBlock;
675
676 if(!GridwiseGemmPipe::IsSupported(num_k_loop))
677 {
678 return false;
679 }
680
681 if(!block_2_ctile_map.CheckValidity(e_grid_desc_m_n))
682 {
683 return false;
684 }
685
686 // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
687 constexpr long_index_t TwoGB = (long_index_t{1} << 31);
688
689 if(!(a_grid_desc.GetElementSpaceSize() * sizeof(ADataType) <= TwoGB &&
690 b_grid_desc.GetElementSpaceSize() * sizeof(BDataType) <= TwoGB &&
691 e_grid_desc_m_n.GetElementSpaceSize() * sizeof(EDataType) <= TwoGB))
692 {
693 return false;
694 }
695
696 return true;
697 }
698
699 // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
700 template <typename Block2CTileMap>
701 __host__ __device__ static constexpr bool CheckValidity(const AGridDesc& a_grid_desc,
702 const BGridDesc& b_grid_desc,
703 const DsGridDesc_M_N& ds_grid_desc_m_n,
704 const EGridDesc_M_N& e_grid_desc_m_n,
705 const Block2CTileMap& block_2_ctile_map)
706 {
707 static_assert(is_known_at_compile_time<remove_cv_t<decltype(K1)>>::value,
708 "wrong! K1 need to be known at compile-time");
709
710 static_assert((MPerBlock % (MPerWmma * MRepeat) == 0) &&
711 (NPerBlock % (NRepeat * NPerWmma)) == 0,
712 "Invalid tuning param!");
713
714 const auto GetAProblemsizeMK = [&]() {
715 if constexpr(AEnableLds)
716 {
717 return make_tuple(a_grid_desc.GetLength(I1),
718 a_grid_desc.GetLength(I0) * a_grid_desc.GetLength(I2));
719 }
720 else
721 {
722 return make_tuple(a_grid_desc.GetLength(I1) * a_grid_desc.GetLength(I2) *
723 a_grid_desc.GetLength(I5),
724 a_grid_desc.GetLength(I0) * a_grid_desc.GetLength(I3) *
725 a_grid_desc.GetLength(I4) * a_grid_desc.GetLength(I6));
726 }
727 };
728
729 const auto GetBProblemsizeNK = [&]() {
730 if constexpr(BEnableLds)
731 {
732 return make_tuple(b_grid_desc.GetLength(I1),
733 b_grid_desc.GetLength(I0) * b_grid_desc.GetLength(I2));
734 }
735 else
736 {
737 return make_tuple(b_grid_desc.GetLength(I1) * b_grid_desc.GetLength(I2) *
738 b_grid_desc.GetLength(I5),
739 b_grid_desc.GetLength(I0) * b_grid_desc.GetLength(I3) *
740 b_grid_desc.GetLength(I4) * b_grid_desc.GetLength(I6));
741 }
742 };
743
744 const auto M = GetAProblemsizeMK()[I0];
745 const auto N = GetBProblemsizeNK()[I0];
746 const auto K = GetAProblemsizeMK()[I1];
747
748 bool valid = true;
749
750 static_for<0, NumDTensor, 1>{}([&](auto i) {
751 valid = valid && (M == ds_grid_desc_m_n[i].GetLength(I0) &&
752 N == ds_grid_desc_m_n[i].GetLength(I1));
753 });
754
755 if(!valid)
756 {
757 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
758 {
759 printf("GridwiseOp: D descriptor dimension check failure\n");
760 }
761 return false;
762 }
763
764 if(!(M == e_grid_desc_m_n.GetLength(I0) && N == e_grid_desc_m_n.GetLength(I1) &&
765 K == GetBProblemsizeNK()[I1]))
766 {
767 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
768 {
769 printf("GridwiseOp: ABE descriptor dimension cross check failure\n");
770 }
771 return false;
772 }
773
774 if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0))
775 {
776 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
777 {
778 printf("GridwiseOp: Problemsize descriptor dimension check failure\n");
779 }
780 return false;
781 }
782
783 // check gridwise gemm pipeline
784 const auto num_k_loop = K / KPerBlock;
785
786 if(!GridwiseGemmPipe::IsSupported(num_k_loop))
787 {
788 return false;
789 }
790
791 if(!block_2_ctile_map.CheckValidity(e_grid_desc_m_n))
792 {
793 return false;
794 }
795
796 // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
797 constexpr long_index_t TwoGB = (long_index_t{1} << 31);
798
799 if(!(a_grid_desc.GetElementSpaceSize() * sizeof(ADataType) <= TwoGB &&
800 b_grid_desc.GetElementSpaceSize() * sizeof(BDataType) <= TwoGB &&
801 e_grid_desc_m_n.GetElementSpaceSize() * sizeof(EDataType) <= TwoGB))
802 {
803 return false;
804 }
805
806 return true;
807 }
808
809 __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
810 {
811 const index_t num_loop = K / KPerBlock;
812
813 return GridwiseGemmPipe::CalculateHasMainLoop(num_loop);
814 }
815
816 // E desc for destination in blockwise copy
817 template <typename EGridDesc_M_N_>
818 __host__ __device__ static constexpr auto
819 MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const EGridDesc_M_N_& e_grid_desc_m_n)
820 {
821 const auto M = e_grid_desc_m_n.GetLength(I0);
822 const auto N = e_grid_desc_m_n.GetLength(I1);
823
824 const auto MBlock = M / MPerBlock;
825 const auto NBlock = N / NPerBlock;
826
827 const auto e_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor(
828 e_grid_desc_m_n,
833
834 return e_grid_desc_mblock_mperblock_nblock_nperblock;
835 }
836
837 // Ds desc for source in blockwise copy
838 template <typename DsGridDesc_M_N_>
839 __host__ __device__ static constexpr auto
840 MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const DsGridDesc_M_N_& ds_grid_desc_m_n)
841 {
842 return generate_tuple(
843 [&](auto i) {
845 },
847 }
848
849 // return block_id to C matrix tile idx (m0, n0) mapping
850 __host__ __device__ static constexpr auto MakeDefaultBlock2CTileMap(
851 const EGridDesc_M_N& e_grid_desc_m_n, index_t /* M01 */, index_t /* N01 */)
852 {
854 e_grid_desc_m_n);
855 }
856
858 {
859 // LDS allocation for A and B: be careful of alignment
860
861 static constexpr auto max_lds_align = K1;
862
863 static constexpr auto a_block_space_size_aligned =
864 AEnableLds ? math::integer_least_multiple(MakeABlockDescriptor().GetElementSpaceSize(),
866 : 0;
867 static constexpr auto b_block_space_size_aligned =
868 BEnableLds ? math::integer_least_multiple(MakeBBlockDescriptor().GetElementSpaceSize(),
870 : 0;
871
872 static constexpr auto a_block_space_offset = 0;
874
875 // LDS allocation for C shuffle in LDS
876 static constexpr auto c_shuffle_block_space_size =
878 .GetElementSpaceSize();
879
880 static constexpr auto c_shuffle_block_space_offset = 0;
881
882 static constexpr auto lds_size =
883 math::max(c_shuffle_block_space_size * sizeof(CShuffleDataType),
884 a_block_space_size_aligned * sizeof(ADataType) +
885 b_block_space_size_aligned * sizeof(BDataType));
886 };
887
890 DsGridDesc_M_N{}))>;
893 EGridDesc_M_N{}))>;
895 remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(EGridDesc_M_N{}, 1, 1))>;
896 using DsGridPointer = decltype(MakeDsGridPointer());
897
898 template <bool HasMainKBlockLoop, typename Block2CTileMap = DefaultBlock2CTileMap>
899 __device__ static void Run(const ADataType* __restrict__ p_a_grid,
900 const BDataType* __restrict__ p_b_grid,
901 DsGridPointer p_ds_grid,
902 EDataType* __restrict__ p_e_grid,
903 void* __restrict__ p_shared,
904 const AGridDesc& a_grid_desc,
905 const BGridDesc& b_grid_desc,
907 ds_grid_desc_mblock_mperblock_nblock_nperblock,
909 e_grid_desc_mblock_mperblock_nblock_nperblock,
910 const AElementwiseOperation& a_element_op,
911 const BElementwiseOperation& b_element_op,
912 const CDEElementwiseOperation& cde_element_op,
913 const Block2CTileMap& block_2_ctile_map)
914 {
915 // clang-format off
916/*******************************************************************************/
917// Memory buffer zone.
918 const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
919 p_a_grid, a_grid_desc.GetElementSpaceSize());
920 const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
921 p_b_grid, b_grid_desc.GetElementSpaceSize());
922 const auto ds_grid_buf = generate_tuple(
923 [&](auto i) {
925 p_ds_grid[i],
926 ds_grid_desc_mblock_mperblock_nblock_nperblock[i].GetElementSpaceSize());
927 },
930 p_e_grid, e_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
931
932/*******************************************************************************/
933// BlockIdx.x -> [BlockId.m, BlockId.n]
934 const auto block_work_idx = block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
935 if(!block_2_ctile_map.ValidCTileIndex(
936 block_work_idx,
937 make_tuple(e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
938 e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
939 { return; }
940
941 // Store BlockId into SGPR
942 const index_t m_block_data_idx_on_grid = __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock);
943 const index_t n_block_data_idx_on_grid = __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock);
944
945/*******************************************************************************/
946// BlockLevel, A/B Matrix ThreadMapping in LDS, As Destinaion of BlockWise_Copy
947 const auto K = [&](){
948 if constexpr(AEnableLds){
949 return a_grid_desc.GetLength(I0) * a_grid_desc.GetLength(I2);
950 }
951 else{
952 return a_grid_desc.GetLength(I0) * a_grid_desc.GetLength(I3) *
953 a_grid_desc.GetLength(I4) * a_grid_desc.GetLength(I6);
954 }
955 }();
956
957 constexpr auto a_block_desc = MakeABlockDescriptor();
958 constexpr auto b_block_desc = MakeBBlockDescriptor();
959
960 auto a_block_trait = [&](){
961 // A matrix blockwise copy
962 if constexpr(AEnableLds)
963 {
964 constexpr auto K0PerBlock = KPerBlock/ K1;
966 static_cast<ADataType*>(p_shared),
967 a_block_desc.GetElementSpaceSize());
968
969 auto a_blockwise_copy =
971/* typename SrcElementwiseOperation, */ AElementwiseOperation,
972/* typename DstElementwiseOperation, */ ck::tensor_operation::element_wise::PassThrough,
973/* InMemoryDataOperationEnum DstInMemOp, */ InMemoryDataOperationEnum::Set,
974/* typename BlockSliceLengths, */ Sequence<K0PerBlock, MPerBlock, K1>,
975/* typename ThreadClusterLengths, */ ABlockTransferThreadClusterLengths_K0_M_K1,
976/* typename ThreadClusterArrangeOrder, */ ABlockTransferThreadClusterArrangeOrder,
977/* typename SrcData, */ ADataType,
978/* typename DstData, */ ADataType,
979/* typename SrcDesc, */ decltype(a_grid_desc),
980/* typename DstDesc, */ decltype(a_block_desc),
981/* typename SrcDimAccessOrder, */ ABlockTransferSrcAccessOrder,
982/* typename DstDimAccessOrder, */ Sequence<0, 1, 2>,
983/* index_t SrcVectorDim, */ ABlockTransferSrcVectorDim,
984/* index_t DstVectorDim, */ 2,
985/* index_t SrcScalarPerVector, */ ABlockTransferSrcScalarPerVector,
986/* index_t DstScalarPerVector, */ ABlockTransferDstScalarPerVector_K1,
987/* index_t SrcScalarStrideInVector, */ 1,
988/* index_t DstScalarStrideInVector, */ 1,
989/* bool ThreadTransferSrcResetCoordinateAfterRun, */ AThreadTransferSrcResetCoordinateAfterRun,
990/* bool ThreadTransferDstResetCoordinateAfterRun, */ true,
991 NumGemmKPrefetchStage>(
992 a_grid_desc,
993 make_multi_index(0, m_block_data_idx_on_grid, 0),
994 a_element_op,
995 a_block_desc,
996 make_multi_index(0, 0, 0),
998
999 return make_tuple(a_block_buf, a_blockwise_copy);
1000 }
1001 else
1002 {
1003 // Thread-wise copy
1004 // KPerBlock/WmmaK -> MRepeat -> MWaves -> K0PerWmma -> KRow -> MPerWmma -> K1
1005 constexpr auto KWmmaPerBlock = KPerBlock / WmmaK;
1006 constexpr auto K0PerWmma = WmmaK/2/K1Value;
1008 a_block_desc.GetElementSpaceSize());
1009
1010 // Limitation: NumDim of Src and Dst descriptor should be identical
1011 auto a_blockwise_copy =
1013 ADataType,
1014 decltype(a_grid_desc),
1015 decltype(a_block_desc),
1018 I1,
1020 I1,
1021 I1,
1022 Number<K1Value>{}>,
1024 6,
1025 ABlockTransferSrcScalarPerVector,
1026 AThreadTransferSrcResetCoordinateAfterRun,
1027 true>(
1028 a_grid_desc,
1030 m_block_data_idx_on_grid/(MWaves * MPerWmma),
1032 0,
1033 (get_thread_local_1d_id() % 32 )/ 16,
1035 0));
1036
1037 return make_tuple(a_block_buf, a_blockwise_copy);
1038 }
1039 };
1040
1041 auto b_block_trait = [&](){
1042 if constexpr(BEnableLds)
1043 {
1044 constexpr auto K0PerBlock = KPerBlock/ K1;
1046 static_cast<BDataType*>(p_shared) + SharedMemTrait::a_block_space_size_aligned,
1047 b_block_desc.GetElementSpaceSize());
1048
1049 auto b_blockwise_copy =
1051 BElementwiseOperation,
1055 BBlockTransferThreadClusterLengths_K0_N_K1,
1056 BBlockTransferThreadClusterArrangeOrder,
1057 BDataType,
1058 BDataType,
1059 decltype(b_grid_desc),
1060 decltype(b_block_desc),
1061 BBlockTransferSrcAccessOrder,
1063 BBlockTransferSrcVectorDim,
1064 2,
1065 BBlockTransferSrcScalarPerVector,
1066 BBlockTransferDstScalarPerVector_K1,
1067 1,
1068 1,
1069 BThreadTransferSrcResetCoordinateAfterRun,
1070 true,
1071 NumGemmKPrefetchStage>(
1072 b_grid_desc,
1073 make_multi_index(0, n_block_data_idx_on_grid, 0),
1074 b_element_op,
1075 b_block_desc,
1076 make_multi_index(0, 0, 0),
1078
1079 return make_tuple(b_block_buf, b_blockwise_copy);
1080 }
1081 else
1082 {
1083 // Thread-wise copy
1084 // KPerBlock/WmmaK -> NRepeat -> NWaves -> WmmaK/K1 -> NPerWmma -> K1
1085 constexpr auto KWmmaPerBlock = KPerBlock / WmmaK;
1086 constexpr auto K0PerWmma = WmmaK/2/K1Value;
1088 b_block_desc.GetElementSpaceSize());
1089
1090 // Limitation: NumDim of Src and Dst descriptor should be identical
1091 auto b_blockwise_copy =
1093 BDataType,
1094 decltype(b_grid_desc),
1095 decltype(b_block_desc),
1098 I1,
1100 I1,
1101 I1,
1102 Number<K1Value>{}>,
1104 6,
1105 BBlockTransferSrcScalarPerVector,
1106 BThreadTransferSrcResetCoordinateAfterRun,
1107 true>(
1108 b_grid_desc,
1110 n_block_data_idx_on_grid/(NWaves * NPerWmma),
1112 0,
1113 (get_thread_local_1d_id() % 32 )/ 16,
1115 0));
1116
1117 return make_tuple(b_block_buf, b_blockwise_copy);
1118 }
1119 };
1120
1121 auto a_block_buf = a_block_trait()[I0];
1122 auto a_blockwise_copy = a_block_trait()[I1];
1123
1124 auto b_block_buf = b_block_trait()[I0];
1125 auto b_blockwise_copy = b_block_trait()[I1];
1126/*******************************************************************************/
1127 // GEMM
1128 constexpr auto KPack = math::integer_least_multiple(K1, WmmaK);
1129
1130 auto blockwise_gemm =
1131 BlockwiseGemmWMMA<BlockSize,
1132 ADataType,
1133 BDataType,
1134 AccDataType,
1135 decltype(MakeAWaveDescriptor(a_block_desc)),
1136 decltype(MakeBWaveDescriptor(b_block_desc)),
1137 MPerBlock,
1138 NPerBlock,
1139 KPerBlock,
1140 MPerWmma,
1141 NPerWmma,
1142 MRepeat,
1143 NRepeat,
1144 KPack,
1145 AEnableLds,
1146 BEnableLds>{};
1147
1148 // Prepare Register for C matrix
1149 auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();
1150
1151/*******************************************************************************/
1152 // Shift Per SUB_K
1153 constexpr auto a_block_slice_copy_step = MakeABlockSliceCopyStep();
1154 constexpr auto b_block_slice_copy_step = MakeBBlockSliceCopyStep();
1155
1156 // gridwise GEMM pipeline
1157 const index_t KBlockMainLoop = __builtin_amdgcn_readfirstlane(K / KPerBlock);
1158 GridwiseGemmPipe::template Run<HasMainKBlockLoop>(a_grid_desc,
1159 a_block_desc,
1160 a_blockwise_copy,
1161 a_grid_buf,
1162 a_block_buf,
1163 a_block_slice_copy_step,
1164 b_grid_desc,
1165 b_block_desc,
1166 b_blockwise_copy,
1167 b_grid_buf,
1168 b_block_buf,
1169 b_block_slice_copy_step,
1170 blockwise_gemm,
1171 c_thread_buf,
1172 KBlockMainLoop);
1173/*******************************************************************************/
1174 // write out to C, implement shuffle
1175 {
1176 constexpr auto c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs =
1177 blockwise_gemm.GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs();
1178
1179 // This API Provide All dimension (size) you need
1180 constexpr auto c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp =
1181 blockwise_gemm.GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs();
1182
1183 constexpr auto MWave = c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp.GetLength(I1);
1184 constexpr auto MSubGroup = c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp.GetLength(I2);
1185 constexpr auto NWave = c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp.GetLength(I4);
1186 constexpr auto NThreadPerSubGroup = c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp.GetLength(I5);
1187 constexpr auto MAccVgprs = c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp.GetLength(I6);
1188
1189 // LDS descriptor, shuffle and write out in MRepeat x NRepeat times
1190 constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat =
1192
1193 auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1194 static_cast<CShuffleDataType*>(p_shared),
1195 c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat.GetElementSpaceSize());
1196
1197 constexpr auto c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs = transform_tensor_descriptor(
1198 c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat,
1199 make_tuple(
1202 Number<CShuffleMRepeatPerShuffle>{}, // MRepeat per shuffle repeat
1203 MWave, // MWave
1204 MSubGroup, // MSubGroup * MAccVgprs = MPerWmma
1205 MAccVgprs)),
1208 Number<CShuffleNRepeatPerShuffle>{}, // NRepeat per shuffle repeat
1209 NWave, // NWave
1210 NThreadPerSubGroup))), // NThreadPerSubGroup = NPerWmma
1213
1214 // calculate origin of thread output tensor on global memory
1215 // blockwise GEMM c matrix starting index
1216 const auto c_thread_mtx_on_block = blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0);
1217
1218 const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
1219 const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
1220
1221 const auto m_thread_data_on_block_to_mrepeat_mwave_msubgroup_maccvgprs_adaptor =
1223 make_tuple(make_merge_transform(make_tuple(MRepeat, MWave, MSubGroup, MAccVgprs))),
1226
1227 const auto n_thread_data_on_block_to_nrepeat_nwave_nthreadpersubgroup_adaptor =
1229 make_tuple(make_merge_transform(make_tuple(NRepeat, NWave, NThreadPerSubGroup))),
1232
1233 const auto m_thread_data_on_block_idx = m_thread_data_on_block_to_mrepeat_mwave_msubgroup_maccvgprs_adaptor.CalculateBottomIndex(
1234 make_multi_index(m_thread_data_on_block));
1235
1236 const auto n_thread_data_on_block_idx = n_thread_data_on_block_to_nrepeat_nwave_nthreadpersubgroup_adaptor.CalculateBottomIndex(
1237 make_multi_index(n_thread_data_on_block));
1238
1239 // shuffle: threadwise copy C from VGPR to LDS
1240 auto c_thread_copy_vgpr_to_lds =
1242 CShuffleDataType,
1243 decltype(c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs),
1244 decltype(c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs),
1246 Sequence<CShuffleMRepeatPerShuffle,
1247 I1,
1248 I1,
1249 CShuffleNRepeatPerShuffle,
1250 I1,
1251 I1,
1252 MAccVgprs>,
1254 6,
1255 1, // vector write pixel
1257 1,
1258 true>{
1259 c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs,
1261 m_thread_data_on_block_idx[I1],
1262 m_thread_data_on_block_idx[I2],
1263 0,
1264 n_thread_data_on_block_idx[I1],
1265 n_thread_data_on_block_idx[I2],
1266 m_thread_data_on_block_idx[I3]),
1268
1269 // tuple of reference to C/Ds tensor descriptors
1270 const auto c_ds_desc_refs = concat_tuple_of_reference(
1271 tie(c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat),
1273 [&](auto i) -> const auto& // return type should be reference
1274 { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; },
1276
1277 // tuple of reference to C/Ds tensor buffers
1278 const auto c_ds_buf_refs = concat_tuple_of_reference(
1279 tie(c_shuffle_block_buf),
1281 [&](auto i) -> const auto& // return type should be reference
1282 { return ds_grid_buf[i]; },
1284
1285 // tuple of starting index of C/Ds blockwise copy
1286 const auto idx_c_ds_block_begin = container_concat(
1287 make_tuple(make_multi_index(0, 0, 0, 0)),
1289 [&](auto) {
1290 return make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0);
1291 },
1293
1294 // shuffle: blockwise copy C from LDS to global
1295 auto cde_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v7<
1296 ThisThreadBlock, // ThreadGroup
1297 decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})),
1299 decltype(c_ds_desc_refs),
1300 decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)),
1301 CDEElementwiseOperation, // ElementwiseOperation,
1302 Sequence<static_cast<index_t>(EGlobalMemoryDataOperation)>, // DstInMemOp,
1303 Sequence<1,
1304 CShuffleMRepeatPerShuffle * MWave * MPerWmma,
1305 1,
1306 CShuffleNRepeatPerShuffle * NWave * NPerWmma>, // BlockSliceLengths,
1307 CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
1308 Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
1309 Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
1310 3, // index_t VectorDim,
1311 CDEShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector,
1315 false>>, // bool ThreadTransferSrcResetCoordinateAfterRun,
1316 Sequence<false>> // bool ThreadTransferDstResetCoordinateAfterRun>
1317 {c_ds_desc_refs,
1318 idx_c_ds_block_begin,
1319 tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
1320 make_tuple(make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0)),
1321 cde_element_op};
1322
1323 // space filling curve for local reg & global memory
1324 // space filling curve for threadwise C in VGPR
1325 constexpr auto sfc_c_vgpr =
1328 Sequence<CShuffleMRepeatPerShuffle,
1329 1,
1330 1,
1331 CShuffleNRepeatPerShuffle,
1332 1,
1333 1,
1334 MAccVgprs>>{};
1335
1336 // space filling curve for shuffled blockwise C in global mem
1337 constexpr auto sfc_cde_global =
1340 Sequence<1,
1341 CShuffleMRepeatPerShuffle * MWave * MPerWmma,
1342 1,
1343 CShuffleNRepeatPerShuffle * NWave * NPerWmma>>{};
1344
1345 constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
1346
1347 static_assert(num_access == sfc_cde_global.GetNumOfAccess(), "wrong!");
1348
1349 static_for<0, num_access, 1>{}([&](auto access_id) {
1350 // make sure it's safe to write to LDS
1352
1353 // each thread write its data from VGPR to LDS
1354 c_thread_copy_vgpr_to_lds.Run(c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs,
1355 sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
1356 c_thread_buf,
1357 c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs,
1358 c_shuffle_block_buf);
1359
1360 // make sure it's safe to read from LDS
1362
1363 // each block copy its data from LDS to global
1364 cde_shuffle_block_copy_lds_to_global.Run(
1365 c_ds_desc_refs,
1366 c_ds_buf_refs,
1367 tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
1368 tie(e_grid_buf));
1369
1370 if constexpr(access_id < num_access - 1)
1371 {
1372 constexpr auto cde_global_step = sfc_cde_global.GetForwardStep(access_id);
1373 // move on Ds
1374 static_for<0, NumDTensor, 1>{}([&](auto i) {
1375 cde_shuffle_block_copy_lds_to_global.MoveSrcSliceWindow(
1376 c_ds_desc_refs, i + I1, cde_global_step);
1377 });
1378
1379 // move on E
1380 cde_shuffle_block_copy_lds_to_global.MoveDstSliceWindow(
1381 tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
1382 I0,
1383 cde_global_step);
1384 }
1385 });
1386 }
1387 // clang-format on
1388 }
1389};
1390
1391} // namespace ck
#define CK_MIN_BLOCK_PER_CU
Definition ck.hpp:31
#define CK_MAX_THREAD_PER_BLOCK
Definition ck.hpp:30
__host__ __device__ constexpr auto integer_least_multiple(X x, Y y)
Definition utility/math.hpp:78
__host__ __device__ constexpr T max(T x)
Definition utility/math.hpp:84
Definition ck.hpp:268
__host__ __device__ constexpr auto make_multi_index(Xs &&... xs)
Definition array_multi_index.hpp:15
typename uniform_sequence_gen< NSize, I >::type uniform_sequence_gen_t
Definition utility/sequence.hpp:928
__host__ __device__ constexpr auto make_static_buffer(Number< N >)
Definition static_buffer.hpp:186
__host__ __device__ constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition multi_index_transform_helper.hpp:12
__device__ index_t get_grid_size()
Definition get_id.hpp:49
__host__ __device__ constexpr auto container_concat(const X &x, const Ys &... ys)
Definition utility/container_helper.hpp:320
constexpr auto GridwiseGemmPipeline_Selector()
Definition gridwise_gemm_pipeline_selector.hpp:31
int32_t index_t
Definition ck.hpp:299
__host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition tensor_descriptor_helper.hpp:49
InMemoryDataOperationEnum
Definition ck.hpp:277
@ Set
Definition ck.hpp:278
__host__ __device__ constexpr auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition tensor_description/tensor_adaptor.hpp:425
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
__host__ __device__ constexpr auto make_freeze_transform(const LowerIndex &low_idx)
Definition multi_index_transform_helper.hpp:151
constexpr Tuple< Args &... > tie(Args &... args) noexcept
Definition utility/tuple.hpp:218
integral_constant< index_t, N > Number
Definition number.hpp:12
typename tuple_element< I, TTuple >::type tuple_element_t
Definition utility/tuple.hpp:208
__host__ __device__ constexpr auto make_merge_transform(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:55
constexpr detail::ignore_t ignore
Definition utility/ignore.hpp:20
__device__ uint32_t amd_wave_read_first_lane(uint32_t value)
Definition amd_wave_read_first_lane.hpp:100
__device__ index_t get_block_1d_id()
Definition get_id.hpp:47
__global__ void kernel_gemm_mupltipe_d_wmma_cshuffle(const ADataType *__restrict__ p_a_grid, const BDataType *__restrict__ p_b_grid, DsPointer p_ds_grid, EDataType *__restrict__ p_e_grid, const AGridDesc a_grid_desc, const BGridDesc b_grid_desc, const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock ds_grid_desc_mblock_mperblock_nblock_nperblock, const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock, const AElementwiseOperation a_element_op, const BElementwiseOperation b_element_op, const CDEElementwiseOperation cde_element_op, const Block2CTileMap block_2_ctile_map)
Definition gridwise_gemm_multiple_d_wmma_cshuffle.hpp:225
__global__ void kernel_grouped_conv_multiple_d_wmma_cshuffle(const ADataType *__restrict__ p_a_grid, const BDataType *__restrict__ p_b_grid, DsPointer p_ds_grid, EDataType *__restrict__ p_e_grid, const AElementwiseOperation a_element_op, const BElementwiseOperation b_element_op, const CDEElementwiseOperation cde_element_op, const index_t batch_count, const AGridDesc_AK0_M_AK1 a_grid_desc, const BGridDesc_BK0_N_BK1 b_grid_desc, const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock ds_grid_desc_mblock_mperblock_nblock_nperblock, const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock_, const Block2CTileMap block_2_ctile_map, const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch)
Definition gridwise_gemm_multiple_d_wmma_cshuffle.hpp:40
bool EnvIsEnabled(EnvVar)
Definition utility/env.hpp:140
__host__ __device__ constexpr auto make_naive_tensor_descriptor_aligned(const Tuple< Lengths... > &lengths, Align align)
Definition tensor_descriptor_helper.hpp:132
__host__ __device__ constexpr auto generate_tuple(F &&f, Number< N >)
Definition tuple_helper.hpp:21
__host__ __device__ constexpr auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition tensor_descriptor_helper.hpp:101
__global__ void kernel_contraction_multiple_d_wmma_cshuffle(const ADataType *__restrict__ p_a_grid, const BDataType *__restrict__ p_b_grid, DsPointer p_ds_grid, EDataType *__restrict__ p_e_grid, const index_t batch_count, const AGridDesc a_grid_desc, const BGridDesc b_grid_desc, const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock ds_grid_desc_mblock_mperblock_nblock_nperblock, const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock, const AElementwiseOperation a_element_op, const BElementwiseOperation b_element_op, const CDEElementwiseOperation cde_element_op, const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, const Block2CTileMap block_2_etile_map)
Definition gridwise_gemm_multiple_d_wmma_cshuffle.hpp:133
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
LoopScheduler
Definition loop_scheduler.hpp:15
__device__ index_t get_thread_local_1d_id()
Definition get_id.hpp:41
typename sequence_merge< Sx, Sy >::type sequence_merge_t
Definition utility/sequence.hpp:925
int64_t long_index_t
Definition ck.hpp:300
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
__device__ void block_sync_lds()
Definition synchronization.hpp:16
PipelineVersion
Definition gridwise_gemm_pipeline_selector.hpp:18
@ v1
Definition gridwise_gemm_pipeline_selector.hpp:19
__host__ __device__ constexpr auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:90
typename remove_cv< T >::type remove_cv_t
Definition type.hpp:295
__host__ __device__ constexpr auto make_dynamic_buffer(T *p, ElementSpaceSize element_space_size)
Definition dynamic_buffer.hpp:472
__host__ __device__ constexpr auto generate_tie(F &&f, Number< N >)
Definition tuple_helper.hpp:34
__host__ __device__ constexpr auto concat_tuple_of_reference(const Tuple< X &... > &tx, const Tuple< Y &... > &ty)
Definition tuple_helper.hpp:42
constexpr LoopScheduler make_default_loop_scheduler()
Definition loop_scheduler.hpp:20
const GenericPointer< typename T::ValueType > T2 value
Definition pointer.h:1697
Definition block_to_ctile_map.hpp:261
Definition blockwise_gemm_wmma.hpp:550
__host__ __device__ constexpr auto & GetCThreadBuffer()
Definition blockwise_gemm_wmma.hpp:585
Definition gridwise_gemm_multiple_d_wmma_cshuffle.hpp:858
static constexpr auto lds_size
Definition gridwise_gemm_multiple_d_wmma_cshuffle.hpp:882
static constexpr auto max_lds_align
Definition gridwise_gemm_multiple_d_wmma_cshuffle.hpp:861
static constexpr auto b_block_space_size_aligned
Definition gridwise_gemm_multiple_d_wmma_cshuffle.hpp:867
static constexpr auto a_block_space_offset
Definition gridwise_gemm_multiple_d_wmma_cshuffle.hpp:872
static constexpr auto c_shuffle_block_space_offset
Definition gridwise_gemm_multiple_d_wmma_cshuffle.hpp:880
static constexpr auto c_shuffle_block_space_size
Definition gridwise_gemm_multiple_d_wmma_cshuffle.hpp:876
static constexpr auto b_block_space_offset
Definition gridwise_gemm_multiple_d_wmma_cshuffle.hpp:873
static constexpr auto a_block_space_size_aligned
Definition gridwise_gemm_multiple_d_wmma_cshuffle.hpp:863
Definition gridwise_gemm_multiple_d_wmma_cshuffle.hpp:326
__host__ static __device__ constexpr auto MakeBBlockDescriptor()
Definition gridwise_gemm_multiple_d_wmma_cshuffle.hpp:403
__host__ static __device__ constexpr auto MakeAWaveDescriptor(const ABlockDesc_ &)
Definition gridwise_gemm_multiple_d_wmma_cshuffle.hpp:493
__host__ static __device__ constexpr bool CalculateHasMainKBlockLoop(index_t K)
Definition gridwise_gemm_multiple_d_wmma_cshuffle.hpp:809
__host__ static __device__ constexpr auto MakeBWaveDescriptor(const BBlockDesc_ &)
Definition gridwise_gemm_multiple_d_wmma_cshuffle.hpp:537
__host__ static __device__ constexpr auto MakeBBlockSliceCopyStep()
Definition gridwise_gemm_multiple_d_wmma_cshuffle.hpp:471
__host__ static __device__ constexpr bool CheckValidity(const AGridDesc &a_grid_desc, const BGridDesc &b_grid_desc, const DsGridDesc_M_N &ds_grid_desc_m_n, const EGridDesc_M_N &e_grid_desc_m_n, const Block2CTileMap &block_2_ctile_map)
Definition gridwise_gemm_multiple_d_wmma_cshuffle.hpp:701
__host__ static __device__ constexpr auto MakeABlockDescriptor()
Definition gridwise_gemm_multiple_d_wmma_cshuffle.hpp:355
static __device__ void Run(const ADataType *__restrict__ p_a_grid, const BDataType *__restrict__ p_b_grid, DsGridPointer p_ds_grid, EDataType *__restrict__ p_e_grid, void *__restrict__ p_shared, const AGridDesc &a_grid_desc, const BGridDesc &b_grid_desc, const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock &ds_grid_desc_mblock_mperblock_nblock_nperblock, const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock &e_grid_desc_mblock_mperblock_nblock_nperblock, const AElementwiseOperation &a_element_op, const BElementwiseOperation &b_element_op, const CDEElementwiseOperation &cde_element_op, const Block2CTileMap &block_2_ctile_map)
Definition gridwise_gemm_multiple_d_wmma_cshuffle.hpp:899
__host__ static __device__ constexpr auto MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const EGridDesc_M_N_ &e_grid_desc_m_n)
Definition gridwise_gemm_multiple_d_wmma_cshuffle.hpp:819
__host__ static __device__ constexpr auto MakeDefaultBlock2CTileMap(const EGridDesc_M_N &e_grid_desc_m_n, index_t, index_t)
Definition gridwise_gemm_multiple_d_wmma_cshuffle.hpp:850
__host__ static __device__ constexpr auto MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const DsGridDesc_M_N_ &ds_grid_desc_m_n)
Definition gridwise_gemm_multiple_d_wmma_cshuffle.hpp:840
static constexpr auto MakeDsGridPointer()
Definition gridwise_gemm_multiple_d_wmma_cshuffle.hpp:595
__host__ static __device__ constexpr bool CheckValidity(const AGridDesc &a_grid_desc, const BGridDesc &b_grid_desc, const EGridDesc_M_N &e_grid_desc_m_n, const Block2CTileMap &block_2_ctile_map)
Definition gridwise_gemm_multiple_d_wmma_cshuffle.hpp:608
__host__ static __device__ constexpr auto MakeABlockSliceCopyStep()
Definition gridwise_gemm_multiple_d_wmma_cshuffle.hpp:451
__host__ static __device__ constexpr auto GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat()
Definition gridwise_gemm_multiple_d_wmma_cshuffle.hpp:582
Definition utility/sequence.hpp:43
Definition tensor_space_filling_curve.hpp:20
Blockwise data transfer.
Definition thread_group_tensor_slice_transfer_v4r1.hpp:46
Definition thread_group_tensor_slice_transfer_v7.hpp:42
Definition threadwise_tensor_slice_transfer.hpp:39
Helper structure that facilitates transfer of source (grid) data to destination threads.
Definition threadwise_tensor_slice_transfer.hpp:234
Definition utility/tuple.hpp:117
Definition is_known_at_compile_time.hpp:14
Definition functional2.hpp:33
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:340
#define CK_ENV(name)
Definition utility/env.hpp:129