Skip to content

Commit

Permalink
refactor: refactor: remove FilterCallable, abs_to, exp, filter, gathe…
Browse files Browse the repository at this point in the history
…r, scatter, add map_inplace. provide default impl of fill (#79)

* refactor: remove FilterCallable, abs_to, exp, filter, gather, scatter, add map_inplace

* refactor: provide default impl of fill

* ignore notes.txt

* cargo fmt
  • Loading branch information
martinjrobins authored Aug 1, 2024
1 parent 9642a69 commit a600d30
Show file tree
Hide file tree
Showing 8 changed files with 25 additions and 218 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,5 @@ perf.data*
flamegraph.svg
*.png
env/
*.mtx
*.mtx
*notes.txt
6 changes: 4 additions & 2 deletions src/ode_solver/test_models/gaussian_decay.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use crate::ode_solver::problem::OdeSolverSolution;
use crate::OdeSolverProblem;
use crate::{scalar::scale, ConstantOp, DenseMatrix, OdeBuilder, OdeEquations, Vector};
use nalgebra::ComplexField;
use num_traits::Pow;
use num_traits::Zero;
use std::ops::MulAssign;
Expand Down Expand Up @@ -40,9 +41,10 @@ pub fn gaussian_decay_problem<M: DenseMatrix + 'static>(
let mut soln = OdeSolverSolution::default();
for i in 0..10 {
let t = M::T::from(i as f64 / 1.0);
let px = M::V::from_vec(p.clone()) * scale(t.pow(2)) / scale(M::T::from(-2.0));
let mut y: M::V = problem.eqn.init().call(M::T::zero());
y.component_mul_assign(&px.exp());
let mut px = M::V::from_vec(p.clone()) * scale(t.pow(2) / M::T::from(-2.0));
px.map_inplace(|x| x.exp());
y.component_mul_assign(&px);
soln.push(y, t);
}
(problem, soln)
Expand Down
70 changes: 0 additions & 70 deletions src/op/filter.rs

This file was deleted.

1 change: 0 additions & 1 deletion src/op/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ pub mod closure_no_jac;
pub mod closure_with_sens;
pub mod constant_closure;
pub mod constant_closure_with_sens;
pub mod filter;
pub mod init;
pub mod linear_closure;
pub mod linear_closure_with_sens;
Expand Down
36 changes: 2 additions & 34 deletions src/vector/faer_serial.rs
Original file line number Diff line number Diff line change
Expand Up @@ -109,9 +109,6 @@ impl<T: Scalar> Vector for Col<T> {
}
acc / Self::T::from(self.len() as f64)
}
fn abs_to(&self, y: &mut Self) {
zipped!(self, y.as_mut()).for_each(|unzipped!(xi, mut yi)| *yi = xi.faer_abs());
}
fn as_view(&self) -> Self::View<'_> {
self.as_ref()
}
Expand Down Expand Up @@ -147,8 +144,8 @@ impl<T: Scalar> Vector for Col<T> {
zipped!(self.as_mut(), x)
.for_each(|unzipped!(mut si, xi)| si.write(si.read() * beta + xi.read() * alpha));
}
fn exp(&self) -> Self {
zipped!(self).map(|unzipped!(xi)| xi.exp())
fn map_inplace(&mut self, f: impl Fn(Self::T) -> Self::T) {
zipped!(self.as_mut()).for_each(|unzipped!(mut xi)| xi.write(f(*xi)));
}
fn component_mul_assign(&mut self, other: &Self) {
zipped!(self.as_mut(), other.as_view()).for_each(|unzipped!(mut s, o)| *s *= *o);
Expand All @@ -175,21 +172,6 @@ impl<T: Scalar> Vector for Col<T> {
}
acc
}
fn gather_from(&mut self, other: &Self, indices: &Self::Index) {
for (i, &index) in indices.iter().enumerate() {
self[i] = other[index];
}
}
fn scatter_from(&mut self, other: &Self, indices: &Self::Index) {
for (i, &index) in indices.iter().enumerate() {
self[index] = other[i];
}
}
fn assign_at_indices(&mut self, indices: &Self::Index, value: Self::T) {
for &index in indices {
self[index] = value;
}
}
}

