blockwise_gemm_pipeline_xdlops_v1.hpp Source File

blockwise_gemm_pipeline_xdlops_v1.hpp Source File#

Composable Kernel: blockwise_gemm_pipeline_xdlops_v1.hpp Source File
blockwise_gemm_pipeline_xdlops_v1.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
7
8namespace ck {
9
10// Naive pipeline with lowest resource request per WGP
11// GlobalPrefetchStages: 1
12// LocalPreFillStages: 1
13// LocalPreFetchStages: 0
14// LocalSharedMemoryBuffer: 1
15
16template <BlockGemmPipelineScheduler BlkGemmPipelineVer,
17 index_t BlockSize,
18 typename ADataType,
19 typename BDataType,
20 typename ComputeDataType,
21 typename AccDataType,
22 typename ATileDesc,
23 typename BTileDesc,
24 typename AMmaTileDesc,
25 typename BMmaTileDesc,
26 index_t ABlockTransferSrcScalarPerVector,
27 index_t BBlockTransferSrcScalarPerVector,
28 index_t MPerBlock,
29 index_t NPerBlock,
30 index_t KPerBlock,
31 index_t MPerXDL,
32 index_t NPerXDL,
33 index_t MRepeat,
34 index_t NRepeat,
35 index_t KPacks>
39
40template <index_t BlockSize,
41 typename ADataType,
42 typename BDataType,
43 typename ComputeDataType,
44 typename AccDataType,
45 typename ATileDesc,
46 typename BTileDesc,
47 typename AMmaTileDesc,
48 typename BMmaTileDesc,
49 index_t ABlockTransferSrcScalarPerVector,
50 index_t BBlockTransferSrcScalarPerVector,
51 index_t MPerBlock,
52 index_t NPerBlock,
53 index_t KPerBlock,
54 index_t MPerXDL,
55 index_t NPerXDL,
56 index_t MRepeat,
57 index_t NRepeat,
58 index_t KPack
59 // ,bool TransposeC //disable transposec right now...
60 >
62 BlockSize,
63 ADataType,
64 BDataType,
65 ComputeDataType,
66 AccDataType,
67 ATileDesc,
68 BTileDesc,
69 AMmaTileDesc,
70 BMmaTileDesc,
71 ABlockTransferSrcScalarPerVector,
72 BBlockTransferSrcScalarPerVector,
73 MPerBlock,
74 NPerBlock,
75 KPerBlock,
76 MPerXDL,
77 NPerXDL,
78 MRepeat,
79 NRepeat,
80 KPack>
82 ADataType,
83 BDataType,
84 ComputeDataType,
85 AccDataType,
86 ATileDesc,
87 BTileDesc,
88 AMmaTileDesc,
89 BMmaTileDesc,
90 ABlockTransferSrcScalarPerVector,
91 BBlockTransferSrcScalarPerVector,
92 MPerBlock,
93 NPerBlock,
94 KPerBlock,
95 MPerXDL,
96 NPerXDL,
97 MRepeat,
98 NRepeat,
99 KPack>
100
101{
103 ADataType,
104 BDataType,
105 ComputeDataType,
106 AccDataType,
107 ATileDesc,
108 BTileDesc,
109 AMmaTileDesc,
110 BMmaTileDesc,
111 ABlockTransferSrcScalarPerVector,
112 BBlockTransferSrcScalarPerVector,
113 MPerBlock,
114 NPerBlock,
115 KPerBlock,
116 MPerXDL,
117 NPerXDL,
118 MRepeat,
119 NRepeat,
120 KPack>;
121 using Base::I0;
122 using Base::KRepeat;
123 using Base::xdlops_gemm;
124
136
139
140 using Base::AMmaKStride;
141 using Base::BMmaKStride;
142
144
145 static constexpr index_t PrefetchStages = 1;
146 static constexpr index_t PrefillStages = 1;
147 static constexpr index_t GlobalBufferNum = 1;
148
149 __host__ __device__ static constexpr bool BlockHasHotloop(index_t num_loop)
150 {
151 return num_loop > PrefetchStages;
152 }
153
154 __host__ __device__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop)
155 {
156 ignore = num_loop;
157 return TailNumber::Full;
158 }
159
160 template <bool HasMainLoop,
161 TailNumber TailNum,
162 typename AGridDesc,
163 typename ABlockDesc,
164 typename ABlockTransfer,
165 typename AGridBuffer,
166 typename ABlockBuffer,
167 typename ABlockTransferStep,
168 typename BGridDesc,
169 typename BBlockDesc,
170 typename BBlockTransfer,
171 typename BGridBuffer,
172 typename BBlockBuffer,
173 typename BBlockTransferStep,
174 typename CThreadBuffer>
175 __device__ void Run(const AGridDesc& a_grid_desc,
176 const ABlockDesc& a_block_desc,
177 ABlockTransfer& a_blockwise_copy,
178 const AGridBuffer& a_grid_buf,
179 ABlockBuffer& a_block_buf,
180 const ABlockTransferStep& a_block_copy_step,
181 const BGridDesc& b_grid_desc,
182 const BBlockDesc& b_block_desc,
183 BBlockTransfer& b_blockwise_copy,
184 const BGridBuffer& b_grid_buf,
185 BBlockBuffer& b_block_buf,
186 const BBlockTransferStep& b_block_copy_step,
187 CThreadBuffer& c_thread_buf,
188 index_t num_loop) const
189 {
191 a_thread_desc_.GetElementSpaceSize());
193 b_thread_desc_.GetElementSpaceSize());
194
195 // Global prefetch 1
196 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
197 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
198
199 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
200 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
201
202 // Local prefill 1
203 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
204 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
205
206 // Initialize C
207 c_thread_buf.Clear();
208
209 // main body
210 if constexpr(HasMainLoop)
211 {
212 index_t i = 0;
213 do
214 {
215 // -------------------------------------------------------------------------------------------
216 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
217 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
218
219 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
220 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
221
223 static_for<0, KRepeat, 1>{}([&](auto k) {
224 static_for<0, MRepeat, 1>{}([&](auto m0) {
227 a_block_buf,
229 make_tuple(m0, I0, k, I0),
230 a_thread_buf);
231 static_for<0, NRepeat, 1>{}([&](auto n0) {
234 b_block_buf,
236 make_tuple(n0, I0, k, I0),
237 b_thread_buf);
238 });
239 });
240 });
241
242 static_for<0, KRepeat, 1>{}([&](auto k0) {
243 static_for<0, MRepeat, 1>{}([&](auto m0) {
244 static_for<0, NRepeat, 1>{}([&](auto n0) {
247
248 static_for<0, KPack, 1>{}([&](auto ik) {
249 a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
250 a_thread_buf[Number<a_thread_desc_.CalculateOffset(
251 make_tuple(m0, I0, k0, ik))>{}];
252 b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
253 b_thread_buf[Number<b_thread_desc_.CalculateOffset(
254 make_tuple(n0, I0, k0, ik))>{}];
255 });
256
257 using mfma_input_type =
259 xdlops_gemm.K1PerXdlops>::type;
260
261 constexpr index_t c_offset =
262 c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
263
264 xdlops_gemm.Run(
265 a_thread_vec.template AsType<mfma_input_type>(),
266 b_thread_vec.template AsType<mfma_input_type>(),
267 c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
268 });
269 });
270 });
271
273 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
274 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
275
276 i += 1;
277 } while(i < (num_loop - 1));
278 }
279
280 // tail
281 if constexpr(TailNum == TailNumber::Full)
282 {
284 static_for<0, KRepeat, 1>{}([&](auto k) {
285 static_for<0, MRepeat, 1>{}([&](auto m0) {
288 a_block_buf,
290 make_tuple(m0, I0, k, I0),
291 a_thread_buf);
292 static_for<0, NRepeat, 1>{}([&](auto n0) {
295 b_block_buf,
297 make_tuple(n0, I0, k, I0),
298 b_thread_buf);
299 });
300 });
301 });
302
303 static_for<0, KRepeat, 1>{}([&](auto k0) {
304 static_for<0, MRepeat, 1>{}([&](auto m0) {
305 static_for<0, NRepeat, 1>{}([&](auto n0) {
308
309 static_for<0, KPack, 1>{}([&](auto ik) {
310 a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
311 a_thread_buf[Number<a_thread_desc_.CalculateOffset(
312 make_tuple(m0, I0, k0, ik))>{}];
313 b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
314 b_thread_buf[Number<b_thread_desc_.CalculateOffset(
315 make_tuple(n0, I0, k0, ik))>{}];
316 });
317
318 using mfma_input_type =
319 typename vector_type<ComputeDataTypeBuf, xdlops_gemm.K1PerXdlops>::type;
320
321 constexpr index_t c_offset =
322 c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
323
324 xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
325 b_thread_vec.template AsType<mfma_input_type>(),
326 c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
327 });
328 });
329 });
330 }
331 }
332
333 protected:
334 using Base::a_thread_copy_;
335 using Base::a_thread_desc_;
336 using Base::b_thread_copy_;
337 using Base::b_thread_desc_;
338 using Base::c_thread_desc_;
339};
340
341template <index_t BlockSize,
342 typename ADataType,
343 typename BDataType,
344 typename ComputeDataType,
345 typename AccDataType,
346 typename ATileDesc,
347 typename BTileDesc,
348 typename AMmaTileDesc,
349 typename BMmaTileDesc,
350 index_t ABlockTransferSrcScalarPerVector,
351 index_t BBlockTransferSrcScalarPerVector,
352 index_t MPerBlock,
353 index_t NPerBlock,
354 index_t KPerBlock,
355 index_t MPerXDL,
356 index_t NPerXDL,
357 index_t MRepeat,
358 index_t NRepeat,
359 index_t KPack
360 // ,bool TransposeC //disable transposec right now...
361 >
363 BlockSize,
364 ADataType,
365 BDataType,
366 ComputeDataType,
367 AccDataType,
368 ATileDesc,
369 BTileDesc,
370 AMmaTileDesc,
371 BMmaTileDesc,
372 ABlockTransferSrcScalarPerVector,
373 BBlockTransferSrcScalarPerVector,
374 MPerBlock,
375 NPerBlock,
376 KPerBlock,
377 MPerXDL,
378 NPerXDL,
379 MRepeat,
380 NRepeat,
381 KPack>
383 ADataType,
384 BDataType,
385 ComputeDataType,
386 AccDataType,
387 ATileDesc,
388 BTileDesc,
389 AMmaTileDesc,
390 BMmaTileDesc,
391 ABlockTransferSrcScalarPerVector,
392 BBlockTransferSrcScalarPerVector,
393 MPerBlock,
394 NPerBlock,
395 KPerBlock,
396 MPerXDL,
397 NPerXDL,
398 MRepeat,
399 NRepeat,
400 KPack>
401
402{
404 ADataType,
405 BDataType,
406 ComputeDataType,
407 AccDataType,
408 ATileDesc,
409 BTileDesc,
410 AMmaTileDesc,
411 BMmaTileDesc,
412 ABlockTransferSrcScalarPerVector,
413 BBlockTransferSrcScalarPerVector,
414 MPerBlock,
415 NPerBlock,
416 KPerBlock,
417 MPerXDL,
418 NPerXDL,
419 MRepeat,
420 NRepeat,
421 KPack>;
422 using Base::A_K1;
423 using Base::B_K1;
424 using Base::I0;
425 using Base::I1;
426 using Base::KPerThread;
427 using Base::xdlops_gemm;
428
440
443
445
449 static constexpr index_t PrefetchStages = 1;
450 static constexpr index_t PrefillStages = 1;
451 static constexpr index_t GlobalBufferNum = 1;
452 __host__ __device__ static constexpr bool BlockHasHotloop(index_t num_loop)
453 {
454 return num_loop > PrefetchStages;
455 }
456
457 __host__ __device__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop)
458 {
459 ignore = num_loop;
460 return TailNumber::Full;
461 }
462
463 template <bool HasMainLoop,
464 TailNumber TailNum,
465 typename AGridDesc,
466 typename ABlockDesc,
467 typename ABlockTransfer,
468 typename AGridBuffer,
469 typename ABlockBuffer,
470 typename ABlockTransferStep,
471 typename BGridDesc,
472 typename BBlockDesc,
473 typename BBlockTransfer,
474 typename BGridBuffer,
475 typename BBlockBuffer,
476 typename BBlockTransferStep,
477 typename CThreadBuffer>
478 __device__ void Run(const AGridDesc& a_grid_desc,
479 const ABlockDesc& a_block_desc,
480 ABlockTransfer& a_blockwise_copy,
481 const AGridBuffer& a_grid_buf,
482 ABlockBuffer& a_block_buf,
483 const ABlockTransferStep& a_block_copy_step,
484 const BGridDesc& b_grid_desc,
485 const BBlockDesc& b_block_desc,
486 BBlockTransfer& b_blockwise_copy,
487 const BGridBuffer& b_grid_buf,
488 BBlockBuffer& b_block_buf,
489 const BBlockTransferStep& b_block_copy_step,
490 CThreadBuffer& c_thread_buf,
491 index_t num_loop) const
492 {
494 a_thread_desc_.GetElementSpaceSize());
496 b_thread_desc_.GetElementSpaceSize());
497
498 // Global prefetch 1
499 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
500 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
501
502 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
503 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
504
505 // Local prefill 1
506 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
507 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
508
509 // Initialize C
510 c_thread_buf.Clear();
511
512 // main body
513 if constexpr(HasMainLoop)
514 {
515 index_t i = 0;
516 do
517 {
518 // -------------------------------------------------------------------------------------------
519 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
520 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
521
522 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
523 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
524
526 static_for<0, KRepeat, 1>{}([&](auto k0) {
527 static_for<0, MRepeat, 1>{}([&](auto m0) {
530 a_block_buf,
532 make_tuple(m0, I0, k0, I0),
533 a_thread_buf);
534 static_for<0, NRepeat, 1>{}([&](auto n0) {
537 b_block_buf,
539 make_tuple(n0, I0, k0, I0),
540 b_thread_buf);
541 });
542 });
543 __builtin_amdgcn_sched_barrier(0);
544 // NOTE: Synchronize threads in a workgroup at the start of each MAC cluster,
545 // but except the first, as we can shorten non-MAC cluster a bit and there's no
546 // observable negative impact. The desired effect is waves in a workgroup
547 // executing MAC in sync. This avoids some out-of-sync waves hijacking MAC
548 // resource from other workgroups and reducing the chance of latency hiding by
549 // waiting for the rest of the workgroup at the eventual sync point.
550 if constexpr(k0.value != 0 || KRepeat == 1)
551 {
552 __builtin_amdgcn_s_barrier();
553 __builtin_amdgcn_sched_barrier(0);
554 }
556 static_for<0, MRepeat, 1>{}([&](auto m0) {
557 static_for<0, NRepeat, 1>{}([&](auto n0) {
560
561 static_for<0, KPack, 1>{}([&](auto ik) {
562 a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
563 a_thread_buf[Number<a_thread_desc_.CalculateOffset(
564 make_tuple(m0, I0, k0, k_ + ik))>{}];
565 b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
566 b_thread_buf[Number<b_thread_desc_.CalculateOffset(
567 make_tuple(n0, I0, k0, k_ + ik))>{}];
568 });
569
570 using mfma_input_type =
572 xdlops_gemm.K1PerXdlops>::type;
573
574 constexpr index_t c_offset =
575 c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
576
577 // The block_sync_lds() here performs double duty:
578 // A) safeguard against data hazard because barrier from
579 // blockwise_gemm is moved here B) reduce VMEM FIFO congestion by
580 // applying small delays to different wavefronts It is performed
581 // near the end of MAC cluster to minimize lgkmcnt penalty
582 if constexpr(k0.value == KRepeat - 1 &&
583 k_.value == KPerInnerLoop - KPack &&
584 m0.value == MRepeat - 1 && n0.value == NRepeat - 1)
585 {
586 __builtin_amdgcn_sched_barrier(0);
588 __builtin_amdgcn_sched_barrier(0);
589 }
590 xdlops_gemm.Run(
591 a_thread_vec.template AsType<mfma_input_type>(),
592 b_thread_vec.template AsType<mfma_input_type>(),
593 c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
594 if constexpr(k_.value == 0 && m0.value == 0 && n0.value == 0)
595 {
596 __builtin_amdgcn_sched_barrier(0);
597 __builtin_amdgcn_s_setprio(1);
598 __builtin_amdgcn_sched_barrier(0);
599 }
600 });
601 });
602 });
603 __builtin_amdgcn_sched_barrier(0);
604 __builtin_amdgcn_s_setprio(0);
605 __builtin_amdgcn_sched_barrier(0);
606 });
607
608 // block_sync_lds();
609 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
610 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
611
612 i += 1;
613 } while(i < (num_loop - 1));
614 }
615
616 // tail
617 if constexpr(TailNum == TailNumber::Full)
618 {
620 static_for<0, KRepeat, 1>{}([&](auto k0) {
621 static_for<0, MRepeat, 1>{}([&](auto m0) {
624 a_block_buf,
626 make_tuple(m0, I0, k0, I0),
627 a_thread_buf);
628 static_for<0, NRepeat, 1>{}([&](auto n0) {
631 b_block_buf,
633 make_tuple(n0, I0, k0, I0),
634 b_thread_buf);
635 });
636 });
637
638 __builtin_amdgcn_sched_barrier(0);
639 if constexpr(k0.value != 0 || KRepeat == 1)
640 {
641 __builtin_amdgcn_s_barrier();
642 __builtin_amdgcn_sched_barrier(0);
643 }
645 static_for<0, MRepeat, 1>{}([&](auto m0) {
646 static_for<0, NRepeat, 1>{}([&](auto n0) {
649
650 static_for<0, KPack, 1>{}([&](auto ik) {
651 a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
652 a_thread_buf[Number<a_thread_desc_.CalculateOffset(
653 make_tuple(m0, I0, k0, k_ + ik))>{}];
654 b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
655 b_thread_buf[Number<b_thread_desc_.CalculateOffset(
656 make_tuple(n0, I0, k0, k_ + ik))>{}];
657 });
658
659 using mfma_input_type =
661 xdlops_gemm.K1PerXdlops>::type;
662
663 constexpr index_t c_offset =
664 c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
665
666 if constexpr(k0.value == KRepeat - 1 &&
667 k_.value == KPerInnerLoop - KPack &&
668 m0.value == MRepeat - 1 && n0.value == NRepeat - 1)
669 {
670 __builtin_amdgcn_sched_barrier(0);
672 __builtin_amdgcn_sched_barrier(0);
673 }
674 xdlops_gemm.Run(
675 a_thread_vec.template AsType<mfma_input_type>(),
676 b_thread_vec.template AsType<mfma_input_type>(),
677 c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
678 if constexpr(k_.value == 0 && m0.value == 0 && n0.value == 0)
679 {
680 __builtin_amdgcn_sched_barrier(0);
681 __builtin_amdgcn_s_setprio(1);
682 __builtin_amdgcn_sched_barrier(0);
683 }
684 });
685 });
686 });
687 __builtin_amdgcn_sched_barrier(0);
688 __builtin_amdgcn_s_setprio(0);
689 __builtin_amdgcn_sched_barrier(0);
690 });
691 }
692 }
693
694 protected:
695 // K->M loopover
701 I1));
702
708 I1));
709
712 decltype(a_block_desc_m0_m1_m2_k),
713 decltype(a_thread_desc_),
716 3,
717 A_K1,
718 A_K1>;
719
722 decltype(b_block_desc_n0_n1_n2_k),
723 decltype(b_thread_desc_),
726 3,
727 B_K1,
728 B_K1>;
729
732 using Base::c_thread_desc_;
733};
734
735// Naive pipeline with lowest resource request per WGP
736// Implementation with direct load
737// GlobalPrefetchStages: 1
738// LocalPreFillStages: 1
739// LocalPreFetchStages: 0
740// LocalSharedMemoryBuffer: 1
741
742template <BlockGemmPipelineScheduler BlkGemmPipelineVer,
743 index_t BlockSize,
744 typename ADataType,
745 typename BDataType,
746 typename ComputeDataType,
747 typename AccDataType,
748 typename ATileDesc,
749 typename BTileDesc,
750 typename AMmaTileDesc,
751 typename BMmaTileDesc,
752 index_t ABlockTransferSrcScalarPerVector,
753 index_t BBlockTransferSrcScalarPerVector,
754 index_t MPerBlock,
755 index_t NPerBlock,
756 index_t KPerBlock,
757 index_t MPerXDL,
758 index_t NPerXDL,
759 index_t MRepeat,
760 index_t NRepeat,
761 index_t KPacks>
765
766template <index_t BlockSize,
767 typename ADataType,
768 typename BDataType,
769 typename ComputeDataType,
770 typename AccDataType,
771 typename ATileDesc,
772 typename BTileDesc,
773 typename AMmaTileDesc,
774 typename BMmaTileDesc,
775 index_t ABlockTransferSrcScalarPerVector,
776 index_t BBlockTransferSrcScalarPerVector,
777 index_t MPerBlock,
778 index_t NPerBlock,
779 index_t KPerBlock,
780 index_t MPerXDL,
781 index_t NPerXDL,
782 index_t MRepeat,
783 index_t NRepeat,
784 index_t KPack
785 // ,bool TransposeC //disable transposec right now...
786 >
788 BlockSize,
789 ADataType,
790 BDataType,
791 ComputeDataType,
792 AccDataType,
793 ATileDesc,
794 BTileDesc,
795 AMmaTileDesc,
796 BMmaTileDesc,
797 ABlockTransferSrcScalarPerVector,
798 BBlockTransferSrcScalarPerVector,
799 MPerBlock,
800 NPerBlock,
801 KPerBlock,
802 MPerXDL,
803 NPerXDL,
804 MRepeat,
805 NRepeat,
806 KPack>
808 ADataType,
809 BDataType,
810 ComputeDataType,
811 AccDataType,
812 ATileDesc,
813 BTileDesc,
814 AMmaTileDesc,
815 BMmaTileDesc,
816 ABlockTransferSrcScalarPerVector,
817 BBlockTransferSrcScalarPerVector,
818 MPerBlock,
819 NPerBlock,
820 KPerBlock,
821 MPerXDL,
822 NPerXDL,
823 MRepeat,
824 NRepeat,
825 KPack>
826
827{
829 ADataType,
830 BDataType,
831 ComputeDataType,
832 AccDataType,
833 ATileDesc,
834 BTileDesc,
835 AMmaTileDesc,
836 BMmaTileDesc,
837 ABlockTransferSrcScalarPerVector,
838 BBlockTransferSrcScalarPerVector,
839 MPerBlock,
840 NPerBlock,
841 KPerBlock,
842 MPerXDL,
843 NPerXDL,
844 MRepeat,
845 NRepeat,
846 KPack>;
847 using Base::I0;
848 using Base::KRepeat;
849 using Base::xdlops_gemm;
850
862
865
866 using Base::AMmaKStride;
867 using Base::BMmaKStride;
868
870
871 static constexpr index_t PrefetchStages = 1;
872 static constexpr index_t PrefillStages = 1;
873 static constexpr index_t GlobalBufferNum = 1;
874
875 __host__ __device__ static constexpr bool BlockHasHotloop(index_t num_loop)
876 {
877 return num_loop > PrefetchStages;
878 }
879
880 __host__ __device__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop)
881 {
882 ignore = num_loop;
883 return TailNumber::Full;
884 }
885
886 template <bool HasMainLoop,
887 TailNumber TailNum,
888 typename AGridDesc,
889 typename ABlockDesc,
890 typename ABlockTransfer,
891 typename AGridBuffer,
892 typename ABlockBuffer,
893 typename ABlockTransferStep,
894 typename BGridDesc,
895 typename BBlockDesc,
896 typename BBlockTransfer,
897 typename BGridBuffer,
898 typename BBlockBuffer,
899 typename BBlockTransferStep,
900 typename CThreadBuffer>
901 __device__ void Run(const AGridDesc& a_grid_desc,
902 const ABlockDesc& a_block_desc,
903 ABlockTransfer& a_blockwise_copy,
904 const AGridBuffer& a_grid_buf,
905 ABlockBuffer& a_block_buf,
906 const ABlockTransferStep& a_block_copy_step,
907 const BGridDesc& b_grid_desc,
908 const BBlockDesc& b_block_desc,
909 BBlockTransfer& b_blockwise_copy,
910 const BGridBuffer& b_grid_buf,
911 BBlockBuffer& b_block_buf,
912 const BBlockTransferStep& b_block_copy_step,
913 CThreadBuffer& c_thread_buf,
914 index_t num_loop) const
915 {
917 a_thread_desc_.GetElementSpaceSize());
919 b_thread_desc_.GetElementSpaceSize());
920
921 // Global prefetch 1
922 a_blockwise_copy.Run(a_grid_desc, a_grid_buf, a_block_desc, a_block_buf);
923 b_blockwise_copy.Run(b_grid_desc, b_grid_buf, b_block_desc, b_block_buf);
924
925 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
926 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
927
929
930 // Initialize C
931 c_thread_buf.Clear();
932
933 // main body
934 if constexpr(HasMainLoop)
935 {
936 index_t i = 0;
937 do
938 {
939 static_for<0, KRepeat, 1>{}([&](auto k) {
940 static_for<0, MRepeat, 1>{}([&](auto m0) {
943 a_block_buf,
945 make_tuple(m0, I0, k, I0),
946 a_thread_buf);
947 static_for<0, NRepeat, 1>{}([&](auto n0) {
950 b_block_buf,
952 make_tuple(n0, I0, k, I0),
953 b_thread_buf);
954 });
955 });
956 });
957
959 a_blockwise_copy.Run(a_grid_desc, a_grid_buf, a_block_desc, a_block_buf);
960 b_blockwise_copy.Run(b_grid_desc, b_grid_buf, b_block_desc, b_block_buf);
961
962 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
963 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
964
965 static_for<0, KRepeat, 1>{}([&](auto k0) {
966 static_for<0, MRepeat, 1>{}([&](auto m0) {
967 static_for<0, NRepeat, 1>{}([&](auto n0) {
970
971 static_for<0, KPack, 1>{}([&](auto ik) {
972 a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
973 a_thread_buf[Number<a_thread_desc_.CalculateOffset(
974 make_tuple(m0, I0, k0, ik))>{}];
975 b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
976 b_thread_buf[Number<b_thread_desc_.CalculateOffset(
977 make_tuple(n0, I0, k0, ik))>{}];
978 });
979
980 using mfma_input_type =
982 xdlops_gemm.K1PerXdlops>::type;
983
984 constexpr index_t c_offset =
985 c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
986
987 xdlops_gemm.Run(
988 a_thread_vec.template AsType<mfma_input_type>(),
989 b_thread_vec.template AsType<mfma_input_type>(),
990 c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
991 });
992 });
993 });
994
996
997 i += 1;
998 } while(i < (num_loop - 1));
999 }
1000
1001 // tail
1002 if constexpr(TailNum == TailNumber::Full)
1003 {
1004 static_for<0, KRepeat, 1>{}([&](auto k) {
1005 static_for<0, MRepeat, 1>{}([&](auto m0) {
1008 a_block_buf,
1010 make_tuple(m0, I0, k, I0),
1011 a_thread_buf);
1012 static_for<0, NRepeat, 1>{}([&](auto n0) {
1015 b_block_buf,
1017 make_tuple(n0, I0, k, I0),
1018 b_thread_buf);
1019 });
1020 });
1021 });
1022
1023 static_for<0, KRepeat, 1>{}([&](auto k0) {
1024 static_for<0, MRepeat, 1>{}([&](auto m0) {
1025 static_for<0, NRepeat, 1>{}([&](auto n0) {
1028
1029 static_for<0, KPack, 1>{}([&](auto ik) {
1030 a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
1031 a_thread_buf[Number<a_thread_desc_.CalculateOffset(
1032 make_tuple(m0, I0, k0, ik))>{}];
1033 b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
1034 b_thread_buf[Number<b_thread_desc_.CalculateOffset(
1035 make_tuple(n0, I0, k0, ik))>{}];
1036 });
1037
1038 using mfma_input_type =
1039 typename vector_type<ComputeDataTypeBuf, xdlops_gemm.K1PerXdlops>::type;
1040
1041 constexpr index_t c_offset =
1042 c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
1043
1044 xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
1045 b_thread_vec.template AsType<mfma_input_type>(),
1046 c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
1047 });
1048 });
1049 });
1050 }
1051 }
1052
1053 protected:
1054 using Base::a_thread_copy_;
1055 using Base::a_thread_desc_;
1056 using Base::b_thread_copy_;
1057 using Base::b_thread_desc_;
1058 using Base::c_thread_desc_;
1059};
1060
1061} // namespace ck
#define CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING_MAC_CLUSTERS
Definition ck.hpp:209
__host__ __device__ constexpr T max(T x)
Definition utility/math.hpp:84
Definition ck.hpp:268
__host__ __device__ constexpr auto make_static_buffer(Number< N >)
Definition static_buffer.hpp:186
int32_t index_t
Definition ck.hpp:299
__host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition tensor_descriptor_helper.hpp:49
integral_constant< index_t, N > Number
Definition number.hpp:12
TailNumber
Definition blkgemmpipe_scheduler.hpp:31
@ Full
Definition blkgemmpipe_scheduler.hpp:49
__device__ void block_sync_lds_direct_load()
Definition synchronization.hpp:43
constexpr detail::ignore_t ignore
Definition utility/ignore.hpp:20
BlockGemmPipelineScheduler
Definition blkgemmpipe_scheduler.hpp:25
@ Intrawave
Definition blkgemmpipe_scheduler.hpp:26
@ Interwave
Definition blkgemmpipe_scheduler.hpp:27
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__device__ void block_sync_lds()
Definition synchronization.hpp:16
__host__ __device__ BlockwiseGemmXdlops_pipeline_base(Tuple4 a_origin=CalculateAThreadOriginDataIndex(), Tuple4 b_origin=CalculateBThreadOriginDataIndex())
Constructor for BlockwiseGemmXdlops_pipeline_base.
Definition blockwise_gemm_pipeline_xdlops_base.hpp:222
__host__ static __device__ constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4()
Definition blockwise_gemm_pipeline_xdlops_base.hpp:280
__host__ static __device__ constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4()
Definition blockwise_gemm_pipeline_xdlops_base.hpp:239
static constexpr auto xdlops_gemm
Definition blockwise_gemm_pipeline_xdlops_base.hpp:54
conditional_t< std::is_same< ComputeDataType, ck::tf32_t >::value, float, ComputeDataType > ComputeDataTypeBuf
Definition blockwise_gemm_pipeline_xdlops_base.hpp:57
static __device__ auto CalculateBThreadOriginDataIndex()
Definition blockwise_gemm_pipeline_xdlops_base.hpp:147
static constexpr BMmaTileDesc b_block_desc_n0_n1_n2_k
Definition blockwise_gemm_pipeline_xdlops_base.hpp:360
static constexpr auto I1
Definition blockwise_gemm_pipeline_xdlops_base.hpp:37
__host__ static __device__ constexpr auto GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2()
Definition blockwise_gemm_pipeline_xdlops_base.hpp:266
__host__ static __device__ constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()
Definition blockwise_gemm_pipeline_xdlops_base.hpp:294
static constexpr index_t AMmaKStride
Definition blockwise_gemm_pipeline_xdlops_base.hpp:60
static __device__ auto CalculateAThreadOriginDataIndex()
Definition blockwise_gemm_pipeline_xdlops_base.hpp:125
__host__ static __device__ constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()
Definition blockwise_gemm_pipeline_xdlops_base.hpp:253
static constexpr index_t B_K1
Definition blockwise_gemm_pipeline_xdlops_base.hpp:51
__host__ __device__ constexpr auto & GetCThreadBuffer()
Definition blockwise_gemm_pipeline_xdlops_base.hpp:111
static constexpr auto I0
Definition blockwise_gemm_pipeline_xdlops_base.hpp:36
static __device__ auto CalculateCThreadOriginDataIndex(Number< m0 >, Number< n0 >, Number< xdlops_i >, Number< blk_i >)
Definition blockwise_gemm_pipeline_xdlops_base.hpp:160
static __device__ auto CalculateCThreadOriginDataIndex8D(Number< m0 >, Number< n0 >, Number< xdlops_i >, Number< blk_i >)
Definition blockwise_gemm_pipeline_xdlops_base.hpp:189
static constexpr index_t KRepeat
Definition blockwise_gemm_pipeline_xdlops_base.hpp:64
static constexpr AMmaTileDesc a_block_desc_m0_m1_m2_k
Definition blockwise_gemm_pipeline_xdlops_base.hpp:359
static constexpr index_t A_K1
Definition blockwise_gemm_pipeline_xdlops_base.hpp:50
static constexpr index_t BMmaKStride
Definition blockwise_gemm_pipeline_xdlops_base.hpp:61
__host__ static __device__ constexpr auto MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_G_M_N &c_grid_desc_g_m_n)
Definition blockwise_gemm_pipeline_xdlops_base.hpp:341
static constexpr index_t KPerThread
Definition blockwise_gemm_pipeline_xdlops_base.hpp:63
__host__ static __device__ constexpr auto GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2()
Definition blockwise_gemm_pipeline_xdlops_base.hpp:307
__host__ static __device__ constexpr auto MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_M_N &c_grid_desc_m_n)
Definition blockwise_gemm_pipeline_xdlops_base.hpp:324
ThreadwiseTensorSliceTransfer_v4< BDataType, ComputeDataTypeBuf, decltype(b_block_desc_n0_n1_n2_k), decltype(b_thread_desc_), Sequence< 1, 1, 1, KPerInnerLoop >, Sequence< 0, 1, 2, 3 >, 3, B_K1, B_K1 > BThreadCopy
Definition blockwise_gemm_pipeline_xdlops_v1.hpp:720
BlockwiseGemmXdlops_pipeline_base< BlockSize, ADataType, BDataType, ComputeDataType, AccDataType, ATileDesc, BTileDesc, AMmaTileDesc, BMmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXDL, NPerXDL, MRepeat, NRepeat, KPack > Base
Definition blockwise_gemm_pipeline_xdlops_v1.hpp:403
__device__ void Run(const AGridDesc &a_grid_desc, const ABlockDesc &a_block_desc, ABlockTransfer &a_blockwise_copy, const AGridBuffer &a_grid_buf, ABlockBuffer &a_block_buf, const ABlockTransferStep &a_block_copy_step, const BGridDesc &b_grid_desc, const BBlockDesc &b_block_desc, BBlockTransfer &b_blockwise_copy, const BGridBuffer &b_grid_buf, BBlockBuffer &b_block_buf, const BBlockTransferStep &b_block_copy_step, CThreadBuffer &c_thread_buf, index_t num_loop) const
Definition blockwise_gemm_pipeline_xdlops_v1.hpp:478
ThreadwiseTensorSliceTransfer_v4< ADataType, ComputeDataTypeBuf, decltype(a_block_desc_m0_m1_m2_k), decltype(a_thread_desc_), Sequence< 1, 1, 1, KPerInnerLoop >, Sequence< 0, 1, 2, 3 >, 3, A_K1, A_K1 > AThreadCopy
Definition blockwise_gemm_pipeline_xdlops_v1.hpp:710
__device__ void Run(const AGridDesc &a_grid_desc, const ABlockDesc &a_block_desc, ABlockTransfer &a_blockwise_copy, const AGridBuffer &a_grid_buf, ABlockBuffer &a_block_buf, const ABlockTransferStep &a_block_copy_step, const BGridDesc &b_grid_desc, const BBlockDesc &b_block_desc, BBlockTransfer &b_blockwise_copy, const BGridBuffer &b_grid_buf, BBlockBuffer &b_block_buf, const BBlockTransferStep &b_block_copy_step, CThreadBuffer &c_thread_buf, index_t num_loop) const
Definition blockwise_gemm_pipeline_xdlops_v1.hpp:175
BlockwiseGemmXdlops_pipeline_base< BlockSize, ADataType, BDataType, ComputeDataType, AccDataType, ATileDesc, BTileDesc, AMmaTileDesc, BMmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXDL, NPerXDL, MRepeat, NRepeat, KPack > Base
Definition blockwise_gemm_pipeline_xdlops_v1.hpp:102
Definition blockwise_gemm_pipeline_xdlops_v1.hpp:37
__device__ void Run(const AGridDesc &a_grid_desc, const ABlockDesc &a_block_desc, ABlockTransfer &a_blockwise_copy, const AGridBuffer &a_grid_buf, ABlockBuffer &a_block_buf, const ABlockTransferStep &a_block_copy_step, const BGridDesc &b_grid_desc, const BBlockDesc &b_block_desc, BBlockTransfer &b_blockwise_copy, const BGridBuffer &b_grid_buf, BBlockBuffer &b_block_buf, const BBlockTransferStep &b_block_copy_step, CThreadBuffer &c_thread_buf, index_t num_loop) const
Definition blockwise_gemm_pipeline_xdlops_v1.hpp:901
BlockwiseGemmXdlops_pipeline_base< BlockSize, ADataType, BDataType, ComputeDataType, AccDataType, ATileDesc, BTileDesc, AMmaTileDesc, BMmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXDL, NPerXDL, MRepeat, NRepeat, KPack > Base
Definition blockwise_gemm_pipeline_xdlops_v1.hpp:828
Definition blockwise_gemm_pipeline_xdlops_v1.hpp:763
Definition utility/sequence.hpp:43
Definition threadwise_tensor_slice_transfer.hpp:1260
Definition functional2.hpp:33
Definition dtype_vector.hpp:10