Skip to content

Commit

Permalink
Well, it doesn't throw an exception. You have to give it that.
Browse files Browse the repository at this point in the history
  • Loading branch information
cwindolf committed Oct 23, 2024
1 parent 358603a commit f480417
Show file tree
Hide file tree
Showing 11 changed files with 97 additions and 53 deletions.
6 changes: 4 additions & 2 deletions src/dartsort/cluster/cluster_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ def agglomerate(labels, distances, linkage_method="complete", threshold=1.0):
n = distances.shape[0]
pdist = distances[np.triu_indices(n, k=1)]
if pdist.min() > threshold:
return labels
ids = np.unique(labels)
return labels, ids[ids >= 0]
finite = np.isfinite(pdist)
if not finite.all():
inf = max(0, pdist[finite].max()) + threshold + 1.0
Expand All @@ -28,7 +29,8 @@ def agglomerate(labels, distances, linkage_method="complete", threshold=1.0):
new_ids -= new_ids.min()

kept = labels >= 0
new_labels = np.where(kept, new_ids[labels[kept]], -1)
new_labels = np.full_like(labels, -1)
new_labels[kept] = new_ids[labels[kept]]

return new_labels, new_ids

Expand Down
56 changes: 34 additions & 22 deletions src/dartsort/cluster/gaussian_mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ def log_likelihoods(
# log liks as sparse matrix. sparse zeros are not 0 but -inf!!
log_liks = coo_array(
(coo_data, (coo_uix, coo_six)),
shape=(len(self.units) + with_noise_unit, self.data.n_spikes),
shape=(len(unit_ids) + with_noise_unit, self.data.n_spikes),
)
return log_liks

Expand Down Expand Up @@ -325,7 +325,7 @@ def distances(
nu = len(units)

# stack unit data into one place
means, covs, logdets = self.stack_units()
means, covs, logdets = self.stack_units(units)

# compute denominator of noised normalized distances
if noise_normalized:
Expand Down Expand Up @@ -494,7 +494,7 @@ def unit_log_likelihoods(
offset = 0
log_likelihoods = torch.empty(ns)
else:
log_likelihoods = torch.full(ns, -torch.inf)
log_likelihoods = torch.full((ns,), -torch.inf)

jobs = neighbs.items()
if show_progress:
Expand Down Expand Up @@ -532,10 +532,10 @@ def unit_log_likelihoods(

return spike_indices, log_likelihoods

def noise_log_likelihoods(self):
def noise_log_likelihoods(self, show_progress=False):
if self._noise_log_likelihoods is None:
self._noise_six, self._noise_log_likelihoods = self.unit_log_likelihoods(
unit=self.noise_unit, show_progress=True, desc_prefix="Noise "
unit=self.noise_unit, show_progress=show_progress, desc_prefix="Noise "
)
return self._noise_six, self._noise_log_likelihoods

Expand All @@ -544,13 +544,13 @@ def kmeans_split_unit(self, unit_id, debug=False):
# unit's channel set
unit = self.units[unit_id]
indices_full, sp = self.random_spike_data(unit_id, return_full_indices=True)
X = self.data.interp_to_chans(sp, unit)
X = self.data.interp_to_chans(sp, unit.channels)
if debug:
debug_info = dict(indices_full=indices_full, sp=sp, X=X)

# run kmeans with kmeans++ initialization
split_labels, responsibilities = kmeans(
X,
X.view(len(X), -1),
n_iter=self.kmeans_n_iter,
n_components=self.kmeans_k,
random_state=self.rg,
Expand Down Expand Up @@ -593,36 +593,42 @@ def mini_merge(
units = []
for label in split_ids:
(in_label,) = torch.nonzero(labels == labels, as_tuple=True)
weights = None if weights is None else weights[in_label]
w = None if weights is None else weights[in_label, label]
features = spike_data[in_label]
unit = GaussianUnit.from_features(features, weights, **self.unit_args)
unit = GaussianUnit.from_features(features, weights=w, **self.unit_args)
units.append(unit)

# determine their distances
distances = self.distances(units=units)
distances = self.distances(units=units, show_progress=False)

# determine their bimodalities while at once mini-reassigning
lls = spike_data.features.new_full((len(units), len(spike_data)), -torch.inf)
for j, unit in enumerate(units):
lls[j] = self.unit_log_likelihoods(
inds_, lls_ = self.unit_log_likelihoods(
unit=unit, spike_indices=spike_data.indices
)
if lls_ is not None:
lls[j] = lls
best_liks, labels = lls.max(dim=0)
labels[torch.isinf(best_liks)] = -1
labels = labels.numpy(force=True)
kept = np.flatnonzero(labels >= 0)
ids, labels[kept], counts = np.unique(labels[kept])
ids, labels[kept], counts = np.unique(labels[kept], return_inverse=True, return_counts=True)
ids = ids[counts > 0]
units = [u for j, u in enumerate(units) if counts[j]]
counts_dense = np.zeros(len(units), dtype=counts.dtype)
counts_dense[ids] = counts
units = [u for j, u in enumerate(units) if counts_dense[j]]
bimodalities = bimodalities_dense(
lls.numpy(force=True)[ids], labels, ids=np.arange(ids.size)
)
bimodalities_full = np.full_like(distances, np.inf)
bimodalities_full[ids[:, None], ids[None, :]] = bimodalities

# return merged labels
distances = combine_distances(
distances,
self.merge_distance_threshold,
bimodalities,
bimodalities_full,
self.merge_bimodality_threshold,
sym_function=self.merge_sym_function,
)
Expand Down Expand Up @@ -653,15 +659,15 @@ def unit_pair_bimodality(
masked=True,
max_spikes=2048,
dt_s=2.0,
debug=False,
score_kind="tv",
debug=False,
):
if in_units is not None:
ina = in_units[id_a]
inb = in_units[id_b]
else:
(ina,) = torch.nonzero(self.labels == id_a)
(inb,) = torch.nonzero(self.labels == id_b)
(ina,) = torch.nonzero(self.labels == id_a, as_tuple=True)
(inb,) = torch.nonzero(self.labels == id_b, as_tuple=True)

if masked:
times_a = self.data.times_seconds[ina]
Expand Down Expand Up @@ -755,7 +761,7 @@ def stack_units(self, units=None, distance_metric=None):
if distance_metric is None:
kind = self.distance_metric
nu, rank, nc = len(units), self.data.rank, self.data.n_channels
means = self.noise_unit.mean.new_zeros((nu, rank, nc))
means = torch.zeros((nu, rank, nc), device=self.data.device)
covs = logdets = None
if kind in ("kl_divergence",):
covs = means.new_zeros((nu, rank * nc, rank * nc))
Expand All @@ -765,6 +771,7 @@ def stack_units(self, units=None, distance_metric=None):
if covs is not None:
covs[j] = unit.dense_cov()
logdets[j] = unit.logdet
return means, covs, logdets


# -- modeling class
Expand Down Expand Up @@ -961,15 +968,19 @@ def divergence(
raise ValueError(f"Unknown divergence {kind=}.")

def noise_metric_divergence(self, other_means):
dmu = other_means - self.mean
dmu = other_means
if self.mean_kind != "zero":
dmu = dmu - self.mean
dmu = dmu.view(len(other_means), -1)
noise_cov = self.noise.marginal_covariance()
return noise_cov.inv_quad(dmu.T, reduce_inv_quad=False)

def kl_divergence(self, other_means, other_covs, other_logdets):
"""DKL(others || self)"""
n = other_means.shape[0]
dmu = other_means - self.mean
dmu = other_means
if self.mean_kind != "zero":
dmu = dmu - self.mean

# compute the inverse quad and self log det terms
inv_quad, self_logdet = self.cov.inv_quad_logdet(
Expand Down Expand Up @@ -1201,7 +1212,8 @@ def qda(
debug_info=None,
):
# "in b not a"-ness
diff = log_liks_b - log_liks_a
if diff is None:
diff = log_liks_b - log_liks_a
keep = np.isfinite(diff)
if not keep.mean() >= min_overlap:
return np.inf
Expand All @@ -1221,7 +1233,7 @@ def qda(
sample_weights=sample_weights,
dipscore_only=True,
score_kind=score_kind,
debug_info=None,
debug_info=debug_info,
)


Expand Down
3 changes: 2 additions & 1 deletion src/dartsort/cluster/kmeans.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,12 @@ def kmeans(
kmeanspp_initial=kmeanspp_initial,
)
# responsibilities, sum to 1 over centroids
e = F.softmax(-0.5 * dists, dim=1)
if not n_iter:
return labels, e

centroids = X[centroid_ixs]
dists = torch.cdist(X, centroids).square_()
e = F.softmax(-0.5 * dists, dim=1)
proportions = e.mean(0)

for j in range(n_iter):
Expand Down
2 changes: 2 additions & 0 deletions src/dartsort/cluster/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,7 @@ def cross_match_distance_matrix(
template_data, cross_mask, ids_a, ids_b = combine_templates(
template_data_a, template_data_b
)
print(f"{ids_a.shape=} {ids_b.shape=} {template_data.templates.shape=}")
units, dists, shifts, template_snrs = calculate_merge_distances(
template_data,
superres_linkage=superres_linkage,
Expand All @@ -363,6 +364,7 @@ def cross_match_distance_matrix(
# (with infs on main diag blocks)
a_mask = np.flatnonzero(np.isin(units, ids_a))
b_mask = np.flatnonzero(np.isin(units, ids_b))
print(f"{a_mask.shape=} {b_mask.shape=} {units.shape=}")
Dab = dists[a_mask[:, None], b_mask[None, :]]
Dba = dists[b_mask[:, None], a_mask[None, :]]
Dstack = np.stack((Dab, Dba.T))
Expand Down
11 changes: 7 additions & 4 deletions src/dartsort/cluster/stable_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,25 +53,28 @@ def __init__(
# neighborhoods module, for querying spikes by channel group
self.core_neighborhoods = core_neighborhoods

extract_amp_vecs = torch.linalg.vecnorm(extract_features, dim=1)
amps = extract_amp_vecs.max(1).values
extract_amp_vecs = torch.linalg.vector_norm(extract_features, dim=1)
amps = extract_amp_vecs.nan_to_num().max(1).values

# channel neighborhoods and features
# if not self.features_on_device, .spike_data() will .to(self.device)
times_s = torch.asarray(self.original_sorting.times_seconds[kept_indices])
if self.features_on_device:
self.register_buffer("core_channels", core_channels)
self.register_buffer("extract_channels", extract_channels)
self.register_buffer("core_features", core_features)
self.register_buffer("extract_features", extract_features)
# self.register_buffer("extract_amp_vecs", extract_amp_vecs)
self.register_buffer("amps", amps)
self.register_buffer("times_seconds", times_s)
else:
self.core_channels = core_channels
self.extract_channels = extract_channels
self.core_features = core_features
self.extract_features = extract_features
# self.extract_amp_vecs = extract_amp_vecs
self.amps = amps
self.times_seconds = times_s

# always on device
self.register_buffer("prgeom", prgeom)
Expand Down Expand Up @@ -379,7 +382,7 @@ def spike_neighborhoods(self, channels, spike_indices, min_coverage=1.0):
nhoodv = nhood[nhood < self.n_channels]
coverage = torch.isin(nhoodv, channels).sum() / nhoodv.numel()
if coverage >= min_coverage:
(member_indices,) = torch.nonzero(spike_ids == j)
(member_indices,) = torch.nonzero(spike_ids == j, as_tuple=True)
neighborhood_info[j] = (nhood, member_indices)
n_spikes += member_indices.numel()
return neighborhood_info, n_spikes
Expand All @@ -397,7 +400,7 @@ def interp_to_chans(
):
source_pos = prgeom[spike_data.channels]
target_pos = prgeom[channels]
shape = spike_data.n_spikes, *target_pos.shape
shape = len(spike_data), *target_pos.shape
target_pos = target_pos[None].broadcast_to(shape)
return interpolation_util.kernel_interpolate(
spike_data.features,
Expand Down
3 changes: 0 additions & 3 deletions src/dartsort/util/comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,9 +118,7 @@ def _calculate_template_distances(self):
matches = self.comparison.best_match_12.astype(int).values
matched = np.flatnonzero(matches >= 0)
matches = matches[matched]
print(f"{matched.sum()=} {matched.shape=}")
tested_td = self.tested_analysis.coarse_template_data[matches]
print(f"{tested_td.templates.shape=}")

dists, shifts, snrs_a, snrs_b = merge.cross_match_distance_matrix(
gt_td,
Expand All @@ -129,7 +127,6 @@ def _calculate_template_distances(self):
n_jobs=self.n_jobs,
device=self.device,
)
print(f"{dists.shape=}")
self._template_distances = np.full((nugt, nugt), np.inf)
self._template_distances[np.arange(nugt)[:, None], matched[None, :]] = dists

Expand Down
9 changes: 8 additions & 1 deletion src/dartsort/util/noise_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,9 +349,16 @@ def __init__(

# precompute stuff
self._full_cov = None
self._logdet = None
self.register_buffer("mean_full", self.mean_rc().clone().detach())
self.cache = {}

@property
def logdet(self):
if self._logdet is None:
self.marginal_covariance()
return self._logdet

@property
def device(self):
return self.global_std.device
Expand Down Expand Up @@ -383,7 +390,7 @@ def marginal_covariance(self, channels=slice(None), cache_key=None):
if channels == slice(None):
if self._full_cov is None:
self._full_cov = self._marginal_covariance()
self.logdet = self._full_cov.logdet()
self._logdet = self._full_cov.logdet()
return self._full_cov
cov = self._marginal_covariance(channels)
if cache_key is not None:
Expand Down
Loading

0 comments on commit f480417

Please sign in to comment.