Skip to content

Commit

Permalink
Merge branch 'andreytkachenko-opt-out-python-bindings'
Browse files Browse the repository at this point in the history
  • Loading branch information
bwsw committed Jul 27, 2023
2 parents 78747ad + 83b3741 commit 58bcd5e
Show file tree
Hide file tree
Showing 27 changed files with 1,891 additions and 1,263 deletions.
17 changes: 13 additions & 4 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ repository = "https://github.com/insight-platform/Similari"
readme = "README.md"
keywords = ["machine-learning", "similarity", "tracking", "SORT", "DeepSORT"]
categories = ["algorithms", "data-structures", "computer-vision", "science"]
version = "0.26.4"
version = "0.26.5"
edition = "2021"
license="Apache-2.0"
rust-version = "1.66"
Expand All @@ -16,6 +16,10 @@ rust-version = "1.66"
crate-type = ["cdylib", "lib"]
name = "similari"

[features]
default = ["python"]
python = ["dep:pyo3", "dep:pyo3-build-config","dep:pyo3-log"]

[dependencies]
itertools = "0.10"
anyhow = "1.0"
Expand All @@ -27,17 +31,22 @@ crossbeam = "0.8"
rand = "0.8"
log = "0.4"
nalgebra = "0.32"
pathfinding = "4.2"
geo = "0.23"
pathfinding = "4.3"
geo = "0.25"
rayon = "1.7"
env_logger = "0.10"

[dependencies.pyo3]
version = "0.18"
features = ["extension-module"]
optional = true

[dependencies.pyo3-log]
version = "0.8"
optional = true

[build-dependencies]
pyo3-build-config = "0.18"
pyo3-build-config = { version = "0.18", optional = true }

[dev-dependencies]
wide = "0.7"
Expand Down
4 changes: 2 additions & 2 deletions benches/nms_oriented.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ fn bench_nms(objects: usize, b: &mut Bencher) {
for (indx, i) in iterators.iter_mut().enumerate() {
let b = i.next();
let bb: Universal2DBox = b.unwrap().into();
observations.push((bb.rotate(indx as f32 / 10.0).gen_vertices(), None));
observations.push((bb.rotate(indx as f32 / 10.0).gen_vertices().clone(), None));
}
nms(&observations, 0.8, None);
nms(observations.as_slice(), 0.8, None);
});
}
1 change: 1 addition & 0 deletions build.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
fn main() {
#[cfg(feature = "python")]
pyo3_build_config::add_extension_module_link_args();
}
1 change: 1 addition & 0 deletions python/visual_sort/visual_sort.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def get_opts():
tracker = VisualSort(shards=1, opts=get_opts())


assert False

