Skip to content

Expose the FusedMultiplyAdd and MultiplyAddEstimate APIs on relevant vector and scalar types #102181

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
May 16, 2024
Prev Previous commit
Next Next commit
Ensure TensorPrimitives uses the xplat APIs on .NET 9+
  • Loading branch information
tannergooding committed May 14, 2024
commit d5b2ad9eb24c9f95d64f92604c895d36cccc161a
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,17 @@ public static void FusedMultiplyAdd<T>(ReadOnlySpan<T> x, T y, ReadOnlySpan<T> a

public static Vector128<T> Invoke(Vector128<T> x, Vector128<T> y, Vector128<T> z)
{
#if NET9_0_OR_GREATER
if (typeof(T) == typeof(double))
{
return Vector128.FusedMultiplyAdd(x.AsDouble(), y.AsDouble(), z.AsDouble()).As<double, T>();
}
else
{
Debug.Assert(typeof(T) == typeof(float));
return Vector128.FusedMultiplyAdd(x.AsSingle(), y.AsSingle(), z.AsSingle()).As<float, T>();
}
#else
if (Fma.IsSupported)
{
if (typeof(T) == typeof(float)) return Fma.MultiplyAdd(x.AsSingle(), y.AsSingle(), z.AsSingle()).As<float, T>();
Expand Down Expand Up @@ -137,10 +148,22 @@ public static Vector128<T> Invoke(Vector128<T> x, Vector128<T> y, Vector128<T> z
double.FusedMultiplyAdd(xDoubles[0], yDoubles[0], zDoubles[0]),
double.FusedMultiplyAdd(xDoubles[1], yDoubles[1], zDoubles[1])).As<double, T>();
}
#endif
}

public static Vector256<T> Invoke(Vector256<T> x, Vector256<T> y, Vector256<T> z)
{
#if NET9_0_OR_GREATER
if (typeof(T) == typeof(double))
{
return Vector256.FusedMultiplyAdd(x.AsDouble(), y.AsDouble(), z.AsDouble()).As<double, T>();
}
else
{
Debug.Assert(typeof(T) == typeof(float));
return Vector256.FusedMultiplyAdd(x.AsSingle(), y.AsSingle(), z.AsSingle()).As<float, T>();
}
#else
if (Fma.IsSupported)
{
if (typeof(T) == typeof(float)) return Fma.MultiplyAdd(x.AsSingle(), y.AsSingle(), z.AsSingle()).As<float, T>();
Expand All @@ -150,10 +173,22 @@ public static Vector256<T> Invoke(Vector256<T> x, Vector256<T> y, Vector256<T> z
return Vector256.Create(
Invoke(x.GetLower(), y.GetLower(), z.GetLower()),
Invoke(x.GetUpper(), y.GetUpper(), z.GetUpper()));
#endif
}

public static Vector512<T> Invoke(Vector512<T> x, Vector512<T> y, Vector512<T> z)
{
#if NET9_0_OR_GREATER
if (typeof(T) == typeof(double))
{
return Vector512.FusedMultiplyAdd(x.AsDouble(), y.AsDouble(), z.AsDouble()).As<double, T>();
}
else
{
Debug.Assert(typeof(T) == typeof(float));
return Vector512.FusedMultiplyAdd(x.AsSingle(), y.AsSingle(), z.AsSingle()).As<float, T>();
}
#else
if (Avx512F.IsSupported)
{
if (typeof(T) == typeof(float)) return Avx512F.FusedMultiplyAdd(x.AsSingle(), y.AsSingle(), z.AsSingle()).As<float, T>();
Expand All @@ -163,6 +198,7 @@ public static Vector512<T> Invoke(Vector512<T> x, Vector512<T> y, Vector512<T> z
return Vector512.Create(
Invoke(x.GetLower(), y.GetLower(), z.GetLower()),
Invoke(x.GetUpper(), y.GetUpper(), z.GetUpper()));
#endif
}
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System.Diagnostics;
using System.Runtime.CompilerServices;
using System.Runtime.Intrinsics;
using System.Runtime.Intrinsics.Arm;
Expand Down Expand Up @@ -88,40 +89,51 @@ public static void MultiplyAddEstimate<T>(ReadOnlySpan<T> x, T y, ReadOnlySpan<T
InvokeSpanScalarSpanIntoSpan<T, MultiplyAddEstimateOperator<T>>(x, y, addend, destination);

/// <summary>(x * y) + z</summary>
private readonly struct MultiplyAddEstimateOperator<T> : ITernaryOperator<T> where T : INumberBase<T>
private readonly struct MultiplyAddEstimateOperator<T> : ITernaryOperator<T>
where T : INumberBase<T>
{
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static T Invoke(T x, T y, T z)
{
// TODO https://github.com/dotnet/runtime/issues/98053: Use T.MultiplyAddEstimate when it's available.

#if NET9_0_OR_GREATER
return T.MultiplyAddEstimate(x, y, z);
#else
if (Fma.IsSupported || AdvSimd.IsSupported)
{
if (typeof(T) == typeof(Half))
{
Half result = Half.FusedMultiplyAdd(Unsafe.As<T, Half>(ref x), Unsafe.As<T, Half>(ref y), Unsafe.As<T, Half>(ref z));
return Unsafe.As<Half, T>(ref result);
return (T)(object)Half.FusedMultiplyAdd((Half)(object)x, (Half)(object)y, (Half)(object)z);
}

if (typeof(T) == typeof(float))
{
float result = float.FusedMultiplyAdd(Unsafe.As<T, float>(ref x), Unsafe.As<T, float>(ref y), Unsafe.As<T, float>(ref z));
return Unsafe.As<float, T>(ref result);
return (T)(object)float.FusedMultiplyAdd((float)(object)x, (float)(object)y, (float)(object)z);
}

if (typeof(T) == typeof(double))
{
double result = double.FusedMultiplyAdd(Unsafe.As<T, double>(ref x), Unsafe.As<T, double>(ref y), Unsafe.As<T, double>(ref z));
return Unsafe.As<double, T>(ref result);
return (T)(object)double.FusedMultiplyAdd((double)(object)x, (double)(object)y, (double)(object)z);
}
}

return (x * y) + z;
#endif
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static Vector128<T> Invoke(Vector128<T> x, Vector128<T> y, Vector128<T> z)
{
#if NET9_0_OR_GREATER
if (typeof(T) == typeof(double))
{
return Vector128.MultiplyAddEstimate(x.AsDouble(), y.AsDouble(), z.AsDouble()).As<double, T>();
}
else
{
Debug.Assert(typeof(T) == typeof(float));
return Vector128.MultiplyAddEstimate(x.AsSingle(), y.AsSingle(), z.AsSingle()).As<float, T>();
}
#else
if (Fma.IsSupported)
{
if (typeof(T) == typeof(float)) return Fma.MultiplyAdd(x.AsSingle(), y.AsSingle(), z.AsSingle()).As<float, T>();
Expand All @@ -139,30 +151,55 @@ public static Vector128<T> Invoke(Vector128<T> x, Vector128<T> y, Vector128<T> z
}

return (x * y) + z;
#endif
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static Vector256<T> Invoke(Vector256<T> x, Vector256<T> y, Vector256<T> z)
{
#if NET9_0_OR_GREATER
if (typeof(T) == typeof(double))
{
return Vector256.MultiplyAddEstimate(x.AsDouble(), y.AsDouble(), z.AsDouble()).As<double, T>();
}
else
{
Debug.Assert(typeof(T) == typeof(float));
return Vector256.MultiplyAddEstimate(x.AsSingle(), y.AsSingle(), z.AsSingle()).As<float, T>();
}
#else
if (Fma.IsSupported)
{
if (typeof(T) == typeof(float)) return Fma.MultiplyAdd(x.AsSingle(), y.AsSingle(), z.AsSingle()).As<float, T>();
if (typeof(T) == typeof(double)) return Fma.MultiplyAdd(x.AsDouble(), y.AsDouble(), z.AsDouble()).As<double, T>();
}

return (x * y) + z;
#endif
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static Vector512<T> Invoke(Vector512<T> x, Vector512<T> y, Vector512<T> z)
{
#if NET9_0_OR_GREATER
if (typeof(T) == typeof(double))
{
return Vector512.MultiplyAddEstimate(x.AsDouble(), y.AsDouble(), z.AsDouble()).As<double, T>();
}
else
{
Debug.Assert(typeof(T) == typeof(float));
return Vector512.MultiplyAddEstimate(x.AsSingle(), y.AsSingle(), z.AsSingle()).As<float, T>();
}
#else
if (Avx512F.IsSupported)
{
if (typeof(T) == typeof(float)) return Avx512F.FusedMultiplyAdd(x.AsSingle(), y.AsSingle(), z.AsSingle()).As<float, T>();
if (typeof(T) == typeof(double)) return Avx512F.FusedMultiplyAdd(x.AsDouble(), y.AsDouble(), z.AsDouble()).As<double, T>();
}

return (x * y) + z;
#endif
}
}
}
Expand Down
Loading