diff --git a/.gitignore b/.gitignore index 806d6d2..6f6c7a5 100644 --- a/.gitignore +++ b/.gitignore @@ -3,4 +3,6 @@ __pycache__ *.so .DS_Store -.vscode/* \ No newline at end of file +.vscode/* +.venv +.pytest_cache \ No newline at end of file diff --git a/Cargo.lock b/Cargo.lock index 1992913..132804d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -46,9 +46,9 @@ checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" [[package]] name = "bitflags" -version = "2.4.2" +version = "2.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ed570934406eb16438a4e976b1b4500774099c13b8cb96eec99f620f05090ddf" +checksum = "cf4b9d6a944f767f8e5e0db018570623c85f3d925ac718db4e06d0187adb21c1" [[package]] name = "block-buffer" @@ -67,9 +67,9 @@ checksum = "7ff69b9dd49fd426c69a0db9fc04dd934cdb6645ff000864d98f7e2af8830eaa" [[package]] name = "bytemuck" -version = "1.14.3" +version = "1.15.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a2ef034f05691a48569bd920a96c81b9d91bbad1ab5ac7c4616c1f6ef36cb79f" +checksum = "5d6d68c57235a3a081186990eca2867354726650f42f7516ca50c28d6281fd15" dependencies = [ "bytemuck_derive", ] @@ -82,7 +82,7 @@ checksum = "4da9a32f3fed317401fa3c862968128267c3106685286e15d5aaa3d7389c2f60" dependencies = [ "proc-macro2", "quote", - "syn 2.0.52", + "syn 2.0.53", ] [[package]] @@ -329,7 +329,7 @@ dependencies = [ "heck", "proc-macro2", "quote", - "syn 2.0.52", + "syn 2.0.53", ] [[package]] @@ -349,7 +349,7 @@ checksum = "60d08acb9849f7fb4401564f251be5a526829183a3645a90197dea8e786cf3ae" dependencies = [ "proc-macro2", "quote", - "syn 2.0.52", + "syn 2.0.53", ] [[package]] @@ -701,6 +701,21 @@ version = "2.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1e186cfbae8084e513daff4240b4797e342f988cecda4fb6c939150f96315fd8" +[[package]] +name = "inventory" +version = "0.3.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f958d3d68f4167080a18141e10381e7634563984a537f2a49a30fd8e53ac5767" + +[[package]] +name = "itertools" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba291022dbbd398a455acf126c1e341954079855bc60dfdda641363bd6922569" +dependencies = [ + "either", +] + [[package]] name = "jpeg-decoder" version = "0.3.1" @@ -750,7 +765,7 @@ version = "0.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "85c833ca1e66078851dba29046874e38f08b2c883700aa29a03ddd3b23814ee8" dependencies = [ - "bitflags 2.4.2", + "bitflags 2.5.0", "libc", "redox_syscall", ] @@ -902,7 +917,10 @@ checksum = "cefb764932cd108cfce13339af181a7de5382d33e6b52c64d3f2aaec0f1f8c47" dependencies = [ "approx", "nalgebra", + "ndarray", "num-traits", + "numpy", + "pyo3", "simba", ] @@ -1037,7 +1055,7 @@ dependencies = [ "pest_meta", "proc-macro2", "quote", - "syn 2.0.52", + "syn 2.0.53", ] [[package]] @@ -1124,18 +1142,18 @@ checksum = "7170ef9988bc169ba16dd36a7fa041e5c4cbeb6a35b76d4c03daded371eae7c0" [[package]] name = "proc-macro2" -version = "1.0.78" +version = "1.0.79" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e2422ad645d89c99f8f3e6b88a9fdeca7fabeac836b1002371c4367c8f984aae" +checksum = "e835ff2298f5721608eb1a980ecaee1aef2c132bf95ecc026a11b7bf3c01c02e" dependencies = [ "unicode-ident", ] [[package]] name = "pulp" -version = "0.18.8" +version = "0.18.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "091bad01115892393939669b38f88ff2b70838e969a7ac172a9d06d05345a732" +checksum = "03457ac216146f43f921500bac4e892d5cd32b0479b929cbfc90f95cd6c599c2" dependencies = [ "bytemuck", "libm", @@ -1164,6 +1182,7 @@ checksum = "53bdbb96d49157e65d45cc287af5f32ffadd5f4761438b527b055fb0d4bb8233" dependencies = [ "cfg-if", "indoc", + "inventory", "libc", "memoffset", "parking_lot", @@ -1203,7 +1222,7 @@ dependencies = [ "proc-macro2", "pyo3-macros-backend", "quote", - "syn 2.0.52", + "syn 2.0.53", ] [[package]] @@ -1216,7 +1235,7 @@ dependencies = [ "proc-macro2", "pyo3-build-config", "quote", - "syn 2.0.52", + "syn 2.0.53", ] [[package]] @@ -1382,7 +1401,7 @@ checksum = "7eb0b34b42edc17f6b7cac84a52a1c5f0e1bb2227e997ca9011ea3dd34e8610b" dependencies = [ "proc-macro2", "quote", - "syn 2.0.52", + "syn 2.0.53", ] [[package]] @@ -1417,9 +1436,9 @@ checksum = "d66dc143e6b11c1eddc06d5c423cfc97062865baf299914ab64caa38182078fe" [[package]] name = "smallvec" -version = "1.13.1" +version = "1.13.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e6ecd384b10a64542d77071bd64bd7b231f4ed5940fba55e98c3de13824cf3d7" +checksum = "3c5e1a9a646d36c3599cd173a41282daf47c44583ad367b8e6837255952e5c67" [[package]] name = "syn" @@ -1434,9 +1453,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.52" +version = "2.0.53" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b699d15b36d1f02c3e7c69f8ffef53de37aefae075d8488d4ba1a7788d574a07" +checksum = "7383cd0e49fff4b6b90ca5670bfd3e9d6a733b3f90c686605aa7eec8c4996032" dependencies = [ "proc-macro2", "quote", @@ -1449,7 +1468,7 @@ version = "0.5.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ec7dddc5f0fee506baf8b9fdb989e242f17e4b11c61dfbb0635b705217199eea" dependencies = [ - "bitflags 2.4.2", + "bitflags 2.5.0", "byteorder", "enum-as-inner", "libc", @@ -1465,30 +1484,31 @@ checksum = "e1fc403891a21bcfb7c37834ba66a547a8f402146eba7265b5a6d88059c9ff2f" [[package]] name = "thiserror" -version = "1.0.57" +version = "1.0.58" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1e45bcbe8ed29775f228095caf2cd67af7a4ccf756ebff23a306bf3e8b47b24b" +checksum = "03468839009160513471e86a034bb2c5c0e4baae3b43f79ffc55c4a5427b3297" dependencies = [ "thiserror-impl", ] [[package]] name = "thiserror-impl" -version = "1.0.57" +version = "1.0.58" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a953cb265bef375dae3de6663da4d3804eee9682ea80d8e2542529b73c531c81" +checksum = "c61f3ba182994efc43764a46c018c347bc492c79f024e705f46567b418f6d4f7" dependencies = [ "proc-macro2", "quote", - "syn 2.0.52", + "syn 2.0.53", ] [[package]] name = "tiny-solver" -version = "0.3.4" +version = "0.4.0" dependencies = [ "faer", "faer-ext", + "itertools", "nalgebra", "num-dual", "num-traits", @@ -1571,7 +1591,7 @@ dependencies = [ "once_cell", "proc-macro2", "quote", - "syn 2.0.52", + "syn 2.0.53", "wasm-bindgen-shared", ] @@ -1593,7 +1613,7 @@ checksum = "e94f17b526d0a461a191c78ea52bbce64071ed5c04c9ffe424dcb38f74171bb7" dependencies = [ "proc-macro2", "quote", - "syn 2.0.52", + "syn 2.0.53", "wasm-bindgen-backend", "wasm-bindgen-shared", ] diff --git a/Cargo.toml b/Cargo.toml index 2ed8fcf..64d2c75 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "tiny-solver" -version = "0.3.4" +version = "0.4.0" edition = "2021" authors = ["Powei Lin "] readme = "README.md" @@ -18,7 +18,7 @@ exclude = ["/.github/*", "*.ipynb", "./scripts/*", "examples/*"] faer = "0.18.2" faer-ext = { version = "0.1.0", features = ["nalgebra"] } nalgebra = "0.32.4" -num-dual = "0.8.1" +num-dual = "0.8.1" num-traits = "0.2.18" numpy = { version = "0.20.0", features = ["nalgebra"] } pyo3 = { version = "0.20.3", features = ["abi3", "abi3-py38"] } @@ -26,10 +26,13 @@ rayon = "1.9.0" [[example]] name = "m3500_benchmark" -# crate-type = ["bin"] path = "examples/m3500_benchmark.rs" +[features] +python = ["num-dual/python"] + [dev-dependencies] +itertools = "0.12.1" plotters = "0.3.5" [profile.dev.package.faer] @@ -37,4 +40,4 @@ opt-level = 3 [lib] name = "tiny_solver" -# crate-type = ["cdylib"] +# crate-type = ["staticlib"] diff --git a/examples/python/small_problem.py b/examples/python/small_problem.py new file mode 100644 index 0000000..2efa2ae --- /dev/null +++ b/examples/python/small_problem.py @@ -0,0 +1,24 @@ +import numpy as np +from tiny_solver import Problem, GaussNewtonOptimizer +from tiny_solver.factors import PriorFactor, PyFactor + + +def cost(x, y, z): + r0 = x[0] + 2*y[0] + 4*z[0] + r1 = y[0] * z[0] + return np.array([r0, r1]) + +def main(): + problem = Problem() + pf = PyFactor(cost) + problem.add_residual_block(2, [('x', 1), ('y', 1), ('z', 1),], pf, None) + pp = PriorFactor(np.array([3.0])) + problem.add_residual_block(1, [('x', 1)], pp, None) + gn = GaussNewtonOptimizer() + init_values = {"x": np.array([0.7]), "y": np.array([-30.2]), "z": np.array([123.9])} + result_values = gn.optimize(problem, init_values) + print(result_values) + + +if __name__ == '__main__': + main() diff --git a/examples/python/try_import.py b/examples/python/try_import.py index 28631f3..e73cb17 100644 --- a/examples/python/try_import.py +++ b/examples/python/try_import.py @@ -1,9 +1,17 @@ import tiny_solver -from tiny_solver import GaussNewtonOptimizer, Problem, LinearSolver, OptimizerOptions -from tiny_solver.factors import PriorFactor, BetweenFactorSE2 +from tiny_solver import GaussNewtonOptimizer, Problem, LinearSolver, OptimizerOptions, first_derivative_test +from tiny_solver.factors import PriorFactor, BetweenFactorSE2, PyFactor from tiny_solver.loss_functions import HuberLoss import numpy as np +def f(x: np.ndarray, y: np.ndarray): + # print("py ", x*x) + return np.array([2*x[0], x[1]*x[1]*x[1], y[1]*4.0]) + +def fa(): + print("fa") + return 123 + def main(): print(f"{tiny_solver.__version__=}") @@ -14,6 +22,11 @@ def main(): print(opt_option) loss = HuberLoss(1.0) print(loss) + a = np.array([1.0, 2.0]) + # j = first_derivative_test(f, a) + # print(j) + a = PyFactor(f) + a.call_func() exit() # print(tiny_solver.sum_as_string(1, 2)) diff --git a/pyproject.toml b/pyproject.toml index 2d3963b..248ba54 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,4 +20,4 @@ dependencies = ["numpy"] [tool.maturin] -features = ["pyo3/extension-module"] +features = ["pyo3/extension-module", "python"] diff --git a/src/gauss_newton_optimizer.rs b/src/gauss_newton_optimizer.rs index 3247731..f0a5da9 100644 --- a/src/gauss_newton_optimizer.rs +++ b/src/gauss_newton_optimizer.rs @@ -21,7 +21,6 @@ impl optimizer::Optimizer for GaussNewtonOptimizer { let opt_option = optimizer_option.unwrap_or_default(); let mut last_err: f64 = 1.0; - let mut symbolic_pattern: Option> = None; for i in 0..opt_option.max_iteration { diff --git a/src/lib.rs b/src/lib.rs index b43367b..0a0b078 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -4,7 +4,6 @@ pub mod linear; pub mod loss_functions; pub mod optimizer; pub mod problem; -pub mod python; pub mod residual_block; pub mod tiny_solver_old; @@ -12,5 +11,7 @@ pub use gauss_newton_optimizer::*; pub use linear::*; pub use optimizer::*; pub use problem::*; -pub use python::*; pub use residual_block::*; + +#[cfg(feature = "python")] +pub mod python; diff --git a/src/problem.rs b/src/problem.rs index cbe9ed0..f22be06 100644 --- a/src/problem.rs +++ b/src/problem.rs @@ -15,6 +15,7 @@ pub struct Problem { pub total_residual_dimension: usize, residual_blocks: Vec, pub variable_name_to_col_idx_dict: HashMap, + _has_py_factor: bool, } impl Problem { pub fn new() -> Problem { @@ -23,6 +24,7 @@ impl Problem { total_residual_dimension: 0, residual_blocks: Vec::::new(), variable_name_to_col_idx_dict: HashMap::::new(), + _has_py_factor: false, } } pub fn add_residual_block( @@ -70,52 +72,35 @@ impl Problem { variable_key_value_map: &HashMap>, ) -> (faer::Mat, SparseColMat) { // multi - let total_residual = Arc::new(Mutex::new(na::DVector::::zeros( + let total_residual: Arc< + Mutex< + na::Matrix, na::VecStorage>>, + >, + > = Arc::new(Mutex::new(na::DVector::::zeros( self.total_residual_dimension, ))); - let jacobian_list = Arc::new(Mutex::new(Vec::<(usize, usize, f64)>::new())); + let jacobian_list: Arc>> = + Arc::new(Mutex::new(Vec::<(usize, usize, f64)>::new())); - self.residual_blocks.par_iter().for_each(|residual_block| { - let mut params = Vec::>::new(); - let mut variable_local_idx_size_list = Vec::<(usize, usize)>::new(); - let mut count_variable_local_idx: usize = 0; - for vk in &residual_block.variable_key_list { - if let Some(param) = variable_key_value_map.get(vk) { - params.push(param.clone()); - variable_local_idx_size_list.push((count_variable_local_idx, param.shape().0)); - count_variable_local_idx += param.shape().0; - }; - } - let (res, jac) = residual_block.jacobian(¶ms); - - { - let mut total_residual = total_residual.lock().unwrap(); - total_residual - .rows_mut( - residual_block.residual_row_start_idx, - residual_block.dim_residual, - ) - .copy_from(&res); - } - - for (i, vk) in residual_block.variable_key_list.iter().enumerate() { - if let Some(variable_global_idx) = self.variable_name_to_col_idx_dict.get(vk) { - let (variable_local_idx, var_size) = variable_local_idx_size_list[i]; - let variable_jac = jac.view((0, variable_local_idx), (jac.shape().0, var_size)); - let mut local_jacobian_list = Vec::new(); - for row_idx in 0..jac.shape().0 { - for col_idx in 0..var_size { - let global_row_idx = residual_block.residual_row_start_idx + row_idx; - let global_col_idx = variable_global_idx + col_idx; - let value = variable_jac[(row_idx, col_idx)]; - local_jacobian_list.push((global_row_idx, global_col_idx, value)); - } - } - let mut jacobian_list = jacobian_list.lock().unwrap(); - jacobian_list.extend(local_jacobian_list); - } - } - }); + if self._has_py_factor { + self.residual_blocks.iter().for_each(|residual_block| { + self.compute_residual_and_jacobian_impl( + residual_block, + variable_key_value_map, + &total_residual, + &jacobian_list, + ) + }); + } else { + self.residual_blocks.par_iter().for_each(|residual_block| { + self.compute_residual_and_jacobian_impl( + residual_block, + variable_key_value_map, + &total_residual, + &jacobian_list, + ) + }); + } let total_residual = Arc::try_unwrap(total_residual) .unwrap() @@ -136,4 +121,58 @@ impl Problem { .unwrap(); (residual_faer, jacobian_faer) } + fn compute_residual_and_jacobian_impl( + &self, + residual_block: &crate::ResidualBlock, + variable_key_value_map: &HashMap>, + total_residual: &Arc< + Mutex< + na::Matrix, na::VecStorage>>, + >, + >, + jacobian_list: &Arc>>, + ) { + let mut params = Vec::>::new(); + let mut variable_local_idx_size_list = Vec::<(usize, usize)>::new(); + let mut count_variable_local_idx: usize = 0; + for vk in &residual_block.variable_key_list { + if let Some(param) = variable_key_value_map.get(vk) { + params.push(param.clone()); + variable_local_idx_size_list.push((count_variable_local_idx, param.shape().0)); + count_variable_local_idx += param.shape().0; + }; + } + let (res, jac) = residual_block.jacobian(¶ms); + + { + let mut total_residual = total_residual.lock().unwrap(); + total_residual + .rows_mut( + residual_block.residual_row_start_idx, + residual_block.dim_residual, + ) + .copy_from(&res); + } + + for (i, vk) in residual_block.variable_key_list.iter().enumerate() { + if let Some(variable_global_idx) = self.variable_name_to_col_idx_dict.get(vk) { + let (variable_local_idx, var_size) = variable_local_idx_size_list[i]; + let variable_jac = jac.view((0, variable_local_idx), (jac.shape().0, var_size)); + let mut local_jacobian_list = Vec::new(); + for row_idx in 0..jac.shape().0 { + for col_idx in 0..var_size { + let global_row_idx = residual_block.residual_row_start_idx + row_idx; + let global_col_idx = variable_global_idx + col_idx; + let value = variable_jac[(row_idx, col_idx)]; + local_jacobian_list.push((global_row_idx, global_col_idx, value)); + } + } + let mut jacobian_list = jacobian_list.lock().unwrap(); + jacobian_list.extend(local_jacobian_list); + } + } + } + pub fn has_py_factor(&mut self) { + self._has_py_factor = true + } } diff --git a/src/python/mod.rs b/src/python/mod.rs index 36e843f..4ecb6bc 100644 --- a/src/python/mod.rs +++ b/src/python/mod.rs @@ -8,12 +8,14 @@ mod py_factors; mod py_loss_functions; mod py_optimizer; mod py_problem; +use self::py_factors::*; fn register_child_module(py: Python<'_>, parent_module: &PyModule) -> PyResult<()> { // For factors submodule let factors_module = PyModule::new(py, "factors")?; factors_module.add_class::()?; factors_module.add_class::()?; + factors_module.add_class::()?; parent_module.add_submodule(factors_module)?; py.import("sys")? .getattr("modules")? diff --git a/src/python/py_factors.rs b/src/python/py_factors.rs index 0b25075..9c96da4 100644 --- a/src/python/py_factors.rs +++ b/src/python/py_factors.rs @@ -1,6 +1,8 @@ -use num_dual::{try_first_derivative, Dual64}; +use nalgebra as na; +use num_dual::python::PyDual64Dyn; use numpy::PyReadonlyArray1; -use pyo3::{exceptions::PyTypeError, prelude::*}; +use pyo3::prelude::*; +use pyo3::types::PyTuple; use crate::factors::*; @@ -26,34 +28,59 @@ impl PriorFactor { } } -#[pyclass(name = "Dual64")] -#[derive(Clone, Debug)] -pub struct PyDual64(Dual64); +// #[pyclass(name = "DualDVec64")] +// #[derive(Clone)] +// pub struct PyDualDVec64(DualDVec64); +// #[pymethods] +// impl PyDualDVec64 { +// #[new] +// pub fn new(re: f64) -> Self { +// Self(DualDVec64::from_re(re)) +// } +// } + +#[pyclass] +#[derive(Debug, Clone)] +pub struct PyFactor { + pub func: Py, +} + #[pymethods] -impl PyDual64 { +impl PyFactor { #[new] - pub fn new(re: f64, eps: f64) -> Self { - Self(Dual64::new(re, eps)) + pub fn new(f: Py) -> Self { + PyFactor { func: f } } +} - #[getter] - pub fn get_first_derivative(&self) -> f64 { - self.0.eps +impl Factor for PyFactor { + fn residual_func( + &self, + params: &Vec>, + ) -> na::DVector { + let residual_py = Python::with_gil(|py| -> PyResult> { + let py_params: Vec> = params + .iter() + .map(|param| { + param + .data + .as_vec() + .iter() + .map(|x| PyDual64Dyn::from(x.clone())) + .collect::>() + }) + .map(|x| x.into_py(py)) + .collect(); + let args = PyTuple::new(py, py_params); + let result = self.func.call1(py, args); + let residual_py = result.unwrap().extract::>(py); + residual_py + }); + let residual_py: Vec = residual_py + .unwrap() + .iter() + .map(|x| ::clone(x).into()) + .collect(); + na::DVector::from_vec(residual_py) } } -#[pyfunction] -pub fn first_derivative(f: &PyAny, x: f64) -> PyResult<(f64, f64)> { - let g = |x| { - let res = f.call1((PyDual64::from(x),))?; - if let Ok(res) = res.extract::() { - Ok(res.0) - } else { - Err(PyErr::new::( - "argument 'f' must return a scalar. For vector functions use 'jacobian' instead." - .to_string(), - )) - } - }; - Ok((1.0, 2.0)) - // try_first_derivative(g, x) -} diff --git a/src/python/py_problem.rs b/src/python/py_problem.rs index 5264f94..e6c70e3 100644 --- a/src/python/py_problem.rs +++ b/src/python/py_problem.rs @@ -4,16 +4,22 @@ use crate::factors::*; use crate::loss_functions::*; use crate::problem::Problem; -fn convert_pyany_to_factor(py_any: &PyAny) -> PyResult> { +use super::PyFactor; + +fn convert_pyany_to_factor(py_any: &PyAny) -> PyResult<(bool, Box)> { let factor_name: String = py_any.get_type().getattr("__name__")?.extract()?; match factor_name.as_str() { "BetweenFactorSE2" => { let factor: BetweenFactorSE2 = py_any.extract().unwrap(); - Ok(Box::new(factor)) + Ok((false, Box::new(factor))) } "PriorFactor" => { let factor: PriorFactor = py_any.extract().unwrap(); - Ok(Box::new(factor)) + Ok((false, Box::new(factor))) + } + "PyFactor" => { + let factor: PyFactor = py_any.extract().unwrap(); + Ok((true, Box::new(factor))) } _ => Err(PyErr::new::( "Unknown factor type", @@ -49,13 +55,16 @@ impl Problem { pyfactor: &PyAny, pyloss_func: &PyAny, ) -> PyResult<()> { + let (is_pyfactor, factor) = convert_pyany_to_factor(pyfactor).unwrap(); self.add_residual_block( dim_residual, variable_key_size_list, - convert_pyany_to_factor(pyfactor).unwrap(), + factor, convert_pyany_to_loss_function(pyloss_func).unwrap(), ); - + if is_pyfactor { + self.has_py_factor() + } Ok(()) } } diff --git a/tests/test_tiny_solver.rs b/tests/test_tiny_solver.rs index 38f1b2e..e80f804 100644 --- a/tests/test_tiny_solver.rs +++ b/tests/test_tiny_solver.rs @@ -1,25 +1,25 @@ -use num_dual::DualNum; -use tiny_solver::TinySolver; -extern crate nalgebra as na; -use std::ops::Mul; +// use num_dual::DualNum; +// use tiny_solver::TinySolver; +// extern crate nalgebra as na; +// use std::ops::Mul; -struct TestProblem; +// struct TestProblem; -impl TinySolver<3, 2> for TestProblem { - fn cost_function( - params: nalgebra::SVector, 3>, - ) -> nalgebra::SVector, 2> { - let x = params[0]; - let y = params[1]; - let z = params[2]; - return nalgebra::SVector::from([x + y.mul(3.0) + z.powf(1.1), y * z]); - } -} +// impl TinySolver<3, 2> for TestProblem { +// fn cost_function( +// params: nalgebra::SVector, 3>, +// ) -> nalgebra::SVector, 2> { +// let x = params[0]; +// let y = params[1]; +// let z = params[2]; +// return nalgebra::SVector::from([x + y.mul(3.0) + z.powf(1.1), y * z]); +// } +// } -#[test] -fn test_residual() { - let xvec = na::SVector::from([5.0, 3.0, 2.0]); - let residual = TestProblem::cost_function(xvec.map(num_dual::DualVec::from_re)); - assert_eq!(residual[0].re, 16.143546925072584); - assert_eq!(residual[1].re, 6.0); -} +// #[test] +// fn test_residual() { +// let xvec = na::SVector::from([5.0, 3.0, 2.0]); +// let residual = TestProblem::cost_function(xvec.map(num_dual::DualVec::from_re)); +// assert_eq!(residual[0].re, 16.143546925072584); +// assert_eq!(residual[1].re, 6.0); +// } diff --git a/tiny_solver/factors/__init__.pyi b/tiny_solver/factors/__init__.pyi index 797b89b..6bba068 100644 --- a/tiny_solver/factors/__init__.pyi +++ b/tiny_solver/factors/__init__.pyi @@ -7,3 +7,6 @@ class BetweenFactorSE2(Factor): class PriorFactor(Factor): def __init__(self, x: np.ndarray) -> None: ... + +class PyFactor(Factor): + def __init__(self, func: callable) -> None: ... \ No newline at end of file diff --git a/tiny_solver/tiny_solver.pyi b/tiny_solver/tiny_solver.pyi index c09c59d..6c3064e 100644 --- a/tiny_solver/tiny_solver.pyi +++ b/tiny_solver/tiny_solver.pyi @@ -32,3 +32,5 @@ class OptimizerOptions: min_rel_error_decrease_threshold: float = 1e-5, min_error_threshold: float = 1e-8, ) -> None: ... + +def first_derivative_test(f: callable, x): ... \ No newline at end of file