Skip to content

Commit 15e4423

Browse files
[release/9.0-rc2] Replace VectorXx.Exp's edge case fallback with scalar processing (#107942)
* Replace VectorXx.Exp's edge case fallback with scalar processing The better, vectorized fix is more complex and can be done for .NET 10. * Revert addition to Helpers.IsEqualWithTolerance --------- Co-authored-by: Stephen Toub <[email protected]>
1 parent c478b2a commit 15e4423

File tree

3 files changed

+173
-167
lines changed

3 files changed

+173
-167
lines changed

src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/TensorPrimitives.Exp.cs

Lines changed: 102 additions & 115 deletions
Original file line numberDiff line numberDiff line change
@@ -157,9 +157,6 @@ public static Vector512<T> Invoke(Vector512<T> x)
157157
private const ulong V_ARG_MAX = 0x40862000_00000000;
158158
private const ulong V_DP64_BIAS = 1023;
159159

160-
private const double V_EXPF_MIN = -709.782712893384;
161-
private const double V_EXPF_MAX = +709.782712893384;
162-
163160
private const double V_EXPF_HUGE = 6755399441055744;
164161
private const double V_TBL_LN2 = 1.4426950408889634;
165162

@@ -183,155 +180,145 @@ public static Vector512<T> Invoke(Vector512<T> x)
183180

184181
public static Vector128<double> Invoke(Vector128<double> x)
185182
{
186-
// x * (64.0 / ln(2))
187-
Vector128<double> z = x * Vector128.Create(V_TBL_LN2);
188-
189-
Vector128<double> dn = z + Vector128.Create(V_EXPF_HUGE);
183+
// Check if -709 < vx < 709
184+
if (Vector128.LessThanOrEqualAll(Vector128.Abs(x).AsUInt64(), Vector128.Create(V_ARG_MAX)))
185+
{
186+
// x * (64.0 / ln(2))
187+
Vector128<double> z = x * Vector128.Create(V_TBL_LN2);
190188

191-
// n = (int)z
192-
Vector128<ulong> n = dn.AsUInt64();
189+
Vector128<double> dn = z + Vector128.Create(V_EXPF_HUGE);
193190

194-
// dn = (double)n
195-
dn -= Vector128.Create(V_EXPF_HUGE);
191+
// n = (int)z
192+
Vector128<ulong> n = dn.AsUInt64();
196193

197-
// r = x - (dn * (ln(2) / 64))
198-
// where ln(2) / 64 is split into Head and Tail values
199-
Vector128<double> r = x - (dn * Vector128.Create(V_LN2_HEAD)) - (dn * Vector128.Create(V_LN2_TAIL));
194+
// dn = (double)n
195+
dn -= Vector128.Create(V_EXPF_HUGE);
200196

201-
Vector128<double> r2 = r * r;
202-
Vector128<double> r4 = r2 * r2;
203-
Vector128<double> r8 = r4 * r4;
197+
// r = x - (dn * (ln(2) / 64))
198+
// where ln(2) / 64 is split into Head and Tail values
199+
Vector128<double> r = x - (dn * Vector128.Create(V_LN2_HEAD)) - (dn * Vector128.Create(V_LN2_TAIL));
204200

205-
// Compute polynomial
206-
Vector128<double> poly = ((Vector128.Create(C12) * r + Vector128.Create(C11)) * r2 +
207-
Vector128.Create(C10) * r + Vector128.Create(C9)) * r8 +
208-
((Vector128.Create(C8) * r + Vector128.Create(C7)) * r2 +
209-
(Vector128.Create(C6) * r + Vector128.Create(C5))) * r4 +
210-
((Vector128.Create(C4) * r + Vector128.Create(C3)) * r2 + (r + Vector128<double>.One));
201+
Vector128<double> r2 = r * r;
202+
Vector128<double> r4 = r2 * r2;
203+
Vector128<double> r8 = r4 * r4;
211204

212-
// m = (n - j) / 64
213-
// result = polynomial * 2^m
214-
Vector128<double> ret = poly * ((n + Vector128.Create(V_DP64_BIAS)) << 52).AsDouble();
205+
// Compute polynomial
206+
Vector128<double> poly = ((Vector128.Create(C12) * r + Vector128.Create(C11)) * r2 +
207+
Vector128.Create(C10) * r + Vector128.Create(C9)) * r8 +
208+
((Vector128.Create(C8) * r + Vector128.Create(C7)) * r2 +
209+
(Vector128.Create(C6) * r + Vector128.Create(C5))) * r4 +
210+
((Vector128.Create(C4) * r + Vector128.Create(C3)) * r2 + (r + Vector128<double>.One));
215211

216-
// Check if -709 < vx < 709
217-
if (Vector128.GreaterThanAny(Vector128.Abs(x).AsUInt64(), Vector128.Create(V_ARG_MAX)))
212+
// m = (n - j) / 64
213+
// result = polynomial * 2^m
214+
return poly * ((n + Vector128.Create(V_DP64_BIAS)) << 52).AsDouble();
215+
}
216+
else
218217
{
219-
// (x > V_EXPF_MAX) ? double.PositiveInfinity : x
220-
Vector128<double> infinityMask = Vector128.GreaterThan(x, Vector128.Create(V_EXPF_MAX));
221-
222-
ret = Vector128.ConditionalSelect(
223-
infinityMask,
224-
Vector128.Create(double.PositiveInfinity),
225-
ret
226-
);
218+
return ScalarFallback(x);
227219

228-
// (x < V_EXPF_MIN) ? 0 : x
229-
ret = Vector128.AndNot(ret, Vector128.LessThan(x, Vector128.Create(V_EXPF_MIN)));
220+
static Vector128<double> ScalarFallback(Vector128<double> x) =>
221+
Vector128.Create(Math.Exp(x.GetElement(0)),
222+
Math.Exp(x.GetElement(1)));
230223
}
231-
232-
return ret;
233224
}
234225

235226
public static Vector256<double> Invoke(Vector256<double> x)
236227
{
237-
// x * (64.0 / ln(2))
238-
Vector256<double> z = x * Vector256.Create(V_TBL_LN2);
239-
240-
Vector256<double> dn = z + Vector256.Create(V_EXPF_HUGE);
228+
// Check if -709 < vx < 709
229+
if (Vector256.LessThanOrEqualAll(Vector256.Abs(x).AsUInt64(), Vector256.Create(V_ARG_MAX)))
230+
{
231+
// x * (64.0 / ln(2))
232+
Vector256<double> z = x * Vector256.Create(V_TBL_LN2);
241233

242-
// n = (int)z
243-
Vector256<ulong> n = dn.AsUInt64();
234+
Vector256<double> dn = z + Vector256.Create(V_EXPF_HUGE);
244235

245-
// dn = (double)n
246-
dn -= Vector256.Create(V_EXPF_HUGE);
236+
// n = (int)z
237+
Vector256<ulong> n = dn.AsUInt64();
247238

248-
// r = x - (dn * (ln(2) / 64))
249-
// where ln(2) / 64 is split into Head and Tail values
250-
Vector256<double> r = x - (dn * Vector256.Create(V_LN2_HEAD)) - (dn * Vector256.Create(V_LN2_TAIL));
239+
// dn = (double)n
240+
dn -= Vector256.Create(V_EXPF_HUGE);
251241

252-
Vector256<double> r2 = r * r;
253-
Vector256<double> r4 = r2 * r2;
254-
Vector256<double> r8 = r4 * r4;
242+
// r = x - (dn * (ln(2) / 64))
243+
// where ln(2) / 64 is split into Head and Tail values
244+
Vector256<double> r = x - (dn * Vector256.Create(V_LN2_HEAD)) - (dn * Vector256.Create(V_LN2_TAIL));
255245

256-
// Compute polynomial
257-
Vector256<double> poly = ((Vector256.Create(C12) * r + Vector256.Create(C11)) * r2 +
258-
Vector256.Create(C10) * r + Vector256.Create(C9)) * r8 +
259-
((Vector256.Create(C8) * r + Vector256.Create(C7)) * r2 +
260-
(Vector256.Create(C6) * r + Vector256.Create(C5))) * r4 +
261-
((Vector256.Create(C4) * r + Vector256.Create(C3)) * r2 + (r + Vector256<double>.One));
246+
Vector256<double> r2 = r * r;
247+
Vector256<double> r4 = r2 * r2;
248+
Vector256<double> r8 = r4 * r4;
262249

263-
// m = (n - j) / 64
264-
// result = polynomial * 2^m
265-
Vector256<double> ret = poly * ((n + Vector256.Create(V_DP64_BIAS)) << 52).AsDouble();
250+
// Compute polynomial
251+
Vector256<double> poly = ((Vector256.Create(C12) * r + Vector256.Create(C11)) * r2 +
252+
Vector256.Create(C10) * r + Vector256.Create(C9)) * r8 +
253+
((Vector256.Create(C8) * r + Vector256.Create(C7)) * r2 +
254+
(Vector256.Create(C6) * r + Vector256.Create(C5))) * r4 +
255+
((Vector256.Create(C4) * r + Vector256.Create(C3)) * r2 + (r + Vector256<double>.One));
266256

267-
// Check if -709 < vx < 709
268-
if (Vector256.GreaterThanAny(Vector256.Abs(x).AsUInt64(), Vector256.Create(V_ARG_MAX)))
257+
// m = (n - j) / 64
258+
// result = polynomial * 2^m
259+
return poly * ((n + Vector256.Create(V_DP64_BIAS)) << 52).AsDouble();
260+
}
261+
else
269262
{
270-
// (x > V_EXPF_MAX) ? double.PositiveInfinity : x
271-
Vector256<double> infinityMask = Vector256.GreaterThan(x, Vector256.Create(V_EXPF_MAX));
263+
return ScalarFallback(x);
272264

273-
ret = Vector256.ConditionalSelect(
274-
infinityMask,
275-
Vector256.Create(double.PositiveInfinity),
276-
ret
277-
);
278-
279-
// (x < V_EXPF_MIN) ? 0 : x
280-
ret = Vector256.AndNot(ret, Vector256.LessThan(x, Vector256.Create(V_EXPF_MIN)));
265+
static Vector256<double> ScalarFallback(Vector256<double> x) =>
266+
Vector256.Create(Math.Exp(x.GetElement(0)),
267+
Math.Exp(x.GetElement(1)),
268+
Math.Exp(x.GetElement(2)),
269+
Math.Exp(x.GetElement(3)));
281270
}
282-
283-
return ret;
284271
}
285272

