|
11 | 11 | from uncertaintyx.b.jax import BernsteinPoly |
12 | 12 | from uncertaintyx.b.jax import b_basis |
13 | 13 | from uncertaintyx.b.jax import b_poly |
| 14 | +from uncertaintyx.b.jax import solve |
14 | 15 |
|
15 | 16 |
|
16 | 17 | class BBasisTest(unittest.TestCase): |
@@ -277,6 +278,72 @@ def test_bernstein_poly(self): |
277 | 278 | self.assertEqual((3,) + d, g.shape) |
278 | 279 | self.assertTrue(np.all(g > 0.0)) |
279 | 280 |
|
| 281 | + def test_from_lookup_table(self): |
| 282 | + k = 5 |
| 283 | + x = np.array([0.00, 0.20, 0.40, 0.60, 0.80, 1.00]) |
| 284 | + y = np.array( # y = x ** 2 + 2 x + 3 |
| 285 | + [3.00, 3.44, 3.96, 4.56, 5.24, 6.00] |
| 286 | + ) |
| 287 | + |
| 288 | + f = BernsteinPoly.from_lookup_table((k,), (x,), y, non_negative=True) |
| 289 | + c = f.prior() |
| 290 | + self.assertEqual((k + 1,), c.shape) |
| 291 | + self.assertTrue(jnp.allclose(f.eval(c, x), y)) |
| 292 | + self.assertAlmostEqual(3.0, c[0]) |
| 293 | + self.assertAlmostEqual(3.4, c[1]) |
| 294 | + self.assertAlmostEqual(3.9, c[2]) |
| 295 | + self.assertAlmostEqual(4.5, c[3]) |
| 296 | + self.assertAlmostEqual(5.2, c[4]) |
| 297 | + self.assertAlmostEqual(6.0, c[5]) |
| 298 | + |
| 299 | + |
| 300 | +class SolveTest(unittest.TestCase): |
| 301 | + """Tests the solving function.""" |
| 302 | + |
| 303 | + def test_solve_degree_2(self): |
| 304 | + k = 2 |
| 305 | + x = jnp.array([0.00, 0.20, 0.40, 0.60, 0.80, 1.00]) |
| 306 | + y = jnp.array( # y = x ** 2 + 2 x + 3 |
| 307 | + [3.00, 3.44, 3.96, 4.56, 5.24, 6.00] |
| 308 | + ) |
| 309 | + |
| 310 | + c = solve((k,), (x,), y, non_negative=True) |
| 311 | + self.assertEqual((k + 1,), c.shape) |
| 312 | + self.assertTrue(jnp.allclose(b_poly(c, x), y)) |
| 313 | + self.assertAlmostEqual(3.0, c[0].item()) |
| 314 | + self.assertAlmostEqual(4.0, c[1].item()) |
| 315 | + self.assertAlmostEqual(6.0, c[2].item()) |
| 316 | + |
| 317 | + def test_solve_degree_5(self): |
| 318 | + k = 5 |
| 319 | + x = jnp.array([0.00, 0.20, 0.40, 0.60, 0.80, 1.00]) |
| 320 | + y = ( |
| 321 | + jnp.array( # y = x ** 2 + 2 x - 1 / 100 |
| 322 | + [0.00, 0.44, 0.96, 1.56, 2.24, 3.00] |
| 323 | + ) |
| 324 | + - 0.01 |
| 325 | + ) |
| 326 | + |
| 327 | + c = solve((k,), (x,), y) |
| 328 | + self.assertEqual((k + 1,), c.shape) |
| 329 | + self.assertTrue(jnp.allclose(b_poly(c, x), y)) |
| 330 | + self.assertAlmostEqual(-0.01, c[0].item()) |
| 331 | + self.assertAlmostEqual(0.39, c[1].item()) |
| 332 | + self.assertAlmostEqual(0.89, c[2].item()) |
| 333 | + self.assertAlmostEqual(1.49, c[3].item()) |
| 334 | + self.assertAlmostEqual(2.19, c[4].item()) |
| 335 | + self.assertAlmostEqual(2.99, c[5].item()) |
| 336 | + |
| 337 | + c = solve((k,), (x,), y, non_negative=True) |
| 338 | + self.assertEqual((k + 1,), c.shape) |
| 339 | + self.assertTrue(jnp.allclose(b_poly(c, x), y, atol=0.1)) |
| 340 | + self.assertAlmostEqual(0.00, c[0].item()) |
| 341 | + self.assertAlmostEqual(0.38, c[1].item(), places=2) |
| 342 | + self.assertAlmostEqual(0.90, c[2].item(), places=2) |
| 343 | + self.assertAlmostEqual(1.48, c[3].item(), places=2) |
| 344 | + self.assertAlmostEqual(2.19, c[4].item(), places=2) |
| 345 | + self.assertAlmostEqual(2.99, c[5].item(), places=2) |
| 346 | + |
280 | 347 |
|
281 | 348 | def to_var(u: np.ndarray) -> np.ndarray: |
282 | 349 | """ |
|
0 commit comments