-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathPredict.py
More file actions
executable file
·91 lines (81 loc) · 3.01 KB
/
Predict.py
File metadata and controls
executable file
·91 lines (81 loc) · 3.01 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.nn.functional as F
import torch.utils.data as Data
import torchvision
import time
import json
from sklearn import preprocessing
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import Normalizer
import pandas as pd
from datetime import timedelta
from sklearn import metrics
from sklearn.model_selection import StratifiedKFold
import numpy as np
import matplotlib.pyplot as plt
import os
import sys
from DataUtils import DefineTestDataset
from DataUtils import readTestData
from model import CNN
from model import Net
from model import mymodel
def get_time_dif(start_time):
end_time = time.time()
time_dif = end_time - start_time
return timedelta(seconds=int(round(time_dif)))
def test_model(model, test_data, device):
model.eval()
test_loader = Data.DataLoader(test_data, batch_size=64,num_workers=8,pin_memory=True)
y_true, y_pred, y_pred_prob = [], [], []
for data, feature in test_loader:
data, feature = Variable(data), Variable(feature)
data, feature = data.to(device,non_blocking=True), feature.to(device,non_blocking=True)
output = model(data, feature)
pred = torch.max(output.data, dim=1)[1].cpu().numpy().tolist()
pred_prob = torch.softmax(output.data, dim=1).cpu()
pred_prob = np.asarray(pred_prob, dtype=float)
y_pred.extend(pred)
y_pred_prob.extend(pred_prob[:, 1].tolist())
return y_pred_prob
def main(iexp,itheory,ifeature,output,fmodel):
start = time.time()
L, idset = readTestData(iexp,itheory,ifeature)
L_idx = [i for i in range(len(L))]
test_data = DefineTestDataset(L_idx, L)
device = torch.device("cuda")
model = mymodel(CNN(), Net())
model.cuda()
model = nn.DataParallel(model)
model.to(device)
#model.load_state_dict(torch.load('./temp_model_second/epoch148.pt', map_location=lambda storage, loc: storage))
model.load_state_dict(torch.load(fmodel, map_location=lambda storage, loc: storage))
y_pred = test_model(model, test_data, device)
print(time.time() - start)
# f = open('./marine2/OSU_D2_FASP_Elite_02262014_' + str(i) + '.pin')
# fw=open('./marine2/OSU_rerank_'+str(i)+'.csv','w')
# fw_test = open('prob' + str(i) + '.txt', 'w')
fw_test = open(output, 'w')
count=0
for line_id, line in enumerate(idset):
# if line_id == 0:
# fw.write('rerank_score' + ',' + line)
# continue
# fw.write(str(y_pred[line_id - 1]) + ',' + line)
if line == False:
fw_test.write(str(line) + ':' + 'None' + '\n')
else:
fw_test.write(line + ':' + str(y_pred[count]) + '\n')
count+=1
# f.close()
# fw.close()
fw_test.close()
if __name__ == "__main__":
inputfile_exp=sys.argv[1]
inputfile_theory=sys.argv[2]
inputfile_feature=sys.argv[3]
outputfile=sys.argv[4]
modelfile=sys.argv[5]
main(inputfile_exp,inputfile_theory,inputfile_feature,outputfile,modelfile)