convolution_parameter.hpp Source File

convolution_parameter.hpp Source File#

Composable Kernel: convolution_parameter.hpp Source File
library/utility/convolution_parameter.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include <cstdlib>
7#include <numeric>
8#include <iterator>
9#include <vector>
10
11#include "ck/ck.hpp"
12
14
15namespace ck {
16namespace utils {
17namespace conv {
18
20{
23 ck::index_t group_count,
24 ck::index_t n_batch,
25 ck::index_t n_out_channels,
26 ck::index_t n_in_channels,
27 const std::vector<ck::index_t>& filters_len,
28 const std::vector<ck::index_t>& input_len,
29 const std::vector<ck::index_t>& strides,
30 const std::vector<ck::index_t>& dilations,
31 const std::vector<ck::index_t>& left_pads,
32 const std::vector<ck::index_t>& right_pads);
33
35 ck::long_index_t group_count,
36 ck::long_index_t n_batch,
37 ck::long_index_t n_out_channels,
38 ck::long_index_t n_in_channels,
39 const std::vector<ck::long_index_t>& filters_len,
40 const std::vector<ck::long_index_t>& input_len,
41 const std::vector<ck::long_index_t>& strides,
42 const std::vector<ck::long_index_t>& dilations,
43 const std::vector<ck::long_index_t>& left_pads,
44 const std::vector<ck::long_index_t>& right_pads);
45
51
52 std::vector<ck::long_index_t> filter_spatial_lengths_;
53 std::vector<ck::long_index_t> input_spatial_lengths_;
54 std::vector<ck::long_index_t> output_spatial_lengths_;
55
56 std::vector<ck::long_index_t> conv_filter_strides_;
57 std::vector<ck::long_index_t> conv_filter_dilations_;
58
59 std::vector<ck::long_index_t> input_left_pads_;
60 std::vector<ck::long_index_t> input_right_pads_;
61
62 std::vector<ck::long_index_t> GetOutputSpatialLengths() const;
63
64 std::size_t GetFlops() const;
65
66 template <typename InDataType>
67 std::size_t GetInputByte() const
68 {
69 // sizeof(InDataType) * (G * N * C * <input spatial lengths product>) +
70 return sizeof(InDataType) *
71 (G_ * N_ * C_ *
73 std::begin(input_spatial_lengths_), num_dim_spatial_, 1, std::multiplies<>()));
74 }
75
76 template <typename WeiDataType>
77 std::size_t GetWeightByte() const
78 {
79 // sizeof(WeiDataType) * (G * K * C * <filter spatial lengths product>) +
80 return sizeof(WeiDataType) *
81 (G_ * K_ * C_ *
83 std::begin(filter_spatial_lengths_), num_dim_spatial_, 1, std::multiplies<>()));
84 }
85
86 template <typename OutDataType>
87 std::size_t GetOutputByte() const
88 {
89 // sizeof(OutDataType) * (G * N * K * <output spatial lengths product>);
90 return sizeof(OutDataType) * (G_ * N_ * K_ *
91 std::accumulate(std::begin(output_spatial_lengths_),
93 static_cast<std::size_t>(1),
94 std::multiplies<std::size_t>()));
95 }
96
97 template <typename InDataType, typename WeiDataType, typename OutDataType>
98 std::size_t GetByte() const
99 {
102 }
103};
104
106
107ConvParam parse_conv_param(int num_dim_spatial, int arg_idx, char* const argv[]);
108
109} // namespace conv
110} // namespace utils
111} // namespace ck
112
113std::ostream& operator<<(std::ostream& os, const ck::utils::conv::ConvParam& p);
std::ostream & operator<<(std::ostream &os, const ck::utils::conv::ConvParam &p)
Definition library/utility/convolution_host_tensor_descriptor_helper.hpp:14
ConvParam parse_conv_param(int num_dim_spatial, int arg_idx, char *const argv[])
std::string get_conv_param_parser_helper_msg()
Definition library/utility/check_err.hpp:24
Definition ck.hpp:268
int32_t index_t
Definition ck.hpp:299
auto accumulate_n(ForwardIterator first, Size count, T init, BinaryOperation op) -> decltype(std::accumulate(first, std::next(first, count), init, op))
Definition library/utility/numeric.hpp:11
int64_t long_index_t
Definition ck.hpp:300
Definition library/utility/convolution_parameter.hpp:20
std::size_t GetFlops() const
ck::long_index_t C_
Definition library/utility/convolution_parameter.hpp:50
std::vector< ck::long_index_t > input_right_pads_
Definition library/utility/convolution_parameter.hpp:60
std::vector< ck::long_index_t > input_left_pads_
Definition library/utility/convolution_parameter.hpp:59
ConvParam(ck::long_index_t n_dim, ck::long_index_t group_count, ck::long_index_t n_batch, ck::long_index_t n_out_channels, ck::long_index_t n_in_channels, const std::vector< ck::long_index_t > &filters_len, const std::vector< ck::long_index_t > &input_len, const std::vector< ck::long_index_t > &strides, const std::vector< ck::long_index_t > &dilations, const std::vector< ck::long_index_t > &left_pads, const std::vector< ck::long_index_t > &right_pads)
std::vector< ck::long_index_t > conv_filter_dilations_
Definition library/utility/convolution_parameter.hpp:57
ck::long_index_t num_dim_spatial_
Definition library/utility/convolution_parameter.hpp:46
std::vector< ck::long_index_t > input_spatial_lengths_
Definition library/utility/convolution_parameter.hpp:53
ConvParam(ck::index_t n_dim, ck::index_t group_count, ck::index_t n_batch, ck::index_t n_out_channels, ck::index_t n_in_channels, const std::vector< ck::index_t > &filters_len, const std::vector< ck::index_t > &input_len, const std::vector< ck::index_t > &strides, const std::vector< ck::index_t > &dilations, const std::vector< ck::index_t > &left_pads, const std::vector< ck::index_t > &right_pads)
std::size_t GetByte() const
Definition library/utility/convolution_parameter.hpp:98
std::vector< ck::long_index_t > output_spatial_lengths_
Definition library/utility/convolution_parameter.hpp:54
std::vector< ck::long_index_t > conv_filter_strides_
Definition library/utility/convolution_parameter.hpp:56
std::size_t GetInputByte() const
Definition library/utility/convolution_parameter.hpp:67
ck::long_index_t N_
Definition library/utility/convolution_parameter.hpp:48
std::vector< ck::long_index_t > GetOutputSpatialLengths() const
ck::long_index_t G_
Definition library/utility/convolution_parameter.hpp:47
std::size_t GetWeightByte() const
Definition library/utility/convolution_parameter.hpp:77
ck::long_index_t K_
Definition library/utility/convolution_parameter.hpp:49
std::vector< ck::long_index_t > filter_spatial_lengths_
Definition library/utility/convolution_parameter.hpp:52
std::size_t GetOutputByte() const
Definition library/utility/convolution_parameter.hpp:87