Skip to content

This repo implements and trains DallE-1 on a synthetically generated dataset which has colored mnist images on texture/solid background and auto generated captions.

Notifications You must be signed in to change notification settings

explainingai-code/Dalle-Pytorch

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

17 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

DallE Implementation in pytorch with generation using mingpt

This repository implements DallE-1 Zero-Shot Text-to-Image Generation on a synthetic dataset of mnist colored numbers on textures/solid background .

DallE Tutorial Video

DallE Tutorial

Sample from dataset

A lot of parts of the implementation have been taken from below two repositories:

  1. GPT from - https://github.com/karpathy/minGPT/blob/master/mingpt/model.py

  2. Parts of DallE implementation from https://github.com/lucidrains/DALLE-pytorch/tree/main/dalle_pytorch .

    I have only kept the minimal version of Dalle which allows us to get decent results(on this dataset) and play around with it. If you are looking for a much more efficient and complete implementation please use the above repo.

Data preparation

For setting up the mnist dataset: Follow - https://github.com/explainingai-code/Pytorch-VAE#data-preparation

Download Quarter RGB resolution texture data from ALOT Homepage

If you are facing issues then use curl

curl -O https://aloi.science.uva.nl/public_alot/tars/alot_png4.tar

In case you want to train on higher resolution, you can download that as well and but you would have to create new train.json and test.json. Rest of the code should work fine as long as you create valid json files.

Download train.json and test.json from Drive Verify the data directory has the following structure after textures download

DallE/data/textures/{texture_number}
	*.png
DallE/data/train/images/{0/1/.../9}
	*.png
DallE/data/test/images/{0/1/.../9}
	*.png
DallE/data/train.json
DallE/data/test.json

Quickstart

  • Create a new conda environment with python 3.8 then run below commands
  • git clone https://github.com/explainingai-code/DallE.git
  • cd DallE
  • pip install -r requirements.txt
  • For training/inferencing discrete vae and gpt use the below commands passing the desired configuration file as the config argument in case you want to play with it.
  • python -m tools.train_dvae for training discrete vae
  • python -m tools.infer_dvae for generating reconstructions
  • python -m tools.train_dalle for training minimal version of DallE
  • python -m tools.generate_image for using the trained DallE to generate images

Configuration

  • config/default.yaml - Allows you to play with different components of discrete vae as well as DallE and play around with these modifications

Output

Outputs will be saved according to the configuration present in yaml files.

For every run a folder of task_name key in config will be created and output_train_dir will be created inside it.

During training of Discrete VAE and DallE the following output will be saved

  • Best Model checkpoints(DVAE and DallE) in task_name directory

During inference the following output will be saved

  • Reconstructions for sample of test set in task_name/dvae_reconstruction.png
  • GPT generation output in task_name/generation_results.png

Sample Output for DallE

Running default config DiscreteVAE should give you below reconstructions (left - input | right - reconstruction)

Sample Generation Output after 40 epochs with 4 layers and 512 hidden dimension and 8 attention heads

Generate 0 in blue and solid background of olive

Generate 1 in cyan and texture background of cracker

Generate 6 in pink and texture background of stones

Generate 8 in red and texture background of lego

Citations

@misc{ramesh2021zeroshot,
      title={Zero-Shot Text-to-Image Generation}, 
      author={Aditya Ramesh and Mikhail Pavlov and Gabriel Goh and Scott Gray and Chelsea Voss and Alec Radford and Mark Chen and Ilya Sutskever},
      year={2021},
      eprint={2102.12092},
      archivePrefix={arXiv},
      primaryClass={cs.CV}
}

About

This repo implements and trains DallE-1 on a synthetically generated dataset which has colored mnist images on texture/solid background and auto generated captions.

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages