Skip to content

Commit

Permalink
Latest vis
Browse files Browse the repository at this point in the history
  • Loading branch information
cwindolf committed Sep 3, 2024
1 parent 2bfe1d2 commit 6b51274
Show file tree
Hide file tree
Showing 2 changed files with 102 additions and 83 deletions.
181 changes: 100 additions & 81 deletions src/dartsort/vis/gmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class DPCSplitPlot(GMMPlot):
width = 5
height = 5

def __init__(self, spike_kind="train", feature="pca", inherit_chans=True, common_chans=True, dist_vmax=1., cmap=plt.cm.rainbow):
def __init__(self, spike_kind="split", feature="pca", inherit_chans=True, common_chans=True, dist_vmax=1., cmap=plt.cm.rainbow):
self.spike_kind = spike_kind
assert feature in ("pca", "spread_amp")
self.feature = feature
Expand All @@ -58,7 +58,7 @@ def draw(self, panel, gmm, unit_id):
for sl, data in gmm.batches(in_unit):
gmm[unit_id].residual_embed(**data, out=features[sl])
z = features[:, : gmm.dpc_split_kw.rank].numpy(force=True)
elif self.spike_kind == "train":
elif self.spike_kind == "split":
_, in_unit, z = gmm.split_features(unit_id)
elif self.spike_kind == "global":
in_unit, data = gmm.get_training_data(unit_id)
Expand All @@ -76,7 +76,7 @@ def draw(self, panel, gmm, unit_id):
)
z = loadings.numpy(force=True)
elif self.feature == "spread_amp":
assert self.spike_kind in ("train", "global")
assert self.spike_kind in ("split", "global")
in_unit, data = gmm.get_training_data(unit_id)
waveforms = data["waveforms"]
channel_norms = torch.sqrt(torch.nan_to_num(waveforms.square().sum(1)))
Expand Down Expand Up @@ -124,8 +124,8 @@ def draw(self, panel, gmm, unit_id):
axes[0].set_xlabel("spread")
axes[0].set_ylabel("amp")

axes = panel_bottom.subplots(ncols=2)
axes = {'d': axes[0], 'e': axes[1]}
axes = panel_bottom.subplots(ncols=3)
axes = {'d': axes[0], 'f': axes[1], 'e': axes[2]}
labels = dens['labels'][inv]
in_unit = torch.from_numpy(in_unit)
ids = np.unique(labels)
Expand Down Expand Up @@ -160,36 +160,38 @@ def draw(self, panel, gmm, unit_id):
ju = [(j, u) for j, u in enumerate(new_units) if u.n_chans_unit]

# plot new unit maxchan wfs and old one in black
ax = axes["d"]
ax.axhline(0, c="k", lw=0.8)
all_means = []
for j, unit in ju:
if unit.do_interp:
times = unit.interp.grid.squeeze()
if self.fitted_only:
times = times[unit.interp.grid_fitted]
else:
times = torch.tensor([sum(gmm.t_bounds) / 2]).to(gmm.device)
times = torch.atleast_1d(times)

chans = torch.full((times.numel(),), unit.max_channel, device=times.device)
means = unit.get_means(times).to(gmm.device)
if j > 0:
all_means.append(means.mean(0))
means = unit.to_waveform_channels(means, waveform_channels=chans[:, None])
means = means[..., 0]
means = gmm.data.tpca._inverse_transform_in_probe(means)
means = means.numpy(force=True)
color = glasbey1024[j]

lines = np.stack(
(np.broadcast_to(np.arange(means.shape[1])[None], means.shape), means),
axis=-1,
)
ax.add_collection(LineCollection(lines, colors=color, lw=1))
ax.autoscale_view()
ax.set_xticks([])
ax.spines[["top", "right", "bottom"]].set_visible(False)
gmc = gmm[unit_id].max_channel
for ax, pick in zip((axes["d"], axes["f"]), ("unit", "shared")):
ax.axhline(0, c="k", lw=0.8)
all_means = []
for j, unit in ju:
if unit.do_interp:
times = unit.interp.grid.squeeze()
if self.fitted_only:
times = times[unit.interp.grid_fitted]
else:
times = torch.tensor([sum(gmm.t_bounds) / 2]).to(gmm.device)
times = torch.atleast_1d(times)

