Skip to content

Commit

Permalink
cargo fmt
Browse files Browse the repository at this point in the history
  • Loading branch information
martinjrobins committed Nov 29, 2024
1 parent 8c1d34e commit 512079f
Show file tree
Hide file tree
Showing 15 changed files with 149 additions and 124 deletions.
22 changes: 11 additions & 11 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
//!
//! ## Solving ODEs
//!
//! The simplest way to create a new problem is to use the [OdeBuilder] struct. You can set many configuration options such as the initial time ([OdeBuilder::t0]), initial step size ([OdeBuilder::h0]),
//! The simplest way to create a new problem is to use the [OdeBuilder] struct. You can set many configuration options such as the initial time ([OdeBuilder::t0]), initial step size ([OdeBuilder::h0]),
//! relative tolerance ([OdeBuilder::rtol]), absolute tolerance ([OdeBuilder::atol]), parameters ([OdeBuilder::p]) and equations ([OdeBuilder::rhs_implicit], [OdeBuilder::init], [OdeBuilder::mass] etc.)
//! or leave them at their default values. Then, call the [OdeBuilder::build] function to create a [OdeSolverProblem].
//!
Expand All @@ -25,7 +25,7 @@
//! To solve the problem given the initial state, you need to choose a solver. DiffSol provides the following solvers:
//! - A Backwards Difference Formulae [Bdf] solver, suitable for stiff problems and singular mass matrices.
//! - A Singly Diagonally Implicit Runge-Kutta (SDIRK or ESDIRK) solver [Sdirk]. You can use your own butcher tableau using [Tableau] or use one of the provided ([Tableau::tr_bdf2], [Tableau::esdirk34]).
//!
//!
//! The easiest way to create a solver is to use one of the provided methods on the [OdeSolverProblem] struct ([OdeSolverProblem::bdf_solver], [OdeSolverProblem::tr_bdf2_solver], [OdeSolverProblem::esdirk34_solver]).
//! These create a new solver from a provided state and problem. Alternatively, you can create both the solver and the state at once using [OdeSolverProblem::bdf], [OdeSolverProblem::tr_bdf2], [OdeSolverProblem::esdirk34].
//!
Expand Down Expand Up @@ -71,7 +71,7 @@
//! DiffSol provides a way to compute the forward sensitivity of the solution with respect to the parameters. You can provide the requires equations to the builder using [OdeBuilder::rhs_sens_implicit] and [OdeBuilder::init_sens],
//! or your equations struct must implement the [OdeEquationsSens] trait,
//! Note that by default the sensitivity equations are included in the error control for the solvers, you can change this by setting tolerances using the [OdeBuilder::sens_atol] and [OdeBuilder::sens_rtol] methods.
//!
//!
//! The easiest way to obtain the sensitivity solution is to use the [OdeSolverMethod::solve_dense_sensitivities] method, which will solve the forward problem and the sensitivity equations simultaneously and return the result.
//! If you are manually stepping the solver, you can use the [OdeSolverMethod::interpolate_sens] method to obtain the sensitivity solution at a given time. Otherwise the sensitivity vectors are stored in the [OdeSolverState] struct.
//!
Expand Down Expand Up @@ -99,14 +99,14 @@
//! to solve the adjoint equations.
//!
//! To provide the builder with the required equations, you can use the [OdeBuilder::rhs_adjoint_implicit], [OdeBuilder::init_adjoint], and [OdeBuilder::out_adjoint_implicit] methods,
//! or your equations struct must implement the [OdeEquationsAdjoint] trait.
//!
//! or your equations struct must implement the [OdeEquationsAdjoint] trait.
//!
//! The easiest way to obtain the adjoint solution is to use the [OdeSolverMethod::solve_adjoint] method, which will solve the forwards problem, then the adjoint problem and return the result.
//! If you wish to manually do the timestepping, then the best place to start is by looking at the source code for the [OdeSolverMethod::solve_adjoint] method. During the solution of the forwards problem
//! If you wish to manually do the timestepping, then the best place to start is by looking at the source code for the [OdeSolverMethod::solve_adjoint] method. During the solution of the forwards problem
//! you will need to use checkpointing to store the solution at a set of times.
//! From this you should obtain a `Vec<OdeSolverState>` (that can be the start and end of the solution), and
//! a [HermiteInterpolator] that can be used to interpolate the solution between the last two checkpoints. You can then use the [AdjointOdeSolverMethod::adjoint_equations] and then create
//! an adjoint solver either manually or using the [AdjointOdeSolverMethod::default_adjoint_solver] method. You can then use this solver to step the adjoint equations backwards in time using [OdeSolverMethod::step] as normal.
//! an adjoint solver either manually or using the [AdjointOdeSolverMethod::default_adjoint_solver] method. You can then use this solver to step the adjoint equations backwards in time using [OdeSolverMethod::step] as normal.
//! Once the adjoint equations have been solved,
//! the sensitivities of the output function will be stored in the [StateRef::sg] field of the adjoint solver state. If your parameters are used to calculate the initial conditions
//! of the forward problem, then you will need to use the [AdjointEquations::correct_sg_for_init] method to correct the sensitivities for the initial conditions.
Expand Down Expand Up @@ -197,10 +197,10 @@ pub use ode_solver::{
equations::AugmentedOdeEquations, equations::AugmentedOdeEquationsImplicit, equations::NoAug,
equations::OdeEquations, equations::OdeEquationsAdjoint, equations::OdeEquationsImplicit,
equations::OdeEquationsRef, equations::OdeEquationsSens, equations::OdeSolverEquations,
method::AdjointOdeSolverMethod, method::OdeSolverMethod, method::OdeSolverStopReason,
problem::OdeSolverProblem, sdirk::Sdirk, sdirk_state::SdirkState,
method::AdjointOdeSolverMethod, method::AugmentedOdeSolverMethod, method::OdeSolverMethod,
method::OdeSolverStopReason, problem::OdeSolverProblem, sdirk::Sdirk, sdirk_state::SdirkState,
sens_equations::SensEquations, sens_equations::SensInit, sens_equations::SensRhs,
state::OdeSolverState, tableau::Tableau, method::AugmentedOdeSolverMethod,
state::OdeSolverState, tableau::Tableau,
};
pub use op::constant_op::{ConstantOp, ConstantOpSens, ConstantOpSensAdjoint};
pub use op::linear_op::{LinearOp, LinearOpSens, LinearOpTranspose};
Expand All @@ -210,7 +210,7 @@ pub use op::nonlinear_op::{
pub use op::{
closure::Closure, closure_with_adjoint::ClosureWithAdjoint, constant_closure::ConstantClosure,
constant_closure_with_adjoint::ConstantClosureWithAdjoint, linear_closure::LinearClosure,
unit::UnitCallable, Op, BuilderOp, ParameterisedOp,
unit::UnitCallable, BuilderOp, Op, ParameterisedOp,
};
use op::{
closure_no_jac::ClosureNoJac, closure_with_sens::ClosureWithSens,
Expand Down
5 changes: 1 addition & 4 deletions src/linear_solver/faer/lu.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,7 @@ impl<T: Scalar> LinearSolver<Mat<T>> for LU<T> {
Ok(())
}

fn set_problem<C: NonLinearOpJacobian<T = T, V = Col<T>, M = Mat<T>>>(
&mut self,
op: &C,
) {
fn set_problem<C: NonLinearOpJacobian<T = T, V = Col<T>, M = Mat<T>>>(&mut self, op: &C) {
let ncols = op.nstates();
let nrows = op.nout();
let matrix = C::M::new_from_sparsity(nrows, ncols, op.jacobian_sparsity());
Expand Down
6 changes: 1 addition & 5 deletions src/linear_solver/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ pub use nalgebra::lu::LU as NalgebraLU;

/// A solver for the linear problem `Ax = b`, where `A` is a linear operator that is obtained by taking the linearisation of a nonlinear operator `C`
pub trait LinearSolver<M: Matrix>: Default {

// sets the point at which the linearisation of the operator is evaluated
// the operator is assumed to have the same sparsity as that given to [Self::set_problem]
fn set_linearisation<C: NonLinearOpJacobian<V = M::V, T = M::T, M = M>>(
Expand All @@ -30,10 +29,7 @@ pub trait LinearSolver<M: Matrix>: Default {
/// Set the problem to be solved, any previous problem is discarded.
/// Any internal state of the solver is reset.
/// This function will normally set the sparsity pattern of the matrix to be solved.
fn set_problem<C: NonLinearOpJacobian<V = M::V, T = M::T, M = M>>(
&mut self,
op: &C,
);
fn set_problem<C: NonLinearOpJacobian<V = M::V, T = M::T, M = M>>(&mut self, op: &C);

/// Solve the problem `Ax = b` and return the solution `x`.
/// panics if [Self::set_linearisation] has not been called previously
Expand Down
5 changes: 1 addition & 4 deletions src/linear_solver/suitesparse/klu.rs
Original file line number Diff line number Diff line change
Expand Up @@ -221,10 +221,7 @@ where
Ok(())
}

fn set_problem<C: NonLinearOpJacobian<T = M::T, V = M::V, M = M>>(
&mut self,
op: &C,
) {
fn set_problem<C: NonLinearOpJacobian<T = M::T, V = M::V, M = M>>(&mut self, op: &C) {
let ncols = op.nstates();
let nrows = op.nout();
let mut matrix = C::M::new_from_sparsity(nrows, ncols, op.jacobian_sparsity());
Expand Down
9 changes: 4 additions & 5 deletions src/nonlinear_solver/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,7 @@ impl<V> NonLinearSolveSolution<V> {
/// A solver for the nonlinear problem `F(x) = 0`.
pub trait NonLinearSolver<M: Matrix>: Default {
/// Set the problem to be solved, any previous problem is discarded.
fn set_problem<C: NonLinearOpJacobian<V = M::V, T = M::T, M = M>>(
&mut self,
op: &C,
);
fn set_problem<C: NonLinearOpJacobian<V = M::V, T = M::T, M = M>>(&mut self, op: &C);

/// Reset the approximation of the Jacobian matrix.
fn reset_jacobian<C: NonLinearOpJacobian<V = M::V, T = M::T, M = M>>(
Expand Down Expand Up @@ -131,7 +128,9 @@ pub mod tests {
let t = C::T::zero();
solver.reset_jacobian(&op, &solns[0].x0, t);
for soln in solns {
let x = solver.solve(&op, &soln.x0, t, &soln.x0, &mut convergence).unwrap();
let x = solver
.solve(&op, &soln.x0, t, &soln.x0, &mut convergence)
.unwrap();
let tol = x.clone() * scale(rtol) + atol;
x.assert_eq(&soln.x, &tol);
}
Expand Down
9 changes: 2 additions & 7 deletions src/nonlinear_solver/newton.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,13 +59,8 @@ impl<M: Matrix, Ls: LinearSolver<M>> Default for NewtonNonlinearSolver<M, Ls> {
}
}

impl<M: Matrix, Ls: LinearSolver<M>> NonLinearSolver<M>
for NewtonNonlinearSolver<M, Ls>
{
fn set_problem<C: NonLinearOpJacobian<V = M::V, T = M::T, M = M>>(
&mut self,
op: &C,
) {
impl<M: Matrix, Ls: LinearSolver<M>> NonLinearSolver<M> for NewtonNonlinearSolver<M, Ls> {
fn set_problem<C: NonLinearOpJacobian<V = M::V, T = M::T, M = M>>(&mut self, op: &C) {
self.linear_solver.set_problem(op);
self.is_jacobian_set = false;
self.tmp = C::V::zeros(op.nstates());
Expand Down
2 changes: 1 addition & 1 deletion src/ode_solver/adjoint_equations.rs
Original file line number Diff line number Diff line change
Expand Up @@ -623,7 +623,7 @@ where
fn update_rhs_out_state(&mut self, _y: &Eqn::V, _dy: &Eqn::V, _t: Eqn::T) {}

fn update_init_state(&mut self, _t: <Eqn as Op>::T) {}

fn integrate_main_eqn(&self) -> bool {
false
}
Expand Down
69 changes: 45 additions & 24 deletions src/ode_solver/bdf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@ use nalgebra::ComplexField;
use std::ops::AddAssign;

use crate::{
error::{DiffsolError, OdeSolverError}, AdjointEquations, AugmentedOdeEquationsImplicit, Convergence, DefaultDenseMatrix, LinearSolver, NoAug, OdeEquationsAdjoint, OdeEquationsSens, SensEquations, StateRef, StateRefMut
error::{DiffsolError, OdeSolverError},
AdjointEquations, AugmentedOdeEquationsImplicit, Convergence, DefaultDenseMatrix, LinearSolver,
NoAug, OdeEquationsAdjoint, OdeEquationsSens, SensEquations, StateRef, StateRefMut,
};

use num_traits::{abs, One, Pow, Zero};
Expand Down Expand Up @@ -44,7 +46,8 @@ where
}
}

impl<'a, M, Eqn, Nls> SensitivitiesOdeSolverMethod<'a, Eqn> for Bdf<'a, Eqn, Nls, M, SensEquations<'a, Eqn>>
impl<'a, M, Eqn, Nls> SensitivitiesOdeSolverMethod<'a, Eqn>
for Bdf<'a, Eqn, Nls, M, SensEquations<'a, Eqn>>
where
Eqn: OdeEquationsSens,
M: DenseMatrix<T = Eqn::T, V = Eqn::V>,
Expand All @@ -64,18 +67,12 @@ where
Eqn::V: DefaultDenseMatrix,
Nls: NonLinearSolver<Eqn::M> + 'a,
{
type DefaultAdjointSolver = Bdf<
'a,
Eqn,
Nls,
M,
AdjointEquations<'a, Eqn, Bdf<'a, Eqn, Nls, M>>,
>;
type DefaultAdjointSolver =
Bdf<'a, Eqn, Nls, M, AdjointEquations<'a, Eqn, Bdf<'a, Eqn, Nls, M>>>;
fn default_adjoint_solver<LS: LinearSolver<Eqn::M>>(
self,
mut aug_eqn: AdjointEquations<'a, Eqn, Self>,
) -> Result<Self::DefaultAdjointSolver, DiffsolError>
{
) -> Result<Self::DefaultAdjointSolver, DiffsolError> {
let problem = self.problem();
let nonlinear_solver = self.nonlinear_solver;
let state = self.state.into_adjoint::<LS, _, _>(problem, &mut aug_eqn)?;
Expand Down Expand Up @@ -216,7 +213,7 @@ where
const MAX_THRESHOLD: f64 = 2.0;
const MIN_THRESHOLD: f64 = 0.9;
const MIN_TIMESTEP: f64 = 1e-32;

pub fn new(
problem: &'a OdeSolverProblem<Eqn>,
state: BdfState<Eqn::V, M>,
Expand Down Expand Up @@ -340,7 +337,12 @@ where
) -> Result<Self, DiffsolError> {
state.check_sens_consistent_with_problem(problem, &augmented_eqn)?;

let mut ret = Self::_new(problem, state, nonlinear_solver, augmented_eqn.integrate_main_eqn())?;
let mut ret = Self::_new(
problem,
state,
nonlinear_solver,
augmented_eqn.integrate_main_eqn(),
)?;

ret.state.set_augmented_problem(problem, &augmented_eqn)?;

Expand All @@ -353,7 +355,8 @@ where
} else {
let bdf_callable = BdfCallable::new(augmented_eqn);
ret.nonlinear_solver.set_problem(&bdf_callable);
ret.nonlinear_solver.reset_jacobian(&bdf_callable, &ret.state.s[0], ret.state.t);
ret.nonlinear_solver
.reset_jacobian(&bdf_callable, &ret.state.s[0], ret.state.t);
Some(bdf_callable)
};

Expand Down Expand Up @@ -400,18 +403,22 @@ where
if self.jacobian_update.check_rhs_jacobian_update(c, &state) {
if let Some(op) = self.op.as_mut() {
op.set_jacobian_is_stale();
self.nonlinear_solver.reset_jacobian(op, &self.state.y, self.state.t);
self.nonlinear_solver
.reset_jacobian(op, &self.state.y, self.state.t);
} else if let Some(s_op) = self.s_op.as_mut() {
s_op.set_jacobian_is_stale();
self.nonlinear_solver.reset_jacobian(s_op, &self.state.s[0], self.state.t);
self.nonlinear_solver
.reset_jacobian(s_op, &self.state.s[0], self.state.t);
}
self.jacobian_update.update_rhs_jacobian();
self.jacobian_update.update_jacobian(c);
} else if self.jacobian_update.check_jacobian_update(c, &state) {
if let Some(op) = self.op.as_mut() {
self.nonlinear_solver.reset_jacobian(op, &self.state.y, self.state.t);
self.nonlinear_solver
.reset_jacobian(op, &self.state.y, self.state.t);
} else if let Some(s_op) = self.s_op.as_mut() {
self.nonlinear_solver.reset_jacobian(s_op, &self.state.s[0], self.state.t);
self.nonlinear_solver
.reset_jacobian(s_op, &self.state.s[0], self.state.t);
}
self.jacobian_update.update_jacobian(c);
}
Expand All @@ -433,7 +440,12 @@ where
let ru = r.mat_mul(&self.u);
{
if self.op.is_some() {
Self::_update_diff_for_step_size(&ru, &mut self.state.diff, &mut self.diff_tmp, order);
Self::_update_diff_for_step_size(
&ru,
&mut self.state.diff,
&mut self.diff_tmp,
order,
);
if self.ode_problem.integrate_out {
Self::_update_diff_for_step_size(
&ru,
Expand All @@ -446,7 +458,7 @@ where
for diff in self.state.sdiff.iter_mut() {
Self::_update_diff_for_step_size(&ru, diff, &mut self.diff_tmp, order);
}

for diff in self.state.sgdiff.iter_mut() {
Self::_update_diff_for_step_size(&ru, diff, &mut self.sgdiff_tmp, order);
}
Expand Down Expand Up @@ -693,7 +705,8 @@ where
if self.op.is_some() {
let atol = &self.ode_problem.atol;
let rtol = self.ode_problem.rtol;
error_norm += self.y_delta.squared_norm(&state.y, atol, rtol) * self.error_const2[order - 1];
error_norm +=
self.y_delta.squared_norm(&state.y, atol, rtol) * self.error_const2[order - 1];
ncontrib += 1;
if output_in_error_control {
let rtol = self.ode_problem.out_rtol.unwrap();
Expand Down Expand Up @@ -833,8 +846,13 @@ where
let s_new = &mut self.state.s[i];
s_new.copy_from(&self.s_predict);
// todo: should be a separate convergence object?
self.nonlinear_solver
.solve_in_place(&*s_op, s_new, t_new, &self.s_predict, &mut self.convergence)?;
self.nonlinear_solver.solve_in_place(
&*s_op,
s_new,
t_new,
&self.s_predict,
&mut self.convergence,
)?;
self.statistics.number_of_nonlinear_solver_iterations += self.convergence.niter();
let s_new = &*s_new;
self.s_deltas[i].copy_from(s_new);
Expand Down Expand Up @@ -1050,7 +1068,10 @@ where
}

// only calculate sensitivities if solve was successful
if solve_result.is_ok() && integrate_sens && self.sensitivity_solve(self.t_predict).is_err() {
if solve_result.is_ok()
&& integrate_sens
&& self.sensitivity_solve(self.t_predict).is_err()
{
solve_result = Err(ode_solver_error!(SensitivitySolveFailed));
}

Expand Down
29 changes: 22 additions & 7 deletions src/ode_solver/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,10 @@ where
}
}

pub fn rhs_custom<RhsCustom: BuilderOp>(self, rhs: RhsCustom) -> OdeBuilder<M, RhsCustom, Init, Mass, Root, Out> {
pub fn rhs_custom<RhsCustom: BuilderOp>(
self,
rhs: RhsCustom,
) -> OdeBuilder<M, RhsCustom, Init, Mass, Root, Out> {
OdeBuilder::<M, RhsCustom, Init, Mass, Root, Out> {
rhs: Some(rhs),
init: self.init,
Expand Down Expand Up @@ -346,7 +349,10 @@ where
}
}

pub fn init_custom<InitCustom: BuilderOp>(self, init: InitCustom) -> OdeBuilder<M, Rhs, InitCustom, Mass, Root, Out> {
pub fn init_custom<InitCustom: BuilderOp>(
self,
init: InitCustom,
) -> OdeBuilder<M, Rhs, InitCustom, Mass, Root, Out> {
OdeBuilder::<M, Rhs, InitCustom, Mass, Root, Out> {
rhs: self.rhs,
init: Some(init),
Expand Down Expand Up @@ -493,7 +499,10 @@ where
}
}

pub fn mass_custom<MassCustom: BuilderOp>(self, mass: MassCustom) -> OdeBuilder<M, Rhs, Init, MassCustom, Root, Out> {
pub fn mass_custom<MassCustom: BuilderOp>(
self,
mass: MassCustom,
) -> OdeBuilder<M, Rhs, Init, MassCustom, Root, Out> {
OdeBuilder::<M, Rhs, Init, MassCustom, Root, Out> {
rhs: self.rhs,
init: self.init,
Expand Down Expand Up @@ -554,7 +563,10 @@ where
}
}

pub fn root_custom<RootCustom: BuilderOp>(self, root: RootCustom) -> OdeBuilder<M, Rhs, Init, Mass, RootCustom, Out> {
pub fn root_custom<RootCustom: BuilderOp>(
self,
root: RootCustom,
) -> OdeBuilder<M, Rhs, Init, Mass, RootCustom, Out> {
OdeBuilder::<M, Rhs, Init, Mass, RootCustom, Out> {
rhs: self.rhs,
init: self.init,
Expand Down Expand Up @@ -657,7 +669,10 @@ where
}
}

pub fn out_custom<OutCustom: BuilderOp>(self, out: OutCustom) -> OdeBuilder<M, Rhs, Init, Mass, Root, OutCustom> {
pub fn out_custom<OutCustom: BuilderOp>(
self,
out: OutCustom,
) -> OdeBuilder<M, Rhs, Init, Mass, Root, OutCustom> {
OdeBuilder::<M, Rhs, Init, Mass, Root, OutCustom> {
rhs: self.rhs,
init: self.init,
Expand Down Expand Up @@ -897,12 +912,12 @@ where
mass.set_nparams(nparams);
mass.set_nout(nstates);
}

if let Some(ref mut root) = root {
root.set_nstates(nstates);
root.set_nparams(nparams);
}

if let Some(ref mut out) = out {
out.set_nstates(nstates);
out.set_nparams(nparams);
Expand Down
Loading

0 comments on commit 512079f

Please sign in to comment.