Skip to content

Commit

Permalink
Debug vis edge cases
Browse files Browse the repository at this point in the history
  • Loading branch information
cwindolf committed Nov 21, 2024
1 parent 0944c79 commit 2f5358d
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 10 deletions.
37 changes: 28 additions & 9 deletions src/dartsort/vis/gmm.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from pathlib import Path
import warnings
import itertools

import matplotlib.pyplot as plt
import numpy as np
Expand Down Expand Up @@ -102,9 +103,21 @@ def draw(self, panel, gmm, unit_id):
k = k.item()
if torch.is_tensor(v):
v = v.numpy(force=True)
if v.size == 1:
if isinstance(v, np.ndarray):
if not v.size:
v = "[]"
elif v.size == 1:
v = v.item()
msg += f"{k}: {v}"
elif v.ndim == 1:
vv = [str(v[0])]
for vvv in map(str, v[1:]):
if len(vv[-1]) > 16:
vv[-1] += "\n"
vv.append(vvv)
continue
vv[-1] += "," + vvv
v = "\n".join(vv)
msg += f"{k}:\n{v}"

axis.text(0, 0, msg, fontsize=6.5)

Expand Down Expand Up @@ -581,7 +594,7 @@ def draw(self, panel, gmm, unit_id):
class NeighborBimodalities(GMMPlot):
kind = "bim"
width = 4
height = 8
height = 9

def __init__(self, n_neighbors=5):
self.n_neighbors = n_neighbors
Expand Down Expand Up @@ -627,6 +640,7 @@ def draw(self, panel, gmm, unit_id):
scatter_ax.text(
0.5, 0.5, f"too few spikes", transform=scatter_ax.transAxes, ha='center', va='center'
)
continue
else:
c = np.atleast_2d(glasbey1024[labels[bimod_info["in_pair_kept"]]])
scatter_ax.scatter(bimod_info["xi"], bimod_info["xj"], s=3, lw=0, c=c)
Expand Down Expand Up @@ -677,7 +691,7 @@ def draw(self, panel, gmm, unit_id):
class NeighborInfoCriteria(GMMPlot):
kind = "bim"
width = 4
height = 8
height = 9

def __init__(self, n_neighbors=5, fit_by_avg=False):
self.n_neighbors = n_neighbors
Expand All @@ -693,6 +707,7 @@ def draw(self, panel, gmm, unit_id):
bstr = "BICfull/merged: {bic_full:0.1f} / {bic_merged:0.1f}"
lstr = "LLfull/merged: {full_loglik:0.1f} / {unit_loglik:0.1f}"
cstr = f"{astr}\n{bstr}\n{lstr}\n"
bbox = dict(facecolor='w', alpha=0.5, edgecolor="none")
for ax, other_id in zip(axes, others):
uids = [unit_id, other_id]
res = gmm.unit_group_criterion(uids, gmm.log_liks, debug=True)
Expand All @@ -708,15 +723,19 @@ def draw(self, panel, gmm, unit_id):
if rowsll.numel():
ax.hist(rowsll, color=glasbey1024[uid], **histkw)
ull = res["unit_logliks"]
ax.hist(ull[torch.isfinite(ull)], color="k", **histkw)
if ull is not None:
ax.hist(ull[torch.isfinite(ull)], color="k", **histkw)
s = f"other={other_id}\n" + cstr.format_map(res)
aic_merge = res['aic_merged'] < res['aic_full']
bic_merge = res['bic_merged'] < res['bic_full']
ll_merge = res['unit_loglik'] > res['full_loglik']
s += f"aic: " + ("merge!" if aic_merge else "nope.") + "\n"
s += f"bic: " + ("merge!" if bic_merge else "nope.") + "\n"
s += f"ll: " + ("merge!" if ll_merge else "nope.")
ax.text(0.05, 0.95, s, transform=ax.transAxes)
aicdif = res['aic_full'] - res['aic_merged']
bicdif = res['bic_full'] - res['bic_merged']
lldif = res['full_loglik'] - res['unit_loglik']
s += f"aic: {aicdif:0.1f}, " + ("merge!" if aic_merge else "nope.") + "\n"
s += f"bic: {bicdif:0.1f}, " + ("merge!" if bic_merge else "nope.") + "\n"
s += f"ll: {lldif:0.1f}, " + ("merge!" if ll_merge else "nope.")
ax.text(0.05, 0.95, s, transform=ax.transAxes, va="top", bbox=bbox)
ax.set_xlabel("log lik")
sns.despine(ax=ax, left=True, right=True, top=True)
# ax.set_yticks([])
Expand Down
3 changes: 2 additions & 1 deletion src/dartsort/vis/unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ def draw(self, panel, sorting_analysis, unit_id):
axis.axis("off")
msg = f"unit {unit_id}\n"

msg += f"feature source: {sorting_analysis.hdf5_path.name}\n"
if getattr(sorting_analysis, 'hdf5_path', None):
msg += f"feature source: {sorting_analysis.hdf5_path.name}\n"

nspikes = sorting_analysis.spike_counts[
sorting_analysis.unit_ids == unit_id
Expand Down

0 comments on commit 2f5358d

Please sign in to comment.