Skip to content

Commit 55fb054

Browse files
author
peng.li24
committed
fix: linalg dot/norm — 0 ULP via OpenBLAS bridge
np.dot(a,b) and np.linalg.norm(a) internally call BLAS (OpenBLAS ILP64 sdot_64_/ddot_64_) — our previous pairwise_sum implementation gave the same mathematical result but different bit patterns (~70% mismatch rate). Changes: - numpy/blas_bridge.h (new): auto-discovers libopenblas64_p*.so from /proc/self/maps (same pattern as svml_bridge.h); provides detail::blas_ops<T>::dot/norm using ILP64 Fortran calling convention - numpy/core.h: numpy::dot<T> now calls blas_ops<T>::dot - numpy/linalg.h: linalg::norm(axis=None) now calls blas_ops<T>::norm = sqrt(blas_dot(x, x)) — matches np.linalg.norm exactly - linalg::norm_axis already 0 ULP (uses numpy pairwise sum) — unchanged - tests/test_all.py: test_dot and test_norm_* now compare against np.dot / np.linalg.norm (correct BLAS references) ULP scan (1000 random arrays, sizes 1-300): 0/1000 mismatches for dot f32/f64, norm(axis=None) f32/f64, norm(axis=1) f32/f64. All 548 tests pass.
1 parent dc21f03 commit 55fb054

4 files changed

Lines changed: 128 additions & 16 deletions

File tree

numpy/blas_bridge.h

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
// INTERNAL HEADER — auto-included by core.h and linalg.h.
2+
// DO NOT include directly.
3+
//
4+
// BLAS bridge — bit-exact dot/norm vs numpy's OpenBLAS-backed np.dot /
5+
// np.linalg.norm (without axis).
6+
//
7+
// numpy routes 1-D dot and Frobenius norm through BLAS (OpenBLAS ILP64):
8+
// np.dot(a, b) → sdot_64_ / ddot_64_
9+
// np.linalg.norm(a) → sqrt(x.dot(x)) → same sdot_64_ / ddot_64_
10+
//
11+
// np.linalg.norm(a, axis=k) uses numpy's own pairwise sum — already
12+
// handled by norm_axis() in core.h, no BLAS needed.
13+
//
14+
// The OpenBLAS library path is auto-discovered from /proc/self/maps
15+
// (numpy loads it when imported), so no compile-time link flag is needed.
16+
//
17+
// ILP64 Fortran calling convention (OpenBLAS built with BLAS_SYMBOL_SUFFIX=64_):
18+
// sdot_64_(n*, x*, incx*, y*, incy*) → float (return in xmm0)
19+
// ddot_64_(n*, x*, incx*, y*, incy*) → double (return in xmm0)
20+
//
21+
// Fallback (if OpenBLAS not discovered): sequential accumulation.
22+
23+
#pragma once
24+
25+
#include <cstdint>
26+
#include <cmath>
27+
#include <dlfcn.h>
28+
#include <fstream>
29+
#include <string>
30+
31+
namespace numpy {
32+
namespace detail {
33+
34+
inline void* g_blas_handle = nullptr;
35+
36+
inline const char* find_openblas_path() {
37+
static std::string path;
38+
static bool tried = false;
39+
if (tried) return path.empty() ? nullptr : path.c_str();
40+
tried = true;
41+
42+
std::ifstream maps("/proc/self/maps");
43+
std::string line;
44+
while (std::getline(maps, line)) {
45+
if (line.find("libopenblas") != std::string::npos &&
46+
line.find(".so") != std::string::npos) {
47+
auto pos = line.rfind('/');
48+
auto start = line.rfind(' ', pos);
49+
if (start != std::string::npos && pos != std::string::npos) {
50+
path = line.substr(start + 1);
51+
// trim trailing whitespace / newline
52+
while (!path.empty() && (path.back() == '\n' || path.back() == '\r'
53+
|| path.back() == ' '))
54+
path.pop_back();
55+
break;
56+
}
57+
}
58+
}
59+
return path.empty() ? nullptr : path.c_str();
60+
}
61+
62+
inline void* resolve_blas(const char* sym) {
63+
if (!g_blas_handle) {
64+
const char* path = find_openblas_path();
65+
if (path) g_blas_handle = dlopen(path, RTLD_NOLOAD | RTLD_LAZY);
66+
}
67+
return g_blas_handle ? dlsym(g_blas_handle, sym) : nullptr;
68+
}
69+
70+
// ILP64 Fortran function types (all int args are int64_t by pointer)
71+
using sdot64_fn = float (const int64_t*, const float*, const int64_t*,
72+
const float*, const int64_t*);
73+
using ddot64_fn = double (const int64_t*, const double*, const int64_t*,
74+
const double*, const int64_t*);
75+
76+
inline float blas_sdot(const float* x, const float* y, size_t n) {
77+
static auto fn = (sdot64_fn*)resolve_blas("sdot_64_");
78+
if (__builtin_expect(fn != nullptr, 1)) {
79+
const int64_t ni = static_cast<int64_t>(n), inc = 1;
80+
return fn(&ni, x, &inc, y, &inc);
81+
}
82+
// Fallback: sequential accumulation
83+
float r = 0.0f;
84+
for (size_t i = 0; i < n; ++i) r += x[i] * y[i];
85+
return r;
86+
}
87+
88+
inline double blas_ddot(const double* x, const double* y, size_t n) {
89+
static auto fn = (ddot64_fn*)resolve_blas("ddot_64_");
90+
if (__builtin_expect(fn != nullptr, 1)) {
91+
const int64_t ni = static_cast<int64_t>(n), inc = 1;
92+
return fn(&ni, x, &inc, y, &inc);
93+
}
94+
double r = 0.0;
95+
for (size_t i = 0; i < n; ++i) r += x[i] * y[i];
96+
return r;
97+
}
98+
99+
// Template dispatcher
100+
template<typename T> struct blas_ops;
101+
102+
template<> struct blas_ops<float> {
103+
static float dot (const float* x, const float* y, size_t n) { return blas_sdot(x, y, n); }
104+
static float norm(const float* x, size_t n) { return std::sqrt(blas_sdot(x, x, n)); }
105+
};
106+
template<> struct blas_ops<double> {
107+
static double dot (const double* x, const double* y, size_t n) { return blas_ddot(x, y, n); }
108+
static double norm(const double* x, size_t n) { return std::sqrt(blas_ddot(x, x, n)); }
109+
};
110+
111+
} // namespace detail
112+
} // namespace numpy

