Skip to content

Commit

Permalink
Merge pull request #17 from francois-drielsma/develop
Browse files Browse the repository at this point in the history
Minor bug fixes
  • Loading branch information
francois-drielsma authored Sep 12, 2024
2 parents 2e147c6 + b002d8b commit 1a19d14
Show file tree
Hide file tree
Showing 3 changed files with 173 additions and 2 deletions.
14 changes: 13 additions & 1 deletion spine/io/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from torch.utils.data import Dataset

from spine.utils.factory import module_dict, instantiate
from spine.utils.augment import Augmenter

from . import parse
from .read import LArCVReader
Expand All @@ -28,7 +29,7 @@ class LArCVDataset(Dataset):
"""
name = 'larcv'

def __init__(self, schema, dtype, **kwargs):
def __init__(self, schema, dtype, augment=None, **kwargs):
"""Instantiates the LArCVDataset.
Parameters
Expand All @@ -42,6 +43,8 @@ def __init__(self, schema, dtype, **kwargs):
names and their values
dtype : str
Data type to cast the input data to (to match the downstream model)
augment : dict, optional
Augmentation strategy configuration
**kwargs : dict, optional
Additional arguments to pass to the LArCVReader class
"""
Expand All @@ -58,6 +61,11 @@ def __init__(self, schema, dtype, **kwargs):
if key not in tree_keys:
tree_keys.append(key)

# Parse the augmentation configuration
self.augmenter = None
if augment is not None:
self.augmenter = Augmenter(**augment)

# Instantiate the reader
self.reader = LArCVReader(tree_keys=tree_keys, **kwargs)

Expand Down Expand Up @@ -102,6 +110,10 @@ def __getitem__(self, idx):
print(f"Failed to produce {name} using {parser}")
raise err

# If requested, augment the data
if self.augmenter is not None:
result = self.augmenter(result)

return result

def data_keys(self):
Expand Down
6 changes: 5 additions & 1 deletion spine/post/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class PostBase(ABC):
name = None
aliases = ()
parent_path = ''
keys = {}
keys = None
truth_point_mode = 'points'
units = 'cm'

Expand Down Expand Up @@ -76,6 +76,10 @@ def __init__(self, obj_type=None, run_mode=None, truth_point_mode=None,
Path to the parent directory of the main analysis configuration. This
allows for the use of relative paths in the post-processors.
"""
# Initialize default keys
if self.keys is None:
self.keys = {}

# If run mode is specified, process it
if run_mode is not None:
# Check that the run mode is recognized
Expand Down
155 changes: 155 additions & 0 deletions spine/utils/augment.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
"""Module with methods to augment the input data to SPINE."""

import numpy as np

from spine.data import Meta


class Augmenter:
"""Generic class to handle data augmentation in SPINE."""

def __init__(self, translate=None):
"""Initialize the augmenter.
Parameters
----------
translate : dict, optional
Translation confiugration (move input image around)
"""
# Make sure at least one augmentation scheme is requested
assert translate is not None, (
"Must provide `translate` block minimally to do any augmentation.")

# Parse the translation configuration
self.translater = None
if translate is not None:
self.translater = Translater(**translate)

def __call__(self, data):
"""Augment the data products in one event.
Parameters
----------
data : dict
Data product dictionary
"""
# Get the list of keys to augment and the shared metadata
augment_keys = []
meta = None
for key, value in data.items():
if (isinstance(value, tuple) and len(value) == 3 and
isinstance(value[2], Meta)):
augment_keys.append(key)
if meta is None:
meta = value[2]
else:
assert meta == value[2], (
"Metadata should be shared by all data products.")
elif isinstance(value, Meta):
augment_keys.append(key)
meta = value

# If there are no sparse tensors in the input data, nothing to do
if meta is None:
return data

# Translate
if self.translater is not None:
data = self.translater(data, meta, augment_keys)

return data


class Translater:
"""Generic class to handle moving images around."""

def __init__(self, lower, upper):
"""Initialize the translater..
This defines a way to move the image around within a volume greater
than that define by the image metadata. The box must be larger than
the image itself.
Parameters
----------
lower : np.ndarray
Lower bounds of the box in which to move the image around
upper : np.ndarray
Upper bounds of the box in which to move the image around
"""
# Sanity check
assert len(lower) == len(upper) == 3, (
"Must provide boundaries for each dimension.")

# Define a new image metadata corresponding to the full range
self.meta = Meta(lower=np.asarray(lower), upper=np.asarray(upper))

def __call__(self, data, meta, keys):
"""Move an image around within the the pre-defined volume.
Parameters
----------
data : dict
Dictionary of data products to offset
meta : Meta
Shared image metadata
keys : List[str]
List of keys with coordinates to offset
Returns
-------
np.ndarray
(N, 3) Translated points
"""
# Set the target volume pixel pitch to match that of the original image
if np.all(self.meta.size < 0.):
self.meta.size = meta.size
self.meta.count = (self.meta.upper - self.meta.lower)//meta.size
self.meta.count = self.meta.count.astype(int)

# Generate an offset
offset = self.generate_offset(meta)

# Offset all coordinates
for key in keys:
# If the key is the metadata, modify and continue
if isinstance(data[key], Meta):
data[key] = self.meta
continue

# Fetch attributes to modify
voxels, features, _ = data[key]

# Translate
width = voxels.shape[1]
voxels = (voxels.reshape(-1, 3) + offset).reshape(-1, width)

# Update
data[key] = (voxels, features, self.meta)

return data

def generate_offset(self, meta):
"""Generate an offset to apply to all the voxel index sets.
This offset is such that the the voxels will be randomly shifted
within the target bounding box.
Parameters
----------
meta : Meta
Metadata of the original image
Returns
-------
np.ndarray
Value by which to shift the pixels by
"""
# Check that the original metadata is compatible with the target volume
assert np.all(meta.count <= self.meta.count), (
"The input image is larger than the target translation volume.")

# Generate an offset with respect to the voxel indices
offset = np.random.randint((self.meta.count - meta.count) + 1)

return offset

0 comments on commit 1a19d14

Please sign in to comment.