From faa860111d1dbdafe599d7e6bc0bd57293b9856c Mon Sep 17 00:00:00 2001 From: Charlie Windolf Date: Thu, 17 Oct 2024 12:39:21 -0700 Subject: [PATCH] Trying to silence a new torch.load warning --- src/dartsort/__init__.py | 1 + src/dartsort/transform/all_transformers.py | 6 ++++++ src/dartsort/transform/pipeline.py | 4 +--- src/dartsort/util/data_util.py | 6 +++++- src/dartsort/util/nn_util.py | 6 ++++++ 5 files changed, 19 insertions(+), 4 deletions(-) diff --git a/src/dartsort/__init__.py b/src/dartsort/__init__.py index c2a5970f..b3de9a98 100644 --- a/src/dartsort/__init__.py +++ b/src/dartsort/__init__.py @@ -11,6 +11,7 @@ from .transform import WaveformPipeline from .util.analysis import DARTsortAnalysis from .util.data_util import DARTsortSorting +from .util.drift_util import registered_geometry from .util.waveform_util import make_channel_index from .cluster import merge, postprocess, density from . import util diff --git a/src/dartsort/transform/all_transformers.py b/src/dartsort/transform/all_transformers.py index 2a5b649b..a260dc99 100644 --- a/src/dartsort/transform/all_transformers.py +++ b/src/dartsort/transform/all_transformers.py @@ -1,3 +1,4 @@ +import torch from .amplitudes import AmplitudeFeatures, AmplitudeVector, MaxAmplitude, Voltage from .enforce_decrease import EnforceDecrease from .localize import Localization, PointSourceLocalization @@ -6,6 +7,7 @@ from .temporal_pca import TemporalPCADenoiser, TemporalPCAFeaturizer, TemporalPCA from .transform_base import Waveform, Passthrough from .decollider import Decollider +from .pipeline import WaveformPipeline all_transformers = [ Waveform, @@ -26,3 +28,7 @@ ] transformers_by_class_name = {cls.__name__: cls for cls in all_transformers} + +# serialization +others = [WaveformPipeline, set, torch.nn.ModuleList, slice, torch.nn.Sequential] +torch.serialization.add_safe_globals(all_transformers + others) diff --git a/src/dartsort/transform/pipeline.py b/src/dartsort/transform/pipeline.py index ab3fcaf7..3b34c93c 100644 --- a/src/dartsort/transform/pipeline.py +++ b/src/dartsort/transform/pipeline.py @@ -2,9 +2,6 @@ """ import torch -from .all_transformers import transformers_by_class_name - - class WaveformPipeline(torch.nn.Module): def __init__(self, transformers): super().__init__() @@ -15,6 +12,7 @@ def __init__(self, transformers): def from_class_names_and_kwargs( cls, geom, channel_index, class_names_and_kwargs ): + from .all_transformers import transformers_by_class_name return cls( [ transformers_by_class_name[name]( diff --git a/src/dartsort/util/data_util.py b/src/dartsort/util/data_util.py index 1f8db442..3644414d 100644 --- a/src/dartsort/util/data_util.py +++ b/src/dartsort/util/data_util.py @@ -194,7 +194,11 @@ def get_tpca(sorting): """Look for the TemporalPCAFeaturizer in the usual place.""" base_dir = sorting.parent_h5_path.parent model_dir = base_dir / f"{sorting.parent_h5_path.stem}_models" - pipeline = torch.load(model_dir / "featurization_pipeline.pt") + pipeline = torch.load( + model_dir / "featurization_pipeline.pt", + weights_only=True, + map_location="cpu", + ) tpca = pipeline.transformers[0] return tpca diff --git a/src/dartsort/util/nn_util.py b/src/dartsort/util/nn_util.py index d8be96ac..e09ddd04 100644 --- a/src/dartsort/util/nn_util.py +++ b/src/dartsort/util/nn_util.py @@ -130,3 +130,9 @@ def __init__(self, module): def forward(self, inputs): waveforms, masks = inputs return self.module(waveforms), masks + + +# is this what they want us to do?? +torch.serialization.add_safe_globals( + [ResidualForm, WaveformOnlyResidualForm, ChannelwiseDropout, Cat, Permute, WaveformOnly, nn.Flatten, nn.Linear, nn.Conv1d, nn.ReLU, nn.BatchNorm1d] +) \ No newline at end of file