Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] Introduce Serialization Support #154

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ ignore_path = ["docs/_build", "docs/source/api_reference/auto_generated"]

[tool.bandit]
exclude_dirs = ["*/tests/*", "*/testing/*"]
skips = ["B301", "B403"]

[tool.setuptools]
zip-safe = true
Expand Down
86 changes: 86 additions & 0 deletions skbase/base/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -682,6 +682,92 @@ def _components(self, base_class=None):

return comp_dict

def save(self, path=None):
"""Save serialized self to bytes-like object or to (.zip) file.

Behaviour:
if `path` is None, returns an in-memory serialized self
if `path` is a file location, stores self at that location as a zip file

saved files are zip files with following contents:
_metadata - contains class of self, i.e., type(self)
_obj - serialized self. This class uses the default serialization (pickle).

Parameters
----------
path : None or file location (str or Path)
if None, self is saved to an in-memory object
if file location, self is saved to that file location. If:
path="estimator" then a zip file `estimator.zip` will be made at cwd.
path="/home/stored/estimator" then a zip file `estimator.zip` will be
stored in `/home/stored/`.

Returns
-------
if `path` is None - in-memory serialized self
if `path` is file location - ZipFile with reference to the file
"""
import pickle
import shutil
from pathlib import Path
from zipfile import ZipFile

if path is None:
return (type(self), pickle.dumps(self))
if not isinstance(path, (str, Path)):
raise TypeError(
"`path` is expected to either be a string or a Path object "
f"but found of type:{type(path)}."
)

path = Path(path) if isinstance(path, str) else path
path.mkdir()

mfile_ = open(path / "_metadata", "wb")
ofile_ = open(path / "_obj", "wb")
pickle.dump(type(self), mfile_)
pickle.dump(self, ofile_)
mfile_.close()
ofile_.close()

shutil.make_archive(base_name=path, format="zip", root_dir=path)
shutil.rmtree(path)
return ZipFile(path.with_name(f"{path.stem}.zip"))

@classmethod
def load_from_serial(cls, serial):
"""Load object from serialized memory container.

Parameters
----------
serial : 1st element of output of `cls.save(None)`

Returns
-------
deserialized self resulting in output `serial`, of `cls.save(None)`
"""
import pickle

return pickle.loads(serial)

@classmethod
def load_from_path(cls, serial):
"""Load object from file location.

Parameters
----------
serial : result of ZipFile(path).open("object)

Returns
-------
deserialized self resulting in output at `path`, of `cls.save(path)`
"""
import pickle
from zipfile import ZipFile

with ZipFile(serial, "r") as file:
return pickle.loads(file.open("_obj").read())


class TagAliaserMixin:
"""Mixin class for tag aliasing and deprecation of old tags.
Expand Down
62 changes: 62 additions & 0 deletions skbase/base/_serialize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# -*- coding: utf-8 -*-
# copyright: skbase developers, BSD-3-Clause License (see LICENSE file)
"""Utilities for serializing and deserializing objects.

IMPORTANT CAVEAT FOR DEVELOPERS:
Do not add estimator specific functionality to the `load` utility.
All estimator specific functionality should be in
the class methods `load_from_serial` and `load_from_path`.
"""

__author__ = ["fkiraly", "achieveordie"]


def load(serial):
"""Load an object either from in-memory object or from a file location.

Parameters
----------
serial : serialized container (tuple), str (path), or Path object (reference)
if serial is a tuple (serialized container):
Contains two elements, first in-memory metadata and second
the related object.
if serial is a string (path reference):
The name of the file without the extension, for e.g: if the file
is `estimator.zip`, `serial='estimator'`. It can also represent a
path, for eg: if location is `home/stored/models/estimator.zip`
then `serial='home/stored/models/estimator'`.
if serial is a Path object (path reference):
`serial` then points to the `.zip` file into which the
object was stored using class method `.save()` of an estimator.

Returns
-------
Deserialized self resulting in output `serial`, of `cls.save`
"""
import pickle
from pathlib import Path
from zipfile import ZipFile

if isinstance(serial, tuple):
if len(serial) != 2:
raise ValueError(
"`serial` should be a tuple of size 2 "
f"found, a tuple of size: {len(serial)}"
)
cls, stored = serial
return cls.load_from_serial(stored)

elif isinstance(serial, (str, Path)):
path = Path(serial + ".zip") if isinstance(serial, str) else serial
if not path.exists():
raise FileNotFoundError(f"The given save location: {serial}\nwas not found")
with ZipFile(path) as file:
cls = pickle.loads(file.open("_metadata", "r").read())
return cls.load_from_path(path)
else:
raise TypeError(
"serial must either be a serialized in-memory sktime object, "
"a str, Path or ZipFile pointing to a file which is a serialized sktime "
"object, created by save of an sktime object; but found serial "
f"of type {serial}"
)
2 changes: 2 additions & 0 deletions skbase/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
"skbase.base",
"skbase.base._base",
"skbase.base._meta",
"skbase.base._serialize",
"skbase.base._tagmanager",
"skbase.lookup",
"skbase.lookup.tests",
Expand Down Expand Up @@ -90,6 +91,7 @@
}
)
SKBASE_PUBLIC_FUNCTIONS_BY_MODULE = {
"skbase.base._serialize": ("load",),
"skbase.lookup": ("all_objects", "get_package_metadata"),
"skbase.lookup._lookup": ("all_objects", "get_package_metadata"),
"skbase.testing.utils._conditional_fixtures": (
Expand Down