Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adjustments to make it possible to execute the EasyReflectometryApp #179

Merged
merged 7 commits into from
Sep 18, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ classifiers = [
]
requires-python = ">=3.9,<3.12"
dependencies = [
"easyscience>=1.0.1",
'easyscience @ git+https://github.com/EasyScience/EasyScience.git@adjustments-to-fit-app',
"scipp>=23.12.0",
"refnx>=0.1.15",
"refl1d>=0.8.14",
Expand Down
5 changes: 3 additions & 2 deletions src/easyreflectometry/experiment/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import copy
from numbers import Number
from typing import Optional
from typing import Union

import numpy as np
Expand Down Expand Up @@ -189,7 +190,7 @@ def __repr__(self) -> str:
"""String representation of the layer."""
return yaml_dump(self._dict_repr)

def as_dict(self, skip: list = None) -> dict:
def as_dict(self, skip: Optional[list[str]] = None) -> dict:
"""Produces a cleaned dict using a custom as_dict method to skip necessary things.
The resulting dict matches the parameters in __init__

Expand All @@ -200,7 +201,7 @@ def as_dict(self, skip: list = None) -> dict:
skip.extend(['sample', 'resolution_function', 'interface'])
this_dict = super().as_dict(skip=skip)
this_dict['sample'] = self.sample.as_dict(skip=skip)
this_dict['resolution_function'] = self.resolution_function.as_dict()
this_dict['resolution_function'] = self.resolution_function.as_dict(skip=skip)
if self.interface is None:
this_dict['interface'] = None
else:
Expand Down
6 changes: 3 additions & 3 deletions src/easyreflectometry/experiment/model_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
__author__ = 'github.com/arm61'

from typing import List
from typing import Optional
from typing import Tuple

from easyreflectometry.sample.base_element_collection import SIZE_DEFAULT_COLLECTION
from easyreflectometry.sample.base_element_collection import BaseElementCollection
Expand All @@ -17,7 +17,7 @@ class ModelCollection(BaseElementCollection):

def __init__(
self,
*models: Optional[tuple[Model]],
*models: Tuple[Model],
name: str = 'EasyModels',
interface=None,
populate_if_none: bool = True,
Expand Down Expand Up @@ -52,7 +52,7 @@ def remove_model(self, idx: int):
del self[idx]

def as_dict(self, skip: List[str] | None = None) -> dict:
this_dict = super().as_dict(skip)
this_dict = super().as_dict(skip=skip)
this_dict['populate_if_none'] = self.populate_if_none
return this_dict

Expand Down
18 changes: 12 additions & 6 deletions src/easyreflectometry/experiment/resolution_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from __future__ import annotations

from abc import abstractmethod
from typing import List
from typing import Optional
from typing import Union

import numpy as np
Expand All @@ -17,10 +19,10 @@

class ResolutionFunction:
@abstractmethod
def smearing(q: Union[np.array, float]) -> np.array: ...
def smearing(self, q: Union[np.array, float]) -> np.array: ...

@abstractmethod
def as_dict() -> dict: ...
def as_dict(self, skip: Optional[List[str]] = None) -> dict: ...

@classmethod
def from_dict(cls, data: dict) -> ResolutionFunction:
Expand All @@ -31,7 +33,7 @@ def from_dict(cls, data: dict) -> ResolutionFunction:
raise ValueError('Unknown resolution function type')


class PercentageFhwm:
class PercentageFhwm(ResolutionFunction):
def __init__(self, constant: Union[None, float] = None):
if constant is None:
constant = DEFAULT_RESOLUTION_FWHM_PERCENTAGE
Expand All @@ -40,17 +42,21 @@ def __init__(self, constant: Union[None, float] = None):
def smearing(self, q: Union[np.array, float]) -> np.array:
return np.ones(np.array(q).size) * self.constant

def as_dict(self) -> dict:
def as_dict(
self, skip: Optional[List[str]] = None
) -> dict[str, str]: # skip is kept for consistency of the as_dict signature
return {'smearing': 'PercentageFhwm', 'constant': self.constant}


class LinearSpline:
class LinearSpline(ResolutionFunction):
def __init__(self, q_data_points: np.array, fwhm_values: np.array):
self.q_data_points = q_data_points
self.fwhm_values = fwhm_values

def smearing(self, q: Union[np.array, float]) -> np.array:
return np.interp(q, self.q_data_points, self.fwhm_values)

def as_dict(self) -> dict:
def as_dict(
self, skip: Optional[List[str]] = None
) -> dict[str, str]: # skip is kept for consistency of the as_dict signature
return {'smearing': 'LinearSpline', 'q_data_points': self.q_data_points, 'fwhm_values': self.fwhm_values}
2 changes: 1 addition & 1 deletion src/easyreflectometry/sample/assemblies/gradient_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def _dict_repr(self) -> dict[str, str]:
'front_layer': self.front_layer._dict_repr,
}

def as_dict(self, skip: list = None) -> dict:
def as_dict(self, skip: Optional[list[str]] = None) -> dict:
"""Produces a cleaned dict using a custom as_dict method to skip necessary things.
The resulting dict matches the parameters in __init__

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ def _dict_repr(self) -> dict:
}
}

