-
Notifications
You must be signed in to change notification settings - Fork 1
/
maml.py
124 lines (101 loc) · 5.02 KB
/
maml.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
from torch import nn
from torch.nn import functional as F
from torch.utils.data import TensorDataset, DataLoader, RandomSampler
from torch.optim import Adam
from torch.nn import CrossEntropyLoss
from transformers import BertForSequenceClassification
from copy import deepcopy
import gc
import torch
from sklearn.metrics import accuracy_score
import numpy as np
class Learner(nn.Module):
"""
Meta Learner
"""
def __init__(self, args):
"""
:param args:
"""
super(Learner, self).__init__()
self.num_labels = args.num_labels
self.outer_batch_size = args.outer_batch_size
self.inner_batch_size = args.inner_batch_size
self.outer_update_lr = args.outer_update_lr
self.inner_update_lr = args.inner_update_lr
self.inner_update_step = args.inner_update_step
self.inner_update_step_eval = args.inner_update_step_eval
self.bert_model = args.bert_model
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.model = BertForSequenceClassification.from_pretrained(self.bert_model, num_labels = self.num_labels)
self.outer_optimizer = Adam(self.model.parameters(), lr=self.outer_update_lr)
self.model.train()
def forward(self, batch_tasks, training = True):
"""
batch = [(support TensorDataset, query TensorDataset),
(support TensorDataset, query TensorDataset),
(support TensorDataset, query TensorDataset),
(support TensorDataset, query TensorDataset)]
# support = TensorDataset(all_input_ids, all_attention_mask, all_segment_ids, all_label_ids)
"""
task_accs = []
sum_gradients = []
num_task = len(batch_tasks)
num_inner_update_step = self.inner_update_step if training else self.inner_update_step_eval
for task_id, task in enumerate(batch_tasks):
support = task[0]
query = task[1]
fast_model = deepcopy(self.model)
fast_model.to(self.device)
support_dataloader = DataLoader(support, sampler=RandomSampler(support),
batch_size=self.inner_batch_size)
inner_optimizer = Adam(fast_model.parameters(), lr=self.inner_update_lr)
fast_model.train()
print('----Task',task_id, '----')
for i in range(0,num_inner_update_step):
all_loss = []
for inner_step, batch in enumerate(support_dataloader):
batch = tuple(t.to(self.device) for t in batch)
input_ids, attention_mask, segment_ids, label_id = batch
outputs = fast_model(input_ids, attention_mask, segment_ids, labels = label_id)
loss = outputs[0]
loss.backward()
inner_optimizer.step()
inner_optimizer.zero_grad()
all_loss.append(loss.item())
if i % 4 == 0:
print("Inner Loss: ", np.mean(all_loss))
query_dataloader = DataLoader(query, sampler=None, batch_size=len(query))
query_batch = iter(query_dataloader).next()
query_batch = tuple(t.to(self.device) for t in query_batch)
q_input_ids, q_attention_mask, q_segment_ids, q_label_id = query_batch
q_outputs = fast_model(q_input_ids, q_attention_mask, q_segment_ids, labels = q_label_id)
if training:
q_loss = q_outputs[0]
q_loss.backward()
fast_model.to(torch.device('cpu'))
for i, params in enumerate(fast_model.parameters()):
if task_id == 0:
sum_gradients.append(deepcopy(params.grad))
else:
sum_gradients[i] += deepcopy(params.grad)
q_logits = F.softmax(q_outputs[1],dim=1)
pre_label_id = torch.argmax(q_logits,dim=1)
pre_label_id = pre_label_id.detach().cpu().numpy().tolist()
q_label_id = q_label_id.detach().cpu().numpy().tolist()
acc = accuracy_score(pre_label_id,q_label_id)
task_accs.append(acc)
del fast_model, inner_optimizer
torch.cuda.empty_cache()
if training:
# Average gradient across tasks
for i in range(0,len(sum_gradients)):
sum_gradients[i] = sum_gradients[i] / float(num_task)
#Assign gradient for original model, then using optimizer to update its weights
for i, params in enumerate(self.model.parameters()):
params.grad = sum_gradients[i]
self.outer_optimizer.step()
self.outer_optimizer.zero_grad()
del sum_gradients
gc.collect()
return np.mean(task_accs)