-
Notifications
You must be signed in to change notification settings - Fork 10
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Improve prediction time (predict_max) #73
Comments
Improving predict_maxThe idea is to reduce computation at prediction time: the current system continues to predict even when EOS is reached for tokens. So until the biggest prediction is reached, all tokens are computed on. Setupimport torch
import torch.quantization
import torch.nn as nn
import copy
import os
import time
from pie.models.decoder import AttentionalDecoder
import torch.nn.functional as F
from pie.models import BaseModel
from pie.tagger import Tagger
from pie.data import Dataset, Reader
from pie.settings import load_default_settings, settings_from_file
TINY = 1e-8
DEVICE = "cpu"
import copy
old = copy.deepcopy(AttentionalDecoder.predict_max)
def load_and_monkey_patch(patch = None):
if not patch:
patch = old
AttentionalDecoder.predict_max = patch
tagger = Tagger()
tagger.add_model("models/Final-Latin-Lemma-H384-C700-lemma-2020_08_06-18_44_24.tar", "lemma")
NormalModel = tagger.models[0][0]
settings = NormalModel._settings
settings.device = DEVICE
settings.shuffle = False # avoid shuffling
return tagger, NormalModel
sentences = """Lorem ipsum dolor sit amet, consectetur adipiscing elit.
Phasellus dolor sapien, laoreet non turpis eget, tincidunt commodo magna. Duis at dapibus ipsum.
Etiam fringilla et magna sed vehicula.
Nunc tristique eros non faucibus viverra.
Sed dictum scelerisque tortor, eu ullamcorper odio.
Aenean fermentum a urna quis tempus.
Maecenas imperdiet est a nisi pellentesque dictum.
Maecenas ac hendrerit ante. Vestibulum eleifend nulla at vulputate sagittis.
Maecenas sed magna diam sed facilisis tempus ipsum, nec mattis elit tincidunt lobortis Phasellus vel ex lorem nulla nunc odio, tempor non consequat in, luctus elementum dolor.
Nullam tincidunt purus vel lorem placerat, ac pulvinar turpis sodales.
Sed eget urna ac quam cursus porta.
Pellentesque luctus aliquet sem, a egestas purus finibus ac.
Mauris nec mauris non metus tempor faucibus non in est.
Class aptent taciti sociosqu ad litora torquent per conubia nostra, per inceptos himenaeos.
Proin tristique nulla nec purus iaculis, eu scelerisque mi egestas.
In hac habitasse platea dictumst.
Ut placerat a neque eget aliquet. """.lower().replace("\n", "").replace(",", "").split(".")
Test = [
sent.split()
for sent in sentences
if sent.split()
] Testing functionimport time
import statistics
def test(patch, n_iters=10, sentences=Test, name="default"):
tagger, _ = load_and_monkey_patch(patch)
lengths = [len(x) for x in sentences]
times = []
for i in range(n_iters):
start = time.time()
out = tagger.tag(sentences, lengths)
times.append(time.time() - start)
print("====")
print(f"For a group of {len(sentences)} sentences over {n_iters} iterations")
print(f"Average tagging time with `{name}`: {sum(times) / n_iters} s")
print(f"Median tagging time with `{name}`: {statistics.median(times)}")
print(f"Total tagging time with `{name}`: {sum(times)} s")
return out New functiondef predict_max_debug(self, enc_outs, lengths,
max_seq_len=20, bos=None, eos=None,
context=None):
"""
Decoding routine for inference with step-wise argmax procedure
Parameters
===========
enc_outs : tensor(src_seq_len x batch x hidden_size)
context : tensor(batch x hidden_size), optional
"""
eos = eos or self.label_encoder.get_eos()
bos = bos or self.label_encoder.get_bos()
hidden, batch, device = None, enc_outs.size(1), enc_outs.device
inp = torch.zeros(batch, dtype=torch.int64, device=device) + bos
hyps, scores = [], [0 for _ in range(batch)]
# We store a conversion table for tensor index to
# Tensor Index -> Hyp Index
indexes = {
x: x for x in range(batch)
}
for _ in range(max_seq_len):
# prepare input
# Context is NEVER changed after the method has been called
emb = self.embs(inp)
if context is not None:
emb = torch.cat([emb, context], dim=1)
# run rnn
# Move embeddings to a 2-d Tensor to a 3-D tensor (1, word number, emb size(+context))
emb = emb.unsqueeze(0)
# Hidden is always reused
# -> Hidden is (1, word number, emb size)
outs, hidden = self.rnn(emb, hidden)
outs, _ = self.attn(outs, enc_outs, lengths)
outs = self.proj(outs).squeeze(0)
# get logits
probs = F.log_softmax(outs, dim=1)
# sample and accumulate
score, inp = probs.max(1)
# We create a mask of value that are not ending the string
non_eos = (inp != eos)
# Keep are the index of item we choose to keep (ie, not ending with EOS)
keep = torch.nonzero(non_eos, as_tuple=True)[0]
# add new chars to hypotheses
# We prepare a list the size of the output (with EOS)
# Once done, we replace the values using the table of equivalencies
to_append = [eos for _ in range(batch)]
new_scores = [0 for _ in range(batch)]
for ind, (hyp, sc) in enumerate(zip(inp.tolist(), score.tolist())):
to_append[indexes[ind]] = hyp
if hyp != eos:
scores[indexes[ind]] += sc
hyps.append(to_append)
# If there is no non_eos, it's the end of the prediction time
if True not in non_eos:
break
# We update the indexes so that tensor "row" index maps to the correct
# hypothesis value
indexes = {elem: indexes[former_index] for elem, former_index in enumerate(keep.tolist())}
# print(indexes)
# Stop are the index of elements we remove from the input tensor
inp = inp[keep]
context = context[keep]
lengths = lengths[keep]
# Hidden is 3D with 1 in first dimension
hidden = hidden.squeeze(0)[keep].unsqueeze(0)
# enc_outs is seq * batch * size, so we tranpose and transpose back
# Seq_len is supposed to be equal to max(lengths), but if the maximum length is popped
# We need to reduce the dimension of enc_outs as well
max_seq_len = lengths.max()
enc_outs = enc_outs[:max_seq_len].transpose(0, 1)[keep].transpose(0, 1)
hyps = [self.label_encoder.stringify(hyp) for hyp in zip(*hyps)]
scores = [s/(len(hyp) + TINY) for s, hyp in zip(scores, hyps)]
return hyps, scores
new_out = test(patch=predict_max_debug, sentences=Test, name="new", n_iters=100)
Current functiondef predict_max(self, enc_outs, lengths,
max_seq_len=20, bos=None, eos=None,
context=None):
"""
Decoding routine for inference with step-wise argmax procedure
Parameters
===========
enc_outs : tensor(src_seq_len x batch x hidden_size)
context : tensor(batch x hidden_size), optional
"""
eos = eos or self.label_encoder.get_eos()
bos = bos or self.label_encoder.get_bos()
hidden, batch, device = None, enc_outs.size(1), enc_outs.device
mask = torch.ones(batch, dtype=torch.int64, device=device)
inp = torch.zeros(batch, dtype=torch.int64, device=device) + bos
hyps, scores = [], 0
for _ in range(max_seq_len):
if mask.sum().item() == 0:
break
# prepare input
emb = self.embs(inp)
if context is not None:
emb = torch.cat([emb, context], dim=1)
# run rnn
emb = emb.unsqueeze(0)
outs, hidden = self.rnn(emb, hidden)
outs, _ = self.attn(outs, enc_outs, lengths)
outs = self.proj(outs).squeeze(0)
# get logits
probs = F.log_softmax(outs, dim=1)
# sample and accumulate
score, inp = probs.max(1)
hyps.append(inp.tolist())
mask = mask * (inp != eos).long()
score = score.cpu()
score[mask == 0] = 0
scores += score
hyps = [self.label_encoder.stringify(hyp) for hyp in zip(*hyps)]
scores = [s/(len(hyp) + TINY) for s, hyp in zip(scores.tolist(), hyps)]
return hyps, scores
former_out = test(patch=predict_max, sentences=Test, name="default", n_iters=100)
new_out == former_out
|
This time I'll wait for your feedback @emanjavacas |
Note that the improvement I noted are highly dependant of the seq_len and the disparity accross prediction sizes. |
Are you testing this on a gpu?
…On Fri, Aug 7, 2020 at 3:09 PM Thibault Clérice ***@***.***> wrote:
Note that the improvement I noted are highly dependant of the seq_len and
the disparity accross prediction sizes.
—
You are receiving this because you were mentioned.
Reply to this email directly, view it on GitHub
<#73 (comment)>, or
unsubscribe
<https://github.com/notifications/unsubscribe-auth/ABPIPI45S2Y3ZUKLKUITM7TR7P4KFANCNFSM4PXTDLSA>
.
--
Enrique Manjavacas
|
I tested on CPU for now. Which is generally what the users have for inference or webservices. |
For a stupid evaluation, I ran the same model on README.md, went from 13 seconds to 7 seconds... |
…ched EOS for prediction (See emanjavacas#73) The current prediction time is quite slow, we agree that there might be room for improvement. After having a good look at it, it seemed clear that we were computing on items that technically did not need to continue to be computed upon (string that reach EOS). I propose here my refactor of the predict_max function that stop computing over elements that reached EOS. There is probably still room for improvement here. For a group of 19 sentences over 100 iterations Average tagging time with default: 0.556127781867981 s Median tagging time with default: 0.5420029163360596 Total tagging time with default: 55.612778186798096 s For a group of 19 sentences over 100 iterations Average tagging time with new: 0.4061899709701538 s Median tagging time with new: 0.40130531787872314 Total tagging time with new: 40.61899709701538 s - 27 % time for the whole tagging (lemma only)
I actually PRed so that you can pull and test :) |
I am quite busy right now, so it will take some time until I can look at
this. My impression is that gpu will actually be slower, since
parallelizing is very cheap. I am not opposed to the idea of optimizing at
this level if people can profit from that (even if the gains aren't really
that significant), but the code would have to improve readability, because
this is a bug-sensitive part of the codebase that I eventually come back to
modify.
…On Fri, Aug 7, 2020 at 3:46 PM Thibault Clérice ***@***.***> wrote:
I actually PRed so that you can pull and test :)
—
You are receiving this because you were mentioned.
Reply to this email directly, view it on GitHub
<#73 (comment)>, or
unsubscribe
<https://github.com/notifications/unsubscribe-auth/ABPIPIYFW2VS6YXKIFARIJDR7QATDANCNFSM4PXTDLSA>
.
--
Enrique Manjavacas
|
Actually, on real world data, the improvement are substantial: For this test, CPU = i7 6700k, GPU = 1080
For this test, i5 8265u, on a laptop (Base Freq: 1,6 GHz, 4 cores, 8 threads)
it/sec > 100 % is improved, sec < 100 % is improved It's actually better with the files |
It consistently beats CPU by a huge margin, and beats GPU with a relatively small margin (from 4 to 18 %) |
I did not address this, but please, take the time you need. It's august, you are probably on vacation and if you are not, I technically am. So it can wait ;) I finished the PR by adding a lot of comments, making sure variable names are clear. I also added / edited my comments to add some more information. |
I know your time is busy, but if you are merging PR, the #74 might be somewhere you could stop, if you have time :) Those are some nice improvements :) |
* (improvement/AttentionalDecoder.predict_max) Avoid computation on reached EOS for prediction (See emanjavacas#73) The current prediction time is quite slow, we agree that there might be room for improvement. After having a good look at it, it seemed clear that we were computing on items that technically did not need to continue to be computed upon (string that reach EOS). I propose here my refactor of the predict_max function that stop computing over elements that reached EOS. There is probably still room for improvement here. For a group of 19 sentences over 100 iterations Average tagging time with default: 0.556127781867981 s Median tagging time with default: 0.5420029163360596 Total tagging time with default: 55.612778186798096 s For a group of 19 sentences over 100 iterations Average tagging time with new: 0.4061899709701538 s Median tagging time with new: 0.40130531787872314 Total tagging time with new: 40.61899709701538 s - 27 % time for the whole tagging (lemma only) * (Improvement/AttentionalDecoder.predict_max) Lots of comment and some performance fixes (avoid transposing when non needed) * Update torch * (improvement/argmax.decoder) Improved a little more readability * (improvement/decode) Removed loop by tensor use for prediction * Do not update sneakily the torch requirements
The current prediction time is quite slow, we agree that there might be room for improvement.
After having a good look at it, it seemed clear that we were computing on items that technically did not need to continue to be computed upon (string that reach EOS).
I propose here my refactor of the predict_max function that stop computing over elements that reached EOS. There is probably still room for improvement here.
For a group of 19 sentences over 100 iterations
Average tagging time with
default
: 0.556127781867981 sMedian tagging time with
default
: 0.5420029163360596Total tagging time with
default
: 55.612778186798096 sFor a group of 19 sentences over 100 iterations
Average tagging time with
new
: 0.4061899709701538 sMedian tagging time with
new
: 0.40130531787872314Total tagging time with
new
: 40.61899709701538 s- 27 % time for the whole tagging (lemma only)
The text was updated successfully, but these errors were encountered: