Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 5 additions & 8 deletions torchdiffeq/_impl/odeint.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from .bosh3 import Bosh3Solver
from .adaptive_heun import AdaptiveHeunSolver
from .fehlberg2 import Fehlberg2
from .fixed_grid import Euler, Midpoint, Heun2, Heun3, RK4
from .fixed_grid import Euler, Midpoint, Heun3, RK4
from .fixed_adams import AdamsBashforth, AdamsBashforthMoulton
from .dopri8 import Dopri8Solver
from .scipy_wrapper import ScipyWrapperODESolver
Expand All @@ -19,7 +19,6 @@
'adaptive_heun': AdaptiveHeunSolver,
'euler': Euler,
'midpoint': Midpoint,
'heun2': Heun2,
'heun3': Heun3,
'rk4': RK4,
'explicit_adams': AdamsBashforth,
Expand All @@ -29,9 +28,7 @@
# ~Backwards compatibility
'scipy_solver': ScipyWrapperODESolver,
}


def odeint(func, y0, t, *, rtol=1e-7, atol=1e-9, method=None, options=None, event_fn=None):
def odeint(func, y0, t, *, rtol=1e-7, atol=1e-9, method=None, options=None, event_fn=None, condition=None):
"""Integrate a system of ordinary differential equations.

Solves the initial value problem for a non-stiff system of first order ODEs:
Expand Down Expand Up @@ -77,7 +74,7 @@ def odeint(func, y0, t, *, rtol=1e-7, atol=1e-9, method=None, options=None, even
solver = SOLVERS[method](func=func, y0=y0, rtol=rtol, atol=atol, **options)

if event_fn is None:
solution = solver.integrate(t)
solution = solver.integrate(t, condition)
else:
event_t, solution = solver.integrate_until_event(t[0], event_fn)
event_t = event_t.to(t)
Expand All @@ -91,7 +88,7 @@ def odeint(func, y0, t, *, rtol=1e-7, atol=1e-9, method=None, options=None, even
return solution
else:
return event_t, solution


def odeint_dense(func, y0, t0, t1, *, rtol=1e-7, atol=1e-9, method=None, options=None):

Expand Down Expand Up @@ -213,4 +210,4 @@ def backward(ctx, grad_t, grad_state):

grad_state = grad_state + dstate

return None, None, None, grad_state
return None, None, None, grad_state
16 changes: 12 additions & 4 deletions torchdiffeq/_impl/solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,20 +99,27 @@ def _grid_constructor(func, y0, t):
def _step_func(self, func, t0, dt, t1, y0):
pass

def integrate(self, t):
def integrate(self, t, condition=None):
time_grid = self.grid_constructor(self.func, self.y0, t)
assert time_grid[0] == t[0] and time_grid[-1] == t[-1]

solution = torch.empty(len(t), *self.y0.shape, dtype=self.y0.dtype, device=self.y0.device)
solution[0] = self.y0
all_ys = torch.empty(len(t-1), *self.y0.shape, dtype=self.y0.dtype, device=self.y0.device)
all_ys[0] = self.y0

j = 1
y0 = self.y0
if condition is not None:
y0_condition = y0.clone()
for t0, t1 in zip(time_grid[:-1], time_grid[1:]):
dt = t1 - t0
self.func.callback_step(t0, y0, dt)
dy, f0 = self._step_func(self.func, t0, dt, t1, y0)
y1 = y0 + dy
if condition is not None:
#replace all channels of y1 except ch 0 with y0_condition
y1[:,condition:,:,:] = y0_condition[:,condition:,:,:]

while j < len(t) and t1 >= t[j]:
if self.interp == "linear":
Expand All @@ -124,8 +131,9 @@ def integrate(self, t):
raise ValueError(f"Unknown interpolation method {self.interp}")
j += 1
y0 = y1

return solution
all_ys[j-1]=y0

return solution, all_ys

def integrate_until_event(self, t0, event_fn):
assert self.step_size is not None, "Event handling for fixed step solvers currently requires `step_size` to be provided in options."
Expand Down Expand Up @@ -178,4 +186,4 @@ def _linear_interp(self, t0, t1, y0, y1, t):
if t == t1:
return y1
slope = (t - t0) / (t1 - t0)
return y0 + slope * (y1 - y0)
return y0 + slope * (y1 - y0)