Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core] Support thread-based async tokenizer pools #3449

Open
wants to merge 4 commits into
base: main
Choose a base branch
from

Conversation

njhill
Copy link
Member

@njhill njhill commented Mar 16, 2024

#2879 added support for using ray to offload tokenization from the asyncio event loop.

This PR extends that to support using a thread pool instead of ray, and makes that the default, with the default pool size determined based on the number of available CPU cores and the tensor parallel size.

The main thing to note is that separate tokenizer instances are used per thread. This is because officially the HF tokenizers are not thread-safe. In practice I think they are unless you're making use of padding/truncation, which we aren't currently but may want to soon (see for example #3144).

Also includes some type hint additions to related parts of the code.

This replaces the original PR #3206 from before #2879 was reworked and merged.

@njhill njhill marked this pull request as ready for review March 16, 2024 18:53
@simon-mo simon-mo requested a review from Yard1 March 16, 2024 20:41
@Yard1 Yard1 self-assigned this Mar 17, 2024
@Yard1
Copy link
Collaborator

Yard1 commented Mar 17, 2024

Thanks, I will review properly on Monday. Once comment:

This PR extends that to support using a thread pool instead of ray, and makes that the default, with the default pool size determined based on the number of available CPU cores and the tensor parallel size.

I don't think we should default to all available cores by default. This will cause issues and resource contention in multi-tenant cases (plus some threads will be already occupied by eg. vLLM workers). I think we should default to thread pool of size 1, and let the user configure the size if they want to.

@njhill
Copy link
Member Author

njhill commented Mar 17, 2024

Thanks @Yard1!

I don't think we should default to all available cores by default. This will cause issues and resource contention in multi-tenant cases (plus some threads will be already occupied by eg. vLLM workers). I think we should default to thread pool of size 1, and let the user configure the size if they want to.

I'm not actually setting the pool size equal to the number of cores, rather using formula max(1, min(16, cpu_count - tp_size - 1)). IMHO it's better for users to not have to know about and individually tune all of these things. I think multi-tenant concerns will be less common since I'd expect most folks will be deploying in a container with dedicated resource allocations.

vllm/config.py Outdated Show resolved Hide resolved
Copy link
Collaborator

@Yard1 Yard1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, looks good. Left some comments. Let's add a test for this as well (there are tests where you just need to change the TokenizerGroup type). I am still not sure about the automatic threadpool size, would be great if some other committers could weigh in.

def _encode_local(self, *args, **kwargs):
return self.local.tokenizer.encode(*args, **kwargs)

def encode(self, *args, **kwargs):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's do full signature instead of *args/**kwargs, makes it easier to read the code

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Yard1 let me try to make a case for this here, happy to change if you aren't convinced :)

The actual args are kind of irrelevant here since it's essentially a pass-through within a specific implementation. Calling code is always working with the methods in the abstract superclass.

So using *args/**kwargs here to me is actually clearer because you can more easily see that it's a pass-through without the verbose boilerplate, and additionally does not need to be kept in sync if the specific args change.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see your point - seeing as this is a stylistic choice, it would probably be best for the authors to weigh in and set a precedent to follow @WoosukKwon @zhuohan123

vllm/utils.py Show resolved Hide resolved
vllm-project#2879 added support for using ray to offload tokenization from the asyncio event loop.

This PR extends that to support using a thread pool instead of ray, and makes that the default, with the default pool size determined based on the number of available CPU cores and the tensor parallel size.

The main thing to note is that separate tokenizer instances are used per thread. This is because officially the HF tokenizers are not thread-safe. In practice I think they are unless you're making use of padding/truncation, which we aren't currently but may want to soon (see for example vllm-project#3144).

Also includes some type hint additions to related parts of the code.

This replaces the original PR vllm-project#3206 from before vllm-project#2879 was reworked and merged.
@Qubitium
Copy link
Contributor

Qubitium commented Mar 19, 2024

Threads will hit the GIL so there can be only 1 actual tokenizer op at once. Actual tokenizer ops are serialized due to gil. Perhaps a ProcessPool executor?

@njhill
Copy link
Member Author

njhill commented Mar 19, 2024

@Qubitium the majority of tokenizers use a fast rust-based implementation which doesn't hold the GIL.

@Qubitium
Copy link
Contributor

Qubitium commented Mar 19, 2024

@Qubitium the majority of tokenizers use a fast rust-based implementation which doesn't hold the GIL.

@njhill But isn't the entry to the rust code (hf fast tokenizers) just normal python which is still bound by GIL? I would think the threads are just piled up waiting to call lower level rust code.

@njhill
Copy link
Member Author

njhill commented Mar 19, 2024

Argh @Qubitium @Yard1 yes, I may have been mistaken w.r.t. the GIL being released in the rust code for encoding. So we can hold this for now, I'll look into it more. ProcessPool would also be an option (i.e. for parallel tokenization without ray), but I'm concerned about the serialization overhead.

@remusao
Copy link

remusao commented Mar 19, 2024

@Qubitium the majority of tokenizers use a fast rust-based implementation which doesn't hold the GIL.

@njhill But isn't the entry to the rust code (hf fast tokenizers) just normal python which is still bound by GIL? I would think the threads are just piled up waiting to call lower level rust code.

Why would the threads pile-up? If the wrapping Python code is not CPU-bound, I would expect that once the 'no GIL' section of the code is reached (i.e. the Rust code), the tokenizing work can then proceed in the thread, while other threads can then start jobs of their own. I made a quick benchmark to assess and using a thread pool seems to help (although increasing the number of threads leads to diminishing returns and even increases the total time after a while so maybe that's something to benchmark more carefully in the context of vLLM to make sure the optimal value is found).

We see that going from 1 thread to 2 threads ~halves the total execution time:

With threadpool (n=1): 0.8177816867828369
With threadpool (n=2): 0.4640841484069824
With threadpool (n=3): 0.31597471237182617
With threadpool (n=4): 0.3116874694824219
With threadpool (n=5): 0.3126490116119385
With threadpool (n=6): 0.32393932342529297
With threadpool (n=7): 0.38161683082580566
With threadpool (n=8): 0.5974140167236328
With threadpool (n=9): 0.6499311923980713
With threadpool (n=10): 0.6498591899871826
With threadpool (n=11): 0.6028220653533936
With threadpool (n=12): 0.5991456508636475
With threadpool (n=13): 0.5550651550292969
With threadpool (n=14): 0.5854413509368896
With threadpool (n=15): 0.6049799919128418
*Without* threadpool: 0.8543827533721924
Benchmarking code
import time
from concurrent.futures import ThreadPoolExecutor

from transformers import LlamaTokenizerFast

tokenizer = LlamaTokenizerFast.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2")


def tokenize(*args, **kwargs):
    return tokenizer.tokenize(
        """Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. At tempor commodo ullamcorper a. Blandit massa enim nec dui nunc. Turpis egestas sed tempus urna et pharetra pharetra. Blandit cursus risus at ultrices. Turpis egestas maecenas pharetra convallis posuere morbi leo urna. Vitae justo eget magna fermentum iaculis eu non diam. Id eu nisl nunc mi ipsum faucibus. Viverra ipsum nunc aliquet bibendum enim facilisis gravida neque. Congue eu consequat ac felis donec et odio pellentesque diam. Tincidunt eget nullam non nisi est sit amet facilisis magna. Sed lectus vestibulum mattis ullamcorper velit sed ullamcorper morbi. Odio tempor orci dapibus ultrices in. Sodales ut etiam sit amet. Facilisi cras fermentum odio eu feugiat pretium. Convallis convallis tellus id interdum velit laoreet id donec ultrices.

Venenatis urna cursus eget nunc scelerisque. Massa enim nec dui nunc mattis enim ut. Semper risus in hendrerit gravida rutrum quisque non tellus. Sed lectus vestibulum mattis ullamcorper velit. Odio euismod lacinia at quis risus sed. Ultricies leo integer malesuada nunc vel. Proin sed libero enim sed faucibus. Auctor urna nunc id cursus metus aliquam. Tortor posuere ac ut consequat semper viverra. Volutpat est velit egestas dui id. Maecenas pharetra convallis posuere morbi leo urna molestie at. Dapibus ultrices in iaculis nunc sed augue. Mi proin sed libero enim sed faucibus turpis. Est velit egestas dui id. Felis eget velit aliquet sagittis id consectetur purus ut. Aliquam eleifend mi in nulla posuere sollicitudin aliquam ultrices. Pharetra magna ac placerat vestibulum lectus mauris ultrices. Malesuada nunc vel risus commodo viverra maecenas.

Nibh tortor id aliquet lectus proin nibh nisl condimentum id. Sodales ut eu sem integer. Morbi enim nunc faucibus a pellentesque sit amet porttitor eget. Risus in hendrerit gravida rutrum quisque non. Integer feugiat scelerisque varius morbi enim nunc faucibus a. Massa id neque aliquam vestibulum. Nam libero justo laoreet sit amet cursus sit amet. Ornare quam viverra orci sagittis eu volutpat odio. Tellus pellentesque eu tincidunt tortor. Etiam dignissim diam quis enim. Sed lectus vestibulum mattis ullamcorper velit sed ullamcorper morbi. Magna eget est lorem ipsum dolor sit amet consectetur. Laoreet sit amet cursus sit amet.

Amet purus gravida quis blandit. Amet justo donec enim diam vulputate ut pharetra sit. Diam sit amet nisl suscipit adipiscing bibendum est ultricies integer. Pellentesque diam volutpat commodo sed. Pulvinar pellentesque habitant morbi tristique senectus. Elementum eu facilisis sed odio morbi quis. Tristique senectus et netus et malesuada fames ac turpis egestas. Tellus in hac habitasse platea dictumst vestibulum rhoncus. Tellus elementum sagittis vitae et leo duis ut diam. Purus in mollis nunc sed id semper risus in hendrerit. Sed id semper risus in hendrerit gravida rutrum quisque. Pharetra convallis posuere morbi leo urna molestie at.

Arcu cursus vitae congue mauris rhoncus aenean. Praesent elementum facilisis leo vel fringilla. Et ultrices neque ornare aenean euismod elementum nisi quis. Lectus mauris ultrices eros in cursus. Pretium fusce id velit ut. Commodo odio aenean sed adipiscing diam donec adipiscing. Purus non enim praesent elementum facilisis. Eget mauris pharetra et ultrices neque ornare aenean euismod elementum. Nunc vel risus commodo viverra maecenas. Ultricies lacus sed turpis tincidunt id aliquet risus feugiat in. Egestas quis ipsum suspendisse ultrices gravida. Sed nisi lacus sed viverra tellus in hac habitasse platea. Tortor at risus viverra adipiscing at in tellus integer. Lacus suspendisse faucibus interdum posuere lorem ipsum dolor sit. Enim tortor at auctor urna nunc. Morbi non arcu risus quis. Erat velit scelerisque in dictum non consectetur. Quis hendrerit dolor magna eget est lorem ipsum. In aliquam sem fringilla ut.

Duis tristique sollicitudin nibh sit. Fames ac turpis egestas integer eget aliquet. Congue nisi vitae suscipit tellus mauris a diam maecenas. Auctor augue mauris augue neque gravida in. Mauris augue neque gravida in fermentum et. Ut consequat semper viverra nam libero justo laoreet. Porta lorem mollis aliquam ut porttitor leo a. Quis auctor elit sed vulputate mi sit amet mauris. Quisque sagittis purus sit amet volutpat. Nam at lectus urna duis convallis. Rhoncus aenean vel elit scelerisque.

Viverra tellus in hac habitasse platea dictumst vestibulum rhoncus est. Felis eget velit aliquet sagittis. Lectus mauris ultrices eros in cursus. Nec feugiat nisl pretium fusce id velit ut. Posuere morbi leo urna molestie at elementum eu facilisis. Amet risus nullam eget felis. Sed pulvinar proin gravida hendrerit lectus. Ac turpis egestas sed tempus. Rhoncus mattis rhoncus urna neque viverra justo nec. Platea dictumst quisque sagittis purus sit amet volutpat consequat mauris. Vel turpis nunc eget lorem. Diam quam nulla porttitor massa id neque aliquam vestibulum.

Arcu cursus euismod quis viverra nibh cras pulvinar. Vitae justo eget magna fermentum iaculis. Ac turpis egestas sed tempus urna et pharetra. Urna porttitor rhoncus dolor purus non. Morbi non arcu risus quis varius quam. Neque aliquam vestibulum morbi blandit cursus risus at ultrices mi. Volutpat blandit aliquam etiam erat. Sapien et ligula ullamcorper malesuada proin libero nunc consequat interdum. Luctus venenatis lectus magna fringilla urna. Facilisi morbi tempus iaculis urna. Semper auctor neque vitae tempus quam pellentesque. Nisl condimentum id venenatis a condimentum vitae sapien pellentesque habitant.

Sem viverra aliquet eget sit amet tellus. Ullamcorper velit sed ullamcorper morbi. Sit amet tellus cras adipiscing enim. Natoque penatibus et magnis dis parturient montes nascetur ridiculus. Aenean euismod elementum nisi quis eleifend quam adipiscing vitae proin. Nibh sit amet commodo nulla. Posuere urna nec tincidunt praesent semper feugiat nibh sed. Rhoncus aenean vel elit scelerisque mauris pellentesque pulvinar. In massa tempor nec feugiat nisl pretium fusce id. Adipiscing enim eu turpis egestas pretium aenean pharetra magna ac. A diam sollicitudin tempor id eu nisl. Id diam vel quam elementum pulvinar etiam non. Aliquam faucibus purus in massa. Arcu risus quis varius quam quisque. Est lorem ipsum dolor sit amet consectetur. Lacinia quis vel eros donec ac. Id semper risus in hendrerit.

Arcu odio ut sem nulla pharetra diam sit. Ut consequat semper viverra nam. Urna nunc id cursus metus aliquam eleifend. Nibh tortor id aliquet lectus proin nibh nisl condimentum id. Id venenatis a condimentum vitae. Leo vel fringilla est ullamcorper eget nulla. Blandit aliquam etiam erat velit scelerisque. Amet porttitor eget dolor morbi non arcu. Ut porttitor leo a diam sollicitudin. Laoreet suspendisse interdum consectetur libero. Tempor id eu nisl nunc mi ipsum faucibus vitae. Leo duis ut diam quam nulla porttitor. Duis at tellus at urna condimentum. Egestas erat imperdiet sed euismod nisi porta."""
    )


def with_threadpool(n=1):
    with ThreadPoolExecutor(max_workers=n) as executor:
        executor.map(tokenize, range(1000), chunksize=8)


def without_threadpool():
    for _ in range(1000):
        tokenize()


if __name__ == "__main__":
    # Warm-up
    with_threadpool()
    without_threadpool()

    for n in range(1, 16):
        t0 = time.time()
        with_threadpool(n=n)
        t1 = time.time()
        print(f"With threadpool (n={n}):", t1 - t0)

    t0 = time.time()
    without_threadpool()
    t1 = time.time()
    print("*Without* threadpool:", t1 - t0)

@Qubitium
Copy link
Contributor

Qubitium commented Mar 19, 2024

@remusao I may be wrong but I believe you're are just timing the task submission. You also want to init the pools outside and pass them inside or else the warmups doesn't work.

executor.map(tokenize, range(1000), chunksize=8)

@Qubitium
Copy link
Contributor

Qubitium commented Mar 19, 2024

@remusao You're right. I ran the code, even making sure I looped over the map iterable so that all the yielded result happen and I also see 2-2.5x speed improvement.

@Yard1
Copy link
Collaborator

Yard1 commented Mar 19, 2024

I think it would be good to use the serving benchmark to see what the improvement is.

@njhill
Copy link
Member Author

njhill commented Mar 19, 2024

@Yard1 while we investigate the perf of threads/procs for this I have opened a couple of narrower PRs covering the orthogonal parts of this one, PTAL! (isubset of this one so not really any new code to review)

@zhaoyang-star
Copy link
Contributor

Great work! I am also curious about the speedup of throughput benchmark. Is this pr still active? @njhill

Copy link

This pull request has been automatically marked as stale because it has not had any activity within 90 days. It will be automatically closed if no further activity occurs within 30 days. Leave a comment if you feel this pull request should remain open. Thank you!

@github-actions github-actions bot added the stale label Oct 29, 2024
Copy link

mergify bot commented Oct 29, 2024

This pull request has merge conflicts that must be resolved before it can be
merged. @njhill please rebase it. https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants