Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve prediction time (predict_max) #73

Open
PonteIneptique opened this issue Aug 7, 2020 · 12 comments
Open

Improve prediction time (predict_max) #73

PonteIneptique opened this issue Aug 7, 2020 · 12 comments

Comments

@PonteIneptique
Copy link
Contributor

PonteIneptique commented Aug 7, 2020

The current prediction time is quite slow, we agree that there might be room for improvement.

After having a good look at it, it seemed clear that we were computing on items that technically did not need to continue to be computed upon (string that reach EOS).

I propose here my refactor of the predict_max function that stop computing over elements that reached EOS. There is probably still room for improvement here.

For a group of 19 sentences over 100 iterations
Average tagging time with default: 0.556127781867981 s
Median tagging time with default: 0.5420029163360596
Total tagging time with default: 55.612778186798096 s

For a group of 19 sentences over 100 iterations
Average tagging time with new: 0.4061899709701538 s
Median tagging time with new: 0.40130531787872314
Total tagging time with new: 40.61899709701538 s

- 27 % time for the whole tagging (lemma only)

@PonteIneptique
Copy link
Contributor Author

PonteIneptique commented Aug 7, 2020

Evaluate PredictMax.ipynb.zip

Improving predict_max

The idea is to reduce computation at prediction time: the current system continues to predict even when EOS is reached for tokens. So until the biggest prediction is reached, all tokens are computed on.

Setup

import torch
import torch.quantization
import torch.nn as nn
import copy
import os
import time

from pie.models.decoder import AttentionalDecoder
import torch.nn.functional as F
from pie.models import BaseModel
from pie.tagger import Tagger
from pie.data import Dataset, Reader
from pie.settings import load_default_settings, settings_from_file

TINY = 1e-8


DEVICE = "cpu"

import copy

old = copy.deepcopy(AttentionalDecoder.predict_max)

def load_and_monkey_patch(patch = None):
    if not patch:
        patch = old
        
    AttentionalDecoder.predict_max = patch    
    tagger = Tagger()
    tagger.add_model("models/Final-Latin-Lemma-H384-C700-lemma-2020_08_06-18_44_24.tar", "lemma")
    NormalModel = tagger.models[0][0]
    settings = NormalModel._settings
    settings.device = DEVICE
    settings.shuffle = False    # avoid shuffling
    return tagger, NormalModel

sentences = """Lorem ipsum dolor sit amet, consectetur adipiscing elit. 
Phasellus dolor sapien, laoreet non turpis eget, tincidunt commodo magna. Duis at dapibus ipsum. 
Etiam fringilla et magna sed vehicula. 
Nunc tristique eros non faucibus viverra. 
Sed dictum scelerisque tortor, eu ullamcorper odio. 
Aenean fermentum a urna quis tempus. 
Maecenas imperdiet est a nisi pellentesque dictum. 
Maecenas ac hendrerit ante. Vestibulum eleifend nulla at vulputate sagittis. 
Maecenas sed magna diam sed facilisis tempus ipsum, nec mattis elit tincidunt lobortis Phasellus vel ex lorem nulla nunc odio, tempor non consequat in, luctus elementum dolor. 
Nullam tincidunt purus vel lorem placerat, ac pulvinar turpis sodales. 
Sed eget urna ac quam cursus porta. 
Pellentesque luctus aliquet sem, a egestas purus finibus ac. 
Mauris nec mauris non metus tempor faucibus non in est. 
Class aptent taciti sociosqu ad litora torquent per conubia nostra, per inceptos himenaeos. 
Proin tristique nulla nec purus iaculis, eu scelerisque mi egestas. 
In hac habitasse platea dictumst.
Ut placerat a neque eget aliquet. """.lower().replace("\n", "").replace(",", "").split(".")
Test = [
    sent.split()
    for sent in sentences
    if sent.split()
]

Testing function

import time
import statistics

def test(patch, n_iters=10, sentences=Test, name="default"):
    tagger, _ = load_and_monkey_patch(patch)
    lengths = [len(x) for x in sentences]
    
    times = []
    for i in range(n_iters):
        start = time.time()
        out = tagger.tag(sentences, lengths)
        times.append(time.time() - start)
    
    print("====")
    print(f"For a group of {len(sentences)} sentences over {n_iters} iterations")
    print(f"Average tagging time with `{name}`: {sum(times) / n_iters} s")
    print(f"Median tagging time with `{name}`: {statistics.median(times)}")
    print(f"Total tagging time with `{name}`: {sum(times)} s")
    
    return out

New function

def predict_max_debug(self, enc_outs, lengths,
                max_seq_len=20, bos=None, eos=None,
                context=None):
    """
    Decoding routine for inference with step-wise argmax procedure

    Parameters
    ===========
    enc_outs : tensor(src_seq_len x batch x hidden_size)
    context : tensor(batch x hidden_size), optional
    """
    eos = eos or self.label_encoder.get_eos()
    bos = bos or self.label_encoder.get_bos()
    hidden, batch, device = None, enc_outs.size(1), enc_outs.device
    inp = torch.zeros(batch, dtype=torch.int64, device=device) + bos
    hyps, scores = [], [0 for _ in range(batch)]
    
    # We store a conversion table for tensor index to 
    # Tensor Index -> Hyp Index
    indexes = {
        x: x for x in range(batch)
    }
    
    for _ in range(max_seq_len):

        # prepare input
        #    Context is NEVER changed after the method has been called
        
        emb = self.embs(inp)
        if context is not None:
            emb = torch.cat([emb, context], dim=1)
            
        # run rnn
        # Move embeddings to a 2-d Tensor to a 3-D tensor (1, word number, emb size(+context))
        emb = emb.unsqueeze(0)
        
        # Hidden is always reused
        #  -> Hidden is (1, word number, emb size)
        outs, hidden = self.rnn(emb, hidden)
        
        outs, _ = self.attn(outs, enc_outs, lengths)
        outs = self.proj(outs).squeeze(0)
        
        # get logits
        probs = F.log_softmax(outs, dim=1)
        
        # sample and accumulate
        score, inp = probs.max(1)
        
        # We create a mask of value that are not ending the string
        non_eos = (inp != eos)
        
        # Keep are the index of item we choose to keep (ie, not ending with EOS)
        keep = torch.nonzero(non_eos, as_tuple=True)[0]
        
        # add new chars to hypotheses
        # We prepare a list the size of the output (with EOS)
        # Once done, we replace the values using the table of equivalencies
        to_append = [eos for _ in range(batch)]
        new_scores = [0 for _ in range(batch)]
        
        for ind, (hyp, sc) in enumerate(zip(inp.tolist(), score.tolist())):
            to_append[indexes[ind]] = hyp
            if hyp != eos:
                scores[indexes[ind]] += sc
            
        hyps.append(to_append)
        
        # If there is no non_eos, it's the end of the prediction time
        if True not in non_eos:
            break
        
        # We update the indexes so that tensor "row" index maps to the correct 
        #   hypothesis value
        indexes = {elem: indexes[former_index] for elem, former_index in enumerate(keep.tolist())}
        # print(indexes)
        
        # Stop are the index of elements we remove from the input tensor
        inp = inp[keep]
        context = context[keep]
        lengths = lengths[keep]
        
        # Hidden is 3D with 1 in first dimension
        hidden = hidden.squeeze(0)[keep].unsqueeze(0)
        
        # enc_outs is seq * batch * size, so we tranpose and transpose back
        #   Seq_len is supposed to be equal to max(lengths), but if the maximum length is popped
        #   We need to reduce the dimension of enc_outs as well
        max_seq_len = lengths.max()
        
        enc_outs = enc_outs[:max_seq_len].transpose(0, 1)[keep].transpose(0, 1)
        
    
    hyps = [self.label_encoder.stringify(hyp) for hyp in zip(*hyps)]
    scores = [s/(len(hyp) + TINY) for s, hyp in zip(scores, hyps)]

    return hyps, scores

new_out = test(patch=predict_max_debug, sentences=Test, name="new", n_iters=100)
====
For a group of 19 sentences over 100 iterations
Average tagging time with `new`: 0.4061899709701538 s
Median tagging time with `new`: 0.40130531787872314
Total tagging time with `new`: 40.61899709701538 s

Current function

def predict_max(self, enc_outs, lengths,
                max_seq_len=20, bos=None, eos=None,
                context=None):
    """
    Decoding routine for inference with step-wise argmax procedure

    Parameters
    ===========
    enc_outs : tensor(src_seq_len x batch x hidden_size)
    context : tensor(batch x hidden_size), optional
    """
    eos = eos or self.label_encoder.get_eos()
    bos = bos or self.label_encoder.get_bos()
    hidden, batch, device = None, enc_outs.size(1), enc_outs.device
    mask = torch.ones(batch, dtype=torch.int64, device=device)
    inp = torch.zeros(batch, dtype=torch.int64, device=device) + bos
    hyps, scores = [], 0

    for _ in range(max_seq_len):
        if mask.sum().item() == 0:
            break

        # prepare input
        emb = self.embs(inp)
        if context is not None:
            emb = torch.cat([emb, context], dim=1)
        # run rnn
        emb = emb.unsqueeze(0)
        outs, hidden = self.rnn(emb, hidden)
        outs, _ = self.attn(outs, enc_outs, lengths)
        outs = self.proj(outs).squeeze(0)
        # get logits
        probs = F.log_softmax(outs, dim=1)
        # sample and accumulate
        score, inp = probs.max(1)
        hyps.append(inp.tolist())
        mask = mask * (inp != eos).long()
        score = score.cpu()
        score[mask == 0] = 0
        scores += score

    
    hyps = [self.label_encoder.stringify(hyp) for hyp in zip(*hyps)]
    scores = [s/(len(hyp) + TINY) for s, hyp in zip(scores.tolist(), hyps)]

    return hyps, scores

former_out = test(patch=predict_max, sentences=Test, name="default", n_iters=100)
====
For a group of 19 sentences over 100 iterations
Average tagging time with `default`: 0.556127781867981 s
Median tagging time with `default`: 0.5420029163360596
Total tagging time with `default`: 55.612778186798096 s
new_out == former_out
True