286273
public static Vector512<double> Invoke(Vector512<double> x)
287274
{
288-
// x * (64.0 / ln(2))
289-
Vector512<double> z = x * Vector512.Create(V_TBL_LN2);
290-
291-
Vector512<double> dn = z + Vector512.Create(V_EXPF_HUGE);
275+
// Check if -709 < vx < 709
276+
if (Vector512.LessThanOrEqualAll(Vector512.Abs(x).AsUInt64(), Vector512.Create(V_ARG_MAX)))
277+
{
278+
// x * (64.0 / ln(2))
279+
Vector512<double> z = x * Vector512.Create(V_TBL_LN2);
292280

293-
// n = (int)z
294-
Vector512<ulong> n = dn.AsUInt64();
281+
Vector512<double> dn = z + Vector512.Create(V_EXPF_HUGE);
295282

296-
// dn = (double)n
297-
dn -= Vector512.Create(V_EXPF_HUGE);
283+
// n = (int)z
284+
Vector512<ulong> n = dn.AsUInt64();
298285

299-
// r = x - (dn * (ln(2) / 64))
300-
// where ln(2) / 64 is split into Head and Tail values
301-
Vector512<double> r = x - (dn * Vector512.Create(V_LN2_HEAD)) - (dn * Vector512.Create(V_LN2_TAIL));
286+
// dn = (double)n
287+
dn -= Vector512.Create(V_EXPF_HUGE);
302288

303-
Vector512<double> r2 = r * r;
304-
Vector512<double> r4 = r2 * r2;
305-
Vector512<double> r8 = r4 * r4;
289+
// r = x - (dn * (ln(2) / 64))
290+
// where ln(2) / 64 is split into Head and Tail values
291+
Vector512<double> r = x - (dn * Vector512.Create(V_LN2_HEAD)) - (dn * Vector512.Create(V_LN2_TAIL));
306292

307-
// Compute polynomial
308-
Vector512<double> poly = ((Vector512.Create(C12) * r + Vector512.Create(C11)) * r2 +
309-
Vector512.Create(C10) * r + Vector512.Create(C9)) * r8 +
310-
((Vector512.Create(C8) * r + Vector512.Create(C7)) * r2 +
311-
(Vector512.Create(C6) * r + Vector512.Create(C5))) * r4 +
312-
((Vector512.Create(C4) * r + Vector512.Create(C3)) * r2 + (r + Vector512<double>.One));
293+
Vector512<double> r2 = r * r;
294+
Vector512<double> r4 = r2 * r2;
295+
Vector512<double> r8 = r4 * r4;
313296

314-
// m = (n - j) / 64
315-
// result = polynomial * 2^m
316-
Vector512<double> ret = poly * ((n + Vector512.Create(V_DP64_BIAS)) << 52).AsDouble();
297+
// Compute polynomial
298+
Vector512<double> poly = ((Vector512.Create(C12) * r + Vector512.Create(C11)) * r2 +
299+
Vector512.Create(C10) * r + Vector512.Create(C9)) * r8 +
300+
((Vector512.Create(C8) * r + Vector512.Create(C7)) * r2 +
301+
(Vector512.Create(C6) * r + Vector512.Create(C5))) * r4 +
302+
((Vector512.Create(C4) * r + Vector512.Create(C3)) * r2 + (r + Vector512<double>.One));
317303

318-
// Check if -709 < vx < 709
319-
if (Vector512.GreaterThanAny(Vector512.Abs(x).AsUInt64(), Vector512.Create(V_ARG_MAX)))
304+
// m = (n - j) / 64
305+
// result = polynomial * 2^m
306+
return poly * ((n + Vector512.Create(V_DP64_BIAS)) << 52).AsDouble();
307+
}
308+
else
320309
{
321-
// (x > V_EXPF_MAX) ? double.PositiveInfinity : x
322-
Vector512<double> infinityMask = Vector512.GreaterThan(x, Vector512.Create(V_EXPF_MAX));
323-
324-
ret = Vector512.ConditionalSelect(
325-
infinityMask,
326-
Vector512.Create(double.PositiveInfinity),
327-
ret
328-
);
329-
330-
// (x < V_EXPF_MIN) ? 0 : x
331-
ret = Vector512.AndNot(ret, Vector512.LessThan(x, Vector512.Create(V_EXPF_MIN)));
310+
return ScalarFallback(x);
311+
312+
static Vector512<double> ScalarFallback(Vector512<double> x) =>
313+
Vector512.Create(Math.Exp(x.GetElement(0)),
314+
Math.Exp(x.GetElement(1)),
315+
Math.Exp(x.GetElement(2)),
316+
Math.Exp(x.GetElement(3)),
317+
Math.Exp(x.GetElement(4)),
318+
Math.Exp(x.GetElement(5)),
319+
Math.Exp(x.GetElement(6)),
320+
Math.Exp(x.GetElement(7)));
332321
}
333-
334-
return ret;
335322
}
336323
}
337324

