@@ -148,29 +148,39 @@ NUMPY_SVML_F32(cbrt, "__svml_cbrtf16", "npy_cbrtf")
148148NUMPY_SVML_F32(expm1, " __svml_expm1f16" ," npy_expm1f" )
149149NUMPY_SVML_F32(log1p, " __svml_log1pf16" ," npy_log1pf" )
150150
151- // pow / atan2 — SVML 2-arg
151+ // pow / atan2 — SVML 2-arg: 使用 __svml_pow8 / __svml_atan28 向量符号,
152+ // 广播标量到 __m512,调用 SVML 向量函数,提取结果。确保与 numpy 内部使用的
153+ // SVML 实现位级一致(npy_pow / npy_atan2 是标量 libm 回退,会差 1 ULP)。
152154__attribute__((target(" avx512f" )))
153155inline double pow_svml_f64(double x, double e) {
154- static auto fn = (double (*)(double, double))resolve_svml(" npy_pow" );
155- if (fn) return fn(x, e);
156+ static auto fn = (__m512d (*)(__m512d, __m512d))resolve_svml(" __svml_pow8" );
157+ if (fn) return _mm512_cvtsd_f64(fn(_mm512_set1_pd(x), _mm512_set1_pd(e)));
158+ static auto scalar_fn = (double (*)(double, double))resolve_svml(" npy_pow" );
159+ if (scalar_fn) return scalar_fn(x, e);
156160 return std::pow(x, e);
157161}
158162__attribute__((target(" avx512f" )))
159163inline float pow_svml_f32(float x, float e) {
160- static auto fn = (float (*)(float, float))resolve_svml(" npy_powf" );
161- if (fn) return fn(x, e);
164+ static auto fn = (__m512 (*)(__m512, __m512))resolve_svml(" __svml_powf16" );
165+ if (fn) return _mm512_cvtss_f32(fn(_mm512_set1_ps(x), _mm512_set1_ps(e)));
166+ static auto scalar_fn = (float (*)(float, float))resolve_svml(" npy_powf" );
167+ if (scalar_fn) return scalar_fn(x, e);
162168 return std::pow(x, e);
163169}
164170__attribute__((target(" avx512f" )))
165171inline double atan2_svml_f64(double y, double x) {
166- static auto fn = (double (*)(double, double))resolve_svml(" npy_atan2" );
167- if (fn) return fn(y, x);
172+ static auto fn = (__m512d (*)(__m512d, __m512d))resolve_svml(" __svml_atan28" );
173+ if (fn) return _mm512_cvtsd_f64(fn(_mm512_set1_pd(y), _mm512_set1_pd(x)));
174+ static auto scalar_fn = (double (*)(double, double))resolve_svml(" npy_atan2" );
175+ if (scalar_fn) return scalar_fn(y, x);
168176 return std::atan2(y, x);
169177}
170178__attribute__((target(" avx512f" )))
171179inline float atan2_svml_f32(float y, float x) {
172- static auto fn = (float (*)(float, float))resolve_svml(" npy_atan2f" );
173- if (fn) return fn(y, x);
180+ static auto fn = (__m512 (*)(__m512, __m512))resolve_svml(" __svml_atan2f16" );
181+ if (fn) return _mm512_cvtss_f32(fn(_mm512_set1_ps(y), _mm512_set1_ps(x)));
182+ static auto scalar_fn = (float (*)(float, float))resolve_svml(" npy_atan2f" );
183+ if (scalar_fn) return scalar_fn(y, x);
174184 return std::atan2(y, x);
175185}
176186
0 commit comments