Skip to content

Commit

Permalink
Work on cov vis
Browse files Browse the repository at this point in the history
  • Loading branch information
cwindolf committed Nov 7, 2024
1 parent 07d095b commit 6fe2005
Showing 1 changed file with 58 additions and 13 deletions.
71 changes: 58 additions & 13 deletions src/dartsort/vis/gmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from ..cluster import gaussian_mixture
from ..util.multiprocessing_util import (CloudpicklePoolExecutor,
ThreadPoolExecutor, get_pool, cloudpickle)
from ..util import spiketorch
from . import analysis_plots, gmm_helpers, layout
from .colors import glasbey1024
from .waveforms import geomplot
Expand Down Expand Up @@ -97,36 +98,49 @@ def draw(self, panel, gmm, unit_id):

class MStep(GMMPlot):
kind = "waveform"
width = 4
height = 5
width = 5
height = 9
alpha = 0.05
n_show = 64

def draw(self, panel, gmm, unit_id, axes=None):
ax = panel.subplots()
panel_top, panel_bottom = panel.subfigures(nrows=2, height_ratios=[1.5, 1])
ax = panel_top.subplots()
ax.axis("off")

# panel_bottom, panel_cbar = panel_bottom.subfigures(ncols=2, width_ratios=[5, 0.5])
cov_axes = panel_bottom.subplots(
nrows=2, ncols=2, sharey=True, sharex=True
)
# cax = panel_cbar.add_subplot(3, 1, 2)

# get spike data and determine channel set by plotting
sp = gmm.random_spike_data(unit_id, max_size=self.n_show, with_reconstructions=True)
maa = sp.waveforms.abs().nan_to_num().max()
geomplot_kw = dict(
max_abs_amp=maa,
geom=gmm.data.prgeom.numpy(force=True),
show_zero=False,
return_chans=True,
)
lines, chans = geomplot(
sp.waveforms,
channels=sp.channels,
geom=gmm.data.prgeom.numpy(force=True),
max_abs_amp=maa,
color="k",
alpha=self.alpha,
return_chans=True,
show_zero=False,
ax=ax,
**geomplot_kw,
)
chans = torch.tensor(list(chans))
tup = gaussian_mixture.to_full_probe(
sp, weights=None, n_channels=gmm.data.n_channels, storage=None
)
features_full, weights_full, count_data, weights_normalized = tup
emp_mean = torch.nanmean(features_full, dim=0)[:, chans]
print(f"{features_full.shape=}")
feats = features_full[:, :, chans]
n, r, c = feats.shape
emp_mean = torch.nanmean(feats, dim=0)
emp_mean = gmm.data.tpca.force_reconstruct(emp_mean.nan_to_num_())

model_mean = gmm.units[unit_id].mean[:, chans]
model_mean = gmm.data.tpca.force_reconstruct(model_mean)

Expand All @@ -142,6 +156,37 @@ def draw(self, panel, gmm, unit_id, axes=None):
ax.axis("off")
ax.set_title("reconstructed mean and example inputs")

# covariance vis
feats = features_full[:, :, gmm.units[unit_id].channels]
model_mean = gmm.units[unit_id].mean[:, gmm.units[unit_id].channels]
n, r, c = feats.shape
emp_cov, nobs = spiketorch.nancov(feats.view(n, r * c), return_nobs=True)
denom = nobs + gmm.units[unit_id].prior_pseudocount
emp_cov = (nobs / denom) * emp_cov
noise_cov = gmm.noise.marginal_covariance(channels=gmm.units[unit_id].channels).to_dense()
m = model_mean.abs().reshape(-1)
mmt = m[:, None] @ m[None, :]
covs = (emp_cov, noise_cov, mmt)
vmax = max(c.abs().max() for c in covs)
names = ("regemp", "noise", "|temptempT|")
print(f"{feats.shape=} {gmm.units[unit_id].channels.shape=}")
print(f"{vmax=}")
print(f"{emp_cov.abs().max()=}")
print(f"{noise_cov.abs().max()=}")
print(f"{mmt.abs().max()=}")
print(f"{emp_cov.shape=}")
print(f"{noise_cov.shape=}")
print(f"{mmt.shape=}")

for ax, cov, name in zip(cov_axes.flat, covs, names):
vmax = cov.abs().triu(diagonal=1)
vmax = vmax[vmax>0].quantile(.95)
im = ax.imshow(cov.numpy(force=True), vmin=-vmax, vmax=vmax, cmap=plt.cm.seismic)
ax.axis("off")
ax.set_title(name, fontsize="small")
plt.colorbar(im, ax=ax, shrink=0.5)
# plt.colorbar(im, cax=cax, shrink=0.5)


class Likelihoods(GMMPlot):
kind = "widescatter"
Expand Down Expand Up @@ -324,8 +369,8 @@ def draw(self, panel, gmm, unit_id, split_info=None):

class NeighborMeans(GMMPlot):
kind = "merge"
width = 3
height = 4
width = 4
height = 3

def __init__(self, n_neighbors=5):
self.n_neighbors = n_neighbors
Expand All @@ -348,8 +393,8 @@ def draw(self, panel, gmm, unit_id):

class NeighborDistances(GMMPlot):
kind = "merge"
width = 3
height = 3
width = 4
height = 2

def __init__(self, n_neighbors=5, dist_vmax=1.0):
self.n_neighbors = n_neighbors
Expand Down

0 comments on commit 6fe2005

Please sign in to comment.