Skip to content

Commit

Permalink
test(activation): mv doc tests to unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
drewxs committed Sep 9, 2023
1 parent 41ac6b6 commit b648b47
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 64 deletions.
33 changes: 17 additions & 16 deletions src/activation/leaky_relu.rs
Original file line number Diff line number Diff line change
@@ -1,29 +1,30 @@
/// Returns the result of the leaky rectified linear unit function.
///
/// # Examples
///
/// ```
/// # use engram::*;
/// assert_eq!(leaky_relu(0.0), 0.0, "leaky_relu(0.0)");
/// assert_eq!(leaky_relu(-10.0), -0.1, "leaky_relu(-10.0)");
/// ```
pub fn leaky_relu(x: f64) -> f64 {
f64::max(x, 0.01 * x)
}

/// Returns the derivative of the leaky rectified linear unit function.
///
/// # Examples
///
/// ```
/// # use engram::*;
/// assert_eq!(d_leaky_relu(0.0), 0.01, "d_leaky_relu(0.0)");
/// assert_eq!(d_leaky_relu(-10.0), 0.01, "d_leaky_relu(-10.0)");
/// ```
pub fn d_leaky_relu(x: f64) -> f64 {
if x > 0.0 {
1.0
} else {
0.01
}
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn test_leaky_relu() {
assert_eq!(leaky_relu(0.0), 0.0, "leaky_relu(0.0)");
assert_eq!(leaky_relu(-10.0), -0.1, "leaky_relu(-10.0)");
}

#[test]
fn test_d_leaky_relu() {
assert_eq!(d_leaky_relu(0.0), 0.01, "d_leaky_relu(0.0)");
assert_eq!(d_leaky_relu(-10.0), 0.01, "d_leaky_relu(-10.0)");
}
}
33 changes: 17 additions & 16 deletions src/activation/relu.rs
Original file line number Diff line number Diff line change
@@ -1,29 +1,30 @@
/// Returns the result of the rectified linear unit function.
///
/// # Examples
///
/// ```
/// # use engram::*;
/// assert_eq!(relu(0.0), 0.0, "relu(0.0)");
/// assert_eq!(relu(10.0), 10.0, "relu(10.0)");
/// ```
pub fn relu(x: f64) -> f64 {
f64::max(x, 0.0)
}

/// Returns the derivative of the rectified linear unit function.
///
/// # Examples
///
/// ```
/// # use engram::*;
/// assert_eq!(d_relu(0.0), 0.0, "d_relu(0.0)");
/// assert_eq!(d_relu(-10.0), 0.0, "d_relu(-10.0)");
/// ```
pub fn d_relu(x: f64) -> f64 {
if x > 0.0 {
1.0
} else {
0.0
}
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn test_relu() {
assert_eq!(relu(0.0), 0.0, "relu(0.0)");
assert_eq!(relu(10.0), 10.0, "relu(10.0)");
}

#[test]
fn test_d_relu() {
assert_eq!(d_relu(0.0), 0.0, "d_relu(0.0)");
assert_eq!(d_relu(-10.0), 0.0, "d_relu(-10.0)");
}
}
33 changes: 17 additions & 16 deletions src/activation/sigmoid.rs
Original file line number Diff line number Diff line change
@@ -1,26 +1,27 @@
/// Returns the result of the sigmoid function.
///
/// # Examples
///
/// ```
/// # use engram::*;
/// assert_eq!(sigmoid(0.0), 0.5, "sigmoid(0.0)");
/// assert_eq!(sigmoid(1.0), 0.7310585786300049, "sigmoid(1.0)");
/// ```
pub fn sigmoid(x: f64) -> f64 {
1.0 / (1.0 + (-x).exp())
}

/// Returns the derivative of the sigmoid function.
///
/// # Examples
///
/// ```
/// # use engram::*;
/// assert_eq!(d_sigmoid(0.0), 0.25, "d_sigmoid(0.0)");
/// assert_eq!(d_sigmoid(1.0), 0.19661193324148185, "d_sigmoid(1.0)");
/// ```
pub fn d_sigmoid(x: f64) -> f64 {
let y = sigmoid(x);
y * (1.0 - y)
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn test_sigmoid() {
assert_eq!(sigmoid(0.0), 0.5, "sigmoid(0.0)");
assert_eq!(sigmoid(1.0), 0.7310585786300049, "sigmoid(1.0)");
}

#[test]
fn test_d_sigmoid() {
assert_eq!(d_sigmoid(0.0), 0.25, "d_sigmoid(0.0)");
assert_eq!(d_sigmoid(1.0), 0.19661193324148185, "d_sigmoid(1.0)");
}
}
33 changes: 17 additions & 16 deletions src/activation/tanh.rs
Original file line number Diff line number Diff line change
@@ -1,25 +1,26 @@
/// Returns the result of the hyperbolic tangent function.
///
/// # Examples
///
/// ```
/// # use engram::*;
/// assert_eq!(tanh(0.0), 0.0, "tanh(0.0)");
/// assert_eq!(tanh(1.0), 0.7615941559557649, "tanh(1.0)");
/// ```
pub fn tanh(x: f64) -> f64 {
x.tanh()
}

/// Returns the derivative of the hyperbolic tangent function.
///
/// # Examples
///
/// ```
/// # use engram::*;
/// assert_eq!(d_tanh(0.0), 1.0, "d_tanh(0.0)");
/// assert_eq!(d_tanh(1.0), 0.41997434161402614, "d_tanh(1.0)");
/// ```
pub fn d_tanh(x: f64) -> f64 {
1.0 - tanh(x).powi(2)
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn test_tanh() {
assert_eq!(tanh(0.0), 0.0, "tanh(0.0)");
assert_eq!(tanh(1.0), 0.7615941559557649, "tanh(1.0)");
}

#[test]
fn test_d_tanh() {
assert_eq!(d_tanh(0.0), 1.0, "d_tanh(0.0)");
assert_eq!(d_tanh(1.0), 0.41997434161402614, "d_tanh(1.0)");
}
}

0 comments on commit b648b47

Please sign in to comment.