Skip to content

Commit

Permalink
extracted two test to outmost level
Browse files Browse the repository at this point in the history
  • Loading branch information
andped10 committed Sep 16, 2024
1 parent 20fa247 commit 7649d6b
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 42 deletions.
42 changes: 0 additions & 42 deletions tests/experiment/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
__version__ = '0.0.1'

import unittest
from copy import copy
from unittest.mock import MagicMock

import numpy as np
Expand Down Expand Up @@ -392,32 +391,6 @@ def test_repr_resolution_function(self):
)


def test_copy():
# When
resolution_function = LinearSpline([0, 10], [0, 10])
model = Model(interface=CalculatorFactory())
model.resolution_function = resolution_function
for additional_layer in [SurfactantLayer(), Multilayer(), RepeatingMultilayer()]:
model.add_item(additional_layer)

# Then
model_copy = copy(model)

# Expect
assert sorted(model.as_data_dict()) == sorted(model_copy.as_data_dict())
assert model._resolution_function.smearing(5.5) == model_copy._resolution_function.smearing(5.5)
assert model.interface().name == model_copy.interface().name
assert_almost_equal(
model.interface().fit_func([0.3], model.unique_name),
model_copy.interface().fit_func([0.3], model_copy.unique_name),
)
assert model.unique_name != model_copy.unique_name
assert model.name == model_copy.name
assert model.as_data_dict(skip=['interface', 'unique_name', 'resolution_function']) == model_copy.as_data_dict(
skip=['interface', 'unique_name', 'resolution_function']
)


@pytest.mark.parametrize(
'interface',
[None, CalculatorFactory()],
Expand Down Expand Up @@ -446,18 +419,3 @@ def test_dict_round_trip(interface):
model.interface().fit_func([0.3], model.unique_name),
model_from_dict.interface().fit_func([0.3], model_from_dict.unique_name),
)


def test_dict_skip_unique_name():
# When
resolution_function = LinearSpline([0, 10], [0, 10])
model = Model(interface=CalculatorFactory())
model.resolution_function = resolution_function
for additional_layer in [SurfactantLayer(), Multilayer(), RepeatingMultilayer()]:
model.add_item(additional_layer)

# Then
dict_no_unique_name = model.as_dict(skip=['unique_name'])

# Expect
assert 'unique_name' not in dict_no_unique_name
56 changes: 56 additions & 0 deletions tests/test_topmost_nesting.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
"""
Tests exercising the methods of the topmost classes for nested structure.
To ensure that the parameters are relayed.
"""

from copy import copy

from numpy.testing import assert_almost_equal

from easyreflectometry.calculators import CalculatorFactory
from easyreflectometry.experiment import LinearSpline
from easyreflectometry.experiment import Model
from easyreflectometry.sample import Multilayer
from easyreflectometry.sample import RepeatingMultilayer
from easyreflectometry.sample import SurfactantLayer


def test_dict_skip_unique_name():
# When
resolution_function = LinearSpline([0, 10], [0, 10])
model = Model(interface=CalculatorFactory())
model.resolution_function = resolution_function
for additional_layer in [SurfactantLayer(), Multilayer(), RepeatingMultilayer()]:
model.add_item(additional_layer)

# Then
dict_no_unique_name = model.as_dict(skip=['unique_name'])

# Expect
assert 'unique_name' not in dict_no_unique_name


def test_copy():
# When
resolution_function = LinearSpline([0, 10], [0, 10])
model = Model(interface=CalculatorFactory())
model.resolution_function = resolution_function
for additional_layer in [SurfactantLayer(), Multilayer(), RepeatingMultilayer()]:
model.add_item(additional_layer)

# Then
model_copy = copy(model)

# Expect
assert sorted(model.as_data_dict()) == sorted(model_copy.as_data_dict())
assert model._resolution_function.smearing(5.5) == model_copy._resolution_function.smearing(5.5)
assert model.interface().name == model_copy.interface().name
assert_almost_equal(
model.interface().fit_func([0.3], model.unique_name),
model_copy.interface().fit_func([0.3], model_copy.unique_name),
)
assert model.unique_name != model_copy.unique_name
assert model.name == model_copy.name
assert model.as_data_dict(skip=['interface', 'unique_name', 'resolution_function']) == model_copy.as_data_dict(
skip=['interface', 'unique_name', 'resolution_function']
)

0 comments on commit 7649d6b

Please sign in to comment.