Skip to content

Commit

Permalink
load and save models
Browse files Browse the repository at this point in the history
  • Loading branch information
maximilian-heeg committed Jul 4, 2024
1 parent b581940 commit 6771b6a
Show file tree
Hide file tree
Showing 9 changed files with 105 additions and 14 deletions.
1 change: 1 addition & 0 deletions 128.json

Large diffs are not rendered by default.

5 changes: 3 additions & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,6 @@ leptos_hotkeys = "0.2.1"
rand = "0.8.5"
rayon = "1.10.0"
serde = { version = "1.0.203", features = ["derive"] }
serde_json = "1.0.120"
strum = "0.26.3"
strum_macros = "0.26.4"
38 changes: 30 additions & 8 deletions src/main.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use clap::{Parser, Subcommand};
use leptos::*;
use leptos_2048::*;
use nn::activation::ActivationFunction;
use nn::{activation::ActivationFunction, NeuralNetwork};
use population::Population;
use ui::RenderGame;

Expand All @@ -17,12 +17,20 @@ struct Arguments {
enum Commands {
Web {},

Cli {},
Train {
/// Save the model every 10 steps
#[arg(short, long)]
save: Option<String>,

/// Load the model from file
#[arg(short, long)]
load: Option<String>,
},
}

pub const BRAIN_MUTATION_RATE: f64 = 0.1;
pub const BRAIN_MUTATION_VARIATION: f64 = 0.1;
pub const AGENTS_KEEP_PROPORTION: f64 = 0.05;
pub const AGENTS_KEEP_PROPORTION: f64 = 0.02;

fn main() {
let args = Arguments::parse();
Expand All @@ -33,16 +41,27 @@ fn main() {
<RenderGame />
}
}),
Some(Commands::Cli {}) => {
let layers = &[16, 12, 8, 4];
let act_funs = &[ActivationFunction::ReLU; 3];

Some(Commands::Train { save, load }) => {
let rounds = 100;
let max_steps = 10000;
let evolution_steps = 10000;
let n_agents = 1000;

let mut population = Population::new(n_agents, layers, act_funs);
let mut population = match load {
Some(file) => {
let nn = NeuralNetwork::load(&file).expect("Failed to load NN");
Population::from_nn(n_agents, nn)
}
None => {
let layers = &[16, 128, 64, 4];
let act_funs = &[
ActivationFunction::ReLU,
ActivationFunction::ReLU,
ActivationFunction::None,
];
Population::new(n_agents, layers, act_funs)
}
};

for _ in 0..evolution_steps {
for _ in 0..rounds {
Expand All @@ -60,6 +79,9 @@ fn main() {
best.avg_score(),
best.get_highest_tile().expect("Error getting best tile")
);
if let Some(file) = save.clone() {
best.nn.save(&file).expect("Failed to save model");
}
}

population.evolve(
Expand Down
4 changes: 3 additions & 1 deletion src/nn/activation.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use serde::{Deserialize, Serialize};

/// Computes the sigmoid activation function.
pub fn sigmoid(x: f64) -> f64 {
1.0 / (1.0 + (-x).exp())
Expand All @@ -18,7 +20,7 @@ pub fn tanh(x: f64) -> f64 {
}

/// Enumeration of possible activation functions.
#[derive(Clone, Copy)]
#[derive(Clone, Copy, Serialize, Deserialize)]
pub enum ActivationFunction {
Sigmoid,
ReLU,
Expand Down
3 changes: 2 additions & 1 deletion src/nn/layer.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
use super::activation::ActivationFunction;
use super::node::Node;
use serde::{Deserialize, Serialize};

/// A layer in a neural network, consisting of multiple nodes.
#[derive(Clone)]
#[derive(Clone, Serialize, Deserialize)]
pub struct Layer {
pub nodes: Vec<Node>,
pub activation_function: ActivationFunction,
Expand Down
53 changes: 52 additions & 1 deletion src/nn/mod.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
pub mod activation;
mod layer;
mod node;
use std::fs::File;
use std::io::{self};

use activation::ActivationFunction;
use layer::Layer;
use serde::{Deserialize, Serialize};

/// A neural network consisting of multiple layers.
#[derive(Clone)]
#[derive(Clone, Serialize, Deserialize)]
pub struct NeuralNetwork {
pub layers: Vec<Layer>,
}
Expand Down Expand Up @@ -69,12 +72,43 @@ impl NeuralNetwork {
layer.update(rate, variation);
}
}

/// Saves the neural network to a file in JSON format.
///
/// # Arguments
///
/// * `filename` - The name of the file to save the neural network to.
///
/// # Returns
///
/// A `Result` indicating success or failure.
pub fn save(&self, filename: &str) -> io::Result<()> {
let file = File::create(filename)?;
serde_json::to_writer(file, &self)?;
Ok(())
}

/// Loads a neural network from a file in JSON format.
///
/// # Arguments
///
/// * `filename` - The name of the file to load the neural network from.
///
/// # Returns
///
/// A `Result` containing the loaded `NeuralNetwork` or an error.
pub fn load(filename: &str) -> io::Result<Self> {
let file = File::open(filename)?;
let network = serde_json::from_reader(file)?;
Ok(network)
}
}

#[cfg(test)]
mod tests {
use super::activation::ActivationFunction;
use super::*;
use std::fs;

#[test]
fn test_new_neural_network() {
Expand Down Expand Up @@ -114,4 +148,21 @@ mod tests {
assert_eq!(outputs.len(), 1);
assert!((outputs[0] - expected_final_output).abs() < 1e-6);
}

#[test]
fn test_save_and_load() {
let layer_sizes = vec![2, 3, 1];
let activation_functions = vec![ActivationFunction::ReLU, ActivationFunction::Sigmoid];
let network = NeuralNetwork::new(&layer_sizes, &activation_functions);

let filename = "test_model.json";
network.save(filename).expect("Failed to save the network");

let loaded_network = NeuralNetwork::load(filename).expect("Failed to load the network");

assert_eq!(network.layers.len(), loaded_network.layers.len());

// Clean up
fs::remove_file(filename).expect("Failed to remove test file");
}
}
3 changes: 2 additions & 1 deletion src/nn/node.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
use rand::Rng;
use serde::{Deserialize, Serialize};

/// A single node (neuron) in a neural network layer.
#[derive(Clone)]
#[derive(Clone, Serialize, Deserialize)]
pub struct Node {
pub weights: Vec<f64>,
pub bias: f64,
Expand Down
11 changes: 11 additions & 0 deletions src/population/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,17 @@ impl Population {
}
}

pub fn from_nn(n_agents: usize, nn: NeuralNetwork) -> Self {
let mut agents = vec![];
for _ in 0..n_agents {
agents.push(Agent::new(nn.clone(), Game::new()));
}
Self {
agents,
evolution_step: 0,
}
}

pub fn play(&mut self, max_steps: usize) {
let mut v = vec![0usize; 5];
v.par_iter_mut().enumerate().for_each(|(i, x)| *x = i);
Expand Down

0 comments on commit 6771b6a

Please sign in to comment.