Skip to content

Commit

Permalink
pr response
Browse files Browse the repository at this point in the history
  • Loading branch information
andped10 committed Oct 2, 2024
1 parent b3c4816 commit b86df21
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 25 deletions.
16 changes: 8 additions & 8 deletions src/easyreflectometry/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
class Project:
def __init__(self):
self._info = self._default_info()
self._root_path = Path(os.path.expanduser('~'))
self._path_project_parent = Path(os.path.expanduser('~'))
self._models = ModelCollection(populate_if_none=False, unique_name='project_models')
self._materials = MaterialCollection(populate_if_none=False, unique_name='project_materials')
self._calculator = None
Expand All @@ -44,7 +44,7 @@ def reset(self):
self._materials = MaterialCollection(populate_if_none=False, unique_name='project_materials')

self._info = self._default_info()
self._root_path = Path(os.path.expanduser('~'))
self._path_project_parent = Path(os.path.expanduser('~'))
self._calculator = None
self._minimizer = None
self._experiments = None
Expand All @@ -61,10 +61,10 @@ def created(self) -> bool:

@property
def path(self):
return self._root_path / self._info['name']
return self._path_project_parent / self._info['name']

def set_root_path(self, path: Union[Path, str]):
self._root_path = Path(path)
def set_path_project_parent(self, path: Union[Path, str]):
self._path_project_parent = Path(path)

@property
def models(self) -> ModelCollection:
Expand Down Expand Up @@ -145,13 +145,13 @@ def create(self):
print(f'ERROR: Directory {self.path} already exists')

def save_as_json(self, overwrite=False):
if self.path_json.exists() and not overwrite:
if self.path_json.exists() and overwrite:
print(f'File already exists {self.path_json}. Overwriting...')
self.path_json.unlink()
try:
project_json = json.dumps(self.as_dict(include_materials_not_in_model=True), indent=4)
self.path_json.parent.mkdir(exist_ok=True, parents=True)
with open(self.path_json, 'w') as file:
with open(self.path_json, mode='x') as file:
file.write(project_json)
except Exception as exception:
print(exception)
Expand All @@ -165,7 +165,7 @@ def load_from_json(self, path: Optional[Union[Path, str]] = None):
project_dict = json.load(file)
self.reset()
self.from_dict(project_dict)
self._root_path = path.parents[1]
self._path_project_parent = path.parents[1]
self._created = True
else:
print(f'ERROR: File {path} does not exist')
Expand Down
74 changes: 57 additions & 17 deletions tests/test_project.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import datetime
import os
import time
from pathlib import Path

from easyscience import global_object
Expand All @@ -25,7 +26,7 @@ def test_constructor(self):
'experiments': 'None',
'modified': datetime.datetime.now().strftime('%d.%m.%Y %H:%M'),
}
assert project._root_path == Path(os.path.expanduser('~'))
assert project._path_project_parent == Path(os.path.expanduser('~'))
assert len(project._materials) == 0
assert len(project._models) == 0
assert project._calculator is None
Expand All @@ -47,7 +48,7 @@ def test_reset(self):
project._report = 'report'
project._created = True
project._with_experiments = True
project._root_path = 'project_path'
project._path_project_parent = 'project_path'

# Then
project.reset()
Expand All @@ -65,7 +66,7 @@ def test_reset(self):
assert project._materials.unique_name == 'project_materials'
assert len(project._materials) == 0

assert project._root_path == Path(os.path.expanduser('~'))
assert project._path_project_parent == Path(os.path.expanduser('~'))
assert project._calculator is None
assert project._minimizer is None
assert project._experiments is None
Expand Down Expand Up @@ -132,7 +133,7 @@ def test_experiments(self):
def test_path_json(self, tmp_path):
# When
project = Project()
project.set_root_path(tmp_path)
project.set_path_project_parent(tmp_path)

# Then Expect
assert project.path_json == Path(tmp_path) / 'Example Project' / 'project.json'
Expand Down Expand Up @@ -313,50 +314,89 @@ def test_dict_round_trip(self):
assert project_dict[key] == new_project_dict[key]
assert project_materials_dict == new_project_materials_dict

def test_save_project(self, tmp_path):
def test_save_as_json(self, tmp_path):
# When
global_object.map._clear()
project = Project()
project.set_root_path(tmp_path)
project.set_path_project_parent(tmp_path)
project._models.append(Model())
project._info['name'] = 'Test Project'

# Then
project.save_as_json()
project.path_json

# Expect
assert project.path_json.exists()

def test_save_as_json_overwrite(self, tmp_path):
# When
global_object.map._clear()
project = Project()
project.set_path_project_parent(tmp_path)
project._models.append(Model())
project.save_as_json()
file_info = project.path_json.stat()

# Then
project._info['short_description'] = 'short_description'
project.save_as_json(overwrite=True)

# Expect
assert file_info != project.path_json.stat()

def test_save_as_json_dont_overwrite(self, tmp_path):
# When
global_object.map._clear()
project = Project()
project.set_path_project_parent(tmp_path)
project._models.append(Model())
project.save_as_json()
file_info = project.path_json.stat()

# Then
project_path = project.path_json
project._info['short_description'] = 'short_description'
project.save_as_json()

# Expect
assert project_path.exists()
assert file_info == project.path_json.stat()

def test_load_project(self, tmp_path):
def test_load_from_json(self, tmp_path):
# When
global_object.map._clear()
project = Project()
project.set_root_path(tmp_path)
project.set_path_project_parent(tmp_path)
project._models.append(Model())
project._info['name'] = 'Test Project'
project._info['name'] = 'name'
project._info['short_description'] = 'short_description'
project._info['samples'] = 'samples'
project._info['experiments'] = 'experiments'

project.save_as_json()
project_dict = project.as_dict()

global_object.map._clear()
new_project = Project()

# Then
new_project.load_from_json(tmp_path / 'Test Project' / 'project.json')
new_project.load_from_json(tmp_path / 'name' / 'project.json')
# Do it twice to ensure that potential global objects don't collide
new_project.load_from_json(tmp_path / 'Test Project' / 'project.json')
new_project.load_from_json(tmp_path / 'name' / 'project.json')

# Expect
assert len(new_project._models) == 1
assert new_project._info['name'] == 'Test Project'
assert new_project.as_dict() == project_dict
assert new_project._root_path == tmp_path
assert new_project._info['name'] == 'name'
assert new_project._info['short_description'] == 'short_description'
assert new_project._info['samples'] == 'samples'
assert new_project._info['experiments'] == 'experiments'
assert project_dict == new_project.as_dict()
assert new_project._path_project_parent == tmp_path
assert new_project.created is True

def test_create(self, tmp_path):
# When
project = Project()
project.set_root_path(tmp_path)
project.set_path_project_parent(tmp_path)
project._info['modified'] = 'modified'
project._info['name'] = 'Test Project'

Expand Down

0 comments on commit b86df21

Please sign in to comment.