Skip to content

Commit

Permalink
test(optimizer): add unit/doc tests
Browse files Browse the repository at this point in the history
  • Loading branch information
drewxs committed Sep 9, 2023
1 parent 0f69745 commit 41ac6b6
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 13 deletions.
38 changes: 30 additions & 8 deletions src/optimizer/adagrad.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
//! Adaptive Gradient (Adagrad):
//!
//! Adapts the learning rate based on the history of gradients.
//! Divides the learning rate by a running average of the magnitude of the gradients.
//! This allows the learning rate to decrease for parameters that have consistently large gradients
//! and increase for parameters that have consistently small gradients.
//! Includes an option to apply weight decay regularization to the gradients.

use crate::{Optimize, Tensor};

/// Adaptive Gradient (Adagrad):
///
/// Adapts the learning rate based on the history of gradients.
/// Divides the learning rate by a running average of the magnitude of the gradients.
/// This allows the learning rate to decrease for parameters that have consistently large gradients
/// and increase for parameters that have consistently small gradients.
/// Includes an option to apply weight decay regularization to the gradients.
#[derive(Clone, Debug)]
pub struct Adagrad {
learning_rate: f64,
Expand All @@ -17,6 +16,7 @@ pub struct Adagrad {
}

impl Adagrad {
/// Creates a new Adagrad optimizer with the specified parameters.
pub fn new(
learning_rate: f64,
shape: (usize, usize),
Expand Down Expand Up @@ -49,3 +49,25 @@ impl Optimize for Adagrad {
);
}
}

#[cfg(test)]
mod tests {
use crate::*;

#[test]
fn test_adagrad() {
let mut adagrad = Adagrad::new(0.1, (2, 3), None, Some(1e-8));
let mut weights = tensor![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
let mut gradients = weights.gradient(&Activation::ReLU);

adagrad.step(&mut weights, &mut gradients);

assert_eq!(
weights,
tensor![
[0.9000000005, 1.9000000005, 2.9000000005],
[3.9000000005, 4.9000000005, 5.9000000005]
]
);
}
}
15 changes: 15 additions & 0 deletions src/optimizer/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,21 @@ pub trait Optimize {
}

impl Optimize for Optimizer {
/// Updates the weights based on the gradients of the loss function
///
/// # Examples
///
/// ```
/// # use engram::*;
///
/// let mut optimizer = Optimizer::SGD { learning_rate: 0.1 };
/// let mut weights = tensor![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
/// let mut gradients = weights.gradient(&Activation::ReLU);
///
/// optimizer.step(&mut weights, &mut gradients);
///
/// # assert_eq!(weights, tensor![[0.9, 1.9, 2.9], [3.9, 4.9, 5.9]]);
/// ```
fn step(&mut self, weights: &mut Tensor, gradients: &mut Tensor) {
match self {
Optimizer::SGD { learning_rate } => {
Expand Down
26 changes: 21 additions & 5 deletions src/optimizer/sgd.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
//! Stochastic Gradient Descent (SGD).
//!
//! A basic optimizer that updates the weights based on the gradients of the loss function
//! with respect to the weights multiplied by a learning rate.

use crate::{Optimize, Tensor};

/// Stochastic Gradient Descent (SGD).
///
/// A basic optimizer that updates the weights based on the gradients of the loss function
/// with respect to the weights multiplied by a learning rate.
#[derive(Clone, Debug)]
pub struct SGD {
learning_rate: f64,
}

impl SGD {
/// Creates a new SGD optimizer with the specified parameters.
pub fn new(learning_rate: f64) -> SGD {
SGD { learning_rate }
}
Expand All @@ -21,3 +21,19 @@ impl Optimize for SGD {
weights.sub_assign(&gradients.mul_scalar(self.learning_rate));
}
}

#[cfg(test)]
mod tests {
use crate::*;

#[test]
fn test_sgd() {
let mut sgd = SGD::new(0.1);
let mut weights = tensor![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
let mut gradients = weights.gradient(&Activation::ReLU);

sgd.step(&mut weights, &mut gradients);

assert_eq!(weights, tensor![[0.9, 1.9, 2.9], [3.9, 4.9, 5.9]]);
}
}

0 comments on commit 41ac6b6

Please sign in to comment.