-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathlinking.py
More file actions
373 lines (316 loc) · 13.1 KB
/
linking.py
File metadata and controls
373 lines (316 loc) · 13.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
import h5py
import yaml
import numpy as np
from numba import jit
import numba
from collections import OrderedDict
import os.path as ospath
import os
###############################################################################################
# HDF5 and YAML files loading and saving
###############################################################################################
def load_locs(path):
"""Load localizations from HDF5 file."""
with h5py.File(path, "r") as locs_file:
locs = locs_file["locs"][...]
locs = np.rec.array(locs, dtype=locs.dtype)
info = load_info(path)
return locs, info
def load_info(path):
"""Load metadata from the corresponding YAML file."""
path_base = path.rsplit(".", 1)[0]
filename = path_base + ".yaml"
try:
with open(filename, "r") as info_file:
info = list(yaml.load_all(info_file, Loader=yaml.UnsafeLoader))
except FileNotFoundError as e:
print(f"Could not find metadata file: {filename}")
raise FileNotFoundError(e)
return info
def save_locs(path, locs, info):
"""Save localizations to an HDF5 file."""
with h5py.File(path, "w") as locs_file:
locs_file.create_dataset("locs", data=locs)
base = path.rsplit(".", 1)[0]
save_info(base + ".yaml", info)
def save_info(path, info):
"""Save metadata information to a YAML file."""
with open(path, "w") as file:
yaml.dump_all(info, file, default_flow_style=False)
###############################################################################################
# Linking
###############################################################################################
def process_linking(input_folder, max_frame_gap, nena_multiplier,
input_extension, output_extension, nena_data):
"""Process all HDF5 files in the input folder for linking."""
hdf5_files = []
for root, _, files in os.walk(input_folder):
for f in files:
if f.endswith(f'{input_extension}.hdf5') and output_extension not in f:
hdf5_files.append((root, f))
print(f"Found {len(hdf5_files)} files to process")
for root, filename in hdf5_files:
input_file = os.path.join(root, filename)
# Save in same directory as input file
output_filename = filename.replace('.hdf5', f"{output_extension}.hdf5")
output_file = os.path.join(root, output_filename)
print(f"\nProcessing file: {filename}")
try:
# Check if the current file's name contains any of the base names from NeNA calculation
nena_value = None
for base_name in nena_data['values'].keys():
if base_name in filename:
nena_value = nena_data['values'][base_name]
print(f"Found matching NeNA value for base name: {base_name}")
break
if nena_value is None:
raise ValueError(f"No matching NeNA value found for file {filename}")
locs, info = load_locs(input_file)
max_distance = nena_value * nena_multiplier
linked_locs = link(
locs,
info,
r_max=max_distance,
max_dark_time=max_frame_gap,
combine_mode="average",
remove_ambiguous_lengths=True
)
save_locs(output_file, linked_locs, info)
except Exception as e:
print(f"Error processing {filename}: {str(e)}")
continue
print("\nAll files linked")
def append_to_rec(array, values, field_name):
"""Appends a new field to a structured NumPy array."""
new_dtype = array.dtype.descr + [(field_name, values.dtype)]
new_array = np.empty(array.shape, dtype=new_dtype)
for name in array.dtype.names:
new_array[name] = array[name]
new_array[field_name] = values
return new_array
def link(
locs,
info,
r_max=0.05,
max_dark_time=1,
combine_mode="average",
remove_ambiguous_lengths=True,
):
"""Links localizations across frames based on proximity and temporal constraints."""
if len(locs) == 0:
linked_locs = locs.copy()
if "frame" in locs.dtype.names:
linked_locs = append_to_rec(linked_locs, np.array([], dtype=np.int32), "len")
linked_locs = append_to_rec(linked_locs, np.array([], dtype=np.int32), "n")
if "photons" in locs.dtype.names:
linked_locs = append_to_rec(linked_locs, np.array([], dtype=np.float32), "photon_rate")
return linked_locs
locs.sort(kind="mergesort", order="frame")
group = locs["group"] if "group" in locs.dtype.names else np.zeros(len(locs), dtype=np.int32)
link_group = get_link_groups(locs, r_max, max_dark_time, group)
if combine_mode == "average":
linked_locs = link_loc_groups(locs, info, link_group, remove_ambiguous_lengths)
elif combine_mode == "refit":
pass # Placeholder for future refit implementation
return linked_locs
@numba.jit(nopython=True)
def get_link_groups(locs, d_max, max_dark_time, group):
"""Assumes that locs are sorted by frame"""
frame = locs.frame
x = locs.x
y = locs.y
N = len(x)
link_group = -np.ones(N, dtype=np.int32)
current_link_group = -1
for i in range(N):
if link_group[i] == -1: # loc has no group yet
current_link_group += 1
link_group[i] = current_link_group
current_index = i
next_loc_index_in_group = _get_next_loc_index_in_link_group(
current_index,
link_group,
N,
frame,
x,
y,
d_max,
max_dark_time,
group,
)
while next_loc_index_in_group != -1:
link_group[next_loc_index_in_group] = current_link_group
current_index = next_loc_index_in_group
next_loc_index_in_group = _get_next_loc_index_in_link_group(
current_index,
link_group,
N,
frame,
x,
y,
d_max,
max_dark_time,
group,
)
return link_group
def link_loc_groups(locs, info, link_group, remove_ambiguous_lengths=True):
n_locs = len(link_group)
n_groups = link_group.max() + 1
n_ = _link_group_count(link_group, n_locs, n_groups)
columns = OrderedDict()
if hasattr(locs, "frame"):
first_frame_, last_frame_ = _link_group_min_max(
locs.frame, link_group, n_locs, n_groups
)
columns["frame"] = first_frame_.astype(np.int32)
if hasattr(locs, "x"):
# Add small epsilon to prevent division by zero
eps = np.finfo(float).eps # smallest float number
weights_x = 1 / np.maximum(locs.lpx**2, eps)
columns["x"], sum_weights_x_ = _link_group_weighted_mean(
locs.x, weights_x, link_group, n_locs, n_groups, n_
)
if hasattr(locs, "y"):
# Add small epsilon to prevent division by zero
eps = np.finfo(float).eps # smallest float number
weights_y = 1 / np.maximum(locs.lpy**2, eps)
columns["y"], sum_weights_y_ = _link_group_weighted_mean(
locs.y, weights_y, link_group, n_locs, n_groups, n_
)
if hasattr(locs, "photons"):
columns["photons"] = _link_group_sum(locs.photons, link_group, n_locs, n_groups)
if hasattr(locs, "sx"):
columns["sx"] = _link_group_mean(locs.sx, link_group, n_locs, n_groups, n_)
if hasattr(locs, "sy"):
columns["sy"] = _link_group_mean(locs.sy, link_group, n_locs, n_groups, n_)
if hasattr(locs, "bg"):
columns["bg"] = _link_group_sum(locs.bg, link_group, n_locs, n_groups)
if hasattr(locs, "x"):
columns["lpx"] = np.sqrt(1 / sum_weights_x_)
if hasattr(locs, "y"):
columns["lpy"] = np.sqrt(1 / sum_weights_y_)
if hasattr(locs, "ellipticity"):
columns["ellipticity"] = _link_group_mean(
locs.ellipticity, link_group, n_locs, n_groups, n_
)
if hasattr(locs, "net_gradient"):
columns["net_gradient"] = _link_group_mean(
locs.net_gradient, link_group, n_locs, n_groups, n_
)
if hasattr(locs, "likelihood"):
columns["likelihood"] = _link_group_mean(
locs.likelihood, link_group, n_locs, n_groups, n_
)
if hasattr(locs, "iterations"):
columns["iterations"] = _link_group_mean(
locs.iterations, link_group, n_locs, n_groups, n_
)
if hasattr(locs, "z"):
columns["z"] = _link_group_mean(locs.z, link_group, n_locs, n_groups, n_)
if hasattr(locs, "d_zcalib"):
columns["d_zcalib"] = _link_group_mean(
locs.d_zcalib, link_group, n_locs, n_groups, n_
)
if hasattr(locs, "group"):
columns["group"] = _link_group_last(locs.group, link_group, n_locs, n_groups)
if hasattr(locs, "frame"):
columns["len"] = (last_frame_ - first_frame_ + 1).astype(np.int32)
columns["n"] = n_.astype(np.int32)
if hasattr(locs, "photons"):
columns["photon_rate"] = np.float32(columns["photons"] / n_)
linked_locs = np.rec.array(list(columns.values()), names=list(columns.keys()))
if remove_ambiguous_lengths:
valid = np.logical_and(first_frame_ > 0, last_frame_ < info[0]["Frames"])
linked_locs = linked_locs[valid]
return linked_locs
@numba.jit(nopython=True)
def _link_group_count(link_group, n_locs, n_groups):
"""Count number of localizations per group"""
n = np.zeros(n_groups, dtype=np.int32)
for i in range(n_locs):
n[link_group[i]] += 1
return n
@numba.jit(nopython=True)
def _link_group_min_max(values, link_group, n_locs, n_groups):
"""Find min and max values for each group"""
min_values = np.full(n_groups, np.inf)
max_values = np.full(n_groups, -np.inf)
for i in range(n_locs):
group = link_group[i]
value = values[i]
if value < min_values[group]:
min_values[group] = value
if value > max_values[group]:
max_values[group] = value
return min_values, max_values
@numba.jit(nopython=True)
def _link_group_sum(values, link_group, n_locs, n_groups):
"""Sum values within each group"""
sums = np.zeros(n_groups, dtype=values.dtype)
for i in range(n_locs):
sums[link_group[i]] += values[i]
return sums
@numba.jit(nopython=True)
def _link_group_mean(values, link_group, n_locs, n_groups, n):
"""Calculate mean values for each group"""
sums = _link_group_sum(values, link_group, n_locs, n_groups)
means = sums / n
return means
@numba.jit(nopython=True)
def _link_group_weighted_mean(values, weights, link_group, n_locs, n_groups, n):
"""Calculate weighted mean values for each group"""
sum_weights = np.zeros(n_groups, dtype=np.float32)
sum_weighted_values = np.zeros(n_groups, dtype=np.float32)
for i in range(n_locs):
group = link_group[i]
weight = weights[i]
sum_weights[group] += weight
sum_weighted_values[group] += weight * values[i]
means = sum_weighted_values / sum_weights
return means, sum_weights
@numba.jit(nopython=True)
def _link_group_last(values, link_group, n_locs, n_groups):
"""Get last value for each group"""
last_values = np.zeros(n_groups, dtype=values.dtype)
for i in range(n_locs):
last_values[link_group[i]] = values[i]
return last_values
@numba.jit(nopython=True)
def _get_next_loc_index_in_link_group(
current_index, link_group, N, frame, x, y, d_max, max_dark_time, group
):
"""Helper function to find the next localization in the same link group."""
current_frame = frame[current_index]
current_x = x[current_index]
current_y = y[current_index]
current_group = group[current_index]
# Convert to basic numeric types for comparison
min_frame = float(current_frame + 1)
max_frame = float(current_frame + max_dark_time + 1)
# Find valid frame range
min_index = current_index + 1
while min_index < N and float(frame[min_index]) < min_frame:
min_index += 1
max_index = min_index
while max_index < N and float(frame[max_index]) <= max_frame:
max_index += 1
# Calculate distances
d_max_2 = d_max * d_max
for j in range(min_index, max_index):
if group[j] == current_group and link_group[j] == -1:
dx2 = (float(current_x) - float(x[j])) ** 2
if dx2 <= d_max_2:
dy2 = (float(current_y) - float(y[j])) ** 2
if dx2 + dy2 <= d_max_2:
return j
return -1
if __name__ == "__main__":
# Load config file
with open('config.yaml', 'r') as f:
config = yaml.safe_load(f)
# Process files using config parameters
process_linking(
input_folder=config['paths']['input_folder'],
output_folder=config['paths']['output_folder'],
**config['linking']
)