Skip to content

Commit

Permalink
finish second version using auto grad
Browse files Browse the repository at this point in the history
  • Loading branch information
powei-lin committed Jul 23, 2023
1 parent dbad7f8 commit 1adf00b
Showing 1 changed file with 8 additions and 15 deletions.
23 changes: 8 additions & 15 deletions src/tiny_solver.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
extern crate nalgebra as na;
use std::cmp;
use std::ops::Mul;

use num_dual;
use std::ops::Mul;

struct SolverParameters {
gradient_threshold: f64,
Expand Down Expand Up @@ -88,10 +86,12 @@ pub trait TinySolver<const NUM_PARAMETERS: usize, const NUM_RESIDUALS: usize> {
println!("gradient too small. {}", max_gradient);
result.status = SolverStatus::GradientTooSmall;
break;
} else if (residual.norm() < solver_params.error_threshold) {
} else if residual.norm() < solver_params.error_threshold {
result.status = SolverStatus::ErrorTooSmall;
break;
}

// initialize u and v
if step == 0 {
u = solver_params.initial_scale_factor * jtj.diagonal().max();
v = 2;
Expand All @@ -101,19 +101,14 @@ pub trait TinySolver<const NUM_PARAMETERS: usize, const NUM_RESIDUALS: usize> {
jtj_augmented.copy_from(&jtj);
jtj_augmented.set_diagonal(&jtj_augmented.diagonal().add_scalar(u));

println!("jtj {}", jtj_augmented);
// println!("jtj {}", jtj_augmented);
let dx = na::linalg::LU::new(jtj_augmented.clone())
.solve(&gradient)
.unwrap();
let solution: na::SMatrix<f64, NUM_PARAMETERS, 1> = jtj_augmented.fixed_view(0, 0) * dx;
let solved = (solution - gradient).abs().min() < solver_params.error_threshold;
if solved {
// if (dx.norm() < params.relative_step_threshold * x.norm()) {
// results.status = RELATIVE_STEP_SIZE_TOO_SMALL;
// break;
// }
// println!("success!");
if (dx.norm() < solver_params.relative_step_threshold * params.norm()) {
if dx.norm() < solver_params.relative_step_threshold * params.norm() {
result.status = SolverStatus::RelativeStepSizeTooSmall;
break;
}
Expand All @@ -124,12 +119,11 @@ pub trait TinySolver<const NUM_PARAMETERS: usize, const NUM_RESIDUALS: usize> {
// TODO: Error handling on user eval.
let residual_new =
Self::cost_function(param_new.map(num_dual::DualSVec64::from_re)).map(|x| x.re);
let rho: f64 = ((residual.norm_squared() - residual_new.norm_squared())
/ dx.dot(&(u * dx + gradient)));
let rho: f64 = (residual.norm_squared() - residual_new.norm_squared())
/ dx.dot(&(u * dx + gradient));
if rho > 0.0 {
// Accept the Gauss-Newton step because the linear model fits well.
*params = param_new;
// result.status = Update(function, x);
let tmp: f64 = 2.0 * rho - 1.0;
u = u * (1.0_f64 / 3.0).max(1.0 - tmp.powi(3));
v = 2;
Expand All @@ -150,7 +144,6 @@ pub trait TinySolver<const NUM_PARAMETERS: usize, const NUM_RESIDUALS: usize> {
}
result.error_magnitude = residual.norm();
result.gradient_magnitude = gradient.norm();
// println!("x0 {}", params);

result
}
Expand Down

0 comments on commit 1adf00b

Please sign in to comment.