chans = torch.full((times.numel(),), unit.max_channel if pick == "unit" else gmc, device=times.device)
means = unit.get_means(times).to(gmm.device)
if j > 0:
all_means.append(means.mean(0))
means = unit.to_waveform_channels(means, waveform_channels=chans[:, None])
means = means[..., 0]
means = gmm.data.tpca._inverse_transform_in_probe(means)
means = means.numpy(force=True)
color = glasbey1024[j]

lines = np.stack(
(np.broadcast_to(np.arange(means.shape[1])[None], means.shape), means),
axis=-1,
)
ax.add_collection(LineCollection(lines, colors=color, lw=1))
ax.autoscale_view()
ax.set_xticks([])
ax.spines[["top", "right", "bottom"]].set_visible(False)
ax.set_title(pick)

# plot distance matrix
kind = gmm.merge_metric
Expand Down Expand Up @@ -424,6 +426,7 @@ def __init__(
n_iter=50,
common_chans=True,
inherit_chans=True,
impute_before_center=False,
min_overlap=0.0,
):
self.cmap = cmap
Expand All @@ -439,6 +442,7 @@ def __init__(
self.n_iter = n_iter
self.common_chans = common_chans
self.inherit_chans = inherit_chans
self.impute_before_center = impute_before_center
self.min_overlap = min_overlap
self.merge_on_waveform_radius = merge_on_waveform_radius
if self.scaled:
Expand All @@ -458,7 +462,7 @@ def draw(self, panel, gmm, unit_id):
if ids.size > 1:
top, bottom = panel.subfigures(nrows=2)
ax_top = top.subplots()
axes = bottom.subplot_mosaic("de")
axes = bottom.subplot_mosaic("dfe")
else:
ax_top = panel.subplots()

Expand All @@ -482,7 +486,7 @@ def draw(self, panel, gmm, unit_id):
for j, label in enumerate(ids):
u = spike_interp.InterpUnit(
do_interp=False,
**gmm.unit_kw,
**gmm.unit_kw | dict(impute_before_center=self.impute_before_center),
)
inu = in_unit[np.flatnonzero(labels == label)]
w = None if weights is None else weights[labels == label, j]
Expand All @@ -504,36 +508,38 @@ def draw(self, panel, gmm, unit_id):
ju = [(j, u) for j, u in enumerate(new_units) if u.n_chans_unit]

# plot new unit maxchan wfs and old one in black
ax = axes["d"]
ax.axhline(0, c="k", lw=0.8)
all_means = []
for j, unit in ju:
if unit.do_interp:
times = unit.interp.grid.squeeze()
if self.fitted_only:
times = times[unit.interp.grid_fitted]
else:
times = torch.tensor([sum(gmm.t_bounds) / 2]).to(gmm.device)
times = torch.atleast_1d(times)

chans = torch.full((times.numel(),), unit.max_channel, device=times.device)
means = unit.get_means(times).to(gmm.device)
if j > 0:
all_means.append(means.mean(0))
means = unit.to_waveform_channels(means, waveform_channels=chans[:, None])
means = means[..., 0]
means = gmm.data.tpca._inverse_transform_in_probe(means)
means = means.numpy(force=True)
color = glasbey1024[j]

lines = np.stack(
(np.broadcast_to(np.arange(means.shape[1])[None], means.shape), means),
axis=-1,
)
ax.add_collection(LineCollection(lines, colors=color, lw=1))
ax.autoscale_view()
ax.set_xticks([])
ax.spines[["top", "right", "bottom"]].set_visible(False)
gmc = gmm[unit_id].max_channel
for ax, pick in zip((axes["d"], axes["f"]), ("unit", "shared")):
ax.axhline(0, c="k", lw=0.8)
all_means = []
for j, unit in ju:
if unit.do_interp:
times = unit.interp.grid.squeeze()
if self.fitted_only:
times = times[unit.interp.grid_fitted]
else:
times = torch.tensor([sum(gmm.t_bounds) / 2]).to(gmm.device)
times = torch.atleast_1d(times)

chans = torch.full((times.numel(),), unit.max_channel if pick == "unit" else gmc, device=times.device)
means = unit.get_means(times).to(gmm.device)
if j > 0:
all_means.append(means.mean(0))
means = unit.to_waveform_channels(means, waveform_channels=chans[:, None])
means = means[..., 0]
means = gmm.data.tpca._inverse_transform_in_probe(means)
means = means.numpy(force=True)
color = glasbey1024[j]

lines = np.stack(
(np.broadcast_to(np.arange(means.shape[1])[None], means.shape), means),
axis=-1,
)
ax.add_collection(LineCollection(lines, colors=color, lw=1))
ax.autoscale_view()
ax.set_xticks([])
ax.spines[["top", "right", "bottom"]].set_visible(False)
ax.set_title(pick)

# plot distance matrix
kind = gmm.merge_metric
Expand Down Expand Up @@ -597,7 +603,7 @@ class HDBScanSplitPlot(GMMPlot):
width = 2
height = 2

def __init__(self, spike_kind="train"):
def __init__(self, spike_kind="split"):
self.spike_kind = spike_kind

def draw(self, panel, gmm, unit_id):
Expand All @@ -609,7 +615,7 @@ def draw(self, panel, gmm, unit_id):
for sl, data in gmm.batches(in_unit):
gmm[unit_id].residual_embed(**data, out=features[sl])
z = features[:, : gmm.dpc_split_kw.rank].numpy(force=True)
elif self.spike_kind == "train":
elif self.spike_kind == "split":
_, in_unit, z = gmm.split_features(unit_id)
elif self.spike_kind == "global":
in_unit, data = gmm.get_training_data(unit_id)
Expand Down Expand Up @@ -694,7 +700,7 @@ class AmplitudesOverTimePlot(GMMPlot):
width = 5
height = 3

def __init__(self, kinds=("recon", 'model'), colors="bkrg"):
def __init__(self, kinds=("recon", "maxchan_energy", 'model'), colors="bkrg"):
self.kinds = kinds
self.colors = dict(zip(kinds, colors))

Expand Down Expand Up @@ -734,6 +740,7 @@ def draw(self, panel, gmm, unit_id):
t = utd["times"].numpy(force=True)
for kind, a in show.items():
ax.scatter(t, a, c=self.colors[kind], s=3, lw=0, label=kind)
ax.legend(loc="upper left")
ax.set_ylabel("amplitude")


Expand Down Expand Up @@ -788,19 +795,28 @@ def draw(self, panel, gmm, unit_id):
overlaps = overlaps.numpy(force=True)
times = utd["times"][spike_ix].numpy(force=True)

ax = panel.subplots()
ax, ay = panel.subplots(ncols=2, width_ratios=[2, 1])
ay.grid(True)
for j, (kind, b) in enumerate(badnesses.items()):
b = b.numpy(force=True)
c = self.colors[j]
ax.scatter(
times,
b.numpy(force=True),
b,
alpha=overlaps,
s=3,
c=self.colors[j],
c=c,
lw=0,
label=kind,
)
ay.ecdf(b, lw=1, color=c)
ay.axvline(np.mean(b), color=c, lw=1)
ay.axvline(np.median(b), color=c, ls="--", lw=1)
ay.set_xlabel(kind)
ay.set_ylabel("cdf")
ax.legend(loc="upper left")
ax.set_ylabel("badness")
ay.set_yticks([0, 0.25, 0.5, 0.75, 1])


class FeaturesVsBadnessesPlot(GMMPlot):
Expand Down Expand Up @@ -1352,9 +1368,9 @@ def __init__(self, n_neighbors=5):
self.n_neighbors = n_neighbors

def get_neighbors(self, gmm, unit_id, reversed=False):
unit_dists = gmm.central_divergences(units_a=[unit_id])[0]
unit_dists = gmm.central_divergences(units_a=torch.tensor([unit_id]))[0]
unit_ids = gmm.unit_ids()
neighbors = torch.argsort(unit_dists)
neighbors = torch.argsort((unit_ids != unit_id).to(unit_dists) + unit_dists)
assert unit_ids[neighbors[0]] == unit_id
neighbors = neighbors[: self.n_neighbors + 1]
neighbors = neighbors[torch.isfinite(unit_dists[neighbors])]
Expand Down Expand Up @@ -1561,12 +1577,13 @@ class NeighborBimodality(GMMMergePlot):
width = 3
height = 10

def __init__(self, n_neighbors=5, badness_kind="1-r^2", do_reg=False, masked=False, mask_radius_s=5.0):
def __init__(self, n_neighbors=5, badness_kind="1-r^2", do_reg=False, masked=False, mask_radius_s=5.0, impute_missing=False):
self.n_neighbors = n_neighbors
self.badness_kind = badness_kind
self.do_reg = do_reg
self.masked = masked
self.mask_radius_s = mask_radius_s
self.impute_missing = impute_missing

def draw(self, panel, gmm, unit_id):
from isosplit import isocut, dipscore_at
Expand Down Expand Up @@ -1625,6 +1642,7 @@ def draw(self, panel, gmm, unit_id):
unit_ids=[unit_id, u],
show_progress=False,
kind=self.badness_kind,
impute_missing=self.impute_missing,
)
a = np.full(badness.shape, np.inf)
a[badness.coords] = badness.data
Expand Down Expand Up @@ -1666,7 +1684,8 @@ def draw(self, panel, gmm, unit_id):
ds_ud = f"{ds_ud:0.3f}".lstrip("0").rstrip("0")
ds_udw = f"{ds_udw:0.3f}".lstrip("0").rstrip("0")
mstr = "masked " if self.masked else ""
row[1].set_title(f"{mstr} u{ds_ud} uw{ds_udw}", fontsize=7)
istr = "imp " if self.impute_missing else ""
row[1].set_title(f"{mstr}{istr} u{ds_ud} uw{ds_udw}", fontsize=7)

