diff --git a/src/easyreflectometry/sample/collections/material_collection.py b/src/easyreflectometry/sample/collections/material_collection.py index 875a5175..6295d8ef 100644 --- a/src/easyreflectometry/sample/collections/material_collection.py +++ b/src/easyreflectometry/sample/collections/material_collection.py @@ -3,9 +3,14 @@ from typing import Tuple from ..elements.materials.material import Material -from .base_element_collection import SIZE_DEFAULT_COLLECTION from .base_element_collection import BaseElementCollection +DEFAULT_COLLECTION = ( + Material(sld=0.0, isld=0.0, name='Air'), + Material(sld=6.335, isld=0.0, name='D2O'), + Material(sld=2.074, isld=0.0, name='Si'), +) + class MaterialCollection(BaseElementCollection): def __init__( @@ -18,7 +23,7 @@ def __init__( ): if not materials: # Empty tuple if no materials are provided if populate_if_none: - materials = (Material(interface=interface) for _ in range(SIZE_DEFAULT_COLLECTION)) + materials = DEFAULT_COLLECTION # (Material(interface=interface) for _ in range(SIZE_DEFAULT_COLLECTION)) else: materials = () # Needed to ensure an empty list is created when saving and instatiating the object as_dict -> from_dict @@ -37,6 +42,8 @@ def add_material(self, material: Optional[Material] = None): :param material: Material to add. """ + if material is None: + material = Material(name='New EasyMaterial', interface=self.interface) self.append(material) def duplicate_material(self, index: int): diff --git a/tests/sample/collections/test_material_collection.py b/tests/sample/collections/test_material_collection.py index 5934a9d9..ca85e3ee 100644 --- a/tests/sample/collections/test_material_collection.py +++ b/tests/sample/collections/test_material_collection.py @@ -13,9 +13,10 @@ def test_default(self): p = MaterialCollection() assert p.name == 'EasyMaterials' assert p.interface is None - assert len(p) == 2 - assert p[0].name == 'EasyMaterial' - assert p[1].name == 'EasyMaterial' + assert len(p) == 3 + assert p[0].name == 'Air' + assert p[1].name == 'D2O' + assert p[2].name == 'Si' def test_from_pars(self): m = Material(6.908, -0.278, 'Boron') @@ -37,8 +38,9 @@ def test_dict_repr(self): p = MaterialCollection() assert p._dict_repr == { 'EasyMaterials': [ - {'EasyMaterial': {'isld': '0.000e-6 1/Å^2', 'sld': '4.186e-6 1/Å^2'}}, - {'EasyMaterial': {'isld': '0.000e-6 1/Å^2', 'sld': '4.186e-6 1/Å^2'}}, + {'Air': {'isld': '0.000e-6 1/Å^2', 'sld': '0.000e-6 1/Å^2'}}, + {'D2O': {'isld': '0.000e-6 1/Å^2', 'sld': '6.335e-6 1/Å^2'}}, + {'Si': {'isld': '0.000e-6 1/Å^2', 'sld': '2.074e-6 1/Å^2'}}, ] } @@ -47,7 +49,7 @@ def test_repr(self): p.__repr__() assert ( p.__repr__() - == 'EasyMaterials:\n- EasyMaterial:\n sld: 4.186e-6 1/Å^2\n isld: 0.000e-6 1/Å^2\n- EasyMaterial:\n sld: 4.186e-6 1/Å^2\n isld: 0.000e-6 1/Å^2\n' # noqa: E501 + == 'EasyMaterials:\n- Air:\n sld: 0.000e-6 1/Å^2\n isld: 0.000e-6 1/Å^2\n- D2O:\n sld: 6.335e-6 1/Å^2\n isld: 0.000e-6 1/Å^2\n- Si:\n sld: 2.074e-6 1/Å^2\n isld: 0.000e-6 1/Å^2\n' # noqa: E501 ) def test_dict_round_trip(self): @@ -56,7 +58,7 @@ def test_dict_round_trip(self): k = Material(0.487, 0.000, 'Potassium') p = MaterialCollection() p.insert(0, m) - p.append(k) + p.add_material(k) p_dict = p.as_dict() global_object.map._clear() @@ -75,7 +77,7 @@ def test_add_material(self): p.add_material(m) # Then - assert p[2] == m + assert p[3] == m def test_duplicate_material(self): # Given @@ -84,10 +86,10 @@ def test_duplicate_material(self): p.add_material(m) # When - p.duplicate_material(2) + p.duplicate_material(3) # Then - assert p[3].name == 'Boron duplicate' + assert p[4].name == 'Boron duplicate' def test_move_material_up(self): # Given @@ -96,11 +98,11 @@ def test_move_material_up(self): p.add_material(k) # When - p.move_material_up(2) + p.move_material_up(3) # Then - assert p[1].name == 'Bottom' - assert p[2].name == 'EasyMaterial' + assert p[2].name == 'Bottom' + assert p[3].name == 'Si' def test_move_material_up_to_top_and_further(self): # Given @@ -109,13 +111,14 @@ def test_move_material_up_to_top_and_further(self): p.add_material(m) # When + p.move_material_up(3) p.move_material_up(2) p.move_material_up(1) p.move_material_up(0) # Then assert p[0].name == 'Bottom' - assert p[2].name == 'EasyMaterial' + assert p[3].name == 'Si' def test_move_material_down(self): # Given @@ -124,11 +127,11 @@ def test_move_material_down(self): p.add_material(m) # When - p.move_material_down(1) + p.move_material_down(2) # Then - assert p[1].name == 'Bottom' - assert p[2].name == 'EasyMaterial' + assert p[2].name == 'Bottom' + assert p[3].name == 'Si' def test_move_material_down_to_bottom_and_further(self): # Given @@ -139,13 +142,13 @@ def test_move_material_down_to_bottom_and_further(self): p.add_material(m) # When - p.move_material_down(2) p.move_material_down(3) + p.move_material_down(4) # Then - assert p[0].name == 'EasyMaterial' - assert p[2].name == 'Bottom' - assert p[3].name == 'Middle' + assert p[0].name == 'Air' + assert p[3].name == 'Bottom' + assert p[4].name == 'Middle' def test_remove_material(self): # Given @@ -157,6 +160,6 @@ def test_remove_material(self): p.remove_material(1) # Then - assert len(p) == 2 - assert p[0].name == 'EasyMaterial' - assert p[1].name == 'Bottom' + assert len(p) == 3 + assert p[0].name == 'Air' + assert p[2].name == 'Bottom'