Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 43 additions & 0 deletions examples/pseudospectral_demo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import torch
import matplotlib.pyplot as plt
from torchdiffeq import odeint

# Define a simple harmonic oscillator
class HarmonicOscillator:
def __init__(self, k=1.0):
self.k = k

def __call__(self, t, y):
# y[0] is position, y[1] is velocity
return torch.stack([
y[1], # dx/dt = v
-self.k * y[0] # dv/dt = -kx
])

# Initial conditions: [x0, v0]
y0 = torch.tensor([1., 0.])

# Time points (using more points for better resolution)
t = torch.linspace(0, 10, 1000)

# Solve using RPM method
with torch.no_grad():
solution = odeint(HarmonicOscillator(), y0, t, method='rpm',
options={'n_points': 16}) # Using 16 Chebyshev points

# Plot results
plt.figure(figsize=(10, 5))
plt.plot(t, solution[:, 0].numpy(), label='Position')
plt.plot(t, solution[:, 1].numpy(), label='Velocity')
plt.grid(True)
plt.legend()
plt.title('Harmonic Oscillator - Pseudospectral Solution')
plt.xlabel('Time')
plt.ylabel('Value')
plt.show()

# Print maximum error compared to analytical solution
exact_x = torch.cos(t)
exact_v = -torch.sin(t)
max_error = torch.max(torch.abs(solution[:, 0] - exact_x))
print("Maximum error in position: {:.2e}".format(max_error))
9 changes: 7 additions & 2 deletions torchdiffeq/_impl/odeint.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,20 @@
from torch.autograd.functional import vjp
from .dopri5 import Dopri5Solver
from .bosh3 import Bosh3Solver
from .radau import RadauSolver
from .adaptive_heun import AdaptiveHeunSolver
from .fehlberg2 import Fehlberg2
from .fixed_grid import Euler, Midpoint, Heun2, Heun3, RK4
from .fixed_adams import AdamsBashforth, AdamsBashforthMoulton
from .dopri8 import Dopri8Solver
from .scipy_wrapper import ScipyWrapperODESolver
from .scipy_wrapper import ScipyWrapperODESolver, RK45Solver, DOP853Solver, Radau, BDF
from .misc import _check_inputs, _flat_to_shape
from .interp import _interp_evaluate

SOLVERS = {
'dopri8': Dopri8Solver,
'dopri5': Dopri5Solver,
'radau': RadauSolver,
'bosh3': Bosh3Solver,
'fehlberg2': Fehlberg2,
'adaptive_heun': AdaptiveHeunSolver,
Expand All @@ -27,7 +29,10 @@
# Backward compatibility: use the same name as before
'fixed_adams': AdamsBashforthMoulton,
# ~Backwards compatibility
'scipy_solver': ScipyWrapperODESolver,
'scipy_rk45': RK45Solver,
'scipy_dop853': DOP853Solver,
'scipy_radau': Radau,
'scipy_bdf': BDF,
}


Expand Down
1 change: 1 addition & 0 deletions torchdiffeq/_impl/pseudospectral.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@

50 changes: 50 additions & 0 deletions torchdiffeq/_impl/radau.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import torch
from .rk_common import _ButcherTableau, RKAdaptiveStepsizeODESolver

# Radau IIA coefficients (order 5, 3 stages)
# Reference: E. Hairer, G. Wanner, "Solving Ordinary Differential Equations II: Stiff and Differential-Algebraic Problems"
_RADAU_IIA_TABLEAU = _ButcherTableau(
alpha=torch.tensor([
(4 - torch.sqrt(torch.tensor(6.))) / 10,
(4 + torch.sqrt(torch.tensor(6.))) / 10,
1.
], dtype=torch.float64),
beta=[
torch.tensor([(4 - torch.sqrt(torch.tensor(6.))) / 10], dtype=torch.float64),
torch.tensor([
(88 - 7 * torch.sqrt(torch.tensor(6.))) / 360,
(296 + 169 * torch.sqrt(torch.tensor(6.))) / 1800
], dtype=torch.float64),
torch.tensor([
(296 - 169 * torch.sqrt(torch.tensor(6.))) / 1800,
(88 + 7 * torch.sqrt(torch.tensor(6.))) / 360,
(16 - torch.sqrt(torch.tensor(6.))) / 36
], dtype=torch.float64),
],
c_sol=torch.tensor([
(16 - torch.sqrt(torch.tensor(6.))) / 36,
(16 + torch.sqrt(torch.tensor(6.))) / 36,
1/9,
0.
], dtype=torch.float64),
c_error=torch.tensor([
(16 - torch.sqrt(torch.tensor(6.))) / 36 - (1/9),
(16 + torch.sqrt(torch.tensor(6.))) / 36 - (1/9),
0.,
0.
], dtype=torch.float64),
)

