24template <
typename ALayout,
30 typename GemmAccDataType,
31 typename CShuffleDataType,
32 typename AElementwiseOperation,
33 typename BElementwiseOperation,
34 typename CElementwiseOperation,
46 typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
47 typename ABlockTransferThreadClusterArrangeOrder,
48 typename ABlockTransferSrcAccessOrder,
49 index_t ABlockTransferSrcVectorDim,
50 index_t ABlockTransferSrcScalarPerVector,
51 index_t ABlockTransferDstScalarPerVector_AK1,
53 typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
54 typename BBlockTransferThreadClusterArrangeOrder,
55 typename BBlockTransferSrcAccessOrder,
56 index_t BBlockTransferSrcVectorDim,
57 index_t BBlockTransferSrcScalarPerVector,
58 index_t BBlockTransferDstScalarPerVector_BK1,
60 index_t CShuffleMXdlPerWavePerShuffle,
61 index_t CShuffleNXdlPerWavePerShuffle,
62 typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
63 index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
66 typename ComputeTypeA = CDataType,
67 typename ComputeTypeB = ComputeTypeA>
74 AElementwiseOperation,
75 BElementwiseOperation,
76 CElementwiseOperation>
83 template <index_t NXdlPerWave_>
93 AElementwiseOperation,
94 BElementwiseOperation,
95 CElementwiseOperation,
107 ABlockTransferThreadClusterLengths_AK0_M_AK1,
108 ABlockTransferThreadClusterArrangeOrder,
109 ABlockTransferSrcAccessOrder,
110 ABlockTransferSrcVectorDim,
111 ABlockTransferSrcScalarPerVector,
112 ABlockTransferDstScalarPerVector_AK1,
115 BBlockTransferThreadClusterLengths_BK0_N_BK1,
116 BBlockTransferThreadClusterArrangeOrder,
117 BBlockTransferSrcAccessOrder,
118 BBlockTransferSrcVectorDim,
119 BBlockTransferSrcScalarPerVector,
120 BBlockTransferDstScalarPerVector_BK1,
123 CShuffleMXdlPerWavePerShuffle,
124 CShuffleNXdlPerWavePerShuffle,
125 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
126 CShuffleBlockTransferScalarPerVector_NPerBlock,
134 using Argument =
typename GridwiseGemm64::Argument;
140 template <
typename Gr
idwiseGemm>
141 float RunImp(
const typename GridwiseGemm::Argument& arg,
145 if(stream_config.log_level_ > 0)
150 if(!GridwiseGemm::CheckValidity(arg))
152 throw std::runtime_error(
"wrong! GridwiseGemm has invalid setting");
158 index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock;
160 const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
166 arg.p_c_grid, 0, arg.M * arg.N *
sizeof(CDataType), stream_config.stream_id_));
169 const auto Run = [&](
const auto& kernel) {
171 if(arg.Grid_size < 0)
173 int occupancy, num_cu;
175 &occupancy, kernel, BlockSize, 0));
176 hipDeviceProp_t dev_prop;
180 num_cu = dev_prop.multiProcessorCount;
181 arg.Grid_size = num_cu * occupancy;
182 grid_dim = arg.Grid_size;
185 grid_dim = arg.Grid_size;
187 if(stream_config.flush_cache)
192 stream_config.rotating_count,
193 arg_.M * arg_.K *
sizeof(ADataType),
194 arg_.K * arg_.N *
sizeof(BDataType));
195 rotating_mem.Print();
197 auto run_flush_cache = [&]() {
205 stream_config, run_flush_cache, kernel, grid_dim, dim3(BlockSize), 0, arg_);
213 stream_config, kernel, grid_dim, dim3(BlockSize), 0, arg);
217 char* workspace_semaphore =
218 reinterpret_cast<char*
>(arg.p_workspace_) +
219 arg.block_2_ctile_map_streamk.get_workspace_size_for_acc(
220 sizeof(GemmAccDataType));
221 auto preprocess = [&]() {
222 hipError_t status = hipMemsetAsync(
226 arg.block_2_ctile_map_streamk.get_workspace_size_for_semaphore(),
227 stream_config.stream_id_);
234 stream_config, preprocess, kernel, grid_dim, dim3(BlockSize), 0, arg);
239 constexpr index_t minimum_occupancy =
242 if(has_main_k_block_loop)
261 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::One)
271 else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
283 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2)
285 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Two)
297 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3)
299 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
312 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4)
314 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
327 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5)
329 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
342 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6)
344 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Six)
356 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7)
358 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
376 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Odd)
400 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Odd)
445 return Run(*
dynamic_cast<const Argument*
>(p_arg), stream_config);
454 return p_arg->block_2_ctile_map_streamk.get_workspace_size(
sizeof(GemmAccDataType));
468 pArg_->p_workspace_ = p_workspace;
513 reinterpret_cast<const typename GridwiseGemm32::Argument&
>(arg));
524 template <
typename Gr
idwiseGemm,
bool IsVal
id>
527 const BDataType* p_b,
537 AElementwiseOperation,
538 BElementwiseOperation,
539 CElementwiseOperation,
543 constexpr index_t minimum_occupancy =
545 index_t K_split = (K + KPerBlock - 1) / KPerBlock * KPerBlock;
547 int occupancy = 1, num_cu = 1;
548 const auto calculate_grid_size = [&](
const auto& kernel) {
550 hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel, BlockSize, 0));
551 hipDeviceProp_t dev_prop;
555 num_cu = dev_prop.multiProcessorCount;
556 Grid_size = num_cu * occupancy;
559 if constexpr(IsValid)
561 const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
562 if(has_main_k_block_loop)
573 calculate_grid_size(kernel);
579 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::One)
587 calculate_grid_size(kernel);
589 else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Full)
597 calculate_grid_size(kernel);
600 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2)
602 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Two)
610 calculate_grid_size(kernel);
614 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3)
624 calculate_grid_size(kernel);
628 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4)
638 calculate_grid_size(kernel);
642 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5)
652 calculate_grid_size(kernel);
656 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6)
658 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Six)
666 calculate_grid_size(kernel);
670 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7)
680 calculate_grid_size(kernel);
688 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Odd)
696 calculate_grid_size(kernel);
706 calculate_grid_size(kernel);
712 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Odd)
720 calculate_grid_size(kernel);
730 calculate_grid_size(kernel);
744 calculate_grid_size(kernel);
765 const BDataType* p_b,
775 AElementwiseOperation a_op,
776 BElementwiseOperation b_op,
777 CElementwiseOperation c_op,
834 AElementwiseOperation,
835 BElementwiseOperation,
836 CElementwiseOperation,
839 return std::make_unique<Argument>(
static_cast<const ADataType*
>(p_a),
840 static_cast<const BDataType*
>(p_b),
841 static_cast<CDataType*
>(p_c),
856 return std::make_unique<Invoker>(
Invoker{});
862 auto str = std::stringstream();
864 std::map<BlockGemmPipelineScheduler, std::string> BlkGemmPipelineSchedulerToString{
868 std::map<BlockGemmPipelineVersion, std::string> BlkGemmPipelineVersionToString{
876 str <<
"DeviceGemmXdlUniversal"
879 << std::string(ALayout::name)[0]
880 << std::string(BLayout::name)[0]
881 << std::string(CLayout::name)[0]
886 << MPerBlock<<
"x"<<NPerBlock<<
"x"<<KPerBlock <<
", "
888 << MPerXDL<<
"x"<<NPerXDL <<
", "
890 << MXdlPerWave<<
"x" << NXdlPerWave<<
", "
892 << ABlockTransferSrcScalarPerVector<<
"x"<<BBlockTransferSrcScalarPerVector<<
", "
893 <<
"BlkGemmPipelineScheduler: "
894 << BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] <<
", "
895 <<
"BlkGemmPipelineVersion: "
896 << BlkGemmPipelineVersionToString[BlkGemmPipelineVer] <<
", "
897 <<
"BlkGemmPipelinePrefetchStages: "
898 << GridwiseGemm64::BlockwiseGemmPipe::PrefetchStages;
#define INVOKER_RUN3_IMPL
Definition device_base.hpp:114
#define GET_NXDL_PER_WAVE_IMPL
Definition device_base.hpp:81
void hip_check_error(hipError_t x)
Definition host_utility/hip_check_error.hpp:10
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:14
float launch_and_time_kernel_with_preprocess(const StreamConfig &stream_config, PreProcessFunc preprocess, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:91
__host__ __device__ constexpr T max(T x)
Definition utility/math.hpp:84
Definition convolution_backward_data_specialization.hpp:8
std::string getGemmSpecializationString(const GemmSpecialization &s)
Definition gemm_specialization.hpp:32
GemmSpecialization
Definition gemm_specialization.hpp:11
@ MKPadding
Definition gemm_specialization.hpp:18
@ KPadding
Definition gemm_specialization.hpp:16
@ MNKPadding
Definition gemm_specialization.hpp:20
@ NKPadding
Definition gemm_specialization.hpp:19
Definition convolution_backward_data_specialization.hpp:7
void flush_icache()
Definition flush_cache.hpp:383
float launch_and_time_kernel_with_preprocess(const StreamConfig &stream_config, PreProcessFunc preprocess, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, GemmArgs &gemm_args, Args... args)
Definition flush_cache.hpp:398
StreamKReductionStrategy
Definition block_to_ctile_map.hpp:1011
@ Atomic
Definition block_to_ctile_map.hpp:1012
@ Reduction
Definition block_to_ctile_map.hpp:1013
int32_t index_t
Definition ck.hpp:299
@ Set
Definition ck.hpp:278
BlockGemmPipelineVersion
Definition blkgemmpipe_scheduler.hpp:12
@ v2
Definition blkgemmpipe_scheduler.hpp:15
@ v3
Definition blkgemmpipe_scheduler.hpp:16
@ v5
Definition blkgemmpipe_scheduler.hpp:18
@ v4
Definition blkgemmpipe_scheduler.hpp:17
@ v1
Definition blkgemmpipe_scheduler.hpp:14
@ One
Definition blkgemmpipe_scheduler.hpp:37
@ Seven
Definition blkgemmpipe_scheduler.hpp:43
@ Even
Definition blkgemmpipe_scheduler.hpp:34
@ Odd
Definition blkgemmpipe_scheduler.hpp:33
@ Four
Definition blkgemmpipe_scheduler.hpp:40
@ Two
Definition blkgemmpipe_scheduler.hpp:38
@ Full
Definition blkgemmpipe_scheduler.hpp:49
@ Three
Definition blkgemmpipe_scheduler.hpp:39
@ Five
Definition blkgemmpipe_scheduler.hpp:41
@ Six
Definition blkgemmpipe_scheduler.hpp:42
__global__ void kernel_gemm_xdl_cshuffle_v3_2lds(typename GridwiseGemm::Argument karg)
Definition gridwise_gemm_xdl_cshuffle_streamk_v3.hpp:62
bool is_xdl_wmma_supported()
Definition host_utility/device_prop.hpp:76
__device__ constexpr index_t get_warp_size()
Definition get_id.hpp:10
BlockGemmPipelineScheduler
Definition blkgemmpipe_scheduler.hpp:25
@ Intrawave
Definition blkgemmpipe_scheduler.hpp:26
@ Interwave
Definition blkgemmpipe_scheduler.hpp:27
__global__ void kernel_gemm_xdl_cshuffle_v3(typename GridwiseGemm::Argument karg)
Definition gridwise_gemm_xdl_cshuffle_streamk_v3.hpp:38
bool is_gfx11_supported()
Definition host_utility/device_prop.hpp:60
bool is_bf16_atomic_supported()
Definition host_utility/device_prop.hpp:108
Definition ck/stream_config.hpp:10
Definition gridwise_gemm_xdl_cshuffle_streamk_v3.hpp:133
ck::GridwiseGemm_xdl_cshuffle_streamk_v3< ALayout, BLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB >::CheckValidity static __host__ constexpr bool CheckValidity(const Argument &karg)
Definition gridwise_gemm_xdl_cshuffle_streamk_v3.hpp:1024
Definition device_base.hpp:197
Definition device_gemm_streamk_v2.hpp:23
Definition device_gemm_xdl_cshuffle_streamk_v3.hpp:139
INVOKER_RUN3_IMPL float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition device_gemm_xdl_cshuffle_streamk_v3.hpp:442
float RunImp(const typename GridwiseGemm::Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition device_gemm_xdl_cshuffle_streamk_v3.hpp:141
Definition device_gemm_xdl_cshuffle_streamk_v3.hpp:77
static constexpr auto NXdlPerWave32
Definition device_gemm_xdl_cshuffle_streamk_v3.hpp:80
GridwiseGemmBase< math::max(NXdlPerWave64, 1)> GridwiseGemm64
Definition device_gemm_xdl_cshuffle_streamk_v3.hpp:131
size_t GetWorkSpaceSize(const BaseArgument *pArg) const override
Definition device_gemm_xdl_cshuffle_streamk_v3.hpp:449
typename GridwiseGemm64::Argument Argument
Definition device_gemm_xdl_cshuffle_streamk_v3.hpp:134
GridwiseGemm_xdl_cshuffle_streamk_v3< ALayout, BLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB > GridwiseGemmBase
Definition device_gemm_xdl_cshuffle_streamk_v3.hpp:84
static constexpr bool IsValidCompilationParameter()
Definition device_gemm_xdl_cshuffle_streamk_v3.hpp:471
static auto MakeArgument(const ADataType *p_a, const BDataType *p_b, CDataType *p_c, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, index_t StrideC, index_t streamk_sel, index_t Grid_size, AElementwiseOperation a_op, BElementwiseOperation b_op, CElementwiseOperation c_op, StreamKReductionStrategy reduction_strategy=StreamKReductionStrategy::Atomic)
Definition device_gemm_xdl_cshuffle_streamk_v3.hpp:764
void SetWorkSpacePointer(BaseArgument *pArg, void *p_workspace, const StreamConfig &=StreamConfig{}) const override
Definition device_gemm_xdl_cshuffle_streamk_v3.hpp:462
static bool IsSupportedArgument(const Argument &arg)
Definition device_gemm_xdl_cshuffle_streamk_v3.hpp:477
static auto MakeInvoker()
Definition device_gemm_xdl_cshuffle_streamk_v3.hpp:819
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, void *p_c, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, index_t StrideC, index_t streamk_sel, index_t Grid_size, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, StreamKReductionStrategy reduction_strategy=StreamKReductionStrategy::Atomic) override
Definition device_gemm_xdl_cshuffle_streamk_v3.hpp:822
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition device_gemm_xdl_cshuffle_streamk_v3.hpp:854
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition device_gemm_xdl_cshuffle_streamk_v3.hpp:520
std::string GetTypeString() const override
Definition device_gemm_xdl_cshuffle_streamk_v3.hpp:860
GridwiseGemmBase< NXdlPerWave32 > GridwiseGemm32
Definition device_gemm_xdl_cshuffle_streamk_v3.hpp:132
static auto MakeArgumentImp(const ADataType *p_a, const BDataType *p_b, CDataType *p_c, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, index_t StrideC, index_t streamk_sel, index_t Grid_size, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, StreamKReductionStrategy reduction_strategy=StreamKReductionStrategy::Atomic)
Definition device_gemm_xdl_cshuffle_streamk_v3.hpp:526
static GET_NXDL_PER_WAVE_IMPL constexpr auto NXdlPerWave64
Definition device_gemm_xdl_cshuffle_streamk_v3.hpp:79
Definition flush_cache.hpp:299