forked from TransylvanianInstituteOfNeuroscience/Superlets
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathsuperlet.py
More file actions
320 lines (251 loc) · 10.2 KB
/
superlet.py
File metadata and controls
320 lines (251 loc) · 10.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
# Time-frequency analysis with superlets
# Based on 'Time-frequency super-resolution with superlets'
# by Moca et al., 2021 Nature Communications
#
# Implementation by Harald Bârzan and Richard Eugen Ardelean
#
# Note: for runs on multiple batches of data, the class SuperletTransform can be instantiated just once
# this saves time and memory allocation for the wavelets and buffers
#
import numpy as np
from scipy.signal import fftconvolve
# spread, in units of standard deviation, of the Gaussian window of the Morlet wavelet
MORLET_SD_SPREAD = 6
# the length, in units of standard deviation, of the actual support window of the Morlet
MORLET_SD_FACTOR = 2.5
def computeWaveletSize(fc, nc, fs):
"""
Compute the size in samples of a morlet wavelet.
Arguments:
fc - center frequency in Hz
nc - number of cycles
fs - sampling rate in Hz
"""
sd = (nc / 2) * (1 / np.abs(fc)) / MORLET_SD_FACTOR
return int(2 * np.floor(np.round(sd * fs * MORLET_SD_SPREAD) / 2) + 1)
def computeLongestWaveletSize(fs, foi, c1, ord):
"""
Estimates the size of the longest wavelet.
"""
# make order parameter
if len(ord) == 1:
ord = (ord, ord)
orders = np.linspace(start=ord[0], stop=ord[1], num=len(foi))
# create wavelets
max = 0
for iFreq in range(len(foi)):
centerFreq = foi[iFreq]
nWavelets = int(np.ceil(orders[iFreq]))
for iWave in range(nWavelets):
# create morlet wavelet
wlen = computeWaveletSize(centerFreq, fs, (iWave + 1) * c1)
if wlen > max:
max = wlen
return max
def gausswin(size, alpha):
"""
Create a Gaussian window.
"""
halfSize = int(np.floor(size / 2))
idiv = alpha / halfSize
t = (np.arange(size, dtype=np.float64) - halfSize) * idiv
window = np.exp(-(t * t) * 0.5)
return window
def morlet(fc, nc, fs):
"""
Create an analytic Morlet wavelet.
Arguments:
fc - center frequency in Hz
nc - number of cycles
fs - sampling rate in Hz
"""
size = computeWaveletSize(fc, nc, fs)
half = int(np.floor(size / 2))
gauss = gausswin(size, MORLET_SD_SPREAD / 2)
igsum = 1 / gauss.sum()
ifs = 1 / fs
t = (np.arange(size, dtype=np.float64) - half) * ifs
wavelet = gauss * np.exp(2 * np.pi * fc * t * 1j) * igsum
return wavelet
def fractional(x):
"""
Get the fractional part of the scalar value x.
"""
return x - int(x)
class SuperletTransform:
"""
Class used to compute the Superlet Transform of input data.
"""
def __init__( self,
inputSize,
samplingRate,
frequencyRange,
frequencyBins,
baseCycles,
superletOrders,
frequencies = None):
"""
Initialize the superlet transform.
Arguments:
inputSize: size of the input in samples
samplingRate: the sampling rate of the input signal in Hz
frequencyRange: tuplet of ascending frequency points, in Hz
frequencyBins: number of frequency bins to sample in the interval frequencyRange
baseCycles: number of cycles of the smallest wavelet (c1 in the paper)
superletOrders: a tuple containing the range of superlet orders, linearly distributed along frequencyRange
frequencies: specific list of frequencies - can be provided instead of frequencyRange (it is ignored in this case)
"""
# clear to reinit
self.clear()
# initialize containers
if frequencies is not None:
frequencyBins = len(frequencies)
self.frequencies = frequencies
else:
self.frequencies = np.linspace(start=frequencyRange[0], stop=frequencyRange[1], num=frequencyBins)
self.inputSize = inputSize
self.orders = np.linspace(start=superletOrders[0], stop=superletOrders[1], num=frequencyBins)
self.convBuffer = np.zeros(inputSize, dtype=np.complex128)
self.poolBuffer = np.zeros(inputSize, dtype=np.float64)
self.superlets = []
# create wavelets
for iFreq in range(frequencyBins):
centerFreq = self.frequencies[iFreq]
nWavelets = int(np.ceil(self.orders[iFreq]))
self.superlets.append([])
for iWave in range(nWavelets):
# create morlet wavelet
self.superlets[iFreq].append(morlet(centerFreq, (iWave + 1) * baseCycles, samplingRate))
def __del__(self):
"""
Destructor.
"""
self.clear()
def clear(self):
"""
Clear the transform.
"""
# fields
self.inputSize = None
self.superlets = None
self.poolBuffer = None
self.convBuffer = None
self.frequencies = None
self.orders = None
def longestWaveletSize(self):
"""
Return the size of the longest wavelet.
"""
max = 0
for s in self.superlets:
for w in s:
if w.shape[0] > max:
max = w.shape[0]
return max
def validTimeRegion(self):
"""
Compute the start and end of the valid spectrum region.
Returns:
start: the start of the valid time region
end: the end of the valid time region
"""
pad = self.longestWaveletSize() // 2
start = self.inputSize + pad
end = self.inputSize - pad
return start, end
def transform(self, inputData):
"""
Apply the transform to a buffer or list of buffers.
Arguments:
inputData - an NDarray of input data
"""
# compute number of arrays to transform
if len(inputData.shape) == 1:
if inputData.shape[0] != self.inputSize:
raise "Input data must meet the defined input size for this transform."
result = np.zeros((self.inputSize, len(self.frequencies)), dtype=np.float64)
self.transformOne(inputData, result)
return result
else:
n = int(np.sum(inputData.shape[0:len(inputData.shape) - 1]))
insize = int(inputData.shape[len(inputData.shape) - 1])
if insize != self.inputSize:
raise "Input data must meet the defined input size for this transform."
# reshape to data list
datalist = np.reshape(inputData, (n, insize), 'C')
result = np.zeros((len(self.frequencies), self.inputSize), dtype=np.float64)
for i in range(0, n):
self.transformOne(datalist[i, :], result)
return result / n
def transformOne(self, inputData, accumulator):
"""
Apply the superlet transform on a single data buffer.
Arguments:
inputData: A 1xInputSize array containing the signal to be transformed.
accumulator: a spectrum to accumulate the resulting superlet transform
"""
accumulator.resize((len(self.frequencies), self.inputSize))
for iFreq in range(len(self.frequencies)):
# init pooling buffer
self.poolBuffer.fill(1)
if len(self.superlets[iFreq]) > 1:
# superlet
nWavelets = int(np.floor(self.orders[iFreq]))
rfactor = 1.0 / nWavelets
for iWave in range(nWavelets):
self.convBuffer = fftconvolve(inputData, self.superlets[iFreq][iWave], "same")
self.poolBuffer *= 2 * np.abs(self.convBuffer) ** 2
if fractional(self.orders[iFreq]) != 0 and len(self.superlets[iFreq]) == nWavelets + 1:
# apply the fractional wavelet
exponent = self.orders[iFreq] - nWavelets
rfactor = 1 / (nWavelets + exponent)
self.convBuffer = fftconvolve(inputData, self.superlets[iFreq][nWavelets], "same")
self.poolBuffer *= (2 * np.abs(self.convBuffer) ** 2) ** exponent
# perform geometric mean
accumulator[iFreq, :] += self.poolBuffer ** rfactor
else:
# wavelet transform
accumulator[iFreq, :] += (2 * np.abs(fftconvolve(inputData, self.superlets[iFreq][0], "same")) ** 2).astype(np.float64)
def cropSpectrum(spectrum, paddingSize):
"""
Remove paddingSize samples at both ends of the spectrum.
Arguments:
spectrum: a 2D numpy array
paddingSize: number of samples to remove - equals to longestWaveletSize() / 2 of the computing SuperletTransform object
Returns:
the spectrum with the padding removed
"""
return spectrum[:, paddingSize:(spectrum.shape[1] - paddingSize)]
# main superlet function
def superlets(data,
fs,
foi,
c1,
ord):
"""
Perform fractional adaptive superlet transform (FASLT) on a list of trials.
Arguments:
data: a numpy array of data. The rightmost dimension of the data is the trial size. The result will be the average over all the spectra.
fs: the sampling rate in Hz
foi: list of frequencies of interest
c1: base number of cycles parameter
ord: the order (for SLT) or order range (for FASLT), spanned across the frequencies of interest
Returns: a matrix containing the average superlet spectrum
"""
# determine buffer size
bufferSize = data.shape[len(data.shape) - 1]
# make order parameter
if len(ord) == 1:
ord = (ord, ord)
# build the superlet analyzer
faslt = SuperletTransform( inputSize = bufferSize,
frequencyRange = None,
frequencyBins = None,
samplingRate = fs,
frequencies = foi,
baseCycles = c1,
superletOrders = ord)
# apply transform
result = faslt.transform(data)
faslt.clear()
return result