-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathpower_vs_error.py
More file actions
139 lines (118 loc) · 4 KB
/
power_vs_error.py
File metadata and controls
139 lines (118 loc) · 4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
''' Graph error of TORTE and ASIC models '''
from tkinter import Tk
from tkinter.filedialog import askopenfilename
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import math
import scipy.io as scio
import matlab.engine
from scipy.signal import hilbert, filtfilt, butter
from scipy.stats import circmean
from numpy import angle, array
import mne
print("starting...")
# Load matlab files
Tk().withdraw()
file = askopenfilename()
log_mat = scio.loadmat(file)
print("File: ", file)
mat_channel = log_mat["watchChannel"]
chosen_chan = int(mat_channel[0][0])
# Pull lfp data for selected channel
print("grabbing raw lfp data for channel....")
eng = matlab.engine.start_matlab()
chan = matlab.double([6])
chan_phase = eng.single_chan_lfp(chan, nargout = 2)
chans = np.asarray(chan_phase[0])
chans = chans[0].tolist()
### Calculating Ground Truth ###
print("calculating ground truth")
# Set some params #
fs = 30000 # sampling rate (default: 30000)
lowpass = 4.0 # Hz
low = lowpass / (fs/2)
highpass = 12.0 # Hz
high = highpass / (fs/2)
order = 2
data = chans
# Filter / Get Phase ###
[b, a] = butter(order, [low, high], btype = 'band') # Butterworth digital and analog filter
filtered_data = filtfilt(b, a, data)
analytic_data = hilbert(filtered_data)
phase_data = angle(analytic_data)
# Calculate mean angle
rad_avg = circmean(phase_data)
avg = math.degrees(rad_avg)
phase_data = phase_data*180/3.14159 ## Converting to degrees
# Shorten phase to every second
phases = []
p = 0
for p in range(0, len(phase_data), 60):
phases.append(phase_data[p])
### Calculating TORTE ###
# Call torte function
print("calculating TORTE in matlab...")
mat_chan = matlab.double([chans])
buff = matlab.double([[200]])
torte = eng.hilbert_transformer_phase(mat_chan, buff, nargout = 3)
# Calculate mean angle
rad_avg = circmean(torte[0])
avg = math.degrees(rad_avg)
# Convert to correct data type and convert to degrees
temp_phases = array(torte[0])
d = 0
torte_phases = []
for d in range(0, len(temp_phases)):
temp = float(temp_phases[d])
torte_phases.append(temp)
torte_phases = np.degrees(torte_phases) ## change to degrees
# Call phase calculation functions and convert to list
print("calculating error")
#df_gtp = groundTruth(file)
gtp = phases
torte = torte_phases
# Calc circular distance b/w points
eng = matlab.engine.start_matlab()
mat_gtp = matlab.double([gtp])
mat_torte = matlab.double([torte])
mat_dist = eng.circ_dist(mat_gtp, mat_torte, nargout = 1)
eng.quit()
# Convert to degrees
rad_dist = np.asarray(mat_dist)
rad_dist = rad_dist[0]
dist = []
d = 0
for d in range(0, len(rad_dist)):
temp_degrees = math.degrees(rad_dist[d])
dist.append(temp_degrees)
print("starting power calculations")
c = 0
chan_list = []
for c in range(0, len(chans), 30): ## downsample to 1k hz
chan_list.append(chans[c])
print("===========structuring data============")
chan_info = mne.create_info(1,1000, ch_types='ecog')
rawData = np.array([[chan_list]]) ## (1,1,18033) (epochs, chans, times)
epochs = mne.EpochsArray(data=rawData, info=chan_info)
print("==========CALCULATING POWER========")
frequencies = np.logspace(np.log10(1), np.log10(30), 32)
num_cycles = np.logspace(np.log10(3), np.log10(7), 32)
power = mne.time_frequency.tfr_array_multitaper(epochs, 1000, frequencies,time_bandwidth = 3.0, output = 'power', n_cycles=num_cycles)
## output is (n_epochs, n_chans, n_freqs, n_times)
## epoch_data = shape(epochs, chan, time), sfreq = 30,000 (1k ds), freqs = [1-30],
# n_cycles=7.0, zero_mean=True, time_bandwidth=3, use_fft=True, decim=1, output='power', n_jobs=1, verbose=None)[source]
# Calculating mean over 30 Hz for each second
power = power[0][0]
power_df = pd.DataFrame(power)
power_means = power_df.mean(axis=0)
all_means = power_means.values.tolist()
# Shortening to every other second to match error array
m = 0
means = []
for m in range(0, len(all_means),2):
means.append(all_means[m])
# Calculating correlation coefficient
corr_matrix = np.corrcoef(means, y=dist)
print(corr_matrix)