diff --git a/README.md b/README.md index 4d6a5902..6b457156 100644 --- a/README.md +++ b/README.md @@ -32,7 +32,7 @@ $ pip install dartsort If you want to run the test suite or use `dartsort.vis`, you can install the optional dependencies with `pip install dartsort[test,vis]`. -## Setting up a Python environment +### Setting up a Python environment If you need to set up Python or PyTorch, I find that a [`conda-forge`](https://conda-forge.org/)-based distribution is the most reliable at installing the GPU dependencies which PyTorch needs (note: `conda-forge` is different from the non-free Anaconda). diff --git a/src/dartsort/clustering/agglomerate.py b/src/dartsort/clustering/agglomerate.py index 50dd07c5..f6ad4571 100644 --- a/src/dartsort/clustering/agglomerate.py +++ b/src/dartsort/clustering/agglomerate.py @@ -56,7 +56,7 @@ class Agglomeration: def agglomerate( *, sorting: DARTsortSorting, - recording: BaseRecording | None, + recording: BaseRecording, template_merge_cfg: TemplateMergeConfig | None, refinement_cfg: RefinementConfig | None, motion: MotionInfo, @@ -104,7 +104,7 @@ def agglomerate( ) # tdist tells us the possible merges - mask = linkage_mask( + distance_mask = linkage_mask( tdist.distances, linkage_method=template_merge_cfg.linkage, threshold=template_merge_cfg.merge_distance_threshold, @@ -117,9 +117,9 @@ def agglomerate( dt=refinement_cfg.glom_firing_corr_dt, method=refinement_cfg.glom_firing_corr_method, ) - _oldsum = mask[np.triu_indices_from(mask)].sum() + _oldsum = distance_mask[np.triu_indices_from(distance_mask)].sum() fcorr_mask = fcorr <= refinement_cfg.glom_max_firing_corr - mask = np.logical_and(mask, fcorr_mask) + mask = np.logical_and(distance_mask, fcorr_mask) np.fill_diagonal(mask, True) _newsum = mask[np.triu_indices_from(mask)].sum() logger.dartsortdebug( @@ -127,6 +127,7 @@ def agglomerate( ) else: fcorr = fcorr_mask = None + mask = distance_mask # restrict mask by overlap criteria qda_res = qda( @@ -158,19 +159,47 @@ def agglomerate( assert np.all(qda_mask <= mask) if fcorr_mask is not None: assert np.all(np.logical_and(qda_mask, fcorr_mask) <= mask) + + if refinement_cfg.spikeinterface_merge_preset is not None: + pair_mask = tdist.distances < refinement_cfg.spikeinterface_merge_max_distance + if refinement_cfg.spikeinterface_merge_min_coentropy is not None: + cmask, _ = coentropy_merge_mask( + sorting=sorting, + min_coentropy=refinement_cfg.spikeinterface_merge_min_coentropy, + coverage_threshold=refinement_cfg.spikeinterface_merge_coent_coverage, + iou_threshold=refinement_cfg.spikeinterface_merge_coent_iou, + ) + pair_mask = np.logical_or(cmask, pair_mask) + + si_mask = spikeinterface_merge_mask( + recording=recording, + sorting=sorting, + preset=refinement_cfg.spikeinterface_merge_preset, + censor_ms=refinement_cfg.censor_ms, + template_data=tdist.template_data, + pair_mask=pair_mask, + ) + else: + si_mask = None + + # force merges for very close neighbors force_mask = linkage_mask( tdist.distances, linkage_method=template_merge_cfg.linkage, threshold=refinement_cfg.qda_force_merge_for_temp_dist_below, ) - qda_mask = np.logical_or(qda_mask, force_mask) - np.fill_diagonal(qda_mask, True) - qda_as_dist = np.logical_not(qda_mask).astype(np.float32) + + # extract final mask + final_mask = np.logical_or(qda_mask, force_mask) + if si_mask is not None: + final_mask = np.logical_or(final_mask, si_mask) + np.fill_diagonal(final_mask, True) + final_mask_as_distance = np.logical_not(final_mask).astype(np.float32) agg_sorting, new_ids = recluster( sorting=sorting, unit_ids=tdist.template_data.unit_ids, - dists=qda_as_dist, + dists=final_mask_as_distance, shifts=tdist.shifts, unit_snrs=tdist.template_data.snrs_by_channel().max(1), threshold=0.5, @@ -352,6 +381,100 @@ def _get_scores(sorting: DARTsortSorting) -> tuple[np.ndarray, Scores]: return labels, scores +def spikeinterface_merge_mask( + *, + recording: BaseRecording, + sorting: DARTsortSorting, + preset: str | None, + censor_ms: float = 0.0, + template_data: TemplateData, + pair_mask: np.ndarray, + min_count: int = 100, +): + from spikeinterface.curation.auto_merge import compute_merge_unit_groups + from spikeinterface.postprocessing import ComputeTemplateSimilarity + + # censor first + if censor_ms: + sorting = deduplicate_spikes(sorting, censor_ms) + + # analyzer + analyzer = sorting.to_sorting_analyzer( + recording=recording, template_data=template_data + ) + + # register the mask as the template similarity extension + tsim_ext = ComputeTemplateSimilarity(analyzer) + tsim_ext.data = {"similarity": pair_mask.astype(np.float32)} + tsim_ext.params = {"method": "dartsort"} + tsim_ext.run_info = {"run_completed": True} + analyzer.extensions["template_similarity"] = tsim_ext + + # handle custom presets + if preset == "dartsort_slay_xc": + steps = [ + "num_spikes", + "remove_contaminated", + "unit_locations", + "template_similarity", + "slay_score", + "cross_contamination", + "quality_score", + ] + preset = None + analyzer.compute_one_extension("correlograms") + elif preset == "dartsort_slay_ccg": + steps = [ + "num_spikes", + "remove_contaminated", + "unit_locations", + "template_similarity", + "correlogram", + "slay_score", + "quality_score", + ] + preset = None + analyzer.compute_one_extension("correlograms") + elif preset == "dartsort_slay_xc_ccg": + steps = [ + "num_spikes", + "remove_contaminated", + "unit_locations", + "template_similarity", + "correlogram", + "cross_contamination", + "slay_score", + "quality_score", + ] + preset = None + analyzer.compute_one_extension("correlograms") + else: + assert preset is not None + steps = None + + # make parameters aware of censorship and other params + my_step_params = { + "num_spikes": {"min_spikes": min_count}, + "remove_contaminated": {"censored_period_ms": censor_ms}, + "template_similarity": {"similarity_method": "dartsort"}, + "correlogram": {"censor_correlograms_ms": censor_ms}, + "cross_contamination": {"censored_period_ms": censor_ms}, + "quality_score": {"censored_period_ms": censor_ms}, + } + groups = compute_merge_unit_groups( + preset=preset, + steps=steps, + sorting_analyzer=analyzer, + steps_params=my_step_params, + force_copy=False, + ) + mask = np.zeros_like(pair_mask) + for g in groups: + g = np.array(g) + mask[g[:, None], g[None, :]] = True + return mask + + @databag class QDAResult: """Unit pair QDA metrics @@ -665,7 +788,8 @@ def combine_gmm_scores( # check invariants at the top if responsibilities.shape[1] > 2: - assert np.all(np.diff(responsibilities[:, :-1], axis=1) <= 0) + _maxdiff = np.diff(responsibilities[:, :-1], axis=1).max() + assert _maxdiff <= 1e-3, _maxdiff assert np.greater_equal(np.isneginf(logliks[:, :-1]), candidates == -1).all() if sorting.labels is not None: assert np.all( @@ -698,7 +822,8 @@ def combine_gmm_scores( # check invariants at the bottom if mergedr.shape[1] > 2: - assert np.all(np.diff(mergedr[:, : cand.shape[1]], axis=1) <= 0) + _maxdiff = np.diff(mergedr[:, : cand.shape[1]], axis=1).max() + assert _maxdiff <= 1e-3, _maxdiff assert np.greater_equal(np.isneginf(mergedl[:, : cand.shape[1]]), cand == -1).all() assert (cand < 0).sum() >= nbye if sorting.labels is not None: @@ -733,7 +858,7 @@ def _combine_loop( continue eq_ncandj = rcand[j + 1 :] == ncandj - if eq_ncandj.sum() <= 1: + if eq_ncandj.sum() < 1: continue rsum = mergedr[s, j] @@ -799,6 +924,8 @@ def deduplicate_spikes( ndrop = 0 for unit_id in unit_ids: in_unit = np.flatnonzero(new_labels == unit_id) + if in_unit.size <= 1: + continue t = times_samples[in_unit] dt = np.diff(t) if dt.min() > radius_samples: @@ -869,3 +996,152 @@ def _dedup_unit_loop( break i0 = i1 + + +@databag +class CoentropyResult: + coentropy: np.ndarray + """KxK; reduction of entropy per cooccurrence due to merging pair""" + + cooccurrence: np.ndarray + """KxK; number of times these units score the same spike""" + + rival_count: np.ndarray + """KxK; number of times one unit scores a spike where the other is top""" + + occurrence: np.ndarray + """K; number of times the unit appears in the candidates at all""" + + cov: np.ndarray + """KxK; rival count / max pair count (rival diag)""" + + iou: np.ndarray + """KxK; rival count over pair sum""" + + +def coentropy_merge_mask( + sorting: DARTsortSorting, + min_coentropy: float, + coverage_threshold: float, + iou_threshold: float, + gmm_prefix=("merged", "gmm"), +) -> tuple[np.ndarray, CoentropyResult]: + """ + Parameters + ---------- + sorting : DARTsortSorting + min_coentropy : float + Must be met by pair for mask=True + min_coverage : float + Pairs such that at least one unit in each pair has + rival_count/count > mincov are allowed + iou_threshold: float + Pairs with rival iou > iouthresh are allowed + """ + c = coentropy(sorting, gmm_prefix=gmm_prefix) + assert c is not None + + mask = np.logical_or(c.cov >= coverage_threshold, c.iou >= iou_threshold) + mask = np.logical_and(c.coentropy >= min_coentropy, mask) + np.fill_diagonal(mask, True) + return mask, c + + +def coentropy( + sorting: DARTsortSorting, + gmm_prefix=("merged", "gmm"), +) -> CoentropyResult | None: + """Calculate entropy reduction due to merging pairs.""" + for k in gmm_prefix: + cands = getattr(sorting, f"{k}_candidates", None) + resps = getattr(sorting, f"{k}_responsibilities", None) + if cands is not None: + assert resps is not None + break + else: + return None + + k = sorting.n_units + resps = resps[:, : cands.shape[1]].astype(np.float64) + coentropy = np.zeros((k, k)) + cooccurrence = np.zeros((k, k), dtype=np.int64) + rival_count = np.zeros((k, k), dtype=np.int64) + occurrence = np.zeros((k,), dtype=np.int64) + _calc_coentropy(coentropy, cooccurrence, rival_count, occurrence, cands, resps) + rival_count += rival_count.T + cdiag = np.diagonal(rival_count) + assert (cdiag % 2 == 0).all() + np.fill_diagonal(rival_count, cdiag // 2) + coentropy += coentropy.T + cooccurrence += cooccurrence.T + + # rival count diagonal is just unit top count (not exactly label count, + # since it doesn't account for noise assignments) + counts = np.diagonal(rival_count) + counts = np.maximum(counts, 1) + + cov = rival_count / counts + cov = np.minimum(cov, cov.T) + + # this is a disjoint union, since it's the top-label count + union = counts[:, None] + counts[None, :] + iou = rival_count / union + + return CoentropyResult( + coentropy=coentropy, + cooccurrence=cooccurrence, + rival_count=rival_count, + occurrence=occurrence, + cov=cov, + iou=iou, + ) + + +@numba.njit(parallel=True) +def _calc_coentropy( + coentropy: np.ndarray, + cooccurrence: np.ndarray, + rival_count: np.ndarray, + occurrence: np.ndarray, + cands: np.ndarray, + resps: np.ndarray, +): + for i in numba.prange(cands.shape[0]): # ty: ignore + u = cands[i] + q = resps[i] + log_q = np.log(q) + np.nan_to_num(log_q, copy=False, neginf=0.0) + dh = q * log_q + + ui0 = u[0] + qi0 = q[0] + dhi0 = dh[0] + + occurrence[ui0] += 1 + rival_count[ui0, ui0] += 1 + + for j in range(1, cands.shape[1]): + uj = u[j] + if uj < 0: + break + + ii = min(ui0, uj) + jj = max(ui0, uj) + + occurrence[uj] += 1 + rival_count[ui0, uj] += 1 + + cij = cooccurrence[ii, jj] + 1 + cooccurrence[ii, jj] = cij + + # change in entropy due to merging uj, uk: + # subtract their current contribution, add the new contribution + # we want reduction of entropy, so this is the negative of that! + qij = q[j] + qi0 + dhij = dh[j] + dhi0 + if qij > 0: + dhij -= qij * np.log(qij) + + # Welford mean of -dh + cur_coent = coentropy[ii, jj] + coentropy[ii, jj] = cur_coent + (-dhij - cur_coent) / cij diff --git a/src/dartsort/clustering/clustering.py b/src/dartsort/clustering/clustering.py index b4eaf3ae..b940a3fa 100644 --- a/src/dartsort/clustering/clustering.py +++ b/src/dartsort/clustering/clustering.py @@ -19,7 +19,16 @@ from ..util.main_util import ds_save_intermediate_labels from ..util.motion import MotionInfo from ..util.multiprocessing_util import handle_negative_jobs -from . import agglomerate, cluster_util, density, forward_backward, mixture, refine_util +from ..util.torch_util import cleanup_and_log_gpu_usage +from . import ( + agglomerate, + cluster_util, + density, + forward_backward, + kmeans, + mixture, + refine_util, +) from .clustering_features import SimpleMatrixFeatures, StableWaveformFeatures if TYPE_CHECKING: @@ -292,6 +301,10 @@ def __init__( workers=-1, uhdversion=False, random_seed=0, + kmeans_cleanup=True, + kmeans_max_sigma: float = 5.0, + kmeans_iter=100, + device=None, **kwargs, ): super().__init__(**kwargs) @@ -308,6 +321,10 @@ def __init__( self.workers = workers self.uhdversion = uhdversion self.random_seed = random_seed + self.kmeans_cleanup = kmeans_cleanup + self.kmeans_max_sigma = kmeans_max_sigma + self.kmeans_iter = kmeans_iter + self.device = device @classmethod def from_config( @@ -335,6 +352,7 @@ def from_config( outlier_radius=clustering_cfg.outlier_radius, outlier_neighbor_count=clustering_cfg.outlier_neighbor_count, workers=workers, + device=computation_cfg.actual_device(), uhdversion=uhdversion, computation_cfg=computation_cfg, waveform_cfg=waveform_cfg, @@ -342,6 +360,9 @@ def from_config( save_labels_dir=save_labels_dir, labels_fmt=labels_fmt, sampling_cfg=clustering_cfg.sampling_cfg, + kmeans_cleanup=clustering_cfg.dpc_kmeans_cleanup, + kmeans_iter=clustering_cfg.kmeans_iter, + kmeans_max_sigma=clustering_cfg.gmmdpc_max_sigma, ) def _cluster( @@ -425,6 +446,19 @@ def _cluster_extra( kdtree = None labels = res["labels"] + if self.kmeans_cleanup: + kres = kmeans.truncated_kmeans_from_labels( + X=X, + labels=labels, + device=self.device, + max_sigma=self.kmeans_max_sigma, + n_iter=self.kmeans_iter, + ) + assert kres.labels is not None + labels = kres.labels.numpy(force=True) + del kres + cleanup_and_log_gpu_usage(self.computation_cfg, "DPC->kmeans") + labels = cluster_util.decrumb( labels, min_size=self.remove_clusters_smaller_than, in_place=True ) @@ -821,6 +855,7 @@ def _refine( recording: BaseRecording | None, motion: MotionInfo, ): + assert recording is not None return agglomerate.agglomerate( recording=recording, sorting=sorting, diff --git a/src/dartsort/clustering/clustering_features.py b/src/dartsort/clustering/clustering_features.py index a53789a1..5b633ec3 100644 --- a/src/dartsort/clustering/clustering_features.py +++ b/src/dartsort/clustering/clustering_features.py @@ -25,15 +25,6 @@ from ..util.waveform_util import single_channel_index from . import cluster_util -default_clustering_features_cfg = ClusteringFeaturesConfig() -minimal_features_cfg = ClusteringFeaturesConfig( - n_main_channel_pcs=0, - use_amplitude=False, - use_signed_amplitude=False, - use_x=False, - use_z=False, -) - logger = get_logger(__name__) @@ -77,8 +68,14 @@ def from_config( ) if xyza is not None: x = xyza[:, 0] + if not _allfinite(x): + raise ValueError(_numbers_error_str("x", x)) z = xyza[:, 2] + if not _allfinite(z): + raise ValueError(_numbers_error_str("z", z)) z_reg = motion.correct_s(t_s, z) + if not _allfinite(z_reg): + raise ValueError(_numbers_error_str("z_reg", z_reg)) else: x = z = z_reg = None @@ -99,10 +96,14 @@ def from_config( amp = getattr(sorting, clustering_features_cfg.amplitudes_dataset_name) if clustering_features_cfg.use_amplitude: assert amp is not None + if not _allfinite(amp): + raise ValueError(_numbers_error_str("amp", amp)) ampft = amp.copy() if clustering_features_cfg.log_transform_amplitude: ampft = np.log(clustering_features_cfg.amp_log_c + ampft) ampft *= clustering_features_cfg.amp_scale + if not _allfinite(ampft): + raise ValueError(_numbers_error_str("ampft", ampft)) features.append(ampft[:, None]) v = getattr(sorting, clustering_features_cfg.voltages_dataset_name, None) @@ -123,6 +124,8 @@ def from_config( rank=clustering_features_cfg.n_main_channel_pcs, dataset_name=clustering_features_cfg.pca_dataset_name, ) + if not _allfinite(pcs): + raise ValueError(_numbers_error_str("No motion pcs", pcs)) elif do_pcs and clustering_features_cfg.motion_aware: shifts, n_pitches_shift = motion.pitch_shifts( sorting=sorting, @@ -141,6 +144,8 @@ def from_config( mask = np.broadcast_to(mask, len(schan)) if hasattr(sorting, clustering_features_cfg.pca_dataset_name): pcs = getattr(sorting, clustering_features_cfg.pca_dataset_name) + if not _allfinite(pcs): + raise ValueError(_numbers_error_str("sorting pcs", pcs)) erp, pcs = interpolate_by_chunk( mask=mask, dataset=pcs, @@ -153,6 +158,8 @@ def from_config( params=clustering_features_cfg.interp_params, ) pcs = pcs[:, : clustering_features_cfg.n_main_channel_pcs, 0] + if not _allfinite(pcs): + raise ValueError(_numbers_error_str("sorting interp pcs", pcs)) else: assert sorting.parent_h5_path is not None with h5py.File(sorting.parent_h5_path, "r", locking=False) as h5: @@ -168,20 +175,25 @@ def from_config( params=clustering_features_cfg.interp_params, ) pcs = pcs[:, : clustering_features_cfg.n_main_channel_pcs, 0] + if not _allfinite(pcs): + raise ValueError(_numbers_error_str("h5 interp pcs", pcs)) if do_pcs: assert pcs is not None - if clustering_features_cfg.pc_transform == "log": + pctf = clustering_features_cfg.pc_transform + if pctf == "log": pcs = signed_log1p( pcs, pre_scale=clustering_features_cfg.pc_pre_transform_scale ) - elif clustering_features_cfg.pc_transform == "sqrt": + elif pctf == "sqrt": pcs = signed_sqrt_transform( pcs, pre_scale=clustering_features_cfg.pc_pre_transform_scale ) else: - assert clustering_features_cfg.pc_transform in ("none", None) + assert pctf in ("none", None) pcs *= clustering_features_cfg.pc_scale + if not _allfinite(pcs): + raise ValueError(_numbers_error_str(f"{pctf} pcs", pcs)) if torch.is_tensor(pcs): pcs = pcs.numpy(force=True) features.append(pcs) @@ -310,3 +322,16 @@ def signed_sqrt_transform(x, pre_scale=1.0): xx.sub_(1.0) xx.mul_(torch.sign(x)) return xx + + +def _allfinite(x): + if isinstance(x, torch.Tensor): + return x.isfinite().all() + else: + return np.isfinite(x).all() + + +def _numbers_error_str(name: str, x: np.ndarray): + if isinstance(x, torch.Tensor): + x = x.numpy(force=True) + return f"{name}: {np.isposinf(x).sum()} +inf, {np.isneginf(x).sum()} -inf, {np.isnan(x).sum()} nan." diff --git a/src/dartsort/clustering/kmeans.py b/src/dartsort/clustering/kmeans.py index 52e38327..0809b82c 100644 --- a/src/dartsort/clustering/kmeans.py +++ b/src/dartsort/clustering/kmeans.py @@ -3,6 +3,8 @@ import torch.nn.functional as F from torch import Tensor +from dartsort.util.py_util import databag + try: import cupy # type: ignore # ty: ignore[x] @@ -33,7 +35,6 @@ def kmeanspp( n_components=10, random_state: np.random.Generator | torch.Generator | int = 0, kmeanspp_initial="random", - mode_dim=2, skip_assignment=False, min_distance=None, Xnormsq: Tensor | None = None, @@ -412,6 +413,14 @@ def kmeans_inner( return assignments, e, centroids, dists +@databag +class KMeansResult: + labels: Tensor | None + responsibilities: Tensor | None + centroids: Tensor | None + dists: Tensor | None + + def kmeans( X: Tensor, n_kmeans_tries=5, @@ -421,11 +430,11 @@ def kmeans( random_state: np.random.Generator | torch.Generator | int = 0, kmeanspp_initial="random", with_proportions=False, - drop_prop=0.025, - drop_sum=5.0, + drop_prop=0.0, + drop_sum=0.0, weights: Tensor | None = None, test_convergence_every=10, -): +) -> KMeansResult: best_phi = np.inf if isinstance(random_state, int): random_state = np.random.default_rng(random_state) @@ -461,7 +470,7 @@ def kmeans( assignments = aa.clone() e = ee centroids = cc - return dict( + return KMeansResult( labels=assignments, responsibilities=e, centroids=centroids, dists=dists ) @@ -614,3 +623,285 @@ def _kmeans_main_loop( break dists = _sqeuc(X, Xnormsq, centroids, dists) return e, centroids, dists, proportions + + +def batched_kmeans( + X: Tensor, + n_components: int, + seed: torch.Generator | np.random.Generator | int = 0, + n_iter: int = 100, + kmeanspp_seeds_per_try: int = 5, + n_tries: int = 10, + test_convergence_every=10, + atol=1e-5, + with_labels=True, + with_proportions=True, + beta: float = 1.0, +) -> KMeansResult: + """ + Compared to above: + - with_proportions = True + - drop_prop = 0 + - no weights allowed for now + - kmeanspp always random initial + - n_iter > 0 + """ + k = n_components + del n_components + assert n_iter > 0 + assert n_tries > 0 + assert k > 1 + assert kmeanspp_seeds_per_try >= 1 + + dev = X.device + gen = spawn_torch_rg(seed, device=dev) + Xnormsq = torch.linalg.norm(X, dim=1).square_() + n, dim = X.shape + k = min(n, k) + assert n >= k > 1 + ntries_k = n_tries * k + n_kmeanspps = n_tries * kmeanspp_seeds_per_try + + # -- kmeanspp stage: initialization + centroid_ixs = torch.full((k, n_kmeanspps), n, dtype=torch.long, device=dev) + centroid_ixs[0] = torch.randint(n, size=(n_kmeanspps,), device=dev, generator=gen) + dists = X.new_empty((n_kmeanspps, n)) + Y = X[centroid_ixs[0]] + Ynormsq = Xnormsq[centroid_ixs[0]] + dists = sqeuc_cdist_known_norm(Y, Ynormsq, X, Xnormsq, dists) + + # -- kmeanspp stage: loop + # buf for random sampling with Gumbel trick and new distance storage + _buf = torch.empty_like(dists) + for j in range(1, k): + # sample new centroid indices wppt dists (which is squared) + # gumbel max: argmax [log(d) + -log(-log(u))] + # = argmax d / (-log u) = argmin d / (log u) + u = _buf.uniform_(generator=gen).log_() + u = torch.div(dists, u, out=u) + cix_j = torch.argmin(u, dim=1, out=centroid_ixs[j]) + + # grab jth centroid data + torch.index_select(X, dim=0, index=cix_j, out=Y) + torch.index_select(Xnormsq, dim=0, index=cix_j, out=Ynormsq) + + # update distances + newdists = sqeuc_cdist_known_norm(Y, Ynormsq, X, Xnormsq, _buf) + torch.minimum(dists, newdists, out=dists) + + # -- kmeanspp finish: pick best by phi + if kmeanspp_seeds_per_try > 1: + phi = dists.mean(1).view(n_tries, kmeanspp_seeds_per_try) + best_kmpp = phi.argmin(1, keepdim=True) + centroid_ixs = centroid_ixs.view(k, n_tries, kmeanspp_seeds_per_try) + centroid_ixs = centroid_ixs.take_along_dim(dim=2, indices=best_kmpp[None, :, :]) + assert centroid_ixs.shape == (k, n_tries, 1) + centroid_ixs = centroid_ixs[:, :, 0] + assert centroid_ixs.shape == (k, n_tries) + centroid_ixs = centroid_ixs.T.contiguous() + + # -- kmeans stage: initialization + Y = X[centroid_ixs].view(ntries_k, dim) + Ynormsq = Xnormsq[centroid_ixs].view(ntries_k) + dists = dists.resize_(n, ntries_k) + e = X.new_empty((n, n_tries, k)) + N = X.new_ones((n_tries, k)) + Ntot = X.new_full((), float(n)) + log_props = X.new_zeros((n_tries, k)) + phi = phi_ = e.new_full((n_tries,), torch.nan) + + # -- kmeans stage: loop + check = False + for j in range(n_iter): + jmod = j % test_convergence_every + check = (j in (0, n_iter - 1)) or (jmod in (0, test_convergence_every - 1)) + + # e step + dists = sqeuc_cdist_known_norm(X, Xnormsq, Y, Ynormsq, dists.view(n, ntries_k)) + dists = dists.view(n, n_tries, k) + e = torch.add(log_props, dists, alpha=-0.5 * beta, out=e) + e = F.softmax(e, dim=2) + if check: + phi_ = dists.mul_(e).mean(0).sum(1) + assert phi_.shape == phi.shape + + # m step + N = torch.sum(e, dim=0, out=N) + if with_proportions: + torch.div(N, Ntot, out=log_props).log_() + w = e.div_(N) + Y = torch.mm(w.view(n, ntries_k).t(), X, out=Y) + Ynormsq = torch.linalg.vector_norm(Y, dim=1, out=Ynormsq) + Ynormsq.square_() + + # check convergence + if check: + done = torch.allclose(phi, phi_, atol=atol) or phi_.max() < atol + phi = phi_ + if done: + break + assert check # => phi is up to date + + # -- kmeans finish: pick best kmeans by phi, update responsibilities + best = phi.argmin() + Y = Y.view(n_tries, k, dim)[best] + Ynormsq = Ynormsq.view(n_tries, k)[best] + log_props = log_props[best] + dists = dists.resize_(n, k) + e = e.resize_(n, k) + dists = sqeuc_cdist_known_norm(X, Xnormsq, Y, Ynormsq, dists) + e = torch.add(log_props, dists, alpha=-0.5 * beta, out=e) + e = F.softmax(e, dim=1) + if with_labels: + labels = e.argmax(dim=1) + else: + labels = None + + return KMeansResult(centroids=Y, responsibilities=e, labels=labels, dists=dists) + + +def truncated_kmeans_from_labels( + X: Tensor | np.ndarray, + labels: Tensor | np.ndarray, + device=None, + atol=1e-3, + max_sigma=5.0, + dirichlet_alpha=1.0, + n_iter=100, + show_progress: bool = True, + batch_size: int = 4096, + centroid_dist_batch_size: int = 128, + min_log_prop=-25.0, + trunc_guess=20, + initial_undershoot=2.0, +) -> KMeansResult: + X = torch.asarray(X, device=device) + labels = torch.asarray(labels, device=device) + n, dim = X.shape + assert labels.shape == (n,) + is_gpu = X.device.type == "cuda" + + # flatten and count labels + ulabels, labels = labels.unique(return_inverse=True) + k = ulabels.shape[0] + del ulabels + + # initialize parameters + e = F.one_hot(labels, k).to(X) + log_props = e.mean(0).log_() + N = e.sum(dim=0) + w = e.div_(N) + Y = torch.mm(w.view(n, k).t(), X) + Ynormsq = torch.linalg.vector_norm(Y, dim=1).square_() + + # initialize sigma + sigmasq = X.new_zeros(()) + bY = X.new_empty((batch_size, dim)) + for i0 in range(0, n, batch_size): + i1 = min(n, i0 + batch_size) + bY = torch.index_select(Y, 0, labels[i0:i1], out=bY[: i1 - i0]) + bsigsq = bY.sub_(X[i0:i1]).square_().mean() + sigmasq += (bsigsq - sigmasq) * ((i1 - i0) / i1) + del bY + assert sigmasq > 0 + sigma = initial_undershoot * sigmasq.sqrt_() + assert sigma.isfinite().item() + prev_sigma = sigma.clone() + + # storage + dYY = X.new_zeros((k, k)) + dYYmask = X.new_zeros((k, k), dtype=torch.bool) + new_Y = torch.empty_like(Y) + distsq_buf = torch.zeros_like(X[: min(trunc_guess, k) * batch_size]) + distsq_buf = (distsq_buf, torch.zeros_like(distsq_buf)) + + if show_progress: + it = progrange(n_iter, desc=f"kmeans σ={sigma:0.4f}") + else: + it = range(n_iter) + + done = False + for j in it: + done = done or j == n_iter - 1 + max_distance_sq = max_sigma * sigmasq * dim + + new_Y.fill_(0.0) + N.fill_(0.0) + new_sigmasq = 0.0 + weight = 0.0 + + # update centroid dists + for i0 in range(0, k, centroid_dist_batch_size): + i1 = min(k, i0 + centroid_dist_batch_size) + sqeuc_cdist_known_norm(Y[i0:i1], Ynormsq[i0:i1], Y, Ynormsq, out=dYY[i0:i1]) + torch.lt(dYY, max_distance_sq, out=dYYmask) + + for i0 in range(0, n, batch_size): + i1 = min(n, i0 + batch_size) + distsq_coo, distsq_buf = sparse_centroid_distsq( + X[i0:i1], + Y, + labels=labels[i0:i1], + centroid_mask=dYYmask, + dbufs=distsq_buf, + ) + assert distsq_coo.shape == (i1 - i0, k) + distsq_values = distsq_coo.values().clone() + liks = distsq_to_lik_coo(distsq_coo, sigmasq, log_props, in_place=True) + del distsq_coo + + resps = torch.sparse.softmax(liks, dim=1) + # update labels... torch sparse has no argmax(), so need scipy + # or cupy. scipy is a big slowdown here, so cupy if possible. + if is_gpu and HAVE_CUPY: + resps_cupy = coo_to_cupy(resps).tocsc() + batch_labels = resps_cupy.argmax(axis=1) + else: + resps_scipy = coo_to_scipy(resps) + batch_labels = resps_scipy.argmax(axis=1, explicit=True) + labels[i0:i1] = torch.as_tensor(batch_labels).to(labels).squeeze() + + # get sigmasq + w = resps.values().clone() + batch_w = w.sum() + w /= batch_w + batch_sigmasq = torch.sum(distsq_values.mul_(w)) / dim + + # get N and centroids + batch_N = resps.sum(dim=0).to_dense() + resps.values().div_(batch_N[resps.indices()[1]]) + batch_centroids = resps.T @ X[i0:i1] + + # update counts + N += batch_N + weight += batch_w + + # update Welford running means + n1_n01 = batch_N.div_(N.clip(min=1e-5))[:, None] + w1_w01 = batch_w / weight + new_Y += batch_centroids.sub_(new_Y).mul_(n1_n01) + new_sigmasq += batch_sigmasq.sub_(new_sigmasq).mul_(w1_w01) + + # update state + logN = N.log_() + dirichlet_alpha + log_props = F.log_softmax(logN, dim=0).to(X) + log_props = log_props.clamp_(min=min_log_prop) + Y, new_Y = new_Y, Y + Ynormsq = torch.linalg.vector_norm(Y, dim=1).square_() + sigmasq = new_sigmasq + + # check convergence + sigma = torch.sqrt(sigmasq).numpy(force=True).item() # type: ignore + if abs(sigma - prev_sigma) < atol: + break + + prev_sigma = sigma + if show_progress: + it.set_description(f"kmeans σ={sigma:0.4f}") # type: ignore + + return KMeansResult( + labels=labels, + responsibilities=None, + centroids=Y, + dists=None, + ) diff --git a/src/dartsort/clustering/mixture.py b/src/dartsort/clustering/mixture.py index 89f3497a..352b9dda 100644 --- a/src/dartsort/clustering/mixture.py +++ b/src/dartsort/clustering/mixture.py @@ -55,7 +55,6 @@ import numpy as np import torch import torch.nn.functional as F -from packaging.version import Version from scipy.sparse.csgraph import connected_components from sympy.utilities.iterables import multiset_partitions, subsets from torch import Tensor @@ -88,6 +87,7 @@ from ..util.noise_util import EmbeddedNoise from ..util.py_util import databag from ..util.spiketorch import ( + _nonzero_static, cosine_distance, ecl, entropy, @@ -100,14 +100,7 @@ from ..util.torch_util import BModule, torch_compiler from .cluster_util import linkage, maximal_leaf_groups from .clustering_features import StableWaveformFeatures -from .kmeans import kmeans - -TORCH_IS_OLD = Version(torch.__version__) < Version("2.6.0") -if TORCH_IS_OLD and torch.cuda.is_available(): - warnings.warn( - f"Your PyTorch version ({torch.__version__}) is supported by dartsort, " - "but dartsort would be faster if you had >= 2.6.0." - ) +from .kmeans import batched_kmeans, kmeans if TYPE_CHECKING: from ..transform.temporal_pca import BaseTemporalPCA @@ -185,8 +178,14 @@ def tmm_demix( allow_blanks = False for outer_it in range(refinement_cfg.n_total_iters): for inner_it, step_type in enumerate(refinement_cfg.mixture_steps): - if step_type == "split": - run_split(tmm, train_data, val_data, prog_level) + if step_type.endswith("split"): + run_split( + tmm, + train_data, + val_data, + prog_level, + single=step_type.startswith("single"), + ) tmm.em( train_data, show_progress=prog_level, @@ -795,6 +794,7 @@ class TMMParams: split_max_distance: float merge_max_distance: float split_k: int + single_split_k: int distance_kind: ComponentDistanceMetric em_iters: int min_em_iters: int @@ -813,7 +813,9 @@ class TMMParams: robust_strategy: RobustnessStrategy demolition_min_resp_ratio: float demolish_during_selection: bool + refit_in_demolition: bool kmeans_tries: int + kmeans_beta: float kmeanspp_tries: int whiten_split: bool scale_dist_args: tuple[float, float, float] @@ -830,7 +832,9 @@ def from_refinement_cfg(cls, refinement_cfg: RefinementConfig): min_count=refinement_cfg.min_count, split_min_count=refinement_cfg.split_min_count, split_k=refinement_cfg.kmeansk, + single_split_k=refinement_cfg.single_split_k, kmeans_tries=refinement_cfg.kmeans_tries, + kmeans_beta=refinement_cfg.kmeans_beta, kmeanspp_tries=refinement_cfg.kmeanspp_tries, min_channel_count=refinement_cfg.channels_count_min, em_iters=refinement_cfg.n_em_iters, @@ -847,6 +851,7 @@ def from_refinement_cfg(cls, refinement_cfg: RefinementConfig): robust_strategy=refinement_cfg.robust_strategy, demolition_min_resp_ratio=refinement_cfg.demolition_min_resp_ratio, demolish_during_selection=refinement_cfg.demolish_during_selection, + refit_in_demolition=refinement_cfg.refit_in_demolition, whiten_split=refinement_cfg.whiten_split, scale_dist_args=refinement_cfg.scale_dist_args, whiten_dist=refinement_cfg.whiten_dist, @@ -1813,7 +1818,7 @@ def dense_slice_by_unit( gen: torch.Generator, labels: Tensor | None, min_count: int = 0, - ): + ) -> DenseSpikeData | None: assert self.candidates is not None assert unit_ids is not None unit_ids = torch.as_tensor(unit_ids, device=self.candidates.device) @@ -2346,6 +2351,7 @@ def initialize_from_dense_data_with_fixed_responsibilities( total_log_proportion: float, noise_log_prop: Tensor | float = -torch.inf, p: TMMParams, + min_channel_count: int | None = None, ) -> tuple[Self, Tensor, DenseSpikeData, bool, Tensor, Tensor]: """Fit units with fixed label posterior @@ -2358,14 +2364,16 @@ def initialize_from_dense_data_with_fixed_responsibilities( data=data, rank=signal_rank, erp=erp, - min_channel_count=p.min_channel_count, + min_channel_count=p.min_channel_count + if min_channel_count is None + else min_channel_count, noise=noise, weights=responsibilities, latent_prior_std=p.latent_prior_std, prior_pseudocount=p.prior_pseudocount, ) assert chan_coverage is not None - valid = torch.tensor([r is not None for r in initialization]) + valid = torch.tensor([r is not None for r in initialization], dtype=torch.bool) initialization = [r for r in initialization if r is not None] responsibilities = responsibilities[:, valid] K = responsibilities.shape[1] @@ -2674,13 +2682,14 @@ def score( *, skip_noise: bool = False, allow_blanks: bool = False, + skip_responsibility: bool = True, ) -> Scores: scores = [] for batch in data.to_batches(self.unit_ids, self.lut): batch_scores = self.score_batch( batch=batch, n_candidates=batch.candidates.shape[1], - skip_responsibility=True, + skip_responsibility=skip_responsibility, skip_noise=skip_noise, allow_blanks=allow_blanks, ) @@ -2933,7 +2942,10 @@ def split_group( group_size = group.numel() single = group_size == 1 - k = min(self.p.max_group_size, self.p.split_k) + if single: + k = self.p.single_split_k + else: + k = self.p.split_k # get dense train set slice in group split_data = train_data.dense_slice_by_unit( @@ -2954,6 +2966,7 @@ def split_group( min_count=self.p.split_min_count, min_channel_count=self.p.min_channel_count, n_kmeans_tries=self.p.kmeans_tries, + kmeans_beta=self.p.kmeans_beta, n_kmeanspp_tries=self.p.kmeanspp_tries, bail_at=int(single), weights=split_data.duties, @@ -3184,11 +3197,14 @@ def split( train_scores: Scores, eval_scores: Scores, show_progress: bool = True, + friend_distance: float | None = None, _stop_after: int | None = None, _dry_run: bool = False, ) -> SplitResult: + if friend_distance is None: + friend_distance = self.p.split_friend_distance split_groups = self.group_units_by_distance( - distance=self.p.split_friend_distance, + distance=friend_distance, max_group_size=max(1, self.p.split_k - 1), ) if _stop_after: @@ -3460,6 +3476,8 @@ def demolish( train_scores=train_scores, eval_scores=val_scores, cur_crit=cur_crit, + train_data=train_data, + eval_data=val_data, ) if _stop_after and j > _stop_after: # for profiling @@ -4469,6 +4487,7 @@ def run_split( train_data: TruncatedSpikeData, val_data: TruncatedSpikeData | None, prog_level: int, + single: bool = False, _stop_after: int | None = None, _dry_run: bool = False, ): @@ -4493,6 +4512,7 @@ def run_split( eval_scores=eval_scores, train_scores=train_scores, show_progress=prog_level > 0, + friend_distance=0.0 if single else None, _stop_after=_stop_after, _dry_run=_dry_run, ) @@ -5029,10 +5049,13 @@ def all_demolished_partitions( can_demolish_mask.shape == part.unit_ids.shape == part.group_ids.shape ) single_ixs = torch.tensor(part.single_ixs, dtype=torch.long) + npart = single_ixs.numel() (part_demo_ix,) = can_demolish_mask[single_ixs].nonzero(as_tuple=True) for demo_ixs in subsets(part_demo_ix.tolist()): demo_ixs = list(demo_ixs) + if len(demo_ixs) == npart: + continue demo_group_ids = part.group_ids.clone() demo_group_ids[single_ixs[demo_ixs]] = -1 demo_part = replace( @@ -5098,7 +5121,7 @@ def _fit_subset_models( def _score_subset_models( mm: BaseMixtureModel, subset_models: BaseMixtureModel, - train_full_scores: Scores, + train_full_scores: Scores | None, cur_scores: Scores, train_data: DenseSpikeData, eval_data: DenseSpikeData | None, @@ -5137,6 +5160,7 @@ def _score_subset_models( crit_subset_scores = subset_models.score(eval_data, skip_noise=True) else: assert eval_data is None + assert train_full_scores is not None crit_subset_scores = train_subset_scores crit_full_scores = train_full_scores @@ -5337,6 +5361,7 @@ def try_kmeans( drop_prop: float = 0.0, kmeanspp_initial="random", n_kmeans_tries: int = 25, + kmeans_beta: float = 1.0, n_kmeanspp_tries: int = 25, weights: Tensor | None = None, debug: bool = False, @@ -5354,19 +5379,37 @@ def try_kmeans( x_ret = x if debug else None # kmeans - kres = kmeans( - x, - n_components=k, - random_state=gen, - n_iter=n_iter, - with_proportions=with_proportions, - drop_prop=drop_prop, - kmeanspp_initial=kmeanspp_initial, - n_kmeans_tries=n_kmeans_tries, - n_kmeanspp_tries=n_kmeanspp_tries, - weights=weights.to(x) if weights is not None else None, + _can_batch = ( + weights is None + and with_proportions + and not drop_prop + and kmeanspp_initial == "random" ) - resps = kres["responsibilities"] + if _can_batch: + kres = batched_kmeans( + x, + k, + seed=gen, + n_iter=n_iter, + kmeanspp_seeds_per_try=n_kmeanspp_tries, + n_tries=n_kmeans_tries, + beta=kmeans_beta, + ) + else: + assert kmeans_beta == 1.0 + kres = kmeans( + x, + n_components=k, + random_state=gen, + n_iter=n_iter, + with_proportions=with_proportions, + drop_prop=drop_prop, + kmeanspp_initial=kmeanspp_initial, + n_kmeans_tries=n_kmeans_tries, + n_kmeanspp_tries=n_kmeanspp_tries, + weights=weights.to(x) if weights is not None else None, + ) + resps = kres.responsibilities if resps is None: return None, x_ret, channels assert resps.shape[1] <= k @@ -5380,8 +5423,11 @@ def try_kmeans( def evaluate_group_demolitions( + *, mm: TruncatedMixtureModel, group: Tensor, + train_data: TruncatedSpikeData, + eval_data: TruncatedSpikeData, mean_train_resp: Tensor | None, mean_eval_resp: Tensor | None, cur_crit: float | None, @@ -5396,7 +5442,7 @@ def evaluate_group_demolitions( if mean_eval_resp is None: mean_eval_resp = mean_responsibilities(scores=eval_scores, n_units=mm.n_units) ratio = mean_train_resp[group] / mean_eval_resp[group] - can_demolish = ratio > mm.p.demolition_min_resp_ratio + can_demolish = ratio > 0 # mm.p.demolition_min_resp_ratio if not can_demolish.any(): return GroupDemolition(unit_ids=group, improvement=0.0, demolished=None) @@ -5415,16 +5461,46 @@ def evaluate_group_demolitions( unit_ids=group, improvement=0.0, demolished=torch.zeros_like(can_demolish) ) best_imp = 0.0 - for demo_mask in submasks(can_demolish): - crit = _evaluate_single_demolition( - orig_log_props=mm.b.log_proportions, - noise_log_prop=mm.b.noise_log_prop, - cl_alpha=alpha, - group=group_, - demolish_mask=demo_mask, - train_scores=train_scores, - eval_scores=eval_scores, + + if mm.p.refit_in_demolition: + group = group.to(device=train_scores.candidates.device) + (train_ixs,) = ( + torch.isin(train_scores.candidates, group).any(dim=1).nonzero(as_tuple=True) + ) + group_train_data = train_data.dense_slice(train_ixs) + assert group_train_data is not None + group_train_scores = train_scores.slice(group_train_data.indices) + + (eval_ixs,) = ( + torch.isin(eval_scores.candidates, group).any(dim=1).nonzero(as_tuple=True) ) + group_eval_data = eval_data.dense_slice(eval_ixs) + assert group_eval_data is not None + group_eval_scores = eval_scores.slice(group_eval_data.indices) + nongroup_eval_scores = remove_units_from_scores(group_eval_scores, group) + + for demo_mask in submasks(can_demolish, skip_full=True): + if mm.p.refit_in_demolition: + crit = _evaluate_single_refit_demolition( + group=group, + demolish_mask=demo_mask, + group_train_scores=group_train_scores, # ty: ignore[possibly-unresolved-reference] + nongroup_eval_scores=nongroup_eval_scores, # ty: ignore[possibly-unresolved-reference] + group_eval_scores=group_eval_scores, # ty: ignore[possibly-unresolved-reference] + group_eval_data=group_eval_data, # ty: ignore[possibly-unresolved-reference] + group_train_data=group_train_data, # ty: ignore[possibly-unresolved-reference] + mm=mm, + ) + else: + crit = _evaluate_single_demolition( + orig_log_props=mm.b.log_proportions, + noise_log_prop=mm.b.noise_log_prop, + cl_alpha=alpha, + group=group_, + demolish_mask=demo_mask, + train_scores=train_scores, + eval_scores=eval_scores, + ) imp = crit - cur_crit if imp > best_imp: best_demo = GroupDemolition( @@ -5476,13 +5552,83 @@ def _evaluate_single_demolition( ) -def submasks(mask: Tensor, skip_empty=True): +def _evaluate_single_refit_demolition( + group: Tensor, + demolish_mask: Tensor, + group_train_scores: Scores, + nongroup_eval_scores: Scores, + group_eval_scores: Scores, + group_eval_data: DenseSpikeData, + group_train_data: DenseSpikeData, + mm: TruncatedMixtureModel, +) -> float: + assert demolish_mask.shape == group.shape + + # determine mean responsibility after demolition on train set + demolish_mask = demolish_mask.cpu() + chopping_block = group[demolish_mask] + nchop = chopping_block.numel() + if not nchop: + return ecl( + resps=group_eval_scores.responsibilities, + log_liks=group_eval_scores.log_liks, + cl_alpha=mm.p.cl_alpha, + ) + group_remain = group[demolish_mask.logical_not()] + n_remain = group.numel() - nchop + train_scores_adj = remove_units_from_scores(group_train_scores, chopping_block) + assert mm.erp is not None + assert mm.noise is not None + + # responsibilities for re-fitting rest of group + resp0 = train_scores_adj.responsibilities + assert resp0 is not None + cand0 = train_scores_adj.candidates + chopped_responsibilities = mm.b.log_proportions.new_zeros( + (resp0.shape[0], n_remain) + ) + for j, cand in enumerate(group_remain): + cii, cjj = (cand0 == cand).nonzero(as_tuple=True) + chopped_responsibilities[cii, j] = resp0[cii, cjj] + chopped_responsibilities.clamp_(min=1e-8) + + # re-fit model to train_scores_adj + chopped_model, s0valid, _, s0discard, s0mask, _ = ( + TruncatedMixtureModel.initialize_from_dense_data_with_fixed_responsibilities( + data=group_train_data, + responsibilities=chopped_responsibilities, + signal_rank=mm.signal_rank, + erp=mm.erp, + noise=mm.noise, + neighb_cov=mm.neighb_cov, + total_log_proportion=mm.non_noise_log_proportion(), + p=mm.p, + min_channel_count=0, + ) + ) + + # re-score eval data and combine with bystander units + chopped_score = chopped_model.score(data=group_eval_data, skip_responsibility=False) + assert chopped_score.responsibilities is not None + final_score = concatenate_scores([chopped_score, nongroup_eval_scores], dim=1) + assert final_score.responsibilities is not None + + return ecl( + resps=final_score.responsibilities, + log_liks=final_score.log_liks, + cl_alpha=mm.p.cl_alpha, + ) + + +def submasks(mask: Tensor, skip_empty=True, skip_full=False): mask_ = mask.cpu() (on,) = mask_.nonzero(as_tuple=True) on = on.tolist() for subset in subsets(on): if skip_empty and not len(subset): continue + if skip_full and len(subset) == len(mask): + continue m = torch.zeros_like(mask) m[list(subset)] = True yield m @@ -6194,18 +6340,6 @@ def _count_candidates(candidates, batch_candidate_counts, batch_size): batch_candidate_counts.copy_(counts.cpu()) -if TORCH_IS_OLD: - - def _nonzero_static(x: Tensor, size: int): - nz = x.nonzero() - assert nz.shape[0] == size - return nz -else: - - def _nonzero_static(x: Tensor, size: int): - return x.nonzero_static(size=size) - - @torch_compiler(fullgraph=False) def _combine_similar_resps(resps: Tensor, keep_mask: Tensor, n_keep: int) -> Tensor: n_discard = resps.shape[1] - n_keep @@ -6222,22 +6356,22 @@ def _combine_similar_resps(resps: Tensor, keep_mask: Tensor, n_keep: int) -> Ten return kept_resp -def concatenate_scores(scoress: list[Scores]) -> Scores: +def concatenate_scores(scoress: list[Scores], dim=0) -> Scores: assert len(scoress) > 0 if len(scoress) == 1: return scoress[0] - log_liks = torch.concatenate([s.log_liks for s in scoress], dim=0) - candidates = torch.concatenate([s.candidates for s in scoress], dim=0) + log_liks = torch.concatenate([s.log_liks for s in scoress], dim=dim) + candidates = torch.concatenate([s.candidates for s in scoress], dim=dim) if scoress[0].responsibilities is None: responsibilities = None else: responsibilities = torch.concatenate( - [cast(Tensor, s.responsibilities) for s in scoress], dim=0 + [cast(Tensor, s.responsibilities) for s in scoress], dim=dim ) if scoress[0].duties is None: duties = None else: - duties = torch.concatenate([cast(Tensor, s.duties) for s in scoress], dim=0) + duties = scoress[0].duties return Scores( log_liks=log_liks, responsibilities=responsibilities, diff --git a/src/dartsort/config.py b/src/dartsort/config.py index 9abac469..72695fa9 100644 --- a/src/dartsort/config.py +++ b/src/dartsort/config.py @@ -47,7 +47,7 @@ class DARTsortUserConfig: relevant if `preprocessing != 'none'`. If the recording isn't getting saved, stick to float32.""" - subsampling_spikes: int | None = 2_048_000 + subsampling_spikes_per_channel: int | None = 5000 """Detection steps before the final matching round will run until at least this many spikes are found or the whole recording is covered, to make sure that there is enough data for clustering. See also subsampling_fraction. @@ -123,10 +123,9 @@ class DARTsortUserConfig: alignment_ms: Annotated[float, Field(gt=0)] = 1.5 """Largest time shift allowed when re-aligning events.""" - deduplication_ms: Annotated[float, Field(gt=0)] = 0.5 + deduplication_ms: Annotated[float, Field(gt=0)] = 0.5 """As a final postprocessing step, only the higher-scoring of any spikes within this time radius of each other are kept. - If this is negative, it does nothing. If it's 0, exact duplicates are dropped. """ @@ -146,7 +145,7 @@ class DARTsortUserConfig: threshold in Kilosort and other sorters, and it represents reduction in Euclidean norm of standardized data due to matching a new event.""" - initial_threshold: Annotated[float, Field(gt=0)] = 10.0 + initial_threshold: Annotated[float, Field(gt=0)] = 9.0 """Initial detection's neural net matching threshold. Same as matching_threshold, except that a neural net is trying to guess the true waveforms here, rather than using cluster templates.""" @@ -232,7 +231,7 @@ class DeveloperConfig(DARTsortUserConfig): """Additional parameters for experiments. This API will never be stable.""" # high level behavior - initial_steps: Sequence[MixtureStep] = ("split", "demolish") + initial_steps: Sequence[MixtureStep] = ("split", "demolish", "demolish") later_steps: Sequence[MixtureStep] = ("split", "merge", "demolish") detection_type: Literal["subtract", "match", "threshold"] = "subtract" cluster_strategy: str = "dpc" @@ -243,7 +242,7 @@ class DeveloperConfig(DARTsortUserConfig): n_waveforms_fit: int = 40_000 max_waveforms_fit: int = 50_000 fit_sampling: Literal["random", "amp_reweighted"] = "amp_reweighted" - n_residual_snips: int = 4 * 4096 + n_residual_snips: int = 2 * 4096 # initial detection nn_denoiser_max_waveforms_fit: int = 512_000 @@ -254,6 +253,12 @@ class DeveloperConfig(DARTsortUserConfig): first_denoiser_spatial_dedup_radius: float = 100.0 realign_to_denoiser: bool = True use_nn_in_subtraction: bool = True + whiten_in_subtraction: bool = True + threshold_before_whitening: float = 10.0 + temporal_dedup_radius_samples: int = 7 + positive_temporal_dedup_radius_samples: int = 41 + spikeinterface_merge_preset: str | None = None + spikeinterface_merge_max_distance: float = 0.5 # matching matching_template_type: Literal["individual_compressed_upsampled", "drifty"] = ( @@ -278,6 +283,7 @@ class DeveloperConfig(DARTsortUserConfig): trough_factor: float = 3.0 whiten_strategy: WhiteningStrategy = "prewhiten_postapply" whiten_estimator: WhiteningEstimator = "localzca" + whiten_temporal_length: int | None = None whiten_features: bool = False matching_fp_control: bool = False refractory_radius_frames: int = 0 @@ -348,6 +354,7 @@ class DeveloperConfig(DARTsortUserConfig): robust_df: float = 4.0 demolish_during_selection: bool = False em_after_demolish: bool = False + tpca_from_templates: bool = True # agglomeration agg_kind: Literal["none", "template_distance", "qda"] = "qda" diff --git a/src/dartsort/evaluate/analysis.py b/src/dartsort/evaluate/analysis.py index bf36a6c7..263f6c87 100644 --- a/src/dartsort/evaluate/analysis.py +++ b/src/dartsort/evaluate/analysis.py @@ -121,9 +121,12 @@ def from_sorting( motion = MotionInfo.from_motion_est(geom=recording.get_channel_locations()) if has_hdf5: - tpca = get_tpca( - sorting, featurization_pipeline_pt=featurization_pipeline_pt - ) + try: + tpca = get_tpca( + sorting, featurization_pipeline_pt=featurization_pipeline_pt + ) + except ValueError: + tpca = None else: tpca = None if has_hdf5 and vis_radius and tpca is not None: @@ -413,10 +416,14 @@ def unit_raw_waveforms( read_chans=read_chans, main_channel=main_channel, ) - channels = None + channels = self.vis_channel_index[main_channel] + channels = np.broadcast_to( + channels[None], (waveforms.shape[0], waveforms.shape[2]) + ) else: channels = read_channel_index[read_chans] - main_channel = self.unit_max_channel(unit_id) + if main_channel is None: + main_channel = self.unit_max_channel(unit_id) return WaveformsBag( which=which, waveforms=waveforms, @@ -558,7 +565,10 @@ def unit_pca_features( def unit_max_channel(self, unit_id) -> int: assert self.coarse_template_data is not None temp = self.coarse_template_data.unit_templates(unit_id) - assert temp.ndim == 3 and temp.shape[0] == np.atleast_1d(unit_id).size + assert temp.ndim == 3, f"{self.name}: {unit_id=} {temp.shape=}" + assert temp.shape[0] == np.atleast_1d(unit_id).size, ( + f"{self.name}: {unit_id=} {temp.shape=}" + ) which = self.in_unit(unit_id) if self.motion.drifting and hasattr(self.sorting, "channel_index"): @@ -641,7 +651,9 @@ def nearby_coarse_templates(self, unit_id, n_neighbors=5): assert td is not None assert self.merge_distances is not None - unit_ix = np.searchsorted(td.unit_ids, unit_id) + unit_ix = np.flatnonzero(td.unit_ids == unit_id) + assert unit_ix.shape[0] == 1 + unit_ix = unit_ix[0] unit_dists = self.merge_distances[unit_ix] distance_order = np.argsort(unit_dists) distance_order = np.concatenate( diff --git a/src/dartsort/evaluate/comparison.py b/src/dartsort/evaluate/comparison.py index 4a9bfab7..3da1dfc1 100644 --- a/src/dartsort/evaluate/comparison.py +++ b/src/dartsort/evaluate/comparison.py @@ -264,6 +264,19 @@ def nearby_tested_templates(self, gt_unit_id, n_neighbors=5): return neighb_ixs, neighb_ids, neighb_dists, neighb_coarse_templates + def full_tested_labels(self): + labels = np.full(len(self.tested_analysis.sorting), "NA", dtype="= 0: + fp_waves = self.tested_analysis.unit_raw_waveforms( + unit_id=tested_unit, + which=ind_groups["only_tested_indices"], + **waveform_kw, # type: ignore + to_main_channel=True, + ) + else: + fp_waves = None if fp_waves is None: w["which_fp"] = None w["fp"] = None else: w["which_fp"] = fp_waves.which w["fp"] = fp_waves.waveforms + w["channels_fp"] = fp_waves.channels if self.unsorted_detection is None: w["unsorted_tp"] = w["unsorted_fn"] = None else: utp_waves = self.gt_analysis.unit_raw_waveforms( + unit_id=gt_unit, which=ind_groups["unsorted_tp_indices"], **waveform_kw, # type: ignore ) @@ -556,7 +582,9 @@ def get_raw_waveforms_by_category( else: w["which_unsorted_tp"] = utp_waves.which w["unsorted_tp"] = utp_waves.waveforms + w["channels_unsorted_tp"] = utp_waves.channels ufn_waves = self.gt_analysis.unit_raw_waveforms( + unit_id=gt_unit, which=ind_groups["unsorted_fn_indices"], **waveform_kw, # type: ignore ) @@ -566,6 +594,7 @@ def get_raw_waveforms_by_category( else: w["which_unsorted_fn"] = ufn_waves.which w["unsorted_fn"] = ufn_waves.waveforms + w["channels_unsorted_fn"] = ufn_waves.channels return w diff --git a/src/dartsort/main.py b/src/dartsort/main.py index 01cc50b3..3385de07 100644 --- a/src/dartsort/main.py +++ b/src/dartsort/main.py @@ -1,4 +1,5 @@ """High-level spike sorting toolbox functions.""" + import traceback from pathlib import Path from tempfile import TemporaryDirectory @@ -59,7 +60,7 @@ motion_needs_peaks, ) from .util.motion import MotionInfo, get_motion_info -from .util.noise_util import SpatialWhitener +from .util.noise_util import Whitener from .util.peel_util import run_peeler from .util.preprocess_util import preprocess from .util.py_util import dartcopytree, ensure_path, timer @@ -118,11 +119,10 @@ def dartsort( - "sorting": `DARTsortSorting` - "motion": MotionInfo """ - output_dir = ensure_path(output_dir) - output_dir.mkdir(exist_ok=True) + output_dir = ensure_path(output_dir, mkdir=True) # convert cfg to internal format and store it for posterity - cfg = to_internal_config(cfg) + cfg = to_internal_config(cfg, recording.get_num_channels()) ds_dump_config(cfg, output_dir) # in benchmarking, it can be useful to resume from initial detection @@ -252,7 +252,7 @@ def _dartsort_impl( ) ret["motion"] = motion - is_subsampling = cfg.subsampling_spikes is not None + is_subsampling = cfg.subsampling_spikes_per_channel is not None is_subsampling = is_subsampling and cfg.subsampling_presence != 1.0 if next_step == 0: @@ -341,7 +341,10 @@ def _dartsort_impl( else: previous_detection_cfg = cfg.matching_cfg - _nspk = None if is_final else cfg.subsampling_spikes + if is_final or cfg.subsampling_spikes_per_channel is None: + _nspk = None + else: + _nspk = cfg.subsampling_spikes_per_channel * motion.geom.shape[0] _pres = 1.0 if is_final else cfg.subsampling_presence step_clus_cfg, step_clfeat_cfg, step_ref_cfgs, step_feat_cfg, samp_cfg = ( _matching_step_cfgs(is_final, is_subsampling, cfg) @@ -444,6 +447,10 @@ def initial_detection( ------- DARTsortSorting """ + if cfg.subsampling_spikes_per_channel is None: + _nspk = None + else: + _nspk = cfg.subsampling_spikes_per_channel * recording.get_num_channels() if cfg.detection_type == "subtract": assert isinstance(cfg.initial_detection_cfg, SubtractionConfig) return subtract( @@ -454,7 +461,7 @@ def initial_detection( subtraction_cfg=cfg.initial_detection_cfg, sampling_cfg=cfg.peeler_sampling_cfg, computation_cfg=cfg.computation_cfg, - stop_after_n_spikes=cfg.subsampling_spikes, + stop_after_n_spikes=_nspk, ensure_coverage=cfg.subsampling_presence, overwrite=overwrite, show_progress=show_progress, @@ -468,7 +475,7 @@ def initial_detection( thresholding_cfg=cfg.initial_detection_cfg, sampling_cfg=cfg.peeler_sampling_cfg, featurization_cfg=cfg.featurization_cfg, - stop_after_n_spikes=cfg.subsampling_spikes, + stop_after_n_spikes=_nspk, ensure_coverage=cfg.subsampling_presence, overwrite=overwrite, show_progress=show_progress, @@ -485,7 +492,7 @@ def initial_detection( matching_cfg=cfg.initial_detection_cfg, sampling_cfg=cfg.peeler_sampling_cfg, motion=motion, - stop_after_n_spikes=cfg.subsampling_spikes, + stop_after_n_spikes=_nspk, ensure_coverage=cfg.subsampling_presence, overwrite=overwrite, show_progress=show_progress, @@ -572,7 +579,7 @@ def match( template_npz="template_data.npz", computation_cfg: ComputationConfig | None = None, template_denoising_tsvd=None, - whitener: SpatialWhitener | None = None, + whitener: Whitener | None = None, ) -> DARTsortSorting: output_dir = ensure_path(output_dir) model_dir = output_dir / model_subdir diff --git a/src/dartsort/peel/grab.py b/src/dartsort/peel/grab.py index b7e9e6bf..5219080b 100644 --- a/src/dartsort/peel/grab.py +++ b/src/dartsort/peel/grab.py @@ -1,4 +1,5 @@ """Grab and featurize events at known times.""" + from typing import Mapping import numpy as np @@ -6,7 +7,6 @@ from spikeinterface import BaseRecording from ..transform import WaveformPipeline -from ..util import spiketorch from ..util.data_util import DARTsortSorting from ..util.internal_config import ( FeaturizationConfig, @@ -14,6 +14,7 @@ WaveformConfig, default_waveform_cfg, ) +from ..util.spiketorch import _nonzero_static, grab_spikes from ..util.waveform_util import make_channel_index from .peel_base import ( BasePeeler, @@ -39,6 +40,7 @@ def __init__( chunk_length_samples=30_000, fit_sampling_cfg: FitSamplingConfig = FitSamplingConfig(n_residual_snips=0), waveform_cfg: WaveformConfig = default_waveform_cfg, + batch_size: int = 2048, dtype=torch.float, ): fixed_properties = fixed_properties or {} @@ -49,6 +51,7 @@ def __init__( spike_length_samples = waveform_cfg.spike_length_samples( recording.sampling_frequency ) + self.batch_size = batch_size assert not fit_sampling_cfg.n_residual_snips super().__init__( recording=recording, @@ -94,17 +97,44 @@ def process_chunk( ) t_clip = self.b.times_samples.clip(chunk_start_samples, chunk_end_samples - 1) in_chunk = self.b.times_samples == t_clip - if not in_chunk.any(): + n_in_chunk = int(in_chunk.sum().item()) + if not n_in_chunk: return peeling_empty_result - res = super().process_chunk( - chunk_start_samples, - n_resid_snips=n_resid_snips, - chunk_end_samples=chunk_end_samples, - return_residual=return_residual, - skip_features=skip_features, - to_cpu=to_cpu, + in_chunk = _nonzero_static(in_chunk, size=n_in_chunk) + assert in_chunk.shape == (n_in_chunk, 1) + in_chunk = in_chunk[:, 0] + return_waveforms = not skip_features and bool(self.featurization_pipeline) + + chunk, chunk_end_samples_, left_margin, right_margin = self.get_chunk( + chunk_start_samples, chunk_end_samples ) + assert chunk_end_samples == chunk_end_samples_ + + batch_results = [] + for i0 in range(0, n_in_chunk, self.batch_size): + i1 = min(n_in_chunk, i0 + self.batch_size) + batch_in_chunk = in_chunk[i0:i1] + batch_peel_result = self.peel_chunk( + traces=chunk, + chunk_start_samples=chunk_start_samples, + left_margin=left_margin, + right_margin=right_margin, + return_residual=return_residual, + in_chunk=batch_in_chunk, + ) + assert batch_peel_result["n_spikes"] == i1 - i0 + batch_res = self.featurize_chunk_result( + peel_result=batch_peel_result, + to_cpu=to_cpu, + return_waveforms=return_waveforms, + chunk_start_samples=chunk_start_samples, + chunk_end_samples=chunk_end_samples, + device=chunk.device, + n_resid_snips=n_resid_snips, + ) + batch_results.append(batch_res) + res = _cat_results(batch_results) return res @classmethod @@ -150,21 +180,23 @@ def from_config( def peel_chunk( self, - traces, + traces: torch.Tensor, *, chunk_start_samples=0, left_margin=0, right_margin=0, return_residual=False, return_waveforms=True, + in_chunk: torch.Tensor | None = None, ) -> PeelingBatchResult: assert not return_residual - max_t = chunk_start_samples + self.chunk_length_samples - 1 - in_chunk = self.b.times_samples == self.b.times_samples.clip( - chunk_start_samples, max_t - ) - (in_chunk,) = in_chunk.nonzero(as_tuple=True) + if in_chunk is None: + max_t = chunk_start_samples + self.chunk_length_samples - 1 + in_chunk = self.b.times_samples == self.b.times_samples.clip( + chunk_start_samples, max_t + ) + (in_chunk,) = in_chunk.nonzero(as_tuple=True) if not in_chunk.numel(): return peeling_empty_result @@ -184,7 +216,7 @@ def peel_chunk( channels=channels, ) if return_waveforms: - res["collisioncleaned_waveforms"] = spiketorch.grab_spikes( + res["collisioncleaned_waveforms"] = grab_spikes( traces, times_rel, channels, @@ -198,3 +230,44 @@ def peel_chunk( for k in self.fixed_property_keys: res[k] = getattr(self.b, k)[in_chunk].to(device=dev) return res + + +def _cat_results(results): + # not ideal, should work on typing these things better, but no time now. + assert len(results) > 0 + if len(results) == 1: + return results[0] + rdict: dict[str, int | float | list[torch.Tensor] | list[np.ndarray]] = { + "n_spikes": 0 + } + for res in results: + for k, v in res.items(): + if k == "n_spikes": + rdict[k] += v + elif k == "chunk_center_s": + rdict[k] = v + elif isinstance(v, (int, float)): + if k not in rdict: + rdict[k] = v + else: + assert rdict[k] == v + elif isinstance(v, (torch.Tensor, np.ndarray)): + if k not in rdict: + rdict[k] = [] + rdict[k].append(v) # type: ignore + else: + assert False + res = {} + for k, v in rdict.items(): + if k in ("n_spikes", "chunk_center_s"): + res[k] = v + elif isinstance(v, list): + if isinstance(v[0], torch.Tensor): + res[k] = torch.concatenate(v) # type: ignore + elif isinstance(v[0], np.ndarray): + res[k] = np.concatenate(v) + else: + assert False + else: + res[k] = v + return res diff --git a/src/dartsort/peel/matching.py b/src/dartsort/peel/matching.py index ad9ddeb5..5cf95819 100644 --- a/src/dartsort/peel/matching.py +++ b/src/dartsort/peel/matching.py @@ -51,7 +51,8 @@ def __init__( fpctrl_spike_counts=None, fit_sampling_cfg: FitSamplingConfig = default_peeling_fit_sampling_cfg, save_collidedness=False, - whiten_features=True, + whiten_features=False, + whiten_kernel_length=0, parent_sorting_hdf5_path: str | Path | None = None, dtype=torch.float, ): @@ -109,7 +110,9 @@ def __init__( ) self.amp_scale_max = 1.0 + p.amplitude_scaling_boundary self.amp_scale_min = 1.0 / self.amp_scale_max - self.obj_pad_len = max(p.refractory_radius_frames, self.spike_length_samples) + self.whiten_pad = max(0, whiten_kernel_length - 1) + pt = self.spike_length_samples + self.whiten_pad + self.obj_pad_len = max(p.refractory_radius_frames, pt) conv_len = ( self.chunk_length_samples + 2 * self.chunk_margin_samples @@ -192,6 +195,10 @@ def from_config( matching_cfg.precomputed_templates_npz ) assert trough_offset_samples == template_data.trough_offset_samples + if template_data.temporal_kernel is None: + whiten_kernel_length = 0 + else: + whiten_kernel_length = template_data.temporal_kernel.shape[0] if motion is None: motion = MotionInfo.from_motion_est(geom=geom.numpy()) @@ -237,6 +244,7 @@ def from_config( parent_sorting_hdf5_path=parent_sorting_hdf5_path, save_collidedness=save_collidedness, whiten_features=matching_cfg.whiten_features, + whiten_kernel_length=whiten_kernel_length, fpctrl_spike_counts=template_data.spike_counts if matching_cfg.threshold == "fp_control" else None, @@ -258,12 +266,17 @@ def peel_chunk( chunk_center_samples = chunk_start_samples + self.chunk_length_samples // 2 segment = self.recording._recording_segments[0] chunk_center_seconds = float(segment.sample_index_to_time(chunk_center_samples)) + if self.whiten_features: + resid_offset = self.whiten_pad + else: + resid_offset = 0 chunk_template_data = self.matching_templates.data_at_time( t_s=chunk_center_seconds, scaling=self.is_scaling, inv_lambda=self.inv_lambda, scale_min=self.amp_scale_min, scale_max=self.amp_scale_max, + resid_offset=resid_offset, ) # deconvolve @@ -279,7 +292,7 @@ def peel_chunk( # process spike times and create return result if match_results["n_spikes"]: - match_results["times_samples"] += chunk_start_samples - left_margin # type: ignore + match_results["times_samples"] += chunk_start_samples - left_margin if match_results["n_spikes"] > self.p.max_spikes_per_second: raise ValueError( f"Too many spikes {match_results['n_spikes']} > {self.p.max_spikes_per_second}." @@ -314,8 +327,8 @@ def match_chunk( # name objective variables so that we can update them in-place later # padded objective has an extra unit (for group_index) and refractory # padding (for easier implementation of enforce_refractory) - valid_len = traces.shape[0] - self.spike_length_samples + 1 - padded_obj_len = valid_len + 2 * self.obj_pad_len + valid_len = traces.shape[0] - self.spike_length_samples - self.whiten_pad + 1 + padded_obj_len = valid_len + 2 * self.obj_pad_len + self.whiten_pad padded_conv = traces.new_zeros( chunk_template_data.obj_n_templates, padded_obj_len ) @@ -419,6 +432,7 @@ def match_chunk( if not chunk_template_data.needs_residual: chunk_template_data.subtract(residual_padded, peaks) + assert residual.shape[0] == traces.shape[0] if not peaks.n_spikes: res = PeelingBatchResult(n_spikes=0) if return_residual: @@ -522,7 +536,11 @@ def pick_threshold(self): from scipy.stats import norm # TODO: remove? - if self.is_scaling and self.p.scale_adjusts_threshold and self.p.amplitude_scaling_variance < torch.inf: + if ( + self.is_scaling + and self.p.scale_adjusts_threshold + and self.p.amplitude_scaling_variance < torch.inf + ): # adjust threshold by the scaling prior's constant term # nb, everything is x2 so halves are gone. scstd = np.sqrt(self.p.amplitude_scaling_variance) diff --git a/src/dartsort/peel/matching_util/compressed_upsampled.py b/src/dartsort/peel/matching_util/compressed_upsampled.py index 2d4ad3ad..d01738c7 100644 --- a/src/dartsort/peel/matching_util/compressed_upsampled.py +++ b/src/dartsort/peel/matching_util/compressed_upsampled.py @@ -260,7 +260,9 @@ def data_at_time( inv_lambda: float, scale_min: float, scale_max: float, + resid_offset: int, ) -> "CompressedUpsampledChunkTemplateData": + assert not resid_offset if self.drifting: shifts, padded_spatial_sing = templates_at_time( t_s=t_s, @@ -300,6 +302,7 @@ def data_at_time( return CompressedUpsampledChunkTemplateData( coarse_objective=self.coarse_objective, + resid_offset=resid_offset, grouping=self.have_groups, upsampling=self.upsampling, scaling=scaling, @@ -308,6 +311,7 @@ def data_at_time( n_templates=self.n_templates, obj_n_templates=self.obj_n_templates, spike_length_samples=self.spike_length_samples, + filter_length_samples=self.spike_length_samples, up_factor=self.b.cup_index.shape[1], inv_lambda=torch.tensor(inv_lambda, device=normsq.device), scale_min=torch.tensor(scale_min, device=normsq.device), @@ -354,6 +358,8 @@ class CompressedUpsampledChunkTemplateData(ChunkTemplateData): inv_lambda: Tensor scale_min: Tensor scale_max: Tensor + resid_offset: int + filter_length_samples: int # objective props obj_normsq: Tensor @@ -458,7 +464,12 @@ def subtract(self, traces, peaks, sign=-1): ) def fine_match( - self, *, peaks: MatchingPeaks, residual: Tensor | None, conv: Tensor, padding: int = 0 + self, + *, + peaks: MatchingPeaks, + residual: Tensor | None, + conv: Tensor, + padding: int = 0, ): """Determine superres ids, temporal upsampling, and scaling diff --git a/src/dartsort/peel/matching_util/drifty.py b/src/dartsort/peel/matching_util/drifty.py index 426426cd..c7195d86 100644 --- a/src/dartsort/peel/matching_util/drifty.py +++ b/src/dartsort/peel/matching_util/drifty.py @@ -55,7 +55,7 @@ from ...util.job_util import ensure_computation_config from ...util.logging_util import get_logger from ...util.motion import MotionInfo -from ...util.noise_util import SpatialWhitener +from ...util.noise_util import Whitener from ...util.py_util import databag from ...util.spiketorch import full_shared_pconv, shared_temporal_pconv from ...util.torch_util import torch_compiler @@ -84,7 +84,7 @@ def __init__( trough_offset_samples: int, unit_ids: Tensor | None = None, whiten_strategy: WhiteningStrategy = "none", - whitener: SpatialWhitener | None = None, + whitener: Whitener | None = None, whiten_features: bool = True, up_factor: int = 1, up_method: Literal["interpolation", "keys3", "keys4", "direct"] = "keys4", @@ -138,20 +138,31 @@ def __init__( up_temporal_comps = upsample_singlechan_torch( temporal_comps, temporal_jitter=up_factor ) - tconv = shared_temporal_pconv(temporal_comps, up_temporal_comps) + if whitener is not None: + conv_temporal_comps = whitener._convolve(temporal_comps, padding="full") + conv_up_temporal_comps = whitener._convolve( + up_temporal_comps, padding="full" + ) + norm_discount = torch.linalg.vector_norm( + conv_temporal_comps, dim=1, keepdim=True + ) + else: + conv_temporal_comps = temporal_comps + conv_up_temporal_comps = up_temporal_comps + norm_discount = None + tconv = shared_temporal_pconv(conv_temporal_comps, conv_up_temporal_comps) assert temporal_comps.shape == (rank, self.spike_length_samples) + self.register_buffer("conv_temporal_comps", conv_temporal_comps.contiguous()) + self.register_buffer_or_none("norm_discount", norm_discount) self.register_buffer("temporal_comps", temporal_comps.contiguous()) assert up_temporal_comps.shape == (rank, up_factor, self.spike_length_samples) self.register_buffer_or_none( "up_temporal_comps", up_temporal_comps.contiguous() ) - if up_temporal_comps is not None: - up_major_temporal_comps = up_temporal_comps.permute(1, 0, 2).contiguous() - else: - up_major_temporal_comps = None + up_major_temporal_comps = up_temporal_comps.permute(1, 0, 2).contiguous() self.register_buffer_or_none("up_major_temporal_comps", up_major_temporal_comps) - self.register_buffer("spatial_sing", spatial_sing) + self.register_buffer("spatial_sing", spatial_sing.contiguous()) self.register_buffer("tconv", tconv) if unit_ids is None: unit_ids = torch.arange(self.n_units, device=device) @@ -160,11 +171,13 @@ def __init__( if self.whiten_strategy == "postwhiten": assert whitener is not None self.whitener = whitener.to(spatial_sing.device) - conv_spatial_sing = self.whitener.transpose_whiten(spatial_sing) + conv_spatial_sing = self.whitener.transpose_whiten( + spatial_sing, spatial_only=True + ) elif self.whiten_strategy == "prewhiten_postapply" and not self.interpolating: assert whitener is not None self.whitener = whitener.to(spatial_sing.device) - conv_spatial_sing = self.whitener.whiten(spatial_sing) + conv_spatial_sing = self.whitener.whiten(spatial_sing, spatial_only=True) elif self.whiten_strategy == "prewhiten_postapply": assert whitener is not None self.whitener = whitener.to(spatial_sing.device) @@ -177,6 +190,8 @@ def __init__( self.whitener = conv_spatial_sing = None else: assert False + if conv_spatial_sing is not None: + conv_spatial_sing = conv_spatial_sing.contiguous() self.register_buffer_or_none("conv_spatial_sing", conv_spatial_sing) # full pconv can be precomputed when not interpolating @@ -194,11 +209,16 @@ def __init__( # indexing helpers t = self.spike_length_samples rr = refractory_radius_frames + if self.whitener is not None: + ct = t + max(0, self.whitener.temporal_length - 1) + else: + ct = t self.register_buffer("refrac_ix", torch.arange(-rr, rr + 1, device=device)) - self.register_buffer("time_ix", torch.arange(t, device=device)) + self.register_buffer("time_ix", torch.arange(ct, device=device)) + self.register_buffer("sub_time_ix", torch.arange(t, device=device)) self.register_buffer("chan_ix", torch.arange(n_channels, device=device)) self.register_buffer("rank_ix", torch.arange(rank, device=device)) - self.register_buffer("conv_lags", torch.arange(-t + 1, t, device=device)) + self.register_buffer("conv_lags", torch.arange(-ct + 1, ct, device=device)) offset = torch.asarray(trough_offset_samples, device=device) self.register_buffer("trough_offset_samples", offset) @@ -226,6 +246,7 @@ def _from_config( min_channel_amplitude=matching_cfg.template_min_channel_amplitude, rank=matching_cfg.template_svd_compression_rank, computation_cfg=computation_cfg, + min_explained_variance=matching_cfg.template_svd_compression_min_explained_variance, ) temporal_comps = torch.asarray(shared_basis_temps.temporal_components) spatial_sing = torch.asarray(shared_basis_temps.spatial_singular) @@ -236,7 +257,14 @@ def _from_config( whitener = None else: assert template_data.whitener is not None - whitener = SpatialWhitener.from_numpy(template_data.whitener) + assert template_data.covariance is not None + if template_data.temporal_kernel is not None: + tk = template_data.temporal_kernel + else: + tk = None + whitener = Whitener.from_numpy( + template_data.whitener, template_data.covariance, tk + ) if not wh_none and not matching_cfg.whiten_features: assert matching_cfg.whitening.strategy == "prewhiten_postapply" @@ -264,6 +292,10 @@ def interp_at_time(self, t_s: float, x: Tensor) -> Tensor: return self.erp.interp_at_time(t_s=t_s, waveforms=x) def spatial_at_time(self, t_s: float) -> tuple[Tensor, ...]: + """Get spatial components and norms at current chunk + + This handles the spatial whitening strategy logic. + """ if self.whiten_strategy == "postwhiten": spatial_sing = self.interp_at_time(t_s, self.b.spatial_sing) normsq_spatial_sing = spatial_sing @@ -273,11 +305,19 @@ def spatial_at_time(self, t_s: float) -> tuple[Tensor, ...]: conv_spatial_sing = normsq_spatial_sing = spatial_sing elif self.whiten_strategy == "prewhiten_postapply": assert self.whitener is not None + + # features / cc waveforms just use interpolated plain template spatial_sing = self.interp_at_time(t_s, self.b.spatial_sing) + + # for convolution, whitening is applied if self.interpolating: - conv_spatial_sing = self.whitener.whiten(spatial_sing) + conv_spatial_sing = self.whitener.whiten( + spatial_sing, spatial_only=True + ) else: conv_spatial_sing = self.b.conv_spatial_sing + + # this is usually false if self.whiten_features: spatial_sing = conv_spatial_sing normsq_spatial_sing = conv_spatial_sing @@ -285,7 +325,11 @@ def spatial_at_time(self, t_s: float) -> tuple[Tensor, ...]: assert False # normsq for channel selection from original - normsq_by_chan = normsq_spatial_sing.square().sum(dim=1) + norm_discount = self.b.norm_discount + if norm_discount is None: + normsq_by_chan = normsq_spatial_sing.square().sum(dim=1) + else: + normsq_by_chan = (norm_discount * normsq_spatial_sing).square_().sum(dim=1) main_channels = normsq_by_chan.argmax(dim=1) normsq = normsq_by_chan.sum(dim=1) @@ -307,12 +351,15 @@ def data_at_time( inv_lambda: float, scale_min: float, scale_max: float, + resid_offset: int, ) -> ChunkTemplateData: spatial_sing, normsq, main_channels, padded_spatial_sing, pconv = ( self.spatial_at_time(t_s=t_s) ) return DriftyChunkTemplateData( spike_length_samples=self.spike_length_samples, + filter_length_samples=self.b.conv_temporal_comps.shape[1], + resid_offset=resid_offset, unit_ids=self.b.unit_ids, main_channels=main_channels, obj_normsq=normsq, @@ -326,11 +373,12 @@ def data_at_time( inv_lambda=torch.asarray(inv_lambda).to(normsq, non_blocking=True), scale_min=torch.asarray(scale_min).to(normsq, non_blocking=True), scale_max=torch.asarray(scale_max).to(normsq, non_blocking=True), - temporal_comps=self.b.temporal_comps, + conv_temporal_comps=self.b.conv_temporal_comps, up_major_temporal_comps=self.b.up_major_temporal_comps, spatial_sing=spatial_sing, pconv=pconv, time_ix=self.b.time_ix, + sub_time_ix=self.b.sub_time_ix + resid_offset, chan_ix=self.b.chan_ix, rank_ix=self.b.rank_ix, conv_lags=self.b.conv_lags, @@ -343,6 +391,8 @@ def data_at_time( @databag class DriftyChunkTemplateData(ChunkTemplateData): spike_length_samples: int + filter_length_samples: int + resid_offset: int # for the full templates unit_ids: Tensor main_channels: Tensor @@ -358,14 +408,15 @@ class DriftyChunkTemplateData(ChunkTemplateData): scale_min: Tensor scale_max: Tensor - temporal_comps: Tensor + conv_temporal_comps: Tensor up_major_temporal_comps: Tensor spatial_sing: Tensor padded_spatial_sing: Tensor pconv: Tensor - spatial_whitener: SpatialWhitener | None + spatial_whitener: Whitener | None time_ix: Tensor + sub_time_ix: Tensor chan_ix: Tensor rank_ix: Tensor refrac_ix: Tensor @@ -381,13 +432,13 @@ def convolve(self, traces: Tensor, padding: int = 0, out: Tensor | None = None): Odd to have batch come second, but it makes things simpler here (and it's only something used for debugging: usually traces.ndim == 2). """ - out_len = traces.shape[-1] + 2 * padding - self.spike_length_samples + 1 + out_len = traces.shape[-1] + 2 * padding - self.filter_length_samples + 1 if out is not None: assert out.shape == (self.obj_n_templates, out_len) conv = convolve_lowrank_shared( traces=traces, spatial_singular=self.spatial_sing, - temporal_components=self.temporal_comps, + temporal_components=self.conv_temporal_comps, padding=padding, out=out, ) @@ -399,8 +450,8 @@ def score(self, spikes: Tensor) -> Tensor: n, t, c = spikes.shape spikes = self.whiten_traces(spikes) spikes_t = spikes.mT.reshape(n * c, t) - rank = self.temporal_comps.shape[0] - tconv = spikes_t @ self.temporal_comps.T + rank = self.conv_temporal_comps.shape[0] + tconv = spikes_t @ self.conv_temporal_comps.T spatial_t = self.spatial_sing.permute(2, 1, 0) # now chan,rank,unit conv = tconv.view(n, c * rank) @ spatial_t.reshape( c * rank, self.obj_n_templates @@ -418,27 +469,16 @@ def subtract(self, traces: Tensor, peaks: "MatchingPeaks", sign: int = -1): return assert peaks.times is not None assert peaks.template_inds is not None - - if peaks.up_inds is None: - tempc = self.temporal_comps[None] - else: - tempc = self.up_major_temporal_comps - assert sign in (-1, 1) - tempc = ( - self.temporal_comps - if peaks.up_inds is None - else self.up_major_temporal_comps - ) _subtract_templates_loop( traces=traces, up_inds=peaks.up_inds, scalings=peaks.scalings, template_inds=peaks.template_inds, - tempc=tempc, + tempc=self.up_major_temporal_comps, spatc=self.padded_spatial_sing, times=peaks.times, - time_ix=self.time_ix, + time_ix=self.sub_time_ix, neg=sign == -1, ) @@ -476,7 +516,7 @@ def get_clean_waveforms( if peaks.up_inds is not None: tempc = self.up_major_temporal_comps[peaks.up_inds] else: - tempc = self.temporal_comps + tempc = self.up_major_temporal_comps[0] tempc = tempc[None].broadcast_to(peaks.n_spikes, *tempc.shape) if add_into is None: return tempc.mT.bmm(spatial) @@ -548,7 +588,9 @@ def reconstruct_up_templates(self): def whiten_traces(self, traces: Tensor, out: Tensor | None = None): if self.prewhiten: assert self.spatial_whitener is not None - return self.spatial_whitener.whiten_traces_spatial_major(traces, out=out) + return self.spatial_whitener.whiten_traces_spatial_major( + traces, out=out, padding="full" + ) elif out is not None: return out.copy_(traces.T) else: diff --git a/src/dartsort/peel/matching_util/matching_base.py b/src/dartsort/peel/matching_util/matching_base.py index 9f5e6389..f487dc1d 100644 --- a/src/dartsort/peel/matching_util/matching_base.py +++ b/src/dartsort/peel/matching_util/matching_base.py @@ -83,6 +83,7 @@ def data_at_time( inv_lambda: float, scale_min: float, scale_max: float, + resid_offset: int, ) -> "ChunkTemplateData": raise NotImplementedError @@ -129,6 +130,9 @@ def spike_length_samples(self) -> int: class ChunkTemplateData: # -- subclasses must assign the following properties that the matcher uses. spike_length_samples: int + filter_length_samples: int + resid_offset: int + # for the full templates unit_ids: Tensor main_channels: Tensor @@ -260,7 +264,7 @@ def coarse_match( assert nt > 2 * padding times = argrelmax_dedup( x=objective_max, - dedup_radius=self.spike_length_samples, + dedup_radius=self.filter_length_samples, threshold=thresholdsq, arange=obj_arange[:nt], padding=padding, @@ -329,12 +333,13 @@ def get_collisioncleaned_waveforms( assert times is not None # get noise + # TODO check the offset is correct waveforms = grab_spikes( residual_padded, times, channels, sel_ci, - trough_offset=0, + trough_offset=self.resid_offset, spike_length_samples=self.spike_length_samples, buffer=0, already_padded=True, diff --git a/src/dartsort/peel/peel_base.py b/src/dartsort/peel/peel_base.py index 7f3255bf..6b5ca81e 100644 --- a/src/dartsort/peel/peel_base.py +++ b/src/dartsort/peel/peel_base.py @@ -1,4 +1,5 @@ """Home of the BasePeeler, where shared logic for other modules here lives.""" + import gc import tempfile from concurrent.futures import CancelledError @@ -154,6 +155,7 @@ def load_or_fit_and_save_models( cleanup_and_log_gpu_usage(computation_cfg, "peel: Usage after model fits:") assert not self.needs_precompute() assert not self.needs_fit() + self.post_fit() def peel( self, @@ -242,7 +244,7 @@ def peel( resids_remaining = total_residual_snips - resids_so_far chunks_remaining = len(chunks_to_do) chunks_done = n_chunks_orig - chunks_remaining - chunks_cover = int(np.floor(ensure_coverage * n_chunks_orig)) + chunks_cover = int(np.ceil(ensure_coverage * n_chunks_orig)) chunks_cover_remaining = chunks_cover - chunks_done if chunks_cover_remaining == 0: assert resids_remaining == 0 @@ -402,7 +404,7 @@ def peel( def peel_chunk( self, - traces, + traces: torch.Tensor, *, chunk_start_samples=0, left_margin=0, @@ -431,6 +433,9 @@ def peeling_needs_fit(self) -> bool: def peeling_needs_precompute(self) -> bool: return False + def post_fit(self): + pass + def precompute_peeling_data( self, save_folder, overwrite=False, computation_cfg=None ): @@ -496,9 +501,35 @@ def process_chunk( Main function called in peeling workers """ + chunk, chunk_end_samples, left_margin, right_margin = self.get_chunk( + chunk_start_samples, chunk_end_samples + ) + return_waveforms = not skip_features and bool(self.featurization_pipeline) + peel_result = self.peel_chunk( + chunk, + chunk_start_samples=chunk_start_samples, + left_margin=left_margin, + right_margin=right_margin, + return_waveforms=return_waveforms, + return_residual=return_residual or bool(n_resid_snips), + ) + chunk_result = self.featurize_chunk_result( + peel_result=peel_result, + to_cpu=to_cpu, + return_waveforms=return_waveforms, + chunk_start_samples=chunk_start_samples, + chunk_end_samples=chunk_end_samples, + device=chunk.device, + n_resid_snips=n_resid_snips, + ) + return chunk_result + + def get_chunk(self, chunk_start_samples: int, chunk_end_samples: int | None = None): + Ts = self.recording.get_num_samples() if chunk_end_samples is None: chunk_end_samples = chunk_start_samples + self.chunk_length_samples - chunk_end_samples = min(self.recording.get_num_samples(), chunk_end_samples) + chunk_end_samples = min(Ts, chunk_end_samples) + assert chunk_end_samples <= Ts chunk, left_margin, right_margin = get_chunk_with_margin( self.recording._recording_segments[0], start_frame=chunk_start_samples, @@ -508,15 +539,19 @@ def process_chunk( ) device = self.b.channel_index.device chunk = torch.tensor(chunk, device=device, dtype=self.dtype) - return_waveforms = not skip_features and bool(self.featurization_pipeline) - peel_result = self.peel_chunk( - chunk, - chunk_start_samples=chunk_start_samples, - left_margin=left_margin, - right_margin=right_margin, - return_waveforms=return_waveforms, - return_residual=return_residual or bool(n_resid_snips), - ) + return chunk, chunk_end_samples, left_margin, right_margin + + def featurize_chunk_result( + self, + *, + peel_result, + to_cpu: bool, + return_waveforms: bool, + chunk_start_samples: int, + chunk_end_samples: int, + device: torch.device, + n_resid_snips: int | None, + ): if peel_result["n_spikes"] > 0 and to_cpu: t_s = self.recording.sample_index_to_time( peel_result["times_samples"].numpy(force=True) @@ -539,12 +574,13 @@ def process_chunk( # a user who wants these must featurize with a waveform node # then they'll end up in `features` if "collisioncleaned_waveforms" in peel_result: - del peel_result["collisioncleaned_waveforms"] # type: ignore + del peel_result["collisioncleaned_waveforms"] chunk_result = peel_result | features if to_cpu: chunk_result = { - k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in chunk_result.items() + k: v.cpu() if isinstance(v, torch.Tensor) else v + for k, v in chunk_result.items() } if "residual" in peel_result: if torch.is_tensor(peel_result["residual"]): @@ -567,7 +603,7 @@ def process_chunk( self.recording.sample_index_to_time(resid_times_samples) ) - return chunk_result # type: ignore + return chunk_result def gather_chunk_result( self, @@ -1175,6 +1211,7 @@ def __init__( self.skip_features = skip_features self.chunk_length_samples = chunk_length_samples self.to_cpu = to_cpu + self.Ts = peeler.recording.get_num_samples() # this state will be set on each thread globally @@ -1220,10 +1257,12 @@ def _peeler_process_job(chunk_start_samples__n_resid_snips): # by returning here, we are implicitly relying on pickle # TODO: replace with manual np.saves with torch.no_grad(): - chunk_end_samples = None chlen = _peeler_process_context.ctx.chunk_length_samples if chlen is not None: chunk_end_samples = chunk_start_samples + chlen + chunk_end_samples = min(chunk_end_samples, _peeler_process_context.ctx.Ts) + else: + chunk_end_samples = None return _peeler_process_context.ctx.peeler.process_chunk( chunk_start_samples, n_resid_snips=n_resid_snips, diff --git a/src/dartsort/peel/peel_lib.py b/src/dartsort/peel/peel_lib.py index ca4e5793..39d8af54 100644 --- a/src/dartsort/peel/peel_lib.py +++ b/src/dartsort/peel/peel_lib.py @@ -25,19 +25,20 @@ if TYPE_CHECKING: from ..transform.pipeline import WaveformPipeline + from ..util.internal_config import PeakSign def denoiser_time_shifts( - waveforms, - channels, - voltages, - subtract_rel_inds, - trough_offset_samples, - spike_length_samples, - peak_sign, - denoiser_realignment_shift, - denoiser_realignment_channel, -): + waveforms: Tensor, + channels: Tensor, + voltages: Tensor, + subtract_rel_inds: Tensor | None, + trough_offset_samples: int, + spike_length_samples: int, + peak_sign: "PeakSign", + denoiser_realignment_shift: int, + denoiser_realignment_channel: Literal["detection", "denoised"], +) -> Tensor: # extract main channel traces if denoiser_realignment_channel == "detection": assert subtract_rel_inds is not None @@ -74,15 +75,43 @@ def denoiser_time_shifts( def check_residual_decrease( - orig_wfs, - dn_wfs, - decrease_objective="deconv", + orig_wfs: Tensor | None, + dn_wfs: Tensor, + decrease_objective: Literal["deconv", "norm", "normsq"] = "deconv", threshold=10.0, save_residnorm_decrease=False, overwrite_orig_waveforms: bool = False, -): + local_whiteners: Tensor | None = None, + whitening_kernel: Tensor | None = None, + channels: Tensor | None = None, +) -> tuple[Tensor | None, dict[str, Tensor]]: if not threshold: return None, {} + assert orig_wfs is not None + + if local_whiteners is not None: + assert channels is not None + W = local_whiteners[channels] + + # remove nans + if overwrite_orig_waveforms: + orig_wfs = orig_wfs.nan_to_num_() + else: + orig_wfs = orig_wfs.nan_to_num() + dn_wfs = dn_wfs.nan_to_num() + + # spatial mul -- putting temporal dim last here + orig_wfs = W.bmm(orig_wfs.mT) + dn_wfs = W.bmm(dn_wfs.mT) + + # temporal conv if needed + if whitening_kernel is not None: + *shp, t = orig_wfs.shape + k = whitening_kernel[None, None] + orig_wfs = F.conv1d(orig_wfs.view(-1, 1, t), k, padding="same") + dn_wfs = F.conv1d(dn_wfs.view(-1, 1, t), k, padding="same") + orig_wfs = orig_wfs.view(*shp, t) + dn_wfs = dn_wfs.view(*shp, t) if decrease_objective == "deconv": if overwrite_orig_waveforms: @@ -92,7 +121,7 @@ def check_residual_decrease( norm = buf.nan_to_num_().sum(dim=(1, 2)) else: dn_wfs = dn_wfs.nan_to_num() - conv = (orig_wfs * dn_wfs).sum(dim=(1, 2)) + conv = (orig_wfs * dn_wfs).nan_to_num_().sum(dim=(1, 2)) norm = dn_wfs.square_().sum(dim=(1, 2)) reduction = conv.mul_(2.0).sub_(norm) threshold = threshold**2 @@ -132,11 +161,11 @@ def check_residual_decrease( def subtract_chunk( - traces, - channel_index, - denoising_pipeline, - extract_index=None, - extract_mask=None, + traces: Tensor, + channel_index: Tensor, + denoising_pipeline: "WaveformPipeline", + extract_index: Tensor | None = None, + extract_mask: Tensor | None = None, trough_offset_samples=42, spike_length_samples=121, left_margin=0, @@ -148,12 +177,14 @@ def subtract_chunk( denoiser_realignment_channel="detection", convexity_threshold=None, convexity_radius=3, - peak_channel_index=None, - dedup_channel_index=None, - subtract_rel_inds=None, - dedup_rel_inds=None, + peak_channel_index: Tensor | None = None, + dedup_channel_index: Tensor | None = None, + subtract_rel_inds: Tensor | None = None, + dedup_rel_inds: Tensor | None = None, residnorm_decrease_threshold=16.0, decrease_objective: Literal["norm", "normsq", "deconv"] = "deconv", + local_whiteners: Tensor | None = None, + whitening_kernel: Tensor | None = None, relative_peak_radius=5, dedup_temporal_radius=7, remove_exact_duplicates=True, @@ -161,13 +192,13 @@ def subtract_chunk( dedup_batch_size=512, no_subtraction=False, max_iter=100, - trough_priority=None, - growth_tolerance=None, + trough_priority: float | None = None, + growth_tolerance: float | None = None, cumulant_order=None, save_iteration=False, save_residnorm_decrease=False, compute_collidedness=False, -): +) -> ChunkSubtractionResult: """Core peeling routine for subtraction""" if no_subtraction: threshold_res = threshold_chunk( @@ -223,7 +254,7 @@ def subtract_chunk( if growth_tolerance is not None: gtol = traces.abs().add_(growth_tolerance) else: - gtol = 0.0 + gtol = None # initialize residual, it needs to be padded to support # our channel indexing convention. this copies the input. @@ -247,7 +278,7 @@ def subtract_chunk( for it in range(max_iter): residual_det = residual[:, :-1] - if it and growth_tolerance is not None: + if it and gtol is not None: residual_det = residual_det.clamp(-gtol, gtol) times_samples, channels = detect_and_deduplicate( @@ -353,6 +384,9 @@ def subtract_chunk( decrease_objective=decrease_objective, threshold=residnorm_decrease_threshold, save_residnorm_decrease=save_residnorm_decrease, + local_whiteners=local_whiteners, + whitening_kernel=whitening_kernel, + channels=channels, ) features.update(new_feats) if resid_keep is not None: diff --git a/src/dartsort/peel/reduction_template.py b/src/dartsort/peel/reduction_template.py index 14d790ff..8ddd81ea 100644 --- a/src/dartsort/peel/reduction_template.py +++ b/src/dartsort/peel/reduction_template.py @@ -30,7 +30,7 @@ from ..util.job_util import ensure_computation_config from ..util.logging_util import get_logger from ..util.motion import MotionInfo -from ..util.noise_util import SpatialWhitener +from ..util.noise_util import Whitener from ..util.py_util import ensure_path from ..util.waveform_util import full_channel_index from .grab import GrabAndFeaturize @@ -54,7 +54,7 @@ def _from_config( waveform_cfg: WaveformConfig = default_waveform_cfg, motion: MotionInfo, tsvd=None, - whitener: SpatialWhitener | None = None, + whitener: Whitener | None = None, computation_cfg: ComputationConfig | None = None, show_progress: bool = True, ) -> TemplateData: @@ -155,9 +155,9 @@ def _from_config( templates *= msk if whitener is None: - whitener_np = None + whitener_np = covariance_np = tk_np = None else: - whitener_np = whitener.to_numpy() + whitener_np, covariance_np, tk_np = whitener.to_numpy() return TemplateData( unit_ids=unit_ids, @@ -169,6 +169,8 @@ def _from_config( trough_offset_samples=trough, tsvd=p.temporal_svd(), whitener=whitener_np, + covariance=covariance_np, + temporal_kernel=tk_np, sampling_frequency=recording.sampling_frequency, whiten_strategy=template_cfg.whitening.strategy, ) @@ -189,7 +191,7 @@ def from_config( # type: ignore waveform_cfg: WaveformConfig, template_cfg: TemplateConfig, computation_cfg: ComputationConfig, - whitener: SpatialWhitener | None = None, + whitener: Whitener | None = None, ): # geom processing rgeom = torch.asarray(motion.rgeom) @@ -205,12 +207,14 @@ def from_config( # type: ignore pad_spike_len = padded_waveform_cfg.spike_length_samples( recording.sampling_frequency ) + rank = template_cfg.denoising_rank if template_cfg.use_svd and tsvd is not None: if isinstance(tsvd, FullProbeTemporalPCAEmbedder): if do_align: raise ValueError("Haven't handled svd alignment in this case.") else: - assert tsvd.components_.shape[0] == template_cfg.denoising_rank + assert tsvd.components_.shape[0] <= rank + rank = tsvd.components_.shape[0] tsvd = FullProbeTemporalPCAEmbedder.from_sklearn( channel_index=channel_index, pca=tsvd, @@ -229,7 +233,8 @@ def from_config( # type: ignore waveform_cfg=waveform_cfg, computation_cfg=computation_cfg, ) - assert tsvd.components_.shape[0] == template_cfg.denoising_rank + assert tsvd.components_.shape[0] <= rank + rank = tsvd.components_.shape[0] tsvd = FullProbeTemporalPCAEmbedder.from_sklearn( channel_index=channel_index, pca=tsvd, @@ -242,7 +247,7 @@ def from_config( # type: ignore elif template_cfg.use_svd: tsvd = FullProbeTemporalPCAEmbedder( channel_index=channel_index, - rank=template_cfg.denoising_rank, + rank=rank, geom=geom, fit_radius=template_cfg.denoising_fit_radius, max_waveforms=template_cfg.denoising_fit_sampling_cfg.n_waveforms_fit, @@ -320,7 +325,7 @@ def from_config( # type: ignore name_prefix="svd", with_raw_std_dev=False, n_units=sorting.n_units, - feature_dim=template_cfg.denoising_rank, + feature_dim=rank, output_channels=len(rgeom), reduction=template_cfg.reduction, ) diff --git a/src/dartsort/peel/subtract.py b/src/dartsort/peel/subtract.py index 44fa3279..26759406 100644 --- a/src/dartsort/peel/subtract.py +++ b/src/dartsort/peel/subtract.py @@ -9,7 +9,7 @@ import torch from spikeinterface.core import BaseRecording -from ..transform import Voltage, Waveform, WaveformPipeline +from ..transform import Voltage, Waveform, WaveformPipeline, WaveformWhitener from ..util import job_util from ..util.data_util import SpikeDataset, subsample_waveforms from ..util.internal_config import ( @@ -72,6 +72,10 @@ def __init__( self.save_residnorm_decrease = save_residnorm_decrease self.save_collidedness = save_collidedness self.dedup_batch_size = self.nearest_batch_length() + if self.p.whiten: + self.threshold = self.p.threshold_before_whitening + else: + self.threshold = self.p.residnorm_decrease_threshold geom = recording.get_channel_locations() sub_channel_index = make_channel_index( @@ -125,6 +129,10 @@ def __init__( can_thin = recording.get_total_duration() > fit_sampling_cfg.n_seconds_fit / _p self.first_denoiser_thinning = p.first_denoiser_thinning if can_thin else 0.0 + # this may be overwritten after featurization fit + self.register_buffer_or_none("local_whiteners", None) + self.register_buffer_or_none("whitening_kernel", None) + def out_datasets(self): datasets = super().out_datasets() @@ -148,17 +156,38 @@ def peeling_needs_fit(self): def peeling_needs_precompute(self): return self.subtraction_denoising_pipeline.needs_precompute() - def save_models(self, save_folder): - super().save_models(save_folder) + def post_fit(self): + if not self.p.whiten: + return + assert self.featurization_pipeline is not None + assert not self.featurization_pipeline.needs_fit() + whitener = [ + f + for f in self.featurization_pipeline.transformers + if isinstance(f, WaveformWhitener) + ] + assert len(whitener) == 1 + whitener = whitener[0].whitener + assert whitener is not None + local_whiteners = whitener.local_whiteners(self.sub_channel_index) # type: ignore + self.del_none_buffer("local_whiteners") + self.register_buffer("local_whiteners", local_whiteners) + if whitener.temporal: + self.del_none_buffer("whitening_kernel") + self.register_buffer("whitening_kernel", whitener.b.temporal_kernel.clone()) + self.threshold = self.p.residnorm_decrease_threshold + + def save_models(self, save_folder: str | Path): sub_denoise_pt = Path(save_folder) / "subtraction_denoising_pipeline.pt" torch.save(self.subtraction_denoising_pipeline.state_dict(), sub_denoise_pt) + super().save_models(save_folder) - def load_models(self, save_folder): - super().load_models(save_folder) + def load_models(self, save_folder: str | Path): sub_denoise_pt = Path(save_folder) / "subtraction_denoising_pipeline.pt" if sub_denoise_pt.exists(): state_dict = torch.load(sub_denoise_pt, weights_only=True) self.subtraction_denoising_pipeline.load_state_dict(state_dict) + super().load_models(save_folder) @classmethod def from_config( @@ -179,6 +208,15 @@ def from_config( geom, subtraction_cfg.subtract_radius_um, to_torch=True ) + # handle whitener fitting + if subtraction_cfg.whiten: + assert subtraction_cfg.whiten_cfg is not None + featurization_cfg = replace( + featurization_cfg, + fit_disabled_whitener=True, + whiten_cfg=subtraction_cfg.whiten_cfg, + ) + # construct denoising and featurization pipelines subtraction_denoising_pipeline = WaveformPipeline.from_config( geom=geom, @@ -223,12 +261,12 @@ def peel_chunk( ): del return_waveforms # always done here - extract_index = None if self.extract_subtract_same else self.channel_index + extract_index = None if self.extract_subtract_same else self.b.channel_index traces = traces.to(self.dtype) subtraction_result = subtract_chunk( traces, - self.sub_channel_index, + self.b.sub_channel_index, self.subtraction_denoising_pipeline, extract_index=extract_index, extract_mask=self.extract_subtract_mask, @@ -245,7 +283,7 @@ def peel_chunk( dedup_temporal_radius=self.p.temporal_dedup_radius_samples, remove_exact_duplicates=self.p.remove_exact_duplicates, pos_dedup_temporal_radius=self.p.positive_temporal_dedup_radius_samples, - residnorm_decrease_threshold=self.p.residnorm_decrease_threshold, + residnorm_decrease_threshold=self.threshold, decrease_objective=self.p.decrease_objective, trough_priority=self.p.trough_priority, growth_tolerance=self.p.growth_tolerance, @@ -255,12 +293,14 @@ def peel_chunk( save_iteration=self.save_iteration, save_residnorm_decrease=self.save_residnorm_decrease, max_iter=self.p.max_iter, - subtract_rel_inds=self.subtract_index_rel_inds, - dedup_rel_inds=self.dedup_rel_inds, + subtract_rel_inds=self.b.subtract_index_rel_inds, + dedup_rel_inds=self.b.dedup_rel_inds, realign_to_denoiser=self.p.realign_to_denoiser, denoiser_realignment_shift=self.p.denoiser_realignment_shift, denoiser_realignment_channel=self.p.denoiser_realignment_channel, compute_collidedness=self.save_collidedness, + local_whiteners=self.b.local_whiteners, + whitening_kernel=self.b.whitening_kernel, ) # add in chunk_start_samples @@ -312,6 +352,7 @@ def fit_peeler_models(self, save_folder, tmp_dir=None, computation_cfg=None): gc.collect() torch.cuda.empty_cache() + self.save_models(save_folder=save_folder) def _fit_subtraction_transformers( self, save_folder, tmp_dir=None, computation_cfg=None, which="denoisers" diff --git a/src/dartsort/templates/get_templates.py b/src/dartsort/templates/get_templates.py index 84a93b31..3df6965b 100644 --- a/src/dartsort/templates/get_templates.py +++ b/src/dartsort/templates/get_templates.py @@ -27,7 +27,7 @@ from ..util.logging_util import get_logger, progbar from ..util.motion import MotionInfo from ..util.multiprocessing_util import get_pool -from ..util.noise_util import SpatialWhitener +from ..util.noise_util import Whitener from ..util.spiketorch import fast_nanmedian, nanmean, ptp from .templib import denoising_weights, fit_tsvd @@ -49,7 +49,7 @@ def _from_config( waveform_cfg: WaveformConfig = default_waveform_cfg, motion: MotionInfo, tsvd=None, - whitener: SpatialWhitener | None = None, + whitener: Whitener | None = None, computation_cfg: ComputationConfig | None = None, show_progress: bool = True, ) -> TemplateData: diff --git a/src/dartsort/templates/postprocess_util.py b/src/dartsort/templates/postprocess_util.py index e73d9841..11aaa10d 100644 --- a/src/dartsort/templates/postprocess_util.py +++ b/src/dartsort/templates/postprocess_util.py @@ -25,7 +25,7 @@ from ..util.job_util import ensure_computation_config from ..util.logging_util import get_logger from ..util.motion import MotionInfo -from ..util.noise_util import SpatialWhitener +from ..util.noise_util import Whitener from ..util.py_util import ensure_path from ..util.spiketorch import ptp from . import TemplateData, realign @@ -49,7 +49,7 @@ def estimate_template_library( realign_cfg: TemplateRealignmentConfig | None = None, template_merge_cfg: TemplateMergeConfig | None = None, tsvd: PCA | TruncatedSVD | None = None, - whitener: SpatialWhitener | None = None, + whitener: Whitener | None = None, computation_cfg: ComputationConfig | None = None, fit_featurization_tsvd: bool = False, featurization_cfg: FeaturizationConfig | None = None, @@ -384,8 +384,6 @@ def _handle_merge( if merge_cfg is None or not merge_cfg.merge_distance_threshold: return sorting, template_data - from ..clustering.merge import merge_templates - if template_cfg.denoising_method == "svd": # use new shared basis stuff from ..clustering.agglomerate import agglomerate @@ -406,6 +404,8 @@ def _handle_merge( del agg else: # TODO: remove old impl? + from ..clustering.merge import merge_templates + merge_shift_samples = waveform_cfg.ms_to_samples(merge_cfg.max_shift_ms) merge_res = merge_templates( sorting=sorting, diff --git a/src/dartsort/templates/template_util.py b/src/dartsort/templates/template_util.py index b2ef1253..36cb5b02 100644 --- a/src/dartsort/templates/template_util.py +++ b/src/dartsort/templates/template_util.py @@ -275,6 +275,7 @@ def shared_basis_compress_templates( rank=5, computation_cfg=None, precomputed_basis: np.ndarray | None = None, + min_explained_variance: float = 5e-3, with_r2: bool = False, ) -> SharedBasisTemplates: computation_cfg = ensure_computation_config(computation_cfg) @@ -292,7 +293,7 @@ def shared_basis_compress_templates( rank = min(rank, t) if precomputed_basis is None: temporal_comps = get_shared_temporal_basis( - templates, rank, dev, min_channel_amplitude + templates, rank, dev, min_channel_amplitude, min_explained_variance ) assert temporal_comps.shape[1] == t assert temporal_comps.ndim == 2 @@ -341,6 +342,7 @@ def get_shared_temporal_basis( rank: int, device: torch.device, min_channel_amplitude: float, + min_explained_variance: float, eps=1e-6, ) -> np.ndarray: n, t, c = templates.shape @@ -370,9 +372,11 @@ def get_shared_temporal_basis( cov += m * m.T if nvis < t: logger.warning(f"Had {nvis=} smaller than {t=} in shared basis compression.") - cov.diagonal().add_(1e-5) + cov.diagonal().add_(eps) vals, U = torch.linalg.eigh(cov) - big_enough = vals > eps + min_explained_variance = max(eps, min_explained_variance) + exvar = vals / vals.sum() + big_enough = exvar > min_explained_variance nbig = big_enough.sum() if nbig < rank: logger.dartsortdebug(f"Shared basis only needed rank {nbig}.") diff --git a/src/dartsort/templates/templates.py b/src/dartsort/templates/templates.py index 1e60d9f5..cb3b0a26 100644 --- a/src/dartsort/templates/templates.py +++ b/src/dartsort/templates/templates.py @@ -19,7 +19,7 @@ ) from ..util.logging_util import get_logger from ..util.motion import MotionInfo -from ..util.noise_util import SpatialWhitener +from ..util.noise_util import Whitener from ..util.py_util import databag from .template_util import weighted_average @@ -51,6 +51,8 @@ class TemplateData: properties: dict[str, np.ndarray] | None = None tsvd: TruncatedSVD | PCA | None = None whitener: np.ndarray | None = None + covariance: np.ndarray | None = None + temporal_kernel: np.ndarray | None = None whiten_strategy: WhiteningStrategy = "none" featurization_basis: np.ndarray | None = None @@ -160,6 +162,10 @@ def to_npz(self, npz_path): to_save["raw_std_dev"] = self.raw_std_dev if self.whitener is not None: to_save["whitener"] = self.whitener + if self.covariance is not None: + to_save["covariance"] = self.covariance + if self.temporal_kernel is not None: + to_save["temporal_kernel"] = self.temporal_kernel if self.featurization_basis is not None: to_save["featurization_basis"] = self.featurization_basis if not npz_path.parent.exists(): @@ -248,7 +254,7 @@ def from_config( save_folder: Path | None = None, overwrite=False, motion: MotionInfo | None = None, - whitener: SpatialWhitener | None = None, + whitener: Whitener | None = None, save_npz_name: str | None = "template_data.npz", tsvd=None, featurization_basis=None, @@ -270,7 +276,7 @@ def from_config( motion = MotionInfo.from_motion_est(geom=recording.get_channel_locations()) if template_cfg.whitening.strategy != "none" and whitener is None: assert sorting is not None - whitener = SpatialWhitener.from_config( + whitener = Whitener.from_config( sorting=sorting, motion=motion, whiten_cfg=template_cfg.whitening, @@ -319,7 +325,7 @@ def _from_config( template_cfg: TemplateConfig, waveform_cfg: WaveformConfig = default_waveform_cfg, motion: MotionInfo, - whitener: SpatialWhitener | None = None, + whitener: Whitener | None = None, tsvd=None, computation_cfg: ComputationConfig | None = None, ) -> "TemplateData": diff --git a/src/dartsort/templates/templib.py b/src/dartsort/templates/templib.py index 46b591d6..9d456f3c 100644 --- a/src/dartsort/templates/templib.py +++ b/src/dartsort/templates/templib.py @@ -63,6 +63,7 @@ def fit_tsvd( min_channel_amplitude=template_cfg.template_min_channel_amplitude, random_seed=random_seed, computation_cfg=computation_cfg, + min_explained_variance=template_cfg.svd_min_explained_variance, ) return pca @@ -128,6 +129,7 @@ def pca_from_templates( rank: int, min_channel_amplitude: float = 1.0, random_seed: int = 0, + min_explained_variance: float = 5e-3, computation_cfg: ComputationConfig | None = None, ) -> PCA: from .template_util import shared_basis_compress_templates @@ -137,8 +139,11 @@ def pca_from_templates( rank=rank, computation_cfg=computation_cfg, min_channel_amplitude=min_channel_amplitude, + min_explained_variance=min_explained_variance, ) basis = tdc.temporal_components + assert 0 < tdc.temporal_components.shape[0] <= rank + rank = tdc.temporal_components.shape[0] pca = PCA( n_components=rank, random_state=random_seed, diff --git a/src/dartsort/transform/_multichan_denoiser_kit.py b/src/dartsort/transform/_multichan_denoiser_kit.py index 99a65622..da0a9e2e 100644 --- a/src/dartsort/transform/_multichan_denoiser_kit.py +++ b/src/dartsort/transform/_multichan_denoiser_kit.py @@ -32,6 +32,7 @@ def __init__( n_epochs=75, pad_depth_only=True, channelwise_dropout_p=0.0, + svd_projection_rank: int | None = None, with_conv_fullheight=False, val_split_p=0.0, min_epochs=10, @@ -85,6 +86,7 @@ def __init__( self.res_type = res_type self.inference_batch_size = inference_batch_size self.epoch_size = epoch_size + self.svd_projection_rank = svd_projection_rank model_channel_index = regularize_channel_index( geom=self.geom, channel_index=channel_index, depth_only=pad_depth_only @@ -127,7 +129,11 @@ def initialize_shapes(self): ) # we don't know these dimensions til we see a spike assert self.spike_length_samples is not None - self.wf_dim = self.spike_length_samples * self.b.model_channel_index.shape[1] + if self.svd_projection_rank: + dim0 = self.svd_projection_rank + else: + dim0 = self.spike_length_samples + self.wf_dim = dim0 * self.b.model_channel_index.shape[1] self.output_dim = self.wf_dim def get_optimizer(self): @@ -196,8 +202,12 @@ def get_mlp( hidden_dims = self.hidden_dims if output_layer is None: output_layer = "gated_linear" if self.signal_gates else "linear" + if self.svd_projection_rank: + dim0 = self.svd_projection_rank + else: + dim0 = self.spike_length_samples return nn_util.get_waveform_mlp( - self.spike_length_samples, + dim0, self.b.model_channel_index.shape[1], hidden_dims, self.output_dim, diff --git a/src/dartsort/transform/decollider.py b/src/dartsort/transform/decollider.py index bcde2cf0..ba391042 100644 --- a/src/dartsort/transform/decollider.py +++ b/src/dartsort/transform/decollider.py @@ -28,6 +28,7 @@ class Decollider(BaseMultichannelDenoiser): """Unsupervised spike waveform denoising.""" + default_name = "decollider" needs_residual = True @@ -48,6 +49,7 @@ def __init__( pad_depth_only=True, channelwise_dropout_p=0.0, with_conv_fullheight=False, + svd_projection_rank: int | None = None, val_split_p=0.0, min_epochs=10, earlystop_eps=None, @@ -105,6 +107,7 @@ def __init__( pad_depth_only=pad_depth_only, channelwise_dropout_p=channelwise_dropout_p, with_conv_fullheight=with_conv_fullheight, + svd_projection_rank=svd_projection_rank, val_split_p=val_split_p, min_epochs=min_epochs, earlystop_eps=earlystop_eps, @@ -143,6 +146,8 @@ def __init__( self.detach_cycle_loss = detach_cycle_loss self.clip_value = clip_value self.clip_norm = clip_norm + if self.svd_projection_rank: + self.submodule_names = ["tpca"] if separate_cycle_net: assert cycle_loss_alpha > 0 @@ -176,6 +181,19 @@ def initialize_spike_length_dependent_params(self): ) else: self.den_net: torch.nn.Module = self.inf_net + if self.svd_projection_rank: + from .temporal_pca import BaseTemporalPCA + + self.tpca = BaseTemporalPCA( + self.b.channel_index, + geom=self.b.geom, + waveform_cfg=self.waveform_cfg, + rank=self.svd_projection_rank, + ) + self.tpca.spike_length_samples = self.spike_length_samples + self.tpca.initialize_spike_length_dependent_params() + else: + self.tpca = None self.to(self.device) def fit( @@ -192,17 +210,27 @@ def fit( super().fit( recording, waveforms, computation_cfg=computation_cfg, channels=channels ) + if self.tpca is not None and self.tpca.needs_fit(): + self.tpca.fit( + recording=recording, + waveforms=waveforms, + computation_cfg=computation_cfg, + channels=channels, + ) train_data, val_data = self._construct_datasets_from_waveforms( waveforms, channels, recording, weights, hdf5_filename=hdf5_filename ) with torch.enable_grad(): + if self.tpca is not None: + self.tpca.eval() res = self._fit(train_data, val_data) self._needs_fit = False return res def forward_unbatched(self, waveforms, channels): """Called only at inference time.""" - # TODO: batch all of this. + if self.tpca is not None: + waveforms = self.tpca.force_embed(waveforms) waveforms, masks = self.to_nn_channels(waveforms, channels) net_input = waveforms, masks.unsqueeze(1) @@ -266,6 +294,9 @@ def forward_unbatched(self, waveforms, channels): pred = self.to_orig_channels(pred, channels) + if self.tpca is not None: + pred = self.tpca.force_reconstruct(pred) + return pred def train_forward(self, y, m, ell, mask): @@ -559,15 +590,6 @@ def _run_train_loop(self, train_data, val_data): train_losses = {} for waveform_b, channels_b, noise_b, cnoise_b in train_data: waveform_b = waveform_b.to(device=self.device, non_blocking=True) - channels_b = channels_b.to(device=self.device, non_blocking=True) - waveform_b = reindex( - channels_b, - waveform_b, - self.relative_index, - pad_value=0.0, - ) - - optimizer.zero_grad() m = noise_b.to( dtype=waveform_b.dtype, device=self.device, non_blocking=True ) @@ -579,6 +601,24 @@ def _run_train_loop(self, train_data, val_data): ) else: ell = None + + if self.tpca is not None: + with torch.no_grad(): + waveform_b = self.tpca.force_embed(waveform_b) + m = self.tpca.force_embed(m) + if ell is not None: + ell = self.tpca.force_embed(ell) + + channels_b = channels_b.to(device=self.device, non_blocking=True) + waveform_b = reindex( + channels_b, + waveform_b, + self.relative_index, + pad_value=0.0, + ) + + optimizer.zero_grad() + mask = self.get_masks(channels_b).to( dtype=waveform_b.dtype, device=self.device, non_blocking=True ) @@ -596,9 +636,13 @@ def _run_train_loop(self, train_data, val_data): loss.backward() if self.clip_value is not None: - torch.nn.utils.clip_grad_value_(self.parameters(), self.clip_value) + torch.nn.utils.clip_grad_value_( + self.parameters(), self.clip_value + ) if self.clip_norm is not None: - torch.nn.utils.clip_grad_norm_(self.parameters(), self.clip_norm) + torch.nn.utils.clip_grad_norm_( + self.parameters(), self.clip_norm + ) optimizer.step() for k, v in loss_dict.items(): @@ -622,6 +666,12 @@ def _run_train_loop(self, train_data, val_data): channels_b = channels_b.to(self.device) noise_b = noise_b.to(self.device) ell_b = None if ell_b is None else ell_b.to(self.device) + if self.tpca is not None: + with torch.no_grad(): + waveform_b = self.tpca.force_embed(waveform_b) + noise_b = self.tpca.force_embed(noise_b) + if ell_b is not None: + ell_b = self.tpca.force_embed(ell_b) waveform_b, mask = self.to_nn_channels( waveform_b, channels_b diff --git a/src/dartsort/transform/pipeline.py b/src/dartsort/transform/pipeline.py index ed9a2ac5..05d12ad1 100644 --- a/src/dartsort/transform/pipeline.py +++ b/src/dartsort/transform/pipeline.py @@ -494,6 +494,12 @@ def featurization_config_to_class_names_and_kwargs( ) ) + if fc.fit_disabled_whitener: + assert fc.whiten_cfg is not None + class_names_and_kwargs.append( + ("WaveformWhitener", {"disabled": True, "whiten_cfg": fc.whiten_cfg}) + ) + # logic for picking an efficient combo of tpcas and nn denoisers class_names_and_kwargs.extend( _add_tpca_and_nn(featurization_cfg, waveform_cfg, sampling_frequency) @@ -632,7 +638,7 @@ def _add_localization_and_ampvec(fc): ) ) - if fc.do_enforce_decrease == "loc_only" and fc.do_localization: + if (fc.do_enforce_decrease == "loc_only") and (fc.do_localization and do_feats): more.append(("EnforceDecrease", {})) if do_feats and fc.do_localization and fc.nn_localization: diff --git a/src/dartsort/transform/temporal_pca.py b/src/dartsort/transform/temporal_pca.py index 1afbfc1b..34a48b77 100644 --- a/src/dartsort/transform/temporal_pca.py +++ b/src/dartsort/transform/temporal_pca.py @@ -22,6 +22,7 @@ class BaseTemporalPCA(BaseWaveformModule): """Base class for PCA featurizers.""" + default_name = "basis" def __init__( @@ -93,14 +94,13 @@ def fit( super().fit( recording, waveforms, computation_cfg=computation_cfg, channels=channels ) + del spike_data rg = np.random.default_rng(self.random_state) if weights is not None and waveforms.shape[0] > self.max_waveforms: weights = weights.numpy(force=True) if torch.is_tensor(weights) else weights weights = weights.astype(np.float64) weights = weights / weights.sum() - choices = rg.choice( - len(weights), p=weights, size=self.max_waveforms - ) + choices = rg.choice(len(weights), p=weights, size=self.max_waveforms) choices.sort() choices = torch.from_numpy(choices) waveforms = waveforms[choices] @@ -248,24 +248,39 @@ def force_reconstruct(self, features): if ndim == 2: features = features.unsqueeze(0) n, r, c = features.shape - waveforms = features.permute(0, 2, 1).reshape(n * c, r) - waveforms = self._inverse_transform_in_probe(waveforms) - waveforms = waveforms.reshape(n, c, -1).permute(0, 2, 1) + if self.whiten: + W = self.b.components / self.b.whitener + else: + W = self.b.components + Wt = W.t() + assert Wt.shape[1] == r + out = features.new_empty((n, Wt.shape[0], c)) + bs = self.batch_size + for i0 in range(0, n, bs): + i1 = i0 + bs + torch.matmul(Wt, features[i0:i1], out=out[i0:i1]) if ndim == 2: - waveforms = waveforms[0] - return waveforms + out = out[0] + return out def force_embed(self, waveforms): ndim = waveforms.ndim if ndim == 2: waveforms = waveforms.unsqueeze(0) n, t, c = waveforms.shape - waveforms = waveforms.mT.reshape(n * c, t) - waveforms = self._transform_in_probe(waveforms) - waveforms = waveforms.view(n, c, self.rank).mT + if self.whiten: + W = self.b.components / self.b.whitener + else: + W = self.b.components + assert W.shape[1] == t + out = waveforms.new_empty((n, W.shape[0], c)) + bs = self.batch_size + for i0 in range(0, n, bs): + i1 = i0 + bs + torch.matmul(W, waveforms[i0:i1], out=out[i0:i1]) if ndim == 2: - waveforms = waveforms[0] - return waveforms + out = out[0] + return out def force_project(self, waveforms): ndim = waveforms.ndim @@ -364,6 +379,9 @@ def initialize_spike_length_dependent_params(self): self.register_buffer("whitener", torch.zeros(self.rank)) else: assert self.b.mean.shape == (nt,) + if self.b.components.shape[0] < self.rank: + self.rank = self.b.components.shape[0] + self.shape = (self.rank, self.b.channel_index.shape[1]) assert self.b.components.shape == (self.rank, nt) self.to(self.b.channel_index.device) @@ -399,13 +417,18 @@ def initialize_from_templates(self, td): else: dt = self.temporal_slice.stop - self.temporal_slice.start assert basis.shape[1] == dt + self.rank = min(self.rank, basis.shape[0]) + self.shape = (self.rank, self.b.channel_index.shape[1]) self.b.mean.zero_() - self.b.components.copy_(torch.asarray(basis[: self.rank])) + comps = torch.asarray(basis[: self.rank]) + self.b.components.resize_(comps.shape) + self.b.components.copy_(comps) self._needs_fit = False class TemporalPCADenoiser(BaseWaveformDenoiser, BaseTemporalPCA): """Spike waveform denoising with PCA.""" + default_name = "temporal_pca" def forward(self, waveforms, *, channels, time_shifts=None, **unused): @@ -460,6 +483,7 @@ def forward(self, waveforms, *, channels, **spike_data): class TemporalPCAFeaturizer(BaseWaveformFeaturizer, BaseTemporalPCA): """Spike featurization with PCA.""" + default_name = "tpca_features" def transform( @@ -520,6 +544,7 @@ def inverse_transform(self, features, channels, channel_index=None): class TemporalPCA(BaseWaveformAutoencoder, TemporalPCAFeaturizer): """Combined spike featurization and denoising with PCA.""" + default_name = "tpca_features" def forward(self, waveforms, *, channels, time_shifts=None, **unused): diff --git a/src/dartsort/transform/transform_base.py b/src/dartsort/transform/transform_base.py index 0defad80..606e8588 100644 --- a/src/dartsort/transform/transform_base.py +++ b/src/dartsort/transform/transform_base.py @@ -41,6 +41,7 @@ def __init__( if name_prefix: name = f"{name_prefix}_{name}" self.name = name + self.submodule_names = None # these buffers below need to be copied, else they share references # across all the transformers which seems to cause problems! if channel_index is not None: @@ -131,9 +132,22 @@ def _other_pre_load_state(self, state_dict, prefix): def _pre_load_state(self, state_dict, prefix, *args, **kwargs): # wish torch would strip the prefix for us? extra_state_keys = [k for k in state_dict.keys() if k.endswith("_extra_state")] - assert len(extra_state_keys) <= 1 - if extra_state_keys: - extra_state = state_dict[extra_state_keys[0]] + + all_submodule_keys = [] + if self.submodule_names: + for sn in self.submodule_names: + sn_keys = [ + k for k in extra_state_keys if k.endswith(f"{sn}._extra_state") + ] + assert len(sn_keys) == 1 + all_submodule_keys.append(sn_keys[0]) + + my_extra_state_keys = [ + k for k in extra_state_keys if k not in all_submodule_keys + ] + assert len(my_extra_state_keys) <= 1 + if my_extra_state_keys: + extra_state = state_dict[my_extra_state_keys[0]] # some modules want to know the spike length before loading the state dict # and unfortunately set_extra_state usually runs after. doesn't hurt to run now. @@ -147,6 +161,11 @@ def _pre_load_state(self, state_dict, prefix, *args, **kwargs): self._other_pre_load_state(state_dict, prefix) + if self.submodule_names: + for sn, smk in zip(self.submodule_names, all_submodule_keys): + sn_dict = {smk: state_dict[smk]} + getattr(self, sn)._pre_load_state(sn_dict, prefix, *args, **kwargs) + def initialize_spike_length_dependent_params(self): pass diff --git a/src/dartsort/transform/whiten.py b/src/dartsort/transform/whiten.py index c3255f9b..c82eb43a 100644 --- a/src/dartsort/transform/whiten.py +++ b/src/dartsort/transform/whiten.py @@ -1,13 +1,26 @@ +from pathlib import Path from typing import TYPE_CHECKING +import torch +from spikeinterface.core import BaseRecording + +from ..util.data_util import DARTsortSorting +from ..util.internal_config import ( + ComputationConfig, + WaveformConfig, + WhiteningConfig, + default_waveform_cfg, +) from .transform_base import BaseWaveformDenoiser if TYPE_CHECKING: - from ..util.noise_util import SpatialWhitener + from ..util.noise_util import Whitener + from .pipeline import WaveformPipeline class WaveformWhitener(BaseWaveformDenoiser): default_name = "whiten" + needs_residual = True def __init__( self, @@ -16,19 +29,74 @@ def __init__( channel_index, name=None, name_prefix=None, - whitener: "SpatialWhitener | None" = None, + waveform_cfg: WaveformConfig | None = default_waveform_cfg, + whitener: "Whitener | None" = None, + disabled: bool = True, + whiten_cfg: WhiteningConfig = WhiteningConfig(), + sampling_frequency: float = 30_000.0, ): super().__init__( name=name, name_prefix=name_prefix, geom=geom, channel_index=channel_index ) - assert channel_index.shape[1] == geom.shape[0], ( - "Meant to be used with full-probe data." - ) self.whitener = whitener + self.disabled = disabled + self.whiten_cfg = whiten_cfg + self.motion = None + + def needs_fit(self): + return self.whitener is None def forward(self, waveforms, **unused): del unused - if self.whitener is None: + if self.disabled or self.whitener is None: return waveforms else: return self.whitener.whiten(x=waveforms) + + def attach_motion(self, motion): + self.motion = motion + + def _other_pre_load_state(self, state_dict, prefix): + if self.whitener is not None: + return + from ..util.noise_util import Whitener + + self.whitener = Whitener.blank( + len(self.b.geom), self.b.geom.device, self.whiten_cfg.temporal_length + ) + + def fit( + self, + recording: BaseRecording, + waveforms: torch.Tensor, + *, + hdf5_filename: Path | None = None, + computation_cfg: ComputationConfig, + pipeline: "WaveformPipeline | None" = None, + **spike_data: torch.Tensor, + ): + super().fit( + recording, + waveforms, + hdf5_filename=hdf5_filename, + computation_cfg=computation_cfg, + pipeline=pipeline, + **spike_data, + ) + del recording, spike_data, waveforms, pipeline + from ..util.noise_util import Whitener + + assert hdf5_filename is not None + + if self.motion is None: + assert self.whiten_cfg.strategy != "postwhiten" + + sorting = DARTsortSorting.from_peeling_hdf5( + hdf5_filename, load_simple_features=False + ) + self.whitener = Whitener.from_config( + sorting=sorting, + motion=self.motion, + whiten_cfg=self.whiten_cfg, + computation_cfg=computation_cfg, + ) diff --git a/src/dartsort/util/data_util.py b/src/dartsort/util/data_util.py index 497cf57f..0e366f10 100644 --- a/src/dartsort/util/data_util.py +++ b/src/dartsort/util/data_util.py @@ -11,14 +11,23 @@ BaseRecording, BaseSorting, NumpySorting, + SortingAnalyzer, + create_sorting_analyzer, get_random_data_chunks, ) +from spikeinterface.core.sparsity import estimate_sparsity from ..detect import detect_and_deduplicate -from .internal_config import WaveformConfig, default_waveform_cfg +from .internal_config import ( + TemplateConfig, + WaveformConfig, + default_clustering_features_cfg, + default_waveform_cfg, +) from .logging_util import get_logger, progbar if TYPE_CHECKING: + from ..templates.templates import TemplateData from .motion import MotionInfo from .job_util import ensure_computation_config from .py_util import ensure_path @@ -87,6 +96,10 @@ def __init__( if ephemeral_features is not None: for k, v in ephemeral_features.items(): check_shape = not self._no_check_needed(k) + if k in self._persistent_features: + assert np.array_equal(v, self._persistent_features[k]) + assert hasattr(self, k) + continue self.add_ephemeral_feature(k, v, check_shape=check_shape) @property @@ -126,6 +139,7 @@ def to_numpy_sorting( labels_list=labels, sampling_frequency=st.sampling_frequency, ) + numpy_sorting._compute_and_cache_spike_vector() if return_kept_indices: # kept_indices[i] is the original index of numpy_sorting's ith spike kept_indices = st.mask_indices[order] @@ -207,6 +221,138 @@ def to_tsgroup( trains[unit_id] = Tsd(t=ut, d=uw) return TsGroup(trains, metadata=metadata) + def to_sorting_analyzer( + self, + recording: BaseRecording, + template_data: "TemplateData | None" = None, + template_cfg: TemplateConfig | None = None, + motion: "MotionInfo | None" = None, + drop_doubles: bool = True, + compute_extensions: Sequence[str] | None = ("random_spikes", "waveforms"), + features_cfg=default_clustering_features_cfg, + ) -> SortingAnalyzer: + """Export dartsort's internal data to a SortingAnalyzer + + This will first call to_numpy_sorting() and then register some of the sorting's + features as extensions for the analyzer. + + If template_data is supplied, a templates extension will be registered. Or, you + can supply template_cfg and motion to compute that. + + The implementation is based on SpikeInterface's `read_kilosort_as_analyzer()`, + thanks to Chris Halcrow for that. + + The return value here can be passed into SpikeInterface's Phy export machine + `export_to_phy()`. + + TODO: This doesn't handle gain_to_uV... should I not be doing my own amps here? + + Parameters + ---------- + recording : BaseRecording + template_data : TemplateData | None, optional + Templates to register with SortingAnalyzer as the templates + extension + template_cfg : TemplateConfig | None, optional + If template_data is not supplied but this is (together with motion), + templates will be estimated using dartsort machinery + motion: MotionInfo | None, optional + drop_doubles : bool + Call .drop_doubles(). This will probably do nothing if the sorting had + dedup_ms > 0 in the parameters. + compute_extensions : Sequence[str] | None + Extra analyzer extensions to compute. The default set is what's needed for + SpikeInterface's `export_to_phy()` to run. These are included here because + SpikeInterface is picky about extension order and will evict the template + extension if some of these are computed after the templates, which would then + cause the Phy export to fail. + features_cfg : ClusteringFeaturesConfig + Stores attribute dataset names + + Returns + ------- + analyzer : SortingAnalyzer + """ + from spikeinterface.core import ComputeTemplates + from spikeinterface.postprocessing import ( + ComputeSpikeLocations, + ComputeUnitLocations, + ) + + sorting, kept_indices = self.to_numpy_sorting( + drop_doubles=drop_doubles, return_kept_indices=True + ) # type: ignore + + sparsity = estimate_sparsity(sorting, recording) + analyzer = create_sorting_analyzer( + sorting=sorting, recording=recording, sparsity=sparsity, return_in_uV=False + ) + + for ext in compute_extensions or []: + analyzer.compute_one_extension(ext) + + loc_name = features_cfg.localizations_dataset_name + if (locs := self.localizations_as_structured_array(loc_name)) is not None: + locs = locs[kept_indices] + loc_ext = ComputeSpikeLocations(analyzer) + loc_ext.data = {"spike_locations": locs} + loc_ext.params = {} + loc_ext.run_info = {"run_completed": True} + analyzer.extensions["spike_locations"] = loc_ext + + if template_data is None and template_cfg is not None: + from ..templates.postprocess_util import estimate_template_library + + _, template_data = estimate_template_library( + recording=recording, + sorting=self, + motion=motion, + template_cfg=template_cfg, + ) + + if template_data is not None: + td_ext = ComputeTemplates(analyzer) + assert np.array_equal( + template_data.unit_ids, np.arange(len(template_data.unit_ids)) + ) + s_before = template_data.trough_offset_samples + s_after = template_data.spike_length_samples - s_before + ms_per_sample = 1000 / self.sampling_frequency + + td_ext.data = {"average": template_data.templates} + td_ext.params = { + "operators": ["average"], + "ms_before": s_before * ms_per_sample, + "ms_after": s_after * ms_per_sample, + "peak_sign": "both", + } + td_ext.run_info = {"run_completed": True} + analyzer.extensions["templates"] = td_ext + + uloc_ext = ComputeUnitLocations(analyzer) + uloc_ext.params = {"method": "monopolar_triangulation"} + # turns out they don't want a structured array in this extension + ulocs = template_data.template_locations(mode="localization") + uloc_ext.data = {"unit_locations": ulocs} + uloc_ext.run_info = {"run_completed": True} + analyzer.extensions["unit_locations"] = uloc_ext + + return analyzer + + def localizations_as_structured_array( + self, property_name="point_source_localizations" + ) -> np.ndarray | None: + locs = getattr(self, property_name, None) + if locs is None: + return None + + return si_structured_localizations_array(locs) + + def permute_labels(self, seed=0): + assert self.labels is not None + rg = np.random.default_rng(seed) + return self.ephemeral_replace(labels=rg.permutation(self.labels)) + def copy(self) -> Self: """Shallow copy. Doesn't copy data, but copies references and internal state.""" other = copy(self) @@ -902,6 +1048,8 @@ def get_tpca(sorting, name_prefix="collisioncleaned", featurization_pipeline_pt= if k in ("TemporalPCA", "TemporalPCAFeaturizer") and v["name_prefix"] == name_prefix ] + if len(tpca_kw) == 0: + raise ValueError(f"No TPCA found in {featurization_pipeline_pt}.") assert len(tpca_kw) == 1 ix, clsname, kw = tpca_kw[0] tpca = transformers_by_class_name[clsname]( @@ -1027,6 +1175,28 @@ def sorting_from_spikeinterface( ) +def si_structured_localizations_array(locs: np.ndarray) -> np.ndarray: + """Convert our localization format to SpikeInterface's""" + # NB: spikeinterface's y is our z. I like theirs better, I'm sorry, it's not my fault. + if locs.shape[1] >= 3: + column_names = ["x", "y", "z"] + elif locs.shape[1] == 2: + column_names = ["x", "y"] + else: + assert False + dtype = [(name, locs.dtype) for name in column_names] + + structured_array = np.zeros(len(locs), dtype=dtype) + structured_array["x"] = locs[:, 0] + if locs.shape[1] >= 3: + structured_array["y"] = locs[:, 2] + structured_array["z"] = locs[:, 1] + elif locs.shape[1] == 2: + structured_array["y"] = locs[:, 1] + + return structured_array + + def filter_link_h5(in_h5_path: str | Path, out_h5_path: str | Path, keep_filter): in_h5_path = ensure_path(in_h5_path, strict=True) out_h5_path = ensure_path(out_h5_path) diff --git a/src/dartsort/util/internal_config.py b/src/dartsort/util/internal_config.py index 0293c760..c008803b 100644 --- a/src/dartsort/util/internal_config.py +++ b/src/dartsort/util/internal_config.py @@ -129,6 +129,7 @@ def pad(self, padding_ms: float) -> Self: @cfg_dataclass class InterpolationParams: """Spatial waveform or feature interpolation parameters""" + method: InterpMethod = "kriging" kernel: InterpKernel = "thinplate" extrap_method: InterpMethod | None = None @@ -213,10 +214,11 @@ def normalize(self) -> Self: @cfg_dataclass class FitSamplingConfig: """Data sampling parameters for model fitting""" + max_waveforms_fit: int = 50_000 n_waveforms_fit: int = 40_000 more_waveforms_fit: int = 2000 * 1024 - n_residual_snips: int = 4 * 4096 + n_residual_snips: int = 2 * 4096 residual_snip_ms: float | None = None residual_sampling_target_density: float = 0.25 seed: int = 0 @@ -235,6 +237,7 @@ class FitSamplingConfig: @cfg_dataclass class ClusteringFeaturesConfig: """Parameters to control which features are used for initial clustering""" + # simple matrix feature controls use_x: bool = True use_z: bool = True @@ -269,6 +272,7 @@ class ClusteringFeaturesConfig: @cfg_dataclass class ClusteringConfig: """Initial clustering parameters""" + cluster_strategy: str = "dpc" sampling_cfg: FitSamplingConfig = default_clustering_fit_sampling_cfg @@ -285,6 +289,7 @@ class ClusteringConfig: noise_density: float = 0.0 outlier_radius: float | None = 25.0 outlier_neighbor_count: int = 10 + dpc_kmeans_cleanup: bool = False # gmm density peaks additional parameters kmeanspp_initializations: int = 10 @@ -318,10 +323,12 @@ class ClusteringConfig: @cfg_dataclass class WhiteningConfig: """Whitening parameters""" + strategy: WhiteningStrategy = "none" estimator: WhiteningEstimator = "localzca" interp_params: InterpolationParams = tps_interp_clampna_extrap_params radius: float = 200.0 + temporal_length: int | None = None TemplateSVDMethod = Literal[ @@ -332,6 +339,7 @@ class WhiteningConfig: @cfg_dataclass class TemplateConfig: """Template waveform estimation parameters""" + spikes_per_unit: int = 500 with_raw_std_dev: bool = False reduction: Literal["median", "mean"] = "median" @@ -367,6 +375,7 @@ class TemplateConfig: try_reload_svd: bool = True svd_alignment_iterations: int = 0 svd_alignment_ms: float = 0.75 + svd_min_explained_variance: float = 5e-3 # exp weight denoising exp_weight_snr_threshold: float = 50.0 @@ -410,6 +419,7 @@ def actual_algorithm(self) -> str: @cfg_dataclass class TemplateRealignmentConfig: """Template waveform alignment parameters""" + realign_peaks: bool = True realign_strategy: RealignStrategy = "snr_weighted_trough_factor" realign_shift_ms: float = 1.5 @@ -421,6 +431,7 @@ class TemplateRealignmentConfig: @cfg_dataclass class TemplateMergeConfig: """Parameters describing how to judge whether to merge groups of templates""" + distance_kind: Literal[ "scaled_normeuc", "deconv", "max", "weighted_scaled_normeuc" ] = "weighted_scaled_normeuc" @@ -452,13 +463,14 @@ def to_template_config(self, template_cfg: TemplateConfig | None = None): ) -MixtureStep = Literal["split", "merge", "demolish"] +MixtureStep = Literal["split", "singlesplit", "merge", "demolish"] ComponentDistanceMetric = Literal["cosine", "normeuc", "scaled_normeuc"] @cfg_dataclass class RefinementConfig: """Parameters for clustering refinement""" + refinement_strategy: str = "tmm" sampling_cfg: FitSamplingConfig = default_clustering_fit_sampling_cfg @@ -506,7 +518,9 @@ class RefinementConfig: mixture_steps: Sequence[MixtureStep] = ("split", "merge", "demolish") prior_pseudocount: float = 0.0 kmeansk: int = 4 - kmeans_tries: int = 5 + single_split_k: int = 3 + kmeans_tries: int = 10 + kmeans_beta: float = 50.0 kmeanspp_tries: int = 5 full_proposal_every: int = 10 main_min_iters: int = 20 @@ -515,8 +529,9 @@ class RefinementConfig: robust_fixed_std_dataset: str = "collidedness" robust_fixed_power: float = 40.0 robust_df: float = 4.0 - demolition_min_resp_ratio: float = 1.1 + demolition_min_resp_ratio: float = 0.9 demolish_during_selection: bool = False + refit_in_demolition: bool = False em_after_demolish: bool = False whiten_split: bool = True scale_dist_args: tuple[float, float, float] = (0.01, 3.0 / 4.0, 4.0 / 3.0) @@ -536,6 +551,11 @@ class RefinementConfig: qda_min_coverage: float = 0.35 qda_min_iou: float = 0.5 qda_force_merge_for_temp_dist_below: float = 0.3 + spikeinterface_merge_preset: str | None = None + spikeinterface_merge_max_distance: float = 0.5 + spikeinterface_merge_min_coentropy: float | None = 0.01 + spikeinterface_merge_coent_coverage: float = 0.8 + spikeinterface_merge_coent_iou: float = 0.5 # forward_backward parameters chunk_size_s: float = 300.0 @@ -551,6 +571,7 @@ class RefinementConfig: # deduplication control dedup_ms: float = 0.0 + censor_ms: float = 0.3 @cfg_dataclass @@ -629,6 +650,10 @@ class FeaturizationConfig: gmm_refinement_cfg: RefinementConfig | None = None gmm_clustering_features_cfg: ClusteringFeaturesConfig | None = None + # helper for fitting whiteners + fit_disabled_whitener: bool = False + whiten_cfg: WhiteningConfig | None = None + # used when naming datasets saved to h5 files input_waveforms_name: str = "collisioncleaned" output_waveforms_name: str = "denoised" @@ -640,6 +665,7 @@ class FeaturizationConfig: @cfg_dataclass class SubtractionConfig: """Parameters for neural-net based spike detection""" + # peeling common chunk_length_samples: int = 30_000 fit_only: bool = False @@ -653,11 +679,11 @@ class SubtractionConfig: relative_peak_radius_samples: int = 5 relative_peak_radius_um: float | None = 35.0 spatial_dedup_radius_um: float | None = 50.0 - temporal_dedup_radius_samples: int = 11 + temporal_dedup_radius_samples: int = 7 remove_exact_duplicates: bool = True positive_temporal_dedup_radius_samples: int = 41 subtract_radius_um: float = 200.0 - residnorm_decrease_threshold: float = 10.0 + residnorm_decrease_threshold: float = 9.0 decrease_objective: Literal["norm", "normsq", "deconv"] = "deconv" growth_tolerance: float | None = None trough_priority: float | None = 2.0 @@ -665,6 +691,9 @@ class SubtractionConfig: convexity_threshold: float | None = None convexity_radius: int = 7 max_iter: int = 100 + whiten: bool = True + threshold_before_whitening: float = 10.0 + whiten_cfg: WhiteningConfig | None = WhiteningConfig(strategy="prewhiten_postapply") # how will waveforms be denoised before subtraction? # users can also save waveforms/features during subtraction @@ -694,6 +723,7 @@ class SubtractionConfig: @cfg_dataclass class ThresholdingConfig: """Parameters for threshold-crossing spike detection""" + # peeling common chunk_length_samples: int = 30_000 @@ -720,6 +750,7 @@ class ThresholdingConfig: @cfg_dataclass class MatchingConfig: """Template matching pursuit parameters""" + # peeling common chunk_length_samples: int = 30_000 max_spikes_per_second: int = 16384 @@ -729,6 +760,7 @@ class MatchingConfig: # template matching parameters threshold: float | Literal["fp_control"] = 8.0 template_svd_compression_rank: int = 5 + template_svd_compression_min_explained_variance: float = 5e-3 up_factor: int = 4 upsampling_radius: int = 8 template_min_channel_amplitude: float = 1.0 @@ -814,6 +846,7 @@ class MotionEstimationConfig: @cfg_dataclass class ComputationConfig: """Multiprocessing or threading parameters""" + n_jobs_cpu: int = 0 n_jobs_gpu: int = 0 n_jobs_small: int = -2 @@ -872,7 +905,9 @@ def is_multi_gpu(self): default_motion_estimation_cfg = MotionEstimationConfig() default_computation_cfg = ComputationConfig() default_refinement_cfg = RefinementConfig() -default_initial_refinement_cfg = RefinementConfig(mixture_steps=("split", "demolish")) +default_initial_refinement_cfg = RefinementConfig( + mixture_steps=("split", "demolish", "demolish") +) default_pre_refinement_cfg = RefinementConfig(refinement_strategy="pcmerge") default_agglomerate_cfg = RefinementConfig( refinement_strategy="agglomerate", @@ -916,7 +951,7 @@ class DARTsortInternalConfig: recluster_after_first_matching: bool = False # subsampling: intermediate peels will continue until both criteria satisfied # need at least this many spikes - subsampling_spikes: int | None = 2_048_000 + subsampling_spikes_per_channel: int | None = 5000 # need to cover at least this fraction of chunks subsampling_presence: float = 0.1 @@ -935,13 +970,16 @@ class DARTsortInternalConfig: save_everything_on_error: bool = False -def to_internal_config(cfg) -> DARTsortInternalConfig: +def to_internal_config(cfg, n_channels: int) -> DARTsortInternalConfig: """Laundromat of configuration formats Parameters ---------- cfg : str | Path | DARTsortUserConfig | DeveloperConfig If str or Path, it should point to a .toml file. + n_channels : int + The number of channels in the input recording, used to adjust + parameters which are specified as counts per channel. Returns ------- @@ -988,8 +1026,29 @@ def to_internal_config(cfg) -> DARTsortInternalConfig: tpca_max_waveforms=cfg.n_waveforms_fit, save_input_waveforms=cfg.save_collisioncleaned_waveforms, save_collidedness=save_collidedness, + tpca_from_templates=cfg.tpca_from_templates, + ) + if cfg.template_interp_kind == "tps": + temp_interp_params = tps_interp_clampna_extrap_params + elif cfg.template_interp_kind == "clampna": + temp_interp_params = clampna_interp_params + else: + assert False + if cfg.matching_interp_kind == "tps": + match_interp_params = tps_interp_clampna_extrap_params + elif cfg.matching_interp_kind == "clampna": + match_interp_params = clampna_interp_params + else: + assert False + whiten_cfg = WhiteningConfig( + strategy=cfg.whiten_strategy, + estimator=cfg.whiten_estimator, + radius=cfg.subtraction_radius_um, + interp_params=temp_interp_params, + temporal_length=cfg.whiten_temporal_length, ) - if cfg.dredge_only: + # TODO: dredge_only is a bad name for this. + if cfg.dredge_only and not cfg.whiten_in_subtraction: n_residual_snips = 0 else: n_residual_snips = cfg.n_residual_snips @@ -1028,6 +1087,10 @@ def to_internal_config(cfg) -> DARTsortInternalConfig: first_denoiser_noise_snips=cfg.nn_denoiser_noise_waveforms, first_denoiser_spatial_dedup_radius=cfg.first_denoiser_spatial_dedup_radius, subtraction_denoising_cfg=subtraction_denoising_cfg, + temporal_dedup_radius_samples=cfg.temporal_dedup_radius_samples, + positive_temporal_dedup_radius_samples=cfg.positive_temporal_dedup_radius_samples, + whiten=cfg.whiten_in_subtraction, + whiten_cfg=whiten_cfg, ) elif cfg.detection_type == "threshold": initial_detection_cfg = ThresholdingConfig( @@ -1035,6 +1098,7 @@ def to_internal_config(cfg) -> DARTsortInternalConfig: detection_threshold=cfg.voltage_threshold, spatial_dedup_radius_um=cfg.deduplication_radius_um, chunk_length_samples=cfg.chunk_length_samples, + temporal_dedup_radius_samples=cfg.temporal_dedup_radius_samples, ) elif cfg.detection_type == "match": assert cfg.precomputed_templates_npz is not None @@ -1056,24 +1120,6 @@ def to_internal_config(cfg) -> DARTsortInternalConfig: else: raise ValueError(f"Unknown detection_type {cfg.detection_type}.") - if cfg.template_interp_kind == "tps": - temp_interp_params = tps_interp_clampna_extrap_params - elif cfg.template_interp_kind == "clampna": - temp_interp_params = clampna_interp_params - else: - assert False - if cfg.matching_interp_kind == "tps": - match_interp_params = tps_interp_clampna_extrap_params - elif cfg.matching_interp_kind == "clampna": - match_interp_params = clampna_interp_params - else: - assert False - whiten_cfg = WhiteningConfig( - strategy=cfg.whiten_strategy, - estimator=cfg.whiten_estimator, - radius=cfg.subtraction_radius_um, - interp_params=temp_interp_params, - ) template_cfg = TemplateConfig( denoising_fit_radius=cfg.fit_radius_um, spikes_per_unit=cfg.template_spikes_per_unit, @@ -1195,13 +1241,13 @@ def to_internal_config(cfg) -> DARTsortInternalConfig: detection_threshold=cfg.motion_voltage_threshold, chunk_length_samples=cfg.chunk_length_samples, peak_sign=cfg.peak_sign, - shave_score=cfg.initial_threshold, + shave_score=cfg.threshold_before_whitening, ) motion_estimation_cfg = MotionEstimationConfig( **motion_kw, tpca_rank=cfg.temporal_pca_rank, threshold_cfg=motion_threshold_cfg, - spike_denoising_score=cfg.initial_threshold, + spike_denoising_score=cfg.threshold_before_whitening, ) matching_cfg = MatchingConfig( threshold="fp_control" if cfg.matching_fp_control else cfg.matching_threshold, @@ -1269,6 +1315,8 @@ def to_internal_config(cfg) -> DARTsortInternalConfig: template_merge_cfg=agg_tmcfg, qda_threshold=0.0, dedup_ms=cfg.deduplication_ms, + spikeinterface_merge_preset=cfg.spikeinterface_merge_preset, + spikeinterface_merge_max_distance=cfg.spikeinterface_merge_max_distance, ) elif cfg.agg_kind == "qda": agg_whiten_cfg = WhiteningConfig( @@ -1289,6 +1337,8 @@ def to_internal_config(cfg) -> DARTsortInternalConfig: template_merge_cfg=agg_tmcfg, qda_force_merge_for_temp_dist_below=cfg.agg_no_qda_template_distance, dedup_ms=cfg.deduplication_ms, + spikeinterface_merge_preset=cfg.spikeinterface_merge_preset, + spikeinterface_merge_max_distance=cfg.spikeinterface_merge_max_distance, ) else: assert False @@ -1326,7 +1376,7 @@ def to_internal_config(cfg) -> DARTsortInternalConfig: save_everything_on_error=cfg.save_everything_on_error, link_from=cfg.link_from, link_step=cfg.link_step, - subsampling_spikes=cfg.subsampling_spikes, + subsampling_spikes_per_channel=cfg.subsampling_spikes_per_channel, subsampling_presence=cfg.subsampling_presence, always_save_final_tpca_feature=cfg.always_save_final_tpca_feature, ) diff --git a/src/dartsort/util/main_util.py b/src/dartsort/util/main_util.py index 2146960e..679c63fb 100644 --- a/src/dartsort/util/main_util.py +++ b/src/dartsort/util/main_util.py @@ -143,7 +143,7 @@ def motion_needs_peaks( ): if cfg.subsampling_presence == 1.0: return False - if cfg.subsampling_spikes is None: + if cfg.subsampling_spikes_per_channel is None: return False # assert sorting's chunk starts, sorted, match full recording's diff --git a/src/dartsort/util/noise_util.py b/src/dartsort/util/noise_util.py index 49a793e4..267d13c3 100644 --- a/src/dartsort/util/noise_util.py +++ b/src/dartsort/util/noise_util.py @@ -979,7 +979,9 @@ def estimate( cov = torch.cov(x_spatial.T.double()) cov = spiketorch.enforce_posdef(cov, eps=eps) else: - cov = spiketorch.nancov(x_spatial[:, valid].double(), force_posdef=True, eps=eps) + cov = spiketorch.nancov( + x_spatial[:, valid].double(), force_posdef=True, eps=eps + ) assert torch.is_tensor(cov) if shrinkage: cov = F.softshrink(cov, shrinkage) @@ -1449,26 +1451,18 @@ def fp_control_threshold_from_h5( def residual_covariance( sorting: DARTsortSorting, do_interpolation: bool, - motion: MotionInfo, + motion: MotionInfo | None, interp_params: InterpolationParams = tps_interp_clampna_extrap_params, device: torch.device | None = None, - rgeom=None, residual_times_s_dataset_name="residual_times_seconds", residual_dataset_name="residual", seed: int = 0, batch_size=256, -): +) -> Tensor: assert sorting.parent_h5_path is not None if do_interpolation: - with h5py.File(sorting.parent_h5_path, "r", locking=False) as h5: - geom = cast(h5py.Dataset, h5["geom"])[:].astype(np.float32) - if rgeom is None: - if motion is None: - rgeom = geom - else: - rgeom = motion.rgeom - rgeom = rgeom.astype(np.float32) + assert motion is not None snipgen = generate_interpolated_residual_snippets( motion=motion, hdf5_path=sorting.parent_h5_path, @@ -1503,10 +1497,56 @@ def residual_covariance( N += n w = n / N cov += scov.sub_(cov).mul_(w) + assert N > 0 + assert cov is not None return cov +def residual_welch_whitener( + sorting: DARTsortSorting, + device: torch.device | None = None, + residual_dataset_name="residual", + batch_size=1024, + temporal_length: int = 11, + spatial_whitener: torch.Tensor | None = None, +): + """Estimate a 0-phase whitening convolution kernel with Welch's method""" + assert sorting.parent_h5_path is not None + snipgen = sorting._yield_dataset(residual_dataset_name, batch_size=batch_size) + + if spatial_whitener is not None: + W = torch.asarray(spatial_whitener, device=device) + else: + W = None + + # Welch's method to estimate residual PSD + snip_psds = [] + block_len = temporal_length + for snip in snipgen: + block_len = next_fast_len(snip.shape[1]) + snip = torch.asarray(snip).to(device=device, non_blocking=True) + if W is not None: + snip = torch.einsum("ntc,cd->ntd", snip, W) + snip = snip.permute(0, 2, 1).reshape(-1, snip.shape[1]) + periodogram = torch.fft.rfft(snip, n=block_len, norm="ortho") + dens = (periodogram * periodogram.conj()).mean(dim=0) + snip_psds.append(dens) + snip_psds = torch.stack(snip_psds, dim=0) + spectral_density = snip_psds.mean(0).sqrt_() + + # estimate 0-phase FIR whitener + wkernel = torch.fft.irfft(1.0 / spectral_density, n=block_len) + wkernel = torch.fft.fftshift(wkernel) + + # trim to requested length + assert temporal_length <= wkernel.shape[0] + i0 = wkernel.shape[0] // 2 - temporal_length // 2 + i1 = i0 + temporal_length + wkernel = wkernel[i0:i1].clone() + return wkernel + + def fullzca_whitener( cov: np.ndarray, channel_index: np.ndarray | None = None, eps=1e-6 ) -> np.ndarray: @@ -1519,7 +1559,6 @@ def fullzca_whitener( def localzca_whitener( cov: np.ndarray, channel_index: np.ndarray, eps=1e-6 ) -> np.ndarray: - """""" w = np.zeros_like(cov) for j, chans in enumerate(channel_index): chans = chans[chans < len(channel_index)] @@ -1558,30 +1597,78 @@ def sparsechol_whitener( } -class SpatialWhitener(BModule): - def __init__(self, whitener: Tensor): +class Whitener(BModule): + def __init__( + self, whitener: Tensor, covariance: Tensor, temporal_kernel: Tensor | None + ): super().__init__() self.register_buffer("whitener", whitener) + self.register_buffer("covariance", covariance) + self.temporal = temporal_kernel is not None + self.register_buffer_or_none("temporal_kernel", temporal_kernel) + if temporal_kernel is not None: + self.temporal_length = temporal_kernel.shape[0] + tk_twice = self._convolve(temporal_kernel) + else: + self.temporal_length = 0 + tk_twice = None + self.register_buffer_or_none("temporal_kernel_twice", tk_twice) @classmethod - def from_numpy(cls, whitener: np.ndarray): - logger.dartsortverbose("Load whitener from numpy.") - return cls(whitener=torch.asarray(whitener)) + def blank(cls, n_channels: int, device: torch.device, temporal_length: int | None): + w = torch.zeros((n_channels, n_channels), device=device) + if temporal_length: + k = torch.zeros((temporal_length,), device=device) + else: + k = None + return cls(w, torch.zeros_like(w), k) - def to_numpy(self) -> np.ndarray: - return self.b.whitener.numpy(force=True) + @classmethod + def from_numpy( + cls, + whitener: np.ndarray, + covariance: np.ndarray, + temporal_kernel: np.ndarray | None, + ): + if temporal_kernel is None: + tk = None + tmsg = "" + else: + tmsg = f" with temporal length {temporal_kernel.shape[0]}" + tk = torch.asarray(temporal_kernel) + logger.dartsortverbose("Load whitener%s from numpy.", tmsg) + return cls( + whitener=torch.asarray(whitener), + covariance=torch.asarray(covariance), + temporal_kernel=tk, + ) + + def to_numpy(self) -> tuple[np.ndarray, np.ndarray, np.ndarray | None]: + tk = self.b.temporal_kernel + if tk is not None: + tk = tk.numpy(force=True) + return ( + self.b.whitener.numpy(force=True), + self.b.covariance.numpy(force=True), + tk, + ) @classmethod def from_config( cls, *, sorting: DARTsortSorting, - motion: MotionInfo, + motion: MotionInfo | None, whiten_cfg: WhiteningConfig, computation_cfg: ComputationConfig | None = None, ) -> Self: logger.dartsortdebug( - "Estimating %s-%s whitener.", whiten_cfg.strategy, whiten_cfg.estimator + "Estimating %s-%s whitener%s.", + whiten_cfg.strategy, + whiten_cfg.estimator, + "" + if not whiten_cfg.temporal_length + else f"temporal length: {whiten_cfg.temporal_length}", ) device = ensure_computation_config(computation_cfg).actual_device() cov = residual_covariance( @@ -1599,25 +1686,76 @@ def from_config( cov_np, channel_index=neighbs ) whitener = torch.asarray(whitener).to(cov) - return cls(whitener=whitener) + + if whiten_cfg.temporal_length: + assert whiten_cfg.strategy != "postwhiten" + temporal_kernel = residual_welch_whitener( + sorting=sorting, + device=device, + temporal_length=whiten_cfg.temporal_length, + spatial_whitener=whitener, + ) + assert temporal_kernel.shape == (whiten_cfg.temporal_length,) + else: + temporal_kernel = None + + return cls(whitener=whitener, covariance=cov, temporal_kernel=temporal_kernel) + + def _convolve(self, x: Tensor, twice=False, padding="same"): + if not self.temporal: + return x + *shp, t = x.shape + x = x.reshape(-1, 1, t) + if twice: + k = self.b.temporal_kernel_twice + else: + k = self.b.temporal_kernel + if padding == "full": + padding = k.shape[0] - 1 + k = k.to(device=x.device) + res = F.conv1d( + input=x, + weight=k[None, None], + padding=padding, + groups=x.shape[1], + ) + if padding == "same": + assert res.shape[-1] == t + ot = t + else: + assert isinstance(padding, int) + ot = t + padding + res = res.reshape(*shp, ot) + return res def whiten_traces_spatial_major( - self, x: Tensor, out: Tensor | None = None + self, x: Tensor, out: Tensor | None = None, padding="same" ) -> Tensor: - return torch.mm(self.b.whitener, x.T, out=out) + assert x.ndim == 2 + out = torch.mm(self.b.whitener, x.T, out=out) + out = self._convolve(out, padding=padding) + return out - def whiten(self, x: Tensor, out: Tensor | None = None) -> Tensor: + def whiten( + self, x: Tensor, out: Tensor | None = None, spatial_only: bool = False + ) -> Tensor: *shp, c = x.shape x = x.reshape(-1, c) x = torch.mm(x, self.b.whitener.T, out=out) x = x.reshape(*shp, c) + if self.temporal and not spatial_only: + x = self._convolve(x.mT).mT return x - def transpose_whiten(self, x: Tensor, out: Tensor | None = None) -> Tensor: + def transpose_whiten( + self, x: Tensor, out: Tensor | None = None, spatial_only: bool = False + ) -> Tensor: *shp, c = x.shape x = x.reshape(-1, c) x = torch.mm(x, self.b.whitener, out=out) x = x.reshape(*shp, c) + if self.temporal and not spatial_only: + x = self._convolve(x.mT).mT return x def prec_mul(self, x: Tensor) -> Tensor: @@ -1625,4 +1763,19 @@ def prec_mul(self, x: Tensor) -> Tensor: x = x.reshape(-1, c) x = x @ (self.b.whitener.T @ self.b.whitener) x = x.reshape(*shp, c) + if self.temporal: + x = self._convolve(x.mT, twice=True).mT return x + + def local_whiteners(self, channel_index: Tensor, eps=1e-6): + channel_index = torch.asarray(channel_index) + nc, cloc = channel_index.shape + assert nc == self.b.covariance.shape[0] + w = self.b.covariance.new_zeros((nc, cloc, cloc)) + for j, chans in enumerate(channel_index): + mask = (chans < nc).nonzero()[:, 0] + chans = chans[mask] + cj = self.b.covariance[chans][:, chans] + wj = fullzca_whitener(cj.numpy(force=True).astype(np.float64), eps=eps) + w[j, mask[:, None], mask[None, :]] = torch.asarray(wj).to(w) + return w diff --git a/src/dartsort/util/spiketorch.py b/src/dartsort/util/spiketorch.py index 2f56449e..de8f4d4c 100644 --- a/src/dartsort/util/spiketorch.py +++ b/src/dartsort/util/spiketorch.py @@ -1,6 +1,5 @@ import math import warnings -from logging import getLogger from typing import overload import linear_operator @@ -8,16 +7,24 @@ import torch import torch.nn.functional as F from linear_operator.utils.cholesky import psd_safe_cholesky +from packaging.version import Version from scipy.fftpack import next_fast_len from scipy.spatial.distance import squareform from sklearn.utils.extmath import svd_flip from torch import Tensor from torch.fft import irfft, rfft -from .logging_util import progrange +from .logging_util import get_logger, progrange from .torch_util import torch_compile, torch_compiler -logger = getLogger(__name__) +TORCH_IS_OLD = Version(torch.__version__) < Version("2.6.0") +if TORCH_IS_OLD and torch.cuda.is_available(): + warnings.warn( + f"Your PyTorch version ({torch.__version__}) is supported by dartsort, " + "but dartsort would be faster if you had >= 2.6.0." + ) + +logger = get_logger(__name__) log2pi = torch.log(torch.tensor(2 * np.pi)) _1 = torch.tensor(1.0) _0 = torch.tensor(0.0) @@ -104,6 +111,18 @@ def ptp(waveforms, dim=1, keepdims=False): return v.numpy() +if TORCH_IS_OLD: + + def _nonzero_static(x: Tensor, size: int): + nz = x.nonzero() + assert nz.shape[0] == size + return nz +else: + + def _nonzero_static(x: Tensor, size: int): + return x.nonzero_static(size=size) + + @torch_compile def mean_elbo_dim1(Q: Tensor, log_liks: Tensor) -> Tensor: logQ = Q.log().nan_to_num_(nan=None, neginf=0.0) @@ -752,7 +771,7 @@ def argrelmax( ): msk = _argrelmax_mask(x=x, radius=radius, threshold=threshold, arange=arange) return msk.nonzero()[:, 0] - + @torch_compile def _argrelmax_mask( @@ -819,7 +838,7 @@ def argrelmax_dedup( padding=padding, ) return msk.nonzero()[:, 0] - + @torch_compile def _argrelmax_dedup_mask( diff --git a/src/dartsort/util/testing_util/matching_debug_util.py b/src/dartsort/util/testing_util/matching_debug_util.py index 1c26140c..7a9432d8 100644 --- a/src/dartsort/util/testing_util/matching_debug_util.py +++ b/src/dartsort/util/testing_util/matching_debug_util.py @@ -86,20 +86,22 @@ def yield_step_results( device = matcher.b.channel_index.device chunk = torch.asarray(chunk, device=device) assert matcher.matching_templates is not None + assert not matcher.whiten_features chunk_data = matcher.matching_templates.data_at_time( t_s, scaling=matcher.is_scaling, inv_lambda=matcher.inv_lambda, scale_min=matcher.amp_scale_min, scale_max=matcher.amp_scale_max, + resid_offset=0,#matcher.whiten_pad, ) - cur_residual = chunk.clone() for it in ( progrange(max_iter, desc="Match steps") if show_progress else range(max_iter) ): - pre_conv = chunk_data.convolve(cur_residual.T, padding=matcher.obj_pad_len) + cur_traces_wh = chunk_data.whiten_traces(cur_residual) + pre_conv = chunk_data.convolve(cur_traces_wh, padding=matcher.obj_pad_len) if obj_mode: pre_conv = chunk_data.obj_from_conv( conv=pre_conv, @@ -162,6 +164,9 @@ def visualize_step_results( chunk_vis_style: Literal["im", "trace"] = "im", gt_sorting: DARTsortSorting | None = None, vis_only_last_step: bool = False, + vline_new_peaks=True, + vline_at=None, + objline_at=None, ): import matplotlib.pyplot as plt @@ -205,6 +210,7 @@ def visualize_step_results( if vis_only_last_step: iterator = list(iterator) + resid = None for it, resid, pre_conv, conv, times_samples, labels, channels in iterator: v = np.flatnonzero(times_samples == times_samples.clip(vis_start, vis_end - 1)) times_samples = times_samples[v] - vis_start @@ -253,6 +259,11 @@ def visualize_step_results( ) ax.set_ylabel(name) + if vline_new_peaks: + for ax in axes.flat: + for ts, ll in zip(times_samples, labels): + ax.axvline(ts, c=glasbey1024[ll % len(glasbey1024)], lw=0.5, ls=":") + if gt_t is not None: axes[-3].scatter(gt_t, gt_chan, c=gt_c, s=4 * s, lw=0, marker="o") @@ -273,7 +284,9 @@ def visualize_step_results( ec="w", lw=1, ) + axes[-3].set_ylim([0, chunk.shape[1]]) + vmax = max(np.nanmax(pre_conv[:, obj_sl]), np.nanmax(conv[:, obj_sl])) for j, c in enumerate(pre_conv): axes[-2].plot( obj_domain, c[obj_sl], color=glasbey1024[j % len(glasbey1024)], lw=0.5 @@ -288,16 +301,31 @@ def visualize_step_results( axes[-1].set_ylabel("post-step " + ("obj" if obj_mode else "conv")) for ax in axes[-2:]: ax.grid() - if obj_mode: - vmin = max(-100, pre_conv[:, obj_sl].min(), conv[:, obj_sl].min()) - for ax in axes[-2:]: - ax.set_ylim([vmin, pre_conv[:, obj_sl].max() * 1.05]) + vmin = max(-100, pre_conv[:, obj_sl].min(), conv[:, obj_sl].min()) + for ax in axes[-2:]: + ax.set_ylim([vmin, vmax * 1.05]) panel.suptitle(f"iteration {it}", fontsize=12) + if vline_at is not None: + for ax in axes[:3]: + ax.axvline(vline_at, color="w", ls="--", lw=0.8) + for ax in axes[-2:]: + ax.axvline(vline_at, color="k", ls="--", lw=0.8) + if objline_at is not None: + for ax in axes[-2:]: + ax.axhline(objline_at, color="k", ls="--", lw=0.8) + plt.show() plt.close(panel) + return dict( + resid=resid, + times=t_full[:n], + channels=c_full[:n], + labels=l_full[:n], + ) + # -- reference implementation for upsampled matching @@ -362,9 +390,13 @@ def data_at_time( inv_lambda: float, scale_min: float, scale_max: float, + resid_offset: int = 0, ) -> ChunkTemplateData: + assert not resid_offset return DebugChunkTemplateData( spike_length_samples=self.b.templates_up.shape[2], + filter_length_samples=self.b.templates_up.shape[2], + resid_offset=resid_offset, unit_ids=torch.arange( self.b.templates_up.shape[0], device=self.b.pconv.device ), @@ -389,6 +421,8 @@ def data_at_time( @databag class DebugChunkTemplateData(ChunkTemplateData): spike_length_samples: int + filter_length_samples: int + resid_offset: int unit_ids: Tensor main_channels: Tensor obj_normsq: Tensor diff --git a/src/dartsort/vis/analysis_plots.py b/src/dartsort/vis/analysis_plots.py index 77a090ed..bdbeb2ff 100644 --- a/src/dartsort/vis/analysis_plots.py +++ b/src/dartsort/vis/analysis_plots.py @@ -1,5 +1,6 @@ import matplotlib.pyplot as plt import numpy as np +import pandas as pd import scipy.cluster.hierarchy import seaborn as sns import torch @@ -10,7 +11,11 @@ from ..clustering.cluster_util import leafsets from ..util import spikeio -from ..util.data_util import DARTsortSorting, try_get_denoising_pipeline +from ..util.data_util import ( + DARTsortSorting, + get_featurization_pipeline, + try_get_denoising_pipeline, +) from .colors import glasbey1024 @@ -355,7 +360,6 @@ def isi_hist( dt_ms, bin_edges, color=color, label=label, histtype=histtype, alpha=alpha ) axis.set_xlabel("isi (ms)") - axis.set_ylabel(f"count (out of {dt_ms.size} total isis)") def centered_bins(x, dx=1.0): @@ -432,11 +436,36 @@ def stackbar(ax, x, y, colors, labels, fill=True): def plot_correlogram( - axis, times_a, times_b=None, max_lag=50, color="k", fill=True, **stairs_kwargs + axis, + times_a, + times_b=None, + max_lag=50, + samples_per_ms: float = 30.0, + color="k", + fill=True, + bin=1, + to_ms=False, + **stairs_kwargs, ): lags, ccg = correlogram(times_a, times_b=times_b, max_lag=max_lag) - axis.set_xlabel("lag (samples)") - return bar(axis, lags, ccg, fill=fill, color=color, **stairs_kwargs) + assert lags.shape == ccg.shape == (2 * max_lag + 1,) + if not to_ms: + axis.set_xlabel("lag (samples)") + return bar(axis, lags, ccg, fill=fill, color=color, **stairs_kwargs) + + max_lag_ms = max_lag / samples_per_ms + ms_lags = np.arange(-max_lag_ms, max_lag_ms + bin / 2, bin) + ms_ccg = np.zeros_like(ms_lags) + ctr = ms_lags.shape[0] // 2 + assert ms_lags.shape == (2 * ctr + 1,) + for j in range(max_lag): + binix = int((j / samples_per_ms) // bin) + ms_ccg[ctr + binix] += ccg[max_lag + j] + if j: + ms_ccg[ctr - binix] += ccg[max_lag - j] + + axis.set_xlabel("lag (ms)") + return bar(axis, ms_lags, ms_ccg, fill=fill, color=color, **stairs_kwargs) def visualize_denoiser( @@ -525,3 +554,135 @@ def visualize_denoiser( if suptitle: fig.suptitle(suptitle) return fig + + +def plot_denoiser_scores( + recording: BaseRecording, + vis_sorting: DARTsortSorting, + load_denoiser_from_sorting: DARTsortSorting | None = None, + count_per_unit: int = 128, + figscale: float = 2.0, + decrease_objective="deconv", + seed: int = 0, + vmax=50.0, + dv=0.5, +): + from ..peel.peel_lib import check_residual_decrease + from ..transform import WaveformWhitener + + # load denoiser + if load_denoiser_from_sorting is None: + load_denoiser_from_sorting = vis_sorting + dn, geom, channel_index = try_get_denoising_pipeline(load_denoiser_from_sorting) + assert dn is not None + assert channel_index is not None + assert geom is not None + + # try load whitener + fp = get_featurization_pipeline(load_denoiser_from_sorting) + whitener = [f for f in fp.transformers if isinstance(f, WaveformWhitener)] + assert len(whitener) <= 1 + if len(whitener) == 0: + local_whiteners = None + else: + whitener = whitener[0].whitener + local_whiteners = whitener.local_whiteners(channel_index) + + # choose and load examples + rg = np.random.default_rng(seed) + times_samples = [] + channels = [] + labels = [] + scores_unwhitened = [] + scores_whitened = None if local_whiteners is None else [] + assert vis_sorting.labels is not None + for unit_id in np.unique(vis_sorting.unit_ids): + if unit_id < 0: + continue + + in_unit = np.flatnonzero(vis_sorting.labels == unit_id) + if in_unit.size <= count_per_unit: + choices = in_unit + else: + choices = rg.choice(in_unit, size=count_per_unit, replace=False) + choices.sort() + + tt = vis_sorting.times_samples[choices] + cc = vis_sorting.channels[choices] + x = spikeio.read_waveforms_channel_index( + recording, + times_samples=tt, + main_channels=cc, + channel_index=channel_index.numpy(force=True), + ) + x = torch.asarray(x, dtype=torch.float) + y, _ = dn(x, channels=torch.asarray(cc)) + + _, sc_res_a = check_residual_decrease( + x, + y, + decrease_objective=decrease_objective, + threshold=-1.0, + save_residnorm_decrease=True, + ) + + times_samples.append(tt) + channels.append(cc) + labels.append(vis_sorting.labels[choices]) + scores_unwhitened.append(sc_res_a["residnorm_decreases"]) + + if local_whiteners is None: + continue + assert scores_whitened is not None + + _, sc_res_b = check_residual_decrease( + x, + y, + decrease_objective=decrease_objective, + threshold=-1.0, + save_residnorm_decrease=True, + local_whiteners=local_whiteners, + channels=torch.asarray(cc), + ) + scores_whitened.append(sc_res_b["residnorm_decreases"]) + + data = dict( + time_samples=np.concatenate(times_samples), + channel=np.concatenate(channels), + label=np.concatenate(labels), + score_unwhitened=np.sqrt(np.maximum(0.0, np.concatenate(scores_unwhitened))), + ) + if scores_whitened is not None: + data["score_whitened"] = np.sqrt( + np.maximum(0.0, np.concatenate(scores_whitened)) + ) + df = pd.DataFrame(data) + + bins = np.arange(0.0, vmax, step=dv) + ncols = 1 + int(scores_whitened is not None) + fig, axes = plt.subplots( + nrows=1, ncols=ncols, squeeze=False, figsize=(2.0 * figscale * ncols, figscale) + ) + + for unit_id, sub_df in df.groupby("label"): + unit_id = int(unit_id) # type: ignore + axes[0, 0].hist( + sub_df.score_unwhitened, + bins=bins, + histtype="step", + color=glasbey1024[unit_id % len(glasbey1024)], + ) + if scores_whitened is None: + continue + axes[0, 1].hist( + sub_df.score_whitened, + bins=bins, + histtype="step", + color=glasbey1024[unit_id % len(glasbey1024)], + ) + + for ax, name in zip(axes.flat, ["original", "whitened"]): + ax.grid() + ax.set_xlabel(f"{name} score") + + return fig, axes, df diff --git a/src/dartsort/vis/gt.py b/src/dartsort/vis/gt.py index a4bad362..fff4ab10 100644 --- a/src/dartsort/vis/gt.py +++ b/src/dartsort/vis/gt.py @@ -85,9 +85,6 @@ def draw(self, panel, comparison): col_order = col_order[col_order < comparison.template_distances.shape[1]] dist = comparison.template_distances[row_order, :][:, col_order] dist = dist.astype(np.float64) - print(f"{dist.shape=}") - print(f"{dist.min()=}") - print(f"{dist.max()=}") ax = panel.subplots() log1p_norm = FuncNorm((np.log1p, np.expm1), vmin=0) diff --git a/src/dartsort/vis/mixture.py b/src/dartsort/vis/mixture.py index 192f1329..4809c89d 100644 --- a/src/dartsort/vis/mixture.py +++ b/src/dartsort/vis/mixture.py @@ -487,12 +487,12 @@ def draw(self, panel, mix_data: MixtureVisData, unit_id: int): for split, linestyle in zip(["full", "eval"], "-:"): if split == "full": ssco = mix_data.full_scores - in_unit_id = mix_data.full_inunits[unit_id] - in_nid = mix_data.full_inunits[nid] + in_unit_id = mix_data.full_inunits[int(unit_id)] + in_nid = mix_data.full_inunits[int(nid)] elif split == "eval": ssco = mix_data.eval_scores - in_nid = mix_data.eval_inunits.get(nid, empty) - in_unit_id = mix_data.eval_inunits.get(unit_id, empty) + in_nid = mix_data.eval_inunits.get(int(nid), empty) + in_unit_id = mix_data.eval_inunits.get(int(unit_id), empty) else: assert False @@ -615,12 +615,13 @@ def compute(self, mix_data: MixtureVisData, unit_id: int): mean_eval_resp=mix_data.mean_eval_resp, train_scores=mix_data.train_scores, eval_scores=mix_data.eval_scores, + train_data=mix_data.train_data, + eval_data=mix_data.val_data, # type: ignore cur_crit=None, ) return group_res def draw(self, panel, mix_data: MixtureVisData, unit_id: int): - print(f"{unit_id=}") demo_res = self.compute(mix_data, unit_id) ax = panel.subplots() @@ -633,7 +634,6 @@ def draw(self, panel, mix_data: MixtureVisData, unit_id: int): else: ds = ",".join([str(uu.item())[:1] for uu in demo_res.demolished.cpu()]) msg = f"units: {us}\n{ims}\ndemo: {ds}" - print(f"{msg=}") ax.text( 0.5, diff --git a/src/dartsort/vis/unit.py b/src/dartsort/vis/unit.py index e7c349d5..494ea5d8 100644 --- a/src/dartsort/vis/unit.py +++ b/src/dartsort/vis/unit.py @@ -21,14 +21,7 @@ from ..util.job_util import get_global_computation_config from ..util.multiprocessing_util import CloudpicklePoolExecutor, cloudpickle, get_pool from . import layout -from .analysis_plots import ( - bar, - bimod_stats, - centered_bins, - correlogram, - isi_hist, - plot_correlogram, -) +from .analysis_plots import bimod_stats, centered_bins, isi_hist, plot_correlogram from .colors import glasbey1024 from .waveforms import geomplot, geomplot_templates @@ -59,20 +52,20 @@ def draw(self, panel, sorting_analysis: DARTsortAnalysis, unit_id: int): h5_path = sorting_analysis.sorting.parent_h5_path if h5_path: - msg += f"feature source: {h5_path.name}\n" + msg += f"from: {h5_path.name}\n" nspikes = cast(np.ndarray, sorting_analysis.sorting.labels == unit_id).sum() - msg += f"n spikes: {nspikes}\n" + msg += f"count: {nspikes}\n" assert sorting_analysis.template_data is not None temps = sorting_analysis.template_data.unit_templates(unit_id) if not temps.size: - msg += "no template (too few spikes)" + msg += "no template\n(too few spikes)" elif temps.shape[0] == 1: ptp = np.ptp(temps, 1).max(1)[0] - msg += f"maxptp: {ptp:0.2f} su\n" + msg += f"ptp: {ptp:0.2f} su\n" snr = ptp * np.sqrt(nspikes) - msg += f"template snr: {snr:.1f}" + msg += f"snr: {snr:.1f}" else: assert False @@ -86,15 +79,34 @@ class ACG(UnitPlot): kind = "histogram" height = 0.75 - def __init__(self, max_lag=50): + def __init__(self, max_lag=50, bin=1, unit="samples"): super().__init__() self.max_lag = max_lag + self.bin = bin + self.unit = unit + if unit == "ms": + self.width = 2 def draw(self, panel, sorting_analysis: DARTsortAnalysis, unit_id: int): axis = panel.subplots() which = sorting_analysis.in_unit(unit_id) - times_samples = sorting_analysis.sorting.times_samples[which] - plot_correlogram(axis, times_samples, max_lag=self.max_lag) + t = sorting_analysis.sorting.times_samples[which] + samples_per_ms = sorting_analysis.sorting.sampling_frequency / 1000 + if self.unit == "samples": + max_lag_samples = self.max_lag + elif self.unit == "ms": + max_lag_samples = int(np.ceil(self.max_lag * samples_per_ms)) + else: + assert False + plot_correlogram( + axis, + t, + max_lag=max_lag_samples, + bin=self.bin, + samples_per_ms=samples_per_ms, + to_ms=self.unit == "ms", + ) + axis.grid(which="both") axis.set_ylabel("acg") @@ -129,6 +141,7 @@ def draw( color=color, label=label, ) + axis.grid(which="both") class XZScatter(UnitPlot): @@ -226,41 +239,67 @@ def draw(self, panel, sorting_analysis: DARTsortAnalysis, unit_id, axis=None): class TimeFeatScatter(UnitPlot): - kind = "medium" + kind = "ctimefeat" width = 2 height = 0.75 def __init__( self, feat_name, + color_by_template_if_possible=False, color_by_amplitude=True, amplitude_color_cutoff=15, alpha=1.0, label=None, + cbar=True, ): super().__init__() self.feat_name = feat_name self.amplitude_color_cutoff = amplitude_color_cutoff self.color_by_amplitude = color_by_amplitude + self.color_by_template_if_possible = color_by_template_if_possible self.alpha = alpha self.label = label or feat_name + self.cbar = cbar def draw(self, panel, sorting_analysis: DARTsortAnalysis, unit_id: int): axis = panel.subplots() assert sorting_analysis.times_seconds is not None assert sorting_analysis.amplitudes is not None + in_unit = sorting_analysis.in_unit(unit_id, at_most=50_000) t = sorting_analysis.times_seconds[in_unit] feat = sorting_analysis.named_feature(self.feat_name, which=in_unit) c = None - if self.color_by_amplitude: + cbar = self.cbar + did_by_template = False + if c is None and self.color_by_template_if_possible: + temp_ix = getattr(sorting_analysis.sorting, "template_inds", None) + if temp_ix is not None: + temp_ix = temp_ix[in_unit] + c = glasbey1024[temp_ix % len(glasbey1024)] + did_by_template = True + cbar = False + if c is None and self.color_by_amplitude: amps = sorting_analysis.amplitudes[in_unit] c = np.minimum(amps, self.amplitude_color_cutoff) s = axis.scatter(t, feat, c=c, lw=0, s=3, alpha=self.alpha, rasterized=True) axis.set_xlabel("time (s)") axis.set_ylabel(self.label) - if self.color_by_amplitude: - plt.colorbar(s, ax=axis, shrink=0.5, label="amp (su)") + axis.grid() + axis.set_axisbelow(True) + if cbar and self.color_by_amplitude: + plt.colorbar(s, ax=axis, shrink=0.5, pad=0.01, label="amp (su)") + if did_by_template: + axis.text( + 0.97, + 0.97, + "color: template", + ha="right", + va="top", + transform=axis.transAxes, + fontsize="small", + ) class TimeZScatter(TimeFeatScatter): @@ -274,8 +313,64 @@ def __init__(self, **kwargs): class TimeAmpScatter(TimeFeatScatter): - def __init__(self, **kwargs): - super().__init__(feat_name="amplitudes", label="amp (su)", **kwargs) + def __init__(self, color_by_template_if_possible=True, **kwargs): + super().__init__( + feat_name="amplitudes", + label="amp (su)", + color_by_template_if_possible=color_by_template_if_possible, + **kwargs, + ) + + +class AmplitudeHistogramByDiscreteVariable(UnitPlot): + kind = "camphist" + width = 2 + height = 0.75 + + def __init__(self, var="channels"): + self.var = var + + def draw(self, panel, sorting_analysis: DARTsortAnalysis, unit_id: int): + axis = panel.subplots() + assert sorting_analysis.amplitudes is not None + z = getattr(sorting_analysis.sorting, self.var, None) + if z is None: + axis.axis("off") + return + in_unit = sorting_analysis.in_unit(unit_id, at_most=50_000) + a = sorting_analysis.amplitudes[in_unit] + z = z[in_unit] + bins = np.arange(np.floor(a.min()), np.ceil(a.max()) + 0.1) + uqz = np.unique(z) + for uz in uqz: + axis.hist( + a[z == uz], + color=glasbey1024[uz % len(glasbey1024)], + histtype="step", + lw=1, + label=uz, + bins=bins, + ) + if uqz.size < 4: + axis.legend(title=self.var, fancybox=False, loc="upper right") + else: + msg = f"{self.var}, {uqz.size} uniques" + if uqz.size < 8: + msg += ":\n" + msg += ",".join(list(map(str, uqz.tolist()))) + axis.text( + 0.97, + 0.97, + msg, + fontsize="small", + ha="right", + va="top", + transform=axis.transAxes, + ) + axis.grid() + axis.set_axisbelow(True) + axis.semilogy() + axis.set_xlabel("amplitude (s.u.)") # -- waveform plots @@ -442,9 +537,9 @@ def draw(self, panel, sorting_analysis: DARTsortAnalysis, unit_id, axis=None): shift_str = "shifted " * sorting_analysis.shifting if self.title is None: - axis.set_title(shift_str + self.wfs_kind + rmsg) + axis.set_title(shift_str + self.wfs_kind + rmsg, fontsize="small") else: - axis.set_title(self.title + rmsg) + axis.set_title(self.title + rmsg, fontsize="small") axis.set_xticks([]) axis.set_yticks([]) @@ -468,7 +563,7 @@ def get_waveforms( class TPCAWaveformPlot(WaveformPlot): - wfs_kind = "coll.-cl. tpca wfs" + wfs_kind = "c-c tpca wfs" def get_waveforms( self, sorting_analysis: DARTsortAnalysis, unit_id: int @@ -596,7 +691,7 @@ def draw(self, panel, sorting_analysis: DARTsortAnalysis, unit_id, axis=None): ): tx.set_color(colors[i]) ty.set_color(colors[i]) - axis.set_title(self.title) + axis.set_title(self.title, fontsize="small") class NeighborQDAMatrices(UnitPlot): @@ -658,23 +753,20 @@ def draw(self, panel, sorting_analysis: DARTsortAnalysis, unit_id): ): tx.set_color(colors[i]) ty.set_color(colors[i]) - ax.set_title(title) + ax.set_title(title, fontsize="small") class NeighborCCGPlot(UnitPlot): - kind = "medium" + kind = "bneighborccg" - def __init__(self, n_neighbors=3, max_lag=50, with_merged_acg=False): + def __init__(self, n_neighbors=5, max_lag=50, bin=1, unit="samples"): super().__init__() self.n_neighbors = n_neighbors self.max_lag = max_lag - self.with_merged_acg = with_merged_acg - if self.with_merged_acg: - self.height = 1.0 - self.width = 3.0 - else: - self.height = 1.75 - self.width = 2 + self.height = 2 + self.width = 2 + self.unit = unit + self.bin = bin def draw(self, panel, sorting_analysis: DARTsortAnalysis, unit_id: int): ( @@ -699,41 +791,44 @@ def draw(self, panel, sorting_analysis: DARTsortAnalysis, unit_id: int): for nid in neighbor_ids ] - if self.with_merged_acg: - axes = panel.subplots( - nrows=1 + self.with_merged_acg, - ncols=len(neighb_sts), - sharey="row", - sharex=True, - squeeze=False, - ) - else: - axes = panel.subplots( - ncols=1 + self.with_merged_acg, - nrows=len(neighb_sts), - sharey="row", - sharex=True, - squeeze=False, - ) - axes = axes.T - for j in range(len(neighb_sts)): - clags, ccg = correlogram(my_st, neighb_sts[j], max_lag=self.max_lag) - bar(axes[0, j], clags, ccg, fill=True, fc=colors[j]) # , ec="k", lw=1) - axes[0, j].set_title(f"unit {neighbor_ids[j]}") - - if not self.with_merged_acg: - continue + axes = panel.subplots( + ncols=1, + nrows=len(neighb_sts), + sharey="row", + sharex=True, + squeeze=False, + ) + axes = axes.T - merged_st = np.concatenate((my_st, neighb_sts[j])) - merged_st.sort() - alags, acg = correlogram(merged_st, max_lag=self.max_lag) - bar(axes[1, j], alags, acg, fill=True, fc=colors[j]) # , ec="k", lw=1) - if self.with_merged_acg: - axes[0, 0].set_ylabel("ccg") - axes[1, 0].set_ylabel("merged acg") - axes[-1, len(neighb_sts) // 2].set_xlabel("lag (samples)") + samples_per_ms = sorting_analysis.sorting.sampling_frequency / 1000 + if self.unit == "samples": + max_lag_samples = self.max_lag + elif self.unit == "ms": + max_lag_samples = int(np.ceil(self.max_lag * samples_per_ms)) else: - axes[-1, -1].set_xlabel("ccg") + assert False + + for j in range(len(neighb_sts)): + plot_correlogram( + axes[0, j], + my_st, + neighb_sts[j], + max_lag=max_lag_samples, + bin=self.bin, + samples_per_ms=samples_per_ms, + to_ms=self.unit == "ms", + fc=colors[j], + ) + axes[0, j].grid(which="both") + axes[0, j].text( + 0.97, + 0.97, + f"vs. unit {neighbor_ids[j]}", + ha="right", + va="top", + transform=axes[0, j].transAxes, + fontsize="small", + ) class NeighborQDAPlot(UnitPlot): @@ -809,6 +904,7 @@ def draw(self, panel, sorting_analysis: DARTsortAnalysis, unit_id: int): ax.axvline(0, color="k", lw=0.8) ax.grid() + ax.set_axisbelow(True) hstat = kstat = "" if self.kind == "hist": @@ -877,14 +973,17 @@ def default_plots(sorting_analysis=None): p = [ UnitTextInfo(), ACG(), + ACG(max_lag=50.0, bin=0.5, unit="ms"), ISIHistogram(), ISIHistogram(bin_ms=0.25, max_ms=50.0), XZScatter(), + AmplitudeHistogramByDiscreteVariable(), TimeAmpScatter(), RawWaveformPlot(), NearbyCoarseTemplatesPlot(), CoarseTemplateDistancePlot(), NeighborCCGPlot(), + NeighborCCGPlot(max_lag=50.0, bin=0.5, unit="ms"), ] if sorting_analysis is not None and sorting_analysis.has_localizations(): p.extend([TimeZScatter(), TimeRegZScatter()]) @@ -892,6 +991,10 @@ def default_plots(sorting_analysis=None): p.extend([PCAScatter(), TPCAWaveformPlot()]) if sorting_analysis is not None and sorting_analysis.qda is not None: p.extend([NeighborQDAMatrices(), NeighborQDAPlot()]) + if sorting_analysis is not None and hasattr( + sorting_analysis.sorting, "template_inds" + ): + p.extend([AmplitudeHistogramByDiscreteVariable("template_inds")]) return p @@ -920,7 +1023,7 @@ def make_unit_summary( pca_radius_um=75.0, plots=None, max_height=4, - figsize=(16, 8.5), + figsize=(18, 8.5), figure=None, gizmo_name="sorting_analysis", **other_global_params, @@ -954,7 +1057,7 @@ def make_all_summaries( amplitude_color_cutoff=15.0, pca_radius_um=75.0, max_height=4, - figsize=(16, 8.5), + figsize=(18, 8.5), dpi=200, image_ext="png", n_jobs=None, diff --git a/src/dartsort/vis/unit_comparison.py b/src/dartsort/vis/unit_comparison.py index 52051a6a..ef6d7d14 100644 --- a/src/dartsort/vis/unit_comparison.py +++ b/src/dartsort/vis/unit_comparison.py @@ -279,7 +279,7 @@ def __init__( unsorted_missed_color=None, count=50, single_channel=False, - order="tprandom", + order="tpfpfn", show_sorted_matches=True, show_unsorted_matches=False, average=False, @@ -328,6 +328,7 @@ def _draw(self, panel, comparison, unit_id, tested_unit_id): waveforms = [] colors = [] + channels = [] maa = 1.0 for kind, color in self.colors.items(): if w[kind] is None or not w[kind].size or not np.isfinite(w[kind]).any(): @@ -337,11 +338,15 @@ def _draw(self, panel, comparison, unit_id, tested_unit_id): maxs = np.nanmax(np.abs(w[kind]), axis=(1, 2)) maxs = np.nan_to_num(maxs, nan=1.0) maa = max(maa, np.percentile(maxs, 90)) + ch = w[f"channels_{kind}"] if self.average: avg = w[kind].mean(0, keepdims=True) + assert np.all(ch == ch[0]), "Chans mismatch in avg" waveforms.append(avg) + channels.append(ch[:1]) else: waveforms.append(w[kind]) + channels.append(ch) colors.append(np.broadcast_to([color], waveforms[-1].shape[:1])) if not len(waveforms): @@ -350,25 +355,26 @@ def _draw(self, panel, comparison, unit_id, tested_unit_id): if self.order == "tprandom" and not self.average and len(waveforms) > 1: wlast = np.concatenate(waveforms[1:]) clast = np.concatenate(colors[1:]) + chanlast = np.concatenate(channels[1:]) n = len(wlast) shuf = np.random.default_rng(0).permutation(n) waveforms = [waveforms[0], wlast[shuf]] colors = [colors[0], clast[shuf]] + channels = [channels[0], chanlast[shuf]] waveforms = np.concatenate(waveforms) colors = np.concatenate(colors) - max_channels = np.broadcast_to([w["max_chan"]], colors.shape) + channels = np.concatenate(channels) + chans = comparison.gt_analysis.vis_channel_index[w["max_chan"]] + chans = chans[chans < comparison.gt_analysis.vis_channel_index.shape[0]] ax = panel.subplots() max_abs_amp = maa - chans = w["channel_index"][w["max_chan"]] - chans = chans[chans < len(w["geom"])] geomplot( waveforms, - max_channels=max_channels, - channel_index=w["channel_index"], geom=w["geom"], ax=ax, + channels=channels, show_zero=False, max_abs_amp=max_abs_amp, annotate_z=True, diff --git a/tests/test_config.py b/tests/test_config.py index 6063ee83..116ab769 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -1,10 +1,11 @@ import dataclasses + import dartsort def test_cfg_consistency(): """Ensure config.py and internal_config.py don't diverge.""" - cfg0 = dartsort.to_internal_config(dartsort.DeveloperConfig()) + cfg0 = dartsort.to_internal_config(dartsort.DeveloperConfig(), 10) cfg1 = dartsort.DARTsortInternalConfig() # can just do assert cfg0 == cfg1, but pytest gives a better diff --git a/tests/test_kmeans.py b/tests/test_kmeans.py index 92b0e5a4..36e01d1e 100644 --- a/tests/test_kmeans.py +++ b/tests/test_kmeans.py @@ -28,12 +28,15 @@ def blobs(): @pytest.mark.parametrize( "algorithm", - [kmeans.kmeans, kmeans.truncated_kmeans], + [kmeans.kmeans, kmeans.truncated_kmeans, kmeans.batched_kmeans], ) def test_kmeans(blobs, algorithm): res = algorithm(blobs["X"], n_components=blobs["K"]) - order = np.lexsort(np.asarray(res["centroids"]).T) - labels = np.argsort(order)[np.asarray(res["labels"])] + centroids = res["centroids"] if isinstance(res, dict) else res.centroids + labels = res["labels"] if isinstance(res, dict) else res.labels - assert np.allclose(res["centroids"][order], blobs["centroids"], atol=0.25) + order = np.lexsort(np.asarray(centroids).T) + labels = np.argsort(order)[np.asarray(labels)] + + assert np.allclose(centroids[order], blobs["centroids"], atol=0.25) assert np.array_equal(labels, blobs["labels"]) diff --git a/tests/test_matching.py b/tests/test_matching.py index 67b4c20b..89b589ad 100644 --- a/tests/test_matching.py +++ b/tests/test_matching.py @@ -116,6 +116,7 @@ def test_no_crumbs(subtests, refractory_sim, method, cd_iter, channel_selection_ threshold=threshold, cd_iter=cd_iter, channel_selection_radius=channel_selection_radius, + template_svd_compression_min_explained_variance=0.0, ) if method == "upcomp": cfg_kw["template_type"] = "individual_compressed_upsampled" @@ -163,6 +164,7 @@ def test_no_crumbs(subtests, refractory_sim, method, cd_iter, channel_selection_ inv_lambda=matcher.inv_lambda, scale_min=matcher.amp_scale_min, scale_max=matcher.amp_scale_max, + resid_offset=0, ) conv = res["conv"].numpy(force=True) diff --git a/tests/test_subtract.py b/tests/test_subtract.py index 4c9d98f6..cb9cdb96 100644 --- a/tests/test_subtract.py +++ b/tests/test_subtract.py @@ -18,7 +18,6 @@ FeaturizationConfig, FitSamplingConfig, SubtractionConfig, - default_pretrained_path ) fixedlenkeys = ( @@ -127,7 +126,6 @@ def test_fakedata_nonn(fakedata, tmp_path): with tempfile.TemporaryDirectory( dir=tmp_path, ignore_cleanup_errors=True ) as tempdir: - print("first one") torch.manual_seed(0) st = subtract( recording=rec, @@ -194,7 +192,7 @@ def test_fakedata_nonn(fakedata, tmp_path): recording=rec, output_dir=tempdir, featurization_cfg=featconf, - sampling_cfg=FitSamplingConfig(n_residual_snips=0), + sampling_cfg=sampconf, subtraction_cfg=subconf, overwrite=True, ) @@ -216,17 +214,17 @@ def test_fakedata_nonn(fakedata, tmp_path): channel_index.shape[1], ) - for ccfg in (two_jobs_cfg, two_jobs_cfg_spawn): + for cname, ccfg in (("", two_jobs_cfg), ("spawn", two_jobs_cfg_spawn)): with tempfile.TemporaryDirectory( dir=tmp_path, ignore_cleanup_errors=True ) as tempdir: - print("parallel first one") + print("parallel first one", cname) torch.manual_seed(0) st = subtract( recording=rec, output_dir=tempdir, featurization_cfg=nolocfeatconf, - sampling_cfg=FitSamplingConfig(n_residual_snips=0), + sampling_cfg=sampconf, subtraction_cfg=subconf, overwrite=True, computation_cfg=ccfg, @@ -284,7 +282,7 @@ def test_fakedata_nonn(fakedata, tmp_path): recording=rec, output_dir=tempdir, featurization_cfg=nolocfeatconf, - sampling_cfg=FitSamplingConfig(n_residual_snips=0), + sampling_cfg=sampconf, subtraction_cfg=subconf, overwrite=True, computation_cfg=ccfg, @@ -314,12 +312,12 @@ def test_resume(fakedata, tmp_path): detection_threshold=20.0, peak_sign="both", subtraction_denoising_cfg=FeaturizationConfig( - do_nn_denoise=False, - denoise_only=True + do_nn_denoise=False, denoise_only=True ), first_denoiser_thinning=0.0, first_denoiser_spatial_jitter=0, first_denoiser_temporal_jitter=0, + whiten=False, ) featconf = FeaturizationConfig(skip=True) sampconf = FitSamplingConfig(n_residual_snips=0) @@ -413,6 +411,7 @@ def test_small_nonn(tmp_path, nn_localization): subtraction_denoising_cfg=FeaturizationConfig( do_nn_denoise=False, denoise_only=True ), + whiten=False, ) featconf = FeaturizationConfig(do_nn_denoise=False, nn_localization=nn_localization)