This repo features a benchmark on the joint self-attention and selective state space models for conditional image synthesis using latent diffusion models.
It contains model implementation of:
DiT
: Diffusion Transformer (baseline)DiM
: Diffusion Mamba (ours)M-DiT
: Mamba Diffusion Transformer (ours)DiMT
: Diffusion Mamba-Transformer (ours)E-DiMT
: Enhanced Diffusion Mamba-Transformer (ours)
The repository is organised as follows:
- This repository is built upon the fast-DiT, an improved PyTorch implementation of Scalable Diffusion Models with Transformers (DiT).
- Refer to Vim, as the official implementation of the paper Vision Mamba: Efficient Visual Representation Learning with Bidirectional State Space Model.
- Refer to ADM's TensorFlow evaluation suite for more details about the compute of FID, Inception Score and other metrics.
e-dimt-main
├── diffusion/ # Diffusion dir
├── docs/ # Documentation figures
├── evaluator/ # Evaluator dir
│ └── eval.py # Evaluation script
│ └── ...
├── mamba/ # Mamba model dir
├── models/ # Backbones dir
│ └── layers.py # Layers and utility functions
│ └── edimt.py # e-dimt backbone
│ └── ...
├── extract_features.py # Feature extraction script
├── requirements.txt # Requirements
├── sample.py # Sampling script
├── sample_ddp.py # DDP sampling script
└── train.py # Training script
First, download and set up the repo:
git clone https://github.com/ahmedgh970/e-dimt.git
cd EDiMT
Then, create a python 3.10 conda env and install the requirements
conda create --name edimt python=3.10
conda activate edimt
pip install -r requirements.txt
cd mamba
pip install -e .
To extract ImageNet-256 features:
torchrun --nnodes=1 --nproc_per_node=1 extract_features.py --image-size 256 --data-path /path/to/imagenet256 --features-path /path/to/store/features
We provide a training script for E-DiMT model in train.py
. Please modify the necessary import and model definition/name to train a different model from the proposed benchmark.
To launch EDiMT-L/2 (256x256) training with N
GPUs on one node:
accelerate launch --multi_gpu --num_processes N --mixed_precision fp16 train.py --model EDiMT-L/2 --image-size 256 --features-path /path/to/store/features
To sample from the EMA weights of a trained (256x256) EDiMT-L/2 model, run:
python sample.py --model EDiMT-L/2 --image-size 256 --ckpt /path/to/model.pt
The sampling results will be saved in the model results dir and inside samples dir.
To samples 50K images from E-DiMT model over N
GPUs, run:
torchrun --nnodes=1 --nproc_per_node=N sample_ddp.py --model EDiMT-L/2 --num-fid-samples 50000
This script generates a folder of samples as well as a .npz
file which can be directly used with evaluator/eval.py
to compute FID, Inception Score and other metrics, as follows:
python eval.py /path/to/refernce/batch/.npz /path/to/generated/batch/.npz
...
computing reference batch activations...
computing/reading reference batch statistics...
computing sample batch activations...
computing/reading sample batch statistics...
Computing evaluations...
Inception Score: 215.8370361328125
FID: 3.9425574129223264
sFID: 6.140433703346162
Precision: 0.8265
Recall: 0.5309