#include <variants.hpp>
|
| __device__ __host__ | StandardAttention ()=default |
| template<typename Params, typename T> |
| __device__ __forceinline__ T | QueryTransform (const Params ¶ms, T q) const |
| template<typename Params, typename T> |
| __device__ __forceinline__ T | LogitsTransform (const Params ¶ms, T logits, uint32_t batch_idx, uint32_t qo_head_idx, uint32_t kv_head_idx) const |
| template<typename Params> |
| __device__ __forceinline__ bool | LogitsMask (const Params ¶ms, uint32_t batch_idx, uint32_t qo_idx, uint32_t kv_idx, uint32_t qo_head_idx, uint32_t kv_head_idx) const |
◆ StandardAttention()
| __device__ __host__ ck_tile::StandardAttention::StandardAttention |
( |
| ) |
|
|
default |
◆ LogitsMask()
template<typename Params>
| __device__ __forceinline__ bool ck_tile::StandardAttention::LogitsMask |
( |
const Params & | params, |
|
|
uint32_t | batch_idx, |
|
|
uint32_t | qo_idx, |
|
|
uint32_t | kv_idx, |
|
|
uint32_t | qo_head_idx, |
|
|
uint32_t | kv_head_idx ) const |
|
inline |
◆ LogitsTransform()
template<typename Params, typename T>
| __device__ __forceinline__ T ck_tile::StandardAttention::LogitsTransform |
( |
const Params & | params, |
|
|
T | logits, |
|
|
uint32_t | batch_idx, |
|
|
uint32_t | qo_head_idx, |
|
|
uint32_t | kv_head_idx ) const |
|
inline |
NOTICE: For better performance, we simpliy transform thread buffer without calculating qo_idx/kv_idx.
◆ QueryTransform()
template<typename Params, typename T>
| __device__ __forceinline__ T ck_tile::StandardAttention::QueryTransform |
( |
const Params & | params, |
|
|
T | q ) const |
|
inline |
The documentation for this struct was generated from the following file: