From 2bfe1d23c83973ef26fa3d03eef7280de41ca8bb Mon Sep 17 00:00:00 2001 From: Charlie Windolf Date: Tue, 3 Sep 2024 08:30:48 -0700 Subject: [PATCH] Template std devs --- src/dartsort/templates/get_templates.py | 1 + src/dartsort/templates/templates.py | 26 ++++++++++++++++++------- 2 files changed, 20 insertions(+), 7 deletions(-) diff --git a/src/dartsort/templates/get_templates.py b/src/dartsort/templates/get_templates.py index 47fdf041..14255ce4 100644 --- a/src/dartsort/templates/get_templates.py +++ b/src/dartsort/templates/get_templates.py @@ -42,6 +42,7 @@ def get_templates( n_jobs=0, dtype=np.float32, show_progress=True, + with_std_dev=False, device=None, ): """Raw, denoised, and shifted templates diff --git a/src/dartsort/templates/templates.py b/src/dartsort/templates/templates.py index 802f0034..72cfbb1c 100644 --- a/src/dartsort/templates/templates.py +++ b/src/dartsort/templates/templates.py @@ -27,7 +27,9 @@ class TemplateData: # (n_templates,) spike count for each template spike_counts: np.ndarray # (n_templates, n_registered_channels or n_channels) spike count for each channel - spike_counts_by_channel: np.ndarray + spike_counts_by_channel: Optional[np.ndarray] = None + # (n_templates, spike_length_samples, n_registered_channels or n_channels) + raw_std_dev: Optional[np.ndarray] = None registered_geom: Optional[np.ndarray] = None registered_template_depths_um: Optional[np.ndarray] = None @@ -52,6 +54,14 @@ def to_npz(self, npz_path): to_save["registered_template_depths_um"] = ( self.registered_template_depths_um ) + if self.spike_counts_by_channel is not None: + to_save["spike_counts_by_channel"] = ( + self.spike_counts_by_channel + ) + if self.raw_std_dev is not None: + to_save["raw_std_dev"] = ( + self.raw_std_dev + ) if not npz_path.parent.exists(): npz_path.parent.mkdir() np.savez(npz_path, **to_save) @@ -109,7 +119,7 @@ def from_config( motion_est=None, save_npz_name="template_data.npz", localizations_dataset_name="point_source_localizations", - with_locs=True, + with_locs=False, n_jobs=0, units_per_job=8, tsvd=None, @@ -212,11 +222,13 @@ def from_config( # handle registered templates if template_config.registered_templates and motion_est is not None: - registered_template_depths_um = get_template_depths( - results["templates"], - kwargs["registered_geom"], - localization_radius_um=template_config.registered_template_localization_radius_um, - ) + registered_template_depths_um = None + if with_locs: + registered_template_depths_um = get_template_depths( + results["templates"], + kwargs["registered_geom"], + localization_radius_um=template_config.registered_template_localization_radius_um, + ) obj = cls( results["templates"], unit_ids=results["unit_ids"],