block_masking.hpp Source File

block_masking.hpp Source File#

Composable Kernel: block_masking.hpp Source File
block_masking.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include "ck_tile/core.hpp"
7
8namespace ck_tile {
9
11{
13
14 // below enum could be causal, or sliding window
17
18 // this enum maybe not used by xformer/FA, since it's hard to
19 // specify left/right window for varlen case. put it here for
20 // debug purpose
22};
23
24// clang-format off
25/* generic Attention Mask Coordinate
26 use x(horizontal axis), y(vertical axis) to describe mask.
27 top-left corner is origin
28
29 x=1/y=5(top-left) x=4/y=5(botm-r) x=6/y=5 x=8/y=5(no mask)
30 1 * * * * * * * 1 1 1 1 * * * * 1 1 1 1 1 1 * * 1 1 1 1 1 1 1 1
31 1 1 * * * * * * 1 1 1 1 1 * * * 1 1 1 1 1 1 1 * 1 1 1 1 1 1 1 1
32 1 1 1 * * * * * 1 1 1 1 1 1 * * 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
33 1 1 1 1 * * * * 1 1 1 1 1 1 1 * 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
34 1 1 1 1 1 * * * 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
35 l=7,-1/r=0(tl) l=7,-1/r=0(br)
36
37 x=1/y=2 x=4/y=2 x=6/y=2 x=8/y=2
38 1 * * * * * * * 1 1 1 1 * * * * 1 1 1 1 1 1 * * 1 1 1 1 1 1 1 1
39 1 1 * * * * * * 1 1 1 1 1 * * * 1 1 1 1 1 1 1 * 1 1 1 1 1 1 1 1
40 * 1 1 * * * * * * 1 1 1 1 1 * * * 1 1 1 1 1 1 1 * 1 1 1 1 1 1 1
41 * * 1 1 * * * * * * 1 1 1 1 1 * * * 1 1 1 1 1 1 * * 1 1 1 1 1 1
42 * * * 1 1 * * * * * * 1 1 1 1 1 * * * 1 1 1 1 1 * * * 1 1 1 1 1
43 l=1/r=0(tl) l=1/r=3(tl) l=1/r=5(tl) l=1/r=7(tl)
44 l=4/r=0(br) l=4/r=2(br) l=4/r=4(br)
45
46 x=4/y=-1 x=6/y=-1 x=8/y=-1
47 * * 1 1 * * * * * * 1 1 1 1 * * * * 1 1 1 1 1 1
48 * * * 1 1 * * * * * * 1 1 1 1 * * * * 1 1 1 1 1
49 * * * * 1 1 * * * * * * 1 1 1 1 * * * * 1 1 1 1
50 * * * * * 1 1 * * * * * * 1 1 1 * * * * * 1 1 1
51 * * * * * * 1 1 * * * * * * 1 1 * * * * * * 1 1
52
53 x=-2/y=5 x=1/y=5(top-left) x=0/y=5(botm-r)
54 * * * * * * * * 1 * * * * * * *
55 * * * * * * * * 1 1 * * 1 * * *
56 * * * * * * * * 1 1 1 * 1 1 * *
57 1 * * * * * * * 1 1 1 1 1 1 1 *
58 1 1 * * * * * * 1 1 1 1 1 1 1 1
59
60 Validations:
61 x + y > 1 (x + y >= 2)
62
63 Note:
64 y = seq_q, x = 1 -> top-left
65 y = seq_q, x = seq_k - seq_q + 1 -> bottom-right
66 y < seq_q, x < seq_k -> local-attn
67 y = seq_q, x = seq_k -> no mask
68
69*/
70namespace impl {
71 template <bool IsMasking_, bool IsLocal_> struct MaskName;
72 template<> struct MaskName<false, false> { static constexpr const char * name = "mn"; };
73 template<> struct MaskName<false, true> { static constexpr const char * name = "mn"; };
74 template<> struct MaskName<true, false> { static constexpr const char * name = "mc"; };
75 template<> struct MaskName<true, true> { static constexpr const char * name = "mg"; };
76}
77// clang-format on
78
79template <bool IsMasking_ = true, bool IsLocal_ = false>
81{
82 static constexpr bool IsMasking = IsMasking_; // false will disable masking
83 static constexpr bool IsLocal = IsLocal_; // if true, upper/lower area could have mask,
84 // else only upper-right could have mask
85
86 static constexpr const char* name = impl::MaskName<IsMasking, IsLocal>::name;
87
89 : GenericAttentionMask(0, 0, y_total_, x_total_)
90 {
91 }
92
95 : y(y_), x(x_), y_total(y_total_), x_total(x_total_)
96 {
97 }
98 template <typename MaskCoordinates>
99 CK_TILE_HOST_DEVICE GenericAttentionMask(const MaskCoordinates& mask_coord)
100 : y(mask_coord.at(number<0>{})),
101 x(mask_coord.at(number<1>{})),
102 y_total(mask_coord.at(number<2>{})),
103 x_total(mask_coord.at(number<3>{}))
104 {
105 }
106
107 // to get the loop length along X axis, return index:[start, end), end-start=length
108 // use this if need loop over X axis tile by tile (like k-seqlen loopover)
109 // TODO: x_end still could be negative, so end-start could be negative(need check)
110 template <index_t YTile, index_t XTile>
111 CK_TILE_HOST_DEVICE constexpr auto
113 {
114 if constexpr(!IsMasking)
115 {
116 return ck_tile::make_tuple(0, x_total);
117 }
118 else
119 {
120 // get the tile start/end range assum we loop over along X tile by tile
121 index_t x_start = [&]() {
122 if constexpr(IsLocal)
123 {
124 index_t tmp = max(-y + i_y + 1, 0);
125 return (tmp / XTile) * XTile; // round to tile aligned
126 }
127 else
128 {
129 return 0;
130 }
131 }();
132
133 // TODO: end could be negative, we ignore clamp here, and let caller to check
134 // ... in which case end-start is negative
135 index_t x_end = [&]() {
136 index_t tmp = min(i_y + YTile - 1 + x, x_total);
137 return ((tmp + XTile - 1) / XTile) * XTile;
138 }();
139
140 return ck_tile::make_tuple(x_start, x_end);
141 }
142 }
143
144 // to get the loop length along Y axis, return index:[start, end), end-start=length
145 // use this if need loop over Y axis tile by tile (like q-seqlen loopover)
146 // TODO: y_end still could be negative, so end-start could be negative(need check)
147 template <index_t YTile, index_t XTile>
148 CK_TILE_HOST_DEVICE constexpr auto
150 {
151 if constexpr(!IsMasking)
152 {
153 return ck_tile::make_tuple(0, y_total);
154 }
155 else
156 {
157 // get the tile start/end range assum we loop over along Y tile by tile
158 index_t y_start = [&]() {
159 index_t tmp = max(-x + i_x + 1, 0);
160 return (tmp / YTile) * YTile; // round to tile aligned
161 }();
162
163 // TODO: end could be negative, we ignore clamp here, and let caller to check
164 // ... in which case end-start is negative
165 index_t y_end = [&]() {
166 index_t tmp = min(i_x + XTile - 1 + y, y_total);
167 return ((tmp + YTile - 1) / YTile) * YTile;
168 }();
169
170 return ck_tile::make_tuple(y_start, y_end);
171 }
172 }
173
174 // per-pixel check if out-of-bound, if true, need mask a value(like -INF)
175 CK_TILE_HOST_DEVICE constexpr auto IsOutOfBound(index_t i_y, index_t i_x) const
176 {
177 if constexpr(!IsMasking)
178 {
179 return i_x >= x_total;
180 }
181 else
182 {
183 // no need to do min/max here, since i_x will never be < 0 or >= x_total
184 index_t x_start = -y + i_y + 1;
185 index_t x_end = min(i_y + x, x_total);
186
187 if constexpr(IsLocal)
188 {
189 return i_x < x_start || i_x >= x_end;
190 }
191 else
192 {
193 return i_x >= x_end || i_y >= y_total;
194 }
195 }
196 }
197
198 // if current tile is at the edge, means need per-pixel mask check.
199 // otherwise no need to check per-pixel
200 // Attention! assume the idex passed in this function is with in range of GetTileRangeAlongX/Y()
201 // can be used as a fast-path to decide if do per-pixel check or not
202 template <index_t TileHeight, index_t TileWidth>
203 CK_TILE_HOST_DEVICE constexpr auto
205 {
206 if constexpr(!IsMasking)
207 {
208 // TODO: no need to check begin
209 return (i_tile_left + TileWidth) > x_total;
210 }
211 else
212 {
213 if constexpr(IsLocal)
214 {
215 // check top-right corner > x or left-borrom corner < x
216 index_t i_tile_right = i_tile_left + TileWidth;
217 index_t i_tile_bottom = i_tile_top + TileHeight;
218 index_t x_end = min(i_tile_top + x, x_total);
219
220 bool top_right_edge = i_tile_right > (i_tile_top + x);
221 bool bottom_left_edge = i_tile_bottom > (i_tile_left + y);
222 bool is_partial_out_of_bound =
223 i_tile_right > x_end; // only consider right-pad for now
224
225 return top_right_edge || bottom_left_edge || is_partial_out_of_bound;
226 }
227 else
228 {
229 // only need to check top-right corner > x
230 index_t i_tile_right = i_tile_left + TileWidth;
231 index_t x_end = min(i_tile_top + x, x_total);
232
233 bool top_right_edge = i_tile_right > x_end;
234 return top_right_edge;
235 }
236 }
237 }
238
239 private:
240 index_t y, x;
241 index_t y_total, x_total;
242};
243
244// clang-format off
245namespace impl {
246 template <bool IsMasking_> struct SimplifiedMaskName;
247 template<> struct SimplifiedMaskName<false> { static constexpr const char * name = "nomask"; };
248 template<> struct SimplifiedMaskName<true> { static constexpr const char * name = "mask"; };
249}
250// clang-format on
251
252// this version only have 2 variation: masking and non-masking
253// This is more friendly to codegen (e.g. need generate less kernel)
254// ... with the trade-off that may have more instruction in causal mode
255template <bool IsMasking_ = true>
257{
258 static constexpr bool IsMasking = IsMasking_; // false will disable masking
259
260 static constexpr const char* name = impl::SimplifiedMaskName<IsMasking>::name;
261
263 : SimplifiedGenericAttentionMask(0, 0, y_total_, x_total_)
264 {
265 }
266
269 : y(y_), x(x_), y_total(y_total_), x_total(x_total_)
270 {
271 }
272 template <typename MaskCoordinates>
273 CK_TILE_HOST_DEVICE SimplifiedGenericAttentionMask(const MaskCoordinates& mask_coord)
274 : y(mask_coord.at(number<0>{})),
275 x(mask_coord.at(number<1>{})),
276 y_total(mask_coord.at(number<2>{})),
277 x_total(mask_coord.at(number<3>{}))
278 {
279 }
280
281 // to get the loop length along X axis, return index:[start, end), end-start=length
282 // use this if need loop over X axis tile by tile (like k-seqlen loopover)
283 // TODO: x_end still could be negative, so end-start could be negative(need check)
284 template <index_t YTile, index_t XTile>
285 CK_TILE_HOST_DEVICE constexpr auto
287 {
288 if constexpr(!IsMasking)
289 {
290 return ck_tile::make_tuple(0, x_total);
291 }
292 else
293 {
294 // get the tile start/end range assum we loop over along X tile by tile
295 index_t x_start = [&]() {
296 index_t tmp = max(-y + i_y + 1, 0);
297 return (tmp / XTile) * XTile; // round to tile aligned
298 }();
299
300 // TODO: end could be negative, we ignore clamp here, and let caller to check
301 // ... in which case end-start is negative
302 index_t x_end = [&]() {
303 index_t tmp = min(i_y + YTile - 1 + x, x_total);
304 return ((tmp + XTile - 1) / XTile) * XTile;
305 }();
306
307 return ck_tile::make_tuple(x_start, x_end);
308 }
309 }
310
311 template <index_t TileHeight, index_t TileWidth>
313 number<TileHeight> height,
314 number<TileWidth> width,
315 index_t num_splits,
316 index_t i_split) const
317 {
318 auto [origin_start, origin_end] = GetTileRangeAlongX(i_y, height, width);
319
320 const index_t x_per_split = ck_tile::max(1, integer_divide_ceil(x_total, num_splits));
321 const index_t split_start = x_per_split * i_split;
322 const index_t split_end = ck_tile::min(x_total, split_start + x_per_split);
323
324 return ck_tile::make_tuple(ck_tile::max(origin_start, split_start),
325 ck_tile::min(origin_end, split_end));
326 }
327
328 // to get the loop length along Y axis, return index:[start, end), end-start=length
329 // use this if need loop over Y axis tile by tile (like q-seqlen loopover)
330 // TODO: y_end still could be negative, so end-start could be negative(need check)
331 template <index_t YTile, index_t XTile>
332 CK_TILE_HOST_DEVICE constexpr auto
334 {
335 if constexpr(!IsMasking)
336 {
337 return ck_tile::make_tuple(0, y_total);
338 }
339 else
340 {
341 // get the tile start/end range assum we loop over along Y tile by tile
342 index_t y_start = [&]() {
343 index_t tmp = max(-x + i_x + 1, 0);
344 return (tmp / YTile) * YTile; // round to tile aligned
345 }();
346
347 // TODO: end could be negative, we ignore clamp here, and let caller to check
348 // ... in which case end-start is negative
349 index_t y_end = [&]() {
350 index_t tmp = min(i_x + XTile - 1 + y, y_total);
351 return ((tmp + YTile - 1) / YTile) * YTile;
352 }();
353
354 return ck_tile::make_tuple(y_start, y_end);
355 }
356 }
357
358 // per-pixel check if out-of-bound, if true, need mask a value(like -INF)
359 CK_TILE_HOST_DEVICE constexpr auto IsOutOfBound(index_t i_y, index_t i_x) const
360 {
361 if constexpr(!IsMasking)
362 {
363 // the only case that need do following compare is under kPadSeqLenK
364 // ... for non-masking kernel.
365 return i_x >= x_total;
366 }
367 else
368 {
369 index_t x_start = -y + i_y + 1; // this could be negative, but it's fine
370 index_t x_end = min(i_y + x, x_total); // need min in case x is padded
371
372 return i_x < x_start || i_x >= x_end || i_y >= y_total;
373 }
374 }
375
376 // if current tile is at the edge, means need per-pixel mask check.
377 // otherwise no need to check per-pixel
378 // Attention! assume the idex passed in this function is with in range of GetTileRangeAlongX/Y()
379 // can be used as a fast-path to decide if do per-pixel check or not
380 template <index_t TileHeight, index_t TileWidth>
381 CK_TILE_HOST_DEVICE constexpr auto
383 {
384 if constexpr(!IsMasking)
385 {
386 // the only case that need do following compare is under kPadSeqLenK
387 // ... for non-masking kernel.
388 // return (i_x < x_total) && ((i_x + TileWidth) > x_total);
389
390 // TODO: no need to check begin
391 return (i_x + TileWidth) > x_total;
392 }
393 else
394 {
395 // check top-right corner > x or left-borrom corner < x
396 index_t i_x_end = i_x + TileWidth;
397 index_t i_y_end = i_y + TileHeight;
398 // index_t x_end = min(i_y + x, x_total);
399
400 bool top_right_edge = i_x_end > min(i_y + x, x_total); // consider right pad
401 bool bottom_left_edge = i_y_end > min(i_x + y, y_total); // consider bottom pad
402 // bool is_partial_out_of_bound = i_x_end > x_end; // only consider right-pad for now
403
404 return top_right_edge || bottom_left_edge;
405 }
406 }
407
408 private:
409 index_t y, x;
410 index_t y_total, x_total;
411};
412
413// clang-format off
414namespace impl {
415 template <bool IsMasking_> struct SimplifiedRatioMaskName;
416 template<> struct SimplifiedRatioMaskName<false> { static constexpr const char * name = "nomask"; };
417 template<> struct SimplifiedRatioMaskName<true> { static constexpr const char * name = "mask"; };
418}
419// clang-format on
420
421// this version is used for cases that the step length of y-direction changes greater than one. It
422// means that the mask is not a regular triangular matrix.
423
424// clang-format off
425/* y_ratio is used to describe the step length of y-direction changes
426 in certain performance optimization scenarios like merging seqlen
427 and qk_head_ratio, for example:
428
429 x=1/y=6/y_ratio=2(top-left)
430 1 * * * * * * *
431 1 * * * * * * *
432 1 1 * * * * * *
433 1 1 * * * * * *
434 1 1 1 * * * * *
435 1 1 1 * * * * *
436
437*/
438// clang-format on
439template <bool IsMasking_ = true>
441{
442 static constexpr bool IsMasking = IsMasking_; // false will disable masking
443
445
447 : SimplifiedRatioAttentionMask(0, 0, y_total_, x_total_, 0, 1, mdiv{})
448 {
449 }
450
453 index_t y_real_, index_t x_, index_t y_total_, index_t x_total_, mdiv y_ratio_mdiv_)
454 : SimplifiedRatioAttentionMask(/*y_=*/y_real_ * static_cast<index_t>(y_ratio_mdiv_.get()),
455 /*x_=*/x_,
456 /*y_total_=*/y_total_,
457 /*x_total_=*/x_total_,
458 /*y_real_=*/y_real_,
459 /*y_ratio_=*/static_cast<index_t>(y_ratio_mdiv_.get()),
460 /*y_ratio_mdiv_=*/y_ratio_mdiv_)
461
462 {
463 }
466 index_t x_,
467 index_t y_total_,
468 index_t x_total_,
469 index_t y_real_,
470 index_t y_ratio_,
471 mdiv y_ratio_mdiv_)
472 : y(y_),
473 x(x_),
474 y_total(y_total_),
475 x_total(x_total_),
476 y_real(y_real_),
477 y_ratio(y_ratio_),
478 y_ratio_mdiv(y_ratio_mdiv_)
479 {
480 }
481
482 // to get the loop length along X axis, return index:[start, end), end-start=length
483 // use this if need loop over X axis tile by tile (like k-seqlen loopover)
484 // TODO: x_end still could be negative, so end-start could be negative(need check)
485 template <index_t YTile, index_t XTile>
486 CK_TILE_HOST_DEVICE constexpr auto
488 {
489 if constexpr(!IsMasking)
490 {
491 return ck_tile::make_tuple(0, x_total);
492 }
493 else
494 {
495 // get the tile start/end range assum we loop over along X tile by tile
496 index_t x_start = [&]() {
497 index_t tmp = -y_real +
498 static_cast<index_t>(y_ratio_mdiv.div(static_cast<uint32_t>(i_y))) +
499 1;
500
501 return (tmp / XTile) * XTile; // round to tile aligned
502 }();
503
504 // TODO: end could be negative, we ignore clamp here, and let caller to check
505 // ... in which case end-start is negative
506 index_t x_end = [&]() {
507 uint32_t y_offset = i_y + YTile - 1;
508 index_t tmp = min(static_cast<index_t>(y_ratio_mdiv.div(y_offset)) + x, x_total);
509 return ((tmp + XTile - 1) / XTile) * XTile;
510 }();
511
512 return ck_tile::make_tuple(x_start, x_end);
513 }
514 }
515
516 // to get the loop length along Y axis, return index:[start, end), end-start=length
517 // use this if need loop over Y axis tile by tile (like q-seqlen loopover)
518 // TODO: y_end still could be negative, so end-start could be negative(need check)
519 template <index_t YTile, index_t XTile>
520 CK_TILE_HOST_DEVICE constexpr auto
522 {
523 if constexpr(!IsMasking)
524 {
525 return ck_tile::make_tuple(0, y_total);
526 }
527 else
528 {
529 // get the tile start/end range assum we loop over along Y tile by tile
530 index_t y_start = [&]() {
531 index_t tmp = max((-x + i_x + 1) * y_ratio, 0);
532 return (tmp / YTile) * YTile; // round to tile aligned
533 }();
534
535 // TODO: end could be negative, we ignore clamp here, and let caller to check
536 // ... in which case end-start is negative
537 index_t y_end = [&]() {
538 index_t tmp = min((i_x + XTile - 1) * y_ratio + y, y_total);
539 return ((tmp + YTile - 1) / YTile) * YTile;
540 }();
541
542 return ck_tile::make_tuple(y_start, y_end);
543 }
544 }
545
546 // per-pixel check if out-of-bound, if true, need mask a value(like -INF)
547 CK_TILE_HOST_DEVICE constexpr auto IsOutOfBound(index_t i_y, index_t i_x) const
548 {
549 if constexpr(!IsMasking)
550 {
551 return i_x >= x_total;
552 }
553 else
554 {
555 index_t x_tmp = static_cast<index_t>(y_ratio_mdiv.div(static_cast<uint32_t>(i_y)));
556 index_t x_start = -y_real + x_tmp + 1;
557 index_t x_end = min(x_tmp + x,
558 x_total); // need min in case x is padded
559 return i_x < x_start || i_x >= x_end || i_y >= y_total;
560 }
561 }
562
563 // if current tile is at the edge, means need per-pixel mask check.
564 // otherwise no need to check per-pixel
565 // Attention! assume the idex passed in this function is with in range of GetTileRangeAlongX/Y()
566 // can be used as a fast-path to decide if do per-pixel check or not
567 template <index_t TileHeight, index_t TileWidth>
568 CK_TILE_HOST_DEVICE constexpr auto
570 {
571 if constexpr(!IsMasking)
572 {
573 // the only case that need do following compare is under kPadSeqLenK
574 // ... for non-masking kernel.
575 // return (i_x < x_total) && ((i_x + TileWidth) > x_total);
576
577 return (i_x + TileWidth) > x_total;
578 }
579 else
580 {
581 // check top-right corner > x or left-borrom corner < x
582 index_t i_x_end = i_x + TileWidth;
583 index_t i_y_end = i_y + TileHeight;
584 // index_t x_end = min(i_y + x, x_total);
585 uint32_t y_tmp = static_cast<uint32_t>(i_y);
586 bool top_right_edge = i_x_end > min(static_cast<index_t>(y_ratio_mdiv.div(y_tmp)) + x,
587 x_total); // consider right pad
588 bool bottom_left_edge =
589 i_y_end > min(i_x * y_ratio + y, y_total); // consider bottom pad
590 return top_right_edge || bottom_left_edge;
591 }
592 }
593
594 private:
595 index_t y, x;
596 index_t y_total, x_total;
597 // y_real is vertical axis before multiplying y_ratio. y_real * y_ratio = y
598 index_t y_real;
599 index_t y_ratio;
600 mdiv y_ratio_mdiv;
601};
602
603// TODO: prefer use this function in host code
604// can convert from the FA style left/right to our generic coordinate
605// if left_size < 0 && right_size = 0, it is normal causal mask
606// local is left_size >=0 or right_size >=0
607CK_TILE_HOST_DEVICE constexpr auto
609 index_t right_size,
610 index_t y_total,
611 index_t x_total,
612 bool is_top_left = true)
613{
614 // TODO: below should all use sgpr arithmetic
615 index_t left_size_tmp = is_top_left ? y_total - 1 : x_total - 1;
616 index_t right_size_tmp = is_top_left ? x_total - 1 : y_total - 1;
617
618 left_size = left_size < 0 ? left_size_tmp : left_size;
619 right_size = right_size < 0 ? right_size_tmp : right_size;
620
621 index_t x_tmp = is_top_left ? 0 : x_total - y_total;
622 index_t y_tmp = is_top_left ? 0 : y_total - x_total;
623
624 index_t x = 1 + right_size + x_tmp;
625 index_t y = 1 + left_size + y_tmp;
626
627 return ck_tile::make_tuple(y, x, y_total, x_total);
628}
629
630template <typename MaskType>
631CK_TILE_HOST_DEVICE constexpr auto
633 index_t right_size,
634 index_t y_total,
635 index_t x_total,
636 bool is_top_left = true)
637{
639 left_size, right_size, y_total, x_total, is_top_left);
640 return MaskType{r.at(number<0>{}), r.at(number<1>{}), y_total, x_total};
641}
642} // namespace ck_tile
#define CK_TILE_HOST_DEVICE
Definition config.hpp:42
Definition tile/core/algorithm/cluster_descriptor.hpp:13
CK_TILE_HOST_DEVICE constexpr auto make_generic_attention_mask_from_lr_window(index_t left_size, index_t right_size, index_t y_total, index_t x_total, bool is_top_left=true)
Definition block_masking.hpp:632
CK_TILE_HOST_DEVICE constexpr auto make_generic_attention_mask_coordinates_from_lr_window(index_t left_size, index_t right_size, index_t y_total, index_t x_total, bool is_top_left=true)
Definition block_masking.hpp:608
constant< v > number
Definition tile/core/numeric/integral_constant.hpp:37
CK_TILE_HOST_DEVICE constexpr auto integer_divide_ceil(X x, Y y)
Definition tile/core/numeric/math.hpp:149
GenericAttentionMaskEnum
Definition block_masking.hpp:11
@ MASK_GENERIC
Definition block_masking.hpp:21
@ MASK_FROM_TOP_LEFT
Definition block_masking.hpp:15
@ NO_MASK
Definition block_masking.hpp:12
@ MASK_FROM_BOTTOM_RIGHT
Definition block_masking.hpp:16
CK_TILE_HOST_DEVICE constexpr T max(T x)
Definition tile/core/numeric/math.hpp:161
CK_TILE_HOST_DEVICE constexpr T min(T x)
Definition tile/core/numeric/math.hpp:210
int32_t index_t
Definition integer.hpp:9
CK_TILE_HOST_DEVICE constexpr auto make_tuple(Xs &&... xs)
Definition tile/core/container/tuple.hpp:360
unsigned int uint32_t
Definition stdint.h:126
CK_TILE_HOST_DEVICE GenericAttentionMask(index_t y_, index_t x_, index_t y_total_, index_t x_total_)
Definition block_masking.hpp:94
CK_TILE_HOST_DEVICE GenericAttentionMask(index_t y_total_, index_t x_total_)
Definition block_masking.hpp:88
CK_TILE_HOST_DEVICE constexpr auto GetTileRangeAlongY(index_t i_x, number< YTile >, number< XTile >) const
Definition block_masking.hpp:149
static constexpr const char * name
Definition block_masking.hpp:86
CK_TILE_HOST_DEVICE constexpr auto IsEdgeTile(index_t i_tile_top, index_t i_tile_left, number< TileHeight >, number< TileWidth >) const
Definition block_masking.hpp:204
CK_TILE_HOST_DEVICE constexpr auto IsOutOfBound(index_t i_y, index_t i_x) const
Definition block_masking.hpp:175
CK_TILE_HOST_DEVICE GenericAttentionMask(const MaskCoordinates &mask_coord)
Definition block_masking.hpp:99
static constexpr bool IsMasking
Definition block_masking.hpp:82
static constexpr bool IsLocal
Definition block_masking.hpp:83
CK_TILE_HOST_DEVICE constexpr auto GetTileRangeAlongX(index_t i_y, number< YTile >, number< XTile >) const
Definition block_masking.hpp:112
CK_TILE_HOST_DEVICE constexpr auto IsEdgeTile(index_t i_y, index_t i_x, number< TileHeight >, number< TileWidth >) const
Definition block_masking.hpp:382
static constexpr const char * name
Definition block_masking.hpp:260
CK_TILE_HOST_DEVICE constexpr auto GetTileRangeAlongX(index_t i_y, number< YTile >, number< XTile >) const
Definition block_masking.hpp:286
CK_TILE_HOST_DEVICE SimplifiedGenericAttentionMask(const MaskCoordinates &mask_coord)
Definition block_masking.hpp:273
static constexpr bool IsMasking
Definition block_masking.hpp:258
CK_TILE_HOST_DEVICE constexpr auto GetTileRangeAlongX(index_t i_y, number< TileHeight > height, number< TileWidth > width, index_t num_splits, index_t i_split) const
Definition block_masking.hpp:312
CK_TILE_HOST_DEVICE constexpr auto IsOutOfBound(index_t i_y, index_t i_x) const
Definition block_masking.hpp:359
CK_TILE_HOST_DEVICE constexpr auto GetTileRangeAlongY(index_t i_x, number< YTile >, number< XTile >) const
Definition block_masking.hpp:333
CK_TILE_HOST_DEVICE SimplifiedGenericAttentionMask(index_t y_total_, index_t x_total_)
Definition block_masking.hpp:262
CK_TILE_HOST_DEVICE SimplifiedGenericAttentionMask(index_t y_, index_t x_, index_t y_total_, index_t x_total_)
Definition block_masking.hpp:268
CK_TILE_HOST_DEVICE SimplifiedRatioAttentionMask(index_t y_, index_t x_, index_t y_total_, index_t x_total_, index_t y_real_, index_t y_ratio_, mdiv y_ratio_mdiv_)
Definition block_masking.hpp:465
static constexpr const char * name
Definition block_masking.hpp:444
CK_TILE_HOST_DEVICE SimplifiedRatioAttentionMask(index_t y_real_, index_t x_, index_t y_total_, index_t x_total_, mdiv y_ratio_mdiv_)
Definition block_masking.hpp:452
CK_TILE_HOST_DEVICE constexpr auto GetTileRangeAlongX(index_t i_y, number< YTile >, number< XTile >) const
Definition block_masking.hpp:487
CK_TILE_HOST_DEVICE SimplifiedRatioAttentionMask(index_t y_total_, index_t x_total_)
Definition block_masking.hpp:446
CK_TILE_HOST_DEVICE constexpr auto IsEdgeTile(index_t i_y, index_t i_x, number< TileHeight >, number< TileWidth >) const
Definition block_masking.hpp:569
CK_TILE_HOST_DEVICE constexpr auto IsOutOfBound(index_t i_y, index_t i_x) const
Definition block_masking.hpp:547
CK_TILE_HOST_DEVICE constexpr auto GetTileRangeAlongY(index_t i_x, number< YTile >, number< XTile >) const
Definition block_masking.hpp:521
static constexpr bool IsMasking
Definition block_masking.hpp:442
static constexpr const char * name
Definition block_masking.hpp:72
static constexpr const char * name
Definition block_masking.hpp:73
static constexpr const char * name
Definition block_masking.hpp:74
static constexpr const char * name
Definition block_masking.hpp:75
Definition block_masking.hpp:71
static constexpr const char * name
Definition block_masking.hpp:247
static constexpr const char * name
Definition block_masking.hpp:248
Definition block_masking.hpp:246
static constexpr const char * name
Definition block_masking.hpp:416
static constexpr const char * name
Definition block_masking.hpp:417
Definition block_masking.hpp:415
Definition magic_div.hpp:186