Skip to content

Commit 8e45c51

Browse files
author
peng.li24
committed
feat: add comprehensive special-value tests + fix NaN/signed-zero in all AVX-512 paths
- avx512_loops.h: add explicit NaN passthrough blend to exp<f32>, sin<f32>, cos<f32> after polynomial computation, guaranteeing same NaN bit pattern as numpy's scalar npy_expf / npy_sinf / npy_cosf (which return x unchanged for NaN input) - svml_bridge.h: fix sin(±0)=±0 for both f32 and f64 scalar paths * sin_f32(): npy_sinf polynomial fma(sp,r,r) with r=±0 gives +0 by IEEE 754 RN * sin_f64(): SVML broadcast scalar path __svml_sin8(-0) returns +0 Both fixed with cheap branch: if (x==0 && r==0) return x - test_all.py section 16 – 206 new bit-exact special-value tests: * NaN passthrough: all 21 unary math functions × f32/f64 × sizes 1,16,17 * Mixed NaN/finite (17 elements): NaN must not corrupt neighbours in SIMD path * Signed zero: sin(±0)=±0, cos(±0)=1, log(-0)=-inf, exp(±0)=1 * Infinity: exp(±inf), log(+inf), sqrt(+inf), sin/cos(±inf)→NaN * Domain errors: log(neg), sqrt(neg), arcsin/arccos(|x|>1) → NaN bit-exact * sign(NaN)=NaN, sign(±inf)=±1, sign(±0)=0 * unwrap NaN propagation (mid, leading, all-NaN) * linalg: norm/dot with NaN or Inf inputs * AVX-512 boundary sizes 15/16/17/32 for exp/log/sin/cos - test_all.py check_bit_aligned: upgraded to uint-view bit comparison for float arrays with matching dtype so NaN==NaN passes at bit level; dtype-mismatch case (C++ returns float64 for float32 input) falls back to numeric equality
1 parent 55fb054 commit 8e45c51

3 files changed

Lines changed: 370 additions & 9 deletions

File tree

numpy/avx512_loops.h

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,11 @@ inline void exp<float>(const float* __restrict__ s,
200200
poly, Vinf);
201201
poly = _mm512_mask_blend_ps(_mm512_cmp_ps_mask(x, VXmin, _CMP_LE_OQ),
202202
poly, _mm512_setzero_ps());
203+
// NaN passthrough: ordered comparisons above return false for NaN → poly holds
204+
// polynomial-derived NaN; blend back original x to guarantee bit-exact match
205+
// with numpy's scalar npy_expf which returns x unchanged for NaN input.
206+
__mmask16 is_nan_e = _mm512_cmp_ps_mask(x, x, _CMP_UNORD_Q);
207+
poly = _mm512_mask_blend_ps(is_nan_e, poly, x);
203208
_mm512_storeu_ps(d + i, poly);
204209
}
205210
for (; i < n; ++i) d[i] = detail::exp_npy_f32(s[i]);
@@ -366,9 +371,14 @@ inline void sin<float>(const float* __restrict__ s,
366371
if (!((inr >> j) & 1)) rt[j] = std::sin(xt[j]);
367372
result = _mm512_loadu_ps(rt);
368373
}
374+
// NaN passthrough: blend back original x after fallback so NaN output = NaN
375+
// input (bit-exact with numpy's scalar npy_sinf which returns x for NaN).
376+
__mmask16 is_nan_s = _mm512_cmp_ps_mask(x, x, _CMP_UNORD_Q);
377+
result = _mm512_mask_blend_ps(is_nan_s, result, x);
369378
_mm512_storeu_ps(d + i, result);
370379
}
371-
for (; i < n; ++i) d[i] = detail::sin_npy_f32(s[i]);
380+
// sin_f32 adds signed-zero fix: sin(±0)=±0 (npy_sinf polynomial gives +0 for -0).
381+
for (; i < n; ++i) d[i] = detail::sin_f32(s[i]);
372382
}
373383

374384
// ----------------------------------------------------------------------------
@@ -434,6 +444,9 @@ inline void cos<float>(const float* __restrict__ s,
434444
if (!((inr >> j) & 1)) rt[j] = std::cos(xt[j]);
435445
result = _mm512_loadu_ps(rt);
436446
}
447+
// NaN passthrough: blend back original x (bit-exact with numpy scalar npy_cosf).
448+
__mmask16 is_nan_c = _mm512_cmp_ps_mask(x, x, _CMP_UNORD_Q);
449+
result = _mm512_mask_blend_ps(is_nan_c, result, x);
437450
_mm512_storeu_ps(d + i, result);
438451
}
439452
for (; i < n; ++i) d[i] = detail::cos_npy_f32(s[i]);

numpy/svml_bridge.h

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -274,7 +274,13 @@ inline float atan2_npy_f32(float y, float x) {
274274

275275
DISPATCH_F64(exp)
276276
DISPATCH_F64(log)
277-
DISPATCH_F64(sin)
277+
// sin_f64: custom — SVML scalar broadcast path loses signed zero (sin(-0)→+0).
278+
// IEEE 754 requires sin(±0) = ±0; preserve sign of zero explicitly.
279+
inline double sin_f64(double x) {
280+
double r = cpu_has_avx512f() ? sin_svml_f64(x) : sin_npy_f64(x);
281+
if (__builtin_expect(x == 0.0 && r == 0.0, 0)) return x; // ±0 → ±0
282+
return r;
283+
}
278284
DISPATCH_F64(cos)
279285
DISPATCH_F64(tan)
280286
DISPATCH_F64(asin)
@@ -301,7 +307,13 @@ DISPATCH_F32(log1p)
301307
// (npy_math_float.h), NOT SVML. These are bit-exact on all architectures.
302308
inline float exp_f32(float x) { return exp_npy_f32(x); }
303309
inline float log_f32(float x) { return log_npy_f32(x); }
304-
inline float sin_f32(float x) { return sin_npy_f32(x); }
310+
// sin_f32: npy_sinf polynomial computes fma(sp,r,r) with r=±0 → +0 (IEEE RN rule),
311+
// losing the sign. Restore: IEEE 754 mandates sin(±0) = ±0.
312+
inline float sin_f32(float x) {
313+
float r = sin_npy_f32(x);
314+
if (__builtin_expect(x == 0.0f && r == 0.0f, 0)) return x; // sin(±0)=±0
315+
return r;
316+
}
305317
inline float cos_f32(float x) { return cos_npy_f32(x); }
306318

307319
// pow / atan2 dispatchers

0 commit comments

Comments
 (0)