# Interpolation coefficients for dense output
RADAU_C_MID = torch.tensor([
0.5 * ((16 - torch.sqrt(torch.tensor(6.))) / 36),
0.5 * ((16 + torch.sqrt(torch.tensor(6.))) / 36),
0.5 * (1/9),
0.
], dtype=torch.float64)


class RadauSolver(RKAdaptiveStepsizeODESolver):
order = 5
tableau = _RADAU_IIA_TABLEAU
mid = RADAU_C_MID
4 changes: 3 additions & 1 deletion torchdiffeq/_impl/rk_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,9 @@ def _adaptive_step(self, rk_state):
########################################################
# Assertions #
########################################################
assert t0 + dt > t0, 'underflow in dt {}'.format(dt.item())
# if (dt <= 0).any():
# print('dt < 0')
assert t0 + dt > t0, 'underflow in dt {}, t0 {}'.format(dt.item(), t0)
assert torch.isfinite(y0).all(), 'non-finite values in state `y`: {}'.format(y0)

########################################################
Expand Down
62 changes: 47 additions & 15 deletions torchdiffeq/_impl/scipy_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,34 +27,66 @@ def __init__(self, func, y0, rtol, atol, min_step=0, max_step=float('inf'), solv
def integrate(self, t):
if t.numel() == 1:
return torch.tensor(self.y0)[None].to(self.device, self.dtype)
t = t.detach().cpu().numpy()
t_np = t.detach().cpu().numpy()
sol = solve_ivp(
self.func,
t_span=[t.min(), t.max()],
t_span=[t_np.min(), t_np.max()],
y0=self.y0,
t_eval=t,
t_eval=t_np,
method=self.solver,
rtol=self.rtol,
atol=self.atol,
min_step=self.min_step,
max_step=self.max_step
atol=self.atol
)
sol = torch.tensor(sol.y).T.to(self.device, self.dtype)
sol = sol.reshape(-1, *self.shape)
return sol
sol_tensor = torch.tensor(sol.y, requires_grad=True).T.to(self.device, self.dtype)
sol_tensor = sol_tensor.reshape(-1, *self.shape)
return sol_tensor

@classmethod
def valid_callbacks(cls):
return set()


def convert_func_to_numpy(func, shape, device, dtype):
class RK45Solver(ScipyWrapperODESolver):
"""Explicit Runge-Kutta method of order 5(4)."""
def __init__(self, func, y0, rtol=1e-7, atol=1e-9, **kwargs):
super().__init__(func, y0, rtol, atol, solver="RK45", **kwargs)


class DOP853Solver(ScipyWrapperODESolver):
"""Explicit Runge-Kutta method of order 8."""
def __init__(self, func, y0, rtol=1e-7, atol=1e-9, **kwargs):
super().__init__(func, y0, rtol, atol, solver="DOP853", **kwargs)


class Radau(ScipyWrapperODESolver):
"""Implicit Runge-Kutta method of the Radau IIA family of order 5."""
def __init__(self, func, y0, rtol=1e-7, atol=1e-9, **kwargs):
super().__init__(func, y0, rtol, atol, solver="Radau", **kwargs)


class BDF(ScipyWrapperODESolver):
"""Implicit multi-step variable-order method based on backward differentiation formula."""
def __init__(self, func, y0, rtol=1e-7, atol=1e-9, **kwargs):
super().__init__(func, y0, rtol, atol, solver="BDF", **kwargs)


def convert_func_to_numpy(func, shape, device, dtype):
def np_func(t, y):
t = torch.tensor(t).to(device, dtype)
y = torch.reshape(torch.tensor(y).to(device, dtype), shape)
with torch.no_grad():
f = func(t, y)
return f.detach().cpu().numpy().reshape(-1)
# Convert numpy inputs to torch tensors with requires_grad=True
t_tensor = torch.tensor(t, dtype=dtype, device=device, requires_grad=True)
y_tensor = torch.reshape(torch.tensor(y, dtype=dtype, device=device, requires_grad=True), shape)

# Compute function value with gradients
f = func(t_tensor, y_tensor)

# Create gradient checkpoint to save memory
def grad_checkpoint(t, y):
return func(t, y)

# Use gradient checkpointing for better memory efficiency
f = torch.utils.checkpoint.checkpoint(grad_checkpoint, t_tensor, y_tensor)

# Convert back to numpy while preserving gradient information
return f.cpu().detach().numpy().reshape(-1)

return np_func