From 0aa0c619dc27b26fb4764c2491f53bbbf35e86d1 Mon Sep 17 00:00:00 2001 From: Andreas Pedersen Date: Mon, 23 Sep 2024 09:17:56 +0200 Subject: [PATCH] default collections --- .../collections/base_element_collection.py | 7 ++++++ .../sample/collections/material_collection.py | 2 +- .../sample/collections/sample.py | 22 +++++++++++++------ 3 files changed, 23 insertions(+), 8 deletions(-) diff --git a/src/easyreflectometry/sample/collections/base_element_collection.py b/src/easyreflectometry/sample/collections/base_element_collection.py index 892d556c..cb7967e2 100644 --- a/src/easyreflectometry/sample/collections/base_element_collection.py +++ b/src/easyreflectometry/sample/collections/base_element_collection.py @@ -1,3 +1,4 @@ +from copy import deepcopy from typing import List from typing import Optional @@ -52,3 +53,9 @@ def as_dict(self, skip: Optional[List[str]] = None) -> dict: for collection_element in self: this_dict['data'].append(collection_element.as_dict(skip=skip)) return this_dict + + def _make_defalut_collection(self, default_collection: List, interface): + elements = deepcopy(default_collection) + for element in elements: + element.interface = interface + return elements diff --git a/src/easyreflectometry/sample/collections/material_collection.py b/src/easyreflectometry/sample/collections/material_collection.py index d00f1fc3..64e164b7 100644 --- a/src/easyreflectometry/sample/collections/material_collection.py +++ b/src/easyreflectometry/sample/collections/material_collection.py @@ -23,7 +23,7 @@ def __init__( ): if not materials: # Empty tuple if no materials are provided if populate_if_none: - materials = DEFAULT_COLLECTION + materials = self._make_defalut_collection(DEFAULT_COLLECTION, interface) else: materials = () # Needed to ensure an empty list is created when saving and instatiating the object as_dict -> from_dict diff --git a/src/easyreflectometry/sample/collections/sample.py b/src/easyreflectometry/sample/collections/sample.py index d5505ce9..1ebb5bf6 100644 --- a/src/easyreflectometry/sample/collections/sample.py +++ b/src/easyreflectometry/sample/collections/sample.py @@ -2,6 +2,7 @@ __author__ = 'github.com/arm61' +from copy import deepcopy from typing import List from typing import Optional @@ -15,7 +16,8 @@ from ..assemblies.surfactant_layer import SurfactantLayer from ..elements.layers.layer import Layer -NR_DEFAULT_ASSEMBLIES = 2 +# NR_DEFAULT_ASSEMBLIES = 2 +DEFAULT_COLLECTION = [Multilayer(), Multilayer()] class Sample(BaseCollection): @@ -23,7 +25,7 @@ class Sample(BaseCollection): def __init__( self, - *list_assemblies: Optional[List[BaseAssembly]], + *assemblies: Optional[List[BaseAssembly]], name: str = 'EasySample', interface=None, populate_if_none: bool = True, @@ -35,19 +37,19 @@ def __init__( :param name: Name of the sample, defaults to 'EasySample'. :param interface: Calculator interface, defaults to `None`. """ - if not list_assemblies: + if not assemblies: if populate_if_none: - list_assemblies = [Multilayer(interface=interface) for _ in range(NR_DEFAULT_ASSEMBLIES)] + assemblies = self._make_defalut_collection(DEFAULT_COLLECTION, interface) else: - list_assemblies = [] + assemblies = [] # Needed to ensure an empty list is created when saving and instatiating the object as_dict -> from_dict # Else collisions might occur in global_object.map self.populate_if_none = False - for assembly in list_assemblies: + for assembly in assemblies: if not issubclass(type(assembly), BaseAssembly): raise ValueError('The elements must be an Assembly.') - super().__init__(name, *list_assemblies, **kwargs) + super().__init__(name, *assemblies, **kwargs) self.interface = interface def add_assembly(self, assembly: Optional[BaseAssembly] = None): @@ -164,3 +166,9 @@ def as_dict(self, skip: list = None) -> dict: this_dict['data'][i] = assembly.as_dict(skip=skip) this_dict['populate_if_none'] = self.populate_if_none return this_dict + + def _make_defalut_collection(self, default_collection: List, interface): + elements = deepcopy(default_collection) + for element in elements: + element.interface = interface + return elements