device_softmax.hpp Source File

device_softmax.hpp Source File#

Composable Kernel: device_softmax.hpp Source File
device_softmax.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include <memory>
7#include <vector>
8
9#include "ck/ck.hpp"
11
12namespace ck {
13namespace tensor_operation {
14namespace device {
15
16template <typename InDataType,
17 typename AccDataType,
18 typename OutDataType,
19 typename InElementwiseOp,
20 typename AccElementwiseOp,
21 index_t Rank,
22 index_t NumReduceDim>
24{
25 //
26 // @brief Makes a pointer to Argument class.
27 //
28 // @param[in] inLengths Input tensor extent(s) from high to low dimension
29 // @param[in] inStrides Input tensor stride(s) from high to low dimension
30 // @param[in] reduceDims The dimension(s) the normalization operation is applied
31 // @param[in] alpha double type value
32 // @param[in] beta double type value
33 // @param[in] in_dev Typeless const pointer in device memory storing the input
34 // tensor
35 // @param out_dev Typeless pointer in device memory storing the output tensor
36 // @param[in] in_elementwise_op The input elementwise operation.
37 // @param[in] acc_elementwise_op The accumulation elementwise operation.
38 //
39 // @return Unique pointer to the Argument class.
40 //
41 virtual std::unique_ptr<BaseArgument>
42 MakeArgumentPointer(const std::vector<index_t> inLengths,
43 const std::vector<index_t> inStrides,
44 const std::vector<int> reduceDims,
45 double alpha,
46 double beta,
47 const void* in_dev,
48 void* out_dev,
49 InElementwiseOp in_elementwise_op,
50 AccElementwiseOp acc_elementwise_op) = 0;
51
52 virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
53};
54
55template <typename InDataType,
56 typename AccDataType,
57 typename OutDataType,
58 typename InElementwiseOp,
59 typename AccElementwiseOp,
60 index_t Rank,
61 index_t NumReduceDim>
62using DeviceSoftmaxPtr = std::unique_ptr<DeviceSoftmax<InDataType,
63 AccDataType,
64 OutDataType,
65 InElementwiseOp,
66 AccElementwiseOp,
67 Rank,
68 NumReduceDim>>;
69
70} // namespace device
71} // namespace tensor_operation
72} // namespace ck
Definition convolution_backward_data_specialization.hpp:8
std::unique_ptr< DeviceSoftmax< InDataType, AccDataType, OutDataType, InElementwiseOp, AccElementwiseOp, Rank, NumReduceDim > > DeviceSoftmaxPtr
Definition device_softmax.hpp:62
Definition convolution_backward_data_specialization.hpp:7
Definition ck.hpp:268
int32_t index_t
Definition ck.hpp:299
Definition device_softmax.hpp:24
virtual std::unique_ptr< BaseInvoker > MakeInvokerPointer()=0
virtual std::unique_ptr< BaseArgument > MakeArgumentPointer(const std::vector< index_t > inLengths, const std::vector< index_t > inStrides, const std::vector< int > reduceDims, double alpha, double beta, const void *in_dev, void *out_dev, InElementwiseOp in_elementwise_op, AccElementwiseOp acc_elementwise_op)=0