24 template <index_t ndim>
27 return number < i == ndim - 2 ? ndim - 1 : i == ndim - 1 ? ndim - 2 : i > {};
31 template <
typename Problem>
36 typename Problem::KDataType,
37 typename Problem::AccDataType,
40 Problem::BlockFmhaShape::kN0,
41 Problem::BlockFmhaShape::kK0>,
42 typename Problem::BlockFmhaShape::Gemm0BlockWarps,
43 typename Problem::BlockFmhaShape::Gemm0WarpTile>>;
46 typename Problem::QDataType,
47 typename Problem::KDataType,
48 typename Problem::AccDataType,
49 Problem::BlockFmhaShape::Gemm0WarpTile::at(
number<0>{}),
50 Problem::BlockFmhaShape::Gemm0WarpTile::at(
number<1>{}),
51 Problem::BlockFmhaShape::Gemm0WarpTile::at(
number<2>{}),
53 Problem::BlockFmhaShape::Gemm0WarpTile::at(
number<0>{}) == 16 ?
false :
true>;
55 using BlockGemmPolicy =
57 typename Problem::KDataType,
58 typename Problem::AccDataType,
59 typename Problem::BlockFmhaShape::Gemm0BlockWarps,
65 template <
typename Problem>
70 typename Problem::OGradDataType,
71 typename Problem::AccDataType,
74 Problem::BlockFmhaShape::kVHeaddim,
75 Problem::BlockFmhaShape::kK1>,
76 typename Problem::BlockFmhaShape::Gemm1BlockWarps,
77 typename Problem::BlockFmhaShape::Gemm1WarpTile>>;
81 typename Problem::OGradDataType,
82 typename Problem::AccDataType,
83 Problem::BlockFmhaShape::Gemm1WarpTile::at(
number<0>{}),
84 Problem::BlockFmhaShape::Gemm1WarpTile::at(
number<1>{}),
85 Problem::BlockFmhaShape::Gemm1WarpTile::at(
number<2>{}),
89 (Problem::BlockFmhaShape::Gemm1WarpTile::at(
number<2>{}) == 32)
90 ? WGAttrNumAccessEnum ::Double
91 : WGAttrNumAccessEnum ::Single>;
93 using BlockGemmPolicy =
95 typename Problem::OGradDataType,
96 typename Problem::AccDataType,
97 typename Problem::BlockFmhaShape::Gemm1BlockWarps,
103 template <
typename Problem>
108 typename Problem::VDataType,
109 typename Problem::AccDataType,
112 Problem::BlockFmhaShape::kN0,
113 Problem::BlockFmhaShape::kK2>,
114 typename Problem::BlockFmhaShape::Gemm2BlockWarps,
115 typename Problem::BlockFmhaShape::Gemm2WarpTile>>;
118 typename Problem::OGradDataType,
119 typename Problem::VDataType,
120 typename Problem::AccDataType,
121 Problem::BlockFmhaShape::Gemm2WarpTile::at(
number<0>{}),
122 Problem::BlockFmhaShape::Gemm2WarpTile::at(
number<1>{}),
123 Problem::BlockFmhaShape::Gemm2WarpTile::at(
number<2>{}),
125 Problem::BlockFmhaShape::Gemm0WarpTile::at(
number<0>{}) == 16 ?
false :
true>;
127 using BlockGemmPolicy =
129 typename Problem::VDataType,
130 typename Problem::AccDataType,
131 typename Problem::BlockFmhaShape::Gemm2BlockWarps,
137 template <
typename Problem>
142 typename Problem::QDataType,
143 typename Problem::AccDataType,
146 Problem::BlockFmhaShape::kQKHeaddim,
147 Problem::BlockFmhaShape::kK3>,
148 typename Problem::BlockFmhaShape::Gemm3BlockWarps,
149 typename Problem::BlockFmhaShape::Gemm3WarpTile>>;
153 typename Problem::QDataType,
154 typename Problem::AccDataType,
155 Problem::BlockFmhaShape::Gemm3WarpTile::at(
number<0>{}),
156 Problem::BlockFmhaShape::Gemm3WarpTile::at(
number<1>{}),
157 Problem::BlockFmhaShape::Gemm3WarpTile::at(
number<2>{}),
161 (Problem::BlockFmhaShape::Gemm1WarpTile::at(
number<2>{}) == 32)
162 ? WGAttrNumAccessEnum ::Double
163 : WGAttrNumAccessEnum ::Single>;
165 using BlockGemmPolicy =
167 typename Problem::QDataType,
168 typename Problem::AccDataType,
169 typename Problem::BlockFmhaShape::Gemm3BlockWarps,
175 template <
typename Problem>
180 typename Problem::KDataType,
181 typename Problem::AccDataType,
184 Problem::BlockFmhaShape::kQKHeaddim,
185 Problem::BlockFmhaShape::kK4>,
186 typename Problem::BlockFmhaShape::Gemm4BlockWarps,
187 typename Problem::BlockFmhaShape::Gemm4WarpTile>>;
190 typename Problem::KDataType,
191 typename Problem::AccDataType,
192 Problem::BlockFmhaShape::Gemm4WarpTile::at(
number<0>{}),
193 Problem::BlockFmhaShape::Gemm4WarpTile::at(
number<1>{}),
194 Problem::BlockFmhaShape::Gemm4WarpTile::at(
number<2>{}),
197 using BlockGemmPolicy =
199 typename Problem::KDataType,
200 typename Problem::AccDataType,
201 typename Problem::BlockFmhaShape::Gemm4BlockWarps,
208 template <
typename Problem>
212 constexpr index_t kBlockSize = Problem::kBlockSize;
213 constexpr index_t kMNPerBlock = Problem::BlockFmhaShape::kM0;
214 constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
215 constexpr index_t kMaxVecLoad = 16 /
sizeof(QDataType);
216 constexpr index_t kMinVecLoad = 4 /
sizeof(QDataType);
218 constexpr index_t total_pixels = kMNPerBlock * kKPerBlock / kBlockSize;
220 constexpr index_t kVecLoad = ((total_pixels / kMaxVecLoad) >= kMinVecLoad)
222 : (total_pixels / kMinVecLoad);
227 template <
typename Problem>
231 constexpr index_t kBlockSize = Problem::kBlockSize;
232 constexpr index_t kMNPerBlock = Problem::BlockFmhaShape::kN0;
233 constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
234 constexpr index_t kMaxVecLoad = 16 /
sizeof(KDataType);
235 constexpr index_t kMinVecLoad = 4 /
sizeof(KDataType);
237 constexpr index_t total_pixels = kMNPerBlock * kKPerBlock / kBlockSize;
239 constexpr index_t kVecLoad = ((total_pixels / kMaxVecLoad) >= kMinVecLoad)
241 : (total_pixels / kMinVecLoad);
246 template <
typename Problem>
250 constexpr index_t kBlockSize = Problem::kBlockSize;
251 constexpr index_t kMNPerBlock = Problem::BlockFmhaShape::kN0;
252 constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kVHeaddim;
253 constexpr index_t kMaxVecLoad = 16 /
sizeof(VDataType);
254 constexpr index_t total_pixels = kMNPerBlock * kKPerBlock / kBlockSize;
256 return total_pixels > kMaxVecLoad ? kMaxVecLoad : total_pixels;
259 template <
typename Problem>
263 return 16 /
sizeof(ODataType);
266 template <
typename Problem>
270 constexpr index_t kBlockSize = Problem::kBlockSize;
271 constexpr index_t kMNPerBlock = Problem::BlockFmhaShape::kM0;
272 constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kVHeaddim;
273 constexpr index_t kMaxVecLoad = 16 /
sizeof(OGradDataType);
274 constexpr index_t kMinVecLoad = 4 /
sizeof(OGradDataType);
276 constexpr index_t total_pixels = kMNPerBlock * kKPerBlock / kBlockSize;
278 constexpr index_t kVecLoad = ((total_pixels / kMaxVecLoad) >= kMinVecLoad)
280 : (total_pixels / kMinVecLoad);
285 template <
typename Problem>
289 constexpr index_t kBlockSize = Problem::kBlockSize;
290 constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
291 constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
292 constexpr index_t kMaxVecLoad = 16 /
sizeof(BiasDataType);
293 constexpr index_t kMinVecLoad = 4 /
sizeof(BiasDataType);
295 constexpr index_t total_pixels = kMPerBlock * kNPerBlock / kBlockSize;
297 constexpr index_t kVecLoad = ((total_pixels / kMaxVecLoad) >= kMinVecLoad)
299 : (total_pixels / kMinVecLoad);
304 template <
typename Problem>
308 constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
310 using CWarpDstr =
typename WG::CWarpDstr;
316 template <
typename Problem>
320 constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
322 using CWarpDstr =
typename WG::CWarpDstr;
328 template <
typename Problem>
331 constexpr index_t kBlockSize = Problem::kBlockSize;
332 constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kM0;
333 constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
335 constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize;
340 template <
typename Problem>
343 constexpr index_t kBlockSize = Problem::kBlockSize;
344 constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
345 constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
346 constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize;
351 template <
typename Problem>
354 constexpr index_t kBlockSize = Problem::kBlockSize;
355 constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kM0;
356 constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kVHeaddim;
358 constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize;
363 template <
typename Problem>
366 constexpr index_t kBlockSize = Problem::kBlockSize;
367 constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
368 constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
370 constexpr index_t total_pixels = kMPerBlock * kNPerBlock / kBlockSize;
375 template <
typename Problem>
379 return 16 /
sizeof(AccDataType);
382 template <
typename Problem>
388 template <
typename Problem>
391 constexpr index_t kBlockSize = Problem::kBlockSize;
393 constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
394 constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
397 constexpr index_t K0 = kKPerBlock / K1;
400 constexpr index_t N2 = kNPerBlock / (N1 * N0);
410 if constexpr((kKPerBlock & (kKPerBlock - 1)) == 0)
416 constexpr index_t kKPerIter = 32;
417 static_assert(kKPerBlock % kKPerIter == 0);
418 constexpr index_t K0_m = kKPerBlock / kKPerIter;
420 constexpr index_t K1_m = kKPerIter / K2;
422 constexpr index_t N2_m = kNPerBlock / (N1_m * N0);
431 static_assert(
container_reduce(dstr_m.get_lengths(), std::multiplies<index_t>{}, 1) ==
432 kNPerBlock * kKPerBlock);
437 template <
typename Problem>
440 constexpr index_t kBlockSize = Problem::kBlockSize;
442 constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
443 constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kVHeaddim;
446 constexpr index_t K0 = kKPerBlock / K1;
449 constexpr index_t N0 = kNPerBlock / (N2 * N1);
458 if constexpr((kKPerBlock & (kKPerBlock - 1)) == 0)
464 constexpr index_t kKPerIter = 32;
465 static_assert(kKPerBlock % kKPerIter == 0);
466 constexpr index_t K0_m = kKPerBlock / kKPerIter;
468 constexpr index_t K1_m = kKPerIter / K2;
470 constexpr index_t N0_m = kNPerBlock / (N2_m * N1);
479 static_assert(
container_reduce(dstr_m.get_lengths(), std::multiplies<index_t>{}, 1) ==
480 kNPerBlock * kKPerBlock);
485 template <
typename Problem>
488 constexpr index_t kBlockSize = Problem::kBlockSize;
490 constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
491 constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
494 constexpr index_t K0 = kKPerBlock / K1;
497 constexpr index_t M2 = kMPerBlock / (M1 * M0);
507 if constexpr((kKPerBlock & (kKPerBlock - 1)) == 0)
514 constexpr index_t kKPerIter = 32;
515 static_assert(kKPerBlock % kKPerIter == 0);
516 constexpr index_t K0_m = kKPerBlock / kKPerIter;
518 constexpr index_t K1_m = kKPerIter / K2;
520 constexpr index_t M2_m = kMPerBlock / (M1_m * M0);
529 static_assert(
container_reduce(dstr_m.get_lengths(), std::multiplies<index_t>{}, 1) ==
530 kMPerBlock * kKPerBlock);
535 template <
typename Problem>
538 constexpr index_t kBlockSize = Problem::kBlockSize;
540 constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
541 constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kVHeaddim;
544 constexpr index_t K0 = kKPerBlock / K1;
547 constexpr index_t M2 = kMPerBlock / (M1 * M0);
557 if constexpr((kKPerBlock & (kKPerBlock - 1)) == 0)
564 constexpr index_t kKPerIter = 32;
565 static_assert(kKPerBlock % kKPerIter == 0);
566 constexpr index_t K0_m = kKPerBlock / kKPerIter;
568 constexpr index_t K1_m = kKPerIter / K2;
570 constexpr index_t M2_m = kMPerBlock / (M1_m * M0);
579 static_assert(
container_reduce(dstr_m.get_lengths(), std::multiplies<index_t>{}, 1) ==
580 kMPerBlock * kKPerBlock);
585 template <
typename Problem,
typename BlockGemm>
588 constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
589 constexpr index_t MWarp = config.template at<1>();
590 constexpr index_t NWarp = config.template at<2>();
592 constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
613 template <
typename Problem>
616 constexpr index_t kBlockSize = Problem::kBlockSize;
618 constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
619 constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
622 constexpr index_t N0 = kNPerBlock / N1;
625 constexpr index_t M2 = kMPerBlock / (M1 * M0);
634 static_assert(
container_reduce(dstr.get_lengths(), std::multiplies<index_t>{}, 1) ==
635 kMPerBlock * kNPerBlock);
639 template <
typename DataType, index_t MPerBlock, index_t KPerBlock>
642 constexpr index_t K1 = 16 /
sizeof(DataType);
643 constexpr index_t K0 = KPerBlock / K1;
646 constexpr index_t M0 = MPerBlock / M1;
655 static_assert(
container_reduce(dstr.get_lengths(), std::multiplies<index_t>{}, 1) ==
656 MPerBlock * KPerBlock);
660 template <
typename Problem>
665 constexpr index_t kBlockSize = Problem::kBlockSize;
666 constexpr index_t kKPerBlock = Problem::kVHeaddim;
671 template <
typename Problem>
676 constexpr index_t kBlockSize = Problem::kBlockSize;
677 constexpr index_t kKPerBlock = Problem::kVHeaddim;
682 template <
typename Problem>
685 constexpr index_t kBlockSize = Problem::kBlockSize;
686 constexpr index_t kMPerBlock = Problem::kM0;
687 constexpr index_t kKPerBlock = Problem::kQKHeaddim;
691 constexpr index_t K0 = kKPerBlock / (K1 * K2);
695 constexpr index_t M0 = kMPerBlock / (M1 * M2);
705 static_assert(
container_reduce(dstr.get_lengths(), std::multiplies<index_t>{}, 1) ==
706 kMPerBlock * kKPerBlock);
710 template <
typename Problem>
713 constexpr index_t kBlockSize = Problem::kBlockSize;
714 constexpr index_t kMPerBlock = Problem::kM0;
715 constexpr index_t kKPerBlock = Problem::kQKHeaddim;
719 constexpr index_t K0 = kKPerBlock / (K1 * K2);
723 constexpr index_t M0 = kMPerBlock / (M1 * M2);
732 static_assert(
container_reduce(dstr.get_lengths(), std::multiplies<index_t>{}, 1) ==
733 kMPerBlock * kKPerBlock);
738 template <
typename Problem>
744 template <
typename Problem>
750 template <
typename Problem>
756 template <
typename Problem>
762 template <
typename Problem>
768 template <
typename Problem>
774 template <
typename Problem>
780 template <
typename Problem>
786 template <
typename Problem>
792 template <
typename Problem>
797 return 16 /
sizeof(GemmDataType);
800 template <index_t KIter, index_t MNPerBlock, index_t KPerSubBlock, index_t KPack>
803 constexpr auto DataTypeSize = 2;
804 constexpr auto MNLdsLayer =
805 (32 * 4 / KPerSubBlock / DataTypeSize) < 1 ? 1 : (32 * 4 / KPerSubBlock / DataTypeSize);
807 constexpr auto x_lds_block_desc_0 =
809 number<KPerSubBlock / KPack * MNLdsLayer>{},
810 number<MNPerBlock / MNLdsLayer>{},
823 number<KPerSubBlock / KPack * MNLdsLayer>{})),
829 x_lds_block_desc_permuted,
839 x_lds_block_desc_xk0_mnldslayer_mn_xk1,
848 std::multiplies<index_t>{},
849 1) == KIter * MNPerBlock * KPerSubBlock);
850 return x_lds_block_desc;
853 template <index_t MNPerBlock, index_t KPerBlock, index_t KPack>
858 template <
typename Problem,
867 template <
typename Problem,
878 constexpr auto MNPerXDL = Problem::BlockFmhaShape::Gemm0WarpTile::at(
number<0>{});
879 constexpr auto kBlockSize = Problem::kBlockSize;
881 constexpr auto MN0 = MNPerSubBlock / KPack;
882 constexpr auto MN1 = KPack;
884 constexpr auto KThreadWrite = kBlockSize / MN0;
885 constexpr auto K0Number = KPerBlock / KPackT;
886 constexpr auto K0PerThreadWrite = K0Number / KThreadWrite;
888 constexpr auto K0PerThreadRead = K0Number / KThreadRead;
890 constexpr auto kfold = (KPackT * MN0 * 2 > 128) ? 1 : 128 / (KPackT * MN0 * 2);
891 constexpr auto KThreadReadPerm =
892 (kfold * K0PerThreadWrite / K0PerThreadRead) > 1
893 ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead)
897 constexpr auto mnpair =
898 (KPackT * MNPerXDL * 2 > 128)
900 : ((128 / (KPackT * MNPerXDL * 2)) > MN0 ? MN0 : 128 / (KPackT * MNPerXDL * 2));
904 number<KThreadWrite / kfold / KThreadReadPerm>{},
907 number<kfold * MN0 / mnpair>{},
921 xt_lds_block_desc_raw,
944 xt_lds_block_desc_permuted,
969 xt_lds_block_desc_unmerged,
973 number<KThreadWrite / kfold / KThreadReadPerm>{},
982 std::multiplies<index_t>{},
983 1) == MNPerSubBlock * MNIter * KPerBlock);
984 return xt_lds_block_desc;
987 template <
typename Problem>
990 constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
991 constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
994 constexpr index_t dram_y_ndim =
typename dram_encoding::Ys2RHsMajor{}.size();
995 if constexpr(dram_y_ndim == 2)
1000 else if constexpr(dram_y_ndim == 3)
1002 constexpr index_t KIter =
typename dram_encoding::HsLengthss{}.at(
number<1>{}).at(0);
1003 constexpr index_t kKPack =
typename dram_encoding::HsLengthss{}.at(
number<1>{}).at(2);
1008 static_assert(
false,
"Unexpected dram y dimension");
1012 template <
typename Problem>
1016 constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
1017 using WarpGemm =
remove_cvref_t<
decltype(config.template at<0>())>;
1019 constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm0BlockWarps::at(
number<0>{});
1020 constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm0BlockWarps::at(
number<1>{});
1022 constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
1023 constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
1025 constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WarpGemm::kN);
1026 constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK;
1028 constexpr auto k_block_outer_dstr_encoding =
1037 k_block_outer_dstr_encoding,
typename WarpGemm::BWarpDstrEncoding{});
1040 static_assert(
container_reduce(k_block_dstr.get_lengths(), std::multiplies<index_t>{}, 1) ==
1041 kNPerBlock * kKPerBlock);
1042 return k_block_dstr;
1045 template <
typename Problem>
1048 constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
1049 constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kVHeaddim;
1052 constexpr index_t dram_y_ndim =
typename dram_encoding::Ys2RHsMajor{}.size();
1053 if constexpr(dram_y_ndim == 2)
1058 else if constexpr(dram_y_ndim == 3)
1060 constexpr index_t KIter =
typename dram_encoding::HsLengthss{}.at(
number<1>{}).at(0);
1061 constexpr index_t kVPack =
typename dram_encoding::HsLengthss{}.at(
number<1>{}).at(2);
1066 static_assert(
false,
"Unexpected dram y dimension");
1070 template <
typename Problem>
1074 constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
1075 using WarpGemm =
remove_cvref_t<
decltype(config.template at<0>())>;
1077 constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm2BlockWarps::at(
number<0>{});
1078 constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm2BlockWarps::at(
number<1>{});
1080 constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
1081 constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK2;
1083 constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WarpGemm::kN);
1084 constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK;
1086 constexpr auto v_block_outer_dstr_encoding =
1095 v_block_outer_dstr_encoding,
typename WarpGemm::BWarpDstrEncoding{});
1098 static_assert(
container_reduce(v_block_dstr.get_lengths(), std::multiplies<index_t>{}, 1) ==
1099 kNPerBlock * kKPerBlock);
1100 return v_block_dstr;
1103 template <
typename Problem>
1107 constexpr index_t y_ndim =
typename dram_encoding::Ys2RHsMajor{}.size();
1108 static_assert(y_ndim >= 2);
1109 using shuffled_encoding_t =
1115 template <
typename Problem>
1119 constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
1120 constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kN0;
1123 constexpr index_t dram_y_ndim =
typename dram_encoding::Ys2RHsMajor{}.size();
1124 if constexpr(dram_y_ndim == 2)
1130 else if constexpr(dram_y_ndim == 3)
1132 constexpr index_t KIter =
typename dram_encoding::HsLengthss{}.at(
number<1>{}).at(0);
1133 constexpr index_t kKPack =
typename dram_encoding::HsLengthss{}.at(
number<1>{}).at(2);
1134 constexpr index_t kKPackT =
typename dram_encoding::HsLengthss{}.at(
number<0>{}).at(2);
1144 static_assert(
false,
"Unexpected dram y dimension");
1148 template <
typename Problem>
1151 constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
1152 constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kN0;
1157 shuffled_k_lds_block_desc,
1164 template <
typename Problem>
1168 constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
1169 using WarpGemm =
remove_cvref_t<
decltype(config.template at<0>())>;
1171 constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm4BlockWarps::at(
number<0>{});
1172 constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm4BlockWarps::at(
number<1>{});
1174 constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
1175 constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kN0;
1177 constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WarpGemm::kN);
1178 constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK;
1180 constexpr auto kt_block_outer_dstr_encoding =
1189 kt_block_outer_dstr_encoding,
typename WarpGemm::BWarpDstrEncoding{});
1193 std::multiplies<index_t>{},
1194 1) == kNPerBlock * kKPerBlock);
1195 return kt_block_dstr;
1198 template <
typename Problem>
1201 constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
1202 constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
1205 constexpr index_t dram_y_ndim =
typename dram_encoding::Ys2RHsMajor{}.size();
1206 if constexpr(dram_y_ndim == 2)
1211 else if constexpr(dram_y_ndim == 3)
1213 constexpr index_t KIter =
typename dram_encoding::HsLengthss{}.at(
number<1>{}).at(0);
1214 constexpr index_t kKPack =
typename dram_encoding::HsLengthss{}.at(
number<1>{}).at(2);
1219 static_assert(
false,
"Unexpected dram y dimension");
1223 template <
typename Problem>
1227 constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
1228 using WarpGemm =
remove_cvref_t<
decltype(config.template at<0>())>;
1230 constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm0BlockWarps::at(
number<0>{});
1231 constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm0BlockWarps::at(
number<1>{});
1233 constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
1234 constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
1236 constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WarpGemm::kM);
1237 constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK;
1239 constexpr auto q_block_outer_dstr_encoding =
1248 q_block_outer_dstr_encoding,
typename WarpGemm::AWarpDstrEncoding{});
1251 static_assert(
container_reduce(q_block_dstr.get_lengths(), std::multiplies<index_t>{}, 1) ==
1252 kMPerBlock * kKPerBlock);
1253 return q_block_dstr;
1256 template <
typename Problem>
1260 constexpr index_t y_ndim =
typename dram_encoding::Ys2RHsMajor{}.size();
1261 static_assert(y_ndim >= 2);
1262 using shuffled_encoding_t =
1268 template <
typename Problem>
1272 constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
1273 constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kM0;
1276 constexpr index_t dram_y_ndim =
typename dram_encoding::Ys2RHsMajor{}.size();
1277 if constexpr(dram_y_ndim == 2)
1283 else if constexpr(dram_y_ndim == 3)
1285 constexpr index_t KIter =
typename dram_encoding::HsLengthss{}.at(
number<1>{}).at(0);
1286 constexpr index_t kKPack =
typename dram_encoding::HsLengthss{}.at(
number<1>{}).at(2);
1287 constexpr index_t kKPackT =
typename dram_encoding::HsLengthss{}.at(
number<0>{}).at(2);
1297 static_assert(
false,
"Unexpected dram y dimension");
1301 template <
typename Problem>
1305 constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
1306 constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kM0;
1311 shuffled_q_lds_block_desc,
1318 template <
typename Problem>
1322 constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
1323 using WarpGemm =
remove_cvref_t<
decltype(config.template at<0>())>;
1325 constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm3BlockWarps::at(
number<0>{});
1326 constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm3BlockWarps::at(
number<1>{});
1328 constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
1329 constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK3;
1331 constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WarpGemm::kN);
1332 constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK;
1334 constexpr auto qt_block_outer_dstr_encoding =
1343 qt_block_outer_dstr_encoding,
typename WarpGemm::BWarpDstrEncoding{});
1347 std::multiplies<index_t>{},
1348 1) == kNPerBlock * kKPerBlock);
1350 return qt_block_dstr;
1353 template <
typename Problem>
1357 constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
1358 using WarpGemm =
remove_cvref_t<
decltype(config.template at<0>())>;
1360 constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm3BlockWarps::at(
number<0>{});
1361 constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm3BlockWarps::at(
number<1>{});
1363 constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kN0;
1364 constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK3;
1366 constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WarpGemm::kM);
1367 constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK;
1369 constexpr auto dst_block_outer_dstr_encoding =
1378 dst_block_outer_dstr_encoding,
typename WarpGemm::AWarpDstrEncoding{});
1382 std::multiplies<index_t>{},
1383 1) == kMPerBlock * kKPerBlock);
1384 return dst_block_dstr;
1387 template <
typename Problem>
1390 constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
1392 constexpr index_t kMPack = 16 /
sizeof(LSEDType);
1394 constexpr auto lsed_lds_block_desc =
1400 return lsed_lds_block_desc;
1403 template <
typename Problem,
typename BlockGemm>
1406 constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
1408 constexpr index_t MWarp = config.template at<1>();
1409 constexpr index_t NWarp = config.template at<2>();
1411 constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
1413 constexpr index_t N1 = WG::WarpGemmAttribute::Impl::kCNLane;
1417 constexpr index_t SwizzleConfig = WG::kM == 16 ? 1 : 2;
1419 constexpr index_t M4 = WG::WarpGemmAttribute::Impl::kCM1PerLane * SwizzleConfig;
1420 constexpr index_t M3 = WG::WarpGemmAttribute::Impl::kCMLane;
1421 constexpr index_t M2 = WG::WarpGemmAttribute::Impl::kCM0PerLane / SwizzleConfig;
1423 constexpr index_t M0 = kMPerBlock / (M1 * WG::WarpGemmAttribute::Impl::kM);
1432 static_assert(
container_reduce(dstr.get_lengths(), std::multiplies<index_t>{}, 1) ==
1437 template <
typename Problem>
1441 constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
1442 constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kVHeaddim;
1444 using dram_encoding =
1446 constexpr index_t dram_y_ndim =
typename dram_encoding::Ys2RHsMajor{}.size();
1447 if constexpr(dram_y_ndim == 2)
1452 else if constexpr(dram_y_ndim == 3)
1454 constexpr index_t KIter =
typename dram_encoding::HsLengthss{}.at(
number<1>{}).at(0);
1455 constexpr index_t kKPack =
typename dram_encoding::HsLengthss{}.at(
number<1>{}).at(2);
1460 static_assert(
false,
"Unexpected dram y dimension");
1464 template <
typename Problem>
1468 constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
1469 using WarpGemm =
remove_cvref_t<
decltype(config.template at<0>())>;
1471 constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm2BlockWarps::at(
number<0>{});
1472 constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm2BlockWarps::at(
number<1>{});
1474 constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
1475 constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK2;
1477 constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WarpGemm::kM);
1478 constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK;
1480 constexpr auto do_block_outer_dstr_encoding =
1489 do_block_outer_dstr_encoding,
typename WarpGemm::AWarpDstrEncoding{});
1493 std::multiplies<index_t>{},
1494 1) == kMPerBlock * kKPerBlock);
1495 return do_block_dstr;
1498 template <
typename Problem>
1502 using dram_encoding =
1504 constexpr index_t y_ndim =
typename dram_encoding::Ys2RHsMajor{}.size();
1505 static_assert(y_ndim >= 2);
1506 using shuffled_encoding_t =
1512 template <
typename Problem>
1516 constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kVHeaddim;
1517 constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kM0;
1519 using dram_encoding =
1521 constexpr index_t dram_y_ndim =
typename dram_encoding::Ys2RHsMajor{}.size();
1522 if constexpr(dram_y_ndim == 2)
1528 else if constexpr(dram_y_ndim == 3)
1530 constexpr index_t KIter =
typename dram_encoding::HsLengthss{}.at(
number<1>{}).at(0);
1531 constexpr index_t kKPack =
typename dram_encoding::HsLengthss{}.at(
number<1>{}).at(2);
1532 constexpr index_t kKPackT =
typename dram_encoding::HsLengthss{}.at(
number<0>{}).at(2);
1542 static_assert(
false,
"Unexpected dram y dimension");
1546 template <
typename Problem>
1550 constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kVHeaddim;
1551 constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kM0;
1555 shuffled_do_lds_block_desc,
1562 template <
typename Problem>
1566 constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
1567 using WarpGemm =
remove_cvref_t<
decltype(config.template at<0>())>;
1569 constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm1BlockWarps::at(
number<0>{});
1570 constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm1BlockWarps::at(
number<1>{});
1572 constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kVHeaddim;
1574 constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
1576 constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WarpGemm::kN);
1577 constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK;
1579 constexpr auto dot_block_outer_dstr_encoding =
1588 dot_block_outer_dstr_encoding,
typename WarpGemm::BWarpDstrEncoding{});
1592 std::multiplies<index_t>{},
1593 1) == kNPerBlock * kKPerBlock);
1594 return dot_block_dstr;
1597 template <
typename Problem>
1601 constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
1602 using WarpGemm =
remove_cvref_t<
decltype(config.template at<0>())>;
1604 constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm1BlockWarps::at(
number<0>{});
1605 constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm1BlockWarps::at(
number<1>{});
1607 constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kN0;
1608 constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
1610 constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WarpGemm::kM);
1611 constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK;
1613 constexpr auto pt_block_outer_dstr_encoding =
1622 pt_block_outer_dstr_encoding,
typename WarpGemm::AWarpDstrEncoding{});
1626 std::multiplies<index_t>{},
1627 1) == kMPerBlock * kKPerBlock);
1628 return pt_block_dstr;
1631 template <
typename Problem>
1634 constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
1635 constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kN0;
1641 template <
typename Problem>
1645 constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
1646 using WarpGemm =
remove_cvref_t<
decltype(config.template at<0>())>;
1648 constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm4BlockWarps::at(
number<0>{});
1649 constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm4BlockWarps::at(
number<1>{});
1651 constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
1652 constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK4;
1654 constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WarpGemm::kM);
1655 constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK;
1657 constexpr auto ds_block_outer_dstr_encoding =
1666 ds_block_outer_dstr_encoding,
typename WarpGemm::AWarpDstrEncoding{});
1670 std::multiplies<index_t>{},
1671 1) == kMPerBlock * kKPerBlock);
1672 return ds_block_dstr;
1675 template <
typename Problem,
typename PTOutTensor,
typename PInTensor>
1677 const PInTensor& p_in)
1679 if constexpr(Problem::BlockFmhaShape::Gemm1WarpTile::at(
number<0>{}) == 16)
1682 constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
1683 using WarpGemm =
remove_cvref_t<
decltype(config.template at<0>())>;
1685 constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm1BlockWarps::at(
number<0>{});
1687 constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kN0;
1688 constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
1690 constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WarpGemm::kM);
1691 constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK;
1693 using AWarpDstr =
typename WarpGemm::AWarpDstr;
1694 using CWarpDstr =
typename WarpGemm::CWarpDstr;
1695 auto pt_warp_tensor =
1698 constexpr auto a_warp_y_lengths =
1699 to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
1700 constexpr auto c_warp_y_lengths =
1701 to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
1708 pt_warp_tensor.get_thread_buffer() = p_in.get_y_sliced_thread_data(
1712 pt_out.set_y_sliced_thread_data(
1715 pt_warp_tensor.get_thread_buffer());
1721 pt_out.get_thread_buffer() = p_in.get_thread_buffer();
1725 template <
typename Problem,
typename SGradTOutTensor,
typename SGradInTensor>
1727 const SGradInTensor& ds_in)
1729 if constexpr(Problem::BlockFmhaShape::Gemm3WarpTile::at(
number<0>{}) == 16)
1732 constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
1733 using WarpGemm =
remove_cvref_t<
decltype(config.template at<0>())>;
1735 constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm3BlockWarps::at(
number<0>{});
1737 constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kN0;
1738 constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK3;
1740 constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WarpGemm::kM);
1741 constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK;
1743 using AWarpDstr =
typename WarpGemm::AWarpDstr;
1744 using CWarpDstr =
typename WarpGemm::CWarpDstr;
1745 auto dst_warp_tensor =
1748 constexpr auto a_warp_y_lengths =
1749 to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
1750 constexpr auto c_warp_y_lengths =
1751 to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
1758 dst_warp_tensor.get_thread_buffer() = ds_in.get_y_sliced_thread_data(
1762 dst_out.set_y_sliced_thread_data(
1765 dst_warp_tensor.get_thread_buffer());
1771 dst_out.get_thread_buffer() = ds_in.get_thread_buffer();
1775 template <
typename Problem>
1778 constexpr index_t kBlockSize = Problem::kBlockSize;
1780 constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
1783 constexpr index_t N0 = kNPerBlock / N1;
1797 template <
typename Problem>
1801 constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
1802 constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
1810 template <
typename BlockGemm>
1813 using c_block_tensor_type =
decltype(BlockGemm{}.MakeCBlockTile());
1814 return c_block_tensor_type::get_tile_distribution();
1817 template <
typename Problem>
1820 constexpr index_t smem_size_q =
sizeof(
typename Problem::QDataType) *
1825 template <
typename Problem>
1828 constexpr index_t smem_size_qt =
1829 sizeof(
typename Problem::QDataType) *
1832 return smem_size_qt;
1835 template <
typename Problem>
1838 constexpr index_t smem_size_k =
1839 sizeof(
typename Problem::KDataType) *
1844 template <
typename Problem>
1847 constexpr index_t smem_size_kt =
1848 sizeof(
typename Problem::KDataType) *
1850 return smem_size_kt;
1853 template <
typename Problem>
1856 constexpr index_t smem_size_lse =
1857 sizeof(
typename Problem::LSEDataType) *
1859 return smem_size_lse;
1862 template <
typename Problem>
1865 constexpr index_t smem_size_d =
1866 sizeof(
typename Problem::DDataType) *
1871 template <
typename Problem>
1874 constexpr index_t smem_size_v =
1875 sizeof(
typename Problem::VDataType) *
1880 template <
typename Problem>
1883 constexpr index_t smem_size_do =
1884 sizeof(
typename Problem::OGradDataType) *
1886 return smem_size_do;
1889 template <
typename Problem>
1892 constexpr index_t smem_size_dot =
1893 sizeof(
typename Problem::OGradDataType) *
1895 return smem_size_dot;
1898 template <
typename Problem>
1901 constexpr index_t smem_size_ds =
1902 sizeof(
typename Problem::GemmDataType) *
1904 return smem_size_ds;
1907 template <
typename Problem>
1910 constexpr index_t smem_size_bias = [&]() {
1912 return sizeof(
typename Problem::BiasDataType) *
1917 return smem_size_bias;
1920 template <
typename Problem>
1935 constexpr index_t smem_size_stage0_0 = smem_size_k + smem_size_kt;
1936 constexpr index_t smem_size_stage0_1 = smem_size_v;
1937 constexpr index_t smem_size_stage1 = smem_size_qt + smem_size_q + smem_size_dot +
1938 smem_size_do + smem_size_lse + smem_size_d +
1939 max(smem_size_bias, smem_size_ds);
1941 return max(smem_size_stage0_0, smem_size_stage0_1, smem_size_stage1);
1944 template <
typename Problem_>
1949 template <index_t GemmStage>
1959 constexpr index_t VMEM_READ_INST =
1960 Q_VMEM_READ + OGrad_VMEM_READ + LSE_VMEM_READ + D_VMEM_READ;
1961 constexpr index_t LDS_READ_INST = OGradT_LDS_READ;
1962 constexpr index_t MFMA_INST = Gemm0MFMA;
1965 constexpr index_t MFMA_PER_VMEM_READ = MFMA_INST / VMEM_READ_INST;
1966 constexpr index_t MFMA_Remainder = MFMA_INST - MFMA_PER_VMEM_READ * VMEM_READ_INST;
1968 constexpr index_t LDS_READ_PER_MFMA = LDS_READ_INST / MFMA_INST;
1972 __builtin_amdgcn_sched_group_barrier(0x020, 1, 0);
1975 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0);
1976 __builtin_amdgcn_sched_group_barrier(0x100, LDS_READ_PER_MFMA, 0);
1981 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0);
1982 __builtin_amdgcn_sched_group_barrier(0x100, LDS_READ_PER_MFMA, 0);
1991 constexpr index_t LDS_READ_INST = QT_LDS_READ;
1992 constexpr index_t MFMA_INST = Gemm1MFMA;
1995 constexpr index_t LDS_READ_PER_MFMA = LDS_READ_INST / MFMA_INST;
1999 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0);
2000 __builtin_amdgcn_sched_group_barrier(0x100, LDS_READ_PER_MFMA, 0);
2009 constexpr index_t LDS_WRITE_INST = Q_LDS_WRITE + QT_LDS_WRITE + OGrad_LDS_WRITE +
2010 OGradT_LDS_WRITE + LSE_LDS_WRITE + D_LDS_WRITE;
2011 constexpr index_t MFMA_INST = Gemm2MFMA;
2014 constexpr index_t LDS_WRITE_PER_MFMA = LDS_WRITE_INST / MFMA_INST;
2018 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0);
2019 __builtin_amdgcn_sched_group_barrier(0x200, LDS_WRITE_PER_MFMA, 0);
2028 constexpr index_t LDS_WRITE_INST = SGradT_LDS_WRITE;
2029 constexpr index_t LDS_READ_INST = SGradT_LDS_READ_P1 + Q_LDS_READ + LSE_LDS_READ;
2030 constexpr index_t MFMA_INST = Gemm3MFMA;
2033 constexpr index_t LDS_WRITE_PER_MFMA =
2034 LDS_WRITE_INST / MFMA_INST >= 1 ? LDS_WRITE_INST / MFMA_INST : 1;
2035 constexpr index_t MFMA_INST_LDS_WRITE = LDS_WRITE_INST / LDS_WRITE_PER_MFMA;
2037 constexpr index_t LDS_READ_PER_MFMA =
2038 (MFMA_INST - MFMA_INST_LDS_WRITE) > 0
2039 ? LDS_READ_INST / (MFMA_INST - MFMA_INST_LDS_WRITE) > 0
2040 ? LDS_READ_INST / (MFMA_INST - MFMA_INST_LDS_WRITE)
2046 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0);
2047 __builtin_amdgcn_sched_group_barrier(0x200, LDS_WRITE_PER_MFMA, 0);
2050 static_for<0, MFMA_INST - MFMA_INST_LDS_WRITE, 1>{}([&](
auto i) {
2052 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0);
2053 __builtin_amdgcn_sched_group_barrier(0x100, LDS_READ_PER_MFMA, 0);
2062 constexpr index_t LDS_READ_INST = SGradT_LDS_READ_P2 + OGrad_LDS_READ + D_LDS_READ;
2063 constexpr index_t MFMA_INST = Gemm4MFMA;
2066 constexpr index_t LDS_READ_PER_MFMA =
2067 LDS_READ_INST / MFMA_INST > 0 ? LDS_READ_INST / MFMA_INST : 1;
2071 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0);
2072 __builtin_amdgcn_sched_group_barrier(0x100, LDS_READ_PER_MFMA, 0);
2077 static constexpr index_t kBlockSize = Problem::kBlockSize;
2078 static constexpr index_t kM0 = Problem::BlockFmhaShape::kM0;
2079 static constexpr index_t kN0 = Problem::BlockFmhaShape::kN0;
2080 static constexpr index_t kQKHeaddim = Problem::BlockFmhaShape::kQKHeaddim;
2081 static constexpr index_t kVHeaddim = Problem::BlockFmhaShape::kVHeaddim;
2082 static constexpr index_t kK0 = Problem::BlockFmhaShape::kK0;
2083 static constexpr index_t kK2 = Problem::BlockFmhaShape::kK2;
2084 static constexpr index_t kK4 = Problem::BlockFmhaShape::kK4;
2086 static constexpr index_t WarpGemmM =
2087 Problem::BlockFmhaShape::Gemm0WarpTile::at(
number<0>{});
2088 static constexpr index_t WarpGemmN =
2089 Problem::BlockFmhaShape::Gemm0WarpTile::at(
number<1>{});
2090 static constexpr index_t WarpGemmK = WarpGemmM == 16 ? 16 : 8;
2091 static constexpr index_t Gemm4MWarp =
2092 Problem::BlockFmhaShape::Gemm4BlockWarps::at(
number<0>{});
2093 static constexpr index_t Gemm4NWarp =
2094 Problem::BlockFmhaShape::Gemm4BlockWarps::at(
number<1>{});
2097 static constexpr index_t Gemm0MFMA =
2098 kM0 * kN0 * kK0 / (kBlockSize /
get_warp_size() * WarpGemmM * WarpGemmN * WarpGemmK);
2099 static constexpr index_t Gemm1MFMA =
2100 kN0 * kVHeaddim * kM0 /
2101 (kBlockSize /
get_warp_size() * WarpGemmM * WarpGemmN * WarpGemmK);
2102 static constexpr index_t Gemm2MFMA =
2103 kM0 * kN0 * kK2 / (kBlockSize /
get_warp_size() * WarpGemmM * WarpGemmN * WarpGemmK);
2104 static constexpr index_t Gemm3MFMA =
2105 kN0 * kQKHeaddim * kM0 /
2106 (kBlockSize /
get_warp_size() * WarpGemmM * WarpGemmN * WarpGemmK);
2107 static constexpr index_t Gemm4MFMA =
2108 kM0 * kQKHeaddim * kN0 /
2109 (kBlockSize /
get_warp_size() * WarpGemmM * WarpGemmN * WarpGemmK);
2112 static constexpr index_t Q_VMEM_READ =
2114 static constexpr index_t OGrad_VMEM_READ =
2116 static constexpr index_t LSE_VMEM_READ = 1;
2117 static constexpr index_t D_VMEM_READ = 1;
2120 static constexpr index_t OGradT_LDS_READ =
2122 static constexpr index_t QT_LDS_READ =
2124 static constexpr index_t SGradT_LDS_READ_P1 =
2127 static constexpr index_t LSE_LDS_READ = WarpGemmM == 16 ? kM0 / (4 * 4) : kM0 / (2 * 4);
2128 static constexpr index_t SGradT_LDS_READ_P2 =
2130 static constexpr index_t OGrad_LDS_READ =
2132 static constexpr index_t D_LDS_READ = WarpGemmM == 16 ? kM0 / (4 * 4) : kM0 / (2 * 4);
2135 static constexpr index_t Q_LDS_WRITE =
2137 static constexpr index_t QT_LDS_WRITE =
2139 static constexpr index_t OGrad_LDS_WRITE =
2141 static constexpr index_t OGradT_LDS_WRITE =
2143 static constexpr index_t LSE_LDS_WRITE = 1;
2144 static constexpr index_t D_LDS_WRITE = 1;
2145 static constexpr index_t SGradT_LDS_WRITE = kM0 * kN0 / kBlockSize;
#define CK_TILE_DEVICE
Definition config.hpp:41
#define CK_TILE_HOST_DEVICE
Definition config.hpp:42
CK_TILE_HOST_DEVICE constexpr auto make_embed_tile_distribution_encoding(OuterDstr, InnerDstr)
Definition tile_distribution_encoding.hpp:457
Definition tile/core/algorithm/cluster_descriptor.hpp:13
typename impl::WarpGemmDispatcher< AType, BType, AccType, MPerWave, NPerWave, KPerWave, TransposeC, SwizzleA, UseStructuredSparsity, AttrNumAccess >::Type WarpGemmDispatcher
Definition warp_gemm_dispatcher.hpp:182
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
CK_TILE_HOST_DEVICE constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition coordinate_transform.hpp:1558
CK_TILE_HOST_DEVICE constexpr index_t get_warp_size()
Definition arch.hpp:63
CK_TILE_HOST_DEVICE constexpr auto container_reduce(const Container &x, Reduce reduce, Init init, number< IBegin >=number< 0 >{}, number< IEnd >=number< Container::size()>{}, number< IStep >=number< 1 >{})
Definition tile/core/container/container_helper.hpp:198
@ ELEMENTWISE_BIAS
Definition block_attention_bias_enum.hpp:14
CK_TILE_HOST_DEVICE constexpr auto generate_sequence_v2(F &&f, number< N >)
Definition tile/core/container/sequence.hpp:1045
CK_TILE_HOST_DEVICE constexpr auto make_naive_tensor_descriptor(const tuple< Lengths... > &lengths, const tuple< Strides... > &strides, number< GuaranteedLastDimensionVectorLength >=number<-1 >{}, number< GuaranteedLastDimensionVectorStride >=number<-1 >{})
Definition tile/core/tensor/tensor_descriptor.hpp:274
typename tile_distribution_encoding_shuffle< encoding, shuffle >::type tile_distribution_encoding_shuffle_t
Definition tile_distribution_encoding.hpp:451
CK_TILE_HOST_DEVICE constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldTopIdss, NewUpperDimensionNewTopIdss)
Definition tile/core/tensor/tensor_descriptor.hpp:203
CK_TILE_HOST_DEVICE constexpr auto make_unmerge_transform(const UpLengths &up_lengths, bool_constant< Use24BitIntegerCalculation >=bool_constant< false >{})
Definition coordinate_transform.hpp:1622
CK_TILE_HOST_DEVICE constexpr auto make_static_distributed_tensor(const StaticTileDistribution &)
Definition static_distributed_tensor.hpp:142
CK_TILE_HOST_DEVICE constexpr auto merge_sequences(Seqs...)
Definition tile/core/container/sequence.hpp:826
constant< v > number
Definition tile/core/numeric/integral_constant.hpp:37
CK_TILE_HOST_DEVICE constexpr auto make_xor_transform(const LowLengths &low_lengths)
Definition coordinate_transform.hpp:1662
constexpr detail::ignore_t ignore
Definition tile/core/utility/ignore.hpp:20
CK_TILE_HOST_DEVICE constexpr auto to_sequence(tuple< number< Is >... >)
Definition tile/core/container/sequence.hpp:1055
CK_TILE_HOST_DEVICE constexpr auto make_merge_transform_v3_division_mod(const LowLengths &low_lengths)
Definition coordinate_transform.hpp:1609
CK_TILE_HOST_DEVICE constexpr T max(T x)
Definition tile/core/numeric/math.hpp:161
CK_TILE_HOST_DEVICE constexpr T min(T x)
Definition tile/core/numeric/math.hpp:210
typename uniform_sequence_gen< NSize, I >::type uniform_sequence_gen_t
Definition tile/core/container/sequence.hpp:1026
int32_t index_t
Definition integer.hpp:9
CK_TILE_HOST_DEVICE constexpr auto make_static_tile_distribution(StaticTileDistributionEncoding_)
Definition tile_distribution.hpp:480
CK_TILE_HOST_DEVICE constexpr auto make_tuple(Xs &&... xs)
Definition tile/core/container/tuple.hpp:360
Definition block_fmha_bwd_pipeline_default_policy.hpp:1946
Problem_ Problem
Definition block_fmha_bwd_pipeline_default_policy.hpp:1947
static CK_TILE_DEVICE constexpr void GemmStagedScheduler()
Definition block_fmha_bwd_pipeline_default_policy.hpp:1950
Definition block_fmha_bwd_pipeline_default_policy.hpp:23
static CK_TILE_HOST_DEVICE constexpr index_t GetSmemSizeOGrad()
Definition block_fmha_bwd_pipeline_default_policy.hpp:1881
static CK_TILE_HOST_DEVICE constexpr auto MakeShuffledOGradLdsWriteBlockDescriptor()
Definition block_fmha_bwd_pipeline_default_policy.hpp:1513
static CK_TILE_HOST_DEVICE constexpr auto MakeShuffledOGradRegWriteBlockDescriptor()
Definition block_fmha_bwd_pipeline_default_policy.hpp:1499
static CK_TILE_HOST_DEVICE constexpr auto MakeOGradLdsBlockDescriptor()
Definition block_fmha_bwd_pipeline_default_policy.hpp:1438
static CK_TILE_HOST_DEVICE constexpr auto GetAlignmentPostQGradAcc()
Definition block_fmha_bwd_pipeline_default_policy.hpp:376
static CK_TILE_HOST_DEVICE constexpr auto MakeShuffledKLdsWriteBlockDescriptor()
Definition block_fmha_bwd_pipeline_default_policy.hpp:1116
static CK_TILE_HOST_DEVICE constexpr auto GetSmemKPackKT()
Definition block_fmha_bwd_pipeline_default_policy.hpp:757
static CK_TILE_HOST_DEVICE constexpr auto MakeQRegSliceBlockDescriptor()
Definition block_fmha_bwd_pipeline_default_policy.hpp:1224
static CK_TILE_HOST_DEVICE constexpr auto MakeKDramTileDistribution()
Definition block_fmha_bwd_pipeline_default_policy.hpp:389
static CK_TILE_HOST_DEVICE constexpr auto MakeSGradTRegSliceBlockDescriptor()
Definition block_fmha_bwd_pipeline_default_policy.hpp:1354
static CK_TILE_HOST_DEVICE constexpr auto GetOGradVBlockGemm()
Definition block_fmha_bwd_pipeline_default_policy.hpp:104
static CK_TILE_HOST_DEVICE constexpr auto MakeKTLdsReadBlockDescriptor()
Definition block_fmha_bwd_pipeline_default_policy.hpp:1149
static CK_TILE_HOST_DEVICE constexpr auto MakeOGradTRegSliceBlockDescriptor()
Definition block_fmha_bwd_pipeline_default_policy.hpp:1563
static CK_TILE_HOST_DEVICE constexpr auto MakeLSEDLdsReadBlockDescriptor()
Definition block_fmha_bwd_pipeline_default_policy.hpp:1404
static CK_TILE_HOST_DEVICE constexpr auto GetAlignmentO()
Definition block_fmha_bwd_pipeline_default_policy.hpp:260
static CK_TILE_HOST_DEVICE constexpr auto GetSmemKPackBiasT()
Definition block_fmha_bwd_pipeline_default_policy.hpp:775
static CK_TILE_HOST_DEVICE constexpr auto GetQKBlockGemm()
Definition block_fmha_bwd_pipeline_default_policy.hpp:32
static CK_TILE_HOST_DEVICE constexpr index_t GetSmemSizeBias()
Definition block_fmha_bwd_pipeline_default_policy.hpp:1908
static CK_TILE_HOST_DEVICE constexpr auto MakeOGradDramTileDistribution()
Definition block_fmha_bwd_pipeline_default_policy.hpp:536
static CK_TILE_HOST_DEVICE constexpr auto MakeSGradRegSliceBlockDescriptor()
Definition block_fmha_bwd_pipeline_default_policy.hpp:1642
static CK_TILE_HOST_DEVICE constexpr auto GetTransposedAlignmentK()
Definition block_fmha_bwd_pipeline_default_policy.hpp:341
static CK_TILE_HOST_DEVICE constexpr auto MakePreOGradDramTileDistribution()
Definition block_fmha_bwd_pipeline_default_policy.hpp:672
static CK_TILE_HOST_DEVICE constexpr auto MakeShuffledKRegWriteBlockDescriptor()
Definition block_fmha_bwd_pipeline_default_policy.hpp:1104
static CK_TILE_HOST_DEVICE constexpr index_t GetSmemSizeLSE()
Definition block_fmha_bwd_pipeline_default_policy.hpp:1854
static CK_TILE_HOST_DEVICE constexpr auto MakeVRegBlockDescriptor()
Definition block_fmha_bwd_pipeline_default_policy.hpp:1071
static CK_TILE_HOST_DEVICE constexpr auto GetAlignmentK()
Definition block_fmha_bwd_pipeline_default_policy.hpp:228
static CK_TILE_HOST_DEVICE constexpr auto GetAlignmentOGrad()
Definition block_fmha_bwd_pipeline_default_policy.hpp:267
static CK_TILE_HOST_DEVICE constexpr index_t GetSmemSizeQ()
Definition block_fmha_bwd_pipeline_default_policy.hpp:1818
static CK_TILE_HOST_DEVICE constexpr auto GetAlignmentVGrad()
Definition block_fmha_bwd_pipeline_default_policy.hpp:317
static CK_TILE_DEVICE constexpr void SGradTFromGemm2CToGemm3A(SGradTOutTensor &dst_out, const SGradInTensor &ds_in)
Definition block_fmha_bwd_pipeline_default_policy.hpp:1726
static CK_TILE_HOST_DEVICE constexpr index_t GetSmemSizeK()
Definition block_fmha_bwd_pipeline_default_policy.hpp:1836
static CK_TILE_HOST_DEVICE constexpr auto MakeOGradRegSliceBlockDescriptor()
Definition block_fmha_bwd_pipeline_default_policy.hpp:1465
static CK_TILE_HOST_DEVICE constexpr auto MakeXLdsBlockDescriptor()
Definition block_fmha_bwd_pipeline_default_policy.hpp:854
static CK_TILE_HOST_DEVICE constexpr auto MakeBiasLdsBlockDescriptor()
Definition block_fmha_bwd_pipeline_default_policy.hpp:1798
static CK_TILE_HOST_DEVICE constexpr auto GetSmemKPackBias()
Definition block_fmha_bwd_pipeline_default_policy.hpp:769
static CK_TILE_HOST_DEVICE constexpr auto MakeOGradTLdsReadBlockDescriptor()
Definition block_fmha_bwd_pipeline_default_policy.hpp:1547
static CK_TILE_HOST_DEVICE constexpr index_t GetSmemSizeD()
Definition block_fmha_bwd_pipeline_default_policy.hpp:1863
static CK_TILE_HOST_DEVICE constexpr auto GetSmemKPackSGrad()
Definition block_fmha_bwd_pipeline_default_policy.hpp:793
static CK_TILE_HOST_DEVICE constexpr auto MakeVLdsWriteBlockDescriptor()
Definition block_fmha_bwd_pipeline_default_policy.hpp:1046
static CK_TILE_HOST_DEVICE constexpr auto GetAlignmentPostQGrad()
Definition block_fmha_bwd_pipeline_default_policy.hpp:383
static constexpr auto swap_last2
Definition block_fmha_bwd_pipeline_default_policy.hpp:25
static CK_TILE_HOST_DEVICE constexpr auto MakeXLdsBlockDescriptor()
Definition block_fmha_bwd_pipeline_default_policy.hpp:801
static CK_TILE_HOST_DEVICE constexpr auto GetSmemKPackK()
Definition block_fmha_bwd_pipeline_default_policy.hpp:751
static CK_TILE_HOST_DEVICE constexpr auto GetSmemKPackOGrad()
Definition block_fmha_bwd_pipeline_default_policy.hpp:781
static CK_TILE_HOST_DEVICE constexpr auto GetAlignmentQ()
Definition block_fmha_bwd_pipeline_default_policy.hpp:209
static CK_TILE_HOST_DEVICE constexpr auto GetTransposedAlignmentBias()
Definition block_fmha_bwd_pipeline_default_policy.hpp:364
static CK_TILE_HOST_DEVICE constexpr index_t GetSmemSize()
Definition block_fmha_bwd_pipeline_default_policy.hpp:1921
static CK_TILE_HOST_DEVICE constexpr auto MakePTRegSliceBlockDescriptor()
Definition block_fmha_bwd_pipeline_default_policy.hpp:1598
static CK_TILE_HOST_DEVICE constexpr auto MakeSGradLdsBlockDescriptor()
Definition block_fmha_bwd_pipeline_default_policy.hpp:1632
static CK_TILE_HOST_DEVICE constexpr index_t GetSmemSizeQT()
Definition block_fmha_bwd_pipeline_default_policy.hpp:1826
static CK_TILE_HOST_DEVICE constexpr auto MakePostQGradDramTileDistribution()
Definition block_fmha_bwd_pipeline_default_policy.hpp:711
static CK_TILE_HOST_DEVICE constexpr auto MakeVDramTileDistribution()
Definition block_fmha_bwd_pipeline_default_policy.hpp:438
static CK_TILE_HOST_DEVICE constexpr index_t GetSmemSizeV()
Definition block_fmha_bwd_pipeline_default_policy.hpp:1872
static CK_TILE_HOST_DEVICE constexpr auto GetTransposedAlignmentOGrad()
Definition block_fmha_bwd_pipeline_default_policy.hpp:352
static CK_TILE_HOST_DEVICE constexpr auto MakeXTLdsBlockDescriptor()
Definition block_fmha_bwd_pipeline_default_policy.hpp:873
static CK_TILE_HOST_DEVICE constexpr index_t GetSmemSizeOGradT()
Definition block_fmha_bwd_pipeline_default_policy.hpp:1890
static CK_TILE_HOST_DEVICE constexpr auto MakeQDramTileDistribution()
Definition block_fmha_bwd_pipeline_default_policy.hpp:486
static CK_TILE_HOST_DEVICE constexpr auto MakePreODramTileDistribution()
Definition block_fmha_bwd_pipeline_default_policy.hpp:661
static CK_TILE_HOST_DEVICE constexpr auto MakeXTLdsBlockDescriptor()
Definition block_fmha_bwd_pipeline_default_policy.hpp:863
static CK_TILE_HOST_DEVICE constexpr auto GetSGradTQTBlockGemm()
Definition block_fmha_bwd_pipeline_default_policy.hpp:138
static CK_TILE_HOST_DEVICE constexpr auto MakeKLdsWriteBlockDescriptor()
Definition block_fmha_bwd_pipeline_default_policy.hpp:988
static CK_TILE_HOST_DEVICE constexpr auto MakeLSEDDramTileDistribution()
Definition block_fmha_bwd_pipeline_default_policy.hpp:586
static CK_TILE_HOST_DEVICE constexpr index_t GetSmemSizeSGrad()
Definition block_fmha_bwd_pipeline_default_policy.hpp:1899
static CK_TILE_DEVICE constexpr auto GetPTOGradTBlockGemm()
Definition block_fmha_bwd_pipeline_default_policy.hpp:66
static CK_TILE_HOST_DEVICE constexpr auto GetSGradKTBlockGemm()
Definition block_fmha_bwd_pipeline_default_policy.hpp:176
static CK_TILE_HOST_DEVICE constexpr auto MakeBiasTileDistribution()
Definition block_fmha_bwd_pipeline_default_policy.hpp:614
static CK_TILE_HOST_DEVICE constexpr auto GetTransposedAlignmentQ()
Definition block_fmha_bwd_pipeline_default_policy.hpp:329
static CK_TILE_HOST_DEVICE constexpr auto MakeQTRegSliceBlockDescriptor()
Definition block_fmha_bwd_pipeline_default_policy.hpp:1319
static CK_TILE_HOST_DEVICE constexpr auto MakeBiasSTileDistribution()
Definition block_fmha_bwd_pipeline_default_policy.hpp:1811
static CK_TILE_HOST_DEVICE constexpr auto GetSmemKPackOGradT()
Definition block_fmha_bwd_pipeline_default_policy.hpp:787
static CK_TILE_HOST_DEVICE constexpr auto MakeQTLdsReadBlockDescriptor()
Definition block_fmha_bwd_pipeline_default_policy.hpp:1302
static CK_TILE_HOST_DEVICE constexpr auto MakeShuffledQLdsWriteBlockDescriptor()
Definition block_fmha_bwd_pipeline_default_policy.hpp:1269
static CK_TILE_HOST_DEVICE constexpr auto GetSmemKPackV()
Definition block_fmha_bwd_pipeline_default_policy.hpp:763
static CK_TILE_DEVICE constexpr void PTFromGemm0CToGemm1A(PTOutTensor &pt_out, const PInTensor &p_in)
Definition block_fmha_bwd_pipeline_default_policy.hpp:1676
static CK_TILE_HOST_DEVICE constexpr auto MakePreXDramTileDistribution()
Definition block_fmha_bwd_pipeline_default_policy.hpp:640
static CK_TILE_HOST_DEVICE constexpr auto GetSmemKPackQT()
Definition block_fmha_bwd_pipeline_default_policy.hpp:745
static CK_TILE_HOST_DEVICE constexpr auto MakeQLdsBlockDescriptor()
Definition block_fmha_bwd_pipeline_default_policy.hpp:1199
static CK_TILE_HOST_DEVICE constexpr auto MakeKTRegBlockDescriptor()
Definition block_fmha_bwd_pipeline_default_policy.hpp:1165
static CK_TILE_HOST_DEVICE constexpr auto MakeKRegBlockDescriptor()
Definition block_fmha_bwd_pipeline_default_policy.hpp:1013
static CK_TILE_HOST_DEVICE constexpr auto MakeShuffledQRegWriteBlockDescriptor()
Definition block_fmha_bwd_pipeline_default_policy.hpp:1257
static CK_TILE_HOST_DEVICE constexpr auto MakeShuffledBiasTileDistribution()
Definition block_fmha_bwd_pipeline_default_policy.hpp:1776
static CK_TILE_HOST_DEVICE constexpr auto GetAlignmentKGrad()
Definition block_fmha_bwd_pipeline_default_policy.hpp:305
static CK_TILE_HOST_DEVICE constexpr auto MakeLSEDLdsWriteBlockDescriptor()
Definition block_fmha_bwd_pipeline_default_policy.hpp:1388
static CK_TILE_HOST_DEVICE constexpr auto GetSmemKPackQ()
Definition block_fmha_bwd_pipeline_default_policy.hpp:739
static CK_TILE_HOST_DEVICE constexpr auto GetAlignmentBias()
Definition block_fmha_bwd_pipeline_default_policy.hpp:286
static CK_TILE_HOST_DEVICE constexpr index_t GetSmemSizeKT()
Definition block_fmha_bwd_pipeline_default_policy.hpp:1845
static CK_TILE_HOST_DEVICE constexpr auto MakePostQGradAccDramTileDistribution()
Definition block_fmha_bwd_pipeline_default_policy.hpp:683
static CK_TILE_HOST_DEVICE constexpr auto GetAlignmentV()
Definition block_fmha_bwd_pipeline_default_policy.hpp:247
Definition block_gemm_areg_breg_creg_v1_custom_policy.hpp:16
Definition block_gemm_areg_breg_creg_v1.hpp:18
Definition block_gemm_problem.hpp:18
Definition tile_gemm_shape.hpp:17
Definition tile/core/container/sequence.hpp:49
Definition tile/core/utility/functional.hpp:43
Definition tile_distribution_encoding.hpp:26
Definition tile/core/container/tuple.hpp:192