Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
0f0700b
readme: should be h3
cwindolf May 16, 2026
19a3e32
whiten: save covariance, fitting as a transformer
cwindolf Jun 4, 2026
7d42678
whiten: hook into subtraction
cwindolf Jun 4, 2026
63c6fe9
whiten: debug, handle border nans
cwindolf Jun 4, 2026
f6ced36
Merge branch 'develop' into whiten-initial
cwindolf Jun 4, 2026
4b25b95
clean and add some analysis vis
cwindolf Jun 8, 2026
0dc0f34
cfg
cwindolf Jun 8, 2026
d644dc3
dredge_only cfg broken
cwindolf Jun 8, 2026
9fda04b
fix subtract tests
cwindolf Jun 8, 2026
d8b1b25
cfg
cwindolf Jun 8, 2026
9cc419f
Merge pull request #41 from cwindolf/whiten-initial
cwindolf Jun 8, 2026
aae6b0b
template: batches in grab to cap reduction_template memory
cwindolf Jun 8, 2026
11002ca
eval: full spike label vector
cwindolf Jun 8, 2026
6459d32
left a print
cwindolf Jun 9, 2026
1cfb1c6
eval: debug comparison vis
cwindolf Jun 9, 2026
4b7c09a
clus: raise on numerical issues in SimpleMatrixFeatures
cwindolf Jun 10, 2026
5014755
glom: crash in deduplication for tiny units
cwindolf Jun 10, 2026
97f7675
data: raise for no TPCA
cwindolf Jun 10, 2026
e2a3f64
extend matching debug vis
cwindolf Jun 10, 2026
48cc66d
lint
cwindolf Jun 11, 2026
20124be
glom: numerical case
cwindolf Jun 11, 2026
a265778
kmeans: batched version
cwindolf Jun 11, 2026
8a3ece0
templates: handle case of low numerical rank (probably just a sim thing)
cwindolf Jun 11, 2026
d823719
print
cwindolf Jun 11, 2026
a4957a6
kmeans: reset temperature, gumbel
cwindolf Jun 12, 2026
01b6df8
matching: fix tests, toggle tpca from templates
cwindolf Jun 12, 2026
0b11e8d
kmeans: set beta param
cwindolf Jun 12, 2026
f5e17b8
whitening: impl temporal whitening estimation, config, data plumbing
cwindolf Jun 12, 2026
179968c
whitening: impl needed data transforms
cwindolf Jun 12, 2026
54a529f
subtract: save models earlier so debugging can resume
cwindolf Jun 14, 2026
0185bf2
matching: implement temporal whitening
cwindolf Jun 14, 2026
f1191ce
clus: single split, optional kmeans after dpc
cwindolf Jun 17, 2026
b784f22
glom: fixes from Keshav
cwindolf Jun 17, 2026
8f49f3b
cfg: parametrize count/chan
cwindolf Jun 18, 2026
b9cb7dd
decollider: investigate training in svd land
cwindolf Jun 18, 2026
2e39bfd
clean / fix enfdec accidentally on bug
cwindolf Jun 18, 2026
f9be675
more params needed for demolish vis
cwindolf Jun 20, 2026
f7bae72
gmm: choice to refit within demolish
cwindolf Jun 20, 2026
51dba98
data/glom: add export to SortingAnalyzer, expose some SpikeInterface …
cwindolf Jun 20, 2026
1f32992
Merge branch 'develop' of github.com:cwindolf/dartsort into develop
cwindolf Jun 20, 2026
6aa3c69
vis: impl acg/ccg in ms
cwindolf Jun 20, 2026
7f16e88
glom: gate SpikeInterface merge with coentropy
cwindolf Jun 20, 2026
c6f685a
glom: coentropy mask should be or
cwindolf Jun 21, 2026
608ed5e
vis: debug neighbor ccg
cwindolf Jun 21, 2026
b1bad58
glom: restrict coentropy to rival pairs
cwindolf Jun 22, 2026
198e3e5
cfg: reworking subsampling
cwindolf Jun 22, 2026
2e602a3
vis: better single unit vis
cwindolf Jun 22, 2026
ba4737e
data_util: work on SortingAnalyzer export
cwindolf Jun 22, 2026
5348d56
gmm: can't throw everything away...
cwindolf Jun 22, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ $ pip install dartsort

If you want to run the test suite or use `dartsort.vis`, you can install the optional dependencies with `pip install dartsort[test,vis]`.

