-
Notifications
You must be signed in to change notification settings - Fork 38
/
rnn_reader.py
139 lines (114 loc) · 5.17 KB
/
rnn_reader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
#!/usr/bin/env python3
# Copyright 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""Implementation of the RNN based DrQA reader."""
import torch
import torch.nn as nn
import layers
# ------------------------------------------------------------------------------
# Network
# ------------------------------------------------------------------------------
class RnnDocReader(nn.Module):
RNN_TYPES = {'lstm': nn.LSTM, 'gru': nn.GRU, 'rnn': nn.RNN}
CELL_TYPES = {'lstm': nn.LSTMCell, 'gru': nn.GRUCell, 'rnn': nn.RNNCell}
def __init__(self, args, normalize=True):
super(RnnDocReader, self).__init__()
# Store config
self.args = args
# Word embeddings (+1 for padding)
self.embedding = nn.Embedding(args.vocab_size,
args.embedding_dim,
padding_idx=0)
# Projection for attention weighted question
if args.use_qemb:
self.qemb_match = layers.SeqAttnMatch(args.embedding_dim)
# Input size to RNN: word emb + question emb + manual features
doc_input_size = args.embedding_dim + args.num_features
if args.use_qemb:
doc_input_size += args.embedding_dim
# RNN document encoder
self.doc_rnn = layers.StackedBRNN(
input_size=doc_input_size,
hidden_size=args.hidden_size,
num_layers=args.doc_layers,
dropout_rate=args.dropout_rnn,
dropout_output=args.dropout_rnn_output,
concat_layers=args.concat_rnn_layers,
rnn_type=self.RNN_TYPES[args.rnn_type],
padding=args.rnn_padding,
)
# RNN question encoder
self.question_rnn = layers.StackedBRNN(
input_size=args.embedding_dim,
hidden_size=args.hidden_size,
num_layers=args.question_layers,
dropout_rate=args.dropout_rnn,
dropout_output=args.dropout_rnn_output,
concat_layers=args.concat_rnn_layers,
rnn_type=self.RNN_TYPES[args.rnn_type],
padding=args.rnn_padding,
)
# Output sizes of rnn encoders
doc_hidden_size = 2 * args.hidden_size
question_hidden_size = 2 * args.hidden_size
if args.concat_rnn_layers:
doc_hidden_size *= args.doc_layers
question_hidden_size *= args.question_layers
# Question merging
if args.question_merge not in ['avg', 'self_attn']:
raise NotImplementedError('merge_mode = %s' % args.merge_mode)
if args.question_merge == 'self_attn':
self.self_attn = layers.LinearSeqAttn(question_hidden_size)
# Bilinear attention for span start/end
self.start_attn = layers.BilinearSeqAttn(
doc_hidden_size,
question_hidden_size,
normalize=normalize,
)
self.end_attn = layers.BilinearSeqAttn(
doc_hidden_size,
question_hidden_size,
normalize=normalize,
)
def forward(self, x1, x1_c, x1_f, x1_mask, x2, x2_c, x2_f, x2_mask):
"""Inputs:
x1 = document word indices [batch * len_d]
x1_f = document word features indices [batch * len_d * nfeat]
x1_mask = document padding mask [batch * len_d]
x2 = question word indices [batch * len_q]
x2_mask = question padding mask [batch * len_q]
"""
# Embed both document and question
x1_emb = self.embedding(x1)
x2_emb = self.embedding(x2)
# Dropout on embeddings
if self.args.dropout_emb > 0:
x1_emb = nn.functional.dropout(x1_emb, p=self.args.dropout_emb,
training=self.training)
x2_emb = nn.functional.dropout(x2_emb, p=self.args.dropout_emb,
training=self.training)
# Form document encoding inputs
drnn_input = [x1_emb]
# Add attention-weighted question representation
if self.args.use_qemb:
x2_weighted_emb = self.qemb_match(x1_emb, x2_emb, x2_mask)
drnn_input.append(x2_weighted_emb)
# Add manual features
if self.args.num_features > 0:
drnn_input.append(x1_f)
# Encode document with RNN
doc_hiddens = self.doc_rnn(torch.cat(drnn_input, 2), x1_mask)
# Encode question with RNN + merge hiddens
question_hiddens = self.question_rnn(x2_emb, x2_mask)
if self.args.question_merge == 'avg':
q_merge_weights = layers.uniform_weights(question_hiddens, x2_mask)
elif self.args.question_merge == 'self_attn':
q_merge_weights = self.self_attn(question_hiddens, x2_mask)
question_hidden = layers.weighted_avg(question_hiddens, q_merge_weights)
# Predict start and end positions
start_scores = self.start_attn(doc_hiddens, question_hidden, x1_mask)
end_scores = self.end_attn(doc_hiddens, question_hidden, x1_mask)
return start_scores, end_scores