Skip to content

Commit

Permalink
Re-do the GMM visualization for the new code
Browse files Browse the repository at this point in the history
  • Loading branch information
cwindolf committed Oct 23, 2024
1 parent 1f1fc74 commit 2e502f7
Show file tree
Hide file tree
Showing 9 changed files with 532 additions and 2,216 deletions.
208 changes: 159 additions & 49 deletions src/dartsort/cluster/gaussian_mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,10 @@ def em(self, n_iter=None, show_progress=True):
self.m_step()

for _ in its:
# having this cleanup here doesn't do anything except make the
# reassignment counts make sense
self.cleanup(clean_units=True, min_count=1)

# E step: get responsibilities and update hard assignments
log_liks = self.log_likelihoods(show_progress=show_progress > 1)
reas_count = self.reassign(log_liks)
Expand Down Expand Up @@ -171,14 +175,18 @@ def m_step(self, likelihoods=None, show_progress=False):
self.units.append(unit)

def log_likelihoods(
self, with_noise_unit=True, use_storage=True, show_progress=False
self, unit_ids=None, with_noise_unit=True, use_storage=True, show_progress=False
):
"""Noise unit last so that rows correspond to unit ids without 1 offset"""
if unit_ids is None:
unit_ids = range(len(self.units))

# determine how much storage space we need by figuring out how many spikes
# are overlapping with each unit
neighb_info = []
ns_total = 0
for j, unit in enumerate(self.units):
for j in unit_ids:
unit = self.units[j]
neighbs, ns_unit = self.data.core_neighborhoods.subset_neighborhoods(
unit.channels
)
Expand Down Expand Up @@ -236,7 +244,7 @@ def reassign(self, log_liks):
self.labels.copy_(assignments)
return reassign_count

def cleanup(self, log_liks=None, clean_units=True):
def cleanup(self, log_liks=None, clean_units=True, min_count=None):
"""Remove too-small units
Also handles bookkeeping to throw those units out of the sparse
Expand All @@ -247,7 +255,9 @@ def cleanup(self, log_liks=None, clean_units=True):
unit_ids, counts = torch.unique(self.labels, return_counts=True)
counts = counts[unit_ids >= 0]
unit_ids = unit_ids[unit_ids >= 0]
big_enough = counts > self.min_count
if min_count is None:
min_count = self.min_count
big_enough = counts >= min_count
if big_enough.all():
return log_liks

Expand Down Expand Up @@ -312,19 +322,10 @@ def distances(

if units is None:
units = self.units
nu = len(units)

# stack unit data into one place
nu, rank, nc = len(units), self.data.rank, self.data.n_channels
means = self.noise_unit.mean.new_zeros((nu, rank, nc))
covs = logdets = None
if kind in ("kl_divergence",):
covs = means.new_zeros((nu, rank * nc, rank * nc))
logdets = means.new_zeros((nu,))
for j, unit in enumerate(units):
means[j] = unit.mean
if covs is not None:
covs[j] = unit.dense_cov()
logdets[j] = unit.logdet
means, covs, logdets = self.stack_units()

# compute denominator of noised normalized distances
if noise_normalized:
Expand Down Expand Up @@ -380,32 +381,17 @@ def bimodalities(

@delayed
def bimod_job(i, j):
if not compute_mask[i, j]:
return
ini = in_units[i]
inj = in_units[j]
if masked:
times_i = self.data.times_seconds[ini]
times_j = self.data.times_seconds[inj]
ini = ini[getdt(times_j, times_i) <= dt_s]
inj = inj[getdt(times_i, times_j) <= dt_s]
ini = shrinkfit(ini, max_spikes, self.rg)
inj = shrinkfit(inj, max_spikes, self.rg)

in_pair = torch.concatenate((ini, inj))
ijlabels = torch.zeros(in_pair.shape, dtype=bool)
ijlabels[ini.numel() :] = 1
in_pair, order = in_pair.sort()
ijlabels = ijlabels[order]

log_lik_diff = get_diff_sparse(log_liks, i, j, in_pair)
scores[i, j] = scores[j, i] = qda(
ijlabels.numpy(force=True),
diff=log_lik_diff,
scores[i, j] = scores[j, i] = self.unit_pair_bimodality(
id_a=i,
id_b=j,
log_liks=log_liks,
cut=cut,
weighted=weighted,
min_overlap=min_overlap,
score_kind="tv",
in_units=in_units,
masked=masked,
max_spikes=max_spikes,
dt_s=dt_s,
)

if compute_mask is None:
Expand Down Expand Up @@ -519,10 +505,14 @@ def unit_log_likelihoods(
for neighb_id, (neighb_chans, neighb_member_ixs) in jobs:
if inds_already:
sp = self.data.spike_data(
spike_indices[neighb_member_ixs], with_channels=False, neighborhood="core"
spike_indices[neighb_member_ixs],
with_channels=False,
neighborhood="core",
)
else:
sp = self.data.spike_data(neighb_member_ixs, with_channels=False, neighborhood="core")
sp = self.data.spike_data(
neighb_member_ixs, with_channels=False, neighborhood="core"
)
features = sp.features
chans_valid = neighb_chans < self.data.n_channels
features = features[..., chans_valid]
Expand All @@ -549,12 +539,14 @@ def noise_log_likelihoods(self):
)
return self._noise_six, self._noise_log_likelihoods

def kmeans_split_unit(self, unit_id):
def kmeans_split_unit(self, unit_id, debug=False):
# get spike data and use interpolation to fill it out to the
# unit's channel set
unit = self.units[unit_id]
indices_full, sp = self.random_spike_data(unit_id, return_full_indices=True)
X = self.data.interp_to_chans(sp, unit)
if debug:
debug_info = dict(indices_full=indices_full, sp=sp, X=X)

# run kmeans with kmeans++ initialization
split_labels, responsibilities = kmeans(
Expand All @@ -566,12 +558,23 @@ def kmeans_split_unit(self, unit_id):
with_proportions=self.kmeans_with_proportions,
drop_prop=self.kmeans_drop_prop,
)
if debug:
debug_info["split_labels"] = split_labels
debug_info["responsibilities"] = responsibilities
if split_labels.unique().numel() <= 1:
if debug:
return debug_info
return 0, []
weights = responsibilities / responsibilities.sum(0)

# avoid oversplitting by doing a mini merge here
split_labels, split_ids = self.mini_merge(sp, split_labels, weights)
split_labels, split_ids = self.mini_merge(
sp, split_labels, weights, debug=debug, debug_info=debug_info
)

if debug:
debug_info["merge_labels"] = split_labels
return debug_info

# tack these new units onto the end
with self.labels_lock:
Expand All @@ -581,7 +584,9 @@ def kmeans_split_unit(self, unit_id):
self.labels[indices_full] = -1
self.labels[sp.indices] = split_labels

def mini_merge(self, spike_data, labels, weights=None):
def mini_merge(
self, spike_data, labels, weights=None, debug=False, debug_info=None
):
"""Given labels for a small bag of data, fit and merge."""
# fit sub-units
split_ids = labels.unique()
Expand Down Expand Up @@ -626,9 +631,81 @@ def mini_merge(self, spike_data, labels, weights=None):
distances,
linkage_method="complete",
)
if debug:
debug_info["units"] = units
debug_info["lls"] = lls
debug_info["bimodalities"] = bimodalities
debug_info["distances"] = distances

return new_labels, new_ids

def unit_pair_bimodality(
self,
id_a,
id_b,
log_liks,
loglik_ix_a=None,
loglik_ix_b=None,
cut=None,
weighted=True,
min_overlap=0.95,
in_units=None,
masked=True,
max_spikes=2048,
dt_s=2.0,
debug=False,
score_kind="tv",
):
if in_units is not None:
ina = in_units[id_a]
inb = in_units[id_b]
else:
(ina,) = torch.nonzero(self.labels == id_a)
(inb,) = torch.nonzero(self.labels == id_b)

if masked:
times_a = self.data.times_seconds[ina]
times_b = self.data.times_seconds[inb]
ina = ina[getdt(times_b, times_a) <= dt_s]
inb = inb[getdt(times_a, times_b) <= dt_s]

ina = shrinkfit(ina, max_spikes, self.rg)
inb = shrinkfit(inb, max_spikes, self.rg)

in_pair = torch.concatenate((ina, inb))
is_b = torch.zeros(in_pair.shape, dtype=bool)
is_b[ina.numel() :] = 1
in_pair, order = in_pair.sort()
is_b = is_b[order]

loglik_ix_a = id_a if loglik_ix_a is None else loglik_ix_a
loglik_ix_b = id_b if loglik_ix_b is None else loglik_ix_b
log_lik_diff = get_diff_sparse(log_liks, loglik_ix_a, loglik_ix_b, in_pair, return_extra=debug)

debug_info = None
if debug:
log_lik_diff, extra = log_lik_diff
debug_info = {}
debug_info["log_lik_diff"] = log_lik_diff
# adds keys: xi, xj, keep_inds
debug_info.update(extra)
debug_info["in_pair_kept"] = in_pair[extra["keep_inds"]]
# qda adds keys: domain, alternative_density, cut, score, score_kind,
# uni_density, sample_weights, samples

score = qda(
is_b.numpy(force=True),
diff=log_lik_diff,
cut=cut,
weighted=weighted,
min_overlap=min_overlap,
score_kind=score_kind,
debug_info=debug_info,
)
if debug:
return debug_info
return score

# -- gizmos

@property
Expand Down Expand Up @@ -672,6 +749,23 @@ def relabel_units(self, old_labels, new_labels=None, flat=False):
self.labels[kept] = new_labels.to(self.labels)[label_indices]
self.labels[torch.logical_not(kept)] = -1

def stack_units(self, units=None, distance_metric=None):
if units is None:
units = self.units
if distance_metric is None:
kind = self.distance_metric
nu, rank, nc = len(units), self.data.rank, self.data.n_channels
means = self.noise_unit.mean.new_zeros((nu, rank, nc))
covs = logdets = None
if kind in ("kl_divergence",):
covs = means.new_zeros((nu, rank * nc, rank * nc))
logdets = means.new_zeros((nu,))
for j, unit in enumerate(units):
means[j] = unit.mean
if covs is not None:
covs[j] = unit.dense_cov()
logdets[j] = unit.logdet


# -- modeling class

Expand Down Expand Up @@ -815,8 +909,13 @@ def pick_channels(self, count_data):
if self.channels_strategy == "all":
self.register_buffer("channels", torch.arange(self.n_channels))
elif self.channels_strategy == "snr":
snr = torch.linalg.vector_norm(self.mean, dim=0) * count_data.sqrt().view(-1)
(channels,) = torch.nonzero(snr >= self.channels_strategy_snr_min, as_tuple=True)
snr = torch.linalg.vector_norm(self.mean, dim=0) * count_data.sqrt().view(
-1
)
self.register_buffer("snr", snr)
(channels,) = torch.nonzero(
snr >= self.channels_strategy_snr_min, as_tuple=True
)
self.register_buffer("channels", channels)
else:
assert False
Expand Down Expand Up @@ -891,7 +990,9 @@ def kl_divergence(self, other_means, other_covs, other_logdets):

def to_full_probe(features, weights, n_channels, storage):
n, r, c = features.features.shape
features_full = get_nans(features.features, storage, "features_full", (n, r, n_channels + 1))
features_full = get_nans(
features.features, storage, "features_full", (n, r, n_channels + 1)
)
targ_inds = features.channels.unsqueeze(1).broadcast_to(features.features.shape)
features_full.scatter_(2, targ_inds, features.features)
features_full = features_full[:, :, :-1]
Expand Down Expand Up @@ -1006,8 +1107,10 @@ def sparse_reassign(liks, match_threshold=None, batch_size=512, return_csc=False
),
numba.void(
numba.int64[:], numba.int64[:], numba.int64[:], numba.float32[:], numba.int64[:]
)
),
]


@numba.njit(
sigs,
error_model="numpy",
Expand Down Expand Up @@ -1095,6 +1198,7 @@ def qda(
weighted=True,
min_overlap=0.95,
score_kind="tv",
debug_info=None,
):
# "in b not a"-ness
diff = log_liks_b - log_liks_a
Expand All @@ -1117,6 +1221,7 @@ def qda(
sample_weights=sample_weights,
dipscore_only=True,
score_kind=score_kind,
debug_info=None,
)


Expand All @@ -1141,7 +1246,7 @@ def shrinkfit(x, max_size, rg):
return x[choices.to(x.device)]


def get_diff_sparse(sparse_arr, i, j, cols):
def get_diff_sparse(sparse_arr, i, j, cols, return_extra=False):
xi = torch.index_select(sparse_arr[i], 0, cols).coalesce()
indsi = xi.indices().view(-1).numpy()
xi = xi.values().numpy()
Expand All @@ -1154,6 +1259,11 @@ def get_diff_sparse(sparse_arr, i, j, cols):
jkeep = np.isin(indsj, indsi)

diff = np.full(cols.shape, np.nan)
diff[indsi[ikeep]] = xj[jkeep] - xi[ikeep]
xj = xj[jkeep]
xi = xi[ikeep]
diff[indsi[ikeep]] = xi - xj

if return_extra:
return diff, dict(xi=xi, xj=xj, keep_inds=indsi[ikeep])

return diff
Loading

0 comments on commit 2e502f7

Please sign in to comment.