From 0f0700b7422a301901b298c9b15c70d7d9238b07 Mon Sep 17 00:00:00 2001 From: Charlie Windolf Date: Sat, 16 May 2026 14:03:04 -0400 Subject: [PATCH 01/46] readme: should be h3 --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 4d6a5902..6b457156 100644 --- a/README.md +++ b/README.md @@ -32,7 +32,7 @@ $ pip install dartsort If you want to run the test suite or use `dartsort.vis`, you can install the optional dependencies with `pip install dartsort[test,vis]`. -## Setting up a Python environment +### Setting up a Python environment If you need to set up Python or PyTorch, I find that a [`conda-forge`](https://conda-forge.org/)-based distribution is the most reliable at installing the GPU dependencies which PyTorch needs (note: `conda-forge` is different from the non-free Anaconda). From 19a3e3265182b8fa1be6ae48e3a9c16833d3f381 Mon Sep 17 00:00:00 2001 From: Charlie Windolf Date: Thu, 4 Jun 2026 11:33:12 -0400 Subject: [PATCH 02/46] whiten: save covariance, fitting as a transformer --- src/dartsort/peel/matching_util/drifty.py | 5 ++- src/dartsort/peel/reduction_template.py | 5 ++- src/dartsort/templates/templates.py | 3 ++ src/dartsort/transform/pipeline.py | 6 +++ src/dartsort/transform/whiten.py | 50 ++++++++++++++++++++++- src/dartsort/util/internal_config.py | 4 ++ src/dartsort/util/noise_util.py | 45 ++++++++++++-------- 7 files changed, 96 insertions(+), 22 deletions(-) diff --git a/src/dartsort/peel/matching_util/drifty.py b/src/dartsort/peel/matching_util/drifty.py index 2bfaa510..e9bfe265 100644 --- a/src/dartsort/peel/matching_util/drifty.py +++ b/src/dartsort/peel/matching_util/drifty.py @@ -236,7 +236,10 @@ def _from_config( whitener = None else: assert template_data.whitener is not None - whitener = SpatialWhitener.from_numpy(template_data.whitener) + assert template_data.covariance is not None + whitener = SpatialWhitener.from_numpy( + template_data.whitener, template_data.covariance + ) if not wh_none and not matching_cfg.whiten_features: assert matching_cfg.whitening.strategy == "prewhiten_postapply" diff --git a/src/dartsort/peel/reduction_template.py b/src/dartsort/peel/reduction_template.py index 14d790ff..b14bc488 100644 --- a/src/dartsort/peel/reduction_template.py +++ b/src/dartsort/peel/reduction_template.py @@ -155,9 +155,9 @@ def _from_config( templates *= msk if whitener is None: - whitener_np = None + whitener_np = covariance_np = None else: - whitener_np = whitener.to_numpy() + whitener_np, covariance_np = whitener.to_numpy() return TemplateData( unit_ids=unit_ids, @@ -169,6 +169,7 @@ def _from_config( trough_offset_samples=trough, tsvd=p.temporal_svd(), whitener=whitener_np, + covariance=covariance_np, sampling_frequency=recording.sampling_frequency, whiten_strategy=template_cfg.whitening.strategy, ) diff --git a/src/dartsort/templates/templates.py b/src/dartsort/templates/templates.py index 1e60d9f5..58e1e057 100644 --- a/src/dartsort/templates/templates.py +++ b/src/dartsort/templates/templates.py @@ -51,6 +51,7 @@ class TemplateData: properties: dict[str, np.ndarray] | None = None tsvd: TruncatedSVD | PCA | None = None whitener: np.ndarray | None = None + covariance: np.ndarray | None = None whiten_strategy: WhiteningStrategy = "none" featurization_basis: np.ndarray | None = None @@ -160,6 +161,8 @@ def to_npz(self, npz_path): to_save["raw_std_dev"] = self.raw_std_dev if self.whitener is not None: to_save["whitener"] = self.whitener + if self.covariance is not None: + to_save["covariance"] = self.covariance if self.featurization_basis is not None: to_save["featurization_basis"] = self.featurization_basis if not npz_path.parent.exists(): diff --git a/src/dartsort/transform/pipeline.py b/src/dartsort/transform/pipeline.py index ed9a2ac5..dd91585b 100644 --- a/src/dartsort/transform/pipeline.py +++ b/src/dartsort/transform/pipeline.py @@ -494,6 +494,12 @@ def featurization_config_to_class_names_and_kwargs( ) ) + if fc.fit_disabled_whitener: + assert fc.whiten_cfg is not None + class_names_and_kwargs.append( + ("WaveformWhitener", {"disabled": True, "whiten_cfg": fc.whiten_cfg}) + ) + # logic for picking an efficient combo of tpcas and nn denoisers class_names_and_kwargs.extend( _add_tpca_and_nn(featurization_cfg, waveform_cfg, sampling_frequency) diff --git a/src/dartsort/transform/whiten.py b/src/dartsort/transform/whiten.py index c3255f9b..b1c9fd6b 100644 --- a/src/dartsort/transform/whiten.py +++ b/src/dartsort/transform/whiten.py @@ -1,13 +1,21 @@ +from pathlib import Path from typing import TYPE_CHECKING +import torch +from spikeinterface.core import BaseRecording + +from ..util.data_util import DARTsortSorting +from ..util.internal_config import ComputationConfig, WhiteningConfig from .transform_base import BaseWaveformDenoiser if TYPE_CHECKING: from ..util.noise_util import SpatialWhitener + from .pipeline import WaveformPipeline class WaveformWhitener(BaseWaveformDenoiser): default_name = "whiten" + needs_residual = True def __init__( self, @@ -17,6 +25,8 @@ def __init__( name=None, name_prefix=None, whitener: "SpatialWhitener | None" = None, + disabled: bool = True, + whiten_cfg: WhiteningConfig = WhiteningConfig(), ): super().__init__( name=name, name_prefix=name_prefix, geom=geom, channel_index=channel_index @@ -25,10 +35,48 @@ def __init__( "Meant to be used with full-probe data." ) self.whitener = whitener + self.disabled = disabled + self.whiten_cfg = whiten_cfg + self.motion = None + + @property + def needs_fit(self): + return self.whitener is None def forward(self, waveforms, **unused): del unused - if self.whitener is None: + if self.disabled or self.whitener is None: return waveforms else: return self.whitener.whiten(x=waveforms) + + def attach_motion(self, motion): + self.motion = motion + + def fit( + self, + recording: BaseRecording, + waveforms: torch.Tensor, + *, + hdf5_filename: Path | None = None, + computation_cfg: ComputationConfig, + pipeline: "WaveformPipeline | None" = None, + **spike_data: torch.Tensor, + ): + del recording, spike_data, waveforms, pipeline + from ..util.noise_util import SpatialWhitener + + assert hdf5_filename is not None + + if self.motion is None: + assert self.whiten_cfg.strategy != "postwhiten" + + sorting = DARTsortSorting.from_peeling_hdf5( + hdf5_filename, load_simple_features=False + ) + self.whitener = SpatialWhitener.from_config( + sorting=sorting, + motion=self.motion, + whiten_cfg=self.whiten_cfg, + computation_cfg=computation_cfg, + ) diff --git a/src/dartsort/util/internal_config.py b/src/dartsort/util/internal_config.py index 0293c760..ddda3291 100644 --- a/src/dartsort/util/internal_config.py +++ b/src/dartsort/util/internal_config.py @@ -629,6 +629,10 @@ class FeaturizationConfig: gmm_refinement_cfg: RefinementConfig | None = None gmm_clustering_features_cfg: ClusteringFeaturesConfig | None = None + # helper for fitting whiteners + fit_disabled_whitener: bool = False + whiten_cfg: WhiteningConfig | None = None + # used when naming datasets saved to h5 files input_waveforms_name: str = "collisioncleaned" output_waveforms_name: str = "denoised" diff --git a/src/dartsort/util/noise_util.py b/src/dartsort/util/noise_util.py index 49a793e4..95100744 100644 --- a/src/dartsort/util/noise_util.py +++ b/src/dartsort/util/noise_util.py @@ -979,7 +979,9 @@ def estimate( cov = torch.cov(x_spatial.T.double()) cov = spiketorch.enforce_posdef(cov, eps=eps) else: - cov = spiketorch.nancov(x_spatial[:, valid].double(), force_posdef=True, eps=eps) + cov = spiketorch.nancov( + x_spatial[:, valid].double(), force_posdef=True, eps=eps + ) assert torch.is_tensor(cov) if shrinkage: cov = F.softshrink(cov, shrinkage) @@ -1449,10 +1451,9 @@ def fp_control_threshold_from_h5( def residual_covariance( sorting: DARTsortSorting, do_interpolation: bool, - motion: MotionInfo, + motion: MotionInfo | None, interp_params: InterpolationParams = tps_interp_clampna_extrap_params, device: torch.device | None = None, - rgeom=None, residual_times_s_dataset_name="residual_times_seconds", residual_dataset_name="residual", seed: int = 0, @@ -1461,14 +1462,7 @@ def residual_covariance( assert sorting.parent_h5_path is not None if do_interpolation: - with h5py.File(sorting.parent_h5_path, "r", locking=False) as h5: - geom = cast(h5py.Dataset, h5["geom"])[:].astype(np.float32) - if rgeom is None: - if motion is None: - rgeom = geom - else: - rgeom = motion.rgeom - rgeom = rgeom.astype(np.float32) + assert motion is not None snipgen = generate_interpolated_residual_snippets( motion=motion, hdf5_path=sorting.parent_h5_path, @@ -1559,24 +1553,27 @@ def sparsechol_whitener( class SpatialWhitener(BModule): - def __init__(self, whitener: Tensor): + def __init__(self, whitener: Tensor, covariance: Tensor): super().__init__() self.register_buffer("whitener", whitener) + self.register_buffer("covariance", covariance) @classmethod - def from_numpy(cls, whitener: np.ndarray): + def from_numpy(cls, whitener: np.ndarray, covariance: np.ndarray): logger.dartsortverbose("Load whitener from numpy.") - return cls(whitener=torch.asarray(whitener)) + return cls( + whitener=torch.asarray(whitener), covariance=torch.asarray(covariance) + ) - def to_numpy(self) -> np.ndarray: - return self.b.whitener.numpy(force=True) + def to_numpy(self) -> tuple[np.ndarray, np.ndarray]: + return self.b.whitener.numpy(force=True), self.b.covariance.numpy(force=True) @classmethod def from_config( cls, *, sorting: DARTsortSorting, - motion: MotionInfo, + motion: MotionInfo | None, whiten_cfg: WhiteningConfig, computation_cfg: ComputationConfig | None = None, ) -> Self: @@ -1599,7 +1596,7 @@ def from_config( cov_np, channel_index=neighbs ) whitener = torch.asarray(whitener).to(cov) - return cls(whitener=whitener) + return cls(whitener=whitener, covariance=cov) def whiten_traces_spatial_major( self, x: Tensor, out: Tensor | None = None @@ -1626,3 +1623,15 @@ def prec_mul(self, x: Tensor) -> Tensor: x = x @ (self.b.whitener.T @ self.b.whitener) x = x.reshape(*shp, c) return x + + def local_whiteners(self, channel_index: Tensor, eps=1e-6): + nc, cloc = channel_index.shape + assert nc == self.b.covariance.shape[0] + w = self.b.covariance.new_zeros((nc, cloc, cloc)) + for j, chans in enumerate(channel_index): + mask = (chans < nc).nonzero()[:, 0] + chans = chans[mask] + cj = self.b.covariance[chans][:, chans] + wj = fullzca_whitener(cj.numpy(force=True).astype(np.float64), eps=eps) + w[j, mask[:, None], mask[None, :]] = torch.asarray(wj).to(w) + return w From 7d426786645ce081fba472ce041f90f6afdf2051 Mon Sep 17 00:00:00 2001 From: Charlie Windolf Date: Thu, 4 Jun 2026 11:50:28 -0400 Subject: [PATCH 03/46] whiten: hook into subtraction --- src/dartsort/config.py | 1 + src/dartsort/main.py | 3 +- src/dartsort/peel/peel_base.py | 4 +++ src/dartsort/peel/peel_lib.py | 11 ++++++ src/dartsort/peel/subtract.py | 32 ++++++++++++++++- src/dartsort/util/internal_config.py | 53 ++++++++++++++++++---------- tests/test_config.py | 1 + 7 files changed, 84 insertions(+), 21 deletions(-) diff --git a/src/dartsort/config.py b/src/dartsort/config.py index 9abac469..b4741424 100644 --- a/src/dartsort/config.py +++ b/src/dartsort/config.py @@ -254,6 +254,7 @@ class DeveloperConfig(DARTsortUserConfig): first_denoiser_spatial_dedup_radius: float = 100.0 realign_to_denoiser: bool = True use_nn_in_subtraction: bool = True + whiten_in_subtraction: bool = False # matching matching_template_type: Literal["individual_compressed_upsampled", "drifty"] = ( diff --git a/src/dartsort/main.py b/src/dartsort/main.py index 01cc50b3..92b46704 100644 --- a/src/dartsort/main.py +++ b/src/dartsort/main.py @@ -118,8 +118,7 @@ def dartsort( - "sorting": `DARTsortSorting` - "motion": MotionInfo """ - output_dir = ensure_path(output_dir) - output_dir.mkdir(exist_ok=True) + output_dir = ensure_path(output_dir, mkdir=True) # convert cfg to internal format and store it for posterity cfg = to_internal_config(cfg) diff --git a/src/dartsort/peel/peel_base.py b/src/dartsort/peel/peel_base.py index 7f3255bf..b0e4cf8a 100644 --- a/src/dartsort/peel/peel_base.py +++ b/src/dartsort/peel/peel_base.py @@ -154,6 +154,7 @@ def load_or_fit_and_save_models( cleanup_and_log_gpu_usage(computation_cfg, "peel: Usage after model fits:") assert not self.needs_precompute() assert not self.needs_fit() + self.post_fit() def peel( self, @@ -431,6 +432,9 @@ def peeling_needs_fit(self) -> bool: def peeling_needs_precompute(self) -> bool: return False + def post_fit(self): + pass + def precompute_peeling_data( self, save_folder, overwrite=False, computation_cfg=None ): diff --git a/src/dartsort/peel/peel_lib.py b/src/dartsort/peel/peel_lib.py index ca4e5793..78ab28f7 100644 --- a/src/dartsort/peel/peel_lib.py +++ b/src/dartsort/peel/peel_lib.py @@ -80,10 +80,18 @@ def check_residual_decrease( threshold=10.0, save_residnorm_decrease=False, overwrite_orig_waveforms: bool = False, + local_whiteners: Tensor | None = None, + channels: Tensor | None = None, ): if not threshold: return None, {} + if local_whiteners is not None: + assert channels is not None + W = local_whiteners[channels].mT + orig_wfs = orig_wfs.bmm(W) + dn_wfs = dn_wfs.bmm(W) + if decrease_objective == "deconv": if overwrite_orig_waveforms: buf = orig_wfs.mul_(dn_wfs).nan_to_num_() @@ -154,6 +162,7 @@ def subtract_chunk( dedup_rel_inds=None, residnorm_decrease_threshold=16.0, decrease_objective: Literal["norm", "normsq", "deconv"] = "deconv", + local_whiteners: Tensor | None = None, relative_peak_radius=5, dedup_temporal_radius=7, remove_exact_duplicates=True, @@ -353,6 +362,8 @@ def subtract_chunk( decrease_objective=decrease_objective, threshold=residnorm_decrease_threshold, save_residnorm_decrease=save_residnorm_decrease, + local_whiteners=local_whiteners, + channels=channels, ) features.update(new_feats) if resid_keep is not None: diff --git a/src/dartsort/peel/subtract.py b/src/dartsort/peel/subtract.py index 44fa3279..180d1027 100644 --- a/src/dartsort/peel/subtract.py +++ b/src/dartsort/peel/subtract.py @@ -9,7 +9,7 @@ import torch from spikeinterface.core import BaseRecording -from ..transform import Voltage, Waveform, WaveformPipeline +from ..transform import Voltage, Waveform, WaveformPipeline, WaveformWhitener from ..util import job_util from ..util.data_util import SpikeDataset, subsample_waveforms from ..util.internal_config import ( @@ -125,6 +125,9 @@ def __init__( can_thin = recording.get_total_duration() > fit_sampling_cfg.n_seconds_fit / _p self.first_denoiser_thinning = p.first_denoiser_thinning if can_thin else 0.0 + # this may be overwritten after featurization fit + self.register_buffer_or_none("local_whiteners", None) + def out_datasets(self): datasets = super().out_datasets() @@ -148,6 +151,23 @@ def peeling_needs_fit(self): def peeling_needs_precompute(self): return self.subtraction_denoising_pipeline.needs_precompute() + def post_fit(self): + if not self.p.whiten: + return + assert self.featurization_pipeline is not None + assert not self.featurization_pipeline.needs_fit() + whitener = [ + f + for f in self.featurization_pipeline.transformers + if isinstance(f, WaveformWhitener) + ] + assert len(whitener) == 1 + whitener = whitener[0].whitener + assert whitener is not None + local_whiteners = whitener.local_whiteners(self.b.subtract_channel_index) + self.del_none_buffer("local_whiteners") + self.register_buffer("local_whiteners", local_whiteners) + def save_models(self, save_folder): super().save_models(save_folder) sub_denoise_pt = Path(save_folder) / "subtraction_denoising_pipeline.pt" @@ -179,6 +199,15 @@ def from_config( geom, subtraction_cfg.subtract_radius_um, to_torch=True ) + # handle whitener fitting + if subtraction_cfg.whiten: + assert subtraction_cfg.whiten_cfg is not None + featurization_cfg = replace( + featurization_cfg, + fit_disabled_whitener=True, + whiten_cfg=subtraction_cfg.whiten_cfg, + ) + # construct denoising and featurization pipelines subtraction_denoising_pipeline = WaveformPipeline.from_config( geom=geom, @@ -261,6 +290,7 @@ def peel_chunk( denoiser_realignment_shift=self.p.denoiser_realignment_shift, denoiser_realignment_channel=self.p.denoiser_realignment_channel, compute_collidedness=self.save_collidedness, + local_whiteners=self.b.local_whiteners, ) # add in chunk_start_samples diff --git a/src/dartsort/util/internal_config.py b/src/dartsort/util/internal_config.py index ddda3291..47db4e3e 100644 --- a/src/dartsort/util/internal_config.py +++ b/src/dartsort/util/internal_config.py @@ -129,6 +129,7 @@ def pad(self, padding_ms: float) -> Self: @cfg_dataclass class InterpolationParams: """Spatial waveform or feature interpolation parameters""" + method: InterpMethod = "kriging" kernel: InterpKernel = "thinplate" extrap_method: InterpMethod | None = None @@ -213,6 +214,7 @@ def normalize(self) -> Self: @cfg_dataclass class FitSamplingConfig: """Data sampling parameters for model fitting""" + max_waveforms_fit: int = 50_000 n_waveforms_fit: int = 40_000 more_waveforms_fit: int = 2000 * 1024 @@ -235,6 +237,7 @@ class FitSamplingConfig: @cfg_dataclass class ClusteringFeaturesConfig: """Parameters to control which features are used for initial clustering""" + # simple matrix feature controls use_x: bool = True use_z: bool = True @@ -269,6 +272,7 @@ class ClusteringFeaturesConfig: @cfg_dataclass class ClusteringConfig: """Initial clustering parameters""" + cluster_strategy: str = "dpc" sampling_cfg: FitSamplingConfig = default_clustering_fit_sampling_cfg @@ -318,6 +322,7 @@ class ClusteringConfig: @cfg_dataclass class WhiteningConfig: """Whitening parameters""" + strategy: WhiteningStrategy = "none" estimator: WhiteningEstimator = "localzca" interp_params: InterpolationParams = tps_interp_clampna_extrap_params @@ -332,6 +337,7 @@ class WhiteningConfig: @cfg_dataclass class TemplateConfig: """Template waveform estimation parameters""" + spikes_per_unit: int = 500 with_raw_std_dev: bool = False reduction: Literal["median", "mean"] = "median" @@ -410,6 +416,7 @@ def actual_algorithm(self) -> str: @cfg_dataclass class TemplateRealignmentConfig: """Template waveform alignment parameters""" + realign_peaks: bool = True realign_strategy: RealignStrategy = "snr_weighted_trough_factor" realign_shift_ms: float = 1.5 @@ -421,6 +428,7 @@ class TemplateRealignmentConfig: @cfg_dataclass class TemplateMergeConfig: """Parameters describing how to judge whether to merge groups of templates""" + distance_kind: Literal[ "scaled_normeuc", "deconv", "max", "weighted_scaled_normeuc" ] = "weighted_scaled_normeuc" @@ -459,6 +467,7 @@ def to_template_config(self, template_cfg: TemplateConfig | None = None): @cfg_dataclass class RefinementConfig: """Parameters for clustering refinement""" + refinement_strategy: str = "tmm" sampling_cfg: FitSamplingConfig = default_clustering_fit_sampling_cfg @@ -644,6 +653,7 @@ class FeaturizationConfig: @cfg_dataclass class SubtractionConfig: """Parameters for neural-net based spike detection""" + # peeling common chunk_length_samples: int = 30_000 fit_only: bool = False @@ -669,6 +679,8 @@ class SubtractionConfig: convexity_threshold: float | None = None convexity_radius: int = 7 max_iter: int = 100 + whiten: bool = False + whiten_cfg: WhiteningConfig | None = None # how will waveforms be denoised before subtraction? # users can also save waveforms/features during subtraction @@ -698,6 +710,7 @@ class SubtractionConfig: @cfg_dataclass class ThresholdingConfig: """Parameters for threshold-crossing spike detection""" + # peeling common chunk_length_samples: int = 30_000 @@ -724,6 +737,7 @@ class ThresholdingConfig: @cfg_dataclass class MatchingConfig: """Template matching pursuit parameters""" + # peeling common chunk_length_samples: int = 30_000 max_spikes_per_second: int = 16384 @@ -818,6 +832,7 @@ class MotionEstimationConfig: @cfg_dataclass class ComputationConfig: """Multiprocessing or threading parameters""" + n_jobs_cpu: int = 0 n_jobs_gpu: int = 0 n_jobs_small: int = -2 @@ -993,6 +1008,24 @@ def to_internal_config(cfg) -> DARTsortInternalConfig: save_input_waveforms=cfg.save_collisioncleaned_waveforms, save_collidedness=save_collidedness, ) + if cfg.template_interp_kind == "tps": + temp_interp_params = tps_interp_clampna_extrap_params + elif cfg.template_interp_kind == "clampna": + temp_interp_params = clampna_interp_params + else: + assert False + if cfg.matching_interp_kind == "tps": + match_interp_params = tps_interp_clampna_extrap_params + elif cfg.matching_interp_kind == "clampna": + match_interp_params = clampna_interp_params + else: + assert False + whiten_cfg = WhiteningConfig( + strategy=cfg.whiten_strategy, + estimator=cfg.whiten_estimator, + radius=cfg.subtraction_radius_um, + interp_params=temp_interp_params, + ) if cfg.dredge_only: n_residual_snips = 0 else: @@ -1032,6 +1065,8 @@ def to_internal_config(cfg) -> DARTsortInternalConfig: first_denoiser_noise_snips=cfg.nn_denoiser_noise_waveforms, first_denoiser_spatial_dedup_radius=cfg.first_denoiser_spatial_dedup_radius, subtraction_denoising_cfg=subtraction_denoising_cfg, + whiten=cfg.whiten_in_subtraction, + whiten_cfg=whiten_cfg, ) elif cfg.detection_type == "threshold": initial_detection_cfg = ThresholdingConfig( @@ -1060,24 +1095,6 @@ def to_internal_config(cfg) -> DARTsortInternalConfig: else: raise ValueError(f"Unknown detection_type {cfg.detection_type}.") - if cfg.template_interp_kind == "tps": - temp_interp_params = tps_interp_clampna_extrap_params - elif cfg.template_interp_kind == "clampna": - temp_interp_params = clampna_interp_params - else: - assert False - if cfg.matching_interp_kind == "tps": - match_interp_params = tps_interp_clampna_extrap_params - elif cfg.matching_interp_kind == "clampna": - match_interp_params = clampna_interp_params - else: - assert False - whiten_cfg = WhiteningConfig( - strategy=cfg.whiten_strategy, - estimator=cfg.whiten_estimator, - radius=cfg.subtraction_radius_um, - interp_params=temp_interp_params, - ) template_cfg = TemplateConfig( denoising_fit_radius=cfg.fit_radius_um, spikes_per_unit=cfg.template_spikes_per_unit, diff --git a/tests/test_config.py b/tests/test_config.py index 6063ee83..4e66d68e 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -1,4 +1,5 @@ import dataclasses + import dartsort From 63c6fe94f0056b48624207f08c763ea6b8f09e80 Mon Sep 17 00:00:00 2001 From: Charlie Windolf Date: Thu, 4 Jun 2026 15:37:10 -0400 Subject: [PATCH 04/46] whiten: debug, handle border nans --- src/dartsort/peel/peel_lib.py | 8 +++++++- src/dartsort/peel/subtract.py | 2 +- src/dartsort/transform/whiten.py | 28 +++++++++++++++++++++++----- src/dartsort/util/noise_util.py | 6 ++++++ 4 files changed, 37 insertions(+), 7 deletions(-) diff --git a/src/dartsort/peel/peel_lib.py b/src/dartsort/peel/peel_lib.py index 78ab28f7..cb0d65a4 100644 --- a/src/dartsort/peel/peel_lib.py +++ b/src/dartsort/peel/peel_lib.py @@ -89,8 +89,14 @@ def check_residual_decrease( if local_whiteners is not None: assert channels is not None W = local_whiteners[channels].mT + if overwrite_orig_waveforms: + orig_wfs = orig_wfs.nan_to_num_() + else: + orig_wfs = orig_wfs.nan_to_num() + buf = orig_wfs orig_wfs = orig_wfs.bmm(W) - dn_wfs = dn_wfs.bmm(W) + dn_wfs = dn_wfs.nan_to_num() + dn_wfs = torch.bmm(dn_wfs, W, out=buf) if decrease_objective == "deconv": if overwrite_orig_waveforms: diff --git a/src/dartsort/peel/subtract.py b/src/dartsort/peel/subtract.py index 180d1027..ba413715 100644 --- a/src/dartsort/peel/subtract.py +++ b/src/dartsort/peel/subtract.py @@ -164,7 +164,7 @@ def post_fit(self): assert len(whitener) == 1 whitener = whitener[0].whitener assert whitener is not None - local_whiteners = whitener.local_whiteners(self.b.subtract_channel_index) + local_whiteners = whitener.local_whiteners(self.b.sub_channel_index) self.del_none_buffer("local_whiteners") self.register_buffer("local_whiteners", local_whiteners) diff --git a/src/dartsort/transform/whiten.py b/src/dartsort/transform/whiten.py index b1c9fd6b..a590f1d7 100644 --- a/src/dartsort/transform/whiten.py +++ b/src/dartsort/transform/whiten.py @@ -5,7 +5,12 @@ from spikeinterface.core import BaseRecording from ..util.data_util import DARTsortSorting -from ..util.internal_config import ComputationConfig, WhiteningConfig +from ..util.internal_config import ( + ComputationConfig, + WaveformConfig, + WhiteningConfig, + default_waveform_cfg, +) from .transform_base import BaseWaveformDenoiser if TYPE_CHECKING: @@ -24,22 +29,20 @@ def __init__( channel_index, name=None, name_prefix=None, + waveform_cfg: WaveformConfig | None = default_waveform_cfg, whitener: "SpatialWhitener | None" = None, disabled: bool = True, whiten_cfg: WhiteningConfig = WhiteningConfig(), + sampling_frequency: float = 30_000.0, ): super().__init__( name=name, name_prefix=name_prefix, geom=geom, channel_index=channel_index ) - assert channel_index.shape[1] == geom.shape[0], ( - "Meant to be used with full-probe data." - ) self.whitener = whitener self.disabled = disabled self.whiten_cfg = whiten_cfg self.motion = None - @property def needs_fit(self): return self.whitener is None @@ -53,6 +56,13 @@ def forward(self, waveforms, **unused): def attach_motion(self, motion): self.motion = motion + def _other_pre_load_state(self, state_dict, prefix): + if self.whitener is not None: + return + from ..util.noise_util import SpatialWhitener + + self.whitener = SpatialWhitener.blank(len(self.b.geom), self.b.geom.device) + def fit( self, recording: BaseRecording, @@ -63,6 +73,14 @@ def fit( pipeline: "WaveformPipeline | None" = None, **spike_data: torch.Tensor, ): + super().fit( + recording, + waveforms, + hdf5_filename=hdf5_filename, + computation_cfg=computation_cfg, + pipeline=pipeline, + **spike_data, + ) del recording, spike_data, waveforms, pipeline from ..util.noise_util import SpatialWhitener diff --git a/src/dartsort/util/noise_util.py b/src/dartsort/util/noise_util.py index 95100744..9c766736 100644 --- a/src/dartsort/util/noise_util.py +++ b/src/dartsort/util/noise_util.py @@ -1558,6 +1558,11 @@ def __init__(self, whitener: Tensor, covariance: Tensor): self.register_buffer("whitener", whitener) self.register_buffer("covariance", covariance) + @classmethod + def blank(cls, n_channels: int, device: torch.device): + w = torch.zeros((n_channels, n_channels), device=device) + return cls(w, torch.zeros_like(w)) + @classmethod def from_numpy(cls, whitener: np.ndarray, covariance: np.ndarray): logger.dartsortverbose("Load whitener from numpy.") @@ -1625,6 +1630,7 @@ def prec_mul(self, x: Tensor) -> Tensor: return x def local_whiteners(self, channel_index: Tensor, eps=1e-6): + channel_index = torch.asarray(channel_index) nc, cloc = channel_index.shape assert nc == self.b.covariance.shape[0] w = self.b.covariance.new_zeros((nc, cloc, cloc)) From 4b25b95d4798e72792536fd9681bd39dac65a343 Mon Sep 17 00:00:00 2001 From: Charlie Windolf Date: Mon, 8 Jun 2026 09:58:52 -0400 Subject: [PATCH 05/46] clean and add some analysis vis --- src/dartsort/config.py | 6 +- src/dartsort/evaluate/analysis.py | 4 +- src/dartsort/peel/peel_lib.py | 60 ++++++------ src/dartsort/peel/subtract.py | 15 ++- src/dartsort/util/data_util.py | 9 ++ src/dartsort/util/internal_config.py | 7 +- src/dartsort/vis/analysis_plots.py | 137 ++++++++++++++++++++++++++- 7 files changed, 196 insertions(+), 42 deletions(-) diff --git a/src/dartsort/config.py b/src/dartsort/config.py index b4741424..b6630a88 100644 --- a/src/dartsort/config.py +++ b/src/dartsort/config.py @@ -126,7 +126,6 @@ class DARTsortUserConfig: deduplication_ms: Annotated[float, Field(gt=0)] = 0.5 """As a final postprocessing step, only the higher-scoring of any spikes within this time radius of each other are kept. - If this is negative, it does nothing. If it's 0, exact duplicates are dropped. """ @@ -146,7 +145,7 @@ class DARTsortUserConfig: threshold in Kilosort and other sorters, and it represents reduction in Euclidean norm of standardized data due to matching a new event.""" - initial_threshold: Annotated[float, Field(gt=0)] = 10.0 + initial_threshold: Annotated[float, Field(gt=0)] = 9.0 """Initial detection's neural net matching threshold. Same as matching_threshold, except that a neural net is trying to guess the true waveforms here, rather than using cluster templates.""" @@ -254,7 +253,8 @@ class DeveloperConfig(DARTsortUserConfig): first_denoiser_spatial_dedup_radius: float = 100.0 realign_to_denoiser: bool = True use_nn_in_subtraction: bool = True - whiten_in_subtraction: bool = False + whiten_in_subtraction: bool = True + threshold_before_whitening: float = 9.0 # matching matching_template_type: Literal["individual_compressed_upsampled", "drifty"] = ( diff --git a/src/dartsort/evaluate/analysis.py b/src/dartsort/evaluate/analysis.py index bf36a6c7..c12497e3 100644 --- a/src/dartsort/evaluate/analysis.py +++ b/src/dartsort/evaluate/analysis.py @@ -641,7 +641,9 @@ def nearby_coarse_templates(self, unit_id, n_neighbors=5): assert td is not None assert self.merge_distances is not None - unit_ix = np.searchsorted(td.unit_ids, unit_id) + unit_ix = np.flatnonzero(td.unit_ids == unit_id) + assert unit_ix.shape[0] == 1 + unit_ix = unit_ix[0] unit_dists = self.merge_distances[unit_ix] distance_order = np.argsort(unit_dists) distance_order = np.concatenate( diff --git a/src/dartsort/peel/peel_lib.py b/src/dartsort/peel/peel_lib.py index cb0d65a4..863cc10a 100644 --- a/src/dartsort/peel/peel_lib.py +++ b/src/dartsort/peel/peel_lib.py @@ -25,19 +25,20 @@ if TYPE_CHECKING: from ..transform.pipeline import WaveformPipeline + from ..util.internal_config import PeakSign def denoiser_time_shifts( - waveforms, - channels, - voltages, - subtract_rel_inds, - trough_offset_samples, - spike_length_samples, - peak_sign, - denoiser_realignment_shift, - denoiser_realignment_channel, -): + waveforms: Tensor, + channels: Tensor, + voltages: Tensor, + subtract_rel_inds: Tensor | None, + trough_offset_samples: int, + spike_length_samples: int, + peak_sign: "PeakSign", + denoiser_realignment_shift: int, + denoiser_realignment_channel: Literal["detection", "denoised"], +) -> Tensor: # extract main channel traces if denoiser_realignment_channel == "detection": assert subtract_rel_inds is not None @@ -74,17 +75,18 @@ def denoiser_time_shifts( def check_residual_decrease( - orig_wfs, - dn_wfs, - decrease_objective="deconv", + orig_wfs: Tensor | None, + dn_wfs: Tensor, + decrease_objective: Literal["deconv", "norm", "normsq"] = "deconv", threshold=10.0, save_residnorm_decrease=False, overwrite_orig_waveforms: bool = False, local_whiteners: Tensor | None = None, channels: Tensor | None = None, -): +) -> tuple[Tensor | None, dict[str, Tensor]]: if not threshold: return None, {} + assert orig_wfs is not None if local_whiteners is not None: assert channels is not None @@ -106,7 +108,7 @@ def check_residual_decrease( norm = buf.nan_to_num_().sum(dim=(1, 2)) else: dn_wfs = dn_wfs.nan_to_num() - conv = (orig_wfs * dn_wfs).sum(dim=(1, 2)) + conv = (orig_wfs * dn_wfs).nan_to_num_().sum(dim=(1, 2)) norm = dn_wfs.square_().sum(dim=(1, 2)) reduction = conv.mul_(2.0).sub_(norm) threshold = threshold**2 @@ -146,11 +148,11 @@ def check_residual_decrease( def subtract_chunk( - traces, - channel_index, - denoising_pipeline, - extract_index=None, - extract_mask=None, + traces: Tensor, + channel_index: Tensor, + denoising_pipeline: WaveformPipeline, + extract_index: Tensor | None = None, + extract_mask: Tensor | None = None, trough_offset_samples=42, spike_length_samples=121, left_margin=0, @@ -162,10 +164,10 @@ def subtract_chunk( denoiser_realignment_channel="detection", convexity_threshold=None, convexity_radius=3, - peak_channel_index=None, - dedup_channel_index=None, - subtract_rel_inds=None, - dedup_rel_inds=None, + peak_channel_index: Tensor | None = None, + dedup_channel_index: Tensor | None = None, + subtract_rel_inds: Tensor | None = None, + dedup_rel_inds: Tensor | None = None, residnorm_decrease_threshold=16.0, decrease_objective: Literal["norm", "normsq", "deconv"] = "deconv", local_whiteners: Tensor | None = None, @@ -176,13 +178,13 @@ def subtract_chunk( dedup_batch_size=512, no_subtraction=False, max_iter=100, - trough_priority=None, - growth_tolerance=None, + trough_priority: float | None = None, + growth_tolerance: float | None = None, cumulant_order=None, save_iteration=False, save_residnorm_decrease=False, compute_collidedness=False, -): +) -> ChunkSubtractionResult: """Core peeling routine for subtraction""" if no_subtraction: threshold_res = threshold_chunk( @@ -238,7 +240,7 @@ def subtract_chunk( if growth_tolerance is not None: gtol = traces.abs().add_(growth_tolerance) else: - gtol = 0.0 + gtol = None # initialize residual, it needs to be padded to support # our channel indexing convention. this copies the input. @@ -262,7 +264,7 @@ def subtract_chunk( for it in range(max_iter): residual_det = residual[:, :-1] - if it and growth_tolerance is not None: + if it and gtol is not None: residual_det = residual_det.clamp(-gtol, gtol) times_samples, channels = detect_and_deduplicate( diff --git a/src/dartsort/peel/subtract.py b/src/dartsort/peel/subtract.py index ba413715..ea9f0cba 100644 --- a/src/dartsort/peel/subtract.py +++ b/src/dartsort/peel/subtract.py @@ -72,6 +72,10 @@ def __init__( self.save_residnorm_decrease = save_residnorm_decrease self.save_collidedness = save_collidedness self.dedup_batch_size = self.nearest_batch_length() + if self.p.whiten: + self.threshold = self.p.threshold_before_whitening + else: + self.threshold = self.p.residnorm_decrease_threshold geom = recording.get_channel_locations() sub_channel_index = make_channel_index( @@ -167,6 +171,7 @@ def post_fit(self): local_whiteners = whitener.local_whiteners(self.b.sub_channel_index) self.del_none_buffer("local_whiteners") self.register_buffer("local_whiteners", local_whiteners) + self.threshold = self.p.residnorm_decrease_threshold def save_models(self, save_folder): super().save_models(save_folder) @@ -252,12 +257,12 @@ def peel_chunk( ): del return_waveforms # always done here - extract_index = None if self.extract_subtract_same else self.channel_index + extract_index = None if self.extract_subtract_same else self.b.channel_index traces = traces.to(self.dtype) subtraction_result = subtract_chunk( traces, - self.sub_channel_index, + self.b.sub_channel_index, self.subtraction_denoising_pipeline, extract_index=extract_index, extract_mask=self.extract_subtract_mask, @@ -274,7 +279,7 @@ def peel_chunk( dedup_temporal_radius=self.p.temporal_dedup_radius_samples, remove_exact_duplicates=self.p.remove_exact_duplicates, pos_dedup_temporal_radius=self.p.positive_temporal_dedup_radius_samples, - residnorm_decrease_threshold=self.p.residnorm_decrease_threshold, + residnorm_decrease_threshold=self.threshold, decrease_objective=self.p.decrease_objective, trough_priority=self.p.trough_priority, growth_tolerance=self.p.growth_tolerance, @@ -284,8 +289,8 @@ def peel_chunk( save_iteration=self.save_iteration, save_residnorm_decrease=self.save_residnorm_decrease, max_iter=self.p.max_iter, - subtract_rel_inds=self.subtract_index_rel_inds, - dedup_rel_inds=self.dedup_rel_inds, + subtract_rel_inds=self.b.subtract_index_rel_inds, + dedup_rel_inds=self.b.dedup_rel_inds, realign_to_denoiser=self.p.realign_to_denoiser, denoiser_realignment_shift=self.p.denoiser_realignment_shift, denoiser_realignment_channel=self.p.denoiser_realignment_channel, diff --git a/src/dartsort/util/data_util.py b/src/dartsort/util/data_util.py index 497cf57f..0809d55d 100644 --- a/src/dartsort/util/data_util.py +++ b/src/dartsort/util/data_util.py @@ -87,6 +87,10 @@ def __init__( if ephemeral_features is not None: for k, v in ephemeral_features.items(): check_shape = not self._no_check_needed(k) + if k in self._persistent_features: + assert np.array_equal(v, self._persistent_features[k]) + assert hasattr(self, k) + continue self.add_ephemeral_feature(k, v, check_shape=check_shape) @property @@ -207,6 +211,11 @@ def to_tsgroup( trains[unit_id] = Tsd(t=ut, d=uw) return TsGroup(trains, metadata=metadata) + def permute_labels(self, seed=0): + assert self.labels is not None + rg = np.random.default_rng(seed) + return self.ephemeral_replace(labels=rg.permutation(self.labels)) + def copy(self) -> Self: """Shallow copy. Doesn't copy data, but copies references and internal state.""" other = copy(self) diff --git a/src/dartsort/util/internal_config.py b/src/dartsort/util/internal_config.py index 47db4e3e..bf469b14 100644 --- a/src/dartsort/util/internal_config.py +++ b/src/dartsort/util/internal_config.py @@ -671,7 +671,7 @@ class SubtractionConfig: remove_exact_duplicates: bool = True positive_temporal_dedup_radius_samples: int = 41 subtract_radius_um: float = 200.0 - residnorm_decrease_threshold: float = 10.0 + residnorm_decrease_threshold: float = 9.0 decrease_objective: Literal["norm", "normsq", "deconv"] = "deconv" growth_tolerance: float | None = None trough_priority: float | None = 2.0 @@ -679,8 +679,9 @@ class SubtractionConfig: convexity_threshold: float | None = None convexity_radius: int = 7 max_iter: int = 100 - whiten: bool = False - whiten_cfg: WhiteningConfig | None = None + whiten: bool = True + threshold_before_whitening: float = 9.0 + whiten_cfg: WhiteningConfig | None = WhiteningConfig() # how will waveforms be denoised before subtraction? # users can also save waveforms/features during subtraction diff --git a/src/dartsort/vis/analysis_plots.py b/src/dartsort/vis/analysis_plots.py index 77a090ed..7aa9837f 100644 --- a/src/dartsort/vis/analysis_plots.py +++ b/src/dartsort/vis/analysis_plots.py @@ -1,5 +1,6 @@ import matplotlib.pyplot as plt import numpy as np +import pandas as pd import scipy.cluster.hierarchy import seaborn as sns import torch @@ -10,7 +11,11 @@ from ..clustering.cluster_util import leafsets from ..util import spikeio -from ..util.data_util import DARTsortSorting, try_get_denoising_pipeline +from ..util.data_util import ( + DARTsortSorting, + get_featurization_pipeline, + try_get_denoising_pipeline, +) from .colors import glasbey1024 @@ -525,3 +530,133 @@ def visualize_denoiser( if suptitle: fig.suptitle(suptitle) return fig + + +def plot_denoiser_scores( + recording: BaseRecording, + vis_sorting: DARTsortSorting, + load_denoiser_from_sorting: DARTsortSorting | None = None, + count_per_unit: int = 128, + figscale: float = 2.0, + decrease_objective="deconv", + seed: int = 0, + vmax=50.0, + dv=0.5, +): + from ..peel.peel_lib import check_residual_decrease + from ..transform import WaveformWhitener + + # load denoiser + if load_denoiser_from_sorting is None: + load_denoiser_from_sorting = vis_sorting + dn, geom, channel_index = try_get_denoising_pipeline(load_denoiser_from_sorting) + assert dn is not None + assert channel_index is not None + assert geom is not None + + # try load whitener + fp = get_featurization_pipeline(load_denoiser_from_sorting) + whitener = [f for f in fp.transformers if isinstance(f, WaveformWhitener)] + assert len(whitener) <= 1 + if len(whitener) == 0: + local_whiteners = None + else: + whitener = whitener[0].whitener + local_whiteners = whitener.local_whiteners(channel_index) + + # choose and load examples + rg = np.random.default_rng(seed) + times_samples = [] + channels = [] + labels = [] + scores_unwhitened = [] + scores_whitened = None if local_whiteners is None else [] + assert vis_sorting.labels is not None + for unit_id in np.unique(vis_sorting.unit_ids): + if unit_id < 0: + continue + + in_unit = np.flatnonzero(vis_sorting.labels == unit_id) + if in_unit.size <= count_per_unit: + choices = in_unit + else: + choices = rg.choice(in_unit, size=count_per_unit, replace=False) + choices.sort() + + tt = vis_sorting.times_samples[choices] + cc = vis_sorting.channels[choices] + x = spikeio.read_waveforms_channel_index( + recording, + times_samples=tt, + main_channels=cc, + channel_index=channel_index.numpy(force=True), + ) + x = torch.asarray(x, dtype=torch.float) + y, _ = dn(x, channels=torch.asarray(cc)) + + _, sc_res_a = check_residual_decrease( + x, + y, + decrease_objective=decrease_objective, + threshold=-1.0, + save_residnorm_decrease=True, + ) + + times_samples.append(tt) + channels.append(cc) + labels.append(vis_sorting.labels[choices]) + scores_unwhitened.append(sc_res_a["residnorm_decreases"]) + + if local_whiteners is None: + continue + assert scores_whitened is not None + + _, sc_res_b = check_residual_decrease( + x, + y, + decrease_objective=decrease_objective, + threshold=-1.0, + save_residnorm_decrease=True, + local_whiteners=local_whiteners, + channels=torch.asarray(cc), + ) + scores_whitened.append(sc_res_b["residnorm_decreases"]) + + data = dict( + time_samples=np.concatenate(times_samples), + channel=np.concatenate(channels), + label=np.concatenate(labels), + score_unwhitened=np.sqrt(np.maximum(0.0, np.concatenate(scores_unwhitened))), + ) + if scores_whitened is not None: + data["score_whitened"] = np.sqrt(np.maximum(0.0, np.concatenate(scores_whitened))) + df = pd.DataFrame(data) + + bins = np.arange(0.0, vmax, step=dv) + ncols = 1 + int(scores_whitened is not None) + fig, axes = plt.subplots( + nrows=1, ncols=ncols, squeeze=False, figsize=(2.0 * figscale * ncols, figscale) + ) + + for unit_id, sub_df in df.groupby("label"): + unit_id = int(unit_id) # type: ignore + axes[0, 0].hist( + sub_df.score_unwhitened, + bins=bins, + histtype="step", + color=glasbey1024[unit_id % len(glasbey1024)], + ) + if scores_whitened is None: + continue + axes[0, 1].hist( + sub_df.score_whitened, + bins=bins, + histtype="step", + color=glasbey1024[unit_id % len(glasbey1024)], + ) + + for ax, name in zip(axes.flat, ["original", "whitened"]): + ax.grid() + ax.set_xlabel(f"{name} score") + + return fig, axes, df From 0dc0f342ba03f5ec00aecf2f3a0b401155642a8f Mon Sep 17 00:00:00 2001 From: Charlie Windolf Date: Mon, 8 Jun 2026 10:17:35 -0400 Subject: [PATCH 06/46] cfg --- src/dartsort/config.py | 2 +- src/dartsort/util/internal_config.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/dartsort/config.py b/src/dartsort/config.py index b6630a88..4a4f79c7 100644 --- a/src/dartsort/config.py +++ b/src/dartsort/config.py @@ -254,7 +254,7 @@ class DeveloperConfig(DARTsortUserConfig): realign_to_denoiser: bool = True use_nn_in_subtraction: bool = True whiten_in_subtraction: bool = True - threshold_before_whitening: float = 9.0 + threshold_before_whitening: float = 10.0 # matching matching_template_type: Literal["individual_compressed_upsampled", "drifty"] = ( diff --git a/src/dartsort/util/internal_config.py b/src/dartsort/util/internal_config.py index bf469b14..ae70fc55 100644 --- a/src/dartsort/util/internal_config.py +++ b/src/dartsort/util/internal_config.py @@ -680,8 +680,8 @@ class SubtractionConfig: convexity_radius: int = 7 max_iter: int = 100 whiten: bool = True - threshold_before_whitening: float = 9.0 - whiten_cfg: WhiteningConfig | None = WhiteningConfig() + threshold_before_whitening: float = 10.0 + whiten_cfg: WhiteningConfig | None = WhiteningConfig(strategy="prewhiten_postapply") # how will waveforms be denoised before subtraction? # users can also save waveforms/features during subtraction @@ -1217,13 +1217,13 @@ def to_internal_config(cfg) -> DARTsortInternalConfig: detection_threshold=cfg.motion_voltage_threshold, chunk_length_samples=cfg.chunk_length_samples, peak_sign=cfg.peak_sign, - shave_score=cfg.initial_threshold, + shave_score=cfg.threshold_before_whitening, ) motion_estimation_cfg = MotionEstimationConfig( **motion_kw, tpca_rank=cfg.temporal_pca_rank, threshold_cfg=motion_threshold_cfg, - spike_denoising_score=cfg.initial_threshold, + spike_denoising_score=cfg.threshold_before_whitening, ) matching_cfg = MatchingConfig( threshold="fp_control" if cfg.matching_fp_control else cfg.matching_threshold, From d644dc368c2ca088ab025324ef7ab46d090f4407 Mon Sep 17 00:00:00 2001 From: Charlie Windolf Date: Mon, 8 Jun 2026 11:06:15 -0400 Subject: [PATCH 07/46] dredge_only cfg broken --- src/dartsort/util/internal_config.py | 5 +++-- src/dartsort/util/noise_util.py | 3 ++- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/src/dartsort/util/internal_config.py b/src/dartsort/util/internal_config.py index ae70fc55..df1d6209 100644 --- a/src/dartsort/util/internal_config.py +++ b/src/dartsort/util/internal_config.py @@ -1027,7 +1027,8 @@ def to_internal_config(cfg) -> DARTsortInternalConfig: radius=cfg.subtraction_radius_um, interp_params=temp_interp_params, ) - if cfg.dredge_only: + # TODO: dredge_only is a bad name for this. + if cfg.dredge_only and not cfg.whiten_in_subtraction: n_residual_snips = 0 else: n_residual_snips = cfg.n_residual_snips @@ -1223,7 +1224,7 @@ def to_internal_config(cfg) -> DARTsortInternalConfig: **motion_kw, tpca_rank=cfg.temporal_pca_rank, threshold_cfg=motion_threshold_cfg, - spike_denoising_score=cfg.threshold_before_whitening, + spike_denoising_score=cfg.initial_threshold, ) matching_cfg = MatchingConfig( threshold="fp_control" if cfg.matching_fp_control else cfg.matching_threshold, diff --git a/src/dartsort/util/noise_util.py b/src/dartsort/util/noise_util.py index 9c766736..70889ea8 100644 --- a/src/dartsort/util/noise_util.py +++ b/src/dartsort/util/noise_util.py @@ -1458,7 +1458,7 @@ def residual_covariance( residual_dataset_name="residual", seed: int = 0, batch_size=256, -): +) -> Tensor: assert sorting.parent_h5_path is not None if do_interpolation: @@ -1497,6 +1497,7 @@ def residual_covariance( N += n w = n / N cov += scov.sub_(cov).mul_(w) + assert cov is not None return cov From 9fda04bb584a4f983ee6b172e126ae553b08aa75 Mon Sep 17 00:00:00 2001 From: Charlie Windolf Date: Mon, 8 Jun 2026 11:42:06 -0400 Subject: [PATCH 08/46] fix subtract tests --- src/dartsort/peel/peel_base.py | 2 +- src/dartsort/peel/subtract.py | 2 +- src/dartsort/util/noise_util.py | 2 +- tests/test_subtract.py | 17 ++++++++--------- 4 files changed, 11 insertions(+), 12 deletions(-) diff --git a/src/dartsort/peel/peel_base.py b/src/dartsort/peel/peel_base.py index b0e4cf8a..0e154320 100644 --- a/src/dartsort/peel/peel_base.py +++ b/src/dartsort/peel/peel_base.py @@ -243,7 +243,7 @@ def peel( resids_remaining = total_residual_snips - resids_so_far chunks_remaining = len(chunks_to_do) chunks_done = n_chunks_orig - chunks_remaining - chunks_cover = int(np.floor(ensure_coverage * n_chunks_orig)) + chunks_cover = int(np.ceil(ensure_coverage * n_chunks_orig)) chunks_cover_remaining = chunks_cover - chunks_done if chunks_cover_remaining == 0: assert resids_remaining == 0 diff --git a/src/dartsort/peel/subtract.py b/src/dartsort/peel/subtract.py index ea9f0cba..421a7c70 100644 --- a/src/dartsort/peel/subtract.py +++ b/src/dartsort/peel/subtract.py @@ -168,7 +168,7 @@ def post_fit(self): assert len(whitener) == 1 whitener = whitener[0].whitener assert whitener is not None - local_whiteners = whitener.local_whiteners(self.b.sub_channel_index) + local_whiteners = whitener.local_whiteners(self.sub_channel_index) # type: ignore self.del_none_buffer("local_whiteners") self.register_buffer("local_whiteners", local_whiteners) self.threshold = self.p.residnorm_decrease_threshold diff --git a/src/dartsort/util/noise_util.py b/src/dartsort/util/noise_util.py index 70889ea8..fcf84c2f 100644 --- a/src/dartsort/util/noise_util.py +++ b/src/dartsort/util/noise_util.py @@ -1497,6 +1497,7 @@ def residual_covariance( N += n w = n / N cov += scov.sub_(cov).mul_(w) + assert N > 0 assert cov is not None return cov @@ -1514,7 +1515,6 @@ def fullzca_whitener( def localzca_whitener( cov: np.ndarray, channel_index: np.ndarray, eps=1e-6 ) -> np.ndarray: - """""" w = np.zeros_like(cov) for j, chans in enumerate(channel_index): chans = chans[chans < len(channel_index)] diff --git a/tests/test_subtract.py b/tests/test_subtract.py index 4c9d98f6..cb9cdb96 100644 --- a/tests/test_subtract.py +++ b/tests/test_subtract.py @@ -18,7 +18,6 @@ FeaturizationConfig, FitSamplingConfig, SubtractionConfig, - default_pretrained_path ) fixedlenkeys = ( @@ -127,7 +126,6 @@ def test_fakedata_nonn(fakedata, tmp_path): with tempfile.TemporaryDirectory( dir=tmp_path, ignore_cleanup_errors=True ) as tempdir: - print("first one") torch.manual_seed(0) st = subtract( recording=rec, @@ -194,7 +192,7 @@ def test_fakedata_nonn(fakedata, tmp_path): recording=rec, output_dir=tempdir, featurization_cfg=featconf, - sampling_cfg=FitSamplingConfig(n_residual_snips=0), + sampling_cfg=sampconf, subtraction_cfg=subconf, overwrite=True, ) @@ -216,17 +214,17 @@ def test_fakedata_nonn(fakedata, tmp_path): channel_index.shape[1], ) - for ccfg in (two_jobs_cfg, two_jobs_cfg_spawn): + for cname, ccfg in (("", two_jobs_cfg), ("spawn", two_jobs_cfg_spawn)): with tempfile.TemporaryDirectory( dir=tmp_path, ignore_cleanup_errors=True ) as tempdir: - print("parallel first one") + print("parallel first one", cname) torch.manual_seed(0) st = subtract( recording=rec, output_dir=tempdir, featurization_cfg=nolocfeatconf, - sampling_cfg=FitSamplingConfig(n_residual_snips=0), + sampling_cfg=sampconf, subtraction_cfg=subconf, overwrite=True, computation_cfg=ccfg, @@ -284,7 +282,7 @@ def test_fakedata_nonn(fakedata, tmp_path): recording=rec, output_dir=tempdir, featurization_cfg=nolocfeatconf, - sampling_cfg=FitSamplingConfig(n_residual_snips=0), + sampling_cfg=sampconf, subtraction_cfg=subconf, overwrite=True, computation_cfg=ccfg, @@ -314,12 +312,12 @@ def test_resume(fakedata, tmp_path): detection_threshold=20.0, peak_sign="both", subtraction_denoising_cfg=FeaturizationConfig( - do_nn_denoise=False, - denoise_only=True + do_nn_denoise=False, denoise_only=True ), first_denoiser_thinning=0.0, first_denoiser_spatial_jitter=0, first_denoiser_temporal_jitter=0, + whiten=False, ) featconf = FeaturizationConfig(skip=True) sampconf = FitSamplingConfig(n_residual_snips=0) @@ -413,6 +411,7 @@ def test_small_nonn(tmp_path, nn_localization): subtraction_denoising_cfg=FeaturizationConfig( do_nn_denoise=False, denoise_only=True ), + whiten=False, ) featconf = FeaturizationConfig(do_nn_denoise=False, nn_localization=nn_localization) From d8b1b2517a6c8a6ca4fd5d7cf0de6e959c4f49fb Mon Sep 17 00:00:00 2001 From: Charlie Windolf Date: Mon, 8 Jun 2026 12:10:56 -0400 Subject: [PATCH 09/46] cfg --- src/dartsort/util/internal_config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/dartsort/util/internal_config.py b/src/dartsort/util/internal_config.py index df1d6209..227acba6 100644 --- a/src/dartsort/util/internal_config.py +++ b/src/dartsort/util/internal_config.py @@ -1224,7 +1224,7 @@ def to_internal_config(cfg) -> DARTsortInternalConfig: **motion_kw, tpca_rank=cfg.temporal_pca_rank, threshold_cfg=motion_threshold_cfg, - spike_denoising_score=cfg.initial_threshold, + spike_denoising_score=cfg.threshold_before_whitening, ) matching_cfg = MatchingConfig( threshold="fp_control" if cfg.matching_fp_control else cfg.matching_threshold, From aae6b0bc42669e3c87723c21a3aedb7ab4115086 Mon Sep 17 00:00:00 2001 From: Charlie Windolf Date: Mon, 8 Jun 2026 16:55:02 -0400 Subject: [PATCH 10/46] template: batches in grab to cap reduction_template memory --- src/dartsort/clustering/mixture.py | 21 +----- src/dartsort/config.py | 5 +- src/dartsort/peel/grab.py | 105 +++++++++++++++++++++++---- src/dartsort/peel/peel_base.py | 66 +++++++++++++---- src/dartsort/util/internal_config.py | 3 + src/dartsort/util/spiketorch.py | 29 ++++++-- tests/test_matching.py | 8 +- 7 files changed, 176 insertions(+), 61 deletions(-) diff --git a/src/dartsort/clustering/mixture.py b/src/dartsort/clustering/mixture.py index 89f3497a..9c0a771d 100644 --- a/src/dartsort/clustering/mixture.py +++ b/src/dartsort/clustering/mixture.py @@ -55,7 +55,6 @@ import numpy as np import torch import torch.nn.functional as F -from packaging.version import Version from scipy.sparse.csgraph import connected_components from sympy.utilities.iterables import multiset_partitions, subsets from torch import Tensor @@ -88,6 +87,7 @@ from ..util.noise_util import EmbeddedNoise from ..util.py_util import databag from ..util.spiketorch import ( + _nonzero_static, cosine_distance, ecl, entropy, @@ -102,13 +102,6 @@ from .clustering_features import StableWaveformFeatures from .kmeans import kmeans -TORCH_IS_OLD = Version(torch.__version__) < Version("2.6.0") -if TORCH_IS_OLD and torch.cuda.is_available(): - warnings.warn( - f"Your PyTorch version ({torch.__version__}) is supported by dartsort, " - "but dartsort would be faster if you had >= 2.6.0." - ) - if TYPE_CHECKING: from ..transform.temporal_pca import BaseTemporalPCA @@ -6194,18 +6187,6 @@ def _count_candidates(candidates, batch_candidate_counts, batch_size): batch_candidate_counts.copy_(counts.cpu()) -if TORCH_IS_OLD: - - def _nonzero_static(x: Tensor, size: int): - nz = x.nonzero() - assert nz.shape[0] == size - return nz -else: - - def _nonzero_static(x: Tensor, size: int): - return x.nonzero_static(size=size) - - @torch_compiler(fullgraph=False) def _combine_similar_resps(resps: Tensor, keep_mask: Tensor, n_keep: int) -> Tensor: n_discard = resps.shape[1] - n_keep diff --git a/src/dartsort/config.py b/src/dartsort/config.py index 4a4f79c7..3a2baf42 100644 --- a/src/dartsort/config.py +++ b/src/dartsort/config.py @@ -123,7 +123,7 @@ class DARTsortUserConfig: alignment_ms: Annotated[float, Field(gt=0)] = 1.5 """Largest time shift allowed when re-aligning events.""" - deduplication_ms: Annotated[float, Field(gt=0)] = 0.5 + deduplication_ms: Annotated[float, Field(gt=0)] = 0.5 """As a final postprocessing step, only the higher-scoring of any spikes within this time radius of each other are kept. If this is negative, it does nothing. If it's 0, exact duplicates are dropped. @@ -255,6 +255,9 @@ class DeveloperConfig(DARTsortUserConfig): use_nn_in_subtraction: bool = True whiten_in_subtraction: bool = True threshold_before_whitening: float = 10.0 + temporal_dedup_radius_samples: int = 11 + positive_temporal_dedup_radius_samples: int = 41 + spatial_dedup_radius_um: float | None = 50.0 # matching matching_template_type: Literal["individual_compressed_upsampled", "drifty"] = ( diff --git a/src/dartsort/peel/grab.py b/src/dartsort/peel/grab.py index b7e9e6bf..5219080b 100644 --- a/src/dartsort/peel/grab.py +++ b/src/dartsort/peel/grab.py @@ -1,4 +1,5 @@ """Grab and featurize events at known times.""" + from typing import Mapping import numpy as np @@ -6,7 +7,6 @@ from spikeinterface import BaseRecording from ..transform import WaveformPipeline -from ..util import spiketorch from ..util.data_util import DARTsortSorting from ..util.internal_config import ( FeaturizationConfig, @@ -14,6 +14,7 @@ WaveformConfig, default_waveform_cfg, ) +from ..util.spiketorch import _nonzero_static, grab_spikes from ..util.waveform_util import make_channel_index from .peel_base import ( BasePeeler, @@ -39,6 +40,7 @@ def __init__( chunk_length_samples=30_000, fit_sampling_cfg: FitSamplingConfig = FitSamplingConfig(n_residual_snips=0), waveform_cfg: WaveformConfig = default_waveform_cfg, + batch_size: int = 2048, dtype=torch.float, ): fixed_properties = fixed_properties or {} @@ -49,6 +51,7 @@ def __init__( spike_length_samples = waveform_cfg.spike_length_samples( recording.sampling_frequency ) + self.batch_size = batch_size assert not fit_sampling_cfg.n_residual_snips super().__init__( recording=recording, @@ -94,17 +97,44 @@ def process_chunk( ) t_clip = self.b.times_samples.clip(chunk_start_samples, chunk_end_samples - 1) in_chunk = self.b.times_samples == t_clip - if not in_chunk.any(): + n_in_chunk = int(in_chunk.sum().item()) + if not n_in_chunk: return peeling_empty_result - res = super().process_chunk( - chunk_start_samples, - n_resid_snips=n_resid_snips, - chunk_end_samples=chunk_end_samples, - return_residual=return_residual, - skip_features=skip_features, - to_cpu=to_cpu, + in_chunk = _nonzero_static(in_chunk, size=n_in_chunk) + assert in_chunk.shape == (n_in_chunk, 1) + in_chunk = in_chunk[:, 0] + return_waveforms = not skip_features and bool(self.featurization_pipeline) + + chunk, chunk_end_samples_, left_margin, right_margin = self.get_chunk( + chunk_start_samples, chunk_end_samples ) + assert chunk_end_samples == chunk_end_samples_ + + batch_results = [] + for i0 in range(0, n_in_chunk, self.batch_size): + i1 = min(n_in_chunk, i0 + self.batch_size) + batch_in_chunk = in_chunk[i0:i1] + batch_peel_result = self.peel_chunk( + traces=chunk, + chunk_start_samples=chunk_start_samples, + left_margin=left_margin, + right_margin=right_margin, + return_residual=return_residual, + in_chunk=batch_in_chunk, + ) + assert batch_peel_result["n_spikes"] == i1 - i0 + batch_res = self.featurize_chunk_result( + peel_result=batch_peel_result, + to_cpu=to_cpu, + return_waveforms=return_waveforms, + chunk_start_samples=chunk_start_samples, + chunk_end_samples=chunk_end_samples, + device=chunk.device, + n_resid_snips=n_resid_snips, + ) + batch_results.append(batch_res) + res = _cat_results(batch_results) return res @classmethod @@ -150,21 +180,23 @@ def from_config( def peel_chunk( self, - traces, + traces: torch.Tensor, *, chunk_start_samples=0, left_margin=0, right_margin=0, return_residual=False, return_waveforms=True, + in_chunk: torch.Tensor | None = None, ) -> PeelingBatchResult: assert not return_residual - max_t = chunk_start_samples + self.chunk_length_samples - 1 - in_chunk = self.b.times_samples == self.b.times_samples.clip( - chunk_start_samples, max_t - ) - (in_chunk,) = in_chunk.nonzero(as_tuple=True) + if in_chunk is None: + max_t = chunk_start_samples + self.chunk_length_samples - 1 + in_chunk = self.b.times_samples == self.b.times_samples.clip( + chunk_start_samples, max_t + ) + (in_chunk,) = in_chunk.nonzero(as_tuple=True) if not in_chunk.numel(): return peeling_empty_result @@ -184,7 +216,7 @@ def peel_chunk( channels=channels, ) if return_waveforms: - res["collisioncleaned_waveforms"] = spiketorch.grab_spikes( + res["collisioncleaned_waveforms"] = grab_spikes( traces, times_rel, channels, @@ -198,3 +230,44 @@ def peel_chunk( for k in self.fixed_property_keys: res[k] = getattr(self.b, k)[in_chunk].to(device=dev) return res + + +def _cat_results(results): + # not ideal, should work on typing these things better, but no time now. + assert len(results) > 0 + if len(results) == 1: + return results[0] + rdict: dict[str, int | float | list[torch.Tensor] | list[np.ndarray]] = { + "n_spikes": 0 + } + for res in results: + for k, v in res.items(): + if k == "n_spikes": + rdict[k] += v + elif k == "chunk_center_s": + rdict[k] = v + elif isinstance(v, (int, float)): + if k not in rdict: + rdict[k] = v + else: + assert rdict[k] == v + elif isinstance(v, (torch.Tensor, np.ndarray)): + if k not in rdict: + rdict[k] = [] + rdict[k].append(v) # type: ignore + else: + assert False + res = {} + for k, v in rdict.items(): + if k in ("n_spikes", "chunk_center_s"): + res[k] = v + elif isinstance(v, list): + if isinstance(v[0], torch.Tensor): + res[k] = torch.concatenate(v) # type: ignore + elif isinstance(v[0], np.ndarray): + res[k] = np.concatenate(v) + else: + assert False + else: + res[k] = v + return res diff --git a/src/dartsort/peel/peel_base.py b/src/dartsort/peel/peel_base.py index 0e154320..8f8ddff5 100644 --- a/src/dartsort/peel/peel_base.py +++ b/src/dartsort/peel/peel_base.py @@ -1,4 +1,5 @@ """Home of the BasePeeler, where shared logic for other modules here lives.""" + import gc import tempfile from concurrent.futures import CancelledError @@ -403,7 +404,7 @@ def peel( def peel_chunk( self, - traces, + traces: torch.Tensor, *, chunk_start_samples=0, left_margin=0, @@ -500,9 +501,36 @@ def process_chunk( Main function called in peeling workers """ + chunk, chunk_end_samples, left_margin, right_margin = self.get_chunk( + chunk_start_samples, chunk_end_samples + ) + return_waveforms = not skip_features and bool(self.featurization_pipeline) + peel_result = self.peel_chunk( + chunk, + chunk_start_samples=chunk_start_samples, + left_margin=left_margin, + right_margin=right_margin, + return_waveforms=return_waveforms, + return_residual=return_residual or bool(n_resid_snips), + ) + chunk_result = self.featurize_chunk_result( + peel_result=peel_result, + to_cpu=to_cpu, + return_waveforms=return_waveforms, + chunk_start_samples=chunk_start_samples, + chunk_end_samples=chunk_end_samples, + device=chunk.device, + n_resid_snips=n_resid_snips, + ) + return chunk_result + + def get_chunk(self, chunk_start_samples: int, chunk_end_samples: int | None = None): + Ts = self.recording.get_num_samples() if chunk_end_samples is None: chunk_end_samples = chunk_start_samples + self.chunk_length_samples - chunk_end_samples = min(self.recording.get_num_samples(), chunk_end_samples) + chunk_end_samples = min(Ts, chunk_end_samples) + print(f"{chunk_end_samples=} {Ts=}") + assert chunk_end_samples <= Ts chunk, left_margin, right_margin = get_chunk_with_margin( self.recording._recording_segments[0], start_frame=chunk_start_samples, @@ -512,15 +540,19 @@ def process_chunk( ) device = self.b.channel_index.device chunk = torch.tensor(chunk, device=device, dtype=self.dtype) - return_waveforms = not skip_features and bool(self.featurization_pipeline) - peel_result = self.peel_chunk( - chunk, - chunk_start_samples=chunk_start_samples, - left_margin=left_margin, - right_margin=right_margin, - return_waveforms=return_waveforms, - return_residual=return_residual or bool(n_resid_snips), - ) + return chunk, chunk_end_samples, left_margin, right_margin + + def featurize_chunk_result( + self, + *, + peel_result, + to_cpu: bool, + return_waveforms: bool, + chunk_start_samples: int, + chunk_end_samples: int, + device: torch.device, + n_resid_snips: int | None, + ): if peel_result["n_spikes"] > 0 and to_cpu: t_s = self.recording.sample_index_to_time( peel_result["times_samples"].numpy(force=True) @@ -543,12 +575,13 @@ def process_chunk( # a user who wants these must featurize with a waveform node # then they'll end up in `features` if "collisioncleaned_waveforms" in peel_result: - del peel_result["collisioncleaned_waveforms"] # type: ignore + del peel_result["collisioncleaned_waveforms"] chunk_result = peel_result | features if to_cpu: chunk_result = { - k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in chunk_result.items() + k: v.cpu() if isinstance(v, torch.Tensor) else v + for k, v in chunk_result.items() } if "residual" in peel_result: if torch.is_tensor(peel_result["residual"]): @@ -571,7 +604,7 @@ def process_chunk( self.recording.sample_index_to_time(resid_times_samples) ) - return chunk_result # type: ignore + return chunk_result def gather_chunk_result( self, @@ -1179,6 +1212,7 @@ def __init__( self.skip_features = skip_features self.chunk_length_samples = chunk_length_samples self.to_cpu = to_cpu + self.Ts = peeler.recording.get_num_samples() # this state will be set on each thread globally @@ -1224,10 +1258,12 @@ def _peeler_process_job(chunk_start_samples__n_resid_snips): # by returning here, we are implicitly relying on pickle # TODO: replace with manual np.saves with torch.no_grad(): - chunk_end_samples = None chlen = _peeler_process_context.ctx.chunk_length_samples if chlen is not None: chunk_end_samples = chunk_start_samples + chlen + chunk_end_samples = min(chunk_end_samples, _peeler_process_context.ctx.Ts) + else: + chunk_end_samples = None return _peeler_process_context.ctx.peeler.process_chunk( chunk_start_samples, n_resid_snips=n_resid_snips, diff --git a/src/dartsort/util/internal_config.py b/src/dartsort/util/internal_config.py index 227acba6..6ab011dc 100644 --- a/src/dartsort/util/internal_config.py +++ b/src/dartsort/util/internal_config.py @@ -1067,6 +1067,8 @@ def to_internal_config(cfg) -> DARTsortInternalConfig: first_denoiser_noise_snips=cfg.nn_denoiser_noise_waveforms, first_denoiser_spatial_dedup_radius=cfg.first_denoiser_spatial_dedup_radius, subtraction_denoising_cfg=subtraction_denoising_cfg, + temporal_dedup_radius_samples=cfg.temporal_dedup_radius_samples, + positive_temporal_dedup_radius_samples=cfg.positive_temporal_dedup_radius_samples, whiten=cfg.whiten_in_subtraction, whiten_cfg=whiten_cfg, ) @@ -1076,6 +1078,7 @@ def to_internal_config(cfg) -> DARTsortInternalConfig: detection_threshold=cfg.voltage_threshold, spatial_dedup_radius_um=cfg.deduplication_radius_um, chunk_length_samples=cfg.chunk_length_samples, + temporal_dedup_radius_samples=cfg.temporal_dedup_radius_samples, ) elif cfg.detection_type == "match": assert cfg.precomputed_templates_npz is not None diff --git a/src/dartsort/util/spiketorch.py b/src/dartsort/util/spiketorch.py index 2f56449e..de8f4d4c 100644 --- a/src/dartsort/util/spiketorch.py +++ b/src/dartsort/util/spiketorch.py @@ -1,6 +1,5 @@ import math import warnings -from logging import getLogger from typing import overload import linear_operator @@ -8,16 +7,24 @@ import torch import torch.nn.functional as F from linear_operator.utils.cholesky import psd_safe_cholesky +from packaging.version import Version from scipy.fftpack import next_fast_len from scipy.spatial.distance import squareform from sklearn.utils.extmath import svd_flip from torch import Tensor from torch.fft import irfft, rfft -from .logging_util import progrange +from .logging_util import get_logger, progrange from .torch_util import torch_compile, torch_compiler -logger = getLogger(__name__) +TORCH_IS_OLD = Version(torch.__version__) < Version("2.6.0") +if TORCH_IS_OLD and torch.cuda.is_available(): + warnings.warn( + f"Your PyTorch version ({torch.__version__}) is supported by dartsort, " + "but dartsort would be faster if you had >= 2.6.0." + ) + +logger = get_logger(__name__) log2pi = torch.log(torch.tensor(2 * np.pi)) _1 = torch.tensor(1.0) _0 = torch.tensor(0.0) @@ -104,6 +111,18 @@ def ptp(waveforms, dim=1, keepdims=False): return v.numpy() +if TORCH_IS_OLD: + + def _nonzero_static(x: Tensor, size: int): + nz = x.nonzero() + assert nz.shape[0] == size + return nz +else: + + def _nonzero_static(x: Tensor, size: int): + return x.nonzero_static(size=size) + + @torch_compile def mean_elbo_dim1(Q: Tensor, log_liks: Tensor) -> Tensor: logQ = Q.log().nan_to_num_(nan=None, neginf=0.0) @@ -752,7 +771,7 @@ def argrelmax( ): msk = _argrelmax_mask(x=x, radius=radius, threshold=threshold, arange=arange) return msk.nonzero()[:, 0] - + @torch_compile def _argrelmax_mask( @@ -819,7 +838,7 @@ def argrelmax_dedup( padding=padding, ) return msk.nonzero()[:, 0] - + @torch_compile def _argrelmax_dedup_mask( diff --git a/tests/test_matching.py b/tests/test_matching.py index 67b4c20b..f36f336c 100644 --- a/tests/test_matching.py +++ b/tests/test_matching.py @@ -214,7 +214,7 @@ def test_no_crumbs(subtests, refractory_sim, method, cd_iter, channel_selection_ assert gt_shift is not None gt_shift = gt_shift[gt_in_chunk] if "scalings" in res: - match_scale = res["scalings"].numpy(force=True) + match_scale = res["scalings"].numpy(force=True) # type: ignore else: match_scale = np.ones(gt_n_spikes, dtype=np.float32) @@ -288,7 +288,7 @@ def test_no_crumbs(subtests, refractory_sim, method, cd_iter, channel_selection_ ) if "up_inds" in res: - match_up = res["up_inds"].numpy(force=True) + match_up = res["up_inds"].numpy(force=True) # type: ignore np.testing.assert_array_equal(gt_up, match_up, err_msg="sorting: up_inds 1") else: np.testing.assert_array_equal( @@ -300,7 +300,7 @@ def test_no_crumbs(subtests, refractory_sim, method, cd_iter, channel_selection_ ) if "time_shifts" in res: - match_shift = res["time_shifts"].numpy(force=True) + match_shift = res["time_shifts"].numpy(force=True) # type: ignore np.testing.assert_array_equal(match_shift, gt_shift) else: assert (gt_shift == 0).all() @@ -546,7 +546,7 @@ def test_tiny_up(tiny_up_sim, tmp_path, up_factor, scaling, cd_iter, up_offset): # ) assert np.array_equal(res["labels"].numpy(force=True), labels) if "up_inds" in res: - assert np.array_equal(res["up_inds"].numpy(force=True), upsampling_indices) + assert np.array_equal(res["up_inds"].numpy(force=True), upsampling_indices) # type: ignore else: assert up_factor == 1 resid_rms = torch.square(res["residual"]).mean().numpy(force=True) From 11002ca5a348ab9be02ec77176fef61b6610edc6 Mon Sep 17 00:00:00 2001 From: Charlie Windolf Date: Mon, 8 Jun 2026 16:55:18 -0400 Subject: [PATCH 11/46] eval: full spike label vector --- src/dartsort/evaluate/comparison.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/src/dartsort/evaluate/comparison.py b/src/dartsort/evaluate/comparison.py index 4a9bfab7..22cd940f 100644 --- a/src/dartsort/evaluate/comparison.py +++ b/src/dartsort/evaluate/comparison.py @@ -264,6 +264,19 @@ def nearby_tested_templates(self, gt_unit_id, n_neighbors=5): return neighb_ixs, neighb_ids, neighb_dists, neighb_coarse_templates + def full_tested_labels(self): + labels = np.full(len(self.tested_analysis.sorting), "NA", dtype=" Date: Tue, 9 Jun 2026 11:36:48 -0400 Subject: [PATCH 12/46] left a print --- src/dartsort/peel/peel_base.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/dartsort/peel/peel_base.py b/src/dartsort/peel/peel_base.py index 8f8ddff5..6b5ca81e 100644 --- a/src/dartsort/peel/peel_base.py +++ b/src/dartsort/peel/peel_base.py @@ -529,7 +529,6 @@ def get_chunk(self, chunk_start_samples: int, chunk_end_samples: int | None = No if chunk_end_samples is None: chunk_end_samples = chunk_start_samples + self.chunk_length_samples chunk_end_samples = min(Ts, chunk_end_samples) - print(f"{chunk_end_samples=} {Ts=}") assert chunk_end_samples <= Ts chunk, left_margin, right_margin = get_chunk_with_margin( self.recording._recording_segments[0], From 1cfb1c62e1570963d20569d45e5b5dfc03a3b774 Mon Sep 17 00:00:00 2001 From: Charlie Windolf Date: Tue, 9 Jun 2026 16:58:51 -0400 Subject: [PATCH 13/46] eval: debug comparison vis --- src/dartsort/evaluate/analysis.py | 13 ++++++++++--- src/dartsort/evaluate/comparison.py | 24 ++++++++++++++++++++---- src/dartsort/vis/gt.py | 3 --- src/dartsort/vis/mixture.py | 2 -- src/dartsort/vis/unit_comparison.py | 18 ++++++++++++------ 5 files changed, 42 insertions(+), 18 deletions(-) diff --git a/src/dartsort/evaluate/analysis.py b/src/dartsort/evaluate/analysis.py index c12497e3..f81ab423 100644 --- a/src/dartsort/evaluate/analysis.py +++ b/src/dartsort/evaluate/analysis.py @@ -413,10 +413,14 @@ def unit_raw_waveforms( read_chans=read_chans, main_channel=main_channel, ) - channels = None + channels = self.vis_channel_index[main_channel] + channels = np.broadcast_to( + channels[None], (waveforms.shape[0], waveforms.shape[2]) + ) else: channels = read_channel_index[read_chans] - main_channel = self.unit_max_channel(unit_id) + if main_channel is None: + main_channel = self.unit_max_channel(unit_id) return WaveformsBag( which=which, waveforms=waveforms, @@ -558,7 +562,10 @@ def unit_pca_features( def unit_max_channel(self, unit_id) -> int: assert self.coarse_template_data is not None temp = self.coarse_template_data.unit_templates(unit_id) - assert temp.ndim == 3 and temp.shape[0] == np.atleast_1d(unit_id).size + assert temp.ndim == 3, f"{self.name}: {unit_id=} {temp.shape=}" + assert temp.shape[0] == np.atleast_1d(unit_id).size, ( + f"{self.name}: {unit_id=} {temp.shape=}" + ) which = self.in_unit(unit_id) if self.motion.drifting and hasattr(self.sorting, "channel_index"): diff --git a/src/dartsort/evaluate/comparison.py b/src/dartsort/evaluate/comparison.py index 22cd940f..3da1dfc1 100644 --- a/src/dartsort/evaluate/comparison.py +++ b/src/dartsort/evaluate/comparison.py @@ -516,6 +516,7 @@ def get_raw_waveforms_by_category( # load TP waveforms # which, waveforms, max_chan, show_geom, show_channel_index tp_waves = self.gt_analysis.unit_raw_waveforms( + unit_id=gt_unit, which=ind_groups["matched_gt_indices"], **waveform_kw, # type: ignore ) @@ -524,42 +525,54 @@ def get_raw_waveforms_by_category( w["tp"] = None w["geom"] = self.gt_analysis.registered_geom w["channel_index"] = self.gt_analysis.vis_channel_index + w["channels_tp"] = None else: w["which_tp"] = tp_waves.which w["tp"] = tp_waves.waveforms w["geom"] = tp_waves.geom w["channel_index"] = tp_waves.channel_index + w["channels_tp"] = tp_waves.channels # load FN waveforms # which, waveforms, max_chan, show_geom, show_channel_index fn_waves = self.gt_analysis.unit_raw_waveforms( + unit_id=gt_unit, which=ind_groups["only_gt_indices"], **waveform_kw, # type: ignore ) if fn_waves is None: w["which_fn"] = None w["fn"] = None + w["channels_fn"] = None else: w["which_fn"] = fn_waves.which w["fn"] = fn_waves.waveforms + w["channels_fn"] = fn_waves.channels # load FP waveforms # which, waveforms, max_chan, show_geom, show_channel_index - fp_waves = self.tested_analysis.unit_raw_waveforms( - which=ind_groups["only_tested_indices"], - **waveform_kw, # type: ignore - ) + if tested_unit is not None and tested_unit >= 0: + fp_waves = self.tested_analysis.unit_raw_waveforms( + unit_id=tested_unit, + which=ind_groups["only_tested_indices"], + **waveform_kw, # type: ignore + to_main_channel=True, + ) + else: + fp_waves = None if fp_waves is None: w["which_fp"] = None w["fp"] = None else: w["which_fp"] = fp_waves.which w["fp"] = fp_waves.waveforms + w["channels_fp"] = fp_waves.channels if self.unsorted_detection is None: w["unsorted_tp"] = w["unsorted_fn"] = None else: utp_waves = self.gt_analysis.unit_raw_waveforms( + unit_id=gt_unit, which=ind_groups["unsorted_tp_indices"], **waveform_kw, # type: ignore ) @@ -569,7 +582,9 @@ def get_raw_waveforms_by_category( else: w["which_unsorted_tp"] = utp_waves.which w["unsorted_tp"] = utp_waves.waveforms + w["channels_unsorted_tp"] = utp_waves.channels ufn_waves = self.gt_analysis.unit_raw_waveforms( + unit_id=gt_unit, which=ind_groups["unsorted_fn_indices"], **waveform_kw, # type: ignore ) @@ -579,6 +594,7 @@ def get_raw_waveforms_by_category( else: w["which_unsorted_fn"] = ufn_waves.which w["unsorted_fn"] = ufn_waves.waveforms + w["channels_unsorted_fn"] = ufn_waves.channels return w diff --git a/src/dartsort/vis/gt.py b/src/dartsort/vis/gt.py index a4bad362..fff4ab10 100644 --- a/src/dartsort/vis/gt.py +++ b/src/dartsort/vis/gt.py @@ -85,9 +85,6 @@ def draw(self, panel, comparison): col_order = col_order[col_order < comparison.template_distances.shape[1]] dist = comparison.template_distances[row_order, :][:, col_order] dist = dist.astype(np.float64) - print(f"{dist.shape=}") - print(f"{dist.min()=}") - print(f"{dist.max()=}") ax = panel.subplots() log1p_norm = FuncNorm((np.log1p, np.expm1), vmin=0) diff --git a/src/dartsort/vis/mixture.py b/src/dartsort/vis/mixture.py index 192f1329..4a63ce3c 100644 --- a/src/dartsort/vis/mixture.py +++ b/src/dartsort/vis/mixture.py @@ -620,7 +620,6 @@ def compute(self, mix_data: MixtureVisData, unit_id: int): return group_res def draw(self, panel, mix_data: MixtureVisData, unit_id: int): - print(f"{unit_id=}") demo_res = self.compute(mix_data, unit_id) ax = panel.subplots() @@ -633,7 +632,6 @@ def draw(self, panel, mix_data: MixtureVisData, unit_id: int): else: ds = ",".join([str(uu.item())[:1] for uu in demo_res.demolished.cpu()]) msg = f"units: {us}\n{ims}\ndemo: {ds}" - print(f"{msg=}") ax.text( 0.5, diff --git a/src/dartsort/vis/unit_comparison.py b/src/dartsort/vis/unit_comparison.py index 52051a6a..ef6d7d14 100644 --- a/src/dartsort/vis/unit_comparison.py +++ b/src/dartsort/vis/unit_comparison.py @@ -279,7 +279,7 @@ def __init__( unsorted_missed_color=None, count=50, single_channel=False, - order="tprandom", + order="tpfpfn", show_sorted_matches=True, show_unsorted_matches=False, average=False, @@ -328,6 +328,7 @@ def _draw(self, panel, comparison, unit_id, tested_unit_id): waveforms = [] colors = [] + channels = [] maa = 1.0 for kind, color in self.colors.items(): if w[kind] is None or not w[kind].size or not np.isfinite(w[kind]).any(): @@ -337,11 +338,15 @@ def _draw(self, panel, comparison, unit_id, tested_unit_id): maxs = np.nanmax(np.abs(w[kind]), axis=(1, 2)) maxs = np.nan_to_num(maxs, nan=1.0) maa = max(maa, np.percentile(maxs, 90)) + ch = w[f"channels_{kind}"] if self.average: avg = w[kind].mean(0, keepdims=True) + assert np.all(ch == ch[0]), "Chans mismatch in avg" waveforms.append(avg) + channels.append(ch[:1]) else: waveforms.append(w[kind]) + channels.append(ch) colors.append(np.broadcast_to([color], waveforms[-1].shape[:1])) if not len(waveforms): @@ -350,25 +355,26 @@ def _draw(self, panel, comparison, unit_id, tested_unit_id): if self.order == "tprandom" and not self.average and len(waveforms) > 1: wlast = np.concatenate(waveforms[1:]) clast = np.concatenate(colors[1:]) + chanlast = np.concatenate(channels[1:]) n = len(wlast) shuf = np.random.default_rng(0).permutation(n) waveforms = [waveforms[0], wlast[shuf]] colors = [colors[0], clast[shuf]] + channels = [channels[0], chanlast[shuf]] waveforms = np.concatenate(waveforms) colors = np.concatenate(colors) - max_channels = np.broadcast_to([w["max_chan"]], colors.shape) + channels = np.concatenate(channels) + chans = comparison.gt_analysis.vis_channel_index[w["max_chan"]] + chans = chans[chans < comparison.gt_analysis.vis_channel_index.shape[0]] ax = panel.subplots() max_abs_amp = maa - chans = w["channel_index"][w["max_chan"]] - chans = chans[chans < len(w["geom"])] geomplot( waveforms, - max_channels=max_channels, - channel_index=w["channel_index"], geom=w["geom"], ax=ax, + channels=channels, show_zero=False, max_abs_amp=max_abs_amp, annotate_z=True, From 4b7c09ad4c6d913fa86336ba5492d833fa7ed1ae Mon Sep 17 00:00:00 2001 From: Charlie Windolf Date: Wed, 10 Jun 2026 10:14:33 -0400 Subject: [PATCH 14/46] clus: raise on numerical issues in SimpleMatrixFeatures --- .../clustering/clustering_features.py | 49 ++++++++++++++----- 1 file changed, 37 insertions(+), 12 deletions(-) diff --git a/src/dartsort/clustering/clustering_features.py b/src/dartsort/clustering/clustering_features.py index a53789a1..5b633ec3 100644 --- a/src/dartsort/clustering/clustering_features.py +++ b/src/dartsort/clustering/clustering_features.py @@ -25,15 +25,6 @@ from ..util.waveform_util import single_channel_index from . import cluster_util -default_clustering_features_cfg = ClusteringFeaturesConfig() -minimal_features_cfg = ClusteringFeaturesConfig( - n_main_channel_pcs=0, - use_amplitude=False, - use_signed_amplitude=False, - use_x=False, - use_z=False, -) - logger = get_logger(__name__) @@ -77,8 +68,14 @@ def from_config( ) if xyza is not None: x = xyza[:, 0] + if not _allfinite(x): + raise ValueError(_numbers_error_str("x", x)) z = xyza[:, 2] + if not _allfinite(z): + raise ValueError(_numbers_error_str("z", z)) z_reg = motion.correct_s(t_s, z) + if not _allfinite(z_reg): + raise ValueError(_numbers_error_str("z_reg", z_reg)) else: x = z = z_reg = None @@ -99,10 +96,14 @@ def from_config( amp = getattr(sorting, clustering_features_cfg.amplitudes_dataset_name) if clustering_features_cfg.use_amplitude: assert amp is not None + if not _allfinite(amp): + raise ValueError(_numbers_error_str("amp", amp)) ampft = amp.copy() if clustering_features_cfg.log_transform_amplitude: ampft = np.log(clustering_features_cfg.amp_log_c + ampft) ampft *= clustering_features_cfg.amp_scale + if not _allfinite(ampft): + raise ValueError(_numbers_error_str("ampft", ampft)) features.append(ampft[:, None]) v = getattr(sorting, clustering_features_cfg.voltages_dataset_name, None) @@ -123,6 +124,8 @@ def from_config( rank=clustering_features_cfg.n_main_channel_pcs, dataset_name=clustering_features_cfg.pca_dataset_name, ) + if not _allfinite(pcs): + raise ValueError(_numbers_error_str("No motion pcs", pcs)) elif do_pcs and clustering_features_cfg.motion_aware: shifts, n_pitches_shift = motion.pitch_shifts( sorting=sorting, @@ -141,6 +144,8 @@ def from_config( mask = np.broadcast_to(mask, len(schan)) if hasattr(sorting, clustering_features_cfg.pca_dataset_name): pcs = getattr(sorting, clustering_features_cfg.pca_dataset_name) + if not _allfinite(pcs): + raise ValueError(_numbers_error_str("sorting pcs", pcs)) erp, pcs = interpolate_by_chunk( mask=mask, dataset=pcs, @@ -153,6 +158,8 @@ def from_config( params=clustering_features_cfg.interp_params, ) pcs = pcs[:, : clustering_features_cfg.n_main_channel_pcs, 0] + if not _allfinite(pcs): + raise ValueError(_numbers_error_str("sorting interp pcs", pcs)) else: assert sorting.parent_h5_path is not None with h5py.File(sorting.parent_h5_path, "r", locking=False) as h5: @@ -168,20 +175,25 @@ def from_config( params=clustering_features_cfg.interp_params, ) pcs = pcs[:, : clustering_features_cfg.n_main_channel_pcs, 0] + if not _allfinite(pcs): + raise ValueError(_numbers_error_str("h5 interp pcs", pcs)) if do_pcs: assert pcs is not None - if clustering_features_cfg.pc_transform == "log": + pctf = clustering_features_cfg.pc_transform + if pctf == "log": pcs = signed_log1p( pcs, pre_scale=clustering_features_cfg.pc_pre_transform_scale ) - elif clustering_features_cfg.pc_transform == "sqrt": + elif pctf == "sqrt": pcs = signed_sqrt_transform( pcs, pre_scale=clustering_features_cfg.pc_pre_transform_scale ) else: - assert clustering_features_cfg.pc_transform in ("none", None) + assert pctf in ("none", None) pcs *= clustering_features_cfg.pc_scale + if not _allfinite(pcs): + raise ValueError(_numbers_error_str(f"{pctf} pcs", pcs)) if torch.is_tensor(pcs): pcs = pcs.numpy(force=True) features.append(pcs) @@ -310,3 +322,16 @@ def signed_sqrt_transform(x, pre_scale=1.0): xx.sub_(1.0) xx.mul_(torch.sign(x)) return xx + + +def _allfinite(x): + if isinstance(x, torch.Tensor): + return x.isfinite().all() + else: + return np.isfinite(x).all() + + +def _numbers_error_str(name: str, x: np.ndarray): + if isinstance(x, torch.Tensor): + x = x.numpy(force=True) + return f"{name}: {np.isposinf(x).sum()} +inf, {np.isneginf(x).sum()} -inf, {np.isnan(x).sum()} nan." From 5014755f0102e1d8034068736084306e3388a667 Mon Sep 17 00:00:00 2001 From: Charlie Windolf Date: Wed, 10 Jun 2026 12:47:26 -0400 Subject: [PATCH 15/46] glom: crash in deduplication for tiny units --- src/dartsort/clustering/agglomerate.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/dartsort/clustering/agglomerate.py b/src/dartsort/clustering/agglomerate.py index 50dd07c5..8746e0d0 100644 --- a/src/dartsort/clustering/agglomerate.py +++ b/src/dartsort/clustering/agglomerate.py @@ -799,6 +799,8 @@ def deduplicate_spikes( ndrop = 0 for unit_id in unit_ids: in_unit = np.flatnonzero(new_labels == unit_id) + if in_unit.size <= 1: + continue t = times_samples[in_unit] dt = np.diff(t) if dt.min() > radius_samples: From 97f76753eb4152a10cbc09f6b995ec7d868ee61f Mon Sep 17 00:00:00 2001 From: Charlie Windolf Date: Wed, 10 Jun 2026 12:47:52 -0400 Subject: [PATCH 16/46] data: raise for no TPCA --- src/dartsort/evaluate/analysis.py | 9 ++++++--- src/dartsort/util/data_util.py | 2 ++ 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/src/dartsort/evaluate/analysis.py b/src/dartsort/evaluate/analysis.py index f81ab423..263f6c87 100644 --- a/src/dartsort/evaluate/analysis.py +++ b/src/dartsort/evaluate/analysis.py @@ -121,9 +121,12 @@ def from_sorting( motion = MotionInfo.from_motion_est(geom=recording.get_channel_locations()) if has_hdf5: - tpca = get_tpca( - sorting, featurization_pipeline_pt=featurization_pipeline_pt - ) + try: + tpca = get_tpca( + sorting, featurization_pipeline_pt=featurization_pipeline_pt + ) + except ValueError: + tpca = None else: tpca = None if has_hdf5 and vis_radius and tpca is not None: diff --git a/src/dartsort/util/data_util.py b/src/dartsort/util/data_util.py index 0809d55d..46d81a95 100644 --- a/src/dartsort/util/data_util.py +++ b/src/dartsort/util/data_util.py @@ -911,6 +911,8 @@ def get_tpca(sorting, name_prefix="collisioncleaned", featurization_pipeline_pt= if k in ("TemporalPCA", "TemporalPCAFeaturizer") and v["name_prefix"] == name_prefix ] + if len(tpca_kw) == 0: + raise ValueError(f"No TPCA found in {featurization_pipeline_pt}.") assert len(tpca_kw) == 1 ix, clsname, kw = tpca_kw[0] tpca = transformers_by_class_name[clsname]( From e2a3f643e83a4ed49fa4b1526cf7a22458ff4556 Mon Sep 17 00:00:00 2001 From: Charlie Windolf Date: Wed, 10 Jun 2026 12:48:02 -0400 Subject: [PATCH 17/46] extend matching debug vis --- .../util/testing_util/matching_debug_util.py | 27 ++++++++++++++++--- 1 file changed, 23 insertions(+), 4 deletions(-) diff --git a/src/dartsort/util/testing_util/matching_debug_util.py b/src/dartsort/util/testing_util/matching_debug_util.py index 1c26140c..4cf29183 100644 --- a/src/dartsort/util/testing_util/matching_debug_util.py +++ b/src/dartsort/util/testing_util/matching_debug_util.py @@ -162,6 +162,8 @@ def visualize_step_results( chunk_vis_style: Literal["im", "trace"] = "im", gt_sorting: DARTsortSorting | None = None, vis_only_last_step: bool = False, + vline_at=None, + objline_at=None, ): import matplotlib.pyplot as plt @@ -205,6 +207,7 @@ def visualize_step_results( if vis_only_last_step: iterator = list(iterator) + resid = None for it, resid, pre_conv, conv, times_samples, labels, channels in iterator: v = np.flatnonzero(times_samples == times_samples.clip(vis_start, vis_end - 1)) times_samples = times_samples[v] - vis_start @@ -274,6 +277,7 @@ def visualize_step_results( lw=1, ) + vmax = max(np.nanmax(pre_conv[:, obj_sl]), np.nanmax(conv[:, obj_sl])) for j, c in enumerate(pre_conv): axes[-2].plot( obj_domain, c[obj_sl], color=glasbey1024[j % len(glasbey1024)], lw=0.5 @@ -288,16 +292,31 @@ def visualize_step_results( axes[-1].set_ylabel("post-step " + ("obj" if obj_mode else "conv")) for ax in axes[-2:]: ax.grid() - if obj_mode: - vmin = max(-100, pre_conv[:, obj_sl].min(), conv[:, obj_sl].min()) - for ax in axes[-2:]: - ax.set_ylim([vmin, pre_conv[:, obj_sl].max() * 1.05]) + vmin = max(-100, pre_conv[:, obj_sl].min(), conv[:, obj_sl].min()) + for ax in axes[-2:]: + ax.set_ylim([vmin, vmax * 1.05]) panel.suptitle(f"iteration {it}", fontsize=12) + if vline_at is not None: + for ax in axes[:3]: + ax.axvline(vline_at, color='w', ls='--', lw=0.8) + for ax in axes[-2:]: + ax.axvline(vline_at, color='k', ls='--', lw=0.8) + if objline_at is not None: + for ax in axes[-2:]: + ax.axhline(objline_at, color='k', ls='--', lw=0.8) + plt.show() plt.close(panel) + return dict( + resid=resid, + times=t_full[:n], + channels=c_full[:n], + labels=l_full[:n], + ) + # -- reference implementation for upsampled matching From 48cc66d2bcf456212f9e88688b7b471c4b20706b Mon Sep 17 00:00:00 2001 From: Charlie Windolf Date: Thu, 11 Jun 2026 10:32:55 -0400 Subject: [PATCH 18/46] lint --- src/dartsort/peel/matching.py | 8 ++++++-- src/dartsort/peel/peel_lib.py | 2 +- tests/test_matching.py | 8 ++++---- 3 files changed, 11 insertions(+), 7 deletions(-) diff --git a/src/dartsort/peel/matching.py b/src/dartsort/peel/matching.py index ad9ddeb5..eb6eb21a 100644 --- a/src/dartsort/peel/matching.py +++ b/src/dartsort/peel/matching.py @@ -279,7 +279,7 @@ def peel_chunk( # process spike times and create return result if match_results["n_spikes"]: - match_results["times_samples"] += chunk_start_samples - left_margin # type: ignore + match_results["times_samples"] += chunk_start_samples - left_margin if match_results["n_spikes"] > self.p.max_spikes_per_second: raise ValueError( f"Too many spikes {match_results['n_spikes']} > {self.p.max_spikes_per_second}." @@ -522,7 +522,11 @@ def pick_threshold(self): from scipy.stats import norm # TODO: remove? - if self.is_scaling and self.p.scale_adjusts_threshold and self.p.amplitude_scaling_variance < torch.inf: + if ( + self.is_scaling + and self.p.scale_adjusts_threshold + and self.p.amplitude_scaling_variance < torch.inf + ): # adjust threshold by the scaling prior's constant term # nb, everything is x2 so halves are gone. scstd = np.sqrt(self.p.amplitude_scaling_variance) diff --git a/src/dartsort/peel/peel_lib.py b/src/dartsort/peel/peel_lib.py index 863cc10a..87db97d5 100644 --- a/src/dartsort/peel/peel_lib.py +++ b/src/dartsort/peel/peel_lib.py @@ -150,7 +150,7 @@ def check_residual_decrease( def subtract_chunk( traces: Tensor, channel_index: Tensor, - denoising_pipeline: WaveformPipeline, + denoising_pipeline: "WaveformPipeline", extract_index: Tensor | None = None, extract_mask: Tensor | None = None, trough_offset_samples=42, diff --git a/tests/test_matching.py b/tests/test_matching.py index f36f336c..67b4c20b 100644 --- a/tests/test_matching.py +++ b/tests/test_matching.py @@ -214,7 +214,7 @@ def test_no_crumbs(subtests, refractory_sim, method, cd_iter, channel_selection_ assert gt_shift is not None gt_shift = gt_shift[gt_in_chunk] if "scalings" in res: - match_scale = res["scalings"].numpy(force=True) # type: ignore + match_scale = res["scalings"].numpy(force=True) else: match_scale = np.ones(gt_n_spikes, dtype=np.float32) @@ -288,7 +288,7 @@ def test_no_crumbs(subtests, refractory_sim, method, cd_iter, channel_selection_ ) if "up_inds" in res: - match_up = res["up_inds"].numpy(force=True) # type: ignore + match_up = res["up_inds"].numpy(force=True) np.testing.assert_array_equal(gt_up, match_up, err_msg="sorting: up_inds 1") else: np.testing.assert_array_equal( @@ -300,7 +300,7 @@ def test_no_crumbs(subtests, refractory_sim, method, cd_iter, channel_selection_ ) if "time_shifts" in res: - match_shift = res["time_shifts"].numpy(force=True) # type: ignore + match_shift = res["time_shifts"].numpy(force=True) np.testing.assert_array_equal(match_shift, gt_shift) else: assert (gt_shift == 0).all() @@ -546,7 +546,7 @@ def test_tiny_up(tiny_up_sim, tmp_path, up_factor, scaling, cd_iter, up_offset): # ) assert np.array_equal(res["labels"].numpy(force=True), labels) if "up_inds" in res: - assert np.array_equal(res["up_inds"].numpy(force=True), upsampling_indices) # type: ignore + assert np.array_equal(res["up_inds"].numpy(force=True), upsampling_indices) else: assert up_factor == 1 resid_rms = torch.square(res["residual"]).mean().numpy(force=True) From 20124be54191fd0714d871c2d002d93481ec5819 Mon Sep 17 00:00:00 2001 From: Charlie Windolf Date: Thu, 11 Jun 2026 17:08:44 -0400 Subject: [PATCH 19/46] glom: numerical case --- src/dartsort/clustering/agglomerate.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/dartsort/clustering/agglomerate.py b/src/dartsort/clustering/agglomerate.py index 8746e0d0..59e2cad8 100644 --- a/src/dartsort/clustering/agglomerate.py +++ b/src/dartsort/clustering/agglomerate.py @@ -665,7 +665,8 @@ def combine_gmm_scores( # check invariants at the top if responsibilities.shape[1] > 2: - assert np.all(np.diff(responsibilities[:, :-1], axis=1) <= 0) + _maxdiff = np.diff(responsibilities[:, :-1], axis=1).max() + assert _maxdiff <= 1e-5, _maxdiff assert np.greater_equal(np.isneginf(logliks[:, :-1]), candidates == -1).all() if sorting.labels is not None: assert np.all( @@ -698,7 +699,8 @@ def combine_gmm_scores( # check invariants at the bottom if mergedr.shape[1] > 2: - assert np.all(np.diff(mergedr[:, : cand.shape[1]], axis=1) <= 0) + _maxdiff = np.diff(mergedr[:, : cand.shape[1]], axis=1).max() + assert _maxdiff <= 1e-5, _maxdiff assert np.greater_equal(np.isneginf(mergedl[:, : cand.shape[1]]), cand == -1).all() assert (cand < 0).sum() >= nbye if sorting.labels is not None: From a26577829f3de1cbb711625194ce54cab08c4acc Mon Sep 17 00:00:00 2001 From: Charlie Windolf Date: Thu, 11 Jun 2026 17:09:39 -0400 Subject: [PATCH 20/46] kmeans: batched version --- src/dartsort/clustering/kmeans.py | 154 ++++++++++++++++++++++++++- src/dartsort/clustering/mixture.py | 42 +++++--- src/dartsort/util/internal_config.py | 2 +- tests/test_kmeans.py | 11 +- 4 files changed, 186 insertions(+), 23 deletions(-) diff --git a/src/dartsort/clustering/kmeans.py b/src/dartsort/clustering/kmeans.py index 52e38327..3195b605 100644 --- a/src/dartsort/clustering/kmeans.py +++ b/src/dartsort/clustering/kmeans.py @@ -3,6 +3,8 @@ import torch.nn.functional as F from torch import Tensor +from dartsort.util.py_util import databag + try: import cupy # type: ignore # ty: ignore[x] @@ -33,7 +35,6 @@ def kmeanspp( n_components=10, random_state: np.random.Generator | torch.Generator | int = 0, kmeanspp_initial="random", - mode_dim=2, skip_assignment=False, min_distance=None, Xnormsq: Tensor | None = None, @@ -412,6 +413,14 @@ def kmeans_inner( return assignments, e, centroids, dists +@databag +class KMeansResult: + labels: Tensor | None + responsibilities: Tensor | None + centroids: Tensor | None + dists: Tensor | None + + def kmeans( X: Tensor, n_kmeans_tries=5, @@ -421,11 +430,11 @@ def kmeans( random_state: np.random.Generator | torch.Generator | int = 0, kmeanspp_initial="random", with_proportions=False, - drop_prop=0.025, - drop_sum=5.0, + drop_prop=0.0, + drop_sum=0.0, weights: Tensor | None = None, test_convergence_every=10, -): +) -> KMeansResult: best_phi = np.inf if isinstance(random_state, int): random_state = np.random.default_rng(random_state) @@ -461,7 +470,7 @@ def kmeans( assignments = aa.clone() e = ee centroids = cc - return dict( + return KMeansResult( labels=assignments, responsibilities=e, centroids=centroids, dists=dists ) @@ -614,3 +623,138 @@ def _kmeans_main_loop( break dists = _sqeuc(X, Xnormsq, centroids, dists) return e, centroids, dists, proportions + + +def batched_kmeans( + X: Tensor, + n_components: int, + seed: torch.Generator | np.random.Generator | int = 0, + n_iter: int = 100, + kmeanspp_seeds_per_try: int = 5, + n_tries: int = 10, + test_convergence_every=10, + atol=1e-5, + with_labels=True, + with_proportions=True, +) -> KMeansResult: + """ + Compared to above: + - with_proportions = True + - drop_prop = 0 + - no weights allowed for now + - kmeanspp always random initial + - n_iter > 0 + """ + k = n_components + del n_components + assert n_iter > 0 + assert n_tries > 0 + assert k > 1 + assert kmeanspp_seeds_per_try >= 1 + + dev = X.device + gen = spawn_torch_rg(seed, device=dev) + Xnormsq = torch.linalg.norm(X, dim=1).square_() + n, dim = X.shape + k = min(n, k) + assert n >= k > 1 + ntries_k = n_tries * k + n_kmeanspps = n_tries * kmeanspp_seeds_per_try + + # -- kmeanspp stage: initialization + centroid_ixs = torch.full((k, n_kmeanspps), n, dtype=torch.long, device=dev) + centroid_ixs[0] = torch.randint(n, size=(n_kmeanspps,), device=dev, generator=gen) + dists = X.new_empty((n_kmeanspps, n)) + Y = X[centroid_ixs[0]] + Ynormsq = Xnormsq[centroid_ixs[0]] + dists = sqeuc_cdist_known_norm(Y, Ynormsq, X, Xnormsq, dists) + + # -- kmeanspp stage: loop + # buf for random sampling with Gumbel trick and new distance storage + _buf = torch.empty_like(dists) + for j in range(1, k): + # sample new centroid indices wppt dists (which is squared) + _buf.uniform_(generator=gen) + _buf.log_().neg_().log_().neg_().add_(dists.log()) + # torch.divide(dists, _buf, out=_buf) + # cix_j = torch.argmin(_buf, dim=1, out=centroid_ixs[j]) + print(f"{dists.shape=} {_buf.shape=}") + cix_j = torch.argmax(_buf, dim=1, out=centroid_ixs[j]) + + # grab jth centroid data + torch.index_select(X, dim=0, index=cix_j, out=Y) + torch.index_select(Xnormsq, dim=0, index=cix_j, out=Ynormsq) + + # update distances + newdists = sqeuc_cdist_known_norm(Y, Ynormsq, X, Xnormsq, _buf) + torch.minimum(dists, newdists, out=dists) + + # -- kmeanspp finish: pick best by phi + if kmeanspp_seeds_per_try > 1: + phi = dists.mean(1).view(n_tries, kmeanspp_seeds_per_try) + best_kmpp = phi.argmin(1, keepdim=True) + centroid_ixs = centroid_ixs.view(k, n_tries, kmeanspp_seeds_per_try) + centroid_ixs = centroid_ixs.take_along_dim(dim=2, indices=best_kmpp[None, :, :]) + assert centroid_ixs.shape == (k, n_tries, 1) + centroid_ixs = centroid_ixs[:, :, 0] + assert centroid_ixs.shape == (k, n_tries) + centroid_ixs = centroid_ixs.T.contiguous() + + # -- kmeans stage: initialization + Y = X[centroid_ixs].view(ntries_k, dim) + Ynormsq = Xnormsq[centroid_ixs].view(ntries_k) + dists = dists.resize_(n, ntries_k) + e = X.new_empty((n, n_tries, k)) + N = X.new_ones((n_tries, k)) + Ntot = X.new_full((), float(n)) + log_props = X.new_zeros((n_tries, k)) + phi = phi_ = e.new_full((n_tries,), torch.nan) + + # -- kmeans stage: loop + check = False + for j in range(n_iter): + jmod = j % test_convergence_every + check = (j in (0, n_iter - 1)) or (jmod in (0, test_convergence_every - 1)) + + # e step + dists = sqeuc_cdist_known_norm(X, Xnormsq, Y, Ynormsq, dists.view(n, ntries_k)) + dists = dists.view(n, n_tries, k) + e = torch.add(log_props, dists, alpha=-100.0, out=e) + e = F.softmax(e, dim=2) + if check: + phi_ = dists.mul_(e).mean(0).sum(1) + assert phi_.shape == phi.shape + + # m step + N = torch.sum(e, dim=0, out=N) + if with_proportions: + torch.div(N, Ntot, out=log_props).log_() + w = e.div_(N) + Y = torch.mm(w.view(n, ntries_k).t(), X, out=Y) + Ynormsq = torch.linalg.vector_norm(Y, dim=1, out=Ynormsq) + Ynormsq.square_() + + # check convergence + if check: + done = torch.allclose(phi, phi_, atol=atol) or phi_.max() < atol + phi = phi_ + if done: + break + assert check # => phi is up to date + + # -- kmeans finish: pick best kmeans by phi, update responsibilities + best = phi.argmin() + Y = Y.view(n_tries, k, dim)[best] + Ynormsq = Ynormsq.view(n_tries, k)[best] + log_props = log_props[best] + dists = dists.resize_(n, k) + e = e.resize_(n, k) + dists = sqeuc_cdist_known_norm(X, Xnormsq, Y, Ynormsq, dists) + e = torch.add(log_props, dists, alpha=-0.5, out=e) + e = F.softmax(e, dim=1) + if with_labels: + labels = e.argmax(dim=1) + else: + labels = None + + return KMeansResult(centroids=Y, responsibilities=e, labels=labels, dists=dists) diff --git a/src/dartsort/clustering/mixture.py b/src/dartsort/clustering/mixture.py index 9c0a771d..5ccf430d 100644 --- a/src/dartsort/clustering/mixture.py +++ b/src/dartsort/clustering/mixture.py @@ -100,7 +100,7 @@ from ..util.torch_util import BModule, torch_compiler from .cluster_util import linkage, maximal_leaf_groups from .clustering_features import StableWaveformFeatures -from .kmeans import kmeans +from .kmeans import batched_kmeans, kmeans if TYPE_CHECKING: from ..transform.temporal_pca import BaseTemporalPCA @@ -5347,19 +5347,35 @@ def try_kmeans( x_ret = x if debug else None # kmeans - kres = kmeans( - x, - n_components=k, - random_state=gen, - n_iter=n_iter, - with_proportions=with_proportions, - drop_prop=drop_prop, - kmeanspp_initial=kmeanspp_initial, - n_kmeans_tries=n_kmeans_tries, - n_kmeanspp_tries=n_kmeanspp_tries, - weights=weights.to(x) if weights is not None else None, + _can_batch = ( + weights is None + and with_proportions + and not drop_prop + and kmeanspp_initial == "random" ) - resps = kres["responsibilities"] + if _can_batch: + kres = batched_kmeans( + x, + k, + seed=gen, + n_iter=n_iter, + kmeanspp_seeds_per_try=n_kmeanspp_tries, + n_tries=n_kmeans_tries, + ) + else: + kres = kmeans( + x, + n_components=k, + random_state=gen, + n_iter=n_iter, + with_proportions=with_proportions, + drop_prop=drop_prop, + kmeanspp_initial=kmeanspp_initial, + n_kmeans_tries=n_kmeans_tries, + n_kmeanspp_tries=n_kmeanspp_tries, + weights=weights.to(x) if weights is not None else None, + ) + resps = kres.responsibilities if resps is None: return None, x_ret, channels assert resps.shape[1] <= k diff --git a/src/dartsort/util/internal_config.py b/src/dartsort/util/internal_config.py index 6ab011dc..a8ca351a 100644 --- a/src/dartsort/util/internal_config.py +++ b/src/dartsort/util/internal_config.py @@ -515,7 +515,7 @@ class RefinementConfig: mixture_steps: Sequence[MixtureStep] = ("split", "merge", "demolish") prior_pseudocount: float = 0.0 kmeansk: int = 4 - kmeans_tries: int = 5 + kmeans_tries: int = 10 kmeanspp_tries: int = 5 full_proposal_every: int = 10 main_min_iters: int = 20 diff --git a/tests/test_kmeans.py b/tests/test_kmeans.py index 92b0e5a4..36e01d1e 100644 --- a/tests/test_kmeans.py +++ b/tests/test_kmeans.py @@ -28,12 +28,15 @@ def blobs(): @pytest.mark.parametrize( "algorithm", - [kmeans.kmeans, kmeans.truncated_kmeans], + [kmeans.kmeans, kmeans.truncated_kmeans, kmeans.batched_kmeans], ) def test_kmeans(blobs, algorithm): res = algorithm(blobs["X"], n_components=blobs["K"]) - order = np.lexsort(np.asarray(res["centroids"]).T) - labels = np.argsort(order)[np.asarray(res["labels"])] + centroids = res["centroids"] if isinstance(res, dict) else res.centroids + labels = res["labels"] if isinstance(res, dict) else res.labels - assert np.allclose(res["centroids"][order], blobs["centroids"], atol=0.25) + order = np.lexsort(np.asarray(centroids).T) + labels = np.argsort(order)[np.asarray(labels)] + + assert np.allclose(centroids[order], blobs["centroids"], atol=0.25) assert np.array_equal(labels, blobs["labels"]) From 8a3ece00928f7f79bdb51f4c5aa97ee49ba3ecde Mon Sep 17 00:00:00 2001 From: Charlie Windolf Date: Thu, 11 Jun 2026 17:10:18 -0400 Subject: [PATCH 21/46] templates: handle case of low numerical rank (probably just a sim thing) --- src/dartsort/peel/reduction_template.py | 11 +++++++---- src/dartsort/templates/template_util.py | 10 +++++++--- src/dartsort/templates/templib.py | 5 +++++ src/dartsort/transform/temporal_pca.py | 9 ++++++++- src/dartsort/util/internal_config.py | 1 + 5 files changed, 28 insertions(+), 8 deletions(-) diff --git a/src/dartsort/peel/reduction_template.py b/src/dartsort/peel/reduction_template.py index b14bc488..3c188a97 100644 --- a/src/dartsort/peel/reduction_template.py +++ b/src/dartsort/peel/reduction_template.py @@ -206,12 +206,14 @@ def from_config( # type: ignore pad_spike_len = padded_waveform_cfg.spike_length_samples( recording.sampling_frequency ) + rank = template_cfg.denoising_rank if template_cfg.use_svd and tsvd is not None: if isinstance(tsvd, FullProbeTemporalPCAEmbedder): if do_align: raise ValueError("Haven't handled svd alignment in this case.") else: - assert tsvd.components_.shape[0] == template_cfg.denoising_rank + assert tsvd.components_.shape[0] <= rank + rank = tsvd.components_.shape[0] tsvd = FullProbeTemporalPCAEmbedder.from_sklearn( channel_index=channel_index, pca=tsvd, @@ -230,7 +232,8 @@ def from_config( # type: ignore waveform_cfg=waveform_cfg, computation_cfg=computation_cfg, ) - assert tsvd.components_.shape[0] == template_cfg.denoising_rank + assert tsvd.components_.shape[0] <= rank + rank = tsvd.components_.shape[0] tsvd = FullProbeTemporalPCAEmbedder.from_sklearn( channel_index=channel_index, pca=tsvd, @@ -243,7 +246,7 @@ def from_config( # type: ignore elif template_cfg.use_svd: tsvd = FullProbeTemporalPCAEmbedder( channel_index=channel_index, - rank=template_cfg.denoising_rank, + rank=rank, geom=geom, fit_radius=template_cfg.denoising_fit_radius, max_waveforms=template_cfg.denoising_fit_sampling_cfg.n_waveforms_fit, @@ -321,7 +324,7 @@ def from_config( # type: ignore name_prefix="svd", with_raw_std_dev=False, n_units=sorting.n_units, - feature_dim=template_cfg.denoising_rank, + feature_dim=rank, output_channels=len(rgeom), reduction=template_cfg.reduction, ) diff --git a/src/dartsort/templates/template_util.py b/src/dartsort/templates/template_util.py index b2ef1253..36cb5b02 100644 --- a/src/dartsort/templates/template_util.py +++ b/src/dartsort/templates/template_util.py @@ -275,6 +275,7 @@ def shared_basis_compress_templates( rank=5, computation_cfg=None, precomputed_basis: np.ndarray | None = None, + min_explained_variance: float = 5e-3, with_r2: bool = False, ) -> SharedBasisTemplates: computation_cfg = ensure_computation_config(computation_cfg) @@ -292,7 +293,7 @@ def shared_basis_compress_templates( rank = min(rank, t) if precomputed_basis is None: temporal_comps = get_shared_temporal_basis( - templates, rank, dev, min_channel_amplitude + templates, rank, dev, min_channel_amplitude, min_explained_variance ) assert temporal_comps.shape[1] == t assert temporal_comps.ndim == 2 @@ -341,6 +342,7 @@ def get_shared_temporal_basis( rank: int, device: torch.device, min_channel_amplitude: float, + min_explained_variance: float, eps=1e-6, ) -> np.ndarray: n, t, c = templates.shape @@ -370,9 +372,11 @@ def get_shared_temporal_basis( cov += m * m.T if nvis < t: logger.warning(f"Had {nvis=} smaller than {t=} in shared basis compression.") - cov.diagonal().add_(1e-5) + cov.diagonal().add_(eps) vals, U = torch.linalg.eigh(cov) - big_enough = vals > eps + min_explained_variance = max(eps, min_explained_variance) + exvar = vals / vals.sum() + big_enough = exvar > min_explained_variance nbig = big_enough.sum() if nbig < rank: logger.dartsortdebug(f"Shared basis only needed rank {nbig}.") diff --git a/src/dartsort/templates/templib.py b/src/dartsort/templates/templib.py index 46b591d6..9d456f3c 100644 --- a/src/dartsort/templates/templib.py +++ b/src/dartsort/templates/templib.py @@ -63,6 +63,7 @@ def fit_tsvd( min_channel_amplitude=template_cfg.template_min_channel_amplitude, random_seed=random_seed, computation_cfg=computation_cfg, + min_explained_variance=template_cfg.svd_min_explained_variance, ) return pca @@ -128,6 +129,7 @@ def pca_from_templates( rank: int, min_channel_amplitude: float = 1.0, random_seed: int = 0, + min_explained_variance: float = 5e-3, computation_cfg: ComputationConfig | None = None, ) -> PCA: from .template_util import shared_basis_compress_templates @@ -137,8 +139,11 @@ def pca_from_templates( rank=rank, computation_cfg=computation_cfg, min_channel_amplitude=min_channel_amplitude, + min_explained_variance=min_explained_variance, ) basis = tdc.temporal_components + assert 0 < tdc.temporal_components.shape[0] <= rank + rank = tdc.temporal_components.shape[0] pca = PCA( n_components=rank, random_state=random_seed, diff --git a/src/dartsort/transform/temporal_pca.py b/src/dartsort/transform/temporal_pca.py index 1afbfc1b..4819b771 100644 --- a/src/dartsort/transform/temporal_pca.py +++ b/src/dartsort/transform/temporal_pca.py @@ -364,6 +364,9 @@ def initialize_spike_length_dependent_params(self): self.register_buffer("whitener", torch.zeros(self.rank)) else: assert self.b.mean.shape == (nt,) + if self.b.components.shape[0] < self.rank: + self.rank = self.b.components.shape[0] + self.shape = (self.rank, self.b.channel_index.shape[1]) assert self.b.components.shape == (self.rank, nt) self.to(self.b.channel_index.device) @@ -399,8 +402,12 @@ def initialize_from_templates(self, td): else: dt = self.temporal_slice.stop - self.temporal_slice.start assert basis.shape[1] == dt + self.rank = min(self.rank, basis.shape[0]) + self.shape = (self.rank, self.b.channel_index.shape[1]) self.b.mean.zero_() - self.b.components.copy_(torch.asarray(basis[: self.rank])) + comps = torch.asarray(basis[: self.rank]) + self.b.components.resize_(comps.shape) + self.b.components.copy_(comps) self._needs_fit = False diff --git a/src/dartsort/util/internal_config.py b/src/dartsort/util/internal_config.py index a8ca351a..9bb5adcb 100644 --- a/src/dartsort/util/internal_config.py +++ b/src/dartsort/util/internal_config.py @@ -373,6 +373,7 @@ class TemplateConfig: try_reload_svd: bool = True svd_alignment_iterations: int = 0 svd_alignment_ms: float = 0.75 + svd_min_explained_variance: float = 5e-3 # exp weight denoising exp_weight_snr_threshold: float = 50.0 From d823719005a95b9b2296107b8bf08144685e2b10 Mon Sep 17 00:00:00 2001 From: Charlie Windolf Date: Thu, 11 Jun 2026 18:03:48 -0400 Subject: [PATCH 22/46] print --- src/dartsort/clustering/kmeans.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/dartsort/clustering/kmeans.py b/src/dartsort/clustering/kmeans.py index 3195b605..53d5e40d 100644 --- a/src/dartsort/clustering/kmeans.py +++ b/src/dartsort/clustering/kmeans.py @@ -678,7 +678,6 @@ def batched_kmeans( _buf.log_().neg_().log_().neg_().add_(dists.log()) # torch.divide(dists, _buf, out=_buf) # cix_j = torch.argmin(_buf, dim=1, out=centroid_ixs[j]) - print(f"{dists.shape=} {_buf.shape=}") cix_j = torch.argmax(_buf, dim=1, out=centroid_ixs[j]) # grab jth centroid data From a4957a6dfbe51d3eeabe377ce3170f449d778538 Mon Sep 17 00:00:00 2001 From: Charlie Windolf Date: Thu, 11 Jun 2026 20:51:00 -0400 Subject: [PATCH 23/46] kmeans: reset temperature, gumbel --- src/dartsort/clustering/kmeans.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/dartsort/clustering/kmeans.py b/src/dartsort/clustering/kmeans.py index 53d5e40d..637b755e 100644 --- a/src/dartsort/clustering/kmeans.py +++ b/src/dartsort/clustering/kmeans.py @@ -674,11 +674,11 @@ def batched_kmeans( _buf = torch.empty_like(dists) for j in range(1, k): # sample new centroid indices wppt dists (which is squared) - _buf.uniform_(generator=gen) - _buf.log_().neg_().log_().neg_().add_(dists.log()) - # torch.divide(dists, _buf, out=_buf) - # cix_j = torch.argmin(_buf, dim=1, out=centroid_ixs[j]) - cix_j = torch.argmax(_buf, dim=1, out=centroid_ixs[j]) + # gumbel max: argmax [log(d) + -log(-log(u))] + # = argmax d / (-log u) = argmin d / (log u) + u = _buf.uniform_(generator=gen).log_() + u = torch.div(dists, u, out=u) + cix_j = torch.argmin(u, dim=1, out=centroid_ixs[j]) # grab jth centroid data torch.index_select(X, dim=0, index=cix_j, out=Y) @@ -718,7 +718,7 @@ def batched_kmeans( # e step dists = sqeuc_cdist_known_norm(X, Xnormsq, Y, Ynormsq, dists.view(n, ntries_k)) dists = dists.view(n, n_tries, k) - e = torch.add(log_props, dists, alpha=-100.0, out=e) + e = torch.add(log_props, dists, alpha=-0.5, out=e) e = F.softmax(e, dim=2) if check: phi_ = dists.mul_(e).mean(0).sum(1) From 01b6df8b7edd6294983a757e70bb020be7841d11 Mon Sep 17 00:00:00 2001 From: Charlie Windolf Date: Thu, 11 Jun 2026 20:55:40 -0400 Subject: [PATCH 24/46] matching: fix tests, toggle tpca from templates --- src/dartsort/config.py | 1 + src/dartsort/peel/matching_util/drifty.py | 1 + src/dartsort/util/internal_config.py | 2 ++ tests/test_matching.py | 1 + 4 files changed, 5 insertions(+) diff --git a/src/dartsort/config.py b/src/dartsort/config.py index 3a2baf42..efea26a2 100644 --- a/src/dartsort/config.py +++ b/src/dartsort/config.py @@ -352,6 +352,7 @@ class DeveloperConfig(DARTsortUserConfig): robust_df: float = 4.0 demolish_during_selection: bool = False em_after_demolish: bool = False + tpca_from_templates: bool = True # agglomeration agg_kind: Literal["none", "template_distance", "qda"] = "qda" diff --git a/src/dartsort/peel/matching_util/drifty.py b/src/dartsort/peel/matching_util/drifty.py index 10a1f730..eb7c9bf2 100644 --- a/src/dartsort/peel/matching_util/drifty.py +++ b/src/dartsort/peel/matching_util/drifty.py @@ -226,6 +226,7 @@ def _from_config( min_channel_amplitude=matching_cfg.template_min_channel_amplitude, rank=matching_cfg.template_svd_compression_rank, computation_cfg=computation_cfg, + min_explained_variance=matching_cfg.template_svd_compression_min_explained_variance, ) temporal_comps = torch.asarray(shared_basis_temps.temporal_components) spatial_sing = torch.asarray(shared_basis_temps.spatial_singular) diff --git a/src/dartsort/util/internal_config.py b/src/dartsort/util/internal_config.py index 9bb5adcb..e36af64e 100644 --- a/src/dartsort/util/internal_config.py +++ b/src/dartsort/util/internal_config.py @@ -749,6 +749,7 @@ class MatchingConfig: # template matching parameters threshold: float | Literal["fp_control"] = 8.0 template_svd_compression_rank: int = 5 + template_svd_compression_min_explained_variance: float = 5e-3 up_factor: int = 4 upsampling_radius: int = 8 template_min_channel_amplitude: float = 1.0 @@ -1009,6 +1010,7 @@ def to_internal_config(cfg) -> DARTsortInternalConfig: tpca_max_waveforms=cfg.n_waveforms_fit, save_input_waveforms=cfg.save_collisioncleaned_waveforms, save_collidedness=save_collidedness, + tpca_from_templates=cfg.tpca_from_templates, ) if cfg.template_interp_kind == "tps": temp_interp_params = tps_interp_clampna_extrap_params diff --git a/tests/test_matching.py b/tests/test_matching.py index 67b4c20b..1bd747eb 100644 --- a/tests/test_matching.py +++ b/tests/test_matching.py @@ -116,6 +116,7 @@ def test_no_crumbs(subtests, refractory_sim, method, cd_iter, channel_selection_ threshold=threshold, cd_iter=cd_iter, channel_selection_radius=channel_selection_radius, + template_svd_compression_min_explained_variance=0.0, ) if method == "upcomp": cfg_kw["template_type"] = "individual_compressed_upsampled" From 0b11e8d349f79afb1bad070937d35455996dd421 Mon Sep 17 00:00:00 2001 From: Charlie Windolf Date: Fri, 12 Jun 2026 15:19:48 -0400 Subject: [PATCH 25/46] kmeans: set beta param --- src/dartsort/clustering/kmeans.py | 5 +++-- src/dartsort/clustering/mixture.py | 6 ++++++ src/dartsort/util/internal_config.py | 1 + 3 files changed, 10 insertions(+), 2 deletions(-) diff --git a/src/dartsort/clustering/kmeans.py b/src/dartsort/clustering/kmeans.py index 637b755e..ff02fac8 100644 --- a/src/dartsort/clustering/kmeans.py +++ b/src/dartsort/clustering/kmeans.py @@ -636,6 +636,7 @@ def batched_kmeans( atol=1e-5, with_labels=True, with_proportions=True, + beta: float = 1.0, ) -> KMeansResult: """ Compared to above: @@ -718,7 +719,7 @@ def batched_kmeans( # e step dists = sqeuc_cdist_known_norm(X, Xnormsq, Y, Ynormsq, dists.view(n, ntries_k)) dists = dists.view(n, n_tries, k) - e = torch.add(log_props, dists, alpha=-0.5, out=e) + e = torch.add(log_props, dists, alpha=-0.5 * beta, out=e) e = F.softmax(e, dim=2) if check: phi_ = dists.mul_(e).mean(0).sum(1) @@ -749,7 +750,7 @@ def batched_kmeans( dists = dists.resize_(n, k) e = e.resize_(n, k) dists = sqeuc_cdist_known_norm(X, Xnormsq, Y, Ynormsq, dists) - e = torch.add(log_props, dists, alpha=-0.5, out=e) + e = torch.add(log_props, dists, alpha=-0.5 * beta, out=e) e = F.softmax(e, dim=1) if with_labels: labels = e.argmax(dim=1) diff --git a/src/dartsort/clustering/mixture.py b/src/dartsort/clustering/mixture.py index 5ccf430d..a3eae6fc 100644 --- a/src/dartsort/clustering/mixture.py +++ b/src/dartsort/clustering/mixture.py @@ -807,6 +807,7 @@ class TMMParams: demolition_min_resp_ratio: float demolish_during_selection: bool kmeans_tries: int + kmeans_beta: float kmeanspp_tries: int whiten_split: bool scale_dist_args: tuple[float, float, float] @@ -824,6 +825,7 @@ def from_refinement_cfg(cls, refinement_cfg: RefinementConfig): split_min_count=refinement_cfg.split_min_count, split_k=refinement_cfg.kmeansk, kmeans_tries=refinement_cfg.kmeans_tries, + kmeans_beta=refinement_cfg.kmeans_beta, kmeanspp_tries=refinement_cfg.kmeanspp_tries, min_channel_count=refinement_cfg.channels_count_min, em_iters=refinement_cfg.n_em_iters, @@ -2947,6 +2949,7 @@ def split_group( min_count=self.p.split_min_count, min_channel_count=self.p.min_channel_count, n_kmeans_tries=self.p.kmeans_tries, + kmeans_beta=self.p.kmeans_beta, n_kmeanspp_tries=self.p.kmeanspp_tries, bail_at=int(single), weights=split_data.duties, @@ -5330,6 +5333,7 @@ def try_kmeans( drop_prop: float = 0.0, kmeanspp_initial="random", n_kmeans_tries: int = 25, + kmeans_beta: float = 1.0, n_kmeanspp_tries: int = 25, weights: Tensor | None = None, debug: bool = False, @@ -5361,8 +5365,10 @@ def try_kmeans( n_iter=n_iter, kmeanspp_seeds_per_try=n_kmeanspp_tries, n_tries=n_kmeans_tries, + beta=kmeans_beta, ) else: + assert kmeans_beta == 1.0 kres = kmeans( x, n_components=k, diff --git a/src/dartsort/util/internal_config.py b/src/dartsort/util/internal_config.py index e36af64e..a6ceb409 100644 --- a/src/dartsort/util/internal_config.py +++ b/src/dartsort/util/internal_config.py @@ -517,6 +517,7 @@ class RefinementConfig: prior_pseudocount: float = 0.0 kmeansk: int = 4 kmeans_tries: int = 10 + kmeans_beta: float = 50.0 kmeanspp_tries: int = 5 full_proposal_every: int = 10 main_min_iters: int = 20 From f5e17b836a1b214a8d8474958c1312366e55d56e Mon Sep 17 00:00:00 2001 From: Charlie Windolf Date: Fri, 12 Jun 2026 15:22:44 -0400 Subject: [PATCH 26/46] whitening: impl temporal whitening estimation, config, data plumbing --- src/dartsort/config.py | 1 + src/dartsort/main.py | 4 +- src/dartsort/peel/matching_util/drifty.py | 24 ++++- src/dartsort/peel/reduction_template.py | 11 ++- src/dartsort/templates/get_templates.py | 4 +- src/dartsort/templates/postprocess_util.py | 4 +- src/dartsort/templates/templates.py | 11 ++- src/dartsort/transform/whiten.py | 14 +-- src/dartsort/util/internal_config.py | 2 + src/dartsort/util/noise_util.py | 101 +++++++++++++++++++-- 10 files changed, 140 insertions(+), 36 deletions(-) diff --git a/src/dartsort/config.py b/src/dartsort/config.py index efea26a2..d96d27e9 100644 --- a/src/dartsort/config.py +++ b/src/dartsort/config.py @@ -282,6 +282,7 @@ class DeveloperConfig(DARTsortUserConfig): trough_factor: float = 3.0 whiten_strategy: WhiteningStrategy = "prewhiten_postapply" whiten_estimator: WhiteningEstimator = "localzca" + whiten_temporal_length: int | None = None whiten_features: bool = False matching_fp_control: bool = False refractory_radius_frames: int = 0 diff --git a/src/dartsort/main.py b/src/dartsort/main.py index 92b46704..a880dc8c 100644 --- a/src/dartsort/main.py +++ b/src/dartsort/main.py @@ -59,7 +59,7 @@ motion_needs_peaks, ) from .util.motion import MotionInfo, get_motion_info -from .util.noise_util import SpatialWhitener +from .util.noise_util import Whitener from .util.peel_util import run_peeler from .util.preprocess_util import preprocess from .util.py_util import dartcopytree, ensure_path, timer @@ -571,7 +571,7 @@ def match( template_npz="template_data.npz", computation_cfg: ComputationConfig | None = None, template_denoising_tsvd=None, - whitener: SpatialWhitener | None = None, + whitener: Whitener | None = None, ) -> DARTsortSorting: output_dir = ensure_path(output_dir) model_dir = output_dir / model_subdir diff --git a/src/dartsort/peel/matching_util/drifty.py b/src/dartsort/peel/matching_util/drifty.py index eb7c9bf2..81aec58f 100644 --- a/src/dartsort/peel/matching_util/drifty.py +++ b/src/dartsort/peel/matching_util/drifty.py @@ -55,7 +55,7 @@ from ...util.job_util import ensure_computation_config from ...util.logging_util import get_logger from ...util.motion import MotionInfo -from ...util.noise_util import SpatialWhitener +from ...util.noise_util import Whitener from ...util.py_util import databag from ...util.spiketorch import full_shared_pconv, shared_temporal_pconv from ...util.torch_util import torch_compiler @@ -84,7 +84,7 @@ def __init__( trough_offset_samples: int, unit_ids: Tensor | None = None, whiten_strategy: WhiteningStrategy = "none", - whitener: SpatialWhitener | None = None, + whitener: Whitener | None = None, whiten_features: bool = True, up_factor: int = 1, up_method: Literal["interpolation", "keys3", "keys4", "direct"] = "keys4", @@ -238,8 +238,12 @@ def _from_config( else: assert template_data.whitener is not None assert template_data.covariance is not None - whitener = SpatialWhitener.from_numpy( - template_data.whitener, template_data.covariance + if template_data.temporal_kernel is not None: + tk = template_data.temporal_kernel + else: + tk = None + whitener = Whitener.from_numpy( + template_data.whitener, template_data.covariance, tk ) if not wh_none and not matching_cfg.whiten_features: @@ -268,6 +272,10 @@ def interp_at_time(self, t_s: float, x: Tensor) -> Tensor: return self.erp.interp_at_time(t_s=t_s, waveforms=x) def spatial_at_time(self, t_s: float) -> tuple[Tensor, ...]: + """Get spatial components and norms at current chunk + + This handles the spatial whitening strategy logic. + """ if self.whiten_strategy == "postwhiten": spatial_sing = self.interp_at_time(t_s, self.b.spatial_sing) normsq_spatial_sing = spatial_sing @@ -277,11 +285,17 @@ def spatial_at_time(self, t_s: float) -> tuple[Tensor, ...]: conv_spatial_sing = normsq_spatial_sing = spatial_sing elif self.whiten_strategy == "prewhiten_postapply": assert self.whitener is not None + + # features / cc waveforms just use interpolated plain template spatial_sing = self.interp_at_time(t_s, self.b.spatial_sing) + + # for convolution, whitening is applied if self.interpolating: conv_spatial_sing = self.whitener.whiten(spatial_sing) else: conv_spatial_sing = self.b.conv_spatial_sing + + # this is usually false if self.whiten_features: spatial_sing = conv_spatial_sing normsq_spatial_sing = conv_spatial_sing @@ -367,7 +381,7 @@ class DriftyChunkTemplateData(ChunkTemplateData): spatial_sing: Tensor padded_spatial_sing: Tensor pconv: Tensor - spatial_whitener: SpatialWhitener | None + spatial_whitener: Whitener | None time_ix: Tensor chan_ix: Tensor diff --git a/src/dartsort/peel/reduction_template.py b/src/dartsort/peel/reduction_template.py index 3c188a97..8ddd81ea 100644 --- a/src/dartsort/peel/reduction_template.py +++ b/src/dartsort/peel/reduction_template.py @@ -30,7 +30,7 @@ from ..util.job_util import ensure_computation_config from ..util.logging_util import get_logger from ..util.motion import MotionInfo -from ..util.noise_util import SpatialWhitener +from ..util.noise_util import Whitener from ..util.py_util import ensure_path from ..util.waveform_util import full_channel_index from .grab import GrabAndFeaturize @@ -54,7 +54,7 @@ def _from_config( waveform_cfg: WaveformConfig = default_waveform_cfg, motion: MotionInfo, tsvd=None, - whitener: SpatialWhitener | None = None, + whitener: Whitener | None = None, computation_cfg: ComputationConfig | None = None, show_progress: bool = True, ) -> TemplateData: @@ -155,9 +155,9 @@ def _from_config( templates *= msk if whitener is None: - whitener_np = covariance_np = None + whitener_np = covariance_np = tk_np = None else: - whitener_np, covariance_np = whitener.to_numpy() + whitener_np, covariance_np, tk_np = whitener.to_numpy() return TemplateData( unit_ids=unit_ids, @@ -170,6 +170,7 @@ def _from_config( tsvd=p.temporal_svd(), whitener=whitener_np, covariance=covariance_np, + temporal_kernel=tk_np, sampling_frequency=recording.sampling_frequency, whiten_strategy=template_cfg.whitening.strategy, ) @@ -190,7 +191,7 @@ def from_config( # type: ignore waveform_cfg: WaveformConfig, template_cfg: TemplateConfig, computation_cfg: ComputationConfig, - whitener: SpatialWhitener | None = None, + whitener: Whitener | None = None, ): # geom processing rgeom = torch.asarray(motion.rgeom) diff --git a/src/dartsort/templates/get_templates.py b/src/dartsort/templates/get_templates.py index 84a93b31..3df6965b 100644 --- a/src/dartsort/templates/get_templates.py +++ b/src/dartsort/templates/get_templates.py @@ -27,7 +27,7 @@ from ..util.logging_util import get_logger, progbar from ..util.motion import MotionInfo from ..util.multiprocessing_util import get_pool -from ..util.noise_util import SpatialWhitener +from ..util.noise_util import Whitener from ..util.spiketorch import fast_nanmedian, nanmean, ptp from .templib import denoising_weights, fit_tsvd @@ -49,7 +49,7 @@ def _from_config( waveform_cfg: WaveformConfig = default_waveform_cfg, motion: MotionInfo, tsvd=None, - whitener: SpatialWhitener | None = None, + whitener: Whitener | None = None, computation_cfg: ComputationConfig | None = None, show_progress: bool = True, ) -> TemplateData: diff --git a/src/dartsort/templates/postprocess_util.py b/src/dartsort/templates/postprocess_util.py index e73d9841..19a06126 100644 --- a/src/dartsort/templates/postprocess_util.py +++ b/src/dartsort/templates/postprocess_util.py @@ -25,7 +25,7 @@ from ..util.job_util import ensure_computation_config from ..util.logging_util import get_logger from ..util.motion import MotionInfo -from ..util.noise_util import SpatialWhitener +from ..util.noise_util import Whitener from ..util.py_util import ensure_path from ..util.spiketorch import ptp from . import TemplateData, realign @@ -49,7 +49,7 @@ def estimate_template_library( realign_cfg: TemplateRealignmentConfig | None = None, template_merge_cfg: TemplateMergeConfig | None = None, tsvd: PCA | TruncatedSVD | None = None, - whitener: SpatialWhitener | None = None, + whitener: Whitener | None = None, computation_cfg: ComputationConfig | None = None, fit_featurization_tsvd: bool = False, featurization_cfg: FeaturizationConfig | None = None, diff --git a/src/dartsort/templates/templates.py b/src/dartsort/templates/templates.py index 58e1e057..cb3b0a26 100644 --- a/src/dartsort/templates/templates.py +++ b/src/dartsort/templates/templates.py @@ -19,7 +19,7 @@ ) from ..util.logging_util import get_logger from ..util.motion import MotionInfo -from ..util.noise_util import SpatialWhitener +from ..util.noise_util import Whitener from ..util.py_util import databag from .template_util import weighted_average @@ -52,6 +52,7 @@ class TemplateData: tsvd: TruncatedSVD | PCA | None = None whitener: np.ndarray | None = None covariance: np.ndarray | None = None + temporal_kernel: np.ndarray | None = None whiten_strategy: WhiteningStrategy = "none" featurization_basis: np.ndarray | None = None @@ -163,6 +164,8 @@ def to_npz(self, npz_path): to_save["whitener"] = self.whitener if self.covariance is not None: to_save["covariance"] = self.covariance + if self.temporal_kernel is not None: + to_save["temporal_kernel"] = self.temporal_kernel if self.featurization_basis is not None: to_save["featurization_basis"] = self.featurization_basis if not npz_path.parent.exists(): @@ -251,7 +254,7 @@ def from_config( save_folder: Path | None = None, overwrite=False, motion: MotionInfo | None = None, - whitener: SpatialWhitener | None = None, + whitener: Whitener | None = None, save_npz_name: str | None = "template_data.npz", tsvd=None, featurization_basis=None, @@ -273,7 +276,7 @@ def from_config( motion = MotionInfo.from_motion_est(geom=recording.get_channel_locations()) if template_cfg.whitening.strategy != "none" and whitener is None: assert sorting is not None - whitener = SpatialWhitener.from_config( + whitener = Whitener.from_config( sorting=sorting, motion=motion, whiten_cfg=template_cfg.whitening, @@ -322,7 +325,7 @@ def _from_config( template_cfg: TemplateConfig, waveform_cfg: WaveformConfig = default_waveform_cfg, motion: MotionInfo, - whitener: SpatialWhitener | None = None, + whitener: Whitener | None = None, tsvd=None, computation_cfg: ComputationConfig | None = None, ) -> "TemplateData": diff --git a/src/dartsort/transform/whiten.py b/src/dartsort/transform/whiten.py index a590f1d7..c82eb43a 100644 --- a/src/dartsort/transform/whiten.py +++ b/src/dartsort/transform/whiten.py @@ -14,7 +14,7 @@ from .transform_base import BaseWaveformDenoiser if TYPE_CHECKING: - from ..util.noise_util import SpatialWhitener + from ..util.noise_util import Whitener from .pipeline import WaveformPipeline @@ -30,7 +30,7 @@ def __init__( name=None, name_prefix=None, waveform_cfg: WaveformConfig | None = default_waveform_cfg, - whitener: "SpatialWhitener | None" = None, + whitener: "Whitener | None" = None, disabled: bool = True, whiten_cfg: WhiteningConfig = WhiteningConfig(), sampling_frequency: float = 30_000.0, @@ -59,9 +59,11 @@ def attach_motion(self, motion): def _other_pre_load_state(self, state_dict, prefix): if self.whitener is not None: return - from ..util.noise_util import SpatialWhitener + from ..util.noise_util import Whitener - self.whitener = SpatialWhitener.blank(len(self.b.geom), self.b.geom.device) + self.whitener = Whitener.blank( + len(self.b.geom), self.b.geom.device, self.whiten_cfg.temporal_length + ) def fit( self, @@ -82,7 +84,7 @@ def fit( **spike_data, ) del recording, spike_data, waveforms, pipeline - from ..util.noise_util import SpatialWhitener + from ..util.noise_util import Whitener assert hdf5_filename is not None @@ -92,7 +94,7 @@ def fit( sorting = DARTsortSorting.from_peeling_hdf5( hdf5_filename, load_simple_features=False ) - self.whitener = SpatialWhitener.from_config( + self.whitener = Whitener.from_config( sorting=sorting, motion=self.motion, whiten_cfg=self.whiten_cfg, diff --git a/src/dartsort/util/internal_config.py b/src/dartsort/util/internal_config.py index a6ceb409..22b685ab 100644 --- a/src/dartsort/util/internal_config.py +++ b/src/dartsort/util/internal_config.py @@ -327,6 +327,7 @@ class WhiteningConfig: estimator: WhiteningEstimator = "localzca" interp_params: InterpolationParams = tps_interp_clampna_extrap_params radius: float = 200.0 + temporal_length: int | None = None TemplateSVDMethod = Literal[ @@ -1030,6 +1031,7 @@ def to_internal_config(cfg) -> DARTsortInternalConfig: estimator=cfg.whiten_estimator, radius=cfg.subtraction_radius_um, interp_params=temp_interp_params, + temporal_length=cfg.whiten_temporal_length, ) # TODO: dredge_only is a bad name for this. if cfg.dredge_only and not cfg.whiten_in_subtraction: diff --git a/src/dartsort/util/noise_util.py b/src/dartsort/util/noise_util.py index fcf84c2f..1105fc69 100644 --- a/src/dartsort/util/noise_util.py +++ b/src/dartsort/util/noise_util.py @@ -1503,6 +1503,48 @@ def residual_covariance( return cov +def residual_welch_whitener( + sorting: DARTsortSorting, + device: torch.device | None = None, + residual_dataset_name="residual", + batch_size=1024, + temporal_length: int = 11, + spatial_whitener: torch.Tensor | None = None, +): + assert sorting.parent_h5_path is not None + snipgen = sorting._yield_dataset(residual_dataset_name, batch_size=batch_size) + + if spatial_whitener is not None: + W = torch.asarray(spatial_whitener, device=device) + else: + W = None + + # Welch's method to estimate residual PSD + snip_psds = [] + block_len = temporal_length + for snip in snipgen: + block_len = next_fast_len(snip.shape[1]) + snip = torch.asarray(snip).to(device=device, non_blocking=True) + if W is not None: + snip = torch.einsum("ntc,cd->ntd", snip, W) + snip = snip.permute(0, 2, 1).view(-1, snip.shape[1]).contiguous() + periodogram = torch.fft.rfft(snip, n=block_len, norm="ortho") + dens = (periodogram * periodogram.conj()).mean(dim=0) + snip_psds.append(dens) + snip_psds = torch.stack(snip_psds, dim=0) + spectral_density = snip_psds.mean(0).sqrt_() + + # estimate 0-phase FIR whitener + wkernel = torch.fft.fftshift(torch.fft.irfft(1.0 / spectral_density, n=block_len)) + + # trim to requested length + assert temporal_length <= wkernel.shape[0] + i0 = wkernel.shape[0] // 2 - temporal_length // 2 + i1 = i0 + temporal_length + wkernel = wkernel[i0:i1].copy() + return wkernel + + def fullzca_whitener( cov: np.ndarray, channel_index: np.ndarray | None = None, eps=1e-6 ) -> np.ndarray: @@ -1553,26 +1595,48 @@ def sparsechol_whitener( } -class SpatialWhitener(BModule): - def __init__(self, whitener: Tensor, covariance: Tensor): +class Whitener(BModule): + def __init__( + self, whitener: Tensor, covariance: Tensor, temporal_kernel: Tensor | None + ): super().__init__() self.register_buffer("whitener", whitener) self.register_buffer("covariance", covariance) + self.register_buffer_or_none("temporal_kernel", temporal_kernel) @classmethod - def blank(cls, n_channels: int, device: torch.device): + def blank(cls, n_channels: int, device: torch.device, temporal_length: int | None): w = torch.zeros((n_channels, n_channels), device=device) - return cls(w, torch.zeros_like(w)) + if temporal_length: + k = torch.zeros((temporal_length,), device=device) + else: + k = None + return cls(w, torch.zeros_like(w), k) @classmethod - def from_numpy(cls, whitener: np.ndarray, covariance: np.ndarray): + def from_numpy( + cls, + whitener: np.ndarray, + covariance: np.ndarray, + temporal_kernel: np.ndarray | None, + ): logger.dartsortverbose("Load whitener from numpy.") + tk = None if temporal_kernel is None else torch.asarray(temporal_kernel) return cls( - whitener=torch.asarray(whitener), covariance=torch.asarray(covariance) + whitener=torch.asarray(whitener), + covariance=torch.asarray(covariance), + temporal_kernel=tk, ) - def to_numpy(self) -> tuple[np.ndarray, np.ndarray]: - return self.b.whitener.numpy(force=True), self.b.covariance.numpy(force=True) + def to_numpy(self) -> tuple[np.ndarray, np.ndarray, np.ndarray | None]: + tk = self.b.temporal_kernel + if tk is not None: + tk = tk.numpy(force=True) + return ( + self.b.whitener.numpy(force=True), + self.b.covariance.numpy(force=True), + tk, + ) @classmethod def from_config( @@ -1584,7 +1648,12 @@ def from_config( computation_cfg: ComputationConfig | None = None, ) -> Self: logger.dartsortdebug( - "Estimating %s-%s whitener.", whiten_cfg.strategy, whiten_cfg.estimator + "Estimating %s-%s whitener%s.", + whiten_cfg.strategy, + whiten_cfg.estimator, + "" + if not whiten_cfg.temporal_length + else f"temporal length: {whiten_cfg.temporal_length}", ) device = ensure_computation_config(computation_cfg).actual_device() cov = residual_covariance( @@ -1602,7 +1671,19 @@ def from_config( cov_np, channel_index=neighbs ) whitener = torch.asarray(whitener).to(cov) - return cls(whitener=whitener, covariance=cov) + + if whiten_cfg.temporal_length: + assert whiten_cfg.strategy != "postwhiten" + temporal_kernel = residual_welch_whitener( + sorting=sorting, + device=device, + temporal_length=whiten_cfg.temporal_length, + spatial_whitener=whitener, + ) + else: + temporal_kernel = None + + return cls(whitener=whitener, covariance=cov, temporal_kernel=temporal_kernel) def whiten_traces_spatial_major( self, x: Tensor, out: Tensor | None = None From 179968c3d1fda5e49857c385183740dfb383413a Mon Sep 17 00:00:00 2001 From: Charlie Windolf Date: Fri, 12 Jun 2026 15:53:35 -0400 Subject: [PATCH 27/46] whitening: impl needed data transforms --- src/dartsort/peel/peel_lib.py | 24 ++++++++++++++++++---- src/dartsort/peel/subtract.py | 8 ++++++++ src/dartsort/util/noise_util.py | 36 ++++++++++++++++++++++++++++++++- 3 files changed, 63 insertions(+), 5 deletions(-) diff --git a/src/dartsort/peel/peel_lib.py b/src/dartsort/peel/peel_lib.py index 87db97d5..6ee35a7d 100644 --- a/src/dartsort/peel/peel_lib.py +++ b/src/dartsort/peel/peel_lib.py @@ -82,6 +82,7 @@ def check_residual_decrease( save_residnorm_decrease=False, overwrite_orig_waveforms: bool = False, local_whiteners: Tensor | None = None, + whitening_kernel: Tensor | None = None, channels: Tensor | None = None, ) -> tuple[Tensor | None, dict[str, Tensor]]: if not threshold: @@ -90,15 +91,28 @@ def check_residual_decrease( if local_whiteners is not None: assert channels is not None - W = local_whiteners[channels].mT + W = local_whiteners[channels] + + # remove nans if overwrite_orig_waveforms: orig_wfs = orig_wfs.nan_to_num_() else: orig_wfs = orig_wfs.nan_to_num() - buf = orig_wfs - orig_wfs = orig_wfs.bmm(W) dn_wfs = dn_wfs.nan_to_num() - dn_wfs = torch.bmm(dn_wfs, W, out=buf) + + # spatial mul -- putting temporal dim last here + buf = orig_wfs + orig_wfs = W.bmm(orig_wfs.mT) + dn_wfs = torch.bmm(W, dn_wfs.mT, out=buf) + + # temporal conv if needed + if whitening_kernel is not None: + *shp, t = orig_wfs.shape + k = whitening_kernel[None, None] + orig_wfs = F.conv1d(orig_wfs.view(-1, 1, t), k, padding="same") + dn_wfs = F.conv1d(dn_wfs.view(-1, 1, t), k, padding="same") + orig_wfs = orig_wfs.view(*shp, t) + dn_wfs = dn_wfs.view(*shp, t) if decrease_objective == "deconv": if overwrite_orig_waveforms: @@ -171,6 +185,7 @@ def subtract_chunk( residnorm_decrease_threshold=16.0, decrease_objective: Literal["norm", "normsq", "deconv"] = "deconv", local_whiteners: Tensor | None = None, + whitening_kernel: Tensor | None = None, relative_peak_radius=5, dedup_temporal_radius=7, remove_exact_duplicates=True, @@ -371,6 +386,7 @@ def subtract_chunk( threshold=residnorm_decrease_threshold, save_residnorm_decrease=save_residnorm_decrease, local_whiteners=local_whiteners, + whitening_kernel=whitening_kernel, channels=channels, ) features.update(new_feats) diff --git a/src/dartsort/peel/subtract.py b/src/dartsort/peel/subtract.py index 421a7c70..30cd18d1 100644 --- a/src/dartsort/peel/subtract.py +++ b/src/dartsort/peel/subtract.py @@ -131,6 +131,7 @@ def __init__( # this may be overwritten after featurization fit self.register_buffer_or_none("local_whiteners", None) + self.register_buffer_or_none("whitening_kernel", None) def out_datasets(self): datasets = super().out_datasets() @@ -171,6 +172,12 @@ def post_fit(self): local_whiteners = whitener.local_whiteners(self.sub_channel_index) # type: ignore self.del_none_buffer("local_whiteners") self.register_buffer("local_whiteners", local_whiteners) + if whitener.temporal: + self.del_none_buffer("whitening_kernel") + self.register_buffer( + "whitening_kernel", + self.whitener.b.temporal_kernel.clone(), # type: ignore + ) self.threshold = self.p.residnorm_decrease_threshold def save_models(self, save_folder): @@ -296,6 +303,7 @@ def peel_chunk( denoiser_realignment_channel=self.p.denoiser_realignment_channel, compute_collidedness=self.save_collidedness, local_whiteners=self.b.local_whiteners, + whitening_kernel=self.b.whitening_kernel, ) # add in chunk_start_samples diff --git a/src/dartsort/util/noise_util.py b/src/dartsort/util/noise_util.py index 1105fc69..3d4293af 100644 --- a/src/dartsort/util/noise_util.py +++ b/src/dartsort/util/noise_util.py @@ -1602,7 +1602,13 @@ def __init__( super().__init__() self.register_buffer("whitener", whitener) self.register_buffer("covariance", covariance) + self.temporal = temporal_kernel is not None self.register_buffer_or_none("temporal_kernel", temporal_kernel) + if temporal_kernel is not None: + tk_twice = self._convolve(temporal_kernel) + else: + tk_twice = None + self.register_buffer_or_none("temporal_kernel_twice", tk_twice) @classmethod def blank(cls, n_channels: int, device: torch.device, temporal_length: int | None): @@ -1685,16 +1691,40 @@ def from_config( return cls(whitener=whitener, covariance=cov, temporal_kernel=temporal_kernel) + def _convolve(self, x: Tensor, twice=False): + if not self.temporal: + return x + *shp, t = x.shape + x = x.reshape(-1, 1, t) + if twice: + k = self.b.temporal_kernel_twice + else: + k = self.b.temporal_kernel + res = F.conv1d( + input=x, + weight=k[None, None], + padding="same", + groups=x.shape[1], + ) + assert res.shape[-1] == t + res = res.reshape(*shp, t) + return res + def whiten_traces_spatial_major( self, x: Tensor, out: Tensor | None = None ) -> Tensor: - return torch.mm(self.b.whitener, x.T, out=out) + assert x.ndim == 2 + out = torch.mm(self.b.whitener, x.T, out=out) + out = self._convolve(out) + return out def whiten(self, x: Tensor, out: Tensor | None = None) -> Tensor: *shp, c = x.shape x = x.reshape(-1, c) x = torch.mm(x, self.b.whitener.T, out=out) x = x.reshape(*shp, c) + if self.temporal: + x = self._convolve(x.mT).mT return x def transpose_whiten(self, x: Tensor, out: Tensor | None = None) -> Tensor: @@ -1702,6 +1732,8 @@ def transpose_whiten(self, x: Tensor, out: Tensor | None = None) -> Tensor: x = x.reshape(-1, c) x = torch.mm(x, self.b.whitener, out=out) x = x.reshape(*shp, c) + if self.temporal: + x = self._convolve(x.mT).mT return x def prec_mul(self, x: Tensor) -> Tensor: @@ -1709,6 +1741,8 @@ def prec_mul(self, x: Tensor) -> Tensor: x = x.reshape(-1, c) x = x @ (self.b.whitener.T @ self.b.whitener) x = x.reshape(*shp, c) + if self.temporal: + x = self._convolve(x.mT, twice=True).mT return x def local_whiteners(self, channel_index: Tensor, eps=1e-6): From 54a529f3987bee23b2b7851f5c3db04beb7e2ceb Mon Sep 17 00:00:00 2001 From: Charlie Windolf Date: Sun, 14 Jun 2026 16:33:44 -0400 Subject: [PATCH 28/46] subtract: save models earlier so debugging can resume --- src/dartsort/peel/matching_util/drifty.py | 2 ++ src/dartsort/peel/subtract.py | 14 ++++++-------- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/dartsort/peel/matching_util/drifty.py b/src/dartsort/peel/matching_util/drifty.py index 81aec58f..8363eec1 100644 --- a/src/dartsort/peel/matching_util/drifty.py +++ b/src/dartsort/peel/matching_util/drifty.py @@ -331,6 +331,8 @@ def data_at_time( ) return DriftyChunkTemplateData( spike_length_samples=self.spike_length_samples, + filter_length_samples=self.b.conv_temporal_comps.shape[1], + resid_offset=resid_offset, unit_ids=self.b.unit_ids, main_channels=main_channels, obj_normsq=normsq, diff --git a/src/dartsort/peel/subtract.py b/src/dartsort/peel/subtract.py index 30cd18d1..26759406 100644 --- a/src/dartsort/peel/subtract.py +++ b/src/dartsort/peel/subtract.py @@ -174,23 +174,20 @@ def post_fit(self): self.register_buffer("local_whiteners", local_whiteners) if whitener.temporal: self.del_none_buffer("whitening_kernel") - self.register_buffer( - "whitening_kernel", - self.whitener.b.temporal_kernel.clone(), # type: ignore - ) + self.register_buffer("whitening_kernel", whitener.b.temporal_kernel.clone()) self.threshold = self.p.residnorm_decrease_threshold - def save_models(self, save_folder): - super().save_models(save_folder) + def save_models(self, save_folder: str | Path): sub_denoise_pt = Path(save_folder) / "subtraction_denoising_pipeline.pt" torch.save(self.subtraction_denoising_pipeline.state_dict(), sub_denoise_pt) + super().save_models(save_folder) - def load_models(self, save_folder): - super().load_models(save_folder) + def load_models(self, save_folder: str | Path): sub_denoise_pt = Path(save_folder) / "subtraction_denoising_pipeline.pt" if sub_denoise_pt.exists(): state_dict = torch.load(sub_denoise_pt, weights_only=True) self.subtraction_denoising_pipeline.load_state_dict(state_dict) + super().load_models(save_folder) @classmethod def from_config( @@ -355,6 +352,7 @@ def fit_peeler_models(self, save_folder, tmp_dir=None, computation_cfg=None): gc.collect() torch.cuda.empty_cache() + self.save_models(save_folder=save_folder) def _fit_subtraction_transformers( self, save_folder, tmp_dir=None, computation_cfg=None, which="denoisers" From 0185bf2be9299c086d8c2a08fdff044ac51ed4da Mon Sep 17 00:00:00 2001 From: Charlie Windolf Date: Sun, 14 Jun 2026 16:33:59 -0400 Subject: [PATCH 29/46] matching: implement temporal whitening --- src/dartsort/peel/matching.py | 22 ++++- .../matching_util/compressed_upsampled.py | 13 ++- src/dartsort/peel/matching_util/drifty.py | 88 ++++++++++++------- .../peel/matching_util/matching_base.py | 9 +- src/dartsort/peel/peel_lib.py | 3 +- src/dartsort/util/noise_util.py | 52 +++++++---- .../util/testing_util/matching_debug_util.py | 25 ++++-- tests/test_matching.py | 1 + 8 files changed, 151 insertions(+), 62 deletions(-) diff --git a/src/dartsort/peel/matching.py b/src/dartsort/peel/matching.py index eb6eb21a..5cf95819 100644 --- a/src/dartsort/peel/matching.py +++ b/src/dartsort/peel/matching.py @@ -51,7 +51,8 @@ def __init__( fpctrl_spike_counts=None, fit_sampling_cfg: FitSamplingConfig = default_peeling_fit_sampling_cfg, save_collidedness=False, - whiten_features=True, + whiten_features=False, + whiten_kernel_length=0, parent_sorting_hdf5_path: str | Path | None = None, dtype=torch.float, ): @@ -109,7 +110,9 @@ def __init__( ) self.amp_scale_max = 1.0 + p.amplitude_scaling_boundary self.amp_scale_min = 1.0 / self.amp_scale_max - self.obj_pad_len = max(p.refractory_radius_frames, self.spike_length_samples) + self.whiten_pad = max(0, whiten_kernel_length - 1) + pt = self.spike_length_samples + self.whiten_pad + self.obj_pad_len = max(p.refractory_radius_frames, pt) conv_len = ( self.chunk_length_samples + 2 * self.chunk_margin_samples @@ -192,6 +195,10 @@ def from_config( matching_cfg.precomputed_templates_npz ) assert trough_offset_samples == template_data.trough_offset_samples + if template_data.temporal_kernel is None: + whiten_kernel_length = 0 + else: + whiten_kernel_length = template_data.temporal_kernel.shape[0] if motion is None: motion = MotionInfo.from_motion_est(geom=geom.numpy()) @@ -237,6 +244,7 @@ def from_config( parent_sorting_hdf5_path=parent_sorting_hdf5_path, save_collidedness=save_collidedness, whiten_features=matching_cfg.whiten_features, + whiten_kernel_length=whiten_kernel_length, fpctrl_spike_counts=template_data.spike_counts if matching_cfg.threshold == "fp_control" else None, @@ -258,12 +266,17 @@ def peel_chunk( chunk_center_samples = chunk_start_samples + self.chunk_length_samples // 2 segment = self.recording._recording_segments[0] chunk_center_seconds = float(segment.sample_index_to_time(chunk_center_samples)) + if self.whiten_features: + resid_offset = self.whiten_pad + else: + resid_offset = 0 chunk_template_data = self.matching_templates.data_at_time( t_s=chunk_center_seconds, scaling=self.is_scaling, inv_lambda=self.inv_lambda, scale_min=self.amp_scale_min, scale_max=self.amp_scale_max, + resid_offset=resid_offset, ) # deconvolve @@ -314,8 +327,8 @@ def match_chunk( # name objective variables so that we can update them in-place later # padded objective has an extra unit (for group_index) and refractory # padding (for easier implementation of enforce_refractory) - valid_len = traces.shape[0] - self.spike_length_samples + 1 - padded_obj_len = valid_len + 2 * self.obj_pad_len + valid_len = traces.shape[0] - self.spike_length_samples - self.whiten_pad + 1 + padded_obj_len = valid_len + 2 * self.obj_pad_len + self.whiten_pad padded_conv = traces.new_zeros( chunk_template_data.obj_n_templates, padded_obj_len ) @@ -419,6 +432,7 @@ def match_chunk( if not chunk_template_data.needs_residual: chunk_template_data.subtract(residual_padded, peaks) + assert residual.shape[0] == traces.shape[0] if not peaks.n_spikes: res = PeelingBatchResult(n_spikes=0) if return_residual: diff --git a/src/dartsort/peel/matching_util/compressed_upsampled.py b/src/dartsort/peel/matching_util/compressed_upsampled.py index 2d4ad3ad..d01738c7 100644 --- a/src/dartsort/peel/matching_util/compressed_upsampled.py +++ b/src/dartsort/peel/matching_util/compressed_upsampled.py @@ -260,7 +260,9 @@ def data_at_time( inv_lambda: float, scale_min: float, scale_max: float, + resid_offset: int, ) -> "CompressedUpsampledChunkTemplateData": + assert not resid_offset if self.drifting: shifts, padded_spatial_sing = templates_at_time( t_s=t_s, @@ -300,6 +302,7 @@ def data_at_time( return CompressedUpsampledChunkTemplateData( coarse_objective=self.coarse_objective, + resid_offset=resid_offset, grouping=self.have_groups, upsampling=self.upsampling, scaling=scaling, @@ -308,6 +311,7 @@ def data_at_time( n_templates=self.n_templates, obj_n_templates=self.obj_n_templates, spike_length_samples=self.spike_length_samples, + filter_length_samples=self.spike_length_samples, up_factor=self.b.cup_index.shape[1], inv_lambda=torch.tensor(inv_lambda, device=normsq.device), scale_min=torch.tensor(scale_min, device=normsq.device), @@ -354,6 +358,8 @@ class CompressedUpsampledChunkTemplateData(ChunkTemplateData): inv_lambda: Tensor scale_min: Tensor scale_max: Tensor + resid_offset: int + filter_length_samples: int # objective props obj_normsq: Tensor @@ -458,7 +464,12 @@ def subtract(self, traces, peaks, sign=-1): ) def fine_match( - self, *, peaks: MatchingPeaks, residual: Tensor | None, conv: Tensor, padding: int = 0 + self, + *, + peaks: MatchingPeaks, + residual: Tensor | None, + conv: Tensor, + padding: int = 0, ): """Determine superres ids, temporal upsampling, and scaling diff --git a/src/dartsort/peel/matching_util/drifty.py b/src/dartsort/peel/matching_util/drifty.py index 8363eec1..c7195d86 100644 --- a/src/dartsort/peel/matching_util/drifty.py +++ b/src/dartsort/peel/matching_util/drifty.py @@ -138,20 +138,31 @@ def __init__( up_temporal_comps = upsample_singlechan_torch( temporal_comps, temporal_jitter=up_factor ) - tconv = shared_temporal_pconv(temporal_comps, up_temporal_comps) + if whitener is not None: + conv_temporal_comps = whitener._convolve(temporal_comps, padding="full") + conv_up_temporal_comps = whitener._convolve( + up_temporal_comps, padding="full" + ) + norm_discount = torch.linalg.vector_norm( + conv_temporal_comps, dim=1, keepdim=True + ) + else: + conv_temporal_comps = temporal_comps + conv_up_temporal_comps = up_temporal_comps + norm_discount = None + tconv = shared_temporal_pconv(conv_temporal_comps, conv_up_temporal_comps) assert temporal_comps.shape == (rank, self.spike_length_samples) + self.register_buffer("conv_temporal_comps", conv_temporal_comps.contiguous()) + self.register_buffer_or_none("norm_discount", norm_discount) self.register_buffer("temporal_comps", temporal_comps.contiguous()) assert up_temporal_comps.shape == (rank, up_factor, self.spike_length_samples) self.register_buffer_or_none( "up_temporal_comps", up_temporal_comps.contiguous() ) - if up_temporal_comps is not None: - up_major_temporal_comps = up_temporal_comps.permute(1, 0, 2).contiguous() - else: - up_major_temporal_comps = None + up_major_temporal_comps = up_temporal_comps.permute(1, 0, 2).contiguous() self.register_buffer_or_none("up_major_temporal_comps", up_major_temporal_comps) - self.register_buffer("spatial_sing", spatial_sing) + self.register_buffer("spatial_sing", spatial_sing.contiguous()) self.register_buffer("tconv", tconv) if unit_ids is None: unit_ids = torch.arange(self.n_units, device=device) @@ -160,11 +171,13 @@ def __init__( if self.whiten_strategy == "postwhiten": assert whitener is not None self.whitener = whitener.to(spatial_sing.device) - conv_spatial_sing = self.whitener.transpose_whiten(spatial_sing) + conv_spatial_sing = self.whitener.transpose_whiten( + spatial_sing, spatial_only=True + ) elif self.whiten_strategy == "prewhiten_postapply" and not self.interpolating: assert whitener is not None self.whitener = whitener.to(spatial_sing.device) - conv_spatial_sing = self.whitener.whiten(spatial_sing) + conv_spatial_sing = self.whitener.whiten(spatial_sing, spatial_only=True) elif self.whiten_strategy == "prewhiten_postapply": assert whitener is not None self.whitener = whitener.to(spatial_sing.device) @@ -177,6 +190,8 @@ def __init__( self.whitener = conv_spatial_sing = None else: assert False + if conv_spatial_sing is not None: + conv_spatial_sing = conv_spatial_sing.contiguous() self.register_buffer_or_none("conv_spatial_sing", conv_spatial_sing) # full pconv can be precomputed when not interpolating @@ -194,11 +209,16 @@ def __init__( # indexing helpers t = self.spike_length_samples rr = refractory_radius_frames + if self.whitener is not None: + ct = t + max(0, self.whitener.temporal_length - 1) + else: + ct = t self.register_buffer("refrac_ix", torch.arange(-rr, rr + 1, device=device)) - self.register_buffer("time_ix", torch.arange(t, device=device)) + self.register_buffer("time_ix", torch.arange(ct, device=device)) + self.register_buffer("sub_time_ix", torch.arange(t, device=device)) self.register_buffer("chan_ix", torch.arange(n_channels, device=device)) self.register_buffer("rank_ix", torch.arange(rank, device=device)) - self.register_buffer("conv_lags", torch.arange(-t + 1, t, device=device)) + self.register_buffer("conv_lags", torch.arange(-ct + 1, ct, device=device)) offset = torch.asarray(trough_offset_samples, device=device) self.register_buffer("trough_offset_samples", offset) @@ -291,7 +311,9 @@ def spatial_at_time(self, t_s: float) -> tuple[Tensor, ...]: # for convolution, whitening is applied if self.interpolating: - conv_spatial_sing = self.whitener.whiten(spatial_sing) + conv_spatial_sing = self.whitener.whiten( + spatial_sing, spatial_only=True + ) else: conv_spatial_sing = self.b.conv_spatial_sing @@ -303,7 +325,11 @@ def spatial_at_time(self, t_s: float) -> tuple[Tensor, ...]: assert False # normsq for channel selection from original - normsq_by_chan = normsq_spatial_sing.square().sum(dim=1) + norm_discount = self.b.norm_discount + if norm_discount is None: + normsq_by_chan = normsq_spatial_sing.square().sum(dim=1) + else: + normsq_by_chan = (norm_discount * normsq_spatial_sing).square_().sum(dim=1) main_channels = normsq_by_chan.argmax(dim=1) normsq = normsq_by_chan.sum(dim=1) @@ -325,6 +351,7 @@ def data_at_time( inv_lambda: float, scale_min: float, scale_max: float, + resid_offset: int, ) -> ChunkTemplateData: spatial_sing, normsq, main_channels, padded_spatial_sing, pconv = ( self.spatial_at_time(t_s=t_s) @@ -346,11 +373,12 @@ def data_at_time( inv_lambda=torch.asarray(inv_lambda).to(normsq, non_blocking=True), scale_min=torch.asarray(scale_min).to(normsq, non_blocking=True), scale_max=torch.asarray(scale_max).to(normsq, non_blocking=True), - temporal_comps=self.b.temporal_comps, + conv_temporal_comps=self.b.conv_temporal_comps, up_major_temporal_comps=self.b.up_major_temporal_comps, spatial_sing=spatial_sing, pconv=pconv, time_ix=self.b.time_ix, + sub_time_ix=self.b.sub_time_ix + resid_offset, chan_ix=self.b.chan_ix, rank_ix=self.b.rank_ix, conv_lags=self.b.conv_lags, @@ -363,6 +391,8 @@ def data_at_time( @databag class DriftyChunkTemplateData(ChunkTemplateData): spike_length_samples: int + filter_length_samples: int + resid_offset: int # for the full templates unit_ids: Tensor main_channels: Tensor @@ -378,7 +408,7 @@ class DriftyChunkTemplateData(ChunkTemplateData): scale_min: Tensor scale_max: Tensor - temporal_comps: Tensor + conv_temporal_comps: Tensor up_major_temporal_comps: Tensor spatial_sing: Tensor padded_spatial_sing: Tensor @@ -386,6 +416,7 @@ class DriftyChunkTemplateData(ChunkTemplateData): spatial_whitener: Whitener | None time_ix: Tensor + sub_time_ix: Tensor chan_ix: Tensor rank_ix: Tensor refrac_ix: Tensor @@ -401,13 +432,13 @@ def convolve(self, traces: Tensor, padding: int = 0, out: Tensor | None = None): Odd to have batch come second, but it makes things simpler here (and it's only something used for debugging: usually traces.ndim == 2). """ - out_len = traces.shape[-1] + 2 * padding - self.spike_length_samples + 1 + out_len = traces.shape[-1] + 2 * padding - self.filter_length_samples + 1 if out is not None: assert out.shape == (self.obj_n_templates, out_len) conv = convolve_lowrank_shared( traces=traces, spatial_singular=self.spatial_sing, - temporal_components=self.temporal_comps, + temporal_components=self.conv_temporal_comps, padding=padding, out=out, ) @@ -419,8 +450,8 @@ def score(self, spikes: Tensor) -> Tensor: n, t, c = spikes.shape spikes = self.whiten_traces(spikes) spikes_t = spikes.mT.reshape(n * c, t) - rank = self.temporal_comps.shape[0] - tconv = spikes_t @ self.temporal_comps.T + rank = self.conv_temporal_comps.shape[0] + tconv = spikes_t @ self.conv_temporal_comps.T spatial_t = self.spatial_sing.permute(2, 1, 0) # now chan,rank,unit conv = tconv.view(n, c * rank) @ spatial_t.reshape( c * rank, self.obj_n_templates @@ -438,27 +469,16 @@ def subtract(self, traces: Tensor, peaks: "MatchingPeaks", sign: int = -1): return assert peaks.times is not None assert peaks.template_inds is not None - - if peaks.up_inds is None: - tempc = self.temporal_comps[None] - else: - tempc = self.up_major_temporal_comps - assert sign in (-1, 1) - tempc = ( - self.temporal_comps - if peaks.up_inds is None - else self.up_major_temporal_comps - ) _subtract_templates_loop( traces=traces, up_inds=peaks.up_inds, scalings=peaks.scalings, template_inds=peaks.template_inds, - tempc=tempc, + tempc=self.up_major_temporal_comps, spatc=self.padded_spatial_sing, times=peaks.times, - time_ix=self.time_ix, + time_ix=self.sub_time_ix, neg=sign == -1, ) @@ -496,7 +516,7 @@ def get_clean_waveforms( if peaks.up_inds is not None: tempc = self.up_major_temporal_comps[peaks.up_inds] else: - tempc = self.temporal_comps + tempc = self.up_major_temporal_comps[0] tempc = tempc[None].broadcast_to(peaks.n_spikes, *tempc.shape) if add_into is None: return tempc.mT.bmm(spatial) @@ -568,7 +588,9 @@ def reconstruct_up_templates(self): def whiten_traces(self, traces: Tensor, out: Tensor | None = None): if self.prewhiten: assert self.spatial_whitener is not None - return self.spatial_whitener.whiten_traces_spatial_major(traces, out=out) + return self.spatial_whitener.whiten_traces_spatial_major( + traces, out=out, padding="full" + ) elif out is not None: return out.copy_(traces.T) else: diff --git a/src/dartsort/peel/matching_util/matching_base.py b/src/dartsort/peel/matching_util/matching_base.py index 9f5e6389..f487dc1d 100644 --- a/src/dartsort/peel/matching_util/matching_base.py +++ b/src/dartsort/peel/matching_util/matching_base.py @@ -83,6 +83,7 @@ def data_at_time( inv_lambda: float, scale_min: float, scale_max: float, + resid_offset: int, ) -> "ChunkTemplateData": raise NotImplementedError @@ -129,6 +130,9 @@ def spike_length_samples(self) -> int: class ChunkTemplateData: # -- subclasses must assign the following properties that the matcher uses. spike_length_samples: int + filter_length_samples: int + resid_offset: int + # for the full templates unit_ids: Tensor main_channels: Tensor @@ -260,7 +264,7 @@ def coarse_match( assert nt > 2 * padding times = argrelmax_dedup( x=objective_max, - dedup_radius=self.spike_length_samples, + dedup_radius=self.filter_length_samples, threshold=thresholdsq, arange=obj_arange[:nt], padding=padding, @@ -329,12 +333,13 @@ def get_collisioncleaned_waveforms( assert times is not None # get noise + # TODO check the offset is correct waveforms = grab_spikes( residual_padded, times, channels, sel_ci, - trough_offset=0, + trough_offset=self.resid_offset, spike_length_samples=self.spike_length_samples, buffer=0, already_padded=True, diff --git a/src/dartsort/peel/peel_lib.py b/src/dartsort/peel/peel_lib.py index 6ee35a7d..39d8af54 100644 --- a/src/dartsort/peel/peel_lib.py +++ b/src/dartsort/peel/peel_lib.py @@ -101,9 +101,8 @@ def check_residual_decrease( dn_wfs = dn_wfs.nan_to_num() # spatial mul -- putting temporal dim last here - buf = orig_wfs orig_wfs = W.bmm(orig_wfs.mT) - dn_wfs = torch.bmm(W, dn_wfs.mT, out=buf) + dn_wfs = W.bmm(dn_wfs.mT) # temporal conv if needed if whitening_kernel is not None: diff --git a/src/dartsort/util/noise_util.py b/src/dartsort/util/noise_util.py index 3d4293af..267d13c3 100644 --- a/src/dartsort/util/noise_util.py +++ b/src/dartsort/util/noise_util.py @@ -1511,6 +1511,7 @@ def residual_welch_whitener( temporal_length: int = 11, spatial_whitener: torch.Tensor | None = None, ): + """Estimate a 0-phase whitening convolution kernel with Welch's method""" assert sorting.parent_h5_path is not None snipgen = sorting._yield_dataset(residual_dataset_name, batch_size=batch_size) @@ -1527,7 +1528,7 @@ def residual_welch_whitener( snip = torch.asarray(snip).to(device=device, non_blocking=True) if W is not None: snip = torch.einsum("ntc,cd->ntd", snip, W) - snip = snip.permute(0, 2, 1).view(-1, snip.shape[1]).contiguous() + snip = snip.permute(0, 2, 1).reshape(-1, snip.shape[1]) periodogram = torch.fft.rfft(snip, n=block_len, norm="ortho") dens = (periodogram * periodogram.conj()).mean(dim=0) snip_psds.append(dens) @@ -1535,13 +1536,14 @@ def residual_welch_whitener( spectral_density = snip_psds.mean(0).sqrt_() # estimate 0-phase FIR whitener - wkernel = torch.fft.fftshift(torch.fft.irfft(1.0 / spectral_density, n=block_len)) + wkernel = torch.fft.irfft(1.0 / spectral_density, n=block_len) + wkernel = torch.fft.fftshift(wkernel) # trim to requested length assert temporal_length <= wkernel.shape[0] i0 = wkernel.shape[0] // 2 - temporal_length // 2 i1 = i0 + temporal_length - wkernel = wkernel[i0:i1].copy() + wkernel = wkernel[i0:i1].clone() return wkernel @@ -1605,8 +1607,10 @@ def __init__( self.temporal = temporal_kernel is not None self.register_buffer_or_none("temporal_kernel", temporal_kernel) if temporal_kernel is not None: + self.temporal_length = temporal_kernel.shape[0] tk_twice = self._convolve(temporal_kernel) else: + self.temporal_length = 0 tk_twice = None self.register_buffer_or_none("temporal_kernel_twice", tk_twice) @@ -1626,8 +1630,13 @@ def from_numpy( covariance: np.ndarray, temporal_kernel: np.ndarray | None, ): - logger.dartsortverbose("Load whitener from numpy.") - tk = None if temporal_kernel is None else torch.asarray(temporal_kernel) + if temporal_kernel is None: + tk = None + tmsg = "" + else: + tmsg = f" with temporal length {temporal_kernel.shape[0]}" + tk = torch.asarray(temporal_kernel) + logger.dartsortverbose("Load whitener%s from numpy.", tmsg) return cls( whitener=torch.asarray(whitener), covariance=torch.asarray(covariance), @@ -1686,12 +1695,13 @@ def from_config( temporal_length=whiten_cfg.temporal_length, spatial_whitener=whitener, ) + assert temporal_kernel.shape == (whiten_cfg.temporal_length,) else: temporal_kernel = None return cls(whitener=whitener, covariance=cov, temporal_kernel=temporal_kernel) - def _convolve(self, x: Tensor, twice=False): + def _convolve(self, x: Tensor, twice=False, padding="same"): if not self.temporal: return x *shp, t = x.shape @@ -1700,39 +1710,51 @@ def _convolve(self, x: Tensor, twice=False): k = self.b.temporal_kernel_twice else: k = self.b.temporal_kernel + if padding == "full": + padding = k.shape[0] - 1 + k = k.to(device=x.device) res = F.conv1d( input=x, weight=k[None, None], - padding="same", + padding=padding, groups=x.shape[1], ) - assert res.shape[-1] == t - res = res.reshape(*shp, t) + if padding == "same": + assert res.shape[-1] == t + ot = t + else: + assert isinstance(padding, int) + ot = t + padding + res = res.reshape(*shp, ot) return res def whiten_traces_spatial_major( - self, x: Tensor, out: Tensor | None = None + self, x: Tensor, out: Tensor | None = None, padding="same" ) -> Tensor: assert x.ndim == 2 out = torch.mm(self.b.whitener, x.T, out=out) - out = self._convolve(out) + out = self._convolve(out, padding=padding) return out - def whiten(self, x: Tensor, out: Tensor | None = None) -> Tensor: + def whiten( + self, x: Tensor, out: Tensor | None = None, spatial_only: bool = False + ) -> Tensor: *shp, c = x.shape x = x.reshape(-1, c) x = torch.mm(x, self.b.whitener.T, out=out) x = x.reshape(*shp, c) - if self.temporal: + if self.temporal and not spatial_only: x = self._convolve(x.mT).mT return x - def transpose_whiten(self, x: Tensor, out: Tensor | None = None) -> Tensor: + def transpose_whiten( + self, x: Tensor, out: Tensor | None = None, spatial_only: bool = False + ) -> Tensor: *shp, c = x.shape x = x.reshape(-1, c) x = torch.mm(x, self.b.whitener, out=out) x = x.reshape(*shp, c) - if self.temporal: + if self.temporal and not spatial_only: x = self._convolve(x.mT).mT return x diff --git a/src/dartsort/util/testing_util/matching_debug_util.py b/src/dartsort/util/testing_util/matching_debug_util.py index 4cf29183..7a9432d8 100644 --- a/src/dartsort/util/testing_util/matching_debug_util.py +++ b/src/dartsort/util/testing_util/matching_debug_util.py @@ -86,20 +86,22 @@ def yield_step_results( device = matcher.b.channel_index.device chunk = torch.asarray(chunk, device=device) assert matcher.matching_templates is not None + assert not matcher.whiten_features chunk_data = matcher.matching_templates.data_at_time( t_s, scaling=matcher.is_scaling, inv_lambda=matcher.inv_lambda, scale_min=matcher.amp_scale_min, scale_max=matcher.amp_scale_max, + resid_offset=0,#matcher.whiten_pad, ) - cur_residual = chunk.clone() for it in ( progrange(max_iter, desc="Match steps") if show_progress else range(max_iter) ): - pre_conv = chunk_data.convolve(cur_residual.T, padding=matcher.obj_pad_len) + cur_traces_wh = chunk_data.whiten_traces(cur_residual) + pre_conv = chunk_data.convolve(cur_traces_wh, padding=matcher.obj_pad_len) if obj_mode: pre_conv = chunk_data.obj_from_conv( conv=pre_conv, @@ -162,6 +164,7 @@ def visualize_step_results( chunk_vis_style: Literal["im", "trace"] = "im", gt_sorting: DARTsortSorting | None = None, vis_only_last_step: bool = False, + vline_new_peaks=True, vline_at=None, objline_at=None, ): @@ -256,6 +259,11 @@ def visualize_step_results( ) ax.set_ylabel(name) + if vline_new_peaks: + for ax in axes.flat: + for ts, ll in zip(times_samples, labels): + ax.axvline(ts, c=glasbey1024[ll % len(glasbey1024)], lw=0.5, ls=":") + if gt_t is not None: axes[-3].scatter(gt_t, gt_chan, c=gt_c, s=4 * s, lw=0, marker="o") @@ -276,6 +284,7 @@ def visualize_step_results( ec="w", lw=1, ) + axes[-3].set_ylim([0, chunk.shape[1]]) vmax = max(np.nanmax(pre_conv[:, obj_sl]), np.nanmax(conv[:, obj_sl])) for j, c in enumerate(pre_conv): @@ -300,12 +309,12 @@ def visualize_step_results( if vline_at is not None: for ax in axes[:3]: - ax.axvline(vline_at, color='w', ls='--', lw=0.8) + ax.axvline(vline_at, color="w", ls="--", lw=0.8) for ax in axes[-2:]: - ax.axvline(vline_at, color='k', ls='--', lw=0.8) + ax.axvline(vline_at, color="k", ls="--", lw=0.8) if objline_at is not None: for ax in axes[-2:]: - ax.axhline(objline_at, color='k', ls='--', lw=0.8) + ax.axhline(objline_at, color="k", ls="--", lw=0.8) plt.show() plt.close(panel) @@ -381,9 +390,13 @@ def data_at_time( inv_lambda: float, scale_min: float, scale_max: float, + resid_offset: int = 0, ) -> ChunkTemplateData: + assert not resid_offset return DebugChunkTemplateData( spike_length_samples=self.b.templates_up.shape[2], + filter_length_samples=self.b.templates_up.shape[2], + resid_offset=resid_offset, unit_ids=torch.arange( self.b.templates_up.shape[0], device=self.b.pconv.device ), @@ -408,6 +421,8 @@ def data_at_time( @databag class DebugChunkTemplateData(ChunkTemplateData): spike_length_samples: int + filter_length_samples: int + resid_offset: int unit_ids: Tensor main_channels: Tensor obj_normsq: Tensor diff --git a/tests/test_matching.py b/tests/test_matching.py index 1bd747eb..89b589ad 100644 --- a/tests/test_matching.py +++ b/tests/test_matching.py @@ -164,6 +164,7 @@ def test_no_crumbs(subtests, refractory_sim, method, cd_iter, channel_selection_ inv_lambda=matcher.inv_lambda, scale_min=matcher.amp_scale_min, scale_max=matcher.amp_scale_max, + resid_offset=0, ) conv = res["conv"].numpy(force=True) From f1191ce4748acceff4b857807a8c38264f96c30d Mon Sep 17 00:00:00 2001 From: Charlie Windolf Date: Wed, 17 Jun 2026 08:46:16 -0400 Subject: [PATCH 30/46] clus: single split, optional kmeans after dpc --- src/dartsort/clustering/clustering.py | 36 ++++++- src/dartsort/clustering/kmeans.py | 147 ++++++++++++++++++++++++++ src/dartsort/clustering/mixture.py | 26 ++++- src/dartsort/util/internal_config.py | 4 +- src/dartsort/vis/mixture.py | 8 +- 5 files changed, 210 insertions(+), 11 deletions(-) diff --git a/src/dartsort/clustering/clustering.py b/src/dartsort/clustering/clustering.py index b4eaf3ae..57ed20e4 100644 --- a/src/dartsort/clustering/clustering.py +++ b/src/dartsort/clustering/clustering.py @@ -19,7 +19,16 @@ from ..util.main_util import ds_save_intermediate_labels from ..util.motion import MotionInfo from ..util.multiprocessing_util import handle_negative_jobs -from . import agglomerate, cluster_util, density, forward_backward, mixture, refine_util +from ..util.torch_util import cleanup_and_log_gpu_usage +from . import ( + agglomerate, + cluster_util, + density, + forward_backward, + kmeans, + mixture, + refine_util, +) from .clustering_features import SimpleMatrixFeatures, StableWaveformFeatures if TYPE_CHECKING: @@ -292,6 +301,10 @@ def __init__( workers=-1, uhdversion=False, random_seed=0, + kmeans_cleanup=True, + kmeans_max_sigma: float = 5.0, + kmeans_iter=100, + device=None, **kwargs, ): super().__init__(**kwargs) @@ -308,6 +321,10 @@ def __init__( self.workers = workers self.uhdversion = uhdversion self.random_seed = random_seed + self.kmeans_cleanup = kmeans_cleanup + self.kmeans_max_sigma = kmeans_max_sigma + self.kmeans_iter = kmeans_iter + self.device = device @classmethod def from_config( @@ -335,6 +352,7 @@ def from_config( outlier_radius=clustering_cfg.outlier_radius, outlier_neighbor_count=clustering_cfg.outlier_neighbor_count, workers=workers, + device=computation_cfg.actual_device(), uhdversion=uhdversion, computation_cfg=computation_cfg, waveform_cfg=waveform_cfg, @@ -342,6 +360,9 @@ def from_config( save_labels_dir=save_labels_dir, labels_fmt=labels_fmt, sampling_cfg=clustering_cfg.sampling_cfg, + kmeans_cleanup=clustering_cfg.dpc_kmeans_cleanup, + kmeans_iter=clustering_cfg.kmeans_iter, + kmeans_max_sigma=clustering_cfg.gmmdpc_max_sigma, ) def _cluster( @@ -425,6 +446,19 @@ def _cluster_extra( kdtree = None labels = res["labels"] + if self.kmeans_cleanup: + kres = kmeans.truncated_kmeans_from_labels( + X=X, + labels=labels, + device=self.device, + max_sigma=self.kmeans_max_sigma, + n_iter=self.kmeans_iter, + ) + assert kres.labels is not None + labels = kres.labels.numpy(force=True) + del kres + cleanup_and_log_gpu_usage(self.computation_cfg, "DPC->kmeans") + labels = cluster_util.decrumb( labels, min_size=self.remove_clusters_smaller_than, in_place=True ) diff --git a/src/dartsort/clustering/kmeans.py b/src/dartsort/clustering/kmeans.py index ff02fac8..0809b82c 100644 --- a/src/dartsort/clustering/kmeans.py +++ b/src/dartsort/clustering/kmeans.py @@ -758,3 +758,150 @@ def batched_kmeans( labels = None return KMeansResult(centroids=Y, responsibilities=e, labels=labels, dists=dists) + + +def truncated_kmeans_from_labels( + X: Tensor | np.ndarray, + labels: Tensor | np.ndarray, + device=None, + atol=1e-3, + max_sigma=5.0, + dirichlet_alpha=1.0, + n_iter=100, + show_progress: bool = True, + batch_size: int = 4096, + centroid_dist_batch_size: int = 128, + min_log_prop=-25.0, + trunc_guess=20, + initial_undershoot=2.0, +) -> KMeansResult: + X = torch.asarray(X, device=device) + labels = torch.asarray(labels, device=device) + n, dim = X.shape + assert labels.shape == (n,) + is_gpu = X.device.type == "cuda" + + # flatten and count labels + ulabels, labels = labels.unique(return_inverse=True) + k = ulabels.shape[0] + del ulabels + + # initialize parameters + e = F.one_hot(labels, k).to(X) + log_props = e.mean(0).log_() + N = e.sum(dim=0) + w = e.div_(N) + Y = torch.mm(w.view(n, k).t(), X) + Ynormsq = torch.linalg.vector_norm(Y, dim=1).square_() + + # initialize sigma + sigmasq = X.new_zeros(()) + bY = X.new_empty((batch_size, dim)) + for i0 in range(0, n, batch_size): + i1 = min(n, i0 + batch_size) + bY = torch.index_select(Y, 0, labels[i0:i1], out=bY[: i1 - i0]) + bsigsq = bY.sub_(X[i0:i1]).square_().mean() + sigmasq += (bsigsq - sigmasq) * ((i1 - i0) / i1) + del bY + assert sigmasq > 0 + sigma = initial_undershoot * sigmasq.sqrt_() + assert sigma.isfinite().item() + prev_sigma = sigma.clone() + + # storage + dYY = X.new_zeros((k, k)) + dYYmask = X.new_zeros((k, k), dtype=torch.bool) + new_Y = torch.empty_like(Y) + distsq_buf = torch.zeros_like(X[: min(trunc_guess, k) * batch_size]) + distsq_buf = (distsq_buf, torch.zeros_like(distsq_buf)) + + if show_progress: + it = progrange(n_iter, desc=f"kmeans σ={sigma:0.4f}") + else: + it = range(n_iter) + + done = False + for j in it: + done = done or j == n_iter - 1 + max_distance_sq = max_sigma * sigmasq * dim + + new_Y.fill_(0.0) + N.fill_(0.0) + new_sigmasq = 0.0 + weight = 0.0 + + # update centroid dists + for i0 in range(0, k, centroid_dist_batch_size): + i1 = min(k, i0 + centroid_dist_batch_size) + sqeuc_cdist_known_norm(Y[i0:i1], Ynormsq[i0:i1], Y, Ynormsq, out=dYY[i0:i1]) + torch.lt(dYY, max_distance_sq, out=dYYmask) + + for i0 in range(0, n, batch_size): + i1 = min(n, i0 + batch_size) + distsq_coo, distsq_buf = sparse_centroid_distsq( + X[i0:i1], + Y, + labels=labels[i0:i1], + centroid_mask=dYYmask, + dbufs=distsq_buf, + ) + assert distsq_coo.shape == (i1 - i0, k) + distsq_values = distsq_coo.values().clone() + liks = distsq_to_lik_coo(distsq_coo, sigmasq, log_props, in_place=True) + del distsq_coo + + resps = torch.sparse.softmax(liks, dim=1) + # update labels... torch sparse has no argmax(), so need scipy + # or cupy. scipy is a big slowdown here, so cupy if possible. + if is_gpu and HAVE_CUPY: + resps_cupy = coo_to_cupy(resps).tocsc() + batch_labels = resps_cupy.argmax(axis=1) + else: + resps_scipy = coo_to_scipy(resps) + batch_labels = resps_scipy.argmax(axis=1, explicit=True) + labels[i0:i1] = torch.as_tensor(batch_labels).to(labels).squeeze() + + # get sigmasq + w = resps.values().clone() + batch_w = w.sum() + w /= batch_w + batch_sigmasq = torch.sum(distsq_values.mul_(w)) / dim + + # get N and centroids + batch_N = resps.sum(dim=0).to_dense() + resps.values().div_(batch_N[resps.indices()[1]]) + batch_centroids = resps.T @ X[i0:i1] + + # update counts + N += batch_N + weight += batch_w + + # update Welford running means + n1_n01 = batch_N.div_(N.clip(min=1e-5))[:, None] + w1_w01 = batch_w / weight + new_Y += batch_centroids.sub_(new_Y).mul_(n1_n01) + new_sigmasq += batch_sigmasq.sub_(new_sigmasq).mul_(w1_w01) + + # update state + logN = N.log_() + dirichlet_alpha + log_props = F.log_softmax(logN, dim=0).to(X) + log_props = log_props.clamp_(min=min_log_prop) + Y, new_Y = new_Y, Y + Ynormsq = torch.linalg.vector_norm(Y, dim=1).square_() + sigmasq = new_sigmasq + + # check convergence + sigma = torch.sqrt(sigmasq).numpy(force=True).item() # type: ignore + if abs(sigma - prev_sigma) < atol: + break + + prev_sigma = sigma + if show_progress: + it.set_description(f"kmeans σ={sigma:0.4f}") # type: ignore + + return KMeansResult( + labels=labels, + responsibilities=None, + centroids=Y, + dists=None, + ) diff --git a/src/dartsort/clustering/mixture.py b/src/dartsort/clustering/mixture.py index a3eae6fc..8d699f73 100644 --- a/src/dartsort/clustering/mixture.py +++ b/src/dartsort/clustering/mixture.py @@ -178,8 +178,14 @@ def tmm_demix( allow_blanks = False for outer_it in range(refinement_cfg.n_total_iters): for inner_it, step_type in enumerate(refinement_cfg.mixture_steps): - if step_type == "split": - run_split(tmm, train_data, val_data, prog_level) + if step_type.endswith("split"): + run_split( + tmm, + train_data, + val_data, + prog_level, + single=step_type.startswith("single"), + ) tmm.em( train_data, show_progress=prog_level, @@ -788,6 +794,7 @@ class TMMParams: split_max_distance: float merge_max_distance: float split_k: int + single_split_k: int distance_kind: ComponentDistanceMetric em_iters: int min_em_iters: int @@ -824,6 +831,7 @@ def from_refinement_cfg(cls, refinement_cfg: RefinementConfig): min_count=refinement_cfg.min_count, split_min_count=refinement_cfg.split_min_count, split_k=refinement_cfg.kmeansk, + single_split_k=refinement_cfg.single_split_k, kmeans_tries=refinement_cfg.kmeans_tries, kmeans_beta=refinement_cfg.kmeans_beta, kmeanspp_tries=refinement_cfg.kmeanspp_tries, @@ -1808,7 +1816,7 @@ def dense_slice_by_unit( gen: torch.Generator, labels: Tensor | None, min_count: int = 0, - ): + ) -> DenseSpikeData | None: assert self.candidates is not None assert unit_ids is not None unit_ids = torch.as_tensor(unit_ids, device=self.candidates.device) @@ -2928,7 +2936,10 @@ def split_group( group_size = group.numel() single = group_size == 1 - k = min(self.p.max_group_size, self.p.split_k) + if single: + k = self.p.single_split_k + else: + k = self.p.split_k # get dense train set slice in group split_data = train_data.dense_slice_by_unit( @@ -3180,11 +3191,14 @@ def split( train_scores: Scores, eval_scores: Scores, show_progress: bool = True, + friend_distance: float | None = None, _stop_after: int | None = None, _dry_run: bool = False, ) -> SplitResult: + if friend_distance is None: + friend_distance = self.p.split_friend_distance split_groups = self.group_units_by_distance( - distance=self.p.split_friend_distance, + distance=friend_distance, max_group_size=max(1, self.p.split_k - 1), ) if _stop_after: @@ -4465,6 +4479,7 @@ def run_split( train_data: TruncatedSpikeData, val_data: TruncatedSpikeData | None, prog_level: int, + single: bool = False, _stop_after: int | None = None, _dry_run: bool = False, ): @@ -4489,6 +4504,7 @@ def run_split( eval_scores=eval_scores, train_scores=train_scores, show_progress=prog_level > 0, + friend_distance=0.0 if single else None, _stop_after=_stop_after, _dry_run=_dry_run, ) diff --git a/src/dartsort/util/internal_config.py b/src/dartsort/util/internal_config.py index 22b685ab..d503fc0f 100644 --- a/src/dartsort/util/internal_config.py +++ b/src/dartsort/util/internal_config.py @@ -289,6 +289,7 @@ class ClusteringConfig: noise_density: float = 0.0 outlier_radius: float | None = 25.0 outlier_neighbor_count: int = 10 + dpc_kmeans_cleanup: bool = False # gmm density peaks additional parameters kmeanspp_initializations: int = 10 @@ -462,7 +463,7 @@ def to_template_config(self, template_cfg: TemplateConfig | None = None): ) -MixtureStep = Literal["split", "merge", "demolish"] +MixtureStep = Literal["split", "singlesplit", "merge", "demolish"] ComponentDistanceMetric = Literal["cosine", "normeuc", "scaled_normeuc"] @@ -517,6 +518,7 @@ class RefinementConfig: mixture_steps: Sequence[MixtureStep] = ("split", "merge", "demolish") prior_pseudocount: float = 0.0 kmeansk: int = 4 + single_split_k: int = 3 kmeans_tries: int = 10 kmeans_beta: float = 50.0 kmeanspp_tries: int = 5 diff --git a/src/dartsort/vis/mixture.py b/src/dartsort/vis/mixture.py index 4a63ce3c..ca41d1c7 100644 --- a/src/dartsort/vis/mixture.py +++ b/src/dartsort/vis/mixture.py @@ -487,12 +487,12 @@ def draw(self, panel, mix_data: MixtureVisData, unit_id: int): for split, linestyle in zip(["full", "eval"], "-:"): if split == "full": ssco = mix_data.full_scores - in_unit_id = mix_data.full_inunits[unit_id] - in_nid = mix_data.full_inunits[nid] + in_unit_id = mix_data.full_inunits[int(unit_id)] + in_nid = mix_data.full_inunits[int(nid)] elif split == "eval": ssco = mix_data.eval_scores - in_nid = mix_data.eval_inunits.get(nid, empty) - in_unit_id = mix_data.eval_inunits.get(unit_id, empty) + in_nid = mix_data.eval_inunits.get(int(nid), empty) + in_unit_id = mix_data.eval_inunits.get(int(unit_id), empty) else: assert False From b784f22b387e411fc968fa3efc7e45f9aacd0f4b Mon Sep 17 00:00:00 2001 From: Charlie Windolf Date: Wed, 17 Jun 2026 11:58:24 -0400 Subject: [PATCH 31/46] glom: fixes from Keshav --- src/dartsort/clustering/agglomerate.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/dartsort/clustering/agglomerate.py b/src/dartsort/clustering/agglomerate.py index 59e2cad8..856163c5 100644 --- a/src/dartsort/clustering/agglomerate.py +++ b/src/dartsort/clustering/agglomerate.py @@ -666,7 +666,7 @@ def combine_gmm_scores( # check invariants at the top if responsibilities.shape[1] > 2: _maxdiff = np.diff(responsibilities[:, :-1], axis=1).max() - assert _maxdiff <= 1e-5, _maxdiff + assert _maxdiff <= 1e-3, _maxdiff assert np.greater_equal(np.isneginf(logliks[:, :-1]), candidates == -1).all() if sorting.labels is not None: assert np.all( @@ -700,7 +700,7 @@ def combine_gmm_scores( # check invariants at the bottom if mergedr.shape[1] > 2: _maxdiff = np.diff(mergedr[:, : cand.shape[1]], axis=1).max() - assert _maxdiff <= 1e-5, _maxdiff + assert _maxdiff <= 1e-3, _maxdiff assert np.greater_equal(np.isneginf(mergedl[:, : cand.shape[1]]), cand == -1).all() assert (cand < 0).sum() >= nbye if sorting.labels is not None: @@ -735,7 +735,7 @@ def _combine_loop( continue eq_ncandj = rcand[j + 1 :] == ncandj - if eq_ncandj.sum() <= 1: + if eq_ncandj.sum() < 1: continue rsum = mergedr[s, j] From 8f49f3b41b50084697601615dfe1f671e6136928 Mon Sep 17 00:00:00 2001 From: Charlie Windolf Date: Thu, 18 Jun 2026 18:01:48 -0400 Subject: [PATCH 32/46] cfg: parametrize count/chan --- src/dartsort/config.py | 4 ++-- src/dartsort/main.py | 2 +- src/dartsort/util/internal_config.py | 12 ++++++++++-- tests/test_config.py | 2 +- 4 files changed, 14 insertions(+), 6 deletions(-) diff --git a/src/dartsort/config.py b/src/dartsort/config.py index d96d27e9..c8b40a17 100644 --- a/src/dartsort/config.py +++ b/src/dartsort/config.py @@ -47,7 +47,7 @@ class DARTsortUserConfig: relevant if `preprocessing != 'none'`. If the recording isn't getting saved, stick to float32.""" - subsampling_spikes: int | None = 2_048_000 + subsampling_spikes_per_channel: int | None = 5000 """Detection steps before the final matching round will run until at least this many spikes are found or the whole recording is covered, to make sure that there is enough data for clustering. See also subsampling_fraction. @@ -242,7 +242,7 @@ class DeveloperConfig(DARTsortUserConfig): n_waveforms_fit: int = 40_000 max_waveforms_fit: int = 50_000 fit_sampling: Literal["random", "amp_reweighted"] = "amp_reweighted" - n_residual_snips: int = 4 * 4096 + n_residual_snips: int = 2 * 4096 # initial detection nn_denoiser_max_waveforms_fit: int = 512_000 diff --git a/src/dartsort/main.py b/src/dartsort/main.py index a880dc8c..fffdcb0e 100644 --- a/src/dartsort/main.py +++ b/src/dartsort/main.py @@ -121,7 +121,7 @@ def dartsort( output_dir = ensure_path(output_dir, mkdir=True) # convert cfg to internal format and store it for posterity - cfg = to_internal_config(cfg) + cfg = to_internal_config(cfg, recording.get_num_channels()) ds_dump_config(cfg, output_dir) # in benchmarking, it can be useful to resume from initial detection diff --git a/src/dartsort/util/internal_config.py b/src/dartsort/util/internal_config.py index d503fc0f..e12c5db6 100644 --- a/src/dartsort/util/internal_config.py +++ b/src/dartsort/util/internal_config.py @@ -961,13 +961,16 @@ class DARTsortInternalConfig: save_everything_on_error: bool = False -def to_internal_config(cfg) -> DARTsortInternalConfig: +def to_internal_config(cfg, n_channels: int) -> DARTsortInternalConfig: """Laundromat of configuration formats Parameters ---------- cfg : str | Path | DARTsortUserConfig | DeveloperConfig If str or Path, it should point to a .toml file. + n_channels : int + The number of channels in the input recording, used to adjust + parameters which are specified as counts per channel. Returns ------- @@ -1327,6 +1330,11 @@ def to_internal_config(cfg) -> DARTsortInternalConfig: else: assert False + if cfg.subsampling_spikes_per_channel is not None: + subsampling_spikes = cfg.subsampling_spikes_per_channel * n_channels + else: + subsampling_spikes = None + return DARTsortInternalConfig( waveform_cfg=waveform_cfg, featurization_cfg=featurization_cfg, @@ -1360,7 +1368,7 @@ def to_internal_config(cfg) -> DARTsortInternalConfig: save_everything_on_error=cfg.save_everything_on_error, link_from=cfg.link_from, link_step=cfg.link_step, - subsampling_spikes=cfg.subsampling_spikes, + subsampling_spikes=subsampling_spikes, subsampling_presence=cfg.subsampling_presence, always_save_final_tpca_feature=cfg.always_save_final_tpca_feature, ) diff --git a/tests/test_config.py b/tests/test_config.py index 4e66d68e..116ab769 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -5,7 +5,7 @@ def test_cfg_consistency(): """Ensure config.py and internal_config.py don't diverge.""" - cfg0 = dartsort.to_internal_config(dartsort.DeveloperConfig()) + cfg0 = dartsort.to_internal_config(dartsort.DeveloperConfig(), 10) cfg1 = dartsort.DARTsortInternalConfig() # can just do assert cfg0 == cfg1, but pytest gives a better From b9cb7dd7b813165019e213632ab0f2eb4bbec83f Mon Sep 17 00:00:00 2001 From: Charlie Windolf Date: Thu, 18 Jun 2026 18:02:40 -0400 Subject: [PATCH 33/46] decollider: investigate training in svd land --- .../transform/_multichan_denoiser_kit.py | 14 +++- src/dartsort/transform/decollider.py | 74 ++++++++++++++++--- src/dartsort/transform/temporal_pca.py | 44 +++++++---- src/dartsort/transform/transform_base.py | 25 ++++++- 4 files changed, 127 insertions(+), 30 deletions(-) diff --git a/src/dartsort/transform/_multichan_denoiser_kit.py b/src/dartsort/transform/_multichan_denoiser_kit.py index 99a65622..da0a9e2e 100644 --- a/src/dartsort/transform/_multichan_denoiser_kit.py +++ b/src/dartsort/transform/_multichan_denoiser_kit.py @@ -32,6 +32,7 @@ def __init__( n_epochs=75, pad_depth_only=True, channelwise_dropout_p=0.0, + svd_projection_rank: int | None = None, with_conv_fullheight=False, val_split_p=0.0, min_epochs=10, @@ -85,6 +86,7 @@ def __init__( self.res_type = res_type self.inference_batch_size = inference_batch_size self.epoch_size = epoch_size + self.svd_projection_rank = svd_projection_rank model_channel_index = regularize_channel_index( geom=self.geom, channel_index=channel_index, depth_only=pad_depth_only @@ -127,7 +129,11 @@ def initialize_shapes(self): ) # we don't know these dimensions til we see a spike assert self.spike_length_samples is not None - self.wf_dim = self.spike_length_samples * self.b.model_channel_index.shape[1] + if self.svd_projection_rank: + dim0 = self.svd_projection_rank + else: + dim0 = self.spike_length_samples + self.wf_dim = dim0 * self.b.model_channel_index.shape[1] self.output_dim = self.wf_dim def get_optimizer(self): @@ -196,8 +202,12 @@ def get_mlp( hidden_dims = self.hidden_dims if output_layer is None: output_layer = "gated_linear" if self.signal_gates else "linear" + if self.svd_projection_rank: + dim0 = self.svd_projection_rank + else: + dim0 = self.spike_length_samples return nn_util.get_waveform_mlp( - self.spike_length_samples, + dim0, self.b.model_channel_index.shape[1], hidden_dims, self.output_dim, diff --git a/src/dartsort/transform/decollider.py b/src/dartsort/transform/decollider.py index bcde2cf0..ba391042 100644 --- a/src/dartsort/transform/decollider.py +++ b/src/dartsort/transform/decollider.py @@ -28,6 +28,7 @@ class Decollider(BaseMultichannelDenoiser): """Unsupervised spike waveform denoising.""" + default_name = "decollider" needs_residual = True @@ -48,6 +49,7 @@ def __init__( pad_depth_only=True, channelwise_dropout_p=0.0, with_conv_fullheight=False, + svd_projection_rank: int | None = None, val_split_p=0.0, min_epochs=10, earlystop_eps=None, @@ -105,6 +107,7 @@ def __init__( pad_depth_only=pad_depth_only, channelwise_dropout_p=channelwise_dropout_p, with_conv_fullheight=with_conv_fullheight, + svd_projection_rank=svd_projection_rank, val_split_p=val_split_p, min_epochs=min_epochs, earlystop_eps=earlystop_eps, @@ -143,6 +146,8 @@ def __init__( self.detach_cycle_loss = detach_cycle_loss self.clip_value = clip_value self.clip_norm = clip_norm + if self.svd_projection_rank: + self.submodule_names = ["tpca"] if separate_cycle_net: assert cycle_loss_alpha > 0 @@ -176,6 +181,19 @@ def initialize_spike_length_dependent_params(self): ) else: self.den_net: torch.nn.Module = self.inf_net + if self.svd_projection_rank: + from .temporal_pca import BaseTemporalPCA + + self.tpca = BaseTemporalPCA( + self.b.channel_index, + geom=self.b.geom, + waveform_cfg=self.waveform_cfg, + rank=self.svd_projection_rank, + ) + self.tpca.spike_length_samples = self.spike_length_samples + self.tpca.initialize_spike_length_dependent_params() + else: + self.tpca = None self.to(self.device) def fit( @@ -192,17 +210,27 @@ def fit( super().fit( recording, waveforms, computation_cfg=computation_cfg, channels=channels ) + if self.tpca is not None and self.tpca.needs_fit(): + self.tpca.fit( + recording=recording, + waveforms=waveforms, + computation_cfg=computation_cfg, + channels=channels, + ) train_data, val_data = self._construct_datasets_from_waveforms( waveforms, channels, recording, weights, hdf5_filename=hdf5_filename ) with torch.enable_grad(): + if self.tpca is not None: + self.tpca.eval() res = self._fit(train_data, val_data) self._needs_fit = False return res def forward_unbatched(self, waveforms, channels): """Called only at inference time.""" - # TODO: batch all of this. + if self.tpca is not None: + waveforms = self.tpca.force_embed(waveforms) waveforms, masks = self.to_nn_channels(waveforms, channels) net_input = waveforms, masks.unsqueeze(1) @@ -266,6 +294,9 @@ def forward_unbatched(self, waveforms, channels): pred = self.to_orig_channels(pred, channels) + if self.tpca is not None: + pred = self.tpca.force_reconstruct(pred) + return pred def train_forward(self, y, m, ell, mask): @@ -559,15 +590,6 @@ def _run_train_loop(self, train_data, val_data): train_losses = {} for waveform_b, channels_b, noise_b, cnoise_b in train_data: waveform_b = waveform_b.to(device=self.device, non_blocking=True) - channels_b = channels_b.to(device=self.device, non_blocking=True) - waveform_b = reindex( - channels_b, - waveform_b, - self.relative_index, - pad_value=0.0, - ) - - optimizer.zero_grad() m = noise_b.to( dtype=waveform_b.dtype, device=self.device, non_blocking=True ) @@ -579,6 +601,24 @@ def _run_train_loop(self, train_data, val_data): ) else: ell = None + + if self.tpca is not None: + with torch.no_grad(): + waveform_b = self.tpca.force_embed(waveform_b) + m = self.tpca.force_embed(m) + if ell is not None: + ell = self.tpca.force_embed(ell) + + channels_b = channels_b.to(device=self.device, non_blocking=True) + waveform_b = reindex( + channels_b, + waveform_b, + self.relative_index, + pad_value=0.0, + ) + + optimizer.zero_grad() + mask = self.get_masks(channels_b).to( dtype=waveform_b.dtype, device=self.device, non_blocking=True ) @@ -596,9 +636,13 @@ def _run_train_loop(self, train_data, val_data): loss.backward() if self.clip_value is not None: - torch.nn.utils.clip_grad_value_(self.parameters(), self.clip_value) + torch.nn.utils.clip_grad_value_( + self.parameters(), self.clip_value + ) if self.clip_norm is not None: - torch.nn.utils.clip_grad_norm_(self.parameters(), self.clip_norm) + torch.nn.utils.clip_grad_norm_( + self.parameters(), self.clip_norm + ) optimizer.step() for k, v in loss_dict.items(): @@ -622,6 +666,12 @@ def _run_train_loop(self, train_data, val_data): channels_b = channels_b.to(self.device) noise_b = noise_b.to(self.device) ell_b = None if ell_b is None else ell_b.to(self.device) + if self.tpca is not None: + with torch.no_grad(): + waveform_b = self.tpca.force_embed(waveform_b) + noise_b = self.tpca.force_embed(noise_b) + if ell_b is not None: + ell_b = self.tpca.force_embed(ell_b) waveform_b, mask = self.to_nn_channels( waveform_b, channels_b diff --git a/src/dartsort/transform/temporal_pca.py b/src/dartsort/transform/temporal_pca.py index 4819b771..34a48b77 100644 --- a/src/dartsort/transform/temporal_pca.py +++ b/src/dartsort/transform/temporal_pca.py @@ -22,6 +22,7 @@ class BaseTemporalPCA(BaseWaveformModule): """Base class for PCA featurizers.""" + default_name = "basis" def __init__( @@ -93,14 +94,13 @@ def fit( super().fit( recording, waveforms, computation_cfg=computation_cfg, channels=channels ) + del spike_data rg = np.random.default_rng(self.random_state) if weights is not None and waveforms.shape[0] > self.max_waveforms: weights = weights.numpy(force=True) if torch.is_tensor(weights) else weights weights = weights.astype(np.float64) weights = weights / weights.sum() - choices = rg.choice( - len(weights), p=weights, size=self.max_waveforms - ) + choices = rg.choice(len(weights), p=weights, size=self.max_waveforms) choices.sort() choices = torch.from_numpy(choices) waveforms = waveforms[choices] @@ -248,24 +248,39 @@ def force_reconstruct(self, features): if ndim == 2: features = features.unsqueeze(0) n, r, c = features.shape - waveforms = features.permute(0, 2, 1).reshape(n * c, r) - waveforms = self._inverse_transform_in_probe(waveforms) - waveforms = waveforms.reshape(n, c, -1).permute(0, 2, 1) + if self.whiten: + W = self.b.components / self.b.whitener + else: + W = self.b.components + Wt = W.t() + assert Wt.shape[1] == r + out = features.new_empty((n, Wt.shape[0], c)) + bs = self.batch_size + for i0 in range(0, n, bs): + i1 = i0 + bs + torch.matmul(Wt, features[i0:i1], out=out[i0:i1]) if ndim == 2: - waveforms = waveforms[0] - return waveforms + out = out[0] + return out def force_embed(self, waveforms): ndim = waveforms.ndim if ndim == 2: waveforms = waveforms.unsqueeze(0) n, t, c = waveforms.shape - waveforms = waveforms.mT.reshape(n * c, t) - waveforms = self._transform_in_probe(waveforms) - waveforms = waveforms.view(n, c, self.rank).mT + if self.whiten: + W = self.b.components / self.b.whitener + else: + W = self.b.components + assert W.shape[1] == t + out = waveforms.new_empty((n, W.shape[0], c)) + bs = self.batch_size + for i0 in range(0, n, bs): + i1 = i0 + bs + torch.matmul(W, waveforms[i0:i1], out=out[i0:i1]) if ndim == 2: - waveforms = waveforms[0] - return waveforms + out = out[0] + return out def force_project(self, waveforms): ndim = waveforms.ndim @@ -413,6 +428,7 @@ def initialize_from_templates(self, td): class TemporalPCADenoiser(BaseWaveformDenoiser, BaseTemporalPCA): """Spike waveform denoising with PCA.""" + default_name = "temporal_pca" def forward(self, waveforms, *, channels, time_shifts=None, **unused): @@ -467,6 +483,7 @@ def forward(self, waveforms, *, channels, **spike_data): class TemporalPCAFeaturizer(BaseWaveformFeaturizer, BaseTemporalPCA): """Spike featurization with PCA.""" + default_name = "tpca_features" def transform( @@ -527,6 +544,7 @@ def inverse_transform(self, features, channels, channel_index=None): class TemporalPCA(BaseWaveformAutoencoder, TemporalPCAFeaturizer): """Combined spike featurization and denoising with PCA.""" + default_name = "tpca_features" def forward(self, waveforms, *, channels, time_shifts=None, **unused): diff --git a/src/dartsort/transform/transform_base.py b/src/dartsort/transform/transform_base.py index 0defad80..606e8588 100644 --- a/src/dartsort/transform/transform_base.py +++ b/src/dartsort/transform/transform_base.py @@ -41,6 +41,7 @@ def __init__( if name_prefix: name = f"{name_prefix}_{name}" self.name = name + self.submodule_names = None # these buffers below need to be copied, else they share references # across all the transformers which seems to cause problems! if channel_index is not None: @@ -131,9 +132,22 @@ def _other_pre_load_state(self, state_dict, prefix): def _pre_load_state(self, state_dict, prefix, *args, **kwargs): # wish torch would strip the prefix for us? extra_state_keys = [k for k in state_dict.keys() if k.endswith("_extra_state")] - assert len(extra_state_keys) <= 1 - if extra_state_keys: - extra_state = state_dict[extra_state_keys[0]] + + all_submodule_keys = [] + if self.submodule_names: + for sn in self.submodule_names: + sn_keys = [ + k for k in extra_state_keys if k.endswith(f"{sn}._extra_state") + ] + assert len(sn_keys) == 1 + all_submodule_keys.append(sn_keys[0]) + + my_extra_state_keys = [ + k for k in extra_state_keys if k not in all_submodule_keys + ] + assert len(my_extra_state_keys) <= 1 + if my_extra_state_keys: + extra_state = state_dict[my_extra_state_keys[0]] # some modules want to know the spike length before loading the state dict # and unfortunately set_extra_state usually runs after. doesn't hurt to run now. @@ -147,6 +161,11 @@ def _pre_load_state(self, state_dict, prefix, *args, **kwargs): self._other_pre_load_state(state_dict, prefix) + if self.submodule_names: + for sn, smk in zip(self.submodule_names, all_submodule_keys): + sn_dict = {smk: state_dict[smk]} + getattr(self, sn)._pre_load_state(sn_dict, prefix, *args, **kwargs) + def initialize_spike_length_dependent_params(self): pass From 2e39bfd94939be6475e650e56c9671e55513a0a8 Mon Sep 17 00:00:00 2001 From: Charlie Windolf Date: Thu, 18 Jun 2026 18:02:57 -0400 Subject: [PATCH 34/46] clean / fix enfdec accidentally on bug --- src/dartsort/templates/postprocess_util.py | 4 ++-- src/dartsort/transform/pipeline.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/dartsort/templates/postprocess_util.py b/src/dartsort/templates/postprocess_util.py index 19a06126..11aaa10d 100644 --- a/src/dartsort/templates/postprocess_util.py +++ b/src/dartsort/templates/postprocess_util.py @@ -384,8 +384,6 @@ def _handle_merge( if merge_cfg is None or not merge_cfg.merge_distance_threshold: return sorting, template_data - from ..clustering.merge import merge_templates - if template_cfg.denoising_method == "svd": # use new shared basis stuff from ..clustering.agglomerate import agglomerate @@ -406,6 +404,8 @@ def _handle_merge( del agg else: # TODO: remove old impl? + from ..clustering.merge import merge_templates + merge_shift_samples = waveform_cfg.ms_to_samples(merge_cfg.max_shift_ms) merge_res = merge_templates( sorting=sorting, diff --git a/src/dartsort/transform/pipeline.py b/src/dartsort/transform/pipeline.py index dd91585b..05d12ad1 100644 --- a/src/dartsort/transform/pipeline.py +++ b/src/dartsort/transform/pipeline.py @@ -638,7 +638,7 @@ def _add_localization_and_ampvec(fc): ) ) - if fc.do_enforce_decrease == "loc_only" and fc.do_localization: + if (fc.do_enforce_decrease == "loc_only") and (fc.do_localization and do_feats): more.append(("EnforceDecrease", {})) if do_feats and fc.do_localization and fc.nn_localization: From f9be6753c6ecd7aa18ea470242bff0840e05d796 Mon Sep 17 00:00:00 2001 From: Charlie Windolf Date: Sat, 20 Jun 2026 16:45:49 -0400 Subject: [PATCH 35/46] more params needed for demolish vis --- src/dartsort/vis/mixture.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/dartsort/vis/mixture.py b/src/dartsort/vis/mixture.py index ca41d1c7..4809c89d 100644 --- a/src/dartsort/vis/mixture.py +++ b/src/dartsort/vis/mixture.py @@ -615,6 +615,8 @@ def compute(self, mix_data: MixtureVisData, unit_id: int): mean_eval_resp=mix_data.mean_eval_resp, train_scores=mix_data.train_scores, eval_scores=mix_data.eval_scores, + train_data=mix_data.train_data, + eval_data=mix_data.val_data, # type: ignore cur_crit=None, ) return group_res From f7bae7260732b0f8dad479db47178cd92820aea8 Mon Sep 17 00:00:00 2001 From: Charlie Windolf Date: Sat, 20 Jun 2026 16:46:20 -0400 Subject: [PATCH 36/46] gmm: choice to refit within demolish --- src/dartsort/clustering/mixture.py | 152 +++++++++++++++++++++++++---- 1 file changed, 132 insertions(+), 20 deletions(-) diff --git a/src/dartsort/clustering/mixture.py b/src/dartsort/clustering/mixture.py index 8d699f73..bcee89d1 100644 --- a/src/dartsort/clustering/mixture.py +++ b/src/dartsort/clustering/mixture.py @@ -813,6 +813,7 @@ class TMMParams: robust_strategy: RobustnessStrategy demolition_min_resp_ratio: float demolish_during_selection: bool + refit_in_demolition: bool kmeans_tries: int kmeans_beta: float kmeanspp_tries: int @@ -850,6 +851,7 @@ def from_refinement_cfg(cls, refinement_cfg: RefinementConfig): robust_strategy=refinement_cfg.robust_strategy, demolition_min_resp_ratio=refinement_cfg.demolition_min_resp_ratio, demolish_during_selection=refinement_cfg.demolish_during_selection, + refit_in_demolition=refinement_cfg.refit_in_demolition, whiten_split=refinement_cfg.whiten_split, scale_dist_args=refinement_cfg.scale_dist_args, whiten_dist=refinement_cfg.whiten_dist, @@ -2349,6 +2351,7 @@ def initialize_from_dense_data_with_fixed_responsibilities( total_log_proportion: float, noise_log_prop: Tensor | float = -torch.inf, p: TMMParams, + min_channel_count: int | None = None, ) -> tuple[Self, Tensor, DenseSpikeData, bool, Tensor, Tensor]: """Fit units with fixed label posterior @@ -2361,14 +2364,16 @@ def initialize_from_dense_data_with_fixed_responsibilities( data=data, rank=signal_rank, erp=erp, - min_channel_count=p.min_channel_count, + min_channel_count=p.min_channel_count + if min_channel_count is None + else min_channel_count, noise=noise, weights=responsibilities, latent_prior_std=p.latent_prior_std, prior_pseudocount=p.prior_pseudocount, ) assert chan_coverage is not None - valid = torch.tensor([r is not None for r in initialization]) + valid = torch.tensor([r is not None for r in initialization], dtype=torch.bool) initialization = [r for r in initialization if r is not None] responsibilities = responsibilities[:, valid] K = responsibilities.shape[1] @@ -2677,13 +2682,14 @@ def score( *, skip_noise: bool = False, allow_blanks: bool = False, + skip_responsibility: bool = True, ) -> Scores: scores = [] for batch in data.to_batches(self.unit_ids, self.lut): batch_scores = self.score_batch( batch=batch, n_candidates=batch.candidates.shape[1], - skip_responsibility=True, + skip_responsibility=skip_responsibility, skip_noise=skip_noise, allow_blanks=allow_blanks, ) @@ -3470,6 +3476,8 @@ def demolish( train_scores=train_scores, eval_scores=val_scores, cur_crit=cur_crit, + train_data=train_data, + eval_data=val_data, ) if _stop_after and j > _stop_after: # for profiling @@ -5110,7 +5118,7 @@ def _fit_subset_models( def _score_subset_models( mm: BaseMixtureModel, subset_models: BaseMixtureModel, - train_full_scores: Scores, + train_full_scores: Scores | None, cur_scores: Scores, train_data: DenseSpikeData, eval_data: DenseSpikeData | None, @@ -5149,6 +5157,7 @@ def _score_subset_models( crit_subset_scores = subset_models.score(eval_data, skip_noise=True) else: assert eval_data is None + assert train_full_scores is not None crit_subset_scores = train_subset_scores crit_full_scores = train_full_scores @@ -5411,8 +5420,11 @@ def try_kmeans( def evaluate_group_demolitions( + *, mm: TruncatedMixtureModel, group: Tensor, + train_data: TruncatedSpikeData, + eval_data: TruncatedSpikeData, mean_train_resp: Tensor | None, mean_eval_resp: Tensor | None, cur_crit: float | None, @@ -5427,7 +5439,7 @@ def evaluate_group_demolitions( if mean_eval_resp is None: mean_eval_resp = mean_responsibilities(scores=eval_scores, n_units=mm.n_units) ratio = mean_train_resp[group] / mean_eval_resp[group] - can_demolish = ratio > mm.p.demolition_min_resp_ratio + can_demolish = ratio > 0 # mm.p.demolition_min_resp_ratio if not can_demolish.any(): return GroupDemolition(unit_ids=group, improvement=0.0, demolished=None) @@ -5446,16 +5458,46 @@ def evaluate_group_demolitions( unit_ids=group, improvement=0.0, demolished=torch.zeros_like(can_demolish) ) best_imp = 0.0 - for demo_mask in submasks(can_demolish): - crit = _evaluate_single_demolition( - orig_log_props=mm.b.log_proportions, - noise_log_prop=mm.b.noise_log_prop, - cl_alpha=alpha, - group=group_, - demolish_mask=demo_mask, - train_scores=train_scores, - eval_scores=eval_scores, + + if mm.p.refit_in_demolition: + group = group.to(device=train_scores.candidates.device) + (train_ixs,) = ( + torch.isin(train_scores.candidates, group).any(dim=1).nonzero(as_tuple=True) ) + group_train_data = train_data.dense_slice(train_ixs) + assert group_train_data is not None + group_train_scores = train_scores.slice(group_train_data.indices) + + (eval_ixs,) = ( + torch.isin(eval_scores.candidates, group).any(dim=1).nonzero(as_tuple=True) + ) + group_eval_data = eval_data.dense_slice(eval_ixs) + assert group_eval_data is not None + group_eval_scores = eval_scores.slice(group_eval_data.indices) + nongroup_eval_scores = remove_units_from_scores(group_eval_scores, group) + + for demo_mask in submasks(can_demolish, skip_full=True): + if mm.p.refit_in_demolition: + crit = _evaluate_single_refit_demolition( + group=group, + demolish_mask=demo_mask, + group_train_scores=group_train_scores, # ty: ignore[possibly-unresolved-reference] + nongroup_eval_scores=nongroup_eval_scores, # ty: ignore[possibly-unresolved-reference] + group_eval_scores=group_eval_scores, # ty: ignore[possibly-unresolved-reference] + group_eval_data=group_eval_data, # ty: ignore[possibly-unresolved-reference] + group_train_data=group_train_data, # ty: ignore[possibly-unresolved-reference] + mm=mm, + ) + else: + crit = _evaluate_single_demolition( + orig_log_props=mm.b.log_proportions, + noise_log_prop=mm.b.noise_log_prop, + cl_alpha=alpha, + group=group_, + demolish_mask=demo_mask, + train_scores=train_scores, + eval_scores=eval_scores, + ) imp = crit - cur_crit if imp > best_imp: best_demo = GroupDemolition( @@ -5507,13 +5549,83 @@ def _evaluate_single_demolition( ) -def submasks(mask: Tensor, skip_empty=True): +def _evaluate_single_refit_demolition( + group: Tensor, + demolish_mask: Tensor, + group_train_scores: Scores, + nongroup_eval_scores: Scores, + group_eval_scores: Scores, + group_eval_data: DenseSpikeData, + group_train_data: DenseSpikeData, + mm: TruncatedMixtureModel, +) -> float: + assert demolish_mask.shape == group.shape + + # determine mean responsibility after demolition on train set + demolish_mask = demolish_mask.cpu() + chopping_block = group[demolish_mask] + nchop = chopping_block.numel() + if not nchop: + return ecl( + resps=group_eval_scores.responsibilities, + log_liks=group_eval_scores.log_liks, + cl_alpha=mm.p.cl_alpha, + ) + group_remain = group[demolish_mask.logical_not()] + n_remain = group.numel() - nchop + train_scores_adj = remove_units_from_scores(group_train_scores, chopping_block) + assert mm.erp is not None + assert mm.noise is not None + + # responsibilities for re-fitting rest of group + resp0 = train_scores_adj.responsibilities + assert resp0 is not None + cand0 = train_scores_adj.candidates + chopped_responsibilities = mm.b.log_proportions.new_zeros( + (resp0.shape[0], n_remain) + ) + for j, cand in enumerate(group_remain): + cii, cjj = (cand0 == cand).nonzero(as_tuple=True) + chopped_responsibilities[cii, j] = resp0[cii, cjj] + chopped_responsibilities.clamp_(min=1e-8) + + # re-fit model to train_scores_adj + chopped_model, s0valid, _, s0discard, s0mask, _ = ( + TruncatedMixtureModel.initialize_from_dense_data_with_fixed_responsibilities( + data=group_train_data, + responsibilities=chopped_responsibilities, + signal_rank=mm.signal_rank, + erp=mm.erp, + noise=mm.noise, + neighb_cov=mm.neighb_cov, + total_log_proportion=mm.non_noise_log_proportion(), + p=mm.p, + min_channel_count=0, + ) + ) + + # re-score eval data and combine with bystander units + chopped_score = chopped_model.score(data=group_eval_data, skip_responsibility=False) + assert chopped_score.responsibilities is not None + final_score = concatenate_scores([chopped_score, nongroup_eval_scores], dim=1) + assert final_score.responsibilities is not None + + return ecl( + resps=final_score.responsibilities, + log_liks=final_score.log_liks, + cl_alpha=mm.p.cl_alpha, + ) + + +def submasks(mask: Tensor, skip_empty=True, skip_full=False): mask_ = mask.cpu() (on,) = mask_.nonzero(as_tuple=True) on = on.tolist() for subset in subsets(on): if skip_empty and not len(subset): continue + if skip_full and len(subset) == len(mask): + continue m = torch.zeros_like(mask) m[list(subset)] = True yield m @@ -6241,22 +6353,22 @@ def _combine_similar_resps(resps: Tensor, keep_mask: Tensor, n_keep: int) -> Ten return kept_resp -def concatenate_scores(scoress: list[Scores]) -> Scores: +def concatenate_scores(scoress: list[Scores], dim=0) -> Scores: assert len(scoress) > 0 if len(scoress) == 1: return scoress[0] - log_liks = torch.concatenate([s.log_liks for s in scoress], dim=0) - candidates = torch.concatenate([s.candidates for s in scoress], dim=0) + log_liks = torch.concatenate([s.log_liks for s in scoress], dim=dim) + candidates = torch.concatenate([s.candidates for s in scoress], dim=dim) if scoress[0].responsibilities is None: responsibilities = None else: responsibilities = torch.concatenate( - [cast(Tensor, s.responsibilities) for s in scoress], dim=0 + [cast(Tensor, s.responsibilities) for s in scoress], dim=dim ) if scoress[0].duties is None: duties = None else: - duties = torch.concatenate([cast(Tensor, s.duties) for s in scoress], dim=0) + duties = scoress[0].duties return Scores( log_liks=log_liks, responsibilities=responsibilities, From 51dba98f00f38b96990afbc9d32fc20f643ec2dd Mon Sep 17 00:00:00 2001 From: Charlie Windolf Date: Sat, 20 Jun 2026 16:46:53 -0400 Subject: [PATCH 37/46] data/glom: add export to SortingAnalyzer, expose some SpikeInterface merging tools within agglomerate() --- src/dartsort/clustering/agglomerate.py | 129 +++++++++++++++++++++++-- src/dartsort/clustering/clustering.py | 1 + src/dartsort/config.py | 3 +- src/dartsort/util/data_util.py | 121 ++++++++++++++++++++++- src/dartsort/util/internal_config.py | 10 +- 5 files changed, 252 insertions(+), 12 deletions(-) diff --git a/src/dartsort/clustering/agglomerate.py b/src/dartsort/clustering/agglomerate.py index 856163c5..ec218c76 100644 --- a/src/dartsort/clustering/agglomerate.py +++ b/src/dartsort/clustering/agglomerate.py @@ -56,7 +56,7 @@ class Agglomeration: def agglomerate( *, sorting: DARTsortSorting, - recording: BaseRecording | None, + recording: BaseRecording, template_merge_cfg: TemplateMergeConfig | None, refinement_cfg: RefinementConfig | None, motion: MotionInfo, @@ -104,7 +104,7 @@ def agglomerate( ) # tdist tells us the possible merges - mask = linkage_mask( + distance_mask = linkage_mask( tdist.distances, linkage_method=template_merge_cfg.linkage, threshold=template_merge_cfg.merge_distance_threshold, @@ -117,9 +117,9 @@ def agglomerate( dt=refinement_cfg.glom_firing_corr_dt, method=refinement_cfg.glom_firing_corr_method, ) - _oldsum = mask[np.triu_indices_from(mask)].sum() + _oldsum = distance_mask[np.triu_indices_from(distance_mask)].sum() fcorr_mask = fcorr <= refinement_cfg.glom_max_firing_corr - mask = np.logical_and(mask, fcorr_mask) + mask = np.logical_and(distance_mask, fcorr_mask) np.fill_diagonal(mask, True) _newsum = mask[np.triu_indices_from(mask)].sum() logger.dartsortdebug( @@ -127,6 +127,7 @@ def agglomerate( ) else: fcorr = fcorr_mask = None + mask = distance_mask # restrict mask by overlap criteria qda_res = qda( @@ -158,19 +159,37 @@ def agglomerate( assert np.all(qda_mask <= mask) if fcorr_mask is not None: assert np.all(np.logical_and(qda_mask, fcorr_mask) <= mask) + + if refinement_cfg.spikeinterface_merge_preset is not None: + si_mask = spikeinterface_merge_mask( + recording=recording, + sorting=sorting, + preset=refinement_cfg.spikeinterface_merge_preset, + censor_ms=refinement_cfg.censor_ms, + template_data=tdist.template_data, + pair_mask=tdist.distances < refinement_cfg.spikeinterface_merge_max_distance, + ) + else: + si_mask = None + + # force merges for very close neighbors force_mask = linkage_mask( tdist.distances, linkage_method=template_merge_cfg.linkage, threshold=refinement_cfg.qda_force_merge_for_temp_dist_below, ) - qda_mask = np.logical_or(qda_mask, force_mask) - np.fill_diagonal(qda_mask, True) - qda_as_dist = np.logical_not(qda_mask).astype(np.float32) + + # extract final mask + final_mask = np.logical_or(qda_mask, force_mask) + if si_mask is not None: + final_mask = np.logical_or(final_mask, si_mask) + np.fill_diagonal(final_mask, True) + final_mask_as_distance = np.logical_not(final_mask).astype(np.float32) agg_sorting, new_ids = recluster( sorting=sorting, unit_ids=tdist.template_data.unit_ids, - dists=qda_as_dist, + dists=final_mask_as_distance, shifts=tdist.shifts, unit_snrs=tdist.template_data.snrs_by_channel().max(1), threshold=0.5, @@ -352,6 +371,100 @@ def _get_scores(sorting: DARTsortSorting) -> tuple[np.ndarray, Scores]: return labels, scores +def spikeinterface_merge_mask( + *, + recording: BaseRecording, + sorting: DARTsortSorting, + preset: str | None, + censor_ms: float = 0.0, + template_data: TemplateData, + pair_mask: np.ndarray, + min_count: int = 100, +): + from spikeinterface.curation.auto_merge import compute_merge_unit_groups + from spikeinterface.postprocessing import ComputeTemplateSimilarity + + # censor first + if censor_ms: + sorting = deduplicate_spikes(sorting, censor_ms) + + # analyzer + analyzer = sorting.to_sorting_analyzer( + recording=recording, template_data=template_data + ) + + # register the mask as the template similarity extension + tsim_ext = ComputeTemplateSimilarity(analyzer) + tsim_ext.data = {"similarity": pair_mask.astype(np.float32)} + tsim_ext.params = {"method": "dartsort"} + tsim_ext.run_info = {"run_completed": True} + analyzer.extensions["template_similarity"] = tsim_ext + + # handle custom presets + if preset == "dartsort_slay_xc": + steps = [ + "num_spikes", + "remove_contaminated", + "unit_locations", + "template_similarity", + "cross_contamination", + "slay_score", + "quality_score", + ] + preset = None + analyzer.compute_one_extension("correlograms") + elif preset == "dartsort_slay_ccg": + steps = [ + "num_spikes", + "remove_contaminated", + "unit_locations", + "template_similarity", + "correlogram", + "slay_score", + "quality_score", + ] + preset = None + analyzer.compute_one_extension("correlograms") + elif preset == "dartsort_slay_xc_ccg": + steps = [ + "num_spikes", + "remove_contaminated", + "unit_locations", + "template_similarity", + "correlogram", + "cross_contamination", + "slay_score", + "quality_score", + ] + preset = None + analyzer.compute_one_extension("correlograms") + else: + assert preset is not None + steps = None + + # make parameters aware of censorship and other params + my_step_params = { + "num_spikes": {"min_spikes": min_count}, + "remove_contaminated": {"censored_period_ms": censor_ms}, + "template_similarity": {"similarity_method": "dartsort"}, + "correlogram": {"censor_correlograms_ms": censor_ms}, + "cross_contamination": {"censored_period_ms": censor_ms}, + "quality_score": {"censored_period_ms": censor_ms}, + } + groups = compute_merge_unit_groups( + preset=preset, + steps=steps, + sorting_analyzer=analyzer, + steps_params=my_step_params, + force_copy=False, + ) + mask = np.zeros_like(pair_mask) + for g in groups: + g = np.array(g) + mask[g[:, None], g[None, :]] = True + return mask + + @databag class QDAResult: """Unit pair QDA metrics diff --git a/src/dartsort/clustering/clustering.py b/src/dartsort/clustering/clustering.py index 57ed20e4..b940a3fa 100644 --- a/src/dartsort/clustering/clustering.py +++ b/src/dartsort/clustering/clustering.py @@ -855,6 +855,7 @@ def _refine( recording: BaseRecording | None, motion: MotionInfo, ): + assert recording is not None return agglomerate.agglomerate( recording=recording, sorting=sorting, diff --git a/src/dartsort/config.py b/src/dartsort/config.py index c8b40a17..a4150e8a 100644 --- a/src/dartsort/config.py +++ b/src/dartsort/config.py @@ -231,7 +231,7 @@ class DeveloperConfig(DARTsortUserConfig): """Additional parameters for experiments. This API will never be stable.""" # high level behavior - initial_steps: Sequence[MixtureStep] = ("split", "demolish") + initial_steps: Sequence[MixtureStep] = ("split", "demolish", "demolish") later_steps: Sequence[MixtureStep] = ("split", "merge", "demolish") detection_type: Literal["subtract", "match", "threshold"] = "subtract" cluster_strategy: str = "dpc" @@ -258,6 +258,7 @@ class DeveloperConfig(DARTsortUserConfig): temporal_dedup_radius_samples: int = 11 positive_temporal_dedup_radius_samples: int = 41 spatial_dedup_radius_um: float | None = 50.0 + spikeinterface_merge_preset: str | None = None # matching matching_template_type: Literal["individual_compressed_upsampled", "drifty"] = ( diff --git a/src/dartsort/util/data_util.py b/src/dartsort/util/data_util.py index 46d81a95..aa3f6af1 100644 --- a/src/dartsort/util/data_util.py +++ b/src/dartsort/util/data_util.py @@ -11,14 +11,21 @@ BaseRecording, BaseSorting, NumpySorting, + SortingAnalyzer, + create_sorting_analyzer, get_random_data_chunks, ) from ..detect import detect_and_deduplicate -from .internal_config import WaveformConfig, default_waveform_cfg +from .internal_config import ( + WaveformConfig, + default_clustering_features_cfg, + default_waveform_cfg, +) from .logging_util import get_logger, progbar if TYPE_CHECKING: + from ..templates.templates import TemplateData from .motion import MotionInfo from .job_util import ensure_computation_config from .py_util import ensure_path @@ -130,6 +137,7 @@ def to_numpy_sorting( labels_list=labels, sampling_frequency=st.sampling_frequency, ) + numpy_sorting._compute_and_cache_spike_vector() if return_kept_indices: # kept_indices[i] is the original index of numpy_sorting's ith spike kept_indices = st.mask_indices[order] @@ -211,6 +219,96 @@ def to_tsgroup( trains[unit_id] = Tsd(t=ut, d=uw) return TsGroup(trains, metadata=metadata) + def to_sorting_analyzer( + self, + recording: BaseRecording, + template_data: "TemplateData | None" = None, + drop_doubles: bool = True, + features_cfg=default_clustering_features_cfg, + ) -> SortingAnalyzer: + """Export dartsort's internal data to a SortingAnalyzer + + This will first call to_numpy_sorting() and then register some of the sorting's + features as extensions for the analyzer. + + If template_data is supplied, a templates extension will be registered. + + The implementation is based on SpikeInterface's `read_kilosort_as_analyzer()`, + thanks to Chris Halcrow for that. + + TODO: This doesn't handle gain_to_uV or random waveforms, sparsity, or probably + other important things. + + Parameters + ---------- + recording : BaseRecording + template_data : TemplateData | None, optional + features_cfg : ClusteringFeaturesConfig + Stores attribute dataset names + + Returns + ------- + analyzer : SortingAnalyzer + """ + from spikeinterface.core import ComputeTemplates + from spikeinterface.postprocessing import ( + ComputeSpikeLocations, + ComputeUnitLocations, + ) + + sorting, kept_indices = self.to_numpy_sorting(drop_doubles=drop_doubles, return_kept_indices=True) # type: ignore + + analyzer = create_sorting_analyzer( + sorting=sorting, recording=recording, sparse=False, return_in_uV=False + ) + + loc_name = features_cfg.localizations_dataset_name + if (locs := self.localizations_as_structured_array(loc_name)) is not None: + locs = locs[kept_indices] + loc_ext = ComputeSpikeLocations(analyzer) + loc_ext.data = {"spike_locations": locs} + loc_ext.params = {} + loc_ext.run_info = {"run_completed": True} + analyzer.extensions["spike_locations"] = loc_ext + + if template_data is not None: + td_ext = ComputeTemplates(analyzer) + assert np.array_equal( + template_data.unit_ids, np.arange(len(template_data.unit_ids)) + ) + s_before = template_data.trough_offset_samples + s_after = template_data.spike_length_samples - s_before + ms_per_sample = 1000 / self.sampling_frequency + + td_ext.data = {"average": template_data.templates} + td_ext.params = { + "operators": ["average"], + "ms_before": s_before * ms_per_sample, + "ms_after": s_after * ms_per_sample, + "peak_sign": "both", + } + td_ext.run_info = {"run_completed": True} + analyzer.extensions["templates"] = td_ext + + uloc_ext = ComputeUnitLocations(analyzer) + uloc_ext.params = {"method": "monopolar_triangulation"} + # turns out they don't want a structured array in this extension + ulocs = template_data.template_locations(mode="localization") + uloc_ext.data = {"unit_locations": ulocs} + uloc_ext.run_info = {"run_completed": True} + analyzer.extensions["unit_locations"] = uloc_ext + + return analyzer + + def localizations_as_structured_array( + self, property_name="point_source_localizations" + ) -> np.ndarray | None: + locs = getattr(self, property_name, None) + if locs is None: + return None + + return si_structured_localizations_array(locs) + def permute_labels(self, seed=0): assert self.labels is not None rg = np.random.default_rng(seed) @@ -1038,6 +1136,27 @@ def sorting_from_spikeinterface( ) +def si_structured_localizations_array(locs: np.ndarray) -> np.ndarray: + """Convert our localization format to SpikeInterface's""" + # NB: spikeinterface's y is our z. I like theirs better, I'm sorry, it's not my fault. + if locs.shape[1] >= 3: + column_names = ["x", "y", "z"] + elif locs.shape[1] == 2: + column_names = ["x", "y"] + else: + assert False + dtype = [(name, locs.dtype) for name in column_names] + + structured_array = np.zeros(len(locs), dtype=dtype) + structured_array["x"] = locs[:, 0] + if locs.shape[1] >= 3: + structured_array["y"] = locs[:, 2] + structured_array["z"] = locs[:, 1] + elif locs.shape[1] == 2: + structured_array["y"] = locs[:, 1] + + return structured_array + def filter_link_h5(in_h5_path: str | Path, out_h5_path: str | Path, keep_filter): in_h5_path = ensure_path(in_h5_path, strict=True) out_h5_path = ensure_path(out_h5_path) diff --git a/src/dartsort/util/internal_config.py b/src/dartsort/util/internal_config.py index e12c5db6..c1ca7e47 100644 --- a/src/dartsort/util/internal_config.py +++ b/src/dartsort/util/internal_config.py @@ -529,8 +529,9 @@ class RefinementConfig: robust_fixed_std_dataset: str = "collidedness" robust_fixed_power: float = 40.0 robust_df: float = 4.0 - demolition_min_resp_ratio: float = 1.1 + demolition_min_resp_ratio: float = 0.9 demolish_during_selection: bool = False + refit_in_demolition: bool = False em_after_demolish: bool = False whiten_split: bool = True scale_dist_args: tuple[float, float, float] = (0.01, 3.0 / 4.0, 4.0 / 3.0) @@ -550,6 +551,8 @@ class RefinementConfig: qda_min_coverage: float = 0.35 qda_min_iou: float = 0.5 qda_force_merge_for_temp_dist_below: float = 0.3 + spikeinterface_merge_preset: str | None = None + spikeinterface_merge_max_distance: float = 0.5 # forward_backward parameters chunk_size_s: float = 300.0 @@ -565,6 +568,7 @@ class RefinementConfig: # deduplication control dedup_ms: float = 0.0 + censor_ms: float = 0.3 @cfg_dataclass @@ -898,7 +902,7 @@ def is_multi_gpu(self): default_motion_estimation_cfg = MotionEstimationConfig() default_computation_cfg = ComputationConfig() default_refinement_cfg = RefinementConfig() -default_initial_refinement_cfg = RefinementConfig(mixture_steps=("split", "demolish")) +default_initial_refinement_cfg = RefinementConfig(mixture_steps=("split", "demolish", "demolish")) default_pre_refinement_cfg = RefinementConfig(refinement_strategy="pcmerge") default_agglomerate_cfg = RefinementConfig( refinement_strategy="agglomerate", @@ -1306,6 +1310,7 @@ def to_internal_config(cfg, n_channels: int) -> DARTsortInternalConfig: template_merge_cfg=agg_tmcfg, qda_threshold=0.0, dedup_ms=cfg.deduplication_ms, + spikeinterface_merge_preset=cfg.spikeinterface_merge_preset, ) elif cfg.agg_kind == "qda": agg_whiten_cfg = WhiteningConfig( @@ -1326,6 +1331,7 @@ def to_internal_config(cfg, n_channels: int) -> DARTsortInternalConfig: template_merge_cfg=agg_tmcfg, qda_force_merge_for_temp_dist_below=cfg.agg_no_qda_template_distance, dedup_ms=cfg.deduplication_ms, + spikeinterface_merge_preset=cfg.spikeinterface_merge_preset, ) else: assert False From 6aa3c69a43119c92ff433c5a4b60e18d3ac6a70a Mon Sep 17 00:00:00 2001 From: Charlie Windolf Date: Sat, 20 Jun 2026 17:20:10 -0400 Subject: [PATCH 38/46] vis: impl acg/ccg in ms --- src/dartsort/vis/analysis_plots.py | 35 ++++++++++++++-- src/dartsort/vis/unit.py | 66 ++++++++++++++++++++---------- 2 files changed, 76 insertions(+), 25 deletions(-) diff --git a/src/dartsort/vis/analysis_plots.py b/src/dartsort/vis/analysis_plots.py index 7aa9837f..008b32c3 100644 --- a/src/dartsort/vis/analysis_plots.py +++ b/src/dartsort/vis/analysis_plots.py @@ -437,11 +437,36 @@ def stackbar(ax, x, y, colors, labels, fill=True): def plot_correlogram( - axis, times_a, times_b=None, max_lag=50, color="k", fill=True, **stairs_kwargs + axis, + times_a, + times_b=None, + max_lag=50, + samples_per_ms: float = 30.0, + color="k", + fill=True, + bin=1, + to_ms=False, + **stairs_kwargs, ): lags, ccg = correlogram(times_a, times_b=times_b, max_lag=max_lag) - axis.set_xlabel("lag (samples)") - return bar(axis, lags, ccg, fill=fill, color=color, **stairs_kwargs) + assert lags.shape == ccg.shape == (2 * max_lag + 1,) + if not to_ms: + axis.set_xlabel("lag (samples)") + return bar(axis, lags, ccg, fill=fill, color=color, **stairs_kwargs) + + max_lag_ms = max_lag / samples_per_ms + ms_lags = np.arange(-max_lag_ms, max_lag_ms + bin / 2, bin) + ms_ccg = np.zeros_like(ms_lags) + ctr = ms_lags.shape[0] // 2 + assert ms_lags.shape == (2 * ctr + 1,) + for j in range(max_lag): + binix = (j / samples_per_ms) // bin + ms_ccg[ctr + binix] += ccg[max_lag + j] + if j: + ms_ccg[ctr - binix] += ccg[max_lag - j] + + axis.set_xlabel("lag (ms)") + return bar(axis, ms_lags, ms_ccg, fill=fill, color=color, **stairs_kwargs) def visualize_denoiser( @@ -629,7 +654,9 @@ def plot_denoiser_scores( score_unwhitened=np.sqrt(np.maximum(0.0, np.concatenate(scores_unwhitened))), ) if scores_whitened is not None: - data["score_whitened"] = np.sqrt(np.maximum(0.0, np.concatenate(scores_whitened))) + data["score_whitened"] = np.sqrt( + np.maximum(0.0, np.concatenate(scores_whitened)) + ) df = pd.DataFrame(data) bins = np.arange(0.0, vmax, step=dv) diff --git a/src/dartsort/vis/unit.py b/src/dartsort/vis/unit.py index e7c349d5..cde272fa 100644 --- a/src/dartsort/vis/unit.py +++ b/src/dartsort/vis/unit.py @@ -86,15 +86,31 @@ class ACG(UnitPlot): kind = "histogram" height = 0.75 - def __init__(self, max_lag=50): + def __init__(self, max_lag=50, bin=1, unit="samples"): super().__init__() self.max_lag = max_lag + self.bin = bin + self.unit = unit def draw(self, panel, sorting_analysis: DARTsortAnalysis, unit_id: int): axis = panel.subplots() which = sorting_analysis.in_unit(unit_id) - times_samples = sorting_analysis.sorting.times_samples[which] - plot_correlogram(axis, times_samples, max_lag=self.max_lag) + t = sorting_analysis.sorting.times_samples[which] + samples_per_ms = sorting_analysis.sorting.sampling_frequency / 1000 + if self.unit == "samples": + max_lag_samples = self.max_lag + elif self.unit == "ms": + max_lag_samples = int(np.ceil(self.max_lag * samples_per_ms)) + else: + assert False + plot_correlogram( + axis, + t, + max_lag=max_lag_samples, + bin=self.bin, + samples_per_ms=samples_per_ms, + to_ms=self.unit == "ms", + ) axis.set_ylabel("acg") @@ -664,17 +680,20 @@ def draw(self, panel, sorting_analysis: DARTsortAnalysis, unit_id): class NeighborCCGPlot(UnitPlot): kind = "medium" - def __init__(self, n_neighbors=3, max_lag=50, with_merged_acg=False): + def __init__( + self, n_neighbors=3, max_lag=50, bin=1, unit="samples" + ): super().__init__() self.n_neighbors = n_neighbors self.max_lag = max_lag - self.with_merged_acg = with_merged_acg if self.with_merged_acg: self.height = 1.0 self.width = 3.0 else: self.height = 1.75 self.width = 2 + self.unit = unit + self.bin = bin def draw(self, panel, sorting_analysis: DARTsortAnalysis, unit_id: int): ( @@ -716,24 +735,27 @@ def draw(self, panel, sorting_analysis: DARTsortAnalysis, unit_id: int): squeeze=False, ) axes = axes.T - for j in range(len(neighb_sts)): - clags, ccg = correlogram(my_st, neighb_sts[j], max_lag=self.max_lag) - bar(axes[0, j], clags, ccg, fill=True, fc=colors[j]) # , ec="k", lw=1) - axes[0, j].set_title(f"unit {neighbor_ids[j]}") - if not self.with_merged_acg: - continue - - merged_st = np.concatenate((my_st, neighb_sts[j])) - merged_st.sort() - alags, acg = correlogram(merged_st, max_lag=self.max_lag) - bar(axes[1, j], alags, acg, fill=True, fc=colors[j]) # , ec="k", lw=1) - if self.with_merged_acg: - axes[0, 0].set_ylabel("ccg") - axes[1, 0].set_ylabel("merged acg") - axes[-1, len(neighb_sts) // 2].set_xlabel("lag (samples)") + samples_per_ms = sorting_analysis.sorting.sampling_frequency / 1000 + if self.unit == "samples": + max_lag_samples = self.max_lag + elif self.unit == "ms": + max_lag_samples = int(np.ceil(self.max_lag * samples_per_ms)) else: - axes[-1, -1].set_xlabel("ccg") + assert False + + for j in range(len(neighb_sts)): + plot_correlogram( + axes[0, j], + my_st, + neighb_sts[j], + max_lag=max_lag_samples, + bin=self.bin, + samples_per_ms=samples_per_ms, + to_ms=self.unit == "ms", + fc=colors[j], + ) + axes[0, j].set_title(f"ccg vs. unit {neighbor_ids[j]}") class NeighborQDAPlot(UnitPlot): @@ -877,6 +899,7 @@ def default_plots(sorting_analysis=None): p = [ UnitTextInfo(), ACG(), + ACG(max_lag=50.0, bin=0.5, unit="ms"), ISIHistogram(), ISIHistogram(bin_ms=0.25, max_ms=50.0), XZScatter(), @@ -885,6 +908,7 @@ def default_plots(sorting_analysis=None): NearbyCoarseTemplatesPlot(), CoarseTemplateDistancePlot(), NeighborCCGPlot(), + NeighborCCGPlot(max_lag=50.0, bin=0.5, unit="ms"), ] if sorting_analysis is not None and sorting_analysis.has_localizations(): p.extend([TimeZScatter(), TimeRegZScatter()]) From 7f16e8850c5087384f0d8552fec6774211c3ff93 Mon Sep 17 00:00:00 2001 From: Charlie Windolf Date: Sat, 20 Jun 2026 18:43:08 -0400 Subject: [PATCH 39/46] glom: gate SpikeInterface merge with coentropy --- src/dartsort/clustering/agglomerate.py | 149 ++++++++++++++++++++++++- src/dartsort/util/internal_config.py | 3 + 2 files changed, 151 insertions(+), 1 deletion(-) diff --git a/src/dartsort/clustering/agglomerate.py b/src/dartsort/clustering/agglomerate.py index ec218c76..a76e874f 100644 --- a/src/dartsort/clustering/agglomerate.py +++ b/src/dartsort/clustering/agglomerate.py @@ -161,13 +161,23 @@ def agglomerate( assert np.all(np.logical_and(qda_mask, fcorr_mask) <= mask) if refinement_cfg.spikeinterface_merge_preset is not None: + pair_mask = tdist.distances < refinement_cfg.spikeinterface_merge_max_distance + if refinement_cfg.spikeinterface_merge_min_coentropy is not None: + cmask, _ = coentropy_merge_mask( + sorting=sorting, + min_coentropy=refinement_cfg.spikeinterface_merge_min_coentropy, + coverage_threshold=refinement_cfg.spikeinterface_merge_coent_coverage, + iou_threshold=refinement_cfg.spikeinterface_merge_coent_iou, + ) + pair_mask = np.logical_and(cmask, pair_mask) + si_mask = spikeinterface_merge_mask( recording=recording, sorting=sorting, preset=refinement_cfg.spikeinterface_merge_preset, censor_ms=refinement_cfg.censor_ms, template_data=tdist.template_data, - pair_mask=tdist.distances < refinement_cfg.spikeinterface_merge_max_distance, + pair_mask=pair_mask, ) else: si_mask = None @@ -986,3 +996,140 @@ def _dedup_unit_loop( break i0 = i1 + + +@databag +class CoentropyResult: + coentropy: np.ndarray + """KxK; reduction of entropy per cooccurrence due to merging pair""" + + cooccurrence: np.ndarray + """KxK; number of times these units score the same spike""" + + rival_count: np.ndarray + """KxK; number of times one unit scores a spike where the other is top""" + + occurrence: np.ndarray + """K; number of times the unit appears in the candidates at all""" + + +def coentropy_merge_mask( + sorting: DARTsortSorting, + min_coentropy: float, + coverage_threshold: float, + iou_threshold: float, + gmm_prefix=("merged", "gmm"), +) -> tuple[np.ndarray, CoentropyResult]: + """ + Parameters + ---------- + sorting : DARTsortSorting + min_coentropy : float + Must be met by pair for mask=True + min_coverage : float + Pairs such that at least one unit in each pair has + rival_count/count > mincov are allowed + iou_threshold: float + Pairs with rival iou > iouthresh are allowed + """ + c = coentropy(sorting, gmm_prefix=gmm_prefix) + assert c is not None + + # rival count diagonal is just unit top count (not exactly label count, + # since it doesn't account for noise assignments) + counts = np.diagonal(c.rival_count) + + cov = c.rival_count / counts + cov = np.minimum(cov, cov.T) + + union = counts[:, None] + counts[None, :] + union -= c.rival_count + iou = c.rival_count / union + + mask = np.logical_and.reduce( + [cov >= coverage_threshold, iou >= iou_threshold, c.coentropy >= min_coentropy] + ) + np.fill_diagonal(mask, True) + return mask, c + + +def coentropy( + sorting: DARTsortSorting, + gmm_prefix=("merged", "gmm"), +) -> CoentropyResult | None: + """Calculate entropy reduction due to merging pairs.""" + for k in gmm_prefix: + cands = getattr(sorting, f"{k}_candidates", None) + resps = getattr(sorting, f"{k}_responsibilities", None) + if cands is not None: + assert resps is not None + break + else: + return None + + k = sorting.n_units + resps = resps[:, : cands.shape[1]].astype(np.float64) + coentropy = np.zeros((k, k)) + cooccurrence = np.zeros((k, k), dtype=np.int64) + rival_count = np.zeros((k, k), dtype=np.int64) + occurrence = np.zeros((k,), dtype=np.int64) + _calc_coentropy(coentropy, cooccurrence, rival_count, occurrence, cands, resps) + rival_count += rival_count.T + np.fill_diagonal(rival_count, np.diagonal(rival_count) // 2) + coentropy += coentropy.T + cooccurrence += cooccurrence.T + + return CoentropyResult( + coentropy=coentropy, + cooccurrence=cooccurrence, + rival_count=rival_count, + occurrence=occurrence, + ) + + +@numba.njit(parallel=True) +def _calc_coentropy(coentropy, cooccurrence, rival_count, occurrence, cands, resps): + for i in numba.prange(cands.shape[0]): # ty: ignore + u = cands[i] + q = resps[i] + log_q = np.log(q) + np.nan_to_num(log_q, copy=False, neginf=0.0) + nh = -q * log_q + + ui0 = u[0] + + for j in range(0, cands.shape[1]): + uj = u[j] + if uj < 0: + break + + qj = q[j] + nhj = nh[j] + + occurrence[uj] += 1 + rival_count[ui0, uj] += 1 + + for k in range(j + 1, cands.shape[1]): + uk = u[k] + if uk < 0: + break + + ii = min(uk, uj) + jj = max(uk, uj) + + cij = cooccurrence[ii, jj] + 1 + cooccurrence[ii, jj] = cij + + # change in entropy due to merging uj, uk: + # subtract their current contribution, add the new contribution + qk = q[k] + nhk = nh[k] + qjk = qk + qj + dh = nhj + nhk + if qjk > 0: + log_qjk = np.log(qjk) + dh += qjk * log_qjk + + # Welford mean of -dh + cur_coent = coentropy[ii, jj] + coentropy[ii, jj] = cur_coent + (-dh - cur_coent) / cij diff --git a/src/dartsort/util/internal_config.py b/src/dartsort/util/internal_config.py index c1ca7e47..e72358a3 100644 --- a/src/dartsort/util/internal_config.py +++ b/src/dartsort/util/internal_config.py @@ -553,6 +553,9 @@ class RefinementConfig: qda_force_merge_for_temp_dist_below: float = 0.3 spikeinterface_merge_preset: str | None = None spikeinterface_merge_max_distance: float = 0.5 + spikeinterface_merge_min_coentropy: float | None = 0.1 + spikeinterface_merge_coent_coverage: float = 0.9 + spikeinterface_merge_coent_iou: float = 0.6 # forward_backward parameters chunk_size_s: float = 300.0 From c6f685aeb0ca0259f15ef9b9826a57b4a66b8724 Mon Sep 17 00:00:00 2001 From: Charlie Windolf Date: Sun, 21 Jun 2026 15:05:56 -0400 Subject: [PATCH 40/46] glom: coentropy mask should be or --- src/dartsort/clustering/agglomerate.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/dartsort/clustering/agglomerate.py b/src/dartsort/clustering/agglomerate.py index a76e874f..ff456064 100644 --- a/src/dartsort/clustering/agglomerate.py +++ b/src/dartsort/clustering/agglomerate.py @@ -1046,9 +1046,8 @@ def coentropy_merge_mask( union -= c.rival_count iou = c.rival_count / union - mask = np.logical_and.reduce( - [cov >= coverage_threshold, iou >= iou_threshold, c.coentropy >= min_coentropy] - ) + mask = np.logical_or(cov >= coverage_threshold, iou >= iou_threshold) + mask = np.logical_and(c.coentropy >= min_coentropy, mask) np.fill_diagonal(mask, True) return mask, c From 608ed5ee7daada252217e1feb76c23de97295f65 Mon Sep 17 00:00:00 2001 From: Charlie Windolf Date: Sun, 21 Jun 2026 16:35:18 -0400 Subject: [PATCH 41/46] vis: debug neighbor ccg --- src/dartsort/vis/unit.py | 46 +++++++++++----------------------------- 1 file changed, 12 insertions(+), 34 deletions(-) diff --git a/src/dartsort/vis/unit.py b/src/dartsort/vis/unit.py index cde272fa..8ff315c9 100644 --- a/src/dartsort/vis/unit.py +++ b/src/dartsort/vis/unit.py @@ -21,14 +21,7 @@ from ..util.job_util import get_global_computation_config from ..util.multiprocessing_util import CloudpicklePoolExecutor, cloudpickle, get_pool from . import layout -from .analysis_plots import ( - bar, - bimod_stats, - centered_bins, - correlogram, - isi_hist, - plot_correlogram, -) +from .analysis_plots import bimod_stats, centered_bins, isi_hist, plot_correlogram from .colors import glasbey1024 from .waveforms import geomplot, geomplot_templates @@ -680,18 +673,12 @@ def draw(self, panel, sorting_analysis: DARTsortAnalysis, unit_id): class NeighborCCGPlot(UnitPlot): kind = "medium" - def __init__( - self, n_neighbors=3, max_lag=50, bin=1, unit="samples" - ): + def __init__(self, n_neighbors=3, max_lag=50, bin=1, unit="samples"): super().__init__() self.n_neighbors = n_neighbors self.max_lag = max_lag - if self.with_merged_acg: - self.height = 1.0 - self.width = 3.0 - else: - self.height = 1.75 - self.width = 2 + self.height = 1.75 + self.width = 2 self.unit = unit self.bin = bin @@ -718,23 +705,14 @@ def draw(self, panel, sorting_analysis: DARTsortAnalysis, unit_id: int): for nid in neighbor_ids ] - if self.with_merged_acg: - axes = panel.subplots( - nrows=1 + self.with_merged_acg, - ncols=len(neighb_sts), - sharey="row", - sharex=True, - squeeze=False, - ) - else: - axes = panel.subplots( - ncols=1 + self.with_merged_acg, - nrows=len(neighb_sts), - sharey="row", - sharex=True, - squeeze=False, - ) - axes = axes.T + axes = panel.subplots( + ncols=1, + nrows=len(neighb_sts), + sharey="row", + sharex=True, + squeeze=False, + ) + axes = axes.T samples_per_ms = sorting_analysis.sorting.sampling_frequency / 1000 if self.unit == "samples": From b1bad583e0b234045039ebeeb0f95bac02aede1b Mon Sep 17 00:00:00 2001 From: Charlie Windolf Date: Mon, 22 Jun 2026 11:17:57 -0400 Subject: [PATCH 42/46] glom: restrict coentropy to rival pairs --- src/dartsort/clustering/agglomerate.py | 101 ++++++++++++++----------- 1 file changed, 57 insertions(+), 44 deletions(-) diff --git a/src/dartsort/clustering/agglomerate.py b/src/dartsort/clustering/agglomerate.py index ff456064..f6ad4571 100644 --- a/src/dartsort/clustering/agglomerate.py +++ b/src/dartsort/clustering/agglomerate.py @@ -169,7 +169,7 @@ def agglomerate( coverage_threshold=refinement_cfg.spikeinterface_merge_coent_coverage, iou_threshold=refinement_cfg.spikeinterface_merge_coent_iou, ) - pair_mask = np.logical_and(cmask, pair_mask) + pair_mask = np.logical_or(cmask, pair_mask) si_mask = spikeinterface_merge_mask( recording=recording, @@ -417,8 +417,8 @@ def spikeinterface_merge_mask( "remove_contaminated", "unit_locations", "template_similarity", - "cross_contamination", "slay_score", + "cross_contamination", "quality_score", ] preset = None @@ -1012,6 +1012,12 @@ class CoentropyResult: occurrence: np.ndarray """K; number of times the unit appears in the candidates at all""" + cov: np.ndarray + """KxK; rival count / max pair count (rival diag)""" + + iou: np.ndarray + """KxK; rival count over pair sum""" + def coentropy_merge_mask( sorting: DARTsortSorting, @@ -1035,18 +1041,7 @@ def coentropy_merge_mask( c = coentropy(sorting, gmm_prefix=gmm_prefix) assert c is not None - # rival count diagonal is just unit top count (not exactly label count, - # since it doesn't account for noise assignments) - counts = np.diagonal(c.rival_count) - - cov = c.rival_count / counts - cov = np.minimum(cov, cov.T) - - union = counts[:, None] + counts[None, :] - union -= c.rival_count - iou = c.rival_count / union - - mask = np.logical_or(cov >= coverage_threshold, iou >= iou_threshold) + mask = np.logical_or(c.cov >= coverage_threshold, c.iou >= iou_threshold) mask = np.logical_and(c.coentropy >= min_coentropy, mask) np.fill_diagonal(mask, True) return mask, c @@ -1074,61 +1069,79 @@ def coentropy( occurrence = np.zeros((k,), dtype=np.int64) _calc_coentropy(coentropy, cooccurrence, rival_count, occurrence, cands, resps) rival_count += rival_count.T - np.fill_diagonal(rival_count, np.diagonal(rival_count) // 2) + cdiag = np.diagonal(rival_count) + assert (cdiag % 2 == 0).all() + np.fill_diagonal(rival_count, cdiag // 2) coentropy += coentropy.T cooccurrence += cooccurrence.T + # rival count diagonal is just unit top count (not exactly label count, + # since it doesn't account for noise assignments) + counts = np.diagonal(rival_count) + counts = np.maximum(counts, 1) + + cov = rival_count / counts + cov = np.minimum(cov, cov.T) + + # this is a disjoint union, since it's the top-label count + union = counts[:, None] + counts[None, :] + iou = rival_count / union + return CoentropyResult( coentropy=coentropy, cooccurrence=cooccurrence, rival_count=rival_count, occurrence=occurrence, + cov=cov, + iou=iou, ) @numba.njit(parallel=True) -def _calc_coentropy(coentropy, cooccurrence, rival_count, occurrence, cands, resps): +def _calc_coentropy( + coentropy: np.ndarray, + cooccurrence: np.ndarray, + rival_count: np.ndarray, + occurrence: np.ndarray, + cands: np.ndarray, + resps: np.ndarray, +): for i in numba.prange(cands.shape[0]): # ty: ignore u = cands[i] q = resps[i] log_q = np.log(q) np.nan_to_num(log_q, copy=False, neginf=0.0) - nh = -q * log_q + dh = q * log_q ui0 = u[0] + qi0 = q[0] + dhi0 = dh[0] - for j in range(0, cands.shape[1]): + occurrence[ui0] += 1 + rival_count[ui0, ui0] += 1 + + for j in range(1, cands.shape[1]): uj = u[j] if uj < 0: break - qj = q[j] - nhj = nh[j] + ii = min(ui0, uj) + jj = max(ui0, uj) occurrence[uj] += 1 rival_count[ui0, uj] += 1 - for k in range(j + 1, cands.shape[1]): - uk = u[k] - if uk < 0: - break - - ii = min(uk, uj) - jj = max(uk, uj) - - cij = cooccurrence[ii, jj] + 1 - cooccurrence[ii, jj] = cij - - # change in entropy due to merging uj, uk: - # subtract their current contribution, add the new contribution - qk = q[k] - nhk = nh[k] - qjk = qk + qj - dh = nhj + nhk - if qjk > 0: - log_qjk = np.log(qjk) - dh += qjk * log_qjk - - # Welford mean of -dh - cur_coent = coentropy[ii, jj] - coentropy[ii, jj] = cur_coent + (-dh - cur_coent) / cij + cij = cooccurrence[ii, jj] + 1 + cooccurrence[ii, jj] = cij + + # change in entropy due to merging uj, uk: + # subtract their current contribution, add the new contribution + # we want reduction of entropy, so this is the negative of that! + qij = q[j] + qi0 + dhij = dh[j] + dhi0 + if qij > 0: + dhij -= qij * np.log(qij) + + # Welford mean of -dh + cur_coent = coentropy[ii, jj] + coentropy[ii, jj] = cur_coent + (-dhij - cur_coent) / cij From 198e3e57e284e4858a0e5ee65818cb1869ab875c Mon Sep 17 00:00:00 2001 From: Charlie Windolf Date: Mon, 22 Jun 2026 11:18:18 -0400 Subject: [PATCH 43/46] cfg: reworking subsampling --- src/dartsort/config.py | 5 +++-- src/dartsort/main.py | 18 +++++++++++++----- src/dartsort/util/internal_config.py | 23 +++++++++++------------ src/dartsort/util/main_util.py | 2 +- 4 files changed, 28 insertions(+), 20 deletions(-) diff --git a/src/dartsort/config.py b/src/dartsort/config.py index a4150e8a..1f83d1dc 100644 --- a/src/dartsort/config.py +++ b/src/dartsort/config.py @@ -255,10 +255,11 @@ class DeveloperConfig(DARTsortUserConfig): use_nn_in_subtraction: bool = True whiten_in_subtraction: bool = True threshold_before_whitening: float = 10.0 - temporal_dedup_radius_samples: int = 11 + temporal_dedup_radius_samples: int = 7 positive_temporal_dedup_radius_samples: int = 41 - spatial_dedup_radius_um: float | None = 50.0 + spatial_dedup_radius_um: float | None = 35.0 spikeinterface_merge_preset: str | None = None + spikeinterface_merge_max_distance: float = 0.5 # matching matching_template_type: Literal["individual_compressed_upsampled", "drifty"] = ( diff --git a/src/dartsort/main.py b/src/dartsort/main.py index fffdcb0e..3385de07 100644 --- a/src/dartsort/main.py +++ b/src/dartsort/main.py @@ -1,4 +1,5 @@ """High-level spike sorting toolbox functions.""" + import traceback from pathlib import Path from tempfile import TemporaryDirectory @@ -251,7 +252,7 @@ def _dartsort_impl( ) ret["motion"] = motion - is_subsampling = cfg.subsampling_spikes is not None + is_subsampling = cfg.subsampling_spikes_per_channel is not None is_subsampling = is_subsampling and cfg.subsampling_presence != 1.0 if next_step == 0: @@ -340,7 +341,10 @@ def _dartsort_impl( else: previous_detection_cfg = cfg.matching_cfg - _nspk = None if is_final else cfg.subsampling_spikes + if is_final or cfg.subsampling_spikes_per_channel is None: + _nspk = None + else: + _nspk = cfg.subsampling_spikes_per_channel * motion.geom.shape[0] _pres = 1.0 if is_final else cfg.subsampling_presence step_clus_cfg, step_clfeat_cfg, step_ref_cfgs, step_feat_cfg, samp_cfg = ( _matching_step_cfgs(is_final, is_subsampling, cfg) @@ -443,6 +447,10 @@ def initial_detection( ------- DARTsortSorting """ + if cfg.subsampling_spikes_per_channel is None: + _nspk = None + else: + _nspk = cfg.subsampling_spikes_per_channel * recording.get_num_channels() if cfg.detection_type == "subtract": assert isinstance(cfg.initial_detection_cfg, SubtractionConfig) return subtract( @@ -453,7 +461,7 @@ def initial_detection( subtraction_cfg=cfg.initial_detection_cfg, sampling_cfg=cfg.peeler_sampling_cfg, computation_cfg=cfg.computation_cfg, - stop_after_n_spikes=cfg.subsampling_spikes, + stop_after_n_spikes=_nspk, ensure_coverage=cfg.subsampling_presence, overwrite=overwrite, show_progress=show_progress, @@ -467,7 +475,7 @@ def initial_detection( thresholding_cfg=cfg.initial_detection_cfg, sampling_cfg=cfg.peeler_sampling_cfg, featurization_cfg=cfg.featurization_cfg, - stop_after_n_spikes=cfg.subsampling_spikes, + stop_after_n_spikes=_nspk, ensure_coverage=cfg.subsampling_presence, overwrite=overwrite, show_progress=show_progress, @@ -484,7 +492,7 @@ def initial_detection( matching_cfg=cfg.initial_detection_cfg, sampling_cfg=cfg.peeler_sampling_cfg, motion=motion, - stop_after_n_spikes=cfg.subsampling_spikes, + stop_after_n_spikes=_nspk, ensure_coverage=cfg.subsampling_presence, overwrite=overwrite, show_progress=show_progress, diff --git a/src/dartsort/util/internal_config.py b/src/dartsort/util/internal_config.py index e72358a3..d72043e2 100644 --- a/src/dartsort/util/internal_config.py +++ b/src/dartsort/util/internal_config.py @@ -218,7 +218,7 @@ class FitSamplingConfig: max_waveforms_fit: int = 50_000 n_waveforms_fit: int = 40_000 more_waveforms_fit: int = 2000 * 1024 - n_residual_snips: int = 4 * 4096 + n_residual_snips: int = 2 * 4096 residual_snip_ms: float | None = None residual_sampling_target_density: float = 0.25 seed: int = 0 @@ -553,9 +553,9 @@ class RefinementConfig: qda_force_merge_for_temp_dist_below: float = 0.3 spikeinterface_merge_preset: str | None = None spikeinterface_merge_max_distance: float = 0.5 - spikeinterface_merge_min_coentropy: float | None = 0.1 - spikeinterface_merge_coent_coverage: float = 0.9 - spikeinterface_merge_coent_iou: float = 0.6 + spikeinterface_merge_min_coentropy: float | None = 0.01 + spikeinterface_merge_coent_coverage: float = 0.8 + spikeinterface_merge_coent_iou: float = 0.5 # forward_backward parameters chunk_size_s: float = 300.0 @@ -905,7 +905,9 @@ def is_multi_gpu(self): default_motion_estimation_cfg = MotionEstimationConfig() default_computation_cfg = ComputationConfig() default_refinement_cfg = RefinementConfig() -default_initial_refinement_cfg = RefinementConfig(mixture_steps=("split", "demolish", "demolish")) +default_initial_refinement_cfg = RefinementConfig( + mixture_steps=("split", "demolish", "demolish") +) default_pre_refinement_cfg = RefinementConfig(refinement_strategy="pcmerge") default_agglomerate_cfg = RefinementConfig( refinement_strategy="agglomerate", @@ -949,7 +951,7 @@ class DARTsortInternalConfig: recluster_after_first_matching: bool = False # subsampling: intermediate peels will continue until both criteria satisfied # need at least this many spikes - subsampling_spikes: int | None = 2_048_000 + subsampling_spikes_per_channel: int | None = 5000 # need to cover at least this fraction of chunks subsampling_presence: float = 0.1 @@ -1314,6 +1316,7 @@ def to_internal_config(cfg, n_channels: int) -> DARTsortInternalConfig: qda_threshold=0.0, dedup_ms=cfg.deduplication_ms, spikeinterface_merge_preset=cfg.spikeinterface_merge_preset, + spikeinterface_merge_max_distance=cfg.spikeinterface_merge_max_distance, ) elif cfg.agg_kind == "qda": agg_whiten_cfg = WhiteningConfig( @@ -1335,15 +1338,11 @@ def to_internal_config(cfg, n_channels: int) -> DARTsortInternalConfig: qda_force_merge_for_temp_dist_below=cfg.agg_no_qda_template_distance, dedup_ms=cfg.deduplication_ms, spikeinterface_merge_preset=cfg.spikeinterface_merge_preset, + spikeinterface_merge_max_distance=cfg.spikeinterface_merge_max_distance, ) else: assert False - if cfg.subsampling_spikes_per_channel is not None: - subsampling_spikes = cfg.subsampling_spikes_per_channel * n_channels - else: - subsampling_spikes = None - return DARTsortInternalConfig( waveform_cfg=waveform_cfg, featurization_cfg=featurization_cfg, @@ -1377,7 +1376,7 @@ def to_internal_config(cfg, n_channels: int) -> DARTsortInternalConfig: save_everything_on_error=cfg.save_everything_on_error, link_from=cfg.link_from, link_step=cfg.link_step, - subsampling_spikes=subsampling_spikes, + subsampling_spikes_per_channel=cfg.subsampling_spikes_per_channel, subsampling_presence=cfg.subsampling_presence, always_save_final_tpca_feature=cfg.always_save_final_tpca_feature, ) diff --git a/src/dartsort/util/main_util.py b/src/dartsort/util/main_util.py index 2146960e..679c63fb 100644 --- a/src/dartsort/util/main_util.py +++ b/src/dartsort/util/main_util.py @@ -143,7 +143,7 @@ def motion_needs_peaks( ): if cfg.subsampling_presence == 1.0: return False - if cfg.subsampling_spikes is None: + if cfg.subsampling_spikes_per_channel is None: return False # assert sorting's chunk starts, sorted, match full recording's From 2e602a3a9a3c9fc2ef77acabd7961892cb1374ad Mon Sep 17 00:00:00 2001 From: Charlie Windolf Date: Mon, 22 Jun 2026 11:18:30 -0400 Subject: [PATCH 44/46] vis: better single unit vis --- src/dartsort/vis/analysis_plots.py | 3 +- src/dartsort/vis/unit.py | 145 ++++++++++++++++++++++++----- 2 files changed, 124 insertions(+), 24 deletions(-) diff --git a/src/dartsort/vis/analysis_plots.py b/src/dartsort/vis/analysis_plots.py index 008b32c3..bdbeb2ff 100644 --- a/src/dartsort/vis/analysis_plots.py +++ b/src/dartsort/vis/analysis_plots.py @@ -360,7 +360,6 @@ def isi_hist( dt_ms, bin_edges, color=color, label=label, histtype=histtype, alpha=alpha ) axis.set_xlabel("isi (ms)") - axis.set_ylabel(f"count (out of {dt_ms.size} total isis)") def centered_bins(x, dx=1.0): @@ -460,7 +459,7 @@ def plot_correlogram( ctr = ms_lags.shape[0] // 2 assert ms_lags.shape == (2 * ctr + 1,) for j in range(max_lag): - binix = (j / samples_per_ms) // bin + binix = int((j / samples_per_ms) // bin) ms_ccg[ctr + binix] += ccg[max_lag + j] if j: ms_ccg[ctr - binix] += ccg[max_lag - j] diff --git a/src/dartsort/vis/unit.py b/src/dartsort/vis/unit.py index 8ff315c9..494ea5d8 100644 --- a/src/dartsort/vis/unit.py +++ b/src/dartsort/vis/unit.py @@ -52,20 +52,20 @@ def draw(self, panel, sorting_analysis: DARTsortAnalysis, unit_id: int): h5_path = sorting_analysis.sorting.parent_h5_path if h5_path: - msg += f"feature source: {h5_path.name}\n" + msg += f"from: {h5_path.name}\n" nspikes = cast(np.ndarray, sorting_analysis.sorting.labels == unit_id).sum() - msg += f"n spikes: {nspikes}\n" + msg += f"count: {nspikes}\n" assert sorting_analysis.template_data is not None temps = sorting_analysis.template_data.unit_templates(unit_id) if not temps.size: - msg += "no template (too few spikes)" + msg += "no template\n(too few spikes)" elif temps.shape[0] == 1: ptp = np.ptp(temps, 1).max(1)[0] - msg += f"maxptp: {ptp:0.2f} su\n" + msg += f"ptp: {ptp:0.2f} su\n" snr = ptp * np.sqrt(nspikes) - msg += f"template snr: {snr:.1f}" + msg += f"snr: {snr:.1f}" else: assert False @@ -84,6 +84,8 @@ def __init__(self, max_lag=50, bin=1, unit="samples"): self.max_lag = max_lag self.bin = bin self.unit = unit + if unit == "ms": + self.width = 2 def draw(self, panel, sorting_analysis: DARTsortAnalysis, unit_id: int): axis = panel.subplots() @@ -104,6 +106,7 @@ def draw(self, panel, sorting_analysis: DARTsortAnalysis, unit_id: int): samples_per_ms=samples_per_ms, to_ms=self.unit == "ms", ) + axis.grid(which="both") axis.set_ylabel("acg") @@ -138,6 +141,7 @@ def draw( color=color, label=label, ) + axis.grid(which="both") class XZScatter(UnitPlot): @@ -235,41 +239,67 @@ def draw(self, panel, sorting_analysis: DARTsortAnalysis, unit_id, axis=None): class TimeFeatScatter(UnitPlot): - kind = "medium" + kind = "ctimefeat" width = 2 height = 0.75 def __init__( self, feat_name, + color_by_template_if_possible=False, color_by_amplitude=True, amplitude_color_cutoff=15, alpha=1.0, label=None, + cbar=True, ): super().__init__() self.feat_name = feat_name self.amplitude_color_cutoff = amplitude_color_cutoff self.color_by_amplitude = color_by_amplitude + self.color_by_template_if_possible = color_by_template_if_possible self.alpha = alpha self.label = label or feat_name + self.cbar = cbar def draw(self, panel, sorting_analysis: DARTsortAnalysis, unit_id: int): axis = panel.subplots() assert sorting_analysis.times_seconds is not None assert sorting_analysis.amplitudes is not None + in_unit = sorting_analysis.in_unit(unit_id, at_most=50_000) t = sorting_analysis.times_seconds[in_unit] feat = sorting_analysis.named_feature(self.feat_name, which=in_unit) c = None - if self.color_by_amplitude: + cbar = self.cbar + did_by_template = False + if c is None and self.color_by_template_if_possible: + temp_ix = getattr(sorting_analysis.sorting, "template_inds", None) + if temp_ix is not None: + temp_ix = temp_ix[in_unit] + c = glasbey1024[temp_ix % len(glasbey1024)] + did_by_template = True + cbar = False + if c is None and self.color_by_amplitude: amps = sorting_analysis.amplitudes[in_unit] c = np.minimum(amps, self.amplitude_color_cutoff) s = axis.scatter(t, feat, c=c, lw=0, s=3, alpha=self.alpha, rasterized=True) axis.set_xlabel("time (s)") axis.set_ylabel(self.label) - if self.color_by_amplitude: - plt.colorbar(s, ax=axis, shrink=0.5, label="amp (su)") + axis.grid() + axis.set_axisbelow(True) + if cbar and self.color_by_amplitude: + plt.colorbar(s, ax=axis, shrink=0.5, pad=0.01, label="amp (su)") + if did_by_template: + axis.text( + 0.97, + 0.97, + "color: template", + ha="right", + va="top", + transform=axis.transAxes, + fontsize="small", + ) class TimeZScatter(TimeFeatScatter): @@ -283,8 +313,64 @@ def __init__(self, **kwargs): class TimeAmpScatter(TimeFeatScatter): - def __init__(self, **kwargs): - super().__init__(feat_name="amplitudes", label="amp (su)", **kwargs) + def __init__(self, color_by_template_if_possible=True, **kwargs): + super().__init__( + feat_name="amplitudes", + label="amp (su)", + color_by_template_if_possible=color_by_template_if_possible, + **kwargs, + ) + + +class AmplitudeHistogramByDiscreteVariable(UnitPlot): + kind = "camphist" + width = 2 + height = 0.75 + + def __init__(self, var="channels"): + self.var = var + + def draw(self, panel, sorting_analysis: DARTsortAnalysis, unit_id: int): + axis = panel.subplots() + assert sorting_analysis.amplitudes is not None + z = getattr(sorting_analysis.sorting, self.var, None) + if z is None: + axis.axis("off") + return + in_unit = sorting_analysis.in_unit(unit_id, at_most=50_000) + a = sorting_analysis.amplitudes[in_unit] + z = z[in_unit] + bins = np.arange(np.floor(a.min()), np.ceil(a.max()) + 0.1) + uqz = np.unique(z) + for uz in uqz: + axis.hist( + a[z == uz], + color=glasbey1024[uz % len(glasbey1024)], + histtype="step", + lw=1, + label=uz, + bins=bins, + ) + if uqz.size < 4: + axis.legend(title=self.var, fancybox=False, loc="upper right") + else: + msg = f"{self.var}, {uqz.size} uniques" + if uqz.size < 8: + msg += ":\n" + msg += ",".join(list(map(str, uqz.tolist()))) + axis.text( + 0.97, + 0.97, + msg, + fontsize="small", + ha="right", + va="top", + transform=axis.transAxes, + ) + axis.grid() + axis.set_axisbelow(True) + axis.semilogy() + axis.set_xlabel("amplitude (s.u.)") # -- waveform plots @@ -451,9 +537,9 @@ def draw(self, panel, sorting_analysis: DARTsortAnalysis, unit_id, axis=None): shift_str = "shifted " * sorting_analysis.shifting if self.title is None: - axis.set_title(shift_str + self.wfs_kind + rmsg) + axis.set_title(shift_str + self.wfs_kind + rmsg, fontsize="small") else: - axis.set_title(self.title + rmsg) + axis.set_title(self.title + rmsg, fontsize="small") axis.set_xticks([]) axis.set_yticks([]) @@ -477,7 +563,7 @@ def get_waveforms( class TPCAWaveformPlot(WaveformPlot): - wfs_kind = "coll.-cl. tpca wfs" + wfs_kind = "c-c tpca wfs" def get_waveforms( self, sorting_analysis: DARTsortAnalysis, unit_id: int @@ -605,7 +691,7 @@ def draw(self, panel, sorting_analysis: DARTsortAnalysis, unit_id, axis=None): ): tx.set_color(colors[i]) ty.set_color(colors[i]) - axis.set_title(self.title) + axis.set_title(self.title, fontsize="small") class NeighborQDAMatrices(UnitPlot): @@ -667,17 +753,17 @@ def draw(self, panel, sorting_analysis: DARTsortAnalysis, unit_id): ): tx.set_color(colors[i]) ty.set_color(colors[i]) - ax.set_title(title) + ax.set_title(title, fontsize="small") class NeighborCCGPlot(UnitPlot): - kind = "medium" + kind = "bneighborccg" - def __init__(self, n_neighbors=3, max_lag=50, bin=1, unit="samples"): + def __init__(self, n_neighbors=5, max_lag=50, bin=1, unit="samples"): super().__init__() self.n_neighbors = n_neighbors self.max_lag = max_lag - self.height = 1.75 + self.height = 2 self.width = 2 self.unit = unit self.bin = bin @@ -733,7 +819,16 @@ def draw(self, panel, sorting_analysis: DARTsortAnalysis, unit_id: int): to_ms=self.unit == "ms", fc=colors[j], ) - axes[0, j].set_title(f"ccg vs. unit {neighbor_ids[j]}") + axes[0, j].grid(which="both") + axes[0, j].text( + 0.97, + 0.97, + f"vs. unit {neighbor_ids[j]}", + ha="right", + va="top", + transform=axes[0, j].transAxes, + fontsize="small", + ) class NeighborQDAPlot(UnitPlot): @@ -809,6 +904,7 @@ def draw(self, panel, sorting_analysis: DARTsortAnalysis, unit_id: int): ax.axvline(0, color="k", lw=0.8) ax.grid() + ax.set_axisbelow(True) hstat = kstat = "" if self.kind == "hist": @@ -881,6 +977,7 @@ def default_plots(sorting_analysis=None): ISIHistogram(), ISIHistogram(bin_ms=0.25, max_ms=50.0), XZScatter(), + AmplitudeHistogramByDiscreteVariable(), TimeAmpScatter(), RawWaveformPlot(), NearbyCoarseTemplatesPlot(), @@ -894,6 +991,10 @@ def default_plots(sorting_analysis=None): p.extend([PCAScatter(), TPCAWaveformPlot()]) if sorting_analysis is not None and sorting_analysis.qda is not None: p.extend([NeighborQDAMatrices(), NeighborQDAPlot()]) + if sorting_analysis is not None and hasattr( + sorting_analysis.sorting, "template_inds" + ): + p.extend([AmplitudeHistogramByDiscreteVariable("template_inds")]) return p @@ -922,7 +1023,7 @@ def make_unit_summary( pca_radius_um=75.0, plots=None, max_height=4, - figsize=(16, 8.5), + figsize=(18, 8.5), figure=None, gizmo_name="sorting_analysis", **other_global_params, @@ -956,7 +1057,7 @@ def make_all_summaries( amplitude_color_cutoff=15.0, pca_radius_um=75.0, max_height=4, - figsize=(16, 8.5), + figsize=(18, 8.5), dpi=200, image_ext="png", n_jobs=None, From ba4737e3c7dfa382fb2d5b02f80e3428481694d1 Mon Sep 17 00:00:00 2001 From: Charlie Windolf Date: Mon, 22 Jun 2026 12:39:40 -0400 Subject: [PATCH 45/46] data_util: work on SortingAnalyzer export --- src/dartsort/config.py | 1 - src/dartsort/util/data_util.py | 50 +++++++++++++++++++++++++--- src/dartsort/util/internal_config.py | 2 +- 3 files changed, 46 insertions(+), 7 deletions(-) diff --git a/src/dartsort/config.py b/src/dartsort/config.py index 1f83d1dc..72695fa9 100644 --- a/src/dartsort/config.py +++ b/src/dartsort/config.py @@ -257,7 +257,6 @@ class DeveloperConfig(DARTsortUserConfig): threshold_before_whitening: float = 10.0 temporal_dedup_radius_samples: int = 7 positive_temporal_dedup_radius_samples: int = 41 - spatial_dedup_radius_um: float | None = 35.0 spikeinterface_merge_preset: str | None = None spikeinterface_merge_max_distance: float = 0.5 diff --git a/src/dartsort/util/data_util.py b/src/dartsort/util/data_util.py index aa3f6af1..0e366f10 100644 --- a/src/dartsort/util/data_util.py +++ b/src/dartsort/util/data_util.py @@ -15,9 +15,11 @@ create_sorting_analyzer, get_random_data_chunks, ) +from spikeinterface.core.sparsity import estimate_sparsity from ..detect import detect_and_deduplicate from .internal_config import ( + TemplateConfig, WaveformConfig, default_clustering_features_cfg, default_waveform_cfg, @@ -223,7 +225,10 @@ def to_sorting_analyzer( self, recording: BaseRecording, template_data: "TemplateData | None" = None, + template_cfg: TemplateConfig | None = None, + motion: "MotionInfo | None" = None, drop_doubles: bool = True, + compute_extensions: Sequence[str] | None = ("random_spikes", "waveforms"), features_cfg=default_clustering_features_cfg, ) -> SortingAnalyzer: """Export dartsort's internal data to a SortingAnalyzer @@ -231,18 +236,36 @@ def to_sorting_analyzer( This will first call to_numpy_sorting() and then register some of the sorting's features as extensions for the analyzer. - If template_data is supplied, a templates extension will be registered. + If template_data is supplied, a templates extension will be registered. Or, you + can supply template_cfg and motion to compute that. The implementation is based on SpikeInterface's `read_kilosort_as_analyzer()`, thanks to Chris Halcrow for that. - TODO: This doesn't handle gain_to_uV or random waveforms, sparsity, or probably - other important things. + The return value here can be passed into SpikeInterface's Phy export machine + `export_to_phy()`. + + TODO: This doesn't handle gain_to_uV... should I not be doing my own amps here? Parameters ---------- recording : BaseRecording template_data : TemplateData | None, optional + Templates to register with SortingAnalyzer as the templates + extension + template_cfg : TemplateConfig | None, optional + If template_data is not supplied but this is (together with motion), + templates will be estimated using dartsort machinery + motion: MotionInfo | None, optional + drop_doubles : bool + Call .drop_doubles(). This will probably do nothing if the sorting had + dedup_ms > 0 in the parameters. + compute_extensions : Sequence[str] | None + Extra analyzer extensions to compute. The default set is what's needed for + SpikeInterface's `export_to_phy()` to run. These are included here because + SpikeInterface is picky about extension order and will evict the template + extension if some of these are computed after the templates, which would then + cause the Phy export to fail. features_cfg : ClusteringFeaturesConfig Stores attribute dataset names @@ -256,12 +279,18 @@ def to_sorting_analyzer( ComputeUnitLocations, ) - sorting, kept_indices = self.to_numpy_sorting(drop_doubles=drop_doubles, return_kept_indices=True) # type: ignore + sorting, kept_indices = self.to_numpy_sorting( + drop_doubles=drop_doubles, return_kept_indices=True + ) # type: ignore + sparsity = estimate_sparsity(sorting, recording) analyzer = create_sorting_analyzer( - sorting=sorting, recording=recording, sparse=False, return_in_uV=False + sorting=sorting, recording=recording, sparsity=sparsity, return_in_uV=False ) + for ext in compute_extensions or []: + analyzer.compute_one_extension(ext) + loc_name = features_cfg.localizations_dataset_name if (locs := self.localizations_as_structured_array(loc_name)) is not None: locs = locs[kept_indices] @@ -271,6 +300,16 @@ def to_sorting_analyzer( loc_ext.run_info = {"run_completed": True} analyzer.extensions["spike_locations"] = loc_ext + if template_data is None and template_cfg is not None: + from ..templates.postprocess_util import estimate_template_library + + _, template_data = estimate_template_library( + recording=recording, + sorting=self, + motion=motion, + template_cfg=template_cfg, + ) + if template_data is not None: td_ext = ComputeTemplates(analyzer) assert np.array_equal( @@ -1157,6 +1196,7 @@ def si_structured_localizations_array(locs: np.ndarray) -> np.ndarray: return structured_array + def filter_link_h5(in_h5_path: str | Path, out_h5_path: str | Path, keep_filter): in_h5_path = ensure_path(in_h5_path, strict=True) out_h5_path = ensure_path(out_h5_path) diff --git a/src/dartsort/util/internal_config.py b/src/dartsort/util/internal_config.py index d72043e2..c008803b 100644 --- a/src/dartsort/util/internal_config.py +++ b/src/dartsort/util/internal_config.py @@ -679,7 +679,7 @@ class SubtractionConfig: relative_peak_radius_samples: int = 5 relative_peak_radius_um: float | None = 35.0 spatial_dedup_radius_um: float | None = 50.0 - temporal_dedup_radius_samples: int = 11 + temporal_dedup_radius_samples: int = 7 remove_exact_duplicates: bool = True positive_temporal_dedup_radius_samples: int = 41 subtract_radius_um: float = 200.0 From 5348d56aa2d458fa8426b3a3d092620aa24aa0a9 Mon Sep 17 00:00:00 2001 From: Charlie Windolf Date: Mon, 22 Jun 2026 16:25:29 -0400 Subject: [PATCH 46/46] gmm: can't throw everything away... --- src/dartsort/clustering/mixture.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/dartsort/clustering/mixture.py b/src/dartsort/clustering/mixture.py index bcee89d1..352b9dda 100644 --- a/src/dartsort/clustering/mixture.py +++ b/src/dartsort/clustering/mixture.py @@ -5049,10 +5049,13 @@ def all_demolished_partitions( can_demolish_mask.shape == part.unit_ids.shape == part.group_ids.shape ) single_ixs = torch.tensor(part.single_ixs, dtype=torch.long) + npart = single_ixs.numel() (part_demo_ix,) = can_demolish_mask[single_ixs].nonzero(as_tuple=True) for demo_ixs in subsets(part_demo_ix.tolist()): demo_ixs = list(demo_ixs) + if len(demo_ixs) == npart: + continue demo_group_ids = part.group_ids.clone() demo_group_ids[single_ixs[demo_ixs]] = -1 demo_part = replace(