Skip to content

Commit

Permalink
Merge pull request #22 from martinjrobins/docs
Browse files Browse the repository at this point in the history
Docs
  • Loading branch information
martinjrobins authored Apr 3, 2024
2 parents 2e65bf0 + e795dc9 commit 6354696
Show file tree
Hide file tree
Showing 14 changed files with 309 additions and 206 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ state, we can use the following code:

```rust
let mut state = OdeSolverState::new(&problem);
solver.set_problem(&mut state, problem);
solver.set_problem(&mut state, &problem);
while state.t <= t {
solver.step(&mut state).unwrap();
}
Expand Down
52 changes: 50 additions & 2 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,51 @@
//! # DiffSol
//!
//! DiffSol is a library for solving differential equations. It provides a simple interface to solve ODEs and semi-explicit DAEs.
//!
//! ## Getting Started
//!
//! To create a new problem, use the [OdeBuilder] struct. You can set the initial time, initial step size, relative tolerance, absolute tolerance, and parameters,
//! or leave them at their default values. Then, call the [OdeBuilder::build_ode] method with the ODE equations, or the [OdeBuilder::build_ode_with_mass] method
//! with the ODE equations and the mass matrix equations.
//!
//! You will also need to choose a matrix type to use. DiffSol can use the [nalgebra](https://nalgebra.org) `DMatrix` type, or any other type that implements the
//! [Matrix] trait. You can also use the [sundials](https://computation.llnl.gov/projects/sundials) library for the matrix and vector types (see [SundialsMatrix]).
//!
//! To solve the problem, you need to choose a solver. DiffSol provides a pure rust [Bdf] solver, or you can use the [SundialsIda] solver from the sundials library (requires the `sundials` feature).
//! See the [OdeSolverMethod] trait for a more detailed description of the available methods on the solver.
//!
//! ```rust
//! use diffsol::{OdeBuilder, Bdf, OdeSolverState, OdeSolverMethod};
//! type M = nalgebra::DMatrix<f64>;
//!
//! let problem = OdeBuilder::new()
//! .rtol(1e-6)
//! .p([0.1])
//! .build_ode::<M, _, _, _>(
//! // dy/dt = -ay
//! |x, p, t, y| {
//! y[0] = -p[0] * x[0];
//! },
//! // Jv = -av
//! |x, p, t, v, y| {
//! y[0] = -p[0] * v[0];
//! },
//! // y(0) = 1
//! |p, t| {
//! nalgebra::DVector::from_vec(vec![1.0])
//! },
//! ).unwrap();
//!
//! let mut solver = Bdf::default();
//! let t = 0.4;
//! let mut state = OdeSolverState::new(&problem);
//! solver.set_problem(&mut state, &problem);
//! while state.t <= t {
//! solver.step(&mut state).unwrap();
//! }
//! let y = solver.interpolate(&state, t);
//! ```

#[cfg(feature = "diffsl-llvm10")]
pub extern crate diffsl10_0 as diffsl;
#[cfg(feature = "diffsl-llvm11")]
Expand Down Expand Up @@ -56,8 +104,8 @@ use matrix::{DenseMatrix, Matrix, MatrixViewMut};
pub use nonlinear_solver::newton::NewtonNonlinearSolver;
use nonlinear_solver::NonLinearSolver;
pub use ode_solver::{
bdf::Bdf, builder::OdeBuilder, equations::OdeEquations, OdeSolverMethod, OdeSolverProblem,
OdeSolverState,
bdf::Bdf, builder::OdeBuilder, equations::OdeEquations, method::OdeSolverMethod,
method::OdeSolverState, problem::OdeSolverProblem,
};
use op::NonLinearOp;
use scalar::{IndexType, Scalar, Scale};
Expand Down
21 changes: 18 additions & 3 deletions src/ode_solver/bdf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@ use serde::Serialize;

