Skip to content

Commit

Permalink
experiments and q values
Browse files Browse the repository at this point in the history
  • Loading branch information
andped10 committed Oct 23, 2024
1 parent d6a661a commit 2cca2df
Show file tree
Hide file tree
Showing 2 changed files with 112 additions and 13 deletions.
66 changes: 53 additions & 13 deletions src/easyreflectometry/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,19 @@
import json
import os
from pathlib import Path
from typing import Dict
from typing import List
from typing import Optional
from typing import Union

import numpy as np
from easyscience import global_object
from easyscience.fitting import AvailableMinimizers
from scipp import DataGroup

from easyreflectometry.calculators import CalculatorFactory
from easyreflectometry.data import DataSet1D
from easyreflectometry.data import load
from easyreflectometry.model import Model
from easyreflectometry.model import ModelCollection
from easyreflectometry.model import PercentageFhwm
Expand All @@ -27,15 +30,6 @@

DEFAULT_MINIZER = AvailableMinimizers.LMFit_leastsq

EXPERIMENTAL_DATA = [
DataSet1D(
name='Example Data 0',
x=np.linspace(Q_MIN, Q_MAX, Q_ELEMENTS),
y=3 * np.linspace(Q_MIN, Q_MAX, Q_ELEMENTS),
ye=0.1 * np.linspace(Q_MIN, Q_MAX, Q_ELEMENTS),
)
]


class Project:
def __init__(self):
Expand All @@ -45,9 +39,12 @@ def __init__(self):
self._materials = MaterialCollection(populate_if_none=False, unique_name='project_materials')
self._calculator = CalculatorFactory()
self._minimizer = DEFAULT_MINIZER
self._experiments: List[DataSet1D] = None
self._experiments: Dict[DataGroup] = {}
self._colors = None
self._report = None
self._q_min = None
self._q_max = None
self._q_elements = None

# Project flags
self._created = False
Expand All @@ -65,14 +62,44 @@ def reset(self):
self._path_project_parent = Path(os.path.expanduser('~'))
self._calculator = CalculatorFactory()
self._minimizer = DEFAULT_MINIZER
self._experiments = None
self._experiments = {}
self._colors = None
self._report = None

# Project flags
self._created = False
self._with_experiments = False

@property
def q_min(self):
if self._q_min is None:
return Q_MIN
return self._q_min

@q_min.setter
def q_min(self, value: float) -> None:
self._q_min = value

@property
def q_max(self):
if self._q_max is None:
return Q_MAX
return self._q_max

@q_max.setter
def q_max(self, value: float) -> None:
self._q_max = value

@property
def q_elements(self):
if self._q_elements is None:
return Q_ELEMENTS
return self._q_elements

@q_elements.setter
def q_elements(self, value: int) -> None:
self._q_elements = value

@property
def created(self) -> bool:
return self._created
Expand Down Expand Up @@ -113,6 +140,9 @@ def experiments(self, experiments: List[DataSet1D]) -> None:
def path_json(self):
return self.path / 'project.json'

def load_experiment_for_model_at_index(self, path: Union[Path, str], index: Optional[int] = 0) -> None:
self._experiments[index] = load(str(path))

def sld_data_for_model_at_index(self, index: int = 0) -> DataSet1D:
self.models[index].interface = self._calculator
sld = self.models[index].interface().sld_profile(self._models[index].unique_name)
Expand All @@ -132,7 +162,7 @@ def sample_data_for_model_at_index(self, index: int = 0, q_range: Optional[np.ar

def model_data_for_model_at_index(self, index: int = 0, q_range: Optional[np.array] = None) -> DataSet1D:
if q_range is None:
q_range = np.linspace(Q_MIN, Q_MAX, Q_ELEMENTS)
q_range = np.linspace(self.q_min, self.q_max, self.q_elements)
self.models[index].interface = self._calculator
reflectivity = self.models[index].interface().reflectity_profile(q_range, self._models[index].unique_name)
return DataSet1D(
Expand All @@ -142,7 +172,17 @@ def model_data_for_model_at_index(self, index: int = 0, q_range: Optional[np.arr
)

def experimental_data_for_model_at_index(self, index: int = 0) -> DataSet1D:
return EXPERIMENTAL_DATA[index]
if index in self._experiments.keys():
return DataSet1D(
name=f'Experiment for Model {index}',
x=self._experiments[index]['coords']['Qz_0'].values,
y=self._experiments[index]['data']['R_0'].values,
ye=self._experiments[index]['data']['R_0'].variances,
xe=self._experiments[index]['coords']['Qz_0'].variances,
model=self.models[index],
)
else:
raise IndexError(f'No experiment data for model at index {index}')

def default_model(self):
self._replace_collection(MaterialCollection(), self._materials)
Expand Down
59 changes: 59 additions & 0 deletions tests/test_project.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,18 @@
from easyscience import global_object
from easyscience.fitting import AvailableMinimizers
from numpy.testing import assert_allclose
from scipp import DataGroup

import easyreflectometry
from easyreflectometry.data import DataSet1D
from easyreflectometry.model import Model
from easyreflectometry.model import ModelCollection
from easyreflectometry.project import Project
from easyreflectometry.sample import Material
from easyreflectometry.sample import MaterialCollection

PATH_STATIC = os.path.join(os.path.dirname(easyreflectometry.__file__), '..', '..', 'tests', '_static')


class TestProject:
def test_constructor(self):
Expand Down Expand Up @@ -465,3 +470,57 @@ def test_create(self, tmp_path):
'experiments': 'None',
'modified': datetime.datetime.now().strftime('%d.%m.%Y %H:%M'),
}

def test_load_experiment(self):
# When
project = Project()
fpath = os.path.join(PATH_STATIC, 'example.ort')

# Then
project.load_experiment_for_model_at_index(fpath, 5)

# Expect
assert list(project.experiments.keys()) == [5]
assert isinstance(project.experiments[5], DataGroup)

def test_experimental_data_at_index(self):
# When
project = Project()
fpath = os.path.join(PATH_STATIC, 'example.ort')
project.load_experiment_for_model_at_index(fpath)
project.models = ModelCollection(Model())

# Then
data = project.experimental_data_for_model_at_index()

# Expect
assert data.name == 'Experiment for Model 0'
assert data.is_experiment
assert isinstance(data, DataSet1D)
assert len(data.x) == 408
assert len(data.xe) == 408
assert len(data.y) == 408
assert len(data.ye) == 408

def test_q(self):
# When
project = Project()

# Then
q = project.q_min, project.q_max, project.q_elements

# Expect
assert q == (0.001, 0.3, 500)

def test_set_q(self):
# When
project = Project()

# Then
project.q_min = 1
project.q_max = 2
project.q_elements = 3

# Expect
q = project.q_min, project.q_max, project.q_elements
assert q == (1, 2, 3)

0 comments on commit 2cca2df

Please sign in to comment.