Skip to content

Commit

Permalink
feat: migrate from anyhow to thiserror (#84)
Browse files Browse the repository at this point in the history
* Convert from anyhow to thiserror

Work in progress, does not compile

* cargo fmt

* Final touches

All tests now passing, added the "Other" error to all error enums.

* Provide macros for shorthand error handling

* Handle feature-gated errors

* Formatting

* Fixes for sundials

* Format

* Update src/error.rs

Co-authored-by: Martin Robinson <martinjrobins@gmail.com>

* Update src/error.rs

Co-authored-by: Martin Robinson <martinjrobins@gmail.com>

* Update src/error.rs

Co-authored-by: Martin Robinson <martinjrobins@gmail.com>

* Update src/error.rs

Co-authored-by: Martin Robinson <martinjrobins@gmail.com>

* Update src/error.rs

Co-authored-by: Martin Robinson <martinjrobins@gmail.com>

* Update src/error.rs

Co-authored-by: Martin Robinson <martinjrobins@gmail.com>

* Use OdeSolverError types for IDA errors

* Update src/error.rs

Co-authored-by: Martin Robinson <martinjrobins@gmail.com>

* Adding some missing result types

* Clippy and macro fix

* Map diffsl errors

* Clippy

---------

Co-authored-by: Martin Robinson <martinjrobins@gmail.com>
  • Loading branch information
mhovd and martinjrobins authored Aug 5, 2024
1 parent 25a163a commit e727919
Show file tree
Hide file tree
Showing 26 changed files with 426 additions and 220 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ diffsl-llvm17 = ["diffsl17-0", "diffsl"]
[dependencies]
nalgebra = "0.33"
nalgebra-sparse = { version = "0.10", features = ["io"] }
anyhow = "1.0.86"
num-traits = "0.2.17"
ouroboros = "0.18.2"
serde = { version = "1.0.196", features = ["derive"] }
Expand All @@ -38,6 +37,7 @@ diffsl17-0 = { package = "diffsl", version = "=0.1.9", features = ["llvm17-0"],
petgraph = "0.6.4"
faer = "0.18.2"
suitesparse_sys = { version = "0.1.3", optional = true }
thiserror = "1.0.63"

[dev-dependencies]
insta = { version = "1.34.0", features = ["yaml"] }
Expand Down
146 changes: 146 additions & 0 deletions src/error.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
use faer::sparse::CreationError;
use thiserror::Error;

/// Custom error type for Diffsol
///
/// This error type is used to wrap all possible errors that can occur when using Diffsol
#[derive(Error, Debug)]
pub enum DiffsolError {
#[error("Linear solver error: {0}")]
LinearSolverError(#[from] LinearSolverError),
#[error("Non-linear solver error: {0}")]
NonLinearSolverError(#[from] NonLinearSolverError),
#[error("ODE solver error: {0}")]
OdeSolverError(#[from] OdeSolverError),
#[error("Matrix error: {0}")]
MatrixError(#[from] MatrixError),
#[error("Error: {0}")]
Other(String),
}

/// Possible errors that can occur when solving a linear problem
#[derive(Error, Debug)]
pub enum LinearSolverError {
#[error("LU not initialized")]
LuNotInitialized,
#[error("LU solve failed")]
LuSolveFailed,
#[error("Linear solver not setup")]
LinearSolverNotSetup,
#[error("KLU failed to analyze")]
KluFailedToAnalyze,
#[error("KLU failed to factorize")]
KluFailedToFactorize,
#[error("Error: {0}")]
Other(String),
}

/// Possible errors that can occur when solving a non-linear problem
#[derive(Error, Debug)]
pub enum NonLinearSolverError {
#[error("Newton iterations did not converge")]
NewtonDidNotConverge,
#[error("LU solve failed")]
LuSolveFailed,
#[error("Error: {0}")]
Other(String),
}

/// Possible errors that can occur when solving an ODE
#[derive(Debug, Error)]
pub enum OdeSolverError {
#[error(
"Stop time = {} is less than current state time = {}",
stop_time,
state_time
)]
StopTimeBeforeCurrentTime { stop_time: f64, state_time: f64 },
#[error("Interpolation time is after current time")]
InterpolationTimeAfterCurrentTime,
#[error("Interpolation time is not within the current step. Step size is zero after calling state_mut()")]
InterpolationTimeOutsideCurrentStep,
#[error("Interpolation time is greater than current time")]
InterpolationTimeGreaterThanCurrentTime,
#[error("State not set")]
StateNotSet,
#[error("Sensitivity solve failed")]
SensitivitySolveFailed,
#[error("Step size is too small at time = {time}")]
StepSizeTooSmall { time: f64 },
#[error("Sensitivity requested but equations do not support it")]
SensitivityNotSupported,
#[error("Failed to get mutable reference to equations, is there a solver created with this problem?")]
FailedToGetMutableReference,
#[error("atol must have length 1 or equal to the number of states")]
AtolLengthMismatch,
#[error("t_eval must be increasing and all values must be greater than or equal to the current time")]
InvalidTEval,
#[error("Sundials error: {0}")]
SundialsError(String),
#[error("Problem not set")]
ProblemNotSet,
#[error("Error: {0}")]
Other(String),
}

/// Possible errors for matrix operations
#[derive(Error, Debug)]
pub enum MatrixError {
#[error("Failed to create matrix from triplets: {0}")]
FailedToCreateMatrixFromTriplets(#[from] CreationError),
#[error("Cannot union matrices with different shapes")]
UnionIncompatibleShapes,
#[error("Cannot create a matrix with zero rows or columns")]
MatrixShapeError,
#[error("Index out of bounds")]
IndexOutOfBounds,
#[error("Error: {0}")]
Other(String),
}

#[macro_export]
macro_rules! linear_solver_error {
($variant:ident) => {
DiffsolError::from(LinearSolverError::$variant)
};
($variant:ident, $($arg:tt)*) => {
DiffsolError::from(LinearSolverError::$variant($($arg)*))
};
}

#[macro_export]
macro_rules! non_linear_solver_error {
($variant:ident) => {
DiffsolError::from(NonLinearSolverError::$variant)
};
($variant:ident, $($arg:tt)*) => {
DiffsolError::from(NonLinearSolverError::$variant($($arg)*))
};
}

#[macro_export]
macro_rules! ode_solver_error {
($variant:ident) => {
DiffsolError::from(OdeSolverError::$variant)
};
($variant:ident, $($arg:tt)*) => {
DiffsolError::from(OdeSolverError::$variant($($arg)*.to_string()))
};
}

#[macro_export]
macro_rules! matrix_error {
($variant:ident) => {
DiffsolError::from(MatrixError::$variant)
};
($variant:ident, $($arg:tt)*) => {
DiffsolError::from(MatrixError::$variant($($arg)*))
};
}

#[macro_export]
macro_rules! other_error {
($msg:expr) => {
DiffsolError::Other($msg.to_string())
};
}
2 changes: 2 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -170,3 +170,5 @@ pub use vector::DefaultDenseMatrix;
use vector::{Vector, VectorCommon, VectorIndex, VectorRef, VectorView, VectorViewMut};

pub use scalar::scale;

pub mod error;
11 changes: 6 additions & 5 deletions src/linear_solver/faer/lu.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
use crate::{error::LinearSolverError, linear_solver_error};
use std::rc::Rc;

use crate::{
linear_solver::LinearSolver, op::linearise::LinearisedOp, solver::SolverProblem, LinearOp,
Matrix, MatrixSparsityRef, NonLinearOp, Op, Scalar,
error::DiffsolError, linear_solver::LinearSolver, op::linearise::LinearisedOp,
solver::SolverProblem, LinearOp, Matrix, MatrixSparsityRef, NonLinearOp, Op, Scalar,
};
use anyhow::Result;

use faer::{linalg::solvers::FullPivLu, solvers::SpSolver, Col, Mat};
/// A [LinearSolver] that uses the LU decomposition in the [`faer`](https://github.com/sarah-ek/faer-rs) library to solve the linear system.
pub struct LU<T, C>
Expand Down Expand Up @@ -41,9 +42,9 @@ impl<T: Scalar, C: NonLinearOp<M = Mat<T>, V = Col<T>, T = T>> LinearSolver<C> f
self.lu = Some(matrix.full_piv_lu());
}

fn solve_in_place(&self, x: &mut C::V) -> Result<()> {
fn solve_in_place(&self, x: &mut C::V) -> Result<(), DiffsolError> {
if self.lu.is_none() {
return Err(anyhow::anyhow!("LU not initialized"));
return Err(linear_solver_error!(LuNotInitialized))?;
}
let lu = self.lu.as_ref().unwrap();
lu.solve_in_place(x);
Expand Down
17 changes: 11 additions & 6 deletions src/linear_solver/faer/sparse_lu.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
use std::rc::Rc;

use crate::{
linear_solver::LinearSolver, matrix::sparsity::MatrixSparsityRef, op::linearise::LinearisedOp,
scalar::IndexType, solver::SolverProblem, LinearOp, Matrix, NonLinearOp, Op, Scalar,
SparseColMat,
error::{DiffsolError, LinearSolverError},
linear_solver::LinearSolver,
linear_solver_error,
matrix::sparsity::MatrixSparsityRef,
op::linearise::LinearisedOp,
scalar::IndexType,
solver::SolverProblem,
LinearOp, Matrix, NonLinearOp, Op, Scalar, SparseColMat,
};
use anyhow::Result;

use faer::{
solvers::SpSolver,
sparse::linalg::{solvers::Lu, solvers::SymbolicLu},
Expand Down Expand Up @@ -57,9 +62,9 @@ impl<T: Scalar, C: NonLinearOp<M = SparseColMat<T>, V = Col<T>, T = T>> LinearSo
)
}

fn solve_in_place(&self, x: &mut C::V) -> Result<()> {
fn solve_in_place(&self, x: &mut C::V) -> Result<(), DiffsolError> {
if self.lu.is_none() {
return Err(anyhow::anyhow!("LU not initialized"));
return Err(linear_solver_error!(LuNotInitialized))?;
}
let lu = self.lu.as_ref().unwrap();
lu.solve_in_place(x);
Expand Down
7 changes: 3 additions & 4 deletions src/linear_solver/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
use crate::{op::Op, solver::SolverProblem};
use anyhow::Result;
use crate::{error::DiffsolError, op::Op, solver::SolverProblem};

#[cfg(feature = "nalgebra")]
pub mod nalgebra;
Expand Down Expand Up @@ -27,13 +26,13 @@ pub trait LinearSolver<C: Op> {

/// Solve the problem `Ax = b` and return the solution `x`.
/// panics if [Self::set_linearisation] has not been called previously
fn solve(&self, b: &C::V) -> Result<C::V> {
fn solve(&self, b: &C::V) -> Result<C::V, DiffsolError> {
let mut b = b.clone();
self.solve_in_place(&mut b)?;
Ok(b)
}

fn solve_in_place(&self, b: &mut C::V) -> Result<()>;
fn solve_in_place(&self, b: &mut C::V) -> Result<(), DiffsolError>;
}

pub struct LinearSolveSolution<V> {
Expand Down
12 changes: 6 additions & 6 deletions src/linear_solver/nalgebra/lu.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
use std::rc::Rc;

use anyhow::Result;
use nalgebra::{DMatrix, DVector, Dyn};
use std::rc::Rc;

use crate::{
error::{DiffsolError, LinearSolverError},
linear_solver_error,
matrix::sparsity::MatrixSparsityRef,
op::{linearise::LinearisedOp, NonLinearOp},
LinearOp, LinearSolver, Matrix, Op, Scalar, SolverProblem,
Expand Down Expand Up @@ -37,14 +37,14 @@ where
impl<T: Scalar, C: NonLinearOp<M = DMatrix<T>, V = DVector<T>, T = T>> LinearSolver<C>
for LU<T, C>
{
fn solve_in_place(&self, state: &mut C::V) -> Result<()> {
fn solve_in_place(&self, state: &mut C::V) -> Result<(), DiffsolError> {
if self.lu.is_none() {
return Err(anyhow::anyhow!("LU not initialized"));
return Err(linear_solver_error!(LuNotInitialized))?;
}
let lu = self.lu.as_ref().unwrap();
match lu.solve_mut(state) {
true => Ok(()),
false => Err(anyhow::anyhow!("LU solve failed")),
false => Err(linear_solver_error!(LuSolveFailed))?,
}
}

Expand Down
27 changes: 18 additions & 9 deletions src/linear_solver/suitesparse/klu.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
use std::{cell::RefCell, rc::Rc};

use anyhow::Result;

use faer::Col;

#[cfg(target_pointer_width = "32")]
Expand All @@ -25,7 +23,12 @@ use suitesparse_sys::{
type KluIndextype = i64;

use crate::{
linear_solver::LinearSolver, matrix::MatrixCommon, op::linearise::LinearisedOp, vector::Vector,
error::{DiffsolError, LinearSolverError},
linear_solver::LinearSolver,
linear_solver_error,
matrix::MatrixCommon,
op::linearise::LinearisedOp,
vector::Vector,
LinearOp, Matrix, MatrixSparsityRef, NonLinearOp, Op, SolverProblem, SparseColMat,
};

Expand Down Expand Up @@ -68,7 +71,10 @@ struct KluSymbolic {
}

impl KluSymbolic {
fn try_from_matrix(mat: &mut impl MatrixKLU, common: *mut klu_common) -> Result<Self> {
fn try_from_matrix(
mat: &mut impl MatrixKLU,
common: *mut klu_common,
) -> Result<Self, DiffsolError> {
let n = mat.nrows() as i64;
let inner = unsafe {
klu_analyze(
Expand All @@ -79,7 +85,7 @@ impl KluSymbolic {
)
};
if inner.is_null() {
return Err(anyhow::anyhow!("KLU failed to analyze"));
return Err(linear_solver_error!(KluFailedToAnalyze));
};
Ok(Self { inner, common })
}
Expand All @@ -99,7 +105,10 @@ struct KluNumeric {
}

impl KluNumeric {
fn try_from_symbolic(symbolic: &mut KluSymbolic, mat: &mut impl MatrixKLU) -> Result<Self> {
fn try_from_symbolic(
symbolic: &mut KluSymbolic,
mat: &mut impl MatrixKLU,
) -> Result<Self, DiffsolError> {
let inner = unsafe {
klu_factor(
mat.column_pointers_mut_ptr(),
Expand All @@ -110,7 +119,7 @@ impl KluNumeric {
)
};
if inner.is_null() {
return Err(anyhow::anyhow!("KLU failed to factorize"));
return Err(linear_solver_error!(KluFailedToFactorize));
};
Ok(Self {
inner,
Expand Down Expand Up @@ -194,9 +203,9 @@ where
.ok();
}

fn solve_in_place(&self, x: &mut C::V) -> Result<()> {
fn solve_in_place(&self, x: &mut C::V) -> Result<(), DiffsolError> {
if self.klu_numeric.is_none() {
return Err(anyhow::anyhow!("LU not initialized"));
return Err(linear_solver_error!(LuNotInitialized));
}
let klu_numeric = self.klu_numeric.as_ref().unwrap();
let klu_symbolic = self.klu_symbolic.as_ref().unwrap();
Expand Down
11 changes: 5 additions & 6 deletions src/linear_solver/sundials.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,11 @@ use std::rc::Rc;
use crate::sundials_sys::{
realtype, SUNLinSolFree, SUNLinSolSetup, SUNLinSolSolve, SUNLinSol_Dense, SUNLinearSolver,
};
use anyhow::Result;

use crate::{
ode_solver::sundials::sundials_check, op::linearise::LinearisedOp,
vector::sundials::SundialsVector, LinearOp, Matrix, NonLinearOp, Op, SolverProblem,
SundialsMatrix,
error::*, linear_solver_error, ode_solver::sundials::sundials_check,
op::linearise::LinearisedOp, vector::sundials::SundialsVector, LinearOp, Matrix, NonLinearOp,
Op, SolverProblem, SundialsMatrix,
};

#[cfg(not(sundials_version_major = "5"))]
Expand Down Expand Up @@ -98,9 +97,9 @@ where
self.is_setup = true;
}

fn solve_in_place(&self, b: &mut Op::V) -> Result<()> {
fn solve_in_place(&self, b: &mut Op::V) -> Result<(), DiffsolError> {
if !self.is_setup {
return Err(anyhow::anyhow!("Linear solver not setup"));
return Err(linear_solver_error!(LinearSolverNotSetup));
}
let linear_solver = self.linear_solver.expect("Linear solver not set");
let matrix = self.matrix.as_ref().expect("Matrix not set");
Expand Down
Loading

0 comments on commit e727919

Please sign in to comment.