From 29e2a4cb088a3851f2a5effbae4b2a1aaf6c6515 Mon Sep 17 00:00:00 2001 From: Martin Robinson Date: Fri, 1 Nov 2024 16:20:51 +0000 Subject: [PATCH] refactor: remove use of Rc in OdeEquations trait (#104) refactor: sparsity returned as owned refactor: sparsity functions moved from Op to relevent traits refactor: diffsl struct owns context, now has 'static lifetime --- Cargo.toml | 3 +- benches/ode_solvers.rs | 18 +- src/jacobian/mod.rs | 34 +- src/lib.rs | 19 +- src/linear_solver/faer/lu.rs | 5 +- src/linear_solver/faer/sparse_lu.rs | 8 +- src/linear_solver/nalgebra/lu.rs | 6 +- src/linear_solver/suitesparse/klu.rs | 4 +- src/matrix/mod.rs | 2 +- src/matrix/sparse_faer.rs | 52 ++-- src/matrix/sparse_serial.rs | 13 +- src/matrix/sparsity.rs | 32 +- src/ode_solver/adjoint_equations.rs | 80 ++--- src/ode_solver/bdf.rs | 76 ++--- src/ode_solver/builder.rs | 86 ++---- src/ode_solver/diffsl.rs | 290 ++++++++---------- src/ode_solver/equations.rs | 249 +++++++++------ src/ode_solver/method.rs | 9 +- src/ode_solver/mod.rs | 55 ++-- src/ode_solver/problem.rs | 17 +- src/ode_solver/sdirk.rs | 41 +-- src/ode_solver/sens_equations.rs | 47 ++- src/ode_solver/sundials.rs | 22 +- .../test_models/exponential_decay.rs | 18 +- .../exponential_decay_with_algebraic.rs | 37 +-- src/ode_solver/test_models/foodweb.rs | 269 ++++++++-------- src/ode_solver/test_models/heat2d.rs | 34 +- src/ode_solver/test_models/robertson.rs | 49 +-- src/op/bdf.rs | 31 +- src/op/closure.rs | 12 +- src/op/closure_with_adjoint.rs | 34 +- src/op/closure_with_sens.rs | 23 +- src/op/constant_closure_with_adjoint.rs | 2 +- src/op/constant_op.rs | 44 ++- src/op/init.rs | 10 +- src/op/linear_closure.rs | 12 +- src/op/linear_closure_with_adjoint.rs | 24 +- src/op/linear_op.rs | 23 +- src/op/linearise.rs | 6 +- src/op/matrix.rs | 8 +- src/op/mod.rs | 116 +++++-- src/op/nonlinear_op.rs | 35 ++- src/op/sdirk.rs | 35 +-- src/vector/sundials.rs | 2 +- 44 files changed, 1009 insertions(+), 983 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index b5a6675..5b6f18a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,7 +17,7 @@ nalgebra = [] sundials = ["suitesparse_sys", "bindgen", "cc"] suitesparse = ["suitesparse_sys"] diffsl-cranelift = ["diffsl-no-llvm", "diffsl"] -diffsl = [] +diffsl = [ ] diffsl-llvm = [] diffsl-llvm13 = ["diffsl13-0", "diffsl-llvm", "diffsl"] diffsl-llvm14 = ["diffsl14-0", "diffsl-llvm", "diffsl"] @@ -29,7 +29,6 @@ diffsl-llvm17 = ["diffsl17-0", "diffsl-llvm", "diffsl"] nalgebra = "0.33" nalgebra-sparse = { version = "0.10", features = ["io"] } num-traits = "0.2.17" -ouroboros = "0.18.2" serde = { version = "1.0.196", features = ["derive"] } diffsl-no-llvm = { package = "diffsl", version = "=0.2.0", optional = true } diffsl13-0 = { package = "diffsl", version = "=0.2.0", features = ["llvm13-0"], optional = true } diff --git a/benches/ode_solvers.rs b/benches/ode_solvers.rs index 224433d..0d192f6 100644 --- a/benches/ode_solvers.rs +++ b/benches/ode_solvers.rs @@ -2,8 +2,7 @@ use criterion::{criterion_group, criterion_main, Criterion}; use diffsol::{ ode_solver::test_models::{ exponential_decay::exponential_decay_problem, foodweb::foodweb_problem, - foodweb::FoodWebContext, heat2d::head2d_problem, robertson::robertson, - robertson_ode::robertson_ode, + heat2d::head2d_problem, robertson::robertson, robertson_ode::robertson_ode, }, FaerLU, FaerSparseLU, NalgebraLU, SparseColMat, }; @@ -222,10 +221,8 @@ fn criterion_benchmark(c: &mut Criterion) { c.bench_function(stringify!($name), |b| { use diffsol::diffsl::LlvmModule; use diffsol::ode_solver::test_models::robertson::*; - let mut context = diffsol::DiffSlContext::default(); - robertson_diffsl_compile::<$matrix, LlvmModule>(&mut context); b.iter(|| { - let (problem, soln) = robertson_diffsl_problem(&mut context, false); + let (problem, soln) = robertson_diffsl_problem::<$matrix, LlvmModule>(); let ls = $linear_solver::default(); benchmarks::$solver(&problem, soln.solution_points.last().unwrap().t, ls) }) @@ -336,8 +333,7 @@ fn criterion_benchmark(c: &mut Criterion) { ($name:ident, $solver:ident, $linear_solver:ident, $model:ident, $model_problem:ident, $matrix:ty, $($N:expr),+) => { $(c.bench_function(concat!(stringify!($name), "_", $N), |b| { b.iter(|| { - let context = FoodWebContext::default(); - let (problem, soln) = $model_problem::<$matrix, $N>(&context); + let (problem, soln) = $model_problem::<$matrix, $N>(); let ls = $linear_solver::default(); benchmarks::$solver(&problem, soln.solution_points.last().unwrap().t, ls) }) @@ -429,10 +425,8 @@ fn criterion_benchmark(c: &mut Criterion) { c.bench_function(concat!(stringify!($name), "_", $N), |b| { use diffsol::ode_solver::test_models::heat2d::*; use diffsol::diffsl::LlvmModule; - let mut context = diffsol::DiffSlContext::default(); - heat2d_diffsl_compile::<$matrix, LlvmModule, $N>(&mut context); b.iter(|| { - let (problem, soln) = heat2d_diffsl_problem(&mut context); + let (problem, soln) = heat2d_diffsl_problem::<$matrix, LlvmModule, $N>(); let ls = $linear_solver::default(); benchmarks::$solver(&problem, soln.solution_points.last().unwrap().t, ls) }) @@ -506,10 +500,8 @@ fn criterion_benchmark(c: &mut Criterion) { c.bench_function(concat!(stringify!($name), "_", $N), |b| { use diffsol::ode_solver::test_models::foodweb::*; use diffsol::diffsl::LlvmModule; - let mut context = diffsol::DiffSlContext::default(); - foodweb_diffsl_compile::<$matrix, LlvmModule, $N>(&mut context); b.iter(|| { - let (problem, soln) = foodweb_diffsl_problem(&mut context); + let (problem, soln) = foodweb_diffsl_problem::<$matrix, LlvmModule, $N>(); let ls = $linear_solver::default(); benchmarks::$solver(&problem, soln.solution_points.last().unwrap().t, ls) }) diff --git a/src/jacobian/mod.rs b/src/jacobian/mod.rs index ae79ac1..36dd49e 100644 --- a/src/jacobian/mod.rs +++ b/src/jacobian/mod.rs @@ -1,8 +1,8 @@ use std::collections::HashSet; use crate::{ - LinearOp, LinearOpTranspose, Matrix, MatrixSparsityRef, NonLinearOp, NonLinearOpAdjoint, - NonLinearOpJacobian, NonLinearOpSens, NonLinearOpSensAdjoint, Op, Scalar, Vector, VectorIndex, + LinearOp, LinearOpTranspose, Matrix, MatrixSparsity, NonLinearOp, NonLinearOpAdjoint, + NonLinearOpJacobian, NonLinearOpSens, NonLinearOpSensAdjoint, Scalar, Vector, VectorIndex, }; use num_traits::{One, Zero}; @@ -93,12 +93,9 @@ pub struct JacobianColoring { } impl JacobianColoring { - pub fn new_from_non_zeros>(op: &F, non_zeros: Vec<(usize, usize)>) -> Self { - let sparsity = op - .sparsity() - .expect("Jacobian sparsity not defined, cannot use coloring"); - let ncols = op.nstates(); - let graph = nonzeros2graph(non_zeros.as_slice(), ncols); + pub fn new(sparsity: &impl MatrixSparsity, non_zeros: &[(usize, usize)]) -> Self { + let ncols = sparsity.ncols(); + let graph = nonzeros2graph(non_zeros, ncols); let coloring = color_graph_greedy(&graph); let max_color = coloring.iter().max().copied().unwrap_or(0); let mut dst_indices_per_color = Vec::new(); @@ -224,7 +221,6 @@ mod tests { use std::rc::Rc; use crate::jacobian::{find_jacobian_non_zeros, JacobianColoring}; - use crate::matrix::sparsity::MatrixSparsityRef; use crate::matrix::Matrix; use crate::op::linear_closure::LinearClosure; use crate::vector::Vector; @@ -238,8 +234,6 @@ mod tests { use num_traits::{One, Zero}; use std::ops::MulAssign; - use super::find_matrix_non_zeros; - fn helper_triplets2op_nonlinear<'a, M: Matrix + 'a>( triplets: &'a [(usize, usize, M::T)], nrows: usize, @@ -394,9 +388,12 @@ mod tests { let op = helper_triplets2op_nonlinear::(triplets.as_slice(), n, n); let y0 = M::V::zeros(n); let t0 = M::T::zero(); - let non_zeros = find_jacobian_non_zeros(&op, &y0, t0); - let coloring = JacobianColoring::new_from_non_zeros(&op, non_zeros); - let mut jac = M::new_from_sparsity(3, 3, op.sparsity().map(|s| s.to_owned())); + let nonzeros = triplets + .iter() + .map(|(i, j, _v)| (*i, *j)) + .collect::>(); + let coloring = JacobianColoring::new(&op.jacobian_sparsity().unwrap(), &nonzeros); + let mut jac = M::new_from_sparsity(3, 3, op.jacobian_sparsity()); coloring.jacobian_inplace(&op, &y0, t0, &mut jac); let mut gemv1 = M::V::zeros(n); let v = M::V::from_element(3, M::T::one()); @@ -410,9 +407,12 @@ mod tests { for triplets in test_triplets { let op = helper_triplets2op_linear::(triplets.as_slice(), n, n); let t0 = M::T::zero(); - let non_zeros = find_matrix_non_zeros(&op, t0); - let coloring = JacobianColoring::new_from_non_zeros(&op, non_zeros); - let mut jac = M::new_from_sparsity(3, 3, op.sparsity().map(|s| s.to_owned())); + let nonzeros = triplets + .iter() + .map(|(i, j, _v)| (*i, *j)) + .collect::>(); + let coloring = JacobianColoring::new(&op.sparsity().unwrap(), &nonzeros); + let mut jac = M::new_from_sparsity(3, 3, op.sparsity()); coloring.matrix_inplace(&op, t0, &mut jac); let mut gemv1 = M::V::zeros(n); let v = M::V::from_element(3, M::T::one()); diff --git a/src/lib.rs b/src/lib.rs index 43fc089..924dc63 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -6,7 +6,7 @@ //! ## Solving ODEs //! //! The simplest way to create a new problem is to use the [OdeBuilder] struct. You can set the initial time, initial step size, relative tolerance, absolute tolerance, and parameters, -//! or leave them at their default values. Then, call one of the `build_*` functions (e.g. [OdeBuilder::build_ode], [OdeBuilder::build_ode_with_mass], [OdeBuilder::build_diffsl]) to create a [OdeSolverProblem]. +//! or leave them at their default values. Then, call one of the `build_*` functions (e.g. [OdeBuilder::build_ode], [OdeBuilder::build_ode_with_mass], [OdeBuilder::build_from_eqn]) to create a [OdeSolverProblem]. //! //! You will also need to choose a matrix type to use. DiffSol can use the [nalgebra](https://nalgebra.org) `DMatrix` type, the [faer](https://github.com/sarah-ek/faer-rs) `Mat` type, or any other type that implements the //! [Matrix] trait. @@ -35,7 +35,7 @@ //! DiffSL is a domain-specific language for specifying differential equations . It uses the LLVM compiler framwork //! to compile the equations to efficient machine code and uses the EnzymeAD library to compute the jacobian. //! -//! You can use DiffSL with DiffSol using the [DiffSlContext] struct and [OdeBuilder::build_diffsl] method. You need to enable one of the `diffsl-llvm*` features +//! You can use DiffSL with DiffSol using the [DiffSlContext] and [DiffSl] structs and [OdeBuilder::build_from_eqn] method. You need to enable one of the `diffsl-llvm*` features //! corresponding to the version of LLVM you have installed. E.g. to use your LLVM 10 installation, enable the `diffsl-llvm10` feature. //! //! For more information on the DiffSL language, see the [DiffSL documentation](https://martinjrobins.github.io/diffsl/) @@ -54,7 +54,7 @@ //! of the output vector `J(x) v` are also `NaN`, using the fact that `NaN`s propagate through most operations. However, this method is not foolproof and will fail if, //! for example, your jacobian function uses any control flow that depends on the input vector. If this is the case, you can provide the jacobian matrix directly by //! implementing the optional [NonLinearOpJacobian::jacobian_inplace] and the [LinearOp::matrix_inplace] (if applicable) functions, -//! or by providing a sparsity pattern using the [Op::sparsity] function. +//! or by providing a sparsity pattern using the [NonLinearOpJacobian::jacobian_sparsity] and [LinearOp::sparsity] functions. //! //! ## Events / Root finding //! @@ -173,7 +173,7 @@ pub use ode_solver::sundials::SundialsIda; pub use linear_solver::suitesparse::klu::KLU; #[cfg(feature = "diffsl")] -pub use ode_solver::diffsl::DiffSlContext; +pub use ode_solver::diffsl::{DiffSl, DiffSlContext}; pub use jacobian::{ find_adjoint_non_zeros, find_jacobian_non_zeros, find_matrix_non_zeros, @@ -196,11 +196,12 @@ pub use ode_solver::{ bdf_state::BdfState, builder::OdeBuilder, checkpointing::Checkpointing, checkpointing::HermiteInterpolator, equations::AugmentedOdeEquations, equations::AugmentedOdeEquationsImplicit, equations::NoAug, equations::OdeEquations, - equations::OdeEquationsAdjoint, equations::OdeEquationsImplicit, equations::OdeEquationsSens, - equations::OdeSolverEquations, method::AdjointOdeSolverMethod, method::OdeSolverMethod, - method::OdeSolverStopReason, method::SensitivitiesOdeSolverMethod, problem::OdeSolverProblem, - sdirk::Sdirk, sdirk::SdirkAdj, sdirk_state::SdirkState, sens_equations::SensEquations, - sens_equations::SensInit, sens_equations::SensRhs, state::OdeSolverState, tableau::Tableau, + equations::OdeEquationsAdjoint, equations::OdeEquationsImplicit, equations::OdeEquationsRef, + equations::OdeEquationsSens, equations::OdeSolverEquations, method::AdjointOdeSolverMethod, + method::OdeSolverMethod, method::OdeSolverStopReason, method::SensitivitiesOdeSolverMethod, + problem::OdeSolverProblem, sdirk::Sdirk, sdirk::SdirkAdj, sdirk_state::SdirkState, + sens_equations::SensEquations, sens_equations::SensInit, sens_equations::SensRhs, + state::OdeSolverState, tableau::Tableau, }; pub use op::constant_op::{ConstantOp, ConstantOpSens, ConstantOpSensAdjoint}; pub use op::linear_op::{LinearOp, LinearOpSens, LinearOpTranspose}; diff --git a/src/linear_solver/faer/lu.rs b/src/linear_solver/faer/lu.rs index 500eb32..8e139c9 100644 --- a/src/linear_solver/faer/lu.rs +++ b/src/linear_solver/faer/lu.rs @@ -3,8 +3,7 @@ use std::rc::Rc; use crate::{error::LinearSolverError, linear_solver_error}; use crate::{ - error::DiffsolError, linear_solver::LinearSolver, Matrix, MatrixSparsityRef, - NonLinearOpJacobian, Scalar, + error::DiffsolError, linear_solver::LinearSolver, Matrix, NonLinearOpJacobian, Scalar, }; use faer::{linalg::solvers::FullPivLu, solvers::SpSolver, Col, Mat}; @@ -58,7 +57,7 @@ impl LinearSolver> for LU { ) { let ncols = op.nstates(); let nrows = op.nout(); - let matrix = C::M::new_from_sparsity(nrows, ncols, op.sparsity().map(|s| s.to_owned())); + let matrix = C::M::new_from_sparsity(nrows, ncols, op.jacobian_sparsity()); self.matrix = Some(matrix); } } diff --git a/src/linear_solver/faer/sparse_lu.rs b/src/linear_solver/faer/sparse_lu.rs index 719d216..788933b 100644 --- a/src/linear_solver/faer/sparse_lu.rs +++ b/src/linear_solver/faer/sparse_lu.rs @@ -4,7 +4,6 @@ use crate::{ error::{DiffsolError, LinearSolverError}, linear_solver::LinearSolver, linear_solver_error, - matrix::sparsity::MatrixSparsityRef, scalar::IndexType, Matrix, NonLinearOpJacobian, Scalar, SparseColMat, }; @@ -73,12 +72,7 @@ impl LinearSolver> for FaerSparseLU { ) { let ncols = op.nstates(); let nrows = op.nout(); - let matrix = C::M::new_from_sparsity( - nrows, - ncols, - op.sparsity() - .map(|s| MatrixSparsityRef::>::to_owned(&s)), - ); + let matrix = C::M::new_from_sparsity(nrows, ncols, op.jacobian_sparsity()); self.matrix = Some(matrix); self.lu_symbolic = Some( SymbolicLu::try_new(self.matrix.as_ref().unwrap().faer().symbolic()) diff --git a/src/linear_solver/nalgebra/lu.rs b/src/linear_solver/nalgebra/lu.rs index 48c0b67..1bdd92a 100644 --- a/src/linear_solver/nalgebra/lu.rs +++ b/src/linear_solver/nalgebra/lu.rs @@ -4,9 +4,7 @@ use nalgebra::{DMatrix, DVector, Dyn}; use crate::{ error::{DiffsolError, LinearSolverError}, - linear_solver_error, - matrix::sparsity::MatrixSparsityRef, - LinearSolver, Matrix, NonLinearOpJacobian, Scalar, + linear_solver_error, LinearSolver, Matrix, NonLinearOpJacobian, Scalar, }; /// A [LinearSolver] that uses the LU decomposition in the [`nalgebra` library](https://nalgebra.org/) to solve the linear system. @@ -62,7 +60,7 @@ impl LinearSolver> for LU { ) { let ncols = op.nstates(); let nrows = op.nout(); - let matrix = C::M::new_from_sparsity(nrows, ncols, op.sparsity().map(|s| s.to_owned())); + let matrix = C::M::new_from_sparsity(nrows, ncols, op.jacobian_sparsity()); self.matrix = Some(matrix); } } diff --git a/src/linear_solver/suitesparse/klu.rs b/src/linear_solver/suitesparse/klu.rs index 0593942..d52b4c2 100644 --- a/src/linear_solver/suitesparse/klu.rs +++ b/src/linear_solver/suitesparse/klu.rs @@ -29,7 +29,7 @@ use crate::{ linear_solver_error, matrix::MatrixCommon, vector::Vector, - Matrix, MatrixSparsityRef, NonLinearOpJacobian, SparseColMat, + Matrix, NonLinearOpJacobian, SparseColMat, }; trait MatrixKLU: Matrix { @@ -231,7 +231,7 @@ where ) { let ncols = op.nstates(); let nrows = op.nout(); - let mut matrix = C::M::new_from_sparsity(nrows, ncols, op.sparsity().map(|s| s.to_owned())); + let mut matrix = C::M::new_from_sparsity(nrows, ncols, op.jacobian_sparsity()); let mut klu_common = self.klu_common.borrow_mut(); self.klu_symbolic = KluSymbolic::try_from_matrix(&mut matrix, klu_common.as_mut()).ok(); self.matrix = Some(matrix); diff --git a/src/matrix/mod.rs b/src/matrix/mod.rs index 854dfd0..dc72c76 100644 --- a/src/matrix/mod.rs +++ b/src/matrix/mod.rs @@ -113,7 +113,7 @@ pub trait MatrixView<'a>: } /// A base matrix trait (including sparse and dense matrices) -pub trait Matrix: MatrixCommon + Mul, Output = Self> + Clone { +pub trait Matrix: MatrixCommon + Mul, Output = Self> + Clone + 'static { type Sparsity: MatrixSparsity; type SparsityRef<'a>: MatrixSparsityRef<'a, Self> where diff --git a/src/matrix/sparse_faer.rs b/src/matrix/sparse_faer.rs index 0fddb0b..2877114 100644 --- a/src/matrix/sparse_faer.rs +++ b/src/matrix/sparse_faer.rs @@ -103,6 +103,32 @@ impl MatrixSparsity> for SymbolicSparseColMat Err(DiffsolError::Other(e.to_string())), } } + + fn get_index( + &self, + rows: &[IndexType], + cols: &[IndexType], + ) -> < as MatrixCommon>::V as Vector>::Index { + let col_ptrs = self.col_ptrs(); + let row_indices = self.row_indices(); + let mut indices = Vec::with_capacity(rows.len()); + for (&i, &j) in rows.iter().zip(cols.iter()) { + let col_ptr = col_ptrs[j]; + let next_col_ptr = col_ptrs[j + 1]; + for (ii, &ri) in row_indices + .iter() + .enumerate() + .take(next_col_ptr) + .skip(col_ptr) + { + if ri == i { + indices.push(ii); + break; + } + } + } + indices + } } impl<'a, T: Scalar> MatrixSparsityRef<'a, SparseColMat> @@ -132,32 +158,6 @@ impl<'a, T: Scalar> MatrixSparsityRef<'a, SparseColMat> } indices } - - fn get_index( - &self, - rows: &[IndexType], - cols: &[IndexType], - ) -> < as MatrixCommon>::V as Vector>::Index { - let col_ptrs = self.col_ptrs(); - let row_indices = self.row_indices(); - let mut indices = Vec::with_capacity(rows.len()); - for (&i, &j) in rows.iter().zip(cols.iter()) { - let col_ptr = col_ptrs[j]; - let next_col_ptr = col_ptrs[j + 1]; - for (ii, &ri) in row_indices - .iter() - .enumerate() - .take(next_col_ptr) - .skip(col_ptr) - { - if ri == i { - indices.push(ii); - break; - } - } - } - indices - } } impl Mul> for SparseColMat { diff --git a/src/matrix/sparse_serial.rs b/src/matrix/sparse_serial.rs index 1b95ade..2c4f775 100644 --- a/src/matrix/sparse_serial.rs +++ b/src/matrix/sparse_serial.rs @@ -138,13 +138,6 @@ impl MatrixSparsity> for SparsityPattern { major_offsets.push(n); SparsityPattern::try_from_offsets_and_indices(n, n, major_offsets, minor_indices).unwrap() } -} - -impl<'a, T: Scalar> MatrixSparsityRef<'a, CscMatrix> for &'a SparsityPattern { - fn to_owned(&self) -> SparsityPattern { - SparsityPattern::clone(self) - } - fn get_index(&self, rows: &[IndexType], cols: &[IndexType]) -> DVector { let mut index = DVector::::zeros(rows.len()); #[allow(unused_mut)] @@ -156,6 +149,12 @@ impl<'a, T: Scalar> MatrixSparsityRef<'a, CscMatrix> for &'a SparsityPattern } index } +} + +impl<'a, T: Scalar> MatrixSparsityRef<'a, CscMatrix> for &'a SparsityPattern { + fn to_owned(&self) -> SparsityPattern { + SparsityPattern::clone(self) + } fn nrows(&self) -> IndexType { self.minor_dim() diff --git a/src/matrix/sparsity.rs b/src/matrix/sparsity.rs index bef3ffc..fbc81e2 100644 --- a/src/matrix/sparsity.rs +++ b/src/matrix/sparsity.rs @@ -7,7 +7,7 @@ use crate::{ use super::Matrix; -pub trait MatrixSparsity: Sized { +pub trait MatrixSparsity: Sized + Clone { fn nrows(&self) -> IndexType; fn ncols(&self) -> IndexType; fn is_sparse() -> bool; @@ -20,6 +20,7 @@ pub trait MatrixSparsity: Sized { fn union(self, other: M::SparsityRef<'_>) -> Result; fn new_diagonal(n: IndexType) -> Self; fn as_ref(&self) -> M::SparsityRef<'_>; + fn get_index(&self, rows: &[IndexType], cols: &[IndexType]) -> ::Index; } pub trait MatrixSparsityRef<'a, M: Matrix> { @@ -28,9 +29,9 @@ pub trait MatrixSparsityRef<'a, M: Matrix> { fn is_sparse() -> bool; fn indices(&self) -> Vec<(IndexType, IndexType)>; fn to_owned(&self) -> M::Sparsity; - fn get_index(&self, rows: &[IndexType], cols: &[IndexType]) -> ::Index; } +#[derive(Clone)] pub struct Dense { nrows: IndexType, ncols: IndexType, @@ -102,6 +103,19 @@ where fn new_diagonal(n: IndexType) -> Self { Dense::new(n, n) } + fn get_index(&self, rows: &[IndexType], cols: &[IndexType]) -> ::Index { + let indices: Vec<_> = rows + .iter() + .zip(cols.iter()) + .map(|(i, j)| { + if i >= &self.nrows() || j >= &self.ncols() { + panic!("Index out of bounds") + } + j * self.nrows() + i + }) + .collect(); + ::Index::from_slice(indices.as_slice()) + } } impl<'a, M> MatrixSparsityRef<'a, M> for DenseRef<'a, M> @@ -124,20 +138,6 @@ where false } - fn get_index(&self, rows: &[IndexType], cols: &[IndexType]) -> ::Index { - let indices: Vec<_> = rows - .iter() - .zip(cols.iter()) - .map(|(i, j)| { - if i >= &self.nrows() || j >= &self.ncols() { - panic!("Index out of bounds") - } - j * self.nrows() + i - }) - .collect(); - ::Index::from_slice(indices.as_slice()) - } - fn indices(&self) -> Vec<(IndexType, IndexType)> { Vec::new() } diff --git a/src/ode_solver/adjoint_equations.rs b/src/ode_solver/adjoint_equations.rs index 6dc6b74..3320b19 100644 --- a/src/ode_solver/adjoint_equations.rs +++ b/src/ode_solver/adjoint_equations.rs @@ -1,16 +1,20 @@ use num_traits::{One, Zero}; -use std::{cell::RefCell, ops::AddAssign, ops::SubAssign, rc::Rc}; +use std::{ + cell::RefCell, + ops::{AddAssign, SubAssign}, + rc::Rc, +}; use crate::{ op::nonlinear_op::NonLinearOpJacobian, AugmentedOdeEquations, Checkpointing, ConstantOp, ConstantOpSensAdjoint, LinearOp, LinearOpTranspose, Matrix, NonLinearOp, NonLinearOpAdjoint, - NonLinearOpSensAdjoint, OdeEquations, OdeEquationsAdjoint, OdeSolverMethod, OdeSolverProblem, - Op, Vector, + NonLinearOpSensAdjoint, OdeEquations, OdeEquationsAdjoint, OdeEquationsRef, OdeSolverMethod, + OdeSolverProblem, Op, Vector, }; pub struct AdjointContext where - Eqn: OdeEquationsAdjoint, + Eqn: OdeEquations, Method: OdeSolverMethod, { checkpointer: Checkpointing, @@ -169,7 +173,7 @@ where /// We need the current state x(t), which is obtained from the checkpointed forward solve at the current time step. pub struct AdjointRhs where - Eqn: OdeEquationsAdjoint, + Eqn: OdeEquations, Method: OdeSolverMethod, { eqn: Rc, @@ -180,7 +184,7 @@ where impl AdjointRhs where - Eqn: OdeEquationsAdjoint, + Eqn: OdeEquations, Method: OdeSolverMethod, { pub fn new( @@ -217,9 +221,6 @@ where fn nparams(&self) -> usize { self.eqn.rhs().nparams() } - fn sparsity(&self) -> Option<::SparsityRef<'_>> { - self.eqn.rhs().sparsity_adjoint() - } } impl NonLinearOp for AdjointRhs @@ -267,6 +268,9 @@ where let x = context.state(); self.eqn.rhs().adjoint_inplace(x, t, y); } + fn jacobian_sparsity(&self) -> Option<::Sparsity> { + self.eqn.rhs().adjoint_sparsity() + } } /// Output of the adjoint equations is: @@ -327,9 +331,6 @@ where fn nparams(&self) -> usize { self.eqn.rhs().nparams() } - fn sparsity(&self) -> Option<::SparsityRef<'_>> { - self.eqn.rhs().sparsity_sens_adjoint() - } } impl NonLinearOp for AdjointOut @@ -374,6 +375,9 @@ where let x = context.state(); self.eqn.rhs().sens_adjoint_inplace(x, t, y); } + fn jacobian_sparsity(&self) -> Option<::Sparsity> { + self.eqn.rhs().sens_adjoint_sparsity() + } } /// Adjoint equations for ODEs @@ -388,9 +392,9 @@ where Method: OdeSolverMethod, { eqn: Rc, - rhs: Rc>, - out: Option>>, - mass: Option>>, + rhs: AdjointRhs, + out: Option>, + mass: Option>, context: Rc>>, tmp: RefCell, tmp2: RefCell, @@ -412,10 +416,10 @@ where with_out: bool, ) -> Self { let eqn = problem.eqn.clone(); - let rhs = Rc::new(AdjointRhs::new(&eqn, context.clone(), with_out)); + let rhs = AdjointRhs::new(&eqn, context.clone(), with_out); let init = Rc::new(AdjointInit::new(&eqn)); let out = if with_out { - Some(Rc::new(AdjointOut::new(&eqn, context.clone(), with_out))) + Some(AdjointOut::new(&eqn, context.clone(), with_out)) } else { None }; @@ -441,7 +445,7 @@ where None }; let out_rtol = if with_out { problem.out_rtol } else { None }; - let mass = eqn.mass().map(|_m| Rc::new(AdjointMass::new(&eqn))); + let mass = eqn.mass().map(|_m| AdjointMass::new(&eqn)); Self { rhs, init, @@ -466,10 +470,10 @@ where mass.call_transpose_inplace(s_i, t, &mut tmp2); self.eqn .init() - .sens_mul_transpose_inplace(t, &tmp2, &mut tmp); + .sens_transpose_mul_inplace(t, &tmp2, &mut tmp); sg_i.sub_assign(&*tmp); } else { - self.eqn.init().sens_mul_transpose_inplace(t, s_i, &mut tmp); + self.eqn.init().sens_transpose_mul_inplace(t, s_i, &mut tmp); sg_i.sub_assign(&*tmp); } } @@ -506,36 +510,36 @@ where } } -impl OdeEquations for AdjointEquations +impl<'a, Eqn, Method> OdeEquationsRef<'a> for AdjointEquations where Eqn: OdeEquationsAdjoint, Method: OdeSolverMethod, { - type T = Eqn::T; - type V = Eqn::V; - type M = Eqn::M; - type Rhs = AdjointRhs; - type Mass = AdjointMass; - type Root = Eqn::Root; - type Init = AdjointInit; - type Out = AdjointOut; + type Rhs = &'a AdjointRhs; + type Mass = &'a AdjointMass; + type Root = >::Root; + type Init = &'a AdjointInit; + type Out = &'a AdjointOut; +} - fn rhs(&self) -> &Rc { +impl OdeEquations for AdjointEquations +where + Eqn: OdeEquationsAdjoint, + Method: OdeSolverMethod, +{ + fn rhs(&self) -> &AdjointRhs { &self.rhs } - fn mass(&self) -> Option<&Rc> { + fn mass(&self) -> Option<&AdjointMass> { self.mass.as_ref() } - fn root(&self) -> Option<&Rc> { + fn root(&self) -> Option<>::Root> { None } - fn init(&self) -> &Rc { + fn init(&self) -> &AdjointInit { &self.init } - fn set_params(&mut self, _p: Self::V) { - panic!("Not implemented for SensEquations"); - } - fn out(&self) -> Option<&Rc> { + fn out(&self) -> Option<&AdjointOut> { self.out.as_ref() } } @@ -576,7 +580,7 @@ where fn update_rhs_out_state(&mut self, _y: &Eqn::V, _dy: &Eqn::V, _t: Eqn::T) {} - fn update_init_state(&mut self, _t: ::T) {} + fn update_init_state(&mut self, _t: ::T) {} } #[cfg(test)] diff --git a/src/ode_solver/bdf.rs b/src/ode_solver/bdf.rs index 994e8ca..53fe8f3 100644 --- a/src/ode_solver/bdf.rs +++ b/src/ode_solver/bdf.rs @@ -4,8 +4,8 @@ use std::rc::Rc; use crate::{ error::{DiffsolError, OdeSolverError}, - AdjointEquations, NoAug, OdeEquationsAdjoint, OdeEquationsSens, SensEquations, StateRef, - StateRefMut, + AdjointEquations, AugmentedOdeEquationsImplicit, NoAug, OdeEquationsAdjoint, OdeEquationsSens, + SensEquations, StateRef, StateRefMut, }; use num_traits::{abs, One, Pow, Zero}; @@ -25,9 +25,8 @@ use crate::{ }; use super::jacobian_update::SolverState; -use super::{ - equations::OdeEquations, - method::{AdjointOdeSolverMethod, AugmentedOdeSolverMethod, SensitivitiesOdeSolverMethod}, +use super::method::{ + AdjointOdeSolverMethod, AugmentedOdeSolverMethod, SensitivitiesOdeSolverMethod, }; #[derive(Clone, Debug, Serialize, Default)] @@ -82,7 +81,7 @@ pub struct Bdf< M: DenseMatrix, Eqn: OdeEquationsImplicit, Nls: NonLinearSolver, - AugmentedEqn: AugmentedOdeEquations + OdeEquationsImplicit = NoAug, + AugmentedEqn: AugmentedOdeEquationsImplicit = NoAug, > { nonlinear_solver: Nls, ode_problem: Option>, @@ -758,7 +757,7 @@ where )) } - fn interpolate_sens(&self, t: ::T) -> Result, DiffsolError> { + fn interpolate_sens(&self, t: ::T) -> Result, DiffsolError> { // state must be set let state = self.state.as_ref().ok_or(ode_solver_error!(StateNotSet))?; if self.is_state_modified { @@ -843,7 +842,7 @@ where self.root_finder .as_ref() .unwrap() - .init(root_fn.as_ref(), &state.y, state.t); + .init(&root_fn, &state.y, state.t); } // (re)allocate internal state @@ -1082,8 +1081,8 @@ where // check for root within accepted step if let Some(root_fn) = self.problem().as_ref().unwrap().eqn.root() { let ret = self.root_finder.as_ref().unwrap().check_root( - &|t: ::T| self.interpolate(t), - root_fn.as_ref(), + &|t: ::T| self.interpolate(t), + &root_fn, &self.state.as_ref().unwrap().y, self.state.as_ref().unwrap().t, ); @@ -1102,7 +1101,7 @@ where Ok(OdeSolverStopReason::InternalTimestep) } - fn set_stop_time(&mut self, tstop: ::T) -> Result<(), DiffsolError> { + fn set_stop_time(&mut self, tstop: ::T) -> Result<(), DiffsolError> { self.tstop = Some(tstop); if let Some(OdeSolverStopReason::TstopReached) = self.handle_tstop(tstop)? { let error = OdeSolverError::StopTimeBeforeCurrentTime { @@ -1168,7 +1167,7 @@ where impl AdjointOdeSolverMethod for Bdf where Eqn: OdeEquationsAdjoint, - AugmentedEqn: AugmentedOdeEquations + OdeEquationsAdjoint, + AugmentedEqn: AugmentedOdeEquations + OdeEquationsImplicit, M: DenseMatrix, Nls: NonLinearSolver, for<'b> &'b Eqn::V: VectorRef, @@ -1198,7 +1197,7 @@ mod test { exponential_decay_with_algebraic_problem, exponential_decay_with_algebraic_problem_sens, }, - foodweb::{foodweb_problem, FoodWebContext}, + foodweb::foodweb_problem, gaussian_decay::gaussian_decay_problem, heat2d::head2d_problem, robertson::{robertson, robertson_sens}, @@ -1251,7 +1250,6 @@ mod test { let (problem, soln) = exponential_decay_problem::(false); test_ode_solver_no_sens(&mut s, &problem, soln, None, false); insta::assert_yaml_snapshot!(s.get_statistics(), @r###" - --- number_of_linear_solver_setups: 11 number_of_steps: 47 number_of_error_test_failures: 0 @@ -1259,7 +1257,6 @@ mod test { number_of_nonlinear_solver_fails: 0 "###); insta::assert_yaml_snapshot!(problem.eqn.as_ref().rhs().statistics(), @r###" - --- number_of_calls: 84 number_of_jac_muls: 2 number_of_matrix_evals: 1 @@ -1289,7 +1286,6 @@ mod test { let (problem, soln) = exponential_decay_problem::(false); test_ode_solver_no_sens(&mut s, &problem, soln, None, false); insta::assert_yaml_snapshot!(s.get_statistics(), @r###" - --- number_of_linear_solver_setups: 11 number_of_steps: 47 number_of_error_test_failures: 0 @@ -1297,7 +1293,6 @@ mod test { number_of_nonlinear_solver_fails: 0 "###); insta::assert_yaml_snapshot!(problem.eqn.as_ref().rhs().statistics(), @r###" - --- number_of_calls: 84 number_of_jac_muls: 2 number_of_matrix_evals: 1 @@ -1311,7 +1306,6 @@ mod test { let (problem, soln) = exponential_decay_problem_sens::(false); test_ode_solver(&mut s, &problem, soln, None, false, true); insta::assert_yaml_snapshot!(s.get_statistics(), @r###" - --- number_of_linear_solver_setups: 11 number_of_steps: 44 number_of_error_test_failures: 0 @@ -1319,7 +1313,6 @@ mod test { number_of_nonlinear_solver_fails: 0 "###); insta::assert_yaml_snapshot!(problem.eqn.as_ref().rhs().statistics(), @r###" - --- number_of_calls: 87 number_of_jac_muls: 136 number_of_matrix_evals: 1 @@ -1333,14 +1326,12 @@ mod test { let (problem, soln) = exponential_decay_problem_adjoint::(); let adjoint_solver = test_ode_solver_adjoint(s, &problem, soln); insta::assert_yaml_snapshot!(problem.eqn.rhs().statistics(), @r###" - --- number_of_calls: 84 number_of_jac_muls: 6 number_of_matrix_evals: 3 number_of_jac_adj_muls: 492 "###); insta::assert_yaml_snapshot!(adjoint_solver.get_statistics(), @r###" - --- number_of_linear_solver_setups: 24 number_of_steps: 86 number_of_error_test_failures: 12 @@ -1355,14 +1346,12 @@ mod test { let (problem, soln) = exponential_decay_with_algebraic_adjoint_problem::(); let adjoint_solver = test_ode_solver_adjoint(s, &problem, soln); insta::assert_yaml_snapshot!(problem.eqn.rhs().statistics(), @r###" - --- number_of_calls: 190 number_of_jac_muls: 24 number_of_matrix_evals: 8 number_of_jac_adj_muls: 278 "###); insta::assert_yaml_snapshot!(adjoint_solver.get_statistics(), @r###" - --- number_of_linear_solver_setups: 32 number_of_steps: 74 number_of_error_test_failures: 15 @@ -1377,7 +1366,6 @@ mod test { let (problem, soln) = exponential_decay_with_algebraic_problem::(false); test_ode_solver_no_sens(&mut s, &problem, soln, None, false); insta::assert_yaml_snapshot!(s.get_statistics(), @r###" - --- number_of_linear_solver_setups: 20 number_of_steps: 41 number_of_error_test_failures: 4 @@ -1385,7 +1373,6 @@ mod test { number_of_nonlinear_solver_fails: 0 "###); insta::assert_yaml_snapshot!(problem.eqn.as_ref().rhs().statistics(), @r###" - --- number_of_calls: 83 number_of_jac_muls: 6 number_of_matrix_evals: 2 @@ -1408,7 +1395,6 @@ mod test { let (problem, soln) = exponential_decay_with_algebraic_problem_sens::(); test_ode_solver(&mut s, &problem, soln, None, false, true); insta::assert_yaml_snapshot!(s.get_statistics(), @r###" - --- number_of_linear_solver_setups: 18 number_of_steps: 43 number_of_error_test_failures: 3 @@ -1416,7 +1402,6 @@ mod test { number_of_nonlinear_solver_fails: 0 "###); insta::assert_yaml_snapshot!(problem.eqn.as_ref().rhs().statistics(), @r###" - --- number_of_calls: 71 number_of_jac_muls: 100 number_of_matrix_evals: 3 @@ -1430,7 +1415,6 @@ mod test { let (problem, soln) = robertson::(false); test_ode_solver_no_sens(&mut s, &problem, soln, None, false); insta::assert_yaml_snapshot!(s.get_statistics(), @r###" - --- number_of_linear_solver_setups: 77 number_of_steps: 316 number_of_error_test_failures: 3 @@ -1438,7 +1422,6 @@ mod test { number_of_nonlinear_solver_fails: 19 "###); insta::assert_yaml_snapshot!(problem.eqn.as_ref().rhs().statistics(), @r###" - --- number_of_calls: 725 number_of_jac_muls: 60 number_of_matrix_evals: 20 @@ -1471,10 +1454,8 @@ mod test { use diffsl::LlvmModule; use crate::ode_solver::test_models::robertson; - let mut context = crate::DiffSlContext::default(); let mut s = Bdf::default(); - robertson::robertson_diffsl_compile(&mut context); - let (problem, soln) = robertson::robertson_diffsl_problem::(&context, false); + let (problem, soln) = robertson::robertson_diffsl_problem::(); test_ode_solver_no_sens(&mut s, &problem, soln, None, false); } @@ -1484,7 +1465,6 @@ mod test { let (problem, soln) = robertson_sens::(); test_ode_solver(&mut s, &problem, soln, None, false, true); insta::assert_yaml_snapshot!(s.get_statistics(), @r###" - --- number_of_linear_solver_setups: 160 number_of_steps: 410 number_of_error_test_failures: 4 @@ -1492,7 +1472,6 @@ mod test { number_of_nonlinear_solver_fails: 81 "###); insta::assert_yaml_snapshot!(problem.eqn.as_ref().rhs().statistics(), @r###" - --- number_of_calls: 996 number_of_jac_muls: 2495 number_of_matrix_evals: 71 @@ -1506,7 +1485,6 @@ mod test { let (problem, soln) = robertson::(true); test_ode_solver_no_sens(&mut s, &problem, soln, None, false); insta::assert_yaml_snapshot!(s.get_statistics(), @r###" - --- number_of_linear_solver_setups: 77 number_of_steps: 316 number_of_error_test_failures: 3 @@ -1514,7 +1492,6 @@ mod test { number_of_nonlinear_solver_fails: 19 "###); insta::assert_yaml_snapshot!(problem.eqn.as_ref().rhs().statistics(), @r###" - --- number_of_calls: 725 number_of_jac_muls: 63 number_of_matrix_evals: 20 @@ -1528,7 +1505,6 @@ mod test { let (problem, soln) = robertson_ode::(false, 3); test_ode_solver_no_sens(&mut s, &problem, soln, None, false); insta::assert_yaml_snapshot!(s.get_statistics(), @r###" - --- number_of_linear_solver_setups: 86 number_of_steps: 416 number_of_error_test_failures: 1 @@ -1536,7 +1512,6 @@ mod test { number_of_nonlinear_solver_fails: 15 "###); insta::assert_yaml_snapshot!(problem.eqn.as_ref().rhs().statistics(), @r###" - --- number_of_calls: 913 number_of_jac_muls: 162 number_of_matrix_evals: 18 @@ -1550,7 +1525,6 @@ mod test { let (problem, soln) = robertson_ode_with_sens::(false); test_ode_solver(&mut s, &problem, soln, None, false, true); insta::assert_yaml_snapshot!(s.get_statistics(), @r###" - --- number_of_linear_solver_setups: 112 number_of_steps: 467 number_of_error_test_failures: 2 @@ -1558,7 +1532,6 @@ mod test { number_of_nonlinear_solver_fails: 49 "###); insta::assert_yaml_snapshot!(problem.eqn.as_ref().rhs().statistics(), @r###" - --- number_of_calls: 1041 number_of_jac_muls: 2672 number_of_matrix_evals: 45 @@ -1572,7 +1545,6 @@ mod test { let (problem, soln) = dydt_y2_problem::(false, 10); test_ode_solver_no_sens(&mut s, &problem, soln, None, false); insta::assert_yaml_snapshot!(s.get_statistics(), @r###" - --- number_of_linear_solver_setups: 27 number_of_steps: 161 number_of_error_test_failures: 0 @@ -1580,7 +1552,6 @@ mod test { number_of_nonlinear_solver_fails: 3 "###); insta::assert_yaml_snapshot!(problem.eqn.as_ref().rhs().statistics(), @r###" - --- number_of_calls: 357 number_of_jac_muls: 50 number_of_matrix_evals: 5 @@ -1594,7 +1565,6 @@ mod test { let (problem, soln) = dydt_y2_problem::(true, 10); test_ode_solver_no_sens(&mut s, &problem, soln, None, false); insta::assert_yaml_snapshot!(s.get_statistics(), @r###" - --- number_of_linear_solver_setups: 27 number_of_steps: 161 number_of_error_test_failures: 0 @@ -1602,7 +1572,6 @@ mod test { number_of_nonlinear_solver_fails: 3 "###); insta::assert_yaml_snapshot!(problem.eqn.as_ref().rhs().statistics(), @r###" - --- number_of_calls: 357 number_of_jac_muls: 15 number_of_matrix_evals: 5 @@ -1616,7 +1585,6 @@ mod test { let (problem, soln) = gaussian_decay_problem::(false, 10); test_ode_solver_no_sens(&mut s, &problem, soln, None, false); insta::assert_yaml_snapshot!(s.get_statistics(), @r###" - --- number_of_linear_solver_setups: 14 number_of_steps: 66 number_of_error_test_failures: 1 @@ -1624,7 +1592,6 @@ mod test { number_of_nonlinear_solver_fails: 0 "###); insta::assert_yaml_snapshot!(problem.eqn.as_ref().rhs().statistics(), @r###" - --- number_of_calls: 132 number_of_jac_muls: 20 number_of_matrix_evals: 2 @@ -1640,7 +1607,6 @@ mod test { let (problem, soln) = head2d_problem::, 10>(); test_ode_solver_no_sens(&mut s, &problem, soln, None, false); insta::assert_yaml_snapshot!(s.get_statistics(), @r###" - --- number_of_linear_solver_setups: 21 number_of_steps: 167 number_of_error_test_failures: 0 @@ -1648,7 +1614,6 @@ mod test { number_of_nonlinear_solver_fails: 0 "###); insta::assert_yaml_snapshot!(problem.eqn.as_ref().rhs().statistics(), @r###" - --- number_of_calls: 333 number_of_jac_muls: 128 number_of_matrix_evals: 4 @@ -1661,26 +1626,22 @@ mod test { fn test_bdf_faer_sparse_heat2d_diffsl() { use diffsl::LlvmModule; - use crate::ode_solver::test_models::heat2d::{self, heat2d_diffsl_compile}; + use crate::ode_solver::test_models::heat2d; let linear_solver = FaerSparseLU::default(); let nonlinear_solver = NewtonNonlinearSolver::new(linear_solver); - let mut context = crate::DiffSlContext::default(); let mut s = Bdf::, _, _>::new(nonlinear_solver); - heat2d_diffsl_compile::, LlvmModule, 10>(&mut context); - let (problem, soln) = heat2d::heat2d_diffsl_problem(&context); + let (problem, soln) = heat2d::heat2d_diffsl_problem::, LlvmModule, 10>(); test_ode_solver_no_sens(&mut s, &problem, soln, None, false); } #[test] fn test_bdf_faer_sparse_foodweb() { - let foodweb_context = FoodWebContext::default(); let linear_solver = FaerSparseLU::default(); let nonlinear_solver = NewtonNonlinearSolver::new(linear_solver); let mut s = Bdf::, _, _>::new(nonlinear_solver); - let (problem, soln) = foodweb_problem::, 10>(&foodweb_context); + let (problem, soln) = foodweb_problem::, 10>(); test_ode_solver_no_sens(&mut s, &problem, soln, None, false); insta::assert_yaml_snapshot!(s.get_statistics(), @r###" - --- number_of_linear_solver_setups: 45 number_of_steps: 161 number_of_error_test_failures: 2 @@ -1695,12 +1656,11 @@ mod test { use diffsl::LlvmModule; use crate::ode_solver::test_models::foodweb; - let mut context = crate::DiffSlContext::default(); let linear_solver = FaerSparseLU::default(); let nonlinear_solver = NewtonNonlinearSolver::new(linear_solver); let mut s = Bdf::, _, _>::new(nonlinear_solver); - foodweb::foodweb_diffsl_compile::, LlvmModule, 10>(&mut context); - let (problem, soln) = foodweb::foodweb_diffsl_problem(&context); + let (problem, soln) = + foodweb::foodweb_diffsl_problem::, LlvmModule, 10>(); test_ode_solver_no_sens(&mut s, &problem, soln, None, false); } diff --git a/src/ode_solver/builder.rs b/src/ode_solver/builder.rs index 6236b6b..f677bc1 100644 --- a/src/ode_solver/builder.rs +++ b/src/ode_solver/builder.rs @@ -225,7 +225,7 @@ impl OdeBuilder { nstates: usize, nout: Option, nparam: usize, - ) -> Result<(V, Option, Option, Option), DiffsolError> { + ) -> Result<(Rc, Option>, Option>, Option>), DiffsolError> { let atol = Self::build_atol(atol, nstates, "states")?; let out_atol = match out_atol { Some(out_atol) => Some(Self::build_atol(out_atol, nout.unwrap_or(0), "output")?), @@ -239,7 +239,12 @@ impl OdeBuilder { Some(sens_atol) => Some(Self::build_atol(sens_atol, nstates, "sensitivity")?), None => None, }; - Ok((atol, sens_atol, out_atol, param_atol)) + Ok(( + Rc::new(atol), + sens_atol.map(Rc::new), + out_atol.map(Rc::new), + param_atol.map(Rc::new), + )) } fn build_p(p: Vec) -> V { @@ -322,9 +327,7 @@ impl OdeBuilder { rhs.calculate_sparsity(&y0, t0); mass.calculate_sparsity(t0); } - let mass = Some(Rc::new(mass)); - let rhs = Rc::new(rhs); - let init = Rc::new(init); + let mass = Some(mass); let nparams = p.len(); let (atol, sens_atol, out_atol, param_atol) = Self::build_atols( self.atol, @@ -337,7 +340,7 @@ impl OdeBuilder { )?; let eqn = OdeSolverEquations::new(rhs, mass, None, init, None, p); OdeSolverProblem::new( - eqn, + Rc::new(eqn), M::T::from(self.rtol), atol, self.sens_rtol.map(M::T::from), @@ -398,10 +401,8 @@ impl OdeBuilder { rhs.calculate_sparsity(&y0, t0); mass.calculate_sparsity(t0); } - let mass = Some(Rc::new(mass)); - let rhs = Rc::new(rhs); - let init = Rc::new(init); - let out = Some(Rc::new(out)); + let mass = Some(mass); + let out = Some(out); let eqn = OdeSolverEquations::new(rhs, mass, None, init, out, p); let (atol, sens_atol, out_atol, param_atol) = Self::build_atols( self.atol, @@ -413,7 +414,7 @@ impl OdeBuilder { nparams, )?; OdeSolverProblem::new( - eqn, + Rc::new(eqn), M::T::from(self.rtol), atol, self.sens_rtol.map(M::T::from), @@ -484,8 +485,6 @@ impl OdeBuilder { if self.use_coloring || M::is_sparse() { rhs.calculate_sparsity(&y0, t0); } - let rhs = Rc::new(rhs); - let init = Rc::new(init); let nparams = p.len(); let eqn = OdeSolverEquations::new(rhs, None, None, init, None, p); let (atol, sens_atol, out_atol, param_atol) = Self::build_atols( @@ -498,7 +497,7 @@ impl OdeBuilder { nparams, )?; OdeSolverProblem::new( - eqn, + Rc::new(eqn), M::T::from(self.rtol), atol, self.sens_rtol.map(M::T::from), @@ -574,8 +573,6 @@ impl OdeBuilder { rhs.calculate_jacobian_sparsity(&y0, t0); rhs.calculate_sens_sparsity(&y0, t0); } - let rhs = Rc::new(rhs); - let init = Rc::new(init); let nparams = p.len(); let eqn = OdeSolverEquations::new(rhs, None, None, init, None, p); let (atol, sens_atol, out_atol, param_atol) = Self::build_atols( @@ -588,7 +585,7 @@ impl OdeBuilder { nparams, )?; OdeSolverProblem::new( - eqn, + Rc::new(eqn), M::T::from(self.rtol), atol, self.sens_rtol.map(M::T::from), @@ -671,13 +668,11 @@ impl OdeBuilder { let y0 = init(&p, t0); let nstates = y0.len(); let mut rhs = Closure::new(rhs, rhs_jac, nstates, nstates, p.clone()); - let root = Rc::new(ClosureNoJac::new(root, nstates, nroots, p.clone())); + let root = ClosureNoJac::new(root, nstates, nroots, p.clone()); let init = ConstantClosure::new(init, p.clone()); if self.use_coloring || M::is_sparse() { rhs.calculate_sparsity(&y0, t0); } - let rhs = Rc::new(rhs); - let init = Rc::new(init); let nparams = p.len(); let eqn = OdeSolverEquations::new(rhs, None, Some(root), init, None, p); let (atol, sens_atol, out_atol, param_atol) = Self::build_atols( @@ -690,7 +685,7 @@ impl OdeBuilder { nparams, )?; OdeSolverProblem::new( - eqn, + Rc::new(eqn), M::T::from(self.rtol), atol, self.sens_rtol.map(M::T::from), @@ -726,7 +721,7 @@ impl OdeBuilder { } /// Build an ODE problem from a set of equations - pub fn build_from_eqn(self, eqn: Eqn) -> Result, DiffsolError> + pub fn build_from_eqn(self, mut eqn: Eqn) -> Result, DiffsolError> where Eqn: OdeEquations, { @@ -742,8 +737,10 @@ impl OdeBuilder { nout, nparams, )?; + let p = Rc::new(Self::build_p(self.p)); + eqn.set_params(p); OdeSolverProblem::new( - eqn, + Rc::new(eqn), Eqn::T::from(self.rtol), atol, self.sens_rtol.map(Eqn::T::from), @@ -757,47 +754,4 @@ impl OdeBuilder { self.integrate_out, ) } - - /// Build an ODE problem using the DiffSL language (requires either the `diffsl-cranelift` or `diffls-llvm` features). - /// The source code is provided as a string, please see the [DiffSL documentation](https://martinjrobins.github.io/diffsl/) for more information. - #[cfg(feature = "diffsl")] - pub fn build_diffsl( - self, - context: &crate::ode_solver::diffsl::DiffSlContext, - ) -> Result>, DiffsolError> - where - M: Matrix, - CG: diffsl::execution::module::CodegenModule, - { - use crate::ode_solver::diffsl; - let p = Self::build_p::(self.p); - let nparams = p.len(); - let mut eqn = diffsl::DiffSl::new(context, self.use_coloring || M::is_sparse()); - let nstates = eqn.rhs().nstates(); - let nout = eqn.out().map(|out| out.nout()); - eqn.set_params(p); - let (atol, sens_atol, out_atol, param_atol) = Self::build_atols( - self.atol, - self.sens_atol, - self.out_atol, - self.param_atol, - nstates, - nout, - nparams, - )?; - OdeSolverProblem::new( - eqn, - self.rtol, - atol, - self.sens_rtol.map(M::T::from), - sens_atol, - self.out_rtol.map(M::T::from), - out_atol, - self.param_rtol.map(M::T::from), - param_atol, - self.t0, - self.h0, - self.integrate_out, - ) - } } diff --git a/src/ode_solver/diffsl.rs b/src/ode_solver/diffsl.rs index f9e5c96..e42d77d 100644 --- a/src/ode_solver/diffsl.rs +++ b/src/ode_solver/diffsl.rs @@ -6,7 +6,7 @@ use crate::{ error::DiffsolError, find_jacobian_non_zeros, find_matrix_non_zeros, jacobian::JacobianColoring, matrix::sparsity::MatrixSparsity, op::nonlinear_op::NonLinearOpJacobian, ConstantOp, LinearOp, Matrix, NonLinearOp, OdeEquations, - Op, Vector, + OdeEquationsRef, Op, Vector, }; pub type T = f64; @@ -21,7 +21,7 @@ pub type T = f64; /// # Example /// /// ```rust -/// use diffsol::{OdeBuilder, Bdf, OdeSolverState, OdeSolverMethod, DiffSlContext, diffsl::LlvmModule}; +/// use diffsol::{OdeBuilder, Bdf, OdeSolverState, OdeSolverMethod, DiffSlContext, DiffSl, diffsl::LlvmModule}; /// /// // dy/dt = -ay /// // y(0) = 1 @@ -32,10 +32,11 @@ pub type T = f64; /// F { -a*u } /// out { u } /// ").unwrap(); +/// let eqn = DiffSl::from_context(context); /// let problem = OdeBuilder::new() /// .rtol(1e-6) /// .p([0.1]) -/// .build_diffsl(&context).unwrap(); +/// .build_from_eqn(eqn).unwrap(); /// let mut solver = Bdf::default(); /// let t = 0.4; /// let state = OdeSolverState::new(&problem, &solver).unwrap(); @@ -58,7 +59,7 @@ pub struct DiffSlContext, CG: CodegenModule> { impl, CG: CodegenModule> DiffSlContext { /// Create a new context for the ODE equations specified using the [DiffSL language](https://martinjrobins.github.io/diffsl/). - /// The input parameters are not initialized and must be set using the [OdeEquations::set_params] function before solving the ODE. + /// The input parameters are not initialized and must be set using the [Op::set_params] function before solving the ODE. pub fn new(text: &str) -> Result { let compiler = Compiler::from_discrete_str(text).map_err(|e| DiffsolError::Other(e.to_string()))?; @@ -107,120 +108,54 @@ impl, CG: CodegenModule> Default for DiffSlContext { } } -pub struct DiffSl<'a, M: Matrix, CG: CodegenModule> { - context: &'a DiffSlContext, - rhs: Rc>, - mass: Option>>, - root: Rc>, - init: Rc>, - out: Rc>, +pub struct DiffSl, CG: CodegenModule> { + context: DiffSlContext, + mass_sparsity: Option, + mass_coloring: Option>, + rhs_sparsity: Option, + rhs_coloring: Option>, } -impl<'a, M: Matrix, CG: CodegenModule> DiffSl<'a, M, CG> { - pub fn new(context: &'a DiffSlContext, use_coloring: bool) -> Self { - let rhs = Rc::new(DiffSlRhs::new(context, use_coloring)); - let mass = DiffSlMass::new(context, use_coloring).map(Rc::new); - let root = Rc::new(DiffSlRoot::new(context)); - let init = Rc::new(DiffSlInit::new(context)); - let out = Rc::new(DiffSlOut::new(context)); - Self { - context, - rhs, - mass, - root, - init, - out, - } - } -} - -pub struct DiffSlRoot<'a, M: Matrix, CG: CodegenModule> { - context: &'a DiffSlContext, -} - -pub struct DiffSlOut<'a, M: Matrix, CG: CodegenModule> { - context: &'a DiffSlContext, -} - -pub struct DiffSlRhs<'a, M: Matrix, CG: CodegenModule> { - context: &'a DiffSlContext, - coloring: Option>, - sparsity: Option, -} - -pub struct DiffSlMass<'a, M: Matrix, CG: CodegenModule> { - context: &'a DiffSlContext, - coloring: Option>, - sparsity: Option, -} - -pub struct DiffSlInit<'a, M: Matrix, CG: CodegenModule> { - context: &'a DiffSlContext, -} - -impl<'a, M: Matrix, CG: CodegenModule> DiffSlOut<'a, M, CG> { - pub fn new(context: &'a DiffSlContext) -> Self { - Self { context } - } -} - -impl<'a, M: Matrix, CG: CodegenModule> DiffSlRoot<'a, M, CG> { - pub fn new(context: &'a DiffSlContext) -> Self { - Self { context } - } -} - -impl<'a, M: Matrix, CG: CodegenModule> DiffSlInit<'a, M, CG> { - pub fn new(context: &'a DiffSlContext) -> Self { - Self { context } - } -} - -impl<'a, M: Matrix, CG: CodegenModule> DiffSlRhs<'a, M, CG> { - pub fn new(context: &'a DiffSlContext, use_coloring: bool) -> Self { +impl, CG: CodegenModule> DiffSl { + pub fn from_context(context: DiffSlContext) -> Self { let mut ret = Self { context, - coloring: None, - sparsity: None, + mass_coloring: None, + mass_sparsity: None, + rhs_coloring: None, + rhs_sparsity: None, }; - - if use_coloring { - let x0 = M::V::zeros(context.nstates); + if M::is_sparse() { + let op = ret.rhs(); let t0 = 0.0; - let non_zeros = find_jacobian_non_zeros(&ret, &x0, t0); - ret.sparsity = Some( - MatrixSparsity::try_from_indices(ret.nout(), ret.nstates(), non_zeros.clone()) - .expect("invalid sparsity pattern"), - ); - ret.coloring = Some(JacobianColoring::new_from_non_zeros(&ret, non_zeros)); + let x0 = M::V::zeros(op.nstates()); + let non_zeros = find_jacobian_non_zeros(&op, &x0, t0); + let sparsity = + M::Sparsity::try_from_indices(op.nout(), op.nstates(), non_zeros.clone()) + .expect("invalid sparsity pattern"); + let coloring = JacobianColoring::new(&sparsity, &non_zeros); + ret.rhs_coloring = Some(coloring); + ret.rhs_sparsity = Some(sparsity); + + if let Some(op) = ret.mass() { + let non_zeros = find_matrix_non_zeros(&op, t0); + let sparsity = + M::Sparsity::try_from_indices(op.nout(), op.nstates(), non_zeros.clone()) + .expect("invalid sparsity pattern"); + let coloring = JacobianColoring::new(&sparsity, &non_zeros); + ret.mass_coloring = Some(coloring); + ret.mass_sparsity = Some(sparsity); + } } ret } } -impl<'a, M: Matrix, CG: CodegenModule> DiffSlMass<'a, M, CG> { - pub fn new(context: &'a DiffSlContext, use_coloring: bool) -> Option { - if !context.compiler.has_mass() { - return None; - } - let mut ret = Self { - context, - coloring: None, - sparsity: None, - }; - - if use_coloring { - let t0 = 0.0; - let non_zeros = find_matrix_non_zeros(&ret, t0); - ret.sparsity = Some( - MatrixSparsity::try_from_indices(ret.nout(), ret.nstates(), non_zeros.clone()) - .expect("invalid sparsity pattern"), - ); - ret.coloring = Some(JacobianColoring::new_from_non_zeros(&ret, non_zeros)); - } - Some(ret) - } -} +pub struct DiffSlRoot<'a, M: Matrix, CG: CodegenModule>(&'a DiffSl); +pub struct DiffSlOut<'a, M: Matrix, CG: CodegenModule>(&'a DiffSl); +pub struct DiffSlRhs<'a, M: Matrix, CG: CodegenModule>(&'a DiffSl); +pub struct DiffSlMass<'a, M: Matrix, CG: CodegenModule>(&'a DiffSl); +pub struct DiffSlInit<'a, M: Matrix, CG: CodegenModule>(&'a DiffSl); macro_rules! impl_op_for_diffsl { ($name:ident) => { @@ -230,17 +165,14 @@ macro_rules! impl_op_for_diffsl { type V = M::V; fn nstates(&self) -> usize { - self.context.nstates + self.0.context.nstates } #[allow(clippy::misnamed_getters)] fn nout(&self) -> usize { - self.context.nstates + self.0.context.nstates } fn nparams(&self) -> usize { - self.context.nparams - } - fn sparsity(&self) -> Option<::SparsityRef<'_>> { - self.sparsity.as_ref().map(|s| s.as_ref()) + self.0.context.nparams } } }; @@ -255,14 +187,14 @@ impl, CG: CodegenModule> Op for DiffSlInit<'_, M, CG> { type V = M::V; fn nstates(&self) -> usize { - self.context.nstates + self.0.context.nstates } #[allow(clippy::misnamed_getters)] fn nout(&self) -> usize { - self.context.nstates + self.0.context.nstates } fn nparams(&self) -> usize { - self.context.nparams + self.0.context.nparams } } @@ -272,14 +204,14 @@ impl, CG: CodegenModule> Op for DiffSlRoot<'_, M, CG> { type V = M::V; fn nstates(&self) -> usize { - self.context.nstates + self.0.context.nstates } #[allow(clippy::misnamed_getters)] fn nout(&self) -> usize { - self.context.nroots + self.0.context.nroots } fn nparams(&self) -> usize { - self.context.nparams + self.0.context.nparams } } @@ -289,31 +221,31 @@ impl, CG: CodegenModule> Op for DiffSlOut<'_, M, CG> { type V = M::V; fn nstates(&self) -> usize { - self.context.nstates + self.0.context.nstates } fn nout(&self) -> usize { - self.context.nout + self.0.context.nout } fn nparams(&self) -> usize { - self.context.nparams + self.0.context.nparams } } impl, CG: CodegenModule> ConstantOp for DiffSlInit<'_, M, CG> { fn call_inplace(&self, _t: Self::T, y: &mut Self::V) { - self.context.compiler.set_u0( + self.0.context.compiler.set_u0( y.as_mut_slice(), - self.context.data.borrow_mut().as_mut_slice(), + self.0.context.data.borrow_mut().as_mut_slice(), ); } } impl, CG: CodegenModule> NonLinearOp for DiffSlRoot<'_, M, CG> { fn call_inplace(&self, x: &Self::V, t: Self::T, y: &mut Self::V) { - self.context.compiler.calc_stop( + self.0.context.compiler.calc_stop( t, x.as_slice(), - self.context.data.borrow_mut().as_mut_slice(), + self.0.context.data.borrow_mut().as_mut_slice(), y.as_mut_slice(), ); } @@ -327,42 +259,44 @@ impl, CG: CodegenModule> NonLinearOpJacobian for DiffSlRoot<'_, impl, CG: CodegenModule> NonLinearOp for DiffSlOut<'_, M, CG> { fn call_inplace(&self, x: &Self::V, t: Self::T, y: &mut Self::V) { - self.context.compiler.calc_out( + self.0.context.compiler.calc_out( t, x.as_slice(), - self.context.data.borrow_mut().as_mut_slice(), + self.0.context.data.borrow_mut().as_mut_slice(), ); let out = self + .0 .context .compiler - .get_out(self.context.data.borrow().as_slice()); + .get_out(self.0.context.data.borrow().as_slice()); y.copy_from_slice(out); } } impl, CG: CodegenModule> NonLinearOpJacobian for DiffSlOut<'_, M, CG> { fn jac_mul_inplace(&self, x: &Self::V, t: Self::T, v: &Self::V, y: &mut Self::V) { - self.context.compiler.calc_out_grad( + self.0.context.compiler.calc_out_grad( t, x.as_slice(), v.as_slice(), - self.context.data.borrow_mut().as_mut_slice(), - self.context.ddata.borrow_mut().as_mut_slice(), + self.0.context.data.borrow_mut().as_mut_slice(), + self.0.context.ddata.borrow_mut().as_mut_slice(), ); let out_grad = self + .0 .context .compiler - .get_out(self.context.ddata.borrow().as_slice()); + .get_out(self.0.context.ddata.borrow().as_slice()); y.copy_from_slice(out_grad); } } impl, CG: CodegenModule> NonLinearOp for DiffSlRhs<'_, M, CG> { fn call_inplace(&self, x: &Self::V, t: Self::T, y: &mut Self::V) { - self.context.compiler.rhs( + self.0.context.compiler.rhs( t, x.as_slice(), - self.context.data.borrow_mut().as_mut_slice(), + self.0.context.data.borrow_mut().as_mut_slice(), y.as_mut_slice(), ); } @@ -371,33 +305,36 @@ impl, CG: CodegenModule> NonLinearOp for DiffSlRhs<'_, M, CG> { impl, CG: CodegenModule> NonLinearOpJacobian for DiffSlRhs<'_, M, CG> { fn jac_mul_inplace(&self, x: &Self::V, t: Self::T, v: &Self::V, y: &mut Self::V) { let mut dummy_rhs = Self::V::zeros(self.nstates()); - self.context.compiler.rhs_grad( + self.0.context.compiler.rhs_grad( t, x.as_slice(), v.as_slice(), - self.context.data.borrow_mut().as_mut_slice(), - self.context.ddata.borrow_mut().as_mut_slice(), + self.0.context.data.borrow_mut().as_mut_slice(), + self.0.context.ddata.borrow_mut().as_mut_slice(), dummy_rhs.as_mut_slice(), y.as_mut_slice(), ); } fn jacobian_inplace(&self, x: &Self::V, t: Self::T, y: &mut Self::M) { - if let Some(coloring) = &self.coloring { + if let Some(coloring) = &self.0.rhs_coloring { coloring.jacobian_inplace(self, x, t, y); } else { self._default_jacobian_inplace(x, t, y); } } + fn jacobian_sparsity(&self) -> Option<::Sparsity> { + self.0.rhs_sparsity.clone() + } } impl, CG: CodegenModule> LinearOp for DiffSlMass<'_, M, CG> { fn gemv_inplace(&self, x: &Self::V, t: Self::T, beta: Self::T, y: &mut Self::V) { - let mut tmp = self.context.tmp.borrow_mut(); - self.context.compiler.mass( + let mut tmp = self.0.context.tmp.borrow_mut(); + self.0.context.compiler.mass( t, x.as_slice(), - self.context.data.borrow_mut().as_mut_slice(), + self.0.context.data.borrow_mut().as_mut_slice(), tmp.as_mut_slice(), ); @@ -406,37 +343,32 @@ impl, CG: CodegenModule> LinearOp for DiffSlMass<'_, M, CG> { } fn matrix_inplace(&self, t: Self::T, y: &mut Self::M) { - if let Some(coloring) = &self.coloring { + if let Some(coloring) = &self.0.mass_coloring { coloring.matrix_inplace(self, t, y); } else { self._default_matrix_inplace(t, y); } } + fn sparsity(&self) -> Option<::Sparsity> { + self.0.mass_sparsity.clone() + } } -impl<'a, M: Matrix, CG: CodegenModule> OdeEquations for DiffSl<'a, M, CG> { +impl, CG: CodegenModule> Op for DiffSl { type M = M; type T = T; type V = M::V; - type Mass = DiffSlMass<'a, M, CG>; - type Rhs = DiffSlRhs<'a, M, CG>; - type Root = DiffSlRoot<'a, M, CG>; - type Init = DiffSlInit<'a, M, CG>; - type Out = DiffSlOut<'a, M, CG>; - fn rhs(&self) -> &Rc { - &self.rhs + fn nstates(&self) -> usize { + self.context.nstates } - - fn mass(&self) -> Option<&Rc> { - self.mass.as_ref() + fn nout(&self) -> usize { + self.context.nout } - - fn root(&self) -> Option<&Rc> { - Some(&self.root) + fn nparams(&self) -> usize { + self.context.nparams } - - fn set_params(&mut self, p: Self::V) { + fn set_params(&mut self, p: Rc) { // set the parameters in data self.context .compiler @@ -449,24 +381,48 @@ impl<'a, M: Matrix, CG: CodegenModule> OdeEquations for DiffSl<'a, M, CG> self.context.data.borrow_mut().as_mut_slice(), ); } +} + +impl<'a, M: Matrix, CG: CodegenModule> OdeEquationsRef<'a> for DiffSl { + type Mass = DiffSlMass<'a, M, CG>; + type Rhs = DiffSlRhs<'a, M, CG>; + type Root = DiffSlRoot<'a, M, CG>; + type Init = DiffSlInit<'a, M, CG>; + type Out = DiffSlOut<'a, M, CG>; +} + +impl, CG: CodegenModule> OdeEquations for DiffSl { + fn rhs(&self) -> DiffSlRhs<'_, M, CG> { + DiffSlRhs(self) + } - fn init(&self) -> &Rc { - &self.init + fn mass(&self) -> Option> { + self.context.compiler.has_mass().then_some(DiffSlMass(self)) } - fn out(&self) -> Option<&Rc> { - Some(&self.out) + fn root(&self) -> Option> { + Some(DiffSlRoot(self)) + } + + fn init(&self) -> DiffSlInit<'_, M, CG> { + DiffSlInit(self) + } + + fn out(&self) -> Option> { + Some(DiffSlOut(self)) } } #[cfg(test)] mod tests { + use std::rc::Rc; + use diffsl::{execution::module::CodegenModule, CraneliftModule}; use nalgebra::DVector; use crate::{ Bdf, ConstantOp, LinearOp, NonLinearOp, NonLinearOpJacobian, OdeBuilder, OdeEquations, - OdeSolverMethod, OdeSolverState, Vector, + OdeSolverMethod, OdeSolverState, Op, Vector, }; use super::{DiffSl, DiffSlContext}; @@ -512,9 +468,9 @@ mod tests { let k = 1.0; let r = 1.0; let context = DiffSlContext::, CG>::new(text).unwrap(); - let mut eqn = DiffSl::new(&context, false); let p = DVector::from_vec(vec![r, k]); - eqn.set_params(p); + let mut eqn = DiffSl::from_context(context); + eqn.set_params(Rc::new(p)); // test that the initial values look ok let y0 = 0.1; @@ -535,7 +491,7 @@ mod tests { mass_y.assert_eq_st(&mass_y_expect, 1e-10); // solver a bit and check the state and output - let problem = OdeBuilder::new().p([r, k]).build_diffsl(&context).unwrap(); + let problem = OdeBuilder::new().p([r, k]).build_from_eqn(eqn).unwrap(); let mut solver = Bdf::default(); let t = 1.0; let state = OdeSolverState::new(&problem, &solver).unwrap(); diff --git a/src/ode_solver/equations.rs b/src/ode_solver/equations.rs index 0f16b9f..53b3d4c 100644 --- a/src/ode_solver/equations.rs +++ b/src/ode_solver/equations.rs @@ -3,7 +3,7 @@ use std::rc::Rc; use crate::{ op::{constant_op::ConstantOpSensAdjoint, linear_op::LinearOpTranspose}, ConstantOp, ConstantOpSens, LinearOp, Matrix, NonLinearOp, NonLinearOpAdjoint, - NonLinearOpJacobian, NonLinearOpSens, NonLinearOpSensAdjoint, Scalar, UnitCallable, Vector, + NonLinearOpJacobian, NonLinearOpSens, NonLinearOpSensAdjoint, Op, UnitCallable, }; use serde::Serialize; @@ -49,7 +49,7 @@ pub trait AugmentedOdeEquations: fn out_atol(&self) -> Option<&Rc>; } -pub trait AugmentedOdeEquationsImplicit: +pub trait AugmentedOdeEquationsImplicit: AugmentedOdeEquations + OdeEquationsImplicit { } @@ -57,7 +57,7 @@ pub trait AugmentedOdeEquationsImplicit: impl AugmentedOdeEquationsImplicit for Aug where Aug: AugmentedOdeEquations + OdeEquationsImplicit, - Eqn: OdeEquations, + Eqn: OdeEquationsImplicit, { } @@ -65,42 +65,63 @@ pub struct NoAug { _phantom: std::marker::PhantomData, } -impl OdeEquations for NoAug { +impl Op for NoAug +where + Eqn: OdeEquations, +{ type T = Eqn::T; type V = Eqn::V; type M = Eqn::M; - type Mass = Eqn::Mass; - type Rhs = Eqn::Rhs; - type Root = Eqn::Root; - type Init = Eqn::Init; - type Out = Eqn::Out; - fn set_params(&mut self, _p: Self::V) { + fn nout(&self) -> usize { + panic!("This should never be called") + } + fn nparams(&self) -> usize { + panic!("This should never be called") + } + fn nstates(&self) -> usize { + panic!("This should never be called") + } + fn statistics(&self) -> crate::op::OpStatistics { + panic!("This should never be called") + } + + fn set_params(&mut self, _p: Rc) { panic!("This should never be called") } +} - fn rhs(&self) -> &Rc { +impl<'a, Eqn: OdeEquations> OdeEquationsRef<'a> for NoAug { + type Mass = >::Mass; + type Rhs = >::Rhs; + type Root = >::Root; + type Init = >::Init; + type Out = >::Out; +} + +impl OdeEquations for NoAug { + fn rhs(&self) -> >::Rhs { panic!("This should never be called") } - fn mass(&self) -> Option<&Rc> { + fn mass(&self) -> Option<>::Mass> { panic!("This should never be called") } - fn root(&self) -> Option<&Rc> { + fn root(&self) -> Option<>::Root> { panic!("This should never be called") } - fn out(&self) -> Option<&Rc> { + fn out(&self) -> Option<>::Out> { panic!("This should never be called") } - fn init(&self) -> &Rc { + fn init(&self) -> >::Init { panic!("This should never be called") } } -impl AugmentedOdeEquations for NoAug { +impl AugmentedOdeEquations for NoAug { fn update_rhs_out_state(&mut self, _y: &Eqn::V, _dy: &Eqn::V, _t: Eqn::T) { panic!("This should never be called") } @@ -110,19 +131,19 @@ impl AugmentedOdeEquations for NoAug { fn set_index(&mut self, _index: usize) { panic!("This should never be called") } - fn atol(&self) -> Option<&Rc<::V>> { + fn atol(&self) -> Option<&Rc<::V>> { panic!("This should never be called") } fn include_out_in_error_control(&self) -> bool { panic!("This should never be called") } - fn out_atol(&self) -> Option<&Rc<::V>> { + fn out_atol(&self) -> Option<&Rc<::V>> { panic!("This should never be called") } - fn out_rtol(&self) -> Option<::T> { + fn out_rtol(&self) -> Option<::T> { panic!("This should never be called") } - fn rtol(&self) -> Option<::T> { + fn rtol(&self) -> Option<::T> { panic!("This should never be called") } fn max_index(&self) -> usize { @@ -133,7 +154,8 @@ impl AugmentedOdeEquations for NoAug { } } -/// this is the trait that defines the ODE equations of the form +/// this is the reference trait that defines the ODE equations of the form, this is used to define the ODE equations for a given lifetime. +/// See [OdeEquations] for the main trait that defines the ODE equations. /// /// $$ /// M \frac{dy}{dt} = F(t, y) @@ -141,45 +163,63 @@ impl AugmentedOdeEquations for NoAug { /// $$ /// /// The ODE equations are defined by: -/// - the right-hand side function `F(t, y)`, which is given as a [NonLinearOp] using the `Rhs` associated type and [Self::rhs] function, -/// - the initial condition `y_0(t_0)`, which is given using the [Self::init] function. +/// - the right-hand side function `F(t, y)`, which is given as a [NonLinearOp] using the `Rhs` associated type and [OdeEquations::rhs] function, +/// - the initial condition `y_0(t_0)`, which is given as a [ConstantOp] using the `Init` associated type and [OdeEquations::init] function. /// /// Optionally, the ODE equations can also include: -/// - the mass matrix `M` which is given as a [LinearOp] using the `Mass` associated type and the [Self::mass] function, -/// - the root function `G(t, y)` which is given as a [NonLinearOp] using the `Root` associated type and the [Self::root] function -/// - the output function `H(t, y)` which is given as a [NonLinearOp] using the `Out` associated type and the [Self::out] function -pub trait OdeEquations { - type T: Scalar; - type V: Vector; - type M: Matrix; +/// - the mass matrix `M` which is given as a [LinearOp] using the `Mass` associated type and the [OdeEquations::mass] function, +/// - the root function `G(t, y)` which is given as a [NonLinearOp] using the `Root` associated type and the [OdeEquations::root] function +/// - the output function `H(t, y)` which is given as a [NonLinearOp] using the `Out` associated type and the [OdeEquations::out] function +pub trait OdeEquationsRef<'a, ImplicitBounds: Sealed = Bounds<&'a Self>>: Op { type Mass: LinearOp; type Rhs: NonLinearOp; type Root: NonLinearOp; type Init: ConstantOp; type Out: NonLinearOp; +} - /// The parameters of the ODE equations are assumed to be constant. This function sets the parameters to the given value before solving the ODE. - /// Note that `set_params` must always be called before calling any of the other functions in this trait. - fn set_params(&mut self, p: Self::V); +// seal the trait so that users must use the provided default type for ImplicitBounds +mod sealed { + pub trait Sealed: Sized {} + pub struct Bounds(T); + impl Sealed for Bounds {} +} +use sealed::{Bounds, Sealed}; +/// this is the trait that defines the ODE equations of the form +/// +/// $$ +/// M \frac{dy}{dt} = F(t, y) +/// y(t_0) = y_0(t_0) +/// $$ +/// +/// The ODE equations are defined by: +/// - the right-hand side function `F(t, y)`, which is given as a [NonLinearOp] using the `Rhs` associated type and [OdeEquations::rhs] function, +/// - the initial condition `y_0(t_0)`, which is given as a [ConstantOp] using the `Init` associated type and [OdeEquations::init] function. +/// +/// Optionally, the ODE equations can also include: +/// - the mass matrix `M` which is given as a [LinearOp] using the `Mass` associated type and the [OdeEquations::mass] function, +/// - the root function `G(t, y)` which is given as a [NonLinearOp] using the `Root` associated type and the [OdeEquations::root] function +/// - the output function `H(t, y)` which is given as a [NonLinearOp] using the `Out` associated type and the [OdeEquations::out] function +pub trait OdeEquations: for<'a> OdeEquationsRef<'a> { /// returns the right-hand side function `F(t, y)` as a [NonLinearOp] - fn rhs(&self) -> &Rc; + fn rhs(&self) -> >::Rhs; /// returns the mass matrix `M` as a [LinearOp] - fn mass(&self) -> Option<&Rc>; + fn mass(&self) -> Option<>::Mass>; /// returns the root function `G(t, y)` as a [NonLinearOp] - fn root(&self) -> Option<&Rc> { + fn root(&self) -> Option<>::Root> { None } /// returns the output function `H(t, y)` as a [NonLinearOp] - fn out(&self) -> Option<&Rc> { + fn out(&self) -> Option<>::Out> { None } /// returns the initial condition, i.e. `y(t)`, where `t` is the initial time - fn init(&self) -> &Rc; + fn init(&self) -> >::Init; } pub trait OdeEquationsImplicit: @@ -277,17 +317,17 @@ impl OdeEquationsAdjoint for T where /// } /// /// -/// let rhs = Rc::new(MyProblem); +/// let rhs = MyProblem; /// /// // use the provided constant closure to define the initial condition /// let init_fn = |p: &V, _t: f64| V::from_vec(vec![1.0]); -/// let init = Rc::new(ConstantClosure::new(init_fn, Rc::new(V::from_vec(vec![])))); +/// let init = ConstantClosure::new(init_fn, Rc::new(V::from_vec(vec![]))); /// /// // we don't have a mass matrix, root or output functions, so we can set to None /// // we still need to give a placeholder type for these, so we use the diffsol::UnitCallable type -/// let mass: Option>> = None; -/// let root: Option>> = None; -/// let out: Option>> = None; +/// let mass: Option> = None; +/// let root: Option> = None; +/// let out: Option> = None; /// /// let p = Rc::new(V::from_vec(vec![])); /// let eqn = OdeSolverEquations::new(rhs, mass, root, init, out, p); @@ -313,36 +353,26 @@ pub struct OdeSolverEquations< Out = UnitCallable, > where M: Matrix, - Rhs: NonLinearOp, - Mass: LinearOp, - Root: NonLinearOp, - Init: ConstantOp, - Out: NonLinearOp, { - rhs: Rc, - mass: Option>, - root: Option>, - init: Rc, - out: Option>, + rhs: Rhs, + mass: Option, + root: Option, + init: Init, + out: Option, p: Rc, } impl OdeSolverEquations where M: Matrix, - Rhs: NonLinearOp, - Mass: LinearOp, - Root: NonLinearOp, - Init: ConstantOp, - Out: NonLinearOp, { #[allow(clippy::too_many_arguments)] pub fn new( - rhs: Rc, - mass: Option>, - root: Option>, - init: Rc, - out: Option>, + rhs: Rhs, + mass: Option, + root: Option, + init: Init, + out: Option, p: Rc, ) -> Self { Self { @@ -356,7 +386,48 @@ where } } -impl OdeEquations +impl Op for OdeSolverEquations +where + M: Matrix, + Init: Op, + Rhs: Op, + Mass: Op, + Root: Op, + Out: Op, +{ + type T = M::T; + type V = M::V; + type M = M; + fn nstates(&self) -> usize { + self.init.nstates() + } + fn nout(&self) -> usize { + self.rhs.nout() + } + fn nparams(&self) -> usize { + self.rhs.nparams() + } + fn statistics(&self) -> crate::op::OpStatistics { + self.rhs.statistics() + } + fn set_params(&mut self, p: Rc) { + self.rhs.set_params(p.clone()); + self.init.set_params(p.clone()); + if let Some(mass) = self.mass.as_mut() { + mass.set_params(p.clone()); + } + if let Some(root) = self.root.as_mut() { + root.set_params(p.clone()); + } + + if let Some(out) = self.out.as_mut() { + out.set_params(p.clone()); + } + self.p = p; + } +} + +impl<'a, M, Rhs, Init, Mass, Root, Out> OdeEquationsRef<'a> for OdeSolverEquations where M: Matrix, @@ -366,47 +437,39 @@ where Init: ConstantOp, Out: NonLinearOp, { - type T = M::T; - type V = M::V; - type M = M; - type Rhs = Rhs; - type Mass = Mass; - type Root = Root; - type Init = Init; - type Out = Out; + type Rhs = &'a Rhs; + type Mass = &'a Mass; + type Root = &'a Root; + type Init = &'a Init; + type Out = &'a Out; +} - fn rhs(&self) -> &Rc { +impl OdeEquations + for OdeSolverEquations +where + M: Matrix, + Rhs: NonLinearOp, + Mass: LinearOp, + Root: NonLinearOp, + Init: ConstantOp, + Out: NonLinearOp, +{ + fn rhs(&self) -> &Rhs { &self.rhs } - fn mass(&self) -> Option<&Rc> { + fn mass(&self) -> Option<&Mass> { self.mass.as_ref() } - fn root(&self) -> Option<&Rc> { + fn root(&self) -> Option<&Root> { self.root.as_ref() } - fn init(&self) -> &Rc { + fn init(&self) -> &Init { &self.init } - fn out(&self) -> Option<&Rc> { + fn out(&self) -> Option<&Out> { self.out.as_ref() } - - fn set_params(&mut self, p: Self::V) { - self.p = Rc::new(p); - Rc::::get_mut(&mut self.rhs) - .unwrap() - .set_params(self.p.clone()); - if let Some(m) = self.mass.as_mut() { - Rc::::get_mut(m).unwrap().set_params(self.p.clone()); - } - if let Some(r) = self.root.as_mut() { - Rc::::get_mut(r).unwrap().set_params(self.p.clone()) - } - if let Some(o) = self.out.as_mut() { - Rc::::get_mut(o).unwrap().set_params(self.p.clone()) - } - } } #[cfg(test)] diff --git a/src/ode_solver/method.rs b/src/ode_solver/method.rs index 7b231fc..bfab1b4 100644 --- a/src/ode_solver/method.rs +++ b/src/ode_solver/method.rs @@ -6,10 +6,10 @@ use crate::{ matrix::default_solver::DefaultSolver, ode_solver_error, scalar::Scalar, - AdjointContext, AdjointEquations, AugmentedOdeEquations, Checkpointing, DefaultDenseMatrix, - DenseMatrix, Matrix, NewtonNonlinearSolver, NonLinearOp, OdeEquations, OdeEquationsAdjoint, - OdeEquationsSens, OdeSolverProblem, OdeSolverState, Op, SensEquations, StateRef, StateRefMut, - Vector, VectorViewMut, + AdjointContext, AdjointEquations, Checkpointing, DefaultDenseMatrix, DenseMatrix, Matrix, + NewtonNonlinearSolver, NonLinearOp, OdeEquations, OdeEquationsAdjoint, OdeEquationsSens, + OdeSolverProblem, OdeSolverState, Op, SensEquations, StateRef, StateRefMut, Vector, + VectorViewMut, }; use super::checkpointing::HermiteInterpolator; @@ -397,7 +397,6 @@ where pub trait AugmentedOdeSolverMethod: OdeSolverMethod where Eqn: OdeEquations, - AugmentedEqn: AugmentedOdeEquations, { fn set_augmented_problem( &mut self, diff --git a/src/ode_solver/mod.rs b/src/ode_solver/mod.rs index 7e8517c..504a43b 100644 --- a/src/ode_solver/mod.rs +++ b/src/ode_solver/mod.rs @@ -32,11 +32,12 @@ mod tests { use super::*; use crate::matrix::Matrix; use crate::op::unit::UnitCallable; - use crate::{ConstantOp, DefaultDenseMatrix, DefaultSolver, NonLinearOp, Op, Vector}; use crate::{ - NonLinearOpJacobian, OdeEquations, OdeEquationsAdjoint, OdeEquationsImplicit, - OdeEquationsSens, OdeSolverMethod, OdeSolverProblem, OdeSolverState, OdeSolverStopReason, + op::OpStatistics, NonLinearOpJacobian, OdeEquations, OdeEquationsAdjoint, + OdeEquationsImplicit, OdeEquationsRef, OdeEquationsSens, OdeSolverMethod, OdeSolverProblem, + OdeSolverState, OdeSolverStopReason, }; + use crate::{ConstantOp, DefaultDenseMatrix, DefaultSolver, NonLinearOp, Op, Vector}; use num_traits::One; use num_traits::Zero; @@ -402,44 +403,60 @@ mod tests { } } - impl OdeEquations for TestEqn { + impl Op for TestEqn { type T = M::T; type V = M::V; type M = M; - type Rhs = TestEqnRhs; - type Mass = UnitCallable; - type Root = UnitCallable; - type Init = TestEqnInit; - type Out = UnitCallable; + fn set_params(&mut self, _p: Rc) {} + fn nout(&self) -> usize { + 1 + } + fn nparams(&self) -> usize { + 0 + } + fn nstates(&self) -> usize { + 1 + } + fn statistics(&self) -> crate::op::OpStatistics { + OpStatistics::default() + } + } - fn set_params(&mut self, _p: Self::V) {} + impl<'a, M: Matrix> OdeEquationsRef<'a> for TestEqn { + type Rhs = &'a TestEqnRhs; + type Mass = &'a UnitCallable; + type Root = &'a UnitCallable; + type Init = &'a TestEqnInit; + type Out = &'a UnitCallable; + } - fn rhs(&self) -> &Rc { + impl OdeEquations for TestEqn { + fn rhs(&self) -> &TestEqnRhs { &self.rhs } - fn mass(&self) -> Option<&Rc> { + fn mass(&self) -> Option<&UnitCallable> { None } - fn root(&self) -> Option<&Rc> { + fn root(&self) -> Option<&UnitCallable> { None } - fn init(&self) -> &Rc { + fn init(&self) -> &TestEqnInit { &self.init } - fn out(&self) -> Option<&Rc> { + fn out(&self) -> Option<&UnitCallable> { None } } pub fn test_interpolate>>(mut s: Method) { let problem = OdeSolverProblem::new( - TestEqn::new(), + Rc::new(TestEqn::new()), M::T::from(1e-6), - M::V::from_element(1, M::T::from(1e-6)), + Rc::new(M::V::from_element(1, M::T::from(1e-6))), None, None, None, @@ -474,9 +491,9 @@ mod tests { pub fn test_state_mut>>(mut s: Method) { let problem = OdeSolverProblem::new( - TestEqn::new(), + Rc::new(TestEqn::new()), M::T::from(1e-6), - M::V::from_element(1, M::T::from(1e-6)), + Rc::new(M::V::from_element(1, M::T::from(1e-6))), None, None, None, diff --git a/src/ode_solver/problem.rs b/src/ode_solver/problem.rs index 40fa74d..45d2128 100644 --- a/src/ode_solver/problem.rs +++ b/src/ode_solver/problem.rs @@ -57,24 +57,19 @@ impl OdeSolverProblem { } #[allow(clippy::too_many_arguments)] pub(crate) fn new( - eqn: Eqn, + eqn: Rc, rtol: Eqn::T, - atol: Eqn::V, + atol: Rc, sens_rtol: Option, - sens_atol: Option, + sens_atol: Option>, out_rtol: Option, - out_atol: Option, + out_atol: Option>, param_rtol: Option, - param_atol: Option, + param_atol: Option>, t0: Eqn::T, h0: Eqn::T, integrate_out: bool, ) -> Result { - let eqn = Rc::new(eqn); - let atol = Rc::new(atol); - let out_atol = out_atol.map(Rc::new); - let param_atol = param_atol.map(Rc::new); - let sens_atol = sens_atol.map(Rc::new); Ok(Self { eqn, rtol, @@ -94,7 +89,7 @@ impl OdeSolverProblem { pub fn set_params(&mut self, p: Eqn::V) -> Result<(), DiffsolError> { let eqn = Rc::get_mut(&mut self.eqn).ok_or(ode_solver_error!(FailedToGetMutableReference))?; - eqn.set_params(p); + eqn.set_params(Rc::new(p)); Ok(()) } } diff --git a/src/ode_solver/sdirk.rs b/src/ode_solver/sdirk.rs index c666fab..38d56e9 100644 --- a/src/ode_solver/sdirk.rs +++ b/src/ode_solver/sdirk.rs @@ -23,7 +23,7 @@ use crate::SensEquations; use crate::Tableau; use crate::{ nonlinear_solver::NonLinearSolver, op::sdirk::SdirkCallable, scale, AdjointOdeSolverMethod, - AugmentedOdeEquations, DenseMatrix, JacobianUpdate, NonLinearOp, OdeEquations, + AugmentedOdeEquations, AugmentedOdeEquationsImplicit, DenseMatrix, JacobianUpdate, NonLinearOp, OdeEquationsAdjoint, OdeEquationsImplicit, OdeEquationsSens, OdeSolverMethod, OdeSolverProblem, OdeSolverState, Op, Scalar, StateRef, StateRefMut, Vector, VectorViewMut, }; @@ -166,7 +166,7 @@ where LS: LinearSolver, M: DenseMatrix, Eqn: OdeEquationsImplicit, - AugmentedEqn: AugmentedOdeEquations, + AugmentedEqn: AugmentedOdeEquationsImplicit, for<'a> &'a Eqn::V: VectorRef, for<'a> &'a Eqn::M: MatrixRef, { @@ -451,7 +451,7 @@ where LS: LinearSolver, M: DenseMatrix, Eqn: OdeEquationsImplicit, - AugmentedEqn: AugmentedOdeEquations, + AugmentedEqn: AugmentedOdeEquationsImplicit, for<'a> &'a Eqn::V: VectorRef, for<'a> &'a Eqn::M: MatrixRef, { @@ -533,7 +533,7 @@ where self.root_finder .as_ref() .unwrap() - .init(root_fn.as_ref(), &state.y, state.t); + .init(&root_fn, &state.y, state.t); } Ok(()) } @@ -828,7 +828,7 @@ where if let Some(root_fn) = self.problem.as_ref().unwrap().eqn.root() { let ret = self.root_finder.as_ref().unwrap().check_root( &|t| self.interpolate(t), - root_fn.as_ref(), + &root_fn, &self.state.as_ref().unwrap().y, self.state.as_ref().unwrap().t, ); @@ -848,7 +848,7 @@ where Ok(OdeSolverStopReason::InternalTimestep) } - fn set_stop_time(&mut self, tstop: ::T) -> Result<(), DiffsolError> { + fn set_stop_time(&mut self, tstop: ::T) -> Result<(), DiffsolError> { self.tstop = Some(tstop); if let Some(OdeSolverStopReason::TstopReached) = self.handle_tstop(tstop)? { let error = OdeSolverError::StopTimeBeforeCurrentTime { @@ -861,10 +861,7 @@ where Ok(()) } - fn interpolate_sens( - &self, - t: ::T, - ) -> Result::V>, DiffsolError> { + fn interpolate_sens(&self, t: ::T) -> Result::V>, DiffsolError> { if self.state.is_none() { return Err(ode_solver_error!(StateNotSet)); } @@ -1007,7 +1004,7 @@ where LS: LinearSolver, M: DenseMatrix, Eqn: OdeEquationsImplicit, - AugmentedEqn: AugmentedOdeEquations, + AugmentedEqn: AugmentedOdeEquationsImplicit, for<'a> &'a Eqn::V: VectorRef, for<'a> &'a Eqn::M: MatrixRef, { @@ -1130,7 +1127,6 @@ mod test { let (problem, soln) = exponential_decay_problem::(false); test_ode_solver_no_sens(&mut s, &problem, soln, None, false); insta::assert_yaml_snapshot!(s.get_statistics(), @r###" - --- number_of_linear_solver_setups: 4 number_of_steps: 29 number_of_error_test_failures: 0 @@ -1138,7 +1134,6 @@ mod test { number_of_nonlinear_solver_fails: 0 "###); insta::assert_yaml_snapshot!(problem.eqn.as_ref().rhs().statistics(), @r###" - --- number_of_calls: 118 number_of_jac_muls: 2 number_of_matrix_evals: 1 @@ -1152,7 +1147,6 @@ mod test { let (problem, soln) = exponential_decay_problem_sens::(false); test_ode_solver(&mut s, &problem, soln, None, false, true); insta::assert_yaml_snapshot!(s.get_statistics(), @r###" - --- number_of_linear_solver_setups: 7 number_of_steps: 52 number_of_error_test_failures: 0 @@ -1160,7 +1154,6 @@ mod test { number_of_nonlinear_solver_fails: 0 "###); insta::assert_yaml_snapshot!(problem.eqn.as_ref().rhs().statistics(), @r###" - --- number_of_calls: 210 number_of_jac_muls: 318 number_of_matrix_evals: 2 @@ -1174,7 +1167,6 @@ mod test { let (problem, soln) = exponential_decay_problem::(false); test_ode_solver_no_sens(&mut s, &problem, soln, None, false); insta::assert_yaml_snapshot!(s.get_statistics(), @r###" - --- number_of_linear_solver_setups: 3 number_of_steps: 13 number_of_error_test_failures: 0 @@ -1182,7 +1174,6 @@ mod test { number_of_nonlinear_solver_fails: 0 "###); insta::assert_yaml_snapshot!(problem.eqn.as_ref().rhs().statistics(), @r###" - --- number_of_calls: 86 number_of_jac_muls: 2 number_of_matrix_evals: 1 @@ -1196,7 +1187,6 @@ mod test { let (problem, soln) = exponential_decay_problem_sens::(false); test_ode_solver(&mut s, &problem, soln, None, false, true); insta::assert_yaml_snapshot!(s.get_statistics(), @r###" - --- number_of_linear_solver_setups: 5 number_of_steps: 20 number_of_error_test_failures: 0 @@ -1204,7 +1194,6 @@ mod test { number_of_nonlinear_solver_fails: 0 "###); insta::assert_yaml_snapshot!(problem.eqn.as_ref().rhs().statistics(), @r###" - --- number_of_calls: 122 number_of_jac_muls: 201 number_of_matrix_evals: 1 @@ -1218,14 +1207,12 @@ mod test { let (problem, soln) = exponential_decay_problem_adjoint::(); let adjoint_solver = test_ode_solver_adjoint(s, &problem, soln); insta::assert_yaml_snapshot!(problem.eqn.rhs().statistics(), @r###" - --- number_of_calls: 196 number_of_jac_muls: 6 number_of_matrix_evals: 3 number_of_jac_adj_muls: 599 "###); insta::assert_yaml_snapshot!(adjoint_solver.get_statistics(), @r###" - --- number_of_linear_solver_setups: 18 number_of_steps: 29 number_of_error_test_failures: 10 @@ -1240,14 +1227,12 @@ mod test { let (problem, soln) = exponential_decay_with_algebraic_adjoint_problem::(); let adjoint_solver = test_ode_solver_adjoint(s, &problem, soln); insta::assert_yaml_snapshot!(problem.eqn.rhs().statistics(), @r###" - --- number_of_calls: 171 number_of_jac_muls: 12 number_of_matrix_evals: 4 number_of_jac_adj_muls: 287 "###); insta::assert_yaml_snapshot!(adjoint_solver.get_statistics(), @r###" - --- number_of_linear_solver_setups: 18 number_of_steps: 20 number_of_error_test_failures: 11 @@ -1262,7 +1247,6 @@ mod test { let (problem, soln) = robertson::(false); test_ode_solver_no_sens(&mut s, &problem, soln, None, false); insta::assert_yaml_snapshot!(s.get_statistics(), @r###" - --- number_of_linear_solver_setups: 97 number_of_steps: 232 number_of_error_test_failures: 0 @@ -1270,7 +1254,6 @@ mod test { number_of_nonlinear_solver_fails: 18 "###); insta::assert_yaml_snapshot!(problem.eqn.as_ref().rhs().statistics(), @r###" - --- number_of_calls: 1924 number_of_jac_muls: 36 number_of_matrix_evals: 12 @@ -1284,7 +1267,6 @@ mod test { let (problem, soln) = robertson_sens::(); test_ode_solver(&mut s, &problem, soln, None, false, true); insta::assert_yaml_snapshot!(s.get_statistics(), @r###" - --- number_of_linear_solver_setups: 112 number_of_steps: 216 number_of_error_test_failures: 0 @@ -1292,7 +1274,6 @@ mod test { number_of_nonlinear_solver_fails: 37 "###); insta::assert_yaml_snapshot!(problem.eqn.as_ref().rhs().statistics(), @r###" - --- number_of_calls: 1420 number_of_jac_muls: 3277 number_of_matrix_evals: 27 @@ -1306,7 +1287,6 @@ mod test { let (problem, soln) = robertson::(false); test_ode_solver_no_sens(&mut s, &problem, soln, None, false); insta::assert_yaml_snapshot!(s.get_statistics(), @r###" - --- number_of_linear_solver_setups: 100 number_of_steps: 141 number_of_error_test_failures: 0 @@ -1314,7 +1294,6 @@ mod test { number_of_nonlinear_solver_fails: 24 "###); insta::assert_yaml_snapshot!(problem.eqn.as_ref().rhs().statistics(), @r###" - --- number_of_calls: 1796 number_of_jac_muls: 54 number_of_matrix_evals: 18 @@ -1328,7 +1307,6 @@ mod test { let (problem, soln) = robertson_sens::(); test_ode_solver(&mut s, &problem, soln, None, false, true); insta::assert_yaml_snapshot!(s.get_statistics(), @r###" - --- number_of_linear_solver_setups: 114 number_of_steps: 131 number_of_error_test_failures: 0 @@ -1336,7 +1314,6 @@ mod test { number_of_nonlinear_solver_fails: 44 "###); insta::assert_yaml_snapshot!(problem.eqn.as_ref().rhs().statistics(), @r###" - --- number_of_calls: 1492 number_of_jac_muls: 3136 number_of_matrix_evals: 33 @@ -1350,7 +1327,6 @@ mod test { let (problem, soln) = robertson_ode::(false, 1); test_ode_solver_no_sens(&mut s, &problem, soln, None, false); insta::assert_yaml_snapshot!(s.get_statistics(), @r###" - --- number_of_linear_solver_setups: 113 number_of_steps: 304 number_of_error_test_failures: 1 @@ -1358,7 +1334,6 @@ mod test { number_of_nonlinear_solver_fails: 15 "###); insta::assert_yaml_snapshot!(problem.eqn.as_ref().rhs().statistics(), @r###" - --- number_of_calls: 2603 number_of_jac_muls: 39 number_of_matrix_evals: 13 diff --git a/src/ode_solver/sens_equations.rs b/src/ode_solver/sens_equations.rs index cffa77d..ea7ced6 100644 --- a/src/ode_solver/sens_equations.rs +++ b/src/ode_solver/sens_equations.rs @@ -2,9 +2,9 @@ use num_traits::Zero; use std::{cell::RefCell, rc::Rc}; use crate::{ - matrix::sparsity::MatrixSparsityRef, op::nonlinear_op::NonLinearOpJacobian, - AugmentedOdeEquations, ConstantOp, ConstantOpSens, Matrix, NonLinearOp, NonLinearOpSens, - OdeEquations, OdeEquationsSens, OdeSolverProblem, Op, Vector, + op::nonlinear_op::NonLinearOpJacobian, AugmentedOdeEquations, ConstantOp, ConstantOpSens, + Matrix, NonLinearOp, NonLinearOpSens, OdeEquations, OdeEquationsRef, OdeEquationsSens, + OdeSolverProblem, Op, Vector, }; pub struct SensInit @@ -23,11 +23,7 @@ where pub fn new(eqn: &Rc) -> Self { let nstates = eqn.rhs().nstates(); let nparams = eqn.rhs().nparams(); - let init_sens = Eqn::M::new_from_sparsity( - nstates, - nparams, - eqn.init().sparsity_sens().map(|s| s.to_owned()), - ); + let init_sens = Eqn::M::new_from_sparsity(nstates, nparams, eqn.init().sens_sparsity()); let index = 0; Self { eqn: eqn.clone(), @@ -113,7 +109,7 @@ where let rhs_sens = Eqn::M::new_from_sparsity( nstates, nparams, - eqn.rhs().sparsity_sens().map(|s| s.to_owned()), + eqn.rhs().sens_sparsity().map(|s| s.to_owned()), ); let y = RefCell::new(::zeros(nstates)); let index = RefCell::new(0); @@ -257,35 +253,34 @@ where } } -impl OdeEquations for SensEquations +impl<'a, Eqn> OdeEquationsRef<'a> for SensEquations where Eqn: OdeEquationsSens, { - type T = Eqn::T; - type V = Eqn::V; - type M = Eqn::M; - type Rhs = SensRhs; - type Mass = Eqn::Mass; - type Root = Eqn::Root; - type Init = SensInit; - type Out = Eqn::Out; + type Rhs = &'a SensRhs; + type Mass = >::Mass; + type Root = >::Root; + type Init = &'a SensInit; + type Out = >::Out; +} - fn rhs(&self) -> &Rc { +impl OdeEquations for SensEquations +where + Eqn: OdeEquationsSens, +{ + fn rhs(&self) -> &SensRhs { &self.rhs } - fn mass(&self) -> Option<&Rc> { + fn mass(&self) -> Option<>::Mass> { self.eqn.mass() } - fn root(&self) -> Option<&Rc> { + fn root(&self) -> Option<>::Root> { None } - fn init(&self) -> &Rc { + fn init(&self) -> &SensInit { &self.init } - fn set_params(&mut self, _p: Self::V) { - panic!("Not implemented for SensEquations"); - } - fn out(&self) -> Option<&Rc> { + fn out(&self) -> Option<>::Out> { None } } diff --git a/src/ode_solver/sundials.rs b/src/ode_solver/sundials.rs index 3379a09..b4eea99 100644 --- a/src/ode_solver/sundials.rs +++ b/src/ode_solver/sundials.rs @@ -19,9 +19,9 @@ use std::{ }; use crate::{ - error::*, matrix::sparsity::MatrixSparsityRef, ode_solver_error, scale, LinearOp, Matrix, - NonLinearOp, NonLinearOpJacobian, OdeEquationsImplicit, OdeSolverMethod, OdeSolverProblem, - OdeSolverState, OdeSolverStopReason, Op, SundialsMatrix, SundialsVector, Vector, + error::*, ode_solver_error, scale, LinearOp, Matrix, NonLinearOp, NonLinearOpJacobian, + OdeEquationsImplicit, OdeSolverMethod, OdeSolverProblem, OdeSolverState, OdeSolverStopReason, + Op, SundialsMatrix, SundialsVector, Vector, }; #[cfg(not(sundials_version_major = "5"))] @@ -115,17 +115,16 @@ where { fn new(eqn: Rc) -> Self { let n = eqn.rhs().nstates(); - let rhs = eqn.rhs(); - let rhs_jac_sparsity = rhs.sparsity().map(|s| MatrixSparsityRef::to_owned(&s)); + let rhs_jac_sparsity = eqn.rhs().jacobian_sparsity(); let rhs_jac = SundialsMatrix::new_from_sparsity(n, n, rhs_jac_sparsity); let mass = if let Some(mass) = eqn.mass() { - let mass_sparsity = mass.sparsity().map(|s| MatrixSparsityRef::to_owned(&s)); + let mass_sparsity = mass.sparsity(); SundialsMatrix::new_from_sparsity(n, n, mass_sparsity) } else { let ones = SundialsVector::from_element(n, 1.0); SundialsMatrix::from_diagonal(&ones) }; - Self { eqn, rhs_jac, mass } + Self { rhs_jac, mass, eqn } } } @@ -466,10 +465,8 @@ mod test { use crate::{ ode_solver::{ test_models::{ - exponential_decay::exponential_decay_problem, - foodweb::{foodweb_problem, FoodWebContext}, - heat2d::head2d_problem, - robertson::robertson, + exponential_decay::exponential_decay_problem, foodweb::foodweb_problem, + heat2d::head2d_problem, robertson::robertson, }, tests::{ test_interpolate, test_no_set_problem, test_ode_solver_no_sens, test_state_mut, @@ -538,9 +535,8 @@ mod test { #[test] fn test_sundials_foodweb() { - let foodweb_context = FoodWebContext::default(); let mut s = crate::SundialsIda::default(); - let (problem, soln) = foodweb_problem::(&foodweb_context); + let (problem, soln) = foodweb_problem::(); test_ode_solver_no_sens(&mut s, &problem, soln, None, false); insta::assert_yaml_snapshot!(s.get_statistics(), @r###" --- diff --git a/src/ode_solver/test_models/exponential_decay.rs b/src/ode_solver/test_models/exponential_decay.rs index 74cbf59..aaccf7f 100644 --- a/src/ode_solver/test_models/exponential_decay.rs +++ b/src/ode_solver/test_models/exponential_decay.rs @@ -276,23 +276,21 @@ pub fn exponential_decay_problem_adjoint() -> ( rhs.calculate_jacobian_sparsity(&y0, t0); rhs.calculate_adjoint_sparsity(&y0, t0); } - let rhs = Rc::new(rhs); - let init = Rc::new(init); - let out = Some(Rc::new(out)); - let mass: Option>> = None; - let root: Option>> = None; + let out = Some(out); + let mass: Option> = None; + let root: Option> = None; let eqn = OdeSolverEquations::new(rhs, mass, root, init, out, p.clone()); let rtol = M::T::from(1e-6); - let atol = M::V::from_element(nstates, M::T::from(1e-6)); + let atol = Rc::new(M::V::from_element(nstates, M::T::from(1e-6))); let out_rtol = Some(M::T::from(1e-6)); - let out_atol = Some(M::V::from_element(nout, M::T::from(1e-6))); + let out_atol = Some(Rc::new(M::V::from_element(nout, M::T::from(1e-6)))); let param_rtol = Some(M::T::from(1e-6)); - let param_atol = Some(M::V::from_element(p.len(), M::T::from(1e-6))); + let param_atol = Some(Rc::new(M::V::from_element(p.len(), M::T::from(1e-6)))); let sens_rtol = Some(M::T::from(1e-6)); - let sens_atol = Some(M::V::from_element(nstates, M::T::from(1e-6))); + let sens_atol = Some(Rc::new(M::V::from_element(nstates, M::T::from(1e-6)))); let integrate_out = true; let problem = OdeSolverProblem::new( - eqn, + Rc::new(eqn), rtol, atol, sens_rtol, diff --git a/src/ode_solver/test_models/exponential_decay_with_algebraic.rs b/src/ode_solver/test_models/exponential_decay_with_algebraic.rs index 2e118d6..ae5053d 100644 --- a/src/ode_solver/test_models/exponential_decay_with_algebraic.rs +++ b/src/ode_solver/test_models/exponential_decay_with_algebraic.rs @@ -286,24 +286,22 @@ pub fn exponential_decay_with_algebraic_adjoint_problem() - mass.calculate_sparsity(t0); mass.calculate_adjoint_sparsity(t0); } - let rhs = Rc::new(rhs); - let init = Rc::new(init); - let out = Some(Rc::new(out)); + let out = Some(out); - let root: Option>> = None; - let mass = Some(Rc::new(mass)); + let root: Option> = None; + let mass = Some(mass); let eqn = OdeSolverEquations::new(rhs, mass, root, init, out, p.clone()); let rtol = M::T::from(1e-6); - let atol = M::V::from_element(nstates, M::T::from(1e-6)); + let atol = Rc::new(M::V::from_element(nstates, M::T::from(1e-6))); let out_rtol = Some(M::T::from(1e-6)); - let out_atol = Some(M::V::from_element(nout, M::T::from(1e-6))); + let out_atol = Some(Rc::new(M::V::from_element(nout, M::T::from(1e-6)))); let param_rtol = Some(M::T::from(1e-6)); - let param_atol = Some(M::V::from_element(1, M::T::from(1e-6))); - let sens_atol = Some(M::V::from_element(nstates, M::T::from(1e-6))); + let param_atol = Some(Rc::new(M::V::from_element(1, M::T::from(1e-6)))); + let sens_atol = Some(Rc::new(M::V::from_element(nstates, M::T::from(1e-6)))); let sens_rtol = Some(M::T::from(1e-6)); let integrate_out = true; let problem = OdeSolverProblem::new( - eqn, + Rc::new(eqn), rtol, atol, sens_rtol, @@ -368,22 +366,15 @@ pub fn exponential_decay_with_algebraic_problem_sens() -> ( mass.calculate_sparsity(t0); } - let out: Option>> = None; - let root: Option>> = None; - let eqn = OdeSolverEquations::new( - Rc::new(rhs), - Some(Rc::new(mass)), - root, - Rc::new(init), - out, - p.clone(), - ); + let out: Option> = None; + let root: Option> = None; + let eqn = OdeSolverEquations::new(rhs, Some(mass), root, init, out, p.clone()); let sens_rtol = Some(M::T::from(1e-6)); - let sens_atol = Some(M::V::from_element(3, M::T::from(1e-6))); + let sens_atol = Some(Rc::new(M::V::from_element(3, M::T::from(1e-6)))); let problem = OdeSolverProblem::new( - eqn, + Rc::new(eqn), M::T::from(1e-6), - M::V::from_element(3, M::T::from(1e-6)), + Rc::new(M::V::from_element(3, M::T::from(1e-6))), sens_rtol, sens_atol, None, diff --git a/src/ode_solver/test_models/foodweb.rs b/src/ode_solver/test_models/foodweb.rs index 8a5cd09..9977dfd 100644 --- a/src/ode_solver/test_models/foodweb.rs +++ b/src/ode_solver/test_models/foodweb.rs @@ -3,8 +3,8 @@ use std::rc::Rc; use crate::{ find_jacobian_non_zeros, find_matrix_non_zeros, ode_solver::problem::OdeSolverSolution, ConstantOp, JacobianColoring, LinearOp, Matrix, MatrixSparsity, NonLinearOp, - NonLinearOpJacobian, OdeEquations, OdeEquationsImplicit, OdeSolverProblem, Op, UnitCallable, - Vector, + NonLinearOpJacobian, OdeEquations, OdeEquationsImplicit, OdeEquationsRef, OdeSolverProblem, Op, + UnitCallable, Vector, }; use num_traits::Zero; @@ -24,14 +24,18 @@ const ALPHA: f64 = 50.0; const BETA: f64 = 1000.0; #[cfg(feature = "diffsl")] -pub fn foodweb_diffsl_compile( - diffsl_context: &mut crate::DiffSlContext, -) where +#[allow(clippy::type_complexity)] +pub fn foodweb_diffsl_problem() -> ( + OdeSolverProblem>, + OdeSolverSolution, +) +where M: Matrix, CG: diffsl::execution::module::CodegenModule, { - let context = FoodWebContext::::default(); - let (problem, _soln) = foodweb_problem::(&context); + use crate::{DiffSl, DiffSlContext, OdeBuilder}; + + let (problem, _soln) = foodweb_problem::(); let u0 = problem.eqn.init().call(0.0); let diffop = FoodWebDiff::::new(&u0, 0.0); let diff = diffop.jacobian(&u0, 0.0); @@ -134,27 +138,11 @@ pub fn foodweb_diffsl_compile( n2 = 2 * NX * NX, ); - diffsl_context.recompile(code.as_str()).unwrap(); -} - -#[allow(clippy::type_complexity)] -#[cfg(feature = "diffsl")] -pub fn foodweb_diffsl_problem( - diffsl_context: &crate::DiffSlContext, -) -> ( - OdeSolverProblem + '_>, - OdeSolverSolution, -) -where - M: Matrix, - CG: diffsl::execution::module::CodegenModule, -{ - use crate::OdeBuilder; - + let eqn: DiffSl = DiffSl::from_context(DiffSlContext::new(code.as_str()).unwrap()); let problem = OdeBuilder::new() .rtol(1e-5) .atol([1e-5]) - .build_diffsl(diffsl_context) + .build_from_eqn(eqn) .unwrap(); let soln = soln::(); (problem, soln) @@ -299,7 +287,7 @@ struct FoodWebInit<'a, M, const NX: usize> where M: Matrix, { - pub context: &'a FoodWebContext, + pub foodweb: &'a FoodWeb, } // macro for bringing in constants from Context @@ -309,8 +297,8 @@ macro_rules! context_consts { where M: Matrix, { - pub fn new(context: &'a FoodWebContext) -> Self { - Self { context } + pub fn new(foodweb: &'a FoodWeb) -> Self { + Self { foodweb } } } }; @@ -328,13 +316,13 @@ macro_rules! impl_op { type T = M::T; fn nout(&self) -> usize { - self.context.nstates + self.foodweb.context.nstates } fn nparams(&self) -> usize { 0 } fn nstates(&self) -> usize { - self.context.nstates + self.foodweb.context.nstates } } }; @@ -379,27 +367,15 @@ struct FoodWebRhs<'a, M, const NX: usize> where M: Matrix, { - pub context: &'a FoodWebContext, - pub sparsity: Option, - pub coloring: Option>, + pub foodweb: &'a FoodWeb, } impl<'a, M, const NX: usize> FoodWebRhs<'a, M, NX> where M: Matrix, { - pub fn new(context: &'a FoodWebContext, y0: &M::V, t0: M::T) -> Self { - let mut ret = Self { - context, - sparsity: None, - coloring: None, - }; - let non_zeros = find_jacobian_non_zeros(&ret, y0, t0); - ret.sparsity = Some( - MatrixSparsity::try_from_indices(ret.nout(), ret.nstates(), non_zeros.clone()).unwrap(), - ); - ret.coloring = Some(JacobianColoring::new_from_non_zeros(&ret, non_zeros)); - ret + pub fn new(foodweb: &'a FoodWeb) -> Self { + Self { foodweb } } } @@ -412,16 +388,13 @@ where type T = M::T; fn nout(&self) -> usize { - self.context.nstates + self.foodweb.context.nstates } fn nparams(&self) -> usize { 0 } fn nstates(&self) -> usize { - self.context.nstates - } - fn sparsity(&self) -> Option> { - self.sparsity.as_ref().map(|s| s.as_ref()) + self.foodweb.context.nstates } } @@ -478,7 +451,7 @@ where for (is, rate) in rates.iter_mut().enumerate().take(NUM_SPECIES) { let mut dp = M::T::zero(); for js in 0..NUM_SPECIES { - dp += self.context.acoef[is][js] * x[loc + js]; + dp += self.foodweb.context.acoef[is][js] * x[loc + js]; } *rate = dp; } @@ -490,7 +463,7 @@ where ); for is in 0..NUM_SPECIES { - rates[is] = x[loc + is] * (self.context.bcoef[is] * fac + rates[is]); + rates[is] = x[loc + is] * (self.foodweb.context.bcoef[is] * fac + rates[is]); } /* Loop over species, do differencing, load crate segment. */ @@ -504,8 +477,8 @@ where let dcxui = x[locxu + is] - x[loc + is]; /* Compute the crate values at (xx,yy). */ - y[loc + is] = self.context.coy[is] * (dcyui - dcyli) - + self.context.cox[is] * (dcxui - dcxli) + y[loc + is] = self.foodweb.context.coy[is] * (dcyui - dcyli) + + self.foodweb.context.cox[is] * (dcxui - dcxli) + rates[is]; } } @@ -562,8 +535,8 @@ where let mut ddp = M::T::zero(); let mut dp = M::T::zero(); for js in 0..NUM_SPECIES { - dp += self.context.acoef[is][js] * x[loc + js]; - ddp += self.context.acoef[is][js] * v[loc + js]; + dp += self.foodweb.context.acoef[is][js] * x[loc + js]; + ddp += self.foodweb.context.acoef[is][js] * v[loc + js]; } rates[is] = dp; drates[is] = ddp; @@ -577,7 +550,7 @@ where for is in 0..NUM_SPECIES { drates[is] = x[loc + is] * drates[is] - + v[loc + is] * (self.context.bcoef[is] * fac + rates[is]); + + v[loc + is] * (self.foodweb.context.bcoef[is] * fac + rates[is]); } /* Loop over species, do differencing, load crate segment. */ @@ -591,47 +564,38 @@ where let dcxui = v[locxu + is] - v[loc + is]; /* Compute the crate values at (xx,yy). */ - y[loc + is] = self.context.coy[is] * (dcyui - dcyli) - + self.context.cox[is] * (dcxui - dcxli) + y[loc + is] = self.foodweb.context.coy[is] * (dcyui - dcyli) + + self.foodweb.context.cox[is] * (dcxui - dcxli) + drates[is]; } } } } fn jacobian_inplace(&self, x: &Self::V, t: Self::T, y: &mut Self::M) { - if let Some(coloring) = self.coloring.as_ref() { + if let Some(coloring) = self.foodweb.rhs_coloring.as_ref() { coloring.jacobian_inplace(self, x, t, y); } else { self._default_jacobian_inplace(x, t, y); } } + fn jacobian_sparsity(&self) -> Option { + self.foodweb.rhs_sparsity.clone() + } } struct FoodWebMass<'a, M, const NX: usize> where M: Matrix, { - pub context: &'a FoodWebContext, - pub sparsity: Option, - pub coloring: Option>, + pub foodweb: &'a FoodWeb, } impl<'a, M, const NX: usize> FoodWebMass<'a, M, NX> where M: Matrix, { - pub fn new(context: &'a FoodWebContext, t0: M::T) -> Self { - let mut ret = Self { - context, - sparsity: None, - coloring: None, - }; - let non_zeros = find_matrix_non_zeros(&ret, t0); - ret.sparsity = Some( - MatrixSparsity::try_from_indices(ret.nout(), ret.nstates(), non_zeros.clone()).unwrap(), - ); - ret.coloring = Some(JacobianColoring::new_from_non_zeros(&ret, non_zeros)); - ret + pub fn new(foodweb: &'a FoodWeb) -> Self { + Self { foodweb } } } @@ -644,16 +608,13 @@ where type T = M::T; fn nout(&self) -> usize { - self.context.nstates + self.foodweb.context.nstates } fn nparams(&self) -> usize { 0 } fn nstates(&self) -> usize { - self.context.nstates - } - fn sparsity(&self) -> Option> { - self.sparsity.as_ref().map(|s| s.as_ref()) + self.foodweb.context.nstates } } @@ -680,13 +641,16 @@ where } } } + fn sparsity(&self) -> Option { + self.foodweb.mass_sparsity.clone() + } } struct FoodWebOut<'a, M, const NX: usize> where M: Matrix, { - pub context: &'a FoodWebContext, + pub foodweb: &'a FoodWeb, } context_consts!(FoodWebOut); @@ -706,7 +670,7 @@ where 0 } fn nstates(&self) -> usize { - self.context.nstates + self.foodweb.context.nstates } } @@ -751,67 +715,101 @@ where } } -struct FoodWeb<'a, M, const NX: usize> +struct FoodWeb where M: Matrix, { - pub rhs: Rc>, - pub mass: Rc>, - pub init: Rc>, - pub out: Rc>, + context: FoodWebContext, + rhs_sparsity: Option, + rhs_coloring: Option>, + mass_sparsity: Option, + mass_coloring: Option>, } -impl<'a, M, const NX: usize> FoodWeb<'a, M, NX> +impl FoodWeb where M: Matrix, { - pub fn new(context: &'a FoodWebContext, t0: M::T) -> Self { - let init = FoodWebInit::new(context); + pub fn new(context: FoodWebContext, t0: M::T) -> Self { + let mut ret = Self { + context, + rhs_sparsity: None, + rhs_coloring: None, + mass_sparsity: None, + mass_coloring: None, + }; + let init = FoodWebInit::new(&ret); let y0 = init.call(t0); - let rhs = FoodWebRhs::new(context, &y0, t0); - let mass = FoodWebMass::new(context, t0); - let out = FoodWebOut::new(context); - - let init = Rc::new(init); - let rhs = Rc::new(rhs); - let mass = Rc::new(mass); - let out = Rc::new(out); - - Self { - rhs, - mass, - init, - out, - } + let rhs = FoodWebRhs::new(&ret); + let non_zeros = find_jacobian_non_zeros(&rhs, &y0, t0); + ret.rhs_sparsity = Some( + MatrixSparsity::try_from_indices(rhs.nout(), rhs.nstates(), non_zeros.clone()).unwrap(), + ); + ret.rhs_coloring = Some(JacobianColoring::new( + ret.rhs_sparsity.as_ref().unwrap(), + &non_zeros, + )); + + let mass = FoodWebMass::new(&ret); + let non_zeros = find_matrix_non_zeros(&mass, t0); + ret.mass_sparsity = Some( + MatrixSparsity::try_from_indices(mass.nout(), mass.nstates(), non_zeros.clone()) + .unwrap(), + ); + ret.mass_coloring = Some(JacobianColoring::new( + ret.mass_sparsity.as_ref().unwrap(), + &non_zeros, + )); + ret } } -impl<'a, M, const NX: usize> OdeEquations for FoodWeb<'a, M, NX> +impl Op for FoodWeb where M: Matrix, { type M = M; type V = M::V; type T = M::T; + + fn nout(&self) -> usize { + 2 * NUM_SPECIES + } + fn nparams(&self) -> usize { + 0 + } + fn nstates(&self) -> usize { + self.context.nstates + } +} + +impl<'a, M, const NX: usize> OdeEquationsRef<'a> for FoodWeb +where + M: Matrix, +{ type Init = FoodWebInit<'a, M, NX>; type Rhs = FoodWebRhs<'a, M, NX>; type Mass = FoodWebMass<'a, M, NX>; - type Root = UnitCallable; + type Root = &'a UnitCallable; type Out = FoodWebOut<'a, M, NX>; +} - fn rhs(&self) -> &Rc { - &self.rhs +impl OdeEquations for FoodWeb +where + M: Matrix, +{ + fn rhs(&self) -> FoodWebRhs<'_, M, NX> { + FoodWebRhs::new(self) } - fn init(&self) -> &Rc { - &self.init + fn init(&self) -> FoodWebInit<'_, M, NX> { + FoodWebInit::new(self) } - fn mass(&self) -> Option<&Rc> { - Some(&self.mass) + fn mass(&self) -> Option> { + Some(FoodWebMass::new(self)) } - fn out(&self) -> Option<&Rc> { - Some(&self.out) + fn out(&self) -> Option> { + Some(FoodWebOut::new(self)) } - fn set_params(&mut self, _p: Self::V) {} } #[cfg(feature = "diffsl")] @@ -855,9 +853,6 @@ where fn nstates(&self) -> usize { NX * NX } - fn sparsity(&self) -> Option> { - self.sparsity.as_ref().map(|s| s.as_ref()) - } } #[cfg(feature = "diffsl")] @@ -944,6 +939,9 @@ where } } } + fn jacobian_sparsity(&self) -> Option { + self.sparsity.clone() + } } fn soln() -> OdeSolverSolution { @@ -1020,10 +1018,8 @@ fn soln() -> OdeSolverSolution { } #[allow(clippy::type_complexity)] -pub fn foodweb_problem( - context: &FoodWebContext, -) -> ( - OdeSolverProblem + '_>, +pub fn foodweb_problem() -> ( + OdeSolverProblem>, OdeSolverSolution, ) where @@ -1033,9 +1029,21 @@ where let atol = M::V::from_element(NUM_SPECIES * NX * NX, M::T::from(1e-5)); let t0 = M::T::zero(); let h0 = M::T::from(1.0); + let context = FoodWebContext::::new(); let eqn = FoodWeb::new(context, t0); let problem = OdeSolverProblem::new( - eqn, rtol, atol, None, None, None, None, None, None, t0, h0, false, + Rc::new(eqn), + rtol, + Rc::new(atol), + None, + None, + None, + None, + None, + None, + t0, + h0, + false, ) .unwrap(); let soln = soln::(); @@ -1052,8 +1060,7 @@ mod tests { fn test_jacobian() { type M = nalgebra::DMatrix; const NX: usize = 10; - let context = FoodWebContext::::new(); - let (problem, _soln) = foodweb_problem::(&context); + let (problem, _soln) = foodweb_problem::(); let u0 = problem.eqn.init().call(0.0); let jac = problem.eqn.rhs().jacobian(&u0, 0.0); @@ -1087,15 +1094,12 @@ mod tests { type M = nalgebra::DMatrix; const NX: usize = 10; - let context = FoodWebContext::::new(); - let (problem, _soln) = foodweb_problem(&context); + let (problem, _soln) = foodweb_problem::(); let u0 = problem.eqn.init().call(0.0); let jac = problem.eqn.rhs().jacobian(&u0, 0.0); let y0 = problem.eqn.rhs().call(&u0, 0.0); - let mut diffsl_context = crate::DiffSlContext::default(); - foodweb_diffsl_compile::(&mut diffsl_context); - let (problem_diffsl, _soln) = foodweb_diffsl_problem(&diffsl_context); + let (problem_diffsl, _soln) = foodweb_diffsl_problem::(); let u0_diffsl = problem_diffsl.eqn.init().call(0.0); for i in 0..u0.len() { let i_diffsl = if i % NUM_SPECIES >= NPREY { @@ -1157,8 +1161,7 @@ mod tests { fn test_mass() { type M = nalgebra::DMatrix; const NX: usize = 10; - let context = FoodWebContext::::new(); - let (problem, _soln) = foodweb_problem::(&context); + let (problem, _soln) = foodweb_problem::(); let mass = problem.eqn.mass().unwrap().matrix(0.0); for i in 0..mass.ncols() { for j in 0..mass.nrows() { diff --git a/src/ode_solver/test_models/heat2d.rs b/src/ode_solver/test_models/heat2d.rs index f8b7657..17db675 100644 --- a/src/ode_solver/test_models/heat2d.rs +++ b/src/ode_solver/test_models/heat2d.rs @@ -17,13 +17,17 @@ use num_traits::{One, Zero}; use crate::{ConstantOp, LinearOp, NonLinearOpJacobian, OdeEquations}; #[cfg(feature = "diffsl")] -pub fn heat2d_diffsl_compile< - M: Matrix + 'static, +#[allow(clippy::type_complexity)] +pub fn heat2d_diffsl_problem< + M: Matrix, CG: diffsl::execution::module::CodegenModule, const MGRID: usize, ->( - context: &mut crate::DiffSlContext, +>() -> ( + OdeSolverProblem>, + OdeSolverSolution, ) { + use crate::{DiffSl, DiffSlContext}; + let (problem, _soln) = head2d_problem::(); let u0 = problem.eqn.init().call(0.0); let jac = problem.eqn.rhs().jacobian(&u0, 0.0); @@ -87,24 +91,12 @@ pub fn heat2d_diffsl_compile< dx2 = (1.0 / (MGRID as f64 - 1.0)).powi(2), ); - context.recompile(code.as_str()).unwrap(); -} - -#[allow(clippy::type_complexity)] -#[cfg(feature = "diffsl")] -pub fn heat2d_diffsl_problem< - M: Matrix + 'static, - CG: diffsl::execution::module::CodegenModule, ->( - context: &crate::DiffSlContext, -) -> ( - OdeSolverProblem + '_>, - OdeSolverSolution, -) { + let context: DiffSlContext = DiffSlContext::new(code.as_str()).unwrap(); + let eqn = DiffSl::from_context(context); let problem = OdeBuilder::new() .rtol(1e-7) .atol([1e-7]) - .build_diffsl(context) + .build_from_eqn(eqn) .unwrap(); let soln = soln::(); (problem, soln) @@ -334,9 +326,7 @@ mod tests { use diffsl::CraneliftModule; use faer::Col; - let mut context = crate::DiffSlContext::default(); - heat2d_diffsl_compile::, CraneliftModule, 5>(&mut context); - let (problem, _soln) = heat2d_diffsl_problem(&context); + let (problem, _soln) = heat2d_diffsl_problem::, CraneliftModule, 5>(); let u = Col::from_vec(vec![ 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0, diff --git a/src/ode_solver/test_models/robertson.rs b/src/ode_solver/test_models/robertson.rs index f11b82a..51799ae 100644 --- a/src/ode_solver/test_models/robertson.rs +++ b/src/ode_solver/test_models/robertson.rs @@ -10,12 +10,16 @@ use crate::{ use num_traits::Zero; #[cfg(feature = "diffsl")] -pub fn robertson_diffsl_compile< - M: Matrix + 'static, +#[allow(clippy::type_complexity)] +pub fn robertson_diffsl_problem< + M: Matrix, CG: diffsl::execution::module::CodegenModule, ->( - context: &mut crate::DiffSlContext, +>() -> ( + OdeSolverProblem>, + OdeSolverSolution, ) { + use crate::{DiffSl, DiffSlContext}; + let code = " in = [k1, k2, k3] k1 { 0.04 } @@ -47,27 +51,13 @@ pub fn robertson_diffsl_compile< z, }"; - context.recompile(code).unwrap(); -} - -#[allow(clippy::type_complexity)] -#[cfg(feature = "diffsl")] -pub fn robertson_diffsl_problem< - M: Matrix + 'static, - CG: diffsl::execution::module::CodegenModule, ->( - context: &crate::DiffSlContext, - use_coloring: bool, -) -> ( - OdeSolverProblem + '_>, - OdeSolverSolution, -) { + let context = DiffSlContext::::new(code).unwrap(); + let eqn = DiffSl::from_context(context); let problem = OdeBuilder::new() .p([0.04, 1.0e4, 3.0e7]) .rtol(1e-4) .atol([1.0e-8, 1.0e-6, 1.0e-6]) - .use_coloring(use_coloring) - .build_diffsl(context) + .build_from_eqn(eqn) .unwrap(); let mut soln = soln::(); soln.rtol = problem.rtol; @@ -197,22 +187,15 @@ pub fn robertson_sens() -> ( mass.calculate_sparsity(t0); } - let out: Option>> = None; - let root: Option>> = None; - let eqn = OdeSolverEquations::new( - Rc::new(rhs), - Some(Rc::new(mass)), - root, - Rc::new(init), - out, - p.clone(), - ); + let out: Option> = None; + let root: Option> = None; + let eqn = OdeSolverEquations::new(rhs, Some(mass), root, init, out, p.clone()); let rtol = M::T::from(1e-4); let atol = M::V::from_vec(vec![M::T::from(1e-8), M::T::from(1e-6), M::T::from(1e-6)]); let problem = OdeSolverProblem::new( - eqn, + Rc::new(eqn), rtol, - atol, + Rc::new(atol), None, None, None, diff --git a/src/op/bdf.rs b/src/op/bdf.rs index d211d31..17cf143 100644 --- a/src/op/bdf.rs +++ b/src/op/bdf.rs @@ -1,7 +1,7 @@ use crate::{ matrix::DenseMatrix, ode_solver::equations::OdeEquationsImplicit, scale, LinearOp, Matrix, - MatrixRef, MatrixSparsity, MatrixSparsityRef, NonLinearOp, NonLinearOpJacobian, - OdeSolverProblem, Op, Vector, VectorRef, + MatrixRef, MatrixSparsity, NonLinearOp, NonLinearOpJacobian, OdeSolverProblem, Op, Vector, + VectorRef, }; use num_traits::{One, Zero}; use std::ops::MulAssign; @@ -82,26 +82,25 @@ impl BdfCallable { let tmp = RefCell::new(::zeros(n)); // create the mass and rhs jacobians according to the sparsity pattern - let rhs_jac_sparsity = eqn.rhs().sparsity(); + let rhs_jac_sparsity = eqn.rhs().jacobian_sparsity(); let rhs_jac = RefCell::new(Eqn::M::new_from_sparsity( n, n, rhs_jac_sparsity.map(|s| s.to_owned()), )); - let sparsity = if let Some(rhs_jac_sparsity) = eqn.rhs().sparsity() { + let sparsity = if let Some(rhs_jac_sparsity) = eqn.rhs().jacobian_sparsity() { if let Some(mass) = eqn.mass() { - // have mass, use the union of the mass and rhs jacobians sparse patterns - Some( - mass.sparsity() - .unwrap() - .to_owned() - .union(rhs_jac_sparsity) - .unwrap(), - ) + if let Some(mass_sparsity) = mass.sparsity() { + // have mass, use the union of the mass and rhs jacobians sparse patterns + Some(mass_sparsity.union(rhs_jac_sparsity.as_ref()).unwrap()) + } else { + // no mass sparsity, panic + panic!("Mass matrix must have a sparsity pattern if the rhs jacobian has a sparsity pattern"); + } } else { // no mass, use the identity let mass_sparsity = ::Sparsity::new_diagonal(n); - Some(mass_sparsity.union(rhs_jac_sparsity).unwrap()) + Some(mass_sparsity.union(rhs_jac_sparsity.as_ref()).unwrap()) } } else { None @@ -200,9 +199,6 @@ impl Op for BdfCallable { fn nparams(&self) -> usize { self.eqn.rhs().nparams() } - fn sparsity(&self) -> Option<::SparsityRef<'_>> { - self.sparsity.as_ref().map(|s| s.as_ref()) - } } // dF(y)/dp = dM/dp (y - y0 + psi) + Ms - c * df(y)/dp - c df(y)/dy s = 0 @@ -276,6 +272,9 @@ where let number_of_jac_evals = *self.number_of_jac_evals.borrow() + 1; self.number_of_jac_evals.replace(number_of_jac_evals); } + fn jacobian_sparsity(&self) -> Option<::Sparsity> { + self.sparsity.clone() + } } #[cfg(test)] diff --git a/src/op/closure.rs b/src/op/closure.rs index 04a9dcf..0554585 100644 --- a/src/op/closure.rs +++ b/src/op/closure.rs @@ -51,7 +51,10 @@ where MatrixSparsity::try_from_indices(self.nout(), self.nstates(), non_zeros.clone()) .expect("invalid sparsity pattern"), ); - self.coloring = Some(JacobianColoring::new_from_non_zeros(self, non_zeros)); + self.coloring = Some(JacobianColoring::new( + self.sparsity.as_ref().unwrap(), + &non_zeros, + )); } } @@ -77,9 +80,7 @@ where assert_eq!(p.len(), self.nparams); self.p = p; } - fn sparsity(&self) -> Option<::SparsityRef<'_>> { - self.sparsity.as_ref().map(|s| s.as_ref()) - } + fn statistics(&self) -> OpStatistics { self.statistics.borrow().clone() } @@ -115,4 +116,7 @@ where self._default_jacobian_inplace(x, t, y); } } + fn jacobian_sparsity(&self) -> Option<::Sparsity> { + self.sparsity.clone() + } } diff --git a/src/op/closure_with_adjoint.rs b/src/op/closure_with_adjoint.rs index e93c811..cb5c9bd 100644 --- a/src/op/closure_with_adjoint.rs +++ b/src/op/closure_with_adjoint.rs @@ -79,7 +79,10 @@ where MatrixSparsity::try_from_indices(self.nout(), self.nstates(), non_zeros.clone()) .expect("invalid sparsity pattern"), ); - self.coloring = Some(JacobianColoring::new_from_non_zeros(self, non_zeros)); + self.coloring = Some(JacobianColoring::new( + self.sparsity.as_ref().unwrap(), + &non_zeros, + )); } pub fn calculate_adjoint_sparsity(&mut self, y0: &M::V, t0: M::T) { @@ -88,7 +91,10 @@ where MatrixSparsity::try_from_indices(self.nstates, self.nout, non_zeros.clone()) .expect("invalid sparsity pattern"), ); - self.coloring_adjoint = Some(JacobianColoring::new_from_non_zeros(self, non_zeros)); + self.coloring_adjoint = Some(JacobianColoring::new( + self.sparsity_adjoint.as_ref().unwrap(), + &non_zeros, + )); } pub fn calculate_sens_adjoint_sparsity(&mut self, y0: &M::V, t0: M::T) { @@ -97,7 +103,10 @@ where MatrixSparsity::try_from_indices(self.nstates, self.nparams, non_zeros.clone()) .expect("invalid sparsity pattern"), ); - self.coloring_sens_adjoint = Some(JacobianColoring::new_from_non_zeros(self, non_zeros)); + self.coloring_sens_adjoint = Some(JacobianColoring::new( + self.sens_sparsity.as_ref().unwrap(), + &non_zeros, + )); } } @@ -125,15 +134,7 @@ where assert_eq!(p.len(), self.nparams); self.p = p; } - fn sparsity(&self) -> Option<::SparsityRef<'_>> { - self.sparsity.as_ref().map(|s| s.as_ref()) - } - fn sparsity_adjoint(&self) -> Option<::SparsityRef<'_>> { - self.sparsity_adjoint.as_ref().map(|s| s.as_ref()) - } - fn sparsity_sens_adjoint(&self) -> Option<::SparsityRef<'_>> { - self.sens_sparsity.as_ref().map(|s| s.as_ref()) - } + fn statistics(&self) -> OpStatistics { self.statistics.borrow().clone() } @@ -173,6 +174,9 @@ where self._default_jacobian_inplace(x, t, y); } } + fn jacobian_sparsity(&self) -> Option<::Sparsity> { + self.sparsity.clone() + } } impl NonLinearOpAdjoint for ClosureWithAdjoint @@ -195,6 +199,9 @@ where self._default_adjoint_inplace(x, t, y); } } + fn adjoint_sparsity(&self) -> Option<::Sparsity> { + self.sparsity_adjoint.clone() + } } impl NonLinearOpSensAdjoint for ClosureWithAdjoint @@ -215,4 +222,7 @@ where self._default_sens_adjoint_inplace(x, t, y); } } + fn sens_adjoint_sparsity(&self) -> Option<::Sparsity> { + self.sens_sparsity.clone() + } } diff --git a/src/op/closure_with_sens.rs b/src/op/closure_with_sens.rs index acd79cb..bbf2020 100644 --- a/src/op/closure_with_sens.rs +++ b/src/op/closure_with_sens.rs @@ -66,7 +66,10 @@ where MatrixSparsity::try_from_indices(self.nout(), self.nstates(), non_zeros.clone()) .expect("invalid sparsity pattern"), ); - self.coloring = Some(JacobianColoring::new_from_non_zeros(self, non_zeros)); + self.coloring = Some(JacobianColoring::new( + self.sparsity.as_ref().unwrap(), + &non_zeros, + )); } pub fn calculate_sens_sparsity(&mut self, y0: &M::V, t0: M::T) { let non_zeros = find_sens_non_zeros(self, y0, t0); @@ -74,7 +77,10 @@ where MatrixSparsity::try_from_indices(self.nout(), self.nparams, non_zeros.clone()) .expect("invalid sparsity pattern"), ); - self.sens_coloring = Some(JacobianColoring::new_from_non_zeros(self, non_zeros)); + self.sens_coloring = Some(JacobianColoring::new( + self.sens_sparsity.as_ref().unwrap(), + &non_zeros, + )); } } @@ -101,12 +107,7 @@ where assert_eq!(p.len(), self.nparams); self.p = p; } - fn sparsity(&self) -> Option<::SparsityRef<'_>> { - self.sparsity.as_ref().map(|x| x.as_ref()) - } - fn sparsity_sens(&self) -> Option<::SparsityRef<'_>> { - self.sens_sparsity.as_ref().map(|x| x.as_ref()) - } + fn statistics(&self) -> OpStatistics { self.statistics.borrow().clone() } @@ -144,6 +145,9 @@ where self._default_jacobian_inplace(x, t, y); } } + fn jacobian_sparsity(&self) -> Option<::Sparsity> { + self.sparsity.clone() + } } impl NonLinearOpSens for ClosureWithSens @@ -164,4 +168,7 @@ where self._default_sens_inplace(x, t, y); } } + fn sens_sparsity(&self) -> Option<::Sparsity> { + self.sens_sparsity.clone() + } } diff --git a/src/op/constant_closure_with_adjoint.rs b/src/op/constant_closure_with_adjoint.rs index 0369ad3..0d1f1f1 100644 --- a/src/op/constant_closure_with_adjoint.rs +++ b/src/op/constant_closure_with_adjoint.rs @@ -83,7 +83,7 @@ where I: Fn(&M::V, M::T) -> M::V, J: Fn(&M::V, M::T, &M::V, &mut M::V), { - fn sens_mul_transpose_inplace(&self, t: Self::T, v: &Self::V, y: &mut Self::V) { + fn sens_transpose_mul_inplace(&self, t: Self::T, v: &Self::V, y: &mut Self::V) { (self.func_sens_adjoint)(self.p.as_ref(), t, v, y); } } diff --git a/src/op/constant_op.rs b/src/op/constant_op.rs index e3b1513..1950ed4 100644 --- a/src/op/constant_op.rs +++ b/src/op/constant_op.rs @@ -1,5 +1,5 @@ use super::Op; -use crate::{Matrix, MatrixSparsityRef, Vector}; +use crate::{Matrix, Vector}; use num_traits::{One, Zero}; pub trait ConstantOp: Op { @@ -17,7 +17,7 @@ pub trait ConstantOpSens: ConstantOp { fn sens_mul_inplace(&self, _t: Self::T, _v: &Self::V, _y: &mut Self::V); /// Compute the gradient of the operator wrt a parameter vector p and store it in the matrix `y`. - /// `y` should have been previously initialised using the output of [`Op::sparsity`]. + /// `y` should have been previously initialised using the output of [Self::sens_sparsity]. /// The default implementation of this method computes the gradient using [Self::sens_mul_inplace], /// but it can be overriden for more efficient implementations. fn sens_inplace(&self, t: Self::T, y: &mut Self::M) { @@ -41,14 +41,50 @@ pub trait ConstantOpSens: ConstantOp { fn sens(&self, t: Self::T) -> Self::M { let n = self.nstates(); let m = self.nparams(); - let mut y = Self::M::new_from_sparsity(n, m, self.sparsity_sens().map(|s| s.to_owned())); + let mut y = Self::M::new_from_sparsity(n, m, self.sens_sparsity()); self.sens_inplace(t, &mut y); y } + fn sens_sparsity(&self) -> Option<::Sparsity> { + None + } } pub trait ConstantOpSensAdjoint: ConstantOp { /// Compute the product of the transpose of the gradient of F wrt a parameter vector p with a given vector `-J_p^T(x, t) * v`. /// Note that the vector v is of size nstates() and the result is of size nparam(). - fn sens_mul_transpose_inplace(&self, _t: Self::T, _v: &Self::V, _y: &mut Self::V); + fn sens_transpose_mul_inplace(&self, _t: Self::T, _v: &Self::V, _y: &mut Self::V); + + /// Compute the negative transpose of the gradient of the operator wrt a parameter vector p and return it. + /// See [Self::sens_adjoint_inplace] for a non-allocating version. + fn sens_adjoint(&self, t: Self::T) -> Self::M { + let n = self.nstates(); + let mut y = Self::M::new_from_sparsity(n, n, self.sens_adjoint_sparsity()); + self.sens_adjoint_inplace(t, &mut y); + y + } + + /// Compute the negative transpose of the gradient of the operator wrt a parameter vector p and store it in the matrix `y`. + /// `y` should have been previously initialised using the output of [Self::sens_adjoint_sparsity]. + /// The default implementation of this method computes the gradient using [Self::sens_transpose_mul_inplace], + /// but it can be overriden for more efficient implementations. + fn sens_adjoint_inplace(&self, t: Self::T, y: &mut Self::M) { + self._default_sens_adjoint_inplace(t, y); + } + + /// Default implementation of the gradient computation (this is the default for [Self::sens_adjoint_inplace]). + fn _default_sens_adjoint_inplace(&self, t: Self::T, y: &mut Self::M) { + let mut v = Self::V::zeros(self.nstates()); + let mut col = Self::V::zeros(self.nout()); + for j in 0..self.nstates() { + v[j] = Self::T::one(); + self.sens_transpose_mul_inplace(t, &v, &mut col); + y.set_column(j, &col); + v[j] = Self::T::zero(); + } + } + + fn sens_adjoint_sparsity(&self) -> Option<::Sparsity> { + None + } } diff --git a/src/op/init.rs b/src/op/init.rs index 0cf83b6..87672f4 100644 --- a/src/op/init.rs +++ b/src/op/init.rs @@ -1,5 +1,6 @@ use crate::{ - scale, LinearOp, Matrix, NonLinearOpJacobian, OdeEquationsImplicit, Vector, VectorIndex, + scale, LinearOp, Matrix, MatrixSparsityRef, NonLinearOpJacobian, OdeEquationsImplicit, Vector, + VectorIndex, }; use num_traits::{One, Zero}; use std::{cell::RefCell, rc::Rc}; @@ -83,9 +84,6 @@ impl Op for InitOp { fn nparams(&self) -> usize { self.eqn.rhs().nparams() } - fn sparsity(&self) -> Option<::SparsityRef<'_>> { - self.jac.sparsity() - } } impl NonLinearOp for InitOp { @@ -115,6 +113,10 @@ impl NonLinearOpJacobian for InitOp { fn jacobian_inplace(&self, _x: &Self::V, _t: Self::T, y: &mut Self::M) { y.copy_from(&self.jac); } + + fn jacobian_sparsity(&self) -> Option<::Sparsity> { + self.jac.sparsity().map(|x| x.to_owned()) + } } #[cfg(test)] diff --git a/src/op/linear_closure.rs b/src/op/linear_closure.rs index 6b46bcf..82416b3 100644 --- a/src/op/linear_closure.rs +++ b/src/op/linear_closure.rs @@ -47,7 +47,10 @@ where MatrixSparsity::try_from_indices(self.nout(), self.nstates(), non_zeros.clone()) .expect("invalid sparsity pattern"), ); - self.coloring = Some(JacobianColoring::new_from_non_zeros(self, non_zeros)); + self.coloring = Some(JacobianColoring::new( + self.sparsity.as_ref().unwrap(), + &non_zeros, + )); } } @@ -73,9 +76,7 @@ where assert_eq!(p.len(), self.nparams); self.p = p; } - fn sparsity(&self) -> Option<::SparsityRef<'_>> { - self.sparsity.as_ref().map(|s| s.as_ref()) - } + fn statistics(&self) -> OpStatistics { self.statistics.borrow().clone() } @@ -99,4 +100,7 @@ where self._default_matrix_inplace(t, y); } } + fn sparsity(&self) -> Option<::Sparsity> { + self.sparsity.clone() + } } diff --git a/src/op/linear_closure_with_adjoint.rs b/src/op/linear_closure_with_adjoint.rs index 594fdbd..ae328d2 100644 --- a/src/op/linear_closure_with_adjoint.rs +++ b/src/op/linear_closure_with_adjoint.rs @@ -55,7 +55,10 @@ where MatrixSparsity::try_from_indices(self.nout(), self.nstates(), non_zeros.clone()) .expect("invalid sparsity pattern"), ); - self.coloring = Some(JacobianColoring::new_from_non_zeros(self, non_zeros)); + self.coloring = Some(JacobianColoring::new( + self.sparsity.as_ref().unwrap(), + &non_zeros, + )); } pub fn calculate_adjoint_sparsity(&mut self, t0: M::T) { let non_zeros = find_transpose_non_zeros(self, t0); @@ -63,7 +66,10 @@ where MatrixSparsity::try_from_indices(self.nstates, self.nout, non_zeros.clone()) .expect("invalid sparsity pattern"), ); - self.coloring_adjoint = Some(JacobianColoring::new_from_non_zeros(self, non_zeros)); + self.coloring_adjoint = Some(JacobianColoring::new( + self.sparsity_adjoint.as_ref().unwrap(), + &non_zeros, + )); } } @@ -90,12 +96,7 @@ where assert_eq!(p.len(), self.nparams); self.p = p; } - fn sparsity(&self) -> Option<::SparsityRef<'_>> { - self.sparsity.as_ref().map(|s| s.as_ref()) - } - fn sparsity_adjoint(&self) -> Option<::SparsityRef<'_>> { - self.sparsity_adjoint.as_ref().map(|s| s.as_ref()) - } + fn statistics(&self) -> OpStatistics { self.statistics.borrow().clone() } @@ -120,6 +121,9 @@ where self._default_matrix_inplace(t, y); } } + fn sparsity(&self) -> Option<::Sparsity> { + self.sparsity.clone() + } } impl LinearOpTranspose for LinearClosureWithAdjoint @@ -138,4 +142,8 @@ where self._default_transpose_inplace(t, y); } } + + fn transpose_sparsity(&self) -> Option<::Sparsity> { + self.sparsity_adjoint.clone() + } } diff --git a/src/op/linear_op.rs b/src/op/linear_op.rs index 43cd270..3d9f343 100644 --- a/src/op/linear_op.rs +++ b/src/op/linear_op.rs @@ -1,5 +1,5 @@ use super::Op; -use crate::{Matrix, MatrixSparsityRef, Vector}; +use crate::{Matrix, Vector}; use num_traits::{One, Zero}; /// LinearOp is a trait for linear operators (i.e. they only depend linearly on the input `x`), see [crate::NonLinearOp] for a non-linear op. @@ -19,11 +19,7 @@ pub trait LinearOp: Op { /// Compute the matrix representation of the operator `A(t)` and return it. /// See [Self::matrix_inplace] for a non-allocating version. fn matrix(&self, t: Self::T) -> Self::M { - let mut y = Self::M::new_from_sparsity( - self.nstates(), - self.nstates(), - self.sparsity().map(|s| s.to_owned()), - ); + let mut y = Self::M::new_from_sparsity(self.nstates(), self.nstates(), self.sparsity()); self.matrix_inplace(t, &mut y); y } @@ -46,6 +42,10 @@ pub trait LinearOp: Op { v[j] = Self::T::zero(); } } + + fn sparsity(&self) -> Option<::Sparsity> { + None + } } pub trait LinearOpTranspose: LinearOp { @@ -76,6 +76,9 @@ pub trait LinearOpTranspose: LinearOp { v[j] = Self::T::zero(); } } + fn transpose_sparsity(&self) -> Option<::Sparsity> { + None + } } pub trait LinearOpSens: LinearOp { @@ -93,7 +96,7 @@ pub trait LinearOpSens: LinearOp { } /// Compute the gradient of the operator wrt a parameter vector p and store it in the matrix `y`. - /// `y` should have been previously initialised using the output of [`Op::sparsity`]. + /// `y` should have been previously initialised using the output of [Self::sens_sparsity]. /// The default implementation of this method computes the gradient using [Self::sens_mul_inplace], /// but it can be overriden for more efficient implementations. fn sens_inplace(&self, x: &Self::V, t: Self::T, y: &mut Self::M) { @@ -117,8 +120,12 @@ pub trait LinearOpSens: LinearOp { fn sens(&self, x: &Self::V, t: Self::T) -> Self::M { let n = self.nstates(); let m = self.nparams(); - let mut y = Self::M::new_from_sparsity(n, m, self.sparsity_sens().map(|s| s.to_owned())); + let mut y = Self::M::new_from_sparsity(n, m, self.sens_sparsity()); self.sens_inplace(x, t, &mut y); y } + + fn sens_sparsity(&self) -> Option<::Sparsity> { + None + } } diff --git a/src/op/linearise.rs b/src/op/linearise.rs index de00db5..6a6a7f8 100644 --- a/src/op/linearise.rs +++ b/src/op/linearise.rs @@ -51,9 +51,6 @@ impl Op for LinearisedOp { fn nparams(&self) -> usize { self.callable.nparams() } - fn sparsity(&self) -> Option<::SparsityRef<'_>> { - self.callable.sparsity() - } } impl LinearOp for LinearisedOp { @@ -69,4 +66,7 @@ impl LinearOp for LinearisedOp { fn matrix_inplace(&self, t: Self::T, y: &mut Self::M) { self.callable.jacobian_inplace(&self.x, t, y); } + fn sparsity(&self) -> Option<::Sparsity> { + self.callable.jacobian_sparsity() + } } diff --git a/src/op/matrix.rs b/src/op/matrix.rs index 92ea40a..be4164e 100644 --- a/src/op/matrix.rs +++ b/src/op/matrix.rs @@ -1,4 +1,4 @@ -use crate::{LinearOp, Matrix, Op}; +use crate::{LinearOp, Matrix, MatrixSparsityRef, Op}; pub struct MatrixOp { m: M, @@ -23,13 +23,13 @@ impl Op for MatrixOp { fn nparams(&self) -> usize { 0 } - fn sparsity(&self) -> Option<::SparsityRef<'_>> { - self.m.sparsity() - } } impl LinearOp for MatrixOp { fn gemv_inplace(&self, x: &Self::V, t: Self::T, beta: Self::T, y: &mut Self::V) { self.m.gemv(t, x, beta, y); } + fn sparsity(&self) -> Option<::Sparsity> { + self.m.sparsity().map(|s| s.to_owned()) + } } diff --git a/src/op/mod.rs b/src/op/mod.rs index 0c37719..5a82c46 100644 --- a/src/op/mod.rs +++ b/src/op/mod.rs @@ -1,6 +1,9 @@ use std::rc::Rc; -use crate::{LinearOp, Matrix, NonLinearOp, Scalar, Vector}; +use crate::{ + ConstantOp, ConstantOpSens, ConstantOpSensAdjoint, LinearOp, LinearOpTranspose, Matrix, + NonLinearOp, NonLinearOpAdjoint, NonLinearOpSens, NonLinearOpSensAdjoint, Scalar, Vector, +}; use nonlinear_op::NonLinearOpJacobian; use serde::Serialize; @@ -50,26 +53,6 @@ pub trait Op { assert_eq!(p.len(), self.nparams()); } - /// Return sparsity information for the jacobian or matrix (if available) - fn sparsity(&self) -> Option<::SparsityRef<'_>> { - None - } - - /// Return sparsity information for the jacobian or matrix (if available) - fn sparsity_adjoint(&self) -> Option<::SparsityRef<'_>> { - None - } - - /// Return sparsity information for the sensitivity of the operator wrt a parameter vector p (if available) - fn sparsity_sens(&self) -> Option<::SparsityRef<'_>> { - None - } - - /// Return sparsity information for the sensitivity of the operator wrt a parameter vector p (if available) - fn sparsity_sens_adjoint(&self) -> Option<::SparsityRef<'_>> { - None - } - /// Return statistics about the operator (e.g. how many times it was called, how many times the jacobian was computed, etc.) fn statistics(&self) -> OpStatistics { OpStatistics::default() @@ -124,6 +107,9 @@ impl Op for &C { fn nparams(&self) -> usize { C::nparams(*self) } + fn statistics(&self) -> OpStatistics { + C::statistics(*self) + } } impl NonLinearOp for &C { @@ -139,10 +125,98 @@ impl NonLinearOpJacobian for &C { fn jacobian_inplace(&self, x: &Self::V, t: Self::T, y: &mut Self::M) { C::jacobian_inplace(*self, x, t, y) } + fn jacobian_sparsity(&self) -> Option<::Sparsity> { + C::jacobian_sparsity(*self) + } +} + +impl NonLinearOpAdjoint for &C { + fn adjoint_inplace(&self, x: &Self::V, t: Self::T, y: &mut Self::M) { + C::adjoint_inplace(*self, x, t, y) + } + fn adjoint_sparsity(&self) -> Option<::Sparsity> { + C::adjoint_sparsity(*self) + } + fn jac_transpose_mul_inplace(&self, x: &Self::V, t: Self::T, v: &Self::V, y: &mut Self::V) { + C::jac_transpose_mul_inplace(*self, x, t, v, y) + } +} + +impl NonLinearOpSens for &C { + fn sens_mul_inplace(&self, x: &Self::V, t: Self::T, v: &Self::V, y: &mut Self::V) { + C::sens_mul_inplace(*self, x, t, v, y) + } + fn sens_inplace(&self, x: &Self::V, t: Self::T, y: &mut Self::M) { + C::sens_inplace(*self, x, t, y) + } + + fn sens_sparsity(&self) -> Option<::Sparsity> { + C::sens_sparsity(*self) + } +} + +impl NonLinearOpSensAdjoint for &C { + fn sens_transpose_mul_inplace(&self, x: &Self::V, t: Self::T, v: &Self::V, y: &mut Self::V) { + C::sens_transpose_mul_inplace(*self, x, t, v, y) + } + fn sens_adjoint_inplace(&self, x: &Self::V, t: Self::T, y: &mut Self::M) { + C::sens_adjoint_inplace(*self, x, t, y) + } + fn sens_adjoint_sparsity(&self) -> Option<::Sparsity> { + C::sens_adjoint_sparsity(*self) + } } impl LinearOp for &C { fn gemv_inplace(&self, x: &Self::V, t: Self::T, beta: Self::T, y: &mut Self::V) { C::gemv_inplace(*self, x, t, beta, y) } + fn sparsity(&self) -> Option<::Sparsity> { + C::sparsity(*self) + } + fn matrix_inplace(&self, t: Self::T, y: &mut Self::M) { + C::matrix_inplace(*self, t, y) + } +} + +impl LinearOpTranspose for &C { + fn gemv_transpose_inplace(&self, x: &Self::V, t: Self::T, beta: Self::T, y: &mut Self::V) { + C::gemv_transpose_inplace(*self, x, t, beta, y) + } + fn transpose_inplace(&self, t: Self::T, y: &mut Self::M) { + C::transpose_inplace(*self, t, y) + } + fn transpose_sparsity(&self) -> Option<::Sparsity> { + C::transpose_sparsity(*self) + } +} + +impl ConstantOp for &C { + fn call_inplace(&self, t: Self::T, y: &mut Self::V) { + C::call_inplace(*self, t, y) + } +} + +impl ConstantOpSens for &C { + fn sens_mul_inplace(&self, t: Self::T, v: &Self::V, y: &mut Self::V) { + C::sens_mul_inplace(*self, t, v, y) + } + fn sens_inplace(&self, t: Self::T, y: &mut Self::M) { + C::sens_inplace(*self, t, y) + } + fn sens_sparsity(&self) -> Option<::Sparsity> { + C::sens_sparsity(*self) + } +} + +impl ConstantOpSensAdjoint for &C { + fn sens_transpose_mul_inplace(&self, t: Self::T, v: &Self::V, y: &mut Self::V) { + C::sens_transpose_mul_inplace(*self, t, v, y) + } + fn sens_adjoint_inplace(&self, t: Self::T, y: &mut Self::M) { + C::sens_adjoint_inplace(*self, t, y) + } + fn sens_adjoint_sparsity(&self) -> Option<::Sparsity> { + C::sens_adjoint_sparsity(*self) + } } diff --git a/src/op/nonlinear_op.rs b/src/op/nonlinear_op.rs index 068567b..ef8fe09 100644 --- a/src/op/nonlinear_op.rs +++ b/src/op/nonlinear_op.rs @@ -1,5 +1,5 @@ use super::Op; -use crate::{Matrix, MatrixSparsityRef, Vector}; +use crate::{Matrix, Vector}; use num_traits::{One, Zero}; // NonLinearOp is a trait that defines a nonlinear operator or function `F` that maps an input vector `x` to an output vector `y`, (i.e. `y = F(x, t)`). @@ -34,7 +34,7 @@ pub trait NonLinearOpSens: NonLinearOp { } /// Compute the gradient of the operator wrt a parameter vector p and store it in the matrix `y`. - /// `y` should have been previously initialised using the output of [`Op::sparsity`]. + /// `y` should have been previously initialised using the output of [Self::sens_sparsity]. /// The default implementation of this method computes the gradient using [Self::sens_mul_inplace], /// but it can be overriden for more efficient implementations. fn sens_inplace(&self, x: &Self::V, t: Self::T, y: &mut Self::M) { @@ -58,10 +58,13 @@ pub trait NonLinearOpSens: NonLinearOp { fn sens(&self, x: &Self::V, t: Self::T) -> Self::M { let n = self.nstates(); let m = self.nparams(); - let mut y = Self::M::new_from_sparsity(n, m, self.sparsity_sens().map(|s| s.to_owned())); + let mut y = Self::M::new_from_sparsity(n, m, self.sens_sparsity()); self.sens_inplace(x, t, &mut y); y } + fn sens_sparsity(&self) -> Option<::Sparsity> { + None + } } pub trait NonLinearOpSensAdjoint: NonLinearOp { /// Compute the product of the negative tramspose of the gradient of F wrt a parameter vector p with a given vector `-J_p(x, t)^T * v`. @@ -71,14 +74,13 @@ pub trait NonLinearOpSensAdjoint: NonLinearOp { /// See [Self::sens_adjoint_inplace] for a non-allocating version. fn sens_adjoint(&self, x: &Self::V, t: Self::T) -> Self::M { let n = self.nstates(); - let mut y = - Self::M::new_from_sparsity(n, n, self.sparsity_sens_adjoint().map(|s| s.to_owned())); + let mut y = Self::M::new_from_sparsity(n, n, self.sens_adjoint_sparsity()); self.sens_adjoint_inplace(x, t, &mut y); y } /// Compute the negative transpose of the gradient of the operator wrt a parameter vector p and store it in the matrix `y`. - /// `y` should have been previously initialised using the output of [`Op::sparsity_sens_adjoint`]. + /// `y` should have been previously initialised using the output of [Self::sens_adjoint_sparsity]. /// The default implementation of this method computes the gradient using [Self::sens_transpose_mul_inplace], /// but it can be overriden for more efficient implementations. fn sens_adjoint_inplace(&self, x: &Self::V, t: Self::T, y: &mut Self::M) { @@ -96,6 +98,10 @@ pub trait NonLinearOpSensAdjoint: NonLinearOp { v[j] = Self::T::zero(); } } + + fn sens_adjoint_sparsity(&self) -> Option<::Sparsity> { + None + } } pub trait NonLinearOpAdjoint: NonLinearOp { /// Compute the product of the transpose of the Jacobian with a given vector `-J(x, t)^T * v`. @@ -104,7 +110,7 @@ pub trait NonLinearOpAdjoint: NonLinearOp { fn jac_transpose_mul_inplace(&self, _x: &Self::V, _t: Self::T, _v: &Self::V, _y: &mut Self::V); /// Compute the Adjoint matrix `-J^T(x, t)` of the operator and store it in the matrix `y`. - /// `y` should have been previously initialised using the output of [`Op::sparsity`]. + /// `y` should have been previously initialised using the output of [`Self::adjoint_sparsity`]. /// The default implementation of this method computes the Jacobian using [Self::jac_transpose_mul_inplace], /// but it can be overriden for more efficient implementations. fn adjoint_inplace(&self, x: &Self::V, t: Self::T, y: &mut Self::M) { @@ -127,10 +133,14 @@ pub trait NonLinearOpAdjoint: NonLinearOp { /// See [Self::adjoint_inplace] for a non-allocating version. fn adjoint(&self, x: &Self::V, t: Self::T) -> Self::M { let n = self.nstates(); - let mut y = Self::M::new_from_sparsity(n, n, self.sparsity_adjoint().map(|s| s.to_owned())); + let mut y = Self::M::new_from_sparsity(n, n, self.adjoint_sparsity()); self.adjoint_inplace(x, t, &mut y); y } + /// Return sparsity information (if available) + fn adjoint_sparsity(&self) -> Option<::Sparsity> { + None + } } pub trait NonLinearOpJacobian: NonLinearOp { /// Compute the product of the Jacobian with a given vector `J(x, t) * v`. @@ -148,13 +158,18 @@ pub trait NonLinearOpJacobian: NonLinearOp { /// See [Self::jacobian_inplace] for a non-allocating version. fn jacobian(&self, x: &Self::V, t: Self::T) -> Self::M { let n = self.nstates(); - let mut y = Self::M::new_from_sparsity(n, n, self.sparsity().map(|s| s.to_owned())); + let mut y = Self::M::new_from_sparsity(n, n, self.jacobian_sparsity()); self.jacobian_inplace(x, t, &mut y); y } + /// Return sparsity information (if available) + fn jacobian_sparsity(&self) -> Option<::Sparsity> { + None + } + /// Compute the Jacobian matrix `J(x, t)` of the operator and store it in the matrix `y`. - /// `y` should have been previously initialised using the output of [`Op::sparsity`]. + /// `y` should have been previously initialised using the output of [Self::jacobian_sparsity]. /// The default implementation of this method computes the Jacobian using [Self::jac_mul_inplace], /// but it can be overriden for more efficient implementations. fn jacobian_inplace(&self, x: &Self::V, t: Self::T, y: &mut Self::M) { diff --git a/src/op/sdirk.rs b/src/op/sdirk.rs index 5ceec27..8826d21 100644 --- a/src/op/sdirk.rs +++ b/src/op/sdirk.rs @@ -1,8 +1,8 @@ use crate::{ matrix::{MatrixRef, MatrixView}, ode_solver::equations::OdeEquations, - scale, LinearOp, Matrix, MatrixSparsity, MatrixSparsityRef, NonLinearOpJacobian, - OdeEquationsImplicit, OdeSolverProblem, Vector, VectorRef, + scale, LinearOp, Matrix, MatrixSparsity, NonLinearOpJacobian, OdeEquationsImplicit, + OdeSolverProblem, Vector, VectorRef, }; use num_traits::{One, Zero}; use std::{ @@ -28,7 +28,7 @@ pub struct SdirkCallable { sparsity: Option<::Sparsity>, } -impl SdirkCallable { +impl SdirkCallable { // y = h g(phi + c * y_s) pub fn integrate_out(&self, ys: &Eqn::V, t: Eqn::T, y: &mut Eqn::V) { self.eqn.out().unwrap().call_inplace(ys, t, y); @@ -75,22 +75,21 @@ impl SdirkCallable { let rhs_jac = RefCell::new(Eqn::M::new_from_sparsity( n, n, - eqn.rhs().sparsity().map(|s| s.to_owned()), + eqn.rhs().jacobian_sparsity(), )); - let sparsity = if let Some(rhs_jac_sparsity) = eqn.rhs().sparsity() { + let sparsity = if let Some(rhs_jac_sparsity) = eqn.rhs().jacobian_sparsity() { if let Some(mass) = eqn.mass() { - // have mass, use the union of the mass and rhs jacobians sparse patterns - Some( - mass.sparsity() - .unwrap() - .to_owned() - .union(rhs_jac_sparsity) - .unwrap(), - ) + if let Some(mass_sparsity) = mass.sparsity() { + // have mass, use the union of the mass and rhs jacobians sparse patterns + Some(mass_sparsity.union(rhs_jac_sparsity.as_ref()).unwrap()) + } else { + // no mass sparsity, panic! + panic!("Mass matrix must have a sparsity pattern if the rhs jacobian has a sparsity pattern") + } } else { // no mass, use the identity let mass_sparsity = ::Sparsity::new_diagonal(n); - Some(mass_sparsity.union(rhs_jac_sparsity).unwrap()) + Some(mass_sparsity.union(rhs_jac_sparsity.as_ref()).unwrap()) } } else { None @@ -179,12 +178,9 @@ impl Op for SdirkCallable { fn nparams(&self) -> usize { self.eqn.rhs().nparams() } - fn sparsity(&self) -> Option<::SparsityRef<'_>> { - self.sparsity.as_ref().map(|s| s.as_ref()) - } } -impl NonLinearOp for SdirkCallable +impl NonLinearOp for SdirkCallable where for<'b> &'b Eqn::V: VectorRef, for<'b> &'b Eqn::M: MatrixRef, @@ -257,6 +253,9 @@ where let number_of_jac_evals = *self.number_of_jac_evals.borrow() + 1; self.number_of_jac_evals.replace(number_of_jac_evals); } + fn jacobian_sparsity(&self) -> Option<::Sparsity> { + self.sparsity.clone() + } } #[cfg(test)] diff --git a/src/vector/sundials.rs b/src/vector/sundials.rs index 45fb8e8..2da18a6 100644 --- a/src/vector/sundials.rs +++ b/src/vector/sundials.rs @@ -44,7 +44,7 @@ impl SundialsVector { #[cfg(not(sundials_version_major = "5"))] let nv = { let ctx = get_suncontext(); - unsafe { N_VNew_Serial(len as i32, *ctx) } + unsafe { N_VNew_Serial(len as i64, *ctx) } }; #[cfg(sundials_version_major = "5")]