diff --git a/torchdiffeq/_impl/misc.py b/torchdiffeq/_impl/misc.py index 685cc8da..761a2e7b 100644 --- a/torchdiffeq/_impl/misc.py +++ b/torchdiffeq/_impl/misc.py @@ -104,7 +104,8 @@ def _assert_one_dimensional(name, t): def _assert_increasing(name, t): - assert (t[1:] > t[:-1]).all(), '{} must be strictly increasing or decreasing'.format(name) + cond = (t[1:] > t[:-1]).all().item() + torch._check(cond, f"{name} must be strictly increasing or decreasing") def _assert_floating(name, t): @@ -380,7 +381,8 @@ def _check_timelike(name, timelike, can_grad): if not can_grad: assert not timelike.requires_grad, "{} cannot require gradient".format(name) diff = timelike[1:] > timelike[:-1] - assert diff.all() or (~diff).all(), '{} must be strictly increasing or decreasing'.format(name) + cond = torch.logical_or(diff.all(), (~diff).all()).item() + torch._check(cond, f"{name} must be strictly increasing or decreasing") def _flip_option(options, option_name): diff --git a/torchdiffeq/_impl/solvers.py b/torchdiffeq/_impl/solvers.py index cc64218b..219e55a1 100644 --- a/torchdiffeq/_impl/solvers.py +++ b/torchdiffeq/_impl/solvers.py @@ -101,7 +101,8 @@ def _step_func(self, func, t0, dt, t1, y0): def integrate(self, t): time_grid = self.grid_constructor(self.func, self.y0, t) - assert time_grid[0] == t[0] and time_grid[-1] == t[-1] + torch._check((time_grid[0] == t[0]).item()) + torch._check((time_grid[-1] == t[-1]).item()) solution = torch.empty(len(t), *self.y0.shape, dtype=self.y0.dtype, device=self.y0.device) solution[0] = self.y0