27template <
typename GridwiseGemm,
29 bool HasMainKBlockLoop,
31 typename AElementwiseOperation = ck::tensor_operation::element_wise::PassThrough,
32 typename BElementwiseOperation = ck::tensor_operation::element_wise::PassThrough,
33 typename CElementwiseOperation = ck::tensor_operation::element_wise::PassThrough>
35#if CK_USE_LAUNCH_BOUNDS
40 const AElementwiseOperation a_element_op,
41 const BElementwiseOperation b_element_op,
42 const CElementwiseOperation c_element_op)
44#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__)
45 if constexpr(GridwiseGemm::template IsValidCompilationParameter<CGlobalMemoryDataOperation>())
47 constexpr index_t shared_size = GridwiseGemm::GetSharedMemoryNumberOfByte();
48 __shared__
uint8_t p_shared[shared_size];
51 const auto gemm_desc_ptr =
reinterpret_cast<const GemmDesc*
>(
57 while((!(block_id >= gemm_desc_ptr[group_id].block_start_ &&
58 block_id < gemm_desc_ptr[group_id].block_end_)) &&
61 if(block_id < gemm_desc_ptr[group_id].block_start_)
69 group_id =
index_t((left + right) / 2);
72 GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation>(
73 gemm_desc_ptr[group_id].karg_,
74 static_cast<void*
>(p_shared),
75 gemm_desc_ptr[group_id].block_2_ctile_map_,
89template <
typename ALayout,
96 typename CShuffleDataType,
99 typename AElementwiseOperation,
100 typename BElementwiseOperation,
101 typename CDEElementwiseOperation,
114 typename ABlockTransferThreadClusterLengths_K0_M_K1,
115 typename ABlockTransferThreadClusterArrangeOrder,
116 typename ABlockTransferSrcAccessOrder,
120 bool ABlockLdsExtraM,
121 typename BBlockTransferThreadClusterLengths_K0_N_K1,
122 typename BBlockTransferThreadClusterArrangeOrder,
123 typename BBlockTransferSrcAccessOrder,
127 bool BBlockLdsExtraN,
128 index_t CShuffleMXdlPerWavePerShuffle,
129 index_t CShuffleNXdlPerWavePerShuffle,
130 typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
131 index_t CDEBlockTransferScalarPerVector_NPerBlock,
146 AElementwiseOperation,
147 BElementwiseOperation,
148 CDEElementwiseOperation>
159 static_assert(KPerBlock % AK1 == 0);
162 template <index_t NXdlPerWave_>
172 AElementwiseOperation,
173 BElementwiseOperation,
174 CDEElementwiseOperation,
176 NumGemmKPrefetchStage,
185 ABlockTransferThreadClusterLengths_K0_M_K1,
186 ABlockTransferThreadClusterArrangeOrder,
187 ABlockTransferSrcAccessOrder,
188 ABlockTransferSrcVectorDim,
189 ABlockTransferSrcScalarPerVector,
190 ABlockTransferDstScalarPerVector_K1,
193 BBlockTransferThreadClusterLengths_K0_N_K1,
194 BBlockTransferThreadClusterArrangeOrder,
195 BBlockTransferSrcAccessOrder,
196 BBlockTransferSrcVectorDim,
197 BBlockTransferSrcScalarPerVector,
198 BBlockTransferDstScalarPerVector_K1,
201 CShuffleMXdlPerWavePerShuffle,
202 CShuffleNXdlPerWavePerShuffle,
203 CDEBlockTransferScalarPerVector_NPerBlock,
204 CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
218 template <
typename KernelArgument_>
246 std::vector<const void*>& p_Bs,
247 std::vector<void*>& p_Es,
248 std::vector<GemmDesc>& gemm_descs)
255 std::vector<const void*>& p_Bs,
256 std::vector<void*>& p_Es,
257 std::vector<GemmDesc>& gemm_descs,
268 throw std::runtime_error(
"wrong! group_count_ != p_As/b/c.size");
275 for(std::size_t i = 0; i < gemm_descs.size(); ++i)
277 const index_t M = gemm_descs[i].M_;
278 const index_t N = gemm_descs[i].N_;
279 const index_t K = gemm_descs[i].K_;
287 const index_t stride_a = gemm_descs[i].stride_A_;
288 const index_t stride_b = gemm_descs[i].stride_B_;
289 const index_t stride_c = gemm_descs[i].stride_C_;
296 const auto c_grid_desc_m_n =
299 const auto local_b2c_tile_map =
301 const index_t grid_size_grp = local_b2c_tile_map.CalculateGridSize(c_grid_desc_m_n);
309 auto grouped_block_2_ctile_map =
328 std::move(karg), std::move(grouped_block_2_ctile_map), block_start, block_end);
350 const auto c_grid_desc_m_n =
353 const auto local_b2c_tile_map =
355 const index_t grid_size_grp = local_b2c_tile_map.CalculateGridSize(c_grid_desc_m_n);
363 auto grouped_block_2_ctile_map =
366 karg.KPadded = k_padded;
367 karg.K0Padded = k0_padded;
388 template <
typename Gr
idwiseGemm>
391 hipStream_t cpy_stream =
nullptr,
392 hipEvent_t cpy_event =
nullptr)
396 static_assert(
sizeof(
typename GridwiseGemm::Argument) ==
397 sizeof(
typename GridwiseGemm64::Argument));
401 bool all_have_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0);
405 const auto& karg =
reinterpret_cast<const typename GridwiseGemm::Argument&
>(
407 if(stream_config.log_level_ > 0)
412 auto kbatch = karg.k_batch;
414 if(!GridwiseGemm::CheckValidity(karg))
416 std::ostringstream err;
417 err <<
"Group id: " << i <<
" has invalid GridwiseGemm settings!" << __FILE__
418 <<
":" << __LINE__ <<
", in function: " << __func__;
419 throw std::runtime_error(err.str());
423 bool not_all_have_main_k0_block_loop_same =
424 all_have_main_k0_block_loop xor GridwiseGemm::CalculateHasMainK0BlockLoop(K0);
425 bool not_all_have_kbatch_value_same = all_have_kbatch_gt_one xor (kbatch > 1);
427 if(not_all_have_main_k0_block_loop_same)
429 std::ostringstream err;
430 err <<
"Not all gemms have same value for main_k0_block_loop! in " << __FILE__
431 <<
":" << __LINE__ <<
", in function: " << __func__;
432 throw std::runtime_error(err.str());
435 if(not_all_have_kbatch_value_same)
437 std::ostringstream err;
438 err <<
"Not all gemms have same kbatch value (=1 or >1)! " <<
"group [" << i
439 <<
"], kbatch: " << kbatch
441 <<
" in " << __FILE__ <<
":" << __LINE__ <<
", in function: " << __func__;
442 throw std::runtime_error(err.str());
449 if(cpy_stream && cpy_event)
453 std::ostringstream err;
454 err <<
"No memory has been allocated for gemm kernel host args "
455 <<
"when providing the copy stream and copy event! In " << __FILE__ <<
":"
456 << __LINE__ <<
", in function: " << __func__;
457 throw std::runtime_error(err.str());
462 hipMemcpyHostToDevice,
474 hipMemcpyHostToDevice,
475 stream_config.stream_id_));
480 const auto Run = [&](
const auto& kernel) {
481 if(all_have_kbatch_gt_one)
485 const auto& karg = trans_arg.karg_;
488 karg.M * karg.N *
sizeof(EDataType),
489 stream_config.stream_id_));
506 if(all_have_main_k0_block_loop)
508 if(all_have_kbatch_gt_one)
531 if(all_have_kbatch_gt_one)
558 hipStream_t cpy_stream =
nullptr,
559 hipEvent_t cpy_event =
nullptr)
582 return Run(*
dynamic_cast<const Argument*
>(p_arg), stream_config);
607 std::cout <<
"The group count is not equal to sum of skipped groups "
608 "and kernel args size!"
619 bool supported =
true;
624 bool group_arg_valid =
false;
637 reinterpret_cast<const typename GridwiseGemm32::Argument&
>(
a));
641 if(not group_arg_valid)
645 std::cout <<
"[" << __func__ <<
"] group id: " << i
646 <<
" has invalid GridwiseGemm settings!" << std::endl;
650 supported = supported && group_arg_valid;
662 std::vector<const void*>& p_Bs,
663 std::vector<std::array<const void*, NumDTensor>>&,
664 std::vector<void*>& p_Es,
665 std::vector<GemmDesc> gemm_descs,
666 AElementwiseOperation,
667 BElementwiseOperation,
668 CDEElementwiseOperation)
670 return Argument{p_As, p_Bs, p_Es, gemm_descs};
676 std::unique_ptr<BaseArgument>
678 std::vector<const void*>& p_Bs,
679 std::vector<std::array<const void*, NumDTensor>>&,
680 std::vector<void*>& p_Es,
681 std::vector<GemmDesc>& gemm_descs,
682 AElementwiseOperation,
683 BElementwiseOperation,
684 CDEElementwiseOperation)
override
686 return std::make_unique<Argument>(p_As, p_Bs, p_Es, gemm_descs);
692 return std::make_unique<Invoker>(
Invoker{});
698 auto str = std::stringstream();
701 str <<
"DeviceGroupedGemm_XdlSplitK"
703 << std::string(ALayout::name)[0] <<
","
704 << std::string(BLayout::name)[0] <<
","
705 << std::string(ELayout::name)[0] <<
","
714 << MXdlPerWave <<
", "
715 << NXdlPerWave <<
", "
716 << ABlockTransferSrcScalarPerVector <<
", "
717 << BBlockTransferSrcScalarPerVector <<
", "
718 << CShuffleMXdlPerWavePerShuffle <<
", "
719 << CShuffleNXdlPerWavePerShuffle <<
", "
729 auto p_arg_ =
dynamic_cast<const Argument*
>(p_arg);
735 throw std::runtime_error(
736 "The argument pointer is not an object of "
737 "DeviceGroupedGemmMultipleDSplitKXdlCShuffle::Argument structure!");
753 auto p_arg_ =
dynamic_cast<Argument*
>(p_arg);
756 p_arg_->UpdateKBatch(kbatch);
759 throw std::runtime_error(
760 "The argument pointer is not an object of "
761 "DeviceGroupedGemmMultipleDSplitKXdlCShuffle::Argument structure!");
783 throw std::runtime_error(
"Failed to cast argument pointer!");
#define CK_CONSTANT_ADDRESS_SPACE
Definition ck.hpp:23
#define CK_MIN_BLOCK_PER_CU
Definition ck.hpp:31
#define CK_MAX_THREAD_PER_BLOCK
Definition ck.hpp:30
#define GET_NXDL_PER_WAVE_IMPL
Definition device_base.hpp:81
void hip_check_error(hipError_t x)
Definition host_utility/hip_check_error.hpp:10
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:14
__host__ __device__ constexpr T max(T x)
Definition utility/math.hpp:84
Definition convolution_backward_data_specialization.hpp:8
std::string getGemmSpecializationString(const GemmSpecialization &s)
Definition gemm_specialization.hpp:32
__global__ void kernel_grouped_gemm_xdl_splitk(const void CK_CONSTANT_ADDRESS_SPACE *gemm_descs_const, const index_t group_count, const AElementwiseOperation a_element_op, const BElementwiseOperation b_element_op, const CElementwiseOperation c_element_op)
Definition device_grouped_gemm_xdl_splitk_cshuffle.hpp:38
GemmSpecialization
Definition gemm_specialization.hpp:11
Definition convolution_backward_data_specialization.hpp:7
__host__ __device__ T CK_CONSTANT_ADDRESS_SPACE * cast_pointer_to_constant_address_space(T *p)
Definition amd_address_space.hpp:35
int32_t index_t
Definition ck.hpp:299
InMemoryDataOperationEnum
Definition ck.hpp:277
@ Set
Definition ck.hpp:278
@ AtomicAdd
Definition ck.hpp:279
integral_constant< index_t, N > Number
Definition number.hpp:12
constexpr detail::ignore_t ignore
Definition utility/ignore.hpp:20
__device__ index_t get_block_1d_id()
Definition get_id.hpp:47
__host__ __device__ constexpr Y type_convert(X x)
Definition utility/type_convert.hpp:98
bool is_xdl_wmma_supported()
Definition host_utility/device_prop.hpp:76
__device__ constexpr index_t get_warp_size()
Definition get_id.hpp:10
bool EnvIsEnabled(EnvVar)
Definition utility/env.hpp:140
constexpr bool is_same_v
Definition type.hpp:283
LoopScheduler
Definition loop_scheduler.hpp:15
PipelineVersion
Definition gridwise_gemm_pipeline_selector.hpp:18
@ v1
Definition gridwise_gemm_pipeline_selector.hpp:19
__device__ T * cast_pointer_to_generic_address_space(T CK_CONSTANT_ADDRESS_SPACE *p)
Definition amd_address_space.hpp:24
bool is_gfx11_supported()
Definition host_utility/device_prop.hpp:60
typename std::enable_if< B, T >::type enable_if_t
Definition enable_if.hpp:27
constexpr LoopScheduler make_default_loop_scheduler()
Definition loop_scheduler.hpp:20
bool is_bf16_atomic_supported()
Definition host_utility/device_prop.hpp:108
const GenericPointer< typename T::ValueType > T2 T::AllocatorType & a
Definition pointer.h:1517
unsigned char uint8_t
Definition stdint.h:124
Definition ck/stream_config.hpp:10
Definition block_to_ctile_map.hpp:541
Definition gridwise_gemm_xdlops_v2r4r2.hpp:106
ck::GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2< BlockSize, ADataType, BDataType, AccDataType, EDataType, ALayout, BLayout, ELayout, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, GemmSpec, NumGemmKPrefetchStage, MPerBlock, NPerBlock, K0PerBlock, MPerXDL, NPerXDL, AK1, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_K1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_K0_N_K1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_K1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CDEBlockTransferScalarPerVector_NPerBlock, CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, LoopSched, PipelineVer >::CheckValidity __host__ static __device__ constexpr bool CheckValidity(const Argument &karg)
Definition gridwise_gemm_xdlops_v2r4r2.hpp:440
ck::GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2< BlockSize, ADataType, BDataType, AccDataType, EDataType, ALayout, BLayout, ELayout, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, GemmSpec, NumGemmKPrefetchStage, MPerBlock, NPerBlock, K0PerBlock, MPerXDL, NPerXDL, AK1, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_K1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_K0_N_K1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_K1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CDEBlockTransferScalarPerVector_NPerBlock, CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, LoopSched, PipelineVer >::CalculateMPadded __host__ static __device__ auto CalculateMPadded(index_t M)
Definition gridwise_gemm_xdlops_v2r4r2.hpp:196
ck::GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2< BlockSize, ADataType, BDataType, AccDataType, EDataType, ALayout, BLayout, ELayout, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, GemmSpec, NumGemmKPrefetchStage, MPerBlock, NPerBlock, K0PerBlock, MPerXDL, NPerXDL, AK1, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_K1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_K0_N_K1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_K1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CDEBlockTransferScalarPerVector_NPerBlock, CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, LoopSched, PipelineVer >::CalculateKPadded __host__ static __device__ auto CalculateKPadded(index_t K, index_t K_Batch=1)
Definition gridwise_gemm_xdlops_v2r4r2.hpp:213
ck::GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2< BlockSize, ADataType, BDataType, AccDataType, EDataType, ALayout, BLayout, ELayout, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, GemmSpec, NumGemmKPrefetchStage, MPerBlock, NPerBlock, K0PerBlock, MPerXDL, NPerXDL, AK1, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_K1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_K0_N_K1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_K1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CDEBlockTransferScalarPerVector_NPerBlock, CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, LoopSched, PipelineVer >::CalculateK0Padded __host__ static __device__ auto CalculateK0Padded(index_t K, index_t K_Batch=1)
Definition gridwise_gemm_xdlops_v2r4r2.hpp:206
ck::GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2< BlockSize, ADataType, BDataType, AccDataType, EDataType, ALayout, BLayout, ELayout, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, GemmSpec, NumGemmKPrefetchStage, MPerBlock, NPerBlock, K0PerBlock, MPerXDL, NPerXDL, AK1, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_K1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_K0_N_K1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_K1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CDEBlockTransferScalarPerVector_NPerBlock, CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, LoopSched, PipelineVer >::CGridDesc_M_N remove_cvref_t< decltype(MakeCGridDescriptor_M_N(1, 1, 1))> CGridDesc_M_N
Definition gridwise_gemm_xdlops_v2r4r2.hpp:661
ck::GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2< BlockSize, ADataType, BDataType, AccDataType, EDataType, ALayout, BLayout, ELayout, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, GemmSpec, NumGemmKPrefetchStage, MPerBlock, NPerBlock, K0PerBlock, MPerXDL, NPerXDL, AK1, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_K1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_K0_N_K1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_K1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CDEBlockTransferScalarPerVector_NPerBlock, CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, LoopSched, PipelineVer >::MakeCGridDescriptor_M_N __host__ static __device__ auto MakeCGridDescriptor_M_N(index_t M, index_t N, index_t StrideC)
Definition gridwise_gemm_xdlops_v2r4r2.hpp:371
ck::GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2< BlockSize, ADataType, BDataType, AccDataType, EDataType, ALayout, BLayout, ELayout, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, GemmSpec, NumGemmKPrefetchStage, MPerBlock, NPerBlock, K0PerBlock, MPerXDL, NPerXDL, AK1, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_K1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_K0_N_K1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_K1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CDEBlockTransferScalarPerVector_NPerBlock, CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, LoopSched, PipelineVer >::CalculateNPadded __host__ static __device__ auto CalculateNPadded(index_t N)
Definition gridwise_gemm_xdlops_v2r4r2.hpp:201
Definition block_to_ctile_map.hpp:872
Definition device_base.hpp:197
void * p_workspace_
Definition device_base.hpp:204
virtual void SetWorkSpacePointer(BaseArgument *p_arg, void *p_workspace, const StreamConfig &=StreamConfig{}) const
Definition device_base.hpp:249
Definition device_grouped_gemm_splitk.hpp:33
Definition device_grouped_gemm_xdl_splitk_cshuffle.hpp:243
index_t skipped_group_count_
Definition device_grouped_gemm_xdl_splitk_cshuffle.hpp:378
index_t K_BATCH
Definition device_grouped_gemm_xdl_splitk_cshuffle.hpp:376
void UpdateKBatch(index_t kbatch)
Recalculate group grid size for all gemms and update B2C maps.
Definition device_grouped_gemm_xdl_splitk_cshuffle.hpp:337
void * gemm_kernel_host_args_
Definition device_grouped_gemm_xdl_splitk_cshuffle.hpp:381
std::vector< GemmTransKernelArg > gemm_kernel_args_
Definition device_grouped_gemm_xdl_splitk_cshuffle.hpp:380
Argument(std::vector< const void * > &p_As, std::vector< const void * > &p_Bs, std::vector< void * > &p_Es, std::vector< GemmDesc > &gemm_descs, index_t kbatch)
Definition device_grouped_gemm_xdl_splitk_cshuffle.hpp:254
index_t grid_size_
Definition device_grouped_gemm_xdl_splitk_cshuffle.hpp:382
Argument(std::vector< const void * > &p_As, std::vector< const void * > &p_Bs, std::vector< void * > &p_Es, std::vector< GemmDesc > &gemm_descs)
Definition device_grouped_gemm_xdl_splitk_cshuffle.hpp:245
index_t group_count_
Definition device_grouped_gemm_xdl_splitk_cshuffle.hpp:377
Definition device_grouped_gemm_xdl_splitk_cshuffle.hpp:220
GroupedGemmBlock2ETileMap block_2_ctile_map_
Definition device_grouped_gemm_xdl_splitk_cshuffle.hpp:222
KernelArgument karg_
Definition device_grouped_gemm_xdl_splitk_cshuffle.hpp:221
GemmTransKernelArgBase(KernelArgument_ &&karg, GroupedGemmBlock2ETileMap &&b2c_map, index_t block_start, index_t block_end)
Definition device_grouped_gemm_xdl_splitk_cshuffle.hpp:226
GemmTransKernelArgBase()=default
index_t block_end_
Definition device_grouped_gemm_xdl_splitk_cshuffle.hpp:223
index_t block_start_
Definition device_grouped_gemm_xdl_splitk_cshuffle.hpp:223
Definition device_grouped_gemm_xdl_splitk_cshuffle.hpp:387
float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition device_grouped_gemm_xdl_splitk_cshuffle.hpp:579
float Run(const Argument &arg, const StreamConfig &stream_config=StreamConfig{}, hipStream_t cpy_stream=nullptr, hipEvent_t cpy_event=nullptr)
Definition device_grouped_gemm_xdl_splitk_cshuffle.hpp:556
float RunImp(const Argument &arg, const StreamConfig &stream_config=StreamConfig{}, hipStream_t cpy_stream=nullptr, hipEvent_t cpy_event=nullptr)
Definition device_grouped_gemm_xdl_splitk_cshuffle.hpp:389
Definition device_grouped_gemm_xdl_splitk_cshuffle.hpp:149
static bool IsSupportedArgument(const Argument &arg)
Definition device_grouped_gemm_xdl_splitk_cshuffle.hpp:592
BlockToCTileMap_KSplit_M00_N0_M01Adapt< MPerBlock, NPerBlock, CGridDesc_M_N > Block2ETileMapKSplit
Definition device_grouped_gemm_xdl_splitk_cshuffle.hpp:211
std::string GetTypeString() const override
Definition device_grouped_gemm_xdl_splitk_cshuffle.hpp:696
void SetKBatchSize(BaseArgument *p_arg, index_t kbatch) const override
Sets the k batch size.
Definition device_grouped_gemm_xdl_splitk_cshuffle.hpp:751
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition device_grouped_gemm_xdl_splitk_cshuffle.hpp:690
static constexpr auto I1
Definition device_grouped_gemm_xdl_splitk_cshuffle.hpp:156
static auto MakeInvoker()
Definition device_grouped_gemm_xdl_splitk_cshuffle.hpp:673
static constexpr auto I2
Definition device_grouped_gemm_xdl_splitk_cshuffle.hpp:157
static GET_NXDL_PER_WAVE_IMPL constexpr auto NXdlPerWave64
Definition device_grouped_gemm_xdl_splitk_cshuffle.hpp:151
size_t GetHostKernelArgSize(const BaseArgument *p_arg) const
Definition device_grouped_gemm_xdl_splitk_cshuffle.hpp:745
GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2< BlockSize, ADataType, BDataType, AccDataType, EDataType, ALayout, BLayout, ELayout, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, GemmSpec, NumGemmKPrefetchStage, MPerBlock, NPerBlock, K0PerBlock, MPerXDL, NPerXDL, AK1, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_K1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_K0_N_K1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_K1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CDEBlockTransferScalarPerVector_NPerBlock, CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, LoopSched, PipelineVer > GridwiseGemmBase
Definition device_grouped_gemm_xdl_splitk_cshuffle.hpp:163
GemmTransKernelArgBase< KernelArgument > GemmTransKernelArg
Definition device_grouped_gemm_xdl_splitk_cshuffle.hpp:237
static constexpr index_t DefaultKBatch
Definition device_grouped_gemm_xdl_splitk_cshuffle.hpp:239
void SetHostKernelArgsPointer(BaseArgument *p_arg, void *p_host_kernel_args) const
Sets the host kernel arguments pointer and copies that data on the host side. This function can be ut...
Definition device_grouped_gemm_xdl_splitk_cshuffle.hpp:778
static constexpr index_t B2E_M01
Definition device_grouped_gemm_xdl_splitk_cshuffle.hpp:214
size_t GetWorkSpaceSize(const BaseArgument *p_arg) const override
Definition device_grouped_gemm_xdl_splitk_cshuffle.hpp:727
GridwiseGemmBase< NXdlPerWave32 > GridwiseGemm32
Definition device_grouped_gemm_xdl_splitk_cshuffle.hpp:208
void SetDeviceKernelArgs(BaseArgument *p_arg, void *p_dev_kernel_args) const override
Definition device_grouped_gemm_xdl_splitk_cshuffle.hpp:764
size_t GetDeviceKernelArgSize(const BaseArgument *p_arg) const override
Gets the device kernel argument size.
Definition device_grouped_gemm_xdl_splitk_cshuffle.hpp:740
static constexpr auto I3
Definition device_grouped_gemm_xdl_splitk_cshuffle.hpp:158
OffsettedBlockToCTileMap< Block2ETileMapKSplit > GroupedGemmBlock2ETileMap
Definition device_grouped_gemm_xdl_splitk_cshuffle.hpp:215
static constexpr bool IsValidCompilationParameter()
Definition device_grouped_gemm_xdl_splitk_cshuffle.hpp:586
static constexpr index_t NumDTensor
Definition device_grouped_gemm_xdl_splitk_cshuffle.hpp:153
typename GridwiseGemm64::CGridDesc_M_N CGridDesc_M_N
Definition device_grouped_gemm_xdl_splitk_cshuffle.hpp:210
static constexpr auto NXdlPerWave32
Definition device_grouped_gemm_xdl_splitk_cshuffle.hpp:152
ck::tensor_operation::element_wise::PassThrough PassThrough
Definition device_grouped_gemm_xdl_splitk_cshuffle.hpp:217
GridwiseGemmBase< math::max(NXdlPerWave64, 1)> GridwiseGemm64
Definition device_grouped_gemm_xdl_splitk_cshuffle.hpp:207
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition device_grouped_gemm_xdl_splitk_cshuffle.hpp:656
static constexpr auto I0
Definition device_grouped_gemm_xdl_splitk_cshuffle.hpp:155
static auto MakeArgument(std::vector< const void * > &p_As, std::vector< const void * > &p_Bs, std::vector< std::array< const void *, NumDTensor > > &, std::vector< void * > &p_Es, std::vector< GemmDesc > gemm_descs, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation)
Definition device_grouped_gemm_xdl_splitk_cshuffle.hpp:661
static void SetKBatchSize(Argument &arg, index_t kbatch)
Definition device_grouped_gemm_xdl_splitk_cshuffle.hpp:748
typename GridwiseGemm64::Argument KernelArgument
Definition device_grouped_gemm_xdl_splitk_cshuffle.hpp:216
static constexpr index_t K0PerBlock
Definition device_grouped_gemm_xdl_splitk_cshuffle.hpp:160
std::unique_ptr< BaseArgument > MakeArgumentPointer(std::vector< const void * > &p_As, std::vector< const void * > &p_Bs, std::vector< std::array< const void *, NumDTensor > > &, std::vector< void * > &p_Es, std::vector< GemmDesc > &gemm_descs, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation) override
Definition device_grouped_gemm_xdl_splitk_cshuffle.hpp:677
Definition device_grouped_gemm.hpp:80
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:340
#define CK_ENV(name)
Definition utility/env.hpp:129