device_grouped_conv_utils.hpp Source File

device_grouped_conv_utils.hpp Source File#

Composable Kernel: device_grouped_conv_utils.hpp Source File
device_grouped_conv_utils.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2023-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
8
9namespace ck {
10namespace tensor_operation {
11namespace device {
12
13// 1d
14template <typename InLayout, typename WeiLayout, typename OutLayout>
21
22template <typename InLayout, typename WeiLayout, typename OutLayout>
29
30template <typename InLayout, typename WeiLayout, typename OutLayout>
37
38// 2d
39template <typename InLayout, typename WeiLayout, typename OutLayout>
46
47template <typename InLayout, typename WeiLayout, typename OutLayout>
54
55template <typename InLayout, typename WeiLayout, typename OutLayout>
62
63template <typename InLayout, typename WeiLayout, typename OutLayout>
70
71template <typename InLayout, typename WeiLayout, typename OutLayout>
77
78// 3d
79template <typename InLayout, typename WeiLayout, typename OutLayout>
86
87template <typename InLayout, typename WeiLayout, typename OutLayout>
94
95template <typename InLayout, typename WeiLayout, typename OutLayout>
102
103template <typename InLayout, typename WeiLayout, typename OutLayout>
110
111template <typename InLayout, typename WeiLayout, typename OutLayout>
117
118template <typename InLayout, typename WeiLayout, typename OutLayout>
125
126template <typename InLayout, typename WeiLayout, typename OutLayout>
133
134template <typename InLayout, typename WeiLayout, typename OutLayout>
141
142template <index_t NumATensor = 1, index_t NumBTensor = 1, index_t NumDTensor = 0, typename = void>
143struct ComputePtrOffsetOfStridedBatch
144{
145};
146
147template <index_t NumATensor, index_t NumBTensor, index_t NumDTensor>
149 NumBTensor,
150 NumDTensor,
151 enable_if_t<(NumATensor > 1 || NumBTensor > 1)>>
152{
154
156 Array<long_index_t, NumBTensor>& BatchStrideBs,
157 Array<long_index_t, NumDTensor>& BatchStrideDs,
158 long_index_t BatchStrideE)
159 : BatchStrideA_(BatchStrideAs),
160 BatchStrideB_(BatchStrideBs),
161 BatchStrideDs_(BatchStrideDs),
162 BatchStrideE_(BatchStrideE)
163 {
164 }
165
166 __host__ __device__ constexpr auto GetAsPtrOffset(index_t g_idx) const
167 {
170 [&](auto i) { as_offset(i) = static_cast<long_index_t>(g_idx) * BatchStrideA_[i]; });
171 return as_offset;
172 }
173
174 __host__ __device__ constexpr auto GetBsPtrOffset(index_t g_idx) const
175 {
178 [&](auto i) { bs_offset(i) = static_cast<long_index_t>(g_idx) * BatchStrideB_[i]; });
179 return bs_offset;
180 }
181
182 __host__ __device__ constexpr auto GetDsPtrOffset(index_t g_idx) const
183 {
186 [&](auto i) { ds_offset(i) = static_cast<long_index_t>(g_idx) * BatchStrideDs_[i]; });
187 return ds_offset;
188 }
189
190 [[maybe_unused]] __host__ __device__ constexpr long_index_t GetEPtrOffset(index_t g_idx) const
191 {
192 return static_cast<long_index_t>(g_idx) * BatchStrideE_;
193 }
194
195 // alias for kernels without multiple D
196 [[maybe_unused]] __host__ __device__ constexpr long_index_t GetCPtrOffset(index_t g_idx) const
197 {
198 return static_cast<long_index_t>(g_idx) * BatchStrideE_;
199 }
200
205 long_index_t& BatchStrideC_ = BatchStrideE_; // alias for kernels without multiple D
206};
207
208template <index_t NumATensor, index_t NumBTensor, index_t NumDTensor>
210 NumBTensor,
211 NumDTensor,
212 enable_if_t<(NumATensor == 1 && NumBTensor == 1)>>
213{
215
217 long_index_t BatchStrideB,
219 long_index_t BatchStrideE)
220 : BatchStrideA_(BatchStrideA),
221 BatchStrideB_(BatchStrideB),
222 BatchStrideDs_(BatchStrideDs),
223 BatchStrideE_(BatchStrideE)
224 {
225 }
226
227 __host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const
228 {
229 return static_cast<long_index_t>(g_idx) * BatchStrideA_;
230 }
231
232 __host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const
233 {
234 return static_cast<long_index_t>(g_idx) * BatchStrideB_;
235 }
236
237 __host__ __device__ constexpr auto GetDsPtrOffset(index_t g_idx) const
238 {
241 [&](auto i) { ds_offset(i) = static_cast<long_index_t>(g_idx) * BatchStrideDs_[i]; });
242 return ds_offset;
243 }
244
245 [[maybe_unused]] __host__ __device__ constexpr long_index_t GetEPtrOffset(index_t g_idx) const
246 {
247 return static_cast<long_index_t>(g_idx) * BatchStrideE_;
248 }
249
250 // alias for kernels without multiple D
251 [[maybe_unused]] __host__ __device__ constexpr long_index_t GetCPtrOffset(index_t g_idx) const
252 {
253 return static_cast<long_index_t>(g_idx) * BatchStrideE_;
254 }
255
260 long_index_t& BatchStrideC_ = BatchStrideE_; // alias for kernels without multiple D
261};
262
263template <bool isTuple, typename Tensors>
264constexpr static auto GetNumABTensors()
265{
266 if constexpr(isTuple)
267 {
268 return Number<Tensors::Size()>{};
269 }
270 else
271 {
272 return Number<1>{};
273 }
274}
275
276template <bool isTuple, typename GridwiseGemm, typename DataType>
277constexpr static auto GetAGridPointer()
278{
279 if constexpr(isTuple)
280 {
281 return typename GridwiseGemm::AsGridPointer{};
282 }
283 else
284 {
285 return Tuple<const DataType*>{};
286 }
287}
288
289template <bool isTuple, typename GridwiseGemm, typename DataType>
290constexpr static auto GetBGridPointer()
291{
292 if constexpr(isTuple)
293 {
294 return typename GridwiseGemm::BsGridPointer{};
295 }
296 else
297 {
298 return Tuple<const DataType*>{};
299 }
300}
301
302template <bool isTuple, typename Id, typename Type>
303constexpr static auto UnpackDataType()
304{
305 if constexpr(isTuple)
306 {
307 // unpack if tuple
308 return tuple_element_t<Id{}, Type>{};
309 }
310 else
311 {
312 // if no, return Type
313 return Type{};
314 }
315}
316
317} // namespace device
318} // namespace tensor_operation
319} // namespace ck
Definition convolution_backward_data_specialization.hpp:8
constexpr bool is_NWGC_GKXC_NWGK()
Definition device_grouped_conv_utils.hpp:15
constexpr bool is_NSpatialGC_GKSpatial_NSpatialGK()
Definition device_grouped_conv_utils.hpp:119
constexpr bool is_GNWC_GKXC_GNWK()
Definition device_grouped_conv_utils.hpp:23
constexpr bool is_GNDHWC_GKZYXC_GNDHWK()
Definition device_grouped_conv_utils.hpp:88
constexpr bool is_NGCSpatial_GKSpatial_NGKSpatial()
Definition device_grouped_conv_utils.hpp:135
constexpr bool is_NHWGC_GKYXC_NHWGK()
Definition device_grouped_conv_utils.hpp:40
constexpr bool is_NGCHW_GKYXC_NGKHW()
Definition device_grouped_conv_utils.hpp:56
constexpr bool is_NDHWGC_GKZYXC_NDHWGK()
Definition device_grouped_conv_utils.hpp:80
constexpr bool is_NGCDHW_NGKDHW()
Definition device_grouped_conv_utils.hpp:112
constexpr bool is_NGCW_GKXC_NGKW()
Definition device_grouped_conv_utils.hpp:31
constexpr bool is_NGCHW_GKCYX_NGKHW()
Definition device_grouped_conv_utils.hpp:64
constexpr bool is_GNSpatialC_GKSpatial_GNSpatialK()
Definition device_grouped_conv_utils.hpp:127
constexpr bool is_NGCDHW_GKZYXC_NGKDHW()
Definition device_grouped_conv_utils.hpp:96
constexpr bool is_GNHWC_GKYXC_GNHWK()
Definition device_grouped_conv_utils.hpp:48
constexpr bool is_NGCDHW_GKCZYX_NGKDHW()
Definition device_grouped_conv_utils.hpp:104
constexpr bool is_NGCHW_NGKHW()
Definition device_grouped_conv_utils.hpp:72
Definition convolution_backward_data_specialization.hpp:7
Definition ck.hpp:268
int32_t index_t
Definition ck.hpp:299
integral_constant< index_t, N > Number
Definition number.hpp:12
typename tuple_element< I, TTuple >::type tuple_element_t
Definition utility/tuple.hpp:208
constexpr bool is_same_v
Definition type.hpp:283
int64_t long_index_t
Definition ck.hpp:300
typename std::enable_if< B, T >::type enable_if_t
Definition enable_if.hpp:27
Type
Type of JSON value.
Definition rapidjson.h:760
Definition utility/array.hpp:14
Definition functional2.hpp:33
__host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const
Definition device_grouped_conv_utils.hpp:232
__host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const
Definition device_grouped_conv_utils.hpp:227
__host__ __device__ constexpr long_index_t GetCPtrOffset(index_t g_idx) const
Definition device_grouped_conv_utils.hpp:251
__host__ __device__ constexpr long_index_t GetEPtrOffset(index_t g_idx) const
Definition device_grouped_conv_utils.hpp:245
ComputePtrOffsetOfStridedBatch(long_index_t BatchStrideA, long_index_t BatchStrideB, Array< long_index_t, NumDTensor > BatchStrideDs, long_index_t BatchStrideE)
Definition device_grouped_conv_utils.hpp:216
__host__ __device__ constexpr auto GetDsPtrOffset(index_t g_idx) const
Definition device_grouped_conv_utils.hpp:237
ComputePtrOffsetOfStridedBatch(Array< long_index_t, NumATensor > &BatchStrideAs, Array< long_index_t, NumBTensor > &BatchStrideBs, Array< long_index_t, NumDTensor > &BatchStrideDs, long_index_t BatchStrideE)
Definition device_grouped_conv_utils.hpp:155
__host__ __device__ constexpr long_index_t GetCPtrOffset(index_t g_idx) const
Definition device_grouped_conv_utils.hpp:196
__host__ __device__ constexpr auto GetAsPtrOffset(index_t g_idx) const
Definition device_grouped_conv_utils.hpp:166
__host__ __device__ constexpr auto GetBsPtrOffset(index_t g_idx) const
Definition device_grouped_conv_utils.hpp:174
__host__ __device__ constexpr auto GetDsPtrOffset(index_t g_idx) const
Definition device_grouped_conv_utils.hpp:182
__host__ __device__ constexpr long_index_t GetEPtrOffset(index_t g_idx) const
Definition device_grouped_conv_utils.hpp:190