device_fpAintB_gemm_wmma.hpp Source File#
device_fpAintB_gemm_wmma.hpp
Go to the documentation of this file.
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:14
Definition convolution_backward_data_specialization.hpp:8
GemmSpecialization
Definition gemm_specialization.hpp:11
Definition convolution_backward_data_specialization.hpp:7
CK_TILE_HOST float launch_kernel(const stream_config &s, Callables &&... callables)
Definition tile/host/kernel_launch.hpp:173
Definition ck.hpp:268
__host__ __device__ constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition multi_index_transform_helper.hpp:12
__host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition tensor_descriptor_helper.hpp:49
typename remove_reference< T >::type remove_reference_t
Definition type.hpp:292
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
__host__ __device__ constexpr auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:90
__global__ void kernel_fpAintB_gemm_wmma(const ADataType *__restrict__ p_a_grid, const BDataType *__restrict__ p_b_grid, const ScaleDataType *__restrict__ p_scale_grid, CDataType *__restrict__ p_c_grid, const AGridDesc a_grid_desc, const BGridDesc b_grid_desc, const ScaleGridDesc scale_grid_desc, const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock, const AElementwiseOperation a_element_op, const BElementwiseOperation b_element_op, const CElementwiseOperation c_element_op, const Block2CTileMap block_2_ctile_map)
Definition gridwise_fpAintB_gemm_wmma.hpp:40
constexpr LoopScheduler make_default_loop_scheduler()
Definition loop_scheduler.hpp:20
Definition ck/stream_config.hpp:10
Definition gridwise_fpAintB_gemm_wmma.hpp:136
__host__ static __device__ constexpr auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc_M_N &c_grid_desc_m_n)
Definition gridwise_fpAintB_gemm_wmma.hpp:529
__host__ static __device__ constexpr bool CheckValidity(const AGridDesc &a_grid_desc, const BGridDesc &b_grid_desc, const CGridDesc_M_N &c_grid_desc_m_n, const Block2CTileMap &block_2_ctile_map)
Definition gridwise_fpAintB_gemm_wmma.hpp:421
remove_cvref_t< decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}, 1, 1))> DefaultBlock2CTileMap
Definition gridwise_fpAintB_gemm_wmma.hpp:558
remove_cvref_t< decltype(MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(CGridDesc_M_N{}))> CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
Definition gridwise_fpAintB_gemm_wmma.hpp:555
__host__ static __device__ constexpr bool CalculateHasMainKBlockLoop(index_t K)
Definition gridwise_fpAintB_gemm_wmma.hpp:521
__host__ static __device__ constexpr auto MakeDefaultBlock2CTileMap(const CGridDesc_M_N &c_grid_desc_m_n, index_t, index_t)
Definition gridwise_fpAintB_gemm_wmma.hpp:548
Definition utility/sequence.hpp:43
Definition tensor_operation/gpu/device/tensor_layout.hpp:31
Definition tensor_operation/gpu/device/tensor_layout.hpp:26
Definition device_base.hpp:197
BaseArgument()=default
BaseInvoker()=default
Definition device_fpAintB_gemm_wmma.hpp:344
GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_
Definition device_fpAintB_gemm_wmma.hpp:407
index_t M01_
Definition device_fpAintB_gemm_wmma.hpp:408
Argument(const ADataType *p_a_grid, const BDataType *p_b_grid, const ScaleDataType *p_scale_grid, CDataType *p_c_grid, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, index_t StrideC, index_t M01, index_t N01, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op)
Definition device_fpAintB_gemm_wmma.hpp:345
index_t KRaw_
Definition device_fpAintB_gemm_wmma.hpp:416
AGridDesc a_grid_desc_
Definition device_fpAintB_gemm_wmma.hpp:401
BGridDesc b_grid_desc_
Definition device_fpAintB_gemm_wmma.hpp:402
CElementwiseOperation c_element_op_
Definition device_fpAintB_gemm_wmma.hpp:412
index_t MRaw_
Definition device_fpAintB_gemm_wmma.hpp:414
const BDataType * p_b_grid_
Definition device_fpAintB_gemm_wmma.hpp:398
CDataType * p_c_grid_
Definition device_fpAintB_gemm_wmma.hpp:400
index_t N01_
Definition device_fpAintB_gemm_wmma.hpp:409
index_t NRaw_
Definition device_fpAintB_gemm_wmma.hpp:415
CGridDesc_M_N c_grid_desc_m_n_
Definition device_fpAintB_gemm_wmma.hpp:404
ScaleGridDesc scale_grid_desc_
Definition device_fpAintB_gemm_wmma.hpp:403
const ADataType * p_a_grid_
Definition device_fpAintB_gemm_wmma.hpp:397
BElementwiseOperation b_element_op_
Definition device_fpAintB_gemm_wmma.hpp:411
AElementwiseOperation a_element_op_
Definition device_fpAintB_gemm_wmma.hpp:410
const ScaleDataType * p_scale_grid_
Definition device_fpAintB_gemm_wmma.hpp:399
GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock
Definition device_fpAintB_gemm_wmma.hpp:406
Definition device_fpAintB_gemm_wmma.hpp:421
float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition device_fpAintB_gemm_wmma.hpp:497
float Run(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition device_fpAintB_gemm_wmma.hpp:424
DeviceOp::Argument Argument
Definition device_fpAintB_gemm_wmma.hpp:422
Definition device_fpAintB_gemm_wmma.hpp:79
static constexpr auto AEnableLds_manu
Definition device_fpAintB_gemm_wmma.hpp:101
static constexpr auto I5
Definition device_fpAintB_gemm_wmma.hpp:85
static constexpr auto AEnableLds
Definition device_fpAintB_gemm_wmma.hpp:104
decltype(MakeBGridDescriptor(1, 1, 1)) BGridDesc
Definition device_fpAintB_gemm_wmma.hpp:287
static constexpr auto I6
Definition device_fpAintB_gemm_wmma.hpp:86
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, const void *p_scale, void *p_c, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, index_t StrideC, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op) override
Definition device_fpAintB_gemm_wmma.hpp:633
std::string GetTypeString() const override
Definition device_fpAintB_gemm_wmma.hpp:671
decltype(MakeScaleGridDescriptor(1, 1, 0)) ScaleGridDesc
Definition device_fpAintB_gemm_wmma.hpp:288
decltype(MakeCGridDescriptor_M_N(1, 1, 1)) CGridDesc_M_N
Definition device_fpAintB_gemm_wmma.hpp:289
static auto MakeScaleGridDescriptor(index_t KRaw, index_t NRaw, index_t StrideB=0)
Definition device_fpAintB_gemm_wmma.hpp:221
static constexpr auto I1
Definition device_fpAintB_gemm_wmma.hpp:81
DeviceFpAintBGemm_Wmma_CShuffle DeviceOp
Definition device_fpAintB_gemm_wmma.hpp:110
static constexpr auto BEnableLds_auto
Definition device_fpAintB_gemm_wmma.hpp:96
static constexpr auto BEnableLds_manu
Definition device_fpAintB_gemm_wmma.hpp:102
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition device_fpAintB_gemm_wmma.hpp:665
decltype(MakeAGridDescriptor(1, 1, 1)) AGridDesc
Definition device_fpAintB_gemm_wmma.hpp:286
static auto MakeCGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideC)
Definition device_fpAintB_gemm_wmma.hpp:267
static bool IsSupportedArgument(const Argument &arg)
Definition device_fpAintB_gemm_wmma.hpp:510
static auto MakeAGridDescriptor(index_t MRaw, index_t KRaw, index_t StrideA)
Definition device_fpAintB_gemm_wmma.hpp:113
static constexpr auto I0
Definition device_fpAintB_gemm_wmma.hpp:80
static auto MakeBGridDescriptor(index_t KRaw, index_t NRaw, index_t StrideB)
Definition device_fpAintB_gemm_wmma.hpp:167
static constexpr auto I3
Definition device_fpAintB_gemm_wmma.hpp:83
static constexpr auto AEnableLds_auto
Definition device_fpAintB_gemm_wmma.hpp:94
static constexpr auto I4
Definition device_fpAintB_gemm_wmma.hpp:84
static constexpr auto NWaves
Definition device_fpAintB_gemm_wmma.hpp:91
static constexpr auto matrix_padder
Definition device_fpAintB_gemm_wmma.hpp:107
static constexpr auto MWaves
Definition device_fpAintB_gemm_wmma.hpp:90
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition device_fpAintB_gemm_wmma.hpp:594
static constexpr auto K1Number
Definition device_fpAintB_gemm_wmma.hpp:88
GridwiseFpAintBGemm_Wmma< BlockSize, ADataType, BDataType, ScaleDataType, AccDataType, CShuffleDataType, CDataType, InMemoryDataOperationEnum::Set, AGridDesc, BGridDesc, ScaleGridDesc, CGridDesc_M_N, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, MPerBlock, NPerBlock, KPerBlock, MPerWmma, NPerWmma, K1, MRepeat, NRepeat, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_K1, false, AEnableLds, ABlockLdsAddExtraM, BBlockTransferThreadClusterLengths_K0_N_K1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_K1, false, BEnableLds, BBlockLdsAddExtraN, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, NumPrefetch, LoopSched, PipelineVer > GridwiseGemm
Definition device_fpAintB_gemm_wmma.hpp:292
static constexpr auto WmmaK
Definition device_fpAintB_gemm_wmma.hpp:92
static constexpr auto BEnableLds
Definition device_fpAintB_gemm_wmma.hpp:105
static constexpr auto I2
Definition device_fpAintB_gemm_wmma.hpp:82
static auto MakeArgument(const ADataType *p_a, const BDataType *p_b, const ScaleDataType *p_scale, CDataType *p_c, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, index_t StrideC, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op)
Definition device_fpAintB_gemm_wmma.hpp:599
static auto MakeInvoker()
Definition device_fpAintB_gemm_wmma.hpp:630
static constexpr bool IsValidCompilationParameter()
Definition device_fpAintB_gemm_wmma.hpp:504
Definition device_gemm_dequantB.hpp:25
Definition matrix_padder.hpp:180