Skip to content

Commit

Permalink
Merge pull request #27 from francois-drielsma/develop
Browse files Browse the repository at this point in the history
Various bug fixes
  • Loading branch information
francois-drielsma authored Oct 8, 2024
2 parents 5396653 + 7810f3a commit a555af3
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 28 deletions.
8 changes: 6 additions & 2 deletions spine/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,7 @@ def initialize_io(self, loader=None, reader=None, writer=None):

# Initialize the data loader/reader
self.loader = None
self.unwrapper = None
if loader is not None:
# Initialize the torch data loader
self.watch.initialize('load')
Expand Down Expand Up @@ -409,7 +410,10 @@ def get_prefixes(file_paths, split_output):

# If there is only one file, done
if len(file_names) == 1:
return prefix
if not split_output:
return prefix, prefix
else:
return prefix, [prefix]

# Otherwise, form the suffix from the first and last file names
first = os.path.splitext(file_names[0][len(prefix):])
Expand Down Expand Up @@ -531,7 +535,7 @@ def process(self, entry=None, run=None, event=None, iteration=None):
self.watch.update(self.model.watch, 'model')

# 3. Unwrap
if self.unwrap:
if self.unwrapper is not None:
self.watch.start('unwrap')
data = self.unwrapper(data)
self.watch.stop('unwrap')
Expand Down
8 changes: 5 additions & 3 deletions spine/io/read/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,10 @@ def process_file_paths(self, file_keys, limit_num_files=None,
Maximum number of loaded file names to be printed
"""
# Some basic checks
assert file_keys is not None, (
"No input `file_keys` provided, abort.")
assert limit_num_files is None or limit_num_files > 0, (
"If `limit_num_files` is provided, it must be larger than 0")
"If `limit_num_files` is provided, it must be larger than 0.")

# If the file_keys points to a single text file, it must be a text
# file containing a list of file paths. Parse it to a list.
Expand All @@ -103,7 +105,7 @@ def process_file_paths(self, file_keys, limit_num_files=None,
# If the file list is a text file, extract the list of paths
assert os.path.isfile(file_keys), (
"If the `file_keys` are specified as a single string, "
"it must be the path to a text file with a file list")
"it must be the path to a text file with a file list.")
with open(file_keys, 'r', encoding='utf-8') as f:
file_keys = f.read().splitlines()

Expand All @@ -114,7 +116,7 @@ def process_file_paths(self, file_keys, limit_num_files=None,
for file_key in file_keys:
file_paths = glob.glob(file_key)
assert file_paths, (
f"File key {file_key} yielded no compatible path")
f"File key {file_key} yielded no compatible path.")
for path in file_paths:
if (limit_num_files is not None and
len(self.file_paths) > limit_num_files):
Expand Down
61 changes: 38 additions & 23 deletions spine/model/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,10 @@ def __call__(self, data, iteration=None):
self.save_state(iteration)
self.watch.stop('save')

# If requested, cast the result dictionary to numpy
if self.to_numpy:
self.cast_to_numpy(result)

return result

def clean_config(self, config):
Expand Down Expand Up @@ -492,29 +496,7 @@ def forward(self, data, iteration=None):
result.update(self.loss_fn(
iteration=iteration, **loss_dict, **result))

# Convert to numpy, if requested
if self.to_numpy:
for key, value in result.items():
if np.isscalar(value):
# Scalar
result[key] = value
elif (isinstance(value, torch.Tensor) and
value.numel() == 1):
# Scalar tensor
result[key] = value.item()
elif isinstance(
value, (TensorBatch, IndexBatch, EdgeIndexBatch)):
# Batch of data
result[key] = value.to_numpy()
elif (isinstance(value, list) and
len(value) and
isinstance(value[0], TensorBatch)):
# List of tensor batches
result[key] = [v.to_numpy() for v in value]
else:
raise ValueError(f"Cannot cast output {key} to numpy")

return result
return result

def backward(self, loss):
"""Run the backward step on the model.
Expand All @@ -540,6 +522,39 @@ def backward(self, loss):
logger.info('Updating buffers')
self.net.update_buffers()

def cast_to_numpy(self, result):
"""Casts the model output data products to numpy object in place.
Parameters
----------
result : dict
Dictionary of model and loss outputs
"""
# Loop over the key, value pairs in the result dictionary
for key, value in result.items():
# Cast to numpy or python scalars
if np.isscalar(value):
# Scalar
result[key] = value

elif (isinstance(value, torch.Tensor) and value.numel() == 1):
# Scalar tensor
result[key] = value.item()

elif isinstance(value, (TensorBatch, IndexBatch, EdgeIndexBatch)):
# Batch of data
result[key] = value.to_numpy()

elif (isinstance(value, list) and len(value) and
isinstance(value[0], TensorBatch)):
# List of tensor batches
result[key] = [v.to_numpy() for v in value]

else:
dtype = type(value)
raise ValueError(
f"Cannot cast output {key} of type {dtype} to numpy.")

def save_state(self, iteration):
"""Save the model state.
Expand Down

0 comments on commit a555af3

Please sign in to comment.