diff --git a/src/easyreflectometry/data/__init__.py b/src/easyreflectometry/data/__init__.py index fdecf1dd..18294db3 100644 --- a/src/easyreflectometry/data/__init__.py +++ b/src/easyreflectometry/data/__init__.py @@ -1,9 +1,11 @@ from .data_store import DataSet1D from .data_store import ProjectData from .measurement import load +from .measurement import load_as_dataset __all__ = [ load, + load_as_dataset, ProjectData, DataSet1D, ] diff --git a/src/easyreflectometry/data/measurement.py b/src/easyreflectometry/data/measurement.py index ff502af7..3664af4d 100644 --- a/src/easyreflectometry/data/measurement.py +++ b/src/easyreflectometry/data/measurement.py @@ -8,6 +8,19 @@ from orsopy.fileio import Header from orsopy.fileio import orso +from easyreflectometry.data import DataSet1D + + +def load_as_dataset(fname: Union[TextIO, str]) -> DataSet1D: + """Load data from an ORSO .ort file as a DataSet1D.""" + data_group = load(fname) + return DataSet1D( + x=data_group['coords']['Qz_0'].values, + y=data_group['data']['R_0'].values, + ye=data_group['data']['R_0'].variances, + xe=data_group['coords']['Qz_0'].variances, + ) + def load(fname: Union[TextIO, str]) -> sc.DataGroup: """Load data from an ORSO .ort file. diff --git a/src/easyreflectometry/model/resolution_functions.py b/src/easyreflectometry/model/resolution_functions.py index da276078..379ef207 100644 --- a/src/easyreflectometry/model/resolution_functions.py +++ b/src/easyreflectometry/model/resolution_functions.py @@ -59,4 +59,4 @@ def smearing(self, q: Union[np.array, float]) -> np.array: def as_dict( self, skip: Optional[List[str]] = None ) -> dict[str, str]: # skip is kept for consistency of the as_dict signature - return {'smearing': 'LinearSpline', 'q_data_points': self.q_data_points, 'fwhm_values': self.fwhm_values} + return {'smearing': 'LinearSpline', 'q_data_points': list(self.q_data_points), 'fwhm_values': list(self.fwhm_values)} diff --git a/src/easyreflectometry/project.py b/src/easyreflectometry/project.py index 00aef2b1..65a320e8 100644 --- a/src/easyreflectometry/project.py +++ b/src/easyreflectometry/project.py @@ -16,7 +16,7 @@ from easyreflectometry.calculators import CalculatorFactory from easyreflectometry.data import DataSet1D -from easyreflectometry.data import load +from easyreflectometry.data import load_as_dataset from easyreflectometry.fitting import MultiFitter from easyreflectometry.model import LinearSpline from easyreflectometry.model import Model @@ -198,11 +198,11 @@ def minimizer(self, minimizer: AvailableMinimizers) -> None: self._fitter.easy_science_multi_fitter.switch_minimizer(minimizer) @property - def experiments(self) -> List[DataSet1D]: + def experiments(self) -> Dict[int, DataSet1D]: return self._experiments @experiments.setter - def experiments(self, experiments: List[DataSet1D]) -> None: + def experiments(self, experiments: Dict[int, DataSet1D]) -> None: self._experiments = experiments @property @@ -210,12 +210,17 @@ def path_json(self): return self.path / 'project.json' def load_experiment_for_model_at_index(self, path: Union[Path, str], index: Optional[int] = 0) -> None: - self._experiments[index] = load(str(path)) + self._experiments[index] = load_as_dataset(str(path)) + self._experiments[index].name = f'Experiment for Model {index}' + self._experiments[index].model = self.models[index] + + self._with_experiments = True + # Set the resolution function if variance data is present - if sum(self._experiments[index]['coords']['Qz_0'].variances) != 0: + if sum(self._experiments[index].ye) != 0: resolution_function = LinearSpline( - q_data_points=self._experiments[index]['coords']['Qz_0'].values, - fwhm_values=np.sqrt(self._experiments[index]['coords']['Qz_0'].variances), + q_data_points=self._experiments[index].y, + fwhm_values=np.sqrt(self._experiments[index].ye), ) self._models[index].resolution_function = resolution_function @@ -249,14 +254,7 @@ def model_data_for_model_at_index(self, index: int = 0, q_range: Optional[np.arr def experimental_data_for_model_at_index(self, index: int = 0) -> DataSet1D: if index in self._experiments.keys(): - return DataSet1D( - name=f'Experiment for Model {index}', - x=self._experiments[index]['coords']['Qz_0'].values, - y=self._experiments[index]['data']['R_0'].values, - ye=self._experiments[index]['data']['R_0'].variances, - xe=self._experiments[index]['coords']['Qz_0'].variances, - model=self.models[index], - ) + return self._experiments[index] else: raise IndexError(f'No experiment data for model at index {index}') @@ -364,16 +362,16 @@ def _as_dict_add_materials_not_in_model_dict(self, project_dict: dict): project_dict['materials_not_in_model'] = MaterialCollection(materials_not_in_model).as_dict(skip=['interface']) def _as_dict_add_experiments(self, project_dict: dict): - project_dict['experiments'] = [] - project_dict['experiments_models'] = [] - project_dict['experiments_names'] = [] - for experiment in self._experiments: - if self._experiments[0].xe is not None: - project_dict['experiments'].append([experiment.x, experiment.y, experiment.ye, experiment.xe]) - else: - project_dict['experiments'].append([experiment.x, experiment.y, experiment.ye]) - project_dict['experiments_models'].append(experiment.model.name) - project_dict['experiments_names'].append(experiment.name) + project_dict['experiments'] = {} + project_dict['experiments_models'] = {} + project_dict['experiments_names'] = {} + + for key, experiment in self._experiments.items(): + project_dict['experiments'][key] = [list(experiment.x), list(experiment.y), list(experiment.ye)] + if experiment.xe is not None: + project_dict['experiments'][key].append(list(experiment.xe)) + project_dict['experiments_models'][key] = experiment.model.name + project_dict['experiments_names'][key] = experiment.name def from_dict(self, project_dict: dict): keys = list(project_dict.keys()) @@ -395,20 +393,18 @@ def from_dict(self, project_dict: dict): else: self._experiments = None - def _from_dict_extract_experiments(self, project_dict: dict): - self._experiments: List[DataSet1D] = [] - - for i in range(len(project_dict['experiments'])): - self._experiments.append( - DataSet1D( - name=project_dict['experiments_names'][i], - x=project_dict['experiments'][i][0], - y=project_dict['experiments'][i][1], - ye=project_dict['experiments'][i][2], - xe=project_dict['experiments'][i][3], - model=self._models[project_dict['experiments_models'][i]], - ) + def _from_dict_extract_experiments(self, project_dict: dict) -> Dict[int, DataSet1D]: + experiments = {} + for key in project_dict['experiments'].keys(): + experiments[key] = DataSet1D( + name=project_dict['experiments_names'][key], + x=project_dict['experiments'][key][0], + y=project_dict['experiments'][key][1], + ye=project_dict['experiments'][key][2], + xe=project_dict['experiments'][key][3], + model=self._models[project_dict['experiments_models'][key]], ) + return experiments def _get_materials_in_models(self) -> MaterialCollection: materials_in_model = MaterialCollection(populate_if_none=False) diff --git a/tests/test_project.py b/tests/test_project.py index 4e6c7c57..8edd4de6 100644 --- a/tests/test_project.py +++ b/tests/test_project.py @@ -385,6 +385,8 @@ def test_dict_round_trip(self): project.add_material(material) minimizer = AvailableMinimizers.LMFit project.minimizer = minimizer + fpath = os.path.join(PATH_STATIC, 'example.ort') + project.load_experiment_for_model_at_index(fpath) project_dict = project.as_dict(include_materials_not_in_model=True) project_materials_dict = project._materials.as_dict() @@ -411,9 +413,11 @@ def test_save_as_json(self, tmp_path): project.default_model() project._info['name'] = 'Test Project' + fpath = os.path.join(PATH_STATIC, 'example.ort') + project.load_experiment_for_model_at_index(fpath) + # Then project.save_as_json() - project.path_json # Expect assert project.path_json.exists() @@ -508,7 +512,8 @@ def test_create(self, tmp_path): def test_load_experiment(self): # When project = Project() - project.models = ModelCollection(Model(), Model(), Model(), Model(), Model(), Model()) + model_5 = Model() + project.models = ModelCollection(Model(), Model(), Model(), Model(), Model(), model_5) fpath = os.path.join(PATH_STATIC, 'example.ort') # Then @@ -516,7 +521,9 @@ def test_load_experiment(self): # Expect assert list(project.experiments.keys()) == [5] - assert isinstance(project.experiments[5], DataGroup) + assert isinstance(project.experiments[5], DataSet1D) + assert project.experiments[5].name == 'Experiment for Model 5' + assert project.experiments[5].model == model_5 assert isinstance(project.models[5].resolution_function, LinearSpline) assert isinstance(project.models[4].resolution_function, PercentageFhwm)