Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

beam search #186

Closed
vince62s opened this issue Jun 5, 2018 · 29 comments
Closed

beam search #186

vince62s opened this issue Jun 5, 2018 · 29 comments

Comments

@vince62s
Copy link

vince62s commented Jun 5, 2018

Hi guys,

quick question for you.
Are you using multi threads when decoding beam search to parallelize each segment ?

@emjotde
Copy link
Member

emjotde commented Jun 5, 2018

Hi,
on the GPU no, there you can adjust batch size with for instance:

--mini-batch 64 --maxi-batch 100 --mini-batch-sort src

This will translate 64 sentences at once, preload 100 mini-batches and bucket by source sentence length for batch packing.

When you use the CPU, you can set (the = is required)

--cpu-threads=N

To translate sentences in parallel.

@emjotde
Copy link
Member

emjotde commented Jun 5, 2018

Actually, let me correct that. When using multiple-GPUs they will be used in parallel for translating multiple mini-batches.

@vince62s
Copy link
Author

vince62s commented Jun 5, 2018

on GPU, can you be more specific ?
You will batch sentences for the encoding/decoding/generator,
but the beam search itself on the output, is it done on CPU ? GPU ? parallel ?

@emjotde
Copy link
Member

emjotde commented Jun 5, 2018

We have a GPU and a CPU mode. The GPU mode is mostly happening on the GPU, we only record the indices to select hypotheses for the next step in CPU memory. Top-K search is happening on the GPU. This is one of the reasons we are so much faster at translation than anyone else.

@emjotde
Copy link
Member

emjotde commented Jun 7, 2018

Do you need any more information or can I close this?

@vince62s
Copy link
Author

vince62s commented Jun 7, 2018

We are doing also the topk search on GPU but obviously we are missing something.
you're doing a great job.
Cheers.

@vince62s vince62s closed this as completed Jun 7, 2018
@hieuhoang
Copy link
Collaborator

what toolkit are you using? Some comparison numbers would be good, out of curiosity

@vince62s
Copy link
Author

vince62s commented Jun 7, 2018

Hi Hieu,
I am on https://github.com/Ubiqus/OpenNMT-py
we made it multi-gpu and training as fast as Marian (on 4GPU)
we are about to commit AAN, we need to fix 2 things:

  • cache on decoder
  • batched beam search which is too slow rigth now.
    but even with those I doubt I can meet your decoding numbers.

@hieuhoang
Copy link
Collaborator

it would be good to put flesh on the bone with some numbers & details. Then maybe we can exchange tips

@vince62s
Copy link
Author

vince62s commented Jun 7, 2018

sure. For DE to EN training on 4 GPU I am at about 30K tok/sec (src or tgt about the same, since I use sentence piece)
for decoding, it's too slow so I prefer to wait for our integration of caches and batched beam search, then I'll tell you.

@emjotde
Copy link
Member

emjotde commented Jun 7, 2018

What kind of model?

@emjotde
Copy link
Member

emjotde commented Jun 7, 2018

@hieuhoang It would also be nice if the exchange would not be one-sided :)

@vince62s
Copy link
Author

vince62s commented Jun 7, 2018

transformer_base
are you asking Hieu to give me numbers or do you need more info from me ? :)

@emjotde
Copy link
Member

emjotde commented Jun 7, 2018

Rather extending Hieu's comment with a small dose of snarkiness :)

I believe our implementation of transformer-base is sub-optimal at training time. Still more to do. At least scaling across GPUs is decent enough.

@vince62s
Copy link
Author

vince62s commented Jun 7, 2018

one tip: did you implement an "accumulated gradient" feature?
basically it emulates several GPU on one.
for instance if I set accum=2 on 4GPU it will act as 8 GPU.
transformer is really sensitive to global batch size.
see my issue here: tensorflow/tensor2tensor#444

so we compute loss on 2 mini batchs and do the update of params each 2.

@kpu
Copy link
Member

kpu commented Jun 7, 2018

Yeah, we call it --optimizer-delay

@emjotde
Copy link
Member

emjotde commented Jun 7, 2018

Oh, is that how you get faster? In that case we need to update benchmarks :)

@kpu
Copy link
Member

kpu commented Jun 7, 2018

I'm just gonna leave this training gauntlet here: https://arxiv.org/pdf/1806.00187.pdf

@vince62s
Copy link
Author

vince62s commented Jun 7, 2018

no no !!! :)
It was just a Bleu perf trick but I see everyon eis doing the same thing.
Why are you saying we are faster ? I think we are about the same for training, but you're blazing faster for decoding.

@vince62s
Copy link
Author

vince62s commented Jun 7, 2018

yeah we used pytorch distributed too, so if I need it I will also implement the multi node, but no use for now.

@emjotde
Copy link
Member

emjotde commented Jun 7, 2018

"Faster" as in faster than before.

@kpu So, how is multi-node sync training going? :)

@vince62s
Copy link
Author

I am still stuck so I need to ask again.
when it says 12 seconds for wall time decoding 3K segments under the line "base-transformer-aan".
is it a real transformer that decodes that fast ? or a RNN student distillated from a transformer teacher ?

@emjotde
Copy link
Member

emjotde commented Jun 15, 2018

Real with decoder self-attention replaced with an average attention network. Rest is the same. RNN students are worse in terms of BLEU.

@vince62s
Copy link
Author

hmmm aan is not that fast for us and bleu perf degrades. need to rework the batched beam. but will never be that fast in pytorch.

@emjotde
Copy link
Member

emjotde commented Jun 15, 2018

That's our whole story, meta-algorithms in C++ beat Python :)

@emjotde
Copy link
Member

emjotde commented Jun 15, 2018

Just looked at the paper, the one where it says 12s is a normal transformer, the one with the AAN takes 7s.

@vince62s
Copy link
Author

https://docs.google.com/spreadsheets/d/1wZQegK-9CKY378eAWRlahg23Fq155WTm4TQ8ikf8_6E/edit#gid=0
line 46 says 12 seconds (7 is the empty timing, right ?)
anyway it's just not the same order range I'm looking at ....

@emjotde
Copy link
Member

emjotde commented Jun 15, 2018

7s is Walltime minus empty, that's the actual time without start-up (empty).

@emjotde
Copy link
Member

emjotde commented Jun 18, 2018

@vince62s Should be a good deal faster now after the weekend coding session. Possibly by a factor of 1.5-2.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants