device_grouped_gemm_xdl_splitk_cshuffle.hpp Source File

device_grouped_gemm_xdl_splitk_cshuffle.hpp Source File#

Composable Kernel: device_grouped_gemm_xdl_splitk_cshuffle.hpp Source File
device_grouped_gemm_xdl_splitk_cshuffle.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include <iostream>
7#include <sstream>
8
9#include "ck/ck.hpp"
10#include "ck/utility/env.hpp"
15#include "ck/utility/tuple.hpp"
22
23namespace ck {
24namespace tensor_operation {
25namespace device {
26
27template <typename GridwiseGemm,
28 typename GemmDesc,
29 bool HasMainKBlockLoop,
30 InMemoryDataOperationEnum CGlobalMemoryDataOperation,
31 typename AElementwiseOperation = ck::tensor_operation::element_wise::PassThrough,
32 typename BElementwiseOperation = ck::tensor_operation::element_wise::PassThrough,
33 typename CElementwiseOperation = ck::tensor_operation::element_wise::PassThrough>
34__global__ void
35#if CK_USE_LAUNCH_BOUNDS
37#endif
39 const index_t group_count,
40 const AElementwiseOperation a_element_op,
41 const BElementwiseOperation b_element_op,
42 const CElementwiseOperation c_element_op)
43{
44#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__)
45 if constexpr(GridwiseGemm::template IsValidCompilationParameter<CGlobalMemoryDataOperation>())
46 {
47 constexpr index_t shared_size = GridwiseGemm::GetSharedMemoryNumberOfByte();
48 __shared__ uint8_t p_shared[shared_size];
49
50 const index_t block_id = get_block_1d_id();
51 const auto gemm_desc_ptr = reinterpret_cast<const GemmDesc*>(
53
54 index_t left = 0;
55 index_t right = group_count;
56 index_t group_id = index_t((left + right) / 2);
57 while((!(block_id >= gemm_desc_ptr[group_id].block_start_ &&
58 block_id < gemm_desc_ptr[group_id].block_end_)) &&
59 left <= right)
60 {
61 if(block_id < gemm_desc_ptr[group_id].block_start_)
62 {
63 right = group_id;
64 }
65 else
66 {
67 left = group_id;
68 }
69 group_id = index_t((left + right) / 2);
70 }
71
72 GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation>(
73 gemm_desc_ptr[group_id].karg_,
74 static_cast<void*>(p_shared),
75 gemm_desc_ptr[group_id].block_2_ctile_map_,
76 a_element_op,
77 b_element_op,
78 c_element_op);
79 }
80#else
81 ignore = gemm_descs_const;
82 ignore = group_count;
83 ignore = a_element_op;
84 ignore = b_element_op;
85 ignore = c_element_op;
86#endif // end of if (defined(__gfx9__))
87}
88
89template <typename ALayout,
90 typename BLayout,
91 typename DsLayout,
92 typename ELayout,
93 typename ADataType,
94 typename BDataType,
95 typename AccDataType,
96 typename CShuffleDataType,
97 typename DsDataType,
98 typename EDataType,
99 typename AElementwiseOperation,
100 typename BElementwiseOperation,
101 typename CDEElementwiseOperation,
102 GemmSpecialization GemmSpec,
103 ck::index_t NumGemmKPrefetchStage,
104 ck::index_t BlockSize,
105 ck::index_t MPerBlock,
106 ck::index_t NPerBlock,
107 ck::index_t KPerBlock,
108 ck::index_t AK1,
109 ck::index_t BK1,
110 ck::index_t MPerXDL,
111 ck::index_t NPerXDL,
112 ck::index_t MXdlPerWave,
113 ck::index_t NXdlPerWave,
114 typename ABlockTransferThreadClusterLengths_K0_M_K1,
115 typename ABlockTransferThreadClusterArrangeOrder,
116 typename ABlockTransferSrcAccessOrder,
117 ck::index_t ABlockTransferSrcVectorDim,
118 ck::index_t ABlockTransferSrcScalarPerVector,
119 ck::index_t ABlockTransferDstScalarPerVector_K1,
120 bool ABlockLdsExtraM,
121 typename BBlockTransferThreadClusterLengths_K0_N_K1,
122 typename BBlockTransferThreadClusterArrangeOrder,
123 typename BBlockTransferSrcAccessOrder,
124 ck::index_t BBlockTransferSrcVectorDim,
125 ck::index_t BBlockTransferSrcScalarPerVector,
126 ck::index_t BBlockTransferDstScalarPerVector_K1,
127 bool BBlockLdsExtraN,
128 index_t CShuffleMXdlPerWavePerShuffle,
129 index_t CShuffleNXdlPerWavePerShuffle,
130 typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
131 index_t CDEBlockTransferScalarPerVector_NPerBlock,
134 // Current implementation does not support multiple D fusions.
137 bool> = false>
139 BLayout,
140 DsLayout,
141 ELayout,
142 ADataType,
143 BDataType,
144 DsDataType,
145 EDataType,
146 AElementwiseOperation,
147 BElementwiseOperation,
148 CDEElementwiseOperation>
149{
151 static constexpr auto NXdlPerWave64 = GetNXdlPerWave<true>();
152 static constexpr auto NXdlPerWave32 = GetNXdlPerWave<false>();
153 static constexpr index_t NumDTensor = DsDataType::Size();
154
155 static constexpr auto I0 = Number<0>{};
156 static constexpr auto I1 = Number<1>{};
157 static constexpr auto I2 = Number<2>{};
158 static constexpr auto I3 = Number<3>{};
159 static_assert(KPerBlock % AK1 == 0);
160 static constexpr index_t K0PerBlock = KPerBlock / AK1;
161
162 template <index_t NXdlPerWave_>
164 BlockSize,
165 ADataType,
166 BDataType,
167 AccDataType,
168 EDataType,
169 ALayout,
170 BLayout,
171 ELayout,
172 AElementwiseOperation,
173 BElementwiseOperation,
174 CDEElementwiseOperation,
175 GemmSpec,
176 NumGemmKPrefetchStage,
177 MPerBlock,
178 NPerBlock,
180 MPerXDL,
181 NPerXDL,
182 AK1,
183 MXdlPerWave,
184 NXdlPerWave_,
185 ABlockTransferThreadClusterLengths_K0_M_K1,
186 ABlockTransferThreadClusterArrangeOrder,
187 ABlockTransferSrcAccessOrder,
188 ABlockTransferSrcVectorDim,
189 ABlockTransferSrcScalarPerVector,
190 ABlockTransferDstScalarPerVector_K1,
191 false, // AThreadTransferSrcResetCoordinateAfterRun,
192 ABlockLdsExtraM,
193 BBlockTransferThreadClusterLengths_K0_N_K1,
194 BBlockTransferThreadClusterArrangeOrder,
195 BBlockTransferSrcAccessOrder,
196 BBlockTransferSrcVectorDim,
197 BBlockTransferSrcScalarPerVector,
198 BBlockTransferDstScalarPerVector_K1,
199 false, // BThreadTransferSrcResetCoordinateAfterRun,
200 BBlockLdsExtraN,
201 CShuffleMXdlPerWavePerShuffle,
202 CShuffleNXdlPerWavePerShuffle,
203 CDEBlockTransferScalarPerVector_NPerBlock,
204 CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
205 LoopSched,
206 PipelineVer>;
209
213 // Block2CTileMap configuration parameter.
214 static constexpr index_t B2E_M01 = 8;
216 using KernelArgument = typename GridwiseGemm64::Argument;
218 template <typename KernelArgument_>
220 {
221 KernelArgument_ karg_;
224
226 GemmTransKernelArgBase(KernelArgument_&& karg,
228 index_t block_start,
229 index_t block_end)
230 : karg_{karg},
231 block_2_ctile_map_{b2c_map},
232 block_start_{block_start},
233 block_end_{block_end}
234 {
235 }
236 };
238
239 static constexpr index_t DefaultKBatch = 1;
240
241 // Argument
242 struct Argument : public BaseArgument
243 {
244
245 Argument(std::vector<const void*>& p_As,
246 std::vector<const void*>& p_Bs,
247 std::vector<void*>& p_Es,
248 std::vector<GemmDesc>& gemm_descs)
249 : Argument(p_As, p_Bs, p_Es, gemm_descs, DefaultKBatch)
250 {
251 // TODO: use occupancy api to calculate appropriate batch size.
252 }
253
254 Argument(std::vector<const void*>& p_As,
255 std::vector<const void*>& p_Bs,
256 std::vector<void*>& p_Es,
257 std::vector<GemmDesc>& gemm_descs,
258 index_t kbatch)
259 : K_BATCH{kbatch}, gemm_kernel_host_args_{nullptr}
260 {
261 grid_size_ = 0;
262 group_count_ = ck::type_convert<ck::index_t>(gemm_descs.size());
263
264 if(!(group_count_ == ck::type_convert<ck::index_t>(p_As.size()) &&
267 {
268 throw std::runtime_error("wrong! group_count_ != p_As/b/c.size");
269 }
270
272
274
275 for(std::size_t i = 0; i < gemm_descs.size(); ++i)
276 {
277 const index_t M = gemm_descs[i].M_;
278 const index_t N = gemm_descs[i].N_;
279 const index_t K = gemm_descs[i].K_;
280
281 if(M == 0)
282 {
284 continue;
285 }
286
287 const index_t stride_a = gemm_descs[i].stride_A_;
288 const index_t stride_b = gemm_descs[i].stride_B_;
289 const index_t stride_c = gemm_descs[i].stride_C_;
290
291 const index_t m_padded = GridwiseGemm64::CalculateMPadded(M);
292 const index_t n_padded = GridwiseGemm64::CalculateNPadded(N);
295
296 const auto c_grid_desc_m_n =
298
299 const auto local_b2c_tile_map =
300 Block2ETileMapKSplit{c_grid_desc_m_n, B2E_M01, K_BATCH};
301 const index_t grid_size_grp = local_b2c_tile_map.CalculateGridSize(c_grid_desc_m_n);
302
303 const index_t block_start = grid_size_;
304 const index_t block_end = grid_size_ + grid_size_grp;
305
306 grid_size_ += grid_size_grp;
307
308 // block-to-e-tile map
309 auto grouped_block_2_ctile_map =
310 GroupedGemmBlock2ETileMap(local_b2c_tile_map, block_start);
311
315 M,
316 N,
317 K,
318 stride_a,
319 stride_b,
320 stride_c,
321 m_padded,
322 n_padded,
323 k_padded,
324 k0_padded,
325 K_BATCH};
326
327 gemm_kernel_args_.emplace_back(
328 std::move(karg), std::move(grouped_block_2_ctile_map), block_start, block_end);
329 }
330 }
331
338 {
339 K_BATCH = kbatch;
340 grid_size_ = 0;
341
342 for(std::size_t i = 0; i < gemm_kernel_args_.size(); ++i)
343 {
344
345 auto& karg = gemm_kernel_args_[i].karg_;
346
347 const index_t k_padded = GridwiseGemm64::CalculateKPadded(karg.K, K_BATCH);
348 const index_t k0_padded = GridwiseGemm64::CalculateK0Padded(karg.K, K_BATCH);
349
350 const auto c_grid_desc_m_n =
351 GridwiseGemm64::MakeCGridDescriptor_M_N(karg.M, karg.N, karg.StrideC);
352
353 const auto local_b2c_tile_map =
354 Block2ETileMapKSplit{c_grid_desc_m_n, B2E_M01, K_BATCH};
355 const index_t grid_size_grp = local_b2c_tile_map.CalculateGridSize(c_grid_desc_m_n);
356
357 const index_t block_start = grid_size_;
358 const index_t block_end = grid_size_ + grid_size_grp;
359
360 grid_size_ += grid_size_grp;
361
362 // block-to-e-tile map
363 auto grouped_block_2_ctile_map =
364 GroupedGemmBlock2ETileMap(local_b2c_tile_map, block_start);
365
366 karg.KPadded = k_padded;
367 karg.K0Padded = k0_padded;
368 karg.k_batch = K_BATCH;
369 gemm_kernel_args_[i].block_2_ctile_map_ = grouped_block_2_ctile_map;
370 gemm_kernel_args_[i].block_start_ = block_start;
371 gemm_kernel_args_[i].block_end_ = block_end;
372 }
373 }
374
375 // private:
379
380 std::vector<GemmTransKernelArg> gemm_kernel_args_;
383 };
384
385 // Invoker
386 struct Invoker : public BaseInvoker
387 {
388 template <typename GridwiseGemm>
389 float RunImp(const Argument& arg,
390 const StreamConfig& stream_config = StreamConfig{},
391 hipStream_t cpy_stream = nullptr,
392 hipEvent_t cpy_event = nullptr)
393 {
395 static_assert(sizeof(GemmTransKernelArg_) == sizeof(GemmTransKernelArg));
396 static_assert(sizeof(typename GridwiseGemm::Argument) ==
397 sizeof(typename GridwiseGemm64::Argument));
398
399 index_t K0 = arg.gemm_kernel_args_[0].karg_.K0Padded;
400 bool all_have_kbatch_gt_one = arg.gemm_kernel_args_[0].karg_.k_batch > 1;
401 bool all_have_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0);
402
403 for(std::size_t i = 0; i < arg.gemm_kernel_args_.size(); ++i)
404 {
405 const auto& karg = reinterpret_cast<const typename GridwiseGemm::Argument&>(
406 arg.gemm_kernel_args_[i].karg_);
407 if(stream_config.log_level_ > 0)
408 {
409 karg.Print();
410 }
411
412 auto kbatch = karg.k_batch;
413
414 if(!GridwiseGemm::CheckValidity(karg))
415 {
416 std::ostringstream err;
417 err << "Group id: " << i << " has invalid GridwiseGemm settings!" << __FILE__
418 << ":" << __LINE__ << ", in function: " << __func__;
419 throw std::runtime_error(err.str());
420 }
421
422 K0 = karg.K0Padded;
423 bool not_all_have_main_k0_block_loop_same =
424 all_have_main_k0_block_loop xor GridwiseGemm::CalculateHasMainK0BlockLoop(K0);
425 bool not_all_have_kbatch_value_same = all_have_kbatch_gt_one xor (kbatch > 1);
426
427 if(not_all_have_main_k0_block_loop_same)
428 {
429 std::ostringstream err;
430 err << "Not all gemms have same value for main_k0_block_loop! in " << __FILE__
431 << ":" << __LINE__ << ", in function: " << __func__;
432 throw std::runtime_error(err.str());
433 }
434
435 if(not_all_have_kbatch_value_same)
436 {
437 std::ostringstream err;
438 err << "Not all gemms have same kbatch value (=1 or >1)! " << "group [" << i
439 << "], kbatch: " << kbatch
440 << ", group [0], kbatch: " << arg.gemm_kernel_args_[0].karg_.k_batch
441 << " in " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__;
442 throw std::runtime_error(err.str());
443 }
444 }
445
446 // If the user provides copy stream and copy event, we assume that they're also
447 // responsible for providing allocated host memory (eg. pinned) which
448 // would be used to copy kernel arguments to the device.
449 if(cpy_stream && cpy_event)
450 {
451 if(arg.gemm_kernel_host_args_ == nullptr)
452 {
453 std::ostringstream err;
454 err << "No memory has been allocated for gemm kernel host args "
455 << "when providing the copy stream and copy event! In " << __FILE__ << ":"
456 << __LINE__ << ", in function: " << __func__;
457 throw std::runtime_error(err.str());
458 }
459 hip_check_error(hipMemcpyAsync(arg.p_workspace_,
461 arg.group_count_ * sizeof(GemmTransKernelArg_),
462 hipMemcpyHostToDevice,
463 cpy_stream));
464 hip_check_error(hipEventRecord(cpy_event, cpy_stream));
465 hip_check_error(hipEventSynchronize(cpy_event));
466 }
467 else // In this case CK owns memory allocated on host.
468 {
469
471 hipMemcpyAsync(arg.p_workspace_,
472 arg.gemm_kernel_args_.data(),
473 arg.gemm_kernel_args_.size() * sizeof(GemmTransKernelArg_),
474 hipMemcpyHostToDevice,
475 stream_config.stream_id_));
476 }
477
478 float ave_time = 0;
479
480 const auto Run = [&](const auto& kernel) {
481 if(all_have_kbatch_gt_one)
482 {
483 for(const auto& trans_arg : arg.gemm_kernel_args_)
484 {
485 const auto& karg = trans_arg.karg_;
486 hip_check_error(hipMemsetAsync(karg.p_c_grid,
487 0,
488 karg.M * karg.N * sizeof(EDataType),
489 stream_config.stream_id_));
490 }
491 }
492
493 ave_time =
494 launch_and_time_kernel(stream_config,
495 kernel,
496 dim3(arg.grid_size_),
497 dim3(BlockSize),
498 0,
500 arg.gemm_kernel_args_.size(),
501 PassThrough{},
502 PassThrough{},
503 PassThrough{});
504 };
505
506 if(all_have_main_k0_block_loop)
507 {
508 if(all_have_kbatch_gt_one)
509 {
510 const auto kernel =
512 GemmTransKernelArg_,
513 true,
515
516 Run(kernel);
517 }
518 else
519 {
520 const auto kernel =
522 GemmTransKernelArg_,
523 true,
525
526 Run(kernel);
527 }
528 }
529 else
530 {
531 if(all_have_kbatch_gt_one)
532 {
533 const auto kernel =
535 GemmTransKernelArg_,
536 false,
538
539 Run(kernel);
540 }
541 else
542 {
543 const auto kernel =
545 GemmTransKernelArg_,
546 false,
548
549 Run(kernel);
550 }
551 }
552
553 return ave_time;
554 }
555
556 float Run(const Argument& arg,
557 const StreamConfig& stream_config = StreamConfig{},
558 hipStream_t cpy_stream = nullptr,
559 hipEvent_t cpy_event = nullptr)
560 {
561 if(get_warp_size() == 64)
562 {
563 if constexpr(NXdlPerWave64 > 0)
564 {
565 return RunImp<GridwiseGemm64>(arg, stream_config, cpy_stream, cpy_event);
566 }
567 }
568 else
569 {
570 if constexpr(NXdlPerWave32 > 0)
571 {
572 return RunImp<GridwiseGemm32>(arg, stream_config, cpy_stream, cpy_event);
573 }
574 }
575 return 0;
576 }
577
578 // polymorphic
579 float Run(const BaseArgument* p_arg,
580 const StreamConfig& stream_config = StreamConfig{}) override
581 {
582 return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
583 }
584 };
585
586 static constexpr bool IsValidCompilationParameter()
587 {
588 // TODO: properly implement this check
589 return true;
590 }
591
592 static bool IsSupportedArgument(const Argument& arg)
593 {
595 {
596 return false;
597 }
598 if(is_gfx11_supported() && arg.K_BATCH > 1)
599 {
600 return false;
601 }
604 {
605 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
606 {
607 std::cout << "The group count is not equal to sum of skipped groups "
608 "and kernel args size!"
609 << std::endl;
610 }
611 return false;
612 }
613
614 if(std::is_same_v<EDataType, ck::bhalf_t> && arg.K_BATCH > 1 && !is_bf16_atomic_supported())
615 {
616 return false;
617 }
618
619 bool supported = true;
620 bool isWave64 = get_warp_size() == 64;
621 for(std::size_t i = 0; i < arg.gemm_kernel_args_.size(); ++i)
622 {
623 const auto& a = arg.gemm_kernel_args_[i].karg_;
624 bool group_arg_valid = false;
625 if(isWave64)
626 {
627 if constexpr(NXdlPerWave64 > 0)
628 {
629 group_arg_valid = GridwiseGemm64::CheckValidity(a);
630 }
631 }
632 else
633 {
634 if constexpr(NXdlPerWave32 > 0)
635 {
636 group_arg_valid = GridwiseGemm32::CheckValidity(
637 reinterpret_cast<const typename GridwiseGemm32::Argument&>(a));
638 }
639 }
640
641 if(not group_arg_valid)
642 {
643 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
644 {
645 std::cout << "[" << __func__ << "] group id: " << i
646 << " has invalid GridwiseGemm settings!" << std::endl;
647 a.Print();
648 }
649 }
650 supported = supported && group_arg_valid;
651 }
652 return supported;
653 }
654
655 // polymorphic
656 bool IsSupportedArgument(const BaseArgument* p_arg) override
657 {
658 return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
659 }
660
661 static auto MakeArgument(std::vector<const void*>& p_As,
662 std::vector<const void*>& p_Bs,
663 std::vector<std::array<const void*, NumDTensor>>&,
664 std::vector<void*>& p_Es,
665 std::vector<GemmDesc> gemm_descs,
666 AElementwiseOperation,
667 BElementwiseOperation,
668 CDEElementwiseOperation)
669 {
670 return Argument{p_As, p_Bs, p_Es, gemm_descs};
671 }
672
673 static auto MakeInvoker() { return Invoker{}; }
674
675 // polymorphic
676 std::unique_ptr<BaseArgument>
677 MakeArgumentPointer(std::vector<const void*>& p_As,
678 std::vector<const void*>& p_Bs,
679 std::vector<std::array<const void*, NumDTensor>>&,
680 std::vector<void*>& p_Es,
681 std::vector<GemmDesc>& gemm_descs,
682 AElementwiseOperation,
683 BElementwiseOperation,
684 CDEElementwiseOperation) override
685 {
686 return std::make_unique<Argument>(p_As, p_Bs, p_Es, gemm_descs);
687 }
688
689 // polymorphic
690 std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
691 {
692 return std::make_unique<Invoker>(Invoker{});
693 }
694
695 // polymorphic
696 std::string GetTypeString() const override
697 {
698 auto str = std::stringstream();
699
700 // clang-format off
701 str << "DeviceGroupedGemm_XdlSplitK"
702 << "<"
703 << std::string(ALayout::name)[0] << ","
704 << std::string(BLayout::name)[0] << ","
705 << std::string(ELayout::name)[0] << ","
706 << BlockSize << ", "
707 << MPerBlock << ", "
708 << NPerBlock << ", "
709 << KPerBlock << ", "
710 << AK1 << ", "
711 << BK1 << ", "
712 << MPerXDL << ", "
713 << NPerXDL << ", "
714 << MXdlPerWave << ", "
715 << NXdlPerWave << ", "
716 << ABlockTransferSrcScalarPerVector << ", "
717 << BBlockTransferSrcScalarPerVector << ", "
718 << CShuffleMXdlPerWavePerShuffle << ", "
719 << CShuffleNXdlPerWavePerShuffle << ", "
720 << getGemmSpecializationString(GemmSpec)
721 << ">";
722 // clang-format on
723
724 return str.str();
725 }
726
727 size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override
728 {
729 auto p_arg_ = dynamic_cast<const Argument*>(p_arg);
730 if(p_arg_)
731 {
732 return p_arg_->gemm_kernel_args_.size() * sizeof(GemmTransKernelArg);
733 }
734 else
735 throw std::runtime_error(
736 "The argument pointer is not an object of "
737 "DeviceGroupedGemmMultipleDSplitKXdlCShuffle::Argument structure!");
738 }
739
740 size_t GetDeviceKernelArgSize(const BaseArgument* p_arg) const override
741 {
742 return GetWorkSpaceSize(p_arg);
743 }
744
745 size_t GetHostKernelArgSize(const BaseArgument* p_arg) const { return GetWorkSpaceSize(p_arg); }
746
747 // TODO: deperecation notice.
748 static void SetKBatchSize(Argument& arg, index_t kbatch) { arg.UpdateKBatch(kbatch); }
749
750 // polymorphic
751 void SetKBatchSize(BaseArgument* p_arg, index_t kbatch) const override
752 {
753 auto p_arg_ = dynamic_cast<Argument*>(p_arg);
754 if(p_arg_)
755 {
756 p_arg_->UpdateKBatch(kbatch);
757 }
758 else
759 throw std::runtime_error(
760 "The argument pointer is not an object of "
761 "DeviceGroupedGemmMultipleDSplitKXdlCShuffle::Argument structure!");
762 }
763
764 void SetDeviceKernelArgs(BaseArgument* p_arg, void* p_dev_kernel_args) const override
765 {
766 return this->SetWorkSpacePointer(p_arg, p_dev_kernel_args);
767 }
768
769 //----------------------------------------------------------------------------------------------
778 void SetHostKernelArgsPointer(BaseArgument* p_arg, void* p_host_kernel_args) const
779 {
780 Argument* pArg_ = dynamic_cast<Argument*>(p_arg);
781 if(!pArg_)
782 {
783 throw std::runtime_error("Failed to cast argument pointer!");
784 }
785
786 pArg_->gemm_kernel_host_args_ = p_host_kernel_args;
787 std::copy(pArg_->gemm_kernel_args_.begin(),
788 pArg_->gemm_kernel_args_.end(),
789 static_cast<GemmTransKernelArg*>(pArg_->gemm_kernel_host_args_));
790 }
791};
792
793} // namespace device
794} // namespace tensor_operation
795} // namespace ck
#define CK_CONSTANT_ADDRESS_SPACE
Definition ck.hpp:23
#define CK_MIN_BLOCK_PER_CU
Definition ck.hpp:31
#define CK_MAX_THREAD_PER_BLOCK
Definition ck.hpp:30
#define GET_NXDL_PER_WAVE_IMPL
Definition device_base.hpp:81
void hip_check_error(hipError_t x)
Definition host_utility/hip_check_error.hpp:10
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:14
__host__ __device__ constexpr T max(T x)
Definition utility/math.hpp:84
Definition convolution_backward_data_specialization.hpp:8
std::string getGemmSpecializationString(const GemmSpecialization &s)
Definition gemm_specialization.hpp:32
__global__ void kernel_grouped_gemm_xdl_splitk(const void CK_CONSTANT_ADDRESS_SPACE *gemm_descs_const, const index_t group_count, const AElementwiseOperation a_element_op, const BElementwiseOperation b_element_op, const CElementwiseOperation c_element_op)
Definition device_grouped_gemm_xdl_splitk_cshuffle.hpp:38
GemmSpecialization
Definition gemm_specialization.hpp:11
Definition convolution_backward_data_specialization.hpp:7
Definition ck.hpp:268
__host__ __device__ T CK_CONSTANT_ADDRESS_SPACE * cast_pointer_to_constant_address_space(T *p)
Definition amd_address_space.hpp:35
int32_t index_t
Definition ck.hpp:299
InMemoryDataOperationEnum
Definition ck.hpp:277
@ Set
Definition ck.hpp:278
@ AtomicAdd
Definition ck.hpp:279
integral_constant< index_t, N > Number
Definition number.hpp:12
constexpr detail::ignore_t ignore
Definition utility/ignore.hpp:20
__device__ index_t get_block_1d_id()
Definition get_id.hpp:47
__host__ __device__ constexpr Y type_convert(X x)
Definition utility/type_convert.hpp:98
bool is_xdl_wmma_supported()
Definition host_utility/device_prop.hpp:76
__device__ constexpr index_t get_warp_size()
Definition get_id.hpp:10
bool EnvIsEnabled(EnvVar)
Definition utility/env.hpp:140
constexpr bool is_same_v
Definition type.hpp:283
LoopScheduler
Definition loop_scheduler.hpp:15
PipelineVersion
Definition gridwise_gemm_pipeline_selector.hpp:18
@ v1
Definition gridwise_gemm_pipeline_selector.hpp:19
__device__ T * cast_pointer_to_generic_address_space(T CK_CONSTANT_ADDRESS_SPACE *p)
Definition amd_address_space.hpp:24
bool is_gfx11_supported()
Definition host_utility/device_prop.hpp:60
typename std::enable_if< B, T >::type enable_if_t
Definition enable_if.hpp:27
constexpr LoopScheduler make_default_loop_scheduler()
Definition loop_scheduler.hpp:20
bool is_bf16_atomic_supported()
Definition host_utility/device_prop.hpp:108
const GenericPointer< typename T::ValueType > T2 T::AllocatorType & a
Definition pointer.h:1517
unsigned char uint8_t
Definition stdint.h:124
Definition ck/stream_config.hpp:10
Definition block_to_ctile_map.hpp:541
Definition gridwise_gemm_xdlops_v2r4r2.hpp:106
Definition block_to_ctile_map.hpp:872
Definition device_base.hpp:197
void * p_workspace_
Definition device_base.hpp:204
virtual void SetWorkSpacePointer(BaseArgument *p_arg, void *p_workspace, const StreamConfig &=StreamConfig{}) const
Definition device_base.hpp:249
Definition device_grouped_gemm_splitk.hpp:33
Definition device_grouped_gemm_xdl_splitk_cshuffle.hpp:243
index_t skipped_group_count_
Definition device_grouped_gemm_xdl_splitk_cshuffle.hpp:378
index_t K_BATCH
Definition device_grouped_gemm_xdl_splitk_cshuffle.hpp:376
void UpdateKBatch(index_t kbatch)
Recalculate group grid size for all gemms and update B2C maps.
Definition device_grouped_gemm_xdl_splitk_cshuffle.hpp:337
void * gemm_kernel_host_args_
Definition device_grouped_gemm_xdl_splitk_cshuffle.hpp:381
std::vector< GemmTransKernelArg > gemm_kernel_args_
Definition device_grouped_gemm_xdl_splitk_cshuffle.hpp:380
Argument(std::vector< const void * > &p_As, std::vector< const void * > &p_Bs, std::vector< void * > &p_Es, std::vector< GemmDesc > &gemm_descs, index_t kbatch)
Definition device_grouped_gemm_xdl_splitk_cshuffle.hpp:254
index_t grid_size_
Definition device_grouped_gemm_xdl_splitk_cshuffle.hpp:382
Argument(std::vector< const void * > &p_As, std::vector< const void * > &p_Bs, std::vector< void * > &p_Es, std::vector< GemmDesc > &gemm_descs)
Definition device_grouped_gemm_xdl_splitk_cshuffle.hpp:245
index_t group_count_
Definition device_grouped_gemm_xdl_splitk_cshuffle.hpp:377
GroupedGemmBlock2ETileMap block_2_ctile_map_
Definition device_grouped_gemm_xdl_splitk_cshuffle.hpp:222
KernelArgument karg_
Definition device_grouped_gemm_xdl_splitk_cshuffle.hpp:221
GemmTransKernelArgBase(KernelArgument_ &&karg, GroupedGemmBlock2ETileMap &&b2c_map, index_t block_start, index_t block_end)
Definition device_grouped_gemm_xdl_splitk_cshuffle.hpp:226
Definition device_grouped_gemm_xdl_splitk_cshuffle.hpp:387
float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition device_grouped_gemm_xdl_splitk_cshuffle.hpp:579
float Run(const Argument &arg, const StreamConfig &stream_config=StreamConfig{}, hipStream_t cpy_stream=nullptr, hipEvent_t cpy_event=nullptr)
Definition device_grouped_gemm_xdl_splitk_cshuffle.hpp:556
float RunImp(const Argument &arg, const StreamConfig &stream_config=StreamConfig{}, hipStream_t cpy_stream=nullptr, hipEvent_t cpy_event=nullptr)
Definition device_grouped_gemm_xdl_splitk_cshuffle.hpp:389
Definition device_grouped_gemm_xdl_splitk_cshuffle.hpp:149
static bool IsSupportedArgument(const Argument &arg)
Definition device_grouped_gemm_xdl_splitk_cshuffle.hpp:592
BlockToCTileMap_KSplit_M00_N0_M01Adapt< MPerBlock, NPerBlock, CGridDesc_M_N > Block2ETileMapKSplit
Definition device_grouped_gemm_xdl_splitk_cshuffle.hpp:211
std::string GetTypeString() const override
Definition device_grouped_gemm_xdl_splitk_cshuffle.hpp:696
void SetKBatchSize(BaseArgument *p_arg, index_t kbatch) const override
Sets the k batch size.
Definition device_grouped_gemm_xdl_splitk_cshuffle.hpp:751
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition device_grouped_gemm_xdl_splitk_cshuffle.hpp:690
static constexpr auto I1
Definition device_grouped_gemm_xdl_splitk_cshuffle.hpp:156
static auto MakeInvoker()
Definition device_grouped_gemm_xdl_splitk_cshuffle.hpp:673
static constexpr auto I2
Definition device_grouped_gemm_xdl_splitk_cshuffle.hpp:157
static GET_NXDL_PER_WAVE_IMPL constexpr auto NXdlPerWave64
Definition device_grouped_gemm_xdl_splitk_cshuffle.hpp:151
size_t GetHostKernelArgSize(const BaseArgument *p_arg) const
Definition device_grouped_gemm_xdl_splitk_cshuffle.hpp:745
GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2< BlockSize, ADataType, BDataType, AccDataType, EDataType, ALayout, BLayout, ELayout, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, GemmSpec, NumGemmKPrefetchStage, MPerBlock, NPerBlock, K0PerBlock, MPerXDL, NPerXDL, AK1, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_K1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_K0_N_K1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_K1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CDEBlockTransferScalarPerVector_NPerBlock, CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, LoopSched, PipelineVer > GridwiseGemmBase
Definition device_grouped_gemm_xdl_splitk_cshuffle.hpp:163
GemmTransKernelArgBase< KernelArgument > GemmTransKernelArg
Definition device_grouped_gemm_xdl_splitk_cshuffle.hpp:237
static constexpr index_t DefaultKBatch
Definition device_grouped_gemm_xdl_splitk_cshuffle.hpp:239
void SetHostKernelArgsPointer(BaseArgument *p_arg, void *p_host_kernel_args) const
Sets the host kernel arguments pointer and copies that data on the host side. This function can be ut...
Definition device_grouped_gemm_xdl_splitk_cshuffle.hpp:778
static constexpr index_t B2E_M01
Definition device_grouped_gemm_xdl_splitk_cshuffle.hpp:214
size_t GetWorkSpaceSize(const BaseArgument *p_arg) const override
Definition device_grouped_gemm_xdl_splitk_cshuffle.hpp:727
GridwiseGemmBase< NXdlPerWave32 > GridwiseGemm32
Definition device_grouped_gemm_xdl_splitk_cshuffle.hpp:208
void SetDeviceKernelArgs(BaseArgument *p_arg, void *p_dev_kernel_args) const override
Definition device_grouped_gemm_xdl_splitk_cshuffle.hpp:764
size_t GetDeviceKernelArgSize(const BaseArgument *p_arg) const override
Gets the device kernel argument size.
Definition device_grouped_gemm_xdl_splitk_cshuffle.hpp:740
static constexpr auto I3
Definition device_grouped_gemm_xdl_splitk_cshuffle.hpp:158
OffsettedBlockToCTileMap< Block2ETileMapKSplit > GroupedGemmBlock2ETileMap
Definition device_grouped_gemm_xdl_splitk_cshuffle.hpp:215
static constexpr bool IsValidCompilationParameter()
Definition device_grouped_gemm_xdl_splitk_cshuffle.hpp:586
static constexpr index_t NumDTensor
Definition device_grouped_gemm_xdl_splitk_cshuffle.hpp:153
typename GridwiseGemm64::CGridDesc_M_N CGridDesc_M_N
Definition device_grouped_gemm_xdl_splitk_cshuffle.hpp:210
static constexpr auto NXdlPerWave32
Definition device_grouped_gemm_xdl_splitk_cshuffle.hpp:152
ck::tensor_operation::element_wise::PassThrough PassThrough
Definition device_grouped_gemm_xdl_splitk_cshuffle.hpp:217
GridwiseGemmBase< math::max(NXdlPerWave64, 1)> GridwiseGemm64
Definition device_grouped_gemm_xdl_splitk_cshuffle.hpp:207
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition device_grouped_gemm_xdl_splitk_cshuffle.hpp:656
static constexpr auto I0
Definition device_grouped_gemm_xdl_splitk_cshuffle.hpp:155
static auto MakeArgument(std::vector< const void * > &p_As, std::vector< const void * > &p_Bs, std::vector< std::array< const void *, NumDTensor > > &, std::vector< void * > &p_Es, std::vector< GemmDesc > gemm_descs, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation)
Definition device_grouped_gemm_xdl_splitk_cshuffle.hpp:661
static void SetKBatchSize(Argument &arg, index_t kbatch)
Definition device_grouped_gemm_xdl_splitk_cshuffle.hpp:748
typename GridwiseGemm64::Argument KernelArgument
Definition device_grouped_gemm_xdl_splitk_cshuffle.hpp:216
static constexpr index_t K0PerBlock
Definition device_grouped_gemm_xdl_splitk_cshuffle.hpp:160
std::unique_ptr< BaseArgument > MakeArgumentPointer(std::vector< const void * > &p_As, std::vector< const void * > &p_Bs, std::vector< std::array< const void *, NumDTensor > > &, std::vector< void * > &p_Es, std::vector< GemmDesc > &gemm_descs, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation) override
Definition device_grouped_gemm_xdl_splitk_cshuffle.hpp:677
Definition device_grouped_gemm.hpp:80
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:340
#define CK_ENV(name)
Definition utility/env.hpp:129