diff --git a/src/dartsort/vis/gmm.py b/src/dartsort/vis/gmm.py index 9b91feda..70176803 100644 --- a/src/dartsort/vis/gmm.py +++ b/src/dartsort/vis/gmm.py @@ -38,7 +38,7 @@ class DPCSplitPlot(GMMPlot): width = 5 height = 5 - def __init__(self, spike_kind="train", feature="pca", inherit_chans=True, common_chans=True, dist_vmax=1., cmap=plt.cm.rainbow): + def __init__(self, spike_kind="split", feature="pca", inherit_chans=True, common_chans=True, dist_vmax=1., cmap=plt.cm.rainbow): self.spike_kind = spike_kind assert feature in ("pca", "spread_amp") self.feature = feature @@ -58,7 +58,7 @@ def draw(self, panel, gmm, unit_id): for sl, data in gmm.batches(in_unit): gmm[unit_id].residual_embed(**data, out=features[sl]) z = features[:, : gmm.dpc_split_kw.rank].numpy(force=True) - elif self.spike_kind == "train": + elif self.spike_kind == "split": _, in_unit, z = gmm.split_features(unit_id) elif self.spike_kind == "global": in_unit, data = gmm.get_training_data(unit_id) @@ -76,7 +76,7 @@ def draw(self, panel, gmm, unit_id): ) z = loadings.numpy(force=True) elif self.feature == "spread_amp": - assert self.spike_kind in ("train", "global") + assert self.spike_kind in ("split", "global") in_unit, data = gmm.get_training_data(unit_id) waveforms = data["waveforms"] channel_norms = torch.sqrt(torch.nan_to_num(waveforms.square().sum(1))) @@ -124,8 +124,8 @@ def draw(self, panel, gmm, unit_id): axes[0].set_xlabel("spread") axes[0].set_ylabel("amp") - axes = panel_bottom.subplots(ncols=2) - axes = {'d': axes[0], 'e': axes[1]} + axes = panel_bottom.subplots(ncols=3) + axes = {'d': axes[0], 'f': axes[1], 'e': axes[2]} labels = dens['labels'][inv] in_unit = torch.from_numpy(in_unit) ids = np.unique(labels) @@ -160,36 +160,38 @@ def draw(self, panel, gmm, unit_id): ju = [(j, u) for j, u in enumerate(new_units) if u.n_chans_unit] # plot new unit maxchan wfs and old one in black - ax = axes["d"] - ax.axhline(0, c="k", lw=0.8) - all_means = [] - for j, unit in ju: - if unit.do_interp: - times = unit.interp.grid.squeeze() - if self.fitted_only: - times = times[unit.interp.grid_fitted] - else: - times = torch.tensor([sum(gmm.t_bounds) / 2]).to(gmm.device) - times = torch.atleast_1d(times) - - chans = torch.full((times.numel(),), unit.max_channel, device=times.device) - means = unit.get_means(times).to(gmm.device) - if j > 0: - all_means.append(means.mean(0)) - means = unit.to_waveform_channels(means, waveform_channels=chans[:, None]) - means = means[..., 0] - means = gmm.data.tpca._inverse_transform_in_probe(means) - means = means.numpy(force=True) - color = glasbey1024[j] - - lines = np.stack( - (np.broadcast_to(np.arange(means.shape[1])[None], means.shape), means), - axis=-1, - ) - ax.add_collection(LineCollection(lines, colors=color, lw=1)) - ax.autoscale_view() - ax.set_xticks([]) - ax.spines[["top", "right", "bottom"]].set_visible(False) + gmc = gmm[unit_id].max_channel + for ax, pick in zip((axes["d"], axes["f"]), ("unit", "shared")): + ax.axhline(0, c="k", lw=0.8) + all_means = [] + for j, unit in ju: + if unit.do_interp: + times = unit.interp.grid.squeeze() + if self.fitted_only: + times = times[unit.interp.grid_fitted] + else: + times = torch.tensor([sum(gmm.t_bounds) / 2]).to(gmm.device) + times = torch.atleast_1d(times) + + chans = torch.full((times.numel(),), unit.max_channel if pick == "unit" else gmc, device=times.device) + means = unit.get_means(times).to(gmm.device) + if j > 0: + all_means.append(means.mean(0)) + means = unit.to_waveform_channels(means, waveform_channels=chans[:, None]) + means = means[..., 0] + means = gmm.data.tpca._inverse_transform_in_probe(means) + means = means.numpy(force=True) + color = glasbey1024[j] + + lines = np.stack( + (np.broadcast_to(np.arange(means.shape[1])[None], means.shape), means), + axis=-1, + ) + ax.add_collection(LineCollection(lines, colors=color, lw=1)) + ax.autoscale_view() + ax.set_xticks([]) + ax.spines[["top", "right", "bottom"]].set_visible(False) + ax.set_title(pick) # plot distance matrix kind = gmm.merge_metric @@ -424,6 +426,7 @@ def __init__( n_iter=50, common_chans=True, inherit_chans=True, + impute_before_center=False, min_overlap=0.0, ): self.cmap = cmap @@ -439,6 +442,7 @@ def __init__( self.n_iter = n_iter self.common_chans = common_chans self.inherit_chans = inherit_chans + self.impute_before_center = impute_before_center self.min_overlap = min_overlap self.merge_on_waveform_radius = merge_on_waveform_radius if self.scaled: @@ -458,7 +462,7 @@ def draw(self, panel, gmm, unit_id): if ids.size > 1: top, bottom = panel.subfigures(nrows=2) ax_top = top.subplots() - axes = bottom.subplot_mosaic("de") + axes = bottom.subplot_mosaic("dfe") else: ax_top = panel.subplots() @@ -482,7 +486,7 @@ def draw(self, panel, gmm, unit_id): for j, label in enumerate(ids): u = spike_interp.InterpUnit( do_interp=False, - **gmm.unit_kw, + **gmm.unit_kw | dict(impute_before_center=self.impute_before_center), ) inu = in_unit[np.flatnonzero(labels == label)] w = None if weights is None else weights[labels == label, j] @@ -504,36 +508,38 @@ def draw(self, panel, gmm, unit_id): ju = [(j, u) for j, u in enumerate(new_units) if u.n_chans_unit] # plot new unit maxchan wfs and old one in black - ax = axes["d"] - ax.axhline(0, c="k", lw=0.8) - all_means = [] - for j, unit in ju: - if unit.do_interp: - times = unit.interp.grid.squeeze() - if self.fitted_only: - times = times[unit.interp.grid_fitted] - else: - times = torch.tensor([sum(gmm.t_bounds) / 2]).to(gmm.device) - times = torch.atleast_1d(times) - - chans = torch.full((times.numel(),), unit.max_channel, device=times.device) - means = unit.get_means(times).to(gmm.device) - if j > 0: - all_means.append(means.mean(0)) - means = unit.to_waveform_channels(means, waveform_channels=chans[:, None]) - means = means[..., 0] - means = gmm.data.tpca._inverse_transform_in_probe(means) - means = means.numpy(force=True) - color = glasbey1024[j] - - lines = np.stack( - (np.broadcast_to(np.arange(means.shape[1])[None], means.shape), means), - axis=-1, - ) - ax.add_collection(LineCollection(lines, colors=color, lw=1)) - ax.autoscale_view() - ax.set_xticks([]) - ax.spines[["top", "right", "bottom"]].set_visible(False) + gmc = gmm[unit_id].max_channel + for ax, pick in zip((axes["d"], axes["f"]), ("unit", "shared")): + ax.axhline(0, c="k", lw=0.8) + all_means = [] + for j, unit in ju: + if unit.do_interp: + times = unit.interp.grid.squeeze() + if self.fitted_only: + times = times[unit.interp.grid_fitted] + else: + times = torch.tensor([sum(gmm.t_bounds) / 2]).to(gmm.device) + times = torch.atleast_1d(times) + + chans = torch.full((times.numel(),), unit.max_channel if pick == "unit" else gmc, device=times.device) + means = unit.get_means(times).to(gmm.device) + if j > 0: + all_means.append(means.mean(0)) + means = unit.to_waveform_channels(means, waveform_channels=chans[:, None]) + means = means[..., 0] + means = gmm.data.tpca._inverse_transform_in_probe(means) + means = means.numpy(force=True) + color = glasbey1024[j] + + lines = np.stack( + (np.broadcast_to(np.arange(means.shape[1])[None], means.shape), means), + axis=-1, + ) + ax.add_collection(LineCollection(lines, colors=color, lw=1)) + ax.autoscale_view() + ax.set_xticks([]) + ax.spines[["top", "right", "bottom"]].set_visible(False) + ax.set_title(pick) # plot distance matrix kind = gmm.merge_metric @@ -597,7 +603,7 @@ class HDBScanSplitPlot(GMMPlot): width = 2 height = 2 - def __init__(self, spike_kind="train"): + def __init__(self, spike_kind="split"): self.spike_kind = spike_kind def draw(self, panel, gmm, unit_id): @@ -609,7 +615,7 @@ def draw(self, panel, gmm, unit_id): for sl, data in gmm.batches(in_unit): gmm[unit_id].residual_embed(**data, out=features[sl]) z = features[:, : gmm.dpc_split_kw.rank].numpy(force=True) - elif self.spike_kind == "train": + elif self.spike_kind == "split": _, in_unit, z = gmm.split_features(unit_id) elif self.spike_kind == "global": in_unit, data = gmm.get_training_data(unit_id) @@ -694,7 +700,7 @@ class AmplitudesOverTimePlot(GMMPlot): width = 5 height = 3 - def __init__(self, kinds=("recon", 'model'), colors="bkrg"): + def __init__(self, kinds=("recon", "maxchan_energy", 'model'), colors="bkrg"): self.kinds = kinds self.colors = dict(zip(kinds, colors)) @@ -734,6 +740,7 @@ def draw(self, panel, gmm, unit_id): t = utd["times"].numpy(force=True) for kind, a in show.items(): ax.scatter(t, a, c=self.colors[kind], s=3, lw=0, label=kind) + ax.legend(loc="upper left") ax.set_ylabel("amplitude") @@ -788,19 +795,28 @@ def draw(self, panel, gmm, unit_id): overlaps = overlaps.numpy(force=True) times = utd["times"][spike_ix].numpy(force=True) - ax = panel.subplots() + ax, ay = panel.subplots(ncols=2, width_ratios=[2, 1]) + ay.grid(True) for j, (kind, b) in enumerate(badnesses.items()): + b = b.numpy(force=True) + c = self.colors[j] ax.scatter( times, - b.numpy(force=True), + b, alpha=overlaps, s=3, - c=self.colors[j], + c=c, lw=0, label=kind, ) + ay.ecdf(b, lw=1, color=c) + ay.axvline(np.mean(b), color=c, lw=1) + ay.axvline(np.median(b), color=c, ls="--", lw=1) + ay.set_xlabel(kind) + ay.set_ylabel("cdf") ax.legend(loc="upper left") ax.set_ylabel("badness") + ay.set_yticks([0, 0.25, 0.5, 0.75, 1]) class FeaturesVsBadnessesPlot(GMMPlot): @@ -1352,9 +1368,9 @@ def __init__(self, n_neighbors=5): self.n_neighbors = n_neighbors def get_neighbors(self, gmm, unit_id, reversed=False): - unit_dists = gmm.central_divergences(units_a=[unit_id])[0] + unit_dists = gmm.central_divergences(units_a=torch.tensor([unit_id]))[0] unit_ids = gmm.unit_ids() - neighbors = torch.argsort(unit_dists) + neighbors = torch.argsort((unit_ids != unit_id).to(unit_dists) + unit_dists) assert unit_ids[neighbors[0]] == unit_id neighbors = neighbors[: self.n_neighbors + 1] neighbors = neighbors[torch.isfinite(unit_dists[neighbors])] @@ -1561,12 +1577,13 @@ class NeighborBimodality(GMMMergePlot): width = 3 height = 10 - def __init__(self, n_neighbors=5, badness_kind="1-r^2", do_reg=False, masked=False, mask_radius_s=5.0): + def __init__(self, n_neighbors=5, badness_kind="1-r^2", do_reg=False, masked=False, mask_radius_s=5.0, impute_missing=False): self.n_neighbors = n_neighbors self.badness_kind = badness_kind self.do_reg = do_reg self.masked = masked self.mask_radius_s = mask_radius_s + self.impute_missing = impute_missing def draw(self, panel, gmm, unit_id): from isosplit import isocut, dipscore_at @@ -1625,6 +1642,7 @@ def draw(self, panel, gmm, unit_id): unit_ids=[unit_id, u], show_progress=False, kind=self.badness_kind, + impute_missing=self.impute_missing, ) a = np.full(badness.shape, np.inf) a[badness.coords] = badness.data @@ -1666,7 +1684,8 @@ def draw(self, panel, gmm, unit_id): ds_ud = f"{ds_ud:0.3f}".lstrip("0").rstrip("0") ds_udw = f"{ds_udw:0.3f}".lstrip("0").rstrip("0") mstr = "masked " if self.masked else "" - row[1].set_title(f"{mstr} u{ds_ud} uw{ds_udw}", fontsize=7) + istr = "imp " if self.impute_missing else "" + row[1].set_title(f"{mstr}{istr} u{ds_ud} uw{ds_udw}", fontsize=7) if self.do_reg: sns.regplot( @@ -1864,7 +1883,7 @@ def draw(self, panel, gmm, unit_id): unit.bar(axes[0, 1], alags, acg, fill=True, fc=colors[0]) axes[0, 0].axis("off") axes[0, 1].set_ylabel(f"my acg {neighbor_ids[0]}") - + j = 0 for j, ub in enumerate(neighbor_ids[1:], start=1): their_st = gmm.data.times_samples[gmm.labels == ub] @@ -1893,7 +1912,7 @@ def draw(self, panel, gmm, unit_id): # HDBScanSplitPlot(spike_kind="residual_full"), # HDBScanSplitPlot(), # ZipperSplitPlot(), - KMeansPPSPlitPlot(), + KMeansPPSPlitPlot(inherit_chans=True, n_clust=10, impute_before_center=True), GridMeansSingleChanPlot(), InputWaveformsSingleChanPlot(), # InputWaveformsSingleChanOverTimePlot(channel="unit"), @@ -1903,7 +1922,7 @@ def draw(self, panel, gmm, unit_id): BadnessesOverTimePlot(), EmbedsOverTimePlot(), # DPCSplitPlot(spike_kind="residual_full"), - DPCSplitPlot(spike_kind="train"), + DPCSplitPlot(spike_kind="split"), # DPCSplitPlot(spike_kind="global"), # DPCSplitPlot(spike_kind="global", feature="spread_amp"), FeaturesVsBadnessesPlot(), @@ -1928,8 +1947,8 @@ def draw(self, panel, gmm, unit_id): ISICorner(bin_ms=0.25), ISICorner(bin_ms=0.5, max_ms=8, tick_step=2), # NeighborBimodality(), - NeighborBimodality(badness_kind="diagz", masked=True), NeighborBimodality(badness_kind="1-r^2", masked=True), + NeighborBimodality(badness_kind="1-r^2", masked=True, impute_missing=True), CCGColumn(), # NeighborBimodality(badness_kind="1-scaledr^2", masked=True), ) diff --git a/src/dartsort/vis/unit.py b/src/dartsort/vis/unit.py index 25349af4..227d7cf4 100644 --- a/src/dartsort/vis/unit.py +++ b/src/dartsort/vis/unit.py @@ -421,10 +421,10 @@ def draw(self, panel, sorting_analysis, unit_id, axis=None): ) trough_offset_samples = self.trough_offset_samples spike_length_samples = self.spike_length_samples - if tslice.start is not None: + if tslice is not None and tslice.start is not None: trough_offset_samples = self.trough_offset_samples - tslice.start spike_length_samples = self.spike_length_samples - tslice.start - if tslice.stop is not None: + if tslice is not None and tslice.stop is not None: spike_length_samples = tslice.stop - tslice.start max_abs_amp = None