-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathpredict.py
More file actions
58 lines (49 loc) · 1.82 KB
/
predict.py
File metadata and controls
58 lines (49 loc) · 1.82 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import sys
import pickle
import joblib
import pandas as pd
import numpy as np
# Name of the embedding file(s) located in the data/embeddings/ folder for model training
# Optional: To combine multiple embedding files comma-separate each input file name
# (For example: combining training and validation data after hyperparameter tuning)
name = sys.argv[1]
# Load the embedding data file(s) for the model to be trained on
X = pd.read_pickle("data/embeddings/{}.pkl".format(name))
y = True if "y" in X.columns else False
# Load all the site names for the residue in all three splits
residues = sys.argv[2]
test = sys.argv[3]
if test == "test":
with open("data/sites/{}.{}.test.pkl".format(name, residues), 'rb') as f:
sites = pickle.load(f)
# Prepare the training data
X = X.loc[sites]
if y:
y = X["y"].tolist()
X = X.drop("y", axis=1)
elif test == "all":
X["res"] = X.index.str.split("_").str[-1].str[0]
X = X[X["res"].isin(list(residues))]
X = X.drop("res", axis=1)
if y:
y = X["y"].tolist()
X = X.drop("y", axis=1)
else:
print("The test variable should be either 'test' or 'all'")
# Load the model pickle file
model_name = sys.argv[4]
model = joblib.load("data/models/{}.{}.model".format(model_name, residues))
# Name of the output results saved to data/preds/
pred_name = sys.argv[5]
preds_proba = model.predict_proba(X)
preds = pd.DataFrame()
preds["site"] = X.index.tolist()
if y:
preds["label"] = y
preds["prob"] = preds_proba[:,1]
preds["pred"] = preds["prob"].apply(lambda x: 1 if x >= 0.5 else 0)
preds.to_csv("data/preds/{}.{}.{}.csv".format(pred_name, residues, test), index=False)
print("Prediction results for model file:", model_name)
print("Prediction results for input file:", name)
print("Name of output file:", pred_name)
print("Residues used:", residues)