Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
203 changes: 74 additions & 129 deletions DSA/dmdc.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ def __init__(
device="cpu",
verbose=False,
send_to_cpu=False,
svd_separate=True,
steps_ahead=1,
):
"""
Expand Down Expand Up @@ -128,9 +127,6 @@ def __init__(
self.rank_output = rank_output
self.rank_thresh_output = rank_thresh_output
self.rank_explained_variance_output = rank_explained_variance_output
self.svd_separate = (
svd_separate # do svd on H and u separately as well as regression
)
self.steps_ahead = steps_ahead

# Hankel matrix
Expand All @@ -139,21 +135,26 @@ def __init__(
# Control input Hankel matrix
self.Hu = None

# SVD attributes
self.U = None
self.S = None
self.V = None
self.S_mat = None
self.S_mat_inv = None

# Change of basis between the reduced-order subspace and the full space
self.U_out = None
self.S_out = None
self.V_out = None

# DMDc attributes
# Stacked formulation: Omega = vstack((X, U)), Y = state at t+steps_ahead
self.X = None
self.Y = None
self.Omega = None

# SVD attributes (from Omega and Y)
self.Up = None
self.Sp = None
self.Vp = None
self.Up1 = None
self.Up2 = None
self.Ur = None
self.Sr = None
self.Vr = None

# DMDc attributes (A, B full-space; A_v, B_v reduced-order; A_havok_dmd/B_havok_dmd aliases in compute_dmdc)
self.A_tilde = None
self.B_tilde = None
self.A_v = None
self.B_v = None
self.A = None
self.B = None
self.A_havok_dmd = None
Expand Down Expand Up @@ -235,7 +236,7 @@ def compute_hankel(
delay_interval=None,
):
"""
Computes the Hankel matrix from the provided data and forms Omega.
Computes the Hankel matrices H and Hu from the provided data.
"""
if self.verbose:
print("Computing Hankel matrices ...")
Expand Down Expand Up @@ -295,10 +296,11 @@ def compute_hankel(

def compute_svd(self):
"""
Computes the SVD of the Omega and Y matrices.
Computes the SVD of the Omega and Y matrices (standard stacked DMDc formulation).
Omega = vstack((X, U)) with X = state at t, U = control at t; Y = state at t+steps_ahead.
"""
if self.verbose:
print("Computing SVD on H and U matrices ...")
print("Computing SVD on Omega and Y matrices ...")

if self.is_list_data:
self.H_shapes = [h.shape for h in self.H]
Expand All @@ -314,7 +316,6 @@ def compute_svd(self):
else:
H_list.append(h_elem)

self.Hu_shapes = [h.shape for h in self.Hu]
for hu_elem in self.Hu:
if hu_elem.ndim == 3:
Hu_list.append(
Expand All @@ -326,8 +327,6 @@ def compute_svd(self):
Hu_list.append(hu_elem)
self.H = torch.cat(H_list, dim=0)
self.Hu = torch.cat(Hu_list, dim=0)
# H = torch.cat(H_list, dim=0)
self.H_row_counts = [h.shape[0] for h in H_list]
H = self.H
Hu = self.Hu

Expand All @@ -337,79 +336,42 @@ def compute_svd(self):
else:
H = self.H
Hu = self.Hu
self.Uh, self.Sh, self.Vh = torch.linalg.svd(H.T, full_matrices=False)
self.Uu, self.Su, self.Vu = torch.linalg.svd(Hu.T, full_matrices=False)

self.Vh = self.Vh.T
self.Vu = self.Vu.T
# Stacked formulation: X (state at t), Y (state at t+steps_ahead), U (control at t)
# H, Hu are (n_samples, n_features); we want (n_features, n_samples) for Omega
self.X = H[: -self.steps_ahead].T
self.Y = H[self.steps_ahead :].T
U = Hu[: -self.steps_ahead].T
self.Omega = torch.vstack((self.X, U))

# SVD of Omega and Y
Up, Sp, Vp = torch.linalg.svd(self.Omega, full_matrices=False)
Vp = Vp.conj().T

self.Sh_mat = torch.diag(self.Sh).to(self.device)
self.Sh_mat_inv = torch.diag(1 / self.Sh).to(self.device)
n_states = self.X.shape[0]
self.Up1 = Up[:n_states, :]
self.Up2 = Up[n_states:, :]

self.Su_mat = torch.diag(self.Su).to(self.device)
self.Su_mat_inv = torch.diag(1 / self.Su).to(self.device)
Ur, Sr, Vr = torch.linalg.svd(self.Y, full_matrices=False)
Vr = Vr.conj().T

self.Up = Up
self.Sp = Sp
self.Vp = Vp
self.Ur = Ur
self.Sr = Sr
self.Vr = Vr

self.cumulative_explained_variance_input = self._compute_explained_variance(
self.Su
self.Sp
)
self.cumulative_explained_variance_output = self._compute_explained_variance(
self.Sh
self.Sr
)

self.Vht_minus, self.Vht_plus = self.get_plus_minus(self.Vh, self.H,self.H_shapes if self.is_list_data else None)
self.Vut_minus, _ = self.get_plus_minus(self.Vu, self.Hu,self.Hu_shapes if self.is_list_data else None)

if self.verbose:
print("SVDs computed!")

def get_plus_minus(self, V, H,H_shapes=None):
if self.ntrials > 1:
if self.is_list_data:
V_split = torch.split(V, self.H_row_counts, dim=0)
Vt_minus_list, Vt_plus_list = [], []
for v_part, h_shape in zip(V_split, H_shapes):
if len(h_shape) == 3: # Has trials
v_part_reshaped = v_part.reshape(h_shape)
newshape = (
h_shape[0] * (h_shape[1] - self.steps_ahead),
h_shape[2],
)
Vt_minus_list.append(
v_part_reshaped[:, : -self.steps_ahead].reshape(newshape)
)
Vt_plus_list.append(
v_part_reshaped[:, self.steps_ahead :].reshape(newshape)
)
else: # No trials, just time and features
Vt_minus_list.append(v_part[: -self.steps_ahead])
Vt_plus_list.append(v_part[self.steps_ahead :])

Vt_minus = torch.cat(Vt_minus_list, dim=0)
Vt_plus = torch.cat(Vt_plus_list, dim=0)
else:

if V.numel() < H.numel():
raise ValueError(
"The dimension of the SVD of the Hankel matrix is smaller than the dimension of the Hankel matrix itself. \n \
This is likely due to the number of time points being smaller than the number of dimensions. \n \
Please reduce the number of delays."
)

V = V.reshape(H.shape)

# first reshape back into Hankel shape, separated by trials
newshape = (
H.shape[0] * (H.shape[1] - self.steps_ahead),
H.shape[2],
)
Vt_minus = V[:, : -self.steps_ahead].reshape(newshape)
Vt_plus = V[:, self.steps_ahead :].reshape(newshape)
else:
Vt_minus = V[: -self.steps_ahead]
Vt_plus = V[self.steps_ahead :]

return Vt_minus, Vt_plus

def recalc_rank(
self,
rank_input=None,
Expand All @@ -422,20 +384,19 @@ def recalc_rank(
"""
Recalculates the rank for input and output based on provided parameters.
"""
# Recalculate ranks for input
# Recalculate ranks for input (Omega) and output (Y)
self.rank_input = self._compute_rank_from_params(
S=self.Su,
S=self.Sp,
cumulative_explained_variance=self.cumulative_explained_variance_input,
max_rank=self.Hu.shape[-1],
max_rank=self.Omega.shape[-1],
rank=rank_input,
rank_thresh=rank_thresh_input,
rank_explained_variance=rank_explained_variance_input,
)
# Recalculate ranks for output
self.rank_output = self._compute_rank_from_params(
S=self.Sh,
S=self.Sr,
cumulative_explained_variance=self.cumulative_explained_variance_output,
max_rank=self.H.shape[-1],
max_rank=self.Y.shape[-1],
rank=rank_output,
rank_thresh=rank_thresh_output,
rank_explained_variance=rank_explained_variance_output,
Expand All @@ -446,46 +407,34 @@ def compute_dmdc(self, lamb=None):
if self.verbose:
print("Computing DMDc matrices ...")

self.lamb = self.lamb if lamb is None else lamb
lamb = self.lamb if lamb is None else lamb

V_minus_tot = torch.cat(
[
self.Vht_minus[:, : self.rank_output],
self.Vut_minus[:, : self.rank_input],
],
dim=1,
)
# Use stored SVD components (standard stacked formulation)
Up1 = self.Up1[:, : self.rank_input]
Up2 = self.Up2[:, : self.rank_input]
Vp = self.Vp[:, : self.rank_input]
Sp = self.Sp[: self.rank_input]
Ur = self.Ur[:, : self.rank_output]

A_v_tot = (
torch.linalg.inv(
V_minus_tot.T @ V_minus_tot
+ self.lamb * torch.eye(V_minus_tot.shape[1]).to(self.device)
)
@ V_minus_tot.T
@ self.Vht_plus[:, : self.rank_output]
).T
# split A_v_tot into A_v and B_v
self.A_v = A_v_tot[:, : self.rank_output]
self.B_v = A_v_tot[:, self.rank_output :]
self.A_havok_dmd = (
self.Uh
@ self.Sh_mat[: self.Uh.shape[1], : self.rank_output]
@ self.A_v
@ self.Sh_mat_inv[: self.rank_output, : self.Uh.shape[1]]
@ self.Uh.T
)
# Tikhonov-regularized inverse: Sp_inv = Sp / (Sp^2 + lamb)
Sp_inv = torch.diag(Sp / (Sp**2 + lamb)).to(self.device)

self.B_havok_dmd = (
self.Uh
@ self.Sh_mat[: self.Uh.shape[1], : self.rank_output]
@ self.B_v
@ self.Su_mat_inv[: self.rank_input, : self.Uu.shape[1]]
@ self.Uu.T
)
A_full = self.Y @ Vp @ Sp_inv @ Up1.T.conj()
B_full = self.Y @ Vp @ Sp_inv @ Up2.T.conj()

# Project A onto reduced output subspace for A_tilde, B_tilde
self.A_tilde = Ur.T.conj() @ A_full @ Ur
self.B_tilde = Ur.T.conj() @ B_full

# Reduced-order (fitted) operators in output subspace
self.A_v = self.A_tilde
self.B_v = self.B_tilde

# Set the A and B properties for backward compatibility and easier access
self.A = self.A_v
self.B = self.A_v
self.B = self.B_v

self.A_havok_dmd = A_full.float()
self.B_havok_dmd = B_full.float()

if self.verbose:
print("DMDc matrices computed!")
Expand Down Expand Up @@ -604,10 +553,6 @@ def predict(

for t in range(1, H_test.shape[1]):
u_t = H_control[:, t - 1]
# print(A.shape)
# print(H_test[:, t - 1].shape)
# print(B.shape)
# print(u_t.shape)
if t % reseed == 0:
H_test_dmdc[:, t] = (A @ H_test[:, t - 1].transpose(-2, -1)).transpose(
-2, -1
Expand Down
Loading