gridwise_gemm_xdl_cshuffle_v3_mx.hpp Source File

gridwise_gemm_xdl_cshuffle_v3_mx.hpp Source File#

Composable Kernel: gridwise_gemm_xdl_cshuffle_v3_mx.hpp Source File
gridwise_gemm_xdl_cshuffle_v3_mx.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
16#include "ck/utility/env.hpp"
18
19namespace ck {
20
21#ifndef KERNEL_GEMM_XDL_CSHUFFLE_V3_MX
22#define KERNEL_GEMM_XDL_CSHUFFLE_V3_MX
23// Currently we do not have a elegant way to put single lds buffer & double lds buffer pipe in same
24// kernel function Blockers:
25// 1. Two separted declaration of __shared__ pointer is the key to make sure data access operate on
26// two lds chunks.
27// 2. Occupied __shared__ won't release until whole shader end, a.k.a AB and C may not use same lds
28// buffer when we declare __shared__ inside blkgemmpipe
29template <bool Use2LDS,
30 typename GridwiseGemm,
31 bool HasMainKBlockLoop,
32 InMemoryDataOperationEnum CGlobalMemoryDataOperation,
33 index_t MinimumOccupancy = 1,
36#if CK_USE_LAUNCH_BOUNDS
37__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
38#endif
39 // __attribute__((amdgpu_waves_per_eu(1, 1)))
40 kernel_gemm_xdl_cshuffle_v3_mx(typename GridwiseGemm::Argument karg)
41{
42#if defined(__gfx950__)
43 if constexpr(GridwiseGemm::template IsValidCompilationParameter<CGlobalMemoryDataOperation>())
44 {
45 __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
46
47 auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
48
49 GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
50 karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
51 karg.p_a_scale_grid + splitk_batch_offset.a_scale_k_split_offset,
52 karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
53 karg.p_b_scale_grid + splitk_batch_offset.b_scale_k_split_offset,
54 karg.p_c_grid + splitk_batch_offset.c_reduce_offset,
55 p_shared,
56 karg);
57 }
58#else
59 ignore = karg;
60#endif // end of if (defined(__gfx9__))
61}
62
63template <bool Use2LDS,
64 typename GridwiseGemm,
65 bool HasMainKBlockLoop,
66 InMemoryDataOperationEnum CGlobalMemoryDataOperation,
67 index_t MinimumOccupancy = 1,
70#if CK_USE_LAUNCH_BOUNDS
71__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
72#endif
73 // __attribute__((amdgpu_waves_per_eu(1, 1)))
74 kernel_gemm_xdl_cshuffle_v3_mx(typename GridwiseGemm::Argument karg)
75{
76#if defined(__gfx950__)
77 if constexpr(GridwiseGemm::template IsValidCompilationParameter<CGlobalMemoryDataOperation>())
78 {
79 // Pass two lds pointer is the key to tell compiler that ds_read/write
80 // operate on different lds chunk at same time without order dependecy
81 __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()];
82 __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()];
83
84 auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
85
86 GridwiseGemm::template Run_2Lds<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
87 karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
88 karg.p_a_scale_grid + splitk_batch_offset.a_scale_k_split_offset,
89 karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
90 karg.p_b_scale_grid + splitk_batch_offset.b_scale_k_split_offset,
91 karg.p_c_grid + splitk_batch_offset.c_reduce_offset,
92 p_shared_0,
93 p_shared_1,
94 karg);
95 }
96#else
97 ignore = karg;
98#endif // end of if (defined(__gfx9__))
99}
100#endif
101
102template <typename ALayout,
103 typename BLayout,
104 typename CLayout,
105 typename ADataType,
106 typename AScaleDataType,
107 typename BDataType,
108 typename BScaleDataType,
109 typename AccDataType,
110 typename CShuffleDataType,
111 typename CDataType,
112 typename AElementwiseOperation,
113 typename BElementwiseOperation,
114 typename CElementwiseOperation,
116 index_t ScaleBlockSize, // Scaling block size
117 index_t BlockSize, // Thread block size
118 index_t MPerBlock,
119 index_t NPerBlock,
120 index_t KPerBlock,
121 index_t AK1Value,
122 index_t BK1Value,
123 index_t MPerXdl,
124 index_t NPerXdl,
125 index_t MXdlPerWave,
126 index_t NXdlPerWave,
127 typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
128 typename ABlockTransferThreadClusterArrangeOrder,
129 typename ABlockTransferSrcAccessOrder,
130 index_t ABlockTransferSrcVectorDim,
131 index_t ABlockTransferSrcScalarPerVector,
132 index_t ABlockTransferDstScalarPerVector_AK1,
133 bool AThreadTransferSrcResetCoordinateAfterRun,
134 index_t ABlockLdsExtraM,
135 typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
136 typename BBlockTransferThreadClusterArrangeOrder,
137 typename BBlockTransferSrcAccessOrder,
138 index_t BBlockTransferSrcVectorDim,
139 index_t BBlockTransferSrcScalarPerVector,
140 index_t BBlockTransferDstScalarPerVector_BK1,
141 bool BThreadTransferSrcResetCoordinateAfterRun,
142 index_t BBlockLdsExtraN,
143 index_t CShuffleMXdlPerWavePerShuffle,
144 index_t CShuffleNXdlPerWavePerShuffle,
145 typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
146 index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
149 typename ComputeTypeA =
150 ADataType, // XXX: These should always be the same as ADataType and BDataType
151 typename ComputeTypeB =
152 BDataType, // TODO: Hardcode them and remove from the list of template parameters
153 bool PermuteA = false,
154 bool PermuteB = false>
156{
157
158 static constexpr auto I0 = Number<0>{};
159 static constexpr auto I1 = Number<1>{};
160 static constexpr auto I2 = Number<2>{};
161 static constexpr auto I3 = Number<3>{};
162 static constexpr auto I4 = Number<4>{};
163 static constexpr auto I5 = Number<5>{};
164 static constexpr auto I6 = Number<6>{};
165 static constexpr auto I7 = Number<7>{};
166 static constexpr auto I8 = Number<8>{};
167 static constexpr auto I9 = Number<9>{};
168
169 // K1 should be Number<...>
170 static constexpr auto AK0Number = Number<KPerBlock / AK1Value>{};
171 static constexpr auto BK0Number = Number<KPerBlock / BK1Value>{};
172 static constexpr auto AK1Number = Number<AK1Value>{};
173 static constexpr auto BK1Number = Number<BK1Value>{};
174
175 static constexpr auto lcm_AK1_BK1 = math::lcm(AK1Number, BK1Number);
176 static constexpr bool is_single_rate_mfma = false;
177 static constexpr auto is_scale_mfma = true;
178
179 static constexpr auto MXdlPack = 2;
180 static constexpr auto NXdlPack = 2;
181 static constexpr auto KXdlPack = 2;
182
183 //> KPack is at least the k_per_blk of selected mfma
184 //
185 // Should be a multiple of k_per_blk.
186 // TODO: Move this to blockwise pipeline base
187 // KPack in packed data types for pk A/B
188
191
192 static constexpr index_t KPack =
194 MfmaSelector<ComputeTypeA,
195 MPerXdl,
196 NPerXdl,
197 ComputeTypeB,
199 is_scale_mfma>::selected_mfma.k_per_blk /
201
203
204 __host__ static auto CalculateGridSize(index_t M, index_t N, index_t KBatch)
205 {
206 return std::make_tuple(Block2CTileMap::CalculateGridSize(M, N), 1, KBatch);
207 }
208
209 __host__ static auto CalculateMPadded(index_t M)
210 {
211 return math::integer_least_multiple(M, MPerBlock);
212 }
213
214 __host__ static auto CalculateNPadded(index_t N)
215 {
216 return math::integer_least_multiple(N, NPerBlock);
217 }
218
219 __host__ static auto CalculateKPadded(index_t K)
220 {
221 return math::integer_divide_ceil(K, KPerBlock) * KPerBlock;
222 }
223
224 __host__ static auto CalculateAK0Padded(index_t K, index_t K_Batch = 1)
225 {
226 auto K_t = K_Batch * KPerBlock;
227 return (K + K_t - 1) / K_t * (KPerBlock / AK1Value);
228 }
229
230 __host__ static auto CalculateBK0Padded(index_t K, index_t K_Batch = 1)
231 {
232 auto K_t = K_Batch * KPerBlock;
233 return (K + K_t - 1) / K_t * (KPerBlock / BK1Value);
234 }
235
236 __host__ static auto CalculateKPadded(index_t K, index_t K_Batch = 1)
237 {
238 auto K_t = K_Batch * KPerBlock;
239 return (K + K_t - 1) / K_t * KPerBlock;
240 }
241
242 __host__ static auto CalculateKRead(index_t K, index_t K_Batch = 1)
243 {
244 constexpr auto KReadVec = math::lcm(AK1Number, BK1Number);
245 auto K_t = K_Batch * KReadVec;
246 return (K + K_t - 1) / K_t * KReadVec;
247 }
248
249 __host__ static auto CalculateMBlock(index_t M)
250 {
251 return math::integer_divide_ceil(M, MPerBlock);
252 }
253
254 __host__ static auto CalculateNBlock(index_t N)
255 {
256 return math::integer_divide_ceil(N, NPerBlock);
257 }
258
259 template <index_t MNXdlPerWave,
260 index_t MNWaves,
261 index_t MNXdlPack,
262 index_t MNPerXdl,
263 typename TileDesc_K0_MN_K1>
264 __host__ __device__ static constexpr auto MakeGemmMmaTileDescriptor(const TileDesc_K0_MN_K1&)
265 {
266 constexpr index_t K0 = TileDesc_K0_MN_K1{}.GetLength(Number<0>{});
267 constexpr index_t MN = TileDesc_K0_MN_K1{}.GetLength(Number<1>{});
268 constexpr index_t K1 = TileDesc_K0_MN_K1{}.GetLength(Number<2>{});
269
270 constexpr auto permuted_desc = transform_tensor_descriptor(
271 TileDesc_K0_MN_K1{},
276
278 permuted_desc,
283 Number<MNPerXdl>{}))),
286 }
287
288 __host__ __device__ static auto MakeAGridDescriptor_AK0_M_AK1(
289 index_t M, index_t MPad, index_t K, index_t KPad, index_t StrideA, index_t AK0)
290 {
291 const auto a_grid_desc_mraw_kraw = [&]() {
293 {
294 return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1));
295 }
297 {
298 return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA));
299 }
300 }();
301
302 using GemmSpecialization = tensor_operation::device::GemmSpecialization;
303
304 if constexpr(GemmSpec == GemmSpecialization::MKPadding ||
305 GemmSpec == GemmSpecialization::MNKPadding)
306 {
307 // pad both M and K
308 const auto a_grid_desc_m_k =
309 transform_tensor_descriptor(a_grid_desc_mraw_kraw,
311 make_right_pad_transform(K, KPad - K)),
314
315 const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
316 a_grid_desc_m_k,
321
322 return a_grid_desc_ak0_m_ak1;
323 }
324 else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
325 GemmSpec == GemmSpecialization::MNPadding)
326 {
327 // pad M, but not K
328 const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
329 a_grid_desc_mraw_kraw,
330 make_tuple(make_unmerge_transform(make_tuple(K / KPerBlock, AK0Number, AK1Value)),
331 make_right_pad_transform(M, MPad - M)),
334
335 const auto a_grid_desc_permuted = transform_tensor_descriptor(
336 a_grid_desc_ak0_m_ak1,
342
343 const auto a_grid_desc = transform_tensor_descriptor(
344 a_grid_desc_permuted,
351 return a_grid_desc;
352 }
353 else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
354 GemmSpec == GemmSpecialization::NKPadding)
355 {
356 // pad K, but not M
357 const auto a_grid_desc_m_k = transform_tensor_descriptor(
358 a_grid_desc_mraw_kraw,
362
363 const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
364 a_grid_desc_m_k,
369
370 return a_grid_desc_ak0_m_ak1;
371 }
372 else
373 {
374 // not pad M or K
375 const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
376 a_grid_desc_mraw_kraw,
377 make_tuple(make_unmerge_transform(make_tuple(K / KPerBlock, AK0Number, AK1Value)),
381
382 const auto a_grid_desc_permuted = transform_tensor_descriptor(
383 a_grid_desc_ak0_m_ak1,
389
390 const auto a_grid_desc = transform_tensor_descriptor(
391 a_grid_desc_permuted,
398
399 return a_grid_desc;
400 }
401 }
402
403 __host__ __device__ static auto MakeBGridDescriptor_BK0_N_BK1(
404 index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0)
405 {
406 const auto b_grid_desc_nraw_kraw = [&]() {
408 {
409 return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(I1, StrideB));
410 }
412 {
413 return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(StrideB, I1));
414 }
415 }();
416
417 using GemmSpecialization = tensor_operation::device::GemmSpecialization;
418
420 GemmSpec != GemmSpecialization::Default),
421 "pk_i4_t does not support padding");
423 (GemmSpec != GemmSpecialization::Default &&
424 GemmSpec != GemmSpecialization::MPadding)),
425 "f4x2_pk_t does not support K padding");
430 GemmSpec != GemmSpecialization::Default),
431 "Packed F6 types do not support padding");
432
433 if constexpr(GemmSpec == GemmSpecialization::NKPadding ||
434 GemmSpec == GemmSpecialization::MNKPadding)
435 {
436 // pad both N and K
437 const auto b_grid_desc_n_k =
438 transform_tensor_descriptor(b_grid_desc_nraw_kraw,
440 make_right_pad_transform(K, KPad - K)),
443
444 const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
445 b_grid_desc_n_k,
450
451 return b_grid_desc_bk0_n_bk1;
452 }
453 else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
454 GemmSpec == GemmSpecialization::MNPadding)
455 {
456 // pad N, but not K
457 const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
458 b_grid_desc_nraw_kraw,
460 make_right_pad_transform(N, NPad - N)),
463
464 return b_grid_desc_bk0_n_bk1;
465 }
466 else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
467 GemmSpec == GemmSpecialization::MKPadding)
468 {
469 // pad K, but not N
470 const auto b_grid_desc_n_k = transform_tensor_descriptor(
471 b_grid_desc_nraw_kraw,
475
476 const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
477 b_grid_desc_n_k,
482
483 return b_grid_desc_bk0_n_bk1;
484 }
485 else
486 {
487 if constexpr(!PermuteB)
488 {
489 // not pad N or K
490 const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
491 b_grid_desc_nraw_kraw,
493 make_unmerge_transform(make_tuple(K / KPerBlock, BK0Number, BK1Value)),
497
498 const auto b_grid_desc_permuted = transform_tensor_descriptor(
499 b_grid_desc_bk0_n_bk1,
505
506 const auto b_grid_desc = transform_tensor_descriptor(
507 b_grid_desc_permuted,
514
515 return b_grid_desc;
516 }
517 else
518 {
519 // Weight Tile Permute
520 constexpr index_t BK01 = KPerBlock / BK1Value;
521 // const index_t BK00 = BK0 / BK01;
522 const index_t BK0_ = StrideB / BK1Value;
523 const index_t BK00 = BK0_ / BK01;
524
525 const auto b_grid_desc_bk00_n_bk01_bk1_permute =
526 make_naive_tensor_descriptor_packed(make_tuple(BK00, N, BK01, BK1Value));
527
528 const auto b_grid_desc_bk0_n_bk1_permute = transform_tensor_descriptor(
529 b_grid_desc_bk00_n_bk01_bk1_permute,
535
536 return b_grid_desc_bk0_n_bk1_permute;
537 }
538 }
539 }
540
541 template <typename ABlockDesc_AK0_M_AK1>
542 __host__ __device__ static constexpr auto
543 MakeAMmaTileDescriptor_M0_M1_M2_M3_K(const ABlockDesc_AK0_M_AK1&)
544 {
545 constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl);
546
548 ABlockDesc_AK0_M_AK1{});
549 }
550
551 template <typename BBlockDesc_BK0_N_BK1>
552 __host__ __device__ static constexpr auto
553 MakeBMmaTileDescriptor_N0_N1_N2_N3_K(const BBlockDesc_BK0_N_BK1&)
554 {
555 constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl);
556
558 BBlockDesc_BK0_N_BK1{});
559 }
560
561 __host__ __device__ static auto
563 {
564 const auto c_grid_desc_mraw_nraw = [&]() {
566 {
567 return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1));
568 }
570 {
571 return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC));
572 }
573 }();
574
575 // pad M and N
576 return transform_tensor_descriptor(c_grid_desc_mraw_nraw,
578 make_right_pad_transform(N, NPad - N)),
581#if 0
582 using GemmSpecialization = tensor_operation::device::GemmSpecialization;
583
584 if constexpr(GemmSpec == GemmSpecialization::MNPadding ||
585 GemmSpec == GemmSpecialization::MNKPadding)
586 {
587 // pad M and N
588 return transform_tensor_descriptor(c_grid_desc_mraw_nraw,
590 make_right_pad_transform(N, NPad - N)),
593 }
594 else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
595 GemmSpec == GemmSpecialization::MKPadding)
596 {
597 // pad M, but not N
599 c_grid_desc_mraw_nraw,
603 }
604 else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
605 GemmSpec == GemmSpecialization::NKPadding)
606 {
607 // pad N, but not M
609 c_grid_desc_mraw_nraw,
613 }
614 else
615 {
616 // not pad M or N
617 return c_grid_desc_mraw_nraw;
618 }
619#endif
620 }
621
622 struct Problem
623 {
624 __host__ Problem(index_t M_,
625 index_t N_,
626 index_t K_,
627 index_t StrideA_,
628 index_t StrideScaleA_,
629 index_t StrideB_,
630 index_t StrideScaleB_,
631 index_t StrideC_,
632 index_t KBatch_)
633 : M{M_},
634 N{N_},
635 K{K_},
636 StrideA{StrideA_},
637 StrideScaleA{StrideScaleA_},
638 StrideB{StrideB_},
639 StrideScaleB{StrideScaleB_},
640 StrideC{StrideC_},
641 KBatch{KBatch_},
644 KRead{CalculateKRead(K_, KBatch_)},
645 KPadded{CalculateKPadded(K_, KBatch_)},
646 AK0{CalculateAK0Padded(K_, KBatch_)},
647 BK0{CalculateBK0Padded(K_, KBatch_)},
650 {
651 }
652
653 __host__ void Print() const
654 {
655 std::cout << "problem {" << "M:" << M << ", " << "N:" << N << ", " << "K:" << K << ", "
656 << "SA:" << StrideA << ", " << "SScaleA:" << StrideScaleA << ", "
657 << "SB:" << StrideB << ", " << "SScaleB:" << StrideScaleB << ", "
658 << "SC:" << StrideC << ", " << "MP:" << MPadded << ", " << "NP:" << NPadded
659 << ", " << "KRead:" << KRead << ", " << "KP:" << KPadded << ", "
660 << "AK0:" << AK0 << ", " << "BK0:" << BK0 << ", " << "MBlock: " << MBlock
661 << ", " << "NBlock: " << NBlock << "}" << std::endl;
662 }
663
681 };
682
683 // Argument
685 {
686 __host__ Argument(const ADataType* p_a_grid_,
687 const AScaleDataType* p_a_scale_grid_,
688 const BDataType* p_b_grid_,
689 const BScaleDataType* p_b_scale_grid_,
690 CDataType* p_c_grid_,
691 index_t M_,
692 index_t N_,
693 index_t K_,
694 index_t StrideA_,
695 index_t StrideScaleA_,
696 index_t StrideB_,
697 index_t StrideScaleB_,
698 index_t StrideC_,
699 index_t k_batch_,
700 AElementwiseOperation a_element_op_,
701 BElementwiseOperation b_element_op_,
702 CElementwiseOperation c_element_op_,
703 bool is_reduce_ = false)
704 : Problem{M_,
705 N_,
706 K_ / APackedSize,
707 StrideA_ / APackedSize,
708 StrideScaleA_,
709 StrideB_ / BPackedSize,
710 StrideScaleB_,
711 StrideC_,
712 k_batch_},
713 p_a_grid{p_a_grid_},
714 p_a_scale_grid{p_a_scale_grid_},
715 p_b_grid{p_b_grid_},
716 p_b_scale_grid{p_b_scale_grid_},
717 p_c_grid{p_c_grid_},
718 a_element_op{a_element_op_},
719 b_element_op{b_element_op_},
720 c_element_op{c_element_op_},
721 is_reduce(is_reduce_)
722 {
723 }
724
725 __host__ __device__ inline bool IsReduceAdd() const
726 {
727 return (Problem::KBatch > 1) && is_reduce;
728 }
729
730 __host__ __device__ inline bool IsAtomicAdd() const
731 {
732 return (Problem::KBatch > 1) && (!is_reduce);
733 }
734
735 const ADataType* p_a_grid;
736 const AScaleDataType* p_a_scale_grid;
737 const BDataType* p_b_grid;
738 const BScaleDataType* p_b_scale_grid;
739 CDataType* p_c_grid;
740
741 const AElementwiseOperation a_element_op;
742 const BElementwiseOperation b_element_op;
743 const CElementwiseOperation c_element_op;
745 };
746
748 {
749
750 __device__ SplitKBatchOffset(Argument& karg, index_t k_id)
751 {
753 {
754 a_k_split_offset = k_id * karg.KRead;
755 }
757 {
758 a_k_split_offset = k_id * karg.KRead * karg.StrideA;
759 }
760
762 {
763 b_k_split_offset = k_id * karg.KRead * karg.StrideB;
764 }
766 {
767 if constexpr(!PermuteB)
768 {
769 b_k_split_offset = k_id * karg.KRead;
770 }
771 else
772 {
773 const int k0_offset = karg.KRead * karg.N;
774 b_k_split_offset = k_id * k0_offset;
775 }
776 }
777
778 // Calculate A scale offset
780 k_id * karg.KRead / (ScaleBlockSize / APackedSize) * MXdlPack * MPerXdl;
781
782 // Calculate B scale offset
784 k_id * karg.KRead / (ScaleBlockSize / BPackedSize) * NXdlPack * NPerXdl;
785
786 if(k_id < (karg.KBatch - 1))
787 {
788 karg.K = karg.KRead;
789 }
790 else
791 {
792 karg.K = karg.K - karg.KRead * (karg.KBatch - 1);
793 }
794
795 if(karg.IsReduceAdd())
796 {
797 c_reduce_offset = k_id * karg.M * karg.N;
798 }
799 else
800 {
801 c_reduce_offset = 0;
802 }
803 }
804
807 index_t a_scale_k_split_offset; // New member for scale matrix offset
808 index_t b_scale_k_split_offset; // New member for scale matrix offset
810 };
811
812 __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
813 {
814 constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
815 constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
816 constexpr index_t WaveSize = BlockSize / (MWave * NWave);
817
818 // A matrix in LDS memory, dst of blockwise copy
819 if constexpr(ABlockLdsExtraM || BlkGemmPipelineVer == BlockGemmPipelineVersion::v4)
820 {
821 // contiguous in LDS
825 }
826 // xor tensor transformation request more unnecessary vgpr usage, would cause register spill
827 // in some cases.
829 {
830 constexpr auto a_lds_block_desc =
833
834 constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
835 a_lds_block_desc,
841
842 return a_lds_block_desc_permuted;
843 }
844 else // ColumnMajor A
845 {
846 // kfold and mpair dimension is not always required.
847 // more dimension in merge_transform increase the difficulty of generating immarg offset
848 // for compiler.
849 constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
850 constexpr auto M1 = MPerBlock / M0;
851
852 constexpr auto KThreadWrite = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0);
853 constexpr auto K0PerThreadWrite = AK0Number / KThreadWrite;
854 constexpr auto KThreadRead = WaveSize / MPerXdl;
855 constexpr auto K0PerThreadRead = AK0Number / KThreadRead;
856
857 constexpr auto kfold = (AK1Number * M0 * sizeof(ADataType) > 128)
858 ? 1
859 : 128 / (AK1Number * M0 * sizeof(ADataType));
860 constexpr auto KThreadReadPerm =
861 (kfold * K0PerThreadWrite / K0PerThreadRead) > 1
862 ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead)
863 : KThreadRead;
864
865 // 1<=mpair<=n0
866 constexpr auto mpair = (AK1Number * MPerXdl * sizeof(ADataType) > 128)
867 ? 1
868 : ((128 / (AK1Number * MPerXdl * sizeof(ADataType))) > M0
869 ? M0
870 : 128 / (AK1Number * MPerXdl * sizeof(ADataType)));
871
872 constexpr auto a_lds_block_desc = make_naive_tensor_descriptor_packed(
876 Number<kfold * M0 / mpair>{},
878 AK1Number));
879
880 constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
881 a_lds_block_desc,
886 make_tuple(Number<KThreadReadPerm * M1>{}, Number<kfold * M0 / mpair>{})),
893
894 constexpr auto a_lds_block_desc_unmerged = transform_tensor_descriptor(
895 a_lds_block_desc_permuted,
904 Sequence<1>{},
905 Sequence<2>{},
906 Sequence<3>{},
907 Sequence<4>{},
908 Sequence<5>{}),
910 Sequence<2>{},
913 Sequence<6>{},
914 Sequence<7>{}));
915
916 constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor(
917 a_lds_block_desc_unmerged,
920 Number<KThreadWrite / kfold / KThreadReadPerm>{},
928
929 return a_lds_block_desc_ak0_m_ak1;
930 }
931 }
932
933 __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
934 {
935 constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
936 constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
937 constexpr index_t WaveSize = BlockSize / (MWave * NWave);
938 // B matrix in LDS memory, dst of blockwise copy
939 if constexpr(BBlockLdsExtraN || BlkGemmPipelineVer == BlockGemmPipelineVersion::v4)
940 {
941 // contiguous in lds
945 }
947 {
948 // NLdsLayer * K0 as logical Bank
949 constexpr auto b_lds_block_desc =
952
953 constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor(
954 b_lds_block_desc,
960
961 return b_lds_block_desc_permuted;
962 }
963 else // RowMajor B
964 {
965 constexpr auto N0 = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I1);
966 constexpr auto N1 = NPerBlock / N0;
967
968 constexpr auto KThreadWrite = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I0);
969 constexpr auto K0PerThreadWrite = BK0Number / KThreadWrite;
970 constexpr auto KThreadRead = WaveSize / NPerXdl;
971 constexpr auto K0PerThreadRead = BK0Number / KThreadRead;
972
973 constexpr auto kfold = (BK1Number * N0 * sizeof(BDataType) > 128)
974 ? 1
975 : 128 / (BK1Number * N0 * sizeof(BDataType));
976 constexpr auto KThreadReadPerm =
977 (kfold * K0PerThreadWrite / K0PerThreadRead) > 1
978 ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead)
979 : KThreadRead;
980
981 // 1<=npair<=n0
982 constexpr auto npair = (BK1Number * NPerXdl * sizeof(BDataType) > 128)
983 ? 1
984 : ((128 / (BK1Number * NPerXdl * sizeof(BDataType))) > N0
985 ? N0
986 : 128 / (BK1Number * NPerXdl * sizeof(BDataType)));
987
988 constexpr auto b_lds_block_desc = make_naive_tensor_descriptor_packed(
992 Number<kfold * N0 / npair>{},
994 BK1Number));
995
996 constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor(
997 b_lds_block_desc,
1002 make_tuple(Number<KThreadReadPerm * N1>{}, Number<kfold * N0 / npair>{})),
1005 make_tuple(
1007 make_tuple(
1009
1010 constexpr auto b_lds_block_desc_unmerged = transform_tensor_descriptor(
1011 b_lds_block_desc_permuted,
1012 make_tuple(
1020 Sequence<1>{},
1021 Sequence<2>{},
1022 Sequence<3>{},
1023 Sequence<4>{},
1024 Sequence<5>{}),
1026 Sequence<2>{},
1029 Sequence<6>{},
1030 Sequence<7>{}));
1031
1032 constexpr auto b_lds_block_desc_bk0_n_bk1 = transform_tensor_descriptor(
1033 b_lds_block_desc_unmerged,
1036 Number<KThreadWrite / kfold / KThreadReadPerm>{},
1037 Number<kfold>{},
1044
1045 return b_lds_block_desc_bk0_n_bk1;
1046 }
1047 }
1048
1050 {
1051 constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
1052 constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
1053
1054 constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
1056 make_tuple(I1,
1058 I1,
1060
1061 return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock;
1062 }
1063
1066 BlkGemmPipelineVer,
1067 BlkGemmPipeSched,
1068 BlockSize,
1069 ScaleBlockSize,
1070 ADataType,
1071 AScaleDataType,
1072 BDataType,
1073 BScaleDataType,
1074 ComputeTypeA,
1075 AccDataType,
1082 ABlockTransferSrcScalarPerVector,
1083 BBlockTransferSrcScalarPerVector,
1084 MPerBlock,
1085 NPerBlock,
1086 KPerBlock,
1087 MPerXdl,
1088 NPerXdl,
1089 MXdlPerWave,
1090 NXdlPerWave,
1091 KPack>())>;
1092
1093 __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
1094 {
1095 // LDS allocation for A and B: be careful of alignment
1096 constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
1097 constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
1098
1099 // lds max alignment
1100 constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number);
1101
1102 constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
1103 a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
1104
1105 constexpr auto b_block_space_size_aligned = math::integer_least_multiple(
1106 b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
1107
1108 // LDS allocation for C shuffle in LDS
1109 constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
1111
1112 constexpr auto c_block_size =
1113 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
1114
1115 return math::max((a_block_space_size_aligned * sizeof(ADataType) +
1116 b_block_space_size_aligned * sizeof(BDataType)),
1117 c_block_size * sizeof(CShuffleDataType));
1118 }
1119
1121
1122 // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
1123 __host__ static constexpr bool CheckValidity(const Argument& karg)
1124 {
1125 static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
1126 (NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
1127 "Invalid tuning param!");
1128
1129 static_assert(KPerBlock % (ScaleBlockSize / BPackedSize) == 0,
1130 "KPerBlock should be multiple of ScaleBlockSize");
1131
1132 if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding ||
1137 {
1138 if(!(karg.M % MPerBlock == 0))
1139 {
1140 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1141 {
1142 std::cout << "Arg M value is not a multiple of MPerBlock! M: " << karg.M << " "
1143 << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
1144 << std::endl;
1145 }
1146 return false;
1147 }
1148 }
1149
1150 if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding ||
1155 {
1156 if(!(karg.N % NPerBlock == 0))
1157 {
1158 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1159 {
1160 std::cout << "Arg N value is not a multiple of NPerBlock! N: " << karg.N << " "
1161 << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
1162 << std::endl;
1163 }
1164 return false;
1165 }
1166 }
1167
1168 if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::KPadding ||
1172 {
1173 auto K_t = karg.KBatch * KPerBlock;
1174 if(!(karg.K % K_t == 0))
1175 {
1176 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1177 {
1178 std::cout << "Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: "
1179 << karg.K << " " << __FILE__ << ":" << __LINE__
1180 << ", in function: " << __func__ << std::endl;
1181 }
1182 return false;
1183 }
1184 }
1185 else
1186 {
1187 constexpr auto KReadVec = math::lcm(AK1Number, BK1Number);
1188 auto K_t = karg.KBatch * KReadVec;
1189 auto KReadPadSplited = math::integer_divide_ceil(karg.K, K_t) * KReadVec;
1190 if((KReadPadSplited * (karg.KBatch - 1)) >= karg.K)
1191 {
1192 return false;
1193 }
1194 }
1195
1197 {
1198 if(karg.K % ABlockTransferSrcScalarPerVector != 0)
1199 {
1200 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1201 {
1202 std::cout << "Arg K (" << karg.K
1203 << ") value is not a multiple of ABlockTransferSrcScalarPerVector ("
1204 << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
1205 << __LINE__ << ", in function: " << __func__ << std::endl;
1206 }
1207 return false;
1208 }
1209 }
1210 else
1211 {
1212 if(karg.M % ABlockTransferSrcScalarPerVector != 0)
1213 {
1214 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1215 {
1216 std::cout << "Arg M (" << karg.M
1217 << ") value is not a multiple of ABlockTransferSrcScalarPerVector ("
1218 << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
1219 << __LINE__ << ", in function: " << __func__ << std::endl;
1220 }
1221 return false;
1222 }
1223 }
1224
1226 {
1227 if(karg.N % BBlockTransferSrcScalarPerVector != 0)
1228 {
1229 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1230 {
1231 std::cout << "Arg N (" << karg.N
1232 << ") value is not a multiple of BBlockTransferSrcScalarPerVector ("
1233 << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
1234 << __LINE__ << ", in function: " << __func__ << std::endl;
1235 }
1236 return false;
1237 }
1238 }
1239 else
1240 {
1241 if(karg.K % BBlockTransferSrcScalarPerVector != 0)
1242 {
1243 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1244 {
1245 std::cout << "Arg K (" << karg.K
1246 << ") value is not a multiple of BBlockTransferSrcScalarPerVector ("
1247 << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
1248 << __LINE__ << ", in function: " << __func__ << std::endl;
1249 }
1250 return false;
1251 }
1252 }
1253
1255 {
1256 if(karg.N % CShuffleBlockTransferScalarPerVector_NPerBlock != 0)
1257 {
1258 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1259 {
1260 std::cout << "Arg N (" << karg.N
1261 << ") value is not a multiple of "
1262 "CShuffleBlockTransferScalarPerVector_NPerBlock ("
1263 << CShuffleBlockTransferScalarPerVector_NPerBlock << " )! "
1264 << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
1265 << std::endl;
1266 }
1267 return false;
1268 }
1269 }
1270 else
1271 {
1272 if(karg.M % CShuffleBlockTransferScalarPerVector_NPerBlock != 0)
1273 {
1274 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1275 {
1276 std::cout << "Arg M (" << karg.M
1277 << ") value is not a multiple of "
1278 "CShuffleBlockTransferScalarPerVector_NPerBlock ("
1279 << CShuffleBlockTransferScalarPerVector_NPerBlock << " )! "
1280 << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
1281 << std::endl;
1282 }
1283 return false;
1284 }
1285 }
1286
1291 {
1292 if(!karg.IsReduceAdd())
1293 {
1294 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1295 {
1296 std::cout << " KBatch: " << karg.KBatch << " > 1 is not support yet" << __FILE__
1297 << ":" << __LINE__ << ", in function: " << __func__ << std::endl;
1298 }
1299 if(karg.KBatch > 1)
1300 {
1301 return false;
1302 }
1303 }
1304 }
1305#if 0
1306 // check gridwise gemm pipeline
1307 const auto num_k_loop = karg.AK0 / (KPerBlock / AK1Value);
1308
1309 if constexpr(BlkGemmPipelineVer != BlockGemmPipelineVersion::v1)
1310 {
1311 if(num_k_loop <= BlockwiseGemmPipe::PrefetchStages)
1312 {
1313 return false;
1314 }
1315 }
1316#endif
1317 // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
1318 return true;
1319 }
1320
1321 __host__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
1322 {
1323 const index_t num_loop = K / KPerBlock;
1324
1325 return BlockwiseGemmPipe::BlockHasHotloop(num_loop);
1326 }
1327
1328 __host__ static constexpr TailNumber CalculateKBlockLoopTailNum(index_t K)
1329 {
1330 const index_t num_loop = K / KPerBlock;
1331
1332 return BlockwiseGemmPipe::BlockLoopTailNum(num_loop);
1333 }
1334
1335 template <typename CGridDesc>
1336 __host__ __device__ static constexpr auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
1337 const CGridDesc& c_grid_desc_m_n, index_t MBlock, index_t NBlock)
1338 {
1339 const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor(
1340 c_grid_desc_m_n,
1345
1346 return c_grid_desc_mblock_mperblock_nblock_nperblock;
1347 }
1348
1349 // return block_id to C matrix tile idx (m0, n0) mapping
1350 // if arch = gfx942
1352 // using Block2CTileMap = BlockToCTileMap_3DGrid_KSplit<MPerBlock, NPerBlock>;
1353
1355 static constexpr index_t scale_pack_size_a = sizeof(AScaleDataType) / sizeof(mx_scale_t);
1356 static constexpr index_t scale_pack_size_b = sizeof(BScaleDataType) / sizeof(mx_scale_t);
1357 static_assert(KXdlPack * MXdlPack % scale_pack_size_a == 0,
1358 "A scale pack data type too large!");
1359 static_assert(KXdlPack * NXdlPack % scale_pack_size_b == 0,
1360 "B scale pack data type too large!");
1361
1364 "A/B ElementwiseOperation should be PassThrough as load_to_lds is used!");
1365
1366 template <typename AGridDesc_AK0_M_K1,
1367 typename AScaleGridDesc_AM_AK,
1368 typename BGridDesc_BK0_N_K1,
1369 typename BScaleGridDesc_BN_AK,
1370 typename CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
1371 bool HasMainKBlockLoop,
1372 InMemoryDataOperationEnum CGlobalMemoryDataOperation,
1373 TailNumber TailNum = TailNumber::Odd>
1374 __device__ static void Run(const ADataType* p_a_grid,
1375 const AScaleDataType* p_a_scale_grid,
1376 const BDataType* p_b_grid,
1377 const BScaleDataType* p_b_scale_grid,
1378 CDataType* p_c_grid,
1379 void* p_shared,
1380 const Problem& problem,
1381 const AGridDesc_AK0_M_K1& a_grid_desc_ak0_m_ak1,
1382 const AScaleGridDesc_AM_AK& a_scale_grid_desc_am_ak,
1383 const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1,
1384 const BScaleGridDesc_BN_AK& b_scale_grid_desc_bn_ak,
1385 const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
1386 c_grid_desc_mblock_mperblock_nblock_nperblock)
1387 {
1388 const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1389 p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
1390 const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1391 p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
1393 p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
1394
1395 // A Scale buffer
1396 const auto a_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1397 p_a_scale_grid, a_scale_grid_desc_am_ak.GetElementSpaceSize());
1398
1399 // B Scale buffer
1400 const auto b_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1401 p_b_scale_grid, b_scale_grid_desc_bn_ak.GetElementSpaceSize());
1402
1403 const CElementwiseOperation c_element_op{};
1404
1405 // divide block work by [M, N]
1406 const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N, 4};
1407
1408 const auto block_work_idx =
1409 block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
1410
1411 if(!block_2_ctile_map.ValidCTileIndex(
1412 block_work_idx,
1413 make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
1414 c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
1415 {
1416 return;
1417 }
1418
1419 const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]);
1420 const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]);
1421
1422 // HACK: this force m/n_block_data_idx_on_grid into SGPR
1423 const index_t m_block_data_idx_on_grid =
1424 __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock);
1425
1426 const index_t n_block_data_idx_on_grid =
1427 __builtin_amdgcn_readfirstlane(block_n_id * NPerBlock);
1428
1429 // lds max alignment
1430 constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number);
1431
1432 // A matrix in LDS memory, dst of blockwise copy
1433 constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
1434
1435 // B matrix in LDS memory, dst of blockwise copy
1436 constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
1437
1438 auto a_blockwise_copy =
1441 ABlockTransferThreadClusterLengths_AK0_M_AK1,
1442 ABlockTransferThreadClusterArrangeOrder,
1443 ADataType,
1444 ADataType,
1445 decltype(a_grid_desc_ak0_m_ak1),
1446 decltype(a_block_desc_ak0_m_ak1),
1447 ABlockTransferSrcAccessOrder,
1448 ABlockTransferSrcVectorDim,
1449 2,
1450 ABlockTransferSrcScalarPerVector>(
1451 a_grid_desc_ak0_m_ak1,
1452 make_multi_index(0, m_block_data_idx_on_grid, 0),
1453 a_block_desc_ak0_m_ak1,
1454 make_multi_index(0, 0, 0));
1455
1456 // B matrix blockwise copy
1457 auto b_blockwise_copy =
1460 BBlockTransferThreadClusterLengths_BK0_N_BK1,
1461 BBlockTransferThreadClusterArrangeOrder,
1462 BDataType,
1463 BDataType,
1464 decltype(b_grid_desc_bk0_n_bk1),
1465 decltype(b_block_desc_bk0_n_bk1),
1466 BBlockTransferSrcAccessOrder,
1467 BBlockTransferSrcVectorDim,
1468 2,
1469 BBlockTransferSrcScalarPerVector>(
1470 b_grid_desc_bk0_n_bk1,
1471 make_multi_index(0, n_block_data_idx_on_grid, 0),
1472 b_block_desc_bk0_n_bk1,
1473 make_multi_index(0, 0, 0));
1474
1475 // LDS allocation for A and B: be careful of alignment
1476 constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
1477 a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
1478
1479 // Cast after lds
1481 static_cast<ADataType*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
1482
1484 reinterpret_cast<BDataType*>(static_cast<char*>(p_shared) +
1485 a_block_space_size_aligned * sizeof(ADataType)),
1486 b_block_desc_bk0_n_bk1.GetElementSpaceSize());
1487
1488 constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0);
1489 constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1Number, 0, 0);
1490
1491 // Blockwise GEMM pipeline
1492 static_assert(std::is_default_constructible_v<BlockwiseGemmPipe>);
1493 auto blockwise_gemm_pipeline = BlockwiseGemmPipe{};
1494 auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer();
1495
1496 const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
1497 (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
1498 KPerBlock);
1499
1500 // Initial thread mapping for:
1501 // BlockSize = 256
1502 // MPerXdl=NPerXdl=32 and MPerBlock=NPerBlock=128 MRepeat=NRepeat=2 MWaves=NWaves=2
1503 // For each [m0, n0] tile, there are 4 waves:
1504 // tId in [ 0, 63] m x n = [ 0, 31] x [ 0, 31] waveId = [0, 0]
1505 // tId in [ 64, 127] m x n = [ 0, 31] x [32, 63] waveId = [0, 1]
1506 // tId in [128, 191] m x n = [32, 63] x [ 0, 31] waveId = [1, 0]
1507 // tId in [192, 255] m x n = [32, 63] x [32, 63] waveId = [1, 1]
1508
1509 // BlockSize = 128
1510 // MPerXdl=NPerXdl=16 and MPerBlock=128 NPerBlock=16 MRepeat=4 NRepeat=1 MWaves=2 NWaves=1
1511 // For each [m0, n0] tile, there are 2 waves:
1512 // tId in [ 0, 63] m x n = [ 0, 15] x [0, 15] waveId = [0, 0]
1513 // tId in [ 64, 127] m x n = [16, 31] x [0, 15] waveId = [1, 0]
1514
1515 // TODO: Document initial thread mapping for more combinations of parameters
1516
1517 const auto wave_idx = BlockwiseGemmPipe::GetWaveIdx();
1518 const auto waveId_m = wave_idx[I0];
1519 const auto waveId_n = wave_idx[I1];
1520
1521 // static constexpr auto mfma = BlockwiseGemmPipe::xdlops_gemm.mfma;
1522
1523 // auto thread_offset_k = (get_thread_local_1d_id() % BlockwiseGemmPipe::WaveSize) /
1524 // mfma.selected_mfma.num_threads_per_blk;
1525
1526 // A wave access continuous memory
1527 auto thread_offset_shuffled =
1528 get_thread_local_1d_id() % BlockwiseGemmPipe::WaveSize * KXdlPack * MXdlPack;
1529
1530 auto a_thread_offset_m = waveId_m;
1531
1532 auto a_scale_thread_copy = ThreadwiseTensorSliceTransfer_v2<
1533 AScaleDataType,
1534 AScaleDataType,
1535 decltype(a_scale_grid_desc_am_ak),
1536 decltype(BlockwiseGemmPipe::a_scale_thread_desc),
1537 Sequence<1, 1, KXdlPack * MXdlPack / scale_pack_size_a>, // SliceLengths
1538 Sequence<0, 1, 2>, // DimAccessOrder
1539 2, // SrcVectorDim
1540 KXdlPack * MXdlPack / scale_pack_size_a, // SrcScalarPerVector
1541 1, // SrcScalarStrideInVector
1542 true>(a_scale_grid_desc_am_ak,
1543 make_multi_index(block_m_id * MPerBlock / MPerXdl / MXdlPack + a_thread_offset_m,
1544 0,
1545 thread_offset_shuffled / scale_pack_size_a));
1546
1547 auto b_thread_offset_n = waveId_n;
1548
1549 auto b_scale_thread_copy = ThreadwiseTensorSliceTransfer_v2<
1550 BScaleDataType,
1551 BScaleDataType,
1552 decltype(b_scale_grid_desc_bn_ak),
1553 decltype(BlockwiseGemmPipe::b_scale_thread_desc),
1554 Sequence<1, 1, KXdlPack * NXdlPack / scale_pack_size_b>, // SliceLengths
1555 Sequence<0, 1, 2>, // DimAccessOrder
1556 2, // SrcVectorDim
1557 KXdlPack * MXdlPack / scale_pack_size_b, // SrcScalarPerVector
1558 1, // SrcScalarStrideInVector
1559 true>(b_scale_grid_desc_bn_ak,
1560 make_multi_index(block_n_id * NPerBlock / NPerXdl / NXdlPack + b_thread_offset_n,
1561 0,
1562 thread_offset_shuffled / scale_pack_size_b));
1563
1564 blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(a_grid_desc_ak0_m_ak1,
1565 a_block_desc_ak0_m_ak1,
1566 a_blockwise_copy,
1567 a_grid_buf,
1568 a_block_buf,
1569 a_block_slice_copy_step,
1570 b_grid_desc_bk0_n_bk1,
1571 b_block_desc_bk0_n_bk1,
1572 b_blockwise_copy,
1573 b_grid_buf,
1574 b_block_buf,
1575 b_block_slice_copy_step,
1576 c_thread_buf,
1577 a_scale_grid_desc_am_ak,
1578 a_scale_thread_copy,
1579 a_scale_grid_buf,
1580 b_scale_grid_desc_bn_ak,
1581 b_scale_thread_copy,
1582 b_scale_grid_buf,
1583 num_k_block_main_loop);
1584
1585 // shuffle C and write out
1586 {
1587 static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
1588 NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
1589 "wrong!");
1590 static_assert(CShuffleMXdlPerWavePerShuffle % MXdlPack == 0 &&
1591 CShuffleNXdlPerWavePerShuffle % NXdlPack == 0,
1592 "wrong!");
1593
1594 constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
1595 constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
1596
1597 // TODO: hacky, fix it!
1598 constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
1599 blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3();
1600
1601 // TODO: hacky, fix it!
1602 // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
1603 constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
1604 blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3();
1605
1606 constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0);
1607 constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1);
1608 constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2);
1609 constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3);
1610 constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4);
1611 constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5);
1612 constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6);
1613 constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7);
1614 constexpr auto M5 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I8);
1615 constexpr auto N3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I9);
1616
1617 constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
1619
1620 auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1621 static_cast<CShuffleDataType*>(p_shared),
1622 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
1623
1624 constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
1625 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
1626 make_tuple(
1629 Number<CShuffleMXdlPerWavePerShuffle / MXdlPack>{}, // M0 (MXdlPerWave) per
1630 // shuffle
1631 M1, // M1 = MWave
1632 M2, // M2 = MXdlPack
1633 M3, // M3 * M4 * M5 = MPerXdl
1634 M4,
1635 M5)),
1638 Number<CShuffleNXdlPerWavePerShuffle / NXdlPack>{}, // N0 (NXdlPerWave) per
1639 // shuffle
1640 N1, // N1 = NWave
1641 N2, // N2 = NXdlPack
1642 N3))), // N3 = NPerXdl
1646 Sequence<>{},
1648
1649 // calculate origin of thread output tensor on global memory
1650 // blockwise GEMM c matrix starting index
1651 const auto c_thread_mtx_on_block =
1652 blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
1653
1654 const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
1655 const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
1656
1657 const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
1659 make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4, M5))),
1662
1663 const auto m_thread_data_on_block_idx =
1664 m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
1665 make_multi_index(m_thread_data_on_block));
1666
1667 const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
1669 make_tuple(make_merge_transform(make_tuple(N0, N1, N2, N3))),
1672
1673 const auto n_thread_data_on_block_idx =
1674 n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
1675 make_multi_index(n_thread_data_on_block));
1676
1677 // shuffle: threadwise copy C from VGPR to LDS
1678 auto c_thread_copy_vgpr_to_lds = ThreadwiseTensorSliceTransfer_v1r3<
1679 AccDataType,
1680 CShuffleDataType,
1681 decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
1682 decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
1684 Sequence<CShuffleMXdlPerWavePerShuffle / MXdlPack,
1685 CShuffleNXdlPerWavePerShuffle / NXdlPack,
1686 I1,
1687 I1,
1688 M2,
1689 N2,
1690 M3,
1691 I1,
1692 M5,
1693 I1>,
1695 9,
1696 1,
1698 1,
1699 true>{c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1701 0,
1702 m_thread_data_on_block_idx[I1],
1703 n_thread_data_on_block_idx[I1],
1704 m_thread_data_on_block_idx[I2],
1705 n_thread_data_on_block_idx[I2],
1706 m_thread_data_on_block_idx[I3],
1707 m_thread_data_on_block_idx[I4],
1708 m_thread_data_on_block_idx[I5],
1709 n_thread_data_on_block_idx[I3]),
1711
1712 // shuffle: blockwise copy C from LDS to global
1713 auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1<
1714 ThisThreadBlock, // ThreadGroup
1715 CElementwiseOperation, // ElementwiseOperation,
1716 CGlobalMemoryDataOperation, // DstInMemOp,
1717 Sequence<1,
1718 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1719 1,
1720 CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
1721 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
1722 Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
1723 CShuffleDataType, // typename SrcData,
1724 CDataType, // typename DstData,
1725 decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
1726 decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
1727 Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
1728 3, // index_t VectorDim,
1729 CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector,
1730 true, // bool ThreadTransferSrcResetCoordinateAfterRun,
1731 false> // bool ThreadTransferDstResetCoordinateAfterRun>
1732 {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
1733 make_multi_index(0, 0, 0, 0),
1734 c_grid_desc_mblock_mperblock_nblock_nperblock,
1735 make_multi_index(block_m_id, 0, block_n_id, 0),
1736 c_element_op};
1737
1738 // space filling curve for threadwise C in VGPR
1739 constexpr auto sfc_c_vgpr =
1740 SpaceFillingCurve<Sequence<MXdlPerWave / MXdlPack,
1741 NXdlPerWave / NXdlPack,
1742 1,
1743 1,
1744 MXdlPack,
1745 NXdlPack,
1746 M2,
1747 1,
1748 M4,
1749 1>,
1751 Sequence<CShuffleMXdlPerWavePerShuffle / MXdlPack,
1752 CShuffleNXdlPerWavePerShuffle / NXdlPack,
1753 1,
1754 1,
1755 MXdlPack,
1756 NXdlPack,
1757 M2,
1758 1,
1759 M4,
1760 1>>{};
1761
1762 // space filling curve for shuffled blockwise C in global mem
1763 constexpr auto sfc_c_global =
1766 Sequence<1,
1767 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1768 1,
1769 CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
1770
1771 constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
1772
1773 static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!");
1774
1775 static_for<0, num_access, 1>{}([&](auto access_id) {
1776 // make sure it's safe to write to LDS
1778
1779 // each thread write its data from VGPR to LDS
1780 c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1781 sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
1782 c_thread_buf,
1783 c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1784 c_shuffle_block_buf);
1785
1786 // make sure it's safe to read from LDS
1788
1789 // each block copy its data from LDS to global
1790 c_shuffle_block_copy_lds_to_global.Run(
1791 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
1792 c_shuffle_block_buf,
1793 c_grid_desc_mblock_mperblock_nblock_nperblock,
1794 c_grid_buf);
1795
1796 if constexpr(access_id < num_access - 1)
1797 {
1798 constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id);
1799
1800 // move on C
1801 c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow(
1802 c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step);
1803 }
1804 });
1805 }
1806 }
1807
1808 template <bool HasMainKBlockLoop,
1809 InMemoryDataOperationEnum CGlobalMemoryDataOperation,
1810 TailNumber TailNum = TailNumber::Odd>
1811 __device__ static void Run(const ADataType* p_a_grid,
1812 const AScaleDataType* p_a_scale_grid,
1813 const BDataType* p_b_grid,
1814 const BScaleDataType* p_b_scale_grid,
1815 CDataType* p_c_grid,
1816 void* p_shared,
1817 const Problem& problem)
1818 {
1819 const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
1820 problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0);
1821 const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1(
1822 problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0);
1823 const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N(
1824 problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC);
1825 const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
1827 c_grid_desc_m_n, problem.MBlock, problem.NBlock);
1828
1829 // A/B shuffled scale for better 8-bit scale access pattern
1830 // MNRepeat -> KRepeat -> KThreadPerXdl -> MNThreadPerXdl -> KXdlPack -> MNXdlPack
1831 const auto Padded_Scale_M =
1832 math::integer_divide_ceil(problem.M, ScaleBlockSize) * ScaleBlockSize;
1833 const auto a_scale_grid_desc_am_ak = make_naive_tensor_descriptor(
1834 make_tuple(Padded_Scale_M / (MXdlPack * MPerXdl),
1835 math::integer_divide_ceil(problem.K, (ScaleBlockSize / APackedSize)) /
1836 (KXdlPack * 64 / MPerXdl),
1838 make_tuple(math::integer_divide_ceil(problem.K * problem.KBatch,
1839 (ScaleBlockSize / APackedSize)) *
1840 MPerXdl * MXdlPack / scale_pack_size_a,
1842 1));
1843
1844 const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor(
1845 make_tuple(problem.N / (NXdlPack * NPerXdl),
1846 math::integer_divide_ceil(problem.K, (ScaleBlockSize / BPackedSize)) /
1847 (KXdlPack * 64 / NPerXdl),
1849 make_tuple(math::integer_divide_ceil(problem.K * problem.KBatch,
1850 (ScaleBlockSize / BPackedSize)) *
1851 NPerXdl * NXdlPack / scale_pack_size_b,
1853 1));
1854
1855 Run<decltype(a_grid_desc_ak0_m_ak1),
1856 decltype(a_scale_grid_desc_am_ak),
1857 decltype(b_grid_desc_bk0_n_bk1),
1858 decltype(b_scale_grid_desc_bn_ak),
1859 decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
1860 HasMainKBlockLoop,
1861 CGlobalMemoryDataOperation,
1862 TailNum>(p_a_grid,
1863 p_a_scale_grid,
1864 p_b_grid,
1865 p_b_scale_grid,
1866 p_c_grid,
1867 p_shared,
1868 problem,
1869 a_grid_desc_ak0_m_ak1,
1870 a_scale_grid_desc_am_ak,
1871 b_grid_desc_bk0_n_bk1,
1872 b_scale_grid_desc_bn_ak,
1873 c_grid_desc_mblock_mperblock_nblock_nperblock);
1874 }
1875
1876 template <typename AGridDesc_AK0_M_K1,
1877 typename AScaleGridDesc_AM_AK,
1878 typename BGridDesc_BK0_N_K1,
1879 typename BScaleGridDesc_BN_AK,
1880 typename CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
1881 bool HasMainKBlockLoop,
1882 InMemoryDataOperationEnum CGlobalMemoryDataOperation,
1883 TailNumber TailNum = TailNumber::Odd>
1884 __device__ static void Run_2Lds(const ADataType* p_a_grid,
1885 const AScaleDataType* p_a_scale_grid,
1886 const BDataType* p_b_grid,
1887 const BScaleDataType* p_b_scale_grid,
1888 CDataType* p_c_grid,
1889 void* p_shared_0,
1890 void* p_shared_1,
1891 const Problem& problem,
1892 const AGridDesc_AK0_M_K1& a_grid_desc_ak0_m_ak1,
1893 const AScaleGridDesc_AM_AK& a_scale_grid_desc_am_ak,
1894 const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1,
1895 const BScaleGridDesc_BN_AK& b_scale_grid_desc_bn_ak,
1896 const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
1897 c_grid_desc_mblock_mperblock_nblock_nperblock)
1898 {
1899 const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1900 p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
1901 const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1902 p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
1904 p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
1905
1906 // A Scale buffer
1907 const auto a_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1908 p_a_scale_grid, a_scale_grid_desc_am_ak.GetElementSpaceSize());
1909
1910 // B Scale buffer
1911 const auto b_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1912 p_b_scale_grid, b_scale_grid_desc_bn_ak.GetElementSpaceSize());
1913
1914 const CElementwiseOperation c_element_op{};
1915
1916 // divide block work by [M, N]
1917 const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N, 4};
1918
1919 const auto block_work_idx =
1920 block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
1921
1922 if(!block_2_ctile_map.ValidCTileIndex(
1923 block_work_idx,
1924 make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
1925 c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
1926 {
1927 return;
1928 }
1929
1930 const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]);
1931 const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]);
1932
1933 // HACK: this force m/n_block_data_idx_on_grid into SGPR
1934 const index_t m_block_data_idx_on_grid =
1935 __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock);
1936
1937 const index_t n_block_data_idx_on_grid =
1938 __builtin_amdgcn_readfirstlane(block_n_id * NPerBlock);
1939
1940 // lds max alignment
1941 constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number);
1942
1943 // A matrix in LDS memory, dst of blockwise copy
1944 constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
1945
1946 // B matrix in LDS memory, dst of blockwise copy
1947 constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
1948
1949 auto a_blockwise_copy =
1952 ABlockTransferThreadClusterLengths_AK0_M_AK1,
1953 ABlockTransferThreadClusterArrangeOrder,
1954 ADataType,
1955 ADataType,
1956 decltype(a_grid_desc_ak0_m_ak1),
1957 decltype(a_block_desc_ak0_m_ak1),
1958 ABlockTransferSrcAccessOrder,
1959 ABlockTransferSrcVectorDim,
1960 2,
1961 ABlockTransferSrcScalarPerVector>(
1962 a_grid_desc_ak0_m_ak1,
1963 make_multi_index(0, m_block_data_idx_on_grid, 0),
1964 a_block_desc_ak0_m_ak1,
1965 make_multi_index(0, 0, 0));
1966
1967 // B matrix blockwise copy
1968 auto b_blockwise_copy =
1971 BBlockTransferThreadClusterLengths_BK0_N_BK1,
1972 BBlockTransferThreadClusterArrangeOrder,
1973 BDataType,
1974 BDataType,
1975 decltype(b_grid_desc_bk0_n_bk1),
1976 decltype(b_block_desc_bk0_n_bk1),
1977 BBlockTransferSrcAccessOrder,
1978 BBlockTransferSrcVectorDim,
1979 2,
1980 BBlockTransferSrcScalarPerVector>(
1981 b_grid_desc_bk0_n_bk1,
1982 make_multi_index(0, n_block_data_idx_on_grid, 0),
1983 b_block_desc_bk0_n_bk1,
1984 make_multi_index(0, 0, 0));
1985
1986 // LDS allocation for A and B: be careful of alignment
1987 constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
1988 a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
1989
1990 auto a_block_buf_ping = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1991 static_cast<ADataType*>(p_shared_0), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
1992
1993 auto b_block_buf_ping = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1994 bit_cast<BDataType*>(static_cast<char*>(p_shared_0) +
1995 a_block_space_size_aligned * sizeof(ADataType)),
1996 b_block_desc_bk0_n_bk1.GetElementSpaceSize());
1997
1998 auto a_block_buf_pong = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1999 static_cast<ADataType*>(p_shared_1), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
2000
2001 auto b_block_buf_pong = make_dynamic_buffer<AddressSpaceEnum::Lds>(
2003 a_block_space_size_aligned * sizeof(ADataType)),
2004 b_block_desc_bk0_n_bk1.GetElementSpaceSize());
2005
2006 auto a_block_bufs = make_tuple(a_block_buf_ping, a_block_buf_pong);
2007 auto b_block_bufs = make_tuple(b_block_buf_ping, b_block_buf_pong);
2008
2009 constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0);
2010 constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1Number, 0, 0);
2011
2012 // Blockwise GEMM pipeline
2013 static_assert(std::is_default_constructible_v<BlockwiseGemmPipe>);
2014 auto blockwise_gemm_pipeline = BlockwiseGemmPipe{};
2015 auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer();
2016
2017 const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
2018 (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
2019 KPerBlock);
2020
2021 // Initial thread mapping for:
2022 // BlockSize = 256
2023 // MPerXdl=NPerXdl=32 and MPerBlock=NPerBlock=128 MRepeat=NRepeat=2 MWaves=NWaves=2
2024 // For each [m0, n0] tile, there are 4 waves:
2025 // tId in [ 0, 63] m x n = [ 0, 31] x [ 0, 31] waveId = [0, 0]
2026 // tId in [ 64, 127] m x n = [ 0, 31] x [32, 63] waveId = [0, 1]
2027 // tId in [128, 191] m x n = [32, 63] x [ 0, 31] waveId = [1, 0]
2028 // tId in [192, 255] m x n = [32, 63] x [32, 63] waveId = [1, 1]
2029
2030 // BlockSize = 128
2031 // MPerXdl=NPerXdl=16 and MPerBlock=128 NPerBlock=16 MRepeat=4 NRepeat=1 MWaves=2 NWaves=1
2032 // For each [m0, n0] tile, there are 2 waves:
2033 // tId in [ 0, 63] m x n = [ 0, 15] x [0, 15] waveId = [0, 0]
2034 // tId in [ 64, 127] m x n = [16, 31] x [0, 15] waveId = [1, 0]
2035
2036 // TODO: Document initial thread mapping for more combinations of parameters
2037
2038 const auto wave_idx = BlockwiseGemmPipe::GetWaveIdx();
2039 const auto waveId_m = wave_idx[I0];
2040 const auto waveId_n = wave_idx[I1];
2041
2042 // static constexpr auto mfma = BlockwiseGemmPipe::xdlops_gemm.mfma;
2043
2044 // auto thread_offset_k = (get_thread_local_1d_id() % BlockwiseGemmPipe::WaveSize) /
2045 // mfma.selected_mfma.num_threads_per_blk;
2046
2047 // A wave access continuous memory
2048 auto thread_offset_shuffled =
2049 get_thread_local_1d_id() % BlockwiseGemmPipe::WaveSize * KXdlPack * MXdlPack;
2050
2051 auto a_thread_offset_m = waveId_m;
2052
2053 auto a_scale_thread_copy = ThreadwiseTensorSliceTransfer_v2<
2054 AScaleDataType,
2055 AScaleDataType,
2056 decltype(a_scale_grid_desc_am_ak),
2057 decltype(BlockwiseGemmPipe::a_scale_thread_desc),
2058 Sequence<1, 1, KXdlPack * MXdlPack / scale_pack_size_a>, // SliceLengths
2059 Sequence<0, 1, 2>, // DimAccessOrder
2060 2, // SrcVectorDim
2061 KXdlPack * MXdlPack / scale_pack_size_a, // SrcScalarPerVector
2062 1, // SrcScalarStrideInVector
2063 true>(a_scale_grid_desc_am_ak,
2064 make_multi_index(block_m_id * MPerBlock / MPerXdl / MXdlPack + a_thread_offset_m,
2065 0,
2066 thread_offset_shuffled / scale_pack_size_a));
2067
2068 auto b_thread_offset_n = waveId_n;
2069
2070 auto b_scale_thread_copy = ThreadwiseTensorSliceTransfer_v2<
2071 BScaleDataType,
2072 BScaleDataType,
2073 decltype(b_scale_grid_desc_bn_ak),
2074 decltype(BlockwiseGemmPipe::b_scale_thread_desc),
2075 Sequence<1, 1, KXdlPack * NXdlPack / scale_pack_size_b>, // SliceLengths
2076 Sequence<0, 1, 2>, // DimAccessOrder
2077 2, // SrcVectorDim
2078 KXdlPack * MXdlPack / scale_pack_size_b, // SrcScalarPerVector
2079 1, // SrcScalarStrideInVector
2080 true>(b_scale_grid_desc_bn_ak,
2081 make_multi_index(block_n_id * NPerBlock / NPerXdl / NXdlPack + b_thread_offset_n,
2082 0,
2083 thread_offset_shuffled / scale_pack_size_b));
2084
2085 blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(a_grid_desc_ak0_m_ak1,
2086 a_block_desc_ak0_m_ak1,
2087 a_blockwise_copy,
2088 a_grid_buf,
2089 a_block_bufs,
2090 a_block_slice_copy_step,
2091 b_grid_desc_bk0_n_bk1,
2092 b_block_desc_bk0_n_bk1,
2093 b_blockwise_copy,
2094 b_grid_buf,
2095 b_block_bufs,
2096 b_block_slice_copy_step,
2097 c_thread_buf,
2098 a_scale_grid_desc_am_ak,
2099 a_scale_thread_copy,
2100 a_scale_grid_buf,
2101 b_scale_grid_desc_bn_ak,
2102 b_scale_thread_copy,
2103 b_scale_grid_buf,
2104 num_k_block_main_loop);
2105
2106 // shuffle C and write out
2107 {
2108 static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
2109 NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
2110 "wrong!");
2111 static_assert(CShuffleMXdlPerWavePerShuffle % MXdlPack == 0 &&
2112 CShuffleNXdlPerWavePerShuffle % NXdlPack == 0,
2113 "wrong!");
2114
2115 constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
2116 constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
2117
2118 // TODO: hacky, fix it!
2119 constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
2120 blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3();
2121
2122 // TODO: hacky, fix it!
2123 // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
2124 constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
2125 blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3();
2126
2127 constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0);
2128 constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1);
2129 constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2);
2130 constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3);
2131 constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4);
2132 constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5);
2133 constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6);
2134 constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7);
2135 constexpr auto M5 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I8);
2136 constexpr auto N3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I9);
2137
2138 constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
2140
2141 auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
2142 static_cast<CShuffleDataType*>(p_shared_0),
2143 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
2144
2145 constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
2146 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
2147 make_tuple(
2150 Number<CShuffleMXdlPerWavePerShuffle / MXdlPack>{}, // M0 (MXdlPerWave) per
2151 // shuffle
2152 M1, // M1 = MWave
2153 M2, // M2 = MXdlPack
2154 M3, // M3 * M4 * M5 = MPerXdl
2155 M4,
2156 M5)),
2159 Number<CShuffleNXdlPerWavePerShuffle / NXdlPack>{}, // N0 (NXdlPerWave) per
2160 // shuffle
2161 N1, // N1 = NWave
2162 N2, // N2 = NXdlPack
2163 N3))), // N3 = NPerXdl
2167 Sequence<>{},
2169
2170 // calculate origin of thread output tensor on global memory
2171 // blockwise GEMM c matrix starting index
2172 const auto c_thread_mtx_on_block =
2173 blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
2174
2175 const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
2176 const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
2177
2178 const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
2180 make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4, M5))),
2183
2184 const auto m_thread_data_on_block_idx =
2185 m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
2186 make_multi_index(m_thread_data_on_block));
2187
2188 const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
2190 make_tuple(make_merge_transform(make_tuple(N0, N1, N2, N3))),
2193
2194 const auto n_thread_data_on_block_idx =
2195 n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
2196 make_multi_index(n_thread_data_on_block));
2197
2198 // shuffle: threadwise copy C from VGPR to LDS
2199 auto c_thread_copy_vgpr_to_lds = ThreadwiseTensorSliceTransfer_v1r3<
2200 AccDataType,
2201 CShuffleDataType,
2202 decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
2203 decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
2205 Sequence<CShuffleMXdlPerWavePerShuffle / MXdlPack,
2206 CShuffleNXdlPerWavePerShuffle / NXdlPack,
2207 I1,
2208 I1,
2209 M2,
2210 N2,
2211 M3,
2212 I1,
2213 M5,
2214 I1>,
2216 9,
2217 1,
2219 1,
2220 true>{c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
2222 0,
2223 m_thread_data_on_block_idx[I1],
2224 n_thread_data_on_block_idx[I1],
2225 m_thread_data_on_block_idx[I2],
2226 n_thread_data_on_block_idx[I2],
2227 m_thread_data_on_block_idx[I3],
2228 m_thread_data_on_block_idx[I4],
2229 m_thread_data_on_block_idx[I5],
2230 n_thread_data_on_block_idx[I3]),
2232
2233 // shuffle: blockwise copy C from LDS to global
2234 auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1<
2235 ThisThreadBlock, // ThreadGroup
2236 CElementwiseOperation, // ElementwiseOperation,
2237 CGlobalMemoryDataOperation, // DstInMemOp,
2238 Sequence<1,
2239 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
2240 1,
2241 CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
2242 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
2243 Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
2244 CShuffleDataType, // typename SrcData,
2245 CDataType, // typename DstData,
2246 decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
2247 decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
2248 Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
2249 3, // index_t VectorDim,
2250 CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector,
2251 true, // bool ThreadTransferSrcResetCoordinateAfterRun,
2252 false> // bool ThreadTransferDstResetCoordinateAfterRun>
2253 {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
2254 make_multi_index(0, 0, 0, 0),
2255 c_grid_desc_mblock_mperblock_nblock_nperblock,
2256 make_multi_index(block_m_id, 0, block_n_id, 0),
2257 c_element_op};
2258
2259 // space filling curve for threadwise C in VGPR
2260 constexpr auto sfc_c_vgpr =
2261 SpaceFillingCurve<Sequence<MXdlPerWave / MXdlPack,
2262 NXdlPerWave / NXdlPack,
2263 1,
2264 1,
2265 MXdlPack,
2266 NXdlPack,
2267 M2,
2268 1,
2269 M4,
2270 1>,
2272 Sequence<CShuffleMXdlPerWavePerShuffle / MXdlPack,
2273 CShuffleNXdlPerWavePerShuffle / NXdlPack,
2274 1,
2275 1,
2276 MXdlPack,
2277 NXdlPack,
2278 M2,
2279 1,
2280 M4,
2281 1>>{};
2282
2283 // space filling curve for shuffled blockwise C in global mem
2284 constexpr auto sfc_c_global =
2287 Sequence<1,
2288 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
2289 1,
2290 CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
2291
2292 constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
2293
2294 static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!");
2295
2296 static_for<0, num_access, 1>{}([&](auto access_id) {
2297 // make sure it's safe to write to LDS
2299
2300 // each thread write its data from VGPR to LDS
2301 c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
2302 sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
2303 c_thread_buf,
2304 c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
2305 c_shuffle_block_buf);
2306
2307 // make sure it's safe to read from LDS
2309
2310 // each block copy its data from LDS to global
2311 c_shuffle_block_copy_lds_to_global.Run(
2312 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
2313 c_shuffle_block_buf,
2314 c_grid_desc_mblock_mperblock_nblock_nperblock,
2315 c_grid_buf);
2316
2317 if constexpr(access_id < num_access - 1)
2318 {
2319 constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id);
2320
2321 // move on C
2322 c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow(
2323 c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step);
2324 }
2325 });
2326 }
2327 }
2328
2329 template <bool HasMainKBlockLoop,
2330 InMemoryDataOperationEnum CGlobalMemoryDataOperation,
2331 TailNumber TailNum = TailNumber::Odd>
2332 __device__ static void Run_2Lds(const ADataType* p_a_grid,
2333 const AScaleDataType* p_a_scale_grid,
2334 const BDataType* p_b_grid,
2335 const BScaleDataType* p_b_scale_grid,
2336 CDataType* p_c_grid,
2337 void* p_shared_0,
2338 void* p_shared_1,
2339 const Problem& problem)
2340 {
2341 const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
2342 problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0);
2343 const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1(
2344 problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0);
2345 const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N(
2346 problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC);
2347 const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
2349 c_grid_desc_m_n, problem.MBlock, problem.NBlock);
2350
2351 // A/B shuffled scale for better 8-bit scale access pattern
2352 // MNRepeat -> KRepeat -> KThreadPerXdl -> MNThreadPerXdl -> KXdlPack -> MNXdlPack
2353 const auto Padded_Scale_M =
2354 math::integer_divide_ceil(problem.M, ScaleBlockSize) * ScaleBlockSize;
2355 const auto a_scale_grid_desc_am_ak = make_naive_tensor_descriptor(
2356 make_tuple(Padded_Scale_M / (MXdlPack * MPerXdl),
2357 math::integer_divide_ceil(problem.K, (ScaleBlockSize / APackedSize)) /
2358 (KXdlPack * 64 / MPerXdl),
2360 make_tuple(math::integer_divide_ceil(problem.K * problem.KBatch,
2361 (ScaleBlockSize / APackedSize)) *
2362 MPerXdl * MXdlPack / scale_pack_size_a,
2364 1));
2365
2366 const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor(
2367 make_tuple(problem.N / (NXdlPack * NPerXdl),
2368 math::integer_divide_ceil(problem.K, (ScaleBlockSize / BPackedSize)) /
2369 (KXdlPack * 64 / NPerXdl),
2371 make_tuple(math::integer_divide_ceil(problem.K * problem.KBatch,
2372 (ScaleBlockSize / BPackedSize)) *
2373 NPerXdl * NXdlPack / scale_pack_size_b,
2375 1));
2376
2377 Run_2Lds<decltype(a_grid_desc_ak0_m_ak1),
2378 decltype(a_scale_grid_desc_am_ak),
2379 decltype(b_grid_desc_bk0_n_bk1),
2380 decltype(b_scale_grid_desc_bn_ak),
2381 decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
2382 HasMainKBlockLoop,
2383 CGlobalMemoryDataOperation,
2384 TailNum>(p_a_grid,
2385 p_a_scale_grid,
2386 p_b_grid,
2387 p_b_scale_grid,
2388 p_c_grid,
2389 p_shared_0,
2390 p_shared_1,
2391 problem,
2392 a_grid_desc_ak0_m_ak1,
2393 a_scale_grid_desc_am_ak,
2394 b_grid_desc_bk0_n_bk1,
2395 b_scale_grid_desc_bn_ak,
2396 c_grid_desc_mblock_mperblock_nblock_nperblock);
2397 }
2398};
2399
2400} // namespace ck
#define CK_MAX_THREAD_PER_BLOCK
Definition ck.hpp:30
#define IS_VALID_COMPILATION_PARAMETER_IMPL(CDataType_)
Definition device_base.hpp:178
__host__ __device__ constexpr auto integer_least_multiple(X x, Y y)
Definition utility/math.hpp:78
__host__ __device__ constexpr T max(T x)
Definition utility/math.hpp:84
__host__ __device__ constexpr auto integer_divide_ceil(X x, Y y)
Definition utility/math.hpp:72
__host__ __device__ constexpr auto lcm(X x, Y y)
Definition utility/math.hpp:198
GemmSpecialization
Definition gemm_specialization.hpp:11
@ MKPadding
Definition gemm_specialization.hpp:18
@ KPadding
Definition gemm_specialization.hpp:16
@ NPadding
Definition gemm_specialization.hpp:15
@ MPadding
Definition gemm_specialization.hpp:14
@ MNKPadding
Definition gemm_specialization.hpp:20
@ MNPadding
Definition gemm_specialization.hpp:17
@ NKPadding
Definition gemm_specialization.hpp:19
Definition ck.hpp:268
ushort bhalf_t
Definition data_type.hpp:30
__host__ __device__ constexpr auto make_multi_index(Xs &&... xs)
Definition array_multi_index.hpp:15
f6_pk_t< f6_t, 16 > f6x16_pk_t
Definition data_type.hpp:180
__host__ __device__ constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition multi_index_transform_helper.hpp:12
int32_t index_t
Definition ck.hpp:299
constexpr auto BlockGemmMXPipeline_Selector()
Definition blockwise_gemm_pipeline_xdlops_mx_moe_selector.hpp:36
__host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition tensor_descriptor_helper.hpp:49
InMemoryDataOperationEnum
Definition ck.hpp:277
@ Set
Definition ck.hpp:278
__host__ __device__ constexpr auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition tensor_description/tensor_adaptor.hpp:425
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
BlockGemmPipelineVersion
Definition blkgemmpipe_scheduler.hpp:12
@ v3
Definition blkgemmpipe_scheduler.hpp:16
@ v4
Definition blkgemmpipe_scheduler.hpp:17
@ v1
Definition blkgemmpipe_scheduler.hpp:14
__host__ __device__ constexpr auto make_freeze_transform(const LowerIndex &low_idx)
Definition multi_index_transform_helper.hpp:151
__host__ __device__ constexpr auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:37
_Float16 half_t
Definition data_type.hpp:31
__host__ __device__ constexpr auto make_xor_with_modulo_transform(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:185
integral_constant< index_t, N > Number
Definition number.hpp:12
TailNumber
Definition blkgemmpipe_scheduler.hpp:31
@ Odd
Definition blkgemmpipe_scheduler.hpp:33
@ Full
Definition blkgemmpipe_scheduler.hpp:49
__host__ __device__ constexpr auto make_merge_transform(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:55
f6_pk_t< bf6_t, 32 > bf6x32_pk_t
Definition data_type.hpp:183
constexpr detail::ignore_t ignore
Definition utility/ignore.hpp:20
__device__ index_t get_block_1d_id()
Definition get_id.hpp:47
f6_pk_t< f6_t, 32 > f6x32_pk_t
Definition data_type.hpp:181
bool EnvIsEnabled(EnvVar)
Definition utility/env.hpp:140
constexpr bool is_same_v
Definition type.hpp:283
__host__ __device__ constexpr auto make_merge_transform_v3_division_mod(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:84
__global__ enable_if_t<!Use2LDS, void > kernel_gemm_xdl_cshuffle_v3_mx(typename GridwiseGemm::Argument karg)
Definition gridwise_gemm_xdl_cshuffle_v3_mx.hpp:40
BlockGemmPipelineScheduler
Definition blkgemmpipe_scheduler.hpp:25
@ Intrawave
Definition blkgemmpipe_scheduler.hpp:26
__host__ __device__ constexpr auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition tensor_descriptor_helper.hpp:101
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__device__ index_t get_thread_local_1d_id()
Definition get_id.hpp:41
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
__host__ __device__ constexpr Y bit_cast(const X &x)
Definition type.hpp:306
__device__ void block_sync_lds()
Definition synchronization.hpp:16
__host__ __device__ constexpr auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:90
constexpr index_t packed_size_v
Definition data_type.hpp:411
__host__ __device__ constexpr auto make_dynamic_buffer(T *p, ElementSpaceSize element_space_size)
Definition dynamic_buffer.hpp:472
typename std::enable_if< B, T >::type enable_if_t
Definition enable_if.hpp:27
f6_pk_t< bf6_t, 16 > bf6x16_pk_t
Definition data_type.hpp:182
const GenericPointer< typename T::ValueType > T2 value
Definition pointer.h:1697
signed int int32_t
Definition stdint.h:123
Definition block_to_ctile_map.hpp:271
__host__ static __device__ constexpr index_t CalculateGridSize(index_t M, index_t N)
Definition block_to_ctile_map.hpp:283
Definition gridwise_gemm_xdl_cshuffle_v3_mx.hpp:685
const AElementwiseOperation a_element_op
Definition gridwise_gemm_xdl_cshuffle_v3_mx.hpp:741
__host__ Argument(const ADataType *p_a_grid_, const AScaleDataType *p_a_scale_grid_, const BDataType *p_b_grid_, const BScaleDataType *p_b_scale_grid_, CDataType *p_c_grid_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideScaleA_, index_t StrideB_, index_t StrideScaleB_, index_t StrideC_, index_t k_batch_, AElementwiseOperation a_element_op_, BElementwiseOperation b_element_op_, CElementwiseOperation c_element_op_, bool is_reduce_=false)
Definition gridwise_gemm_xdl_cshuffle_v3_mx.hpp:686
__host__ __device__ bool IsAtomicAdd() const
Definition gridwise_gemm_xdl_cshuffle_v3_mx.hpp:730
__host__ __device__ bool IsReduceAdd() const
Definition gridwise_gemm_xdl_cshuffle_v3_mx.hpp:725
const AScaleDataType * p_a_scale_grid
Definition gridwise_gemm_xdl_cshuffle_v3_mx.hpp:736
CDataType * p_c_grid
Definition gridwise_gemm_xdl_cshuffle_v3_mx.hpp:739
const CElementwiseOperation c_element_op
Definition gridwise_gemm_xdl_cshuffle_v3_mx.hpp:743
bool is_reduce
Definition gridwise_gemm_xdl_cshuffle_v3_mx.hpp:744
const ADataType * p_a_grid
Definition gridwise_gemm_xdl_cshuffle_v3_mx.hpp:735
const BScaleDataType * p_b_scale_grid
Definition gridwise_gemm_xdl_cshuffle_v3_mx.hpp:738
const BDataType * p_b_grid
Definition gridwise_gemm_xdl_cshuffle_v3_mx.hpp:737
const BElementwiseOperation b_element_op
Definition gridwise_gemm_xdl_cshuffle_v3_mx.hpp:742
__host__ void Print() const
Definition gridwise_gemm_xdl_cshuffle_v3_mx.hpp:653
index_t StrideB
Definition gridwise_gemm_xdl_cshuffle_v3_mx.hpp:669
index_t StrideScaleA
Definition gridwise_gemm_xdl_cshuffle_v3_mx.hpp:668
index_t KPadded
Definition gridwise_gemm_xdl_cshuffle_v3_mx.hpp:676
index_t StrideA
Definition gridwise_gemm_xdl_cshuffle_v3_mx.hpp:667
index_t M
Definition gridwise_gemm_xdl_cshuffle_v3_mx.hpp:664
index_t MPadded
Definition gridwise_gemm_xdl_cshuffle_v3_mx.hpp:673
index_t NBlock
Definition gridwise_gemm_xdl_cshuffle_v3_mx.hpp:680
index_t N
Definition gridwise_gemm_xdl_cshuffle_v3_mx.hpp:665
index_t NPadded
Definition gridwise_gemm_xdl_cshuffle_v3_mx.hpp:674
__host__ Problem(index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideScaleA_, index_t StrideB_, index_t StrideScaleB_, index_t StrideC_, index_t KBatch_)
Definition gridwise_gemm_xdl_cshuffle_v3_mx.hpp:624
index_t BK0
Definition gridwise_gemm_xdl_cshuffle_v3_mx.hpp:678
index_t K
Definition gridwise_gemm_xdl_cshuffle_v3_mx.hpp:666
index_t KRead
Definition gridwise_gemm_xdl_cshuffle_v3_mx.hpp:675
index_t MBlock
Definition gridwise_gemm_xdl_cshuffle_v3_mx.hpp:679
index_t StrideC
Definition gridwise_gemm_xdl_cshuffle_v3_mx.hpp:671
index_t StrideScaleB
Definition gridwise_gemm_xdl_cshuffle_v3_mx.hpp:670
index_t KBatch
Definition gridwise_gemm_xdl_cshuffle_v3_mx.hpp:672
index_t AK0
Definition gridwise_gemm_xdl_cshuffle_v3_mx.hpp:677
index_t b_k_split_offset
Definition gridwise_gemm_xdl_cshuffle_v3_mx.hpp:806
__device__ SplitKBatchOffset(Argument &karg, index_t k_id)
Definition gridwise_gemm_xdl_cshuffle_v3_mx.hpp:750
index_t a_scale_k_split_offset
Definition gridwise_gemm_xdl_cshuffle_v3_mx.hpp:807
index_t a_k_split_offset
Definition gridwise_gemm_xdl_cshuffle_v3_mx.hpp:805
index_t b_scale_k_split_offset
Definition gridwise_gemm_xdl_cshuffle_v3_mx.hpp:808
index_t c_reduce_offset
Definition gridwise_gemm_xdl_cshuffle_v3_mx.hpp:809
Definition gridwise_gemm_xdl_cshuffle_v3_mx.hpp:156
static __device__ void Run(const ADataType *p_a_grid, const AScaleDataType *p_a_scale_grid, const BDataType *p_b_grid, const BScaleDataType *p_b_scale_grid, CDataType *p_c_grid, void *p_shared, const Problem &problem)
Definition gridwise_gemm_xdl_cshuffle_v3_mx.hpp:1811
static __host__ constexpr bool CalculateHasMainKBlockLoop(index_t K)
Definition gridwise_gemm_xdl_cshuffle_v3_mx.hpp:1321
__host__ static __device__ auto MakeBGridDescriptor_BK0_N_BK1(index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0)
Definition gridwise_gemm_xdl_cshuffle_v3_mx.hpp:403
static __host__ auto CalculateBK0Padded(index_t K, index_t K_Batch=1)
Definition gridwise_gemm_xdl_cshuffle_v3_mx.hpp:230
static __device__ void Run(const ADataType *p_a_grid, const AScaleDataType *p_a_scale_grid, const BDataType *p_b_grid, const BScaleDataType *p_b_scale_grid, CDataType *p_c_grid, void *p_shared, const Problem &problem, const AGridDesc_AK0_M_K1 &a_grid_desc_ak0_m_ak1, const AScaleGridDesc_AM_AK &a_scale_grid_desc_am_ak, const BGridDesc_BK0_N_K1 &b_grid_desc_bk0_n_bk1, const BScaleGridDesc_BN_AK &b_scale_grid_desc_bn_ak, const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock &c_grid_desc_mblock_mperblock_nblock_nperblock)
Definition gridwise_gemm_xdl_cshuffle_v3_mx.hpp:1374
static __device__ void Run_2Lds(const ADataType *p_a_grid, const AScaleDataType *p_a_scale_grid, const BDataType *p_b_grid, const BScaleDataType *p_b_scale_grid, CDataType *p_c_grid, void *p_shared_0, void *p_shared_1, const Problem &problem)
Definition gridwise_gemm_xdl_cshuffle_v3_mx.hpp:2332
static __device__ constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
Definition gridwise_gemm_xdl_cshuffle_v3_mx.hpp:812
__host__ static __device__ constexpr auto MakeGemmMmaTileDescriptor(const TileDesc_K0_MN_K1 &)
Definition gridwise_gemm_xdl_cshuffle_v3_mx.hpp:264
static __host__ auto CalculateAK0Padded(index_t K, index_t K_Batch=1)
Definition gridwise_gemm_xdl_cshuffle_v3_mx.hpp:224
static __host__ auto CalculateGridSize(index_t M, index_t N, index_t KBatch)
Definition gridwise_gemm_xdl_cshuffle_v3_mx.hpp:204
static __host__ auto CalculateNBlock(index_t N)
Definition gridwise_gemm_xdl_cshuffle_v3_mx.hpp:254
static __host__ auto CalculateNPadded(index_t N)
Definition gridwise_gemm_xdl_cshuffle_v3_mx.hpp:214
static __device__ constexpr auto GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock()
Definition gridwise_gemm_xdl_cshuffle_v3_mx.hpp:1049
__host__ static __device__ constexpr auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc &c_grid_desc_m_n, index_t MBlock, index_t NBlock)
Definition gridwise_gemm_xdl_cshuffle_v3_mx.hpp:1336
static __host__ auto CalculateKPadded(index_t K)
Definition gridwise_gemm_xdl_cshuffle_v3_mx.hpp:219
static __device__ constexpr index_t GetSharedMemoryNumberOfByte()
Definition gridwise_gemm_xdl_cshuffle_v3_mx.hpp:1093
static __host__ auto CalculateKRead(index_t K, index_t K_Batch=1)
Definition gridwise_gemm_xdl_cshuffle_v3_mx.hpp:242
static __device__ void Run_2Lds(const ADataType *p_a_grid, const AScaleDataType *p_a_scale_grid, const BDataType *p_b_grid, const BScaleDataType *p_b_scale_grid, CDataType *p_c_grid, void *p_shared_0, void *p_shared_1, const Problem &problem, const AGridDesc_AK0_M_K1 &a_grid_desc_ak0_m_ak1, const AScaleGridDesc_AM_AK &a_scale_grid_desc_am_ak, const BGridDesc_BK0_N_K1 &b_grid_desc_bk0_n_bk1, const BScaleGridDesc_BN_AK &b_scale_grid_desc_bn_ak, const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock &c_grid_desc_mblock_mperblock_nblock_nperblock)
Definition gridwise_gemm_xdl_cshuffle_v3_mx.hpp:1884
__host__ static __device__ constexpr auto MakeAMmaTileDescriptor_M0_M1_M2_M3_K(const ABlockDesc_AK0_M_AK1 &)
Definition gridwise_gemm_xdl_cshuffle_v3_mx.hpp:543
remove_cvref_t< decltype(BlockGemmMXPipeline_Selector< BlkGemmPipelineVer, BlkGemmPipeSched, BlockSize, ScaleBlockSize, ADataType, AScaleDataType, BDataType, BScaleDataType, ComputeTypeA, GemmAccDataType, decltype(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()), decltype(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()), decltype(MakeAMmaTileDescriptor_M0_M1_M2_M3_K(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1())), decltype(MakeBMmaTileDescriptor_N0_N1_N2_N3_K(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1())), ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXdl, NPerXdl, MXdlPerWave, NXdlPerWave, KPack >())> BlockwiseGemmPipe
Definition gridwise_gemm_xdl_cshuffle_v3_mx.hpp:1064
__host__ static __device__ auto MakeCGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC)
Definition gridwise_gemm_xdl_cshuffle_v3_mx.hpp:562
__host__ static __device__ auto MakeAGridDescriptor_AK0_M_AK1(index_t M, index_t MPad, index_t K, index_t KPad, index_t StrideA, index_t AK0)
Definition gridwise_gemm_xdl_cshuffle_v3_mx.hpp:288
static __host__ auto CalculateMBlock(index_t M)
Definition gridwise_gemm_xdl_cshuffle_v3_mx.hpp:249
__host__ static __device__ constexpr auto MakeBMmaTileDescriptor_N0_N1_N2_N3_K(const BBlockDesc_BK0_N_BK1 &)
Definition gridwise_gemm_xdl_cshuffle_v3_mx.hpp:553
static __host__ constexpr TailNumber CalculateKBlockLoopTailNum(index_t K)
Definition gridwise_gemm_xdl_cshuffle_v3_mx.hpp:1328
static __device__ constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
Definition gridwise_gemm_xdl_cshuffle_v3_mx.hpp:933
static __host__ auto CalculateMPadded(index_t M)
Definition gridwise_gemm_xdl_cshuffle_v3_mx.hpp:209
static __host__ auto CalculateKPadded(index_t K, index_t K_Batch=1)
Definition gridwise_gemm_xdl_cshuffle_v3_mx.hpp:236
Selects the appropriate MFMA instruction type and configuration for given data types and tile sizes o...
Definition xdlops_gemm.hpp:1208
Definition utility/sequence.hpp:43
Definition tensor_space_filling_curve.hpp:20
Definition thread_group_tensor_slice_transfer_direct_load.hpp:55
Definition thread_group_tensor_slice_transfer_v6r1.hpp:34
Definition threadwise_tensor_slice_transfer.hpp:39
Helper structure that facilitates transfer of source (grid) data to destination threads.
Definition threadwise_tensor_slice_transfer.hpp:234
static constexpr value_type value
Definition utility/integral_constant.hpp:13
Unsigned representation of a conventional biased Float32 exponent.
Definition utility/e8m0.hpp:26
Definition data_type.hpp:42
Definition type.hpp:177
Definition data_type.hpp:187
Definition functional2.hpp:33
Definition device_base.hpp:197
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:340
#define CK_ENV(name)
Definition utility/env.hpp:129