device_grouped_query_attention_forward_wmma.hpp Source File#
device_grouped_query_attention_forward_wmma.hpp
Go to the documentation of this file.
75 ? std::array<ck::index_t, array_size>{M * q_head * K, K, q_head * K, 1} // A layout [G0, M, G1, K]
81 ? std::array<ck::index_t, array_size>{N * kv_head * K, K, kv_head * K, 1} // B0 layout [G0, N, 1, K]
87 ? std::array<ck::index_t, array_size>{N * kv_head * O, O, 1, kv_head * O} // B1 layout [G0, N, 1, O]
93 ? std::array<ck::index_t, array_size>{M * q_head * O, O, q_head * O, 1} // C layout [G0, M, G1, O]
123 typename DeviceOp::ComputeBasePtrOfStridedBatch{a_grid_desc_g_m_k, b0_grid_desc_g_l_k, b1_grid_desc_g_n_l, c_grid_desc_g_m_n};
125 const auto c0_matrix_mask = typename DeviceOp::C0MatrixMask{b0_grid_desc_g_l_k.GetLength(Number<1>{})};
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:14
__host__ __device__ constexpr auto integer_divide_ceil(X x, Y y)
Definition utility/math.hpp:72
Definition convolution_backward_data_specialization.hpp:8
std::string getGemmSpecializationString(const GemmSpecialization &s)
Definition gemm_specialization.hpp:32
std::string getMaskingSpecializationString(const MaskingSpecialization &s)
Definition masking_specialization.hpp:17
MaskingSpecialization
Definition masking_specialization.hpp:11
@ MaskDisabled
Definition masking_specialization.hpp:12
@ MaskOutUpperTriangle
Definition masking_specialization.hpp:13
TensorSpecialization
Definition tensor_specialization.hpp:11
__global__ void kernel_grouped_query_attention_wmma(const ADataType *__restrict__ p_a_grid, const B0DataType *__restrict__ p_b0_grid, const B1DataType *__restrict__ p_b1_grid, CDataType *__restrict__ p_c_grid, index_t M, index_t N, index_t K, index_t O, index_t G0, index_t G1, float alpha, bool input_permute, bool output_permute)
Definition device_grouped_query_attention_forward_wmma.hpp:50
GemmSpecialization
Definition gemm_specialization.hpp:11
std::string getTensorSpecializationString(const TensorSpecialization &s)
Definition tensor_specialization.hpp:16
Definition convolution_backward_data_specialization.hpp:7
Definition ck.hpp:268
__host__ __device__ constexpr auto make_multi_index(Xs &&... xs)
Definition array_multi_index.hpp:15
constexpr LoopScheduler make_default_loop_scheduler()
Definition loop_scheduler.hpp:20
Definition ck/stream_config.hpp:10
Definition gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp:93
__host__ static __device__ constexpr auto MakeDefaultBlock2CTileMap(const CGridDesc_M_N &c_grid_desc_m_n, index_t, index_t)
Definition gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp:672
remove_cvref_t< decltype(MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(CGridDesc_M_N{}))> CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
Definition gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp:679
__host__ static __device__ constexpr auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc_M_N &c_grid_desc_m_n)
Definition gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp:653
__host__ static __device__ constexpr bool CheckValidity(const AGridDesc &a_grid_desc, const B0GridDesc &b0_grid_desc, const B1GridDesc &b1_grid_desc, const CGridDesc_M_N &c_grid_desc_m_n, const Block2CTileMap &block_2_ctile_map)
Definition gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp:511
Definition utility/sequence.hpp:43
Definition utility/integral_constant.hpp:20
Definition transform_contraction_to_gemm_arraybase.hpp:122
__host__ static __device__ constexpr auto MakeB0GridDescriptor_BK0_N_BK1(const BGridDesc_N_K &b_grid_desc_n_k, const Number &BK1)
Definition transform_contraction_to_gemm_arraybase.hpp:245
__host__ static __device__ auto MakeCGridDescriptor_G_M_N(const std::array< index_t, NumDimG+NumDimM+NumDimN > &c_gs_ms_os_lengths_vec, const std::array< index_t, NumDimG+NumDimM+NumDimN > &c_gs_ms_os_strides_vec)
Definition transform_contraction_to_gemm_arraybase.hpp:375
__host__ static __device__ constexpr auto MakeAGridDescriptor_AK0_M_AK1(const AGridDesc_M_K &a_grid_desc_m_k, const Number &AK1)
Definition transform_contraction_to_gemm_arraybase.hpp:172
static constexpr auto matrix_padder
Definition transform_contraction_to_gemm_arraybase.hpp:140
__host__ static __device__ auto MakeB1GridDescriptor_N_K(const std::array< index_t, NumDimG+NumDimM+NumDimN > &b1_gs_os_ns_lengths_vec, const std::array< index_t, NumDimG+NumDimM+NumDimN > &b1_gs_os_ns_strides_vec)
Definition transform_contraction_to_gemm_arraybase.hpp:307
__host__ static __device__ auto MakeB1GridDescriptor_G_N_K(const std::array< index_t, NumDimG+NumDimM+NumDimN > &b1_gs_os_ns_lengths_vec, const std::array< index_t, NumDimG+NumDimM+NumDimN > &b1_gs_os_ns_strides_vec)
Definition transform_contraction_to_gemm_arraybase.hpp:301
__host__ static __device__ constexpr auto MakeB0GridDescriptor_BKWmma_LBlockRepeat_LWaves_BK0PerWmma_BKRow_LPerWmma_BK1(const BGridDesc_L_K &b_grid_desc_l_k, const WmmaK &, const LRepeat &, const LWaves &, const LPerWmma &, const BK1 &)
Definition transform_contraction_to_gemm_arraybase.hpp:266
__host__ static __device__ auto MakeB0GridDescriptor_G_N_K(const std::array< index_t, NumDimG+NumDimM+NumDimN > &b0_gs_ns_ks_lengths_vec, const std::array< index_t, NumDimG+NumDimM+NumDimN > &b0_gs_ns_ks_strides_vec)
Definition transform_contraction_to_gemm_arraybase.hpp:228
__host__ static __device__ constexpr auto MakeB1GridDescriptor_BK0_N_BK1(const B1GridDesc_N_K &b1_grid_desc_n_k, const Number &B1K1)
Definition transform_contraction_to_gemm_arraybase.hpp:318
__host__ static __device__ auto MakeAGridDescriptor_G_M_K(const std::array< index_t, NumDimG+NumDimM+NumDimN > &a_gs_ms_ks_lengths_vec, const std::array< index_t, NumDimG+NumDimM+NumDimN > &a_gs_ms_ks_strides_vec)
Definition transform_contraction_to_gemm_arraybase.hpp:156
__host__ static __device__ constexpr auto MakeAGridDescriptor_AKWmma_MBlockRepeat_MWaves_AK0PerWmma_AKRow_MPerWmma_AK1(const AGridDesc_M_K &a_grid_desc_m_k, const WmmaK &, const MRepeat &, const MWaves &, const MPerWmma &, const AK1 &)
Definition transform_contraction_to_gemm_arraybase.hpp:193
__host__ static __device__ auto MakeCGridDescriptor_M_N(const std::array< index_t, NumDimG+NumDimM+NumDimN > &c_gs_ms_os_lengths_vec, const std::array< index_t, NumDimG+NumDimM+NumDimN > &c_gs_ms_os_strides_vec)
Definition transform_contraction_to_gemm_arraybase.hpp:381
__host__ static __device__ constexpr auto MakeB1GridDescriptor_BLWmma_NBlockRepeat_NWaves__BL0PerWmma_BLRow_NPerWmma_BL1(const BGridDesc_N_L &b_grid_desc_n_l, const WmmaL &, const NRepeat &, const NWaves &, const NPerWmma &, const BL1 &)
Definition transform_contraction_to_gemm_arraybase.hpp:340
__host__ static __device__ auto MakeAGridDescriptor_M_K(const std::array< index_t, NumDimG+NumDimM+NumDimN > &a_gs_ms_ks_lengths_vec, const std::array< index_t, NumDimG+NumDimM+NumDimN > &a_gs_ms_ks_strides_vec)
Definition transform_contraction_to_gemm_arraybase.hpp:162
__host__ static __device__ auto MakeB0GridDescriptor_N_K(const std::array< index_t, NumDimG+NumDimM+NumDimN > &b0_gs_ns_ks_lengths_vec, const std::array< index_t, NumDimG+NumDimM+NumDimN > &b0_gs_ns_ks_strides_vec)
Definition transform_contraction_to_gemm_arraybase.hpp:234
Definition device_base.hpp:197
BaseArgument()=default
BaseInvoker()=default
Definition masking_specialization.hpp:57
Definition device_batched_gemm_softmax_gemm_permute.hpp:34
Definition device_grouped_query_attention_forward_wmma.hpp:756
CGridDesc_M_N c_grid_desc_m_n_
Definition device_grouped_query_attention_forward_wmma.hpp:853
const B0DataType * p_b0_grid_
Definition device_grouped_query_attention_forward_wmma.hpp:845
AGridDesc_G_M_K a_grid_desc_g_m_k_
Definition device_grouped_query_attention_forward_wmma.hpp:855
AccElementwiseOperation acc_element_op_
Definition device_grouped_query_attention_forward_wmma.hpp:869
B1ElementwiseOperation b1_element_op_
Definition device_grouped_query_attention_forward_wmma.hpp:870
GridwiseOp::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock_
Definition device_grouped_query_attention_forward_wmma.hpp:861
std::array< index_t, NumDimG+NumDimM+NumDimN > a_mz_kz_strides_
Definition device_grouped_query_attention_forward_wmma.hpp:879
ck::tensor_operation::device::DeviceGroupedQueryAttentionForward_Wmma::Argument::b0_grid_desc_g_l_k_
B0GridDesc_G_L_K b0_grid_desc_g_l_k_
Definition device_grouped_query_attention_forward_wmma.hpp:856
ComputeBasePtrOfStridedBatch compute_ptr_offset_of_batch_
Definition device_grouped_query_attention_forward_wmma.hpp:886
AGridDesc a_grid_desc
Definition device_grouped_query_attention_forward_wmma.hpp:850
GridwiseOp::DefaultBlock2CTileMap block_2_ctile_map_
Definition device_grouped_query_attention_forward_wmma.hpp:864
std::array< index_t, NumDimG+NumDimM+NumDimN > b1_nz_lz_strides_
Definition device_grouped_query_attention_forward_wmma.hpp:881
CElementwiseOperation c_element_op_
Definition device_grouped_query_attention_forward_wmma.hpp:871
B1GridDesc b1_grid_desc
Definition device_grouped_query_attention_forward_wmma.hpp:852
const ADataType * p_a_grid_
Definition device_grouped_query_attention_forward_wmma.hpp:844
B0ElementwiseOperation b0_element_op_
Definition device_grouped_query_attention_forward_wmma.hpp:868
C0MatrixMask c0_matrix_mask_
Definition device_grouped_query_attention_forward_wmma.hpp:874
std::array< index_t, NumDimG+NumDimM+NumDimN > b0_lz_kz_strides_
Definition device_grouped_query_attention_forward_wmma.hpp:880
CDataType * p_c_grid_
Definition device_grouped_query_attention_forward_wmma.hpp:847
CGridDesc_G_M_N c_grid_desc_g_m_n_
Definition device_grouped_query_attention_forward_wmma.hpp:858
B0GridDesc b0_grid_desc
Definition device_grouped_query_attention_forward_wmma.hpp:851
std::array< index_t, NumDimG+NumDimM+NumDimN > raw_lengths_mz_lz_kz_nz_
Definition device_grouped_query_attention_forward_wmma.hpp:878
Argument(const ADataType *p_a_grid, const B0DataType *p_b0_grid, const B1DataType *p_b1_grid, CDataType *p_c_grid, const std::array< void *, NumAcc0Bias > p_acc0_biases, const std::array< void *, NumAcc1Bias > p_acc1_biases, const std::array< index_t, NumDimG+NumDimM+NumDimN > &a_gs_ms_ks_lengths, const std::array< index_t, NumDimG+NumDimM+NumDimN > &a_gs_ms_ks_strides, const std::array< index_t, NumDimG+NumDimM+NumDimN > &b0_gs_ls_ks_lengths, const std::array< index_t, NumDimG+NumDimM+NumDimN > &b0_gs_ls_ks_strides, const std::array< index_t, NumDimG+NumDimM+NumDimN > &b1_gs_ns_ls_lengths, const std::array< index_t, NumDimG+NumDimM+NumDimN > &b1_gs_ns_ls_strides, const std::array< index_t, NumDimG+NumDimM+NumDimN > &c_gs_ms_ns_lengths, const std::array< index_t, NumDimG+NumDimM+NumDimN > &c_gs_ms_ns_strides, const std::array< std::vector< ck::index_t >, NumAcc0Bias > acc0_biases_gs_ms_ls_lengths, const std::array< std::vector< ck::index_t >, NumAcc0Bias > acc0_biases_gs_ms_ls_strides, const std::array< std::vector< ck::index_t >, NumAcc1Bias > acc1_biases_gs_ms_ns_lengths, const std::array< std::vector< ck::index_t >, NumAcc1Bias > acc1_biases_gs_ms_ns_strides, const index_t M01, const index_t N01, AElementwiseOperation a_element_op, B0ElementwiseOperation b0_element_op, AccElementwiseOperation acc_element_op, B1ElementwiseOperation b1_element_op, CElementwiseOperation c_element_op)
Definition device_grouped_query_attention_forward_wmma.hpp:757
index_t batch_count_
Definition device_grouped_query_attention_forward_wmma.hpp:884
ck::tensor_operation::device::DeviceGroupedQueryAttentionForward_Wmma::Argument::b1_grid_desc_g_n_l_
B1GridDesc_G_N_L b1_grid_desc_g_n_l_
Definition device_grouped_query_attention_forward_wmma.hpp:857
AElementwiseOperation a_element_op_
Definition device_grouped_query_attention_forward_wmma.hpp:867
std::array< index_t, NumDimG+NumDimM+NumDimN > c_mz_nz_strides_
Definition device_grouped_query_attention_forward_wmma.hpp:882
const B1DataType * p_b1_grid_
Definition device_grouped_query_attention_forward_wmma.hpp:846
Definition device_grouped_query_attention_forward_wmma.hpp:417
__host__ __device__ constexpr long_index_t GetABasePtr(index_t g_idx) const
Definition device_grouped_query_attention_forward_wmma.hpp:429
__host__ __device__ constexpr long_index_t GetB0BasePtr(index_t g_idx) const
Definition device_grouped_query_attention_forward_wmma.hpp:434
__host__ __device__ ComputeBasePtrOfStridedBatch(const AGridDesc_G_M_K &a_grid_desc_g_m_k, const B0GridDesc_G_L_K &b0_grid_desc_g_l_k, const B1GridDesc_G_N_L &b1_grid_desc_g_n_l, const CGridDesc_G_M_N &c_grid_desc_g_m_n)
Definition device_grouped_query_attention_forward_wmma.hpp:418
__host__ __device__ constexpr long_index_t GetCBasePtr(index_t g_idx) const
Definition device_grouped_query_attention_forward_wmma.hpp:444
__host__ __device__ constexpr long_index_t GetB1BasePtr(index_t g_idx) const
Definition device_grouped_query_attention_forward_wmma.hpp:439
Definition device_grouped_query_attention_forward_wmma.hpp:890
float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition device_grouped_query_attention_forward_wmma.hpp:947
float Run(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition device_grouped_query_attention_forward_wmma.hpp:893
DeviceOp::RawArg Argument
Definition device_grouped_query_attention_forward_wmma.hpp:891
Definition device_grouped_query_attention_forward_wmma.hpp:533
index_t G1_
Definition device_grouped_query_attention_forward_wmma.hpp:574
index_t N_
Definition device_grouped_query_attention_forward_wmma.hpp:570
float alpha_
Definition device_grouped_query_attention_forward_wmma.hpp:575
index_t K_
Definition device_grouped_query_attention_forward_wmma.hpp:571
bool output_permute_
Definition device_grouped_query_attention_forward_wmma.hpp:577
const ADataType * p_a_grid_
Definition device_grouped_query_attention_forward_wmma.hpp:563
index_t M_
Definition device_grouped_query_attention_forward_wmma.hpp:569
const B1DataType * p_b1_grid_
Definition device_grouped_query_attention_forward_wmma.hpp:565
RawArg(const ADataType *p_a_grid, const B0DataType *p_b0_grid, const B1DataType *p_b1_grid, CDataType *p_c_grid, index_t M, index_t N, index_t K, index_t O, index_t G0, index_t G1, float alpha, bool input_permute, bool output_permute)
Definition device_grouped_query_attention_forward_wmma.hpp:534
index_t O_
Definition device_grouped_query_attention_forward_wmma.hpp:572
bool input_permute_
Definition device_grouped_query_attention_forward_wmma.hpp:576
index_t G0_
Definition device_grouped_query_attention_forward_wmma.hpp:573
const B0DataType * p_b0_grid_
Definition device_grouped_query_attention_forward_wmma.hpp:564
CDataType * p_c_grid_
Definition device_grouped_query_attention_forward_wmma.hpp:566
Definition device_grouped_query_attention_forward_wmma.hpp:266
static constexpr auto NWaves
Definition device_grouped_query_attention_forward_wmma.hpp:297
static constexpr bool IsValidCompilationParameter()
Definition device_grouped_query_attention_forward_wmma.hpp:954
static constexpr auto I0
Definition device_grouped_query_attention_forward_wmma.hpp:285
static constexpr auto B0EnableLds_manu
Definition device_grouped_query_attention_forward_wmma.hpp:304
static constexpr index_t NumDimGemm1K
Definition device_grouped_query_attention_forward_wmma.hpp:281
static constexpr auto AEnableLds_auto
Definition device_grouped_query_attention_forward_wmma.hpp:299
decltype(MakeB0GridDescriptor({}, {})) B0GridDesc
Definition device_grouped_query_attention_forward_wmma.hpp:395
decltype(Transform::MakeB1GridDescriptor_G_N_K({}, {})) B1GridDesc_G_N_L
Definition device_grouped_query_attention_forward_wmma.hpp:400
static constexpr auto B1EnableLds_auto
Definition device_grouped_query_attention_forward_wmma.hpp:301
DeviceGroupedQueryAttentionForward_Wmma DeviceOp
Definition device_grouped_query_attention_forward_wmma.hpp:283
static constexpr auto I1
Definition device_grouped_query_attention_forward_wmma.hpp:286
static constexpr auto I5
Definition device_grouped_query_attention_forward_wmma.hpp:290
decltype(Transform::MakeCGridDescriptor_G_M_N({}, {})) CGridDesc_G_M_N
Definition device_grouped_query_attention_forward_wmma.hpp:401
decltype(MakeB1GridDescriptor({}, {})) B1GridDesc
Definition device_grouped_query_attention_forward_wmma.hpp:396
__host__ __device__ static constexpr auto make_MaskOutPredicate()
Definition device_grouped_query_attention_forward_wmma.hpp:403
decltype(MakeAGridDescriptor({}, {})) AGridDesc
Definition device_grouped_query_attention_forward_wmma.hpp:394
__host__ static __device__ auto MakeAGridDescriptor(const std::array< index_t, NumDimG+NumDimM+NumDimN > &a_gs_ms_ks_lengths_vec, const std::array< index_t, NumDimG+NumDimM+NumDimN > &a_gs_ms_ks_strides_vec)
Definition device_grouped_query_attention_forward_wmma.hpp:320
static constexpr auto B0EnableLds
Definition device_grouped_query_attention_forward_wmma.hpp:308
C0MatrixMask_impl< decltype(make_MaskOutPredicate())> C0MatrixMask
Definition device_grouped_query_attention_forward_wmma.hpp:414
static constexpr auto MWaves
Definition device_grouped_query_attention_forward_wmma.hpp:295
static constexpr index_t NumDimGemm0M
Definition device_grouped_query_attention_forward_wmma.hpp:276
static constexpr index_t NumDimGemm0N
Definition device_grouped_query_attention_forward_wmma.hpp:277
static constexpr auto I6
Definition device_grouped_query_attention_forward_wmma.hpp:291
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition device_grouped_query_attention_forward_wmma.hpp:1197
static constexpr index_t NumDimGemm1M
Definition device_grouped_query_attention_forward_wmma.hpp:279
__host__ static __device__ auto MakeB0GridDescriptor(const std::array< index_t, NumDimG+NumDimM+NumDimN > &b0_gs_ls_ks_lengths_vec, const std::array< index_t, NumDimG+NumDimM+NumDimN > &b0_gs_ls_ks_strides_vec)
Definition device_grouped_query_attention_forward_wmma.hpp:344
static bool IsSupportedArgument(const RawArg &arg)
Definition device_grouped_query_attention_forward_wmma.hpp:598
static constexpr index_t NumAcc1Bias
Definition device_grouped_query_attention_forward_wmma.hpp:271
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition device_grouped_query_attention_forward_wmma.hpp:749
TransformBatchedContractionContractionToBatchedGemmGemm_Wmma< Sequence< NumDimG, NumDimM, NumDimL, NumDimK, NumDimN >, Sequence< MPerBlock, LPerBlock, KPerBlock, NPerBlock >, GemmSpec, ASpec, B0Spec, B1Spec, CSpec > Transform
Definition device_grouped_query_attention_forward_wmma.hpp:311
static constexpr auto B1EnableLds_manu
Definition device_grouped_query_attention_forward_wmma.hpp:305
static constexpr index_t NumDimGemm1N
Definition device_grouped_query_attention_forward_wmma.hpp:280
static constexpr auto B0EnableLds_auto
Definition device_grouped_query_attention_forward_wmma.hpp:300
static constexpr auto AEnableLds_manu
Definition device_grouped_query_attention_forward_wmma.hpp:303
static auto MakeInvoker()
Definition device_grouped_query_attention_forward_wmma.hpp:1194
static constexpr auto I4
Definition device_grouped_query_attention_forward_wmma.hpp:289
static auto MakeArgument(const ADataType *p_a, const B0DataType *p_b0, const B1DataType *p_b1, CDataType *p_c, index_t M, index_t N, index_t K, index_t O, index_t G0, index_t G1, float alpha, bool input_permute, bool output_permute)
Definition device_grouped_query_attention_forward_wmma.hpp:580
decltype(Transform::MakeCGridDescriptor_M_N({}, {})) CGridDesc_M_N
Definition device_grouped_query_attention_forward_wmma.hpp:397
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b0, const void *p_b1, void *p_c, const std::array< void *, NumAcc0Bias > p_acc0_biases, const std::array< void *, NumAcc1Bias > p_acc1_biases, const std::vector< index_t > &a_gs_ms_ks_lengths, const std::vector< index_t > &a_gs_ms_ks_strides, const std::vector< index_t > &b0_gs_ls_ks_lengths, const std::vector< index_t > &b0_gs_ls_ks_strides, const std::vector< index_t > &b1_gs_ns_ls_lengths, const std::vector< index_t > &b1_gs_ns_ls_strides, const std::vector< index_t > &c_gs_ms_ns_lengths, const std::vector< index_t > &c_gs_ms_ns_strides, const std::array< std::vector< ck::index_t >, NumAcc0Bias > acc0_biases_gs_ms_ls_lengths, const std::array< std::vector< ck::index_t >, NumAcc0Bias > acc0_biases_gs_ms_ls_strides, const std::array< std::vector< ck::index_t >, NumAcc1Bias > acc1_biases_gs_ms_ns_lengths, const std::array< std::vector< ck::index_t >, NumAcc1Bias > acc1_biases_gs_ms_ns_strides, AElementwiseOperation a_element_op, B0ElementwiseOperation b0_element_op, AccElementwiseOperation acc_element_op, B1ElementwiseOperation b1_element_op, CElementwiseOperation c_element_op) override
Definition device_grouped_query_attention_forward_wmma.hpp:1102
static constexpr index_t NumDimGemm0K
Definition device_grouped_query_attention_forward_wmma.hpp:278
static constexpr auto I2
Definition device_grouped_query_attention_forward_wmma.hpp:287
static constexpr auto B1EnableLds
Definition device_grouped_query_attention_forward_wmma.hpp:309
__host__ static __device__ auto MakeB1GridDescriptor(const std::array< index_t, NumDimG+NumDimM+NumDimN > &b1_gs_ns_ls_lengths_vec, const std::array< index_t, NumDimG+NumDimM+NumDimN > &b1_gs_ns_ls_strides_vec)
Definition device_grouped_query_attention_forward_wmma.hpp:369
static constexpr auto LWaves
Definition device_grouped_query_attention_forward_wmma.hpp:296
static constexpr auto AEnableLds
Definition device_grouped_query_attention_forward_wmma.hpp:307
decltype(Transform::MakeB0GridDescriptor_G_N_K({}, {})) B0GridDesc_G_L_K
Definition device_grouped_query_attention_forward_wmma.hpp:399
std::string GetTypeString() const override
Definition device_grouped_query_attention_forward_wmma.hpp:1203
static constexpr auto WmmaK
Definition device_grouped_query_attention_forward_wmma.hpp:293
decltype(Transform::MakeAGridDescriptor_G_M_K({}, {})) AGridDesc_G_M_K
Definition device_grouped_query_attention_forward_wmma.hpp:398
static constexpr index_t NumAcc0Bias
Definition device_grouped_query_attention_forward_wmma.hpp:270
GridwiseBatchedGemmSoftmaxGemm_Wmma< ADataType, B0DataType, Acc0DataType, B1DataType, Acc1DataType, CShuffleDataType, CDataType, AElementwiseOperation, B0ElementwiseOperation, AccElementwiseOperation, B1ElementwiseOperation, CElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc, B0GridDesc, B1GridDesc, CGridDesc_M_N, MPerBlock, LPerBlock, KPerBlock, AK1, BK1, NPerBlock, LTilePerBlock, L1, MPerWmma, LPerWmma, NPerWmma, MRepeat, LRepeat, NRepeat, BlockSize, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_K1, true, AEnableLds, ABlockLdsAddExtraM, B0BlockTransferThreadClusterLengths_K0_L_K1, B0BlockTransferThreadClusterArrangeOrder, B0BlockTransferSrcAccessOrder, B0BlockTransferSrcVectorDim, B0BlockTransferSrcScalarPerVector, B0BlockTransferDstScalarPerVector_K1, true, B0EnableLds, B0BlockLdsAddExtraL, B1BlockTransferThreadClusterLengths_L0_N_L1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_L1, false, B1EnableLds, B1BlockLdsAddExtraN, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, Transform::matrix_padder.PadN, MaskingSpec==MaskingSpecialization::MaskOutUpperTriangle, NumPrefetch, LoopSched, PipelineVer > GridwiseOp
Definition device_grouped_query_attention_forward_wmma.hpp:457
static constexpr auto I3
Definition device_grouped_query_attention_forward_wmma.hpp:288
Definition masking_specialization.hpp:29
Definition masking_specialization.hpp:43