From 7233abcd9faad425965d544248d300949a3d314e Mon Sep 17 00:00:00 2001 From: Sagar Mishra <54197164+achieveordie@users.noreply.github.com> Date: Sun, 2 Apr 2023 11:31:33 +0530 Subject: [PATCH 1/2] add serialization functionality --- pyproject.toml | 1 + skbase/base/_base.py | 82 +++++++++++++++++++++++++++++++++++++++ skbase/base/_serialize.py | 62 +++++++++++++++++++++++++++++ 3 files changed, 145 insertions(+) create mode 100644 skbase/base/_serialize.py diff --git a/pyproject.toml b/pyproject.toml index dff9498e..89ce4d21 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -141,6 +141,7 @@ ignore_path = ["docs/_build", "docs/source/api_reference/auto_generated"] [tool.bandit] exclude_dirs = ["*/tests/*", "*/testing/*"] +skips = ["B301", "B403"] [tool.setuptools] zip-safe = true diff --git a/skbase/base/_base.py b/skbase/base/_base.py index 970bf350..c6491533 100644 --- a/skbase/base/_base.py +++ b/skbase/base/_base.py @@ -682,6 +682,88 @@ def _components(self, base_class=None): return comp_dict + def save(self, path=None): + """Save serialized self to bytes-like object or to (.zip) file. + + Behaviour: + if `path` is None, returns an in-memory serialized self + if `path` is a file location, stores self at that location as a zip file + + saved files are zip files with following contents: + _metadata - contains class of self, i.e., type(self) + _obj - serialized self. This class uses the default serialization (pickle). + + Parameters + ---------- + path : None or file location (str or Path) + if None, self is saved to an in-memory object + if file location, self is saved to that file location. If: + path="estimator" then a zip file `estimator.zip` will be made at cwd. + path="/home/stored/estimator" then a zip file `estimator.zip` will be + stored in `/home/stored/`. + + Returns + ------- + if `path` is None - in-memory serialized self + if `path` is file location - ZipFile with reference to the file + """ + import pickle + import shutil + from pathlib import Path + from zipfile import ZipFile + + if path is None: + return (type(self), pickle.dumps(self)) + if not isinstance(path, (str, Path)): + raise TypeError( + "`path` is expected to either be a string or a Path object " + f"but found of type:{type(path)}." + ) + + path = Path(path) if isinstance(path, str) else path + path.mkdir() + + pickle.dump(type(self), open(path / "_metadata", "wb")) + pickle.dump(self, open(path / "_obj", "wb")) + + shutil.make_archive(base_name=path, format="zip", root_dir=path) + shutil.rmtree(path) + return ZipFile(path.with_name(f"{path.stem}.zip")) + + @classmethod + def load_from_serial(cls, serial): + """Load object from serialized memory container. + + Parameters + ---------- + serial : 1st element of output of `cls.save(None)` + + Returns + ------- + deserialized self resulting in output `serial`, of `cls.save(None)` + """ + import pickle + + return pickle.loads(serial) + + @classmethod + def load_from_path(cls, serial): + """Load object from file location. + + Parameters + ---------- + serial : result of ZipFile(path).open("object) + + Returns + ------- + deserialized self resulting in output at `path`, of `cls.save(path)` + """ + import pickle + from zipfile import ZipFile + + with ZipFile(serial, "r") as file: + return pickle.loads(file.open("_obj").read()) + class TagAliaserMixin: """Mixin class for tag aliasing and deprecation of old tags. diff --git a/skbase/base/_serialize.py b/skbase/base/_serialize.py new file mode 100644 index 00000000..51817361 --- /dev/null +++ b/skbase/base/_serialize.py @@ -0,0 +1,62 @@ +# -*- coding: utf-8 -*- +# copyright: skbase developers, BSD-3-Clause License (see LICENSE file) +"""Utilities for serializing and deserializing objects. + +IMPORTANT CAVEAT FOR DEVELOPERS: +Do not add estimator specific functionality to the `load` utility. +All estimator specific functionality should be in +the class methods `load_from_serial` and `load_from_path`. +""" + +__author__ = ["fkiraly", "achieveordie"] + + +def load(serial): + """Load an object either from in-memory object or from a file location. + + Parameters + ---------- + serial : serialized container (tuple), str (path), or Path object (reference) + if serial is a tuple (serialized container): + Contains two elements, first in-memory metadata and second + the related object. + if serial is a string (path reference): + The name of the file without the extension, for e.g: if the file + is `estimator.zip`, `serial='estimator'`. It can also represent a + path, for eg: if location is `home/stored/models/estimator.zip` + then `serial='home/stored/models/estimator'`. + if serial is a Path object (path reference): + `serial` then points to the `.zip` file into which the + object was stored using class method `.save()` of an estimator. + + Returns + ------- + Deserialized self resulting in output `serial`, of `cls.save` + """ + import pickle + from pathlib import Path + from zipfile import ZipFile + + if isinstance(serial, tuple): + if len(serial) != 2: + raise ValueError( + "`serial` should be a tuple of size 2 " + f"found, a tuple of size: {len(serial)}" + ) + cls, stored = serial + return cls.load_from_serial(stored) + + elif isinstance(serial, (str, Path)): + path = Path(serial + ".zip") if isinstance(serial, str) else serial + if not path.exists(): + raise FileNotFoundError(f"The given save location: {serial}\nwas not found") + with ZipFile(path) as file: + cls = pickle.loads(file.open("_metadata", "r").read()) + return cls.load_from_path(path) + else: + raise TypeError( + "serial must either be a serialized in-memory sktime object, " + "a str, Path or ZipFile pointing to a file which is a serialized sktime " + "object, created by save of an sktime object; but found serial " + f"of type {serial}" + ) From aea9904f23f0ac145f7761d1a0ea59da4a61bd47 Mon Sep 17 00:00:00 2001 From: Sagar Mishra <54197164+achieveordie@users.noreply.github.com> Date: Thu, 6 Apr 2023 10:53:12 +0530 Subject: [PATCH 2/2] add new file & function to config files --- skbase/base/_base.py | 8 ++++++-- skbase/tests/conftest.py | 2 ++ 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/skbase/base/_base.py b/skbase/base/_base.py index c6491533..30761311 100644 --- a/skbase/base/_base.py +++ b/skbase/base/_base.py @@ -723,8 +723,12 @@ def save(self, path=None): path = Path(path) if isinstance(path, str) else path path.mkdir() - pickle.dump(type(self), open(path / "_metadata", "wb")) - pickle.dump(self, open(path / "_obj", "wb")) + mfile_ = open(path / "_metadata", "wb") + ofile_ = open(path / "_obj", "wb") + pickle.dump(type(self), mfile_) + pickle.dump(self, ofile_) + mfile_.close() + ofile_.close() shutil.make_archive(base_name=path, format="zip", root_dir=path) shutil.rmtree(path) diff --git a/skbase/tests/conftest.py b/skbase/tests/conftest.py index 4f36034a..a251fc68 100644 --- a/skbase/tests/conftest.py +++ b/skbase/tests/conftest.py @@ -22,6 +22,7 @@ "skbase.base", "skbase.base._base", "skbase.base._meta", + "skbase.base._serialize", "skbase.base._tagmanager", "skbase.lookup", "skbase.lookup.tests", @@ -90,6 +91,7 @@ } ) SKBASE_PUBLIC_FUNCTIONS_BY_MODULE = { + "skbase.base._serialize": ("load",), "skbase.lookup": ("all_objects", "get_package_metadata"), "skbase.lookup._lookup": ("all_objects", "get_package_metadata"), "skbase.testing.utils._conditional_fixtures": (