diff --git a/src/dartsort/cluster/cluster_util.py b/src/dartsort/cluster/cluster_util.py index abfb2f8c..29e4da9e 100644 --- a/src/dartsort/cluster/cluster_util.py +++ b/src/dartsort/cluster/cluster_util.py @@ -16,7 +16,8 @@ def agglomerate(labels, distances, linkage_method="complete", threshold=1.0): n = distances.shape[0] pdist = distances[np.triu_indices(n, k=1)] if pdist.min() > threshold: - return labels + ids = np.unique(labels) + return labels, ids[ids >= 0] finite = np.isfinite(pdist) if not finite.all(): inf = max(0, pdist[finite].max()) + threshold + 1.0 @@ -28,7 +29,8 @@ def agglomerate(labels, distances, linkage_method="complete", threshold=1.0): new_ids -= new_ids.min() kept = labels >= 0 - new_labels = np.where(kept, new_ids[labels[kept]], -1) + new_labels = np.full_like(labels, -1) + new_labels[kept] = new_ids[labels[kept]] return new_labels, new_ids diff --git a/src/dartsort/cluster/gaussian_mixture.py b/src/dartsort/cluster/gaussian_mixture.py index fdf3328e..9dc75b5d 100644 --- a/src/dartsort/cluster/gaussian_mixture.py +++ b/src/dartsort/cluster/gaussian_mixture.py @@ -232,7 +232,7 @@ def log_likelihoods( # log liks as sparse matrix. sparse zeros are not 0 but -inf!! log_liks = coo_array( (coo_data, (coo_uix, coo_six)), - shape=(len(self.units) + with_noise_unit, self.data.n_spikes), + shape=(len(unit_ids) + with_noise_unit, self.data.n_spikes), ) return log_liks @@ -325,7 +325,7 @@ def distances( nu = len(units) # stack unit data into one place - means, covs, logdets = self.stack_units() + means, covs, logdets = self.stack_units(units) # compute denominator of noised normalized distances if noise_normalized: @@ -494,7 +494,7 @@ def unit_log_likelihoods( offset = 0 log_likelihoods = torch.empty(ns) else: - log_likelihoods = torch.full(ns, -torch.inf) + log_likelihoods = torch.full((ns,), -torch.inf) jobs = neighbs.items() if show_progress: @@ -532,10 +532,10 @@ def unit_log_likelihoods( return spike_indices, log_likelihoods - def noise_log_likelihoods(self): + def noise_log_likelihoods(self, show_progress=False): if self._noise_log_likelihoods is None: self._noise_six, self._noise_log_likelihoods = self.unit_log_likelihoods( - unit=self.noise_unit, show_progress=True, desc_prefix="Noise " + unit=self.noise_unit, show_progress=show_progress, desc_prefix="Noise " ) return self._noise_six, self._noise_log_likelihoods @@ -544,13 +544,13 @@ def kmeans_split_unit(self, unit_id, debug=False): # unit's channel set unit = self.units[unit_id] indices_full, sp = self.random_spike_data(unit_id, return_full_indices=True) - X = self.data.interp_to_chans(sp, unit) + X = self.data.interp_to_chans(sp, unit.channels) if debug: debug_info = dict(indices_full=indices_full, sp=sp, X=X) # run kmeans with kmeans++ initialization split_labels, responsibilities = kmeans( - X, + X.view(len(X), -1), n_iter=self.kmeans_n_iter, n_components=self.kmeans_k, random_state=self.rg, @@ -593,36 +593,42 @@ def mini_merge( units = [] for label in split_ids: (in_label,) = torch.nonzero(labels == labels, as_tuple=True) - weights = None if weights is None else weights[in_label] + w = None if weights is None else weights[in_label, label] features = spike_data[in_label] - unit = GaussianUnit.from_features(features, weights, **self.unit_args) + unit = GaussianUnit.from_features(features, weights=w, **self.unit_args) units.append(unit) # determine their distances - distances = self.distances(units=units) + distances = self.distances(units=units, show_progress=False) # determine their bimodalities while at once mini-reassigning lls = spike_data.features.new_full((len(units), len(spike_data)), -torch.inf) for j, unit in enumerate(units): - lls[j] = self.unit_log_likelihoods( + inds_, lls_ = self.unit_log_likelihoods( unit=unit, spike_indices=spike_data.indices ) + if lls_ is not None: + lls[j] = lls best_liks, labels = lls.max(dim=0) labels[torch.isinf(best_liks)] = -1 labels = labels.numpy(force=True) kept = np.flatnonzero(labels >= 0) - ids, labels[kept], counts = np.unique(labels[kept]) + ids, labels[kept], counts = np.unique(labels[kept], return_inverse=True, return_counts=True) ids = ids[counts > 0] - units = [u for j, u in enumerate(units) if counts[j]] + counts_dense = np.zeros(len(units), dtype=counts.dtype) + counts_dense[ids] = counts + units = [u for j, u in enumerate(units) if counts_dense[j]] bimodalities = bimodalities_dense( lls.numpy(force=True)[ids], labels, ids=np.arange(ids.size) ) + bimodalities_full = np.full_like(distances, np.inf) + bimodalities_full[ids[:, None], ids[None, :]] = bimodalities # return merged labels distances = combine_distances( distances, self.merge_distance_threshold, - bimodalities, + bimodalities_full, self.merge_bimodality_threshold, sym_function=self.merge_sym_function, ) @@ -653,15 +659,15 @@ def unit_pair_bimodality( masked=True, max_spikes=2048, dt_s=2.0, - debug=False, score_kind="tv", + debug=False, ): if in_units is not None: ina = in_units[id_a] inb = in_units[id_b] else: - (ina,) = torch.nonzero(self.labels == id_a) - (inb,) = torch.nonzero(self.labels == id_b) + (ina,) = torch.nonzero(self.labels == id_a, as_tuple=True) + (inb,) = torch.nonzero(self.labels == id_b, as_tuple=True) if masked: times_a = self.data.times_seconds[ina] @@ -755,7 +761,7 @@ def stack_units(self, units=None, distance_metric=None): if distance_metric is None: kind = self.distance_metric nu, rank, nc = len(units), self.data.rank, self.data.n_channels - means = self.noise_unit.mean.new_zeros((nu, rank, nc)) + means = torch.zeros((nu, rank, nc), device=self.data.device) covs = logdets = None if kind in ("kl_divergence",): covs = means.new_zeros((nu, rank * nc, rank * nc)) @@ -765,6 +771,7 @@ def stack_units(self, units=None, distance_metric=None): if covs is not None: covs[j] = unit.dense_cov() logdets[j] = unit.logdet + return means, covs, logdets # -- modeling class @@ -961,7 +968,9 @@ def divergence( raise ValueError(f"Unknown divergence {kind=}.") def noise_metric_divergence(self, other_means): - dmu = other_means - self.mean + dmu = other_means + if self.mean_kind != "zero": + dmu = dmu - self.mean dmu = dmu.view(len(other_means), -1) noise_cov = self.noise.marginal_covariance() return noise_cov.inv_quad(dmu.T, reduce_inv_quad=False) @@ -969,7 +978,9 @@ def noise_metric_divergence(self, other_means): def kl_divergence(self, other_means, other_covs, other_logdets): """DKL(others || self)""" n = other_means.shape[0] - dmu = other_means - self.mean + dmu = other_means + if self.mean_kind != "zero": + dmu = dmu - self.mean # compute the inverse quad and self log det terms inv_quad, self_logdet = self.cov.inv_quad_logdet( @@ -1201,7 +1212,8 @@ def qda( debug_info=None, ): # "in b not a"-ness - diff = log_liks_b - log_liks_a + if diff is None: + diff = log_liks_b - log_liks_a keep = np.isfinite(diff) if not keep.mean() >= min_overlap: return np.inf @@ -1221,7 +1233,7 @@ def qda( sample_weights=sample_weights, dipscore_only=True, score_kind=score_kind, - debug_info=None, + debug_info=debug_info, ) diff --git a/src/dartsort/cluster/kmeans.py b/src/dartsort/cluster/kmeans.py index 9ec5d4a4..5a3bdc67 100644 --- a/src/dartsort/cluster/kmeans.py +++ b/src/dartsort/cluster/kmeans.py @@ -56,11 +56,12 @@ def kmeans( kmeanspp_initial=kmeanspp_initial, ) # responsibilities, sum to 1 over centroids - e = F.softmax(-0.5 * dists, dim=1) if not n_iter: return labels, e centroids = X[centroid_ixs] + dists = torch.cdist(X, centroids).square_() + e = F.softmax(-0.5 * dists, dim=1) proportions = e.mean(0) for j in range(n_iter): diff --git a/src/dartsort/cluster/merge.py b/src/dartsort/cluster/merge.py index 6fd54d85..6faa05f4 100644 --- a/src/dartsort/cluster/merge.py +++ b/src/dartsort/cluster/merge.py @@ -338,6 +338,7 @@ def cross_match_distance_matrix( template_data, cross_mask, ids_a, ids_b = combine_templates( template_data_a, template_data_b ) + print(f"{ids_a.shape=} {ids_b.shape=} {template_data.templates.shape=}") units, dists, shifts, template_snrs = calculate_merge_distances( template_data, superres_linkage=superres_linkage, @@ -363,6 +364,7 @@ def cross_match_distance_matrix( # (with infs on main diag blocks) a_mask = np.flatnonzero(np.isin(units, ids_a)) b_mask = np.flatnonzero(np.isin(units, ids_b)) + print(f"{a_mask.shape=} {b_mask.shape=} {units.shape=}") Dab = dists[a_mask[:, None], b_mask[None, :]] Dba = dists[b_mask[:, None], a_mask[None, :]] Dstack = np.stack((Dab, Dba.T)) diff --git a/src/dartsort/cluster/stable_features.py b/src/dartsort/cluster/stable_features.py index 614a1c1c..9ba17065 100644 --- a/src/dartsort/cluster/stable_features.py +++ b/src/dartsort/cluster/stable_features.py @@ -53,11 +53,12 @@ def __init__( # neighborhoods module, for querying spikes by channel group self.core_neighborhoods = core_neighborhoods - extract_amp_vecs = torch.linalg.vecnorm(extract_features, dim=1) - amps = extract_amp_vecs.max(1).values + extract_amp_vecs = torch.linalg.vector_norm(extract_features, dim=1) + amps = extract_amp_vecs.nan_to_num().max(1).values # channel neighborhoods and features # if not self.features_on_device, .spike_data() will .to(self.device) + times_s = torch.asarray(self.original_sorting.times_seconds[kept_indices]) if self.features_on_device: self.register_buffer("core_channels", core_channels) self.register_buffer("extract_channels", extract_channels) @@ -65,6 +66,7 @@ def __init__( self.register_buffer("extract_features", extract_features) # self.register_buffer("extract_amp_vecs", extract_amp_vecs) self.register_buffer("amps", amps) + self.register_buffer("times_seconds", times_s) else: self.core_channels = core_channels self.extract_channels = extract_channels @@ -72,6 +74,7 @@ def __init__( self.extract_features = extract_features # self.extract_amp_vecs = extract_amp_vecs self.amps = amps + self.times_seconds = times_s # always on device self.register_buffer("prgeom", prgeom) @@ -379,7 +382,7 @@ def spike_neighborhoods(self, channels, spike_indices, min_coverage=1.0): nhoodv = nhood[nhood < self.n_channels] coverage = torch.isin(nhoodv, channels).sum() / nhoodv.numel() if coverage >= min_coverage: - (member_indices,) = torch.nonzero(spike_ids == j) + (member_indices,) = torch.nonzero(spike_ids == j, as_tuple=True) neighborhood_info[j] = (nhood, member_indices) n_spikes += member_indices.numel() return neighborhood_info, n_spikes @@ -397,7 +400,7 @@ def interp_to_chans( ): source_pos = prgeom[spike_data.channels] target_pos = prgeom[channels] - shape = spike_data.n_spikes, *target_pos.shape + shape = len(spike_data), *target_pos.shape target_pos = target_pos[None].broadcast_to(shape) return interpolation_util.kernel_interpolate( spike_data.features, diff --git a/src/dartsort/util/comparison.py b/src/dartsort/util/comparison.py index 48bfa39d..619004f7 100644 --- a/src/dartsort/util/comparison.py +++ b/src/dartsort/util/comparison.py @@ -118,9 +118,7 @@ def _calculate_template_distances(self): matches = self.comparison.best_match_12.astype(int).values matched = np.flatnonzero(matches >= 0) matches = matches[matched] - print(f"{matched.sum()=} {matched.shape=}") tested_td = self.tested_analysis.coarse_template_data[matches] - print(f"{tested_td.templates.shape=}") dists, shifts, snrs_a, snrs_b = merge.cross_match_distance_matrix( gt_td, @@ -129,7 +127,6 @@ def _calculate_template_distances(self): n_jobs=self.n_jobs, device=self.device, ) - print(f"{dists.shape=}") self._template_distances = np.full((nugt, nugt), np.inf) self._template_distances[np.arange(nugt)[:, None], matched[None, :]] = dists diff --git a/src/dartsort/util/noise_util.py b/src/dartsort/util/noise_util.py index c3f74b60..0a5bf4ad 100644 --- a/src/dartsort/util/noise_util.py +++ b/src/dartsort/util/noise_util.py @@ -349,9 +349,16 @@ def __init__( # precompute stuff self._full_cov = None + self._logdet = None self.register_buffer("mean_full", self.mean_rc().clone().detach()) self.cache = {} + @property + def logdet(self): + if self._logdet is None: + self.marginal_covariance() + return self._logdet + @property def device(self): return self.global_std.device @@ -383,7 +390,7 @@ def marginal_covariance(self, channels=slice(None), cache_key=None): if channels == slice(None): if self._full_cov is None: self._full_cov = self._marginal_covariance() - self.logdet = self._full_cov.logdet() + self._logdet = self._full_cov.logdet() return self._full_cov cov = self._marginal_covariance(channels) if cache_key is not None: diff --git a/src/dartsort/vis/gmm.py b/src/dartsort/vis/gmm.py index b865d35c..d75f2370 100644 --- a/src/dartsort/vis/gmm.py +++ b/src/dartsort/vis/gmm.py @@ -9,7 +9,7 @@ from ..cluster import gaussian_mixture from ..util import spikeio from ..util.multiprocessing_util import (CloudpicklePoolExecutor, - ThreadPoolExecutor, get_pool) + ThreadPoolExecutor, get_pool, cloudpickle) from . import analysis_plots, gmm_helpers, layout, unit from .colors import glasbey1024 from .waveforms import geomplot @@ -38,7 +38,7 @@ def __init__(self, bin_ms=0.1, max_ms=5): def draw(self, panel, gmm, unit_id): axis = panel.subplots() - times_s = gmm.data.times_seconds[gmm.labels == unit_id].numpy(force=True) + times_s = gmm.data.times_seconds[gmm.labels == unit_id] dt_ms = np.diff(times_s) * 1000 bin_edges = np.arange(0, self.max_ms + self.bin_ms, self.bin_ms) counts, _ = np.histogram(dt_ms, bin_edges) @@ -65,13 +65,13 @@ def draw(self, panel, gmm, unit_id): s = ax.scatter(*xy[unique_ixs].T, c=counts, lw=0, cmap=self.cmap) plt.colorbar(s, ax=ax, shrink=0.3) ax.scatter( - *xy[gmm[unit_id].channels.numpy(force=True)].T, + *xy[gmm.units[unit_id].channels.numpy(force=True)].T, color="r", lw=1, fc="none", ) ax.scatter( - *xy[np.atleast_1d(gmm[unit_id].snr.argmax().numpy(force=True))].T, + *xy[np.atleast_1d(gmm.units[unit_id].snr.argmax().numpy(force=True))].T, color="g", lw=0, ) @@ -104,7 +104,7 @@ def draw(self, panel, gmm, unit_id, axes=None): ax.axis("off") sp = gmm.random_spike_data(unit_id, with_reconstructions=True) - maa = sp.waveforms.abs().max() + maa = sp.waveforms.abs().nan_to_num().max() lines, chans = geomplot( sp.waveforms, channels=sp.channels, @@ -118,7 +118,7 @@ def draw(self, panel, gmm, unit_id, axes=None): ) chans = torch.tensor(list(chans)) tup = gaussian_mixture.to_full_probe( - sp.features, weights=None, n_channels=gmm.data.n_channels, storage=None + sp, weights=None, n_channels=gmm.data.n_channels, storage=None ) features_full, weights_full, count_data, weights_normalized = tup emp_mean = torch.nanmean(features_full, dim=0)[:, chans] @@ -158,7 +158,8 @@ def draw(self, panel, gmm, unit_id, axes=None): (in_unit,) = torch.nonzero(gmm.labels == unit_id, as_tuple=True) if not in_unit.numel(): return - liks = gmm.unit_log_likelihoods(unit_id, spike_indices=in_unit) + inds_, liks = gmm.unit_log_likelihoods(unit_id, spike_indices=in_unit) + assert torch.equal(inds_, in_unit) nliks = gmm.noise_log_likelihoods()[1][in_unit] t = gmm.data.times_seconds[in_unit] dt_ms = np.diff(t) * 1000 @@ -219,7 +220,7 @@ def draw(self, panel, gmm, unit_id, axes=None): ax_dist = analysis_plots.distance_matrix_dendro( fig_dist, split_info["distances"], - unit_ids=split_ids, + # unit_ids=split_ids, dendrogram_linkage=None, show_unit_labels=True, vmax=1.0, @@ -232,7 +233,7 @@ def draw(self, panel, gmm, unit_id, axes=None): ax_bimod = analysis_plots.distance_matrix_dendro( fig_bimods, split_info["bimodalities"], - unit_ids=split_ids, + # unit_ids=split_ids, dendrogram_linkage=None, show_unit_labels=True, vmax=0.5, @@ -244,7 +245,7 @@ def draw(self, panel, gmm, unit_id, axes=None): ax_centroids = fig_mean.subplots() mainchan = gmm.units[unit_id].snr.argmax() ax_centroids.axhline(0, color="k", lw=0.8) - sns.despine(ax_centroids, left=False, right=True, bottom=True, top=True) + sns.despine(ax=ax_centroids, left=False, right=True, bottom=True, top=True) for subid, subunit in zip(split_ids, split_info["units"]): subm = subunit.mean[:, mainchan] subm = gmm.data.tpca._inverse_transform_in_probe(subm[None])[0] @@ -284,6 +285,7 @@ def draw(self, panel, gmm, unit_id): # means on core channels chans = gmm.units[unit_id].snr.argmax() chans = torch.cdist(gmm.data.prgeom[chans[None]], gmm.data.prgeom) + chans = chans.view(-1) (chans,) = torch.nonzero(chans <= gmm.data.core_radius, as_tuple=True) means = [gmm.units[u].mean[:, chans].numpy(force=True) for u in neighbors] @@ -294,7 +296,7 @@ def draw(self, panel, gmm, unit_id): .broadcast_to(len(means), *chans.shape) .numpy(force=True), geom=gmm.data.prgeom.numpy(force=True), - colors=glasbey1024[neighbors], + colors=glasbey1024[neighbors.numpy(force=True)], show_zero=False, ax=ax, ) @@ -317,7 +319,7 @@ def draw(self, panel, gmm, unit_id): ax = analysis_plots.distance_matrix_dendro( panel, distances, - unit_ids=neighbors, + unit_ids=neighbors.numpy(force=True), dendrogram_linkage=None, show_unit_labels=True, vmax=0.5, @@ -340,8 +342,11 @@ def draw(self, panel, gmm, unit_id): assert neighbors[0] == unit_id log_liks = gmm.log_likelihoods(unit_ids=neighbors) labels = gaussian_mixture.loglik_reassign(log_liks, has_noise_unit=True) + log_liks = gaussian_mixture.coo_to_torch(log_liks, torch.float) kept = labels >= 0 - labels = torch.where(kept, neighbors[labels[kept]], -1) + labels_ = np.full_like(labels, -1) + labels_[kept] = neighbors[labels[kept]].numpy(force=True) + labels = labels_ others = neighbors[1:] axes = panel.subplots(nrows=len(others), ncols=2) @@ -353,6 +358,7 @@ def draw(self, panel, gmm, unit_id): log_liks, loglik_ix_a=0, loglik_ix_b=j + 1, + debug=True, ) scatter_ax, bimod_ax = axes_row @@ -360,6 +366,9 @@ def draw(self, panel, gmm, unit_id): scatter_ax.scatter(bimod_info["xi"], bimod_info["xj"], s=3, lw=0, c=c) scatter_ax.set_xlabel(unit_id, color=glasbey1024[unit_id]) scatter_ax.set_xlabel(other_id, color=glasbey1024[other_id]) + + if "samples" not in bimod_info: + continue bimod_ax.hist(bimod_info["samples"], color="gray", **histkw) bimod_ax.hist( bimod_info["samples"], @@ -405,7 +414,6 @@ def make_unit_gmm_summary( # notify plots of global params for p in plots: p.notify_global_params( - time_range=gmm.t_bounds, **other_global_params, ) @@ -460,7 +468,6 @@ def make_all_gmm_summaries( if use_threads: cls = ThreadPoolExecutor n_jobs, Executor, context = get_pool(n_jobs, cls=cls) - from cloudpickle import dumps initargs = ( gmm, @@ -475,7 +482,7 @@ def make_all_gmm_summaries( global_params, ) if ispar and not use_threads: - initargs = (dumps(initargs),) + initargs = (cloudpickle.dumps(initargs),) with Executor( max_workers=n_jobs, mp_context=context, diff --git a/src/dartsort/vis/gmm_helpers.py b/src/dartsort/vis/gmm_helpers.py index da93da17..fc613a72 100644 --- a/src/dartsort/vis/gmm_helpers.py +++ b/src/dartsort/vis/gmm_helpers.py @@ -5,7 +5,7 @@ def get_neighbors(gmm, unit_id, n_neighbors=5): - means, covs, logdets = gmm.stack_units + means, covs, logdets = gmm.stack_units() dists = gmm.units[unit_id].divergence(means, covs, logdets, kind=gmm.distance_metric) dists = dists.view(-1) order = torch.argsort(dists) diff --git a/src/dartsort/vis/gt.py b/src/dartsort/vis/gt.py index e0f8c176..d0c49f5a 100644 --- a/src/dartsort/vis/gt.py +++ b/src/dartsort/vis/gt.py @@ -552,6 +552,19 @@ def draw(self, panel, comparison): TrimmedTemplateDistanceMatrix(), ) +gt_overview_plots_no_temp_dist = ( + MetricRegPlot(x="gt_ptp_amplitude", y="accuracy", log_x=True), + MetricRegPlot(x="gt_ptp_amplitude", y="recall", color="r", log_x=True), + MetricRegPlot(x="gt_ptp_amplitude", y="precision", color="g", log_x=True), + MetricRegPlot(x="gt_firing_rate", y="accuracy"), + MetricRegPlot(x="gt_firing_rate", y="recall", color="r"), + MetricRegPlot(x="gt_firing_rate", y="precision", color="g"), + MetricRegPlot(x="gt_ptp_amplitude", y="unsorted_recall", color="purple", log_x=True), + box, + MetricDistribution(xs=("recall", "accuracy", "temp_dist", "precision")), + TrimmedAgreementMatrix(), +) + # multi comparisons stuff # box and whisker between sorters diff --git a/src/dartsort/vis/waveforms.py b/src/dartsort/vis/waveforms.py index 616b727b..9d28388e 100644 --- a/src/dartsort/vis/waveforms.py +++ b/src/dartsort/vis/waveforms.py @@ -55,7 +55,7 @@ def geomplot( raise ValueError(f"Bad shapes: {waveforms.shape=}, {max_channels.shape=}, {C=}") channels = channel_index[max_channels] else: - n_channels = geom.shape[0] + n_channels = geom[np.isfinite(geom).all(1)].shape[0] T = waveforms.shape[1] assert channels.shape[0] == waveforms.shape[0] assert channels.shape[1] == waveforms.shape[-1] @@ -64,7 +64,7 @@ def geomplot( valid = np.isfinite(geom).all(1) z_uniq, z_ix = np.unique(geom[valid, 1], return_inverse=True) for i in z_ix: - x_uniq = np.unique(geom[z_ix == i, 0]) + x_uniq = np.unique(geom[valid][z_ix == i, 0]) if x_uniq.size > 1: break else: @@ -78,7 +78,7 @@ def geomplot( T / inter_chan_x / x_extension, max_abs_amp / inter_chan_z / z_extension, ] - geom_plot = geom * geom_scales + geom_plot = geom[valid] * geom_scales t_domain = np.linspace(-T // 2, T // 2, num=T) # -- and, plot