def as_dict(self, skip: list = None) -> dict:
def as_dict(self, skip: Optional[list[str]] = None) -> dict:
"""Produces a cleaned dict using a custom as_dict method to skip necessary things.
The resulting dict matches the parameters in __init__

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -262,16 +262,16 @@ def _dict_repr(self) -> dict[str, str]:
dict_repr['area_per_molecule'] = f'{self.area_per_molecule:.2f} ' f'{self._area_per_molecule.unit}'
return dict_repr

def as_dict(self, skip: list = None) -> dict[str, str]:
def as_dict(self, skip: Optional[list[str]] = None) -> dict[str, str]:
"""Produces a cleaned dict using a custom as_dict method to skip necessary things.
The resulting dict matches the parameters in __init__

:param skip: List of keys to skip, defaults to `None`.
"""
this_dict = super().as_dict(skip=skip)
this_dict['solvent_fraction'] = self.material._fraction.as_dict()
this_dict['area_per_molecule'] = self._area_per_molecule.as_dict()
this_dict['solvent'] = self.solvent.as_dict()
this_dict['solvent_fraction'] = self.material._fraction.as_dict(skip=skip)
this_dict['area_per_molecule'] = self._area_per_molecule.as_dict(skip=skip)
this_dict['solvent'] = self.solvent.as_dict(skip=skip)
del this_dict['material']
del this_dict['_scattering_length_real']
del this_dict['_scattering_length_imag']
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
__author__ = 'github.com/arm61'
from typing import Tuple
from typing import Union

from ...base_element_collection import SIZE_DEFAULT_COLLECTION
Expand All @@ -13,13 +14,21 @@ class MaterialCollection(BaseElementCollection):

def __init__(
self,
*materials: Union[list[Union[Material, MaterialMixture]], None],
*materials: Tuple[Union[Material, MaterialMixture]],
name: str = 'EasyMaterials',
interface=None,
populate_if_none: bool = True,
**kwargs,
):
if not materials:
materials = [Material(interface=interface) for _ in range(SIZE_DEFAULT_COLLECTION)]
if not materials: # Empty tuple if no materials are provided
if populate_if_none:
materials = [Material(interface=interface) for _ in range(SIZE_DEFAULT_COLLECTION)]
else:
materials = []
# Needed to ensure an empty list is created when saving and instatiating the object as_dict -> from_dict
# Else collisions might occur in global_object.map
self.populate_if_none = False

super().__init__(
name,
interface,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -201,14 +201,14 @@ def _dict_repr(self) -> dict[str, str]:
}
}

def as_dict(self, skip: list = None) -> dict[str, str]:
def as_dict(self, skip: Optional[list[str]] = None) -> dict[str, str]:
"""Produces a cleaned dict using a custom as_dict method to skip necessary things.
The resulting dict matches the parameters in __init__