if self.do_reg:
sns.regplot(
Expand Down Expand Up @@ -1864,7 +1883,7 @@ def draw(self, panel, gmm, unit_id):
unit.bar(axes[0, 1], alags, acg, fill=True, fc=colors[0])
axes[0, 0].axis("off")
axes[0, 1].set_ylabel(f"my acg {neighbor_ids[0]}")

j = 0
for j, ub in enumerate(neighbor_ids[1:], start=1):
their_st = gmm.data.times_samples[gmm.labels == ub]
Expand Down Expand Up @@ -1893,7 +1912,7 @@ def draw(self, panel, gmm, unit_id):
# HDBScanSplitPlot(spike_kind="residual_full"),
# HDBScanSplitPlot(),
# ZipperSplitPlot(),
KMeansPPSPlitPlot(),
KMeansPPSPlitPlot(inherit_chans=True, n_clust=10, impute_before_center=True),
GridMeansSingleChanPlot(),
InputWaveformsSingleChanPlot(),
# InputWaveformsSingleChanOverTimePlot(channel="unit"),
Expand All @@ -1903,7 +1922,7 @@ def draw(self, panel, gmm, unit_id):
BadnessesOverTimePlot(),
EmbedsOverTimePlot(),
# DPCSplitPlot(spike_kind="residual_full"),
DPCSplitPlot(spike_kind="train"),
DPCSplitPlot(spike_kind="split"),
# DPCSplitPlot(spike_kind="global"),
# DPCSplitPlot(spike_kind="global", feature="spread_amp"),
FeaturesVsBadnessesPlot(),
Expand All @@ -1928,8 +1947,8 @@ def draw(self, panel, gmm, unit_id):
ISICorner(bin_ms=0.25),
ISICorner(bin_ms=0.5, max_ms=8, tick_step=2),
# NeighborBimodality(),
NeighborBimodality(badness_kind="diagz", masked=True),
NeighborBimodality(badness_kind="1-r^2", masked=True),
NeighborBimodality(badness_kind="1-r^2", masked=True, impute_missing=True),
CCGColumn(),
# NeighborBimodality(badness_kind="1-scaledr^2", masked=True),
)
Expand Down
4 changes: 2 additions & 2 deletions src/dartsort/vis/unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,10 +421,10 @@ def draw(self, panel, sorting_analysis, unit_id, axis=None):
)
trough_offset_samples = self.trough_offset_samples
spike_length_samples = self.spike_length_samples
if tslice.start is not None:
if tslice is not None and tslice.start is not None:
trough_offset_samples = self.trough_offset_samples - tslice.start
spike_length_samples = self.spike_length_samples - tslice.start
if tslice.stop is not None:
if tslice is not None and tslice.stop is not None:
spike_length_samples = tslice.stop - tslice.start

max_abs_amp = None
Expand Down

0 comments on commit 6b51274

Please sign in to comment.