Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core] add and implement VLLM_LOGITS_PROCESSOR_THREADS #12368

Merged
merged 5 commits into from
Feb 5, 2025

Conversation

akeshet
Copy link
Contributor

@akeshet akeshet commented Jan 23, 2025

This PR adds an environment variable VLLM_LOGITS_PROCESSOR_THREADS.

If set, this will cause VLLM to use a threadpool of the given size to multithread (across sequences in a batch) its calls to the provided logits processors.

This can increase GPU utilization and decrease ITL in cases where batches are large and where logits processors either (a) launch additional CUDA kernels or (b) do significant CPU-bound work while not holding the python GIL, or both.

Copy link

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can do one of these:

  • Add ready label to the PR
  • Enable auto-merge.

🚀

@akeshet
Copy link
Contributor Author

akeshet commented Jan 23, 2025

Note: this PR is slightly different from the one I have already tested internally. I'd like to test it a bit before its truly ready for review, but I didn't see an option to create this PR as a draft.

@akeshet akeshet changed the title [Core] add and implement VLLM_LOGITS_PROCESSOR_THREADS [Core] add and implement VLLM_LOGITS_PROCESSOR_THREADS Jan 23, 2025
Signed-off-by: Aviv Keshet <akeshet@scaledcognition.com>
@akeshet akeshet force-pushed the akeshet/threaded_logits branch from dc3aa96 to ac402d3 Compare January 23, 2025 20:59
@akeshet
Copy link
Contributor Author

akeshet commented Jan 23, 2025

Note: this PR is slightly different from the one I have already tested internally. I'd like to test it a bit before its truly ready for review, but I didn't see an option to create this PR as a draft.

My local tests look good, ready for review!

Signed-off-by: Aviv Keshet <akeshet@scaledcognition.com>
Signed-off-by: Aviv Keshet <akeshet@scaledcognition.com>
@russellb
Copy link
Member

This makes sense to me. Do you happen to have any performance data you can share? Perhaps benchmarking with structured output and xgrammar? It would be nice to show the benefit concretely. There are some scripts that would help automate this -- benchmarks/benchmark_guided.py and benchmarks/benchmark_serving_guided.py.

Also, I wonder about just making this the default behavior instead of requiring it to be enabled via a tunable. Allowing the threadpool to be adjusted seems fine, but if the benefit is noticeable with one of our in-tree logits processors (xgrammar, in particular), then putting it on by default may make sense.

@akeshet
Copy link
Contributor Author

akeshet commented Jan 28, 2025

This makes sense to me. Do you happen to have any performance data you can share?

I can't share my raw numbers or traces, but I can say that we see a roughly 10% improvement in tok/s and req/s when using this feature in combination with our internal logits processor (and with request concurrency of ~50, and threadpool size 50).

Also, I wonder about just making this the default behavior instead of requiring it to be enabled via a tunable. Allowing the threadpool to be adjusted seems fine, but if the benefit is noticeable with one of our in-tree logits processors (xgrammar, in particular), then putting it on by default may make sense.

I wanted to be conservative and not alter existing behavior. As is often the case with python multithreading, there are cases where I suspect this would hurt rather than help performance. For instance, to get the benefit from this PR with our internal logits processor logic I had to make some tweaks (a well placed cuda sync) to ensure that the logits processors in separate threads don't block eachother via cuda's memcpy lock.

@russellb
Copy link
Member

This makes sense to me. Do you happen to have any performance data you can share?

I can't share my raw numbers or traces, but I can say that we see a roughly 10% improvement in tok/s and req/s when using this feature in combination with our internal logits processor (and with request concurrency of ~50, and threadpool size 50).

Also, I wonder about just making this the default behavior instead of requiring it to be enabled via a tunable. Allowing the threadpool to be adjusted seems fine, but if the benefit is noticeable with one of our in-tree logits processors (xgrammar, in particular), then putting it on by default may make sense.

I wanted to be conservative and not alter existing behavior. As is often the case with python multithreading, there are cases where I suspect this would hurt rather than help performance. For instance, to get the benefit from this PR with our internal logits processor logic I had to make some tweaks (a well placed cuda sync) to ensure that the logits processors in separate threads don't block eachother via cuda's memcpy lock.

OK, thanks. If we're not able to show how it helps something in-tree, then I agree with having it off by default.

Can you expand on the doc text to add some guidance on when it would be useful? Something like "useful when using custom logits processors that either (a) launch additional CUDA kernels or (b) do significant CPU-bound work while not holding the python GIL, or both."

Signed-off-by: Aviv Keshet <akeshet@scaledcognition.com>
@akeshet
Copy link
Contributor Author

akeshet commented Jan 29, 2025

Can you expand on the doc text to add some guidance on when it would be useful? Something like "useful when using custom logits processors that either (a) launch additional CUDA kernels or (b) do significant CPU-bound work while not holding the python GIL, or both."

Done.

Copy link
Member

@russellb russellb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm, thanks!

Copy link
Member

@mgoin mgoin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM given this is limited to within logits_processor.py and is disabled by default, thanks

@mgoin mgoin added the ready ONLY add when PR is ready to merge/full CI is needed label Feb 4, 2025
@simon-mo simon-mo merged commit b3a0d01 into vllm-project:main Feb 5, 2025
45 of 48 checks passed
fxmarty-amd pushed a commit to fxmarty-amd/vllm that referenced this pull request Feb 7, 2025
…t#12368)

Signed-off-by: Aviv Keshet <akeshet@scaledcognition.com>
Signed-off-by: Felix Marty <felmarty@amd.com>
ShangmingCai pushed a commit to ShangmingCai/vllm that referenced this pull request Feb 10, 2025
…t#12368)

Signed-off-by: Aviv Keshet <akeshet@scaledcognition.com>
panf2333 pushed a commit to yottalabsai/vllm that referenced this pull request Feb 18, 2025
…t#12368)

Signed-off-by: Aviv Keshet <akeshet@scaledcognition.com>
kerthcet pushed a commit to kerthcet/vllm that referenced this pull request Feb 21, 2025
…t#12368)

Signed-off-by: Aviv Keshet <akeshet@scaledcognition.com>
@alejopaullier96
Copy link

@akeshet is there a minimal example on how to use logits processor with this environment variable?

@akeshet
Copy link
Contributor Author

akeshet commented Feb 28, 2025

@akeshet is there a minimal example on how to use logits processor with this environment variable?

@alejopaullier96 The minimum example would be simply to export VLLM_LOGS_PROCESSOR_THREADS=$VALUE prior to launching vllm. But it will really depend on the details of yours logits processor you are using, as to whether this gives any performance benefit. I came to this optimization by examining nsys traces of our workload with a custom logits processor.

lk-chen pushed a commit to lk-chen/vllm that referenced this pull request Mar 5, 2025
…t#12368)

Signed-off-by: Aviv Keshet <akeshet@scaledcognition.com>
Signed-off-by: Linkun Chen <github@lkchen.net>
Said-Akbar pushed a commit to Said-Akbar/vllm-rocm that referenced this pull request Mar 7, 2025
…t#12368)

Signed-off-by: Aviv Keshet <akeshet@scaledcognition.com>
Signed-off-by: saeediy <saidakbarp@gmail.com>
qscqesze pushed a commit to ZZBoom/vllm that referenced this pull request Mar 13, 2025
…t#12368)

Signed-off-by: Aviv Keshet <akeshet@scaledcognition.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants