Skip to content

Commit

Permalink
update pyo3
Browse files Browse the repository at this point in the history
  • Loading branch information
powei-lin committed Apr 12, 2024
1 parent 80472e5 commit adb51aa
Show file tree
Hide file tree
Showing 6 changed files with 44 additions and 56 deletions.
48 changes: 15 additions & 33 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

12 changes: 6 additions & 6 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "tiny-solver"
version = "0.4.1"
version = "0.5.0"
edition = "2021"
authors = ["Powei Lin <poweilin1994@gmail.com>"]
readme = "README.md"
Expand All @@ -19,19 +19,19 @@ faer = "0.18.2"
faer-ext = { version = "0.1.0", features = ["nalgebra"] }
log = "0.4.21"
nalgebra = "0.32.4"
num-dual = "0.8.1"
num-dual = "0.9.0"
num-traits = "0.2.18"
numpy = { version = "0.20.0", features = ["nalgebra"], optional = true }
pyo3 = { version = "0.20.3", features = ["abi3", "abi3-py38"] }
pyo3-log = { version = "0.9.0", optional = true }
numpy = { version = "0.21.0", features = ["nalgebra"], optional = true }
pyo3 = { version = "0.21.0", features = ["abi3", "abi3-py38"] }
# pyo3-log = { version = "0.9.0", optional = true }
rayon = "1.9.0"

[[example]]
name = "m3500_benchmark"
path = "examples/m3500_benchmark.rs"

[features]
python = ["num-dual/python", "numpy", "pyo3-log"]
python = ["num-dual/python", "numpy"]

[dev-dependencies]
env_logger = "0.11.3"
Expand Down
24 changes: 14 additions & 10 deletions src/python/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,36 +10,40 @@ mod py_optimizer;
mod py_problem;
use self::py_factors::*;

fn register_child_module(py: Python<'_>, parent_module: &PyModule) -> PyResult<()> {
fn register_child_module(parent_module: &Bound<'_, PyModule>) -> PyResult<()> {
// For factors submodule
let factors_module = PyModule::new(py, "factors")?;
let factors_module = PyModule::new_bound(parent_module.py(), "factors")?;
factors_module.add_class::<BetweenFactorSE2>()?;
factors_module.add_class::<PriorFactor>()?;
factors_module.add_class::<PyFactor>()?;
parent_module.add_submodule(factors_module)?;
py.import("sys")?
parent_module.add_submodule(&factors_module)?;
parent_module
.py()
.import_bound("sys")?
.getattr("modules")?
.set_item("tiny_solver.factors", factors_module)?;

let loss_functions_module = PyModule::new(py, "loss_functions")?;
let loss_functions_module = PyModule::new_bound(parent_module.py(), "loss_functions")?;
loss_functions_module.add_class::<HuberLoss>()?;
parent_module.add_submodule(loss_functions_module)?;
py.import("sys")?
parent_module.add_submodule(&loss_functions_module)?;
parent_module
.py()
.import_bound("sys")?
.getattr("modules")?
.set_item("tiny_solver.loss_functions", loss_functions_module)?;
Ok(())
}

/// A Python module implemented in Rust.
#[pymodule]
pub fn tiny_solver<'py>(_py: Python<'py>, m: &'py PyModule) -> PyResult<()> {
pyo3_log::init();
pub fn tiny_solver(m: &Bound<'_, PyModule>) -> PyResult<()> {
// pyo3_log::init();
m.add("__version__", env!("CARGO_PKG_VERSION"))?;
m.add_class::<Problem>()?;
m.add_class::<LinearSolver>()?;
m.add_class::<OptimizerOptions>()?;
m.add_class::<GaussNewtonOptimizer>()?;
register_child_module(_py, m)?;
register_child_module(m)?;

Ok(())
}
2 changes: 1 addition & 1 deletion src/python/py_factors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ impl Factor for PyFactor {
})
.map(|x| x.into_py(py))
.collect();
let args = PyTuple::new(py, py_params);
let args = PyTuple::new_bound(py, py_params);
let result = self.func.call1(py, args);
let residual_py = result.unwrap().extract::<Vec<PyDual64Dyn>>(py);
residual_py
Expand Down
4 changes: 2 additions & 2 deletions src/python/py_optimizer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ impl GaussNewtonOptimizer {
&self,
py: Python<'_>,
problem: &Problem,
initial_values: &PyDict,
initial_values: &Bound<'_, PyDict>,
optimizer_options: Option<OptimizerOptions>,
) -> PyResult<HashMap<String, Py<PyArray2<f64>>>> {
let init_values: HashMap<String, PyReadonlyArray1<f64>> = initial_values.extract().unwrap();
Expand All @@ -35,7 +35,7 @@ impl GaussNewtonOptimizer {

let output_d: HashMap<String, Py<PyArray2<f64>>> = result
.iter()
.map(|(k, v)| (k.to_string(), v.to_pyarray(py).to_owned().into()))
.map(|(k, v)| (k.to_string(), v.to_pyarray_bound(py).to_owned().into()))
.collect();
Ok(output_d)
}
Expand Down
10 changes: 6 additions & 4 deletions src/python/py_problem.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use crate::problem::Problem;

use super::PyFactor;

fn convert_pyany_to_factor(py_any: &PyAny) -> PyResult<(bool, Box<dyn Factor + Send>)> {
fn convert_pyany_to_factor(py_any: &Bound<'_, PyAny>) -> PyResult<(bool, Box<dyn Factor + Send>)> {
let factor_name: String = py_any.get_type().getattr("__name__")?.extract()?;
match factor_name.as_str() {
"BetweenFactorSE2" => {
Expand All @@ -26,7 +26,9 @@ fn convert_pyany_to_factor(py_any: &PyAny) -> PyResult<(bool, Box<dyn Factor + S
)),
}
}
fn convert_pyany_to_loss_function(py_any: &PyAny) -> PyResult<Option<Box<dyn Loss + Send>>> {
fn convert_pyany_to_loss_function(
py_any: &Bound<'_, PyAny>,
) -> PyResult<Option<Box<dyn Loss + Send>>> {
let factor_name: String = py_any.get_type().getattr("__name__")?.extract()?;
match factor_name.as_str() {
"HuberLoss" => {
Expand All @@ -52,8 +54,8 @@ impl Problem {
&mut self,
dim_residual: usize,
variable_key_size_list: Vec<(String, usize)>,
pyfactor: &PyAny,
pyloss_func: &PyAny,
pyfactor: &Bound<'_, PyAny>,
pyloss_func: &Bound<'_, PyAny>,
) -> PyResult<()> {
let (is_pyfactor, factor) = convert_pyany_to_factor(pyfactor).unwrap();
self.add_residual_block(
Expand Down

0 comments on commit adb51aa

Please sign in to comment.