From a0e85346e6b855f84c6ccac3c08112b51c1f0023 Mon Sep 17 00:00:00 2001 From: Charlie Windolf Date: Fri, 16 Aug 2024 09:34:06 -0700 Subject: [PATCH 1/8] Unreg bug --- src/dartsort/templates/get_templates.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/dartsort/templates/get_templates.py b/src/dartsort/templates/get_templates.py index 8c5bd1bc..47fdf041 100644 --- a/src/dartsort/templates/get_templates.py +++ b/src/dartsort/templates/get_templates.py @@ -701,8 +701,10 @@ def _template_job(unit_ids): p.reducer(waveforms[in_unit], axis=0).numpy(force=True) ) counts.append(in_unit.size) - snrs_by_chan = [ptp(rt, 0) * c for rt, c in zip(raw_templates, counts)] + snrs_by_chan = np.array([ptp(rt, 0) * c for rt, c in zip(raw_templates, counts)]) counts_by_chan = np.array(counts) + if counts_by_chan.ndim == 1: + counts_by_chan = np.broadcast_to(counts_by_chan[:, None], snrs_by_chan.shape) raw_templates = np.array(raw_templates) if p.denoising_tsvd is None: From 11faf6782da84fe756aec45147cb5258b862b20d Mon Sep 17 00:00:00 2001 From: Christopher Langfield Date: Tue, 20 Aug 2024 11:07:34 -0700 Subject: [PATCH 2/8] piecewise spiketrain --- src/dartsort/util/hybrid_util.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/src/dartsort/util/hybrid_util.py b/src/dartsort/util/hybrid_util.py index 7fae8356..c3b21a0b 100644 --- a/src/dartsort/util/hybrid_util.py +++ b/src/dartsort/util/hybrid_util.py @@ -201,6 +201,23 @@ def refractory_poisson_spike_train( return spike_samples +def piecewise_refractory_poisson_spike_train(rates, bins, binsize_samples, **kwargs): + """ + Returns a spike train with variable firing rate using refractory_poisson_spike_train(). + + :param rates: list of firing rates in Hz + :param bins: bin starting samples (same shape as rates) + :param binsize_samples: number of samples per bin + :param **kwargs: kwargs to feed to refractory_poisson_spike_train() + """ + sp_tr = np.concatenate( + [ + refractory_poisson_spike_train(r, binsize_samples, **kwargs) + bins[i] if r > 0.1 else [] + for i, r in enumerate(rates) + ] + ) + return sp_tr + def precompute_displaced_registered_templates( template_data: TemplateData, From c9cd3226fed3ad700a67e5353132eb66b31b5e08 Mon Sep 17 00:00:00 2001 From: Christopher Langfield Date: Tue, 20 Aug 2024 19:56:28 -0700 Subject: [PATCH 3/8] feed amps vector directly to hybrid recording gen --- src/dartsort/util/hybrid_util.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/src/dartsort/util/hybrid_util.py b/src/dartsort/util/hybrid_util.py index c3b21a0b..8974cea6 100644 --- a/src/dartsort/util/hybrid_util.py +++ b/src/dartsort/util/hybrid_util.py @@ -26,6 +26,7 @@ def get_drifty_hybrid_recording( firing_rates=None, peak_channels=None, amplitude_scale_std=0.1, + amplitude_factor=None ): """ :param: recording @@ -33,7 +34,9 @@ def get_drifty_hybrid_recording( :param: motion estimate object :param: firing_rates :param: peak_channels - :param: amplitude_factor + :param: amplitude_scale_std -- std of gamma distributed amplitude variation if + amplitude_factor is None + :param: amplitude_factor array of length n_spikes with amplitude factors """ num_units = templates.num_units rg = np.random.default_rng(seed=seed) @@ -50,11 +53,12 @@ def get_drifty_hybrid_recording( n_spikes = sorting.count_total_num_spikes() # Default amplitude scalings for spikes drawn from gamma - if amplitude_scale_std: - shape = 1. / (amplitude_scale_std ** 1.5) - amplitude_factor = rg.gamma(shape, scale=1./(shape-1), size=n_spikes) - else: - amplitude_factor = np.ones(n_spikes) + if not amplitude_factor: + if amplitude_scale_std: + shape = 1. / (amplitude_scale_std ** 1.5) + amplitude_factor = rg.gamma(shape, scale=1./(shape-1), size=n_spikes) + else: + amplitude_factor = np.ones(n_spikes) depths = recording.get_probe().contact_positions[:, 1][peak_channels] t_start = recording.sample_index_to_time(0) From ef0d5ce0b77d668b11495e332c2bb6dbfed14da3 Mon Sep 17 00:00:00 2001 From: Christopher Langfield Date: Tue, 20 Aug 2024 20:22:21 -0700 Subject: [PATCH 4/8] none check --- src/dartsort/util/hybrid_util.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/dartsort/util/hybrid_util.py b/src/dartsort/util/hybrid_util.py index 8974cea6..55cd2980 100644 --- a/src/dartsort/util/hybrid_util.py +++ b/src/dartsort/util/hybrid_util.py @@ -53,7 +53,7 @@ def get_drifty_hybrid_recording( n_spikes = sorting.count_total_num_spikes() # Default amplitude scalings for spikes drawn from gamma - if not amplitude_factor: + if amplitude_factor is None: if amplitude_scale_std: shape = 1. / (amplitude_scale_std ** 1.5) amplitude_factor = rg.gamma(shape, scale=1./(shape-1), size=n_spikes) From bee08d7a3c8da10da7a53429a416d788e9b275b3 Mon Sep 17 00:00:00 2001 From: Charlie Windolf Date: Wed, 21 Aug 2024 08:09:37 -0700 Subject: [PATCH 5/8] UB in reg chans --- src/dartsort/util/drift_util.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/dartsort/util/drift_util.py b/src/dartsort/util/drift_util.py index 739d46d3..dc5e5561 100644 --- a/src/dartsort/util/drift_util.py +++ b/src/dartsort/util/drift_util.py @@ -131,16 +131,17 @@ def registered_geometry( return registered_geom -def registered_channels(channels, geom, n_pitches_shift, registered_geom): +def registered_channels(channels, geom, n_pitches_shift, registered_geom, distance_upper_bound=None): """What registered channels do `channels` land on after shifting by `n_pitches_shift`?""" pitch = get_pitch(geom) shifted_positions = geom.copy()[channels] shifted_positions[:, 1] += n_pitches_shift * pitch registered_kdtree = KDTree(registered_geom) - min_distance = pdist(registered_geom).min() / 2 + if distance_upper_bound is None: + distance_upper_bound = pdist(registered_geom).min() / 2 distances, registered_channels = registered_kdtree.query( - shifted_positions, distance_upper_bound=min_distance + shifted_positions, distance_upper_bound=distance_upper_bound ) # make sure there were no unmatched points assert np.all(registered_channels < len(registered_geom)) From 6dc371b61da828426f1692987974f5c7f149734c Mon Sep 17 00:00:00 2001 From: Charlie Windolf Date: Wed, 21 Aug 2024 08:09:51 -0700 Subject: [PATCH 6/8] Job control --- src/dartsort/util/comparison.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/dartsort/util/comparison.py b/src/dartsort/util/comparison.py index 9157434a..27050e27 100644 --- a/src/dartsort/util/comparison.py +++ b/src/dartsort/util/comparison.py @@ -24,6 +24,7 @@ class DARTsortGroundTruthComparison: match_mode: str = "hungarian" compute_labels: bool = True verbose: bool = True + device: Optional[str] = None compute_distances: bool = True compute_unsorted_recall: bool = True @@ -120,7 +121,8 @@ def _calculate_template_distances(self): gt_td, tested_td, sym_function=np.maximum, - n_jobs=max(self.gt_analysis.n_jobs, self.tested_analysis.n_jobs), + n_jobs=self.n_jobs, + device=self.device, ) self._template_distances = dists From 73cbed06c5e54627b66b7771690afe580c65de5d Mon Sep 17 00:00:00 2001 From: Charlie Windolf Date: Wed, 21 Aug 2024 08:10:06 -0700 Subject: [PATCH 7/8] Merge propagate chan counts --- src/dartsort/cluster/merge.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/dartsort/cluster/merge.py b/src/dartsort/cluster/merge.py index 00cb1169..bc194d0e 100644 --- a/src/dartsort/cluster/merge.py +++ b/src/dartsort/cluster/merge.py @@ -632,12 +632,16 @@ def combine_templates(template_data_a, template_data_b): spike_counts = np.concatenate( (template_data_a.spike_counts, template_data_b.spike_counts) ) + spike_counts_by_channel = np.concatenate( + (template_data_a.spike_counts_by_channel, template_data_b.spike_counts_by_channel) + ) template_data = TemplateData( templates=templates, unit_ids=unit_ids, spike_counts=spike_counts, registered_geom=rgeom, registered_template_depths_um=locs, + spike_counts_by_channel=spike_counts_by_channel, ) cross_mask = np.logical_and(np.isin(unit_ids, ids_a)[:, None], np.isin(unit_ids, ids_b)[None]) From 62972a28a4afc2da04cf620835a1731b63d3ae33 Mon Sep 17 00:00:00 2001 From: Charlie Windolf Date: Mon, 26 Aug 2024 15:21:17 -0700 Subject: [PATCH 8/8] Clustering api --- src/dartsort/cluster/initial.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/dartsort/cluster/initial.py b/src/dartsort/cluster/initial.py index b5244937..c52c5892 100644 --- a/src/dartsort/cluster/initial.py +++ b/src/dartsort/cluster/initial.py @@ -210,6 +210,7 @@ def cluster_chunks( clustering_config, sorting=None, motion_est=None, + amplitudes_dataset_name='denoised_ptp_amplitudes', ): """Divide the recording into chunks, and cluster each chunk @@ -239,6 +240,7 @@ def cluster_chunks( chunk_time_range_s=chunk_range, motion_est=motion_est, recording=recording, + amplitudes_dataset_name=amplitudes_dataset_name, ) for chunk_range in chunk_time_ranges_s ] @@ -253,6 +255,7 @@ def ensemble_chunks( sorting=None, computation_config=None, motion_est=None, + **kwargs, ): """Initial clustering combined across chunks of time @@ -283,6 +286,7 @@ def ensemble_chunks( clustering_config, sorting=sorting, motion_est=motion_est, + **kwargs, ) if len(chunk_sortings) == 1: return chunk_sortings[0] @@ -320,6 +324,7 @@ def initial_clustering( clustering_config=None, computation_config=None, motion_est=None, + **kwargs, ): if sorting is None: sorting = DARTsortSorting.from_peeling_hdf5(peeling_hdf5_filename) @@ -333,6 +338,7 @@ def initial_clustering( sorting=sorting, computation_config=computation_config, motion_est=motion_est, + **kwargs, ) \ No newline at end of file