src/libraries/System.Numerics.Tensors/tests/TensorPrimitivesTests.cs

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -163,9 +163,17 @@ protected T NextRandom(T avoid)
163163
/// the value is stored into a random position in <paramref name="x"/>, and the original
164164
/// value is subsequently restored.
165165
/// </summary>
166-
protected void RunForEachSpecialValue(Action action, BoundedMemory<T> x)
166+
protected void RunForEachSpecialValue(Action action, BoundedMemory<T> x) =>
167+
RunForEachSpecialValue(action, x, GetSpecialValues());
168+
169+
/// <summary>
170+
/// Runs the specified action for each special value. Before the action is invoked,
171+
/// the value is stored into a random position in <paramref name="x"/>, and the original
172+
/// value is subsequently restored.
173+
/// </summary>
174+
protected void RunForEachSpecialValue(Action action, BoundedMemory<T> x, IEnumerable<T> specialValues)
167175
{
168-
Assert.All(GetSpecialValues(), value =>
176+
Assert.All(specialValues, value =>
169177
{
170178
int pos = Random.Next(x.Length);
171179
T orig = x[pos];
@@ -1021,14 +1029,25 @@ public void Exp_SpecialValues()
10211029
using BoundedMemory<T> x = CreateAndFillTensor(tensorLength);
10221030
using BoundedMemory<T> destination = CreateTensor(tensorLength);
10231031

1032+
T[] additionalSpecialValues =
1033+
[
1034+
typeof(T) == typeof(float) ? (T)(object)-709.7f :
1035+
typeof(T) == typeof(double) ? (T)(object)-709.7 :
1036+
default,
1037+
1038+
typeof(T) == typeof(float) ? (T)(object)709.7f :
1039+
typeof(T) == typeof(double) ? (T)(object)709.7 :
1040+
default,
1041+
];
1042+
10241043
RunForEachSpecialValue(() =>
10251044
{
10261045
Exp(x, destination);
10271046
for (int i = 0; i < tensorLength; i++)
10281047
{
10291048
AssertEqualTolerance(Exp(x[i]), destination[i]);
10301049
}
1031-
}, x);
1050+
}, x, GetSpecialValues().Concat(additionalSpecialValues));
10321051
});
10331052
}
10341053

0 commit comments

Comments
 (0)