mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp Source File

mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp Source File#

Composable Kernel: mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp Source File
mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include "ck_tile/core.hpp"
11
12namespace ck_tile {
13
14template <typename ADataType_,
15 typename BDataType_,
16 typename CDataType_,
17 typename BlockGemmShape_,
18 typename Traits_,
20 bool HasHotLoop_ = true,
22 typename ComputeDataType_ = ADataType_>
24 ADataType_,
25 CDataType_,
26 BlockGemmShape_,
27 Traits_,
28 Scheduler_,
29 HasHotLoop_,
30 TailNum_,
31 ComputeDataType_>
32{
33 using BlockGemmShape = BlockGemmShape_;
34
35 using QuantType = BDataType_;
36
37 static constexpr index_t flatNPerWarp = BlockGemmShape::flatNPerWarp;
38
39 static constexpr int MXF4ScaleGranularityK = 32;
40
41 static constexpr int ContinuousKPerThread = 32; // it's fixed for fp4
42 static constexpr int ContinuousScaleNPerThread = 2; // it's fixed for fp4
43 static constexpr int ContinuousScaleKPerThread = 2; // it's fixed for fp4
44 static constexpr index_t flatKPerWarp = 64 * ContinuousKPerThread;
45};
46
47template <typename Problem, typename PipelinePolicy = F16xMXF4FlatmmPipelineAgBgCrPolicy>
49 : FlatmmPipelineAGmemBGmemCRegV1<Problem, PipelinePolicy>
50{
52
57
59 static_assert(sizeof(ADataType) >= sizeof(BDataType));
60
64
67
68 static constexpr auto config =
69 BlockFlatmm::BlockPolicy::template GetWarpGemmMWarpNWarp<Problem>();
70
71 using WG = remove_cvref_t<decltype(config.template at<0>())>;
72
73 static constexpr index_t DsWritePreIssue = 3; // default 2, ds write at MIter - 2
74 static constexpr index_t DsReadPreload = 2; // default 2, preload 2 ds read
75
76 static constexpr index_t BlockSize = Problem::kBlockSize;
77 static constexpr index_t WaveSize = get_warp_size();
78
79 static constexpr index_t kMPerBlock = BlockGemmShape::kM;
80 static constexpr index_t kNPerBlock = BlockGemmShape::kN;
81 static constexpr index_t kKPerBlock = BlockGemmShape::kK;
82
83 static constexpr index_t flatKPerWarp = Problem::flatKPerWarp;
84 static constexpr index_t flatNPerWarp = Problem::flatNPerWarp;
85
86 static constexpr index_t GetVectorSizeA() { return Problem::VectorSizeA; }
87 static constexpr index_t GetVectorSizeB() { return 32; /* fixed for fp4 shuffle layout*/ }
88 static constexpr index_t GetVectorSizeC() { return Problem::VectorSizeC; }
89
90 static constexpr bool kPadM = Problem::kPadM;
91 static constexpr bool kPadN = Problem::kPadN;
92 static constexpr bool kPadK = Problem::kPadK;
93
94 static constexpr index_t kLdsAlignmentInBytes = 16;
95 static constexpr index_t NumWaveGroups = Problem::NumWaveGroups;
96 static constexpr bool UsePersistentKernel = Problem::Traits::UsePersistentKernel;
97
98 static constexpr auto I0 = number<0>();
99 static constexpr auto I1 = number<1>();
100 static constexpr auto I2 = number<2>();
101 static constexpr auto idxM = I0;
102 static constexpr auto idxN = I1;
103 static constexpr auto idxK = I2;
107
108 static constexpr index_t MWarp = config.template at<1>();
109 static constexpr index_t NWarp = config.template at<2>();
110
111 static constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WG::kM);
112 static constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WG::kN);
113 static constexpr index_t KIterPerWarp = kKPerBlock / WG::kK;
114
117
120
121 static constexpr int MXFP4PackedSize = 2;
122 static constexpr index_t AK1 = Problem::VectorLoadSize / sizeof(ADataType);
123 static constexpr index_t BK1 = Problem::VectorLoadSize / sizeof(BDataType) * MXFP4PackedSize;
127
128 static constexpr int ContinuousKPerThread = Problem::ContinuousKPerThread;
129 static constexpr int ContinuousScaleNPerThread = Problem::ContinuousScaleNPerThread;
130 static constexpr int ContinuousScaleKPerThread = Problem::ContinuousScaleKPerThread;
131
132 static constexpr int ScaleKFlatPerWarp =
134
135 static constexpr int XDLK_PerThread =
136 WarpTile::at(I2) / (get_warp_size() / WarpTile::at(I1)); // 8
137
138 static constexpr int XDL_PerWeightK = 4; // 4
140 static constexpr int XDL_PerScaleN = ContinuousScaleNPerThread; // 2
141 static_assert(XDL_PerScaleK % XDL_PerWeightK == 0);
142 static_assert(KIterPerWarp % XDL_PerScaleK == 0);
143 static_assert(NIterPerWarp % XDL_PerScaleN == 0);
144
145 static constexpr int MXFP4KPerWarp = KIterPerWarp / XDL_PerWeightK;
146 static constexpr int ScaleKPerWarp = KIterPerWarp / XDL_PerScaleK;
147 static constexpr int ScaleNPerWarp = NIterPerWarp / XDL_PerScaleN;
148
150
151 static constexpr bool HasHotLoop = Problem::HasHotLoop;
152 static constexpr auto TailNum = Problem::TailNum;
153
154#ifdef __gfx942__
155 static constexpr index_t mfma_per_wg = 2;
156#else
157 static constexpr index_t mfma_per_wg = 1;
158#endif
159 static constexpr index_t dsread_per_wg =
160 WG::kM * WG::kK * sizeof(ADataType) / WaveSize / Problem::VectorLoadSize;
161 static_assert((WG::kM * WG::kK * sizeof(ADataType) / WaveSize) % Problem::VectorLoadSize == 0);
162
167 static constexpr index_t Aload_rep = dswrite_rep;
168 static constexpr index_t Bload_num_perK = kNPerBlock * WG::kK / NWarp / BK1 / WaveSize;
170 static constexpr index_t ScaleBload_num =
172 WaveSize; // BlockN * BlockK / NWarp / ScalePerK / ScaleB_K1 / wavesize
173 static constexpr index_t Bload_total_num =
176 static constexpr index_t HalfMIter = (MIterPerWarp + 1) / 2;
177 static constexpr index_t Bload_rep = (Bload_num_perK + HalfMIter - 1) / HalfMIter;
178
182
183 // For the basic gemm pipelien DoubleSmemBuffer set to be false naturally.
184 static constexpr bool DoubleSmemBuffer = false;
185
186 CK_TILE_HOST_DEVICE static constexpr auto
187 SchedulerPerM(index_t dsread_perM, index_t dswrite_perM, index_t load_perM)
188 {
189#if CKTILE_FLATMM_USE_BUFFER_LOAD_LDS
190 // GFX950 use BUFFER_LOAD_LDS to fill lds_buffer_A.
191 // There is no separate DS_WRITE instruction at all.
192 dswrite_perM = 0;
193#endif
194 // Init inst order
195 index_t max_data_inst = dsread_perM > load_perM
196 ? (dsread_perM > dswrite_perM ? dsread_perM : dswrite_perM)
197 : (load_perM > dswrite_perM ? load_perM : dswrite_perM);
198 index_t sum_data_inst = dsread_perM + load_perM + dswrite_perM;
199 index_t round_data_inst = (sum_data_inst + mfma_perM_perK - 1) / mfma_perM_perK;
200
201 index_t inst_order[NIterPerWarp * 10];
202 _Pragma("unroll") for(int idx = 0; idx < NIterPerWarp * 10; idx++) { inst_order[idx] = 0; }
203
204 index_t index = 0;
205 _Pragma("unroll") for(int j = 0; j < max_data_inst; j++)
206 {
207 if(dswrite_perM > j)
208 {
209 inst_order[index] = 1;
210 index++;
211 }
212 if(load_perM > j)
213 {
214 inst_order[index] = 2;
215 index++;
216 }
217 if(dsread_perM > j)
218 {
219 inst_order[index] = 3;
220 index++;
221 }
222 }
223
224 // Schedule IGLP
225 _Pragma("unroll") for(int j = 0; j < mfma_perM_perK; j++)
226 {
227 index_t inst_idx = 0;
228 if(j == 0)
229 ;
230 else if(j == 1)
231 inst_idx = mfma_perM_perK == 2 ? 1 : mfma_perM_perK - 2;
232 else if(j == 2)
233 inst_idx = mfma_perM_perK - 1;
234 else
235 inst_idx = mfma_perM_perK - j;
236
237 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
238
239 _Pragma("unroll") for(int r = 0; r < round_data_inst; r++)
240 {
241 if(r % 2 == 0)
242 {
243 if(inst_order[inst_idx + r * mfma_perM_perK] == 1)
244 {
245 __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
246 }
247 if(inst_order[inst_idx + r * mfma_perM_perK] == 2)
248 {
249 __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
250 }
251 if(inst_order[inst_idx + r * mfma_perM_perK] == 3)
252 {
253 __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
254 }
255 }
256 else
257 {
258 if(inst_order[(r + 1) * mfma_perM_perK - 1 - inst_idx] == 1)
259 {
260 __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
261 }
262 if(inst_order[(r + 1) * mfma_perM_perK - 1 - inst_idx] == 2)
263 {
264 __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
265 }
266 if(inst_order[(r + 1) * mfma_perM_perK - 1 - inst_idx] == 3)
267 {
268 __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
269 }
270 }
271 }
272 }
273 }
274 CK_TILE_HOST_DEVICE static constexpr auto HotLoopScheduler()
275 {
276 // Keypoint of pipeline optimize is workload balance in time
277 // instruction schedule example(128X256X256, 1X4, 16X16X128):
278 // Iter MNK MFMA ds_read ds_write A_load b_load
279 // -1 M6N0: 57 - 8 - -
280 // -1 M6N1: 58 1 - - -
281 // -1 M6N2: 59 - - 7 -
282 // -1 M6N3: 60 2 - - -
283 // -1 M7N0: 61 - - - -
284 // -1 M7N1: 62 3 - - -
285 // -1 M7N2: 63 - - 8 -
286 // -1 M7N3: 64 4 - - -
287 // 0 M0N0K0: 1 - - - 1
288 // 0 M0N1: 2 5 - - -
289 // 0 M0N2: 3 - - - 2
290 // 0 M0N3: 4 6 - - -
291 // 0 M1N0: 5 - - - 3
292 // 0 M1N1: 6 7 - - -
293 // 0 M1N2: 7 - - - 4
294 // 0 M1N3: 8 8 - - -
295 // 0 M2N0: 9 - - - 5
296 // 0 M2N1: 10 9 - - -
297 // 0 M2N2: 11 - - - 6
298 // 0 M2N3: 12 10 - - -
299 // 0 M3N0: 13 - 1 - 7
300 // 0 M3N1: 14 11 - - -
301 // 0 M3N2: 15 - - - 8
302 // 0 M3N3: 16 12 - - -
303 // 0 M4N0: 17 - 2 - -
304 // 0 M4N1: 18 13 - - -
305 // 0 M4N2: 19 - - 1 -
306 // 0 M4N3: 20 14 - - -
307 // 0 M5N0: 21 - 3 - -
308 // 0 M5N1: 22 15 - - -
309 // 0 M5N2: 23 - - 2 -
310 // 0 M5N3: 24 16 - - -
311 // 0 M6N0: 25 - 4 - -
312 // 0 M6N1: 26 17 - - -
313 // 0 M6N2: 27 - - 3 -
314 // 0 M6N3: 28 18 - - -
315 // 0 M7N0: 29 - - - -
316 // 0 M7N1: 30 19 - - -
317 // 0 M7N2: 31 - - 4 -
318 // 0 M7N3: 32 20 - - -
319 // 0 M0N0K1: 33 - - - 9
320 // 0 M0N1: 34 21 - - -
321 // 0 M0N2: 35 - - - 10
322 // 0 M0N3: 36 22 - - -
323 // 0 M1N0: 37 - - - 11
324 // 0 M1N1: 38 23 - - -
325 // 0 M1N2: 39 - - - 12
326 // 0 M1N3: 40 24 - - -
327 // 0 M2N0: 41 - - - 13
328 // 0 M2N1: 42 25 - - -
329 // 0 M2N2: 43 - - - 14
330 // 0 M2N3: 44 26 - - -
331 // 0 M3N0: 45 - 5 - 15
332 // 0 M3N1: 46 27 - - -
333 // 0 M3N2: 47 - - - 16
334 // 0 M3N3: 48 28 - - -
335 // 0 M4N0: 49 - 6 - -
336 // 0 M4N1: 50 29 - - -
337 // 0 M4N2: 51 - - 5 -
338 // 0 M4N3: 52 30 - - -
339 // 0 M5N0: 53 - 7 - -
340 // 0 M5N1: 54 31 - - -
341 // 0 M5N2: 55 - - 6 -
342 // 0 M5N3: 56 32 - - -
343 // 0 M6N0: 57 - 8 - -
344 // 0 M6N1: 58 1 - - -
345 // 0 M6N2: 59 - - 7 -
346 // 0 M6N3: 60 2 - - -
347 // 0 M7N0: 61 - - - -
348 // 0 M7N1: 62 3 - - -
349 // 0 M7N2: 63 - - 8 -
350 // 0 M7N3: 64 4 - - -
351
352 _Pragma("unroll") for(int kIter = 0; kIter < KIterPerWarp; kIter++)
353 {
354 _Pragma("unroll") for(int mIter = 0; mIter < MIterPerWarp; mIter++)
355 {
356 index_t dsread_perM = 0;
357 index_t dswrite_perM = 0;
358 index_t load_perM = 0;
359
360 // Calculate ds_read number per M
361 dsread_perM = dsread_per_wg;
362
363 // Calculate buffer_load number per M
364 if(mIter < HalfMIter)
365 {
366 load_perM =
367 ((Aload_num_perK - (MIterPerWarp - 1 - mIter) * Aload_rep) > 0 ? Aload_rep
368 : 0) +
369 ((Bload_num_perK - (HalfMIter - 1 - mIter) * Bload_rep) > 0 ? Bload_rep
370 : 0);
371 }
372 else
373 {
374 load_perM = (Aload_num_perK - (MIterPerWarp - 1 - mIter) * Aload_rep) > 0
375 ? Aload_rep
376 : 0;
377 }
378 if((kIter % KPerScaleLoad == 0) && (mIter == 0))
379 {
380 load_perM = load_perM + 1;
381 }
382 SchedulerPerM(dsread_perM, dswrite_perM, load_perM);
383 }
384 }
385 // Add Aload when Aload data > needed
386 if(Aload_num_perK == 0)
387 __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
388 __builtin_amdgcn_sched_barrier(0);
389 }
390
392 {
393 _Pragma("unroll") for(int kIter = 0; kIter < KIterPerWarp; kIter++)
394 {
395 _Pragma("unroll") for(int mIter = 0; mIter < MIterPerWarp; mIter++)
396 {
397 index_t dsread_perM = 0;
398 index_t dswrite_perM = 0;
399 index_t load_perM = 0;
400
401 // Calculate ds_read number per M
402 dsread_perM = dsread_per_wg;
403
404 // Calculate buffer_load number per M
405 if(mIter < HalfMIter)
406 {
407 load_perM =
408 ((Bload_num_perK - (HalfMIter - 1 - mIter) * Bload_rep) > 0 ? Bload_rep
409 : 0);
410 }
411 SchedulerPerM(dsread_perM, dswrite_perM, load_perM);
412 }
413 }
414 __builtin_amdgcn_sched_barrier(0);
415 }
416
418 {
419 _Pragma("unroll") for(int kIter = 0; kIter < KIterPerWarp; kIter++)
420 {
421 _Pragma("unroll") for(int mIter = 0; mIter < MIterPerWarp; mIter++)
422 {
423 index_t dsread_perM = 0;
424 index_t dswrite_perM = 0;
425 index_t load_perM = 0;
426
427 // Calculate ds_read number per M
428 if((kIter * MIterPerWarp + mIter) < (KIterPerWarp * MIterPerWarp - m_preload))
429 dsread_perM = dsread_per_wg;
430
431 SchedulerPerM(dsread_perM, dswrite_perM, load_perM);
432 }
433 }
434 // __builtin_amdgcn_sched_barrier(0);
435 }
436
438 {
439 return PipelinePolicy::template MakeADramTileDistribution<Problem>();
440 }
441
442 template <typename ADramBlockWindowTmp,
443 typename AElementFunction,
444 typename BFlatBlockWindowTmp,
445 typename DequantBFlatWindow>
446 CK_TILE_HOST_DEVICE auto operator()(ADramBlockWindowTmp a_copy_dram_window_,
447 const AElementFunction& a_element_func,
448 const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp,
449 const DequantBFlatWindow& scale_b_flat_window,
450 const index_t num_loop,
451 const index_t k_padded_zeros,
452 void* p_smem_ping,
453 void* p_smem_pong) const
454 {
455 static_assert(
456 std::is_same_v<ADataType, remove_cvref_t<typename ADramBlockWindowTmp::DataType>>,
457 "wrong!");
458
459 static_assert(kMPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<0>{}],
460 "wrong!");
461 static_assert(kKPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
462 "wrong!");
463
464 constexpr auto MIter_2nd_last = (MIterPerWarp >= 2) ? MIterPerWarp - 2 : MIterPerWarp - 1;
465 const index_t iMWarp = get_warp_id() / NWarp;
466
467 using CWarpDstr = typename WG::CWarpDstr;
468 using CWarpTensor = typename WG::CWarpTensor;
469
470 constexpr auto c_warp_y_lengths =
471 to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
472 constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
473
474 __builtin_amdgcn_sched_barrier(0);
475
476 auto a_copy_dram_window = replace_bottom_tensor_view(
477 PipelinePolicy::template TransformF16xF4_ATensorView<Problem>(
478 a_copy_dram_window_.get_bottom_tensor_view()),
479 a_copy_dram_window_);
480
481 // A tile in LDS
482 ADataType* p_a_lds_ping = static_cast<ADataType*>(p_smem_ping);
483 ADataType* p_a_lds_pong = static_cast<ADataType*>(p_smem_pong);
484
485 constexpr auto write_a_lds_block_desc =
486 PipelinePolicy::template MakeF16xF4_WriteALdsBlockDescriptor<Problem>();
487 constexpr auto read_a_lds_block_desc =
488 PipelinePolicy::template MakeF16xF4_ReadALdsBlockDescriptor<Problem>();
489
490 auto write_a_lds_block_ping =
491 make_tensor_view<address_space_enum::lds>(p_a_lds_ping, write_a_lds_block_desc);
492 auto write_a_lds_block_pong =
493 make_tensor_view<address_space_enum::lds>(p_a_lds_pong, write_a_lds_block_desc);
494 auto read_a_lds_block_ping =
495 make_tensor_view<address_space_enum::lds>(p_a_lds_ping, read_a_lds_block_desc);
496 auto read_a_lds_block_pong =
497 make_tensor_view<address_space_enum::lds>(p_a_lds_pong, read_a_lds_block_desc);
498
499 auto a_copy_lds_window_ping =
500 make_tile_window(write_a_lds_block_ping,
502 {0, 0},
503 PipelinePolicy::template MakeADramTileDistribution<Problem>());
504 auto a_copy_lds_window_pong =
505 make_tile_window(write_a_lds_block_pong,
507 {0, 0},
508 PipelinePolicy::template MakeADramTileDistribution<Problem>());
509
510 // ping-pong window for A LDS
511 auto a_warp_window_ping_tmp =
512 make_tile_window(read_a_lds_block_ping,
514 {iMWarp * WG::kM, 0},
515 PipelinePolicy::template MakeF16xF4_ALDS_TileDistribution<Problem>());
516 auto a_warp_window_pong_tmp =
517 make_tile_window(read_a_lds_block_pong,
519 {iMWarp * WG::kM, 0},
520 PipelinePolicy::template MakeF16xF4_ALDS_TileDistribution<Problem>());
521
523 statically_indexed_array<decltype(a_warp_window_ping_tmp), KIterPerWarp>,
525 a_warp_windows_ping;
526
528 statically_indexed_array<decltype(a_warp_window_pong_tmp), KIterPerWarp>,
530 a_warp_windows_pong;
531
532 auto A_Lds_Stride = 8;
533 static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
534 static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
535 a_warp_windows_ping(mIter)(kIter) = a_warp_window_ping_tmp;
536 a_warp_windows_pong(mIter)(kIter) = a_warp_window_pong_tmp;
537
538 auto weight_k_idx = kIter / number<XDL_PerWeightK>{};
539 auto weight_k_rank = kIter % number<XDL_PerWeightK>{};
541 a_warp_windows_ping(mIter)(kIter),
542 {mIter * MPerBlockPerIter,
543 weight_k_rank * A_Lds_Stride + weight_k_idx * XDL_PerWeightK * WG::kK});
545 a_warp_windows_pong(mIter)(kIter),
546 {mIter * MPerBlockPerIter,
547 weight_k_rank * A_Lds_Stride + weight_k_idx * XDL_PerWeightK * WG::kK});
548 });
549 });
550
551 // Block GEMM
552 auto block_flatmm = BlockFlatmm();
553 // Acc register tile
554 auto c_block_tile = block_flatmm.MakeCBlockTile();
555
556 // B flat DRAM window for load
557 auto b_flat_distribution =
558 PipelinePolicy::template MakeFp4BFlatDramTileDistribution<Problem>();
559 auto scale_b_flat_distribution =
560 PipelinePolicy::template MakeFp4ScaleBFlatDramTileDistribution<Problem>();
561
562 auto b_flat_dram_window = make_tile_window(
563 b_flat_dram_block_window_tmp.get_bottom_tensor_view(), // from kernel gemm_pad_views
565 b_flat_dram_block_window_tmp.get_window_origin(),
566 b_flat_distribution);
567
568 auto scale_b_flat_dram_window = make_tile_window(
569 scale_b_flat_window.get_bottom_tensor_view(), // from kernel gemm_pad_views
571 scale_b_flat_window.get_window_origin(),
572 scale_b_flat_distribution);
573
574 using MXFP4_Buffer = decltype(load_tile(b_flat_dram_window));
575 // use v4i32 as the data type between basicblock to avoid unpack and repack operation.
576 using V4UInt_Buffer = thread_buffer<uint32_t, XDL_PerWeightK>;
577 union UnionB
578 {
579 V4UInt_Buffer u = 0;
580 MXFP4_Buffer mxfp4;
581 } ub;
582
583 // pingpong buffer for B
585 statically_indexed_array<decltype(b_flat_dram_window), MXFP4KPerWarp>,
587 b_flat_dram_windows;
590 b_warp_tensor_ping;
593 b_warp_tensor_pong;
594
596 statically_indexed_array<decltype(scale_b_flat_dram_window), ScaleKPerWarp>,
598 scale_b_flat_dram_windows;
600 statically_indexed_array<decltype(load_tile(scale_b_flat_dram_window)), ScaleKPerWarp>,
602 scale_b_warp_tensor_ping;
604 statically_indexed_array<decltype(load_tile(scale_b_flat_dram_window)), ScaleKPerWarp>,
606 scale_b_warp_tensor_pong;
607
608 using ABlockTile = decltype(load_tile(a_copy_dram_window));
609 ABlockTile a_block_tile;
610
611 enum
612 {
613 PrefillBeforeGemm = 1,
614 PrefillAfterGemm = 2,
615 PrefillAlways = PrefillBeforeGemm | PrefillAfterGemm,
616 };
617#if CKTILE_FLATMM_USE_BUFFER_LOAD_LDS
618 auto prefill_lds_a_stage1 =
619 [&]([[maybe_unused]] auto lds_tile_a, auto dram_tile_a, auto prefill_location) {
620 // global -> lds
621 if constexpr(prefill_location & PrefillAfterGemm)
622 async_load_tile(lds_tile_a, dram_tile_a);
623 };
624 auto prefill_lds_a_stage2 = [&](auto lds_tile_a) {
625 // async_load_fence();
626 // __builtin_amdgcn_s_waitcnt(0x03fc);
627 // data has been stored in lds, no need more operation.
628 static_assert(std::is_same_v<AElementFunction, identity>,
629 "buffer_load_lds don't support element func fot A before mfma");
630 };
631#else
632 auto prefill_lds_a_stage1 =
633 [&]([[maybe_unused]] auto lds_tile_a, auto dram_tile_a, auto prefill_location) {
634 // global -> vgpr
635 if constexpr(prefill_location & PrefillBeforeGemm)
636 a_block_tile = load_tile(dram_tile_a);
637 };
638 auto prefill_lds_a_stage2 = [&]([[maybe_unused]] auto lds_tile_a) {
639 // vgpr -> lds
640 auto a_block_tile_transformed = tile_elementwise_in(a_element_func, a_block_tile);
641 store_tile(lds_tile_a, a_block_tile_transformed);
642 };
643#endif
644
645 // HEAD
646 // Prefetch A0
647 prefill_lds_a_stage1(a_copy_lds_window_ping, a_copy_dram_window, number<PrefillAlways>{});
648
649 // move A window to next k
650 move_tile_window(a_copy_dram_window, {0, kKPerBlock});
651
652 // prefetch B
653 static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
654 static_for<0, MXFP4KPerWarp, 1>{}([&](auto kIter) {
655 if constexpr(nIter % XDL_PerScaleN == 0 && kIter % MXFP4K_PerScaleK == 0)
656 {
657 auto scale_n_iter = nIter / number<XDL_PerScaleN>{};
658 auto scale_k_iter = kIter / number<MXFP4K_PerScaleK>{};
659
660 scale_b_flat_dram_windows(scale_n_iter)(scale_k_iter) =
661 scale_b_flat_dram_window;
663 scale_b_flat_dram_windows(scale_n_iter)(scale_k_iter),
664 {scale_n_iter * NFlatPerBlockPerIter, scale_k_iter * ScaleKFlatPerWarp});
665 scale_b_warp_tensor_ping(scale_n_iter)(scale_k_iter) =
666 load_tile(scale_b_flat_dram_windows(scale_n_iter)(scale_k_iter));
667 }
668 auto packed_n_idx = nIter / number<ContinuousScaleNPerThread>{};
669 auto packed_n_rank = nIter % number<ContinuousScaleNPerThread>{};
670
671 b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
672 move_tile_window(b_flat_dram_windows(nIter)(kIter),
674 packed_n_rank,
675 kIter * KFlatPerBlockPerIter});
676
677 ub.mxfp4 = load_tile(b_flat_dram_windows(nIter)(kIter));
678 b_warp_tensor_ping(nIter)(kIter) = ub.u;
679 });
680 });
681 // move B window to next flat K
682 move_tile_window(b_flat_dram_window, {0, MXFP4KPerWarp * KFlatPerBlockPerIter});
683 move_tile_window(scale_b_flat_dram_window, {0, ScaleKPerWarp * ScaleKFlatPerWarp});
684
685 prefill_lds_a_stage2(a_copy_lds_window_ping);
686
687 __builtin_amdgcn_sched_barrier(0);
688
689 // Prefetch A1
690 prefill_lds_a_stage1(a_copy_lds_window_pong, a_copy_dram_window, number<PrefillAlways>{});
691 // move A window to next k
692 move_tile_window(a_copy_dram_window, {0, kKPerBlock});
693
694 // initialize C
695 tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
696
697 __builtin_amdgcn_s_waitcnt(Bload_total_num);
699
700 // preload A00,A10... from lds
701 statically_indexed_array<decltype(load_tile(a_warp_windows_ping(number<0>{})(number<0>{}))),
702 m_preload>
703 a_warp_tensor;
704
705 static_for<0, m_preload, 1>{}([&](auto loadIter) {
706 constexpr auto mIter = loadIter % MIterPerWarp;
707 constexpr auto kIter = loadIter / MIterPerWarp;
708 a_warp_tensor(loadIter) =
709 load_tile(a_warp_windows_ping(number<mIter>{})(number<kIter>{}));
710 });
711 __builtin_amdgcn_sched_barrier(0);
712
714
715 auto dequant_mxfp4 = [&](const auto& quant_weight_tensor,
716 const auto& scale_tensor,
717 auto xdl_nIter,
718 auto xdl_kIter) {
719 auto quant_idx_k = xdl_kIter % number<XDL_PerWeightK>{};
720
721 auto scale_idx_n = xdl_nIter % number<XDL_PerScaleN>{};
722 auto scale_idx_k = (xdl_kIter % number<XDL_PerScaleK>{}) / number<XDL_PerWeightK>{};
723 auto scale_offset = scale_idx_n + scale_idx_k * number<XDL_PerScaleN>{};
724
725 auto scale = scale_tensor.get_thread_buffer()[scale_offset];
726
727 constexpr int ScalarCnt = WG::BWarpTensor::get_thread_buffer_size();
728 constexpr int PackedCnt = ScalarCnt / MXFP4PackedSize;
729 constexpr int float_mantissa = 23;
730
731 uint32_t uscale = uint32_t(scale.data) << float_mantissa;
732
733 using ComputeV2Type =
734 std::conditional_t<std::is_same_v<ComputeType, half_t>, fp16x2_t, bf16x2_t>;
735
736#if defined(__gfx950__)
737 auto pk_mxfp4x4_to_compute_v2 = [](auto pk_mxfp4x4, float fscale, auto byte_idx) {
738 if constexpr(std::is_same_v<ComputeType, half_t>)
739 {
740 return __builtin_amdgcn_cvt_scalef32_pk_f16_fp4(
741 pk_mxfp4x4, fscale, int(byte_idx));
742 }
743 else if constexpr(std::is_same_v<ComputeType, bf16_t>)
744 {
745 return __builtin_amdgcn_cvt_scalef32_pk_bf16_fp4(
746 pk_mxfp4x4, fscale, int(byte_idx));
747 }
748 else
749 {
750 static_assert(sizeof(pk_mxfp4x4) == 0, "unsupported compute type");
751 }
752 };
753 static_for<0, PackedCnt, 1>{}([&](auto i) {
754 dequant_B_n[xdl_nIter].get_thread_buffer().template set_as<ComputeV2Type>(
755 i,
756 pk_mxfp4x4_to_compute_v2(
757 quant_weight_tensor[quant_idx_k], bit_cast<float>(uscale), i));
758 });
759#else
760 auto pk_mxfp4_to_compute_v2 = [](auto pk_mxfp4, float fscale) {
761 if constexpr(std::is_same_v<ComputeType, half_t>)
762 {
763 return pk_fp4_to_fp16x2(pk_mxfp4, fscale);
764 }
765 else if constexpr(std::is_same_v<ComputeType, bf16_t>)
766 {
767 return pk_fp4_to_bf16x2(pk_mxfp4, fscale);
768 }
769 else
770 {
771 static_assert(sizeof(pk_mxfp4) == 0, "unsupported compute type");
772 }
773 };
774 static_for<0, PackedCnt, 1>{}([&](auto i) {
775 dequant_B_n[xdl_nIter].get_thread_buffer().template set_as<ComputeV2Type>(
776 i,
777 pk_mxfp4_to_compute_v2(
778 bit_cast<thread_buffer<pk_fp4_t, 4>>(quant_weight_tensor[quant_idx_k])
779 .at(i),
780 bit_cast<float>(uscale)));
781 });
782#endif
783 };
784
785 // MAIN LOOP
786 index_t iCounter = (num_loop - 1) / 2;
787 while(iCounter > 0)
788 {
789 // prefetch B(2i+1)
790 static_for<0, MXFP4KPerWarp, 1>{}([&](auto kIter) {
791 static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
792 if constexpr(nIter % XDL_PerScaleN == 0 && kIter % MXFP4K_PerScaleK == 0)
793 {
794 auto scale_n_iter = nIter / number<XDL_PerScaleN>{};
795 auto scale_k_iter = kIter / number<MXFP4K_PerScaleK>{};
796
797 scale_b_flat_dram_windows(scale_n_iter)(scale_k_iter) =
798 scale_b_flat_dram_window;
799
800 move_tile_window(scale_b_flat_dram_windows(scale_n_iter)(scale_k_iter),
801 {scale_n_iter * NFlatPerBlockPerIter,
802 scale_k_iter * ScaleKFlatPerWarp});
803
804 scale_b_warp_tensor_pong(scale_n_iter)(scale_k_iter) =
805 load_tile(scale_b_flat_dram_windows(scale_n_iter)(scale_k_iter));
806 }
807
808 auto packed_n_idx = nIter / number<ContinuousScaleNPerThread>{};
809 auto packed_n_rank = nIter % number<ContinuousScaleNPerThread>{};
810
811 b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
812
814 b_flat_dram_windows(nIter)(kIter),
816 packed_n_rank,
817 kIter * KFlatPerBlockPerIter});
818
819 ub.mxfp4 = load_tile(b_flat_dram_windows(nIter)(kIter));
820 b_warp_tensor_pong(nIter)(kIter) = ub.u;
821 });
822 });
823
824 // Prefill A(2i+1)
825 prefill_lds_a_stage2(a_copy_lds_window_pong);
826
827 // Prefetch A(2i+2)
828 prefill_lds_a_stage1(
829 a_copy_lds_window_ping, a_copy_dram_window, number<PrefillBeforeGemm>{});
830 // GEMM 2i
831 static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
832 static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
833 constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload;
834 static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
835 // read C warp tensor from C block tensor
836 CWarpTensor c_warp_tensor;
837
838 c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data(
839 merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
840 merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
841
842 if constexpr(mIter == 0)
843 dequant_mxfp4(
844 b_warp_tensor_ping(nIter)(kIter / number<XDL_PerWeightK>{}),
845 scale_b_warp_tensor_ping(nIter / number<XDL_PerScaleN>{})(
846 kIter / number<XDL_PerScaleK>{}),
847 nIter,
848 kIter);
849
850 // warp GEMM
851 WG{}(c_warp_tensor, a_warp_tensor(number<AwarpIter>{}), dequant_B_n[nIter]);
852
853 // write C warp tensor into C block tensor
854 c_block_tile.set_y_sliced_thread_data(
855 merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
856 merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
857 c_warp_tensor.get_thread_buffer());
858 });
859 // preload next A from lds
860 if constexpr((kIter * MIterPerWarp + mIter) <
862 {
863 constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp;
864 constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp);
865 a_warp_tensor(number<AwarpIter>{}) =
866 load_tile(a_warp_windows_ping(number<AmIter>{})(number<AkIter>{}));
867 }
868
869 // barrier
870 if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last))
871 {
872 __builtin_amdgcn_s_waitcnt(Bload_total_num);
874 }
875 });
876 });
877 prefill_lds_a_stage1(
878 a_copy_lds_window_ping, a_copy_dram_window, number<PrefillAfterGemm>{});
879
880 // move A window to next k
881 move_tile_window(a_copy_dram_window, {0, kKPerBlock});
882
883 // move B window to next flat K
884 move_tile_window(b_flat_dram_window, {0, MXFP4KPerWarp * KFlatPerBlockPerIter});
885 move_tile_window(scale_b_flat_dram_window, {0, ScaleKPerWarp * ScaleKFlatPerWarp});
886
887 static_for<0, m_preload, 1>{}([&](auto loadIter) {
888 constexpr auto mIter = loadIter % MIterPerWarp;
889 constexpr auto kIter = loadIter / MIterPerWarp;
890 a_warp_tensor(loadIter) =
891 load_tile(a_warp_windows_pong(number<mIter>{})(number<kIter>{}));
892 });
894
895 // Next K
896
897 // prefetch B(2i+2)
898 static_for<0, MXFP4KPerWarp, 1>{}([&](auto kIter) {
899 static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
900 if constexpr(nIter % XDL_PerScaleN == 0 && kIter % MXFP4K_PerScaleK == 0)
901 {
902 auto scale_n_iter = nIter / number<XDL_PerScaleN>{};
903 auto scale_k_iter = kIter / number<MXFP4K_PerScaleK>{};
904
905 scale_b_flat_dram_windows(scale_n_iter)(scale_k_iter) =
906 scale_b_flat_dram_window;
907
908 move_tile_window(scale_b_flat_dram_windows(scale_n_iter)(scale_k_iter),
909 {scale_n_iter * NFlatPerBlockPerIter,
910 scale_k_iter * ScaleKFlatPerWarp});
911
912 scale_b_warp_tensor_ping(scale_n_iter)(scale_k_iter) =
913 load_tile(scale_b_flat_dram_windows(scale_n_iter)(scale_k_iter));
914 }
915
916 auto packed_n_idx = nIter / number<ContinuousScaleNPerThread>{};
917 auto packed_n_rank = nIter % number<ContinuousScaleNPerThread>{};
918
919 b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
921 b_flat_dram_windows(nIter)(kIter),
923 packed_n_rank,
924 kIter * KFlatPerBlockPerIter});
925
926 ub.mxfp4 = load_tile(b_flat_dram_windows(nIter)(kIter));
927 b_warp_tensor_ping(nIter)(kIter) = ub.u;
928 });
929 });
930
931 // Prefill A(2i+2)
932 prefill_lds_a_stage2(a_copy_lds_window_ping);
933
934 // Prefetch A(2i+3)
935 prefill_lds_a_stage1(
936 a_copy_lds_window_pong, a_copy_dram_window, number<PrefillBeforeGemm>{});
937
938 // GEMM 2i+1
939 static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
940 static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
941 constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload;
942 static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
943 // read C warp tensor from C block tensor
944 CWarpTensor c_warp_tensor;
945 c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data(
946 merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
947 merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
948
949 if constexpr(mIter == 0)
950 dequant_mxfp4(
951 b_warp_tensor_pong(nIter)(kIter / number<XDL_PerWeightK>{}),
952 scale_b_warp_tensor_pong(nIter / number<XDL_PerScaleN>{})(
953 kIter / number<XDL_PerScaleK>{}),
954 nIter,
955 kIter);
956
957 // warp GEMM
958 WG{}(c_warp_tensor, a_warp_tensor(number<AwarpIter>{}), dequant_B_n[nIter]);
959
960 // write C warp tensor into C block tensor
961 c_block_tile.set_y_sliced_thread_data(
962 merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
963 merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
964 c_warp_tensor.get_thread_buffer());
965 });
966 // preload next A from lds
967 if constexpr((kIter * MIterPerWarp + mIter) <
969 {
970 constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp;
971 constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp);
972 a_warp_tensor(number<AwarpIter>{}) =
973 load_tile(a_warp_windows_pong(number<AmIter>{})(number<AkIter>{}));
974 }
975
976 // barrier
977 if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last))
978 {
979 __builtin_amdgcn_s_waitcnt(Bload_total_num);
981 }
982 });
983 });
984 prefill_lds_a_stage1(
985 a_copy_lds_window_pong, a_copy_dram_window, number<PrefillAfterGemm>{});
986
987 // move A window to next k
988 move_tile_window(a_copy_dram_window, {0, kKPerBlock});
989 // move B window to next flat K
990 move_tile_window(b_flat_dram_window, {0, MXFP4KPerWarp * KFlatPerBlockPerIter});
991 move_tile_window(scale_b_flat_dram_window, {0, ScaleKPerWarp * ScaleKFlatPerWarp});
992
993 static_for<0, m_preload, 1>{}([&](auto loadIter) {
994 constexpr auto mIter = loadIter % MIterPerWarp;
995 constexpr auto kIter = loadIter / MIterPerWarp;
996 a_warp_tensor(loadIter) =
997 load_tile(a_warp_windows_ping(number<mIter>{})(number<kIter>{}));
998 });
1000
1001 iCounter--;
1002 }
1003
1004 // TAIL
1005 if constexpr(TailNum == TailNumber::Even)
1006 {
1007 // prefetch B(loopK)
1008 const int b_k_off = b_flat_dram_window.get_tile_distribution().calculate_index()[I1] /
1010 static_for<0, MXFP4KPerWarp, 1>{}([&](auto kIter) {
1011 static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
1012 if constexpr(nIter % XDL_PerScaleN == 0 && kIter % MXFP4K_PerScaleK == 0)
1013 {
1014 auto scale_n_iter = nIter / number<XDL_PerScaleN>{};
1015 auto scale_k_iter = kIter / number<MXFP4K_PerScaleK>{};
1016
1017 scale_b_flat_dram_windows(scale_n_iter)(scale_k_iter) =
1018 scale_b_flat_dram_window;
1019
1020 move_tile_window(scale_b_flat_dram_windows(scale_n_iter)(scale_k_iter),
1021 {scale_n_iter * NFlatPerBlockPerIter,
1022 scale_k_iter * ScaleKFlatPerWarp});
1023
1024 scale_b_warp_tensor_pong(scale_n_iter)(scale_k_iter) =
1025 load_tile(scale_b_flat_dram_windows(scale_n_iter)(scale_k_iter));
1026 }
1027 });
1028
1029 const int b_k_off_inter = kIter * kKPerBlock / MXFP4KPerWarp + b_k_off;
1030 if(b_k_off_inter < kKPerBlock - k_padded_zeros)
1031 {
1032 static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
1033 auto packed_n_idx = nIter / number<ContinuousScaleNPerThread>{};
1034 auto packed_n_rank = nIter % number<ContinuousScaleNPerThread>{};
1035
1036 b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
1037
1039 b_flat_dram_windows(nIter)(kIter),
1041 packed_n_rank,
1042 kIter * KFlatPerBlockPerIter});
1043
1044 ub.mxfp4 = load_tile(b_flat_dram_windows(nIter)(kIter));
1045 b_warp_tensor_pong(nIter)(kIter) = ub.u;
1046 });
1047 }
1048 });
1049
1050 // Prefill A(loopK)
1051 prefill_lds_a_stage2(a_copy_lds_window_pong);
1052
1053 // GEMM loopK-1
1054 static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
1055 static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
1056 constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload;
1057 static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
1058 // read C warp tensor from C block tensor
1059 CWarpTensor c_warp_tensor;
1060
1061 c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data(
1062 merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
1063 merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
1064
1065 if constexpr(mIter == 0)
1066 dequant_mxfp4(
1067 b_warp_tensor_ping(nIter)(kIter / number<XDL_PerWeightK>{}),
1068 scale_b_warp_tensor_ping(nIter / number<XDL_PerScaleN>{})(
1069 kIter / number<XDL_PerScaleK>{}),
1070 nIter,
1071 kIter);
1072
1073 // warp GEMM
1074 WG{}(c_warp_tensor, a_warp_tensor(number<AwarpIter>{}), dequant_B_n[nIter]);
1075
1076 // write C warp tensor into C block tensor
1077 c_block_tile.set_y_sliced_thread_data(
1078 merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
1079 merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
1080 c_warp_tensor.get_thread_buffer());
1081 });
1082 // preload next A from lds
1083 if constexpr((kIter * MIterPerWarp + mIter) <
1085 {
1086 constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp;
1087 constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp);
1088 a_warp_tensor(number<AwarpIter>{}) =
1089 load_tile(a_warp_windows_ping(number<AmIter>{})(number<AkIter>{}));
1090 }
1091
1092 // barrier
1093 if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last))
1094 {
1095 __builtin_amdgcn_s_waitcnt(Bload_total_num);
1097 }
1098 });
1099 });
1100
1101 static_for<0, m_preload, 1>{}([&](auto loadIter) {
1102 constexpr auto mIter = loadIter % MIterPerWarp;
1103 constexpr auto kIter = loadIter / MIterPerWarp;
1104 a_warp_tensor(loadIter) =
1105 load_tile(a_warp_windows_pong(number<mIter>{})(number<kIter>{}));
1106 });
1107
1108 __builtin_amdgcn_sched_barrier(0);
1109 // Last2ndHotLoopScheduler();
1110
1111 // GEMM loopK
1112 static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
1113 if(kIter * WG::kK < kKPerBlock - k_padded_zeros)
1114 {
1115 static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
1116 constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload;
1117 static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
1118 // read C warp tensor from C block tensor
1119 CWarpTensor c_warp_tensor;
1120
1121 c_warp_tensor.get_thread_buffer() =
1122 c_block_tile.get_y_sliced_thread_data(
1123 merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
1124 merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
1125
1126 if constexpr(mIter == 0)
1127 dequant_mxfp4(
1128 b_warp_tensor_pong(nIter)(kIter / number<XDL_PerWeightK>{}),
1129 scale_b_warp_tensor_pong(nIter / number<XDL_PerScaleN>{})(
1130 kIter / number<XDL_PerScaleK>{}),
1131 nIter,
1132 kIter);
1133
1134 // warp GEMM
1135 WG{}(c_warp_tensor,
1136 a_warp_tensor(number<AwarpIter>{}),
1137 dequant_B_n[nIter]);
1138
1139 // write C warp tensor into C block tensor
1140 c_block_tile.set_y_sliced_thread_data(
1141 merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
1142 merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
1143 c_warp_tensor.get_thread_buffer());
1144 });
1145 if constexpr((kIter * MIterPerWarp + mIter) <
1147 {
1148 constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp;
1149 constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp);
1150 a_warp_tensor(number<AwarpIter>{}) =
1151 load_tile(a_warp_windows_pong(number<AmIter>{})(number<AkIter>{}));
1152 }
1153 // barrier
1154 // if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last))
1155 // {
1156 // block_sync_lds();
1157 // }
1158 });
1159 }
1160 });
1162 }
1163 else if constexpr(TailNum == TailNumber::Odd)
1164 {
1165 // GEMM loopK
1166 static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
1167 static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
1168 constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload;
1169 static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
1170 // read C warp tensor from C block tensor
1171 CWarpTensor c_warp_tensor;
1172
1173 c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data(
1174 merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
1175 merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
1176
1177 if constexpr(mIter == 0)
1178 dequant_mxfp4(
1179 b_warp_tensor_ping(nIter)(kIter / number<XDL_PerWeightK>{}),
1180 scale_b_warp_tensor_ping(nIter / number<XDL_PerScaleN>{})(
1181 kIter / number<XDL_PerScaleK>{}),
1182 nIter,
1183 kIter);
1184 // warp GEMM
1185 WG{}(c_warp_tensor, a_warp_tensor(number<AwarpIter>{}), dequant_B_n[nIter]);
1186
1187 // write C warp tensor into C block tensor
1188 c_block_tile.set_y_sliced_thread_data(
1189 merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
1190 merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
1191 c_warp_tensor.get_thread_buffer());
1192 });
1193 // preload next A from lds
1194 if constexpr((kIter * MIterPerWarp + mIter) <
1196 {
1197 constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp;
1198 constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp);
1199 a_warp_tensor(number<AwarpIter>{}) =
1200 load_tile(a_warp_windows_ping(number<AmIter>{})(number<AkIter>{}));
1201 }
1202
1203 // barrier
1204 if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last))
1205 {
1206 __builtin_amdgcn_s_waitcnt(Bload_total_num);
1208 }
1209 });
1210 });
1212 }
1213
1214 return c_block_tile;
1215 }
1216
1217 template <typename ADramBlockWindowTmp,
1218 typename BFlatBlockWindowTmp,
1219 typename DequantBFlatWindow>
1220 CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
1221 const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp,
1222 const DequantBFlatWindow& scale_b_flat_window,
1223 const index_t num_loop,
1224 const index_t k_padded_zeros,
1225 void* p_smem_ping,
1226 void* p_smem_pong) const
1227 {
1228 return operator()(a_dram_block_window_tmp,
1229 identity{},
1230 b_flat_dram_block_window_tmp,
1231 scale_b_flat_window,
1232 num_loop,
1233 k_padded_zeros,
1234 p_smem_ping,
1235 p_smem_pong);
1236 }
1237
1238 template <typename ADramBlockWindowTmp,
1239 typename BFlatBlockWindowTmp,
1240 typename DequantBFlatWindow>
1241 CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
1242 const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp,
1243 const DequantBFlatWindow& scale_b_flat_window,
1244 const index_t num_loop,
1245 void* p_smem_ping,
1246 void* p_smem_pong) const
1247 {
1248 return operator()(a_dram_block_window_tmp,
1249 identity{},
1250 b_flat_dram_block_window_tmp,
1251 scale_b_flat_window,
1252 num_loop,
1253 0,
1254 p_smem_ping,
1255 p_smem_pong);
1256 }
1257};
1258
1259} // namespace ck_tile
#define CK_TILE_DEVICE
Definition config.hpp:41
#define CK_TILE_HOST_DEVICE
Definition config.hpp:42
Definition tile/core/algorithm/cluster_descriptor.hpp:13
_Float16 fp16x2_t
Definition half.hpp:385
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
CK_TILE_DEVICE auto async_load_tile(LdsTileWindow_ &&lds_tile, const TileWindow_ &tile_window, number< i_access >={}, bool_constant< oob_conditional_check >={})
Definition load_tile.hpp:119
CK_TILE_HOST_DEVICE constexpr index_t get_warp_size()
Definition arch.hpp:63
CK_TILE_DEVICE auto tile_elementwise_in(const InElementFunc &in_element_func, const InTensor &... in_dstr_tensors)
Definition tile_elementwise.hpp:40
CK_TILE_HOST_DEVICE constexpr auto make_tensor_view(DataType *__restrict__ p, const tensor_descriptor< Ts... > &desc)
Definition tensor_view.hpp:452
CK_TILE_HOST_DEVICE constexpr fp16x2_t pk_fp4_to_fp16x2(const pk_fp4_t &x, float scale)
Definition pk_fp4.hpp:354
TailNumber
Definition gemm_pipeline_ag_bg_cr_scheduler.hpp:21
@ Even
Definition gemm_pipeline_ag_bg_cr_scheduler.hpp:24
@ Odd
Definition gemm_pipeline_ag_bg_cr_scheduler.hpp:23
@ Full
Definition gemm_pipeline_ag_bg_cr_scheduler.hpp:39
CK_TILE_HOST_DEVICE constexpr bf16x2_t pk_fp4_to_bf16x2(const pk_fp4_t &x, float scale)
Definition pk_fp4.hpp:358
CK_TILE_DEVICE auto replace_bottom_tensor_view(const NewTensorView_ &new_tensor_view, const tile_scatter_gather< OldTensorView_, WindowLengths_, StaticTileDistribution_, StaticPageIndexArray_, StaticValidArray_, HsGatherDim, NumCoord > &tile_window)
Definition tile_scatter_gather.hpp:1043
CK_TILE_DEVICE index_t get_warp_id(bool_constant< ReturnSgpr >={})
Definition arch.hpp:104
CK_TILE_DEVICE void tile_elementwise_inout(const InOutElementFunc &inout_element_func, InOutDstrTensors &... inout_dstr_tensors)
Definition tile_elementwise.hpp:23
bfloat16_t bf16x2_t
Definition pk_fp4.hpp:24
CK_TILE_HOST_DEVICE constexpr Y bit_cast(const X &x)
Definition bit_cast.hpp:11
CK_TILE_DEVICE void block_sync_lds()
Definition arch.hpp:282
CK_TILE_HOST_DEVICE constexpr auto merge_sequences(Seqs...)
Definition tile/core/container/sequence.hpp:826
constant< v > number
Definition tile/core/numeric/integral_constant.hpp:37
CK_TILE_DEVICE constexpr auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition null_tile_window.hpp:75
CK_TILE_HOST_DEVICE constexpr auto to_sequence(tuple< number< Is >... >)
Definition tile/core/container/sequence.hpp:1055
CK_TILE_DEVICE void move_tile_window(null_tile_window< WindowLengths > &, const typename null_tile_window< WindowLengths >::BottomTensorIndex &)
Definition null_tile_window.hpp:95
typename uniform_sequence_gen< NSize, I >::type uniform_sequence_gen_t
Definition tile/core/container/sequence.hpp:1026
CK_TILE_DEVICE void store_tile(tile_window_with_static_lengths< BottomTensorView_, WindowLengths_ > &tile_window_tmp, const static_distributed_tensor< DataType_, TileDistribution_ > &dstr_tensor)
Definition store_tile.hpp:23
int32_t index_t
Definition integer.hpp:9
GemmPipelineScheduler
Definition gemm_pipeline_ag_bg_cr_scheduler.hpp:14
@ Intrawave
Definition gemm_pipeline_ag_bg_cr_scheduler.hpp:16
CK_TILE_DEVICE auto load_tile(const TileWindow_ &tile_window, number< i_access >={}, bool_constant< oob_conditional_check >={})
Definition load_tile.hpp:22
CK_TILE_HOST_DEVICE constexpr auto make_tuple(Xs &&... xs)
Definition tile/core/container/tuple.hpp:360
tuple_array< T, N > statically_indexed_array
Definition tile/core/container/statically_indexed_array.hpp:16
unsigned int uint32_t
Definition stdint.h:126
Definition mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:50
static constexpr index_t MPerBlockPerIter
Definition mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:118
static constexpr auto I1
Definition mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:99
static constexpr int ContinuousScaleNPerThread
Definition mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:129
remove_cvref_t< typename Problem::QuantType > BDataType
Definition mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:54
static constexpr index_t GetVectorSizeA()
Definition mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:86
ADataType ComputeType
Definition mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:58
static constexpr int ContinuousScaleKPerThread
Definition mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:130
static constexpr bool kPadN
Definition mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:91
static constexpr index_t KPerScaleLoad
Definition mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:175
static constexpr index_t BlockSize
Definition mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:76
static constexpr int ContinuousKPerThread
Definition mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:128
static constexpr index_t DsReadPreload
Definition mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:74
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp &a_dram_block_window_tmp, const BFlatBlockWindowTmp &b_flat_dram_block_window_tmp, const DequantBFlatWindow &scale_b_flat_window, const index_t num_loop, const index_t k_padded_zeros, void *p_smem_ping, void *p_smem_pong) const
Definition mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:1220
static constexpr index_t MIterPerWarp
Definition mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:111
static constexpr index_t KIterPerWarp
Definition mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:113
remove_cvref_t< typename Problem::CDataType > CDataType
Definition mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:55
static CK_TILE_HOST_DEVICE constexpr auto GetADramTileDistribution()
Definition mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:437
remove_cvref_t< typename Problem::BLayout > BLayout
Definition mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:62
remove_cvref_t< typename BlockGemmShape::WarpTile > WarpTile
Definition mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:106
static constexpr auto idxK
Definition mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:103
static constexpr index_t NIterPerWarp
Definition mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:112
static constexpr index_t KFlatPerBlockPerIter
Definition mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:115
static constexpr index_t AK1
Definition mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:122
static constexpr int ScaleKPerWarp
Definition mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:146
static constexpr int XDLK_PerThread
Definition mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:135
static constexpr index_t mfma_per_wg
Definition mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:157
static constexpr index_t NFlatPerBlockPerIter
Definition mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:116
static constexpr bool UsePersistentKernel
Definition mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:96
static constexpr index_t mfma_perM_perK
Definition mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:179
static CK_TILE_HOST_DEVICE constexpr auto HotLoopScheduler()
Definition mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:274
static constexpr int ScaleNPerWarp
Definition mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:147
remove_cvref_t< decltype(PipelinePolicy::template GetBlockFlatmm< Problem >())> BlockFlatmm
Definition mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:65
static constexpr index_t GetVectorSizeB()
Definition mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:87
static constexpr auto config
Definition mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:68
static constexpr int MXFP4K_PerScaleK
Definition mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:149
static constexpr index_t dsread_num_perK
Definition mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:163
static constexpr index_t Bload_rep
Definition mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:177
static constexpr index_t dswrite_num_perK
Definition mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:164
static constexpr bool kPadK
Definition mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:92
static constexpr index_t Bload_total_num
Definition mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:173
remove_cvref_t< typename Problem::ALayout > ALayout
Definition mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:61
CK_TILE_HOST_DEVICE auto operator()(ADramBlockWindowTmp a_copy_dram_window_, const AElementFunction &a_element_func, const BFlatBlockWindowTmp &b_flat_dram_block_window_tmp, const DequantBFlatWindow &scale_b_flat_window, const index_t num_loop, const index_t k_padded_zeros, void *p_smem_ping, void *p_smem_pong) const
Definition mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:446
static constexpr index_t kNPerBlock
Definition mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:80
static constexpr index_t WaveSize
Definition mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:77
static constexpr index_t dsread_per_wg
Definition mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:159
remove_cvref_t< typename Problem::BlockGemmShape > BlockGemmShape
Definition mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:56
static constexpr index_t DsWritePreIssue
Definition mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:73
static constexpr index_t kMPerBlock
Definition mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:79
static constexpr index_t KPerBlockPerIter
Definition mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:119
remove_cvref_t< typename BlockGemmShape::BlockWarps > BlockWarps
Definition mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:105
static constexpr auto I0
Definition mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:98
static constexpr index_t NumWaveGroups
Definition mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:95
static constexpr int XDL_PerWeightK
Definition mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:138
static constexpr index_t ScaleBload_K1
Definition mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:169
static constexpr int MXFP4KPerWarp
Definition mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:145
static CK_TILE_HOST_DEVICE constexpr auto Last2ndHotLoopScheduler()
Definition mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:391
static CK_TILE_HOST_DEVICE constexpr auto LastHotLoopScheduler()
Definition mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:417
static constexpr int XDL_PerScaleK
Definition mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:139
remove_cvref_t< decltype(config.template at< 0 >())> WG
Definition mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:71
static CK_TILE_HOST_DEVICE constexpr auto SchedulerPerM(index_t dsread_perM, index_t dswrite_perM, index_t load_perM)
Definition mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:187
static constexpr auto idxN
Definition mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:102
static constexpr int MXFP4PackedSize
Definition mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:121
static constexpr int ScaleKFlatPerWarp
Definition mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:132
FlatmmPipelineAGmemBGmemCRegV1< Problem, PipelinePolicy > Underlying
Definition mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:51
static constexpr index_t dswrite_kIter
Definition mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:181
static constexpr bool HasHotLoop
Definition mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:151
static constexpr index_t MWarp
Definition mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:108
static constexpr int XDL_PerScaleN
Definition mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:140
static constexpr index_t NWarp
Definition mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:109
static constexpr auto I2
Definition mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:100
remove_cvref_t< typename Problem::CLayout > CLayout
Definition mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:63
remove_cvref_t< typename BlockGemmShape::BlockTile > BlockTile
Definition mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:104
static constexpr index_t ScaleBload_num
Definition mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:170
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp &a_dram_block_window_tmp, const BFlatBlockWindowTmp &b_flat_dram_block_window_tmp, const DequantBFlatWindow &scale_b_flat_window, const index_t num_loop, void *p_smem_ping, void *p_smem_pong) const
Definition mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:1241
static constexpr index_t GetVectorSizeC()
Definition mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:88
static constexpr index_t BK1
Definition mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:123
static constexpr index_t dswrite_rep
Definition mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:165
remove_cvref_t< typename Problem::ADataType > ADataType
Definition mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:53
static constexpr index_t Aload_rep
Definition mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:167
static constexpr index_t Bload_num_perK
Definition mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:168
static constexpr index_t kLdsAlignmentInBytes
Definition mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:94
static constexpr index_t m_preload
Definition mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:124
static constexpr index_t dswrite_mIter
Definition mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:180
static constexpr index_t flatKPerWarp
Definition mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:83
static constexpr index_t HalfMIter
Definition mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:176
static constexpr index_t Aload_num_perK
Definition mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:166
static constexpr bool DoubleSmemBuffer
Definition mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:184
static constexpr auto idxM
Definition mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:101
static constexpr bool kPadM
Definition mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:90
static constexpr index_t kKPerBlock
Definition mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:81
static constexpr index_t flatNPerWarp
Definition mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:84
static constexpr auto TailNum
Definition mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:152
Definition mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:32
BlockGemmShape_ BlockGemmShape
Definition mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:33
BDataType_ QuantType
Definition mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:35
static constexpr index_t flatKPerWarp
Definition mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:44
static constexpr int ContinuousScaleKPerThread
Definition mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:43
static constexpr int ContinuousKPerThread
Definition mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:41
static constexpr index_t flatNPerWarp
Definition mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:37
static constexpr int ContinuousScaleNPerThread
Definition mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:42
static constexpr int MXF4ScaleGranularityK
Definition mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp:39
Definition flatmm_pipeline_agmem_bgmem_creg_v1.hpp:47
Definition gemm_pipeline_problem.hpp:323
Definition tile/core/utility/functional.hpp:86
Definition tile/core/container/sequence.hpp:49
Definition tile/core/utility/functional.hpp:43
Definition tile/core/utility/debug.hpp:67