Skip to content

Commit

Permalink
Test
Browse files Browse the repository at this point in the history
  • Loading branch information
cwindolf committed Nov 15, 2024
1 parent 8fdfc30 commit 1ae0047
Showing 1 changed file with 36 additions and 16 deletions.
52 changes: 36 additions & 16 deletions src/dartsort/cluster/gaussian_mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,16 +249,16 @@ def m_step(self, likelihoods=None, show_progress=False, prev_means=None):
for j, unit in enumerate(zip(unit_ids, results)):
assert unit.annotations["unit_id"] == j
self.units.append(unit)
if self.log_proportions is not None:
# this is the index of the noise unit. it's got to be larger than
# the largest unit index
maxix = self.log_proportions.numel() - 1
assert (unit_ids < maxix).all()
ixs = torch.cat((unit_ids, torch.tensor([maxix])))
self.log_proportions = self.log_proportions[ixs]
# if self.log_proportions is not None:
# # this is the index of the noise unit. it's got to be larger than
# # the largest unit index
# maxix = self.log_proportions.numel() - 1
# assert (unit_ids < maxix).all()
# ixs = torch.cat((unit_ids, torch.tensor([maxix])))
# self.log_proportions = self.log_proportions[ixs]
if prev_means is not None:
nu = len(unit_ids)
prev_means = prev_means[unit_ids]
# prev_means = prev_means[unit_ids]
new_means, *_ = self.stack_units(mean_only=True)
dmu = (prev_means - new_means).abs_().view(nu, -1)
adif = torch.max(dmu)
Expand Down Expand Up @@ -1091,19 +1091,32 @@ def unit_group_criterion(

if fit_type == "refit_all":
units = []
subunit_logliks = spikes_core.features.new_full((len(unit_ids), len(in_any)), -torch.inf)
subunit_logliks = spikes_core.features.new_full(
(len(unit_ids), len(in_any)), -torch.inf
)
full_loglik = 0.0
for i, k in enumerate(unit_ids):
u = self.fit_unit(unit_id=k, indices=in_any, likelihoods=likelihoods, features=spikes_extract)
u = self.fit_unit(
unit_id=k,
indices=in_any,
likelihoods=likelihoods,
features=spikes_extract,
)
units.append(u)
_, subunit_logliks[i] = self.unit_log_likelihoods(unit=u, spikes=spikes_core)
_, subunit_logliks[i] = self.unit_log_likelihoods(
unit=u, spikes=spikes_core
)
subunit_log_props = F.softmax(subunit_logliks, dim=0).mean(1).log_()
# loglik per spik
full_loglik = torch.logsumexp(subunit_logliks.T + subunit_log_props, dim=1).mean()
full_loglik = torch.logsumexp(
subunit_logliks.T + subunit_log_props, dim=1
).mean()
unit = self.fit_unit(indices=in_any, features=spikes_extract)
likelihoods = None
elif fit_type == "avg_preexisting":
unit = self.units[unit_ids[0]].avg_with(*[self.units[u] for u in unit_ids[1:]])
unit = self.units[unit_ids[0]].avg_with(
*[self.units[u] for u in unit_ids[1:]]
)
if debug:
subunit_logliks = likelihoods[:, in_any][unit_ids]
full_loglik = marginal_loglik(
Expand Down Expand Up @@ -1377,8 +1390,13 @@ def avg_with(self, *others):
n_channels=self.n_channels,
noise=self.noise,
)
new.register_buffer("mean", (self.mean + sum(o.mean for o in others)) / (1 + len(others)))
new.register_buffer("channels", torch.cat([self.channels, *[o.channels for o in others]]).unique())
new.register_buffer(
"mean", (self.mean + sum(o.mean for o in others)) / (1 + len(others))
)
new.register_buffer(
"channels",
torch.cat([self.channels, *[o.channels for o in others]]).unique(),
)
assert self.cov_kind == "zero"
return new

Expand Down Expand Up @@ -1726,7 +1744,9 @@ def coo_to_scipy(coo_tensor):
return coo_array((data, coords), shape=coo_tensor.shape)


def marginal_loglik(indices, log_proportions, log_likelihoods, unit_ids=None, reduce="mean"):
def marginal_loglik(
indices, log_proportions, log_likelihoods, unit_ids=None, reduce="mean"
):
if unit_ids is not None:
# renormalize log props
log_proportions = log_proportions[unit_ids]
Expand Down

0 comments on commit 1ae0047

Please sign in to comment.