24template <
typename GridwiseGemm,
29 typename AElementwiseOperation,
30 typename BElementwiseOperation,
31 typename CDEElementwiseOperation,
32 typename AsGridDesc_AK0_M_AK1,
33 typename BsGridDesc_BK0_N_BK1,
34 typename DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
35 typename EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
36 typename Block2ETileMap,
37 bool HasMainKBlockLoop>
39#if CK_USE_LAUNCH_BOUNDS
46 EDataType* __restrict__ p_e_grid,
47 const AElementwiseOperation a_element_op,
48 const BElementwiseOperation b_element_op,
49 const CDEElementwiseOperation cde_element_op,
50 const AsGridDesc_AK0_M_AK1 as_grid_desc_ak0_m_ak1,
51 const BsGridDesc_BK0_N_BK1 bs_grid_desc_bk0_n_bk1,
52 const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
53 ds_grid_desc_mblock_mperblock_nblock_nperblock,
54 const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
55 e_grid_desc_mblock_mperblock_nblock_nperblock,
56 const Block2ETileMap block_2_etile_map)
58#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__)
59 if constexpr(GridwiseGemm::template IsValidCompilationParameter<>())
61 __shared__
char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
63 GridwiseGemm::template Run<HasMainKBlockLoop>(
72 as_grid_desc_ak0_m_ak1,
73 bs_grid_desc_bk0_n_bk1,
74 ds_grid_desc_mblock_mperblock_nblock_nperblock,
75 e_grid_desc_mblock_mperblock_nblock_nperblock,
86 ignore = as_grid_desc_ak0_m_ak1;
87 ignore = bs_grid_desc_bk0_n_bk1;
88 ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock;
89 ignore = e_grid_desc_mblock_mperblock_nblock_nperblock;
90 ignore = block_2_etile_map;
114 typename AccDataType,
115 typename CShuffleDataType,
118 typename AElementwiseOperation,
119 typename BElementwiseOperation,
120 typename CDEElementwiseOperation,
133 typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
134 typename ABlockTransferThreadClusterArrangeOrder,
135 typename ABlockTransferSrcAccessOrder,
136 index_t ABlockTransferSrcVectorDim,
137 index_t ABlockTransferSrcScalarPerVector,
138 index_t ABlockTransferDstScalarPerVector_AK1,
140 typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
141 typename BBlockTransferThreadClusterArrangeOrder,
142 typename BBlockTransferSrcAccessOrder,
143 index_t BBlockTransferSrcVectorDim,
144 index_t BBlockTransferSrcScalarPerVector,
145 index_t BBlockTransferDstScalarPerVector_BK1,
147 index_t CShuffleMXdlPerWavePerShuffle,
148 index_t CShuffleNXdlPerWavePerShuffle,
149 typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
150 index_t CDEBlockTransferScalarPerVector_NPerBlock,
161 AElementwiseOperation,
162 BElementwiseOperation,
163 CDEElementwiseOperation>
183 template <index_t NXdlPerWave_>
192 AElementwiseOperation,
193 BElementwiseOperation,
194 CDEElementwiseOperation,
196 NumGemmKPrefetchStage,
207 ABlockTransferThreadClusterLengths_AK0_M_AK1,
208 ABlockTransferThreadClusterArrangeOrder,
209 ABlockTransferSrcAccessOrder,
210 ABlockTransferSrcVectorDim,
211 ABlockTransferSrcScalarPerVector,
212 ABlockTransferDstScalarPerVector_AK1,
215 BBlockTransferThreadClusterLengths_BK0_N_BK1,
216 BBlockTransferThreadClusterArrangeOrder,
217 BBlockTransferSrcAccessOrder,
218 BBlockTransferSrcVectorDim,
219 BBlockTransferSrcScalarPerVector,
220 BBlockTransferDstScalarPerVector_BK1,
223 CShuffleMXdlPerWavePerShuffle,
224 CShuffleNXdlPerWavePerShuffle,
225 CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
226 CDEBlockTransferScalarPerVector_NPerBlock,
234 MPerBlock, NPerBlock, KPerBlock};
237 const std::vector<index_t>& a_ms_ks_strides_)
239 assert(a_ms_ks_lengths_.size() == NumDimM + NumDimK &&
240 a_ms_ks_strides_.size() == NumDimM + NumDimK);
242 const auto to_tuple = [&](
auto& vec,
auto num) {
253 constexpr auto kDimIds =
263 const auto a_grid_desc_ms_ks =
273 return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw);
276 __host__ __device__
static auto
278 const std::array<std::vector<index_t>,
NumATensor>& as_ms_ks_strides)
289 const std::vector<index_t>& b_ns_ks_strides_)
291 assert(b_ns_ks_lengths_.size() == NumDimN + NumDimK &&
292 b_ns_ks_strides_.size() == NumDimN + NumDimK);
294 const auto to_tuple = [&](
auto& vec,
auto num) {
305 constexpr auto kDimIds =
315 const auto b_grid_desc_ns_ks =
325 return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw);
328 __host__ __device__
static auto
330 const std::array<std::vector<index_t>,
NumBTensor>& bs_ns_ks_strides)
341 const std::vector<index_t>& e_ms_ns_strides_)
343 assert(e_ms_ns_lengths_.size() == NumDimM + NumDimN &&
344 e_ms_ns_strides_.size() == NumDimM + NumDimN);
346 const auto to_tuple = [&](
auto& vec,
auto num) {
357 constexpr auto nDimIds =
367 const auto e_grid_desc_ms_ns =
377 return matrix_padder.PadCDescriptor_M_N(e_grid_desc_mraw_nraw);
382 const std::array<std::vector<index_t>,
NumDTensor>& ds_ms_ns_strides)
418 Argument(std::array<const void*, NumATensor> p_as_grid,
419 std::array<const void*, NumBTensor> p_bs_grid,
420 std::array<const void*, NumDTensor> p_ds_grid,
422 const std::array<std::vector<index_t>,
NumATensor>& a_ms_ks_lengths,
423 const std::array<std::vector<index_t>,
NumATensor>& a_ms_ks_strides,
424 const std::array<std::vector<index_t>,
NumBTensor>& b_ns_ks_lengths,
425 const std::array<std::vector<index_t>,
NumBTensor>& b_ns_ks_strides,
426 const std::array<std::vector<index_t>,
NumDTensor>& d_ms_ns_lengths,
427 const std::array<std::vector<index_t>,
NumDTensor>& d_ms_ns_strides,
428 const std::vector<index_t>& e_ms_ns_length,
429 const std::vector<index_t>& e_ms_ns_stride,
430 AElementwiseOperation a_element_op,
431 BElementwiseOperation b_element_op,
432 CDEElementwiseOperation cde_element_op)
436 p_e_grid_{static_cast<EDataType*>(p_e_grid)},
452 p_as_grid_(i) =
static_cast<const ADataType*
>(p_as_grid[i]);
465 p_bs_grid_(i) =
static_cast<const BDataType*
>(p_bs_grid[i]);
478 p_ds_grid_(i) =
static_cast<const DDataType*
>(p_ds_grid[i]);
545 template <
typename Gr
idwiseGemm>
554 throw std::runtime_error(
"wrong! GridwiseGemm has invalid setting");
556 auto as_grid_desc_ak0_m_ak1 =
559 auto bs_grid_desc_bk0_n_bk1 =
562 auto ds_grid_desc_mblock_mperblock_nblock_nperblock =
563 GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
566 auto e_grid_desc_mblock_mperblock_nblock_nperblock =
567 GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
572 auto launch_kernel = [&](
auto has_main_k_block_loop) {
573 constexpr bool has_main_loop = has_main_k_block_loop.value;
577 typename GridwiseGemm::AsGridPointer,
578 typename GridwiseGemm::BsGridPointer,
579 typename GridwiseGemm::DsGridPointer,
581 AElementwiseOperation,
582 BElementwiseOperation,
583 CDEElementwiseOperation,
603 as_grid_desc_ak0_m_ak1,
604 bs_grid_desc_bk0_n_bk1,
605 ds_grid_desc_mblock_mperblock_nblock_nperblock,
606 e_grid_desc_mblock_mperblock_nblock_nperblock,
612 if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
618 return launch_kernel(integral_constant<bool, false>{});
628 return Run(*
dynamic_cast<const Argument*
>(p_arg), stream_config);
642 bool valid_as_access =
true;
644 const bool valid_a_vector_size =
646 const bool valid_a_access_dim_m =
648 const bool valid_a_access_dim_k =
650 const bool valid_a_access_dim = valid_a_access_dim_m || valid_a_access_dim_k;
651 if(!((valid_a_vector_size && valid_a_access_dim) ||
652 ABlockTransferSrcScalarPerVector == 1))
654 valid_as_access =
false;
662 bool valid_bs_access =
true;
664 const bool valid_b_vector_size =
666 const bool valid_b_access_dim_n =
668 const bool valid_b_access_dim_k =
670 const bool valid_b_access_dim = valid_b_access_dim_n || valid_b_access_dim_k;
671 if(!((valid_b_vector_size && valid_b_access_dim) ||
672 BBlockTransferSrcScalarPerVector == 1))
674 valid_bs_access =
false;
682 bool valid_ds_access =
true;
684 const bool valid_d_vector_size =
688 if(!((valid_d_vector_size && valid_d_access_dim) ||
689 CDEBlockTransferScalarPerVector_NPerBlock == 1))
691 valid_ds_access =
false;
699 const bool valid_e_vector_size =
703 if(!((valid_e_vector_size && valid_e_access_dim) ||
704 CDEBlockTransferScalarPerVector_NPerBlock == 1))
743 std::array<const void*, NumBTensor> p_bs,
744 std::array<const void*, NumDTensor> p_ds,
746 const std::array<std::vector<index_t>,
NumATensor>& a_ms_ks_lengths,
747 const std::array<std::vector<index_t>,
NumATensor>& a_ms_ks_strides,
748 const std::array<std::vector<index_t>,
NumBTensor>& b_ns_ks_lengths,
749 const std::array<std::vector<index_t>,
NumBTensor>& b_ns_ks_strides,
750 const std::array<std::vector<index_t>,
NumDTensor>& d_ms_ns_lengths,
751 const std::array<std::vector<index_t>,
NumDTensor>& d_ms_ns_strides,
752 const std::vector<index_t>& e_ms_ns_length,
753 const std::vector<index_t>& e_ms_ns_stride,
754 AElementwiseOperation a_element_op,
755 BElementwiseOperation b_element_op,
756 CDEElementwiseOperation cde_element_op)
778 std::unique_ptr<BaseArgument>
780 std::array<const void*, NumBTensor> p_bs,
781 std::array<const void*, NumDTensor> p_ds,
783 const std::array<std::vector<index_t>,
NumATensor>& as_ms_ks_lengths,
784 const std::array<std::vector<index_t>,
NumATensor>& as_ms_ks_strides,
785 const std::array<std::vector<index_t>,
NumBTensor>& bs_ns_ks_lengths,
786 const std::array<std::vector<index_t>,
NumBTensor>& bs_ns_ks_strides,
787 const std::array<std::vector<index_t>,
NumDTensor>& ds_ms_ns_lengths,
788 const std::array<std::vector<index_t>,
NumDTensor>& ds_ms_ns_strides,
789 const std::vector<index_t>& e_ms_ns_length,
790 const std::vector<index_t>& e_ms_ns_stride,
791 AElementwiseOperation a_element_op,
792 BElementwiseOperation b_element_op,
793 CDEElementwiseOperation cde_element_op)
override
795 return std::make_unique<Argument>(p_as,
815 return std::make_unique<Invoker>(
Invoker{});
821 auto str = std::stringstream();
823 std::map<LoopScheduler, std::string> LoopSchedToString{
826 std::map<PipelineVersion, std::string> PipelineVersionToString{{
PipelineVersion::v1,
"v1"},
830 str <<
"DeviceContractionMultipleABD_Xdl_CShuffle"
840 << MXdlPerWave <<
", "
841 << NXdlPerWave <<
", "
842 << ABlockTransferSrcScalarPerVector <<
", "
843 << BBlockTransferSrcScalarPerVector <<
", "
844 << CShuffleMXdlPerWavePerShuffle <<
", "
845 << CShuffleNXdlPerWavePerShuffle <<
", "
848 <<
" LoopScheduler: "
849 << LoopSchedToString[LoopSched] <<
", "
850 <<
"PipelineVersion: "
851 << PipelineVersionToString[PipelineVer];
#define CK_MIN_BLOCK_PER_CU
Definition ck.hpp:31
#define CK_MAX_THREAD_PER_BLOCK
Definition ck.hpp:30
#define GET_NXDL_PER_WAVE_IMPL
Definition device_base.hpp:81
#define INVOKER_RUN_IMPL
Definition device_base.hpp:94
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:14
__host__ __device__ constexpr T max(T x)
Definition utility/math.hpp:84
Definition convolution_backward_data_specialization.hpp:8
std::string getGemmSpecializationString(const GemmSpecialization &s)
Definition gemm_specialization.hpp:32
auto CalculateMaxRead(const std::vector< index_t > &lengths, const std::vector< index_t > &strides)
Definition device_contraction_utils.hpp:33
GemmSpecialization
Definition gemm_specialization.hpp:11
Definition convolution_backward_data_specialization.hpp:7
int32_t index_t
Definition ck.hpp:299
__host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition tensor_descriptor_helper.hpp:49
@ Set
Definition ck.hpp:278
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
constexpr Tuple< Args &... > tie(Args &... args) noexcept
Definition utility/tuple.hpp:218
integral_constant< index_t, N > Number
Definition number.hpp:12
typename tuple_element< I, TTuple >::type tuple_element_t
Definition utility/tuple.hpp:208
__host__ __device__ constexpr auto make_merge_transform(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:55
constexpr detail::ignore_t ignore
Definition utility/ignore.hpp:20
bool is_xdl_wmma_supported()
Definition host_utility/device_prop.hpp:76
__device__ constexpr index_t get_warp_size()
Definition get_id.hpp:10
__host__ __device__ constexpr auto get_container_subset(const Array< T, N > &arr, Sequence< Is... >)
Definition utility/container_helper.hpp:346
__host__ __device__ constexpr auto generate_tuple(F &&f, Number< N >)
Definition tuple_helper.hpp:21
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
LoopScheduler
Definition loop_scheduler.hpp:15
@ Default
Definition loop_scheduler.hpp:16
@ Interwave
Definition loop_scheduler.hpp:17
__global__ void kernel_contraction_multiple_abd_xdl_cshuffle(AsPointer p_as_grid, BsPointer p_bs_grid, DsPointer p_ds_grid, EDataType *__restrict__ p_e_grid, const AElementwiseOperation a_element_op, const BElementwiseOperation b_element_op, const CDEElementwiseOperation cde_element_op, const AsGridDesc_AK0_M_AK1 as_grid_desc_ak0_m_ak1, const BsGridDesc_BK0_N_BK1 bs_grid_desc_bk0_n_bk1, const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock ds_grid_desc_mblock_mperblock_nblock_nperblock, const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock, const Block2ETileMap block_2_etile_map)
Definition device_contraction_multiple_abd_xdl_cshuffle.hpp:42
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
PipelineVersion
Definition gridwise_gemm_pipeline_selector.hpp:18
@ v2
Definition gridwise_gemm_pipeline_selector.hpp:20
@ v1
Definition gridwise_gemm_pipeline_selector.hpp:19
constexpr LoopScheduler make_default_loop_scheduler()
Definition loop_scheduler.hpp:20
Definition ck/stream_config.hpp:10
Definition gridwise_gemm_multiple_abd_xdl_cshuffle.hpp:77
ck::GridwiseGemmMultipleABD_xdl_cshuffle< AsDataType, BsDataType, ComputeDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, InMemoryDataOperationEnum::Set, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEBlockTransferScalarPerVector_NPerBlock, LoopSched, PipelineVer >::MakeDefaultAsGridDescriptor_AK0_M_AK1 __host__ static __device__ constexpr auto MakeDefaultAsGridDescriptor_AK0_M_AK1(const AsGridDesc_M_K &as_grid_desc_m_k)
Definition gridwise_gemm_multiple_abd_xdl_cshuffle.hpp:229
ck::GridwiseGemmMultipleABD_xdl_cshuffle< AsDataType, BsDataType, ComputeDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, InMemoryDataOperationEnum::Set, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEBlockTransferScalarPerVector_NPerBlock, LoopSched, PipelineVer >::MakeDefaultBlock2ETileMap __host__ static __device__ constexpr auto MakeDefaultBlock2ETileMap(const EGridDesc_M_N &e_grid_desc_m_n)
Definition gridwise_gemm_multiple_abd_xdl_cshuffle.hpp:298
ck::GridwiseGemmMultipleABD_xdl_cshuffle< AsDataType, BsDataType, ComputeDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, InMemoryDataOperationEnum::Set, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEBlockTransferScalarPerVector_NPerBlock, LoopSched, PipelineVer >::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock __host__ static __device__ constexpr auto MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const DsGridDesc_M_N &ds_grid_desc_m_n)
Definition gridwise_gemm_multiple_abd_xdl_cshuffle.hpp:286
ck::GridwiseGemmMultipleABD_xdl_cshuffle< AsDataType, BsDataType, ComputeDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, InMemoryDataOperationEnum::Set, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEBlockTransferScalarPerVector_NPerBlock, LoopSched, PipelineVer >::AsGridPointer decltype(MakeAsGridPointer()) AsGridPointer
Definition gridwise_gemm_multiple_abd_xdl_cshuffle.hpp:403
ck::GridwiseGemmMultipleABD_xdl_cshuffle< AsDataType, BsDataType, ComputeDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, InMemoryDataOperationEnum::Set, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEBlockTransferScalarPerVector_NPerBlock, LoopSched, PipelineVer >::CheckValidity __host__ static __device__ constexpr bool CheckValidity(const AsGridDesc_M_K &as_grid_desc_m_k, const BsGridDesc_N_K &bs_grid_desc_n_k, const DsGridDesc_M_N &ds_grid_desc_m_n, const EGridDesc_M_N &e_grid_desc_m_n, const Block2ETileMap &block_2_etile_map)
Definition gridwise_gemm_multiple_abd_xdl_cshuffle.hpp:312
ck::GridwiseGemmMultipleABD_xdl_cshuffle< AsDataType, BsDataType, ComputeDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, InMemoryDataOperationEnum::Set, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEBlockTransferScalarPerVector_NPerBlock, LoopSched, PipelineVer >::BsGridPointer decltype(MakeBsGridPointer()) BsGridPointer
Definition gridwise_gemm_multiple_abd_xdl_cshuffle.hpp:404
ck::GridwiseGemmMultipleABD_xdl_cshuffle< AsDataType, BsDataType, ComputeDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, InMemoryDataOperationEnum::Set, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEBlockTransferScalarPerVector_NPerBlock, LoopSched, PipelineVer >::DsGridPointer decltype(MakeDsGridPointer()) DsGridPointer
Definition gridwise_gemm_multiple_abd_xdl_cshuffle.hpp:405
ck::GridwiseGemmMultipleABD_xdl_cshuffle< AsDataType, BsDataType, ComputeDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, InMemoryDataOperationEnum::Set, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEBlockTransferScalarPerVector_NPerBlock, LoopSched, PipelineVer >::MakeDefaultBsGridDescriptor_BK0_N_BK1 __host__ static __device__ constexpr auto MakeDefaultBsGridDescriptor_BK0_N_BK1(const BsGridDesc_N_K &bs_grid_desc_n_k)
Definition gridwise_gemm_multiple_abd_xdl_cshuffle.hpp:255
ck::GridwiseGemmMultipleABD_xdl_cshuffle< AsDataType, BsDataType, ComputeDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, InMemoryDataOperationEnum::Set, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEBlockTransferScalarPerVector_NPerBlock, LoopSched, PipelineVer >::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock __host__ static __device__ constexpr auto MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const EGridDesc_M_N &e_grid_desc_m_n)
Definition gridwise_gemm_multiple_abd_xdl_cshuffle.hpp:265
Definition utility/sequence.hpp:43
typename conditional< kHasContent, type0, type1 >::type type
Definition utility/sequence.hpp:271
Definition utility/integral_constant.hpp:20
Definition functional2.hpp:33
Definition device_base.hpp:197
Definition device_contraction_multiple_abd_xdl_cshuffle.hpp:417
EDataType * p_e_grid_
Definition device_contraction_multiple_abd_xdl_cshuffle.hpp:512
std::array< index_t, NumATensor > bs_continous_dim_
Definition device_contraction_multiple_abd_xdl_cshuffle.hpp:530
GridwiseGemm64::BsGridPointer p_bs_grid_
Definition device_contraction_multiple_abd_xdl_cshuffle.hpp:510
DsGridDesc_M_N ds_grid_desc_m_n_
Definition device_contraction_multiple_abd_xdl_cshuffle.hpp:517
std::array< index_t, NumATensor > as_continous_dim_
Definition device_contraction_multiple_abd_xdl_cshuffle.hpp:529
BsGridDesc_N_K bs_grid_desc_n_k_
Definition device_contraction_multiple_abd_xdl_cshuffle.hpp:516
AElementwiseOperation a_element_op_
Definition device_contraction_multiple_abd_xdl_cshuffle.hpp:524
BElementwiseOperation b_element_op_
Definition device_contraction_multiple_abd_xdl_cshuffle.hpp:525
EGridDesc_M_N e_grid_desc_m_n_
Definition device_contraction_multiple_abd_xdl_cshuffle.hpp:518
std::array< index_t, NumBTensor > ds_continous_dim_
Definition device_contraction_multiple_abd_xdl_cshuffle.hpp:531
Argument(std::array< const void *, NumATensor > p_as_grid, std::array< const void *, NumBTensor > p_bs_grid, std::array< const void *, NumDTensor > p_ds_grid, void *p_e_grid, const std::array< std::vector< index_t >, NumATensor > &a_ms_ks_lengths, const std::array< std::vector< index_t >, NumATensor > &a_ms_ks_strides, const std::array< std::vector< index_t >, NumBTensor > &b_ns_ks_lengths, const std::array< std::vector< index_t >, NumBTensor > &b_ns_ks_strides, const std::array< std::vector< index_t >, NumDTensor > &d_ms_ns_lengths, const std::array< std::vector< index_t >, NumDTensor > &d_ms_ns_strides, const std::vector< index_t > &e_ms_ns_length, const std::vector< index_t > &e_ms_ns_stride, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op)
Definition device_contraction_multiple_abd_xdl_cshuffle.hpp:418
index_t e_max_write_elems_
Definition device_contraction_multiple_abd_xdl_cshuffle.hpp:537
std::array< index_t, NumDTensor > ds_max_read_elems_
Definition device_contraction_multiple_abd_xdl_cshuffle.hpp:536
GridwiseGemm64::DsGridPointer p_ds_grid_
Definition device_contraction_multiple_abd_xdl_cshuffle.hpp:511
AsGridDesc_M_K as_grid_desc_m_k_
Definition device_contraction_multiple_abd_xdl_cshuffle.hpp:515
index_t e_continous_dim_
Definition device_contraction_multiple_abd_xdl_cshuffle.hpp:532
std::array< index_t, NumATensor > as_max_read_elems_
Definition device_contraction_multiple_abd_xdl_cshuffle.hpp:534
GridwiseGemm64::AsGridPointer p_as_grid_
Definition device_contraction_multiple_abd_xdl_cshuffle.hpp:509
std::array< index_t, NumBTensor > bs_max_read_elems_
Definition device_contraction_multiple_abd_xdl_cshuffle.hpp:535
Block2ETileMap block_2_etile_map_
Definition device_contraction_multiple_abd_xdl_cshuffle.hpp:521
CDEElementwiseOperation cde_element_op_
Definition device_contraction_multiple_abd_xdl_cshuffle.hpp:526
Definition device_contraction_multiple_abd_xdl_cshuffle.hpp:542
float RunImp(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition device_contraction_multiple_abd_xdl_cshuffle.hpp:546
INVOKER_RUN_IMPL float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition device_contraction_multiple_abd_xdl_cshuffle.hpp:625
DeviceOp::Argument Argument
Definition device_contraction_multiple_abd_xdl_cshuffle.hpp:543
Definition device_contraction_multiple_abd_xdl_cshuffle.hpp:164
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition device_contraction_multiple_abd_xdl_cshuffle.hpp:813
GridwiseGemmBase< NXdlPerWave32 > GridwiseGemm32
Definition device_contraction_multiple_abd_xdl_cshuffle.hpp:230
remove_cvref_t< decltype(MakeEGridDescriptor_M_N({}, {}))> EGridDesc_M_N
Definition device_contraction_multiple_abd_xdl_cshuffle.hpp:395
remove_cvref_t< decltype(GridwiseGemm64::MakeDefaultAsGridDescriptor_AK0_M_AK1( AsGridDesc_M_K{}))> AsGridDesc_AK0_M_AK1
Definition device_contraction_multiple_abd_xdl_cshuffle.hpp:398
static constexpr index_t NumATensor
Definition device_contraction_multiple_abd_xdl_cshuffle.hpp:171
static auto MakeDsGridDescriptor_M_N(const std::array< std::vector< index_t >, NumDTensor > &ds_ms_ns_lengths, const std::array< std::vector< index_t >, NumDTensor > &ds_ms_ns_strides)
Definition device_contraction_multiple_abd_xdl_cshuffle.hpp:381
remove_cvref_t< decltype(MakeBsGridDescriptor_N_K({}, {}))> BsGridDesc_N_K
Definition device_contraction_multiple_abd_xdl_cshuffle.hpp:393
DeviceContractionMultipleABD_Xdl_CShuffle DeviceOp
Definition device_contraction_multiple_abd_xdl_cshuffle.hpp:165
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition device_contraction_multiple_abd_xdl_cshuffle.hpp:737
__host__ static __device__ auto MakeAsGridDescriptor_M_K(const std::array< std::vector< index_t >, NumATensor > &as_ms_ks_lengths, const std::array< std::vector< index_t >, NumATensor > &as_ms_ks_strides)
Definition device_contraction_multiple_abd_xdl_cshuffle.hpp:277
static constexpr auto I2
Definition device_contraction_multiple_abd_xdl_cshuffle.hpp:177
std::unique_ptr< BaseArgument > MakeArgumentPointer(std::array< const void *, NumATensor > p_as, std::array< const void *, NumBTensor > p_bs, std::array< const void *, NumDTensor > p_ds, void *p_e, const std::array< std::vector< index_t >, NumATensor > &as_ms_ks_lengths, const std::array< std::vector< index_t >, NumATensor > &as_ms_ks_strides, const std::array< std::vector< index_t >, NumBTensor > &bs_ns_ks_lengths, const std::array< std::vector< index_t >, NumBTensor > &bs_ns_ks_strides, const std::array< std::vector< index_t >, NumDTensor > &ds_ms_ns_lengths, const std::array< std::vector< index_t >, NumDTensor > &ds_ms_ns_strides, const std::vector< index_t > &e_ms_ns_length, const std::vector< index_t > &e_ms_ns_stride, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op) override
Definition device_contraction_multiple_abd_xdl_cshuffle.hpp:779
static auto MakeBGridDescriptor_N_K(const std::vector< index_t > &b_ns_ks_lengths_, const std::vector< index_t > &b_ns_ks_strides_)
Definition device_contraction_multiple_abd_xdl_cshuffle.hpp:288
std::string GetTypeString() const override
Definition device_contraction_multiple_abd_xdl_cshuffle.hpp:819
remove_cvref_t< decltype(GridwiseGemm64::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( DsGridDesc_M_N{}))> DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
Definition device_contraction_multiple_abd_xdl_cshuffle.hpp:404
static constexpr auto I1
Definition device_contraction_multiple_abd_xdl_cshuffle.hpp:176
static bool IsSupportedArgument(const Argument &arg)
Definition device_contraction_multiple_abd_xdl_cshuffle.hpp:632
__host__ static __device__ auto MakeBsGridDescriptor_N_K(const std::array< std::vector< index_t >, NumBTensor > &bs_ns_ks_lengths, const std::array< std::vector< index_t >, NumBTensor > &bs_ns_ks_strides)
Definition device_contraction_multiple_abd_xdl_cshuffle.hpp:329
GridwiseGemmMultipleABD_xdl_cshuffle< AsDataType, BsDataType, ComputeDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, InMemoryDataOperationEnum::Set, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEBlockTransferScalarPerVector_NPerBlock, LoopSched, PipelineVer > GridwiseGemmBase
Definition device_contraction_multiple_abd_xdl_cshuffle.hpp:184
remove_cvref_t< decltype(MakeAsGridDescriptor_M_K({}, {}))> AsGridDesc_M_K
Definition device_contraction_multiple_abd_xdl_cshuffle.hpp:392
static constexpr auto matrix_padder
Definition device_contraction_multiple_abd_xdl_cshuffle.hpp:232
GridwiseGemmBase< math::max(NXdlPerWave64, 1)> GridwiseGemm64
Definition device_contraction_multiple_abd_xdl_cshuffle.hpp:229
remove_cvref_t< decltype(GridwiseGemm64::MakeDefaultBsGridDescriptor_BK0_N_BK1( BsGridDesc_N_K{}))> BsGridDesc_BK0_N_BK1
Definition device_contraction_multiple_abd_xdl_cshuffle.hpp:401
remove_cvref_t< decltype(GridwiseGemm64::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( EGridDesc_M_N{}))> EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
Definition device_contraction_multiple_abd_xdl_cshuffle.hpp:407
remove_cvref_t< decltype(MakeDsGridDescriptor_M_N({}, {}))> DsGridDesc_M_N
Definition device_contraction_multiple_abd_xdl_cshuffle.hpp:394
static constexpr auto NXdlPerWave32
Definition device_contraction_multiple_abd_xdl_cshuffle.hpp:169
static GET_NXDL_PER_WAVE_IMPL constexpr auto NXdlPerWave64
Definition device_contraction_multiple_abd_xdl_cshuffle.hpp:168
EDataType ComputeDataType
Definition device_contraction_multiple_abd_xdl_cshuffle.hpp:180
static constexpr auto I3
Definition device_contraction_multiple_abd_xdl_cshuffle.hpp:178
static auto MakeArgument(std::array< const void *, NumATensor > p_as, std::array< const void *, NumBTensor > p_bs, std::array< const void *, NumDTensor > p_ds, void *p_e, const std::array< std::vector< index_t >, NumATensor > &a_ms_ks_lengths, const std::array< std::vector< index_t >, NumATensor > &a_ms_ks_strides, const std::array< std::vector< index_t >, NumBTensor > &b_ns_ks_lengths, const std::array< std::vector< index_t >, NumBTensor > &b_ns_ks_strides, const std::array< std::vector< index_t >, NumDTensor > &d_ms_ns_lengths, const std::array< std::vector< index_t >, NumDTensor > &d_ms_ns_strides, const std::vector< index_t > &e_ms_ns_length, const std::vector< index_t > &e_ms_ns_stride, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op)
Definition device_contraction_multiple_abd_xdl_cshuffle.hpp:742
static constexpr index_t NumDTensor
Definition device_contraction_multiple_abd_xdl_cshuffle.hpp:173
static auto MakeAGridDescriptor_M_K(const std::vector< index_t > &a_ms_ks_lengths_, const std::vector< index_t > &a_ms_ks_strides_)
Definition device_contraction_multiple_abd_xdl_cshuffle.hpp:236
remove_cvref_t< decltype(GridwiseGemm64::MakeDefaultBlock2ETileMap(EGridDesc_M_N{}))> Block2ETileMap
Definition device_contraction_multiple_abd_xdl_cshuffle.hpp:412
static constexpr index_t NumBTensor
Definition device_contraction_multiple_abd_xdl_cshuffle.hpp:172
static auto MakeInvoker()
Definition device_contraction_multiple_abd_xdl_cshuffle.hpp:775
static constexpr auto I0
Definition device_contraction_multiple_abd_xdl_cshuffle.hpp:175
static auto MakeEGridDescriptor_M_N(const std::vector< index_t > &e_ms_ns_lengths_, const std::vector< index_t > &e_ms_ns_strides_)
Definition device_contraction_multiple_abd_xdl_cshuffle.hpp:340
Definition device_contraction_multiple_abd.hpp:34
Definition matrix_padder.hpp:180