Skip to content

Commit

Permalink
bool-based Duals, duals.rs
Browse files Browse the repository at this point in the history
  • Loading branch information
ryan-williams committed Sep 24, 2023
1 parent 7a5dba3 commit aa96f6b
Show file tree
Hide file tree
Showing 9 changed files with 163 additions and 171 deletions.
18 changes: 8 additions & 10 deletions src/circle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use log::debug;
use serde::{Deserialize, Serialize};
use tsify::Tsify;
use crate::{
dual::{D, Dual, d_fns},
dual::{D, Dual},
r2::R2, rotate::{Rotate as _Rotate, RotateArg}, shape::{Duals, Shape, Shapes}, transform::{Projection, CanTransform}, transform::Transform::{Rotate, Scale, ScaleXY, Translate, self}, ellipses::xyrr::{XYRR, UnitCircleGap}, sqrt::Sqrt, math::{is_normal::IsNormal, recip::Recip}, to::To, intersect::Intersect
};

Expand Down Expand Up @@ -253,10 +253,9 @@ impl Add<R2<i64>> for Circle<f64> {

impl Intersect<Circle<f64>, D> for Circle<f64> {
fn intersect(&self, other: &Circle<f64>) -> Vec<R2<D>> {
let ( _z, d ) = d_fns(6);
let [ s0, s1 ] = Shapes::from(&[
(Shape::Circle(* self), vec![ d; 3 ]),
(Shape::Circle(*other), vec![ d; 3 ]),
let [ s0, s1 ] = Shapes::from([
(Shape::Circle(* self), vec![ true; 3 ]),
(Shape::Circle(*other), vec![ true; 3 ]),
]);
s0.intersect(&s1)
}
Expand All @@ -279,7 +278,7 @@ mod tests {

use super::*;

use crate::{intersect::Intersect, dual::d_fns};
use crate::{intersect::Intersect, duals::{D, Z}};

pub fn r2(x: f64, dx: Vec<f64>, y: f64, dy: Vec<f64>) -> R2<D> {
R2 { x: Dual::new(x, dx), y: Dual::new(y, dy) }
Expand Down Expand Up @@ -331,10 +330,9 @@ mod tests {

#[test]
fn tangent_circles() {
let ( z, mut d ) = d_fns(1);
let [ s0, s1 ] = Shapes::from(&[
(Shape::Circle(Circle { c: R2 { x: 0., y: 0. }, r: 2. }), vec![ z, z, z ]),
(Shape::Circle(Circle { c: R2 { x: 3., y: 0. }, r: 1. }), vec![ d, z, z ]),
let [ s0, s1 ] = Shapes::from([
(Shape::Circle(Circle { c: R2 { x: 0., y: 0. }, r: 2. }), vec![ Z, Z, Z ]),
(Shape::Circle(Circle { c: R2 { x: 3., y: 0. }, r: 1. }), vec![ D, Z, Z ]),
]);
let ps = s0.intersect(&s1);
assert_eq!(ps, vec![
Expand Down
76 changes: 1 addition & 75 deletions src/dual.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,7 @@ impl Serialize for Dual {

impl<'de> Deserialize<'de> for Dual {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>
where D: Deserializer<'de>
{
#[derive(Deserialize)]
#[serde(field_identifier, rename_all = "lowercase")]
Expand Down Expand Up @@ -396,79 +395,6 @@ impl Sum for Dual {
}
}

/// Placeholder for the derivative component of a [`Dual`] (corresponding to one "coordinate" of a Shape, e.g. `c.x`).
/// The `usize` argument indicates the length of the derivative vector, which should be the same for all coordinates in a [`Model`].
/// `Zeros` values are stateless, but `OneHot`s are expanded during [`Model`]/[`Step`] construction, so that the "hot" element proceeds through the vector,
/// i.e. the first differentiable coordinate will be `one_hot(0)`, the second will be `one_hot(1)`, etc.
/// This level of indirection makes it easier to specify a model's Shapes, and toggle which coordinates are considered moveable (and whose error-partial-derivative ∂(error) is propagated through all calculations).
/// See [`model::tests`] for example:
/// ```rust
/// let ( z, d ) = d_fns(2);
/// let inputs = vec![
/// (circle(0., 0., 1.), vec![ z, z, z, ]),
/// (circle(1., 0., 1.), vec![ d, z, d, ]),
/// ];
/// let targets = [
/// ("0*", 1. / 3.), // Fizz (multiples of 3)
/// ("*1", 1. / 5.), // Buzz (multiples of 5)
/// ("01", 1. / 15.), // Fizz Buzz (multiples of both 3 and 5)
/// ];
/// let model = Model::new(inputs, targets.to());
/// ```
/// This initializes two [`Circle`]s (with f64 coordinate values), such that

#[derive(Clone, Copy, Debug)]
pub enum InitDual { Zeros(usize), OneHot(usize), }
use InitDual::*;

pub struct InitDuals {
pub nxt: usize,
}
impl InitDuals {
pub fn new() -> Self {
InitDuals { nxt: 0 }
}
pub fn next(&mut self, id: &InitDual) -> Vec<f64> {
match id {
Zeros(n) => vec![0.; *n],
OneHot(n) => {
let duals = one_hot(&self.nxt, n);
self.nxt += 1;
duals
},
}
}
pub fn shape(&mut self, (s, init_duals): &InputSpec) -> Shape<D> {
let duals: Duals = init_duals.iter().map(|init_dual| self.next(init_dual)).collect();
s.dual(&duals)
}
}

pub fn one_hot(idx: &usize, size: &usize) -> Vec<f64> {
let mut v = vec![0.; *size];
v[*idx] = 1.;
v
}

pub fn is_one_hot(v: &Vec<f64>) -> Option<usize> {
let mut idx = None;
for (i, x) in v.iter().enumerate() {
if *x == 1. {
if idx.is_some() {
return None;
}
idx = Some(i);
} else if *x != 0. {
return None;
}
}
idx
}

pub fn d_fns(n: usize) -> (InitDual, InitDual) {
( Zeros(n), OneHot(n), )
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down
78 changes: 78 additions & 0 deletions src/duals.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
use crate::{shape::{InputSpec, Shape, Duals}, dual::Dual};

pub static Z: bool = false;
pub static D: bool = true;

/// Placeholder for the derivative component of a [`Dual`] (corresponding to one "coordinate" of a Shape, e.g. `c.x`).
/// The `usize` argument indicates the length of the derivative vector, which should be the same for all coordinates in a [`Model`].
/// `Zeros` values are stateless, but `OneHot`s are expanded during [`Model`]/[`Step`] construction, so that the "hot" element proceeds through the vector,
/// i.e. the first differentiable coordinate will be `one_hot(0)`, the second will be `one_hot(1)`, etc.
/// This level of indirection makes it easier to specify a model's Shapes, and toggle which coordinates are considered moveable (and whose error-partial-derivative ∂(error) is propagated through all calculations).
/// See [`model::tests`] for example:
/// ```rust
/// use crate::dual::{d, z};
/// let inputs = vec![
/// (circle(0., 0., 1.), vec![ Z, Z, Z ]),
/// (circle(1., 0., 1.), vec![ D, Z, D ]),
/// ];
/// let targets = [
/// ("0*", 1. / 3.), // Fizz (multiples of 3)
/// ("*1", 1. / 5.), // Buzz (multiples of 5)
/// ("01", 1. / 15.), // Fizz Buzz (multiples of both 3 and 5)
/// ];
/// let model = Model::new(inputs, targets.to());
/// ```
/// This initializes two [`Circle`]s (with f64 coordinate values), such that
pub struct InitDuals {
pub nxt: usize,
pub n: usize
}
impl InitDuals {
pub fn from<const N: usize>(input_specs: [InputSpec; N]) -> [ Shape<Dual>; N ] {
let n = input_specs.iter().map(|(_, spec)| spec.iter().filter(|v| **v).count()).sum();
let mut init = InitDuals::new(n);
input_specs.map(|spec| init.shape(&spec))
}
pub fn from_vec(input_specs: &Vec<InputSpec>) -> Vec<Shape<Dual>> {
let n = input_specs.iter().map(|(_, spec)| spec.iter().filter(|v| **v).count()).sum();
let mut init = InitDuals::new(n);
input_specs.iter().map(|spec| init.shape(spec)).collect()
}
pub fn new(n: usize) -> Self {
InitDuals { nxt: 0, n }
}
pub fn next(&mut self, differentiable: bool) -> Vec<f64> {
if differentiable {
let duals = one_hot(&self.nxt, &self.n);
self.nxt += 1;
duals
} else {
vec![0.; self.n]
}
}
pub fn shape(&mut self, (s, init_duals): &InputSpec) -> Shape<Dual> {
let duals: Duals = init_duals.iter().map(|init_dual| self.next(*init_dual)).collect();
s.dual(&duals)
}
}

pub fn one_hot(idx: &usize, size: &usize) -> Vec<f64> {
let mut v = vec![0.; *size];
v[*idx] = 1.;
v
}

pub fn is_one_hot(v: &Vec<f64>) -> Option<usize> {
let mut idx = None;
for (i, x) in v.iter().enumerate() {
if *x == 1. {
if idx.is_some() {
return None;
}
idx = Some(i);
} else if *x != 0. {
return None;
}
}
idx
}
12 changes: 5 additions & 7 deletions src/ellipses/xyrr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ impl<D: RelativeEq<Epsilon = f64> + Clone> RelativeEq for XYRR<D> {
mod tests {
use std::{fmt, f64::NAN};

use crate::{dual::{Dual, d_fns}, circle::Circle, intersect::Intersect, to::To, shape::Shape};
use crate::{dual::Dual, circle::Circle, intersect::Intersect, to::To, shape::{Shape, xyrr, Shapes}, duals::{D, Z}};

use super::*;
use approx::{AbsDiffEq, RelativeEq};
Expand Down Expand Up @@ -390,12 +390,10 @@ mod tests {

#[test]
fn tangent_circles() {
let ( z, mut d ) = d_fns(1);
let ellipses = [
Shape::XYRR(XYRR { c: R2 { x: 0., y: 0. }, r: R2 { x: 2., y: 2., } }.dual(&vec![ z(), z(), z(), z() ])),
Shape::XYRR(XYRR { c: R2 { x: 3., y: 0. }, r: R2 { x: 1., y: 1., } }.dual(&vec![ d(), z(), z(), z() ])),
];
let [ e0, e1 ] = ellipses;
let [ e0, e1 ] = Shapes::from([
(xyrr(0., 0., 2., 2.), vec![ Z, Z, Z, Z ]),
(xyrr(3., 0., 1., 1.), vec![ D, Z, Z, Z ]),
]);
let ps = e0.intersect(&e1);
assert_eq!(ps, vec![
// Heads up! Tangent points have NAN gradients. [`Scene`] detects/filters them, but at this level of the stack they are passed along
Expand Down
7 changes: 4 additions & 3 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ pub mod contains;
pub mod d5;
pub mod distance;
pub mod dual;
pub mod duals;
pub mod edge;
pub mod ellipses;
pub mod float_arr;
Expand Down Expand Up @@ -42,7 +43,7 @@ pub mod zero;
pub mod js_dual;

use targets::Targets;
use shape::Input;
use shape::InputSpec;
use step::Step;
use dual::D;
use ellipses::xyrr::XYRR;
Expand Down Expand Up @@ -83,15 +84,15 @@ pub fn update_log_level(level: JsValue) {

#[wasm_bindgen]
pub fn make_step(inputs: JsValue, targets: JsValue) -> JsValue {
let inputs: Vec<Input> = serde_wasm_bindgen::from_value(inputs).unwrap();
let inputs: Vec<InputSpec> = serde_wasm_bindgen::from_value(inputs).unwrap();
let targets: TargetsMap<f64> = serde_wasm_bindgen::from_value(targets.clone()).unwrap();
let step = Step::new(inputs, targets.into());
serde_wasm_bindgen::to_value(&step).unwrap()
}

#[wasm_bindgen]
pub fn make_model(inputs: JsValue, targets: JsValue) -> JsValue {
let inputs: Vec<Input> = serde_wasm_bindgen::from_value(inputs).unwrap();
let inputs: Vec<InputSpec> = serde_wasm_bindgen::from_value(inputs).unwrap();
let targets: TargetsMap<f64> = serde_wasm_bindgen::from_value(targets.clone()).unwrap();
let model = Model::new(inputs, targets);
serde_wasm_bindgen::to_value(&model).unwrap()
Expand Down
Loading

0 comments on commit aa96f6b

Please sign in to comment.