Utkarsh Ojha, Yijun Li, Jingwan Lu, Alexei A. Efros, Yong Jae Lee, Eli Shechtman, Richard Zhang
Adobe Research, UC Davis, UC Berkeley
PyTorch implementation of adapting a source GAN (trained on a large dataset) to a target domain using very few images.
Our method helps adapt the source GAN where one-to-one correspondence is preserved between the source Gs(z) and target Gt(z) images.
Note: The base model is taken from StyleGAN2's implementation from @rosinality
- Linux
- NVIDIA GPU + CUDA CuDNN 10.2
- PyTorch 1.7.0
- Python 3.6.9
- Install all the libraries through
pip install -r requirements.txt
We provide the pre-trained models for different source and adapted (target) GAN models.
Source GAN: Gs | Target GAN: Gs→t |
---|---|
FFHQ | [Sketches] [Caricatures] [Amedeo Modigliani] [Babies] [Sunglasses] [Rafael] [Otto Dix] |
LSUN Church | [Haunted houses] [Van Gogh houses [Landscapes] [Caricatures] |
LSUN Cars | [Wrecked cars] [Landscapes] [Haunted houses] [Caricatures] |
LSUN Horses | [Landscapes] [Caricatures] [Haunted houses] |
Hand gestures | [Google Maps] [Landscapes] |
For now, we have only included the pre-trained models using FFHQ as the source domain, i.e. all the models in the first row. We will add the remaining ones soon.
Download the pre-trained model(s), and store it into ./checkpoints
directory.
To generate images from a pre-trained GAN, run the following command:
CUDA_VISIBLE_DEVICES=0 python generate.py --ckpt_target /path/to/model/
Here, model_name
follows the notation of source_target
, e.g. ffhq_sketches
. Use the --load_noise
option to use the noise vectors used for some figures in the paper (Figures 1-4). For example:
CUDA_VISIBLE_DEVICES=0 python generate.py --ckpt_target ./checkpoints/ffhq_sketches.pt --load_noise noise.pt
This will save the images in the test_samples/
directory.
To visualize the same noise in the source and adapted models, i.e. Gs(z) and Gs→t(z), run the following command(s):
# generate two image grids of 5x5 for source and target
CUDA_VISIBLE_DEVICES=0 python generate.py --ckpt_source /path/to/source --ckpt_target /path/to/target --load_noise noise.pt
# visualize the interpolations of source and target
CUDA_VISIBLE_DEVICES=0 python generate.py --ckpt_source /path/to/source --ckpt_target /path/to/source --load_noise noise.pt --mode interpolate
python traversal_gif.py 10
- The second argument when running
traversal_gif.py
denotes the number of images you want to interpolate between. --n_sample
determines the number of images to be sampled (default set to 25).--n_steps
determines the number of steps taken when interpolating from G(z1) to G(z2) (default set to 40).--mode
option determines the visualization type: generating either the images or interpolation .gif.- The .gif file will be saved in
gifs/
directory.
We collected images of random hand gestures being performed on a plain surface (~ 18k images), and used that as the data to train a source model (from scratch). We then adapted it to two different target domains; Landscape images and Google maps. The goal was to see if, during inference, interpolating the hand genstures can result in meaningful variations in the target images. Run the following commands to see the results:
CUDA_VISIBLE_DEVICES=0 python generate.py --ckpt_source /path/to/source --ckpt_target /path/to/maps(landscapes) --load_noise noise.pt --mode interpolate
The following table provides a link to the test set of domains used in Table 1:
Download, and unzip the set of images into your desired directory, and compute the FID score (taken from pytorch-fid) between the real (Rtest) and fake (F) images, by running the following command
python -m pytorch_fid /path/to/real/images /path/to/fake/images
Download the entire set of images from this link (1.1 GB), which are used for the results in Table 2. The organization of this collection is as follows:
cluster_centers
└── amedeo # target domain -- will be from [amedeo, sketches]
└── ours # baseline -- will be from [tgan, tgan_ada, freezeD, ewc, ours]
└── c0 # center id -- there will be 10 clusters [c0, c1 ... c9]
├── center.png # cluster center -- this is one of the 10 training images used. Each cluster will have its own center
│── img0.png # generated images which matched with this cluster's center, according to LPIPS metric.
│── img1.png
│ .
│ .
Unzip the file, and then run the following command to compute the results for a baseline on a dataset:
CUDA_VISIBLE_DEVICES=0 python feat_cluster.py --baseline <baseline> --dataset <target_domain> --mode intra_cluster_dist
E.g.
CUDA_VISIBLE_DEVICES=0 python feat_cluster.py --baseline tgan --dataset sketches --mode intra_cluster_dist
We also provide the utility to visualize the closest and farthest members of a cluster, as shown in Figure 14 (shown below), using the following command:
CUDA_VISIBLE_DEVICES=0 python feat_cluster.py --baseline tgan --dataset sketches --mode visualize_members
The command will save the generated image which is closest/farthest to/from a center as closest.png
/farthest.png
respectively.
- Only the pre-trained model is needed, i.e. no need for access to the source data.
- Refer to the first column of the pre-trained models table above.
- If you wish to use some other source model, make sure that it follows the generator architecture defined in this pytorch implementation of StyleGAN2.
- Below are the links to all the target domains, each consisting of 10 images, used in the paper.
Sketches | Amedeo Modigliani | Babies | Sunglasses | Rafael | Otto Dix | Haunted houses | Van Gogh houses | Landscapes | Wrecked cars | Maps |
---|---|---|---|---|---|---|---|---|---|---|
images | images | images | images | images | images | images | images | images | images | images |
processed | processed | processed | processed | processed | processed | processed | processed | processed | processed | processed |
Note We cannot share the images for the caricature domain due to license issues.
-
If downloading the raw images, unzip them into
./raw_data
folder.- Run
python prepare_data.py --out processed_data/<dataset_name> --size 256 ./raw_data/<dataset_name>
- This will generate the processed version of the data in
./processed_data
directory.
- Run
-
Otherwise, if downloading directly the processed files, unzip them into
./processed_data
directory. -
Set the training parameters in
train.py
:--n_train
should be set to the number of training samples (default is 10).--img_freq
andckpt_freq
control how frequently do the intermediate generated images and intermediate models are being saved respectively.--iter
determines the total number of iterations. In our experience, adapting a source GAN on FFHQ to artistic domains (e.g. Sketches/Amedeo's paintings) delivers decent results in 4k-5k iterations. For domains closer to natural faces (e.g. Babies/Sunglasses), we get the best results in about 1k iterations.
-
Run the following command to adapt the source GAN (e.g. FFHQ) to the target domain (e.g. sketches):
CUDA_VISIBLE_DEVICES=0 python train.py --ckpt_source /path/to/source_model --data_path /path/to/target_data --exp <exp_name>
# sample run
CUDA_VISIBLE_DEVICES=0 python train.py --ckpt_source ./checkpoints/source_ffhq.pt --data_path ./processed_data/sketches --exp ffhq_to_sketches
This will create directories with name ffhq_to_sketches
in ./checkpoints/
(saving the intermediate models) and in ./samples
(saving the intermediate generated images).
Runnig the above code with default configurations, i.e. batch size = 4, will use ~20 GB GPU memory.
If you find our code useful, please cite our paper:
@inproceedings{ojha2021few-shot-gan,
title={Few-shot Image Generation via Cross-domain Correspondence},
author={Ojha, Utkarsh and Li, Yijun and Lu, Cynthia and Efros, Alexei A. and Lee, Yong Jae and Shechtman, Eli and Zhang, Richard},
booktitle={CVPR},
year={2021}
}
As mentioned before, the StyleGAN2 model is borrowed from this wonderful pytorch implementation by @rosinality. We are also thankful to @mseitzer and @richzhang for their user friendly implementations of computing FID score and LPIPS metric respectively.