-
Notifications
You must be signed in to change notification settings - Fork 15
Inputdsa #14
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Inputdsa #14
Changes from all commits
Commits
Show all changes
104 commits
Select commit
Hold shift + click to select a range
2ded706
updates from dsa-v2
mitchellostrow 0884e18
add modified version of pykoopman (until they accept my pull request)
mitchellostrow 53ef2f1
add new files for inputDSA
mitchellostrow 0c90eac
big alignment of dsa class, fixes on dmdc, simdist_controlalbility, s…
mitchellostrow bf04fc9
convert subspace_dmdc to working in torch
mitchellostrow 0c3aa9a
fix inits
mitchellostrow f456323
fix imports
mitchellostrow df63ee6
simplify angular distance
mitchellostrow 1ddac90
fix comparing A on input-dsa: should use SimDist, not Controllability…
mitchellostrow 5ee7489
update readme
mitchellostrow a53098c
update simdist to not do wasserstein over anything but eigenvalues
mitchellostrow 25c385d
add docstrings for dsa
mitchellostrow 8d46192
add docstring for simdist_controllability
mitchellostrow 88d8e31
black formatting
mitchellostrow 1a9dd03
bug fixes
mitchellostrow 9732614
fix torch and device things
mitchellostrow 1d3e9d3
zeros dmd catch, streamlining subspace_dmdc
mitchellostrow 5de9f4f
some bug fixes
mitchellostrow f6f1d1a
bug fixes
mitchellostrow db4ccff
pykoopman
mitchellostrow 4327124
bug fixes for comparisons
mitchellostrow 4374db5
add error for align_inputs = True (need to fix later, quick hack)
mitchellostrow 293cd17
add dmdc config
mitchellostrow 8b0cae8
torch compatibility, dmdc bug, allow passing config in directly witho…
mitchellostrow bd103d4
fix inputdsa graceful switching of different comparisons based on con…
mitchellostrow d2a9667
dmdc ragged lists bug fix, starting to fix subspace dmdc but not quit…
mitchellostrow 839143a
fixed prediction function & delay issue
7f4f644
checked gDSA with various data structures, metrics, and configs
00bffc1
updated import section
6fa708d
Remove tracked .pyc files
cef93e8
changed subspace dmdc code to work with (n_timepoints, n_features) da…
e8c93af
changed subspace dmdc code to work with (n_timepoints, n_features) da…
d7af660
Merge pull request #12 from Ann-Huang-0/inputdsa
mitchellostrow 7b3b924
update readme
mitchellostrow 1412ce8
Add abstract for InputDSA paper to README
mitchellostrow 1c4e9b0
precompute eigenvalues for wasserstein distance before comparison, re…
mitchellostrow 788918d
add the koopstd tutorial figure
mitchellostrow a6640d0
replicate dsa paper fig 3
mitchellostrow ab51bb2
bug fixes, add tests, add docstrings to DSA and inputDSA
mitchellostrow a13180f
bug fixes and addition of a new tutorial demonstrating all the differ…
mitchellostrow 5eff5b8
add unmentioned detail to tutorial
mitchellostrow 5160214
updated sweep_ranks_delays to work with DMDc and SubspaceDMDc
bc0f348
Merge pull request #13 from Ann-Huang-0/inputdsa
mitchellostrow 9fd4514
bug fix
mitchellostrow 17b025b
dmdc model tutorial notebook
mitchellostrow 0be3dd8
input dsa figure 2 working!
mitchellostrow 4e1a652
compatibility bw local dmd and pykoopman
mitchellostrow 6abae78
replicate rings with new dsa!
mitchellostrow 3b273a6
fix scaling of wasserstein
mitchellostrow c7ecae9
remove subset index bug for time delays (too few timepoints are selec…
mitchellostrow ea2b483
bug fix
mitchellostrow a145afe
bug fix for data handling
mitchellostrow 5653842
Merge branch 'inputdsa' of https://github.com/mitchellostrow/DSA into…
mitchellostrow fb047cf
dataclass bug fix
mitchellostrow a2e14f4
Merge branch 'main' into inputdsa
mitchellostrow 34151cf
Update DSA/pykoopman/__init__.py
mitchellostrow bcfd862
Update DSA/pykoopman/common/__init__.py
mitchellostrow a5e931f
Fix docstring formatting in simdist_controllability
mitchellostrow 67516a9
Fix syntax error in setup.py requirements
mitchellostrow 6d2021b
Initial plan
Copilot 00c14a8
Initial plan
Copilot bcfd5cd
Fix file resource leaks in _nndmd.py using context managers
Copilot 6b8ad7e
Fix file resource leaks by using context managers
Copilot ee17131
Remove pycache file and add .gitignore
Copilot 0d5cb2c
Initial plan
Copilot d850b2d
Initial plan
Copilot 34ef070
Add prettytable to pyproject.toml dependencies
Copilot 0bd1975
Add proper error handling for UMAP import and make it an optional dep…
Copilot a9c2a1f
Remove __pycache__ files and add .gitignore
Copilot 358a3f1
Merge pull request #15 from mitchellostrow/copilot/sub-pr-14
mitchellostrow 196138e
Merge branch 'inputdsa' into copilot/sub-pr-14-again
mitchellostrow dcb7466
Merge pull request #16 from mitchellostrow/copilot/sub-pr-14-again
mitchellostrow 4f348a3
Initial plan
Copilot e733867
Update DSA/dsa.py
mitchellostrow 447b462
Fix division by zero in nmse() for constant arrays
Copilot 525bda7
Update DSA/preprocessing.py
mitchellostrow fe35859
Optimize nmse to avoid computing MSE twice
Copilot f807592
Improve clarity of zero variance handling in nmse
Copilot ddf84c2
Update DSA/__init__.py
mitchellostrow 5746234
Merge pull request #17 from mitchellostrow/copilot/sub-pr-14-another-one
mitchellostrow bc01eff
Merge branch 'inputdsa' into copilot/sub-pr-14-yet-again
mitchellostrow dffbff7
Merge pull request #18 from mitchellostrow/copilot/sub-pr-14-yet-again
mitchellostrow 86c6209
Merge pull request #19 from mitchellostrow/copilot/sub-pr-14-one-more…
mitchellostrow 14177f4
Initial plan
Copilot eeef423
Remove unused pad_zeros import from tests/simdist_test.py
Copilot aa02f37
Initial plan
Copilot b4dde3a
Remove unused import of SimilarityTransformDist
Copilot 845843d
Merge pull request #21 from mitchellostrow/copilot/sub-pr-14-8afcd5b6…
mitchellostrow 61ca512
Merge pull request #22 from mitchellostrow/copilot/sub-pr-14-cd2ee78c…
mitchellostrow cc2af67
fixing bugs on inputdsa
mitchellostrow 69dd4fc
fix some tests and redo whole sweep class
mitchellostrow f8e782c
add option for sweeping over multiple observable params
mitchellostrow 972fd98
resdmd updates
mitchellostrow 8d6a36a
resdmd with control bug fixes
mitchellostrow 05b7a25
kalman smoothing for postprocessing latent state inference
mitchellostrow 88488e9
resdmd for subspacedmdc fix with kalman smoothing for test error
mitchellostrow 68383f1
pykoopman handle dmdc correctly
mitchellostrow a42cc4a
small bugs and tutorial fixes
mitchellostrow ea0d14c
add flag for differentiability
mitchellostrow 147c39d
Update DSA/pykoopman/__init__.py
mitchellostrow 6bf0b7a
update pyproject, clean and make more pythonic
mitchellostrow dea2af5
Merge branch 'inputdsa' of https://github.com/mitchellostrow/DSA into…
mitchellostrow 2eac6e3
updates for uv, simpler imports, resolving comments
mitchellostrow f1caab0
update notebooks and packages to run seamlessly with provided uv inst…
mitchellostrow File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,63 @@ | ||
| # Python cache files | ||
| __pycache__/ | ||
| *.py[cod] | ||
| *$py.class | ||
|
|
||
| # Distribution / packaging | ||
| .Python | ||
| build/ | ||
| develop-eggs/ | ||
| dist/ | ||
| downloads/ | ||
| eggs/ | ||
| .eggs/ | ||
| lib/ | ||
| lib64/ | ||
| parts/ | ||
| sdist/ | ||
| var/ | ||
| wheels/ | ||
| pip-wheel-metadata/ | ||
| share/python-wheels/ | ||
| *.egg-info/ | ||
| .installed.cfg | ||
| *.egg | ||
| MANIFEST | ||
|
|
||
| # PyInstaller | ||
| *.manifest | ||
| *.spec | ||
|
|
||
| # Unit test / coverage reports | ||
| htmlcov/ | ||
| .tox/ | ||
| .nox/ | ||
| .coverage | ||
| .coverage.* | ||
| .cache | ||
| nosetests.xml | ||
| coverage.xml | ||
| *.cover | ||
| *.py,cover | ||
| .hypothesis/ | ||
| .pytest_cache/ | ||
|
|
||
| # Environments | ||
| .env | ||
| .venv | ||
| env/ | ||
| venv/ | ||
| ENV/ | ||
| env.bak/ | ||
| venv.bak/ | ||
|
|
||
| # IDEs | ||
| .vscode/ | ||
| .idea/ | ||
| *.swp | ||
| *.swo | ||
| *~ | ||
|
|
||
| # OS | ||
| .DS_Store | ||
| Thumbs.db |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,8 +1,15 @@ | ||
| from DSA.dsa import DSA | ||
| __version__ = "2.0.0" | ||
|
|
||
| from DSA.dsa import DSA, ControllabilitySimilarityTransformDistConfig, GeneralizedDSA, InputDSA, SimilarityTransformDistConfig | ||
| from DSA.dsa import DefaultDMDConfig as DMDConfig | ||
| from DSA.dsa import pyKoopmanDMDConfig, SubspaceDMDcConfig, DMDcConfig | ||
| from DSA.dsa import SimilarityTransformDistConfig, ControllabilitySimilarityTransformDistConfig | ||
| from DSA.dmd import DMD | ||
| from DSA.kerneldmd import KernelDMD | ||
| from DSA.dmdc import DMDc | ||
| from DSA.subspace_dmdc import SubspaceDMDc | ||
| from DSA.simdist import SimilarityTransformDist | ||
| from DSA.simdist_controllability import ControllabilitySimilarityTransformDist | ||
| from DSA.stats import * | ||
| from DSA.sweeps import * | ||
| from DSA.sweeps import PyKoopmanSweeper, DefaultSweeper | ||
| from DSA.preprocessing import * | ||
| from DSA.resdmd import * | ||
| from DSA.resdmd import ResidualComputer |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,282 @@ | ||
| """Base class for DMD implementations.""" | ||
|
|
||
| import numpy as np | ||
| import torch | ||
| import warnings | ||
| from abc import ABC, abstractmethod | ||
|
|
||
|
|
||
| class BaseDMD(ABC): | ||
| """Base class for DMD implementations with common functionality.""" | ||
|
|
||
| def __init__( | ||
| self, | ||
| device="cpu", | ||
| verbose=False, | ||
| send_to_cpu=False, | ||
| lamb=0, | ||
| ): | ||
| """ | ||
| Parameters | ||
| ---------- | ||
| device: string, int, or torch.device | ||
| A string, int or torch.device object to indicate the device to torch. | ||
| If 'cuda' or 'cuda:X' is specified but not available, will fall back to 'cpu' with a warning. | ||
| verbose: bool | ||
| If True, print statements will be provided about the progress of the fitting procedure. | ||
| send_to_cpu: bool | ||
| If True, will send all tensors in the object back to the cpu after everything is computed. | ||
| This is implemented to prevent gpu memory overload when computing multiple DMDs. | ||
| lamb : float | ||
| Regularization parameter for ridge regression. Defaults to 0. | ||
| """ | ||
| self.device = device | ||
| self.verbose = verbose | ||
| self.send_to_cpu = send_to_cpu | ||
| self.lamb = lamb | ||
|
|
||
| # Common attributes | ||
| self.data = None | ||
| self.n = None | ||
| self.ntrials = None | ||
| self.is_list_data = False | ||
|
|
||
| # SVD attributes - will be set by subclasses | ||
| self.cumulative_explained_variance = None | ||
|
|
||
| def _setup_device(self, device='cpu', use_torch=None): | ||
| """ | ||
| Smart device setup with graceful fallback and auto-detection. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| device : str or torch.device | ||
| Requested device ('cpu', 'cuda', 'cuda:0', etc.) | ||
| use_torch : bool or None | ||
| Whether to use PyTorch. If None, auto-detected: | ||
| - True if device contains 'cuda' | ||
| - False otherwise (numpy is faster on CPU) | ||
|
|
||
| Returns | ||
| ------- | ||
| tuple | ||
| (device, use_torch) - validated device and use_torch flag | ||
| """ | ||
| # Convert device to string for checking | ||
| device_str = str(device).lower() | ||
|
|
||
| # Auto-detect use_torch if not specified | ||
| if use_torch is None: | ||
| use_torch = 'cuda' in device_str | ||
|
|
||
| # If CUDA requested, check availability | ||
| if 'cuda' in device_str: | ||
| if not torch.cuda.is_available(): | ||
| warnings.warn( | ||
| f"CUDA device '{device}' requested but CUDA is not available. " | ||
| "Falling back to CPU. " | ||
| "To use GPU acceleration, ensure PyTorch with CUDA support is installed.", | ||
| RuntimeWarning, | ||
| stacklevel=3 | ||
| ) | ||
| device = 'cpu' | ||
| use_torch = False # Use numpy on CPU for better performance | ||
| else: | ||
| # CUDA is available, verify the specific device exists | ||
| try: | ||
| test_device = torch.device(device) | ||
| # Test if we can actually use this device | ||
| torch.tensor([1.0], device=test_device) | ||
| use_torch = True | ||
| except (RuntimeError, AssertionError) as e: | ||
| warnings.warn( | ||
| f"CUDA device '{device}' requested but not accessible: {e}. " | ||
| f"Falling back to CPU.", | ||
| RuntimeWarning, | ||
| stacklevel=3 | ||
| ) | ||
| device = 'cpu' | ||
| use_torch = False | ||
|
|
||
| # Convert to torch.device if using torch | ||
| if use_torch: | ||
| device = torch.device(device) | ||
| else: | ||
| device = None # Use numpy (no torch device needed) | ||
|
|
||
| return device, use_torch | ||
|
|
||
| def _process_single_dataset(self, data): | ||
| """Process a single dataset, handling numpy arrays, tensors, and lists.""" | ||
| if isinstance(data, list): | ||
| try: | ||
| # Attempt to convert to a single tensor if possible (non-ragged) | ||
| processed_data = [ | ||
| torch.from_numpy(d) if isinstance(d, np.ndarray) else d | ||
| for d in data | ||
| ] | ||
| return torch.stack(processed_data), False | ||
| except (RuntimeError, ValueError): | ||
| # Handle ragged lists | ||
| processed_data = [ | ||
| torch.from_numpy(d) if isinstance(d, np.ndarray) else d | ||
| for d in data | ||
| ] | ||
| # Check for consistent last dimension | ||
| n_features = processed_data[0].shape[-1] | ||
| if not all(d.shape[-1] == n_features for d in processed_data): | ||
| raise ValueError( | ||
| "All tensors in the list must have the same number of features (last dimension)." | ||
| ) | ||
| return processed_data, True | ||
|
|
||
| elif isinstance(data, np.ndarray): | ||
| return torch.from_numpy(data.copy()), False | ||
|
|
||
| return data, False | ||
|
|
||
| def _init_single_data(self, data): | ||
| """Initialize data attributes for a single dataset.""" | ||
| processed_data, is_ragged = self._process_single_dataset(data) | ||
|
|
||
| if is_ragged: | ||
| # Set attributes for ragged data | ||
| n_features = processed_data[0].shape[-1] | ||
| self.n = n_features | ||
| self.ntrials = sum(d.shape[0] if d.ndim == 3 else 1 for d in processed_data) | ||
| self.trial_counts = [ | ||
| d.shape[0] if d.ndim == 3 else 1 for d in processed_data | ||
| ] | ||
| self.is_list_data = True | ||
| else: | ||
| # Set attributes for non-ragged data | ||
| if processed_data.ndim == 3: | ||
| self.ntrials = processed_data.shape[0] | ||
| self.n = processed_data.shape[2] | ||
| else: | ||
| self.n = processed_data.shape[1] | ||
| self.ntrials = 1 | ||
| self.is_list_data = False | ||
|
|
||
| return processed_data | ||
|
|
||
| def _compute_explained_variance(self, S): | ||
| """Compute cumulative explained variance from singular values.""" | ||
| exp_variance = S**2 / torch.sum(S**2) | ||
| return torch.cumsum(exp_variance, 0) | ||
|
|
||
| def _compute_rank_from_params( | ||
| self, | ||
| S, | ||
| cumulative_explained_variance, | ||
| max_rank, | ||
| rank=None, | ||
| rank_thresh=None, | ||
| rank_explained_variance=None, | ||
| ): | ||
| """ | ||
| Compute rank based on provided parameters. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| S : torch.Tensor | ||
| Singular values | ||
| cumulative_explained_variance : torch.Tensor | ||
| Cumulative explained variance | ||
| max_rank : int | ||
| Maximum possible rank | ||
| rank : int, optional | ||
| Explicit rank specification | ||
| rank_thresh : float, optional | ||
| Threshold for singular values | ||
| rank_explained_variance : float, optional | ||
| Explained variance threshold | ||
|
|
||
| Returns | ||
| ------- | ||
| int | ||
| Computed rank | ||
| """ | ||
| parameters_provided = [ | ||
| rank is not None, | ||
| rank_thresh is not None, | ||
| rank_explained_variance is not None, | ||
| ] | ||
| num_parameters_provided = sum(parameters_provided) | ||
|
|
||
| if num_parameters_provided > 1: | ||
| raise ValueError( | ||
| "More than one rank parameter was provided. Please provide only one of rank, rank_thresh, or rank_explained_variance." | ||
| ) | ||
| elif num_parameters_provided == 0: | ||
| computed_rank = len(S) | ||
| else: | ||
| if rank is not None: | ||
| computed_rank = rank | ||
| elif rank_thresh is not None: | ||
| # Find the number of singular values greater than the threshold | ||
| computed_rank = int((S > rank_thresh).sum().item()) | ||
| if computed_rank == 0: | ||
| computed_rank = 1 # Ensure at least rank 1 | ||
| elif rank_explained_variance is not None: | ||
| cumulative_explained_variance_cpu = ( | ||
| cumulative_explained_variance.cpu().numpy() | ||
| ) | ||
| computed_rank = int( | ||
| np.searchsorted( | ||
| cumulative_explained_variance_cpu, rank_explained_variance | ||
| ) | ||
| + 1 | ||
| ) | ||
| if computed_rank > len(S): | ||
| computed_rank = len(S) | ||
|
|
||
| # Ensure rank doesn't exceed maximum possible | ||
| if computed_rank > max_rank: | ||
| computed_rank = max_rank | ||
|
|
||
| return computed_rank | ||
|
|
||
| def _to_torch(self, x): | ||
| """Convert numpy array to torch tensor on the appropriate device.""" | ||
| if not self.use_torch or x is None: | ||
| return x | ||
| if isinstance(x, torch.Tensor): | ||
| return x.to(self.device) | ||
| return torch.from_numpy(x).to(self.device) | ||
|
|
||
| def _to_numpy(self, x): | ||
| """Convert torch tensor to numpy array.""" | ||
| if not self.use_torch or x is None: | ||
| return x | ||
| if isinstance(x, torch.Tensor): | ||
| return x.cpu().numpy() | ||
| return x | ||
|
|
||
| def all_to_device(self, device="cpu"): | ||
| """Move all tensor attributes to specified device.""" | ||
| for k, v in self.__dict__.items(): | ||
| if isinstance(v, torch.Tensor): | ||
| self.__dict__[k] = v.to(device) | ||
| elif isinstance(v, list) and len(v) > 0 and isinstance(v[0], torch.Tensor): | ||
| self.__dict__[k] = [tensor.to(device) for tensor in v] | ||
|
|
||
| @abstractmethod | ||
| def fit(self, *args, **kwargs): | ||
| """Fit the DMD model. Must be implemented by subclasses.""" | ||
| pass | ||
|
|
||
| @abstractmethod | ||
| def predict(self, *args, **kwargs): | ||
mitchellostrow marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| """Make predictions with the DMD model. Must be implemented by subclasses.""" | ||
| pass | ||
|
|
||
| @abstractmethod | ||
| def compute_hankel(self, *args, **kwargs): | ||
| """Compute Hankel matrix. Must be implemented by subclasses.""" | ||
| pass | ||
|
|
||
| @abstractmethod | ||
| def compute_svd(self, *args, **kwargs): | ||
| """Compute SVD. Must be implemented by subclasses.""" | ||
| pass | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.