Skip to content

Commit

Permalink
load and save project with experimental data
Browse files Browse the repository at this point in the history
  • Loading branch information
andped10 committed Nov 8, 2024
1 parent 49368a2 commit 574e770
Show file tree
Hide file tree
Showing 5 changed files with 60 additions and 42 deletions.
2 changes: 2 additions & 0 deletions src/easyreflectometry/data/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from .data_store import DataSet1D
from .data_store import ProjectData
from .measurement import load
from .measurement import load_as_dataset

__all__ = [
load,
load_as_dataset,
ProjectData,
DataSet1D,
]
13 changes: 13 additions & 0 deletions src/easyreflectometry/data/measurement.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,19 @@
from orsopy.fileio import Header
from orsopy.fileio import orso

from easyreflectometry.data import DataSet1D


def load_as_dataset(fname: Union[TextIO, str]) -> DataSet1D:
"""Load data from an ORSO .ort file as a DataSet1D."""
data_group = load(fname)
return DataSet1D(
x=data_group['coords']['Qz_0'].values,
y=data_group['data']['R_0'].values,
ye=data_group['data']['R_0'].variances,
xe=data_group['coords']['Qz_0'].variances,
)


def load(fname: Union[TextIO, str]) -> sc.DataGroup:
"""Load data from an ORSO .ort file.
Expand Down
2 changes: 1 addition & 1 deletion src/easyreflectometry/model/resolution_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,4 +59,4 @@ def smearing(self, q: Union[np.array, float]) -> np.array:
def as_dict(
self, skip: Optional[List[str]] = None
) -> dict[str, str]: # skip is kept for consistency of the as_dict signature
return {'smearing': 'LinearSpline', 'q_data_points': self.q_data_points, 'fwhm_values': self.fwhm_values}
return {'smearing': 'LinearSpline', 'q_data_points': list(self.q_data_points), 'fwhm_values': list(self.fwhm_values)}
72 changes: 34 additions & 38 deletions src/easyreflectometry/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from easyreflectometry.calculators import CalculatorFactory
from easyreflectometry.data import DataSet1D
from easyreflectometry.data import load
from easyreflectometry.data import load_as_dataset
from easyreflectometry.fitting import MultiFitter
from easyreflectometry.model import LinearSpline
from easyreflectometry.model import Model
Expand Down Expand Up @@ -198,24 +198,29 @@ def minimizer(self, minimizer: AvailableMinimizers) -> None:
self._fitter.easy_science_multi_fitter.switch_minimizer(minimizer)

@property
def experiments(self) -> List[DataSet1D]:
def experiments(self) -> Dict[int, DataSet1D]:
return self._experiments

@experiments.setter
def experiments(self, experiments: List[DataSet1D]) -> None:
def experiments(self, experiments: Dict[int, DataSet1D]) -> None:
self._experiments = experiments

@property
def path_json(self):
return self.path / 'project.json'

def load_experiment_for_model_at_index(self, path: Union[Path, str], index: Optional[int] = 0) -> None:
self._experiments[index] = load(str(path))
self._experiments[index] = load_as_dataset(str(path))
self._experiments[index].name = f'Experiment for Model {index}'
self._experiments[index].model = self.models[index]

self._with_experiments = True

# Set the resolution function if variance data is present
if sum(self._experiments[index]['coords']['Qz_0'].variances) != 0:
if sum(self._experiments[index].ye) != 0:
resolution_function = LinearSpline(
q_data_points=self._experiments[index]['coords']['Qz_0'].values,
fwhm_values=np.sqrt(self._experiments[index]['coords']['Qz_0'].variances),
q_data_points=self._experiments[index].y,
fwhm_values=np.sqrt(self._experiments[index].ye),
)
self._models[index].resolution_function = resolution_function

Expand Down Expand Up @@ -249,14 +254,7 @@ def model_data_for_model_at_index(self, index: int = 0, q_range: Optional[np.arr

def experimental_data_for_model_at_index(self, index: int = 0) -> DataSet1D:
if index in self._experiments.keys():
return DataSet1D(
name=f'Experiment for Model {index}',
x=self._experiments[index]['coords']['Qz_0'].values,
y=self._experiments[index]['data']['R_0'].values,
ye=self._experiments[index]['data']['R_0'].variances,
xe=self._experiments[index]['coords']['Qz_0'].variances,
model=self.models[index],
)
return self._experiments[index]
else:
raise IndexError(f'No experiment data for model at index {index}')

Expand Down Expand Up @@ -364,16 +362,16 @@ def _as_dict_add_materials_not_in_model_dict(self, project_dict: dict):
project_dict['materials_not_in_model'] = MaterialCollection(materials_not_in_model).as_dict(skip=['interface'])

def _as_dict_add_experiments(self, project_dict: dict):
project_dict['experiments'] = []
project_dict['experiments_models'] = []
project_dict['experiments_names'] = []
for experiment in self._experiments:
if self._experiments[0].xe is not None:
project_dict['experiments'].append([experiment.x, experiment.y, experiment.ye, experiment.xe])
else:
project_dict['experiments'].append([experiment.x, experiment.y, experiment.ye])
project_dict['experiments_models'].append(experiment.model.name)
project_dict['experiments_names'].append(experiment.name)
project_dict['experiments'] = {}
project_dict['experiments_models'] = {}
project_dict['experiments_names'] = {}

for key, experiment in self._experiments.items():
project_dict['experiments'][key] = [list(experiment.x), list(experiment.y), list(experiment.ye)]
if experiment.xe is not None:
project_dict['experiments'][key].append(list(experiment.xe))
project_dict['experiments_models'][key] = experiment.model.name
project_dict['experiments_names'][key] = experiment.name

def from_dict(self, project_dict: dict):
keys = list(project_dict.keys())
Expand All @@ -395,20 +393,18 @@ def from_dict(self, project_dict: dict):
else:
self._experiments = None

def _from_dict_extract_experiments(self, project_dict: dict):
self._experiments: List[DataSet1D] = []

for i in range(len(project_dict['experiments'])):
self._experiments.append(
DataSet1D(
name=project_dict['experiments_names'][i],
x=project_dict['experiments'][i][0],
y=project_dict['experiments'][i][1],
ye=project_dict['experiments'][i][2],
xe=project_dict['experiments'][i][3],
model=self._models[project_dict['experiments_models'][i]],
)
def _from_dict_extract_experiments(self, project_dict: dict) -> Dict[int, DataSet1D]:
experiments = {}
for key in project_dict['experiments'].keys():
experiments[key] = DataSet1D(
name=project_dict['experiments_names'][key],
x=project_dict['experiments'][key][0],
y=project_dict['experiments'][key][1],
ye=project_dict['experiments'][key][2],
xe=project_dict['experiments'][key][3],
model=self._models[project_dict['experiments_models'][key]],
)
return experiments

def _get_materials_in_models(self) -> MaterialCollection:
materials_in_model = MaterialCollection(populate_if_none=False)
Expand Down
13 changes: 10 additions & 3 deletions tests/test_project.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,8 @@ def test_dict_round_trip(self):
project.add_material(material)
minimizer = AvailableMinimizers.LMFit
project.minimizer = minimizer
fpath = os.path.join(PATH_STATIC, 'example.ort')
project.load_experiment_for_model_at_index(fpath)
project_dict = project.as_dict(include_materials_not_in_model=True)
project_materials_dict = project._materials.as_dict()

Expand All @@ -411,9 +413,11 @@ def test_save_as_json(self, tmp_path):
project.default_model()
project._info['name'] = 'Test Project'

fpath = os.path.join(PATH_STATIC, 'example.ort')
project.load_experiment_for_model_at_index(fpath)

# Then
project.save_as_json()
project.path_json

# Expect
assert project.path_json.exists()
Expand Down Expand Up @@ -508,15 +512,18 @@ def test_create(self, tmp_path):
def test_load_experiment(self):
# When
project = Project()
project.models = ModelCollection(Model(), Model(), Model(), Model(), Model(), Model())
model_5 = Model()
project.models = ModelCollection(Model(), Model(), Model(), Model(), Model(), model_5)
fpath = os.path.join(PATH_STATIC, 'example.ort')

# Then
project.load_experiment_for_model_at_index(fpath, 5)

# Expect
assert list(project.experiments.keys()) == [5]
assert isinstance(project.experiments[5], DataGroup)
assert isinstance(project.experiments[5], DataSet1D)
assert project.experiments[5].name == 'Experiment for Model 5'
assert project.experiments[5].model == model_5
assert isinstance(project.models[5].resolution_function, LinearSpline)
assert isinstance(project.models[4].resolution_function, PercentageFhwm)

Expand Down

0 comments on commit 574e770

Please sign in to comment.