Skip to content

Commit

Permalink
ability to serialize model_collection
Browse files Browse the repository at this point in the history
  • Loading branch information
andped10 committed May 15, 2024
1 parent a734e8d commit 2e65a01
Show file tree
Hide file tree
Showing 5 changed files with 89 additions and 3 deletions.
2 changes: 1 addition & 1 deletion src/easyreflectometry/experiment/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ def as_dict(self, skip: list = None) -> dict:
if skip is None:
skip = []
this_dict = super().as_dict(skip=skip)
this_dict['sample'] = self.sample.as_dict()
this_dict['sample'] = self.sample.as_dict(skip=skip)
this_dict['resolution_function'] = self.resolution_function.as_dict()
return this_dict

Expand Down
2 changes: 1 addition & 1 deletion src/easyreflectometry/experiment/model_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def __init__(
interface=None,
**kwargs,
):
if models is None:
if models == ():
models = [Model(interface=interface) for _ in range(SIZE_DEFAULT_COLLECTION)]
super().__init__(name, interface, *models, **kwargs)
self.interface = interface
Expand Down
14 changes: 14 additions & 0 deletions src/easyreflectometry/sample/base_element_collection.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from typing import Any
from typing import List
from typing import Optional

import yaml
from easyscience.Objects.Groups import BaseCollection
Expand Down Expand Up @@ -48,6 +50,18 @@ def _dict_repr(self) -> dict:
"""
return {self.name: [i._dict_repr for i in self]}

def as_dict(self, skip: Optional[List[str]] = None) -> dict:
"""
Create a dictionary representation of the collection.
:return: A dictionary representation of the collection
"""
this_dict = super().as_dict(skip=skip)
this_dict['data'] = []
for collection_element in self:
this_dict['data'].append(collection_element.as_dict(skip=skip))
return this_dict

@classmethod
def from_dict(cls, data: dict) -> Any:
"""
Expand Down
2 changes: 1 addition & 1 deletion src/easyreflectometry/sample/sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def as_dict(self, skip: list = None) -> dict:
skip = []
this_dict = super().as_dict(skip=skip)
for i, layer in enumerate(self.data):
this_dict['data'][i] = layer.as_dict()
this_dict['data'][i] = layer.as_dict(skip=skip)
return this_dict

@classmethod
Expand Down
72 changes: 72 additions & 0 deletions tests/experiment/test_model_collection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import unittest

from easyreflectometry.experiment.model import Model
from easyreflectometry.experiment.model_collection import ModelCollection


class TestModelCollection(unittest.TestCase):
def test_default(self):
# When Then
collection = ModelCollection()

# Expect
assert collection.name == 'EasyModels'
assert collection.interface is None
assert len(collection) == 2
assert collection[0].name == 'EasyModel'
assert collection[1].name == 'EasyModel'

def test_from_pars(self):
# When
model_1 = Model(name='Model1')
model_2 = Model(name='Model2')
model_3 = Model(name='Model3')

# Then
collection = ModelCollection(model_1, model_2, model_3)

# Expect
assert collection.name == 'EasyModels'
assert collection.interface is None
assert len(collection) == 3
assert collection[0].name == 'Model1'
assert collection[1].name == 'Model2'
assert collection[2].name == 'Model3'

def test_add_model(self):
# When
model_1 = Model(name='Model1')
model_2 = Model(name='Model2')

# Then
collection = ModelCollection(model_1)
collection.add_model(model_2)

# Expect
assert len(collection) == 2
assert collection[0].name == 'Model1'
assert collection[1].name == 'Model2'

def test_delete_model(self):
# When
model_1 = Model(name='Model1')
model_2 = Model(name='Model2')

# Then
collection = ModelCollection(model_1, model_2)
collection.remove_model(0)

# Expect
assert len(collection) == 1
assert collection[0].name == 'Model2'

def test_as_dict(self):
# When
model_1 = Model(name='Model1')
collection = ModelCollection(model_1)

# Then
dict_repr = collection.as_dict()

# Expect
assert dict_repr['data'][0]['resolution_function'] == {'smearing': 'PercentageFhwm', 'constant': 5.0}

0 comments on commit 2e65a01

Please sign in to comment.