# let's say frame_objs is a list of objs detected in a frame
for frame_objs in frames:
Expand Down
48 changes: 25 additions & 23 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,20 +79,23 @@ pub enum Errors {

pub const EPS: f32 = 0.00001;

#[cfg(feature = "python")]
mod python {
use crate::prelude::{
BatchSort, BoundingBox, Sort, SortTrack, SpatioTemporalConstraints, Universal2DBox,
VisualSort, VisualSortOptions,
use crate::trackers::batch::python::PyPredictionBatchResult;
use crate::trackers::sort::batch_api::python::{PyBatchSort, PySortPredictionBatchRequest};
use crate::trackers::sort::python::{PyPositionalMetricType, PySortTrack, PyWastedSortTrack};
use crate::trackers::sort::simple_api::python::PySort;
use crate::trackers::spatio_temporal_constraints::python::PySpatioTemporalConstraints;
use crate::trackers::visual_sort::batch_api::python::{
PyBatchVisualSort, PyVisualSortPredictionBatchRequest,
};
use crate::trackers::batch::PredictionBatchResult;
use crate::trackers::sort::sort_py::PySortPredictionBatchRequest;
use crate::trackers::sort::{PyPositionalMetricType, PyWastedSortTrack};
use crate::trackers::visual_sort::batch_api::BatchVisualSort;
use crate::trackers::visual_sort::metric::PyVisualSortMetricType;
use crate::trackers::visual_sort::visual_sort_py::{
PyVisualSortObservation, PyVisualSortObservationSet, PyVisualSortPredictionBatchRequest,
use crate::trackers::visual_sort::metric::python::PyVisualSortMetricType;
use crate::trackers::visual_sort::options::python::PyVisualSortOptions;
use crate::trackers::visual_sort::python::{
PyVisualSortObservation, PyVisualSortObservationSet, PyWastedVisualSortTrack,
};
use crate::trackers::visual_sort::PyWastedVisualSortTrack;
use crate::trackers::visual_sort::simple_api::python::PyVisualSort;
use crate::utils::bbox::python::{PyBoundingBox, PyUniversal2DBox};
use crate::utils::clipping::clipping_py::{
intersection_area_py, sutherland_hodgman_clip_py, PyPolygon,
};
Expand All @@ -114,12 +117,12 @@ mod python {
#[pymodule]
#[pyo3(name = "similari")]
fn similari(_py: Python, m: &PyModule) -> PyResult<()> {
let _ = env_logger::try_init();
pyo3_log::init();

m.add_class::<BoundingBox>()?;
m.add_class::<Universal2DBox>()?;
m.add_class::<PyBoundingBox>()?;
m.add_class::<PyUniversal2DBox>()?;
m.add_class::<PyPolygon>()?;
m.add_class::<SortTrack>()?;
m.add_class::<PySortTrack>()?;
m.add_class::<PyWastedSortTrack>()?;

m.add_class::<PyUniversal2DBoxKalmanFilterState>()?;
Expand All @@ -131,25 +134,24 @@ mod python {
m.add_class::<PyVec2DKalmanFilter>()?;

m.add_class::<PySortPredictionBatchRequest>()?;
m.add_class::<SpatioTemporalConstraints>()?;
m.add_class::<Sort>()?;
m.add_class::<PySpatioTemporalConstraints>()?;
m.add_class::<PySort>()?;

m.add_class::<PyPositionalMetricType>()?;
m.add_class::<PyVisualSortMetricType>()?;
m.add_class::<VisualSortOptions>()?;
m.add_class::<PyVisualSortOptions>()?;
m.add_class::<PyVisualSortObservation>()?;
m.add_class::<PyVisualSortObservationSet>()?;
m.add_class::<PyVisualSortPredictionBatchRequest>()?;
m.add_class::<PyWastedVisualSortTrack>()?;
m.add_class::<VisualSort>()?;
m.add_class::<PyVisualSort>()?;

m.add_class::<PredictionBatchResult>()?;
m.add_class::<PyPredictionBatchResult>()?;

m.add_class::<PySortPredictionBatchRequest>()?;
m.add_class::<BatchSort>()?;
m.add_class::<PyBatchSort>()?;

m.add_class::<PyVisualSortPredictionBatchRequest>()?;
m.add_class::<BatchVisualSort>()?;
m.add_class::<PyBatchVisualSort>()?;

m.add_function(wrap_pyfunction!(version, m)?)?;
m.add_function(wrap_pyfunction!(nms_py, m)?)?;
Expand Down
47 changes: 34 additions & 13 deletions src/trackers/batch.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate::prelude::SortTrack;
use crossbeam::channel::{Receiver, Sender};
use log::debug;
use pyo3::prelude::*;

use std::collections::HashMap;
use std::sync::{Arc, Mutex};

Expand All @@ -15,34 +15,56 @@ pub struct PredictionBatchRequest<T> {
batch_size: Arc<Mutex<usize>>,
}

#[pyclass]
#[derive(Clone, Debug)]
pub struct PredictionBatchResult {
receiver: Receiver<SceneTracks>,
batch_size: Arc<Mutex<usize>>,
}

#[pymethods]
impl PredictionBatchResult {
pub fn ready(&self) -> bool {
!self.receiver.is_empty()
}

#[pyo3(name = "get", signature = ())]
fn get_py(&self) -> SceneTracks {
Python::with_gil(|py| py.allow_threads(|| self.get()))
pub fn get(&self) -> SceneTracks {
self.receiver
.recv()
.expect("Receiver must always receive batch computation result")
}

pub fn batch_size(&self) -> usize {
*self.batch_size.lock().unwrap()
}
}

impl PredictionBatchResult {
pub fn get(&self) -> SceneTracks {
self.receiver
.recv()
.expect("Receiver must always receive batch computation result")
#[cfg(feature = "python")]
pub mod python {
use crate::trackers::sort::python::PySortTrack;

use super::PredictionBatchResult;
use pyo3::prelude::*;

pub type PySceneTracks = (u64, Vec<PySortTrack>);

#[pyclass]
#[derive(Clone, Debug)]
#[pyo3(name = "PredictionBatchResult")]
pub struct PyPredictionBatchResult(pub(crate) PredictionBatchResult);

#[pymethods]
impl PyPredictionBatchResult {
pub fn ready(&self) -> bool {
self.0.ready()
}

#[pyo3(signature = ())]
fn get(&self) -> PySceneTracks {
Python::with_gil(|py| py.allow_threads(|| unsafe { std::mem::transmute(self.0.get()) }))
}

pub fn batch_size(&self) -> usize {
self.0.batch_size()
}
}
}

Expand Down Expand Up @@ -112,8 +134,7 @@ mod tests {
request.add(0, Universal2DBox::new(0.0, 0.0, Some(0.5), 1.0, 5.0));
request.add(0, Universal2DBox::new(5.0, 5.0, Some(0.0), 1.5, 10.0));
request.add(1, Universal2DBox::new(0.0, 0.0, Some(1.0), 0.7, 5.1));
let batch = request.get_batch();
drop(batch);
let _batch = request.get_batch();
assert_eq!(result.batch_size(), 2);

assert!(request.send((0, vec![])));
Expand Down
Loading

0 comments on commit 58bcd5e

Please sign in to comment.