From 2e65a011c98d227c61c4dd5e81f82ee43d83e2d7 Mon Sep 17 00:00:00 2001 From: Andreas Pedersen Date: Wed, 15 May 2024 14:27:23 +0200 Subject: [PATCH] ability to serialize model_collection --- src/easyreflectometry/experiment/model.py | 2 +- .../experiment/model_collection.py | 2 +- .../sample/base_element_collection.py | 14 ++++ src/easyreflectometry/sample/sample.py | 2 +- tests/experiment/test_model_collection.py | 72 +++++++++++++++++++ 5 files changed, 89 insertions(+), 3 deletions(-) create mode 100644 tests/experiment/test_model_collection.py diff --git a/src/easyreflectometry/experiment/model.py b/src/easyreflectometry/experiment/model.py index 4994dbab..e9f1dd10 100644 --- a/src/easyreflectometry/experiment/model.py +++ b/src/easyreflectometry/experiment/model.py @@ -201,7 +201,7 @@ def as_dict(self, skip: list = None) -> dict: if skip is None: skip = [] this_dict = super().as_dict(skip=skip) - this_dict['sample'] = self.sample.as_dict() + this_dict['sample'] = self.sample.as_dict(skip=skip) this_dict['resolution_function'] = self.resolution_function.as_dict() return this_dict diff --git a/src/easyreflectometry/experiment/model_collection.py b/src/easyreflectometry/experiment/model_collection.py index fb0f8b04..351446d9 100644 --- a/src/easyreflectometry/experiment/model_collection.py +++ b/src/easyreflectometry/experiment/model_collection.py @@ -19,7 +19,7 @@ def __init__( interface=None, **kwargs, ): - if models is None: + if models == (): models = [Model(interface=interface) for _ in range(SIZE_DEFAULT_COLLECTION)] super().__init__(name, interface, *models, **kwargs) self.interface = interface diff --git a/src/easyreflectometry/sample/base_element_collection.py b/src/easyreflectometry/sample/base_element_collection.py index d55456b1..d59cc9cc 100644 --- a/src/easyreflectometry/sample/base_element_collection.py +++ b/src/easyreflectometry/sample/base_element_collection.py @@ -1,4 +1,6 @@ from typing import Any +from typing import List +from typing import Optional import yaml from easyscience.Objects.Groups import BaseCollection @@ -48,6 +50,18 @@ def _dict_repr(self) -> dict: """ return {self.name: [i._dict_repr for i in self]} + def as_dict(self, skip: Optional[List[str]] = None) -> dict: + """ + Create a dictionary representation of the collection. + + :return: A dictionary representation of the collection + """ + this_dict = super().as_dict(skip=skip) + this_dict['data'] = [] + for collection_element in self: + this_dict['data'].append(collection_element.as_dict(skip=skip)) + return this_dict + @classmethod def from_dict(cls, data: dict) -> Any: """ diff --git a/src/easyreflectometry/sample/sample.py b/src/easyreflectometry/sample/sample.py index 7b695341..7dec0c94 100644 --- a/src/easyreflectometry/sample/sample.py +++ b/src/easyreflectometry/sample/sample.py @@ -69,7 +69,7 @@ def as_dict(self, skip: list = None) -> dict: skip = [] this_dict = super().as_dict(skip=skip) for i, layer in enumerate(self.data): - this_dict['data'][i] = layer.as_dict() + this_dict['data'][i] = layer.as_dict(skip=skip) return this_dict @classmethod diff --git a/tests/experiment/test_model_collection.py b/tests/experiment/test_model_collection.py new file mode 100644 index 00000000..4b89246a --- /dev/null +++ b/tests/experiment/test_model_collection.py @@ -0,0 +1,72 @@ +import unittest + +from easyreflectometry.experiment.model import Model +from easyreflectometry.experiment.model_collection import ModelCollection + + +class TestModelCollection(unittest.TestCase): + def test_default(self): + # When Then + collection = ModelCollection() + + # Expect + assert collection.name == 'EasyModels' + assert collection.interface is None + assert len(collection) == 2 + assert collection[0].name == 'EasyModel' + assert collection[1].name == 'EasyModel' + + def test_from_pars(self): + # When + model_1 = Model(name='Model1') + model_2 = Model(name='Model2') + model_3 = Model(name='Model3') + + # Then + collection = ModelCollection(model_1, model_2, model_3) + + # Expect + assert collection.name == 'EasyModels' + assert collection.interface is None + assert len(collection) == 3 + assert collection[0].name == 'Model1' + assert collection[1].name == 'Model2' + assert collection[2].name == 'Model3' + + def test_add_model(self): + # When + model_1 = Model(name='Model1') + model_2 = Model(name='Model2') + + # Then + collection = ModelCollection(model_1) + collection.add_model(model_2) + + # Expect + assert len(collection) == 2 + assert collection[0].name == 'Model1' + assert collection[1].name == 'Model2' + + def test_delete_model(self): + # When + model_1 = Model(name='Model1') + model_2 = Model(name='Model2') + + # Then + collection = ModelCollection(model_1, model_2) + collection.remove_model(0) + + # Expect + assert len(collection) == 1 + assert collection[0].name == 'Model2' + + def test_as_dict(self): + # When + model_1 = Model(name='Model1') + collection = ModelCollection(model_1) + + # Then + dict_repr = collection.as_dict() + + # Expect + assert dict_repr['data'][0]['resolution_function'] == {'smearing': 'PercentageFhwm', 'constant': 5.0}