numpy/core.h

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
#include <stdexcept>
2424

2525
#include "svml_bridge.h"
26+
#include "blas_bridge.h"
2627

2728
namespace numpy {
2829

@@ -974,15 +975,12 @@ inline T norm_sq(const T* data, size_t n) {
974975
return result;
975976
}
976977

977-
/// numpy.dot(a, b, out=None) — pairwise sum, matches np.sum(a*b)
978+
/// numpy.dot(a, b, out=None)
979+
/// Routes through OpenBLAS sdot_64_/ddot_64_ (auto-discovered via /proc/self/maps)
980+
/// for bit-exact match with np.dot(a, b) which calls BLAS internally.
978981
template<typename T>
979982
inline T dot(const T* a, const T* b, size_t n) {
980-
T buf[NUMPY_SMALL_STACK];
981-
T* prods = (n <= NUMPY_SMALL_STACK) ? buf : new T[n];
982-
for (size_t i = 0; i < n; ++i) prods[i] = a[i] * b[i];
983-
T result = pairwise_sum(prods, n);
984-
if (n > NUMPY_SMALL_STACK) delete[] prods;
985-
return result;
983+
return detail::blas_ops<T>::dot(a, b, n);
986984
}
987985

988986
/// numpy.linalg.norm(x, ord=None, axis=N, keepdims=False) — N-D

numpy/linalg.h

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,12 @@
99
namespace numpy {
1010
namespace linalg {
1111

12-
/// numpy.linalg.norm(x, ord=None, axis=None, keepdims=False) — frobenius/vector
13-
// Uses norm_sq (pairwise sum) → matches np.sqrt(np.sum(x**2)).
14-
// For float32, norm_sq() and sqrt() stay in float32.
12+
/// numpy.linalg.norm(x, ord=None, axis=None, keepdims=False) — vector / Frobenius
13+
// np.linalg.norm(a) internally computes sqrt(a.dot(a)) via BLAS sdot/ddot.
14+
// We call the same OpenBLAS routine (auto-discovered) for bit-exact match.
1515
template<typename T>
1616
inline T norm(const T* data, size_t n) {
17-
T sqnorm = numpy::norm_sq(data, n); // pairwise sum of squares
18-
return std::sqrt(sqnorm);
17+
return numpy::detail::blas_ops<T>::norm(data, n);
1918
}
2019

2120
/// numpy.linalg.norm(x, ord=None, axis=N, keepdims=False) — N-D

tests/test_all.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -958,11 +958,13 @@ def test_to_vector_bool(cpp):
958958

959959
def test_norm_1d(cpp, dtype):
960960
a = random_array((100,), dtype=dtype)
961-
assert_bit_aligned(dtype(cpp.linalg.norm(a)), np.sqrt(np.sum(a * a)), "linalg.norm 1d")
961+
# np.linalg.norm internally computes sqrt(a.dot(a)) via BLAS
962+
assert_bit_aligned(dtype(cpp.linalg.norm(a)), dtype(np.linalg.norm(a)), "linalg.norm 1d")
962963

963964
def test_norm_2d(cpp, dtype):
964965
a = random_array((5, 4), dtype=dtype)
965-
assert_bit_aligned(dtype(cpp.linalg.norm(a)), np.sqrt(np.sum(a * a)), "linalg.norm 2d")
966+
# Frobenius norm: same BLAS path as 1d
967+
assert_bit_aligned(dtype(cpp.linalg.norm(a)), dtype(np.linalg.norm(a)), "linalg.norm 2d")
966968

967969
def test_norm_zero(cpp, dtype):
968970
a = np.zeros((100,), dtype=dtype)
@@ -981,12 +983,13 @@ def test_norm_1d_fallback(cpp, dtype):
981983
def test_dot(cpp, dtype):
982984
a = random_array((5,), dtype=dtype)
983985
b = random_array((5,), seed=99, dtype=dtype)
984-
assert_bit_aligned(cpp.dot(a, b), np.sum(a * b), "dot")
986+
# np.dot routes through BLAS sdot/ddot
987+
assert_bit_aligned(cpp.dot(a, b), np.dot(a, b), "dot")
985988

986989
def test_dot_orthogonal(cpp, dtype):
987990
a = np.array([1.0, 0.0], dtype=dtype)
988991
b = np.array([0.0, 1.0], dtype=dtype)
989-
assert_bit_aligned(cpp.dot(a, b), np.sum(a * b), "dot orthogonal")
992+
assert_bit_aligned(cpp.dot(a, b), np.dot(a, b), "dot orthogonal")
990993

991994

992995
# ============================================================================

0 commit comments

Comments
 (0)