impl VectorIndex for Vec<IndexType> {
Expand Down Expand Up @@ -221,9 +203,6 @@ impl_vector_common!(ColMut<'a, T>);

impl<'a, T: Scalar> VectorView<'a> for ColRef<'a, T> {
type Owned = Col<T>;
fn abs_to(&self, y: &mut Self::Owned) {
zipped!(self, y.as_mut()).for_each(|unzipped!(xi, mut yi)| *yi = xi.faer_abs());
}
fn into_owned(self) -> Col<T> {
self.to_owned()
}
Expand All @@ -248,9 +227,6 @@ impl<'a, T: Scalar> VectorView<'a> for ColRef<'a, T> {
impl<'a, T: Scalar> VectorViewMut<'a> for ColMut<'a, T> {
type Owned = Col<T>;
type View = ColRef<'a, T>;
fn abs_to(&self, y: &mut Self::Owned) {
zipped!(self, y.as_mut()).for_each(|unzipped!(xi, mut yi)| *yi = xi.faer_abs());
}
fn copy_from(&mut self, other: &Self::Owned) {
self.copy_from(other);
}
Expand All @@ -265,14 +241,6 @@ mod tests {
use super::*;
use crate::scalar::scale;

#[test]
fn test_abs() {
let v = Col::from_vec(vec![1.0, -2.0, 3.0]);
let mut v_abs = v.clone();
v.abs_to(&mut v_abs);
assert_eq!(v_abs, Col::from_vec(vec![1.0, 2.0, 3.0]));
}

#[test]
fn test_mult() {
let v = Col::from_vec(vec![1.0, -2.0, 3.0]);
Expand Down
30 changes: 8 additions & 22 deletions src/vector/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,6 @@ pub trait VectorViewMut<'a>:
{
type Owned;
type View;
fn abs_to(&self, y: &mut Self::Owned);
fn copy_from(&mut self, other: &Self::Owned);
fn copy_from_view(&mut self, other: &Self::View);
}
Expand All @@ -98,7 +97,6 @@ pub trait VectorView<'a>:
{
type Owned;
fn squared_norm(&self, y: &Self::Owned, atol: &Self::Owned, rtol: Self::T) -> Self::T;
fn abs_to(&self, y: &mut Self::Owned);
fn norm(&self) -> Self::T;
fn into_owned(self) -> Self::Owned;
}
Expand Down Expand Up @@ -134,13 +132,14 @@ pub trait Vector:
fn is_empty(&self) -> bool {
self.len() == 0
}
fn abs_to(&self, y: &mut Self);
fn exp(&self) -> Self;
fn map_inplace(&mut self, f: impl Fn(Self::T) -> Self::T);
fn from_element(nstates: usize, value: Self::T) -> Self;
fn zeros(nstates: usize) -> Self {
Self::from_element(nstates, Self::T::zero())
}
fn fill(&mut self, value: Self::T);
fn fill(&mut self, value: Self::T) {
self.map_inplace(|_| value);
}
fn as_view(&self) -> Self::View<'_>;
fn as_view_mut(&mut self) -> Self::ViewMut<'_>;
fn as_slice(&self) -> &[Self::T];
Expand All @@ -158,29 +157,16 @@ pub trait Vector:
fn binary_fold<B, F>(&self, other: &Self, init: B, f: F) -> B
where
F: Fn(B, Self::T, Self::T, IndexType) -> B;
fn filter(&self, indices: &Self::Index) -> Self {
let mut result = Self::zeros(indices.len());
result.gather_from(self, indices);
result
fn assign_at_indices(&mut self, indices: &Self::Index, value: Self::T) {
for i in 0..indices.len() {
self[indices[i]] = value;
}
}
fn assign_at_indices(&mut self, indices: &Self::Index, value: Self::T);

// for i in 0..indices.len():
// self[i] = value[indices[i]]
fn gather_from(&mut self, other: &Self, indices: &Self::Index);

// for i in 0..indices.len():
// self[indices[i]] = value[i]
fn scatter_from(&mut self, other: &Self, indices: &Self::Index);

// for i in 0..indices.len():
// self[indices[i]] = value[indices[i]]
fn copy_from_indices(&mut self, other: &Self, indices: &Self::Index) {
for i in 0..indices.len() {
self[indices[i]] = other[indices[i]];
}
}

fn assert_eq_st(&self, other: &Self, tol: Self::T) {
let tol = Self::from_element(self.len(), tol);
self.assert_eq(other, &tol);
Expand Down
28 changes: 2 additions & 26 deletions src/vector/nalgebra_serial.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,6 @@ impl_vector_common!(DVectorViewMut<'a, T>);

impl<'a, T: Scalar> VectorView<'a> for DVectorView<'a, T> {
type Owned = DVector<T>;
fn abs_to(&self, y: &mut Self::Owned) {
y.zip_apply(self, |y, x| *y = x.abs());
}
fn into_owned(self) -> Self::Owned {
self.into_owned()
}
Expand Down Expand Up @@ -112,9 +109,6 @@ impl_mul_assign_scale_vector!(DVectorViewMut<'a, T>);
impl<'a, T: Scalar> VectorViewMut<'a> for DVectorViewMut<'a, T> {
type Owned = DVector<T>;
type View = DVectorView<'a, T>;
fn abs_to(&self, y: &mut Self::Owned) {
y.zip_apply(self, |y, x| *y = x.abs());
}
fn copy_from(&mut self, other: &Self::Owned) {
self.copy_from(other);
}
Expand Down Expand Up @@ -162,9 +156,6 @@ impl<T: Scalar> Vector for DVector<T> {
}
acc / Self::T::from(self.len() as f64)
}
fn abs_to(&self, y: &mut Self) {
y.zip_apply(self, |y, x| *y = x.abs());
}
fn fill(&mut self, value: T) {
self.fill(value);
}
Expand All @@ -177,8 +168,8 @@ impl<T: Scalar> Vector for DVector<T> {
fn copy_from(&mut self, other: &Self) {
self.copy_from(other);
}
fn exp(&self) -> Self {
self.map(|x| x.exp())
fn map_inplace(&mut self, f: impl Fn(Self::T) -> Self::T) {
self.iter_mut().for_each(|x| *x = f(*x));
}
fn copy_from_view(&mut self, other: &Self::View<'_>) {
self.copy_from(other);
Expand Down Expand Up @@ -216,21 +207,6 @@ impl<T: Scalar> Vector for DVector<T> {
}
Self::Index::from_vec(indices)
}
fn gather_from(&mut self, other: &Self, indices: &Self::Index) {
for (i, &index) in indices.iter().enumerate() {
self[i] = other[index];
}
}
fn scatter_from(&mut self, other: &Self, indices: &Self::Index) {
for (i, &index) in indices.iter().enumerate() {
self[index] = other[i];
}
}
fn assign_at_indices(&mut self, indices: &Self::Index, value: Self::T) {
for &index in indices.iter() {
self[index] = value;
}
}
fn binary_fold<B, F>(&self, other: &Self, init: B, f: F) -> B
where
F: Fn(B, Self::T, Self::T, IndexType) -> B,
Expand Down
Loading

0 comments on commit a600d30

Please sign in to comment.