-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathMnvReaderSQLite.py
More file actions
97 lines (88 loc) · 3.18 KB
/
MnvReaderSQLite.py
File metadata and controls
97 lines (88 loc) · 3.18 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
#!/usr/bin/env python
"""
"""
from six.moves import range
from sqlalchemy import create_engine
from sqlalchemy import Table, Column, Integer, Float
from sqlalchemy import UniqueConstraint
from sqlalchemy import MetaData
from sqlalchemy import select
from sqlalchemy import and_
class MnvCategoricalSQLiteReader:
"""
record segments or planecodes in a sqlite db
"""
def __init__(self, n_classes, db_base_name, db_prob_column_format='prob%03d'):
self.n_classes = n_classes
self.db_name = db_base_name + '.db'
self._db_prob_columns_format = db_prob_column_format
self._configure_db()
def read_data(self, limit=1):
""" test reader - careful calling this on anything but tiny dbs! """
s = select([self.table])
s = s.limit(limit)
rp = self.connection.execute(s)
results = rp.fetchall()
return results
def read_record(self, run, subrun, gate, evt):
"""
result structure
r[0][0] == id
r[0][1:5] == (run, sub, gate, evt)
r[0][5] == segment prediction
r[0][6:] == individual segment probabilities
"""
s = select([self.table]).where(
and_(
self.table.c.run == run,
self.table.c.subrun == subrun,
self.table.c.gate == gate,
self.table.c.phys_evt == evt
)
)
rp = self.connection.execute(s)
results = rp.fetchall()
return results
def read_record_by_id(self, id):
""" by id """
s = select([self.table]).where(
self.table.c.id == id
)
rp = self.connection.execute(s)
results = rp.fetchall()
return results
def get_argmax_prediction(self, run, subrun, gate, evt):
""" get the segment / planecode """
s = select([self.table]).where(
and_(
self.table.c.run == run,
self.table.c.subrun == subrun,
self.table.c.gate == gate,
self.table.c.phys_evt == evt
)
)
rp = self.connection.execute(s)
results = rp.fetchall()
return results[0][5]
def _setup_prediction_table(self):
self.table = Table('zsegment_prediction', self.metadata,
Column('id', Integer(), primary_key=True),
Column('run', Integer()),
Column('subrun', Integer()),
Column('gate', Integer()),
Column('phys_evt', Integer()),
Column('segment', Integer()),
UniqueConstraint(
'run', 'subrun', 'gate', 'phys_evt'
))
for i in range(self.n_classes):
name = self._db_prob_columns_format % i
col = Column(name, Float())
self.table.append_column(col)
def _configure_db(self):
db = 'sqlite:///' + self.db_name
self.metadata = MetaData()
self.engine = create_engine(db)
self.connection = self.engine.connect()
self._setup_prediction_table()
self.metadata.create_all(self.engine)