reference_batched_softmax.hpp Source File

reference_batched_softmax.hpp Source File#

Composable Kernel: reference_batched_softmax.hpp Source File
reference_batched_softmax.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include "ck_tile/core.hpp"
8#include <thread>
9
10namespace ck_tile {
11
12template <typename ADataType,
13 typename CompDataType,
14 typename BDataType,
15 typename CompElementOp = ck_tile::identity>
17 const HostTensor<ADataType>& a_b_m_n,
18 HostTensor<BDataType>& b_b_m_n,
19 const CompElementOp& comp_element_op = {},
20 std::optional<std::reference_wrapper<HostTensor<CompDataType>>> lse_b_m = std::nullopt)
21{
22 const int N = a_b_m_n.mDesc.get_lengths()[2];
23
24 auto f = [&](auto batch, auto m) {
25 CompDataType v_max = -ck_tile::numeric<CompDataType>::infinity();
26
27 // max
28 for(int n = 0; n < N; ++n)
29 {
30 const CompDataType v_a = ck_tile::type_convert<CompDataType>(a_b_m_n(batch, m, n));
31
32 v_max = v_max < v_a ? v_a : v_max;
33 }
34
35 CompDataType v_exp_sum = 0;
36 // validate v_max if all the elements within a row are -INF
37 if(std::isinf(v_max) && v_max < 0)
38 {
40 }
41
42 // sum
43 for(int n = 0; n < N; ++n)
44 {
45 const CompDataType v_a = ck_tile::type_convert<CompDataType>(a_b_m_n(batch, m, n));
46
47 v_exp_sum += ck_tile::exp(v_a - v_max);
48 }
49
50 // if sum is zero(masked), or nan/inf(other computation error), don't do divide
51 CompDataType inv_sum = (v_exp_sum == 0.f ? 1.f : 1.f / v_exp_sum);
52
53 // elementwise
54 for(int n = 0; n < N; ++n)
55 {
56 const CompDataType v_a = ck_tile::type_convert<CompDataType>(a_b_m_n(batch, m, n));
57 const CompDataType v_b = ck_tile::exp(v_a - v_max) * inv_sum;
58
59 b_b_m_n(batch, m, n) = ck_tile::type_convert<BDataType>(comp_element_op(v_b));
60 }
61 // lse
62 if(lse_b_m)
63 {
64 lse_b_m->get()(batch, m) = v_max + ck_tile::log(v_exp_sum);
65 }
66 };
67
68 make_ParallelTensorFunctor(f, b_b_m_n.mDesc.get_lengths()[0], b_b_m_n.mDesc.get_lengths()[1])(
69 std::thread::hardware_concurrency());
70}
71} // namespace ck_tile
#define CK_TILE_HOST
Definition config.hpp:40
Definition tile/core/algorithm/cluster_descriptor.hpp:13
CK_TILE_HOST auto make_ParallelTensorFunctor(F f, Xs... xs)
Definition tile/host/host_tensor.hpp:329
CK_TILE_DEVICE bfloat16_t log(bfloat16_t x)
Definition bfloat16.hpp:428
CK_TILE_HOST void reference_batched_softmax(const HostTensor< ADataType > &a_b_m_n, HostTensor< BDataType > &b_b_m_n, const CompElementOp &comp_element_op={}, std::optional< std::reference_wrapper< HostTensor< CompDataType > > > lse_b_m=std::nullopt)
Definition reference_batched_softmax.hpp:16
CK_TILE_DEVICE bfloat16_t exp(bfloat16_t x)
Definition bfloat16.hpp:419
CK_TILE_HOST_DEVICE constexpr Y type_convert(X x)
Definition tile/core/numeric/type_convert.hpp:29
const std::vector< std::size_t > & get_lengths() const
Definition tile/host/host_tensor.hpp:198
Definition tile/host/host_tensor.hpp:336
Descriptor mDesc
Definition tile/host/host_tensor.hpp:800
static CK_TILE_HOST_DEVICE constexpr T infinity()
Definition tile/core/numeric/numeric.hpp:38