diff --git a/spine/ana/script/save.py b/spine/ana/script/save.py index 45386350..5b3b44da 100644 --- a/spine/ana/script/save.py +++ b/spine/ana/script/save.py @@ -23,7 +23,7 @@ class SaveAna(AnaBase): 'reco_fragment': RecoFragment(), 'truth_fragment': TruthFragment(), 'reco_particle': RecoParticle(), 'truth_particle': TruthParticle(), 'reco_interaction': RecoInteraction(), - 'truth_particle': TruthInteraction() + 'truth_interaction': TruthInteraction() } def __init__(self, obj_type, fragment=None, particle=None, interaction=None, diff --git a/spine/build/interaction.py b/spine/build/interaction.py index dae48b2b..3f2aa851 100644 --- a/spine/build/interaction.py +++ b/spine/build/interaction.py @@ -140,20 +140,21 @@ def _build_truth(self, truth_particles, neutrinos=None): p.interaction_id = i # Append the neutrino information, if it is provided - if neutrinos is not None: - nu_ids = [part.nu_id for part in inter_particles] - assert len(np.unique(nu_ids)) == 1, ( - "Interaction made up of particles with different " - "neutrino IDs. Must be unique.") - if nu_ids[0] > -1: - interaction.attach_neutrino(neutrinos[nu_ids[0]]) + nu_ids = [part.nu_id for part in inter_particles] + assert len(np.unique(nu_ids)) == 1, ( + "Interaction made up of particles with different " + "neutrino IDs. Must be unique.") + interaction.nu_id = nu_ids[0] + + if neutrinos is not None and nu_ids[0] > -1: + interaction.attach_neutrino(neutrinos[nu_ids[0]]) else: anc_pos = [part.ancestor_position for part in inter_particles] anc_pos = np.unique(anc_pos, axis=0) assert len(anc_pos) == 1, ( - "Particles making up a true interaction have different " - "ancestor positions.") + "Particles making up a true interaction have " + "different ancestor positions.") interaction.vertex = anc_pos.flatten() # Append diff --git a/spine/data/out/interaction.py b/spine/data/out/interaction.py index 4228fd21..af10fafe 100644 --- a/spine/data/out/interaction.py +++ b/spine/data/out/interaction.py @@ -60,7 +60,9 @@ class InteractionBase: _fixed_length_attrs = {'vertex': 3} # Variable-length attributes as (key, dtype) pairs - _var_length_attrs = {'particle_ids': np.int32} + _var_length_attrs = { + 'particles': object, 'particle_ids': np.int32 + } # Attributes specifying coordinates _pos_attrs = ['vertex'] diff --git a/spine/data/out/particle.py b/spine/data/out/particle.py index cd90de2d..037d6a91 100644 --- a/spine/data/out/particle.py +++ b/spine/data/out/particle.py @@ -95,7 +95,9 @@ class ParticleBase: } # Variable-length attributes as (key, dtype) pairs - _var_length_attrs = {'fragment_ids': np.int32} + _var_length_attrs = { + 'fragments': object, 'fragment_ids': np.int32 + } # Attributes specifying coordinates _pos_attrs = ['start_point', 'end_point'] diff --git a/spine/io/write/hdf5.py b/spine/io/write/hdf5.py index be1509c7..0ad08e1c 100644 --- a/spine/io/write/hdf5.py +++ b/spine/io/write/hdf5.py @@ -387,6 +387,11 @@ def __call__(self, data, cfg=None): Dictionary containing the complete SPINE configuration """ # If this function has never been called, initialiaze the HDF5 file + # TODO: make this nicer? + if np.isscalar(data['index']): + for k in data: + data[k] = [data[k]] + if (not self.ready and (not self.append or os.path.isfile(self.file_name))): self.create(data, cfg)