## Setting up a Python environment
### Setting up a Python environment

If you need to set up Python or PyTorch, I find that a [`conda-forge`](https://conda-forge.org/)-based distribution is the most reliable at installing the GPU dependencies which PyTorch needs (note: `conda-forge` is different from the non-free Anaconda).

Expand Down
298 changes: 287 additions & 11 deletions src/dartsort/clustering/agglomerate.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ class Agglomeration:
def agglomerate(
*,
sorting: DARTsortSorting,
recording: BaseRecording | None,
recording: BaseRecording,
template_merge_cfg: TemplateMergeConfig | None,
refinement_cfg: RefinementConfig | None,
motion: MotionInfo,
Expand Down Expand Up @@ -104,7 +104,7 @@ def agglomerate(
)

# tdist tells us the possible merges
mask = linkage_mask(
distance_mask = linkage_mask(
tdist.distances,
linkage_method=template_merge_cfg.linkage,
threshold=template_merge_cfg.merge_distance_threshold,
Expand All @@ -117,16 +117,17 @@ def agglomerate(
dt=refinement_cfg.glom_firing_corr_dt,
method=refinement_cfg.glom_firing_corr_method,
)
_oldsum = mask[np.triu_indices_from(mask)].sum()
_oldsum = distance_mask[np.triu_indices_from(distance_mask)].sum()
fcorr_mask = fcorr <= refinement_cfg.glom_max_firing_corr
mask = np.logical_and(mask, fcorr_mask)
mask = np.logical_and(distance_mask, fcorr_mask)
np.fill_diagonal(mask, True)
_newsum = mask[np.triu_indices_from(mask)].sum()
logger.dartsortdebug(
f"Firing corr dropped QDA candidate count from {_oldsum} -> {_newsum}."
)
else:
fcorr = fcorr_mask = None
mask = distance_mask

# restrict mask by overlap criteria
qda_res = qda(
Expand Down Expand Up @@ -158,19 +159,47 @@ def agglomerate(
assert np.all(qda_mask <= mask)
if fcorr_mask is not None:
assert np.all(np.logical_and(qda_mask, fcorr_mask) <= mask)

if refinement_cfg.spikeinterface_merge_preset is not None:
pair_mask = tdist.distances < refinement_cfg.spikeinterface_merge_max_distance
if refinement_cfg.spikeinterface_merge_min_coentropy is not None:
cmask, _ = coentropy_merge_mask(
sorting=sorting,
min_coentropy=refinement_cfg.spikeinterface_merge_min_coentropy,
coverage_threshold=refinement_cfg.spikeinterface_merge_coent_coverage,
iou_threshold=refinement_cfg.spikeinterface_merge_coent_iou,
)
pair_mask = np.logical_or(cmask, pair_mask)

si_mask = spikeinterface_merge_mask(
recording=recording,
sorting=sorting,
preset=refinement_cfg.spikeinterface_merge_preset,
censor_ms=refinement_cfg.censor_ms,
template_data=tdist.template_data,
pair_mask=pair_mask,
)
else:
si_mask = None

# force merges for very close neighbors
force_mask = linkage_mask(
tdist.distances,
linkage_method=template_merge_cfg.linkage,
threshold=refinement_cfg.qda_force_merge_for_temp_dist_below,
)
qda_mask = np.logical_or(qda_mask, force_mask)
np.fill_diagonal(qda_mask, True)
qda_as_dist = np.logical_not(qda_mask).astype(np.float32)

# extract final mask
final_mask = np.logical_or(qda_mask, force_mask)
if si_mask is not None:
final_mask = np.logical_or(final_mask, si_mask)
np.fill_diagonal(final_mask, True)
final_mask_as_distance = np.logical_not(final_mask).astype(np.float32)

agg_sorting, new_ids = recluster(
sorting=sorting,
unit_ids=tdist.template_data.unit_ids,
dists=qda_as_dist,
dists=final_mask_as_distance,
shifts=tdist.shifts,
unit_snrs=tdist.template_data.snrs_by_channel().max(1),
threshold=0.5,
Expand Down Expand Up @@ -352,6 +381,100 @@ def _get_scores(sorting: DARTsortSorting) -> tuple[np.ndarray, Scores]:
return labels, scores


def spikeinterface_merge_mask(
*,
recording: BaseRecording,
sorting: DARTsortSorting,
preset: str | None,
censor_ms: float = 0.0,
template_data: TemplateData,
pair_mask: np.ndarray,
min_count: int = 100,
):
from spikeinterface.curation.auto_merge import compute_merge_unit_groups
from spikeinterface.postprocessing import ComputeTemplateSimilarity

# censor first
if censor_ms:
sorting = deduplicate_spikes(sorting, censor_ms)

# analyzer
analyzer = sorting.to_sorting_analyzer(
recording=recording, template_data=template_data
)

# register the mask as the template similarity extension
tsim_ext = ComputeTemplateSimilarity(analyzer)
tsim_ext.data = {"similarity": pair_mask.astype(np.float32)}
tsim_ext.params = {"method": "dartsort"}
tsim_ext.run_info = {"run_completed": True}
analyzer.extensions["template_similarity"] = tsim_ext

# handle custom presets
if preset == "dartsort_slay_xc":
steps = [
"num_spikes",
"remove_contaminated",
"unit_locations",
"template_similarity",
"slay_score",
"cross_contamination",
"quality_score",
]
preset = None
analyzer.compute_one_extension("correlograms")
elif preset == "dartsort_slay_ccg":
steps = [
"num_spikes",
"remove_contaminated",
"unit_locations",
"template_similarity",
"correlogram",
"slay_score",
"quality_score",
]
preset = None
analyzer.compute_one_extension("correlograms")
elif preset == "dartsort_slay_xc_ccg":
steps = [
"num_spikes",
"remove_contaminated",
"unit_locations",
"template_similarity",
"correlogram",
"cross_contamination",
"slay_score",
"quality_score",
]
preset = None
analyzer.compute_one_extension("correlograms")
else:
assert preset is not None
steps = None

# make parameters aware of censorship and other params
my_step_params = {
"num_spikes": {"min_spikes": min_count},
"remove_contaminated": {"censored_period_ms": censor_ms},
"template_similarity": {"similarity_method": "dartsort"},
"correlogram": {"censor_correlograms_ms": censor_ms},
"cross_contamination": {"censored_period_ms": censor_ms},
"quality_score": {"censored_period_ms": censor_ms},
}
groups = compute_merge_unit_groups(
preset=preset,
steps=steps,
sorting_analyzer=analyzer,
steps_params=my_step_params,
force_copy=False,
)
mask = np.zeros_like(pair_mask)
for g in groups:
g = np.array(g)
mask[g[:, None], g[None, :]] = True
return mask


@databag
class QDAResult:
"""Unit pair QDA metrics
Expand Down Expand Up @@ -665,7 +788,8 @@ def combine_gmm_scores(

# check invariants at the top
if responsibilities.shape[1] > 2:
assert np.all(np.diff(responsibilities[:, :-1], axis=1) <= 0)
_maxdiff = np.diff(responsibilities[:, :-1], axis=1).max()
assert _maxdiff <= 1e-3, _maxdiff
assert np.greater_equal(np.isneginf(logliks[:, :-1]), candidates == -1).all()
if sorting.labels is not None:
assert np.all(
Expand Down Expand Up @@ -698,7 +822,8 @@ def combine_gmm_scores(

# check invariants at the bottom
if mergedr.shape[1] > 2:
assert np.all(np.diff(mergedr[:, : cand.shape[1]], axis=1) <= 0)
_maxdiff = np.diff(mergedr[:, : cand.shape[1]], axis=1).max()
assert _maxdiff <= 1e-3, _maxdiff
assert np.greater_equal(np.isneginf(mergedl[:, : cand.shape[1]]), cand == -1).all()
assert (cand < 0).sum() >= nbye
if sorting.labels is not None:
Expand Down Expand Up @@ -733,7 +858,7 @@ def _combine_loop(
continue

eq_ncandj = rcand[j + 1 :] == ncandj
if eq_ncandj.sum() <= 1:
if eq_ncandj.sum() < 1:
continue

rsum = mergedr[s, j]
Expand Down Expand Up @@ -799,6 +924,8 @@ def deduplicate_spikes(
ndrop = 0
for unit_id in unit_ids:
in_unit = np.flatnonzero(new_labels == unit_id)
if in_unit.size <= 1:
continue
t = times_samples[in_unit]
dt = np.diff(t)
if dt.min() > radius_samples:
Expand Down Expand Up @@ -869,3 +996,152 @@ def _dedup_unit_loop(
break

i0 = i1


@databag
class CoentropyResult:
coentropy: np.ndarray
"""KxK; reduction of entropy per cooccurrence due to merging pair"""

cooccurrence: np.ndarray
"""KxK; number of times these units score the same spike"""

rival_count: np.ndarray
"""KxK; number of times one unit scores a spike where the other is top"""

occurrence: np.ndarray
"""K; number of times the unit appears in the candidates at all"""

cov: np.ndarray
"""KxK; rival count / max pair count (rival diag)"""

iou: np.ndarray
"""KxK; rival count over pair sum"""


def coentropy_merge_mask(
sorting: DARTsortSorting,
min_coentropy: float,
coverage_threshold: float,
iou_threshold: float,
gmm_prefix=("merged", "gmm"),
) -> tuple[np.ndarray, CoentropyResult]:
"""
Parameters
----------
sorting : DARTsortSorting
min_coentropy : float
Must be met by pair for mask=True
min_coverage : float
Pairs such that at least one unit in each pair has
rival_count/count > mincov are allowed
iou_threshold: float
Pairs with rival iou > iouthresh are allowed
"""
c = coentropy(sorting, gmm_prefix=gmm_prefix)
assert c is not None

mask = np.logical_or(c.cov >= coverage_threshold, c.iou >= iou_threshold)
mask = np.logical_and(c.coentropy >= min_coentropy, mask)
np.fill_diagonal(mask, True)
return mask, c


def coentropy(
sorting: DARTsortSorting,
gmm_prefix=("merged", "gmm"),
) -> CoentropyResult | None:
"""Calculate entropy reduction due to merging pairs."""
for k in gmm_prefix:
cands = getattr(sorting, f"{k}_candidates", None)
resps = getattr(sorting, f"{k}_responsibilities", None)
if cands is not None:
assert resps is not None
break
else:
return None

k = sorting.n_units
resps = resps[:, : cands.shape[1]].astype(np.float64)
coentropy = np.zeros((k, k))
cooccurrence = np.zeros((k, k), dtype=np.int64)
rival_count = np.zeros((k, k), dtype=np.int64)
occurrence = np.zeros((k,), dtype=np.int64)
_calc_coentropy(coentropy, cooccurrence, rival_count, occurrence, cands, resps)
rival_count += rival_count.T
cdiag = np.diagonal(rival_count)
assert (cdiag % 2 == 0).all()
np.fill_diagonal(rival_count, cdiag // 2)
coentropy += coentropy.T
cooccurrence += cooccurrence.T

# rival count diagonal is just unit top count (not exactly label count,
# since it doesn't account for noise assignments)
counts = np.diagonal(rival_count)
counts = np.maximum(counts, 1)

cov = rival_count / counts
cov = np.minimum(cov, cov.T)

# this is a disjoint union, since it's the top-label count
union = counts[:, None] + counts[None, :]
iou = rival_count / union

return CoentropyResult(
coentropy=coentropy,
cooccurrence=cooccurrence,
rival_count=rival_count,
occurrence=occurrence,
cov=cov,
iou=iou,
)


@numba.njit(parallel=True)
def _calc_coentropy(
coentropy: np.ndarray,
cooccurrence: np.ndarray,
rival_count: np.ndarray,
occurrence: np.ndarray,
cands: np.ndarray,
resps: np.ndarray,
):
for i in numba.prange(cands.shape[0]): # ty: ignore
u = cands[i]
q = resps[i]
log_q = np.log(q)
np.nan_to_num(log_q, copy=False, neginf=0.0)
dh = q * log_q

ui0 = u[0]
qi0 = q[0]
dhi0 = dh[0]

occurrence[ui0] += 1
rival_count[ui0, ui0] += 1

for j in range(1, cands.shape[1]):
uj = u[j]
if uj < 0:
break

ii = min(ui0, uj)
jj = max(ui0, uj)

occurrence[uj] += 1
rival_count[ui0, uj] += 1

cij = cooccurrence[ii, jj] + 1
cooccurrence[ii, jj] = cij

# change in entropy due to merging uj, uk:
# subtract their current contribution, add the new contribution
# we want reduction of entropy, so this is the negative of that!
qij = q[j] + qi0
dhij = dh[j] + dhi0
if qij > 0:
dhij -= qij * np.log(qij)

# Welford mean of -dh
cur_coent = coentropy[ii, jj]
coentropy[ii, jj] = cur_coent + (-dhij - cur_coent) / cij
Loading
Loading