Skip to content

Commit

Permalink
moving responsibilities from model to sample
Browse files Browse the repository at this point in the history
  • Loading branch information
andped10 committed Sep 18, 2024
1 parent 2625550 commit 12f31a4
Show file tree
Hide file tree
Showing 4 changed files with 108 additions and 80 deletions.
46 changes: 15 additions & 31 deletions src/easyreflectometry/experiment/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@
from easyreflectometry.parameter_utils import get_as_parameter
from easyreflectometry.parameter_utils import yaml_dump
from easyreflectometry.sample import BaseAssembly
from easyreflectometry.sample import Layer
from easyreflectometry.sample import LayerCollection
from easyreflectometry.sample import Sample

from .resolution_functions import PercentageFhwm
Expand Down Expand Up @@ -92,51 +90,37 @@ def __init__(
# Must be set after resolution function
self.interface = interface

def add_item(self, *assemblies: list[BaseAssembly]) -> None:
def add_assemblies(self, *assemblies: list[BaseAssembly]) -> None:
"""Add a layer or item to the model sample.
:param assemblies: Assemblies to add to model sample.
"""
for arg in assemblies:
if issubclass(arg.__class__, BaseAssembly):
self.sample.append(arg)
for assembly in assemblies:
if issubclass(assembly.__class__, BaseAssembly):
self.sample.add_assembly(assembly)
if self.interface is not None:
self.interface().add_item_to_model(arg.unique_name, self.unique_name)
self.interface().add_item_to_model(assembly.unique_name, self.unique_name)
else:
raise ValueError(f'Object {arg} is not a valid type, must be a child of BaseAssembly.')
raise ValueError(f'Object {assembly} is not a valid type, must be a child of BaseAssembly.')

def duplicate_item(self, idx: int) -> None:
def duplicate_assembly(self, index: int) -> None:
"""Duplicate a given item or layer in a sample.
:param idx: Index of the item or layer to duplicate
"""
to_duplicate = self.sample[idx]
duplicate_layers = []
for i in to_duplicate.layers:
duplicate_layers.append(
Layer(
material=i.material,
thickness=i.thickness.value,
roughness=i.roughness.value,
name=i.name + ' duplicate',
interface=i.interface,
)
)
duplicate = to_duplicate.__class__(
LayerCollection(*duplicate_layers, name=to_duplicate.layers.name + ' duplicate'),
name=to_duplicate.name + ' duplicate',
)
self.add_item(duplicate)
self.sample.duplicate_assembly(index)
if self.interface is not None:
self.interface().add_item_to_model(self.sample[-1].unique_name, self.unique_name)

def remove_item(self, idx: int) -> None:
"""Remove an item from the model.
def remove_assembly(self, index: int) -> None:
"""Remove an assembly from the model.
:param idx: Index of the item to remove.
"""
item_unique_name = self.sample[idx].unique_name
del self.sample[idx]
assembly_unique_name = self.sample[index].unique_name
self.sample.remove_assembly(index)
if self.interface is not None:
self.interface().remove_item_from_model(item_unique_name, self.unique_name)
self.interface().remove_item_from_model(assembly_unique_name, self.unique_name)

@property
def resolution_function(self) -> ResolutionFunction:
Expand Down
3 changes: 0 additions & 3 deletions src/easyreflectometry/sample/assemblies/multilayer.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,4 @@ def from_dict(cls, data: dict) -> Multilayer:
:return: Multilayer
"""
multilayer = super().from_dict(data)
# Remove the default materials
for i in range(SIZE_DEFAULT_COLLECTION):
del multilayer.layers[0]
return multilayer
65 changes: 56 additions & 9 deletions src/easyreflectometry/sample/collections/sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

__author__ = 'github.com/arm61'

from typing import List
from typing import Union

from easyscience.Objects.Groups import BaseCollection
Expand All @@ -11,15 +10,15 @@

from ..assemblies.base_assembly import BaseAssembly
from ..assemblies.multilayer import Multilayer
from ..assemblies.repeating_multilayer import RepeatingMultilayer
from ..assemblies.surfactant_layer import SurfactantLayer
from ..elements.layers.layer import Layer

NR_DEFAULT_ASSEMBLIES = 2


class Sample(BaseCollection):
"""Collection of assemblies that represent the sample for which experimental measurements exist."""

assemblies: List[BaseAssembly]
"""A sample is a collection of assemblies that represent the structure for which experimental measurements exist."""

def __init__(
self,
Expand Down Expand Up @@ -55,24 +54,72 @@ def __init__(
super().__init__(name, *assemblies, **kwargs)
self.interface = interface

def remove_assmbly(self, index: int):
"""Remove the assembly at given index from the sample.
def add_assembly(self, assembly: BaseAssembly):
"""Add an assembly to the sample.
:param assembly: Assembly to add.
"""
self._enable_changes_to_outermost_layers()
self.append(assembly)
self._disable_changes_to_outermost_layers()

def duplicate_assembly(self, index: int):
"""Add an assembly to the sample.
:param assembly: Assembly to add.
"""
self._enable_changes_to_outermost_layers()
to_be_duplicated = self[index]
if isinstance(to_be_duplicated, RepeatingMultilayer):
duplicate = RepeatingMultilayer.from_dict(to_be_duplicated.as_dict(skip=['unique_name']))
elif isinstance(to_be_duplicated, SurfactantLayer):
duplicate = SurfactantLayer.from_dict(to_be_duplicated.as_dict(skip=['unique_name']))
elif isinstance(to_be_duplicated, Multilayer):
duplicate = Multilayer.from_dict(to_be_duplicated.as_dict(skip=['unique_name']))
duplicate.name = duplicate.name + ' duplicate'
self.append(duplicate)
self._disable_changes_to_outermost_layers()

def move_assembly_up(self, index: int):
"""Move the assembly at the given index up in the sample.
:param index: Index of the assembly to move up.
"""
if index == 0:
return
self._enable_changes_to_outermost_layers()
self.insert(index - 1, self.pop(index))
self._disable_changes_to_outermost_layers()

def move_assembly_down(self, index: int):
"""Move the assembly at the given index down in the sample.
:param index: Index of the assembly to move down.
"""
if index == len(self) - 1:
return
self._enable_changes_to_outermost_layers()
self.insert(index + 1, self.pop(index))
self._disable_changes_to_outermost_layers()

def remove_assembly(self, index: int):
"""Remove the assembly at the given index from the sample.
:param index: Index of the assembly to remove.
"""
self._enable_changes_to_outermost_layers()
self.assemblies.remove(index)
self.pop(index)
self._disable_changes_to_outermost_layers()

@property
def superphase(self) -> Layer:
"""The superphase of the sample."""
return self.assemblies[0].front_layer
return self[0].front_layer

@property
def subphase(self) -> Layer:
"""The superphase of the sample."""
return self.assemblies[1].back_layer
return self[-1].back_layer

def _enable_changes_to_outermost_layers(self):
"""Allowed to change the outermost layers of the sample.
Expand Down
Loading

0 comments on commit 12f31a4

Please sign in to comment.