This repository implements a simple VAE for training on CPU on the MNIST dataset and provides ability to visualize the latent space, entire manifold as well as visualize how numbers interpolate between each other.
The purpose of this project is to get a better understanding of VAE by playing with the different parameters and visualizations.
- Create a new conda environment with python 3.8 then run below commands
git clone https://github.com/explainingai-code/Pytorch-VAE.git
cd Pytorch-VAE
pip install -r requirements.txt
- For running a simple fc layer backed VAE with latent dimension as 2 run
python run_simple_vae.py
- For playing around with VAE and running visualizations, replace tools/train_vae.py and tools/inference.py config argument with the desired one or pass that in the next set of commands
python -m tools.train_vae
python -m tools.inference
config/vae_nokl.yaml
- VAE with only reconstruction lossconfig/vae_kl.yaml
- VAE with reconstruction and KL lossconfig/vae_kl_latent4.yaml
- VAE with reconstruction and KL loss with latent dimension as 4(instead of 2)config/vae_kl_latent4_enc_channel_dec_fc_condition.yaml
- Conditional VAE with reconstruction and KL loss with latent dimension as 4
We don't use the torchvision mnist dataset to allow replacement with any other image dataset.
For setting up the dataset:
- Create
data/train/images
anddata/test/images
folders - Download the csv files for mnist(https://www.kaggle.com/datasets/oddrationale/mnist-in-csv)
and save them under
data
directory. - Run
python utils/extract_mnist_images.py
Verify the data directory has the following structure:
Pytorch-VAE/data/train/images/{0/1/.../9}
*.png
Pytorch-VAE/data/test/images/{0/1/.../9}
*.png
Outputs will be saved according to the configuration present in yaml files.
For every run a folder of task_name
key in config will be created and output_train_dir
will be created inside it.
During training the following output will be saved
- Best Model checkpoints in
task_name
directory - PCA information in pickle file in
task_name
directory - 2D Latent space plotting the images of test set for each epoch in
task_name/output_train_dir
directory
During inference the following output will be saved
- Reconstructions for sample of test set in
task_name/output_train_dir/reconstruction.png
- Decoder output for sample of points evenly spaced across the projection of latent space on 2D in
task_name/output_train_dir/manifold.png
- Interpolation between two randomly sampled points in
task_name/output_train_dir/interp
directory
Latent Visualization
Manifold
Reconstruction Images(reconstruction in black font and original in white font)
Because we end up passing the label to the decoder, the model ends up learning the capability to generate ALL numbers from all points in the latent space.
The model will learn to distinguish points in latent space based on if it should generate a left or right tilted digit or how thick the stroke for digit should be. Below one can visulize those patterns when we attempt to generate all numbers from all points.
Reconstruction Images(reconstruction in black font and original in white font)