diff --git a/src/activation/leaky_relu.rs b/src/activation/leaky_relu.rs index d6eb274..34915c8 100644 --- a/src/activation/leaky_relu.rs +++ b/src/activation/leaky_relu.rs @@ -1,25 +1,9 @@ /// Returns the result of the leaky rectified linear unit function. -/// -/// # Examples -/// -/// ``` -/// # use engram::*; -/// assert_eq!(leaky_relu(0.0), 0.0, "leaky_relu(0.0)"); -/// assert_eq!(leaky_relu(-10.0), -0.1, "leaky_relu(-10.0)"); -/// ``` pub fn leaky_relu(x: f64) -> f64 { f64::max(x, 0.01 * x) } /// Returns the derivative of the leaky rectified linear unit function. -/// -/// # Examples -/// -/// ``` -/// # use engram::*; -/// assert_eq!(d_leaky_relu(0.0), 0.01, "d_leaky_relu(0.0)"); -/// assert_eq!(d_leaky_relu(-10.0), 0.01, "d_leaky_relu(-10.0)"); -/// ``` pub fn d_leaky_relu(x: f64) -> f64 { if x > 0.0 { 1.0 @@ -27,3 +11,20 @@ pub fn d_leaky_relu(x: f64) -> f64 { 0.01 } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_leaky_relu() { + assert_eq!(leaky_relu(0.0), 0.0, "leaky_relu(0.0)"); + assert_eq!(leaky_relu(-10.0), -0.1, "leaky_relu(-10.0)"); + } + + #[test] + fn test_d_leaky_relu() { + assert_eq!(d_leaky_relu(0.0), 0.01, "d_leaky_relu(0.0)"); + assert_eq!(d_leaky_relu(-10.0), 0.01, "d_leaky_relu(-10.0)"); + } +} diff --git a/src/activation/relu.rs b/src/activation/relu.rs index eac215b..f42925a 100644 --- a/src/activation/relu.rs +++ b/src/activation/relu.rs @@ -1,25 +1,9 @@ /// Returns the result of the rectified linear unit function. -/// -/// # Examples -/// -/// ``` -/// # use engram::*; -/// assert_eq!(relu(0.0), 0.0, "relu(0.0)"); -/// assert_eq!(relu(10.0), 10.0, "relu(10.0)"); -/// ``` pub fn relu(x: f64) -> f64 { f64::max(x, 0.0) } /// Returns the derivative of the rectified linear unit function. -/// -/// # Examples -/// -/// ``` -/// # use engram::*; -/// assert_eq!(d_relu(0.0), 0.0, "d_relu(0.0)"); -/// assert_eq!(d_relu(-10.0), 0.0, "d_relu(-10.0)"); -/// ``` pub fn d_relu(x: f64) -> f64 { if x > 0.0 { 1.0 @@ -27,3 +11,20 @@ pub fn d_relu(x: f64) -> f64 { 0.0 } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_relu() { + assert_eq!(relu(0.0), 0.0, "relu(0.0)"); + assert_eq!(relu(10.0), 10.0, "relu(10.0)"); + } + + #[test] + fn test_d_relu() { + assert_eq!(d_relu(0.0), 0.0, "d_relu(0.0)"); + assert_eq!(d_relu(-10.0), 0.0, "d_relu(-10.0)"); + } +} diff --git a/src/activation/sigmoid.rs b/src/activation/sigmoid.rs index d067570..61e2479 100644 --- a/src/activation/sigmoid.rs +++ b/src/activation/sigmoid.rs @@ -1,26 +1,27 @@ /// Returns the result of the sigmoid function. -/// -/// # Examples -/// -/// ``` -/// # use engram::*; -/// assert_eq!(sigmoid(0.0), 0.5, "sigmoid(0.0)"); -/// assert_eq!(sigmoid(1.0), 0.7310585786300049, "sigmoid(1.0)"); -/// ``` pub fn sigmoid(x: f64) -> f64 { 1.0 / (1.0 + (-x).exp()) } /// Returns the derivative of the sigmoid function. -/// -/// # Examples -/// -/// ``` -/// # use engram::*; -/// assert_eq!(d_sigmoid(0.0), 0.25, "d_sigmoid(0.0)"); -/// assert_eq!(d_sigmoid(1.0), 0.19661193324148185, "d_sigmoid(1.0)"); -/// ``` pub fn d_sigmoid(x: f64) -> f64 { let y = sigmoid(x); y * (1.0 - y) } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_sigmoid() { + assert_eq!(sigmoid(0.0), 0.5, "sigmoid(0.0)"); + assert_eq!(sigmoid(1.0), 0.7310585786300049, "sigmoid(1.0)"); + } + + #[test] + fn test_d_sigmoid() { + assert_eq!(d_sigmoid(0.0), 0.25, "d_sigmoid(0.0)"); + assert_eq!(d_sigmoid(1.0), 0.19661193324148185, "d_sigmoid(1.0)"); + } +} diff --git a/src/activation/tanh.rs b/src/activation/tanh.rs index 4d398ec..df3a8b8 100644 --- a/src/activation/tanh.rs +++ b/src/activation/tanh.rs @@ -1,25 +1,26 @@ /// Returns the result of the hyperbolic tangent function. -/// -/// # Examples -/// -/// ``` -/// # use engram::*; -/// assert_eq!(tanh(0.0), 0.0, "tanh(0.0)"); -/// assert_eq!(tanh(1.0), 0.7615941559557649, "tanh(1.0)"); -/// ``` pub fn tanh(x: f64) -> f64 { x.tanh() } /// Returns the derivative of the hyperbolic tangent function. -/// -/// # Examples -/// -/// ``` -/// # use engram::*; -/// assert_eq!(d_tanh(0.0), 1.0, "d_tanh(0.0)"); -/// assert_eq!(d_tanh(1.0), 0.41997434161402614, "d_tanh(1.0)"); -/// ``` pub fn d_tanh(x: f64) -> f64 { 1.0 - tanh(x).powi(2) } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_tanh() { + assert_eq!(tanh(0.0), 0.0, "tanh(0.0)"); + assert_eq!(tanh(1.0), 0.7615941559557649, "tanh(1.0)"); + } + + #[test] + fn test_d_tanh() { + assert_eq!(d_tanh(0.0), 1.0, "d_tanh(0.0)"); + assert_eq!(d_tanh(1.0), 0.41997434161402614, "d_tanh(1.0)"); + } +}