Skip to content

Commit d463175

Browse files
committed
New: non-negative least squares
1 parent efefd31 commit d463175

2 files changed

Lines changed: 108 additions & 16 deletions

File tree

test/b/test_b_jax.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from uncertaintyx.b.jax import BernsteinPoly
1212
from uncertaintyx.b.jax import b_basis
1313
from uncertaintyx.b.jax import b_poly
14+
from uncertaintyx.b.jax import solve
1415

1516

1617
class BBasisTest(unittest.TestCase):
@@ -277,6 +278,72 @@ def test_bernstein_poly(self):
277278
self.assertEqual((3,) + d, g.shape)
278279
self.assertTrue(np.all(g > 0.0))
279280

281+
def test_from_lookup_table(self):
282+
k = 5
283+
x = np.array([0.00, 0.20, 0.40, 0.60, 0.80, 1.00])
284+
y = np.array( # y = x ** 2 + 2 x + 3
285+
[3.00, 3.44, 3.96, 4.56, 5.24, 6.00]
286+
)
287+
288+
f = BernsteinPoly.from_lookup_table((k,), (x,), y, non_negative=True)
289+
c = f.prior()
290+
self.assertEqual((k + 1,), c.shape)
291+
self.assertTrue(jnp.allclose(f.eval(c, x), y))
292+
self.assertAlmostEqual(3.0, c[0])
293+
self.assertAlmostEqual(3.4, c[1])
294+
self.assertAlmostEqual(3.9, c[2])
295+
self.assertAlmostEqual(4.5, c[3])
296+
self.assertAlmostEqual(5.2, c[4])
297+
self.assertAlmostEqual(6.0, c[5])
298+
299+
300+
class SolveTest(unittest.TestCase):
301+
"""Tests the solving function."""
302+
303+
def test_solve_degree_2(self):
304+
k = 2
305+
x = jnp.array([0.00, 0.20, 0.40, 0.60, 0.80, 1.00])
306+
y = jnp.array( # y = x ** 2 + 2 x + 3
307+
[3.00, 3.44, 3.96, 4.56, 5.24, 6.00]
308+
)
309+
310+
c = solve((k,), (x,), y, non_negative=True)
311+
self.assertEqual((k + 1,), c.shape)
312+
self.assertTrue(jnp.allclose(b_poly(c, x), y))
313+
self.assertAlmostEqual(3.0, c[0].item())
314+
self.assertAlmostEqual(4.0, c[1].item())
315+
self.assertAlmostEqual(6.0, c[2].item())
316+
317+
def test_solve_degree_5(self):
318+
k = 5
319+
x = jnp.array([0.00, 0.20, 0.40, 0.60, 0.80, 1.00])
320+
y = (
321+
jnp.array( # y = x ** 2 + 2 x - 1 / 100
322+
[0.00, 0.44, 0.96, 1.56, 2.24, 3.00]
323+
)
324+
- 0.01
325+
)
326+
327+
c = solve((k,), (x,), y)
328+
self.assertEqual((k + 1,), c.shape)
329+
self.assertTrue(jnp.allclose(b_poly(c, x), y))
330+
self.assertAlmostEqual(-0.01, c[0].item())
331+
self.assertAlmostEqual(0.39, c[1].item())
332+
self.assertAlmostEqual(0.89, c[2].item())
333+
self.assertAlmostEqual(1.49, c[3].item())
334+
self.assertAlmostEqual(2.19, c[4].item())
335+
self.assertAlmostEqual(2.99, c[5].item())
336+
337+
c = solve((k,), (x,), y, non_negative=True)
338+
self.assertEqual((k + 1,), c.shape)
339+
self.assertTrue(jnp.allclose(b_poly(c, x), y, atol=0.1))
340+
self.assertAlmostEqual(0.00, c[0].item())
341+
self.assertAlmostEqual(0.38, c[1].item(), places=2)
342+
self.assertAlmostEqual(0.90, c[2].item(), places=2)
343+
self.assertAlmostEqual(1.48, c[3].item(), places=2)
344+
self.assertAlmostEqual(2.19, c[4].item(), places=2)
345+
self.assertAlmostEqual(2.99, c[5].item(), places=2)
346+
280347

