device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp Source File

device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp Source File#

Composable Kernel: device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp Source File
device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include <functional>
7#include <iostream>
8#include <iterator>
9#include <numeric>
10#include <sstream>
11
26
27#ifdef CK_EXPERIMENTAL_BUILDER
28#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp"
29#endif
30
31namespace ck {
32namespace tensor_operation {
33namespace device {
34
35namespace {
36
37/*
38 * \brief Wrapper function of GridwiseGemm::Run to realize BatchedGEMM.
39 *
40 * \tparam ComputePtrOffsetOfBatch Class that computes the base pointer offsets of A, B, C matrix
41 * given the batch. For example, ComputePtrOffsetOfStridedBatch() computes the offsets of evenly
42 * strided batched, but we can easily extend to other layouts. The returned offset can be either \p
43 * index_t or \p long_index_t. If it returns \p long_index_t, we are not subject to the 2GB
44 * limitations.
45 *
46 * \tparam Block2ETileMap Block2ETileMap::CalculateBottomIndex() takes in id of a workgroup and
47 * returns the 2D index of the tile that it computes. \see
48 * GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3::Run().
49 *
50 * \note Using \p ComputePtrOffsetOfBatch gives us the flexibility that 2 workgroups can compute 2
51 * tiles from different matrices. Keep in mind that these 2 matrices can share the same grid
52 * descriptor (like in BatchedGEMM), or use their own grid descriptors (in GroupedGemm). \link
53 * device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp kernel_gemm_xdlops_v2r3_for_conv3d \endlink for \link
54 * DeviceConv3d \endlink uses the same concept, but currently does NOT encapsulate the computing of
55 * pointer offset into \p ComputePtrOffsetOfStridedBatch.
56 *
57 * \note \p Block2ETileMap allows customized mapping between a workgroup and the C-tile it computes.
58 * Together with \p ComputePtrOffsetOfBatch, we can reuse GridwiseGemm (and GridwiseGemm fusion ) to
59 * realize BatchedGemm and GroupedGemm (and the corresponding GEMM fusion).
60 *
61 */
62template <typename GridwiseGemm,
63 typename ABDataType,
64 typename DsPointer,
65 typename EDataType,
66 typename AElementwiseOperation,
67 typename BElementwiseOperation,
68 typename CDEElementwiseOperation,
69 typename AGridDesc_K0_M0_M1_K1,
70 typename BGridDesc_K0_N0_N1_K1,
71 typename DsGridDesc_M0_M10_M11_N0_N10_N11,
72 typename CGridDesc_M0_M10_M11_N0_N10_N11,
73 typename Block2CTileMap,
74 typename ComputePtrOffsetOfBatch,
75 bool HasMainKBlockLoop,
76 bool HasDoubleTailKBlockLoop>
77__global__ void
78#if CK_USE_LAUNCH_BOUNDS
80#endif
81 kernel_grouped_conv_fwd_dl_multiple_d(
82 const ABDataType* __restrict__ p_a_grid,
83 const ABDataType* __restrict__ p_b_grid,
84 DsPointer p_ds_grid,
85 EDataType* __restrict__ p_e_grid,
86 const AElementwiseOperation a_element_op,
87 const BElementwiseOperation b_element_op,
88 const CDEElementwiseOperation cde_element_op,
89 const index_t batch_count,
90 const AGridDesc_K0_M0_M1_K1 a_grid_desc_k0_m0_m1_k1,
91 const BGridDesc_K0_N0_N1_K1 b_grid_desc_k0_n0_n1_k1,
92 const DsGridDesc_M0_M10_M11_N0_N10_N11 ds_grid_desc_m0_m10_m11_n0_n10_n11,
93 const CGridDesc_M0_M10_M11_N0_N10_N11 e_grid_desc_m0_m10_m11_n0_n10_n11,
94 const Block2CTileMap block_2_ctile_map,
95 const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch)
96{
97#if(defined(__gfx906__) || defined(__gfx103__) || defined(__gfx90a__) || defined(__gfx908__) || \
98 defined(__gfx94__) || defined(__gfx11__) || defined(__gfx12__))
99 // offset base pointer for each work-group
100 const index_t num_blocks_per_batch =
101 __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
102 const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
103
104 const long_index_t a_batch_offset = amd_wave_read_first_lane(
105 static_cast<long_index_t>(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)));
106 const long_index_t b_batch_offset = amd_wave_read_first_lane(
107 static_cast<long_index_t>(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)));
108 const long_index_t c_batch_offset = amd_wave_read_first_lane(
109 static_cast<long_index_t>(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx)));
110
111 const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx);
112
113 constexpr index_t shared_block_size =
114 GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(ABDataType);
115
116 __shared__ ABDataType p_shared[shared_block_size];
117
118 DsPointer p_ds_grid_grp;
119
120 static constexpr index_t NumDTensor = DsGridDesc_M0_M10_M11_N0_N10_N11::Size();
121
122 static_for<0, NumDTensor, 1>{}(
123 [&](auto i) { p_ds_grid_grp(i) = p_ds_grid[i] + ds_batch_offset[i]; });
124
125 GridwiseGemm::Run(p_a_grid + a_batch_offset,
126 p_b_grid + b_batch_offset,
127 p_ds_grid_grp,
128 p_e_grid + c_batch_offset,
129 p_shared,
130 a_element_op,
131 b_element_op,
132 cde_element_op,
133 a_grid_desc_k0_m0_m1_k1,
134 b_grid_desc_k0_n0_n1_k1,
135 ds_grid_desc_m0_m10_m11_n0_n10_n11,
136 e_grid_desc_m0_m10_m11_n0_n10_n11,
137 block_2_ctile_map,
138 integral_constant<bool, HasMainKBlockLoop>{},
139 integral_constant<bool, HasDoubleTailKBlockLoop>{});
140#else
141 ignore = p_a_grid;
142 ignore = p_b_grid;
143 ignore = p_ds_grid;
144 ignore = p_e_grid;
145 ignore = a_element_op;
146 ignore = b_element_op;
147 ignore = cde_element_op;
148 ignore = batch_count;
149 ignore = a_grid_desc_k0_m0_m1_k1;
150 ignore = b_grid_desc_k0_n0_n1_k1;
151 ignore = ds_grid_desc_m0_m10_m11_n0_n10_n11;
152 ignore = e_grid_desc_m0_m10_m11_n0_n10_n11;
153 ignore = compute_ptr_offset_of_batch;
154 ignore = block_2_ctile_map;
155
156 compute_ptr_offset_of_batch.GetAPtrOffset(0);
157 compute_ptr_offset_of_batch.GetBPtrOffset(0);
158 compute_ptr_offset_of_batch.GetEPtrOffset(0);
159#endif
160}
161} // namespace
162
163//
164// @brief Device Convolution operation.
165//
166// Supports:
167// @li Forward convolution with up to 3 spatial dimentions
168// @li Input tensor in GNWC data format
169// @li Weight tensor in GKXC data format
170// @li Output tensor in GNWK data format
171//
172// 1D:
173// out[N, Wo, K] = in[N, Wi, C] * wei[K, X, C]
174// 2D:
175// out[N, Ho, Wo, K] = in[N, Hi, Wi, C] * wei[K, Y, X, C]
176// 3D:
177// out[N, Do, Ho, Wo, K] = in[N, Di, Hi, Wi, C] * wei[K, Z, Y, X, C]
178//
179template <index_t NDimSpatial,
180 typename ADataType,
181 typename BDataType,
182 typename DsDataType,
183 typename EDataType,
184 typename AccDataType,
185 typename ALayout,
186 typename BLayout,
187 typename DsLayout,
188 typename ELayout,
189 typename AElementwiseOperation,
190 typename BElementwiseOperation,
191 typename CDEElementwiseOperation,
192 ConvolutionForwardSpecialization ConvForwardSpecialization,
193 GemmSpecialization GemmSpec,
194 index_t BlockSize,
195 index_t MPerBlock,
196 index_t NPerBlock,
197 index_t K0PerBlock,
198 index_t K1,
199 index_t M1PerThread,
200 index_t N1PerThread,
201 index_t KPerThread,
202 typename M1N1ThreadClusterM1Xs,
203 typename M1N1ThreadClusterN1Xs,
204 typename ABlockTransferThreadSliceLengths_K0_M0_M1_K1,
205 typename ABlockTransferThreadClusterLengths_K0_M0_M1_K1,
206 typename ABlockTransferThreadClusterArrangeOrder,
207 typename ABlockTransferSrcAccessOrder,
208 typename ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1,
209 typename ABlockTransferSrcVectorTensorContiguousDimOrder,
210 typename ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1,
211 typename BBlockTransferThreadSliceLengths_K0_N0_N1_K1,
212 typename BBlockTransferThreadClusterLengths_K0_N0_N1_K1,
213 typename BBlockTransferThreadClusterArrangeOrder,
214 typename BBlockTransferSrcAccessOrder,
215 typename BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1,
216 typename BBlockTransferSrcVectorTensorContiguousDimOrder,
217 typename BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1,
218 typename CThreadTransferSrcDstAccessOrder,
219 index_t CThreadTransferSrcDstVectorDim,
220 index_t CThreadTransferDstScalarPerVector>
222 : public DeviceGroupedConvFwdMultipleABD<NDimSpatial,
223 ALayout,
224 BLayout,
225 DsLayout,
226 ELayout,
227 ADataType,
228 BDataType,
229 DsDataType,
230 EDataType,
231 AElementwiseOperation,
232 BElementwiseOperation,
233 CDEElementwiseOperation>
234{
236
237 static constexpr index_t NumDTensor = DsDataType::Size();
238
239 static constexpr auto I0 = Number<0>{};
240 static constexpr auto I1 = Number<1>{};
241 static constexpr auto I2 = Number<2>{};
242 static constexpr auto I3 = Number<3>{};
243
245
246 static constexpr auto matrix_padder =
247 MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, K0PerBlock};
248
249 template <typename ALay>
250 static auto
252 {
253 const auto in_gemmmraw_gemmkraw_desc =
254 conv_to_gemm_transformer.template MakeADescriptor_M_K<ALay>();
255
256 const auto in_gemmm_gemmk_desc =
257 matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc);
258
259 const auto M = in_gemmm_gemmk_desc.GetLength(I0);
260 const auto K = in_gemmm_gemmk_desc.GetLength(I1);
261 const auto AK0 = K / K1;
262
264 in_gemmm_gemmk_desc,
268 }
269
270 template <typename BLay>
271 static auto
273 {
274 const auto wei_gemmnraw_gemmkraw_desc =
275 conv_to_gemm_transformer.template MakeBDescriptor_N_K<BLay>();
276
277 const auto wei_gemmn_gemmk_desc =
278 matrix_padder.PadBDescriptor_N_K(wei_gemmnraw_gemmkraw_desc);
279
280 const auto N = wei_gemmn_gemmk_desc.GetLength(I0);
281 const auto K = wei_gemmn_gemmk_desc.GetLength(I1);
282
283 const auto BK0 = K / K1;
284
286 wei_gemmn_gemmk_desc,
290 }
291
292 template <typename ELay>
293 static auto MakeEGridDescriptor_M_N(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
294 {
295 const auto out_gemmmraw_gemmnraw_desc =
296 conv_to_gemm_transformer.template MakeCDescriptor_M_N<ELay>();
297
298 const auto out_gemmm_gemmn_desc =
299 matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc);
300
301 return out_gemmm_gemmn_desc;
302 }
303
304 static auto MakeDsGridDescriptor_M_N(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
305 {
306 return generate_tuple(
307 [&](auto i) {
308 using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
309
310 return DeviceOp::MakeEGridDescriptor_M_N<DLayout>(conv_to_gemm_transformer);
311 },
313 }
314
315 // desc for problem definition
325
326 // GridwiseGemm
329 ADataType,
330 AccDataType,
331 DsDataType,
332 EDataType,
333 AElementwiseOperation,
334 BElementwiseOperation,
335 CDEElementwiseOperation,
340 MPerBlock,
341 NPerBlock,
342 K0PerBlock,
343 K1,
344 M1PerThread,
345 N1PerThread,
346 KPerThread,
347 M1N1ThreadClusterM1Xs,
348 M1N1ThreadClusterN1Xs,
349 ABlockTransferThreadSliceLengths_K0_M0_M1_K1,
350 ABlockTransferThreadClusterLengths_K0_M0_M1_K1,
351 ABlockTransferThreadClusterArrangeOrder,
352 ABlockTransferSrcAccessOrder,
353 ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1,
354 ABlockTransferSrcVectorTensorContiguousDimOrder,
355 ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1,
356 BBlockTransferThreadSliceLengths_K0_N0_N1_K1,
357 BBlockTransferThreadClusterLengths_K0_N0_N1_K1,
358 BBlockTransferThreadClusterArrangeOrder,
359 BBlockTransferSrcAccessOrder,
360 BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1,
361 BBlockTransferSrcVectorTensorContiguousDimOrder,
362 BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1,
363 CThreadTransferSrcDstAccessOrder,
364 CThreadTransferSrcDstVectorDim,
365 CThreadTransferDstScalarPerVector>;
366
377
378 // Argument
379 struct Argument : public BaseArgument
380 {
381 Argument(const void* p_a,
382 const void* p_b,
383 const std::array<const void*, NumDTensor>& p_ds,
384 void* p_e,
385 const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
386 const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
387 const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
388 const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
389 const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>&
390 ds_g_n_k_wos_lengths,
391 const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>&
392 ds_g_n_k_wos_strides,
393 const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
394 const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
395 const std::array<index_t, NDimSpatial>& conv_filter_strides,
396 const std::array<index_t, NDimSpatial>& conv_filter_dilations,
397 const std::array<index_t, NDimSpatial>& input_left_pads,
398 const std::array<index_t, NDimSpatial>& input_right_pads,
399 const AElementwiseOperation& a_element_op,
400 const BElementwiseOperation& b_element_op,
401 const CDEElementwiseOperation& cde_element_op)
402 : p_a_grid_{static_cast<const ADataType*>(p_a)},
403 p_b_grid_{static_cast<const BDataType*>(p_b)},
404 p_ds_grid_{},
405 p_e_grid_{static_cast<EDataType*>(p_e)},
406 num_group_{a_g_n_c_wis_lengths[0]},
407 conv_to_gemm_transformer_{a_g_n_c_wis_lengths,
408 a_g_n_c_wis_strides,
409 b_g_k_c_xs_lengths,
410 b_g_k_c_xs_strides,
411 e_g_n_k_wos_lengths,
412 e_g_n_k_wos_strides,
413 conv_filter_strides,
414 conv_filter_dilations,
415 input_left_pads,
416 input_right_pads},
429 a_element_op_{a_element_op},
430 b_element_op_{b_element_op},
431 cde_element_op_{cde_element_op},
432 a_g_n_c_wis_lengths_{a_g_n_c_wis_lengths},
433 a_g_n_c_wis_strides_{a_g_n_c_wis_strides},
434 b_g_k_c_xs_lengths_{b_g_k_c_xs_lengths},
435 b_g_k_c_xs_strides_{b_g_k_c_xs_strides},
436 e_g_n_k_wos_lengths_{e_g_n_k_wos_lengths},
437 e_g_n_k_wos_strides_{e_g_n_k_wos_strides},
438 conv_filter_strides_{conv_filter_strides},
439 conv_filter_dilations_{conv_filter_dilations},
440 input_left_pads_{input_left_pads},
441 input_right_pads_{input_right_pads}
442 {
443 // A/B/E Batch Stride
444 compute_ptr_offset_of_batch_.BatchStrideA_ = a_g_n_c_wis_strides[0];
445 compute_ptr_offset_of_batch_.BatchStrideB_ = b_g_k_c_xs_strides[0];
446 compute_ptr_offset_of_batch_.BatchStrideE_ = e_g_n_k_wos_strides[0];
447
448 // populate pointer, batch stride, desc for Ds
449 static_for<0, NumDTensor, 1>{}([&](auto i) {
450 using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
451 using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
452
453 ConvToGemmFwdTransformer conv_to_gemm_transformer_d{a_g_n_c_wis_lengths,
454 a_g_n_c_wis_strides,
455 b_g_k_c_xs_lengths,
456 b_g_k_c_xs_strides,
457 ds_g_n_k_wos_lengths[i],
458 ds_g_n_k_wos_strides[i],
459 conv_filter_strides,
460 conv_filter_dilations,
461 input_left_pads,
462 input_right_pads};
463
464 // D pointer
465 p_ds_grid_(i) = static_cast<const DDataType*>(p_ds[i]);
466
467 // D batch stride
468 compute_ptr_offset_of_batch_.BatchStrideDs_(i) = ds_g_n_k_wos_strides[i][0];
469
470 // D desc
472 DeviceOp::MakeEGridDescriptor_M_N<DLayout>(conv_to_gemm_transformer_d);
473 });
474
475 // populate desc for Ds/E
478 {
479
486
489
491 }
492 }
493
494 void Print() const
495 {
496 std::cout << "A[K0, M, K1]: " << a_grid_desc_ak0_m_ak1_ << std::endl;
497 std::cout << "B[K0, N, K1]: " << b_grid_desc_bk0_n_bk1_ << std::endl;
498 std::cout << "E[M, N]: " << e_grid_desc_m_n_ << std::endl;
499 std::cout << "num_group: " << num_group_ << std::endl;
500
501 std::cout << "A[k0, m0, m1, k1]: " << a_grid_desc_k0_m0_m1_k1_ << std::endl;
502 std::cout << "B[k0, n0, n1, k1]: " << b_grid_desc_k0_n0_n1_k1_ << std::endl;
503 std::cout << "A[m0, m10, m11, n0, n10, n11]: " << e_grid_desc_m0_m10_m11_n0_n10_n11_
504 << std::endl;
505 }
506
507 // private:
508 // pointers
509 const ADataType* p_a_grid_;
510 const BDataType* p_b_grid_;
512 EDataType* p_e_grid_;
513
514 // tensor descriptors for problem definiton
516
518
523
524 // tensor descriptors for block/thread-wise copy
529
530 // block-to-e-tile map
532
533 // for computing batch offset
534 ComputePtrOffsetOfStridedBatch<I1, I1, NumDTensor> compute_ptr_offset_of_batch_;
535
536 // element-wise op
537 AElementwiseOperation a_element_op_;
538 BElementwiseOperation b_element_op_;
539 CDEElementwiseOperation cde_element_op_;
540
541 // for checking IsSupportedArgument()
542 std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_lengths_;
543 std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_strides_;
544 std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_lengths_;
545 std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_strides_;
546 std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_lengths_;
547 std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_strides_;
548 std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_lengths_;
549 std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_strides_;
550 std::array<index_t, NDimSpatial> conv_filter_strides_;
551 std::array<index_t, NDimSpatial> conv_filter_dilations_;
552 std::array<index_t, NDimSpatial> input_left_pads_;
553 std::array<index_t, NDimSpatial> input_right_pads_;
554 };
555
556 // Invoker
557 struct Invoker : public BaseInvoker
558 {
560
561 float Run(const Argument& arg, const StreamConfig& stream_config)
562 {
563 if(stream_config.log_level_ > 0)
564 {
565 arg.Print();
566 }
567
570 {
571 throw std::runtime_error(
572 "wrong! DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK has invalid setting");
573 }
574
575 const index_t grid_size =
577 arg.e_grid_desc_m_n_.GetLength(I1)) *
578 arg.num_group_;
579
580 auto launch_kernel = [&](auto has_main_k_block_loop,
581 auto has_double_tail_k_block_loop) {
582 constexpr bool has_main_loop = has_main_k_block_loop.value;
583 constexpr bool has_double_loop = has_double_tail_k_block_loop;
584
585 const auto kernel = kernel_grouped_conv_fwd_dl_multiple_d<
587 ADataType, // TODO: distiguish A/B datatype
589 EDataType,
590 AElementwiseOperation,
591 BElementwiseOperation,
592 CDEElementwiseOperation,
598 ComputePtrOffsetOfStridedBatch<I1, I1, NumDTensor>,
599 has_main_loop,
600 has_double_loop>;
601
602 return launch_and_time_kernel(stream_config,
603 kernel,
604 dim3(grid_size),
605 dim3(BlockSize),
606 0,
607 arg.p_a_grid_,
608 arg.p_b_grid_,
609 arg.p_ds_grid_,
610 arg.p_e_grid_,
611 arg.a_element_op_,
612 arg.b_element_op_,
613 arg.cde_element_op_,
614 arg.a_g_n_c_wis_lengths_[0], // Group count
621 };
622
623 const auto K0 = arg.a_grid_desc_k0_m0_m1_k1_.GetLength(I0);
624 const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K0);
625 const bool has_double_tail_k_block_loop =
627
628 if(has_main_k_block_loop && has_double_tail_k_block_loop)
629 {
630 return launch_kernel(integral_constant<bool, true>{},
632 }
633 else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
634 {
635 return launch_kernel(integral_constant<bool, true>{},
637 }
638 else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
639 {
640 return launch_kernel(integral_constant<bool, false>{},
642 }
643 else
644 {
645 return launch_kernel(integral_constant<bool, false>{},
647 }
648 return 0;
649 }
650
651 float Run(const BaseArgument* p_arg,
652 const StreamConfig& stream_config = StreamConfig{}) override
653 {
654 return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
655 }
656 };
657
658 static bool IsSupportedArgument(const Argument& arg)
659 {
660 namespace ctc = tensor_layout::convolution;
661
662 // check device
663 if(!(ck::get_device_name() == "gfx906" || ck::is_xdl_supported() ||
665 {
666 return false;
667 }
668
669 // check ConvolutionForwardSpecialization
670 if constexpr(ConvForwardSpecialization ==
672 {
673 // check if it's 1x1, stride=1 conv
674 for(index_t i = 0; i < NDimSpatial; ++i)
675 {
676 const index_t X = arg.b_g_k_c_xs_lengths_[i + 3];
677 const index_t ConvStride = arg.conv_filter_strides_[i];
678 const index_t LeftPad = arg.input_left_pads_[i];
679 const index_t RightPad = arg.input_right_pads_[i];
680
681 if(!(X == 1 && ConvStride == 1 && LeftPad == 0 && RightPad == 0))
682 {
683 std::cout << "Filter1x1Stride1Pad0 check: XY_index = " << i << " X = " << X
684 << " ConvStride = " << ConvStride << " LeftPad = " << LeftPad
685 << " RightPad = " << RightPad << std::endl;
686 return false;
687 }
688 }
689 }
690 else if constexpr(ConvForwardSpecialization ==
692 {
693 // check if it's 1x1 conv
694 for(index_t i = 0; i < NDimSpatial; ++i)
695 {
696 const index_t X = arg.b_g_k_c_xs_lengths_[i + 3];
697 const index_t LeftPad = arg.input_left_pads_[i];
698 const index_t RightPad = arg.input_right_pads_[i];
699
700 if(!(X == 1 && LeftPad == 0 && RightPad == 0))
701 {
702 std::cout << "Filter1x1Stride1Pad0 check: XY_index = " << i << " X = " << X
703 << " LeftPad = " << LeftPad << " RightPad = " << RightPad
704 << std::endl;
705 return false;
706 }
707 }
708 }
709
710 // check vector access of A
711 // FIXME: layout
717 {
718 auto srcVectorLengths = ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1{};
719 if(srcVectorLengths[I1] != 1 || srcVectorLengths[I2] != 1)
720 {
721 return false;
722 }
723 if(K1 % srcVectorLengths[I3] != 0 || K0PerBlock % srcVectorLengths[I0] != 0)
724 {
725 return false;
726 }
727
728 const index_t C = arg.a_g_n_c_wis_lengths_[2];
729
730 if(C % (srcVectorLengths[I0] * srcVectorLengths[I3]) != 0)
731 {
732 return false;
733 }
734 }
735 else
736 {
737 return false;
738 }
739
740 // check vector access of B
741 // FIXME: layout
747
748 {
749 auto srcVectorLengths = BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1{};
750 if(srcVectorLengths[I1] != 1 || srcVectorLengths[I2] != 1)
751 {
752 return false;
753 }
754 if(K1 % srcVectorLengths[I3] != 0 || K0PerBlock % srcVectorLengths[I0] != 0)
755 {
756 return false;
757 }
758
759 const index_t C = arg.b_g_k_c_xs_lengths_[2];
760
761 if(C % (srcVectorLengths[I0] * srcVectorLengths[I3]) != 0)
762 {
763 return false;
764 }
765 }
766 else
767 {
768 return false;
769 }
770
771 // check vector access of E
777 {
778 const index_t K = arg.e_g_n_k_wos_lengths_[2];
779
780 if(!(K % CThreadTransferDstScalarPerVector == 0 && CThreadTransferSrcDstVectorDim == 5))
781 {
782 return false;
783 }
784 }
785 else
786 {
787 return false;
788 }
789
790 // check Gridwise GEMM
793 }
794
795 bool IsSupportedArgument(const BaseArgument* p_arg) override
796 {
797 return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
798 }
799
800 static auto MakeArgument(
801 const void* p_a,
802 const void* p_b,
803 const std::array<const void*, NumDTensor>& p_ds,
804 void* p_e,
805 const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
806 const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
807 const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
808 const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
809 const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>& ds_g_n_k_wos_lengths,
810 const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>& ds_g_n_k_wos_strides,
811 const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
812 const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
813 const std::array<index_t, NDimSpatial>& conv_filter_strides,
814 const std::array<index_t, NDimSpatial>& conv_filter_dilations,
815 const std::array<index_t, NDimSpatial>& input_left_pads,
816 const std::array<index_t, NDimSpatial>& input_right_pads,
817 const AElementwiseOperation& a_element_op,
818 const BElementwiseOperation& b_element_op,
819 const CDEElementwiseOperation& cde_element_op)
820 {
821 return Argument{p_a,
822 p_b,
823 p_ds,
824 p_e,
825 a_g_n_c_wis_lengths,
826 a_g_n_c_wis_strides,
827 b_g_k_c_xs_lengths,
828 b_g_k_c_xs_strides,
829 ds_g_n_k_wos_lengths,
830 ds_g_n_k_wos_strides,
831 e_g_n_k_wos_lengths,
832 e_g_n_k_wos_strides,
833 conv_filter_strides,
834 conv_filter_dilations,
835 input_left_pads,
836 input_right_pads,
837 a_element_op,
838 b_element_op,
839 cde_element_op};
840 }
841
842 static auto
843 MakeArgument(const void* p_a,
844 const void* p_b,
845 const std::array<const void*, NumDTensor>& p_ds,
846 void* p_e,
847 const std::array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
848 const std::array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
849 const std::array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
850 const std::array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
851 const std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor>&
852 ds_g_n_k_wos_lengths,
853 const std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor>&
854 ds_g_n_k_wos_strides,
855 const std::array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
856 const std::array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
857 const std::array<long_index_t, NDimSpatial>& conv_filter_strides,
858 const std::array<long_index_t, NDimSpatial>& conv_filter_dilations,
859 const std::array<long_index_t, NDimSpatial>& input_left_pads,
860 const std::array<long_index_t, NDimSpatial>& input_right_pads,
861 const AElementwiseOperation& a_element_op,
862 const BElementwiseOperation& b_element_op,
863 const CDEElementwiseOperation& cde_element_op)
864 {
865 std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_lengths_i32;
866 std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_strides_i32;
867 std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_lengths_i32;
868 std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_strides_i32;
869 std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_lengths_i32;
870 std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_strides_i32;
871 std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_lengths_i32;
872 std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_strides_i32;
873 std::array<index_t, NDimSpatial> conv_filter_strides_i32;
874 std::array<index_t, NDimSpatial> conv_filter_dilations_i32;
875 std::array<index_t, NDimSpatial> input_left_pads_i32;
876 std::array<index_t, NDimSpatial> input_right_pads_i32;
877
878 array_convert(a_g_n_c_wis_lengths_i32, a_g_n_c_wis_lengths);
879 array_convert(a_g_n_c_wis_strides_i32, a_g_n_c_wis_strides);
880 array_convert(b_g_k_c_xs_lengths_i32, b_g_k_c_xs_lengths);
881 array_convert(b_g_k_c_xs_strides_i32, b_g_k_c_xs_strides);
882 for(index_t d = 0; d < NumDTensor; d++)
883 {
884 array_convert(ds_g_n_k_wos_lengths_i32[d], ds_g_n_k_wos_lengths[d]);
885 array_convert(ds_g_n_k_wos_strides_i32[d], ds_g_n_k_wos_strides[d]);
886 }
887 array_convert(e_g_n_k_wos_lengths_i32, e_g_n_k_wos_lengths);
888 array_convert(e_g_n_k_wos_strides_i32, e_g_n_k_wos_strides);
889 array_convert(conv_filter_strides_i32, conv_filter_strides);
890 array_convert(conv_filter_dilations_i32, conv_filter_dilations);
891 array_convert(input_left_pads_i32, input_left_pads);
892 array_convert(input_right_pads_i32, input_right_pads);
893
894 return Argument{p_a,
895 p_b,
896 p_ds,
897 p_e,
898 a_g_n_c_wis_lengths_i32,
899 a_g_n_c_wis_strides_i32,
900 b_g_k_c_xs_lengths_i32,
901 b_g_k_c_xs_strides_i32,
902 ds_g_n_k_wos_lengths_i32,
903 ds_g_n_k_wos_strides_i32,
904 e_g_n_k_wos_lengths_i32,
905 e_g_n_k_wos_strides_i32,
906 conv_filter_strides_i32,
907 conv_filter_dilations_i32,
908 input_left_pads_i32,
909 input_right_pads_i32,
910 a_element_op,
911 b_element_op,
912 cde_element_op};
913 }
914
915 static auto MakeInvoker() { return Invoker{}; }
916
917 std::unique_ptr<BaseArgument> MakeArgumentPointer(
918 const void* p_a,
919 const void* p_b,
920 const std::array<const void*, NumDTensor>& p_ds,
921 void* p_e,
922 const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
923 const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
924 const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
925 const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
926 const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>& ds_g_n_k_wos_lengths,
927 const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>& ds_g_n_k_wos_strides,
928 const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
929 const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
930 const std::array<index_t, NDimSpatial>& conv_filter_strides,
931 const std::array<index_t, NDimSpatial>& conv_filter_dilations,
932 const std::array<index_t, NDimSpatial>& input_left_pads,
933 const std::array<index_t, NDimSpatial>& input_right_pads,
934 const AElementwiseOperation& a_element_op,
935 const BElementwiseOperation& b_element_op,
936 const CDEElementwiseOperation& cde_element_op) override
937 {
938 return std::make_unique<Argument>(p_a,
939 p_b,
940 p_ds,
941 p_e,
942 a_g_n_c_wis_lengths,
943 a_g_n_c_wis_strides,
944 b_g_k_c_xs_lengths,
945 b_g_k_c_xs_strides,
946 ds_g_n_k_wos_lengths,
947 ds_g_n_k_wos_strides,
948 e_g_n_k_wos_lengths,
949 e_g_n_k_wos_strides,
950 conv_filter_strides,
951 conv_filter_dilations,
952 input_left_pads,
953 input_right_pads,
954 a_element_op,
955 b_element_op,
956 cde_element_op);
957 }
958
959 std::unique_ptr<BaseArgument>
960 MakeArgumentPointer(const void* p_a,
961 const void* p_b,
962 const std::array<const void*, NumDTensor>& p_ds,
963 void* p_e,
964 const std::array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
965 const std::array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
966 const std::array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
967 const std::array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
968 const std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor>&
969 ds_g_n_k_wos_lengths,
970 const std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor>&
971 ds_g_n_k_wos_strides,
972 const std::array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
973 const std::array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
974 const std::array<long_index_t, NDimSpatial>& conv_filter_strides,
975 const std::array<long_index_t, NDimSpatial>& conv_filter_dilations,
976 const std::array<long_index_t, NDimSpatial>& input_left_pads,
977 const std::array<long_index_t, NDimSpatial>& input_right_pads,
978 const AElementwiseOperation& a_element_op,
979 const BElementwiseOperation& b_element_op,
980 const CDEElementwiseOperation& cde_element_op) override
981 {
982 std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_lengths_i32;
983 std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_strides_i32;
984 std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_lengths_i32;
985 std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_strides_i32;
986 std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_lengths_i32;
987 std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_strides_i32;
988 std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_lengths_i32;
989 std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_strides_i32;
990 std::array<index_t, NDimSpatial> conv_filter_strides_i32;
991 std::array<index_t, NDimSpatial> conv_filter_dilations_i32;
992 std::array<index_t, NDimSpatial> input_left_pads_i32;
993 std::array<index_t, NDimSpatial> input_right_pads_i32;
994
995 array_convert(a_g_n_c_wis_lengths_i32, a_g_n_c_wis_lengths);
996 array_convert(a_g_n_c_wis_strides_i32, a_g_n_c_wis_strides);
997 array_convert(b_g_k_c_xs_lengths_i32, b_g_k_c_xs_lengths);
998 array_convert(b_g_k_c_xs_strides_i32, b_g_k_c_xs_strides);
999 for(index_t d = 0; d < NumDTensor; d++)
1000 {
1001 array_convert(ds_g_n_k_wos_lengths_i32[d], ds_g_n_k_wos_lengths[d]);
1002 array_convert(ds_g_n_k_wos_strides_i32[d], ds_g_n_k_wos_strides[d]);
1003 }
1004 array_convert(e_g_n_k_wos_lengths_i32, e_g_n_k_wos_lengths);
1005 array_convert(e_g_n_k_wos_strides_i32, e_g_n_k_wos_strides);
1006 array_convert(conv_filter_strides_i32, conv_filter_strides);
1007 array_convert(conv_filter_dilations_i32, conv_filter_dilations);
1008 array_convert(input_left_pads_i32, input_left_pads);
1009 array_convert(input_right_pads_i32, input_right_pads);
1010
1011 return std::make_unique<Argument>(p_a,
1012 p_b,
1013 p_ds,
1014 p_e,
1015 a_g_n_c_wis_lengths_i32,
1016 a_g_n_c_wis_strides_i32,
1017 b_g_k_c_xs_lengths_i32,
1018 b_g_k_c_xs_strides_i32,
1019 ds_g_n_k_wos_lengths_i32,
1020 ds_g_n_k_wos_strides_i32,
1021 e_g_n_k_wos_lengths_i32,
1022 e_g_n_k_wos_strides_i32,
1023 conv_filter_strides_i32,
1024 conv_filter_dilations_i32,
1025 input_left_pads_i32,
1026 input_right_pads_i32,
1027 a_element_op,
1028 b_element_op,
1029 cde_element_op);
1030 }
1031
1032 std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
1033 {
1034 return std::make_unique<Invoker>(Invoker{});
1035 }
1036
1037 std::string GetTypeString() const override
1038 {
1039 auto str = std::stringstream();
1040
1041 // clang-format off
1042 str << "DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK"
1043 << "<"
1044 << BlockSize << ", "
1045 << MPerBlock << ", "
1046 << NPerBlock << ", "
1047 << K0PerBlock << ", "
1048 << getConvForwardSpecializationString(ConvForwardSpecialization) << ", "
1049 << K1
1050 << ">";
1051 // clang-format on
1052
1053 return str.str();
1054 }
1055
1056#ifdef CK_EXPERIMENTAL_BUILDER
1057 std::string GetInstanceString() const override
1058 {
1059 static_assert(ck_tile::reflect::HasInstanceTraits<DeviceOp>,
1060 "Specialization of instance_traits not found. Please check that a "
1061 "specialization exists in file "
1062 "ck_tile/builder/reflect/"
1063 "instance_traits_device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp "
1064 "for the given template parameters.");
1065 return ck_tile::reflect::instance_string<DeviceOp>();
1066 }
1067#endif
1068};
1069
1070} // namespace device
1071} // namespace tensor_operation
1072} // namespace ck
#define CK_MIN_BLOCK_PER_CU
Definition ck.hpp:31
#define CK_MAX_THREAD_PER_BLOCK
Definition ck.hpp:30
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:14
Definition tensor_operation/gpu/device/tensor_layout.hpp:42
Definition convolution_backward_data_specialization.hpp:8
GemmSpecialization
Definition gemm_specialization.hpp:11
ConvolutionForwardSpecialization
Definition convolution_forward_specialization.hpp:15
@ Filter1x1Stride1Pad0
Definition convolution_forward_specialization.hpp:18
@ Filter1x1Pad0
Definition convolution_forward_specialization.hpp:17
std::string getConvForwardSpecializationString(const ConvolutionForwardSpecialization &s)
Definition convolution_forward_specialization.hpp:24
Definition convolution_backward_data_specialization.hpp:7
Definition ck.hpp:268
bool is_xdl_supported()
Definition host_utility/device_prop.hpp:68
__host__ __device__ constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition multi_index_transform_helper.hpp:12
__device__ index_t get_grid_size()
Definition get_id.hpp:49
int32_t index_t
Definition ck.hpp:299
@ Set
Definition ck.hpp:278
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
integral_constant< index_t, N > Number
Definition number.hpp:12
std::string get_device_name()
Definition host_utility/device_prop.hpp:19
typename tuple_element< I, TTuple >::type tuple_element_t
Definition utility/tuple.hpp:208
constexpr detail::ignore_t ignore
Definition utility/ignore.hpp:20
__device__ uint32_t amd_wave_read_first_lane(uint32_t value)
Definition amd_wave_read_first_lane.hpp:100
__device__ index_t get_block_1d_id()
Definition get_id.hpp:47
bool is_gfx12_supported()
Definition host_utility/device_prop.hpp:55
bool is_gfx103_supported()
Definition host_utility/device_prop.hpp:120
__host__ __device__ void array_convert(std::array< Y, NumElems > &y, const std::array< X, NumElems > &x)
Definition utility/type_convert.hpp:2466
constexpr bool is_same_v
Definition type.hpp:283
__host__ __device__ constexpr auto generate_tuple(F &&f, Number< N >)
Definition tuple_helper.hpp:21
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
int64_t long_index_t
Definition ck.hpp:300
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
__host__ __device__ constexpr auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:90
bool is_gfx11_supported()
Definition host_utility/device_prop.hpp:60
Definition ck/stream_config.hpp:10
int log_level_
Definition ck/stream_config.hpp:13
Definition gridwise_gemm_dl_multiple_d.hpp:60
Definition multi_index_transform.hpp:196
Definition multi_index_transform.hpp:284
Definition utility/sequence.hpp:43
Definition utility/integral_constant.hpp:20
Definition functional2.hpp:33
Definition tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp:25
Definition device_base.hpp:197
virtual std::string GetInstanceString() const
Definition device_base.hpp:230
Definition device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:380
std::array< index_t, NDimSpatial+3 > b_g_k_c_xs_strides_
Definition device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:545
std::array< index_t, NDimSpatial+3 > a_g_n_c_wis_lengths_
Definition device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:542
void Print() const
Definition device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:494
const ADataType * p_a_grid_
Definition device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:509
ConvToGemmFwdTransformer conv_to_gemm_transformer_
Definition device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:517
std::array< index_t, NDimSpatial > conv_filter_strides_
Definition device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:550
EDataType * p_e_grid_
Definition device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:512
AGridDesc_K0_M0_M1_K1 a_grid_desc_k0_m0_m1_k1_
Definition device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:525
std::array< index_t, NDimSpatial+3 > a_g_n_c_wis_strides_
Definition device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:543
std::array< index_t, NDimSpatial+3 > e_g_n_k_wos_strides_
Definition device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:549
GridwiseGemm::DsGridPointer p_ds_grid_
Definition device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:511
std::array< index_t, NDimSpatial > conv_filter_dilations_
Definition device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:551
ComputePtrOffsetOfStridedBatch< I1, I1, NumDTensor > compute_ptr_offset_of_batch_
Definition device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:534
BGridDesc_K0_N0_N1_K1 b_grid_desc_k0_n0_n1_k1_
Definition device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:526
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_
Definition device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:519
std::array< std::array< index_t, NDimSpatial+3 >, NumDTensor > ds_g_n_k_wos_strides_
Definition device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:547
const BDataType * p_b_grid_
Definition device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:510
EGridDesc_M_N e_grid_desc_m_n_
Definition device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:522
std::array< std::array< index_t, NDimSpatial+3 >, NumDTensor > ds_g_n_k_wos_lengths_
Definition device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:546
BElementwiseOperation b_element_op_
Definition device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:538
AElementwiseOperation a_element_op_
Definition device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:537
DsGridDesc_M0_M10_M11_N0_N10_N11 ds_grid_desc_m0_m10_m11_n0_n10_n11_
Definition device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:527
Argument(const void *p_a, const void *p_b, const std::array< const void *, NumDTensor > &p_ds, void *p_e, const std::array< index_t, NDimSpatial+3 > &a_g_n_c_wis_lengths, const std::array< index_t, NDimSpatial+3 > &a_g_n_c_wis_strides, const std::array< index_t, NDimSpatial+3 > &b_g_k_c_xs_lengths, const std::array< index_t, NDimSpatial+3 > &b_g_k_c_xs_strides, const std::array< std::array< index_t, NDimSpatial+3 >, NumDTensor > &ds_g_n_k_wos_lengths, const std::array< std::array< index_t, NDimSpatial+3 >, NumDTensor > &ds_g_n_k_wos_strides, const std::array< index_t, NDimSpatial+3 > &e_g_n_k_wos_lengths, const std::array< index_t, NDimSpatial+3 > &e_g_n_k_wos_strides, const std::array< index_t, NDimSpatial > &conv_filter_strides, const std::array< index_t, NDimSpatial > &conv_filter_dilations, const std::array< index_t, NDimSpatial > &input_left_pads, const std::array< index_t, NDimSpatial > &input_right_pads, const AElementwiseOperation &a_element_op, const BElementwiseOperation &b_element_op, const CDEElementwiseOperation &cde_element_op)
Definition device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:381
DefaultBlock2CTileMap block_2_ctile_map_
Definition device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:531
DsGridDesc_M_N ds_grid_desc_m_n_
Definition device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:521
index_t num_group_
Definition device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:515
std::array< index_t, NDimSpatial+3 > e_g_n_k_wos_lengths_
Definition device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:548
std::array< index_t, NDimSpatial+3 > b_g_k_c_xs_lengths_
Definition device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:544
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_
Definition device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:520
std::array< index_t, NDimSpatial > input_right_pads_
Definition device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:553
CDEElementwiseOperation cde_element_op_
Definition device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:539
CGridDesc_M0_M10_M11_N0_N10_N11 e_grid_desc_m0_m10_m11_n0_n10_n11_
Definition device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:528
std::array< index_t, NDimSpatial > input_left_pads_
Definition device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:552
Definition device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:558
float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:651
DeviceOp::Argument Argument
Definition device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:559
float Run(const Argument &arg, const StreamConfig &stream_config)
Definition device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:561
Definition device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:234
static constexpr auto I1
Definition device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:240
remove_cvref_t< decltype(MakeEGridDescriptor_M_N< ELayout >(dummy_conv_to_gemm_transformer))> EGridDesc_M_N
Definition device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:323
remove_cvref_t< decltype(MakeDsGridDescriptor_M_N(dummy_conv_to_gemm_transformer))> DsGridDesc_M_N
Definition device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:321
decltype(GridwiseGemm::MakeDsGridDescriptor_M0_M10_M11_N0_N10_N11(DsGridDesc_M_N{})) DsGridDesc_M0_M10_M11_N0_N10_N11
Definition device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:371
static auto MakeBGridDescriptor_BK0_N_BK1(const ConvToGemmFwdTransformer &conv_to_gemm_transformer)
Definition device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:272
TransformConvFwdToGemm< NDimSpatial, ConvForwardSpecialization > ConvToGemmFwdTransformer
Definition device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:244
static constexpr index_t NumDTensor
Definition device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:237
static constexpr auto I3
Definition device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:242
static constexpr auto I2
Definition device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:241
decltype(GridwiseGemm::MakeBGridDescriptor_K0_N0_N1_K1(BGridDesc_BK0_N_BK1{})) BGridDesc_K0_N0_N1_K1
Definition device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:369
GridwiseGemmDlMultipleD_km_kn_mn< BlockSize, ADataType, AccDataType, DsDataType, EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc_AK0_M_AK1, BGridDesc_BK0_N_BK1, EGridDesc_M_N, MPerBlock, NPerBlock, K0PerBlock, K1, M1PerThread, N1PerThread, KPerThread, M1N1ThreadClusterM1Xs, M1N1ThreadClusterN1Xs, ABlockTransferThreadSliceLengths_K0_M0_M1_K1, ABlockTransferThreadClusterLengths_K0_M0_M1_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1, ABlockTransferSrcVectorTensorContiguousDimOrder, ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1, BBlockTransferThreadSliceLengths_K0_N0_N1_K1, BBlockTransferThreadClusterLengths_K0_N0_N1_K1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1, BBlockTransferSrcVectorTensorContiguousDimOrder, BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1, CThreadTransferSrcDstAccessOrder, CThreadTransferSrcDstVectorDim, CThreadTransferDstScalarPerVector > GridwiseGemm
Definition device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:327
DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK DeviceOp
Definition device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:235
decltype(GridwiseGemm::MakeDefaultBlock2CTileMap(EGridDesc_M_N{})) DefaultBlock2CTileMap
Definition device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:375
decltype(GridwiseGemm::MakeAGridDescriptor_K0_M0_M1_K1(AGridDesc_AK0_M_AK1{})) AGridDesc_K0_M0_M1_K1
Definition device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:367
std::string GetTypeString() const override
Definition device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:1037
static auto MakeAGridDescriptor_AK0_M_AK1(const ConvToGemmFwdTransformer &conv_to_gemm_transformer)
Definition device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:251
static constexpr auto matrix_padder
Definition device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:246
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:1032
static bool IsSupportedArgument(const Argument &arg)
Definition device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:658
decltype(GridwiseGemm::MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(EGridDesc_M_N{})) CGridDesc_M0_M10_M11_N0_N10_N11
Definition device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:373
static constexpr ConvToGemmFwdTransformer dummy_conv_to_gemm_transformer
Definition device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:316
remove_cvref_t< decltype(MakeAGridDescriptor_AK0_M_AK1< ALayout >( dummy_conv_to_gemm_transformer))> AGridDesc_AK0_M_AK1
Definition device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:317
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, const std::array< const void *, NumDTensor > &p_ds, void *p_e, const std::array< long_index_t, NDimSpatial+3 > &a_g_n_c_wis_lengths, const std::array< long_index_t, NDimSpatial+3 > &a_g_n_c_wis_strides, const std::array< long_index_t, NDimSpatial+3 > &b_g_k_c_xs_lengths, const std::array< long_index_t, NDimSpatial+3 > &b_g_k_c_xs_strides, const std::array< std::array< long_index_t, NDimSpatial+3 >, NumDTensor > &ds_g_n_k_wos_lengths, const std::array< std::array< long_index_t, NDimSpatial+3 >, NumDTensor > &ds_g_n_k_wos_strides, const std::array< long_index_t, NDimSpatial+3 > &e_g_n_k_wos_lengths, const std::array< long_index_t, NDimSpatial+3 > &e_g_n_k_wos_strides, const std::array< long_index_t, NDimSpatial > &conv_filter_strides, const std::array< long_index_t, NDimSpatial > &conv_filter_dilations, const std::array< long_index_t, NDimSpatial > &input_left_pads, const std::array< long_index_t, NDimSpatial > &input_right_pads, const AElementwiseOperation &a_element_op, const BElementwiseOperation &b_element_op, const CDEElementwiseOperation &cde_element_op) override
Definition device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:960
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, const std::array< const void *, NumDTensor > &p_ds, void *p_e, const std::array< index_t, NDimSpatial+3 > &a_g_n_c_wis_lengths, const std::array< index_t, NDimSpatial+3 > &a_g_n_c_wis_strides, const std::array< index_t, NDimSpatial+3 > &b_g_k_c_xs_lengths, const std::array< index_t, NDimSpatial+3 > &b_g_k_c_xs_strides, const std::array< std::array< index_t, NDimSpatial+3 >, NumDTensor > &ds_g_n_k_wos_lengths, const std::array< std::array< index_t, NDimSpatial+3 >, NumDTensor > &ds_g_n_k_wos_strides, const std::array< index_t, NDimSpatial+3 > &e_g_n_k_wos_lengths, const std::array< index_t, NDimSpatial+3 > &e_g_n_k_wos_strides, const std::array< index_t, NDimSpatial > &conv_filter_strides, const std::array< index_t, NDimSpatial > &conv_filter_dilations, const std::array< index_t, NDimSpatial > &input_left_pads, const std::array< index_t, NDimSpatial > &input_right_pads, const AElementwiseOperation &a_element_op, const BElementwiseOperation &b_element_op, const CDEElementwiseOperation &cde_element_op) override
Definition device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:917
static auto MakeInvoker()
Definition device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:915
static auto MakeArgument(const void *p_a, const void *p_b, const std::array< const void *, NumDTensor > &p_ds, void *p_e, const std::array< index_t, NDimSpatial+3 > &a_g_n_c_wis_lengths, const std::array< index_t, NDimSpatial+3 > &a_g_n_c_wis_strides, const std::array< index_t, NDimSpatial+3 > &b_g_k_c_xs_lengths, const std::array< index_t, NDimSpatial+3 > &b_g_k_c_xs_strides, const std::array< std::array< index_t, NDimSpatial+3 >, NumDTensor > &ds_g_n_k_wos_lengths, const std::array< std::array< index_t, NDimSpatial+3 >, NumDTensor > &ds_g_n_k_wos_strides, const std::array< index_t, NDimSpatial+3 > &e_g_n_k_wos_lengths, const std::array< index_t, NDimSpatial+3 > &e_g_n_k_wos_strides, const std::array< index_t, NDimSpatial > &conv_filter_strides, const std::array< index_t, NDimSpatial > &conv_filter_dilations, const std::array< index_t, NDimSpatial > &input_left_pads, const std::array< index_t, NDimSpatial > &input_right_pads, const AElementwiseOperation &a_element_op, const BElementwiseOperation &b_element_op, const CDEElementwiseOperation &cde_element_op)
Definition device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:800
static auto MakeEGridDescriptor_M_N(const ConvToGemmFwdTransformer &conv_to_gemm_transformer)
Definition device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:293
remove_cvref_t< decltype(MakeBGridDescriptor_BK0_N_BK1< BLayout >( dummy_conv_to_gemm_transformer))> BGridDesc_BK0_N_BK1
Definition device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:319
static auto MakeArgument(const void *p_a, const void *p_b, const std::array< const void *, NumDTensor > &p_ds, void *p_e, const std::array< long_index_t, NDimSpatial+3 > &a_g_n_c_wis_lengths, const std::array< long_index_t, NDimSpatial+3 > &a_g_n_c_wis_strides, const std::array< long_index_t, NDimSpatial+3 > &b_g_k_c_xs_lengths, const std::array< long_index_t, NDimSpatial+3 > &b_g_k_c_xs_strides, const std::array< std::array< long_index_t, NDimSpatial+3 >, NumDTensor > &ds_g_n_k_wos_lengths, const std::array< std::array< long_index_t, NDimSpatial+3 >, NumDTensor > &ds_g_n_k_wos_strides, const std::array< long_index_t, NDimSpatial+3 > &e_g_n_k_wos_lengths, const std::array< long_index_t, NDimSpatial+3 > &e_g_n_k_wos_strides, const std::array< long_index_t, NDimSpatial > &conv_filter_strides, const std::array< long_index_t, NDimSpatial > &conv_filter_dilations, const std::array< long_index_t, NDimSpatial > &input_left_pads, const std::array< long_index_t, NDimSpatial > &input_right_pads, const AElementwiseOperation &a_element_op, const BElementwiseOperation &b_element_op, const CDEElementwiseOperation &cde_element_op)
Definition device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:843
static auto MakeDsGridDescriptor_M_N(const ConvToGemmFwdTransformer &conv_to_gemm_transformer)
Definition device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:304
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:795
static constexpr auto I0
Definition device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:239
Grouped Convolution Forward.
Definition device_grouped_conv_fwd_multiple_abd.hpp:73
Definition matrix_padder.hpp:180