-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathpredict_script.py
More file actions
90 lines (76 loc) · 2.86 KB
/
predict_script.py
File metadata and controls
90 lines (76 loc) · 2.86 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
# Run it using: python .\predict_script.py './test_assets/data/monthly_kansas_ndvi.csv' './test_assets/data/predict_kansas_wheat_yield.csv' 'values_output.csv'
import pickle
import pandas as pd
from sys import argv
# read args
monthly_ndvi_file = argv[1]
yield_results_file = argv[2]
output_file = argv[3]
# read CSVs
monthly_kansas_ndvi = pd.read_csv(monthly_ndvi_file)
predict_kansas_wheat_yield = pd.read_csv(yield_results_file).rename(str.lower, axis='columns')
# normalize
monthly_kansas_ndvi['county'] = monthly_kansas_ndvi['county'].str.title()
predict_kansas_wheat_yield['state'] = predict_kansas_wheat_yield['state'].str.title()
predict_kansas_wheat_yield['county'] = predict_kansas_wheat_yield['county'].str.title()
# transformations
def shift_ndvi(x):
x['ndvis'] = monthly_kansas_ndvi['ndvi'].shift(4)
return x
def get_county_id(x):
match = counties[counties['county'] == x['county']].county_id
x['county_id'] = match.values[0]
return x
counties = monthly_kansas_ndvi['county'].drop_duplicates().sort_values().to_frame()
counties['county_id'] = range(len(counties))
transformed = monthly_kansas_ndvi.groupby(['county']).apply(shift_ndvi)
transformed = transformed.set_index(['county', 'year', 'state', 'month']).unstack()
transformed = transformed['ndvis'].rename(columns = {
1: 'ndvi_9_p',
2: 'ndvi_10_p',
3: 'ndvi_11_p',
4: 'ndvi_12_p',
5: 'ndvi_1_c',
6: 'ndvi_2_c',
7: 'ndvi_3_c',
8: 'ndvi_4_c',
9: 'ndvi_5_c',
10: 'ndvi_6_c',
11: 'ndvi_7_c',
12: 'ndvi_8_c',
}).reset_index().rename_axis("", axis="columns").dropna()
transformed = transformed.apply(lambda row: get_county_id(row), axis=1)
# transformations end
# load the best performing model
loaded_model = pickle.load(open('finalized_model.sav', 'rb'))
# function to apply the model to each row
def apply_model(x):
match = transformed[
(transformed['year'] == x['year'])
& (transformed['county'] == x['county'])
& (transformed['state'] == x['state'])
]
if not match.empty:
result = loaded_model.predict(list(zip(
match['year'],
match['county_id'],
match['ndvi_9_p'],
match['ndvi_10_p'],
match['ndvi_11_p'],
match['ndvi_12_p'],
match['ndvi_1_c'],
match['ndvi_2_c'],
match['ndvi_3_c'],
match['ndvi_4_c'],
match['ndvi_5_c'],
match['ndvi_6_c'],
match['ndvi_7_c'],
match['ndvi_8_c'],
)))
x['value_prediction'] = result[0]
return x
# apply the model
predict_kansas_wheat_yield = predict_kansas_wheat_yield.apply(lambda row: apply_model(row), axis=1)
predict_kansas_wheat_yield = predict_kansas_wheat_yield[['state', 'county', 'year', 'value_prediction']]
# save the result
predict_kansas_wheat_yield.to_csv(output_file, index=False)