Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
104 commits
Select commit Hold shift + click to select a range
2ded706
updates from dsa-v2
mitchellostrow Oct 27, 2025
0884e18
add modified version of pykoopman (until they accept my pull request)
mitchellostrow Oct 27, 2025
53ef2f1
add new files for inputDSA
mitchellostrow Oct 27, 2025
0c90eac
big alignment of dsa class, fixes on dmdc, simdist_controlalbility, s…
mitchellostrow Oct 28, 2025
bf04fc9
convert subspace_dmdc to working in torch
mitchellostrow Oct 28, 2025
0c3aa9a
fix inits
mitchellostrow Oct 28, 2025
f456323
fix imports
mitchellostrow Oct 28, 2025
df63ee6
simplify angular distance
mitchellostrow Oct 28, 2025
1ddac90
fix comparing A on input-dsa: should use SimDist, not Controllability…
mitchellostrow Oct 28, 2025
5ee7489
update readme
mitchellostrow Oct 28, 2025
a53098c
update simdist to not do wasserstein over anything but eigenvalues
mitchellostrow Oct 28, 2025
25c385d
add docstrings for dsa
mitchellostrow Oct 28, 2025
8d46192
add docstring for simdist_controllability
mitchellostrow Oct 28, 2025
88d8e31
black formatting
mitchellostrow Oct 28, 2025
1a9dd03
bug fixes
mitchellostrow Oct 28, 2025
9732614
fix torch and device things
mitchellostrow Oct 29, 2025
1d3e9d3
zeros dmd catch, streamlining subspace_dmdc
mitchellostrow Oct 30, 2025
5de9f4f
some bug fixes
mitchellostrow Oct 30, 2025
f6f1d1a
bug fixes
mitchellostrow Oct 30, 2025
db4ccff
pykoopman
mitchellostrow Oct 30, 2025
4327124
bug fixes for comparisons
mitchellostrow Oct 30, 2025
4374db5
add error for align_inputs = True (need to fix later, quick hack)
mitchellostrow Oct 30, 2025
293cd17
add dmdc config
mitchellostrow Oct 30, 2025
8b0cae8
torch compatibility, dmdc bug, allow passing config in directly witho…
mitchellostrow Oct 30, 2025
bd103d4
fix inputdsa graceful switching of different comparisons based on con…
mitchellostrow Oct 30, 2025
d2a9667
dmdc ragged lists bug fix, starting to fix subspace dmdc but not quit…
mitchellostrow Oct 30, 2025
839143a
fixed prediction function & delay issue
Nov 3, 2025
7f4f644
checked gDSA with various data structures, metrics, and configs
Nov 3, 2025
00bffc1
updated import section
Nov 4, 2025
6fa708d
Remove tracked .pyc files
Nov 4, 2025
cef93e8
changed subspace dmdc code to work with (n_timepoints, n_features) da…
Nov 4, 2025
e8c93af
changed subspace dmdc code to work with (n_timepoints, n_features) da…
Nov 4, 2025
d7af660
Merge pull request #12 from Ann-Huang-0/inputdsa
mitchellostrow Nov 4, 2025
7b3b924
update readme
mitchellostrow Nov 4, 2025
1412ce8
Add abstract for InputDSA paper to README
mitchellostrow Nov 4, 2025
1c4e9b0
precompute eigenvalues for wasserstein distance before comparison, re…
mitchellostrow Nov 4, 2025
788918d
add the koopstd tutorial figure
mitchellostrow Nov 5, 2025
a6640d0
replicate dsa paper fig 3
mitchellostrow Nov 5, 2025
ab51bb2
bug fixes, add tests, add docstrings to DSA and inputDSA
mitchellostrow Nov 5, 2025
a13180f
bug fixes and addition of a new tutorial demonstrating all the differ…
mitchellostrow Nov 5, 2025
5eff5b8
add unmentioned detail to tutorial
mitchellostrow Nov 5, 2025
5160214
updated sweep_ranks_delays to work with DMDc and SubspaceDMDc
Nov 6, 2025
bc0f348
Merge pull request #13 from Ann-Huang-0/inputdsa
mitchellostrow Nov 7, 2025
9fd4514
bug fix
mitchellostrow Nov 7, 2025
17b025b
dmdc model tutorial notebook
mitchellostrow Nov 7, 2025
0be3dd8
input dsa figure 2 working!
mitchellostrow Nov 7, 2025
4e1a652
compatibility bw local dmd and pykoopman
mitchellostrow Nov 8, 2025
6abae78
replicate rings with new dsa!
mitchellostrow Nov 9, 2025
3b273a6
fix scaling of wasserstein
mitchellostrow Nov 9, 2025
c7ecae9
remove subset index bug for time delays (too few timepoints are selec…
mitchellostrow Nov 12, 2025
ea2b483
bug fix
mitchellostrow Dec 2, 2025
a145afe
bug fix for data handling
mitchellostrow Dec 9, 2025
5653842
Merge branch 'inputdsa' of https://github.com/mitchellostrow/DSA into…
mitchellostrow Dec 9, 2025
fb047cf
dataclass bug fix
mitchellostrow Jan 10, 2026
a2e14f4
Merge branch 'main' into inputdsa
mitchellostrow Feb 6, 2026
34151cf
Update DSA/pykoopman/__init__.py
mitchellostrow Feb 6, 2026
bcfd862
Update DSA/pykoopman/common/__init__.py
mitchellostrow Feb 6, 2026
a5e931f
Fix docstring formatting in simdist_controllability
mitchellostrow Feb 6, 2026
67516a9
Fix syntax error in setup.py requirements
mitchellostrow Feb 6, 2026
6d2021b
Initial plan
Copilot Feb 6, 2026
00c14a8
Initial plan
Copilot Feb 6, 2026
bcfd5cd
Fix file resource leaks in _nndmd.py using context managers
Copilot Feb 6, 2026
6b8ad7e
Fix file resource leaks by using context managers
Copilot Feb 6, 2026
ee17131
Remove pycache file and add .gitignore
Copilot Feb 6, 2026
0d5cb2c
Initial plan
Copilot Feb 6, 2026
d850b2d
Initial plan
Copilot Feb 6, 2026
34ef070
Add prettytable to pyproject.toml dependencies
Copilot Feb 6, 2026
0bd1975
Add proper error handling for UMAP import and make it an optional dep…
Copilot Feb 6, 2026
a9c2a1f
Remove __pycache__ files and add .gitignore
Copilot Feb 6, 2026
358a3f1
Merge pull request #15 from mitchellostrow/copilot/sub-pr-14
mitchellostrow Feb 6, 2026
196138e
Merge branch 'inputdsa' into copilot/sub-pr-14-again
mitchellostrow Feb 6, 2026
dcb7466
Merge pull request #16 from mitchellostrow/copilot/sub-pr-14-again
mitchellostrow Feb 6, 2026
4f348a3
Initial plan
Copilot Feb 6, 2026
e733867
Update DSA/dsa.py
mitchellostrow Feb 6, 2026
447b462
Fix division by zero in nmse() for constant arrays
Copilot Feb 6, 2026
525bda7
Update DSA/preprocessing.py
mitchellostrow Feb 6, 2026
fe35859
Optimize nmse to avoid computing MSE twice
Copilot Feb 6, 2026
f807592
Improve clarity of zero variance handling in nmse
Copilot Feb 6, 2026
ddf84c2
Update DSA/__init__.py
mitchellostrow Feb 6, 2026
5746234
Merge pull request #17 from mitchellostrow/copilot/sub-pr-14-another-one
mitchellostrow Feb 6, 2026
bc01eff
Merge branch 'inputdsa' into copilot/sub-pr-14-yet-again
mitchellostrow Feb 6, 2026
dffbff7
Merge pull request #18 from mitchellostrow/copilot/sub-pr-14-yet-again
mitchellostrow Feb 6, 2026
86c6209
Merge pull request #19 from mitchellostrow/copilot/sub-pr-14-one-more…
mitchellostrow Feb 6, 2026
14177f4
Initial plan
Copilot Feb 6, 2026
eeef423
Remove unused pad_zeros import from tests/simdist_test.py
Copilot Feb 6, 2026
aa02f37
Initial plan
Copilot Feb 6, 2026
b4dde3a
Remove unused import of SimilarityTransformDist
Copilot Feb 6, 2026
845843d
Merge pull request #21 from mitchellostrow/copilot/sub-pr-14-8afcd5b6…
mitchellostrow Feb 6, 2026
61ca512
Merge pull request #22 from mitchellostrow/copilot/sub-pr-14-cd2ee78c…
mitchellostrow Feb 6, 2026
cc2af67
fixing bugs on inputdsa
mitchellostrow Feb 6, 2026
69dd4fc
fix some tests and redo whole sweep class
mitchellostrow Feb 6, 2026
f8e782c
add option for sweeping over multiple observable params
mitchellostrow Feb 7, 2026
972fd98
resdmd updates
mitchellostrow Feb 7, 2026
8d6a36a
resdmd with control bug fixes
mitchellostrow Feb 7, 2026
05b7a25
kalman smoothing for postprocessing latent state inference
mitchellostrow Feb 7, 2026
88488e9
resdmd for subspacedmdc fix with kalman smoothing for test error
mitchellostrow Feb 7, 2026
68383f1
pykoopman handle dmdc correctly
mitchellostrow Feb 7, 2026
a42cc4a
small bugs and tutorial fixes
mitchellostrow Feb 7, 2026
ea0d14c
add flag for differentiability
mitchellostrow Feb 13, 2026
147c39d
Update DSA/pykoopman/__init__.py
mitchellostrow Feb 14, 2026
6bf0b7a
update pyproject, clean and make more pythonic
mitchellostrow Feb 14, 2026
dea2af5
Merge branch 'inputdsa' of https://github.com/mitchellostrow/DSA into…
mitchellostrow Feb 14, 2026
2eac6e3
updates for uv, simpler imports, resolving comments
mitchellostrow Feb 14, 2026
f1caab0
update notebooks and packages to run seamlessly with provided uv inst…
mitchellostrow Feb 15, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 63 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# Python cache files
__pycache__/
*.py[cod]
*$py.class

# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST

# PyInstaller
*.manifest
*.spec

# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/

# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/

# IDEs
.vscode/
.idea/
*.swp
*.swo
*~

# OS
.DS_Store
Thumbs.db
15 changes: 11 additions & 4 deletions DSA/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,15 @@
from DSA.dsa import DSA
__version__ = "2.0.0"

from DSA.dsa import DSA, ControllabilitySimilarityTransformDistConfig, GeneralizedDSA, InputDSA, SimilarityTransformDistConfig
from DSA.dsa import DefaultDMDConfig as DMDConfig
from DSA.dsa import pyKoopmanDMDConfig, SubspaceDMDcConfig, DMDcConfig
from DSA.dsa import SimilarityTransformDistConfig, ControllabilitySimilarityTransformDistConfig
from DSA.dmd import DMD
from DSA.kerneldmd import KernelDMD
from DSA.dmdc import DMDc
from DSA.subspace_dmdc import SubspaceDMDc
from DSA.simdist import SimilarityTransformDist
from DSA.simdist_controllability import ControllabilitySimilarityTransformDist
from DSA.stats import *
from DSA.sweeps import *
from DSA.sweeps import PyKoopmanSweeper, DefaultSweeper
from DSA.preprocessing import *
from DSA.resdmd import *
from DSA.resdmd import ResidualComputer
282 changes: 282 additions & 0 deletions DSA/base_dmd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,282 @@
"""Base class for DMD implementations."""

import numpy as np
import torch
import warnings
from abc import ABC, abstractmethod


class BaseDMD(ABC):
"""Base class for DMD implementations with common functionality."""

def __init__(
self,
device="cpu",
verbose=False,
send_to_cpu=False,
lamb=0,
):
"""
Parameters
----------
device: string, int, or torch.device
A string, int or torch.device object to indicate the device to torch.
If 'cuda' or 'cuda:X' is specified but not available, will fall back to 'cpu' with a warning.
verbose: bool
If True, print statements will be provided about the progress of the fitting procedure.
send_to_cpu: bool
If True, will send all tensors in the object back to the cpu after everything is computed.
This is implemented to prevent gpu memory overload when computing multiple DMDs.
lamb : float
Regularization parameter for ridge regression. Defaults to 0.
"""
self.device = device
self.verbose = verbose
self.send_to_cpu = send_to_cpu
self.lamb = lamb

# Common attributes
self.data = None
self.n = None
self.ntrials = None
self.is_list_data = False

# SVD attributes - will be set by subclasses
self.cumulative_explained_variance = None

def _setup_device(self, device='cpu', use_torch=None):
"""
Smart device setup with graceful fallback and auto-detection.

Parameters
----------
device : str or torch.device
Requested device ('cpu', 'cuda', 'cuda:0', etc.)
use_torch : bool or None
Whether to use PyTorch. If None, auto-detected:
- True if device contains 'cuda'
- False otherwise (numpy is faster on CPU)

Returns
-------
tuple
(device, use_torch) - validated device and use_torch flag
"""
# Convert device to string for checking
device_str = str(device).lower()

# Auto-detect use_torch if not specified
if use_torch is None:
use_torch = 'cuda' in device_str

# If CUDA requested, check availability
if 'cuda' in device_str:
if not torch.cuda.is_available():
warnings.warn(
f"CUDA device '{device}' requested but CUDA is not available. "
"Falling back to CPU. "
"To use GPU acceleration, ensure PyTorch with CUDA support is installed.",
RuntimeWarning,
stacklevel=3
)
device = 'cpu'
use_torch = False # Use numpy on CPU for better performance
else:
# CUDA is available, verify the specific device exists
try:
test_device = torch.device(device)
# Test if we can actually use this device
torch.tensor([1.0], device=test_device)
use_torch = True
except (RuntimeError, AssertionError) as e:
warnings.warn(
f"CUDA device '{device}' requested but not accessible: {e}. "
f"Falling back to CPU.",
RuntimeWarning,
stacklevel=3
)
device = 'cpu'
use_torch = False

# Convert to torch.device if using torch
if use_torch:
device = torch.device(device)
else:
device = None # Use numpy (no torch device needed)

return device, use_torch

def _process_single_dataset(self, data):
"""Process a single dataset, handling numpy arrays, tensors, and lists."""
if isinstance(data, list):
try:
# Attempt to convert to a single tensor if possible (non-ragged)
processed_data = [
torch.from_numpy(d) if isinstance(d, np.ndarray) else d
for d in data
]
return torch.stack(processed_data), False
except (RuntimeError, ValueError):
# Handle ragged lists
processed_data = [
torch.from_numpy(d) if isinstance(d, np.ndarray) else d
for d in data
]
# Check for consistent last dimension
n_features = processed_data[0].shape[-1]
if not all(d.shape[-1] == n_features for d in processed_data):
raise ValueError(
"All tensors in the list must have the same number of features (last dimension)."
)
return processed_data, True

elif isinstance(data, np.ndarray):
return torch.from_numpy(data.copy()), False

return data, False

def _init_single_data(self, data):
"""Initialize data attributes for a single dataset."""
processed_data, is_ragged = self._process_single_dataset(data)

if is_ragged:
# Set attributes for ragged data
n_features = processed_data[0].shape[-1]
self.n = n_features
self.ntrials = sum(d.shape[0] if d.ndim == 3 else 1 for d in processed_data)
self.trial_counts = [
d.shape[0] if d.ndim == 3 else 1 for d in processed_data
]
self.is_list_data = True
else:
# Set attributes for non-ragged data
if processed_data.ndim == 3:
self.ntrials = processed_data.shape[0]
self.n = processed_data.shape[2]
else:
self.n = processed_data.shape[1]
self.ntrials = 1
self.is_list_data = False

return processed_data

def _compute_explained_variance(self, S):
"""Compute cumulative explained variance from singular values."""
exp_variance = S**2 / torch.sum(S**2)
return torch.cumsum(exp_variance, 0)

def _compute_rank_from_params(
self,
S,
cumulative_explained_variance,
max_rank,
rank=None,
rank_thresh=None,
rank_explained_variance=None,
):
"""
Compute rank based on provided parameters.

Parameters
----------
S : torch.Tensor
Singular values
cumulative_explained_variance : torch.Tensor
Cumulative explained variance
max_rank : int
Maximum possible rank
rank : int, optional
Explicit rank specification
rank_thresh : float, optional
Threshold for singular values
rank_explained_variance : float, optional
Explained variance threshold

Returns
-------
int
Computed rank
"""
parameters_provided = [
rank is not None,
rank_thresh is not None,
rank_explained_variance is not None,
]
num_parameters_provided = sum(parameters_provided)

if num_parameters_provided > 1:
raise ValueError(
"More than one rank parameter was provided. Please provide only one of rank, rank_thresh, or rank_explained_variance."
)
elif num_parameters_provided == 0:
computed_rank = len(S)
else:
if rank is not None:
computed_rank = rank
elif rank_thresh is not None:
# Find the number of singular values greater than the threshold
computed_rank = int((S > rank_thresh).sum().item())
if computed_rank == 0:
computed_rank = 1 # Ensure at least rank 1
elif rank_explained_variance is not None:
cumulative_explained_variance_cpu = (
cumulative_explained_variance.cpu().numpy()
)
computed_rank = int(
np.searchsorted(
cumulative_explained_variance_cpu, rank_explained_variance
)
+ 1
)
if computed_rank > len(S):
computed_rank = len(S)

# Ensure rank doesn't exceed maximum possible
if computed_rank > max_rank:
computed_rank = max_rank

return computed_rank

def _to_torch(self, x):
"""Convert numpy array to torch tensor on the appropriate device."""
if not self.use_torch or x is None:
return x
if isinstance(x, torch.Tensor):
return x.to(self.device)
return torch.from_numpy(x).to(self.device)

def _to_numpy(self, x):
"""Convert torch tensor to numpy array."""
if not self.use_torch or x is None:
return x
if isinstance(x, torch.Tensor):
return x.cpu().numpy()
return x

def all_to_device(self, device="cpu"):
"""Move all tensor attributes to specified device."""
for k, v in self.__dict__.items():
if isinstance(v, torch.Tensor):
self.__dict__[k] = v.to(device)
elif isinstance(v, list) and len(v) > 0 and isinstance(v[0], torch.Tensor):
self.__dict__[k] = [tensor.to(device) for tensor in v]

@abstractmethod
def fit(self, *args, **kwargs):
"""Fit the DMD model. Must be implemented by subclasses."""
pass

@abstractmethod
def predict(self, *args, **kwargs):
"""Make predictions with the DMD model. Must be implemented by subclasses."""
pass

@abstractmethod
def compute_hankel(self, *args, **kwargs):
"""Compute Hankel matrix. Must be implemented by subclasses."""
pass

@abstractmethod
def compute_svd(self, *args, **kwargs):
"""Compute SVD. Must be implemented by subclasses."""
pass
Loading