device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp Source File

device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp Source File#

Composable Kernel: device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp Source File
device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include <iostream>
7#include <sstream>
8
10#include "ck/utility/env.hpp"
19
20namespace ck {
21namespace tensor_operation {
22namespace device {
23
24// out[N, Ho, Wo, K] = in[N, Hi, Wi, C] * wei[K, Y, X, C]
25template <typename InDataType,
26 typename WeiDataType,
27 typename OutDataType,
28 typename AccDataType,
29 typename InElementwiseOperation,
30 typename WeiElementwiseOperation,
31 typename OutElementwiseOperation,
32 ConvolutionForwardSpecialization ConvForwardSpecialization,
33 ck::index_t BlockSize,
34 ck::index_t MPerBlock,
35 ck::index_t NPerBlock,
36 ck::index_t K0PerBlock,
37 ck::index_t K1,
38 ck::index_t MPerXDL,
39 ck::index_t NPerXDL,
40 ck::index_t MXdlPerWave,
41 ck::index_t NXdlPerWave,
42 typename ABlockTransferThreadClusterLengths_K0_M_K1,
43 typename ABlockTransferThreadClusterArrangeOrder,
44 typename ABlockTransferSrcAccessOrder,
45 ck::index_t ABlockTransferSrcVectorDim,
46 ck::index_t ABlockTransferSrcScalarPerVector,
47 ck::index_t ABlockTransferDstScalarPerVector_K1,
48 bool ABlockLdsAddExtraM,
49 typename BBlockTransferThreadClusterLengths_K0_N_K1,
50 typename BBlockTransferThreadClusterArrangeOrder,
51 typename BBlockTransferSrcAccessOrder,
52 ck::index_t BBlockTransferSrcVectorDim,
53 ck::index_t BBlockTransferSrcScalarPerVector,
54 ck::index_t BBlockTransferDstScalarPerVector_K1,
55 bool BBlockLdsAddExtraN,
56 ck::index_t CThreadTransferSrcDstVectorDim,
57 ck::index_t CThreadTransferDstScalarPerVector>
59 : public DeviceConvFwd<2,
60 ck::tensor_layout::convolution::NHWC,
61 ck::tensor_layout::convolution::KYXC,
62 ck::tensor_layout::convolution::NHWK,
63 InDataType,
64 WeiDataType,
65 OutDataType,
66 InElementwiseOperation,
67 WeiElementwiseOperation,
68 OutElementwiseOperation>
69{
71
73 static constexpr auto NXdlPerWave64 = GetNXdlPerWave<true>();
74 static constexpr auto NXdlPerWave32 = GetNXdlPerWave<false>();
75
76 using ADataType = InDataType;
77 using BDataType = WeiDataType;
78 using CDataType = OutDataType;
79
80 // TODO make A/B datatype different
81 using ABDataType = InDataType;
82
83 static constexpr index_t NDimSpatial = 2;
84
85 static constexpr auto I0 = Number<0>{};
86 static constexpr auto I1 = Number<1>{};
87 static constexpr auto I2 = Number<2>{};
88 static constexpr auto I3 = Number<3>{};
89
90 static constexpr auto K1Number = Number<K1>{};
91 static constexpr auto GemmK1Number = K1Number;
92
93 static auto
97 std::vector<ck::index_t> input_spatial_lengths,
98 std::vector<ck::index_t> filter_spatial_lengths,
99 std::vector<ck::index_t> output_spatial_lengths,
100 std::vector<ck::index_t> conv_filter_strides,
101 std::vector<ck::index_t> conv_filter_dilations,
102 std::vector<ck::index_t> input_left_pads,
103 std::vector<ck::index_t> input_right_pads)
104 {
105 using namespace ck;
106
107 const index_t Hi = input_spatial_lengths[0];
108 const index_t Wi = input_spatial_lengths[1];
109
110 const index_t Ho = output_spatial_lengths[0];
111 const index_t Wo = output_spatial_lengths[1];
112
113 const index_t Y = filter_spatial_lengths[0];
114 const index_t X = filter_spatial_lengths[1];
115
116 const index_t ConvStrideH = conv_filter_strides[0];
117 const index_t ConvStrideW = conv_filter_strides[1];
118
119 const index_t ConvDilationH = conv_filter_dilations[0];
120 const index_t ConvDilationW = conv_filter_dilations[1];
121
122 const index_t InLeftPadH = input_left_pads[0];
123 const index_t InLeftPadW = input_left_pads[1];
124
125 const index_t InRightPadH = input_right_pads[0];
126 const index_t InRightPadW = input_right_pads[1];
127
128 const index_t GemmMRaw = N * Ho * Wo;
129 const index_t GemmN = K;
130 const index_t GemmK = Y * X * C;
131
132 const auto GemmMPad = math::integer_least_multiple(GemmMRaw, MPerBlock) - GemmMRaw;
133
134 assert(GemmK % GemmK1Number == 0);
135
136 const index_t GemmK0 = GemmK / GemmK1Number;
137
138 if constexpr(ConvForwardSpecialization ==
140 {
141 // A: input tensor
142 const auto in_gemmmraw_gemmk_grid_desc =
144
145 const auto in_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
146 in_gemmmraw_gemmk_grid_desc,
148 make_right_pad_transform(GemmMRaw, GemmMPad)),
151
152 // B: weight tensor
153 const auto wei_gemmn_gemmk_grid_desc =
155
156 const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
157 wei_gemmn_gemmk_grid_desc,
162
163 // C: output tensor
164 const auto out_gemmmraw_gemmn_grid_desc =
166
167 const auto out_gemmm_gemmn_grid_desc =
168 transform_tensor_descriptor(out_gemmmraw_gemmn_grid_desc,
169 make_tuple(make_right_pad_transform(GemmMRaw, GemmMPad),
173
174 return make_tuple(in_gemmk0_gemmm_gemmk1_grid_desc,
175 wei_gemmk0_gemmn_gemmk1_grid_desc,
176 out_gemmm_gemmn_grid_desc);
177 }
178 else if constexpr(ConvForwardSpecialization ==
180 {
181 // A: input tensor
182 const auto in_n_hi_wi_c_grid_desc =
184
185 const auto in_n_ho_wo_c_grid_desc = transform_tensor_descriptor(
186 in_n_hi_wi_c_grid_desc,
188 make_embed_transform(make_tuple(Ho), make_tuple(ConvStrideH)),
189 make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)),
193
194 const auto in_gemmk0_gemmmraw_gemmk1_grid_desc = transform_tensor_descriptor(
195 in_n_ho_wo_c_grid_desc,
197 make_merge_transform(make_tuple(N, Ho, Wo))),
200
201 const auto in_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
202 in_gemmk0_gemmmraw_gemmk1_grid_desc,
204 make_right_pad_transform(GemmMRaw, GemmMPad),
208
209 // B: weight tensor
210 const auto wei_gemmn_gemmk_grid_desc =
212
213 const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
214 wei_gemmn_gemmk_grid_desc,
219
220 // C: output tensor
221 const auto out_gemmmraw_gemmn_grid_desc =
223
224 const auto out_gemmm_gemmn_grid_desc =
225 transform_tensor_descriptor(out_gemmmraw_gemmn_grid_desc,
226 make_tuple(make_right_pad_transform(GemmMRaw, GemmMPad),
230
231 return make_tuple(in_gemmk0_gemmm_gemmk1_grid_desc,
232 wei_gemmk0_gemmn_gemmk1_grid_desc,
233 out_gemmm_gemmn_grid_desc);
234 }
235 else
236 {
237 // A: input tensor
238 const auto in_n_hi_wi_c_grid_desc =
240
241 const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
242 in_n_hi_wi_c_grid_desc,
244 make_pad_transform(Hi, InLeftPadH, InRightPadH),
245 make_pad_transform(Wi, InLeftPadW, InRightPadW),
249
250 const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
251 in_n_hip_wip_c_grid_desc,
254 make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
255 make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)),
259
260 const auto in_gemmk_gemmmraw_grid_desc =
261 transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc,
263 make_merge_transform(make_tuple(N, Ho, Wo))),
266
267 const auto in_gemmk0_gemmmraw_gemmk1_grid_desc = transform_tensor_descriptor(
268 in_gemmk_gemmmraw_grid_desc,
273
274 const auto in_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
275 in_gemmk0_gemmmraw_gemmk1_grid_desc,
277 make_right_pad_transform(GemmMRaw, GemmMPad),
281
282 // B: weight tensor
283 const auto wei_k_yxc_grid_desc =
285
286 const auto wei_gemmk_gemmn_grid_desc = transform_tensor_descriptor(
287 wei_k_yxc_grid_desc,
291
292 const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
293 wei_gemmk_gemmn_grid_desc,
298
299 // C: output tensor
300 const auto out_nhowo_k_grid_desc =
302
303 const auto out_gemmmraw_gemmn_grid_desc =
304 transform_tensor_descriptor(out_nhowo_k_grid_desc,
309
310 const auto out_gemmm_gemmn_grid_desc =
311 transform_tensor_descriptor(out_gemmmraw_gemmn_grid_desc,
312 make_tuple(make_right_pad_transform(GemmMRaw, GemmMPad),
316
317 return make_tuple(in_gemmk0_gemmm_gemmk1_grid_desc,
318 wei_gemmk0_gemmn_gemmk1_grid_desc,
319 out_gemmm_gemmn_grid_desc);
320 }
321 }
322
324 1, 1, 1, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}));
325
329
330 // GridwiseGemm
331 template <index_t NXdlPerWave_>
333 BlockSize,
334 ABDataType, // TODO: distinguish A/B datatype
335 AccDataType,
336 CDataType,
338 InElementwiseOperation,
339 WeiElementwiseOperation,
340 OutElementwiseOperation,
341 MPerBlock,
342 NPerBlock,
343 K0PerBlock,
344 MPerXDL,
345 NPerXDL,
346 K1,
347 MXdlPerWave,
348 NXdlPerWave_,
349 ABlockTransferThreadClusterLengths_K0_M_K1,
350 Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder,
351 Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder,
352 2, // ABlockTransferSrcVectorDim,
353 ABlockTransferSrcScalarPerVector,
354 ABlockTransferDstScalarPerVector_K1,
355 false, // AThreadTransferSrcResetCoordinateAfterRun,
356 ABlockLdsAddExtraM,
357 BBlockTransferThreadClusterLengths_K0_N_K1,
358 Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder,
359 Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder,
360 2, // BBlockTransferSrcVectorDim,
361 BBlockTransferSrcScalarPerVector,
362 BBlockTransferDstScalarPerVector_K1,
363 false, // BThreadTransferSrcResetCoordinateAfterRun,
364 BBlockLdsAddExtraN,
365 Sequence<2, 3, 0, 1, 7, 5, 4, 6>, // CThreadTransferSrcDstAccessOrder,
366 7, // CThreadTransferSrcDstVectorDim,
367 CThreadTransferDstScalarPerVector>;
370
371 // Argument
372 struct Argument : public BaseArgument
373 {
374 Argument(const InDataType* p_in_grid,
375 const WeiDataType* p_wei_grid,
376 OutDataType* p_out_grid,
377 ck::index_t N,
378 ck::index_t K,
379 ck::index_t C,
380 std::vector<ck::index_t> input_spatial_lengths,
381 std::vector<ck::index_t> filter_spatial_lengths,
382 std::vector<ck::index_t> output_spatial_lengths,
383 std::vector<ck::index_t> conv_filter_strides,
384 std::vector<ck::index_t> conv_filter_dilations,
385 std::vector<ck::index_t> input_left_pads,
386 std::vector<ck::index_t> input_right_pads)
387 : p_a_grid_{p_in_grid},
388 p_b_grid_{p_wei_grid},
389 p_c_grid_{p_out_grid},
393 Conv_N_{N},
394 Conv_K_{K},
395 Conv_C_{C},
396 filter_spatial_lengths_{filter_spatial_lengths},
397 conv_filter_strides_{conv_filter_strides},
398 input_left_pads_{input_left_pads},
399 input_right_pads_{input_right_pads}
400 {
401 const auto descs =
403 K,
404 C,
405 input_spatial_lengths,
406 filter_spatial_lengths,
407 output_spatial_lengths,
408 conv_filter_strides,
409 conv_filter_dilations,
410 input_left_pads,
411 input_right_pads);
412
413 a_grid_desc_k0_m_k1_ = descs[I0];
414 b_grid_desc_k0_n_k1_ = descs[I1];
415 c_grid_desc_m_n_ = descs[I2];
416 }
417
418 // private:
425 // for checking IsSupportedArgument()
429 std::vector<index_t> filter_spatial_lengths_;
430 std::vector<index_t> conv_filter_strides_;
431 std::vector<index_t> input_left_pads_;
432 std::vector<index_t> input_right_pads_;
433 };
434
435 // Invoker
436 struct Invoker : public BaseInvoker
437 {
439
440 template <typename GridwiseGemm>
441 float RunImp(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
442 {
443 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
444 {
445 std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0)
446 << ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", "
447 << arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl;
448
449 std::cout << "arg.b_grid_desc_k0_n_k1_{" << arg.b_grid_desc_k0_n_k1_.GetLength(I0)
450 << ", " << arg.b_grid_desc_k0_n_k1_.GetLength(I1) << ", "
451 << arg.b_grid_desc_k0_n_k1_.GetLength(I2) << "}" << std::endl;
452
453 std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", "
454 << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
455 }
456
457 if(!GridwiseGemm::CheckValidity(
459 {
460 throw std::runtime_error(
461 "wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 has invalid setting");
462 }
463
464 const auto [gdx, gdy, gdz] = GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_);
465
466 const auto K =
467 arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2);
468
469 float ave_time = 0;
470
471 if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
472 {
473 const auto kernel =
474 kernel_gemm_xdlops_v2r3<GridwiseGemm,
475 ADataType, // TODO: distiguish A/B datatype
476 CDataType,
480 true>;
481
482 ave_time = launch_and_time_kernel(stream_config,
483 kernel,
484 dim3(gdx, gdy, gdz),
485 dim3(BlockSize),
486 0,
487 arg.p_a_grid_,
488 arg.p_b_grid_,
489 arg.p_c_grid_,
492 arg.c_grid_desc_m_n_);
493 }
494 else
495 {
496 const auto kernel =
497 kernel_gemm_xdlops_v2r3<GridwiseGemm,
498 ADataType, // TODO: distiguish A/B datatype
499 CDataType,
503 false>;
504
505 ave_time = launch_and_time_kernel(stream_config,
506 kernel,
507 dim3(gdx, gdy, gdz),
508 dim3(BlockSize),
509 0,
510 arg.p_a_grid_,
511 arg.p_b_grid_,
512 arg.p_c_grid_,
515 arg.c_grid_desc_m_n_);
516 }
517
518 return ave_time;
519 }
520
522
523 float Run(const BaseArgument* p_arg,
524 const StreamConfig& stream_config = StreamConfig{}) override
525 {
526 return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
527 }
528 };
529
530 static constexpr bool IsValidCompilationParameter()
531 {
532 // TODO: properly implement this check
533 return true;
534 }
535
536 static bool IsSupportedArgument(const Argument& arg)
537 {
539 {
540 return false;
541 }
542 if constexpr(ConvForwardSpecialization ==
544 {
545 // check if it's 1x1, stride=1 conv
546 if(!(arg.filter_spatial_lengths_[0] == 1 && arg.filter_spatial_lengths_[1] == 1 &&
547 arg.conv_filter_strides_[0] == 1 && arg.conv_filter_strides_[1] == 1 &&
548 arg.input_left_pads_[0] == 0 && arg.input_left_pads_[1] == 0 &&
549 arg.input_right_pads_[0] == 0 && arg.input_right_pads_[1] == 0))
550 {
551 return false;
552 }
553 }
554 else if constexpr(ConvForwardSpecialization ==
556 {
557 // check if it's 1x1 conv
558 if(!(arg.filter_spatial_lengths_[0] == 1 && arg.filter_spatial_lengths_[1] == 1 &&
559 arg.input_left_pads_[0] == 0 && arg.input_left_pads_[1] == 0 &&
560 arg.input_right_pads_[0] == 0 && arg.input_right_pads_[1] == 0))
561 {
562 return false;
563 }
564 }
565
566 // vector load A/B matrix from global memory
567 if(!(ABlockTransferSrcVectorDim == 2 && BBlockTransferSrcVectorDim == 2 &&
568 arg.Conv_C_ % ABlockTransferSrcScalarPerVector == 0 &&
569 arg.Conv_C_ % BBlockTransferSrcScalarPerVector == 0))
570 {
571 return false;
572 }
573
574 // vector store C matrix into global memory
575 if(!(arg.Conv_K_ % CThreadTransferDstScalarPerVector == 0))
576 {
577 return false;
578 }
579
580 // Gridwise GEMM size
581 if(get_warp_size() == 64)
582 {
583 if constexpr(NXdlPerWave64 > 0)
584 {
587 }
588 }
589 else
590 {
591 if constexpr(NXdlPerWave32 > 0)
592 {
595 }
596 }
597 return false;
598 }
599
600 bool IsSupportedArgument(const BaseArgument* p_arg) override
601 {
602 return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
603 }
604
605 static auto MakeArgument(const InDataType* p_in_grid,
606 const WeiDataType* p_wei_grid,
607 OutDataType* p_out_grid,
608 ck::index_t N,
609 ck::index_t K,
610 ck::index_t C,
611 std::vector<ck::index_t> input_spatial_lengths,
612 std::vector<ck::index_t> filter_spatial_lengths,
613 std::vector<ck::index_t> output_spatial_lengths,
614 std::vector<ck::index_t> conv_filter_strides,
615 std::vector<ck::index_t> conv_filter_dilations,
616 std::vector<ck::index_t> input_left_pads,
617 std::vector<ck::index_t> input_right_pads)
618 {
619 return Argument{p_in_grid,
620 p_wei_grid,
621 p_out_grid,
622 N,
623 K,
624 C,
625 input_spatial_lengths,
626 filter_spatial_lengths,
627 output_spatial_lengths,
628 conv_filter_strides,
629 conv_filter_dilations,
630 input_left_pads,
631 input_right_pads};
632 }
633
634 static auto MakeInvoker() { return Invoker{}; }
635
636 std::unique_ptr<BaseArgument>
637 MakeArgumentPointer(const void* p_in_grid,
638 const void* p_wei_grid,
639 void* p_out_grid,
640 ck::index_t N,
641 ck::index_t K,
642 ck::index_t C,
643 std::vector<ck::index_t> input_spatial_lengths,
644 std::vector<ck::index_t> filter_spatial_lengths,
645 std::vector<ck::index_t> output_spatial_lengths,
646 std::vector<ck::index_t> conv_filter_strides,
647 std::vector<ck::index_t> conv_filter_dilations,
648 std::vector<ck::index_t> input_left_pads,
649 std::vector<ck::index_t> input_right_pads,
650 InElementwiseOperation,
651 WeiElementwiseOperation,
652 OutElementwiseOperation) override
653 {
654 return std::make_unique<Argument>(static_cast<const InDataType*>(p_in_grid),
655 static_cast<const WeiDataType*>(p_wei_grid),
656 static_cast<OutDataType*>(p_out_grid),
657 N,
658 K,
659 C,
660 input_spatial_lengths,
661 filter_spatial_lengths,
662 output_spatial_lengths,
663 conv_filter_strides,
664 conv_filter_dilations,
665 input_left_pads,
666 input_right_pads);
667 }
668
669 std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
670 {
671 return std::make_unique<Invoker>(Invoker{});
672 }
673
674 std::string GetTypeString() const override
675 {
676 auto str = std::stringstream();
677
678 // clang-format off
679 str << "DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K"
680 << "<"
681 << BlockSize << ", "
682 << MPerBlock << ", "
683 << NPerBlock << ", "
684 << K0PerBlock << ", "
685 << getConvForwardSpecializationString(ConvForwardSpecialization) << ", "
686 << K1 << ", "
687 << MPerXDL << ", "
688 << NPerXDL << ", "
689 << MXdlPerWave << ", "
690 << NXdlPerWave << ", "
691 << ABlockTransferSrcScalarPerVector << ", "
692 << ABlockTransferDstScalarPerVector_K1 << ", "
693 << BBlockTransferSrcScalarPerVector << ", "
694 << BBlockTransferDstScalarPerVector_K1
695 << ">";
696 // clang-format on
697
698 return str.str();
699 }
700};
701
702} // namespace device
703} // namespace tensor_operation
704} // namespace ck
#define GET_NXDL_PER_WAVE_IMPL
Definition device_base.hpp:81
#define INVOKER_RUN_IMPL
Definition device_base.hpp:94
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:14
__host__ __device__ constexpr auto integer_least_multiple(X x, Y y)
Definition utility/math.hpp:78
__host__ __device__ constexpr T max(T x)
Definition utility/math.hpp:84
Definition convolution_backward_data_specialization.hpp:8
ConvolutionForwardSpecialization
Definition convolution_forward_specialization.hpp:15
@ Filter1x1Stride1Pad0
Definition convolution_forward_specialization.hpp:18
@ Filter1x1Pad0
Definition convolution_forward_specialization.hpp:17
std::string getConvForwardSpecializationString(const ConvolutionForwardSpecialization &s)
Definition convolution_forward_specialization.hpp:24
Definition convolution_backward_data_specialization.hpp:7
Definition ck.hpp:268
__host__ __device__ constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition multi_index_transform_helper.hpp:12
int32_t index_t
Definition ck.hpp:299
@ Set
Definition ck.hpp:278
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
__host__ __device__ constexpr auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:37
integral_constant< index_t, N > Number
Definition number.hpp:12
__host__ __device__ constexpr auto make_pad_transform(const LowLength &low_length, const LeftPad &left_pad, const RightPad &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:19
__host__ __device__ constexpr auto make_embed_transform(const UpLengths &up_lengths, const Coefficients &coefficients)
Definition multi_index_transform_helper.hpp:48
__host__ __device__ constexpr auto make_merge_transform(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:55
bool is_xdl_wmma_supported()
Definition host_utility/device_prop.hpp:76
__device__ constexpr index_t get_warp_size()
Definition get_id.hpp:10
bool EnvIsEnabled(EnvVar)
Definition utility/env.hpp:140
__global__ void kernel_gemm_xdlops_v2r3(const FloatAB *__restrict__ p_a_grid, const FloatAB *__restrict__ p_b_grid, FloatC *__restrict__ p_c_grid, const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1, const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1, const CGridDesc_M_N c_grid_desc_m_n)
Definition gridwise_gemm_xdlops_v2r3.hpp:34
__host__ __device__ constexpr auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition tensor_descriptor_helper.hpp:101
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
__host__ __device__ constexpr auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:90
Definition ck/stream_config.hpp:10
Definition gridwise_gemm_xdlops_v2r3.hpp:142
Definition utility/sequence.hpp:43
Definition device_base.hpp:197
CGridDesc_M_N c_grid_desc_m_n_
Definition device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp:424
std::vector< index_t > input_left_pads_
Definition device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp:431
const ADataType * p_a_grid_
Definition device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp:419
const BDataType * p_b_grid_
Definition device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp:420
AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_
Definition device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp:422
Argument(const InDataType *p_in_grid, const WeiDataType *p_wei_grid, OutDataType *p_out_grid, ck::index_t N, ck::index_t K, ck::index_t C, std::vector< ck::index_t > input_spatial_lengths, std::vector< ck::index_t > filter_spatial_lengths, std::vector< ck::index_t > output_spatial_lengths, std::vector< ck::index_t > conv_filter_strides, std::vector< ck::index_t > conv_filter_dilations, std::vector< ck::index_t > input_left_pads, std::vector< ck::index_t > input_right_pads)
Definition device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp:374
std::vector< index_t > input_right_pads_
Definition device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp:432
std::vector< index_t > conv_filter_strides_
Definition device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp:430
std::vector< index_t > filter_spatial_lengths_
Definition device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp:429
BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_
Definition device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp:423
INVOKER_RUN_IMPL float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp:523
DeviceOp::Argument Argument
Definition device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp:438
float RunImp(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp:441
OutDataType CDataType
Definition device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp:78
std::string GetTypeString() const override
Definition device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp:674
static constexpr auto K1Number
Definition device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp:90
static auto MakeInvoker()
Definition device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp:634
static bool IsSupportedArgument(const Argument &arg)
Definition device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp:536
static constexpr auto GemmK1Number
Definition device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp:91
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp:600
static GET_NXDL_PER_WAVE_IMPL constexpr auto NXdlPerWave64
Definition device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp:73
InDataType ABDataType
Definition device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp:81
InDataType ADataType
Definition device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp:76
remove_cvref_t< decltype(ABCGridDescs{}[I2])> CGridDesc_M_N
Definition device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp:328
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_in_grid, const void *p_wei_grid, void *p_out_grid, ck::index_t N, ck::index_t K, ck::index_t C, std::vector< ck::index_t > input_spatial_lengths, std::vector< ck::index_t > filter_spatial_lengths, std::vector< ck::index_t > output_spatial_lengths, std::vector< ck::index_t > conv_filter_strides, std::vector< ck::index_t > conv_filter_dilations, std::vector< ck::index_t > input_left_pads, std::vector< ck::index_t > input_right_pads, InElementwiseOperation, WeiElementwiseOperation, OutElementwiseOperation) override
Definition device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp:637
GridwiseGemmBase< math::max(NXdlPerWave64, 1)> GridwiseGemm64
Definition device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp:368
GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3< BlockSize, ABDataType, AccDataType, CDataType, InMemoryDataOperationEnum::Set, InElementwiseOperation, WeiElementwiseOperation, OutElementwiseOperation, MPerBlock, NPerBlock, K0PerBlock, MPerXDL, NPerXDL, K1, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_K0_M_K1, Sequence< 1, 0, 2 >, Sequence< 1, 0, 2 >, 2, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_K1, false, ABlockLdsAddExtraM, BBlockTransferThreadClusterLengths_K0_N_K1, Sequence< 1, 0, 2 >, Sequence< 1, 0, 2 >, 2, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_K1, false, BBlockLdsAddExtraN, Sequence< 2, 3, 0, 1, 7, 5, 4, 6 >, 7, CThreadTransferDstScalarPerVector > GridwiseGemmBase
Definition device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp:332
static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(ck::index_t N, ck::index_t K, ck::index_t C, std::vector< ck::index_t > input_spatial_lengths, std::vector< ck::index_t > filter_spatial_lengths, std::vector< ck::index_t > output_spatial_lengths, std::vector< ck::index_t > conv_filter_strides, std::vector< ck::index_t > conv_filter_dilations, std::vector< ck::index_t > input_left_pads, std::vector< ck::index_t > input_right_pads)
Definition device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp:94
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K DeviceOp
Definition device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp:70
WeiDataType BDataType
Definition device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp:77
remove_cvref_t< decltype(ABCGridDescs{}[I0])> AGridDesc_K0_M_K1
Definition device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp:326
decltype(MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( 1, 1, 1, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1})) ABCGridDescs
Definition device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp:323
GridwiseGemmBase< NXdlPerWave32 > GridwiseGemm32
Definition device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp:369
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp:669
static constexpr auto I1
Definition device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp:86
static constexpr auto I3
Definition device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp:88
static constexpr bool IsValidCompilationParameter()
Definition device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp:530
static constexpr index_t NDimSpatial
Definition device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp:83
static constexpr auto I2
Definition device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp:87
remove_cvref_t< decltype(ABCGridDescs{}[I1])> BGridDesc_K0_N_K1
Definition device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp:327
static auto MakeArgument(const InDataType *p_in_grid, const WeiDataType *p_wei_grid, OutDataType *p_out_grid, ck::index_t N, ck::index_t K, ck::index_t C, std::vector< ck::index_t > input_spatial_lengths, std::vector< ck::index_t > filter_spatial_lengths, std::vector< ck::index_t > output_spatial_lengths, std::vector< ck::index_t > conv_filter_strides, std::vector< ck::index_t > conv_filter_dilations, std::vector< ck::index_t > input_left_pads, std::vector< ck::index_t > input_right_pads)
Definition device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp:605
static constexpr auto NXdlPerWave32
Definition device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp:74
static constexpr auto I0
Definition device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp:85
Definition device_conv_fwd.hpp:25
#define CK_ENV(name)
Definition utility/env.hpp:129