Skip to content

Commit

Permalink
Merge pull request #22 from francois-drielsma/develop
Browse files Browse the repository at this point in the history
Bug fixes in the accuracy estimators, add support to evaluate GrapPA alone
  • Loading branch information
francois-drielsma authored Sep 17, 2024
2 parents 600a864 + de5a5cf commit 0306b38
Show file tree
Hide file tree
Showing 4 changed files with 90 additions and 40 deletions.
6 changes: 4 additions & 2 deletions spine/ana/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,12 @@ def __init__(self, obj_type=None, run_mode=None, append=False,
Name to prefix every output CSV file with
"""
# Initialize default keys
self.keys = {
if self.keys is None:
self.keys = {}
self.keys.update({
'index': True, 'file_index': True,
'file_entry_index': False, 'run_info': False
}
})

# If run mode is specified, process it
self.run_mode = run_mode
Expand Down
91 changes: 64 additions & 27 deletions spine/ana/metric/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from scipy.special import softmax

import spine.utils.metrics
from spine.utils.enums import enum_factory
from spine.utils.globals import (
SHAPE_COL, LOWES_SHP, CLUST_COL, GROUP_COL, INTER_COL)

Expand All @@ -23,24 +24,27 @@ class ClusterAna(AnaBase):
"""
name = 'cluster_eval'

# Label column to use for each clustering target
# Label column to use for each clustering label_col
_label_cols = {
'fragment': CLUST_COL, 'particle': GROUP_COL,
'interaction': INTER_COL
}

def __init__(self, obj_type, use_objects=False, per_shape=True,
metrics=('pur', 'eff', 'ari'), label_key='clust_label_adapt',
**kwargs):
def __init__(self, obj_type=None, use_objects=False, per_object=True,
per_shape=True, metrics=('pur', 'eff', 'ari'),
label_key='clust_label_adapt', label_col=None, **kwargs):
"""Initialize the analysis script.
Parameters
----------
obj_type : Union[str, List[str]]
obj_type : Union[str, List[str]], optional
Name or list of names of the object types to process
use_objects : bool, default False
If `True`, rebuild the clustering labels for truth and reco
If `True`, rebuild the clustering assignments for truth and reco
from the set of truth and reco particles
per_object : bool, default True
Evaluate the clustering accuracy for each object type (not relevant
if running GrapPA standalone)
per_shape : bool, default True
Evaluate the clustering accuracy for each object shape (not
relevant in the case of interactions)
Expand All @@ -49,33 +53,59 @@ def __init__(self, obj_type, use_objects=False, per_shape=True,
label_key : str, default 'clust_label_adapt'
Name of the tensor which contains the cluster labels, when
using the raw reconstruction output
label_col : str
Column name in the label tensor specifying the aggregation label_col
**kwargs : dict, optional
Additional arguments to pass to :class:`AnaBase`
"""
# Check parameters
assert obj_type is not None or not per_object, (
"If evaluating clustering metrics per object, provide a list "
"of object types to evaluate the clustering for.")
assert per_object or label_col is not None, (
"If evaluating clustering standalone (not per object), must "
"provide the name of the target clustering label column.")
assert per_object or not use_objects, (
"If evaluating clustering standalone (not per object), cannot "
"use objects to evaluate it.")

# Initialize the parent class
super().__init__(obj_type, 'both', **kwargs)
if not use_objects:
for key in self.obj_keys:
del self.keys[key]
if not per_object:
self.obj_type = [label_col]

# Store the basic parameters
self.use_objects = use_objects
self.per_object = per_object
self.per_shape = per_shape
self.label_key = label_key

# Parse the label_col column, if necessary
if label_col is not None:
self.label_col = enum_factory('cluster', label_col)

# Convert metric strings to functions
self.metrics = {m: getattr(spine.utils.metrics, m) for m in metrics}

# List the necessary data products
if not self.use_objects:
# Store the labels and the clusters output by the reco chain
self.keys[label_key] = True
for obj in self.obj_type:
self.keys[f'{obj}_clusts'] = True
self.keys[f'{obj}_shapes'] = True
if self.per_object:
if not self.use_objects:
# Store the labels and the clusters output by the reco chain
self.keys[label_key] = True
for obj in self.obj_type:
self.keys[f'{obj}_clusts'] = True
self.keys[f'{obj}_shapes'] = True

else:
self.keys['points'] = True

else:
self.keys['points'] = True
self.keys[label_key] = True
self.keys['clusts'] = True
self.keys['group_pred'] = True

# Initialize the output
for obj in self.obj_type:
Expand All @@ -94,8 +124,9 @@ def process(self, data):
# Build the cluster labels for this object type
if not self.use_objects:
# Fetch the right label column
label_col = self.label_col or self._label_cols[obj_type]
num_points = len(data[self.label_key])
labels = data[self.label_key][:, self._label_cols[obj_type]]
labels = data[self.label_key][:, label_col]
shapes = data[self.label_key][:, SHAPE_COL]
num_truth = len(np.unique(labels[labels > -1]))

Expand All @@ -109,21 +140,27 @@ def process(self, data):

# Build the cluster predictions for this object type
preds = -np.ones(num_points)
shapes = -np.full(num_points, LOWES_SHP)
if not self.use_objects:
# Use clusters directly from the full chain output
num_reco = len(data[f'{obj_type}_clusts'])
for i, index in enumerate(data[f'{obj_type}_clusts']):
preds[index] = i
shapes[index] = data[f'{obj_type}_shapes'][i]
if self.per_object:
shapes = -np.full(num_points, LOWES_SHP)
if not self.use_objects:
# Use clusters directly from the full chain output
num_reco = len(data[f'{obj_type}_clusts'])
for i, index in enumerate(data[f'{obj_type}_clusts']):
preds[index] = i
shapes[index] = data[f'{obj_type}_shapes'][i]

else:
# Use clusters from the object indexes
num_reco = len(data[f'reco_{obj_type}s'])
for i, obj in enumerate(data[f'reco_{obj_type}s']):
preds[obj.index] = i
if obj_type != 'interaction':
shapes[obj.index] = obj.shape

else:
# Use clusters from the object indexes
num_reco = len(data[f'reco_{obj_type}s'])
for i, obj in enumerate(data[f'reco_{obj_type}s']):
preds[obj.index] = i
if obj_type != 'interaction':
shapes[obj.index] = obj.shape
num_reco = len(data['clusts'])
for i, index in enumerate(data['clusts']):
preds[index] = data['group_pred'][i]

# Evaluate clustering metrics
row_dict = {'num_points': num_points, 'num_truth': num_truth,
Expand Down
26 changes: 18 additions & 8 deletions spine/ana/metric/point.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ class PointProposalAna(AnaBase):
"""
name = 'point_eval'

def __init__(self, num_classes=LOWES_SHP, label_key='ppn_label', **kwargs):
def __init__(self, num_classes=LOWES_SHP, label_key='ppn_label',
endpoints=False, **kwargs):
"""Initialize the analysis script.
Parameters
Expand All @@ -36,6 +37,8 @@ def __init__(self, num_classes=LOWES_SHP, label_key='ppn_label', **kwargs):
Number of pixel classses, excluding the ghost class
label_key : str, default 'seg_label'
Name of the tensor which contains the segmentation labels
endpoints : bool, default False
Evaluate the accuracy of end point classification
**kwargs : dict, optional
Additional arguments to pass to :class:`AnaBase`
"""
Expand All @@ -45,6 +48,7 @@ def __init__(self, num_classes=LOWES_SHP, label_key='ppn_label', **kwargs):
# Store the basic parameters
self.num_classes = num_classes
self.label_key = label_key
self.endpoints = endpoints

# Append other required key
self.keys['ppn_pred'] = True
Expand All @@ -56,10 +60,12 @@ def __init__(self, num_classes=LOWES_SHP, label_key='ppn_label', **kwargs):

# Initialize a dummy dictionary to return when there is no match
self.dummy_dict = {
'dist': -1., 'shape': -1, 'end': -1,
'closest_shape': -1, 'closest_end': -1}
'dist': -1., 'shape': -1, 'closest_shape': -1}
for s in range(self.num_classes):
self.dummy_dict[f'dist_{s}'] = -1.
if endpoints:
self.dummy_dict['end'] = -1
self.dummy_dict['closest_end'] = -1

def process(self, data):
"""Store the semantic segmentation metrics for one entry.
Expand All @@ -75,13 +81,13 @@ def process(self, data):
# Fetch the list of label points and their characteristics
points['truth'] = data[self.label_key][:, COORD_COLS]
types['truth'] = data[self.label_key][:, PPN_LTYPE_COL].astype(int)
if PPN_LENDP_COL < data[self.label_key].shape[1]:
if self.endpoints:
ends['truth'] = data[self.label_key][:, PPN_LENDP_COL].astype(int)

# Fetch the list of predicted points and their characteristics
points['reco'] = data['ppn_pred'][:, COORD_COLS]
types['reco'] = data['ppn_pred'][:, PPN_SHAPE_COL].astype(int)
if PPN_END_COLS[0] < data['ppn_pred'].shape[1]:
if self.endpoints:
ends['reco'] = np.argmax(data['ppn_pred'][:, PPN_END_COLS], axis=1)

# Compute the pair-wise distances between label and predicted points
Expand All @@ -98,12 +104,16 @@ def process(self, data):

# If there are no target points, record no match
if not len(points[target]):
# Append dummy values
dummy = {**self.dummy_dict}
for i in range(len(points[source])):
dummy['shape'] = types[source][i]
if len(ends):
if self.endpoints:
dummy['end'] = ends[source][i]
self.append(source, **dummy)
self.append(f'{source}_to_{target}', **dummy)

# Proceed
continue

# Otherwise, use closest point as reference
dists = dist_mat[source]
Expand All @@ -114,7 +124,7 @@ def process(self, data):
point_dict['dist'] = dists[i, closest_index[i]]
point_dict['shape'] = types[source][i]
point_dict['closest_shape'] = types[target][closest_index[i]]
if len(ends):
if self.endpoints:
point_dict['end'] = ends[source][i]
point_dict['closest_end'] = ends[target][closest_index[i]]
for s in range(self.num_classes):
Expand Down
7 changes: 4 additions & 3 deletions spine/utils/vertex.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ def get_vertex(start_points, end_points, directions, semantics,
if anchor_vertex:
# If there is a unique point where >=2 particles meet, pick it. Include
# track start and end points, to not rely on direction predictions
vertices = get_confluence_points(start_points, end_points, touching_threshold)
vertices = get_confluence_points(
start_points, end_points, touching_threshold)
if len(vertices) == 1:
if return_mode:
return vertices[0], 'confluence_nodir'
Expand All @@ -54,8 +55,8 @@ def get_vertex(start_points, end_points, directions, semantics,
# If there is more than one option, restrict track end points to the
# predicted start points (relies on direction prediction), check again
if len(vertices) > 1 and len(track_ids):
vertices = get_confluence_points(start_points,
touching_threshold=touching_threshold)
vertices = get_confluence_points(
start_points, touching_threshold=touching_threshold)
if len(vertices) == 1:
if return_mode:
return vertices[0], 'confluence_dir'
Expand Down

0 comments on commit 0306b38

Please sign in to comment.