281348
def to_var(u: np.ndarray) -> np.ndarray:
282349
"""

uncertaintyx/b/jax.py

Lines changed: 41 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
import jax.numpy as jnp
99
import jax.numpy.linalg as jli
1010
import numpy as np
11+
import optax
12+
import optimistix
1113
from jax import Array
1214
from jax.scipy.special import gammaln
1315

@@ -97,7 +99,7 @@ def b_basis(k: int, x: Array) -> Array:
9799

98100

99101
@jax.jit(static_argnums=(0,), static_argnames=("non_negative", "max_steps"))
100-
def linear_solve(
102+
def solve(
101103
k: tuple[int, ...],
102104
x: tuple[Array, ...],
103105
y: Array,
@@ -123,37 +125,60 @@ def linear_solve(
123125

124126
def nnls(c: Array):
125127
"""
126-
Non-negative linear least squares solver (yet incomplete).
128+
Non-negative least-squares solver.
129+
130+
Uses a quadratic transformation and an L-BFGS minimizer.
127131
"""
128132

129133
def hvp(c: Array):
130134
"""The Hessian-vector product."""
131-
hvp = c
135+
res = c
132136
for i in range(N):
133137
G = grams[i] # noqa: N806
134-
hvp = jnp.tensordot(hvp, G, axes=(0, 1))
135-
return hvp
136-
137-
return c
138+
res = jnp.tensordot(res, G, axes=(0, 1))
139+
return res
140+
141+
def misfit(u: Array, _: None = None) -> Array:
142+
"""The misfit function with quadratic transformation."""
143+
c_ = jnp.square(u)
144+
return 0.5 * jnp.sum(c_ * hvp(c_)) - jnp.sum(c_ * rhs)
145+
146+
def make_minimizer():
147+
"""Returns the minimizer."""
148+
return optimistix.OptaxMinimiser(
149+
optax.lbfgs(), atol=atol, rtol=rtol, norm=optimistix.max_norm
150+
)
151+
152+
u = jnp.sqrt(jnp.maximum(0.0, c))
153+
minimum = optimistix.minimise(
154+
misfit, make_minimizer(), u, max_steps=max_steps, throw=False
155+
)
156+
return jnp.square(minimum.value)
138157

139158
N = len(k) # noqa: N806
140159
bases = [b_basis(k[i], x[i]) for i in range(N)]
141160
grams = [jnp.dot(B, B.T) for B in bases] # noqa: N806
142161

143162
# compute the right hand side of the normal equation
163+
rhs = y
144164
for i in range(N):
145165
B = bases[i] # noqa: N806
146-
y = jnp.tensordot(y, B, axes=(0, 1))
166+
rhs = jnp.tensordot(rhs, B, axes=(0, 1))
147167
# solve the normal equation
148-
c = y
168+
c_unconstrained = rhs
149169
for i in range(N):
150170
G = grams[i] # noqa: N806
151-
c = jnp.tensordot(c, jli.inv(G), axes=(0, 1))
152-
# solve with non-negativity constraint, if requested
153-
if non_negative:
154-
c = nnls(jnp.maximum(0.0, c))
155-
156-
return c
171+
c_unconstrained = jnp.tensordot(
172+
c_unconstrained, jli.pinv(G), axes=(0, 1)
173+
)
174+
# solve iteratively with non-negativity constraint, if needed
175+
nnls_needed = non_negative and jnp.any(c_unconstrained < 0.0)
176+
return jax.lax.cond(
177+
nnls_needed,
178+
nnls,
179+
lambda _: _, # forwards the unconstrained solution
180+
c_unconstrained,
181+
)
157182

158183

159184
@jax.jit
@@ -380,7 +405,7 @@ def from_lookup_table(
380405
b = _upper_bounds(b, x)
381406
x_ = tuple(jnp.asarray((x_ - a) / (b - a)) for x_ in x)
382407
y_ = jnp.asarray(y)
383-
c_ = linear_solve(
408+
c_ = solve(
384409
k,
385410
x_,
386411
y_,

0 commit comments

Comments
 (0)