:param skip: List of keys to skip, defaults to `None`.
"""
this_dict = super().as_dict(skip=skip)
this_dict['material_a'] = self._material_a.as_dict()
this_dict['material_b'] = self._material_b.as_dict()
this_dict['fraction'] = self._fraction.as_dict()
this_dict['material_a'] = self._material_a.as_dict(skip=skip)
this_dict['material_b'] = self._material_b.as_dict(skip=skip)
this_dict['fraction'] = self._fraction.as_dict(skip=skip)
return this_dict
Original file line number Diff line number Diff line change
Expand Up @@ -138,16 +138,16 @@ def _dict_repr(self) -> dict[str, str]:
}
}

def as_dict(self, skip: list = None) -> dict[str, str]:
def as_dict(self, skip: Optional[list[str]] = None) -> dict[str, str]:
"""Produces a cleaned dict using a custom as_dict method to skip necessary things.
The resulting dict matches the parameters in __init__

:param skip: List of keys to skip, defaults to `None`.
"""
this_dict = super().as_dict(skip=skip)
this_dict['material'] = self.material.as_dict()
this_dict['solvent'] = self.solvent.as_dict()
this_dict['solvent_fraction'] = self._fraction.as_dict()
this_dict['material'] = self.material.as_dict(skip=skip)
this_dict['solvent'] = self.solvent.as_dict(skip=skip)
this_dict['solvent_fraction'] = self._fraction.as_dict(skip=skip)
# Property and protected varible from material_mixture
del this_dict['material_a']
del this_dict['_material_a']
Expand Down
44 changes: 43 additions & 1 deletion tests/experiment/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
__version__ = '0.0.1'

import unittest
from copy import copy
from unittest.mock import MagicMock

import numpy as np
Expand Down Expand Up @@ -391,11 +392,37 @@ def test_repr_resolution_function(self):
)


def test_copy():
andped10 marked this conversation as resolved.
Show resolved Hide resolved
# When
resolution_function = LinearSpline([0, 10], [0, 10])
model = Model(interface=CalculatorFactory())
model.resolution_function = resolution_function
for additional_layer in [SurfactantLayer(), Multilayer(), RepeatingMultilayer()]:
model.add_item(additional_layer)

# Then
model_copy = copy(model)

# Expect
assert sorted(model.as_data_dict()) == sorted(model_copy.as_data_dict())
assert model._resolution_function.smearing(5.5) == model_copy._resolution_function.smearing(5.5)
assert model.interface().name == model_copy.interface().name
assert_almost_equal(
model.interface().fit_func([0.3], model.unique_name),
model_copy.interface().fit_func([0.3], model_copy.unique_name),
)
assert model.unique_name != model_copy.unique_name
assert model.name == model_copy.name
assert model.as_data_dict(skip=['interface', 'unique_name', 'resolution_function']) == model_copy.as_data_dict(
skip=['interface', 'unique_name', 'resolution_function']
)


@pytest.mark.parametrize(
'interface',
[None, CalculatorFactory()],
)
def test_dict_round_trip(interface): # , additional_layer):
def test_dict_round_trip(interface):
# When
resolution_function = LinearSpline([0, 10], [0, 10])
model = Model(interface=interface)
Expand All @@ -419,3 +446,18 @@ def test_dict_round_trip(interface): # , additional_layer):
model.interface().fit_func([0.3], model.unique_name),
model_from_dict.interface().fit_func([0.3], model_from_dict.unique_name),
)


def test_dict_skip_unique_name():
andped10 marked this conversation as resolved.
Show resolved Hide resolved
# When
resolution_function = LinearSpline([0, 10], [0, 10])
model = Model(interface=CalculatorFactory())
model.resolution_function = resolution_function
for additional_layer in [SurfactantLayer(), Multilayer(), RepeatingMultilayer()]:
model.add_item(additional_layer)

# Then
dict_no_unique_name = model.as_dict(skip=['unique_name'])

# Expect
assert 'unique_name' not in dict_no_unique_name
Loading