Skip to content

Commit

Permalink
Template std devs
Browse files Browse the repository at this point in the history
  • Loading branch information
cwindolf committed Sep 3, 2024
1 parent d7d875f commit 2bfe1d2
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 7 deletions.
1 change: 1 addition & 0 deletions src/dartsort/templates/get_templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def get_templates(
n_jobs=0,
dtype=np.float32,
show_progress=True,
with_std_dev=False,
device=None,
):
"""Raw, denoised, and shifted templates
Expand Down
26 changes: 19 additions & 7 deletions src/dartsort/templates/templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@ class TemplateData:
# (n_templates,) spike count for each template
spike_counts: np.ndarray
# (n_templates, n_registered_channels or n_channels) spike count for each channel
spike_counts_by_channel: np.ndarray
spike_counts_by_channel: Optional[np.ndarray] = None
# (n_templates, spike_length_samples, n_registered_channels or n_channels)
raw_std_dev: Optional[np.ndarray] = None

registered_geom: Optional[np.ndarray] = None
registered_template_depths_um: Optional[np.ndarray] = None
Expand All @@ -52,6 +54,14 @@ def to_npz(self, npz_path):
to_save["registered_template_depths_um"] = (
self.registered_template_depths_um
)
if self.spike_counts_by_channel is not None:
to_save["spike_counts_by_channel"] = (
self.spike_counts_by_channel
)
if self.raw_std_dev is not None:
to_save["raw_std_dev"] = (
self.raw_std_dev
)
if not npz_path.parent.exists():
npz_path.parent.mkdir()
np.savez(npz_path, **to_save)
Expand Down Expand Up @@ -109,7 +119,7 @@ def from_config(
motion_est=None,
save_npz_name="template_data.npz",
localizations_dataset_name="point_source_localizations",
with_locs=True,
with_locs=False,
n_jobs=0,
units_per_job=8,
tsvd=None,
Expand Down Expand Up @@ -212,11 +222,13 @@ def from_config(

# handle registered templates
if template_config.registered_templates and motion_est is not None:
registered_template_depths_um = get_template_depths(
results["templates"],
kwargs["registered_geom"],
localization_radius_um=template_config.registered_template_localization_radius_um,
)
registered_template_depths_um = None
if with_locs:
registered_template_depths_um = get_template_depths(
results["templates"],
kwargs["registered_geom"],
localization_radius_um=template_config.registered_template_localization_radius_um,
)
obj = cls(
results["templates"],
unit_ids=results["unit_ids"],
Expand Down

0 comments on commit 2bfe1d2

Please sign in to comment.