From 0fae86f99ac9bd7a8f513f205239274cd8d0e20d Mon Sep 17 00:00:00 2001 From: Pian Pawakapan Date: Wed, 21 May 2025 18:17:05 -0700 Subject: [PATCH 1/2] Update misc.py --- torchdiffeq/_impl/misc.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/torchdiffeq/_impl/misc.py b/torchdiffeq/_impl/misc.py index 685cc8da2..761a2e7bc 100644 --- a/torchdiffeq/_impl/misc.py +++ b/torchdiffeq/_impl/misc.py @@ -104,7 +104,8 @@ def _assert_one_dimensional(name, t): def _assert_increasing(name, t): - assert (t[1:] > t[:-1]).all(), '{} must be strictly increasing or decreasing'.format(name) + cond = (t[1:] > t[:-1]).all().item() + torch._check(cond, f"{name} must be strictly increasing or decreasing") def _assert_floating(name, t): @@ -380,7 +381,8 @@ def _check_timelike(name, timelike, can_grad): if not can_grad: assert not timelike.requires_grad, "{} cannot require gradient".format(name) diff = timelike[1:] > timelike[:-1] - assert diff.all() or (~diff).all(), '{} must be strictly increasing or decreasing'.format(name) + cond = torch.logical_or(diff.all(), (~diff).all()).item() + torch._check(cond, f"{name} must be strictly increasing or decreasing") def _flip_option(options, option_name): From b7e9143da11d9d5679c30f53b411274ab9838a9b Mon Sep 17 00:00:00 2001 From: Pian Pawakapan Date: Wed, 21 May 2025 18:17:52 -0700 Subject: [PATCH 2/2] Update solvers.py --- torchdiffeq/_impl/solvers.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torchdiffeq/_impl/solvers.py b/torchdiffeq/_impl/solvers.py index cc64218b1..219e55a11 100644 --- a/torchdiffeq/_impl/solvers.py +++ b/torchdiffeq/_impl/solvers.py @@ -101,7 +101,8 @@ def _step_func(self, func, t0, dt, t1, y0): def integrate(self, t): time_grid = self.grid_constructor(self.func, self.y0, t) - assert time_grid[0] == t[0] and time_grid[-1] == t[-1] + torch._check((time_grid[0] == t[0]).item()) + torch._check((time_grid[-1] == t[-1]).item()) solution = torch.empty(len(t), *self.y0.shape, dtype=self.y0.dtype, device=self.y0.device) solution[0] = self.y0