-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathmain.py
More file actions
38 lines (27 loc) · 681 Bytes
/
Copy pathmain.py
File metadata and controls
38 lines (27 loc) · 681 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import torch
import numpy as np
import random
import dataset
import trainer
import linear_rfm
import svd_free_lin_rfm
import csv
SEED = 6
torch.manual_seed(SEED)
random.seed(SEED)
np.random.seed(SEED)
torch.cuda.manual_seed(SEED)
def main():
d = 500
r = 5
NUM_RFM_ITERS = 3000
num_obs = 10000
reg = 5e-2
Y, unmasked = dataset.get_data(d, r, num_obs)
loss = linear_rfm.linear_rfm(Y, unmasked, NUM_RFM_ITERS,
reg=reg)
print("Linear RFM Alpha = 1: ", loss)
loss = svd_free_lin_rfm.rfm(Y, unmasked, NUM_RFM_ITERS, reg=reg)
print("Linear RFM Alpha = 1/2: ", loss)
if __name__ == "__main__":
main()