gridwise_gemm_multiple_abd_xdl_cshuffle.hpp Source File

gridwise_gemm_multiple_abd_xdl_cshuffle.hpp Source File#

Composable Kernel: gridwise_gemm_multiple_abd_xdl_cshuffle.hpp Source File
gridwise_gemm_multiple_abd_xdl_cshuffle.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
16
19
20namespace ck {
21
22// GEMM:
23// input : A0[M, K], A1[M, K]
24// input : B0[N, K], B1[N, K]
25// input : D0[M, N], D1[M, N], ...
26// output : E[M, N]
27// C = a_op(A) * b_op(B)
28// E = cde_op(C, D0, D1, ...)
29// Assume:
30// D0, D1, ... and E have the same layout
31template <typename AsDataType,
32 typename BsDataType,
33 typename AComputeDataType_,
34 typename AccDataType,
35 typename CShuffleDataType,
36 typename DsDataType,
37 typename EDataType,
38 typename AElementwiseOperation,
39 typename BElementwiseOperation,
40 typename CDEElementwiseOperation,
41 InMemoryDataOperationEnum EGlobalMemoryDataOperation,
42 index_t NumGemmKPrefetchStage,
43 index_t BlockSize,
44 index_t MPerBlock,
45 index_t NPerBlock,
46 index_t KPerBlock,
47 index_t AK1Value,
48 index_t BK1Value,
49 index_t MPerXdl,
50 index_t NPerXdl,
51 index_t MXdlPerWave,
52 index_t NXdlPerWave,
53 typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
54 typename ABlockTransferThreadClusterArrangeOrder,
55 typename ABlockTransferSrcAccessOrder,
56 index_t ABlockTransferSrcVectorDim,
57 index_t ABlockTransferSrcScalarPerVector,
58 index_t ABlockTransferDstScalarPerVector_AK1,
59 bool AThreadTransferSrcResetCoordinateAfterRun,
60 index_t ABlockLdsExtraM,
61 typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
62 typename BBlockTransferThreadClusterArrangeOrder,
63 typename BBlockTransferSrcAccessOrder,
64 index_t BBlockTransferSrcVectorDim,
65 index_t BBlockTransferSrcScalarPerVector,
66 index_t BBlockTransferDstScalarPerVector_BK1,
67 bool BThreadTransferSrcResetCoordinateAfterRun,
68 index_t BBlockLdsExtraN,
69 index_t CShuffleMXdlPerWavePerShuffle,
70 index_t CShuffleNXdlPerWavePerShuffle,
71 typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
72 index_t CDEShuffleBlockTransferScalarPerVector_NPerBlock,
73 LoopScheduler LoopSched,
75 typename BComputeDataType_ = AComputeDataType_>
77{
78 static constexpr index_t NumATensor = AsDataType::Size();
79 static constexpr index_t NumBTensor = BsDataType::Size();
80 static constexpr index_t NumDTensor = DsDataType::Size();
81
83
84 static constexpr auto I0 = Number<0>{};
85 static constexpr auto I1 = Number<1>{};
86 static constexpr auto I2 = Number<2>{};
87 static constexpr auto I3 = Number<3>{};
88 static constexpr auto I4 = Number<4>{};
89 static constexpr auto I5 = Number<5>{};
90 static constexpr auto I6 = Number<6>{};
91 static constexpr auto I7 = Number<7>{};
92
93 // K1 should be Number<...>
94 static constexpr auto AK1 = Number<AK1Value>{};
95 static constexpr auto BK1 = Number<BK1Value>{};
96 static constexpr auto AK0PerBlock = Number<KPerBlock / AK1Value>{};
97 static constexpr auto BK0PerBlock = Number<KPerBlock / BK1Value>{};
98
100
103
104#if CK_GFX90A_DENORM_WORKAROUND
105 using AComputeDataType =
107 using BComputeDataType =
109#else
110 // Element data type is used in LDS and registers. ComputeDataType_ is inside mfma, eg tf32.
115#endif
116
117 __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
118 {
119 // A matrix in LDS memory, dst of blockwise copy
123 }
124
125 __host__ __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
126 {
127 // B matrix in LDS memory, dst of blockwise copy
131 }
132
133 __host__ __device__ static constexpr auto
135 {
136 constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
137 constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
138
139 constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
143 I1,
145
146 return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock;
147 }
148
149 static constexpr auto MakeAsGridPointer()
150 {
151 return generate_tuple(
152 [&](auto i) {
153 using ADataType = remove_cvref_t<tuple_element_t<i.value, AsDataType>>;
154
155 return static_cast<const ADataType*>(nullptr);
156 },
158 }
159
160 static constexpr auto MakeBsGridPointer()
161 {
162 return generate_tuple(
163 [&](auto i) {
164 using BDataType = remove_cvref_t<tuple_element_t<i.value, BsDataType>>;
165
166 return static_cast<const BDataType*>(nullptr);
167 },
169 }
170
171 // ck::Tuple<const D0DataType*, const D1DataType*, ...>
172 static constexpr auto MakeDsGridPointer()
173 {
174 return generate_tuple(
175 [&](auto i) {
176 using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
177
178 return static_cast<const DDataType*>(nullptr);
179 },
181 }
182
183 __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
184 {
185 // LDS allocation for A and B: be careful of alignment
186 constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
187 constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
188
189 // lds max alignment
190 constexpr auto max_lds_align = math::lcm(AK1, BK1);
191
192 constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
193 a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
194
195 constexpr auto b_block_space_size_aligned = math::integer_least_multiple(
196 b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
197
198 // LDS allocation for C shuffle in LDS
199 constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
201
202 constexpr auto c_block_size =
203 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
204
205 return math::max(a_block_space_size_aligned * sizeof(AElementDataType) +
206 b_block_space_size_aligned * sizeof(BElementDataType),
207 c_block_size * sizeof(CShuffleDataType));
208 }
209
210 // A desc for source in blockwise copy
211 template <typename AGridDesc_M_K>
212 __host__ __device__ static constexpr auto
213 MakeDefaultAGridDescriptor_AK0_M_AK1(const AGridDesc_M_K& a_grid_desc_m_k)
214 {
215 const auto M = a_grid_desc_m_k.GetLength(I0);
216 const auto K = a_grid_desc_m_k.GetLength(I1);
217
218 const auto AK0 = K / AK1;
219
220 return transform_tensor_descriptor(a_grid_desc_m_k,
225 }
226
227 template <typename AsGridDesc_M_K>
228 __host__ __device__ static constexpr auto
229 MakeDefaultAsGridDescriptor_AK0_M_AK1(const AsGridDesc_M_K& as_grid_desc_m_k)
230 {
231 return generate_tuple(
232 [&](auto i) { return MakeDefaultAGridDescriptor_AK0_M_AK1(as_grid_desc_m_k[i]); },
234 }
235
236 // B desc for source in blockwise copy
237 template <typename BGridDesc_N_K>
238 __host__ __device__ static constexpr auto
239 MakeDefaultBGridDescriptor_BK0_N_BK1(const BGridDesc_N_K& b_grid_desc_n_k)
240 {
241 const auto N = b_grid_desc_n_k.GetLength(I0);
242 const auto K = b_grid_desc_n_k.GetLength(I1);
243
244 const auto BK0 = K / BK1;
245
246 return transform_tensor_descriptor(b_grid_desc_n_k,
251 }
252
253 template <typename BsGridDesc_N_K>
254 __host__ __device__ static constexpr auto
255 MakeDefaultBsGridDescriptor_BK0_N_BK1(const BsGridDesc_N_K& bs_grid_desc_n_k)
256 {
257 return generate_tuple(
258 [&](auto i) { return MakeDefaultBGridDescriptor_BK0_N_BK1(bs_grid_desc_n_k[i]); },
260 }
261
262 // E desc for destination in blockwise copy
263 template <typename EGridDesc_M_N>
264 __host__ __device__ static constexpr auto
265 MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const EGridDesc_M_N& e_grid_desc_m_n)
266 {
267 const auto M = e_grid_desc_m_n.GetLength(I0);
268 const auto N = e_grid_desc_m_n.GetLength(I1);
269
270 const auto MBlock = M / MPerBlock;
271 const auto NBlock = N / NPerBlock;
272
273 const auto e_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor(
274 e_grid_desc_m_n,
279
280 return e_grid_desc_mblock_mperblock_nblock_nperblock;
281 }
282
283 // Ds desc for source in blockwise copy
284 template <typename DsGridDesc_M_N>
285 __host__ __device__ static constexpr auto
286 MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const DsGridDesc_M_N& ds_grid_desc_m_n)
287 {
288 return generate_tuple(
289 [&](auto i) {
291 },
293 }
294
295 // return block_id to E matrix tile idx (m0, n0) mapping
296 template <typename EGridDesc_M_N>
297 __host__ __device__ static constexpr auto
298 MakeDefaultBlock2ETileMap(const EGridDesc_M_N& e_grid_desc_m_n)
299 {
301 e_grid_desc_m_n);
302 }
303
305
306 // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
307 template <typename AsGridDesc_M_K,
308 typename BsGridDesc_N_K,
309 typename DsGridDesc_M_N,
310 typename EGridDesc_M_N,
311 typename Block2ETileMap>
312 __host__ __device__ static constexpr bool CheckValidity(const AsGridDesc_M_K& as_grid_desc_m_k,
313 const BsGridDesc_N_K& bs_grid_desc_n_k,
314 const DsGridDesc_M_N& ds_grid_desc_m_n,
315 const EGridDesc_M_N& e_grid_desc_m_n,
316 const Block2ETileMap& block_2_etile_map)
317 {
318 static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
319 (NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
320 "Invalid tuning param!");
321
322 static_assert(KPerBlock % AK1Value == 0 && KPerBlock % BK1Value == 0,
323 "KPerBlock must be divisible by AK1Value and BK1Value!");
324
325 const auto M = as_grid_desc_m_k[I0].GetLength(I0);
326 const auto N = bs_grid_desc_n_k[I0].GetLength(I0);
327 const auto AK = as_grid_desc_m_k[I0].GetLength(I1);
328 const auto BK = bs_grid_desc_n_k[I0].GetLength(I1);
329
330 // check consistency of desc
331 if(!(M == e_grid_desc_m_n.GetLength(I0) && N == e_grid_desc_m_n.GetLength(I1) && AK == BK))
332 {
333 return false;
334 }
335
336 constexpr long_index_t TwoGB = (long_index_t{1} << 31);
337
338 bool valid = true;
339 static_for<0, NumATensor, 1>{}([&](auto i) {
340 using ADataType = remove_cvref_t<tuple_element_t<i.value, AsDataType>>;
341 valid =
342 valid && (as_grid_desc_m_k[i].GetElementSpaceSize() * sizeof(ADataType) <= TwoGB);
343 valid = valid && (M == as_grid_desc_m_k[i].GetLength(I0) &&
344 AK == as_grid_desc_m_k[i].GetLength(I1));
345 });
346
347 static_for<0, NumBTensor, 1>{}([&](auto i) {
348 using BDataType = remove_cvref_t<tuple_element_t<i.value, BsDataType>>;
349 valid =
350 valid && (bs_grid_desc_n_k[i].GetElementSpaceSize() * sizeof(BDataType) <= TwoGB);
351 valid = valid && (N == bs_grid_desc_n_k[i].GetLength(I0) &&
352 BK == bs_grid_desc_n_k[i].GetLength(I1));
353 });
354
355 static_for<0, NumDTensor, 1>{}([&](auto i) {
356 valid = valid && (M == ds_grid_desc_m_n[i].GetLength(I0) &&
357 N == ds_grid_desc_m_n[i].GetLength(I1));
358 });
359
360 if(!valid)
361 {
362 return false;
363 }
364
365 // check tile size
366 if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && AK % KPerBlock == 0))
367 {
368 return false;
369 }
370
371 // check gridwise gemm pipeline
372 const auto num_k_loop = AK / KPerBlock;
373
374 if(!GridwiseGemmPipe::IsSupported(num_k_loop))
375 {
376 return false;
377 }
378
379 // check block-to-E-tile
380 if(!block_2_etile_map.CheckValidity(e_grid_desc_m_n))
381 {
382 return false;
383 }
384
385 // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
386 // check tensor size: cannot be larger than 2GB each
387
388 if(!(e_grid_desc_m_n.GetElementSpaceSize() * sizeof(EDataType) <= TwoGB))
389 {
390 return false;
391 }
392
393 return true;
394 }
395
396 __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
397 {
398 const index_t num_loop = K / KPerBlock;
399
400 return GridwiseGemmPipe::CalculateHasMainLoop(num_loop);
401 }
402
403 using AsGridPointer = decltype(MakeAsGridPointer());
404 using BsGridPointer = decltype(MakeBsGridPointer());
405 using DsGridPointer = decltype(MakeDsGridPointer());
406
407 template <typename ALayout, GemmSpecialization GemmSpec>
408 __host__ __device__ static auto
410 {
411 constexpr auto matrix_padder =
413 MPerBlock, NPerBlock, KPerBlock};
414
415 const auto a_grid_desc_mraw_kraw = [&]() {
417 {
418 return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw),
419 make_tuple(StrideA, I1));
420 }
422 {
423 return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw),
424 make_tuple(I1, StrideA));
425 }
426 }();
427
428 return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw);
429 }
430
431 template <typename AsLayout, GemmSpecialization GemmSpec>
432 __host__ __device__ static auto MakeAsGridDescriptor_M_K(
433#ifdef CK_CODE_GEN_RTC
436 const ck::Array<index_t, NumATensor>& AsStride
437#else
438 const std::array<index_t, NumATensor>& MRaws,
439 const std::array<index_t, NumATensor>& KRaws,
440 const std::array<index_t, NumATensor>& AsStride
441#endif
442 )
443 {
444 return generate_tuple(
445 [&](auto i) {
446 using ALayout = remove_cvref_t<tuple_element_t<i.value, AsLayout>>;
447
448 return MakeAGridDescriptor_M_K<ALayout, GemmSpec>(MRaws[i], KRaws[i], AsStride[i]);
449 },
451 }
452
453 template <typename BLayout, GemmSpecialization GemmSpec>
454 __host__ __device__ static auto
455 MakeBGridDescriptor_N_K(const index_t NRaw, const index_t KRaw, const index_t StrideB)
456 {
457 constexpr auto matrix_padder =
459 MPerBlock, NPerBlock, KPerBlock};
460
461 const auto b_grid_desc_nraw_kraw = [&]() {
463 {
464 return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
465 make_tuple(I1, StrideB));
466 }
468 {
469 return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
470 make_tuple(StrideB, I1));
471 }
472 }();
473
474 return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw);
475 }
476
477 template <typename BsLayout, GemmSpecialization GemmSpec>
478 __host__ __device__ static auto MakeBsGridDescriptor_N_K(
479#ifdef CK_CODE_GEN_RTC
482 const ck::Array<index_t, NumBTensor>& BsStride
483#else
484 const std::array<index_t, NumBTensor>& NRaws,
485 const std::array<index_t, NumBTensor>& KRaws,
486 const std::array<index_t, NumBTensor>& BsStride
487#endif
488 )
489 {
490 return generate_tuple(
491 [&](auto i) {
492 using BLayout = remove_cvref_t<tuple_element_t<i.value, BsLayout>>;
493
494 return MakeBGridDescriptor_N_K<BLayout, GemmSpec>(NRaws[i], KRaws[i], BsStride[i]);
495 },
497 }
498
499 template <typename ELayout, GemmSpecialization GemmSpec>
500 __host__ __device__ static auto
502 {
503 constexpr auto matrix_padder =
505 MPerBlock, NPerBlock, KPerBlock};
506 const auto e_grid_desc_mraw_nraw = [&]() {
508 {
509 return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw),
510 make_tuple(StrideE, I1));
511 }
513 {
514 return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw),
515 make_tuple(I1, StrideE));
516 }
517 }();
518
519 return matrix_padder.PadCDescriptor_M_N(e_grid_desc_mraw_nraw);
520 }
521
522 template <typename DsLayout, GemmSpecialization GemmSpec>
523 __host__ __device__ static auto MakeDsGridDescriptor_M_N(
524#ifdef CK_CODE_GEN_RTC
527 const ck::Array<index_t, NumDTensor>& DsStride
528#else
529 const std::array<index_t, NumDTensor>& MRaws,
530 const std::array<index_t, NumDTensor>& NRaws,
531 const std::array<index_t, NumDTensor>& DsStride
532#endif
533 )
534 {
535 return generate_tuple(
536 [&](auto i) {
537 using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
538
539 return MakeEGridDescriptor_M_N<DLayout, GemmSpec>(MRaws[i], NRaws[i], DsStride[i]);
540 },
542 }
543
544 __device__ __host__ static constexpr auto GetMPerBlock() { return MPerBlock; }
545
546 template <bool HasMainKBlockLoop,
547 typename AsGridDesc_AK0_M_AK1,
548 typename BsGridDesc_BK0_N_BK1,
549 typename DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
550 typename EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
551 typename Block2ETileMap>
552 __device__ static void Run(AsGridPointer p_as_grid,
553 BsGridPointer p_bs_grid,
554 DsGridPointer p_ds_grid,
555 EDataType* __restrict__ p_e_grid,
556 void* __restrict__ p_shared,
557 const AElementwiseOperation& a_element_op,
558 const BElementwiseOperation& b_element_op,
559 const CDEElementwiseOperation& cde_element_op,
560 const AsGridDesc_AK0_M_AK1 as_grid_desc_ak0_m_ak1,
561 const BsGridDesc_BK0_N_BK1 bs_grid_desc_bk0_n_bk1,
562 const DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
563 ds_grid_desc_mblock_mperblock_nblock_nperblock,
564 const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
565 e_grid_desc_mblock_mperblock_nblock_nperblock,
566 const Block2ETileMap& block_2_etile_map)
567 {
568 const auto as_grid_buf = generate_tuple(
569 [&](auto i) {
571 p_as_grid[i], as_grid_desc_ak0_m_ak1[i].GetElementSpaceSize());
572 },
574
575 const auto bs_grid_buf = generate_tuple(
576 [&](auto i) {
578 p_bs_grid[i], bs_grid_desc_bk0_n_bk1[i].GetElementSpaceSize());
579 },
581
582 const auto ds_grid_buf = generate_tuple(
583 [&](auto i) {
585 p_ds_grid[i],
586 ds_grid_desc_mblock_mperblock_nblock_nperblock[i].GetElementSpaceSize());
587 },
589
591 p_e_grid, e_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
592
593 // divide block work by [M, N]
594 const auto block_work_idx =
595 block_2_etile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
596
597 if(!block_2_etile_map.ValidCTileIndex(
598 block_work_idx,
599 make_tuple(e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
600 e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
601 {
602 return;
603 }
604 // HACK: this force m/n_block_data_idx_on_grid into SGPR
605 const index_t m_block_data_idx_on_grid =
606 __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock);
607
608 const index_t n_block_data_idx_on_grid =
609 __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock);
610
611 // lds max alignment
612 constexpr auto max_lds_align = math::lcm(AK1, BK1);
613
614 // A matrix in LDS memory, dst of blockwise copy
615 constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
616
617 // B matrix in LDS memory, dst of blockwise copy
618 constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
619
620 const auto idx_as_block_begin =
621 generate_tuple([&](auto) { return make_multi_index(0, m_block_data_idx_on_grid, 0); },
623
624 auto a_blockwise_copy = ThreadGroupTensorSliceTransfer_v7r2<
626 AsDataType,
628 decltype(as_grid_desc_ak0_m_ak1),
629 decltype(tie(a_block_desc_ak0_m_ak1)),
630 AElementwiseOperation,
633 ABlockTransferThreadClusterLengths_AK0_M_AK1,
634 ABlockTransferThreadClusterArrangeOrder,
635 ABlockTransferSrcAccessOrder,
637 ABlockTransferSrcVectorDim,
638 2,
639 ABlockTransferSrcScalarPerVector,
640 ABlockTransferDstScalarPerVector_AK1,
642 Sequence<true>>{as_grid_desc_ak0_m_ak1,
643 idx_as_block_begin,
644 tie(a_block_desc_ak0_m_ak1),
646 a_element_op};
647
648 const auto idx_bs_block_begin =
649 generate_tuple([&](auto) { return make_multi_index(0, n_block_data_idx_on_grid, 0); },
651
652 auto b_blockwise_copy = ThreadGroupTensorSliceTransfer_v7r2<
654 BsDataType,
656 decltype(bs_grid_desc_bk0_n_bk1),
657 decltype(tie(b_block_desc_bk0_n_bk1)),
658 BElementwiseOperation,
661 BBlockTransferThreadClusterLengths_BK0_N_BK1,
662 BBlockTransferThreadClusterArrangeOrder,
663 BBlockTransferSrcAccessOrder,
665 BBlockTransferSrcVectorDim,
666 2,
667 BBlockTransferSrcScalarPerVector,
668 BBlockTransferDstScalarPerVector_BK1,
670 Sequence<true>>{bs_grid_desc_bk0_n_bk1,
671 idx_bs_block_begin,
672 tie(b_block_desc_bk0_n_bk1),
674 b_element_op};
675
676 // GEMM definition
677 // c_mtx += transpose(a_mtx) * b_mtx
678 // a_mtx[K0PerBlock, MPerBlock] is in LDS
679 // b_mtx[K0PerBlock, NPerBlock] is in LDS
680 // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
681 // register
682 // sanity check
683 constexpr auto lcm_AK1_BK1 = math::lcm(AK1, BK1);
684 constexpr bool is_single_rate_mfma =
687 lcm_AK1_BK1 <= 4) ||
688 (is_same<AComputeDataType_, int8_t>::value && lcm_AK1_BK1 <= 8) ||
691 lcm_AK1_BK1 < 32))
692 ? true
693 : false;
694 static constexpr auto is_scale_mfma = false;
695 constexpr index_t KPack = math::max(lcm_AK1_BK1,
696 MfmaSelector<AComputeDataType_,
697 MPerXdl,
698 NPerXdl,
699 BComputeDataType_,
700 is_single_rate_mfma,
701 is_scale_mfma>::selected_mfma.k_per_blk);
702
704 BlockSize,
707 AccDataType,
708 decltype(a_block_desc_ak0_m_ak1),
709 decltype(b_block_desc_bk0_n_bk1),
710 MPerXdl,
711 NPerXdl,
712 MXdlPerWave,
713 NXdlPerWave,
714 KPack,
715 LoopSched,
716 AComputeDataType_,
717 BComputeDataType_>();
718
719 auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();
720
721 // LDS allocation for A and B: be careful of alignment
722 constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
723 a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
724
726 static_cast<AElementDataType*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
727
729 static_cast<BElementDataType*>(p_shared) + a_block_space_size_aligned,
730 b_block_desc_bk0_n_bk1.GetElementSpaceSize());
731
732 constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1, 0, 0);
733 constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1, 0, 0);
734
735 const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
736 (as_grid_desc_ak0_m_ak1[I0].GetLength(I0) * as_grid_desc_ak0_m_ak1[I0].GetLength(I2)) /
737 KPerBlock);
738
739 // gridwise GEMM pipeline
740 const auto gridwise_gemm_pipeline =
742
743 gridwise_gemm_pipeline.template Run<HasMainKBlockLoop>(as_grid_desc_ak0_m_ak1,
744 a_block_desc_ak0_m_ak1,
745 a_blockwise_copy,
746 as_grid_buf,
747 a_block_buf,
748 a_block_slice_copy_step,
749 bs_grid_desc_bk0_n_bk1,
750 b_block_desc_bk0_n_bk1,
751 b_blockwise_copy,
752 bs_grid_buf,
753 b_block_buf,
754 b_block_slice_copy_step,
755 blockwise_gemm,
756 c_thread_buf,
757 num_k_block_main_loop);
758
759 // shuffle C and write out
760 {
761 static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
762 NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
763 "wrong!");
764
765 constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
766 constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
767
768 // TODO: hacky, fix it!
769 constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
770 blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
771
772 // TODO: hacky, fix it!
773 // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
774 constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
775 blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
776
777 constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0);
778 constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1);
779 constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2);
780 constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3);
781 constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4);
782 constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5);
783 constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6);
784 constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7);
785
786 constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
788
789 auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
790 static_cast<CShuffleDataType*>(p_shared),
791 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
792
793 constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
794 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
798 Number<CShuffleMXdlPerWavePerShuffle>{}, // M0 (MXdlPerWave) per shuffle
799 M1, // M1 = MWave
800 M2, // M2 * M3 * M4 = MPerXdl
801 M3,
802 M4)),
805 Number<CShuffleNXdlPerWavePerShuffle>{}, // N0 (NXdlPerWave) per shuffle
806 N1, // N1 = NWave
807 N2))), // N2 = NPerXdl
811
812 // calculate origin of thread output tensor on global memory
813 // blockwise GEMM c matrix starting index
814 const auto c_thread_mtx_on_block =
815 blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
816
817 const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
818 const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
819
820 const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
822 make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
825
826 const auto m_thread_data_on_block_idx =
827 m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
828 make_multi_index(m_thread_data_on_block));
829
830 const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
835
836 const auto n_thread_data_on_block_idx =
837 n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
838 make_multi_index(n_thread_data_on_block));
839
840 // shuffle: threadwise copy C from VGPR to LDS
841 auto c_thread_copy_vgpr_to_lds =
843 CShuffleDataType,
844 decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
845 decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
847 Sequence<CShuffleMXdlPerWavePerShuffle,
848 CShuffleNXdlPerWavePerShuffle,
849 I1,
850 I1,
851 M2,
852 I1,
853 M4,
854 I1>,
856 7,
857 1,
859 1,
860 true>{
861 c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
863 0,
864 m_thread_data_on_block_idx[I1],
865 n_thread_data_on_block_idx[I1],
866 m_thread_data_on_block_idx[I2],
867 m_thread_data_on_block_idx[I3],
868 m_thread_data_on_block_idx[I4],
869 n_thread_data_on_block_idx[I2]),
871
872 // tuple of reference to C/Ds tensor descriptors
873 const auto c_ds_desc_refs = concat_tuple_of_reference(
874 tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
875 generate_tie([&](auto i) -> const auto& // return type should be reference
876 { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; },
878
879 // tuple of reference to C/Ds tensor descriptors
880 const auto c_ds_buf_refs = concat_tuple_of_reference(
881 tie(c_shuffle_block_buf),
882 generate_tie([&](auto i) -> const auto& // return type should be reference
883 { return ds_grid_buf[i]; },
885
886 // tuple of starting index of C/Ds blockwise copy
887 const auto idx_c_ds_block_begin = container_concat(
888 make_tuple(make_multi_index(0, 0, 0, 0)),
890 [&](auto) {
891 return make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0);
892 },
894
895 // blockwise copy C/D/E between LDS and global
896 auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r2<
898 decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})),
900 decltype(c_ds_desc_refs),
901 decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)),
902 CDEElementwiseOperation,
903 Sequence<static_cast<index_t>(EGlobalMemoryDataOperation)>, // FIXME: make Sequence
904 // support arbitray type
905 Sequence<1,
906 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
907 1,
908 CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
909 CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
910 Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
911 Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder,
912 Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder,
913 3, // index_t SrcVectorDim,
914 3, // index_t DstVectorDim,
915 CDEShuffleBlockTransferScalarPerVector_NPerBlock,
916 CDEShuffleBlockTransferScalarPerVector_NPerBlock,
920 false>>, // ThreadTransferSrcResetCoordinateAfterRunFlags
921 Sequence<false>> // ThreadTransferDstResetCoordinateAfterRunFlags
922 {c_ds_desc_refs,
923 idx_c_ds_block_begin,
924 tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
925 make_tuple(make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0)),
926 cde_element_op};
927
928 // space filling curve for threadwise C in VGPR before shuffle
929 constexpr auto sfc_c_vgpr =
932 Sequence<CShuffleMXdlPerWavePerShuffle,
933 CShuffleNXdlPerWavePerShuffle,
934 1,
935 1,
936 M2,
937 1,
938 M4,
939 1>>{};
940
941 // space filling curve for shuffled blockwise C/D/E
942 constexpr auto sfc_cde_block =
945 Sequence<1,
946 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
947 1,
948 CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
949
950 constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
951
952 static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!");
953
954 static_for<0, num_access, 1>{}([&](auto access_id) {
955 // make sure it's safe to write to LDS
957
958 // each thread write its data from VGPR to LDS
959 c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
960 sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
961 c_thread_buf,
962 c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
963 c_shuffle_block_buf);
964
965 // make sure it's safe to read from LDS
967
968 // each block copy its data from LDS to global
969 cde_block_copy_lds_and_global.Run(
970 c_ds_desc_refs,
971 c_ds_buf_refs,
972 tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
973 tie(e_grid_buf));
974
975 if constexpr(access_id < num_access - 1)
976 {
977 constexpr auto cde_lds_and_global_step =
978 sfc_cde_block.GetForwardStep(access_id);
979
980 // move on Ds
981 static_for<0, NumDTensor, 1>{}([&](auto i) {
982 cde_block_copy_lds_and_global.MoveSrcSliceWindow(
983 c_ds_desc_refs, i + I1, cde_lds_and_global_step);
984 });
985
986 // move on E
987 cde_block_copy_lds_and_global.MoveDstSliceWindow(
988 tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
989 I0,
990 cde_lds_and_global_step);
991 }
992 });
993 }
994 }
995
996 template <bool HasMainKBlockLoop,
997 GemmSpecialization GemmSpec,
998 typename AsLayout,
999 typename BsLayout,
1000 typename DsLayout,
1001 typename ELayout,
1002 typename Block2ETileMap>
1003 __device__ static void Run(AsGridPointer p_as_grid,
1004 BsGridPointer p_bs_grid,
1005 DsGridPointer p_ds_grid,
1006 void* __restrict__ p_e_grid_,
1007 void* __restrict__ p_shared,
1008 const AElementwiseOperation& a_element_op,
1009 const BElementwiseOperation& b_element_op,
1010 const CDEElementwiseOperation& cde_element_op,
1011 const index_t M,
1012 const index_t N,
1013 const index_t K,
1014#ifdef CK_CODE_GEN_RTC
1015 const ck::Array<index_t, NumATensor> StrideAs,
1016 const ck::Array<index_t, NumBTensor> StrideBs,
1017 const ck::Array<index_t, NumDTensor> StrideDs,
1018#else
1019 const std::array<index_t, NumATensor> StrideAs,
1020 const std::array<index_t, NumBTensor> StrideBs,
1021 const std::array<index_t, NumDTensor> StrideDs,
1022#endif
1023 const index_t StrideE,
1024 const Block2ETileMap& block_2_etile_map)
1025 {
1026 using AsGridDesc_M_K =
1028 using BsGridDesc_N_K =
1030 using DsGridDesc_M_N =
1032
1033 const auto p_e_grid = reinterpret_cast<EDataType*>(p_e_grid_);
1034
1035 AsGridDesc_M_K as_grid_desc_m_k;
1036 BsGridDesc_N_K bs_grid_desc_n_k;
1037 DsGridDesc_M_N ds_grid_desc_m_n;
1038
1039 static_for<0, NumATensor, 1>{}([&](auto j) {
1040 using ALayout = remove_cvref_t<tuple_element_t<j.value, AsLayout>>;
1041
1042 as_grid_desc_m_k(j) = MakeAGridDescriptor_M_K<ALayout, GemmSpec>(M, K, StrideAs[j]);
1043 });
1044
1045 static_for<0, NumBTensor, 1>{}([&](auto j) {
1046 using BLayout = remove_cvref_t<tuple_element_t<j.value, BsLayout>>;
1047
1048 bs_grid_desc_n_k(j) = MakeBGridDescriptor_N_K<BLayout, GemmSpec>(N, K, StrideBs[j]);
1049 });
1050
1051 static_for<0, NumDTensor, 1>{}([&](auto j) {
1052 using DLayout = remove_cvref_t<tuple_element_t<j.value, DsLayout>>;
1053
1054 ds_grid_desc_m_n(j) = MakeEGridDescriptor_M_N<DLayout, GemmSpec>(M, N, StrideDs[j]);
1055 });
1056
1057 const auto e_grid_desc_m_n = MakeEGridDescriptor_M_N<ELayout, GemmSpec>(M, N, StrideE);
1058
1059 // tensor descriptors for block/thread-wise copy
1060 const auto as_grid_desc_ak0_m_ak1 = MakeDefaultAsGridDescriptor_AK0_M_AK1(as_grid_desc_m_k);
1061
1062 const auto bs_grid_desc_bk0_n_bk1 = MakeDefaultBsGridDescriptor_BK0_N_BK1(bs_grid_desc_n_k);
1063
1064 const auto ds_grid_desc_mblock_mperblock_nblock_nperblock =
1066
1067 const auto e_grid_desc_mblock_mperblock_nblock_nperblock =
1069
1070 Run<HasMainKBlockLoop>(p_as_grid,
1071 p_bs_grid,
1072 p_ds_grid,
1073 p_e_grid,
1074 p_shared,
1075 a_element_op,
1076 b_element_op,
1077 cde_element_op,
1078 as_grid_desc_ak0_m_ak1,
1079 bs_grid_desc_bk0_n_bk1,
1080 ds_grid_desc_mblock_mperblock_nblock_nperblock,
1081 e_grid_desc_mblock_mperblock_nblock_nperblock,
1082 block_2_etile_map);
1083 }
1084};
1085
1086} // namespace ck
#define IS_VALID_COMPILATION_PARAMETER_IMPL(CDataType_)
Definition device_base.hpp:178
__host__ __device__ constexpr auto integer_least_multiple(X x, Y y)
Definition utility/math.hpp:78
__host__ __device__ constexpr T max(T x)
Definition utility/math.hpp:84
__host__ __device__ constexpr auto lcm(X x, Y y)
Definition utility/math.hpp:198
GemmSpecialization
Definition gemm_specialization.hpp:11
Definition ck.hpp:268
ushort bhalf_t
Definition data_type.hpp:30
__host__ __device__ constexpr auto make_multi_index(Xs &&... xs)
Definition array_multi_index.hpp:15
typename uniform_sequence_gen< NSize, I >::type uniform_sequence_gen_t
Definition utility/sequence.hpp:928
constexpr auto BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector()
Definition blockwise_gemm_xdlops.hpp:620
__host__ __device__ constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition multi_index_transform_helper.hpp:12
__host__ __device__ constexpr auto container_concat(const X &x, const Ys &... ys)
Definition utility/container_helper.hpp:320
constexpr auto GridwiseGemmPipeline_Selector()
Definition gridwise_gemm_pipeline_selector.hpp:31
int32_t index_t
Definition ck.hpp:299
typename conditional< predicate, X, Y >::type conditional_t
Definition utility/functional.hpp:115
__host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition tensor_descriptor_helper.hpp:49
InMemoryDataOperationEnum
Definition ck.hpp:277
@ Set
Definition ck.hpp:278
__host__ __device__ constexpr auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition tensor_description/tensor_adaptor.hpp:425
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
__host__ __device__ constexpr auto make_freeze_transform(const LowerIndex &low_idx)
Definition multi_index_transform_helper.hpp:151
constexpr Tuple< Args &... > tie(Args &... args) noexcept
Definition utility/tuple.hpp:218
integral_constant< index_t, N > Number
Definition number.hpp:12
typename tuple_element< I, TTuple >::type tuple_element_t
Definition utility/tuple.hpp:208
__host__ __device__ constexpr auto make_merge_transform(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:55
__device__ index_t get_block_1d_id()
Definition get_id.hpp:47
constexpr bool is_same_v
Definition type.hpp:283
__host__ __device__ constexpr auto generate_tuple(F &&f, Number< N >)
Definition tuple_helper.hpp:21
__host__ __device__ constexpr auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition tensor_descriptor_helper.hpp:101
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
LoopScheduler
Definition loop_scheduler.hpp:15
typename sequence_merge< Sx, Sy >::type sequence_merge_t
Definition utility/sequence.hpp:925
int64_t long_index_t
Definition ck.hpp:300
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
__device__ void block_sync_lds()
Definition synchronization.hpp:16
PipelineVersion
Definition gridwise_gemm_pipeline_selector.hpp:18
@ v1
Definition gridwise_gemm_pipeline_selector.hpp:19
__host__ __device__ constexpr auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:90
__host__ __device__ constexpr auto make_dynamic_buffer(T *p, ElementSpaceSize element_space_size)
Definition dynamic_buffer.hpp:472
__host__ __device__ constexpr auto generate_tie(F &&f, Number< N >)
Definition tuple_helper.hpp:34
__host__ __device__ constexpr auto concat_tuple_of_reference(const Tuple< X &... > &tx, const Tuple< Y &... > &ty)
Definition tuple_helper.hpp:42
Definition utility/array.hpp:14
Definition block_to_ctile_map.hpp:261
Definition gridwise_gemm_multiple_abd_xdl_cshuffle.hpp:77
static __device__ void Run(AsGridPointer p_as_grid, BsGridPointer p_bs_grid, DsGridPointer p_ds_grid, EDataType *__restrict__ p_e_grid, void *__restrict__ p_shared, const AElementwiseOperation &a_element_op, const BElementwiseOperation &b_element_op, const CDEElementwiseOperation &cde_element_op, const AsGridDesc_AK0_M_AK1 as_grid_desc_ak0_m_ak1, const BsGridDesc_BK0_N_BK1 bs_grid_desc_bk0_n_bk1, const DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock &ds_grid_desc_mblock_mperblock_nblock_nperblock, const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock &e_grid_desc_mblock_mperblock_nblock_nperblock, const Block2ETileMap &block_2_etile_map)
Definition gridwise_gemm_multiple_abd_xdl_cshuffle.hpp:552
static __device__ void Run(AsGridPointer p_as_grid, BsGridPointer p_bs_grid, DsGridPointer p_ds_grid, void *__restrict__ p_e_grid_, void *__restrict__ p_shared, const AElementwiseOperation &a_element_op, const BElementwiseOperation &b_element_op, const CDEElementwiseOperation &cde_element_op, const index_t M, const index_t N, const index_t K, const std::array< index_t, NumATensor > StrideAs, const std::array< index_t, NumBTensor > StrideBs, const std::array< index_t, NumDTensor > StrideDs, const index_t StrideE, const Block2ETileMap &block_2_etile_map)
Definition gridwise_gemm_multiple_abd_xdl_cshuffle.hpp:1003
Selects the appropriate MFMA instruction type and configuration for given data types and tile sizes o...
Definition xdlops_gemm.hpp:1208
Definition utility/sequence.hpp:43
Definition tensor_space_filling_curve.hpp:20
Definition thread_group_tensor_slice_transfer_v7r2.hpp:47
Definition threadwise_tensor_slice_transfer.hpp:39
Definition utility/tuple.hpp:117
static constexpr value_type value
Definition utility/integral_constant.hpp:13
Definition functional2.hpp:33
__host__ __device__ constexpr auto PadBDescriptor_N_K(const BDesc_NRaw_KRaw &b_desc_nraw_kraw) const
Definition matrix_padder.hpp:155
__host__ __device__ constexpr auto PadCDescriptor_M_N(const CDesc_MRaw_NRaw &c_desc_mraw_nraw) const
Definition matrix_padder.hpp:163
__host__ __device__ constexpr auto PadADescriptor_M_K(const ADesc_MRaw_KRaw &a_desc_mraw_kraw) const
Definition matrix_padder.hpp:147
Definition matrix_padder.hpp:180
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:340