gridwise_elementwise_1d_scale.hpp Source File

gridwise_elementwise_1d_scale.hpp Source File#

Composable Kernel: gridwise_elementwise_1d_scale.hpp Source File
gridwise_elementwise_1d_scale.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
10
11namespace ck {
12
13template <typename GridwiseElementwise1dFunctor,
14 typename InGrid1dDescTuple,
15 typename OutGrid1dDescTuple,
16 typename InDataTypePointerTuple,
17 typename OutDataTypePointerTuple,
18 typename ElementwiseOperation,
19 typename UnaryOperation,
20 typename Scale>
21__global__ void kernel_elementwise_1d(const InGrid1dDescTuple in_grid_1d_desc_tuple,
22 const OutGrid1dDescTuple out_grid_1d_desc_tuple,
23 const InDataTypePointerTuple p_in_global_tuple,
24 const OutDataTypePointerTuple p_out_global_tuple,
25 const ElementwiseOperation elementwise_op,
26 const UnaryOperation unary_op,
27 const Scale scale_op)
28{
29 GridwiseElementwise1dFunctor::Run(in_grid_1d_desc_tuple,
30 out_grid_1d_desc_tuple,
31 p_in_global_tuple,
32 p_out_global_tuple,
33 elementwise_op,
34 unary_op,
35 scale_op);
36}
37
38template <typename InGrid1dDescTuple,
39 typename OutGrid1dDescTuple,
40 typename InDataTypePointerTuple,
41 typename OutDataTypePointerTuple,
42 typename ElementwiseOperation,
43 typename UnaryOperation,
44 typename Scale,
45 index_t MPerThread,
46 typename InScalarPerVectorSeq,
47 typename OutScalarPerVectorSeq>
49{
50 static constexpr index_t NumInput = InDataTypePointerTuple::Size();
51 static constexpr index_t NumOutput = OutDataTypePointerTuple::Size();
52
53 static_assert(NumInput == InScalarPerVectorSeq::Size() &&
54 NumOutput == OutScalarPerVectorSeq::Size() &&
55 NumInput == InGrid1dDescTuple::Size() &&
56 NumOutput == OutGrid1dDescTuple::Size(),
57 "Tuple size is inconsistent with the number of in/out!");
58
59 static constexpr auto I0 = Number<0>{};
60
61 static constexpr auto thread_buffer_desc_m =
63
65
66 __device__ static void Run(const InGrid1dDescTuple in_grid_1d_desc_tuple,
67 const OutGrid1dDescTuple out_grid_1d_desc_tuple,
68 const InDataTypePointerTuple p_in_global_tuple,
69 const OutDataTypePointerTuple p_out_global_tuple,
70 const ElementwiseOperation elementwise_op,
71 const UnaryOperation unary_op,
72 const Scale scale_op)
73 {
74 const index_t thread_global_id = get_thread_global_1d_id();
75
76 auto in_thread_buf_tuple = generate_tuple(
77 [&](auto I) {
78 using DataTypePointer = remove_cvref_t<decltype(InDataTypePointerTuple{}[I])>;
80
82 },
84
85 auto out_thread_buf_tuple = generate_tuple(
86 [&](auto I) {
87 using DataTypePointer = remove_cvref_t<decltype(OutDataTypePointerTuple{}[I])>;
88 using DataType = remove_pointer_t<DataTypePointer>;
89
91 },
93
94 auto in_global_buf_tuple = generate_tuple(
95 [&](auto I) {
96 static_assert(in_grid_1d_desc_tuple[I].GetNumOfDimension() == 1);
97
99 p_in_global_tuple[I], in_grid_1d_desc_tuple[I].GetElementSpaceSize());
100 },
102
103 auto out_global_buf_tuple = generate_tuple(
104 [&](auto I) {
105 static_assert(out_grid_1d_desc_tuple[I].GetNumOfDimension() == 1);
106
108 p_out_global_tuple[I], out_grid_1d_desc_tuple[I].GetElementSpaceSize());
109 },
111
112 const auto thread_global_offset = make_multi_index(thread_global_id * MPerThread);
113
114 const index_t blockSize = get_block_size();
115 const index_t blockPerGrid = get_grid_size();
116 const auto M = in_grid_1d_desc_tuple[I0].GetLength(I0);
117 const index_t loop_step = blockPerGrid * blockSize * MPerThread;
118 const auto loop_step_index = make_multi_index(loop_step);
119
120 auto in_global_load_tuple = generate_tuple(
121 [&](auto I) {
122 using DataTypePointer = remove_cvref_t<decltype(InDataTypePointerTuple{}[I])>;
124
125 return ThreadwiseTensorSliceTransfer_v2<DataType,
126 DataType,
127 decltype(in_grid_1d_desc_tuple[I]),
128 decltype(thread_buffer_desc_m),
129 Sequence<MPerThread>, // SliceLengths
130 Sequence<0>, // DimAccessOrder
131 0, // SrcVectorDim
132 InScalarPerVectorSeq::At(
133 I), // ScalarPerVector
134 1, // SrcScalarStrideInVector
135 false>{in_grid_1d_desc_tuple[I],
136 thread_global_offset};
137 },
139
140 auto out_global_store_tuple = generate_tuple(
141 [&](auto I) {
142 using DataTypePointer = remove_cvref_t<decltype(OutDataTypePointerTuple{}[I])>;
143 using DataType = remove_pointer_t<DataTypePointer>;
144
146 DataType,
147 decltype(thread_buffer_desc_m),
148 decltype(out_grid_1d_desc_tuple[I]),
150 Sequence<MPerThread>, // SliceLengths
151 Sequence<0>, // DimAccessOrder
152 0, // SrcVectorDim
153 OutScalarPerVectorSeq::At(I),
155 1,
156 false>(
157 out_grid_1d_desc_tuple[I], thread_global_offset, PassThroughOp{});
158 },
160
161 index_t num_iter = M / (loop_step);
162 do
163 {
164 static_for<0, NumInput, 1>{}([&](auto I) {
165 in_global_load_tuple(I).Run(in_grid_1d_desc_tuple[I],
166 in_global_buf_tuple[I],
168 make_tuple(I0),
169 in_thread_buf_tuple(I));
170
171 in_global_load_tuple(I).MoveSrcSliceWindow(in_grid_1d_desc_tuple[I],
172 loop_step_index);
173 });
174
175 static_for<0, MPerThread, 1>{}([&](auto iM) {
176 // get reference to in data
177 auto uop_data_refs = generate_tie(
178 // return type should be lvalue
179 [&](auto I) -> auto& { return in_thread_buf_tuple(I)(iM); },
181
182 // get reference to dst data
183 auto out_data_refs = generate_tie(
184 // return type should be lvalue
185 [&](auto I) -> auto& { return out_thread_buf_tuple(I)(iM); },
187
188 unpack2(unary_op, uop_data_refs, uop_data_refs);
189
190 auto sop_in_data_refs = generate_tie(
191 // return type should be lvalue
192 [&](auto I) -> auto& { return in_thread_buf_tuple(I)(iM); },
194
195 auto sop_out_data_refs = generate_tie(
196 // return type should be lvalue
197 [&](auto I) -> auto& { return in_thread_buf_tuple(I)(iM); },
199
200 unpack2(scale_op, sop_out_data_refs, sop_in_data_refs);
201
202 const auto in_data_refs = generate_tie(
203 // return type should be lvalue
204 [&](auto I) -> const auto& { return in_thread_buf_tuple(I)(iM); },
206
207 unpack2(elementwise_op, out_data_refs, in_data_refs);
208 });
209
210 static_for<0, NumOutput, 1>{}([&](auto I) {
211 out_global_store_tuple(I).Run(thread_buffer_desc_m,
212 make_tuple(I0),
213 out_thread_buf_tuple[I],
214 out_grid_1d_desc_tuple[I],
215 out_global_buf_tuple(I));
216
217 out_global_store_tuple(I).MoveDstSliceWindow(out_grid_1d_desc_tuple[I],
218 loop_step_index);
219 });
220 } while(--num_iter);
221 }
222};
223
224} // namespace ck
Definition ck.hpp:268
__host__ __device__ constexpr auto make_multi_index(Xs &&... xs)
Definition array_multi_index.hpp:15
__device__ index_t get_grid_size()
Definition get_id.hpp:49
int32_t index_t
Definition ck.hpp:299
@ Set
Definition ck.hpp:278
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
__device__ index_t get_block_size()
Definition get_id.hpp:51
integral_constant< index_t, N > Number
Definition number.hpp:12
__global__ void kernel_elementwise_1d(const InGrid1dDescTuple in_grid_1d_desc_tuple, const OutGrid1dDescTuple out_grid_1d_desc_tuple, const InDataTypePointerTuple p_in_global_tuple, const OutDataTypePointerTuple p_out_global_tuple, const ElementwiseOperation elementwise_op, const UnaryOperation unary_op, const Scale scale_op)
Definition gridwise_elementwise_1d_scale.hpp:21
__device__ index_t get_thread_global_1d_id()
Definition get_id.hpp:43
__host__ __device__ constexpr auto generate_tuple(F &&f, Number< N >)
Definition tuple_helper.hpp:21
__host__ __device__ constexpr auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition tensor_descriptor_helper.hpp:101
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
typename remove_cv< T >::type remove_cv_t
Definition type.hpp:295
__host__ __device__ constexpr auto unpack2(F &&f, X &&x, Y &&y)
Definition functional4.hpp:55
typename remove_pointer< T >::type remove_pointer_t
Definition type.hpp:300
__host__ __device__ constexpr auto make_dynamic_buffer(T *p, ElementSpaceSize element_space_size)
Definition dynamic_buffer.hpp:472
__host__ __device__ constexpr auto generate_tie(F &&f, Number< N >)
Definition tuple_helper.hpp:34
Definition gridwise_elementwise_1d_scale.hpp:49
tensor_operation::element_wise::PassThrough PassThroughOp
Definition gridwise_elementwise_1d_scale.hpp:64
static constexpr index_t NumOutput
Definition gridwise_elementwise_1d_scale.hpp:51
static __device__ void Run(const InGrid1dDescTuple in_grid_1d_desc_tuple, const OutGrid1dDescTuple out_grid_1d_desc_tuple, const InDataTypePointerTuple p_in_global_tuple, const OutDataTypePointerTuple p_out_global_tuple, const ElementwiseOperation elementwise_op, const UnaryOperation unary_op, const Scale scale_op)
Definition gridwise_elementwise_1d_scale.hpp:66
static constexpr auto thread_buffer_desc_m
Definition gridwise_elementwise_1d_scale.hpp:61
static constexpr auto I0
Definition gridwise_elementwise_1d_scale.hpp:59
static constexpr index_t NumInput
Definition gridwise_elementwise_1d_scale.hpp:50
Definition utility/sequence.hpp:43
Definition static_buffer.hpp:16
Definition threadwise_tensor_slice_transfer.hpp:39
Helper structure that facilitates transfer of source (grid) data to destination threads.
Definition threadwise_tensor_slice_transfer.hpp:234
Definition functional2.hpp:33
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:340