Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
297 changes: 289 additions & 8 deletions src/dartsort/detect/detect.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,21 +5,25 @@
def detect_and_deduplicate(
traces,
threshold,
dedup_channel_index=None,
peak_sign="neg",
relative_peak_radius=5,
dedup_temporal_radius=7,
relative_peak_channel_index=None,
dedup_temporal_radius=11,
dedup_channel_index=None,
spatial_dedup_batch_size=512,
exclude_edges=True,
return_energies=False,
detection_mask=None,
trough_priority=None,
cumulant_order=None,
):
"""Detect and deduplicate peaks

torch-based peak detection and deduplication, relying
on max pooling and scatter operations

TODO: reuse bufs and pre-pad.

Arguments
---------
traces : time by channels tensor
Expand All @@ -43,6 +47,26 @@ def detect_and_deduplicate(
peak times in samples relative to start of traces, along
with corresponding channels
"""
relative_peak_radius = 5
cumulant_win_size= 11
if cumulant_order is not None:
# TODO: combine.
return detect_and_deduplicate_2d_filters(
traces,
cumulant_order=cumulant_order,
threshold=threshold,
cumulant_win_size=cumulant_win_size,
dedup_channel_index=dedup_channel_index,
peak_sign=peak_sign,
relative_peak_radius=relative_peak_radius,
spatial_dedup_batch_size=spatial_dedup_batch_size,
exclude_edges=exclude_edges,
return_energies=return_energies,
detection_mask=detection_mask,
trough_priority=trough_priority,
)


