thread_group_tensor_slice_transfer_v4r1_dequant.hpp Source File

thread_group_tensor_slice_transfer_v4r1_dequant.hpp Source File#

Composable Kernel: thread_group_tensor_slice_transfer_v4r1_dequant.hpp Source File
thread_group_tensor_slice_transfer_v4r1_dequant.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
11
12namespace ck {
13
27template <typename ThreadGroup,
28 typename SrcElementwiseOperation,
29 typename ScaleElementwiseOperation,
30 typename DstElementwiseOperation,
32 typename BlockSliceLengths,
33 typename BlockScaleSliceLengths,
34 typename ThreadClusterLengths,
35 typename ThreadClusterArrangeOrder,
36 typename SrcData,
37 typename ScaleData,
38 typename DstData,
39 typename SrcDesc,
40 typename ScaleDesc,
41 typename DstDesc,
42 typename SrcDimAccessOrder,
43 typename DstDimAccessOrder,
44 index_t SrcVectorDim,
45 index_t DstVectorDim,
46 index_t SrcScalarPerVector,
47 index_t ScaleScalarPerVector,
48 index_t DstScalarPerVector,
49 index_t SrcScalarStrideInVector,
50 index_t ScaleScalarStrideInVector,
51 index_t DstScalarStrideInVector,
52 bool ThreadTransferSrcResetCoordinateAfterRun,
53 bool ThreadTransferDstResetCoordinateAfterRun,
54 index_t NumThreadScratch = 1>
56{
58
59 static constexpr auto thread_slice_lengths = BlockSliceLengths{} / ThreadClusterLengths{};
60 static constexpr auto scale_thread_slice_lengths =
61 BlockScaleSliceLengths{} / ThreadClusterLengths{};
62
64
66 const SrcDesc& src_desc,
67 const Index& src_block_slice_origin,
68 const SrcElementwiseOperation& src_element_op,
69 const ScaleDesc& scale_desc,
70 const Index& scale_block_slice_origin,
71 const ScaleElementwiseOperation& scale_element_op,
72 const DstDesc& dst_desc,
73 const Index& dst_block_slice_origin,
74 const DstElementwiseOperation& dst_element_op)
75 : threadwise_transfer_(src_desc,
77 src_element_op,
78 scale_desc,
80 scale_element_op,
81 dst_desc,
83 dst_element_op)
84
85 {
89 nDim == ThreadClusterLengths::Size() &&
90 nDim == ThreadClusterArrangeOrder::Size() &&
91 nDim == SrcDimAccessOrder::Size() && nDim == DstDimAccessOrder::Size(),
92 "wrong! nDim not consistent");
93
94 static_assert(
95 is_same<BlockSliceLengths, decltype(thread_slice_lengths * ThreadClusterLengths{})>{} &&
96 is_same<BlockScaleSliceLengths,
97 decltype(scale_thread_slice_lengths * ThreadClusterLengths{})>{},
98 "wrong! threads should be mapped to cover entire slicing window");
99
100 static_assert(ThreadGroup::GetNumOfThread() >= thread_cluster_desc_.GetElementSize(),
101 "wrong! ThreadGroup::GetNumOfThread() too small");
102
103 if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
104 ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
105 {
106 const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex(
107 make_multi_index(ThreadGroup::GetThreadId()));
108
109 const auto thread_data_idx_begin = thread_cluster_idx * thread_slice_lengths;
110
111 threadwise_transfer_.SetSrcSliceOrigin(src_desc,
112 src_block_slice_origin + thread_data_idx_begin);
113 threadwise_transfer_.SetScaleSliceOrigin(
114 scale_desc, scale_block_slice_origin + thread_data_idx_begin);
115 threadwise_transfer_.SetDstSliceOrigin(dst_desc,
116 dst_block_slice_origin + thread_data_idx_begin);
117 }
118 }
119
120 template <typename SrcBuffer, index_t ThreadScratchId = 0>
121 __device__ void RunRead(const SrcDesc& src_desc,
122 const SrcBuffer& src_buf,
124 {
125 if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
126 ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
127 {
128 threadwise_transfer_.RunRead(src_desc, src_buf, thread_scratch_id);
129 }
130 }
131
132 // With the assumption, scale scratch is always one
133 template <typename ScaleBuffer>
134 __device__ void RunScaleRead(const ScaleDesc& scale_desc, const ScaleBuffer& scale_buf)
135 {
136 if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
137 ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
138 {
139 threadwise_transfer_.RunScaleRead(scale_desc, scale_buf);
140 }
141 }
142
143 template <typename DstBuffer, index_t ThreadScratchId = 0>
144 __device__ void RunWrite(const DstDesc& dst_desc,
145 DstBuffer& dst_buf,
147 {
148 if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
149 ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
150 {
151 threadwise_transfer_.RunWrite(dst_desc, dst_buf, thread_scratch_id);
152 }
153 }
154
155 // We don't prefer use this API directly
156 /*
157 template <typename SrcBuffer, typename DstBuffer, index_t ThreadScratchId>
158 __device__ void Run(const SrcDesc& src_desc,
159 const SrcBuffer& src_buf,
160 const DstDesc& dst_desc,
161 DstBuffer& dst_buf,
162 Number<ThreadScratchId> thread_scratch_id)
163 {
164 RunRead(src_desc, src_buf, thread_scratch_id);
165 RunWrite(dst_desc, dst_buf, thread_scratch_id);
166 }
167 */
168
169 __device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, const Index& step)
170 {
171 if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
172 ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
173 {
174 threadwise_transfer_.MoveSrcSliceWindow(src_desc, step);
175 }
176 }
177
178 __device__ void MoveDstSliceWindow(const DstDesc& dst_desc, const Index& step)
179 {
180 if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
181 ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
182 {
183 threadwise_transfer_.MoveDstSliceWindow(dst_desc, step);
184 }
185 }
186
187 // With the assumption, scale buffer don't need move slice window method
188
189 private:
190 static constexpr auto thread_cluster_desc_ =
191 make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{});
192
193 using ThreadwiseTransfer =
194 ThreadwiseTensorSliceTransfer_v3r1_dequant<decltype(thread_slice_lengths),
196 SrcElementwiseOperation,
197 ScaleElementwiseOperation,
198 DstElementwiseOperation,
199 DstInMemOp,
200 SrcData,
201 ScaleData,
202 DstData,
203 SrcDesc,
204 ScaleDesc,
205 DstDesc,
206 SrcDimAccessOrder,
207 DstDimAccessOrder,
208 SrcVectorDim,
209 DstVectorDim,
210 SrcScalarPerVector,
211 ScaleScalarPerVector,
212 DstScalarPerVector,
213 SrcScalarStrideInVector,
214 ScaleScalarStrideInVector,
215 DstScalarStrideInVector,
216 ThreadTransferSrcResetCoordinateAfterRun,
217 ThreadTransferDstResetCoordinateAfterRun,
218 NumThreadScratch>;
219
220 ThreadwiseTransfer threadwise_transfer_;
221};
222
223} // namespace ck
Definition ck.hpp:268
int32_t index_t
Definition ck.hpp:299
InMemoryDataOperationEnum
Definition ck.hpp:277
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
__host__ __device__ constexpr auto make_cluster_descriptor(const Lengths &lengths, ArrangeOrder order=typename arithmetic_sequence_gen< 0, Lengths::Size(), 1 >::type{})
Definition tensor_description/cluster_descriptor.hpp:13
integral_constant< index_t, N > Number
Definition number.hpp:12
__host__ __device__ constexpr auto make_zero_multi_index()
Definition array_multi_index.hpp:21
typename remove_reference< T >::type remove_reference_t
Definition type.hpp:292
Array< index_t, N > MultiIndex
Definition array_multi_index.hpp:12
__device__ void RunRead(const SrcDesc &src_desc, const SrcBuffer &src_buf, Number< ThreadScratchId > thread_scratch_id=Number< ThreadScratchId >{})
Definition thread_group_tensor_slice_transfer_v4r1_dequant.hpp:121
static constexpr auto thread_slice_lengths
Definition thread_group_tensor_slice_transfer_v4r1_dequant.hpp:59
__device__ void RunWrite(const DstDesc &dst_desc, DstBuffer &dst_buf, Number< ThreadScratchId > thread_scratch_id=Number< ThreadScratchId >{})
Definition thread_group_tensor_slice_transfer_v4r1_dequant.hpp:144
static constexpr auto scale_thread_slice_lengths
Definition thread_group_tensor_slice_transfer_v4r1_dequant.hpp:60
__device__ void MoveDstSliceWindow(const DstDesc &dst_desc, const Index &step)
Definition thread_group_tensor_slice_transfer_v4r1_dequant.hpp:178
__device__ void RunScaleRead(const ScaleDesc &scale_desc, const ScaleBuffer &scale_buf)
Definition thread_group_tensor_slice_transfer_v4r1_dequant.hpp:134
__device__ constexpr ThreadGroupTensorSliceTransfer_v4r1_dequant(const SrcDesc &src_desc, const Index &src_block_slice_origin, const SrcElementwiseOperation &src_element_op, const ScaleDesc &scale_desc, const Index &scale_block_slice_origin, const ScaleElementwiseOperation &scale_element_op, const DstDesc &dst_desc, const Index &dst_block_slice_origin, const DstElementwiseOperation &dst_element_op)
Definition thread_group_tensor_slice_transfer_v4r1_dequant.hpp:65
MultiIndex< nDim > Index
Definition thread_group_tensor_slice_transfer_v4r1_dequant.hpp:63
static constexpr index_t nDim
Definition thread_group_tensor_slice_transfer_v4r1_dequant.hpp:57
__device__ void MoveSrcSliceWindow(const SrcDesc &src_desc, const Index &step)
Definition thread_group_tensor_slice_transfer_v4r1_dequant.hpp:169
__device__ void RunWrite(const DstDesc &dst_desc, DstBuffer &dst_buf, Number< ThreadScratchId > thread_scratch_id=Number< ThreadScratchId >{})
Definition threadwise_tensor_slice_transfer_v3r1_dequant.hpp:548
__device__ void RunRead(const SrcDesc &src_desc, const SrcBuffer &src_buf, Number< ThreadScratchId > thread_scratch_id=Number< ThreadScratchId >{})
Definition threadwise_tensor_slice_transfer_v3r1_dequant.hpp:129
Definition type.hpp:177