-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathutil.py
More file actions
85 lines (62 loc) · 2.58 KB
/
util.py
File metadata and controls
85 lines (62 loc) · 2.58 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import pickle
import numpy as np
import os
from DBConnection import db
def sensor_list(location='data/sensor_graph/adj_mx_bay_DCRNN.pkl'):
return list(load_pickle(location)[0])
def load_pickle(pickle_file):
try:
with open(pickle_file, 'rb') as f:
pickle_data = pickle.load(f)
except UnicodeDecodeError as e:
with open(pickle_file, 'rb') as f:
pickle_data = pickle.load(f, encoding='latin1')
except Exception as e:
print('Unable to load data ', pickle_file, ':', e)
raise
return pickle_data
def load_dataset(dataset_dir):
data = {}
for category in ['val']:
cat_data = np.load(os.path.join(dataset_dir, category + '.npz'))
data['x_' + category] = cat_data['x']
data['y_' + category] = cat_data['y']
return data
def create_raw_table(name):
query = "SELECT NOT EXISTS (SELECT 1 FROM pg_tables WHERE tablename='{0}')".format(name)
create_table = db.execute_query(query)[0][0]
if create_table:
query = 'CREATE TABLE {0} (' \
'id serial PRIMARY KEY, ' \
'time VARCHAR (50), ' \
'sensor_id numeric , '.format(name)
for i in range(8):
query += 'occupancy_lane_{0} VARCHAR(10), '.format(i)
query += 'speed_lane_{0} VARCHAR(10), '.format(i)
query += 'cars_lane_{0} VARCHAR(10), '.format(i)
query = query[:-2] + ')'
db.execute_command(query, 'Successfully created d4_raw_table')
def create_final_table(name):
query = "SELECT NOT EXISTS (SELECT 1 FROM pg_tables WHERE tablename='{0}')".format(name)
create_table = db.execute_query(query)[0][0]
if create_table:
query = 'CREATE TABLE {0} (' \
'time VARCHAR (50), ' \
'sensor_id numeric , '.format(name)
for i in range(11):
query += 'bucket_{} numeric, '.format(i)
query = query[:-2] + ')'
db.execute_command(query, 'Successfully created d4_raw_table')
def create_final_table_deterministic(name):
query = "SELECT NOT EXISTS (SELECT 1 FROM pg_tables WHERE tablename='{0}')".format(name)
create_table = db.execute_query(query)[0][0]
if create_table:
query = 'CREATE TABLE {0} (' \
'time VARCHAR (50), ' \
'sensor_id numeric , ' \
'speed numeric'.format(name)
query = query + ')'
db.execute_command(query, 'Successfully created d4_raw_table')
def shuffle_along_axis(a, axis):
idx = np.random.rand(*a.shape).argsort(axis=axis)
return np.take_along_axis(a, idx, axis=axis)