Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add ONNXInference algorithm, use it to provide EcalEndcapNClusterParticleIDs #1618

Merged
merged 2 commits into from
Dec 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 86 additions & 0 deletions src/algorithms/onnx/CalorimeterParticleIDPostML.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
// SPDX-License-Identifier: LGPL-3.0-or-later
// Copyright (C) 2024 Dmitry Kalinkin

#include <edm4eic/EDM4eicVersion.h>

#if EDM4EIC_VERSION_MAJOR >= 8
#include <cstddef>
#include <fmt/core.h>
#include <gsl/pointers>
#include <stdexcept>

#include "CalorimeterParticleIDPostML.h"

namespace eicrecon {

void CalorimeterParticleIDPostML::init() {
// Nothing
}

void CalorimeterParticleIDPostML::process(
const CalorimeterParticleIDPostML::Input& input,
const CalorimeterParticleIDPostML::Output& output) const {

const auto [in_clusters, in_assocs, prediction_tensors] = input;
auto [out_clusters, out_assocs, out_particle_ids] = output;
veprbl marked this conversation as resolved.
Show resolved Hide resolved

if (prediction_tensors->size() != 1) {
error("Expected to find a single tensor, found {}", prediction_tensors->size());
throw std::runtime_error("");
}
edm4eic::Tensor prediction_tensor = (*prediction_tensors)[0];

if (prediction_tensor.shape_size() != 2) {
error("Expected tensor rank to be 2, but it is {}", prediction_tensor.shape_size());
throw std::runtime_error(fmt::format("Expected tensor rank to be 2, but it is {}", prediction_tensor.shape_size()));
}

if (prediction_tensor.getShape(0) != in_clusters->size()) {
error("Length mismatch between tensor's 0th axis and number of clusters: {} != {}", prediction_tensor.getShape(0), in_clusters->size());
throw std::runtime_error(fmt::format("Length mismatch between tensor's 0th axis and number of clusters: {} != {}", prediction_tensor.getShape(0), in_clusters->size()));
}

if (prediction_tensor.getShape(1) != 2) {
error("Expected 2 values per cluster in the output tensor, got {}", prediction_tensor.getShape(0));
throw std::runtime_error(fmt::format("Expected 2 values per cluster in the output tensor, got {}", prediction_tensor.getShape(0)));
}

if (prediction_tensor.getElementType() != 1) { // 1 - float
error("Expected a tensor of floats, but element type is {}", prediction_tensor.getElementType());
throw std::runtime_error(fmt::format("Expected a tensor of floats, but element type is {}", prediction_tensor.getElementType()));
}

for (size_t cluster_ix = 0; cluster_ix < in_clusters->size(); cluster_ix++) {
edm4eic::Cluster in_cluster = (*in_clusters)[cluster_ix];
edm4eic::MutableCluster out_cluster = in_cluster.clone();
out_clusters->push_back(out_cluster);

float prob_pion = prediction_tensor.getFloatData(cluster_ix * prediction_tensor.getShape(1) + 0);
float prob_electron = prediction_tensor.getFloatData(cluster_ix * prediction_tensor.getShape(1) + 1);

out_cluster.addToParticleIDs(out_particle_ids->create(
0, // std::int32_t type
211, // std::int32_t PDG
0, // std::int32_t algorithmType
prob_pion // float likelihood
));
out_cluster.addToParticleIDs(out_particle_ids->create(
0, // std::int32_t type
11, // std::int32_t PDG
0, // std::int32_t algorithmType
prob_electron // float likelihood
));

// propagate associations
for (auto in_assoc : *in_assocs) {
simonge marked this conversation as resolved.
Show resolved Hide resolved
if (in_assoc.getRec() == in_cluster) {
auto out_assoc = in_assoc.clone();
out_assoc.setRec(out_cluster);
out_assocs->push_back(out_assoc);
}
}
}
}

} // namespace eicrecon
#endif
44 changes: 44 additions & 0 deletions src/algorithms/onnx/CalorimeterParticleIDPostML.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
// SPDX-License-Identifier: LGPL-3.0-or-later
// Copyright (C) 2024 Dmitry Kalinkin

#pragma once

#include <algorithms/algorithm.h>
#include <edm4eic/ClusterCollection.h>
#include <edm4eic/MCRecoClusterParticleAssociationCollection.h>
#include <edm4eic/TensorCollection.h>
#include <edm4hep/ParticleIDCollection.h>
#include <optional>
#include <string>
#include <string_view>

#include "algorithms/interfaces/WithPodConfig.h"

namespace eicrecon {

using CalorimeterParticleIDPostMLAlgorithm =
algorithms::Algorithm<
algorithms::Input<edm4eic::ClusterCollection,
std::optional<edm4eic::MCRecoClusterParticleAssociationCollection>,
edm4eic::TensorCollection>,
algorithms::Output<edm4eic::ClusterCollection,
std::optional<edm4eic::MCRecoClusterParticleAssociationCollection>,
edm4hep::ParticleIDCollection>
>;

class CalorimeterParticleIDPostML : public CalorimeterParticleIDPostMLAlgorithm,
public WithPodConfig<NoConfig> {

public:
CalorimeterParticleIDPostML(std::string_view name)
: CalorimeterParticleIDPostMLAlgorithm{name,
{"inputClusters", "inputClusterAssociations", "inputPredictionsTensor"},
{"outputClusters", "outputClusterAssociations", "outputParticleIDs"},
""} {
}

void init() final;
void process(const Input&, const Output&) const final;
};

} // namespace eicrecon
100 changes: 100 additions & 0 deletions src/algorithms/onnx/CalorimeterParticleIDPreML.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
// SPDX-License-Identifier: LGPL-3.0-or-later
// Copyright (C) 2024 Dmitry Kalinkin

