-
Notifications
You must be signed in to change notification settings - Fork 1
/
run_adv_wgan_gp_3modal.py
77 lines (72 loc) · 2.33 KB
/
run_adv_wgan_gp_3modal.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import torch
from mmkgc.config import Tester, WCGTrainerDB15KGP
from mmkgc.module.model import AdvRelRotatEDB15K
from mmkgc.module.loss import SigmoidLoss
from mmkgc.module.strategy import NegativeSamplingGP
from mmkgc.data import TrainDataLoader, TestDataLoader
from mmkgc.adv.modules import CombinedGenerator3
from args import get_args
if __name__ == "__main__":
args = get_args()
print(args)
# set the seed
torch.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed)
# dataloader for training
train_dataloader = TrainDataLoader(
in_path="./benchmarks/" + args.dataset + '/',
batch_size=args.batch_size,
threads=8,
sampling_mode="normal",
bern_flag=1,
filter_flag=1,
neg_ent=args.neg_num,
neg_rel=0
)
# dataloader for test
test_dataloader = TestDataLoader(
"./benchmarks/" + args.dataset + '/', "link")
img_emb = torch.load('./embeddings/' + args.dataset + '-visual.pth')
text_emb = torch.load('./embeddings/' + args.dataset + '-textual.pth')
num_emb = torch.load('./embeddings/' + args.dataset + '-numeric.pth')
# define the model
kge_score = AdvRelRotatEDB15K(
ent_tot=train_dataloader.get_ent_tot(),
rel_tot=train_dataloader.get_rel_tot(),
dim=args.dim,
margin=args.margin,
epsilon=2.0,
img_emb=img_emb,
text_emb=text_emb,
numeric_emb=num_emb,
)
print(kge_score)
# define the loss function
model = NegativeSamplingGP(
model=kge_score,
loss=SigmoidLoss(adv_temperature=args.adv_temp),
batch_size=train_dataloader.get_batch_size(),
)
adv_generator = CombinedGenerator3(
noise_dim=64,
structure_dim=2*args.dim,
img_dim=3*args.dim
)
# train the model
trainer = WCGTrainerDB15KGP(
model=model,
data_loader=train_dataloader,
train_times=args.epoch,
alpha=args.learning_rate,
use_gpu=True,
opt_method='Adam',
generator=adv_generator,
lrg=args.lrg,
mu=args.mu
)
trainer.run()
kge_score.save_checkpoint(args.save)
# test the model
# kge_score.load_checkpoint(args.save)
tester = Tester(model=kge_score, data_loader=test_dataloader, use_gpu=True)
tester.run_link_prediction(type_constrain=False)