forked from jvanvugt/pytorch-domain-adaptation
-
Notifications
You must be signed in to change notification settings - Fork 2
/
revgrad.py
94 lines (74 loc) · 3.5 KB
/
revgrad.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
"""
Implements RevGrad:
Unsupervised Domain Adaptation by Backpropagation, Ganin & Lemptsky (2014)
Domain-adversarial training of neural networks, Ganin et al. (2016)
"""
import argparse
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision.transforms import Compose, ToTensor
from tqdm import tqdm
import config
from data import MNISTM
from models import Net
from utils import GrayscaleToRgb, GradientReversal
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def main(args):
model = Net().to(device)
model.load_state_dict(torch.load(args.MODEL_FILE))
feature_extractor = model.feature_extractor
clf = model.classifier
discriminator = nn.Sequential(
GradientReversal(),
nn.Linear(320, 50),
nn.ReLU(),
nn.Linear(50, 20),
nn.ReLU(),
nn.Linear(20, 1)
).to(device)
half_batch = args.batch_size // 2
source_dataset = MNIST(config.DATA_DIR/'mnist', train=True, download=True,
transform=Compose([GrayscaleToRgb(), ToTensor()]))
source_loader = DataLoader(source_dataset, batch_size=half_batch,
shuffle=True, num_workers=1, pin_memory=True)
target_dataset = MNISTM(train=False)
target_loader = DataLoader(target_dataset, batch_size=half_batch,
shuffle=True, num_workers=1, pin_memory=True)
optim = torch.optim.Adam(list(discriminator.parameters()) + list(model.parameters()))
for epoch in range(1, args.epochs+1):
batches = zip(source_loader, target_loader)
n_batches = min(len(source_loader), len(target_loader))
total_domain_loss = total_label_accuracy = 0
for (source_x, source_labels), (target_x, _) in tqdm(batches, leave=False, total=n_batches):
x = torch.cat([source_x, target_x])
x = x.to(device)
domain_y = torch.cat([torch.ones(source_x.shape[0]),
torch.zeros(target_x.shape[0])])
domain_y = domain_y.to(device)
label_y = source_labels.to(device)
features = feature_extractor(x).view(x.shape[0], -1)
domain_preds = discriminator(features).squeeze()
label_preds = clf(features[:source_x.shape[0]])
domain_loss = F.binary_cross_entropy_with_logits(domain_preds, domain_y)
label_loss = F.cross_entropy(label_preds, label_y)
loss = domain_loss + label_loss
optim.zero_grad()
loss.backward()
optim.step()
total_domain_loss += domain_loss.item()
total_label_accuracy += (label_preds.max(1)[1] == label_y).float().mean().item()
mean_loss = total_domain_loss / n_batches
mean_accuracy = total_label_accuracy / n_batches
tqdm.write(f'EPOCH {epoch:03d}: domain_loss={mean_loss:.4f}, '
f'source_accuracy={mean_accuracy:.4f}')
torch.save(model.state_dict(), 'trained_models/revgrad.pt')
if __name__ == '__main__':
arg_parser = argparse.ArgumentParser(description='Domain adaptation using RevGrad')
arg_parser.add_argument('MODEL_FILE', help='A model in trained_models')
arg_parser.add_argument('--batch-size', type=int, default=64)
arg_parser.add_argument('--epochs', type=int, default=15)
args = arg_parser.parse_args()
main(args)