-
Notifications
You must be signed in to change notification settings - Fork 0
/
extract_features.py
116 lines (96 loc) · 4.23 KB
/
extract_features.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import os
import argparse
import torch
import torch.distributed as dist
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from torchvision.datasets import ImageFolder
from torchvision import transforms
import numpy as np
from PIL import Image
from diffusers.models import AutoencoderKL
def center_crop_arr(pil_image, image_size):
"""
Center cropping implementation from ADM.
https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126
"""
while min(*pil_image.size) >= 2 * image_size:
pil_image = pil_image.resize(
tuple(x // 2 for x in pil_image.size), resample=Image.BOX
)
scale = image_size / min(*pil_image.size)
pil_image = pil_image.resize(
tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
)
arr = np.array(pil_image)
crop_y = (arr.shape[0] - image_size) // 2
crop_x = (arr.shape[1] - image_size) // 2
return Image.fromarray(arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size])
def main(args):
assert torch.cuda.is_available(), "Feature extraction requires one GPU."
# Setup DDP:
dist.init_process_group("nccl")
assert args.global_batch_size % dist.get_world_size() == 0, f"Batch size must be divisible by world size."
rank = dist.get_rank()
device = rank % torch.cuda.device_count()
seed = args.global_seed * dist.get_world_size() + rank
torch.manual_seed(seed)
torch.cuda.set_device(device)
print(f"Starting rank={rank}, seed={seed}, world_size={dist.get_world_size()}.")
# Setup a feature folder:
if rank == 0:
os.makedirs(args.features_path, exist_ok=True)
os.makedirs(os.path.join(args.features_path, 'imagenet256_features'), exist_ok=True)
os.makedirs(os.path.join(args.features_path, 'imagenet256_labels'), exist_ok=True)
# Init the pretrained vae from stabilityai:
assert args.image_size % 8 == 0, "Image size must be divisible by 8 (for the VAE encoder)."
vae = AutoencoderKL.from_pretrained(f"stabilityai/sd-vae-ft-{args.vae}").to(device)
# Setup data:
transform = transforms.Compose([
transforms.Lambda(lambda pil_image: center_crop_arr(pil_image, args.image_size)),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
])
dataset = ImageFolder(args.data_path, transform=transform)
sampler = DistributedSampler(
dataset,
num_replicas=dist.get_world_size(),
rank=rank,
shuffle=False,
seed=args.global_seed
)
loader = DataLoader(
dataset,
batch_size = 1,
shuffle=False,
sampler=sampler,
num_workers=args.num_workers,
pin_memory=True,
drop_last=True
)
train_steps = 0
for x, y in loader:
x = x.to(device)
y = y.to(device)
with torch.no_grad():
# Map input images to latent space + normalize latents:
x = vae.encode(x).latent_dist.sample().mul_(0.18215)
x = x.detach().cpu().numpy() # (1, 4, 32, 32)
np.save(f'{args.features_path}/imagenet256_features/{train_steps}.npy', x)
y = y.detach().cpu().numpy() # (1,)
np.save(f'{args.features_path}/imagenet256_labels/{train_steps}.npy', y)
train_steps += 1
print(train_steps)
#-- torchrun --nnodes=1 --nproc_per_node=1 extract_features.py --image-size 256 --data-path /dataset/imagenet --features-path /dataset/imagenet_features
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--data-path", type=str, required=True)
parser.add_argument("--features-path", type=str, default="/dataset/imagenet_features")
parser.add_argument("--image-size", type=int, choices=[256, 512], default=256)
parser.add_argument("--global-batch-size", type=int, default=256)
parser.add_argument("--global-seed", type=int, default=0)
parser.add_argument("--vae", type=str, choices=["ema", "mse"], default="ema")
parser.add_argument("--num-workers", type=int, default=4)
args = parser.parse_args()
main(args)