Skip to content

Commit

Permalink
Make flash parser LArCV version agnostic
Browse files Browse the repository at this point in the history
  • Loading branch information
francois-drielsma committed Nov 18, 2024
1 parent eeb2f64 commit d09435b
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 25 deletions.
2 changes: 1 addition & 1 deletion spine/io/parse/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ def __init__(self, dtype, particle_event=None, add_particle_info=False,
if self.add_particle_info:
assert particle_event is not None, (
"If `add_particle_info` is `True`, must provide the "
"`particle_event` argument")
"`particle_event` argument.")

def __call__(self, trees):
"""Parse one entry.
Expand Down
29 changes: 14 additions & 15 deletions spine/io/parse/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def process(self, sparse_event=None, cluster_event=None):
"""
# Check on the input, pick a source for the metadata
assert (sparse_event is not None) ^ (cluster_event is not None), (
"Must specify either `sparse_event` or `cluster_event`")
"Must specify either `sparse_event` or `cluster_event`.")
ref_event = sparse_event if sparse_event is not None else cluster_event

# Fetch a specific projection, if needed
Expand Down Expand Up @@ -141,7 +141,7 @@ def process(self, sparse_event=None, cluster_event=None):
"""
# Check on the input, pick a source for the run information
assert (sparse_event is not None) ^ (cluster_event is not None), (
"Must specify either `sparse_event` or `cluster_event`")
"Must specify either `sparse_event` or `cluster_event`.")
ref_event = sparse_event if sparse_event is not None else cluster_event

return RunInfo.from_larcv(ref_event)
Expand Down Expand Up @@ -195,34 +195,33 @@ def process(self, flash_event=None, flash_event_list=None):
List[Flash]
List of optical flash objects
"""
# Check on the input, aggregate the sources for the optical flashes
# Check on the input
assert ((flash_event is not None) ^
(flash_event_list is not None)), (
"Must specify either `flash_event` or `flash_event_list`")
"Must specify either `flash_event` or `flash_event_list`.")

# Parse flash objects
if flash_event is not None:
# If there is a single flash event, parse it as is
flash_list = flash_event.as_vector()
flashes = [Flash.from_larcv(larcv.Flash(f)) for f in flash_list]

else:
# Otherwise, set the volume ID of the flash to the source index
# and count the flash index from 0 to the largest number
flash_list = []
flashes = []
idx = 0
for volume_id, flash_event in enumerate(flash_event_list):
for flash in flash_event.as_vector():
# Update attributes (TODO: simplify volume_id with update)
flash.id(idx)
for attr in ['tpc', 'volume_id']:
if hasattr(flash, attr):
getattr(flash, attr)(volume_id)
for f in flash_event.as_vector():
# Cast and update attributes
flash = Flash.from_larcv(f)
flash.id = idx
flash.volume_id = volume_id

# Append, increment counter
flash_list.append(flash)
flashes.append(flash)
idx += 1

# Output as a list of LArCV optical flash objects
flashes = [Flash.from_larcv(larcv.Flash(f)) for f in flash_list]

return ObjectList(flashes, Flash())


Expand Down
16 changes: 8 additions & 8 deletions spine/io/parse/sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,9 @@ def __init__(self, dtype, projection_id, sparse_event=None,

# Get the number of features in the output tensor
assert (sparse_event is not None) ^ (sparse_event_list is not None), (
"Must provide either `sparse_event` or `sparse_event_list`")
"Must provide either `sparse_event` or `sparse_event_list`.")
assert sparse_event_list is None or len(sparse_event_list), (
"Must provide as least 1 sparse_event in the list")
"Must provide as least 1 sparse_event in the list.")

self.num_features = 1
if sparse_event_list is not None:
Expand Down Expand Up @@ -118,9 +118,9 @@ def process(self, sparse_event=None, sparse_event_list=None):
larcv.fill_2d_voxels(tensor, np_voxels)
else:
assert meta == tensor.meta(), (
"The metadata must match between tensors")
"The metadata must match between tensors.")
assert num_points == tensor.as_vector().size(), (
"The number of pixels must match between tensors")
"The number of pixels must match between tensors.")

# Get the feature vector for this tensor
np_data = np.empty((num_points, 1), dtype=self.ftype)
Expand Down Expand Up @@ -194,9 +194,9 @@ def __init__(self, dtype, sparse_event=None, sparse_event_list=None,

# Get the number of features in the output tensor
assert (sparse_event is not None) ^ (sparse_event_list is not None), (
"Must provide either `sparse_event` or `sparse_event_list`")
"Must provide either `sparse_event` or `sparse_event_list`.")
assert sparse_event_list is None or len(sparse_event_list), (
"Must provide as least 1 sparse_event in the list")
"Must provide as least 1 sparse_event in the list.")

num_tensors = 1 if sparse_event is not None else len(sparse_event_list)
if self.num_features is not None:
Expand Down Expand Up @@ -260,7 +260,7 @@ def process(self, sparse_event=None, sparse_event_list=None):
meta = sparse_event.meta()
else:
assert meta == sparse_event.meta(), (
"The metadata must match between tensors")
"The metadata must match between tensors.")

if num_points is None:
num_points = sparse_event.as_vector().size()
Expand All @@ -269,7 +269,7 @@ def process(self, sparse_event=None, sparse_event_list=None):
larcv.fill_3d_voxels(sparse_event, np_voxels)
else:
assert num_points == sparse_event.as_vector().size(), (
"The number of pixels must match between tensors")
"The number of pixels must match between tensors.")

# Get the feature vector for this tensor
np_data = np.empty((num_points, 1), dtype=self.ftype)
Expand Down
2 changes: 1 addition & 1 deletion spine/post/optical/flash_matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ def process(self, data):
matches = self.matcher.get_matches(interactions_v, flashes_v)

# Store flash information
for i, (inter_v, flash, match) in enumerate(matches):
for inter_v, flash, match in matches:
# Get the interaction that matches the cropped version
inter = interactions[inter_v.id]

Expand Down

0 comments on commit d09435b

Please sign in to comment.