From f22e1ba9afde539e83c795fa04ed3b2c566efeaa Mon Sep 17 00:00:00 2001 From: Charlie Windolf Date: Wed, 25 Sep 2024 14:06:40 -0700 Subject: [PATCH] Lates NN stuff + clustering vis debugging + GT metrics --- src/dartsort/peel/grab.py | 10 +- src/dartsort/peel/threshold.py | 2 + src/dartsort/transform/all_transformers.py | 7 +- src/dartsort/transform/decollider.py | 54 ++++--- src/dartsort/transform/transform_base.py | 39 +++++ src/dartsort/transform/vae_localize.py | 165 +++++++++++++++------ src/dartsort/util/nn_util.py | 34 +++++ src/dartsort/vis/gmm.py | 2 +- 8 files changed, 249 insertions(+), 64 deletions(-) diff --git a/src/dartsort/peel/grab.py b/src/dartsort/peel/grab.py index 487bd324..9c2b4029 100644 --- a/src/dartsort/peel/grab.py +++ b/src/dartsort/peel/grab.py @@ -1,7 +1,7 @@ import torch from dartsort.util import spiketorch -from .peel_base import BasePeeler +from .peel_base import BasePeeler, SpikeDataset class GrabAndFeaturize(BasePeeler): @@ -38,6 +38,13 @@ def __init__( self.register_buffer("times_samples", times_samples) self.register_buffer("channels", channels) + def out_datasets(self): + datasets = super().out_datasets() + datasets.append( + SpikeDataset(name="indices", shape_per_spike=(), dtype=int) + ) + return datasets + def process_chunk(self, chunk_start_samples, return_residual=False): """Override process_chunk to skip empties.""" chunk_end_samples = min( @@ -88,6 +95,7 @@ def peel_chunk( return dict( n_spikes=in_chunk.numel(), + indices=in_chunk, times_samples=self.times_samples[in_chunk], channels=channels, collisioncleaned_waveforms=waveforms, diff --git a/src/dartsort/peel/threshold.py b/src/dartsort/peel/threshold.py index 702793ea..f58901eb 100644 --- a/src/dartsort/peel/threshold.py +++ b/src/dartsort/peel/threshold.py @@ -23,6 +23,7 @@ def __init__( relative_peak_radius_samples=5, dedup_temporal_radius_samples=7, n_chunks_fit=40, + n_waveforms_fit=20_000, max_waveforms_fit=50_000, fit_subsampling_random_state=0, dtype=torch.float, @@ -34,6 +35,7 @@ def __init__( chunk_length_samples=chunk_length_samples, chunk_margin_samples=spike_length_samples, n_chunks_fit=n_chunks_fit, + n_waveforms_fit=n_waveforms_fit, max_waveforms_fit=max_waveforms_fit, fit_subsampling_random_state=fit_subsampling_random_state, dtype=dtype, diff --git a/src/dartsort/transform/all_transformers.py b/src/dartsort/transform/all_transformers.py index 59afb12c..bdce9678 100644 --- a/src/dartsort/transform/all_transformers.py +++ b/src/dartsort/transform/all_transformers.py @@ -4,7 +4,8 @@ from .vae_localize import VAELocalization from .single_channel_denoiser import SingleChannelWaveformDenoiser from .temporal_pca import TemporalPCADenoiser, TemporalPCAFeaturizer, TemporalPCA -from .transform_base import Waveform +from .transform_base import Waveform, Passthrough +from .decollider import Decollider all_transformers = [ Waveform, @@ -16,10 +17,12 @@ TemporalPCAFeaturizer, Localization, PointSourceLocalization, - VAELocalization, + VAELocalization, AmplitudeFeatures, TemporalPCA, Voltage, + Decollider, + Passthrough, ] transformers_by_class_name = {cls.__name__: cls for cls in all_transformers} diff --git a/src/dartsort/transform/decollider.py b/src/dartsort/transform/decollider.py index 56047d5a..9a15656d 100644 --- a/src/dartsort/transform/decollider.py +++ b/src/dartsort/transform/decollider.py @@ -25,6 +25,9 @@ def __init__( noisier3noise=False, inference_kind="raw", seed=0, + batch_size=32, + learning_rate=1e-3, + epochs=25, ): assert inference_kind in ("raw", "amortized") @@ -34,7 +37,10 @@ def __init__( self.hidden_dims = hidden_dims self.n_channels = len(geom) self.recording = recording - self.rg = np.random.get_default_rng(seed) + self.batch_size = batch_size + self.learning_rate = learning_rate + self.epochs = epochs + self.rg = np.random.default_rng(seed) super().__init__( geom=geom, channel_index=channel_index, name=name, name_prefix=name_prefix @@ -50,6 +56,15 @@ def __init__( "relative_index", get_relative_index(self.channel_index, self.model_channel_index), ) + # suburban lawns -- janitor + self.register_buffer( + "irrelative_index", + get_relative_index(self.model_channel_index, self.channel_index), + ) + self._needs_fit = True + + def needs_fit(self): + return self._needs_fit def initialize_nets(self, spike_length_samples): self.spike_length_samples = spike_length_samples @@ -83,8 +98,9 @@ def fit(self, waveforms, max_channels): waveforms = reindex(max_channels, waveforms, self.relative_index, pad_value=0.0) with torch.enable_grad(): self._fit(waveforms, max_channels) + self._needs_fit = False - def transform(self, waveforms, max_channels): + def forward(self, waveforms, max_channels): """Called only at inference time.""" n = len(waveforms) waveforms = reindex(max_channels, waveforms, self.relative_index, pad_value=0.0) @@ -92,18 +108,20 @@ def transform(self, waveforms, max_channels): net_input = torch.cat((waveforms.view(n, self.wf_dim), masks), dim=1) if self.inference_kind == "amortized": - pred = self.inf_net(net_input) + pred = self.inf_net(net_input).view(waveforms.shape) elif self.inference_kind == "raw": - pred = self.eyz(net_input) + pred = self.eyz(net_input).view(waveforms.shape) else: assert False + pred = reindex(max_channels, pred, self.irrelative_index) + return pred def get_masks(self, max_channels): return self.model_channel_index[max_channels] < self.n_channels - def forward(self, y, m, mask): + def train_forward(self, y, m, mask): n = len(y) z = y + m z_flat = z.view(n, self.wf_dim) @@ -117,18 +135,18 @@ def forward(self, y, m, mask): # predictions given z if self.noisier3noise: - eyz = self.eyz(z_masked) - emz = self.emz(z_masked) + eyz = self.eyz(z_masked).view(y.shape) + emz = self.emz(z_masked).view(y.shape) exz = eyz - emz else: - eyz = self.eyz(z_masked) + eyz = self.eyz(z_masked).view(y.shape) exz = 2 * eyz - z # predictions given y, if relevant if self.inference_kind == "amortized": y_flat = y.view(n, self.wf_dim) y_masked = torch.cat((y_flat, mask), dim=1) - exy = self.inf_net(y_masked) + exy = self.inf_net(y_masked).view(y.shape) return exz, eyz, emz, exy @@ -157,8 +175,9 @@ def get_noise(self, channels): return torch.from_numpy(noise_waveforms) - def loss(mask, waveforms, m, exz, eyz, emz=None, exy=None): + def loss(self, mask, waveforms, m, exz, eyz, emz=None, exy=None): loss_dict = {} + mask = mask.unsqueeze(1) loss_dict["eyz"] = F.mse_loss(mask * eyz, mask * waveforms) if emz is not None: loss_dict["emz"] = F.mse_loss(mask * emz, mask * m) @@ -167,31 +186,30 @@ def loss(mask, waveforms, m, exz, eyz, emz=None, exy=None): return loss_dict def _fit(self, waveforms, channels): - self.initialize_net(waveforms.shape[1]) + self.initialize_nets(waveforms.shape[1]) dataset = TensorDataset(waveforms, channels) dataloader = DataLoader(dataset, batch_size=self.batch_size, shuffle=True) optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate) + self.to(waveforms.device) with trange(self.epochs, desc="Epochs", unit="epoch") as pbar: - epoch_losses = {} for epoch in pbar: + epoch_losses = {} for waveform_batch, channels_batch in dataloader: optimizer.zero_grad() # get a batch of noise samples m = self.get_noise(channels_batch).to(waveform_batch) mask = self.get_masks(channels_batch).to(waveform_batch) - exz, eyz, emz, exy = self.forward( - waveforms, - m, - ) + exz, eyz, emz, exy = self.train_forward(waveform_batch, m, mask) loss_dict = self.loss(mask, waveform_batch, m, exz, eyz, emz, exy) loss = sum(loss_dict.values()) loss.backward() optimizer.step() - for k, v in loss_dict: + for k, v in loss_dict.items(): epoch_losses[k] = v.item() + epoch_losses.get(k, 0.0) - loss_str = ", ".join(f"{k}: {v:0.2f}" for k, v in epoch_losses.items()) + epoch_losses = {k: v / len(dataloader) for k, v in epoch_losses.items()} + loss_str = ", ".join(f"{k}: {v:.3f}" for k, v in epoch_losses.items()) pbar.set_description(f"Epochs [{loss_str}]") diff --git a/src/dartsort/transform/transform_base.py b/src/dartsort/transform/transform_base.py index 355e36de..60b3cd71 100644 --- a/src/dartsort/transform/transform_base.py +++ b/src/dartsort/transform/transform_base.py @@ -81,6 +81,45 @@ class BaseWaveformAutoencoder(BaseWaveformDenoiser, BaseWaveformFeaturizer): pass +class Passthrough(BaseWaveformDenoiser, BaseWaveformFeaturizer): + + def __init__(self, pipeline): + feat = [t for t in pipeline if t.is_featurizer] + if not len(feat): + raise ValueError("Passthrough with no featurizer?") + name = f"passthrough_{feat[0].name}" + super().__init__(name=name) + self.pipeline = pipeline + + def needs_precompute(self): + return self.pipeline.needs_precompute() + + def precompute(self): + return self.pipeline.precompute() + + def needs_fit(self): + return self.pipeline.needs_fit() + + def fit(self, waveforms, max_channels): + self.pipeline.fit(waveforms, max_channels) + + def forward(self, waveforms, max_channels=None): + pipeline_waveforms, pipeline_features = self.pipeline(waveforms, max_channels) + return waveforms, pipeline_features + + @property + def spike_datasets(self): + datasets = [] + for t in self.pipeline.transformers: + if t.is_featurizer: + datasets.extend(t.spike_datasets) + return datasets + + def transform(self, waveforms, max_channels=None): + pipeline_waveforms, pipeline_features = self.pipeline(waveforms, max_channels) + return pipeline_features + + class IdentityWaveformDenoiser(BaseWaveformDenoiser): def forward(self, waveforms, max_channels=None): return waveforms diff --git a/src/dartsort/transform/vae_localize.py b/src/dartsort/transform/vae_localize.py index 8aa7ff3e..5859a3b7 100644 --- a/src/dartsort/transform/vae_localize.py +++ b/src/dartsort/transform/vae_localize.py @@ -36,34 +36,41 @@ def __init__( prior_variance=80.0, convergence_eps=0.01, min_epochs=2, - scale_loss_by_mean=False, + scale_loss_by_mean=True, + reference='main_channel', + channelwise_dropout_p=0.0, ): - assert localization_model == "pointsource" + assert localization_model in ("pointsource", "dipole") assert amplitude_kind in ("peak", "ptp") + assert reference in ('main_channel', 'com') super().__init__( geom=geom, channel_index=channel_index, name=name, name_prefix=name_prefix ) + self.amplitude_kind = amplitude_kind self.radius = radius self.localization_model = localization_model - self.latent_dim = 3 + (not alpha_closed_form) + alpha_dim = 1 + 2 * (localization_model == "dipole") + self.latent_dim = 3 + (not alpha_closed_form) * alpha_dim self.epochs = epochs self.learning_rate = learning_rate self.batch_size = batch_size self.encoder = None self.amplitudes_only = amplitudes_only self.use_batchnorm = use_batchnorm - self.register_buffer( - "padded_geom", F.pad(self.geom.to(torch.float), (0, 0, 0, 1)) - ) self.hidden_dims = hidden_dims self.alpha_closed_form = alpha_closed_form self.variational = prior_variance is not None - self.prior_variance = prior_variance + self.prior_variance = torch.tensor(prior_variance) if prior_variance is not None else None self.convergence_eps = convergence_eps self.min_epochs = min_epochs self.scale_loss_by_mean = scale_loss_by_mean + self.channelwise_dropout_p = channelwise_dropout_p + self.reference = reference + self.register_buffer( + "padded_geom", F.pad(self.geom.to(torch.float), (0, 0, 0, 1)) + ) self.register_buffer( "model_channel_index", make_regular_channel_index(geom=self.geom, radius=radius, to_torch=True), @@ -72,6 +79,11 @@ def __init__( "relative_index", get_relative_index(self.channel_index, self.model_channel_index), ) + self.nc = len(self.geom) + self._needs_fit = True + + def needs_fit(self): + return self._needs_fit def initialize_net(self, spike_length_samples): if self.encoder is not None: @@ -81,11 +93,13 @@ def initialize_net(self, spike_length_samples): if self.variational: n_latent *= 2 - self.encoder = nn_util.get_mlp( - (spike_length_samples + 1) * self.model_channel_index.shape[1], + self.encoder = nn_util.get_waveform_mlp( + spike_length_samples, + self.model_channel_index.shape[1], self.hidden_dims, n_latent, use_batchnorm=self.use_batchnorm, + channelwise_dropout_p=self.channelwise_dropout_p, ) self.encoder.to(self.padded_geom.device) @@ -96,10 +110,22 @@ def reparameterize(self, mu, var): eps = torch.randn_like(std) return mu + eps * std - def local_distances(self, z, channels): + def get_reference_points(self, channels, obs_amps=None, neighborhoods=None): + if self.reference == 'main_channel': + return self.padded_geom[channels] + elif self.reference == 'com': + if neighborhoods is None: + neighborhoods = self.padded_geom[self.model_channel_index[channels]] + w = obs_amps / obs_amps.sum(1, keepdims=True) + centers = torch.sum(w.unsqueeze(-1) * neighborhoods, dim=1) + return centers + else: + assert False + + def local_distances(self, z, channels, obs_amps=None): """Return distances from each z to its local geom centered at channels.""" - centers = self.padded_geom[channels] neighbors = self.padded_geom[self.model_channel_index[channels]] + centers = self.get_reference_points(channels, obs_amps=obs_amps, neighborhoods=neighbors) local_geom = neighbors - centers.unsqueeze(1) dx = z[:, 0, None] - local_geom[:, :, 0] dz = z[:, 2, None] - local_geom[:, :, 1] @@ -117,8 +143,8 @@ def get_alphas(self, obs_amps, dists, masks, return_pred=False): return alphas, alphas.unsqueeze(1) * pred_amps_alpha1 return alphas - def decode(self, z, channels, obs_amps, masks): - dists = self.local_distances(z, channels) + def point_source_model(self, z, obs_amps, masks, channels): + dists = self.local_distances(z, channels, obs_amps=obs_amps) if self.alpha_closed_form: alphas, pred_amps = self.get_alphas( obs_amps, dists, masks, return_pred=True @@ -128,10 +154,43 @@ def decode(self, z, channels, obs_amps, masks): pred_amps = alphas.unsqueeze(1) / dists return alphas, pred_amps + def dipole_model(self, z, obs_amps, masks, channels): + neighbors = self.padded_geom[self.model_channel_index[channels]] + centers = self.get_reference_points(channels, obs_amps=obs_amps, neighborhoods=neighbors) + local_geom = neighbors - centers.unsqueeze(1) + + # displacements from probe + dx = z[:, 0, None] - local_geom[:, :, 0] + dz = z[:, 2, None] - local_geom[:, :, 1] + y = F.softplus(z[:, 1]).unsqueeze(1) + duv = torch.stack((dx, y.broadcast_to(dx.shape), dz), dim=2) + + # displacment over distance cubed. (n_spikes, n_chans, 3) + X = duv * duv.square().sum(2, keepdim=True).pow(-1.5) + if self.alpha_closed_form: + # beta = torch.linalg.pinv(X.mT @ X) @ (X.mT @ obs_amps.unsqueeze(2)) + # beta = torch.linalg.lstsq(X.mT @ X, X.mT @ obs_amps.unsqueeze(2)).solution + beta = torch.linalg.lstsq(X, obs_amps.unsqueeze(2)).solution + pred_amps = (X @ beta)[:, :, 0] + beta = beta[:, :, 0] + else: + beta = z[:, 3:] + pred_amps = (X @ beta.unsqueeze(2))[:, :, 0] + + return beta, pred_amps + + def decode(self, z, channels, obs_amps, masks): + if self.localization_model in ("pointsource", "monopole"): + alphas, pred_amps = self.point_source_model(z, obs_amps, masks, channels) + elif self.localization_model == "dipole": + alphas, pred_amps = self.dipole_model(z, obs_amps, masks, channels) + else: + assert False + return alphas, pred_amps + def forward(self, x, mask, obs_amps, channels): - x_flat = x.view(x.size(0), -1) - x_flat_mask = torch.cat((x_flat, mask), dim=1) - mu = self.encoder(x_flat_mask) + x_mask = torch.cat((x, mask.unsqueeze(1)), dim=1) + mu = self.encoder(x_mask) var = None if self.variational: mu, var = mu.chunk(2, dim=-1) @@ -148,20 +207,24 @@ def loss_function(self, recon_x, x, mask, mu, var): rescale = mask.sum(1, keepdim=True) / x_masked.sum(1, keepdim=True) x_masked *= rescale recon_x_masked *= rescale - MSE = F.mse_loss(recon_x_masked, x_masked, reduction="sum") + MSE = F.mse_loss(recon_x_masked, x_masked, reduction="sum") / self.batch_size KLD = 0.0 if self.variational: - KLD = ( - var / self.prior_variance + KLD = 0.5 * ( + torch.log(self.prior_variance / var) + mu.pow(2) / self.prior_variance - 1 - ).sum() + ).sum() / self.batch_size return MSE, KLD def _fit(self, waveforms, channels): # apply channel reindexing before any fitting... - if self.amplitudes_only: + if waveforms.ndim == 2: + assert self.amplitudes_only + waveforms = waveforms.unsqueeze(1) + waveforms = reindex(channels, waveforms, self.relative_index, pad_value=0.0) + amps = waveforms[:, 0] + elif self.amplitudes_only: if self.amplitude_kind == "ptp": waveforms = ptp(waveforms) elif self.amplitude_kind == "peak": @@ -183,7 +246,7 @@ def _fit(self, waveforms, channels): optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate) self.train() - loss_history = [] + mse_history = [] with trange(self.epochs, desc="Epochs", unit="epoch") as pbar: for epoch in pbar: total_loss = 0 @@ -209,7 +272,8 @@ def _fit(self, waveforms, channels): total_kld += kld.item() if self.variational else kld loss = total_loss / len(dataloader) - loss_history.append(loss) + mse = total_mse / len(dataloader) + mse_history.append(mse) desc = f"[loss={loss:0.2f},mse={total_mse/len(dataloader):0.2f},kld={total_kld/len(dataloader):0.2f}]" pbar.set_description(f"Epochs {desc}") @@ -217,46 +281,63 @@ def _fit(self, waveforms, channels): if epoch < self.min_epochs: continue - diff = abs(loss - loss_history[-2]) - if diff / loss_history[-2] < self.convergence_eps: + diff = min(mse_history[:-1]) - mse + if diff / min(mse_history[:-1]) < self.convergence_eps: pbar.set_description(f"Converged epoch={epoch} {desc}") break def fit(self, waveforms, max_channels): with torch.enable_grad(): self._fit(waveforms, max_channels) + self._needs_fit = False - def transform(self, waveforms, max_channels, return_amps=False): + def transform(self, waveforms, max_channels, return_extra=False): """ waveforms : torch.tensor, shape (num_waveforms, n_timesteps, n_channels_subset) max_channels : torch.tensor, shape (num_waveforms,) waveform[n] lives on channels self.channel_index[max_channels[n]] """ - if self.amplitudes_only: + # handle getting amplitudes, reindexing channels, and amplitudes_only logic + if waveforms.ndim == 2: + assert self.amplitudes_only + waveforms = waveforms.unsqueeze(1) + elif self.amplitudes_only: if self.amplitude_kind == "ptp": - waveforms = ptp(waveforms) + obs_amps = ptp(waveforms) elif self.amplitude_kind == "peak": - waveforms = waveforms.abs().max(dim=1).values - waveforms = waveforms[:, None] + obs_amps = waveforms.abs().max(dim=1).values + waveforms = obs_amps[:, None] waveforms = reindex(max_channels, waveforms, self.relative_index, pad_value=0.0) - mask = self.model_channel_index[max_channels] < len(self.geom) + if self.amplitudes_only: + obs_amps = waveforms[:, 0] + elif return_extra or self.reference == 'com': + if self.amplitude_kind == "ptp": + obs_amps = ptp(waveforms) + elif self.amplitude_kind == "peak": + obs_amps = waveforms.abs().max(dim=1).values + else: + # in this condition, we don't need the amp vecs + obs_amps = None + + # nn inputs + mask = self.model_channel_index[max_channels] < self.nc mask = mask.to(waveforms) - x_flat = waveforms.view(len(waveforms), -1) - x_flat_mask = torch.cat((x_flat, mask), dim=1) - mu = self.encoder(x_flat_mask) + x_mask = torch.cat((waveforms, mask.unsqueeze(1)), dim=1) + + # encode + mu = self.encoder(x_mask) var = None if self.variational: mu, var = mu.chunk(2, dim=-1) x, y, z = mu[:, :3].T y = F.softplus(y) - mx, mz = self.geom[max_channels].T - if return_amps: - if self.amplitude_kind == "ptp": - obs_amps = ptp(waveforms) - elif self.amplitude_kind == "peak": - obs_amps = waveforms.abs().max(dim=1).values + mx, mz = self.get_reference_points(max_channels, obs_amps=obs_amps).T + x = x + mx + z = z + mz + + if return_extra: alphas, pred_amps = self.decode(mu, max_channels, obs_amps, mask) - return x + mx, y, z + mz, obs_amps, pred_amps, mx, mz + return x, y, z, obs_amps, pred_amps, mx, mz - return x + mx, y, z + mz + return x, y, z diff --git a/src/dartsort/util/nn_util.py b/src/dartsort/util/nn_util.py index c58d02ad..5cc290de 100644 --- a/src/dartsort/util/nn_util.py +++ b/src/dartsort/util/nn_util.py @@ -1,4 +1,5 @@ from torch import nn +import torch.nn.functional as F def get_mlp(input_dim, hidden_dims, output_dim, use_batchnorm=True): @@ -15,3 +16,36 @@ def get_mlp(input_dim, hidden_dims, output_dim, use_batchnorm=True): layers.append(nn.Linear(final_dim, output_dim)) return nn.Sequential(*layers) + + +def get_waveform_mlp( + spike_length_samples, + n_input_channels, + hidden_dims, + output_dim, + input_includes_mask=True, + use_batchnorm=True, + channelwise_dropout_p=0.0, +): + input_dim = n_input_channels * (spike_length_samples + input_includes_mask) + + layers = [] + if channelwise_dropout_p: + layers.append(ChannelwiseDropout(channelwise_dropout_p)) + layers.append(nn.Flatten()) + layers.append(get_mlp(input_dim, hidden_dims, output_dim, use_batchnorm=use_batchnorm)) + return nn.Sequential(*layers) + + +class ChannelwiseDropout(nn.Module): + + def __init__(self, p): + super().__init__() + self.p = p + + def forward(self, waveforms): + return F.dropout1d( + waveforms.permute(0, 2, 1), + p=self.p, + training=self.training, + ).permute(0, 2, 1) \ No newline at end of file diff --git a/src/dartsort/vis/gmm.py b/src/dartsort/vis/gmm.py index 5b237b14..04ba0e68 100644 --- a/src/dartsort/vis/gmm.py +++ b/src/dartsort/vis/gmm.py @@ -1763,7 +1763,7 @@ def draw(self, panel, gmm, unit_id): lw=0, alpha=0.5, ) - row[0].set_title(f"{ns=} {nu=} 1-rho={1.0 - pear.statistic} p={pear.pvalue}", fontsize=6) + row[0].set_title(f"{ns=} {nu=} 1-rho={1.0 - pear.statistic:.2f} p={pear.pvalue:.2f}", fontsize=6) row[0].set_xlabel(f"{u}: {self.badness_kind}") row[0].set_ylabel(f"{unit_id}: {self.badness_kind}")