-
Notifications
You must be signed in to change notification settings - Fork 0
/
demo.py
executable file
·74 lines (66 loc) · 2.86 KB
/
demo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
from argparse import ArgumentParser
from typing import Dict
import torch
from PIL import Image
import primaps_modules.transforms as transforms
from primaps_modules.primaps import PriMaPs
from primaps_modules.backbone.dino.dinovit import DinoFeaturizerv2
from primaps_modules.visualization import visualize_demo
# set seeds
torch.manual_seed(0)
torch.cuda.manual_seed(0)
def main(opts: Dict):
'''
Demo to visualize PriMaPs for a single image.
'''
# get SLL image encoder and primaps module
net = DinoFeaturizerv2(opts.backbone_arch, opts.backbone_patch)
net.to(opts.device)
primaps_module = PriMaPs(threshold=opts.threshold,
ignore_id=255)
# get transforms
demo_transforms = transforms.Compose([transforms.ToTensor(),
transforms.Resize(opts.validation_resize),
transforms.CenterCrop([opts.validation_resize[0], opts.validation_resize[0]]),
transforms.Normalize()])
# load image and apply transforms
img = Image.open(opts.image_path)
img, _ = demo_transforms(img, torch.zeros(img.size))
img.to(opts.device)
# get SSL features
feats = net(img.unsqueeze(0).to(opts.device), n=1).squeeze()
# get primaps pseudo labels
primaps = primaps_module._get_pseudo(img, feats, torch.zeros(img.shape[1:]))
# visualize overlay
Image.fromarray(visualize_demo(img, primaps)).save('demo.png')
print('Image saved as demo.png')
if __name__ == '__main__':
parser = ArgumentParser()
parser.add_argument("--backbone-arch",
type=str,
default=['dino_vits', 'dino_vitb', 'dinov2_vits', 'dinov2_vitb'][1],
help='backbone architecture')
parser.add_argument("--backbone-patch",
type=int,
default=[8, 14, 16][0],
help='patch size of the vit backbone')
parser.add_argument("--validation-resize",
nargs='+',
type=int,
default=[[320], [322]][0],
help='resize images to this size')
parser.add_argument("--threshold",
type=float,
default=0.35,
help='primaps threshold')
parser.add_argument("--device",
type=str,
default='cuda:0',
help='device to use')
parser.add_argument("--image-path",
type=str,
default=['assets/demo_examples/IMG_0709.jpg', 'assets/demo_examples/cityscapes_example.png', 'assets/demo_examples/coco_example.jpg', 'assets/demo_examples/potsdam_example.png'][0],
help='path to images')
args = parser.parse_args()
print(args)
main(args)