From 41ac6b6551ac09fb0a597605272fd0fedf367b0b Mon Sep 17 00:00:00 2001 From: "Andrew X. Shah" Date: Sat, 9 Sep 2023 11:02:01 -0600 Subject: [PATCH] test(optimizer): add unit/doc tests --- src/optimizer/adagrad.rs | 38 ++++++++++++++++++++++++++++++-------- src/optimizer/mod.rs | 15 +++++++++++++++ src/optimizer/sgd.rs | 26 +++++++++++++++++++++----- 3 files changed, 66 insertions(+), 13 deletions(-) diff --git a/src/optimizer/adagrad.rs b/src/optimizer/adagrad.rs index 5a4e7f1..d8fac5a 100644 --- a/src/optimizer/adagrad.rs +++ b/src/optimizer/adagrad.rs @@ -1,13 +1,12 @@ -//! Adaptive Gradient (Adagrad): -//! -//! Adapts the learning rate based on the history of gradients. -//! Divides the learning rate by a running average of the magnitude of the gradients. -//! This allows the learning rate to decrease for parameters that have consistently large gradients -//! and increase for parameters that have consistently small gradients. -//! Includes an option to apply weight decay regularization to the gradients. - use crate::{Optimize, Tensor}; +/// Adaptive Gradient (Adagrad): +/// +/// Adapts the learning rate based on the history of gradients. +/// Divides the learning rate by a running average of the magnitude of the gradients. +/// This allows the learning rate to decrease for parameters that have consistently large gradients +/// and increase for parameters that have consistently small gradients. +/// Includes an option to apply weight decay regularization to the gradients. #[derive(Clone, Debug)] pub struct Adagrad { learning_rate: f64, @@ -17,6 +16,7 @@ pub struct Adagrad { } impl Adagrad { + /// Creates a new Adagrad optimizer with the specified parameters. pub fn new( learning_rate: f64, shape: (usize, usize), @@ -49,3 +49,25 @@ impl Optimize for Adagrad { ); } } + +#[cfg(test)] +mod tests { + use crate::*; + + #[test] + fn test_adagrad() { + let mut adagrad = Adagrad::new(0.1, (2, 3), None, Some(1e-8)); + let mut weights = tensor![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]; + let mut gradients = weights.gradient(&Activation::ReLU); + + adagrad.step(&mut weights, &mut gradients); + + assert_eq!( + weights, + tensor![ + [0.9000000005, 1.9000000005, 2.9000000005], + [3.9000000005, 4.9000000005, 5.9000000005] + ] + ); + } +} diff --git a/src/optimizer/mod.rs b/src/optimizer/mod.rs index 01697ab..986cd27 100644 --- a/src/optimizer/mod.rs +++ b/src/optimizer/mod.rs @@ -31,6 +31,21 @@ pub trait Optimize { } impl Optimize for Optimizer { + /// Updates the weights based on the gradients of the loss function + /// + /// # Examples + /// + /// ``` + /// # use engram::*; + /// + /// let mut optimizer = Optimizer::SGD { learning_rate: 0.1 }; + /// let mut weights = tensor![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]; + /// let mut gradients = weights.gradient(&Activation::ReLU); + /// + /// optimizer.step(&mut weights, &mut gradients); + /// + /// # assert_eq!(weights, tensor![[0.9, 1.9, 2.9], [3.9, 4.9, 5.9]]); + /// ``` fn step(&mut self, weights: &mut Tensor, gradients: &mut Tensor) { match self { Optimizer::SGD { learning_rate } => { diff --git a/src/optimizer/sgd.rs b/src/optimizer/sgd.rs index f4c7454..9e17d8a 100644 --- a/src/optimizer/sgd.rs +++ b/src/optimizer/sgd.rs @@ -1,16 +1,16 @@ -//! Stochastic Gradient Descent (SGD). -//! -//! A basic optimizer that updates the weights based on the gradients of the loss function -//! with respect to the weights multiplied by a learning rate. - use crate::{Optimize, Tensor}; +/// Stochastic Gradient Descent (SGD). +/// +/// A basic optimizer that updates the weights based on the gradients of the loss function +/// with respect to the weights multiplied by a learning rate. #[derive(Clone, Debug)] pub struct SGD { learning_rate: f64, } impl SGD { + /// Creates a new SGD optimizer with the specified parameters. pub fn new(learning_rate: f64) -> SGD { SGD { learning_rate } } @@ -21,3 +21,19 @@ impl Optimize for SGD { weights.sub_assign(&gradients.mul_scalar(self.learning_rate)); } } + +#[cfg(test)] +mod tests { + use crate::*; + + #[test] + fn test_sgd() { + let mut sgd = SGD::new(0.1); + let mut weights = tensor![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]; + let mut gradients = weights.gradient(&Activation::ReLU); + + sgd.step(&mut weights, &mut gradients); + + assert_eq!(weights, tensor![[0.9, 1.9, 2.9], [3.9, 4.9, 5.9]]); + } +}