From 2f5358db18d7aeebb1e76d4857cc671f3bb48472 Mon Sep 17 00:00:00 2001 From: Charlie Windolf Date: Thu, 21 Nov 2024 11:49:36 -0800 Subject: [PATCH] Debug vis edge cases --- src/dartsort/vis/gmm.py | 37 ++++++++++++++++++++++++++++--------- src/dartsort/vis/unit.py | 3 ++- 2 files changed, 30 insertions(+), 10 deletions(-) diff --git a/src/dartsort/vis/gmm.py b/src/dartsort/vis/gmm.py index 262f9f07..3c978a51 100644 --- a/src/dartsort/vis/gmm.py +++ b/src/dartsort/vis/gmm.py @@ -1,5 +1,6 @@ from pathlib import Path import warnings +import itertools import matplotlib.pyplot as plt import numpy as np @@ -102,9 +103,21 @@ def draw(self, panel, gmm, unit_id): k = k.item() if torch.is_tensor(v): v = v.numpy(force=True) - if v.size == 1: + if isinstance(v, np.ndarray): + if not v.size: + v = "[]" + elif v.size == 1: v = v.item() - msg += f"{k}: {v}" + elif v.ndim == 1: + vv = [str(v[0])] + for vvv in map(str, v[1:]): + if len(vv[-1]) > 16: + vv[-1] += "\n" + vv.append(vvv) + continue + vv[-1] += "," + vvv + v = "\n".join(vv) + msg += f"{k}:\n{v}" axis.text(0, 0, msg, fontsize=6.5) @@ -581,7 +594,7 @@ def draw(self, panel, gmm, unit_id): class NeighborBimodalities(GMMPlot): kind = "bim" width = 4 - height = 8 + height = 9 def __init__(self, n_neighbors=5): self.n_neighbors = n_neighbors @@ -627,6 +640,7 @@ def draw(self, panel, gmm, unit_id): scatter_ax.text( 0.5, 0.5, f"too few spikes", transform=scatter_ax.transAxes, ha='center', va='center' ) + continue else: c = np.atleast_2d(glasbey1024[labels[bimod_info["in_pair_kept"]]]) scatter_ax.scatter(bimod_info["xi"], bimod_info["xj"], s=3, lw=0, c=c) @@ -677,7 +691,7 @@ def draw(self, panel, gmm, unit_id): class NeighborInfoCriteria(GMMPlot): kind = "bim" width = 4 - height = 8 + height = 9 def __init__(self, n_neighbors=5, fit_by_avg=False): self.n_neighbors = n_neighbors @@ -693,6 +707,7 @@ def draw(self, panel, gmm, unit_id): bstr = "BICfull/merged: {bic_full:0.1f} / {bic_merged:0.1f}" lstr = "LLfull/merged: {full_loglik:0.1f} / {unit_loglik:0.1f}" cstr = f"{astr}\n{bstr}\n{lstr}\n" + bbox = dict(facecolor='w', alpha=0.5, edgecolor="none") for ax, other_id in zip(axes, others): uids = [unit_id, other_id] res = gmm.unit_group_criterion(uids, gmm.log_liks, debug=True) @@ -708,15 +723,19 @@ def draw(self, panel, gmm, unit_id): if rowsll.numel(): ax.hist(rowsll, color=glasbey1024[uid], **histkw) ull = res["unit_logliks"] - ax.hist(ull[torch.isfinite(ull)], color="k", **histkw) + if ull is not None: + ax.hist(ull[torch.isfinite(ull)], color="k", **histkw) s = f"other={other_id}\n" + cstr.format_map(res) aic_merge = res['aic_merged'] < res['aic_full'] bic_merge = res['bic_merged'] < res['bic_full'] ll_merge = res['unit_loglik'] > res['full_loglik'] - s += f"aic: " + ("merge!" if aic_merge else "nope.") + "\n" - s += f"bic: " + ("merge!" if bic_merge else "nope.") + "\n" - s += f"ll: " + ("merge!" if ll_merge else "nope.") - ax.text(0.05, 0.95, s, transform=ax.transAxes) + aicdif = res['aic_full'] - res['aic_merged'] + bicdif = res['bic_full'] - res['bic_merged'] + lldif = res['full_loglik'] - res['unit_loglik'] + s += f"aic: {aicdif:0.1f}, " + ("merge!" if aic_merge else "nope.") + "\n" + s += f"bic: {bicdif:0.1f}, " + ("merge!" if bic_merge else "nope.") + "\n" + s += f"ll: {lldif:0.1f}, " + ("merge!" if ll_merge else "nope.") + ax.text(0.05, 0.95, s, transform=ax.transAxes, va="top", bbox=bbox) ax.set_xlabel("log lik") sns.despine(ax=ax, left=True, right=True, top=True) # ax.set_yticks([]) diff --git a/src/dartsort/vis/unit.py b/src/dartsort/vis/unit.py index 310c20ea..5325d941 100644 --- a/src/dartsort/vis/unit.py +++ b/src/dartsort/vis/unit.py @@ -51,7 +51,8 @@ def draw(self, panel, sorting_analysis, unit_id): axis.axis("off") msg = f"unit {unit_id}\n" - msg += f"feature source: {sorting_analysis.hdf5_path.name}\n" + if getattr(sorting_analysis, 'hdf5_path', None): + msg += f"feature source: {sorting_analysis.hdf5_path.name}\n" nspikes = sorting_analysis.spike_counts[ sorting_analysis.unit_ids == unit_id