Skip to content

Commit

Permalink
default collections
Browse files Browse the repository at this point in the history
  • Loading branch information
andped10 committed Sep 23, 2024
1 parent b848b86 commit 0aa0c61
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 8 deletions.
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from copy import deepcopy
from typing import List
from typing import Optional

Expand Down Expand Up @@ -52,3 +53,9 @@ def as_dict(self, skip: Optional[List[str]] = None) -> dict:
for collection_element in self:
this_dict['data'].append(collection_element.as_dict(skip=skip))
return this_dict

def _make_defalut_collection(self, default_collection: List, interface):
elements = deepcopy(default_collection)
for element in elements:
element.interface = interface
return elements
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def __init__(
):
if not materials: # Empty tuple if no materials are provided
if populate_if_none:
materials = DEFAULT_COLLECTION
materials = self._make_defalut_collection(DEFAULT_COLLECTION, interface)
else:
materials = ()
# Needed to ensure an empty list is created when saving and instatiating the object as_dict -> from_dict
Expand Down
22 changes: 15 additions & 7 deletions src/easyreflectometry/sample/collections/sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

__author__ = 'github.com/arm61'

from copy import deepcopy
from typing import List
from typing import Optional

Expand All @@ -15,15 +16,16 @@
from ..assemblies.surfactant_layer import SurfactantLayer
from ..elements.layers.layer import Layer

NR_DEFAULT_ASSEMBLIES = 2
# NR_DEFAULT_ASSEMBLIES = 2
DEFAULT_COLLECTION = [Multilayer(), Multilayer()]


class Sample(BaseCollection):
"""A sample is a collection of assemblies that represent the structure for which experimental measurements exist."""

def __init__(
self,
*list_assemblies: Optional[List[BaseAssembly]],
*assemblies: Optional[List[BaseAssembly]],
name: str = 'EasySample',
interface=None,
populate_if_none: bool = True,
Expand All @@ -35,19 +37,19 @@ def __init__(
:param name: Name of the sample, defaults to 'EasySample'.
:param interface: Calculator interface, defaults to `None`.
"""
if not list_assemblies:
if not assemblies:
if populate_if_none:
list_assemblies = [Multilayer(interface=interface) for _ in range(NR_DEFAULT_ASSEMBLIES)]
assemblies = self._make_defalut_collection(DEFAULT_COLLECTION, interface)
else:
list_assemblies = []
assemblies = []
# Needed to ensure an empty list is created when saving and instatiating the object as_dict -> from_dict
# Else collisions might occur in global_object.map
self.populate_if_none = False

for assembly in list_assemblies:
for assembly in assemblies:
if not issubclass(type(assembly), BaseAssembly):
raise ValueError('The elements must be an Assembly.')
super().__init__(name, *list_assemblies, **kwargs)
super().__init__(name, *assemblies, **kwargs)
self.interface = interface

def add_assembly(self, assembly: Optional[BaseAssembly] = None):
Expand Down Expand Up @@ -164,3 +166,9 @@ def as_dict(self, skip: list = None) -> dict:
this_dict['data'][i] = assembly.as_dict(skip=skip)
this_dict['populate_if_none'] = self.populate_if_none
return this_dict

def _make_defalut_collection(self, default_collection: List, interface):
elements = deepcopy(default_collection)
for element in elements:
element.interface = interface
return elements

0 comments on commit 0aa0c61

Please sign in to comment.