Skip to content

Commit

Permalink
fmt
Browse files Browse the repository at this point in the history
  • Loading branch information
powei-lin committed Mar 19, 2024
1 parent 8142cbe commit e7e75cc
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 5 deletions.
4 changes: 2 additions & 2 deletions examples/m3500_benchmark.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ fn read_g2o(filename: &str) -> (problem::Problem, HashMap<String, na::DVector<f6
3,
vec![(id0, 3), (id1, 3)],
Box::new(edge),
Some(Box::new(HuberLoss { scale: 1.0 })),
Some(Box::new(HuberLoss::new(1.0))),
);
}
_ => {
Expand All @@ -54,7 +54,7 @@ fn read_g2o(filename: &str) -> (problem::Problem, HashMap<String, na::DVector<f6
3,
vec![("x0".to_string(), 3)],
Box::new(origin_factor),
Some(Box::new(HuberLoss { scale: 1.0 })),
Some(Box::new(HuberLoss::new(1.0))),
);
(problem, init_values)
}
Expand Down
8 changes: 7 additions & 1 deletion src/loss_functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,15 @@ pub trait Loss: Send + Sync {
#[pyclass]
#[derive(Debug, Clone)]
pub struct HuberLoss {
pub scale: f64,
scale: f64,
}
impl HuberLoss {
pub fn new(scale: f64) -> Self {
if scale <= 0.0 {
panic!("scale needs to be larger than zero");
}
HuberLoss { scale }
}
fn weight(&self, abs_err: f64) -> f64 {
if abs_err < self.scale {
1.0
Expand Down
4 changes: 2 additions & 2 deletions src/python/py_loss_functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use crate::loss_functions::*;
impl HuberLoss {
#[new]
#[pyo3(signature=(scale=1.0))]
pub fn new(scale: f64) -> Self {
HuberLoss { scale }
pub fn new_py(scale: f64) -> Self {
HuberLoss::new(scale)
}
}

0 comments on commit e7e75cc

Please sign in to comment.