@PonteIneptique
Copy link
Contributor Author

This time I'll wait for your feedback @emanjavacas

@PonteIneptique
Copy link
Contributor Author

Note that the improvement I noted are highly dependant of the seq_len and the disparity accross prediction sizes.

@emanjavacas
Copy link
Owner

emanjavacas commented Aug 7, 2020 via email

@PonteIneptique
Copy link
Contributor Author

I tested on CPU for now. Which is generally what the users have for inference or webservices.

@PonteIneptique
Copy link
Contributor Author

PonteIneptique commented Aug 7, 2020

For a stupid evaluation, I ran the same model on README.md, went from 13 seconds to 7 seconds...

PonteIneptique added a commit to PonteIneptique/pie that referenced this issue Aug 7, 2020
…ched EOS for prediction (See emanjavacas#73)

The current prediction time is quite slow, we agree that there might be room for improvement.

After having a good look at it, it seemed clear that we were computing on items that technically did not need to continue to be computed upon (string that reach EOS).

I propose here my refactor of the predict_max function that stop computing over elements that reached EOS. There is probably still room for improvement here.

For a group of 19 sentences over 100 iterations
Average tagging time with default: 0.556127781867981 s
Median tagging time with default: 0.5420029163360596
Total tagging time with default: 55.612778186798096 s

For a group of 19 sentences over 100 iterations
Average tagging time with new: 0.4061899709701538 s
Median tagging time with new: 0.40130531787872314
Total tagging time with new: 40.61899709701538 s

- 27 % time for the whole tagging (lemma only)
@PonteIneptique
Copy link
Contributor Author

I actually PRed so that you can pull and test :)

@emanjavacas
Copy link
Owner

emanjavacas commented Aug 7, 2020 via email

@PonteIneptique
Copy link
Contributor Author

PonteIneptique commented Aug 7, 2020

Actually, on real world data, the improvement are substantial:

For this test, CPU = i7 6700k, GPU = 1080

CPU GPU
Unit New Actual Diff New Actual Diff
Aeneid (9843 units) sec 58 81 71.60% # 19 23 82.61%
it/sec 169.64 120.45 140.84% # 495.12 417.56 118.57%
Priapea (127 units) sec 3 5 60.00% # 1 2 50.00%
it/sec 39.98 24.4 163.85% # 64.25 61.3 104.81%
Martial (1600 units) sec 44 76 57.89% # 8 9 88.89%
it/sec 36.06 20.88 172.70% # 187.6 163.6 114.67%

For this test, i5 8265u, on a laptop (Base Freq: 1,6 GHz, 4 cores, 8 threads)

    New Actual Diff
Aeneid sec 176 245 71,84 %
  it/sec 55,74 40,11 138,97 %
Priapea sec 8 15 53,33 %
  it/sec 14,26 8,1 176,05 %
Martial sec 125 224 55,80 %
  it/sec 12,77 7,12 179,35 %

it/sec > 100 % is improved, sec < 100 % is improved

It's actually better with the files
test.zip

@PonteIneptique
Copy link
Contributor Author

It consistently beats CPU by a huge margin, and beats GPU with a relatively small margin (from 4 to 18 %)

@PonteIneptique
Copy link
Contributor Author

PonteIneptique commented Aug 8, 2020

I am quite busy right now,

I did not address this, but please, take the time you need. It's august, you are probably on vacation and if you are not, I technically am. So it can wait ;)

I finished the PR by adding a lot of comments, making sure variable names are clear. I also added / edited my comments to add some more information.

@PonteIneptique
Copy link
Contributor Author

I know your time is busy, but if you are merging PR, the #74 might be somewhere you could stop, if you have time :) Those are some nice improvements :)

PonteIneptique added a commit to lascivaroma/PaPie that referenced this issue May 19, 2021
* (improvement/AttentionalDecoder.predict_max) Avoid computation on reached EOS for prediction (See emanjavacas#73)

The current prediction time is quite slow, we agree that there might be room for improvement.

After having a good look at it, it seemed clear that we were computing on items that technically did not need to continue to be computed upon (string that reach EOS).

I propose here my refactor of the predict_max function that stop computing over elements that reached EOS. There is probably still room for improvement here.

For a group of 19 sentences over 100 iterations
Average tagging time with default: 0.556127781867981 s
Median tagging time with default: 0.5420029163360596
Total tagging time with default: 55.612778186798096 s

For a group of 19 sentences over 100 iterations
Average tagging time with new: 0.4061899709701538 s
Median tagging time with new: 0.40130531787872314
Total tagging time with new: 40.61899709701538 s

- 27 % time for the whole tagging (lemma only)

* (Improvement/AttentionalDecoder.predict_max) Lots of comment and some performance fixes (avoid transposing when non needed)

* Update torch

* (improvement/argmax.decoder) Improved a little more readability

* (improvement/decode) Removed loop by tensor use for prediction

* Do not update sneakily the torch requirements
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants