num_traits/ops/
mul_add.rs

1/// Fused multiply-add. Computes `(self * a) + b` with only one rounding
2/// error, yielding a more accurate result than an unfused multiply-add.
3///
4/// Using `mul_add` can be more performant than an unfused multiply-add if
5/// the target architecture has a dedicated `fma` CPU instruction.
6///
7/// Note that `A` and `B` are `Self` by default, but this is not mandatory.
8///
9/// # Example
10///
11/// ```
12/// use std::f32;
13///
14/// let m = 10.0_f32;
15/// let x = 4.0_f32;
16/// let b = 60.0_f32;
17///
18/// // 100.0
19/// let abs_difference = (m.mul_add(x, b) - (m*x + b)).abs();
20///
21/// assert!(abs_difference <= 100.0 * f32::EPSILON);
22/// ```
23pub trait MulAdd<A = Self, B = Self> {
24    /// The resulting type after applying the fused multiply-add.
25    type Output;
26
27    /// Performs the fused multiply-add operation.
28    fn mul_add(self, a: A, b: B) -> Self::Output;
29}
30
31/// The fused multiply-add assignment operation.
32pub trait MulAddAssign<A = Self, B = Self> {
33    /// Performs the fused multiply-add operation.
34    fn mul_add_assign(&mut self, a: A, b: B);
35}
36
37#[cfg(any(feature = "std", feature = "libm"))]
38impl MulAdd<f32, f32> for f32 {
39    type Output = Self;
40
41    #[inline]
42    fn mul_add(self, a: Self, b: Self) -> Self::Output {
43        <Self as ::Float>::mul_add(self, a, b)
44    }
45}
46
47#[cfg(any(feature = "std", feature = "libm"))]
48impl MulAdd<f64, f64> for f64 {
49    type Output = Self;
50
51    #[inline]
52    fn mul_add(self, a: Self, b: Self) -> Self::Output {
53        <Self as ::Float>::mul_add(self, a, b)
54    }
55}
56
57macro_rules! mul_add_impl {
58    ($trait_name:ident for $($t:ty)*) => {$(
59        impl $trait_name for $t {
60            type Output = Self;
61
62            #[inline]
63            fn mul_add(self, a: Self, b: Self) -> Self::Output {
64                (self * a) + b
65            }
66        }
67    )*}
68}
69
70mul_add_impl!(MulAdd for isize usize i8 u8 i16 u16 i32 u32 i64 u64);
71#[cfg(has_i128)]
72mul_add_impl!(MulAdd for i128 u128);
73
74#[cfg(any(feature = "std", feature = "libm"))]
75impl MulAddAssign<f32, f32> for f32 {
76    #[inline]
77    fn mul_add_assign(&mut self, a: Self, b: Self) {
78        *self = <Self as ::Float>::mul_add(*self, a, b)
79    }
80}
81
82#[cfg(any(feature = "std", feature = "libm"))]
83impl MulAddAssign<f64, f64> for f64 {
84    #[inline]
85    fn mul_add_assign(&mut self, a: Self, b: Self) {
86        *self = <Self as ::Float>::mul_add(*self, a, b)
87    }
88}
89
90macro_rules! mul_add_assign_impl {
91    ($trait_name:ident for $($t:ty)*) => {$(
92        impl $trait_name for $t {
93            #[inline]
94            fn mul_add_assign(&mut self, a: Self, b: Self) {
95                *self = (*self * a) + b
96            }
97        }
98    )*}
99}
100
101mul_add_assign_impl!(MulAddAssign for isize usize i8 u8 i16 u16 i32 u32 i64 u64);
102#[cfg(has_i128)]
103mul_add_assign_impl!(MulAddAssign for i128 u128);
104
105#[cfg(test)]
106mod tests {
107    use super::*;
108
109    #[test]
110    fn mul_add_integer() {
111        macro_rules! test_mul_add {
112            ($($t:ident)+) => {
113                $(
114                    {
115                        let m: $t = 2;
116                        let x: $t = 3;
117                        let b: $t = 4;
118
119                        assert_eq!(MulAdd::mul_add(m, x, b), (m*x + b));
120                    }
121                )+
122            };
123        }
124
125        test_mul_add!(usize u8 u16 u32 u64 isize i8 i16 i32 i64);
126    }
127
128    #[test]
129    #[cfg(feature = "std")]
130    fn mul_add_float() {
131        macro_rules! test_mul_add {
132            ($($t:ident)+) => {
133                $(
134                    {
135                        use core::$t;
136
137                        let m: $t = 12.0;
138                        let x: $t = 3.4;
139                        let b: $t = 5.6;
140
141                        let abs_difference = (MulAdd::mul_add(m, x, b) - (m*x + b)).abs();
142
143                        assert!(abs_difference <= 46.4 * $t::EPSILON);
144                    }
145                )+
146            };
147        }
148
149        test_mul_add!(f32 f64);
150    }
151}