From 377740549d4d5b519d2e938de5cffe1672dc91d9 Mon Sep 17 00:00:00 2001 From: rayanirban Date: Sat, 14 Dec 2024 12:01:24 +0100 Subject: [PATCH 1/4] added an option to handle conditional input to the func --- torchdiffeq/_impl/solvers.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/torchdiffeq/_impl/solvers.py b/torchdiffeq/_impl/solvers.py index cc64218b1..387d96d59 100644 --- a/torchdiffeq/_impl/solvers.py +++ b/torchdiffeq/_impl/solvers.py @@ -99,7 +99,7 @@ def _grid_constructor(func, y0, t): def _step_func(self, func, t0, dt, t1, y0): pass - def integrate(self, t): + def integrate(self, t, condition=None): time_grid = self.grid_constructor(self.func, self.y0, t) assert time_grid[0] == t[0] and time_grid[-1] == t[-1] @@ -108,11 +108,16 @@ def integrate(self, t): j = 1 y0 = self.y0 + if condition is not None: + y0_condition = y0.clone() for t0, t1 in zip(time_grid[:-1], time_grid[1:]): dt = t1 - t0 self.func.callback_step(t0, y0, dt) dy, f0 = self._step_func(self.func, t0, dt, t1, y0) y1 = y0 + dy + if condition is not None: + #replace all channels of y1 except ch 0 with y0_condition + y1[:,condition:,:,:] = y0_condition[:,condition:,:,:] while j < len(t) and t1 >= t[j]: if self.interp == "linear": From b7dd3f601d0819ccfc5c9cf4dd5dcd2cafe7db51 Mon Sep 17 00:00:00 2001 From: rayanirban Date: Sun, 15 Dec 2024 12:11:50 +0100 Subject: [PATCH 2/4] added an option to handle conditional input to the func --- torchdiffeq/_impl/odeint.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchdiffeq/_impl/odeint.py b/torchdiffeq/_impl/odeint.py index 07f8666a8..690f8de6d 100644 --- a/torchdiffeq/_impl/odeint.py +++ b/torchdiffeq/_impl/odeint.py @@ -31,7 +31,7 @@ } -def odeint(func, y0, t, *, rtol=1e-7, atol=1e-9, method=None, options=None, event_fn=None): +def odeint(func, y0, t, *, rtol=1e-7, atol=1e-9, method=None, options=None, event_fn=None, condition=None): """Integrate a system of ordinary differential equations. Solves the initial value problem for a non-stiff system of first order ODEs: @@ -77,7 +77,7 @@ def odeint(func, y0, t, *, rtol=1e-7, atol=1e-9, method=None, options=None, even solver = SOLVERS[method](func=func, y0=y0, rtol=rtol, atol=atol, **options) if event_fn is None: - solution = solver.integrate(t) + solution = solver.integrate(t, condition) else: event_t, solution = solver.integrate_until_event(t[0], event_fn) event_t = event_t.to(t) From 74a3924afd0f611848d85bd30c7ceb3e870ddf5a Mon Sep 17 00:00:00 2001 From: rayanirban Date: Sun, 23 Feb 2025 11:16:58 +0100 Subject: [PATCH 3/4] also returning all the ys during integration --- torchdiffeq/_impl/odeint.py | 9 +++------ torchdiffeq/_impl/solvers.py | 13 ++++++++----- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/torchdiffeq/_impl/odeint.py b/torchdiffeq/_impl/odeint.py index 690f8de6d..c795e29c4 100644 --- a/torchdiffeq/_impl/odeint.py +++ b/torchdiffeq/_impl/odeint.py @@ -4,7 +4,7 @@ from .bosh3 import Bosh3Solver from .adaptive_heun import AdaptiveHeunSolver from .fehlberg2 import Fehlberg2 -from .fixed_grid import Euler, Midpoint, Heun2, Heun3, RK4 +from .fixed_grid import Euler, Midpoint, Heun3, RK4 from .fixed_adams import AdamsBashforth, AdamsBashforthMoulton from .dopri8 import Dopri8Solver from .scipy_wrapper import ScipyWrapperODESolver @@ -19,7 +19,6 @@ 'adaptive_heun': AdaptiveHeunSolver, 'euler': Euler, 'midpoint': Midpoint, - 'heun2': Heun2, 'heun3': Heun3, 'rk4': RK4, 'explicit_adams': AdamsBashforth, @@ -29,8 +28,6 @@ # ~Backwards compatibility 'scipy_solver': ScipyWrapperODESolver, } - - def odeint(func, y0, t, *, rtol=1e-7, atol=1e-9, method=None, options=None, event_fn=None, condition=None): """Integrate a system of ordinary differential equations. @@ -91,7 +88,7 @@ def odeint(func, y0, t, *, rtol=1e-7, atol=1e-9, method=None, options=None, even return solution else: return event_t, solution - + def odeint_dense(func, y0, t0, t1, *, rtol=1e-7, atol=1e-9, method=None, options=None): @@ -213,4 +210,4 @@ def backward(ctx, grad_t, grad_state): grad_state = grad_state + dstate - return None, None, None, grad_state + return None, None, None, grad_state \ No newline at end of file diff --git a/torchdiffeq/_impl/solvers.py b/torchdiffeq/_impl/solvers.py index 387d96d59..cc8dacb0d 100644 --- a/torchdiffeq/_impl/solvers.py +++ b/torchdiffeq/_impl/solvers.py @@ -105,6 +105,8 @@ def integrate(self, t, condition=None): solution = torch.empty(len(t), *self.y0.shape, dtype=self.y0.dtype, device=self.y0.device) solution[0] = self.y0 + all_ys = torch.empty(len(t-1), *self.y0.shape, dtype=self.y0.dtype, device=self.y0.device) + all_ys[0] = self.y0 j = 1 y0 = self.y0 @@ -114,9 +116,9 @@ def integrate(self, t, condition=None): dt = t1 - t0 self.func.callback_step(t0, y0, dt) dy, f0 = self._step_func(self.func, t0, dt, t1, y0) - y1 = y0 + dy + y1 = y0 + dy[:,0:1,...] if condition is not None: - #replace all channels of y1 except ch 0 with y0_condition + #replace all channels of y1 except ch 0 with y0_condition y1[:,condition:,:,:] = y0_condition[:,condition:,:,:] while j < len(t) and t1 >= t[j]: @@ -129,8 +131,9 @@ def integrate(self, t, condition=None): raise ValueError(f"Unknown interpolation method {self.interp}") j += 1 y0 = y1 - - return solution + all_ys[j-1]=y0 + + return solution, all_ys def integrate_until_event(self, t0, event_fn): assert self.step_size is not None, "Event handling for fixed step solvers currently requires `step_size` to be provided in options." @@ -183,4 +186,4 @@ def _linear_interp(self, t0, t1, y0, y1, t): if t == t1: return y1 slope = (t - t0) / (t1 - t0) - return y0 + slope * (y1 - y0) + return y0 + slope * (y1 - y0) \ No newline at end of file From e8a6aedf6a18b90356db6b3357fb965120eb18d4 Mon Sep 17 00:00:00 2001 From: rayanirban Date: Tue, 10 Jun 2025 18:25:56 +0200 Subject: [PATCH 4/4] removed unnecessary slicing --- torchdiffeq/_impl/solvers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchdiffeq/_impl/solvers.py b/torchdiffeq/_impl/solvers.py index cc8dacb0d..8cbe7a779 100644 --- a/torchdiffeq/_impl/solvers.py +++ b/torchdiffeq/_impl/solvers.py @@ -116,7 +116,7 @@ def integrate(self, t, condition=None): dt = t1 - t0 self.func.callback_step(t0, y0, dt) dy, f0 = self._step_func(self.func, t0, dt, t1, y0) - y1 = y0 + dy[:,0:1,...] + y1 = y0 + dy if condition is not None: #replace all channels of y1 except ch 0 with y0_condition y1[:,condition:,:,:] = y0_condition[:,condition:,:,:]