Skip to content

Commit

Permalink
Trying to silence a new torch.load warning
Browse files Browse the repository at this point in the history
  • Loading branch information
cwindolf committed Oct 17, 2024
1 parent 622f885 commit faa8601
Show file tree
Hide file tree
Showing 5 changed files with 19 additions and 4 deletions.
1 change: 1 addition & 0 deletions src/dartsort/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from .transform import WaveformPipeline
from .util.analysis import DARTsortAnalysis
from .util.data_util import DARTsortSorting
from .util.drift_util import registered_geometry
from .util.waveform_util import make_channel_index
from .cluster import merge, postprocess, density
from . import util
Expand Down
6 changes: 6 additions & 0 deletions src/dartsort/transform/all_transformers.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import torch
from .amplitudes import AmplitudeFeatures, AmplitudeVector, MaxAmplitude, Voltage
from .enforce_decrease import EnforceDecrease
from .localize import Localization, PointSourceLocalization
Expand All @@ -6,6 +7,7 @@
from .temporal_pca import TemporalPCADenoiser, TemporalPCAFeaturizer, TemporalPCA
from .transform_base import Waveform, Passthrough
from .decollider import Decollider
from .pipeline import WaveformPipeline

all_transformers = [
Waveform,
Expand All @@ -26,3 +28,7 @@
]

transformers_by_class_name = {cls.__name__: cls for cls in all_transformers}

# serialization
others = [WaveformPipeline, set, torch.nn.ModuleList, slice, torch.nn.Sequential]
torch.serialization.add_safe_globals(all_transformers + others)
4 changes: 1 addition & 3 deletions src/dartsort/transform/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,6 @@
"""
import torch

from .all_transformers import transformers_by_class_name


class WaveformPipeline(torch.nn.Module):
def __init__(self, transformers):
super().__init__()
Expand All @@ -15,6 +12,7 @@ def __init__(self, transformers):
def from_class_names_and_kwargs(
cls, geom, channel_index, class_names_and_kwargs
):
from .all_transformers import transformers_by_class_name
return cls(
[
transformers_by_class_name[name](
Expand Down
6 changes: 5 additions & 1 deletion src/dartsort/util/data_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,11 @@ def get_tpca(sorting):
"""Look for the TemporalPCAFeaturizer in the usual place."""
base_dir = sorting.parent_h5_path.parent
model_dir = base_dir / f"{sorting.parent_h5_path.stem}_models"
pipeline = torch.load(model_dir / "featurization_pipeline.pt")
pipeline = torch.load(
model_dir / "featurization_pipeline.pt",
weights_only=True,
map_location="cpu",
)
tpca = pipeline.transformers[0]
return tpca

Expand Down
6 changes: 6 additions & 0 deletions src/dartsort/util/nn_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,3 +130,9 @@ def __init__(self, module):
def forward(self, inputs):
waveforms, masks = inputs
return self.module(waveforms), masks


# is this what they want us to do??
torch.serialization.add_safe_globals(
[ResidualForm, WaveformOnlyResidualForm, ChannelwiseDropout, Cat, Permute, WaveformOnly, nn.Flatten, nn.Linear, nn.Conv1d, nn.ReLU, nn.BatchNorm1d]
)

0 comments on commit faa8601

Please sign in to comment.