diff --git a/examples/pseudospectral_demo.py b/examples/pseudospectral_demo.py new file mode 100644 index 000000000..fcfa25af1 --- /dev/null +++ b/examples/pseudospectral_demo.py @@ -0,0 +1,43 @@ +import torch +import matplotlib.pyplot as plt +from torchdiffeq import odeint + +# Define a simple harmonic oscillator +class HarmonicOscillator: + def __init__(self, k=1.0): + self.k = k + + def __call__(self, t, y): + # y[0] is position, y[1] is velocity + return torch.stack([ + y[1], # dx/dt = v + -self.k * y[0] # dv/dt = -kx + ]) + +# Initial conditions: [x0, v0] +y0 = torch.tensor([1., 0.]) + +# Time points (using more points for better resolution) +t = torch.linspace(0, 10, 1000) + +# Solve using RPM method +with torch.no_grad(): + solution = odeint(HarmonicOscillator(), y0, t, method='rpm', + options={'n_points': 16}) # Using 16 Chebyshev points + +# Plot results +plt.figure(figsize=(10, 5)) +plt.plot(t, solution[:, 0].numpy(), label='Position') +plt.plot(t, solution[:, 1].numpy(), label='Velocity') +plt.grid(True) +plt.legend() +plt.title('Harmonic Oscillator - Pseudospectral Solution') +plt.xlabel('Time') +plt.ylabel('Value') +plt.show() + +# Print maximum error compared to analytical solution +exact_x = torch.cos(t) +exact_v = -torch.sin(t) +max_error = torch.max(torch.abs(solution[:, 0] - exact_x)) +print("Maximum error in position: {:.2e}".format(max_error)) diff --git a/torchdiffeq/_impl/odeint.py b/torchdiffeq/_impl/odeint.py index 07f8666a8..3e7d3e8ac 100644 --- a/torchdiffeq/_impl/odeint.py +++ b/torchdiffeq/_impl/odeint.py @@ -2,18 +2,20 @@ from torch.autograd.functional import vjp from .dopri5 import Dopri5Solver from .bosh3 import Bosh3Solver +from .radau import RadauSolver from .adaptive_heun import AdaptiveHeunSolver from .fehlberg2 import Fehlberg2 from .fixed_grid import Euler, Midpoint, Heun2, Heun3, RK4 from .fixed_adams import AdamsBashforth, AdamsBashforthMoulton from .dopri8 import Dopri8Solver -from .scipy_wrapper import ScipyWrapperODESolver +from .scipy_wrapper import ScipyWrapperODESolver, RK45Solver, DOP853Solver, Radau, BDF from .misc import _check_inputs, _flat_to_shape from .interp import _interp_evaluate SOLVERS = { 'dopri8': Dopri8Solver, 'dopri5': Dopri5Solver, + 'radau': RadauSolver, 'bosh3': Bosh3Solver, 'fehlberg2': Fehlberg2, 'adaptive_heun': AdaptiveHeunSolver, @@ -27,7 +29,10 @@ # Backward compatibility: use the same name as before 'fixed_adams': AdamsBashforthMoulton, # ~Backwards compatibility - 'scipy_solver': ScipyWrapperODESolver, + 'scipy_rk45': RK45Solver, + 'scipy_dop853': DOP853Solver, + 'scipy_radau': Radau, + 'scipy_bdf': BDF, } diff --git a/torchdiffeq/_impl/pseudospectral.py b/torchdiffeq/_impl/pseudospectral.py new file mode 100644 index 000000000..8b1378917 --- /dev/null +++ b/torchdiffeq/_impl/pseudospectral.py @@ -0,0 +1 @@ + diff --git a/torchdiffeq/_impl/radau.py b/torchdiffeq/_impl/radau.py new file mode 100644 index 000000000..a7c237b89 --- /dev/null +++ b/torchdiffeq/_impl/radau.py @@ -0,0 +1,50 @@ +import torch +from .rk_common import _ButcherTableau, RKAdaptiveStepsizeODESolver + +# Radau IIA coefficients (order 5, 3 stages) +# Reference: E. Hairer, G. Wanner, "Solving Ordinary Differential Equations II: Stiff and Differential-Algebraic Problems" +_RADAU_IIA_TABLEAU = _ButcherTableau( + alpha=torch.tensor([ + (4 - torch.sqrt(torch.tensor(6.))) / 10, + (4 + torch.sqrt(torch.tensor(6.))) / 10, + 1. + ], dtype=torch.float64), + beta=[ + torch.tensor([(4 - torch.sqrt(torch.tensor(6.))) / 10], dtype=torch.float64), + torch.tensor([ + (88 - 7 * torch.sqrt(torch.tensor(6.))) / 360, + (296 + 169 * torch.sqrt(torch.tensor(6.))) / 1800 + ], dtype=torch.float64), + torch.tensor([ + (296 - 169 * torch.sqrt(torch.tensor(6.))) / 1800, + (88 + 7 * torch.sqrt(torch.tensor(6.))) / 360, + (16 - torch.sqrt(torch.tensor(6.))) / 36 + ], dtype=torch.float64), + ], + c_sol=torch.tensor([ + (16 - torch.sqrt(torch.tensor(6.))) / 36, + (16 + torch.sqrt(torch.tensor(6.))) / 36, + 1/9, + 0. + ], dtype=torch.float64), + c_error=torch.tensor([ + (16 - torch.sqrt(torch.tensor(6.))) / 36 - (1/9), + (16 + torch.sqrt(torch.tensor(6.))) / 36 - (1/9), + 0., + 0. + ], dtype=torch.float64), +) + +# Interpolation coefficients for dense output +RADAU_C_MID = torch.tensor([ + 0.5 * ((16 - torch.sqrt(torch.tensor(6.))) / 36), + 0.5 * ((16 + torch.sqrt(torch.tensor(6.))) / 36), + 0.5 * (1/9), + 0. +], dtype=torch.float64) + + +class RadauSolver(RKAdaptiveStepsizeODESolver): + order = 5 + tableau = _RADAU_IIA_TABLEAU + mid = RADAU_C_MID \ No newline at end of file diff --git a/torchdiffeq/_impl/rk_common.py b/torchdiffeq/_impl/rk_common.py index 3b07877b7..493b7757d 100644 --- a/torchdiffeq/_impl/rk_common.py +++ b/torchdiffeq/_impl/rk_common.py @@ -281,7 +281,9 @@ def _adaptive_step(self, rk_state): ######################################################## # Assertions # ######################################################## - assert t0 + dt > t0, 'underflow in dt {}'.format(dt.item()) + # if (dt <= 0).any(): + # print('dt < 0') + assert t0 + dt > t0, 'underflow in dt {}, t0 {}'.format(dt.item(), t0) assert torch.isfinite(y0).all(), 'non-finite values in state `y`: {}'.format(y0) ######################################################## diff --git a/torchdiffeq/_impl/scipy_wrapper.py b/torchdiffeq/_impl/scipy_wrapper.py index 41fe90149..461dbac5f 100644 --- a/torchdiffeq/_impl/scipy_wrapper.py +++ b/torchdiffeq/_impl/scipy_wrapper.py @@ -27,34 +27,66 @@ def __init__(self, func, y0, rtol, atol, min_step=0, max_step=float('inf'), solv def integrate(self, t): if t.numel() == 1: return torch.tensor(self.y0)[None].to(self.device, self.dtype) - t = t.detach().cpu().numpy() + t_np = t.detach().cpu().numpy() sol = solve_ivp( self.func, - t_span=[t.min(), t.max()], + t_span=[t_np.min(), t_np.max()], y0=self.y0, - t_eval=t, + t_eval=t_np, method=self.solver, rtol=self.rtol, - atol=self.atol, - min_step=self.min_step, - max_step=self.max_step + atol=self.atol ) - sol = torch.tensor(sol.y).T.to(self.device, self.dtype) - sol = sol.reshape(-1, *self.shape) - return sol + sol_tensor = torch.tensor(sol.y, requires_grad=True).T.to(self.device, self.dtype) + sol_tensor = sol_tensor.reshape(-1, *self.shape) + return sol_tensor @classmethod def valid_callbacks(cls): return set() -def convert_func_to_numpy(func, shape, device, dtype): +class RK45Solver(ScipyWrapperODESolver): + """Explicit Runge-Kutta method of order 5(4).""" + def __init__(self, func, y0, rtol=1e-7, atol=1e-9, **kwargs): + super().__init__(func, y0, rtol, atol, solver="RK45", **kwargs) + + +class DOP853Solver(ScipyWrapperODESolver): + """Explicit Runge-Kutta method of order 8.""" + def __init__(self, func, y0, rtol=1e-7, atol=1e-9, **kwargs): + super().__init__(func, y0, rtol, atol, solver="DOP853", **kwargs) + + +class Radau(ScipyWrapperODESolver): + """Implicit Runge-Kutta method of the Radau IIA family of order 5.""" + def __init__(self, func, y0, rtol=1e-7, atol=1e-9, **kwargs): + super().__init__(func, y0, rtol, atol, solver="Radau", **kwargs) + +class BDF(ScipyWrapperODESolver): + """Implicit multi-step variable-order method based on backward differentiation formula.""" + def __init__(self, func, y0, rtol=1e-7, atol=1e-9, **kwargs): + super().__init__(func, y0, rtol, atol, solver="BDF", **kwargs) + + +def convert_func_to_numpy(func, shape, device, dtype): def np_func(t, y): - t = torch.tensor(t).to(device, dtype) - y = torch.reshape(torch.tensor(y).to(device, dtype), shape) - with torch.no_grad(): - f = func(t, y) - return f.detach().cpu().numpy().reshape(-1) + # Convert numpy inputs to torch tensors with requires_grad=True + t_tensor = torch.tensor(t, dtype=dtype, device=device, requires_grad=True) + y_tensor = torch.reshape(torch.tensor(y, dtype=dtype, device=device, requires_grad=True), shape) + + # Compute function value with gradients + f = func(t_tensor, y_tensor) + + # Create gradient checkpoint to save memory + def grad_checkpoint(t, y): + return func(t, y) + + # Use gradient checkpointing for better memory efficiency + f = torch.utils.checkpoint.checkpoint(grad_checkpoint, t_tensor, y_tensor) + + # Convert back to numpy while preserving gradient information + return f.cpu().detach().numpy().reshape(-1) return np_func