nsamples, nchans = traces.shape
all_dedup = isinstance(dedup_channel_index, str) and dedup_channel_index == "all"
if not all_dedup and dedup_channel_index is not None:
Expand All @@ -58,16 +82,26 @@ def detect_and_deduplicate(
# no need to copy since max pooling will
energies = traces

# -- torch temporal relative maxima as pooling operation
# we used to implement with max_pool2d -> unique, but
# we can use max_unpool2d to speed up the second step
# temporal max pooling
# we used to implement with max_pool -> unique, but we can use max_unpool
# to speed up the second step temporal max pooling
energies, indices = F.max_pool1d_with_indices(
energies.T.unsqueeze(0),
kernel_size=2 * relative_peak_radius + 1,
stride=1,
padding=relative_peak_radius,
)
# spatial peak criterion
if relative_peak_channel_index is not None:
# we are in 1CT right now
max_energies = F.pad(energies[0], (0, 0, 0, 1))
for batch_start in range(0, nsamples, spatial_dedup_batch_size):
batch_end = batch_start + spatial_dedup_batch_size
torch.amax(
max_energies[relative_peak_channel_index, batch_start:batch_end],
dim=1,
out=max_energies[:nchans, batch_start:batch_end,],
)
energies.masked_fill_(max_energies[:nchans] > energies[0], 0.0)
# unpool will set non-maxima to 0
energies = F.max_unpool1d(
energies,
Expand All @@ -85,7 +119,7 @@ def detect_and_deduplicate(

# -- temporal deduplication
if detection_mask is not None:
energies.mul_(detection_mask.to(energies))
energies.mul_(detection_mask.T.to(energies))
if dedup_temporal_radius:
max_energies = F.max_pool1d(
energies,
Expand All @@ -102,7 +136,7 @@ def detect_and_deduplicate(
# -- spatial deduplication
# this is max pooling within the channel index's neighborhood's
if all_dedup:
max_energies = max_energies.max(dim=1, keepdim=True).values
max_energies = max_energies.amax(dim=1, keepdim=True)
elif dedup_channel_index is not None:
# pad channel axis with extra chan of 0s
max_energies = F.pad(max_energies, (0, 1))
Expand Down Expand Up @@ -136,6 +170,7 @@ def singlechan_template_detect_and_deduplicate(
singlechan_templates,
threshold=40.0,
trough_offset_samples=42,
relative_peak_channel_index=None,
dedup_channel_index=None,
relative_peak_radius=5,
dedup_temporal_radius=7,
Expand Down Expand Up @@ -167,6 +202,7 @@ def singlechan_template_detect_and_deduplicate(
times, chans = detect_and_deduplicate(
obj,
threshold=threshold,
relative_peak_channel_index=relative_peak_channel_index,
dedup_channel_index=dedup_channel_index,
peak_sign="pos",
relative_peak_radius=relative_peak_radius,
Expand All @@ -181,3 +217,248 @@ def singlechan_template_detect_and_deduplicate(
return times, chans, traces[times, chans]

return times, chans

@torch.no_grad() # remove if you need gradients
def compute_sliding_2d_cumulant(data: torch.Tensor, order: int, win_size: int):
"""
Efficient sliding cumulant over 2D windows (mean or variance only) without unfold/view.
Args:
data: (C, H, W) tensor
order: 1 for mean, 2 for variance
win_size: odd kernel size
Returns:
(C, H, W) tensor
"""
if order == 0:
return data
if win_size % 2 == 0:
raise ValueError("win_size must be odd for symmetric padding")
if order not in (1, 2):
raise ValueError("This fast path supports only order=1 (mean) or order=2 (variance).")

C, H, W = data.shape
pad = win_size // 2

# Pad spatial dims with reflect to match your original behavior
x = F.pad(data, (pad, pad, pad, pad), mode='reflect') # (C, H+2p, W+2p)

# avg_pool2d expects (N, C, H, W); use a dummy batch dim
x = x.unsqueeze(0) # (1, C, H+2p, W+2p)

# Local mean via average pooling (stride=1, valid after our manual padding)
mean = F.avg_pool2d(x, kernel_size=win_size, stride=1) # (1, C, H, W)

if order == 1:
return mean.squeeze(0)

# order == 2: variance = E[x^2] - (E[x])^2
ex2 = F.avg_pool2d(x * x, kernel_size=win_size, stride=1) # (1, C, H, W)
var = ex2 - mean * mean

# Numerical guard: tiny negative values to zero due to FP roundoff
var = torch.clamp(var, min=0.0)

return var.squeeze(0)

# def compute_sliding_2d_cumulant(radiality, order, win_size, chunk_size=256):
# """
# Compute sliding cumulant statistics (mean, variance, skewness, kurtosis) over spatial 2D windows,
# processing the W-axis in manageable chunks.
#
# Args:
# radiality: (C, H, W) tensor
# order: cumulant order (1=mean, 2=variance, 3=skewness, 4=kurtosis)
# win_size: size of spatial window (must be odd for symmetry)
# chunk_size: number of W-axis columns to process per chunk (default: 30,000)
#
# Returns:
# Tensor of shape (C, H, W) with the cumulant statistic at each spatial location
# """
# C, H, W = radiality.shape
#
# if win_size % 2 == 0:
# raise ValueError("win_size must be odd for symmetric padding")
#
# padding = win_size // 2
#
# # Pad spatial dimensions
# padded = F.pad(radiality.unsqueeze(1), (padding, padding, padding, padding), mode='reflect') # (C, 1, H+2p, W+2p)
#
# results = []
# start = 0
# while start < W:
# end = min(start + chunk_size, W)
#
# # Extract current chunk with extra padding
# # Pad adds padding to both sides, so for columns start:end in the original,
# # we need columns start:end + 2*padding in padded space
# padded_start = start
# padded_end = end + 2 * padding
#
# chunk = padded[:, :, :, padded_start:padded_end] # (C, 1, H+2p, chunk_width + 2p)
#
# # Unfold to extract sliding windows
# windows = chunk.unfold(2, win_size, 1).unfold(3, win_size, 1) # (C, 1, H, chunk_width, win_size, win_size)
# windows = windows.contiguous().view(C, H, end - start, -1) # (C, H, chunk_width, win_size*win_size)
#
# # Cumulant calculations
# if order == 1:
# result = windows.mean(dim=-1)
# elif order == 2:
# result = windows.var(dim=-1, unbiased=False)
# elif order == 3:
# mean = windows.mean(dim=-1, keepdim=True)
# std = windows.std(dim=-1, unbiased=False, keepdim=True) + 1e-8
# result = (((windows - mean) / std) ** 3).mean(dim=-1)
# elif order == 4:
# mean = windows.mean(dim=-1, keepdim=True)
# std = windows.std(dim=-1, unbiased=False, keepdim=True) + 1e-8
# result = (((windows - mean) / std) ** 4).mean(dim=-1) - 3
# else:
# raise ValueError(f"Unsupported order: {order}")
#
# results.append(result) # (C, H, chunk_width)
# start = end
#
# return torch.cat(results, dim=-1) # (C, H, W)


def detect_and_deduplicate_2d_filters(
traces,
cum_traces=None,
cumulant_order=2,
cumulant_win_size=11,
threshold=2.0,
dedup_channel_index=None, # kept for API; unused here
peak_sign="neg", # "neg" or "both" supported; localization is on troughs
relative_peak_radius=5, # spatial+temporal NMS radius
spatial_dedup_batch_size=512, # kept for API; unused here
exclude_edges=True,
return_energies=False,
detection_mask=None,
trough_priority=None, # kept for API; unused here
# NEW knobs (sane defaults):
future_peak_window=(10, 5), # lookahead (time, space) for the “followed by a peak” criterion
min_contrast_abs=2.0, # require peak - trough >= this (if set)
):
"""
Localize trough minima efficiently while still using a robust energy to gate candidates.
- Gating: cum_traces (std/radiality) >= threshold
- Localization: trough = local 2D minimum of 'traces'
- Validation: a future positive peak exists within 'future_peak_window'
- Dedup: NMS on contrast map within 'relative_peak_radius'
"""
assert traces.dim() == 2, "traces must be (T, C)"
T, C = traces.shape

# Build (1,1,T,C) tensors
tr4 = traces.unsqueeze(0).unsqueeze(0) # (1,1,T,C)

# cumulant computed if needed
if cumulant_order and cum_traces is None:
# expects (C,H,W), so adapt; we just need a fast local std-like map
ct = compute_sliding_2d_cumulant(traces.unsqueeze(0), cumulant_order, cumulant_win_size)
cum_traces = ct.unsqueeze(0) # (1,1,T,C)
elif cum_traces is not None:
cum_traces = cum_traces # assume already (1,1,T,C)
else:
# fall back to a light local absolute deviation proxy using pooling (optional)
pad = cumulant_win_size // 2
mean = F.avg_pool2d(tr4, kernel_size=cumulant_win_size, stride=1, padding=pad)
cum_traces = F.avg_pool2d((tr4 - mean).abs(), kernel_size=cumulant_win_size, stride=1, padding=pad)

# Threshold the cum_traces
thresh_mask = cum_traces >= threshold # (1,1,T,C)
if detection_mask is not None:
dm = detection_mask.to(thresh_mask.dtype).unsqueeze(0).unsqueeze(0)
thresh_mask = thresh_mask & (dm > 0)

# Local trough candidates via 2D min-pooling (i.e., max-pooling on -traces)
k = 2 * relative_peak_radius + 1
neg_tr4 = -tr4
local_neg_max = F.max_pool2d(neg_tr4, kernel_size=k, stride=1, padding=relative_peak_radius)
trough_mask = (neg_tr4 == local_neg_max) # equal to local minimum in original
# Only keep troughs inside the gated regions
trough_mask = trough_mask & thresh_mask

# Positive pixel mask
pos_mask = (tr4 > 0).float() # (1,1,T,C)
neigh_radius = 3

# Neighborhood size
ksize = 2 * neigh_radius + 1
neigh_area = ksize * ksize

# Count positives in each neighborhood (fast via max_pool2d over floats with stride=1)
# For counting, we use avg_pool2d multiplied by area instead of max_pool
pos_frac = F.avg_pool2d(pos_mask, kernel_size=ksize, stride=1, padding=neigh_radius)

# Keep only candidates with <= frac_thresh positive fraction
keep_mask = pos_frac <= 0.3

trough_mask = trough_mask & keep_mask

# 5) “Followed by a peak” + contrast check (fast, causal window)
# future max of positive part within [t, t+W]
pos_tr4 = tr4.clamp_min(0)
fut_max = F.max_pool2d(pos_tr4, kernel_size=future_peak_window, stride=1, padding=(0, 0))
# Align back to (T,C): pad bottom with window to keep same length
fut_max = F.pad(fut_max, (0, future_peak_window[1]-1, 0, future_peak_window[0]-1)) # (1,1,T,C)

# Peak–trough contrast at each (t,c): peak_future - trough_value
trough_val = tr4 # (1,1,T,C), typically negative at troughs
contrast = fut_max - trough_val # higher is better

# Contrast thresholds
contrast_mask = torch.ones_like(trough_mask, dtype=torch.bool)
if min_contrast_abs is not None:
contrast_mask = contrast_mask & (contrast >= min_contrast_abs)

# Keep only troughs that pass contrast checks
trough_mask = trough_mask & contrast_mask

# 6) Non-maximum suppression on contrast to deduplicate troughs
# We only score positions that are troughs; everything else is zeroed.
scored = contrast * trough_mask
local_max = F.max_pool2d(scored, kernel_size=k, stride=1, padding=relative_peak_radius)
keep = (scored == local_max) & (scored > 0)

# 7) Edges
if exclude_edges:
keep[..., 0, :] = False
keep[..., -1, :] = False

# 8) Return indices (+ optional energies = contrast at kept minima)
times, chans = torch.nonzero(keep[0,0], as_tuple=True)

if return_energies:
return times, chans, scored[0,0][times, chans]

# import matplotlib.pyplot as plt
# import matplotlib
# import numpy as np
# matplotlib.use("TkAgg")
# fig, ax = plt.subplots(2, 2, figsize=(12, 12), sharex=True, sharey=True)
# ax[0, 0].set_title("Raw data")
# ax[0, 0].imshow(traces.T, aspect='auto', interpolation='nearest', cmap='seismic', vmin=-5, vmax=5)
# ax[0, 0].set_xlabel('Time (ms)')
# ax[0, 0].set_ylabel('Amplitude (V)')
# ax[0, 1].imshow(cum_traces[0,0].T, aspect='auto', interpolation='nearest', cmap='seismic', vmin=-15, vmax=15)
# ax[0, 1].set_title("Cumulant of raw data")
# ax[0, 1].set_xlabel('Time (ms)')
# ax[0, 1].set_ylabel('Amplitude (V)')
# ax[1, 0].set_title("Localizations over raw data")
# ax[1, 0].imshow(traces.T, aspect='auto', interpolation='nearest', cmap='seismic', vmin=-5, vmax=5)
# ax[1, 0].set_xlabel('Time (ms)')
# ax[1, 0].set_ylabel('Amplitude (V)')
# ax[1, 0].plot(times, chans, 'yo', markersize=6)
# if(detection_mask is not None):
# ax[1, 1].set_title("Localizations over detection mask")
# ax[1, 1].imshow(detection_mask.T, aspect='auto', interpolation='nearest')
# ax[1, 1].set_xlabel('Time (ms)')
# ax[1, 1].set_ylabel('Amplitude (V)')
# ax[1, 1].plot(times, chans, 'yo', markersize=6)
# plt.tight_layout()
# plt.show()

return times, chans
Loading