#include <edm4eic/EDM4eicVersion.h>

#if EDM4EIC_VERSION_MAJOR >= 8
#include <cstddef>
#include <edm4hep/MCParticle.h>
#include <edm4hep/Vector3f.h>
#include <edm4hep/utils/vector_utils.h>
#include <cstdint>
#include <stdexcept>
#include <fmt/core.h>
#include <gsl/pointers>

#include "CalorimeterParticleIDPreML.h"

namespace eicrecon {

void CalorimeterParticleIDPreML::init() {
// Nothing
}

void CalorimeterParticleIDPreML::process(
const CalorimeterParticleIDPreML::Input& input,
const CalorimeterParticleIDPreML::Output& output) const {

const auto [clusters, cluster_assocs] = input;
auto [feature_tensors, target_tensors] = output;

edm4eic::MutableTensor feature_tensor = feature_tensors->create();
feature_tensor.addToShape(clusters->size());
feature_tensor.addToShape(11); // p, E/p, azimuthal, polar, 7 shape parameters
feature_tensor.setElementType(1); // 1 - float

edm4eic::MutableTensor target_tensor;
if (cluster_assocs) {
target_tensor = target_tensors->create();
target_tensor.addToShape(clusters->size());
target_tensor.addToShape(2); // is electron, is hadron
target_tensor.setElementType(7); // 7 - int64
}

for (edm4eic::Cluster cluster : *clusters) {
double momentum;
{
// FIXME: use track momentum once matching to tracks becomes available
edm4eic::MCRecoClusterParticleAssociation best_assoc;
for (auto assoc : *cluster_assocs) {
if (assoc.getRec() == cluster) {
if ((not best_assoc.isAvailable()) || (assoc.getWeight() > best_assoc.getWeight())) {
best_assoc = assoc;
}
}
}
if (best_assoc.isAvailable()) {
momentum = edm4hep::utils::magnitude(best_assoc.getSim().getMomentum());
} else {
warning("Can't find association for cluster. Skipping...");
continue;
}
}

feature_tensor.addToFloatData(momentum);
feature_tensor.addToFloatData(cluster.getEnergy() / momentum);
auto pos = cluster.getPosition();
feature_tensor.addToFloatData(edm4hep::utils::anglePolar(pos));
feature_tensor.addToFloatData(edm4hep::utils::angleAzimuthal(pos));
for (int par_ix = 0; par_ix < cluster.shapeParameters_size(); par_ix++) {
feature_tensor.addToFloatData(cluster.getShapeParameters(par_ix));
}

if (cluster_assocs) {
edm4eic::MCRecoClusterParticleAssociation best_assoc;
for (auto assoc : *cluster_assocs) {
if (assoc.getRec() == cluster) {
if ((not best_assoc.isAvailable()) || (assoc.getWeight() > best_assoc.getWeight())) {
best_assoc = assoc;
}
}
}
int64_t is_electron = 0, is_pion = 0;
if (best_assoc.isAvailable()) {
is_electron = best_assoc.getSim().getPDG() == 11;
is_pion = best_assoc.getSim().getPDG() != 11;
}
target_tensor.addToInt64Data(is_pion);
target_tensor.addToInt64Data(is_electron);
}
}

size_t expected_num_entries = feature_tensor.getShape(0) * feature_tensor.getShape(1);
if (feature_tensor.floatData_size() != expected_num_entries) {
error("Inconsistent output tensor shape and element count: {} != {}", feature_tensor.floatData_size(), expected_num_entries);
throw std::runtime_error(fmt::format("Inconsistent output tensor shape and element count: {} != {}", feature_tensor.floatData_size(), expected_num_entries));
}
}

} // namespace eicrecon
#endif
39 changes: 39 additions & 0 deletions src/algorithms/onnx/CalorimeterParticleIDPreML.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
// SPDX-License-Identifier: LGPL-3.0-or-later
// Copyright (C) 2024 Dmitry Kalinkin

#pragma once

#include <algorithms/algorithm.h>
#include <edm4eic/ClusterCollection.h>
#include <edm4eic/MCRecoClusterParticleAssociationCollection.h>
#include <edm4eic/TensorCollection.h>
#include <optional>
#include <string>
#include <string_view>

#include "algorithms/interfaces/WithPodConfig.h"

namespace eicrecon {

using CalorimeterParticleIDPreMLAlgorithm =
algorithms::Algorithm<algorithms::Input<edm4eic::ClusterCollection,
std::optional<edm4eic::MCRecoClusterParticleAssociationCollection>>,
algorithms::Output<edm4eic::TensorCollection,
std::optional<edm4eic::TensorCollection>>>;

class CalorimeterParticleIDPreML : public CalorimeterParticleIDPreMLAlgorithm,
public WithPodConfig<NoConfig> {

public:
CalorimeterParticleIDPreML(std::string_view name)
: CalorimeterParticleIDPreMLAlgorithm{name,
{"inputClusters"},
{"outputFeatureTensor", "outputTargetTensor"},
""} {
}

void init() final;
void process(const Input&, const Output&) const final;
};

} // namespace eicrecon
2 changes: 2 additions & 0 deletions src/algorithms/onnx/InclusiveKinematicsML.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ namespace eicrecon {
// onnxruntime setup
m_env = Ort::Env(ORT_LOGGING_LEVEL_WARNING, "inclusive-kinematics-ml");
Ort::SessionOptions session_options;
session_options.SetInterOpNumThreads(1);
session_options.SetIntraOpNumThreads(1);
try {
m_session = Ort::Session(m_env, m_cfg.modelPath.c_str(), session_options);

Expand Down
Loading
Loading