From 19da6755ef6e6e921231a3098d54950627ef747f Mon Sep 17 00:00:00 2001 From: Charlie Windolf Date: Tue, 15 Oct 2024 17:07:43 -0400 Subject: [PATCH] Port over interpolation code; sketch out embedded + shifted noise --- src/dartsort/util/interpolation_util.py | 229 +++++++++++++++++ src/dartsort/util/noise_util.py | 324 ++++++++++++++++++++++-- 2 files changed, 538 insertions(+), 15 deletions(-) create mode 100644 src/dartsort/util/interpolation_util.py diff --git a/src/dartsort/util/interpolation_util.py b/src/dartsort/util/interpolation_util.py new file mode 100644 index 00000000..2886fb26 --- /dev/null +++ b/src/dartsort/util/interpolation_util.py @@ -0,0 +1,229 @@ +"""Library for flavors of kernel interpolation and data interp utilities""" + +import numpy as np +import torch +import torch.nn.functional as F +from dartsort.util.data_util import yield_masked_chunks +from dartsort.util.drift_util import (get_spike_pitch_shifts, + static_channel_neighborhoods) + +interp_kinds = ( + "nearest", + "rbf", + "normalized", + "kriging", + "kriging_normalized", +) + + +def interpolate_by_chunk( + mask, + dataset, + geom, + channel_index, + channels, + shifts, + registered_geom, + target_channels, + sigma=10.0, + interpolation_method="normalized", + device=None, + store_on_device=False, + show_progress=True, +): + """Interpolate data living in an HDF5 file + + If dataset is a h5py.Dataset and mask is a boolean array indicating + positions of data to load, this iterates over the HDF5 chunks to + quickly scan through the data, applying interpolation to all the + features. + + Arguments + --------- + mask : boolean np.ndarray + Load and interpolate these entries. Shape should be + (n_spikes_full,), and let's say it has n_spikes nonzero entries. + dataset : h5py.Dataset + Chunked dataset, shape (n_spikes_full, feature_dim, n_source_channels) + Can only be chunked on the first axis + geom : array or tensor + channel_index : int array or tensor + channels : int array or tensor + Shape (n_spikes,) + shifts : array or tensor + Shape (n_spikes,) or (n_spikes, n_source_channels) + registered_geom : array or tensor + target_channels : int array or tensor + (n_spikes, n_target_channels) + sigma : float + Kernel bandwidth + interpolation_method : str + device : torch device + store_on_device : bool + Allocate the output tensor on gpu? + show_progress : bool + + Returns + ------- + out : torch.Tensor + (n_spikes, feature_dim, n_target_chans) + """ + # devices, dtypes, shapes + assert interpolation_method in interp_kinds + if device is None: + device = "cuda" if torch.cuda.is_available() else "cpu" + device = torch.device(device) + dtype = torch.from_numpy(np.empty((), dtype=dataset.dtype)).dtype + n_spikes = mask.sum() + assert channels.shape == (n_spikes,) + n_target_chans = target_channels.shape[1] + assert target_channels.shape == (n_spikes, n_target_chans) + feature_dim = dataset.shape[1] + assert channel_index.shape[1] == dataset.shape[2] + + # allocate output + storage_device = device if store_on_device else "cpu" + out_shape = n_spikes, feature_dim, n_target_chans + out = torch.empty(out_shape, dtype=dtype, device=storage_device) + + # build data needed for interpolation + source_geom = pad_geom(geom, dtype=dtype, device=device) + target_geom = pad_geom(registered_geom, dtype=dtype, device=device) + shifts = torch.as_tensor(shifts, dtype=dtype).to(device) + target_channels = torch.as_tensor(target_channels, device=device) + channel_index = torch.as_tensor(channel_index, device=device) + channels = torch.as_tensor(channels, device=device) + + for ixs, chunk_features in yield_masked_chunks( + mask, dataset, show_progress=show_progress, desc_prefix="Interpolating" + ): + # where are the spikes? + source_channels = channel_index[channels[ixs]] + source_shifts = shifts[ixs] + if source_shifts.ndim == 1: + # allows per-channel shifts + source_shifts = source_shifts.unsqueeze(1) + source_pos = source_geom[source_channels] + source_shifts + + # where are they going? + target_pos = target_geom[target_channels[ixs]] + + # interpolate, store + chunk_res = kernel_interpolate( + chunk_features, + source_pos, + target_pos, + sigma=sigma, + allow_destroy=True, + interpolation_method=interpolation_method, + ) + out[ixs] = chunk_res.to(out) + + return out + + +def pad_geom(geom, dtype=torch.float, device=None): + geom = torch.as_tensor(geom, dtype=dtype, device=device) + geom = F.pad(geom, (0, 0, 0, 1), value=torch.nan) + return geom + + +def kernel_interpolate( + features, + source_pos, + target_pos, + source_kernel_invs=None, + sigma=10.0, + allow_destroy=False, + interpolation_method="normalized", + out=None, +): + """Kernel interpolation of multi-channel features or waveforms + + Arguments + --------- + features : torch.Tensor + n_spikes, feature_dim, n_source_channels + These can be masked, indicated by nans here and in the same + places of source_pos + source_pos : torch.Tensor + n_spikes, n_source_channels, spatial_dim + target_pos : torch.Tensor + n_spikes, n_target_channels, spatial_dim + These can also be masked, indicate with nans and you will + get nans in those positions + source_kernel_invs : optional torch.Tensor + Precomputed inverses of source-to-source kernel matrices, + if you have them, for use in kriging + sigma : float + Spatial bandwidth of RBF kernels + allow_destroy : bool + We need to overwrite nans in the features with 0s. If you + allow me, I'll do that in-place. + out : torch.Tensor + Storage for target + + Returns + ------- + features : torch.Tensor + n_spikes, feature_dim, n_target_channels + """ + assert interpolation_method in interp_kinds + + # -- build kernel + if interpolation_method == "nearest": + d = torch.cdist(source_pos, target_pos) + kernel = torch.zeros_like(d) + kernel[d.argmin(dim=(1, 2), keepdim=True)] = 1 + else: + kernel = log_rbf(source_pos, target_pos, sigma) + if interpolation_method == "normalized": + kernel = F.softmax(kernel, dim=1) + kernel.nan_to_num_() + elif interpolation_method.startswith("kriging"): + kernel = kernel.exp_() + kernel = source_kernel_invs @ kernel + if interpolation_method == "kriging_normalized": + kernel = kernel / kernel.sum(1, keepdim=True) + elif interpolation_method == "rbf": + kernel = kernel.exp_() + else: + assert False + + # -- apply kernel + features = torch.nan_to_num(features, out=features if allow_destroy else None) + features = torch.bmm(features, kernel, out=out) + + # nan-ify nonexistent chans + needs_nan = torch.isnan(target_pos).all(2).unsqueeze(1) + needs_nan = needs_nan.broadcast_to(features.shape) + features[needs_nan] = torch.nan + + return features + + +def log_rbf(source_pos, target_pos=None, sigma=None): + """Log of RBF kernel + + This handles missing values in source_pos or target_pos, indicated by + nans, by replacing them with -inf so that they exp to 0. + + Arguments + --------- + source_pos : torch.tensor + n source locations + target_pos : torch.tensor + m target locations + sigma : float + + Returns + ------- + kernel : torch.tensor + n by m + """ + if target_pos is None: + target_pos = source_pos + kernel = torch.cdist(source_pos, target_pos) + kernel = kernel.square_().mul_(-1.0 / (2 * sigma**2)) + torch.nan_to_num(kernel, nan=-torch.inf, out=kernel) + return kernel diff --git a/src/dartsort/util/noise_util.py b/src/dartsort/util/noise_util.py index f83a0f85..84cb4e8d 100644 --- a/src/dartsort/util/noise_util.py +++ b/src/dartsort/util/noise_util.py @@ -1,17 +1,16 @@ +import h5py import numpy as np import pandas as pd import torch -import torch.nn.functional as F - -from dartsort.util import spiketorch from dartsort.detect import detect_and_deduplicate +from dartsort.util import drift_util, spiketorch, waveform_util +from linear_operator import operators from scipy.fftpack import next_fast_len from tqdm.auto import trange class FullNoise(torch.nn.Module): - """Do not use this, it's just for comparison to the others. - """ + """Do not use this, it's just for comparison to the others.""" def __init__(self, std, vt, nt, nc): super().__init__() @@ -55,8 +54,7 @@ def estimate(cls, snippets): class FactorizedNoise(torch.nn.Module): - """Spatial/temporal factorized noise. See .estimate(). - """ + """Spatial/temporal factorized noise. See .estimate().""" def __init__(self, spatial_std, vt_spatial, temporal_std, vt_temporal): super().__init__() @@ -115,18 +113,22 @@ def estimate(cls, snippets): n, t, c = snippets.shape sqrt_nt_minus_1 = torch.tensor(n * t - 1, dtype=snippets.dtype).sqrt() sqrt_nc_minus_1 = torch.tensor(n * c - 1, dtype=snippets.dtype).sqrt() - assert n * t > c ** 2 - assert n * c > t ** 2 + assert n * t > c**2 + assert n * c > t**2 # estimate spatial covariance x_spatial = snippets.view(n * t, c) - u_spatial, spatial_sing, vt_spatial = torch.linalg.svd(x_spatial, full_matrices=False) + u_spatial, spatial_sing, vt_spatial = torch.linalg.svd( + x_spatial, full_matrices=False + ) spatial_std = spatial_sing / sqrt_nt_minus_1 # extract whitened temporal snips x_temporal = u_spatial.view(n, t, c).permute(0, 2, 1).reshape(n * c, t) x_temporal.mul_(sqrt_nt_minus_1) - _, temporal_sing, vt_temporal = torch.linalg.svd(x_temporal, full_matrices=False) + _, temporal_sing, vt_temporal = torch.linalg.svd( + x_temporal, full_matrices=False + ) del _ temporal_std = temporal_sing / sqrt_nc_minus_1 @@ -160,7 +162,11 @@ def simulate(self, size=1, t=None, generator=None): t_padded = t + self.t - 1 noise = torch.randn(size * c, t_padded, generator=generator, device=device) noise = spiketorch.single_inv_oaconv1d( - noise, s2=self.t, f2=self.kernel_fft, block_size=self.block_size, norm="ortho" + noise, + s2=self.t, + f2=self.kernel_fft, + block_size=self.block_size, + norm="ortho", ) noise = noise.view(size, c, t) spatial_part = self.spatial_std[:, None] * self.vt_spatial @@ -188,12 +194,14 @@ def estimate(cls, snippets): n, t, c = snippets.shape sqrt_nt_minus_1 = torch.tensor(n * t - 1, dtype=snippets.dtype).sqrt() - assert n * t > c ** 2 + assert n * t > c**2 assert n * c > t # estimate spatial covariance x_spatial = snippets.view(n * t, c) - u_spatial, spatial_sing, vt_spatial = torch.linalg.svd(x_spatial, full_matrices=False) + u_spatial, spatial_sing, vt_spatial = torch.linalg.svd( + x_spatial, full_matrices=False + ) spatial_std = spatial_sing / sqrt_nt_minus_1 # extract whitened temporal snips @@ -271,6 +279,292 @@ def unit_false_positives( return total_samples, df +class EmbeddedNoise(torch.nn.Module): + """Handles computations related to noise in TPCA space. + + Can have a couple of kinds of mean. mean_kind == ... + - "zero": noise was already centered, my mean is 0 + - "by_rank": same mean on all channels + - "full": value per rank, chan + + And cov_kind = ... + - "scalar": one global variance + - "diagonal_by_rank": same variance across chans, varies by rank + - "diagonal": value per rank, chan + - "factorized": kronecker prod of dense rank and chan factors + - "factorized_by_rank": same, but chan factor varies by rank + - "factorized_rank_diag" : factorized, but rank factor is diagonal + - "factorized_by_rank_rank_diag" : factorized_by_rank, but rank factor is diagonal + this one is block diagonal and therefore nicer than factorized_by_rank. + """ + + def __init__( + self, + rank, + n_channels, + mean_kind="zero", + cov_kind="scalar", + mean=None, + global_std=None, + rank_std=None, + full_std=None, + rank_vt=None, + channel_std=None, + channel_vt=None, + ): + self.rank = rank + self.n_channels = n_channels + self.mean_kind = mean_kind + self.cov_kind = cov_kind + + self.mean = mean + self.global_std = global_std + self.rank_std = rank_std + self.channel_std = channel_std + self.full_std = full_std + + self.rank_vt = rank_vt + self.channel_vt = channel_vt + + @property + def device(self): + return self.global_std.device + + def mean_rc(self): + """Return noise mean as a rank x channels tensor""" + shape = self.rank, self.n_channels + if self.mean_kind == "zero": + return torch.zeros(shape) + elif self.mean_kind == "by_rank": + return self.mean[:, None].broadcast_to(shape).contiguous() + elif self.mean_kind == "full": + return self.mean + + def marginal_precision(self, channels): + return self.marginal_covariance(channels).inverse() + + def marginal_covariance(self, channels): + nc = channels.numel() + + if self.cov_kind == "scalar": + eye = operators.IdentityLinearOperator(self.rank * nc, device=self.device) + return self.global_std.square() * eye + + if self.cov_kind == "diagonal_by_rank": + rank_diag = operators.DiagLinearOperator(self.rank_std**2) + chans_eye = operators.IdentityLinearOperator(nc, device=self.device) + return torch.kron(rank_diag, chans_eye) + + if self.cov_kind == "diagonal": + return operators.DiagLinearOperator(self.full_std**2) + + if self.cov_kind == "factorized": + rank_root = self.rank_vt.T * self.rank_std + rank_root = operators.RootLinearOperator(rank_root) + chan_root = self.channel_vt.T * self.channel_std + chan_root = operators.RootLinearOperator(chan_root) + return torch.kron(rank_root, chan_root) + + if self.cov_kind == "factorized_rank_diag": + rank_root = self.rank_vt.T * self.rank_std + rank_root = operators.RootLinearOperator(rank_root) + chan_root = self.channel_vt.T * self.channel_std + chan_root = operators.RootLinearOperator(chan_root) + return torch.kron(rank_root, chan_root) + + assert False + + @classmethod + def estimate(cls, snippets, mean_kind="zero", cov_kind="scalar"): + """Factory method to estimate noise model from TPCA snippets + + Arguments + --------- + snippets : torch.Tensor + (n, rank, c) array of tpca-embedded noise snippets + missing values are okay, indicate by NaN please + """ + n, rank, n_channels = snippets.shape + init_kw = dict( + rank=rank, n_channels=n_channels, mean_kind=mean_kind, cov_kind=cov_kind + ) + x = torch.asarray(snippets).view(n, -1) + x = x.to(torch.promote_types(x.dtype, torch.float)) + + # estimate mean and center data + if mean_kind == "zero": + mean = None + elif mean_kind == "by_rank": + mean = torch.nanmean(x, dim=(0, 2)) + assert mean.isfinite().all() + x = x - mean.unsqueeze(1) + elif mean_kind == "full": + mean = torch.nanmean(x, dim=0) + mean = torch.where( + mean.isnan().all(1).unsqueeze(1), + torch.nanmean(mean, dim=1).unsqueeze(1), + mean, + ) + x = x - mean + else: + assert False + + # estimate covs + dxsq = x.square() + full_var = torch.nanmean(dxsq, dim=0) + rank_var = torch.nanmean(full_var, dim=1) + assert rank_var.isfinite().all() + global_var = torch.nanmean(rank_var) + global_std = global_var.sqrt() + + if cov_kind == "scalar": + return cls(mean=mean, global_std=global_std, **init_kw) + + if cov_kind == "by_rank": + rank_std = rank_var.sqrt_() + return cls(mean=mean, global_std=global_std, rank_std=rank_std, **init_kw) + + if cov_kind == "diagonal": + full_var = torch.where( + full_var.isnan().all(1).unsqueeze(1), + rank_var.unsqueeze(1), + full_var, + ) + full_std = full_var.sqrt() + return cls(mean=mean, global_std=global_std, full_std=full_std, **init_kw) + + assert cov_kind.startswith("factorized") + # handle rank part first, then the spatial part + + # start by getting the rank part of the cov, if necessary, and whitening + # the ranks to leave behind x_spatial + if "rank_diag" in cov_kind: + # rank part is diagonal + rank_vt = None + x_spatial = x.div_(rank_std.unsqueeze(1)) + del x + else: + # full cov estimation for rank part via svd + # we have NaNs, but we can get rid of them because channels are either all + # NaN or not. below, for the spatial part, no such luck and we have to + # evaluate the covariance in a masked way + x_rank = x.permute(0, 2, 1).reshape(n * n_channels, rank) + valid = x_rank.isfinite(1).all() + x_rankv = x_rank[valid] + del x + u_rankv, rank_sing, rank_vt = torch.linalg.svd(x_rankv, full_matrices=False) + correction = torch.tensor(len(x_rankv) - 1.0).sqrt() + rank_std = rank_sing / correction + + # whitened spatial part -- reuse storage + x_spatial = x_rank + del x_rank + x_spatial[valid] = u_rankv + x_spatial = x_spatial.reshape(n, n_channels, rank).permute(0, 2, 1) + x_spatial.mul_(correction) + + # spatial part could be "by rank" or same for all ranks + # either way, there are nans afoot + if "by_rank" in cov_kind: + channel_std = torch.zeros_like(x_spatial[0]) + channel_vt = torch.zeros((rank, n_channels, n_channels)).to(channel_std) + for q in range(rank): + xq = x_spatial[:, q] + covq = spiketorch.nancov(xq) + qeig, qv = torch.linalg.eigh(covq) + channel_std[q] = qeig.sqrt() + channel_vt[q] = qv.T + else: + x_spatial = x_spatial.reshape(n * rank, n_channels) + cov_spatial = spiketorch.nancov(x_spatial) + channel_eig, channel_v = torch.linalg.eigh(cov_spatial) + channel_std = channel_eig.sqrt() + channel_vt = channel_v.T.contiguous() + + return cls( + mean=mean, + global_std=global_std, + rank_std=rank_std, + rank_vt=rank_vt, + channel_std=channel_std, + channel_vt=channel_vt, + **init_kw, + ) + + +def interpolate_residual_snippets( + tpca, + motion_est, + hdf5_path, + geom, + registered_geom, + sigma=10.0, + mean_kind="zero", + cov_kind="scalar", + residual_times_s_dataset_name="residual_times_seconds", + residual_dataset_name="residual", + channels_mode="round", + interpolation_method="normalized", +): + """PCA-embed and interpolate residual snippets to the registered probe""" + from dartsort.util import interpolation_util + + with h5py.File(hdf5_path, "r") as h5: + snippets = h5[residual_dataset_name][:] + times_s = h5[residual_times_s_dataset_name][:] + snippets = torch.from_numpy(snippets).to(tpca.components) + times_s = torch.from_numpy(times_s).to(tpca.components) + + # tpca project + snippets = snippets[:, tpca.temporal_slice] + n, t, c = snippets.shape + snippets = snippets.permute(0, 2, 1).reshape(n * c, t) + snippets = tpca._transform_in_probe(snippets) + snippets = snippets.reshape(n, c, -1).permute(0, 2, 1) + + # -- interpolate + # source positions + source_geom = interpolation_util.pad_geom(geom) + psc = len(source_geom) + source_pos = source_geom[None].broadcast_to(n, psc) + source_depths = source_pos[:, :, 1].reshape(-1).clone() + source_t = times_s[None].broadcast_to(source_depths).reshape(-1) + source_reg_depths = motion_est.correct_s(source_t, source_depths) + source_pos[:, :, 1] = source_reg_depths.reshape(source_pos[:, :, 1].shape) + + # target positions -- these are just the full reg probe + target_pos = torch.asarray(registered_geom).to(source_geom) + target_pos = target_pos[None].broadcast_to(n, *target_pos.shape) + # this is how it would be done sparsely... which we are not doing. + # pitch_shifts = drift_util.get_spike_pitch_shifts( + # source_depths, + # geom=geom, + # motion_est=motion_est, + # times_s=source_t, + # registered_depths_um=source_reg_depths, + # mode=channels_mode, + # ) + # target_channels = drift_util.static_channel_neighborhoods( + # geom, + # channels=np.zeros(n, dtype=int), + # channel_index=waveform_util.full_channel_index(len(geom)), + # n_pitches_shift=pitch_shifts, + # registered_geom=registered_geom, + # ) + # target_geom = interpolation_util.pad_geom(registered_geom) + # target_pos = target_geom[target_channels] + + snippets = interpolation_util.kernel_interpolate( + snippets, + source_pos, + target_pos, + sigma=sigma, + allow_destroy=True, + interpolation_method=interpolation_method, + ) + return snippets + + def get_discovery_control( units, tp_scores, @@ -318,4 +612,4 @@ def get_discovery_control( This can be done globally or per unit. If global, we do... max(count_control_thresholds.max(), fnr_control_thresholds.min()). """ - pass \ No newline at end of file + pass