Skip to content

Commit

Permalink
(improvement/AttentionalDecoder.predict_max) Avoid computation on rea…
Browse files Browse the repository at this point in the history
…ched EOS for prediction (See emanjavacas#73)

The current prediction time is quite slow, we agree that there might be room for improvement.

After having a good look at it, it seemed clear that we were computing on items that technically did not need to continue to be computed upon (string that reach EOS).

I propose here my refactor of the predict_max function that stop computing over elements that reached EOS. There is probably still room for improvement here.

For a group of 19 sentences over 100 iterations
Average tagging time with default: 0.556127781867981 s
Median tagging time with default: 0.5420029163360596
Total tagging time with default: 55.612778186798096 s

For a group of 19 sentences over 100 iterations
Average tagging time with new: 0.4061899709701538 s
Median tagging time with new: 0.40130531787872314
Total tagging time with new: 40.61899709701538 s

- 27 % time for the whole tagging (lemma only)
  • Loading branch information
PonteIneptique committed Aug 7, 2020
1 parent 230e0e6 commit 720389e
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 12 deletions.
1 change: 1 addition & 0 deletions pie/models/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ def forward(self, dec_out, enc_outs, lengths):

# apply source length mask
mask = torch_utils.make_length_mask(lengths)

# (batch x src_seq_len) => (trg_seq_len x batch x src_seq_len)
mask = mask.unsqueeze(0).expand_as(weights)
# weights = weights * mask.float()
Expand Down
75 changes: 63 additions & 12 deletions pie/models/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,8 +315,8 @@ def loss(self, logits, targets):
return loss

def predict_max(self, enc_outs, lengths,
max_seq_len=20, bos=None, eos=None,
context=None):
max_seq_len=20, bos=None, eos=None,
context=None):
"""
Decoding routine for inference with step-wise argmax procedure
Expand All @@ -328,35 +328,86 @@ def predict_max(self, enc_outs, lengths,
eos = eos or self.label_encoder.get_eos()
bos = bos or self.label_encoder.get_bos()
hidden, batch, device = None, enc_outs.size(1), enc_outs.device
mask = torch.ones(batch, dtype=torch.int64, device=device)
inp = torch.zeros(batch, dtype=torch.int64, device=device) + bos
hyps, scores = [], 0
hyps, scores = [], [0 for _ in range(batch)]

# We store a conversion table for tensor index to
# Tensor Index -> Hyp Index
indexes = {
x: x for x in range(batch)
}

for _ in range(max_seq_len):
if mask.sum().item() == 0:
break

# prepare input
# Context is NEVER changed after the method has been called

emb = self.embs(inp)
if context is not None:
emb = torch.cat([emb, context], dim=1)

# run rnn
# Move embeddings to a 2-d Tensor to a 3-D tensor (1, word number, emb size(+context))
emb = emb.unsqueeze(0)

# Hidden is always reused
# -> Hidden is (1, word number, emb size)
outs, hidden = self.rnn(emb, hidden)

outs, _ = self.attn(outs, enc_outs, lengths)
outs = self.proj(outs).squeeze(0)

# get logits
probs = F.log_softmax(outs, dim=1)

# sample and accumulate
score, inp = probs.max(1)
hyps.append(inp.tolist())
mask = mask * (inp != eos).long()
score = score.cpu()
score[mask == 0] = 0
scores += score

# We create a mask of value that are not ending the string
non_eos = (inp != eos)

# Keep are the index of item we choose to keep (ie, not ending with EOS)
keep = torch.nonzero(non_eos, as_tuple=True)[0]

# add new chars to hypotheses
# We prepare a list the size of the output (with EOS)
# Once done, we replace the values using the table of equivalencies
to_append = [eos for _ in range(batch)]
new_scores = [0 for _ in range(batch)]

for ind, (hyp, sc) in enumerate(zip(inp.tolist(), score.tolist())):
to_append[indexes[ind]] = hyp
if hyp != eos:
scores[indexes[ind]] += sc

hyps.append(to_append)

# If there is no non_eos, it's the end of the prediction time
if True not in non_eos:
break

# We update the indexes so that tensor "row" index maps to the correct
# hypothesis value
indexes = {elem: indexes[former_index] for elem, former_index in enumerate(keep.tolist())}
# print(indexes)

# Stop are the index of elements we remove from the input tensor
inp = inp[keep]
context = context[keep]
lengths = lengths[keep]

# Hidden is 3D with 1 in first dimension
hidden = hidden.squeeze(0)[keep].unsqueeze(0)

# enc_outs is seq * batch * size, so we tranpose and transpose back
# Seq_len is supposed to be equal to max(lengths), but if the maximum length is popped
# We need to reduce the dimension of enc_outs as well
max_seq_len = lengths.max()

enc_outs = enc_outs[:max_seq_len].transpose(0, 1)[keep].transpose(0, 1)

hyps = [self.label_encoder.stringify(hyp) for hyp in zip(*hyps)]
scores = [s/(len(hyp) + TINY) for s, hyp in zip(scores.tolist(), hyps)]
scores = [s / (len(hyp) + TINY) for s, hyp in zip(scores, hyps)]

return hyps, scores

Expand Down

0 comments on commit 720389e

Please sign in to comment.