use crate::{
matrix::MatrixRef, op::ode::BdfCallable, scalar::scale, DenseMatrix, IndexType, MatrixViewMut,
NewtonNonlinearSolver, NonLinearSolver, Scalar, SolverProblem, Vector, VectorRef, VectorView,
VectorViewMut, LU,
NewtonNonlinearSolver, NonLinearSolver, OdeSolverMethod, OdeSolverProblem, OdeSolverState,
Scalar, SolverProblem, Vector, VectorRef, VectorView, VectorViewMut, LU,
};

use super::{equations::OdeEquations, OdeSolverMethod, OdeSolverProblem, OdeSolverState};
use super::equations::OdeEquations;

#[derive(Clone, Debug, Serialize)]
pub struct BdfStatistics<T: Scalar> {
Expand All @@ -39,6 +39,21 @@ impl<T: Scalar> Default for BdfStatistics<T> {
}
}

/// Implements a Backward Difference formula (BDF) implicit multistep integrator.
/// The basic algorithm is derived in \[1\]. This
/// particular implementation follows that implemented in the Matlab routine ode15s
/// described in \[2\] and the SciPy implementation
/// /[3/], which features the NDF formulas for improved
/// stability with associated differences in the error constants, and calculates
/// the jacobian at J(t_{n+1}, y^0_{n+1}). This implementation was based on that
/// implemented in the SciPy library \[3\], which also mainly
/// follows \[2\] but uses the more standard Jacobian update.
///
/// # References
///
/// \[1\] Byrne, G. D., & Hindmarsh, A. C. (1975). A polyalgorithm for the numerical solution of ordinary differential equations. ACM Transactions on Mathematical Software (TOMS), 1(1), 71-96.
/// \[2\] Shampine, L. F., & Reichelt, M. W. (1997). The matlab ode suite. SIAM journal on scientific computing, 18(1), 1-22.
/// \[3\] Virtanen, P., Gommers, R., Oliphant, T. E., Haberland, M., Reddy, T., Cournapeau, D., ... & Van Mulbregt, P. (2020). SciPy 1.0: fundamental algorithms for scientific computing in Python. Nature methods, 17(3), 261-272.
pub struct Bdf<M: DenseMatrix<T = Eqn::T, V = Eqn::V>, Eqn: OdeEquations> {
nonlinear_solver: Box<dyn NonLinearSolver<BdfCallable<Eqn>>>,
ode_problem: Option<OdeSolverProblem<Eqn>>,
Expand Down
133 changes: 133 additions & 0 deletions src/ode_solver/method.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
use anyhow::Result;
use std::rc::Rc;

use crate::{
op::{filter::FilterCallable, ode_rhs::OdeRhs},
Matrix, NonLinearSolver, OdeEquations, OdeSolverProblem, SolverProblem, Vector, VectorIndex,
};

/// Trait for ODE solver methods. This is the main user interface for the ODE solvers.
/// The solver is responsible for stepping the solution (given in the `OdeSolverState`), and interpolating the solution at a given time.
/// However, the solver does not own the state, so the user is responsible for creating and managing the state. If the user
/// wants to change the state, they should call `set_problem` again.
///
/// # Example
///
/// ```
/// use diffsol::{ OdeSolverMethod, OdeSolverProblem, OdeSolverState, OdeEquations };
///
/// fn solve_ode<Eqn: OdeEquations>(solver: &mut impl OdeSolverMethod<Eqn>, problem: &OdeSolverProblem<Eqn>, t: Eqn::T) -> Eqn::V {
/// let mut state = OdeSolverState::new(problem);
/// solver.set_problem(&mut state, problem);
/// while state.t <= t {
/// solver.step(&mut state).unwrap();
/// }
/// solver.interpolate(&state, t)
/// }
/// ```
pub trait OdeSolverMethod<Eqn: OdeEquations> {
/// Get the current problem if it has been set
fn problem(&self) -> Option<&OdeSolverProblem<Eqn>>;

/// Set the problem to solve, this performs any initialisation required by the solver.
/// Call this before calling `step` or `solve`, and call it again if the state is changed manually (i.e. not by the solver)
fn set_problem(&mut self, state: &mut OdeSolverState<Eqn::M>, problem: &OdeSolverProblem<Eqn>);

/// Step the solution forward by one step, altering the state in place
fn step(&mut self, state: &mut OdeSolverState<Eqn::M>) -> Result<()>;

/// Interpolate the solution at a given time. This time should be between the current time and the last solver time step
fn interpolate(&self, state: &OdeSolverState<Eqn::M>, t: Eqn::T) -> Eqn::V;

/// Reinitialise the solver state and solve the problem up to time `t`
fn solve(&mut self, problem: &OdeSolverProblem<Eqn>, t: Eqn::T) -> Result<Eqn::V> {
let mut state = OdeSolverState::new(problem);
self.set_problem(&mut state, problem);
while state.t <= t {
self.step(&mut state)?;
}
Ok(self.interpolate(&state, t))
}

/// Reinitialise the solver state making it consistent with the algebraic constraints and solve the problem up to time `t`
fn make_consistent_and_solve<RS: NonLinearSolver<FilterCallable<OdeRhs<Eqn>>>>(
&mut self,
problem: &OdeSolverProblem<Eqn>,
t: Eqn::T,
root_solver: &mut RS,
) -> Result<Eqn::V> {
let mut state = OdeSolverState::new_consistent(problem, root_solver)?;
self.set_problem(&mut state, problem);
while state.t <= t {
self.step(&mut state)?;
}
Ok(self.interpolate(&state, t))
}
}

/// State for the ODE solver, containing the current solution `y`, the current time `t`, and the current step size `h`.
pub struct OdeSolverState<M: Matrix> {
pub y: M::V,
pub t: M::T,
pub h: M::T,
_phantom: std::marker::PhantomData<M>,
}

impl<M: Matrix> OdeSolverState<M> {
/// Create a new solver state from an ODE problem. Note that this does not make the state consistent with the algebraic constraints.
/// If you need to make the state consistent, use `new_consistent` instead.
pub fn new<Eqn>(ode_problem: &OdeSolverProblem<Eqn>) -> Self
where
Eqn: OdeEquations<M = M, T = M::T, V = M::V>,
{
let t = ode_problem.t0;
let h = ode_problem.h0;
let y = ode_problem.eqn.init(t);
Self {
y,
t,
h,
_phantom: std::marker::PhantomData,
}
}

/// Create a new solver state from an ODE problem, making the state consistent with the algebraic constraints.
pub fn new_consistent<Eqn, S>(
ode_problem: &OdeSolverProblem<Eqn>,
root_solver: &mut S,
) -> Result<Self>
where
Eqn: OdeEquations<M = M, T = M::T, V = M::V>,
S: NonLinearSolver<FilterCallable<OdeRhs<Eqn>>> + ?Sized,
{
let t = ode_problem.t0;
let h = ode_problem.h0;
let indices = ode_problem.eqn.algebraic_indices();
let mut y = ode_problem.eqn.init(t);
if indices.len() == 0 {
return Ok(Self {
y,
t,
h,
_phantom: std::marker::PhantomData,
});
}
let mut y_filtered = y.filter(&indices);
let atol = Rc::new(ode_problem.atol.as_ref().filter(&indices));
let rhs = Rc::new(OdeRhs::new(ode_problem.eqn.clone()));
let f = Rc::new(FilterCallable::new(rhs, &y, indices));
let rtol = ode_problem.rtol;
let init_problem = SolverProblem::new(f, t, atol, rtol);
root_solver.set_problem(init_problem);
root_solver.solve_in_place(&mut y_filtered)?;
let init_problem = root_solver.problem().unwrap();
let indices = init_problem.f.indices();
y.scatter_from(&y_filtered, indices);
Ok(Self {
y,
t,
h,
_phantom: std::marker::PhantomData,
})
}
}
Loading

0 comments on commit 6354696

Please sign in to comment.