fmha_fwd_appendkv_kernel.hpp Source File

fmha_fwd_appendkv_kernel.hpp Source File#

Composable Kernel: fmha_fwd_appendkv_kernel.hpp Source File
fmha_fwd_appendkv_kernel.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include "ck_tile/core.hpp"
8#include <string>
9#include <type_traits>
10
11namespace ck_tile {
12
13template <typename FmhaPipeline_>
15{
17 static constexpr ck_tile::index_t kBlockSize = FmhaPipeline::kBlockSize;
18 static constexpr ck_tile::index_t kBlockPerCu = FmhaPipeline::kBlockPerCu;
19
20 static_assert(kBlockPerCu > 0);
21 static constexpr ck_tile::index_t kBlockPerCuInput = FmhaPipeline::Problem::kBlockPerCu;
22
26
28 static constexpr bool kApplyRoPE = FmhaPipeline::RotaryEnum != RotaryEmbeddingEnum::NONE;
29 static constexpr bool kIsPagedKV = FmhaPipeline::kIsPagedKV;
30
31 static constexpr bool kPadSeqLenQ = FmhaPipeline::kPadSeqLenQ;
32 static constexpr bool kPadSeqLenK = FmhaPipeline::kPadSeqLenK;
33 static constexpr bool kPadHeadDimQ = FmhaPipeline::kPadHeadDimQ;
34 static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV;
35
36 // clang-format off
37 template <typename T> struct t2s;
38 template <> struct t2s<float> { static constexpr const char * name = "fp32"; };
39 template <> struct t2s<ck_tile::fp16_t> { static constexpr const char * name = "fp16"; };
40 template <> struct t2s<ck_tile::bf16_t> { static constexpr const char * name = "bf16"; };
41 template <> struct t2s<ck_tile::fp8_t> { static constexpr const char * name = "fp8"; };
42 template <> struct t2s<ck_tile::bf8_t> { static constexpr const char * name = "bf8"; };
43 // clang-format on
44
45 CK_TILE_HOST static std::string GetName()
46 {
47 // sync with generate.py
48 // clang-format off
49
50 #define _SS_ std::string
51 #define _TS_ std::to_string
52 auto pn = [&] () {
53 std::string n;
54 if (kPadSeqLenQ) n += "s";
55 if (kPadSeqLenK) n += "sk";
56 if (kPadHeadDimQ) n += "d";
57 if (kPadHeadDimV) n += "dv";
58 return n.empty() ? n : std::string("p") + n; }();
59 return
60 _SS_("fmha_fwd_appendkv_d") + _TS_(FmhaPipeline::kK0) + "_" + _SS_(t2s<QDataType>::name) + "_"
61 "b" + _TS_(FmhaPipeline::kM0) + "x" + _TS_(FmhaPipeline::kN0) + "x" + _TS_(FmhaPipeline::kK0) + "x" +
62 _TS_(FmhaPipeline::kN1) + "_" + (kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) +
63 "v" + (std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor> ? "r" : "c") + (pn.empty() ? "" : "_" + pn)
65 + (kIsPagedKV ? "_pagedkv" : "" );
66 #undef _SS_
67 #undef _TS_
68 // clang-format on
69 }
70
71 template <ck_tile::index_t I> // to avoid duplicated base class prblem, introduce an template
72 // arg
74 {
75 };
76
77 // kargs use aggregate initializer, so no constructor will provided
78 // use inheritance to minimize karg size
79 // user need to use MakeKargs() function to create kargs.
119
127
134
136 {
138 };
139
141 std::conditional_t<kApplyRoPE, RoPEKargs, EmptyKargs<0>>,
142 std::conditional_t<kIsPagedKV, PageBlockTableKargs, CacheBatchIdxKargs>
143 {
144 };
145
146 CK_TILE_HOST static constexpr Kargs MakeKargs(void* q_ptr,
147 void* k_ptr,
148 const void* knew_ptr,
149 void* v_ptr,
150 const void* vnew_ptr,
151 ck_tile::index_t seqlen_q,
152 const void* seqlen_k_ptr,
153 ck_tile::index_t seqlen_knew,
154 ck_tile::index_t hdim_q,
155 ck_tile::index_t hdim_v,
156 ck_tile::index_t num_head_q,
157 ck_tile::index_t nhead_ratio_qk,
158 const void* rotary_cos_ptr,
159 const void* rotary_sin_ptr,
160 ck_tile::index_t rotary_dim,
161 bool has_mask,
162 const void* block_table_ptr,
163 ck_tile::index_t batch_stride_block_table,
164 ck_tile::index_t page_block_size,
165 const void* cache_batch_idx,
166 ck_tile::index_t stride_q,
167 ck_tile::index_t stride_k,
168 ck_tile::index_t stride_knew,
169 ck_tile::index_t stride_v,
170 ck_tile::index_t stride_vnew,
171 ck_tile::index_t nhead_stride_q,
172 ck_tile::index_t nhead_stride_k,
173 ck_tile::index_t nhead_stride_knew,
174 ck_tile::index_t nhead_stride_v,
175 ck_tile::index_t nhead_stride_vnew,
176 ck_tile::index_t batch_stride_q,
177 ck_tile::index_t batch_stride_k,
178 ck_tile::index_t batch_stride_knew,
179 ck_tile::index_t batch_stride_v,
180 ck_tile::index_t batch_stride_vnew)
181 {
182 Kargs kargs{
183 {q_ptr,
184 k_ptr,
185 knew_ptr,
186 v_ptr,
187 vnew_ptr,
188 reinterpret_cast<const int32_t*>(seqlen_k_ptr),
189 seqlen_q,
190 -1, // seqlen_k will be updated by content of seqlen_k_ptr
191 seqlen_knew,
192 hdim_q,
193 hdim_v,
194 num_head_q,
195 nhead_ratio_qk,
196 stride_q,
197 stride_k,
198 stride_knew,
199 stride_v,
200 stride_vnew,
201 nhead_stride_q,
202 nhead_stride_k,
203 nhead_stride_knew,
204 nhead_stride_v,
205 nhead_stride_vnew,
206 batch_stride_q,
207 batch_stride_k,
208 batch_stride_knew,
209 batch_stride_v,
210 batch_stride_vnew}, // args for common karg
211 {}, // placeholder for rope
212 {} // placeholder for paged-block table or cache_batch_idx
213 };
214
215 if constexpr(kApplyRoPE)
216 {
217 kargs.rotary_cos_ptr = rotary_cos_ptr;
218 kargs.rotary_sin_ptr = rotary_sin_ptr;
219 kargs.rotary_dim = rotary_dim;
220 kargs.has_mask = has_mask;
221 }
222
223 if constexpr(kIsPagedKV)
224 {
225 kargs.block_table_ptr = reinterpret_cast<const int32_t*>(block_table_ptr);
226 kargs.batch_stride_block_table = batch_stride_block_table;
227 kargs.page_block_size = page_block_size;
228 }
229 else
230 {
231 kargs.cache_batch_idx = reinterpret_cast<const int32_t*>(cache_batch_idx);
232 }
233
234 return kargs;
235 }
236
237 CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size,
238 ck_tile::index_t nhead,
239 ck_tile::index_t seqlen_q,
240 ck_tile::index_t seqlen_knew)
241 {
242 // TODO: this may need tuning
243 return dim3(std::max(ck_tile::integer_divide_ceil(seqlen_q, FmhaPipeline::kM0),
244 ck_tile::integer_divide_ceil(seqlen_knew, FmhaPipeline::kN0)),
245 nhead,
246 batch_size);
247 }
248
249 CK_TILE_DEVICE static constexpr auto GetTileIndex(const Kargs& /* kargs */)
250 {
251 const index_t i_tile = blockIdx.x;
252 const index_t i_nhead = blockIdx.y;
253 const index_t i_batch = blockIdx.z;
254
255 return ck_tile::make_tuple(i_tile, i_nhead, i_batch);
256 }
257
258 CK_TILE_HOST static dim3 BlockSize() { return dim3(kBlockSize); }
259
261 {
262 // divide problem
263 const auto [i_tile, i_nhead, i_batch] = GetTileIndex(kargs);
264
265 const index_t i_m0 = amd_wave_read_first_lane(i_tile * FmhaPipeline::kM0);
266 const index_t i_n0 = amd_wave_read_first_lane(i_tile * FmhaPipeline::kN0);
267
268 const index_t i_cache_batch = [&, i_batch_ = i_batch] {
269 if constexpr(kIsPagedKV)
270 {
271 return i_batch_;
272 }
273 else
274 {
275 return (kargs.cache_batch_idx != nullptr ? kargs.cache_batch_idx[i_batch_]
276 : i_batch_);
277 }
278 }();
279
280 const long_index_t batch_offset_q =
281 static_cast<long_index_t>(i_batch) * kargs.batch_stride_q;
282 const long_index_t batch_offset_k =
283 static_cast<long_index_t>(i_cache_batch) * kargs.batch_stride_k;
284 const long_index_t batch_offset_knew =
285 static_cast<long_index_t>(i_batch) * kargs.batch_stride_knew;
286 const long_index_t batch_offset_v =
287 static_cast<long_index_t>(i_cache_batch) * kargs.batch_stride_v;
288 const long_index_t batch_offset_vnew =
289 static_cast<long_index_t>(i_batch) * kargs.batch_stride_vnew;
290
291 kargs.seqlen_k = kargs.seqlen_k_ptr[i_batch];
292
293 // for simplicity, batch stride we just modify the pointer
294 QDataType* q_ptr = reinterpret_cast<QDataType*>(kargs.q_ptr) +
295 static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_q +
296 batch_offset_q;
297 KDataType* k_ptr =
298 reinterpret_cast<KDataType*>(kargs.k_ptr) +
299 static_cast<long_index_t>(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_k +
300 batch_offset_k;
301 const KDataType* knew_ptr =
302 reinterpret_cast<const KDataType*>(kargs.knew_ptr) +
303 static_cast<long_index_t>(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_knew +
304 batch_offset_knew;
305 VDataType* v_ptr =
306 reinterpret_cast<VDataType*>(kargs.v_ptr) +
307 static_cast<long_index_t>(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_v +
308 batch_offset_v;
309 const VDataType* vnew_ptr =
310 reinterpret_cast<const VDataType*>(kargs.vnew_ptr) +
311 static_cast<long_index_t>(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_vnew +
312 batch_offset_vnew;
313
314 // Q/K/V DRAM and DRAM window
315 const auto q_dram = [&]() {
317 q_ptr,
318 make_tuple(kargs.seqlen_q, kargs.hdim_q),
319 make_tuple(kargs.stride_q, 1),
321 number<1>{});
322
323 return pad_tensor_view(
324 q_dram_naive,
327 }();
328
329 const auto make_k_dram = [&](KDataType* data, index_t height) {
331 data, // will update this pointer if using paged-kvcache
332 make_tuple(height, kargs.hdim_q),
333 make_tuple(kargs.stride_k, 1),
335 number<1>{});
336
337 return pad_tensor_view(
338 k_dram_naive,
341 };
342 const auto k_dram = [&]() {
343 if constexpr(kIsPagedKV)
344 {
345 return make_k_dram(nullptr, kargs.page_block_size);
346 }
347 else
348 {
349 return make_k_dram(k_ptr, kargs.seqlen_k + kargs.seqlen_knew);
350 }
351 }();
352
353 const auto knew_dram = [&]() {
354 const auto knew_dram_naive = make_naive_tensor_view<address_space_enum::global>(
355 knew_ptr,
356 make_tuple(kargs.seqlen_knew, kargs.hdim_q),
357 make_tuple(kargs.stride_knew, 1),
359 number<1>{});
360
361 return pad_tensor_view(
362 knew_dram_naive,
365 }();
366
367 const auto make_v_dram = [&](VDataType* data, index_t length) {
368 if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
369 {
371 data, // will update this pointer if using paged-kvcache
372 make_tuple(length, kargs.hdim_v),
373 make_tuple(kargs.stride_v, 1),
375 number<1>{});
376
377 const auto v_dram_transposed =
378 transform_tensor_view(v_dram_naive,
383
384 return pad_tensor_view(
385 v_dram_transposed,
388 }
389 else
390 {
392 data, // will update this pointer if using paged-kvcache
393 make_tuple(kargs.hdim_v, length),
394 make_tuple(kargs.stride_v, 1),
396 number<1>{});
397
398 return pad_tensor_view(
399 v_dram_naive,
402 }
403 };
404 const auto v_dram = [&]() {
405 if constexpr(kIsPagedKV)
406 {
407 return make_v_dram(nullptr, kargs.page_block_size);
408 }
409 else
410 {
411 return make_v_dram(v_ptr, kargs.seqlen_k + kargs.seqlen_knew);
412 }
413 }();
414
415 const auto vnew_dram = [&]() {
416 if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
417 {
418 const auto vnew_dram_naive = make_naive_tensor_view<address_space_enum::global>(
419 vnew_ptr,
420 make_tuple(kargs.seqlen_knew, kargs.hdim_v),
421 make_tuple(kargs.stride_vnew, 1),
423 number<1>{});
424
425 const auto vnew_dram_transposed = transform_tensor_view(
426 vnew_dram_naive,
431
432 return pad_tensor_view(
433 vnew_dram_transposed,
436 }
437 else
438 {
439 const auto vnew_dram_naive = make_naive_tensor_view<address_space_enum::global>(
440 vnew_ptr,
441 make_tuple(kargs.hdim_v, kargs.seqlen_knew),
442 make_tuple(kargs.stride_vnew, 1),
444 number<1>{});
445
446 return pad_tensor_view(
447 vnew_dram_naive,
450 }
451 }();
452
453 constexpr auto q_rotary_cos_sin_dram_window_lengths =
454 make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kK0 / 2>{});
455 const auto q_rotary_cos_dram_window = [&]() {
456 if constexpr(kApplyRoPE)
457 {
458 const auto rotary_cos_dram_native =
460 reinterpret_cast<const QDataType*>(kargs.rotary_cos_ptr) +
461 kargs.seqlen_k * (kargs.rotary_dim / 2),
462 make_tuple(kargs.seqlen_q, kargs.rotary_dim / 2),
463 make_tuple(kargs.has_mask * (kargs.rotary_dim / 2), 1),
464 number<8>{},
465 number<1>{});
466
467 const auto rotary_cos_dram = [&]() {
468 return pad_tensor_view(rotary_cos_dram_native,
469 q_rotary_cos_sin_dram_window_lengths,
471 }();
472
473 return make_tile_window(
474 rotary_cos_dram, q_rotary_cos_sin_dram_window_lengths, {i_m0, 0});
475 }
476 else
477 {
478 return make_null_tile_window(q_rotary_cos_sin_dram_window_lengths);
479 }
480 }();
481 const auto q_rotary_sin_dram_window = [&]() {
482 if constexpr(kApplyRoPE)
483 {
484 const auto rotary_sin_dram_native =
486 reinterpret_cast<const QDataType*>(kargs.rotary_sin_ptr) +
487 kargs.seqlen_k * (kargs.rotary_dim / 2),
488 make_tuple(kargs.seqlen_q, kargs.rotary_dim / 2),
489 make_tuple(kargs.has_mask * (kargs.rotary_dim / 2), 1),
490 number<8>{},
491 number<1>{});
492
493 const auto rotary_sin_dram = [&]() {
494 return pad_tensor_view(rotary_sin_dram_native,
495 q_rotary_cos_sin_dram_window_lengths,
497 }();
498
499 return make_tile_window(
500 rotary_sin_dram, q_rotary_cos_sin_dram_window_lengths, {i_m0, 0});
501 }
502 else
503 {
504 return make_null_tile_window(q_rotary_cos_sin_dram_window_lengths);
505 }
506 }();
507
508 constexpr auto knew_rotary_cos_sin_dram_window_lengths =
509 make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kK0 / 2>{});
510 const auto knew_rotary_cos_dram_window = [&]() {
511 if constexpr(kApplyRoPE)
512 {
513 const auto rotary_cos_dram_native =
515 reinterpret_cast<const KDataType*>(kargs.rotary_cos_ptr) +
516 kargs.seqlen_k * (kargs.rotary_dim / 2),
517 make_tuple(kargs.seqlen_knew, kargs.rotary_dim / 2),
518 make_tuple(kargs.rotary_dim / 2, 1),
519 number<8>{},
520 number<1>{});
521
522 const auto rotary_cos_dram = [&]() {
523 return pad_tensor_view(rotary_cos_dram_native,
524 knew_rotary_cos_sin_dram_window_lengths,
526 }();
527
528 return make_tile_window(
529 rotary_cos_dram, knew_rotary_cos_sin_dram_window_lengths, {i_n0, 0});
530 }
531 else
532 {
533 return make_null_tile_window(knew_rotary_cos_sin_dram_window_lengths);
534 }
535 }();
536 const auto knew_rotary_sin_dram_window = [&]() {
537 if constexpr(kApplyRoPE)
538 {
539 const auto rotary_sin_dram_native =
541 reinterpret_cast<const KDataType*>(kargs.rotary_sin_ptr) +
542 kargs.seqlen_k * (kargs.rotary_dim / 2),
543 make_tuple(kargs.seqlen_knew, kargs.rotary_dim / 2),
544 make_tuple(kargs.rotary_dim / 2, 1),
545 number<8>{},
546 number<1>{});
547
548 const auto rotary_sin_dram = [&]() {
549 return pad_tensor_view(rotary_sin_dram_native,
550 knew_rotary_cos_sin_dram_window_lengths,
552 }();
553
554 return make_tile_window(
555 rotary_sin_dram, knew_rotary_cos_sin_dram_window_lengths, {i_n0, 0});
556 }
557 else
558 {
559 return make_null_tile_window(knew_rotary_cos_sin_dram_window_lengths);
560 }
561 }();
562
563 auto k_page_block_navigator = [&, i_batch_ = i_batch, i_nhead_ = i_nhead]() {
564 if constexpr(kIsPagedKV)
565 {
566 const auto* block_indices =
567 reinterpret_cast<const int32_t*>(kargs.block_table_ptr) +
568 i_batch_ * kargs.batch_stride_block_table;
569 const index_t num_blocks =
570 integer_divide_ceil(kargs.seqlen_k + kargs.seqlen_knew, kargs.page_block_size);
571
572 const long_index_t fixed_offset =
573 static_cast<long_index_t>(i_nhead_ / kargs.nhead_ratio_qk) *
574 kargs.nhead_stride_k;
575
577 kargs.k_ptr,
578 kargs.batch_stride_k,
579 fixed_offset,
580 block_indices,
581 num_blocks,
582 kargs.page_block_size,
583 k_dram,
584 make_k_dram(nullptr,
585 (kargs.seqlen_k + kargs.seqlen_knew) -
586 (num_blocks - 1) * kargs.page_block_size));
587 }
588 else
589 {
590 return make_page_block_navigator(k_dram);
591 }
592 }();
593
594 auto v_page_block_navigator = [&, i_batch_ = i_batch, i_nhead_ = i_nhead]() {
595 if constexpr(kIsPagedKV)
596 {
597 const auto* block_indices =
598 reinterpret_cast<const int32_t*>(kargs.block_table_ptr) +
599 i_batch_ * kargs.batch_stride_block_table;
600 const index_t num_blocks =
601 integer_divide_ceil(kargs.seqlen_k + kargs.seqlen_knew, kargs.page_block_size);
602
603 const long_index_t fixed_offset =
604 static_cast<long_index_t>(i_nhead_ / kargs.nhead_ratio_qk) *
605 kargs.nhead_stride_v;
606
608 kargs.v_ptr,
609 kargs.batch_stride_v,
610 fixed_offset,
611 block_indices,
612 num_blocks,
613 kargs.page_block_size,
614 v_dram,
615 make_v_dram(nullptr,
616 (kargs.seqlen_k + kargs.seqlen_knew) -
617 (num_blocks - 1) * kargs.page_block_size));
618 }
619 else
620 {
621 return make_page_block_navigator(v_dram);
622 }
623 }();
624
625 auto q_dram_window =
626 make_tile_window(q_dram,
628 {i_m0, 0});
629
630 const bool skip_append_kv = kargs.seqlen_knew <= i_n0;
631 // window origin = (0, 0) if no work to do for current block
632 auto [i_page_block_k, k_dram_window] = k_page_block_navigator.make_tile_window(
634 {!skip_append_kv * (kargs.seqlen_k + i_n0), 0});
635
636 auto knew_dram_window =
637 make_tile_window(knew_dram,
639 {i_n0, 0});
640
641 // window origin = (0, 0) if no work to do for current block
642 auto [i_page_block_v, v_dram_window] = v_page_block_navigator.make_tile_window(
644 {0, !skip_append_kv * (kargs.seqlen_k + i_n0)});
645
646 auto vnew_dram_window =
647 make_tile_window(vnew_dram,
649 {0, i_n0});
650
651 // If kApplyRoPe is false, we set the rotary_dim to 0
652 auto rotary_dim = [&]() {
653 if constexpr(kApplyRoPE)
654 return kargs.rotary_dim;
655 else
656 return 0;
657 }();
658 FmhaPipeline{}(q_dram_window,
659 k_dram_window,
660 i_page_block_k,
661 k_page_block_navigator,
662 knew_dram_window,
663 v_dram_window,
664 i_page_block_v,
665 v_page_block_navigator,
666 vnew_dram_window,
667 q_rotary_cos_dram_window,
668 q_rotary_sin_dram_window,
669 knew_rotary_cos_dram_window,
670 knew_rotary_sin_dram_window,
671 rotary_dim,
672 kargs.seqlen_q <= i_m0,
673 skip_append_kv);
674 }
675};
676
677} // namespace ck_tile
#define _TS_
#define _SS_
#define CK_TILE_DEVICE
Definition config.hpp:41
#define CK_TILE_HOST
Definition config.hpp:40
Definition tile/core/algorithm/cluster_descriptor.hpp:13
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
CK_TILE_HOST_DEVICE constexpr auto make_naive_tensor_view(DataType *__restrict__ p, const tuple< Lengths... > &lengths, const tuple< Strides... > &strides, number< GuaranteedLastDimensionVectorLength >=number<-1 >{}, number< GuaranteedLastDimensionVectorStride >=number<-1 >{})
Definition tensor_view.hpp:471
CK_TILE_HOST_DEVICE constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition coordinate_transform.hpp:1558
__device__ uint32_t amd_wave_read_first_lane(uint16_t v)
Definition tile/core/arch/amd_buffer_addressing.hpp:35
CK_TILE_HOST_DEVICE auto make_page_block_navigator(const TensorView &tensor_view)
Definition page_block_navigator.hpp:333
bfloat16_t bf16_t
Definition bfloat16.hpp:113
@ NONE
Definition block_rotary_embedding.hpp:13
_Float16 fp16_t
Definition half.hpp:110
_BitInt(8) fp8_t
Definition float8.hpp:204
int64_t long_index_t
Definition integer.hpp:11
constant< v > number
Definition tile/core/numeric/integral_constant.hpp:37
CK_TILE_DEVICE constexpr auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition null_tile_window.hpp:75
CK_TILE_HOST_DEVICE constexpr auto integer_divide_ceil(X x, Y y)
Definition tile/core/numeric/math.hpp:149
CK_TILE_HOST_DEVICE constexpr auto pad_tensor_view(const TensorView &tensor_view, const TileLengths &tile_lengths, DoPads)
Definition tensor_view.hpp:530
int32_t int32_t
Definition integer.hpp:10
CK_TILE_DEVICE constexpr auto make_null_tile_window(const WindowLengths &window_lengths)
Definition null_tile_window.hpp:66
unsigned _BitInt(8) bf8_t
Definition float8.hpp:206
int32_t index_t
Definition integer.hpp:9
CK_TILE_HOST_DEVICE constexpr auto transform_tensor_view(const OldTensorView &old_tensor_view, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_view.hpp:511
CK_TILE_HOST_DEVICE constexpr auto make_tuple(Xs &&... xs)
Definition tile/core/container/tuple.hpp:360
Definition fmha_fwd_appendkv_kernel.hpp:81
ck_tile::index_t stride_q
Definition fmha_fwd_appendkv_kernel.hpp:101
const int32_t * seqlen_k_ptr
Definition fmha_fwd_appendkv_kernel.hpp:88
const void * knew_ptr
Definition fmha_fwd_appendkv_kernel.hpp:84
ck_tile::index_t stride_k
Definition fmha_fwd_appendkv_kernel.hpp:102
ck_tile::index_t batch_stride_knew
Definition fmha_fwd_appendkv_kernel.hpp:115
ck_tile::index_t batch_stride_v
Definition fmha_fwd_appendkv_kernel.hpp:116
ck_tile::index_t nhead_stride_knew
Definition fmha_fwd_appendkv_kernel.hpp:109
ck_tile::index_t nhead_stride_vnew
Definition fmha_fwd_appendkv_kernel.hpp:111
ck_tile::index_t batch_stride_k
Definition fmha_fwd_appendkv_kernel.hpp:114
ck_tile::index_t stride_knew
Definition fmha_fwd_appendkv_kernel.hpp:103
ck_tile::index_t nhead_stride_q
Definition fmha_fwd_appendkv_kernel.hpp:107
ck_tile::index_t hdim_q
Definition fmha_fwd_appendkv_kernel.hpp:93
ck_tile::index_t stride_v
Definition fmha_fwd_appendkv_kernel.hpp:104
ck_tile::index_t batch_stride_q
Definition fmha_fwd_appendkv_kernel.hpp:113
ck_tile::index_t nhead_stride_k
Definition fmha_fwd_appendkv_kernel.hpp:108
ck_tile::index_t hdim_v
Definition fmha_fwd_appendkv_kernel.hpp:94
ck_tile::index_t batch_stride_vnew
Definition fmha_fwd_appendkv_kernel.hpp:117
void * q_ptr
Definition fmha_fwd_appendkv_kernel.hpp:82
ck_tile::index_t stride_vnew
Definition fmha_fwd_appendkv_kernel.hpp:105
void * v_ptr
Definition fmha_fwd_appendkv_kernel.hpp:85
const void * vnew_ptr
Definition fmha_fwd_appendkv_kernel.hpp:86
ck_tile::index_t nhead_stride_v
Definition fmha_fwd_appendkv_kernel.hpp:110
void * k_ptr
Definition fmha_fwd_appendkv_kernel.hpp:83
ck_tile::index_t seqlen_k
Definition fmha_fwd_appendkv_kernel.hpp:91
ck_tile::index_t seqlen_q
Definition fmha_fwd_appendkv_kernel.hpp:90
ck_tile::index_t seqlen_knew
Definition fmha_fwd_appendkv_kernel.hpp:92
ck_tile::index_t num_head_q
Definition fmha_fwd_appendkv_kernel.hpp:96
ck_tile::index_t nhead_ratio_qk
Definition fmha_fwd_appendkv_kernel.hpp:99
Definition fmha_fwd_appendkv_kernel.hpp:136
const int32_t * cache_batch_idx
Definition fmha_fwd_appendkv_kernel.hpp:137
Definition fmha_fwd_appendkv_kernel.hpp:74
Definition fmha_fwd_appendkv_kernel.hpp:143
Definition fmha_fwd_appendkv_kernel.hpp:129
ck_tile::index_t batch_stride_block_table
Definition fmha_fwd_appendkv_kernel.hpp:131
const int32_t * block_table_ptr
Definition fmha_fwd_appendkv_kernel.hpp:130
ck_tile::index_t page_block_size
Definition fmha_fwd_appendkv_kernel.hpp:132
Definition fmha_fwd_appendkv_kernel.hpp:121
ck_tile::index_t rotary_dim
Definition fmha_fwd_appendkv_kernel.hpp:124
const void * rotary_sin_ptr
Definition fmha_fwd_appendkv_kernel.hpp:123
bool has_mask
Definition fmha_fwd_appendkv_kernel.hpp:125
const void * rotary_cos_ptr
Definition fmha_fwd_appendkv_kernel.hpp:122
static constexpr const char * name
Definition fmha_fwd_appendkv_kernel.hpp:40
static constexpr const char * name
Definition fmha_fwd_appendkv_kernel.hpp:42
static constexpr const char * name
Definition fmha_fwd_appendkv_kernel.hpp:39
static constexpr const char * name
Definition fmha_fwd_appendkv_kernel.hpp:41
static constexpr const char * name
Definition fmha_fwd_appendkv_kernel.hpp:38
Definition fmha_fwd_appendkv_kernel.hpp:37
Definition fmha_fwd_appendkv_kernel.hpp:15
static constexpr bool kPadHeadDimV
Definition fmha_fwd_appendkv_kernel.hpp:34
static constexpr ck_tile::index_t kBlockPerCuInput
Definition fmha_fwd_appendkv_kernel.hpp:21
ck_tile::remove_cvref_t< FmhaPipeline_ > FmhaPipeline
Definition fmha_fwd_appendkv_kernel.hpp:16
ck_tile::remove_cvref_t< typename FmhaPipeline::QDataType > QDataType
Definition fmha_fwd_appendkv_kernel.hpp:23
ck_tile::remove_cvref_t< typename FmhaPipeline::KDataType > KDataType
Definition fmha_fwd_appendkv_kernel.hpp:24
static constexpr bool kPadSeqLenQ
Definition fmha_fwd_appendkv_kernel.hpp:31
CK_TILE_DEVICE void operator()(Kargs kargs) const
Definition fmha_fwd_appendkv_kernel.hpp:260
static CK_TILE_HOST dim3 BlockSize()
Definition fmha_fwd_appendkv_kernel.hpp:258
static CK_TILE_HOST constexpr Kargs MakeKargs(void *q_ptr, void *k_ptr, const void *knew_ptr, void *v_ptr, const void *vnew_ptr, ck_tile::index_t seqlen_q, const void *seqlen_k_ptr, ck_tile::index_t seqlen_knew, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, const void *rotary_cos_ptr, const void *rotary_sin_ptr, ck_tile::index_t rotary_dim, bool has_mask, const void *block_table_ptr, ck_tile::index_t batch_stride_block_table, ck_tile::index_t page_block_size, const void *cache_batch_idx, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_knew, ck_tile::index_t stride_v, ck_tile::index_t stride_vnew, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_knew, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_vnew, ck_tile::index_t batch_stride_q, ck_tile::index_t batch_stride_k, ck_tile::index_t batch_stride_knew, ck_tile::index_t batch_stride_v, ck_tile::index_t batch_stride_vnew)
Definition fmha_fwd_appendkv_kernel.hpp:146
ck_tile::remove_cvref_t< typename FmhaPipeline::VLayout > VLayout
Definition fmha_fwd_appendkv_kernel.hpp:27
static CK_TILE_HOST constexpr auto GridSize(ck_tile::index_t batch_size, ck_tile::index_t nhead, ck_tile::index_t seqlen_q, ck_tile::index_t seqlen_knew)
Definition fmha_fwd_appendkv_kernel.hpp:237
static CK_TILE_DEVICE constexpr auto GetTileIndex(const Kargs &)
Definition fmha_fwd_appendkv_kernel.hpp:249
static constexpr ck_tile::index_t kBlockPerCu
Definition fmha_fwd_appendkv_kernel.hpp:18
ck_tile::remove_cvref_t< typename FmhaPipeline::VDataType > VDataType
Definition fmha_fwd_appendkv_kernel.hpp:25
static constexpr bool kPadHeadDimQ
Definition fmha_fwd_appendkv_kernel.hpp:33
static constexpr bool kPadSeqLenK
Definition fmha_fwd_appendkv_kernel.hpp:32
static constexpr bool kIsPagedKV
Definition fmha_fwd_appendkv_kernel.hpp:29
static CK_TILE_HOST std::string GetName()
Definition fmha_fwd_appendkv_kernel.hpp:45
static constexpr bool kApplyRoPE
Definition fmha_fwd_appendkv_kernel.hpp:28
static constexpr ck_tile::index_t kBlockSize
Definition fmha_fwd_appendkv_kernel.hpp:17
Definition block_rotary_embedding.hpp:19
Definition tile/core/container/sequence.hpp:49