diff --git a/spine/ana/template.py b/spine/ana/template.py new file mode 100644 index 00000000..e206734d --- /dev/null +++ b/spine/ana/template.py @@ -0,0 +1,83 @@ +"""Analysis module template. + +Use this template as a basis to build your own analysis script. An analysis +script takes the output of the reconstruction and the post-processors and +performs basic selection cuts and store the output to a CSV file. +""" + +# Add the imports specific to this module here +# import ... + +# Must import the analysis script base class +from spine.ana.base import AnaBase + +# Must list the post-processor(s) here to be found by the factory. +# You must also add it to the list of imported modules in the +# `spine.ana.factories`! +__all__ = ['TemplateAna'] + + +class TemplateAna(AnaBase): + """Description of what the analysis script is supposed to be doing.""" + name = 'segment_eval' + + def __init__(self, arg0, arg1, obj_type, run_mode, append_file, + overwrite_file, output_prefix): + """Initialize the analysis script. + + Parameters + ---------- + arg0 : type + Description of arg0 + arg1 : type + Description of arg1 + obj_type : Union[str, List[str]] + Name or list of names of the object types to process + run_mode : str, optional + If specified, tells whether the post-processor must run on + reconstructed ('reco'), true ('true') or both objects + ('both' or 'all') + append_file : bool, default False + If True, appends existing CSV files instead of creating new ones + overwrite_file : bool, default False + If True and the output CSV file exists, overwrite it + output_prefix : str, default None + Name to prefix every output CSV file with + """ + # Initialize the parent class + super().__init__( + obj_type, run_mode, append_file, overwrite_file, output_prefix) + + # Store parameter + self.arg0 = arg0 + self.arg1 = arg1 + + # Initialize the CSV writer(s) you want + self.initialize_writer('template') + + # Add additional required data products + self.keys['prod'] = True # Means we must have 'data' in the dictionary + + def process(self, data): + """Pass data products corresponding to one entry through the analysis. + + Parameters + ---------- + data : dict + Dictionary of data products + """ + # Fetch the keys you want + data = data['prod'] + + # Loop over all requested object types + for key in self.obj_keys: + # Loop over all objects of that type + for obj in data[key]: + # Do something with the object + disp = p.end_point - p.start_point + + # Make a dictionary of integer out of it + out = {'disp_x': disp[0], 'disp_y': disp[1], 'disp_z': disp[2]} + + # Write the row to file + self.append('template', **out) diff --git a/spine/build/particle.py b/spine/build/particle.py index 92fb8a8b..77d7af92 100644 --- a/spine/build/particle.py +++ b/spine/build/particle.py @@ -214,7 +214,7 @@ def _build_truth(self, label_tensor, label_adapt_tensor, particles, points, # Update the attributes shared between reconstructed and true particle.length = particle.distance_travel - particle.is_primary = particle.interaction_primary + particle.is_primary = bool(particle.interaction_primary > 0) particle.start_point = particle.first_step if particle.shape == TRACK_SHP: particle.end_point = particle.last_step diff --git a/spine/post/reco/kinematics.py b/spine/post/reco/kinematics.py index 9c94b81d..d79a36ad 100644 --- a/spine/post/reco/kinematics.py +++ b/spine/post/reco/kinematics.py @@ -71,7 +71,7 @@ def process(self, data): primary_scores[primary_range] = part.primary_scores[primary_range] primary_scores /= np.sum(primary_scores) part.primary_scores = primary_scores - part.is_primary = np.argmax(primary_scores) + part.is_primary = bool(np.argmax(primary_scores)) class ParticleThresholdProcessor(PostBase): @@ -148,7 +148,8 @@ def process(self, data): # Adjust the primary ID if self.primary_threshold is not None: - part.is_primary = part.primary_scores[1] >= self.primary_threshold + part.is_primary = bool( + part.primary_scores[1] >= self.primary_threshold) class InteractionTopologyProcessor(PostBase): diff --git a/spine/post/reco/label.py b/spine/post/reco/label.py index d6931860..bccc83ee 100644 --- a/spine/post/reco/label.py +++ b/spine/post/reco/label.py @@ -56,16 +56,16 @@ def process(self, data): G = nx.DiGraph() edges = [] for obj in data[k]: - G.add_node(obj.id, attr=getattr(obj, self.mode)) + G.add_node(obj.orig_id, attr=getattr(obj, self.mode)) parent = obj.parent_id - if parent in G and int(parent) != int(obj.id): - edges.append((parent, obj.id)) + if parent in G and int(parent) != int(obj.orig_id): + edges.append((parent, obj.orig_id)) G.add_edges_from(edges) G.remove_edges_from(nx.selfloop_edges(G)) # Count children for obj in data[k]: - successors = list(G.successors(obj.id)) + successors = list(G.successors(obj.orig_id)) counter = Counter() counter.update([G.nodes[succ]['attr'] for succ in successors]) children_counts = np.zeros(self.num_classes, dtype=np.int64) diff --git a/spine/post/template.py b/spine/post/template.py new file mode 100644 index 00000000..e355a862 --- /dev/null +++ b/spine/post/template.py @@ -0,0 +1,89 @@ +"""Post-processor module template. + +Use this template as a basis to build your own post-processor. A post-processor +takes the output of the reconstruction and either +- Sets additional reconstruction attributes (e.g. direction estimates) +- Adds entirely new data products (e.g. trigger time) +""" + +# Add the imports specific to this module here +# import ... + +# Must import the post-processor base class +from spine.post.base import PostBase + +# Must list the post-processor(s) here to be found by the factory. +# You must also add it to the list of imported modules in the +# `spine.post.factories`! +__all__ = ['TemplateProcssor'] + + +class TemplateProcessor(PostBase): + """Description of what the post-processor is supposed to be doing.""" + name = 'template' # Name used to call the post-processor in the config + + def __init__(self, arg0, arg1, obj_type, run_mode): + """Initialize the post-processor. + + Parameters + ---------- + arg0 : type + Description of arg0 + arg1 : type + Description of arg1 + obj_type : Union[str, List[str]] + Types of objects needed in this post-processor (fragments, + particles and/or interactions). This argument is shared between + all post-processors. If None, does not load these objects. + run_mode : str + One of 'reco', 'truth' or 'both'. Determines what kind of object + the post-processor has to run on. + """ + # Initialize the parent class + super().__init__(obj_type, run_mode) + + # Store parameter + self.arg0 = arg0 + self.arg1 = arg1 + + # Add additional required data products + self.keys['prod'] = True # Means we must have 'prod' in the dictionary + + def process(self, data): + """Pass data products corresponding to one entry through the processor. + + Parameters + ---------- + data : dict + Dictionary of data products + """ + # Fetch the keys you want + data = data['prod'] + + # Loop over all requested object types + for key in self.obj_keys: + # Loop over all objects of that type + for obj in data[key]: + # Fetch points attributes + points = self.get_points(obj) + + # Get another attribute + sources = obj.sources + + # Do something... + + # Loop over requested specific types of objects + for key in self.fragment_keys: + # Do something... + pass + + for key in self.particle_keys: + # Do something... + pass + + for key in self.interaction_keys: + # Do something... + pass + + # Return an update or override to the current data product dictionary + return {} # Can have no return as well if objects are edited in place