block_fmha_fwd_v3_pipeline_default_policy.hpp Source File

block_fmha_fwd_v3_pipeline_default_policy.hpp Source File#

Composable Kernel: block_fmha_fwd_v3_pipeline_default_policy.hpp Source File
block_fmha_fwd_v3_pipeline_default_policy.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include "ck_tile/core.hpp"
11
12namespace ck_tile {
13
15{
16 static constexpr ck_tile::index_t NumWarpPerGroup = 4;
19
20 // TODO: GetAlignment*() currently didn't consider if need padding or not
21 // so in pipeline still need check padding requirement
22 template <typename Problem>
23 CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentQ()
24 {
25 constexpr index_t MaxVectorSize = 16 / sizeof(typename Problem::QDataType);
26
28 constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
29 using WG = remove_cvref_t<decltype(config.template at<0>())>;
30
31 return min(MaxVectorSize, WG::kK / WG::WarpGemmAttribute::Impl::kABKLane);
32 }
33
34 template <typename Problem>
35 CK_TILE_DEVICE static constexpr auto GetAlignmentK()
36 {
37 using namespace ck_tile;
39#if defined(__gfx950__)
40 constexpr index_t MaxReadSizeInBytes = 16;
41#else
42 constexpr index_t MaxReadSizeInBytes = 4;
43#endif
44 return MaxReadSizeInBytes / sizeof(KDataType);
45 }
46
47 template <typename Problem>
48 CK_TILE_DEVICE static constexpr auto GetAlignmentV()
49 {
50 using namespace ck_tile;
52#if defined(__gfx950__)
53 constexpr index_t MaxReadSizeInBytes = 16;
54#else
55 constexpr index_t MaxReadSizeInBytes = 4;
56#endif
57 return MaxReadSizeInBytes / sizeof(VDataType);
58 }
59
60 template <typename Problem>
61 CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentO()
62 {
64 constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
65 using WG = remove_cvref_t<decltype(config.template at<0>())>;
66
67 return WG::WarpGemmAttribute::Impl::kCM1PerLane;
68 }
69
70 template <typename Problem>
71 CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackK()
72 {
73 using namespace ck_tile;
74
75 // TODO: this is for 3d layout
77 return 16 / sizeof(KDataType);
78 }
79
80 template <typename Problem>
81 CK_TILE_HOST_DEVICE static constexpr auto GetSmemVPackK()
82 {
83 using namespace ck_tile;
84
85 // TODO: this is for 3d layout
87 return 16 / sizeof(VDataType);
88 }
89
90 template <typename Problem>
92 {
93 using namespace ck_tile;
94
95 constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
96 constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
97 constexpr index_t kBlockSize = Problem::kBlockSize;
98 constexpr index_t NumWarps = Problem::BlockFmhaShape::NumWarps;
99 constexpr index_t WarpSize = ck_tile::get_warp_size();
100
101 constexpr index_t KVector = GetAlignmentK<Problem>(); // this is for global load
102
103 static_assert(WarpSize * KVector >= kKPerBlock && WarpSize * KVector % kKPerBlock == 0);
104 constexpr index_t LanesPerK = kKPerBlock / KVector; // within a wave
105 constexpr index_t LaneGroups = WarpSize / LanesPerK; // within a wave
106 constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps);
107 static_assert(NumIssues == kNPerBlock * kKPerBlock / (kBlockSize * KVector));
108
109 constexpr index_t N0 = NumIssues;
110 constexpr index_t N1 = LaneGroups;
111 constexpr index_t N2 = NumWarps;
112 constexpr index_t K0 = LanesPerK;
113 constexpr index_t K1 = KVector;
114
121 sequence<0, 1>>{});
122 }
123
124 template <typename Problem>
126 {
127 using namespace ck_tile;
128
129 constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kK1;
130 constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kN1;
131 constexpr index_t kBlockSize = Problem::kBlockSize;
132 constexpr index_t NumWarps = Problem::BlockFmhaShape::NumWarps;
133 constexpr index_t WarpSize = ck_tile::get_warp_size();
134
135 constexpr index_t KVector = GetAlignmentV<Problem>(); // this is for global load
136
137 static_assert(WarpSize * KVector >= kKPerBlock && WarpSize * KVector % kKPerBlock == 0);
138 constexpr index_t LanesPerK = kKPerBlock / KVector; // within a wave
139 constexpr index_t LaneGroups = WarpSize / LanesPerK; // within a wave
140 constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps);
141 static_assert(NumIssues == kNPerBlock * kKPerBlock / (kBlockSize * KVector));
142
143 constexpr index_t N0 = NumIssues;
144 constexpr index_t N1 = LaneGroups;
145 constexpr index_t N2 = NumWarps;
146 constexpr index_t K0 = LanesPerK;
147 constexpr index_t K1 = KVector;
148
155 sequence<0, 1>>{});
156 }
157
158 template <typename Problem>
160 {
161 using namespace ck_tile;
162
164
165 return make_static_tile_distribution(BlockGemm::MakeABlockDistributionEncode());
166 }
167
168 template <typename Problem>
170 {
171 using namespace ck_tile;
172
174
175 return make_static_tile_distribution(BlockGemm::MakeBBlockDistributionEncode());
176 }
177
178 template <typename Problem>
180 {
181 using namespace ck_tile;
182
184
185 return make_static_tile_distribution(BlockGemm::MakeABlockDistributionEncode());
186 }
187
188 template <typename Problem>
190 {
191 using namespace ck_tile;
192
194 constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
195 using WarpGemm = remove_cvref_t<decltype(config.template at<0>())>;
196
197 constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm1BlockWarps::at(number<0>{});
198 constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm1BlockWarps::at(number<1>{});
199
200 constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1;
201 constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
202
203 constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WarpGemm::kN);
204 constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK;
205
206 constexpr auto v_block_outer_dstr_encoding =
213
214 constexpr auto v_block_dstr_encode = ck_tile::detail::make_embed_tile_distribution_encoding(
215 v_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{});
216
217 // compute the endcoding before transpose
218 constexpr auto v_block_dstr =
220 decltype(v_block_dstr_encode),
221 typename Problem::VDataType>::TransposedDstrEncode{});
222
223 return v_block_dstr;
224 }
225
226 template <typename Problem>
227 CK_TILE_DEVICE static constexpr auto GetQKBlockGemm()
228 {
229 using namespace ck_tile;
230
231 using GemmProblem =
232 BlockGemmProblem<typename Problem::QDataType,
233 typename Problem::KDataType,
234 typename Problem::SaccDataType,
235 Problem::kBlockSize,
236 TileGemmShape<sequence<Problem::BlockFmhaShape::kM0,
237 Problem::BlockFmhaShape::kN0,
238 Problem::BlockFmhaShape::kK0>,
239 typename Problem::BlockFmhaShape::Gemm0BlockWarps,
240 typename Problem::BlockFmhaShape::Gemm0WarpTile>>;
241
242 constexpr auto warp_gemm = []() {
243 if constexpr(std::is_same_v<typename Problem::QDataType, half_t> &&
244 std::is_same_v<typename Problem::KDataType, half_t> &&
245 std::is_same_v<typename Problem::SaccDataType, float>)
246 {
250 }
251 else if constexpr(std::is_same_v<typename Problem::QDataType, bf16_t> &&
252 std::is_same_v<typename Problem::KDataType, bf16_t> &&
253 std::is_same_v<typename Problem::SaccDataType, float>)
254 {
258 }
259 }();
260
261 using BlockGemmPolicy =
262 BlockGemmARegBRegCRegV2CustomPolicy<typename Problem::QDataType,
263 typename Problem::KDataType,
264 typename Problem::SaccDataType,
265 typename Problem::BlockFmhaShape::Gemm0BlockWarps,
266 decltype(warp_gemm),
268
270 }
271
272 template <typename Problem>
273 CK_TILE_DEVICE static constexpr auto GetPVBlockGemm()
274 {
275 using namespace ck_tile;
276
277 using GemmProblem =
278 BlockGemmProblem<typename Problem::PDataType,
279 typename Problem::VDataType,
280 typename Problem::OaccDataType,
281 Problem::kBlockSize,
282 TileGemmShape<sequence<Problem::BlockFmhaShape::kM0,
283 Problem::BlockFmhaShape::kN1,
284 Problem::BlockFmhaShape::kK1>,
285 typename Problem::BlockFmhaShape::Gemm1BlockWarps,
286 typename Problem::BlockFmhaShape::Gemm1WarpTile>>;
289 using WarpGemm = WarpGemmDispatcher<typename Problem::PDataType,
290 typename Problem::VDataType,
291 typename Problem::OaccDataType,
292 Problem::BlockFmhaShape::Gemm1WarpTile::at(number<0>{}),
293 Problem::BlockFmhaShape::Gemm1WarpTile::at(number<1>{}),
294 Problem::BlockFmhaShape::Gemm1WarpTile::at(number<2>{}),
295 true,
296 false,
297 false,
299
300 using BlockGemmPolicy =
301 BlockGemmARegBRegCRegV2CustomPolicy<typename Problem::PDataType,
302 typename Problem::VDataType,
303 typename Problem::OaccDataType,
304 typename Problem::BlockFmhaShape::Gemm1BlockWarps,
305 WarpGemm,
308 }
309
310 static constexpr ck_tile::index_t kKLdsPadInBytes = 4 * 4; // 4 dwords
311 static constexpr ck_tile::index_t kVLdsPadInBytes = 4 * 16; // 16 dwords
312
313 template <typename Problem, ck_tile::index_t IBuf = 0>
314 CK_TILE_DEVICE static constexpr auto
316 {
317 using namespace ck_tile;
318
319 // K is always k-major, we use async-copy to load into LDS
320 constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
321 constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
322 constexpr index_t kBlockSize = Problem::kBlockSize;
323 constexpr index_t NumWarps = Problem::BlockFmhaShape::NumWarps;
324 constexpr index_t WarpSize = ck_tile::get_warp_size();
325
326 [[maybe_unused]] constexpr index_t KPack = GetSmemKPackK<Problem>(); // this is for lds
327 constexpr index_t KVector = GetAlignmentK<Problem>(); // this is for global load
328 constexpr index_t kPad =
330 sizeof(typename Problem::KDataType); // for async-copy, this pad is between warps.
331 // Optimize this for lds_read speed
332
333 static_assert(WarpSize * KVector >= kKPerBlock && WarpSize * KVector % kKPerBlock == 0);
334 constexpr index_t LanesPerK =
335 kKPerBlock / KVector; // how many lane (within a wave) to load K
336 constexpr index_t LaneGroups =
337 WarpSize /
338 LanesPerK; // how many groups (within a wave), they may load different N, but same K
339 constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps);
340 static_assert(NumIssues == kNPerBlock * kKPerBlock / (kBlockSize * KVector));
341
342 constexpr auto k_lds_block_desc_0 = make_naive_tensor_descriptor_with_offset(
344 number<LaneGroups>{}, // n1
345 number<NumWarps>{}, // n2
346 number<LanesPerK>{}, // k0
347 number<KVector>{}), // k1
348 make_tuple(number<NumWarps*(WarpSize * KVector + kPad)>{},
352 number<1>{}),
355 number<1>{});
356
357 // TODO this layout is hard coded, and will be used in async copy buffer view load
358 // in LDS the real layout is (bufs, N0, N2, N1*K0*K1)
359 constexpr auto k_lds_block_desc_issues_warps_lanes = transform_tensor_descriptor(
360 k_lds_block_desc_0,
365 make_tuple(sequence<0>{}, sequence<2>{}, sequence<1, 3, 4>{}),
366 make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}));
367
368 return k_lds_block_desc_issues_warps_lanes;
369 }
370
371 template <typename Problem>
373 {
374 using namespace ck_tile;
375
376 // K is always k-major, we use async-copy to load into LDS
377 constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
378 constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
379 constexpr index_t kBlockSize = Problem::kBlockSize;
380 constexpr index_t NumWarps = Problem::BlockFmhaShape::NumWarps;
381 constexpr index_t WarpSize = ck_tile::get_warp_size();
382
383 constexpr index_t KPack = GetSmemKPackK<Problem>(); // this is for lds
384 constexpr index_t KVector = GetAlignmentK<Problem>(); // this is for global load
385 constexpr index_t kPad =
387 sizeof(typename Problem::KDataType); // for async-copy, this pad is between warps
388
389 static_assert(WarpSize * KVector >= kKPerBlock && WarpSize * KVector % kKPerBlock == 0);
390 constexpr index_t LanesPerK = kKPerBlock / KVector; // within a wave
391 constexpr index_t LaneGroups = WarpSize / LanesPerK; // within a wave
392 constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps);
393 static_assert(NumIssues == kNPerBlock * kKPerBlock / (kBlockSize * KVector));
394
395 constexpr auto k_lds_block_desc_0 =
397 number<NumWarps>{}, // n2
398 number<LaneGroups>{}, // n1
399 number<kKPerBlock / KPack>{}, // k0
400 number<KPack>{}), // k1
401 make_tuple(number<NumWarps*(WarpSize * KVector + kPad)>{},
405 number<1>{}),
407 number<1>{});
408
409 constexpr auto k_lds_block_desc = transform_tensor_descriptor(
410 k_lds_block_desc_0,
417
418 return k_lds_block_desc;
419 }
420
421 template <typename Problem>
423 {
424 // this function assume K/V can share smem
425 constexpr index_t SingleKSize = [&]() {
426 constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
427 constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
428 constexpr index_t NumWarps = Problem::BlockFmhaShape::NumWarps;
429 constexpr index_t WarpSize = ck_tile::get_warp_size();
430
431 constexpr index_t KPack = GetSmemKPackK<Problem>(); // this is for lds
432 constexpr index_t KVector = GetAlignmentK<Problem>(); // this is for global load
433 constexpr index_t kPad = KPack;
434
435 static_assert(WarpSize * KVector >= kKPerBlock && WarpSize * KVector % kKPerBlock == 0);
436 constexpr index_t LanesPerK = kKPerBlock / KVector;
437 constexpr index_t LaneGroups = WarpSize / LanesPerK;
438 constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps);
439
440 return NumIssues * NumWarps * (WarpSize * KVector + kPad);
441 }();
442
443 constexpr index_t SingleVSize = [&]() {
445 constexpr index_t Banks = get_n_lds_banks();
446 constexpr index_t PixelsPerRow = Banks * 4 / sizeof(VDataType);
447 constexpr index_t kKPack = GetSmemKPackK<Problem>();
448 static_assert(PixelsPerRow % kKPack == 0);
449 constexpr index_t NPerRow = PixelsPerRow / kKPack;
450 constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1;
451 constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
452 static_assert(kNPerBlock % NPerRow == 0);
453 static_assert(kKPerBlock % kKPack == 0);
454
455 return (kKPerBlock / kKPack) * (kNPerBlock / NPerRow) * (PixelsPerRow + kKPack);
456 }();
457
458 return max(SingleKSize, SingleVSize);
459 }
460
461 template <typename Problem, ck_tile::index_t IBuf = 0>
462 CK_TILE_DEVICE static constexpr auto
464 {
465 using namespace ck_tile;
466
468 constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kK1;
469 constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kN1;
470 constexpr index_t kBlockSize = Problem::kBlockSize;
471 constexpr index_t NumWarps = Problem::BlockFmhaShape::NumWarps;
472 constexpr index_t WarpSize = ck_tile::get_warp_size();
473
474 [[maybe_unused]] constexpr index_t KPack = GetSmemVPackK<Problem>(); // this is for lds
475 constexpr index_t KVector = GetAlignmentV<Problem>(); // this is for global load
476 constexpr index_t kPad =
478 sizeof(typename Problem::VDataType); // for async-copy, this pad is between warps.
479 // Optimize this for lds_read speed
480
481 static_assert(WarpSize * KVector >= kKPerBlock && WarpSize * KVector % kKPerBlock == 0);
482 constexpr index_t LanesPerK =
483 kKPerBlock / KVector; // how many lane (within a wave) to load K
484 constexpr index_t LaneGroups =
485 WarpSize /
486 LanesPerK; // how many groups (within a wave), they may load different N, but same K
487 constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps);
488 static_assert(NumIssues == kNPerBlock * kKPerBlock / (kBlockSize * KVector));
489
490 constexpr auto v_lds_block_desc_0 = make_naive_tensor_descriptor_with_offset(
492 number<LaneGroups>{}, // n1
493 number<NumWarps>{}, // n2
494 number<LanesPerK>{}, // k0
495 number<KVector>{}), // k1
496 make_tuple(number<NumWarps*(WarpSize * KVector + kPad)>{},
500 number<1>{}),
503 number<1>{});
504
505 // TODO this layout is hard coded, and will be used in async copy buffer view load
506 // in LDS the real layout is (bufs, N0, N2, N1*K0*K1)
507 constexpr auto v_lds_block_desc_issues_warps_lanes = transform_tensor_descriptor(
508 v_lds_block_desc_0,
513 make_tuple(sequence<0>{}, sequence<2>{}, sequence<1, 3, 4>{}),
514 make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}));
515
516 return v_lds_block_desc_issues_warps_lanes;
517 }
518
519 template <typename Problem>
521 {
522 using namespace ck_tile;
523
525 constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kK1;
526 constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kN1;
527 constexpr index_t kBlockSize = Problem::kBlockSize;
528 constexpr index_t NumWarps = Problem::BlockFmhaShape::NumWarps;
529 constexpr index_t WarpSize = ck_tile::get_warp_size();
530
531 constexpr index_t KPack = GetSmemVPackK<Problem>(); // this is for lds
532 constexpr index_t KVector = GetAlignmentK<Problem>(); // this is for global load
533 constexpr index_t kPad =
535 sizeof(typename Problem::VDataType); // for async-copy, this pad is between warps
536
537 static_assert(WarpSize * KVector >= kKPerBlock && WarpSize * KVector % kKPerBlock == 0);
538 constexpr index_t LanesPerK = kKPerBlock / KVector; // within a wave
539 constexpr index_t LaneGroups = WarpSize / LanesPerK; // within a wave
540 constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps);
541 static_assert(NumIssues == kNPerBlock * kKPerBlock / (kBlockSize * KVector));
542
543 constexpr auto v_lds_block_desc_0 =
545 number<NumWarps>{}, // n2
546 number<LaneGroups>{}, // n1
547 number<kKPerBlock / KPack>{}, // k0
548 number<KPack>{}), // k1
549 make_tuple(number<NumWarps*(WarpSize * KVector + kPad)>{},
553 number<1>{}),
555 number<1>{});
556
557 constexpr auto v_lds_block_desc = transform_tensor_descriptor(
558 v_lds_block_desc_0,
565
566 return v_lds_block_desc;
567 }
568
569 template <typename Problem>
571 {
572 using namespace ck_tile;
573
574 static_assert(MakeKLdsLoadBlockDescriptor<Problem>().get_element_space_size() ==
575 MakeKLdsStoreBlockDescriptor<Problem>().get_element_space_size());
576 constexpr index_t k_element_space_size =
577 MakeKLdsLoadBlockDescriptor<Problem>().get_element_space_size();
578
579 static_assert(MakeVLdsLoadBlockDescriptor<Problem>().get_element_space_size() ==
580 MakeVLdsStoreBlockDescriptor<Problem>().get_element_space_size());
581 constexpr index_t v_element_space_size =
582 MakeVLdsLoadBlockDescriptor<Problem>().get_element_space_size();
583
584 static_assert(ck_tile::max(k_element_space_size, v_element_space_size) <=
586
589 static_assert(std::is_same_v<typename Problem::KDataType, typename Problem::VDataType>);
590 constexpr index_t kv_element_space_size_in_bytes =
591 GetSingleSmemElementSpaceSize<Problem>() * sizeof(typename Problem::KDataType);
592
593 return kv_element_space_size_in_bytes;
594 }
595
596 template <typename Problem>
598 {
599 return 4 * GetSmemSizeKV<Problem>();
600 }
601};
602
603} // namespace ck_tile
#define CK_TILE_DEVICE
Definition config.hpp:41
#define CK_TILE_HOST_DEVICE
Definition config.hpp:42
CK_TILE_HOST_DEVICE constexpr auto make_embed_tile_distribution_encoding(OuterDstr, InnerDstr)
Definition tile_distribution_encoding.hpp:457
Definition tile/core/algorithm/cluster_descriptor.hpp:13
@ Double
Definition warp_gemm_attribute_mfma.hpp:15
typename impl::WarpGemmDispatcher< AType, BType, AccType, MPerWave, NPerWave, KPerWave, TransposeC, SwizzleA, UseStructuredSparsity, AttrNumAccess >::Type WarpGemmDispatcher
Definition warp_gemm_dispatcher.hpp:182
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
CK_TILE_HOST_DEVICE constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition coordinate_transform.hpp:1558
CK_TILE_HOST_DEVICE constexpr index_t get_warp_size()
Definition arch.hpp:63
TransposeTileDistributionTraits< TileDistributionEncoding_, DataType_, Policy, true > InputTileDistributionTraits
Definition load_tile_transpose.hpp:343
CK_TILE_HOST_DEVICE constexpr auto make_naive_tensor_descriptor_with_offset(const tuple< Lengths... > &lengths, const tuple< Strides... > &strides, const offset &os, number< GuaranteedLastDimensionVectorLength >=number<-1 >{}, number< GuaranteedLastDimensionVectorStride >=number<-1 >{})
Definition tile/core/tensor/tensor_descriptor.hpp:319
CK_TILE_HOST_DEVICE constexpr auto make_naive_tensor_descriptor(const tuple< Lengths... > &lengths, const tuple< Strides... > &strides, number< GuaranteedLastDimensionVectorLength >=number<-1 >{}, number< GuaranteedLastDimensionVectorStride >=number<-1 >{})
Definition tile/core/tensor/tensor_descriptor.hpp:274
CK_TILE_HOST_DEVICE constexpr auto make_merge_transform(const LowLengths &low_lengths)
Definition coordinate_transform.hpp:1615
CK_TILE_HOST_DEVICE constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldTopIdss, NewUpperDimensionNewTopIdss)
Definition tile/core/tensor/tensor_descriptor.hpp:203
constant< v > number
Definition tile/core/numeric/integral_constant.hpp:37
WarpGemmImpl< WarpGemmAttributeMfmaIterateKAndTransposedCDistribution< WarpGemmAttributeMfmaImplF16F16F32M32N32K8< WGAttrCtlEnum::Default_ >, 2, AttrNumAccess > > WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution
Definition warp_gemm.hpp:91
WarpGemmImpl< WarpGemmAttributeMfmaIterateKAndTransposedCDistribution< WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8< WGAttrCtlEnum::Default_ >, 2, AttrNumAccess > > WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution
Definition warp_gemm.hpp:213
CK_TILE_HOST_DEVICE constexpr T max(T x)
Definition tile/core/numeric/math.hpp:161
CK_TILE_HOST_DEVICE constexpr T min(T x)
Definition tile/core/numeric/math.hpp:210
int32_t index_t
Definition integer.hpp:9
CK_TILE_HOST_DEVICE constexpr auto make_static_tile_distribution(StaticTileDistributionEncoding_)
Definition tile_distribution.hpp:480
CK_TILE_HOST_DEVICE constexpr auto make_tuple(Xs &&... xs)
Definition tile/core/container/tuple.hpp:360
@ MNK
Definition block_gemm_areg_breg_creg_v2_custom_policy.hpp:13
Definition block_fmha_fwd_v3_pipeline_default_policy.hpp:15
static CK_TILE_DEVICE constexpr auto MakePRegTileDistribution()
Definition block_fmha_fwd_v3_pipeline_default_policy.hpp:179
static CK_TILE_DEVICE constexpr auto MakeKLdsStoreBlockDescriptor(ck_tile::number< IBuf >=ck_tile::number< 0 >{})
Definition block_fmha_fwd_v3_pipeline_default_policy.hpp:315
static CK_TILE_DEVICE constexpr auto MakeVRegTileDistribution()
Definition block_fmha_fwd_v3_pipeline_default_policy.hpp:189
static constexpr ck_tile::index_t kKLdsPadInBytes
Definition block_fmha_fwd_v3_pipeline_default_policy.hpp:310
static CK_TILE_DEVICE constexpr auto MakeQRegTileDistribution()
Definition block_fmha_fwd_v3_pipeline_default_policy.hpp:159
static constexpr ck_tile::index_t NumThreadPerWarpGroup
Definition block_fmha_fwd_v3_pipeline_default_policy.hpp:17
static CK_TILE_DEVICE constexpr ck_tile::index_t GetSmemSize()
Definition block_fmha_fwd_v3_pipeline_default_policy.hpp:597
static constexpr ck_tile::index_t NumWarpPerGroup
Definition block_fmha_fwd_v3_pipeline_default_policy.hpp:16
static CK_TILE_DEVICE constexpr auto MakeKLdsLoadBlockDescriptor()
Definition block_fmha_fwd_v3_pipeline_default_policy.hpp:372
static CK_TILE_DEVICE constexpr ck_tile::index_t GetSmemSizeKV()
Definition block_fmha_fwd_v3_pipeline_default_policy.hpp:570
static CK_TILE_HOST_DEVICE constexpr auto GetSmemVPackK()
Definition block_fmha_fwd_v3_pipeline_default_policy.hpp:81
static CK_TILE_DEVICE constexpr auto GetSingleSmemElementSpaceSize()
Definition block_fmha_fwd_v3_pipeline_default_policy.hpp:422
static CK_TILE_DEVICE constexpr auto MakeKRegTileDistribution()
Definition block_fmha_fwd_v3_pipeline_default_policy.hpp:169
static CK_TILE_DEVICE constexpr auto MakeVLdsLoadBlockDescriptor()
Definition block_fmha_fwd_v3_pipeline_default_policy.hpp:520
static CK_TILE_DEVICE constexpr auto GetAlignmentK()
Definition block_fmha_fwd_v3_pipeline_default_policy.hpp:35
static CK_TILE_DEVICE constexpr auto GetPVBlockGemm()
Definition block_fmha_fwd_v3_pipeline_default_policy.hpp:273
static CK_TILE_HOST_DEVICE constexpr auto GetSmemKPackK()
Definition block_fmha_fwd_v3_pipeline_default_policy.hpp:71
static CK_TILE_HOST_DEVICE constexpr auto GetAlignmentO()
Definition block_fmha_fwd_v3_pipeline_default_policy.hpp:61
static CK_TILE_DEVICE constexpr auto MakeKDramTileDistribution()
Definition block_fmha_fwd_v3_pipeline_default_policy.hpp:91
static CK_TILE_DEVICE constexpr auto MakeVDramTileDistribution()
Definition block_fmha_fwd_v3_pipeline_default_policy.hpp:125
static CK_TILE_DEVICE constexpr auto MakeVLdsStoreBlockDescriptor(ck_tile::number< IBuf >=ck_tile::number< 0 >{})
Definition block_fmha_fwd_v3_pipeline_default_policy.hpp:463
static constexpr ck_tile::index_t kVLdsPadInBytes
Definition block_fmha_fwd_v3_pipeline_default_policy.hpp:311
static CK_TILE_DEVICE constexpr auto GetQKBlockGemm()
Definition block_fmha_fwd_v3_pipeline_default_policy.hpp:227
static CK_TILE_HOST_DEVICE constexpr auto GetAlignmentQ()
Definition block_fmha_fwd_v3_pipeline_default_policy.hpp:23
static CK_TILE_DEVICE constexpr auto GetAlignmentV()
Definition block_fmha_fwd_v3_pipeline_default_policy.hpp:48
Definition block_gemm_areg_breg_creg_v2_custom_policy.hpp:23
Definition block_gemm_areg_breg_creg_v2.hpp:17
Definition block_gemm_problem.hpp:18
Definition tile_gemm_shape.hpp:17
Definition tile/core/container/sequence.hpp:49
Definition tile_distribution_encoding.hpp:26
Definition tile/core/container/tuple.hpp:192