-
Notifications
You must be signed in to change notification settings - Fork 0
/
validation.py
76 lines (62 loc) · 2.08 KB
/
validation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision.utils import make_grid, save_image
from typing import Callable
from edm import EDMSampler
from model import *
from utils import *
from type_alias import *
def ModelBackToCPU(validFunc: Callable):
def ModelBackToCPU_Valid(
sampler : EDMSampler,
dataloader : DataLoader,
denoiser : UNet,
extractor : Extractor,
device : torch.device,
saveFilename : str
):
denoiser.to(device)
output = validFunc(sampler, dataloader, denoiser, extractor, device, saveFilename)
denoiser.cpu()
return output
return ModelBackToCPU_Valid
@torch.inference_mode()
def Valid(
sampler : EDMSampler,
dataloader : DataLoader,
denoiser : UNet,
extractor : Extractor,
device : torch.device,
saveFilename : str
):
isDenoiserTraining = denoiser .training
isExtractorTraining = extractor.training
denoiser .eval()
extractor.eval()
generateds, nTotal = None, 0
for images, masks, toExtracts in dataloader:
images, masks, toExtracts = images.to(device), masks.to(device), toExtracts.to(device)
B, C, H, W = images.size()
nTotal += B
denoiseArgs = {
"cond": {
"concat" : masks,
"extract": extractor(toExtracts)
},
"uncond": {
"concat" : torch.zeros([B, denoiser.inChannel - C, H, W], device=device),
"extract": extractor.MakeUncondTensor(B, device=device)
}
}
batchRes = sampler.Run(denoiser, B, None, denoiseArgs)
generateds = DefaultConcatTensor(generateds, batchRes)
save_image(
make_grid(generateds.clamp(0., 1.), nrow=round(nTotal ** 0.5)),
saveFilename
)
if isDenoiserTraining:
denoiser.train()
if isExtractorTraining:
extractor.train()
torch.cuda.empty_cache()