device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp Source File#
device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp
Go to the documentation of this file.
float launch_and_time_kernel_with_preprocess(const StreamConfig &stream_config, PreProcessFunc preprocess, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:91
__host__ __device__ constexpr index_t gcd(index_t x, index_t y)
Definition utility/math.hpp:154
__host__ __device__ constexpr auto integer_divide_ceil(X x, Y y)
Definition utility/math.hpp:72
Definition convolution_backward_data_specialization.hpp:8
std::string getConvBackwardDataSpecializationString(const ConvolutionBackwardDataSpecialization &s)
Definition convolution_backward_data_specialization.hpp:17
ConvolutionBackwardDataSpecialization
Definition convolution_backward_data_specialization.hpp:11
@ Filter1x1Stride1Pad0
Definition convolution_backward_data_specialization.hpp:13
Definition convolution_backward_data_specialization.hpp:7
CK_TILE_HOST float launch_kernel(const stream_config &s, Callables &&... callables)
Definition tile/host/kernel_launch.hpp:173
Definition ck.hpp:268
typename tuple_element< I, TTuple >::type tuple_element_t
Definition utility/tuple.hpp:208
__global__ void kernel_grouped_conv_multiple_d_wmma_cshuffle(const ADataType *__restrict__ p_a_grid, const BDataType *__restrict__ p_b_grid, DsPointer p_ds_grid, EDataType *__restrict__ p_e_grid, const AElementwiseOperation a_element_op, const BElementwiseOperation b_element_op, const CDEElementwiseOperation cde_element_op, const index_t batch_count, const AGridDesc_AK0_M_AK1 a_grid_desc, const BGridDesc_BK0_N_BK1 b_grid_desc, const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock ds_grid_desc_mblock_mperblock_nblock_nperblock, const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock_, const Block2CTileMap block_2_ctile_map, const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch)
Definition gridwise_gemm_multiple_d_wmma_cshuffle.hpp:40
__host__ __device__ constexpr auto generate_tuple(F &&f, Number< N >)
Definition tuple_helper.hpp:21
typename remove_reference< T >::type remove_reference_t
Definition type.hpp:292
auto accumulate_n(ForwardIterator first, Size count, T init, BinaryOperation op) -> decltype(std::accumulate(first, std::next(first, count), init, op))
Definition library/utility/numeric.hpp:11
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
constexpr LoopScheduler make_default_loop_scheduler()
Definition loop_scheduler.hpp:20
Definition ck/stream_config.hpp:10
Definition gridwise_gemm_multiple_d_wmma_cshuffle.hpp:326
remove_cvref_t< decltype(MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(EGridDesc_M_N{}))> EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
Definition gridwise_gemm_multiple_d_wmma_cshuffle.hpp:891
remove_cvref_t< decltype(MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(DsGridDesc_M_N{}))> DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
Definition gridwise_gemm_multiple_d_wmma_cshuffle.hpp:888
__host__ static __device__ constexpr bool CalculateHasMainKBlockLoop(index_t K)
Definition gridwise_gemm_multiple_d_wmma_cshuffle.hpp:809
__host__ static __device__ constexpr auto MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const EGridDesc_M_N_ &e_grid_desc_m_n)
Definition gridwise_gemm_multiple_d_wmma_cshuffle.hpp:819
__host__ static __device__ constexpr auto MakeDefaultBlock2CTileMap(const EGridDesc_M_N &e_grid_desc_m_n, index_t, index_t)
Definition gridwise_gemm_multiple_d_wmma_cshuffle.hpp:850
__host__ static __device__ constexpr auto MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const DsGridDesc_M_N_ &ds_grid_desc_m_n)
Definition gridwise_gemm_multiple_d_wmma_cshuffle.hpp:840
__host__ static __device__ constexpr bool CheckValidity(const AGridDesc_AK0_M_AK1 &a_grid_desc, const BGridDesc_BK0_N_BK1 &b_grid_desc, const EGridDesc_M_N &e_grid_desc_m_n, const Block2CTileMap &block_2_ctile_map)
Definition gridwise_gemm_multiple_d_wmma_cshuffle.hpp:608
Definition functional2.hpp:33
Definition transform_conv_bwd_data_to_gemm_v1.hpp:44
__host__ __device__ auto MakeADescriptor_AK0_M_AK1() const
Definition transform_conv_bwd_data_to_gemm_v1.hpp:659
__host__ __device__ auto MakeBDescriptor_BK0_N_BK1() const
Definition transform_conv_bwd_data_to_gemm_v1.hpp:943
__host__ __device__ auto MakeCDescriptor_M_N() const
Definition transform_conv_bwd_data_to_gemm_v1.hpp:1150
Definition device_base.hpp:197
BaseArgument()=default
BaseInvoker()=default
Definition device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp:210
std::array< index_t, NDimSpatial > input_left_pads_
Definition device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp:465
std::array< index_t, NDimSpatial > input_right_pads_
Definition device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp:466
std::array< index_t, NDimSpatial > conv_filter_strides_
Definition device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp:464
Argument(const void *p_a, const void *p_b, const std::array< const void *, NumDTensor > &p_ds, void *p_e, const std::array< index_t, NDimSpatial+3 > &a_g_n_k_wos_lengths, const std::array< index_t, NDimSpatial+3 > &a_g_n_k_wos_strides, const std::array< index_t, NDimSpatial+3 > &b_g_k_c_xs_lengths, const std::array< index_t, NDimSpatial+3 > &b_g_k_c_xs_strides, const std::array< std::array< index_t, NDimSpatial+3 >, NumDTensor > &, const std::array< std::array< index_t, NDimSpatial+3 >, NumDTensor > &ds_g_n_c_wis_strides, const std::array< index_t, NDimSpatial+3 > &e_g_n_c_wis_lengths, const std::array< index_t, NDimSpatial+3 > &e_g_n_c_wis_strides, const std::array< index_t, NDimSpatial > &conv_filter_strides, const std::array< index_t, NDimSpatial > &conv_filter_dilations, const std::array< index_t, NDimSpatial > &input_left_pads, const std::array< index_t, NDimSpatial > &input_right_pads, const AElementwiseOp &a_element_op, const BElementwiseOp &b_element_op, const CDEElementwiseOp &cde_element_op, const ck::index_t split_k=1)
Definition device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp:211
std::vector< EGridDesc_M_N > e_grid_desc_m_n_container_
Definition device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp:441
void Print() const
Definition device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp:410
std::vector< AGridDesc_AK0_M_AK1 > a_grid_desc_ak0_m_ak1_container_
Definition device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp:444
long_index_t e_space_size_bytes
Definition device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp:470
ComputePtrOffsetOfStridedBatch< I1, I1, NumDTensor > compute_ptr_offset_of_batch_
Definition device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp:455
bool bwd_needs_zero_out
Definition device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp:469
index_t num_group_
Definition device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp:439
std::vector< DsGridDesc_M_N > ds_grid_desc_m_n_container_
Definition device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp:440
BElementwiseOp b_element_op_
Definition device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp:459
std::vector< BGridDesc_BK0_N_BK1 > b_grid_desc_bk0_n_bk1_container_
Definition device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp:445
std::vector< DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock > ds_grid_desc_mblock_mperblock_nblock_nperblock_container_
Definition device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp:447
std::vector< typename GridwiseGemm::DefaultBlock2CTileMap > block_2_ctile_map_container_
Definition device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp:452
CDEElementwiseOp cde_element_op_
Definition device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp:460
const index_t k_batch_
Definition device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp:468
std::array< index_t, NDimSpatial+3 > a_g_n_k_wos_lengths_
Definition device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp:462
const ADataType * p_a_grid_
Definition device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp:433
AElementwiseOp a_element_op_
Definition device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp:458
EDataType * p_e_grid_
Definition device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp:436
GridwiseGemm::DsGridPointer p_ds_grid_
Definition device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp:435
const BDataType * p_b_grid_
Definition device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp:434
std::array< index_t, NDimSpatial+3 > b_g_k_c_xs_lengths_
Definition device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp:463
std::vector< EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock > e_grid_desc_mblock_mperblock_nblock_nperblock_container_
Definition device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp:449
Definition device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp:475
float Run(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp:478
float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp:560
DeviceOp::Argument Argument
Definition device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp:476
Definition device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp:92
remove_cvref_t< tuple_element_t< 3, ABDsEGridDesc > > EGridDesc_M_N
Definition device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp:143
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp:705
static auto MakeInvoker()
Definition device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp:756
DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle DeviceOp
Definition device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp:97
remove_cvref_t< tuple_element_t< 1, ABDsEGridDesc > > BGridDesc_BK0_N_BK1
Definition device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp:141
static constexpr index_t NumDTensor
Definition device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp:99
static constexpr auto I1
Definition device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp:105
static constexpr auto I0
Definition device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp:104
decltype(GetDummyABDsEGridDescriptor(dummy_conv_to_gemm_transform)) ABDsEGridDesc
Definition device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp:138
static auto MakeArgument(const void *p_a, const void *p_b, const std::array< const void *, NumDTensor > &p_ds, void *p_e, const std::array< index_t, NDimSpatial+3 > &a_g_n_k_wos_lengths, const std::array< index_t, NDimSpatial+3 > &a_g_n_k_wos_strides, const std::array< index_t, NDimSpatial+3 > &b_g_k_c_xs_lengths, const std::array< index_t, NDimSpatial+3 > &b_g_k_c_xs_strides, const std::array< std::array< index_t, NDimSpatial+3 >, NumDTensor > &ds_g_n_c_wis_lengths, const std::array< std::array< index_t, NDimSpatial+3 >, NumDTensor > &ds_g_n_c_wis_strides, const std::array< index_t, NDimSpatial+3 > &e_g_n_c_wis_lengths, const std::array< index_t, NDimSpatial+3 > &e_g_n_c_wis_strides, const std::array< index_t, NDimSpatial > &conv_filter_strides, const std::array< index_t, NDimSpatial > &conv_filter_dilations, const std::array< index_t, NDimSpatial > &input_left_pads, const std::array< index_t, NDimSpatial > &input_right_pads, const AElementwiseOp &a_element_op, const BElementwiseOp &b_element_op, const CDEElementwiseOp &cde_element_op, const ck::index_t split_k=1)
Definition device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp:711
static constexpr auto I3
Definition device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp:107
remove_cvref_t< tuple_element_t< 2, ABDsEGridDesc > > DsGridDesc_M_N
Definition device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp:142
std::string GetTypeString() const override
Definition device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp:809
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp:804
GridwiseGemmMultipleD_Wmma< ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AGridDesc_AK0_M_AK1, BGridDesc_BK0_N_BK1, DsGridDesc_M_N, EGridDesc_M_N, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, InMemoryDataOperationEnum::Set, MPerBlock, NPerBlock, KPerBlock, MPerWMMA, NPerWMMA, K1, MRepeat, NRepeat, BlockSize, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, true, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, true, BBlockLdsExtraN, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEShuffleBlockTransferScalarPerVector_NPerBlock, NumGemmKPrefetchStage, LoopSched, PipelineVer > GridwiseGemm
Definition device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp:146
ADataType ABDataType
Definition device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp:102
static constexpr ConvToGemmBwdDataTransform dummy_conv_to_gemm_transform
Definition device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp:137
decltype(GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( EGridDesc_M_N{})) EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
Definition device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp:204
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, const std::array< const void *, NumDTensor > &p_ds, void *p_e, const std::array< index_t, NDimSpatial+3 > &a_g_n_k_wos_lengths, const std::array< index_t, NDimSpatial+3 > &a_g_n_k_wos_strides, const std::array< index_t, NDimSpatial+3 > &b_g_k_c_xs_lengths, const std::array< index_t, NDimSpatial+3 > &b_g_k_c_xs_strides, const std::array< std::array< index_t, NDimSpatial+3 >, NumDTensor > &ds_g_n_c_wis_lengths, const std::array< std::array< index_t, NDimSpatial+3 >, NumDTensor > &ds_g_n_c_wis_strides, const std::array< index_t, NDimSpatial+3 > &e_g_n_c_wis_lengths, const std::array< index_t, NDimSpatial+3 > &e_g_n_c_wis_strides, const std::array< index_t, NDimSpatial > &conv_filter_strides, const std::array< index_t, NDimSpatial > &conv_filter_dilations, const std::array< index_t, NDimSpatial > &input_left_pads, const std::array< index_t, NDimSpatial > &input_right_pads, const AElementwiseOp &a_element_op, const BElementwiseOp &b_element_op, const CDEElementwiseOp &cde_element_op, const ck::index_t split_k=1) override
Definition device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp:758
remove_cvref_t< tuple_element_t< 0, ABDsEGridDesc > > AGridDesc_AK0_M_AK1
Definition device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp:140
TransformConvBwdDataToGemm_v1< NDimSpatial, ConvBackwardDataSpecialization, K1, K1, MPerBlock, NPerBlock, KPerBlock, true, true, ALayout, BLayout, ELayout > ConvToGemmBwdDataTransform
Definition device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp:110
static bool IsSupportedArgument(const Argument &arg)
Definition device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp:567
static constexpr auto I2
Definition device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp:106
static auto GetDummyABDsEGridDescriptor(const ConvToGemmBwdDataTransform &conv_to_gemm_transform)
Definition device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp:124
decltype(GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( DsGridDesc_M_N{})) DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
Definition device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp:201
static constexpr index_t KPerBlock
Definition device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp:108
Definition device_grouped_conv_bwd_data_multiple_d.hpp:36