Skip to content

Commit 9cc1aff

Browse files
author
peng.li24
committed
refactor(numpycpp): update init, linalg, numpy headers and svml_bridge
1 parent a40eedb commit 9cc1aff

4 files changed

Lines changed: 33 additions & 10 deletions

File tree

numpycpp/detail/svml_bridge.h

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -148,29 +148,39 @@ NUMPY_SVML_F32(cbrt, "__svml_cbrtf16", "npy_cbrtf")
148148
NUMPY_SVML_F32(expm1, "__svml_expm1f16","npy_expm1f")
149149
NUMPY_SVML_F32(log1p, "__svml_log1pf16","npy_log1pf")
150150
151-
// pow / atan2 — SVML 2-arg
151+
// pow / atan2 — SVML 2-arg: 使用 __svml_pow8 / __svml_atan28 向量符号,
152+
// 广播标量到 __m512,调用 SVML 向量函数,提取结果。确保与 numpy 内部使用的
153+
// SVML 实现位级一致(npy_pow / npy_atan2 是标量 libm 回退,会差 1 ULP)。
152154
__attribute__((target("avx512f")))
153155
inline double pow_svml_f64(double x, double e) {
154-
static auto fn = (double (*)(double, double))resolve_svml("npy_pow");
155-
if (fn) return fn(x, e);
156+
static auto fn = (__m512d (*)(__m512d, __m512d))resolve_svml("__svml_pow8");
157+
if (fn) return _mm512_cvtsd_f64(fn(_mm512_set1_pd(x), _mm512_set1_pd(e)));
158+
static auto scalar_fn = (double (*)(double, double))resolve_svml("npy_pow");
159+
if (scalar_fn) return scalar_fn(x, e);
156160
return std::pow(x, e);
157161
}
158162
__attribute__((target("avx512f")))
159163
inline float pow_svml_f32(float x, float e) {
160-
static auto fn = (float (*)(float, float))resolve_svml("npy_powf");
161-
if (fn) return fn(x, e);
164+
static auto fn = (__m512 (*)(__m512, __m512))resolve_svml("__svml_powf16");
165+
if (fn) return _mm512_cvtss_f32(fn(_mm512_set1_ps(x), _mm512_set1_ps(e)));
166+
static auto scalar_fn = (float (*)(float, float))resolve_svml("npy_powf");
167+
if (scalar_fn) return scalar_fn(x, e);
162168
return std::pow(x, e);
163169
}
164170
__attribute__((target("avx512f")))
165171
inline double atan2_svml_f64(double y, double x) {
166-
static auto fn = (double (*)(double, double))resolve_svml("npy_atan2");
167-
if (fn) return fn(y, x);
172+
static auto fn = (__m512d (*)(__m512d, __m512d))resolve_svml("__svml_atan28");
173+
if (fn) return _mm512_cvtsd_f64(fn(_mm512_set1_pd(y), _mm512_set1_pd(x)));
174+
static auto scalar_fn = (double (*)(double, double))resolve_svml("npy_atan2");
175+
if (scalar_fn) return scalar_fn(y, x);
168176
return std::atan2(y, x);
169177
}
170178
__attribute__((target("avx512f")))
171179
inline float atan2_svml_f32(float y, float x) {
172-
static auto fn = (float (*)(float, float))resolve_svml("npy_atan2f");
173-
if (fn) return fn(y, x);
180+
static auto fn = (__m512 (*)(__m512, __m512))resolve_svml("__svml_atan2f16");
181+
if (fn) return _mm512_cvtss_f32(fn(_mm512_set1_ps(y), _mm512_set1_ps(x)));
182+
static auto scalar_fn = (float (*)(float, float))resolve_svml("npy_atan2f");
183+
if (scalar_fn) return scalar_fn(y, x);
174184
return std::atan2(y, x);
175185
}
176186

numpycpp/init.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ inline void logspace(T* dst, T start, T stop, size_t num,
9797
bool endpoint = true, T base = T(10)) {
9898
linspace(dst, start, stop, num, endpoint);
9999
for (size_t i = 0; i < num; ++i)
100-
dst[i] = std::pow(base, dst[i]);
100+
dst[i] = detail::pow(base, dst[i]);
101101
}
102102

103103
// ============================================================================

numpycpp/linalg.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#include <map>
2121
#include <set>
2222
#include <algorithm>
23+
#include <numeric>
2324
#include <stdexcept>
2425
#include <type_traits>
2526
#include <immintrin.h> // SSE/AVX intrinsics (SSE2 is baseline on x86_64)

numpycpp/numpy.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,18 @@
1717
// ════════════════════════════════════════════════════════════════════════════
1818
#pragma once
1919

20+
// Internal math backend — loaded first so init.h et al. can use detail::pow etc.
21+
#ifndef NUMPYCPP_INTERNAL_INCLUDE
22+
# define NUMPYCPP_INTERNAL_INCLUDE
23+
# ifdef NUMPYCPP_STD_ONLY
24+
# include "detail/std_math_backend.h"
25+
# else
26+
# include "detail/npy_math_float.h"
27+
# include "detail/svml_bridge.h"
28+
# endif
29+
# undef NUMPYCPP_INTERNAL_INCLUDE
30+
#endif
31+
2032
#include "init.h"
2133
#include "elementwise.h"
2234
#include "reduce.h"

0 commit comments

Comments
 (0)