Skip to content

Commit

Permalink
Analysis is slow...
Browse files Browse the repository at this point in the history
  • Loading branch information
cwindolf committed Nov 5, 2024
1 parent 2ee1858 commit 83d93cf
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 4 deletions.
4 changes: 2 additions & 2 deletions src/dartsort/templates/pairwise_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -894,8 +894,8 @@ def shift_deduplicated_pairs(
dot = chan_amp_a @ chan_amp_b.T
pair = dot > conv_ignore_threshold
if min_spatial_cosine:
norm_a = torch.sqrt((chan_amp_a * chan_amp_a).sum(1))
norm_b = torch.sqrt((chan_amp_b * chan_amp_b).sum(1))
norm_a = torch.sqrt(chan_amp_a.square().sum(1))
norm_b = torch.sqrt(chan_amp_b.square().sum(1))
cos = dot / (norm_a[:, None] * norm_b[None, :])
pair = pair & (cos > min_spatial_cosine)

Expand Down
5 changes: 3 additions & 2 deletions src/dartsort/util/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ class DARTsortAnalysis:
merge_distance_kind: str = "rms"
merge_distance_spatial_radius_a: Optional[float] = None
merge_distance_min_channel_amplitude: float = 0.0
merge_distance_min_spatial_cosine: float = 0.0
merge_distance_min_spatial_cosine: float = 0.5
merge_temporal_upsampling: int = 1
merge_superres_linkage: Callable[[np.ndarray], float] = np.max
compute_distances: bool = "if_hdf5"
Expand Down Expand Up @@ -101,7 +101,8 @@ def from_sorting(
assert model_dir.exists()

featurization_pipeline = torch.load(
model_dir / "featurization_pipeline.pt"
model_dir / "featurization_pipeline.pt",
weights_only=True,
)

have_templates = False
Expand Down

0 comments on commit 83d93cf

Please sign in to comment.