Skip to content

Commit

Permalink
optimize pow with 1, 2, ¯1
Browse files Browse the repository at this point in the history
  • Loading branch information
kaikalii committed Nov 26, 2024
1 parent a4e9fd4 commit ed3e30d
Show file tree
Hide file tree
Showing 8 changed files with 57 additions and 62 deletions.
2 changes: 1 addition & 1 deletion src/algorithm/loops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ pub fn split_by(f: SigNode, scalar: bool, env: &mut Uiua) -> UiuaResult {
|| matches!(haystack, Value::Complex(_))
{
let mask = if scalar {
delim.is_ne(haystack.clone(), 0, 0, env)?
delim.is_ne(haystack.clone(), env)?
} else {
delim.mask(&haystack, env)?.not(env)?
};
Expand Down
19 changes: 2 additions & 17 deletions src/algorithm/pervade.rs
Original file line number Diff line number Diff line change
Expand Up @@ -122,24 +122,14 @@ fn derive_new_shape(
Ok(new_shape)
}

pub fn bin_pervade<A, B, C, F>(
a: Array<A>,
b: Array<B>,
a_depth: usize,
b_depth: usize,
env: &Uiua,
f: F,
) -> UiuaResult<Array<C>>
pub fn bin_pervade<A, B, C, F>(a: Array<A>, b: Array<B>, env: &Uiua, f: F) -> UiuaResult<Array<C>>
where
A: ArrayValue,
B: ArrayValue,
C: ArrayValue,
F: PervasiveFn<A, B, Output = C> + Clone,
F::Error: Into<UiuaError>,
{
let _a_depth = a_depth.min(a.rank());
let _b_depth = b_depth.min(b.rank());

let a_fill = env.scalar_fill::<A>();
let b_fill = env.scalar_fill::<B>();
let new_shape = derive_new_shape(
Expand Down Expand Up @@ -282,17 +272,12 @@ where
pub fn bin_pervade_mut<T>(
mut a: Array<T>,
b: &mut Array<T>,
a_depth: usize,
b_depth: usize,
env: &Uiua,
f: impl Fn(T, T) -> T + Copy,
) -> UiuaResult
where
T: ArrayValue + Copy,
{
let _a_depth = a_depth.min(a.rank());
let _b_depth = b_depth.min(b.rank());

fn derive_new_shape(
ash: &Shape,
bsh: &Shape,
Expand Down Expand Up @@ -1425,7 +1410,7 @@ bin_op_mod!(
a.atan2(b),
"Cannot get the atan2 of {a} and {b}"
);
pub mod pow {
pub mod scalar_pow {
use super::*;
pub fn num_num(a: f64, b: f64) -> f64 {
b.powf(a)
Expand Down
4 changes: 2 additions & 2 deletions src/algorithm/zip.rs
Original file line number Diff line number Diff line change
Expand Up @@ -275,9 +275,9 @@ impl<T: Clone> Array<T> {

fn spanned_dy_fn(
span: usize,
f: impl Fn(Value, Value, usize, usize, &Uiua) -> UiuaResult<Value> + 'static,
f: impl Fn(Value, Value, &Uiua) -> UiuaResult<Value> + 'static,
) -> ValueDyFn {
Box::new(move |a, b, ad, bd, env| env.with_span(span, |env| f(a, b, ad, bd, env)))
Box::new(move |a, b, _, _, env| env.with_span(span, |env| f(a, b, env)))
}

fn prim_dy_fast_fn(prim: Primitive, span: usize) -> Option<ValueDyFn> {
Expand Down
7 changes: 7 additions & 0 deletions src/compile/optimize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,8 @@ static OPTIMIZATIONS: &[&dyn Optimization] = &[
&((Pop, Rand), ReplaceRand),
&((Pop, Pop, Rand), ReplaceRand2),
&((1, Flip, Div, Pow), Root),
&((-1, Pow), (1, Flip, Div)),
&((2, Pow), (Dup, Mul)),
&InlineCustomInverse,
&TransposeOpt,
&ReduceTableOpt,
Expand Down Expand Up @@ -440,6 +442,11 @@ impl OptReplace for ImplPrimitive {
Node::ImplPrim(*self, span)
}
}
impl OptReplace for i32 {
fn replacement_node(&self, _: usize) -> Node {
Node::new_push(*self)
}
}

fn replace_nodes(nodes: &mut EcoVec<Node>, i: usize, n: usize, new: Node) {
// dbg!(&nodes, i, n, &new);
Expand Down
2 changes: 1 addition & 1 deletion src/primitive/defs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -838,7 +838,7 @@ primitive!(
/// : ∨ [0 1 0 1] [0 0 1 1]
/// ex: # Experimental!
/// : ⊞∨.[0 1]
/// Non boolean values give the GCD
/// Non-boolean values give the GCD.
/// ex: # Experimental!
/// : ∨ 16 24
/// ex: # Experimental!
Expand Down
42 changes: 21 additions & 21 deletions src/primitive/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -774,24 +774,24 @@ impl Primitive {
Primitive::Floor => env.monadic_env(Value::floor)?,
Primitive::Ceil => env.monadic_env(Value::ceil)?,
Primitive::Round => env.monadic_env(Value::round)?,
Primitive::Eq => env.dyadic_oo_00_env(Value::is_eq)?,
Primitive::Ne => env.dyadic_oo_00_env(Value::is_ne)?,
Primitive::Lt => env.dyadic_oo_00_env(Value::other_is_lt)?,
Primitive::Le => env.dyadic_oo_00_env(Value::other_is_le)?,
Primitive::Gt => env.dyadic_oo_00_env(Value::other_is_gt)?,
Primitive::Ge => env.dyadic_oo_00_env(Value::other_is_ge)?,
Primitive::Add => env.dyadic_oo_00_env(Value::add)?,
Primitive::Sub => env.dyadic_oo_00_env(Value::sub)?,
Primitive::Mul => env.dyadic_oo_00_env(Value::mul)?,
Primitive::Div => env.dyadic_oo_00_env(Value::div)?,
Primitive::Modulus => env.dyadic_oo_00_env(Value::modulus)?,
Primitive::Or => env.dyadic_oo_00_env(Value::or)?,
Primitive::Pow => env.dyadic_oo_00_env(Value::pow)?,
Primitive::Log => env.dyadic_oo_00_env(Value::log)?,
Primitive::Min => env.dyadic_oo_00_env(Value::min)?,
Primitive::Max => env.dyadic_oo_00_env(Value::max)?,
Primitive::Atan => env.dyadic_oo_00_env(Value::atan2)?,
Primitive::Complex => env.dyadic_oo_00_env(Value::complex)?,
Primitive::Eq => env.dyadic_oo_env(Value::is_eq)?,
Primitive::Ne => env.dyadic_oo_env(Value::is_ne)?,
Primitive::Lt => env.dyadic_oo_env(Value::other_is_lt)?,
Primitive::Le => env.dyadic_oo_env(Value::other_is_le)?,
Primitive::Gt => env.dyadic_oo_env(Value::other_is_gt)?,
Primitive::Ge => env.dyadic_oo_env(Value::other_is_ge)?,
Primitive::Add => env.dyadic_oo_env(Value::add)?,
Primitive::Sub => env.dyadic_oo_env(Value::sub)?,
Primitive::Mul => env.dyadic_oo_env(Value::mul)?,
Primitive::Div => env.dyadic_oo_env(Value::div)?,
Primitive::Modulus => env.dyadic_oo_env(Value::modulus)?,
Primitive::Or => env.dyadic_oo_env(Value::or)?,
Primitive::Pow => env.dyadic_oo_env(Value::pow)?,
Primitive::Log => env.dyadic_oo_env(Value::log)?,
Primitive::Min => env.dyadic_oo_env(Value::min)?,
Primitive::Max => env.dyadic_oo_env(Value::max)?,
Primitive::Atan => env.dyadic_oo_env(Value::atan2)?,
Primitive::Complex => env.dyadic_oo_env(Value::complex)?,
Primitive::Match => env.dyadic_rr(|a, b| a == b)?,
Primitive::Join => env.dyadic_oo_env(|a, b, env| a.join(b, true, env))?,
Primitive::Transpose => env.monadic_mut(Value::transpose)?,
Expand Down Expand Up @@ -1191,7 +1191,7 @@ impl ImplPrimitive {
ImplPrimitive::DeshapeSub(i) => {
env.monadic_mut_env(|val, env| val.deshape_sub(*i, true, env))?
}
ImplPrimitive::Root => env.dyadic_oo_00_env(Value::root)?,
ImplPrimitive::Root => env.dyadic_oo_env(Value::root)?,
ImplPrimitive::Cos => env.monadic_env(Value::cos)?,
ImplPrimitive::Asin => env.monadic_env(Value::asin)?,
ImplPrimitive::Acos => env.monadic_env(Value::acos)?,
Expand Down Expand Up @@ -1579,7 +1579,7 @@ impl ImplPrimitive {
ImplPrimitive::MatchLe => {
let max = env.pop(1)?;
let val = env.pop(2)?;
let le = max.clone().other_is_le(val.clone(), 0, 0, env)?;
let le = max.clone().other_is_le(val.clone(), env)?;
if le.all_true() {
env.push(val);
return Ok(());
Expand All @@ -1594,7 +1594,7 @@ impl ImplPrimitive {
ImplPrimitive::MatchGe => {
let min = env.pop(1)?;
let val = env.pop(2)?;
let ge = min.clone().other_is_ge(val.clone(), 0, 0, env)?;
let ge = min.clone().other_is_ge(val.clone(), env)?;
if ge.all_true() {
env.push(val);
return Ok(());
Expand Down
9 changes: 0 additions & 9 deletions src/run.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1219,15 +1219,6 @@ impl Uiua {
self.push(f(a, b, self)?);
Ok(())
}
pub(crate) fn dyadic_oo_00_env<V: Into<Value>>(
&mut self,
f: fn(Value, Value, usize, usize, &Self) -> UiuaResult<V>,
) -> UiuaResult {
let a = self.pop(1)?;
let b = self.pop(2)?;
self.push(f(a, b, 0, 0, self)?);
Ok(())
}
pub(crate) fn dyadic_rr_env<V: Into<Value>>(
&mut self,
f: fn(&Value, &Value, &Self) -> UiuaResult<V>,
Expand Down
34 changes: 23 additions & 11 deletions src/value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1583,6 +1583,18 @@ impl Value {
value => value.scalar_neg(env),
}
}
/// Raise a value to a power
pub fn pow(self, base: Self, env: &Uiua) -> UiuaResult<Self> {
if let Ok(pow) = self.as_int(env, "") {
match pow {
1 => return Ok(base),
2 => return base.clone().mul(base, env),
-1 => return base.div(Value::from(1), env),
_ => {}
}
}
self.scalar_pow(base, env)
}
}

fn optimize_types(a: Value, b: Value) -> (Value, Value) {
Expand All @@ -1604,7 +1616,7 @@ macro_rules! value_bin_impl {
),* ) => {
impl Value {
#[allow(unreachable_patterns, unused_mut, clippy::wrong_self_convention)]
pub(crate) fn $name(self, other: Self, a_depth: usize, b_depth: usize, env: &Uiua) -> UiuaResult<Self> {
pub(crate) fn $name(self, other: Self, env: &Uiua) -> UiuaResult<Self> {
let (mut a, mut b) = optimize_types(self, other);
a.match_fill(env);
b.match_fill(env);
Expand All @@ -1613,43 +1625,43 @@ macro_rules! value_bin_impl {
let f = |$meta: &ArrayMeta| $pred;
f(a.meta()) && f(b.meta())
})* => {
bin_pervade_mut(a, &mut b, a_depth, b_depth, env, $name::$f2)?;
bin_pervade_mut(a, &mut b, env, $name::$f2)?;
let mut val: Value = b.into();
$(if $reset_meta {
val.reset_meta_flags();
})*
val
},)*)*
$($((Value::$na(a), Value::$nb(b)) => {
let mut val: Value = bin_pervade(a, b, a_depth, b_depth, env, InfalliblePervasiveFn::new($name::$f1))?.into();
let mut val: Value = bin_pervade(a, b, env, InfalliblePervasiveFn::new($name::$f1))?.into();
val.reset_meta_flags();
val
},)*)*
(Value::Box(a), Value::Box(b)) => {
let (a, b) = match (a.into_unboxed(), b.into_unboxed()) {
(Ok(a), Ok(b)) => return Ok(Boxed(Value::$name(a, b, a_depth, b_depth, env)?).into()),
(Ok(a), Ok(b)) => return Ok(Boxed(Value::$name(a, b, env)?).into()),
(Ok(a), Err(b)) => (a.coerce_as_boxes().into_owned(), b),
(Err(a), Ok(b)) => (a, b.coerce_as_boxes().into_owned()),
(Err(a), Err(b)) => (a, b),
};
let mut val: Value = bin_pervade(a, b, a_depth, b_depth, env, FalliblePerasiveFn::new(|a: Boxed, b: Boxed, env: &Uiua| {
Ok(Boxed(Value::$name(a.0, b.0, a_depth, b_depth, env)?))
let mut val: Value = bin_pervade(a, b, env, FalliblePerasiveFn::new(|a: Boxed, b: Boxed, env: &Uiua| {
Ok(Boxed(Value::$name(a.0, b.0, env)?))
}))?.into();
val.reset_meta_flags();
val
}
(Value::Box(a), b) => {
let b = b.coerce_as_boxes().into_owned();
let mut val: Value = bin_pervade(a, b, a_depth, b_depth, env, FalliblePerasiveFn::new(|a: Boxed, b: Boxed, env: &Uiua| {
Ok(Boxed(Value::$name(a.0, b.0, a_depth, b_depth, env)?))
let mut val: Value = bin_pervade(a, b, env, FalliblePerasiveFn::new(|a: Boxed, b: Boxed, env: &Uiua| {
Ok(Boxed(Value::$name(a.0, b.0, env)?))
}))?.into();
val.reset_meta_flags();
val
},
(a, Value::Box(b)) => {
let a = a.coerce_as_boxes().into_owned();
let mut val: Value = bin_pervade(a, b, a_depth, b_depth, env, FalliblePerasiveFn::new(|a: Boxed, b: Boxed, env: &Uiua| {
Ok(Boxed(Value::$name(a.0, b.0, a_depth, b_depth, env)?))
let mut val: Value = bin_pervade(a, b, env, FalliblePerasiveFn::new(|a: Boxed, b: Boxed, env: &Uiua| {
Ok(Boxed(Value::$name(a.0, b.0, env)?))
}))?.into();
val.reset_meta_flags();
val
Expand Down Expand Up @@ -1704,7 +1716,7 @@ value_bin_math_impl!(
value_bin_math_impl!(div, (Num, Char, num_char), (Byte, Char, byte_char),);
value_bin_math_impl!(modulus, (Complex, Complex, com_com));
value_bin_math_impl!(or, [|meta| meta.flags.is_boolean(), Byte, bool_bool]);
value_bin_math_impl!(pow);
value_bin_math_impl!(scalar_pow);
value_bin_math_impl!(root);
value_bin_math_impl!(log);
value_bin_math_impl!(atan2);
Expand Down

0 comments on commit ed3e30d

Please sign in to comment.