Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 37 additions & 8 deletions predictability_utils/methods/lrlin_method.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,42 @@ def default_device():
return torch.ones((1,)).device


def run_lrlin(source_data, target_data, n_latents, idcs, if_plot=False, map_shape=None,
n_epochs=10000, lr=1e-4, batch_size=None, weight_decay=0., weight_lasso=0.):
def run_lrlin(source_data, target_data, n_pca_x, n_pca_y, n_latents, idcs, n_epochs, lr,
if_plot=False, map_shape=None, batch_size=None, weight_decay=0., weight_lasso=0.):

T = source_data.shape[0]
assert T == target_data.shape[0]
idx_source_train, idx_target_train, idx_source_test, idx_target_test = idcs

# predict T2ms in Summer from soil moisture levels in Spring
#Training: predict t2m for train data
X = torch.tensor(source_data.reshape(T, -1)[idx_source_train,:].mean(axis=0))
Y = torch.tensor(target_data.reshape(T, -1)[idx_target_train,:].mean(axis=0))

# pca decomposition
n_pca_x = np.min(Xd.shape) if n_pca_x is None else n_pca_x
n_pca_y = np.min(Yd.shape) if n_pca_y is None else n_pca_y

pca_x = PCA(n_components=n_pca_x, copy=True, whiten=False)
pca_y = PCA(n_components=n_pca_y, copy=True, whiten=False)

pca_x.fit(Xd)
pca_y.fit(Yd)
#print(f'yi: {pca_y.explained_variance_ratio_}'); print(f'xi: {pca_x.explained_variance_ratio_}')
#print(f'sumy: {pca_y.explained_variance_ratio_.sum()}'); print(f'sumx: {pca_x.explained_variance_ratio_.sum()}')

if n_pca_x <= np.min(Xd.shape):
X, A = pca_x.transform(Xd), pca_x.components_
else:
X, A = Xd, None
if n_pca_y <= np.min(Yd.shape):
Y, B = pca_y.transform(Yd), pca_y.components_
else:
Y, B = Yd, None

# fit linear model
X = torch.tensor(X, dtype=torch.float32)
Y = torch.tensor(Y, dtype=torch.float32)

# fit CCA-based model
lrlm = LR_lin_method(n_latents=n_latents)
loss_vals = lrlm.fit(X,Y,
n_epochs=n_epochs,
Expand All @@ -36,14 +60,18 @@ def run_lrlin(source_data, target_data, n_latents, idcs, if_plot=False, map_shap
weight_decay=weight_decay,
weight_lasso=weight_lasso)

# predict T2ms for test data (1951 - 2010)
#Forecasting: predict t2m for test data
X_f = source_data.reshape(T, -1)[idx_source_test,:].mean(axis=0)
X_f = pca_x.transform(X_f) if not A is None else X_f
X_f = torch.tensor(X_f, dtype=torch.float32)

out_pred = lrlm.predict(X_f)
out_pred = pca_y.inverse_transform(out_pred) if not B is None else out_pred

# evaluate prediction performance
out_true = target_data.reshape(T, -1)[idx_target_test,:].mean(axis=0)
anomaly_corrs = helpers.compute_anomaly_corrs(out_true, out_pred)

# visualize anomaly correlations and loss curve during training
if if_plot:

Expand All @@ -53,7 +81,8 @@ def run_lrlin(source_data, target_data, n_latents, idcs, if_plot=False, map_shap
plt.title('loss curve')
plt.show()

params = {'U' : lrlm._U.detach().numpy(), 'V': lrlm._V.detach().numpy() }
params = {'U' : lrlm._U.detach().numpy(), 'V': lrlm._V.detach().numpy(),
'out_pred': out_pred, 'out_true': out_true, 'PCx': X, 'EOFx': A, 'PCy': Y, 'EOFy': B }

return anomaly_corrs, params

Expand All @@ -66,7 +95,7 @@ def __init__(self, n_latents, device=None):
self._U, self.V = None, None
self._device = default_device() if device is None else device

def fit(self, X, Y, lr=1e-2, n_epochs=2000, batch_size=None, weight_decay=0., weight_lasso=0.):
def fit(self, X, Y, lr, n_epochs, batch_size=None, weight_decay=0., weight_lasso=0.):

batch_size = X.shape[0] if batch_size is None else batch_size

Expand Down