From a600d30ba23ac5bd7638c37ce7282838b22758f0 Mon Sep 17 00:00:00 2001 From: Martin Robinson Date: Thu, 1 Aug 2024 10:00:27 +0100 Subject: [PATCH] refactor: refactor: remove FilterCallable, abs_to, exp, filter, gather, scatter, add map_inplace. provide default impl of fill (#79) * refactor: remove FilterCallable, abs_to, exp, filter, gather, scatter, add map_inplace * refactor: provide default impl of fill * ignore notes.txt * cargo fmt --- .gitignore | 3 +- src/ode_solver/test_models/gaussian_decay.rs | 6 +- src/op/filter.rs | 70 -------------------- src/op/mod.rs | 1 - src/vector/faer_serial.rs | 36 +--------- src/vector/mod.rs | 30 +++------ src/vector/nalgebra_serial.rs | 28 +------- src/vector/sundials.rs | 69 ++----------------- 8 files changed, 25 insertions(+), 218 deletions(-) delete mode 100644 src/op/filter.rs diff --git a/.gitignore b/.gitignore index 6afcee37..37b269bf 100644 --- a/.gitignore +++ b/.gitignore @@ -25,4 +25,5 @@ perf.data* flamegraph.svg *.png env/ -*.mtx \ No newline at end of file +*.mtx +*notes.txt \ No newline at end of file diff --git a/src/ode_solver/test_models/gaussian_decay.rs b/src/ode_solver/test_models/gaussian_decay.rs index 239361ad..ea55e303 100644 --- a/src/ode_solver/test_models/gaussian_decay.rs +++ b/src/ode_solver/test_models/gaussian_decay.rs @@ -1,6 +1,7 @@ use crate::ode_solver::problem::OdeSolverSolution; use crate::OdeSolverProblem; use crate::{scalar::scale, ConstantOp, DenseMatrix, OdeBuilder, OdeEquations, Vector}; +use nalgebra::ComplexField; use num_traits::Pow; use num_traits::Zero; use std::ops::MulAssign; @@ -40,9 +41,10 @@ pub fn gaussian_decay_problem( let mut soln = OdeSolverSolution::default(); for i in 0..10 { let t = M::T::from(i as f64 / 1.0); - let px = M::V::from_vec(p.clone()) * scale(t.pow(2)) / scale(M::T::from(-2.0)); let mut y: M::V = problem.eqn.init().call(M::T::zero()); - y.component_mul_assign(&px.exp()); + let mut px = M::V::from_vec(p.clone()) * scale(t.pow(2) / M::T::from(-2.0)); + px.map_inplace(|x| x.exp()); + y.component_mul_assign(&px); soln.push(y, t); } (problem, soln) diff --git a/src/op/filter.rs b/src/op/filter.rs deleted file mode 100644 index bbd2344c..00000000 --- a/src/op/filter.rs +++ /dev/null @@ -1,70 +0,0 @@ -// a callable that takes another callable and a mask vector -// this callable, when called, will call the other callable with the mask applied - -use std::{cell::RefCell, rc::Rc}; - -use crate::{Vector, VectorIndex}; - -use super::{NonLinearOp, Op}; - -pub struct FilterCallable { - callable: Rc, - indices: ::Index, - y_full: RefCell, - x_full: RefCell, - v_full: RefCell, -} - -impl FilterCallable { - pub fn new(callable: Rc, x: &C::V, indices: ::Index) -> Self { - let y_full = RefCell::new(C::V::zeros(callable.nout())); - let x_full = RefCell::new(x.clone()); - let v_full = RefCell::new(C::V::zeros(callable.nstates())); - Self { - callable, - indices, - y_full, - x_full, - v_full, - } - } - - pub fn indices(&self) -> &::Index { - &self.indices - } -} - -impl Op for FilterCallable { - type V = C::V; - type T = C::T; - type M = C::M; - fn nstates(&self) -> usize { - self.indices.len() - } - fn nout(&self) -> usize { - self.indices.len() - } - fn nparams(&self) -> usize { - self.callable.nparams() - } -} - -impl NonLinearOp for FilterCallable { - fn call_inplace(&self, x: &Self::V, t: Self::T, y: &mut Self::V) { - let mut y_full = self.y_full.borrow_mut(); - let mut x_full = self.x_full.borrow_mut(); - x_full.scatter_from(x, &self.indices); - self.callable.call_inplace(&x_full, t, &mut y_full); - y.gather_from(&y_full, &self.indices); - } - fn jac_mul_inplace(&self, x: &Self::V, t: Self::T, v: &Self::V, y: &mut Self::V) { - let mut y_full = self.y_full.borrow_mut(); - let mut x_full = self.x_full.borrow_mut(); - let mut v_full = self.v_full.borrow_mut(); - x_full.scatter_from(x, &self.indices); - v_full.scatter_from(v, &self.indices); - self.callable - .jac_mul_inplace(&x_full, t, &v_full, &mut y_full); - y.gather_from(&y_full, &self.indices); - } -} diff --git a/src/op/mod.rs b/src/op/mod.rs index 2bcb0788..861fd178 100644 --- a/src/op/mod.rs +++ b/src/op/mod.rs @@ -11,7 +11,6 @@ pub mod closure_no_jac; pub mod closure_with_sens; pub mod constant_closure; pub mod constant_closure_with_sens; -pub mod filter; pub mod init; pub mod linear_closure; pub mod linear_closure_with_sens; diff --git a/src/vector/faer_serial.rs b/src/vector/faer_serial.rs index 74d24740..264dfc24 100644 --- a/src/vector/faer_serial.rs +++ b/src/vector/faer_serial.rs @@ -109,9 +109,6 @@ impl Vector for Col { } acc / Self::T::from(self.len() as f64) } - fn abs_to(&self, y: &mut Self) { - zipped!(self, y.as_mut()).for_each(|unzipped!(xi, mut yi)| *yi = xi.faer_abs()); - } fn as_view(&self) -> Self::View<'_> { self.as_ref() } @@ -147,8 +144,8 @@ impl Vector for Col { zipped!(self.as_mut(), x) .for_each(|unzipped!(mut si, xi)| si.write(si.read() * beta + xi.read() * alpha)); } - fn exp(&self) -> Self { - zipped!(self).map(|unzipped!(xi)| xi.exp()) + fn map_inplace(&mut self, f: impl Fn(Self::T) -> Self::T) { + zipped!(self.as_mut()).for_each(|unzipped!(mut xi)| xi.write(f(*xi))); } fn component_mul_assign(&mut self, other: &Self) { zipped!(self.as_mut(), other.as_view()).for_each(|unzipped!(mut s, o)| *s *= *o); @@ -175,21 +172,6 @@ impl Vector for Col { } acc } - fn gather_from(&mut self, other: &Self, indices: &Self::Index) { - for (i, &index) in indices.iter().enumerate() { - self[i] = other[index]; - } - } - fn scatter_from(&mut self, other: &Self, indices: &Self::Index) { - for (i, &index) in indices.iter().enumerate() { - self[index] = other[i]; - } - } - fn assign_at_indices(&mut self, indices: &Self::Index, value: Self::T) { - for &index in indices { - self[index] = value; - } - } } impl VectorIndex for Vec { @@ -221,9 +203,6 @@ impl_vector_common!(ColMut<'a, T>); impl<'a, T: Scalar> VectorView<'a> for ColRef<'a, T> { type Owned = Col; - fn abs_to(&self, y: &mut Self::Owned) { - zipped!(self, y.as_mut()).for_each(|unzipped!(xi, mut yi)| *yi = xi.faer_abs()); - } fn into_owned(self) -> Col { self.to_owned() } @@ -248,9 +227,6 @@ impl<'a, T: Scalar> VectorView<'a> for ColRef<'a, T> { impl<'a, T: Scalar> VectorViewMut<'a> for ColMut<'a, T> { type Owned = Col; type View = ColRef<'a, T>; - fn abs_to(&self, y: &mut Self::Owned) { - zipped!(self, y.as_mut()).for_each(|unzipped!(xi, mut yi)| *yi = xi.faer_abs()); - } fn copy_from(&mut self, other: &Self::Owned) { self.copy_from(other); } @@ -265,14 +241,6 @@ mod tests { use super::*; use crate::scalar::scale; - #[test] - fn test_abs() { - let v = Col::from_vec(vec![1.0, -2.0, 3.0]); - let mut v_abs = v.clone(); - v.abs_to(&mut v_abs); - assert_eq!(v_abs, Col::from_vec(vec![1.0, 2.0, 3.0])); - } - #[test] fn test_mult() { let v = Col::from_vec(vec![1.0, -2.0, 3.0]); diff --git a/src/vector/mod.rs b/src/vector/mod.rs index c8915f36..80daaef0 100644 --- a/src/vector/mod.rs +++ b/src/vector/mod.rs @@ -83,7 +83,6 @@ pub trait VectorViewMut<'a>: { type Owned; type View; - fn abs_to(&self, y: &mut Self::Owned); fn copy_from(&mut self, other: &Self::Owned); fn copy_from_view(&mut self, other: &Self::View); } @@ -98,7 +97,6 @@ pub trait VectorView<'a>: { type Owned; fn squared_norm(&self, y: &Self::Owned, atol: &Self::Owned, rtol: Self::T) -> Self::T; - fn abs_to(&self, y: &mut Self::Owned); fn norm(&self) -> Self::T; fn into_owned(self) -> Self::Owned; } @@ -134,13 +132,14 @@ pub trait Vector: fn is_empty(&self) -> bool { self.len() == 0 } - fn abs_to(&self, y: &mut Self); - fn exp(&self) -> Self; + fn map_inplace(&mut self, f: impl Fn(Self::T) -> Self::T); fn from_element(nstates: usize, value: Self::T) -> Self; fn zeros(nstates: usize) -> Self { Self::from_element(nstates, Self::T::zero()) } - fn fill(&mut self, value: Self::T); + fn fill(&mut self, value: Self::T) { + self.map_inplace(|_| value); + } fn as_view(&self) -> Self::View<'_>; fn as_view_mut(&mut self) -> Self::ViewMut<'_>; fn as_slice(&self) -> &[Self::T]; @@ -158,29 +157,16 @@ pub trait Vector: fn binary_fold(&self, other: &Self, init: B, f: F) -> B where F: Fn(B, Self::T, Self::T, IndexType) -> B; - fn filter(&self, indices: &Self::Index) -> Self { - let mut result = Self::zeros(indices.len()); - result.gather_from(self, indices); - result + fn assign_at_indices(&mut self, indices: &Self::Index, value: Self::T) { + for i in 0..indices.len() { + self[indices[i]] = value; + } } - fn assign_at_indices(&mut self, indices: &Self::Index, value: Self::T); - - // for i in 0..indices.len(): - // self[i] = value[indices[i]] - fn gather_from(&mut self, other: &Self, indices: &Self::Index); - - // for i in 0..indices.len(): - // self[indices[i]] = value[i] - fn scatter_from(&mut self, other: &Self, indices: &Self::Index); - - // for i in 0..indices.len(): - // self[indices[i]] = value[indices[i]] fn copy_from_indices(&mut self, other: &Self, indices: &Self::Index) { for i in 0..indices.len() { self[indices[i]] = other[indices[i]]; } } - fn assert_eq_st(&self, other: &Self, tol: Self::T) { let tol = Self::from_element(self.len(), tol); self.assert_eq(other, &tol); diff --git a/src/vector/nalgebra_serial.rs b/src/vector/nalgebra_serial.rs index f66400f6..a9eddf0f 100644 --- a/src/vector/nalgebra_serial.rs +++ b/src/vector/nalgebra_serial.rs @@ -53,9 +53,6 @@ impl_vector_common!(DVectorViewMut<'a, T>); impl<'a, T: Scalar> VectorView<'a> for DVectorView<'a, T> { type Owned = DVector; - fn abs_to(&self, y: &mut Self::Owned) { - y.zip_apply(self, |y, x| *y = x.abs()); - } fn into_owned(self) -> Self::Owned { self.into_owned() } @@ -112,9 +109,6 @@ impl_mul_assign_scale_vector!(DVectorViewMut<'a, T>); impl<'a, T: Scalar> VectorViewMut<'a> for DVectorViewMut<'a, T> { type Owned = DVector; type View = DVectorView<'a, T>; - fn abs_to(&self, y: &mut Self::Owned) { - y.zip_apply(self, |y, x| *y = x.abs()); - } fn copy_from(&mut self, other: &Self::Owned) { self.copy_from(other); } @@ -162,9 +156,6 @@ impl Vector for DVector { } acc / Self::T::from(self.len() as f64) } - fn abs_to(&self, y: &mut Self) { - y.zip_apply(self, |y, x| *y = x.abs()); - } fn fill(&mut self, value: T) { self.fill(value); } @@ -177,8 +168,8 @@ impl Vector for DVector { fn copy_from(&mut self, other: &Self) { self.copy_from(other); } - fn exp(&self) -> Self { - self.map(|x| x.exp()) + fn map_inplace(&mut self, f: impl Fn(Self::T) -> Self::T) { + self.iter_mut().for_each(|x| *x = f(*x)); } fn copy_from_view(&mut self, other: &Self::View<'_>) { self.copy_from(other); @@ -216,21 +207,6 @@ impl Vector for DVector { } Self::Index::from_vec(indices) } - fn gather_from(&mut self, other: &Self, indices: &Self::Index) { - for (i, &index) in indices.iter().enumerate() { - self[i] = other[index]; - } - } - fn scatter_from(&mut self, other: &Self, indices: &Self::Index) { - for (i, &index) in indices.iter().enumerate() { - self[index] = other[i]; - } - } - fn assign_at_indices(&mut self, indices: &Self::Index, value: Self::T) { - for &index in indices.iter() { - self[index] = value; - } - } fn binary_fold(&self, other: &Self, init: B, f: F) -> B where F: Fn(B, Self::T, Self::T, IndexType) -> B, diff --git a/src/vector/sundials.rs b/src/vector/sundials.rs index e3d92046..365b1c6c 100644 --- a/src/vector/sundials.rs +++ b/src/vector/sundials.rs @@ -3,7 +3,7 @@ use std::fmt::{Debug, Display, Formatter}; use std::ops::{Add, AddAssign, Div, DivAssign, Index, IndexMut, Mul, MulAssign, Sub, SubAssign}; use crate::sundials_sys::{ - realtype, N_VAbs, N_VAddConst, N_VClone, N_VConst, N_VDestroy, N_VDiv, N_VGetArrayPointer, + realtype, N_VAddConst, N_VClone, N_VConst, N_VDestroy, N_VDiv, N_VGetArrayPointer, N_VGetLength_Serial, N_VLinearSum, N_VNew_Serial, N_VProd, N_VScale, N_VWL2Norm_Serial, N_Vector, }; @@ -409,9 +409,6 @@ impl_sub_view_owned!(SundialsVectorView, SundialsVector); impl<'a> VectorViewMut<'a> for SundialsVectorViewMut<'a> { type Owned = SundialsVector; type View = SundialsVectorView<'a>; - fn abs_to(&self, y: &mut Self::Owned) { - unsafe { N_VAbs(self.sundials_vector(), y.sundials_vector()) } - } fn copy_from(&mut self, other: &Self::Owned) { unsafe { N_VScale(1.0, other.sundials_vector(), self.sundials_vector()) } } @@ -422,9 +419,6 @@ impl<'a> VectorViewMut<'a> for SundialsVectorViewMut<'a> { impl<'a> VectorView<'a> for SundialsVectorView<'a> { type Owned = SundialsVector; - fn abs_to(&self, y: &mut Self::Owned) { - unsafe { N_VAbs(self.sundials_vector(), y.sundials_vector()) } - } fn into_owned(self) -> Self::Owned { let mut z = SundialsVector::new_serial(self.len()); z.copy_from_view(&self); @@ -515,9 +509,6 @@ impl Vector for SundialsVector { fn fill(&mut self, value: Self::T) { unsafe { N_VConst(value, self.sundials_vector()) } } - fn abs_to(&self, y: &mut Self) { - unsafe { N_VAbs(self.sundials_vector(), y.sundials_vector()) } - } fn add_scalar_mut(&mut self, scalar: Self::T) { unsafe { N_VAddConst(self.sundials_vector(), scalar, self.sundials_vector()) } } @@ -573,12 +564,10 @@ impl Vector for SundialsVector { fn copy_from_view(&mut self, other: &Self::View<'_>) { unsafe { N_VScale(1.0, other.sundials_vector(), self.sundials_vector()) } } - fn exp(&self) -> Self { - let mut z = SundialsVector::new_clone(self); + fn map_inplace(&mut self, f: impl Fn(Self::T) -> Self::T) { for i in 0..self.len() { - z[i] = self[i].exp(); + self[i] = f(self[i]); } - z } fn filter_indices bool>(&self, f: F) -> Self::Index { let mut indices = vec![]; @@ -601,21 +590,6 @@ impl Vector for SundialsVector { } v } - fn gather_from(&mut self, other: &Self, indices: &Self::Index) { - for i in 0..indices.len() { - self[i] = other[indices[i]]; - } - } - fn scatter_from(&mut self, other: &Self, indices: &Self::Index) { - for i in 0..indices.len() { - self[indices[i]] = other[i]; - } - } - fn assign_at_indices(&mut self, indices: &Self::Index, value: Self::T) { - for i in 0..indices.len() { - self[indices[i]] = value; - } - } fn binary_fold(&self, other: &Self, init: B, f: F) -> B where F: Fn(B, Self::T, Self::T, IndexType) -> B, @@ -668,17 +642,6 @@ mod tests { assert_eq!(v3[1], 2.0); } - #[test] - fn test_abs() { - let mut v = SundialsVector::new_serial(2); - v[0] = -1.0; - v[1] = 2.0; - let mut v2 = v.clone(); - v.abs_to(&mut v2); - assert_eq!(v2[0], 1.0); - assert_eq!(v2[1], 2.0); - } - #[test] fn test_axpy() { let mut v = SundialsVector::new_serial(2); @@ -720,13 +683,13 @@ mod tests { } #[test] - fn test_exp() { + fn test_map() { let mut v = SundialsVector::new_serial(2); v[0] = 1.0; v[1] = 2.0; - let v2 = v.exp(); - assert_eq!(v2[0], 1.0_f64.exp()); - assert_eq!(v2[1], 2.0_f64.exp()); + v.map_inplace(f64::exp); + assert_eq!(v[0], 1.0_f64.exp()); + assert_eq!(v[1], 2.0_f64.exp()); } #[test] @@ -739,24 +702,6 @@ mod tests { assert_eq!(indices[0], 1); } - #[test] - fn test_gather_scatter() { - let mut v = SundialsVector::new_serial(3); - v[0] = 1.0; - v[1] = 2.0; - v[2] = 3.0; - let mut v2 = SundialsVector::new_serial(2); - v2.gather_from(&v, &SundialsIndexVector(vec![0, 2])); - assert_eq!(v2[0], 1.0); - assert_eq!(v2[1], 3.0); - v2[0] = 4.0; - v2[1] = 5.0; - v.scatter_from(&v2, &SundialsIndexVector(vec![0, 2])); - assert_eq!(v[0], 4.0); - assert_eq!(v[1], 2.0); - assert_eq!(v[2], 5.0); - } - #[test] fn test_zeros() { let v = SundialsIndexVector::zeros(1);