Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Speculative Decoding 2/2 ] Integrate typical acceptance sampler into Spec Decode Worker #5348

Merged

Conversation

sroy745
Copy link
Collaborator

@sroy745 sroy745 commented Jun 7, 2024

FIX #5015 (link existing issues this PR will resolve)
In this PR we integrate the TypicalAcceptanceSampler into the spec_decode_worker framework. To that end we make the following changes

  1. Introduce 3 engine level args to specify what is the type of verification sampler (RejectionSampler or TypicalAcceptanceSampler) to use for speculative-decoding. If it is the TypicalAcceptanceSampler then we provide 2 more arguments to configure the same.

  2. We plumb all these arguments to the spec_decode_worker and depending on the values of the arguments we use the requested verification sampler.

  3. Modify the spec_decode/e2e/test_multistep_correctness.py to also run with TypicalAcceptanceSampler.

  4. Modify other tests like test_spec_decode_worker etc to work with TypicalAcceptanceSampler.

Please ignore the changes in vllm/model_executor/layers/. They are already being reviewed as part of another pr.

@sroy745 sroy745 marked this pull request as draft June 7, 2024 16:53
@sroy745 sroy745 changed the title [WIP][Spec Decode] integrate typical acceptance sampler into Spec Decode Worker [Speculative Decoding 2/2 ] Integrate typical acceptance sampler into Spec Decode Worker Jun 11, 2024
@sroy745 sroy745 marked this pull request as ready for review June 11, 2024 20:31
Copy link
Collaborator

@cadedaniel cadedaniel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good overall. some comments on testing and config.

Comment on lines 14 to 22
At temp=0, the TypicalAcceptanceSampler ensures that only the tokens with the
highest probability in the target distribution are accepted. Therefore, we can
expect greedy equality for the TypicalAcceptanceSampler at temp=0.

For temp>0, we rely on unit tests on the rejection sampler to verify that the
output distribution is the same with spec decode vs. no spec decode (this would
be prohibitively expensive to run with a real model).
be prohibitively expensive to run with a real model). Similary, for the
TypicalAcceptance sampler also, we rely on unit tests to validate temp>0
test cases.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For testing strategy:

I am concerned that we are adding many E2E tests that don't provide a lot of signal over what already exists. The tradeoff of more tests is that we can accidentally explode CI time. This is because we rely on E2E tests for spec decode correctness and any small regression in model loading or vLLM initialization time can hurt us bad.

So, what I suggest:

  • E2E tests over the interaction between spec decode worker and typical acceptance
    • make sure it can handle different BS
    • make sure it can handle different K
  • E2E test over one other proposer method
    • just need one to make sure typical acceptance works beyond draft model
  • We don't need tests around preemption, disabling/skipping speculation, different block size, since all of these are no different for typical acceptance.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a single test in test_multistep_correctness.py to cover different batch size and speculation_length values with TypicalAcceptanceSampler. Added a similar test to test_ngram_correctness.py

Comment on lines 60 to 61
def test_correctly_calls_target_model(
k: int, batch_size: int, mock_spec_decode_sampler):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

great unit test coverage!

Comment on lines 225 to 229
if isinstance(spec_decode_sampler, RejectionSampler):
# The draft probabilites is used only by the RejectionSampler.
# Ensure that if the sampler is a RejectionSampler then the
# draft probs are being passed.
assert torch.equal(actual.draft_probs, proposal_probs)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I missed this in the other PR. Can we have both acceptance routines take in the draft probabilities so we don't have branches like this? We can have the typical acceptance just ignore them.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed this. Added the same interface for both the samplers. I have included the changes in this PR because it is not a lot of changes. Let me know if that is fine for you in terms of reviewing.

tests/spec_decode/test_utils.py Outdated Show resolved Hide resolved
vllm/config.py Outdated
Comment on lines 824 to 828
draft_token_sampling_method (Optional[str]): The sampling method
to use for accepting draft tokens. This can take two possible
values 'rejection_sampler' and 'typical_acceptance_sampler'
for RejectionSampler and TypicalAcceptanceSampler
respectively.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we call this spec_acceptance_routine or spec_sampling_method? draft_token_sampling_method makes it seem this is what is used to sample only draft tokens.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Renamed this to spec_decoding_acceptance_method. Replaced the name draft_token_sampling_method with draft_token_acceptance_method to avoid confusion.

vllm/engine/arg_utils.py Outdated Show resolved Hide resolved
vllm/spec_decode/spec_decode_worker.py Outdated Show resolved Hide resolved
vllm/spec_decode/spec_decode_worker.py Show resolved Hide resolved
Comment on lines 174 to 177
assert (
self.spec_decode_sampler is not None,
"Sampler is Not set, which is not expected."
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: we should rely on upper layers to provide validation and raise; we don't need to assert something basic this deep.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. Removed this here. Added validations earlier.

Comment on lines 499 to 513
if isinstance(self.spec_decode_sampler, RejectionSampler):
accepted_token_ids = self.spec_decode_sampler(
target_probs=proposal_verifier_probs,
bonus_token_ids=bonus_token_ids,
draft_probs=proposal_probs,
draft_token_ids=proposal_token_ids,
)
else:
assert isinstance(self.spec_decode_sampler, TypicalAcceptanceSampler)
accepted_token_ids = self.spec_decode_sampler(
target_probs=proposal_verifier_probs,
bonus_token_ids=bonus_token_ids,
draft_token_ids=proposal_token_ids,
)
return accepted_token_ids
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 to making them meet the same interface (both take draft_token_probs, typical acceptance ignores them)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added an abstract forward method to the base SpecDecodeBaseSampler and implementing the same in both the samplers.

Copy link
Collaborator Author

@sroy745 sroy745 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the review. Addressed comments. PTAL.

Copy link
Collaborator

@cadedaniel cadedaniel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, see comment about checking test time

tests/spec_decode/e2e/test_ngram_correctness.py Outdated Show resolved Hide resolved
@sroy745
Copy link
Collaborator Author

sroy745 commented Jun 27, 2024

FYI there are some test failures. I am fixing them. Will ping once they are fixed.

@DarkLight1337
Copy link
Member

To speed up the CI queue, I've cancelled the distributed tests for the latest CI run in this PR since they won't pass anyway until #5905 has been merged. Now that it has been merged, please merge main into your branch so that the CI can pass once again.

@sroy745
Copy link
Collaborator Author

sroy745 commented Jun 29, 2024

All tests are passing except 1 test in tests/models/test_big_models.py which failure doesn't seem related. Other than that the pr is ready for merging with all comments addressed. PTAL

@cadedaniel
Copy link
Collaborator

OK, will take a look tomorrow !

@cadedaniel cadedaniel merged commit 80ca1e6 into vllm-project:main Jul 1, 2024
68 checks passed
prashantgupta24 pushed a commit to opendatahub-io/vllm that referenced this pull request Jul 1, 2024
prashantgupta24 pushed a commit to opendatahub-io/vllm that referenced this pull request Jul 1, 2024
robertgshaw2-neuralmagic pushed a commit to neuralmagic/nm-vllm that referenced this pull request Jul 1, 2024
kzawora-intel added a commit to HabanaAI/vllm-fork that referenced this pull request Jul 2, 2024
* [Hardware][Intel] Optimize CPU backend and add more performance tips (vllm-project#4971)

Co-authored-by: Jianan Gu <[email protected]>

* [Docs] Add 4th meetup slides (vllm-project#5509)

* [Misc] Add vLLM version getter to utils (vllm-project#5098)

* [CI/Build] Simplify OpenAI server setup in tests (vllm-project#5100)

* [Doc] Update LLaVA docs (vllm-project#5437)

Co-authored-by: Roger Wang <[email protected]>

* [Kernel] Factor out epilogues from cutlass kernels (vllm-project#5391)

Co-authored-by: Michael Goin <[email protected]>
Co-authored-by: youkaichao <[email protected]>
Co-authored-by: zifeitong <[email protected]>
Co-authored-by: Robert Shaw <[email protected]>

* [MISC] Remove FP8 warning (vllm-project#5472)

Co-authored-by: Philipp Moritz <[email protected]>

* Seperate dev requirements into lint and test (vllm-project#5474)

* Revert "[Core] Remove unnecessary copies in flash attn backend" (vllm-project#5478)

* [misc] fix format.sh (vllm-project#5511)

* [CI/Build] Disable test_fp8.py (vllm-project#5508)

* [Kernel] Disable CUTLASS kernels for fp8 (vllm-project#5505)

* Add `cuda_device_count_stateless` (vllm-project#5473)

* [Hardware][Intel] Support CPU inference with AVX2 ISA (vllm-project#5452)

* [Misc] Fix arg names in quantizer script (vllm-project#5507)

* bump version to v0.5.0.post1 (vllm-project#5522)

* [CI/Build][Misc] Add CI that benchmarks vllm performance on those PRs with `perf-benchmarks` label (vllm-project#5073)

Co-authored-by: simon-mo <[email protected]>

* [CI/Build] Disable LLaVA-NeXT CPU test (vllm-project#5529)

* [Kernel] Fix CUTLASS 3.x custom broadcast load epilogue (vllm-project#5516)

* [Misc] Fix arg names (vllm-project#5524)

* [ Misc ] Rs/compressed tensors cleanup (vllm-project#5432)

Co-authored-by: mgoin <[email protected]>
Co-authored-by: Dipika Sikka <[email protected]>

* [Kernel] Suppress mma.sp warning on CUDA 12.5 and later (vllm-project#5401)

* [mis] fix flaky test of test_cuda_device_count_stateless (vllm-project#5546)

* [Core] Remove duplicate processing in async engine (vllm-project#5525)

* [misc][distributed] fix benign error in `is_in_the_same_node` (vllm-project#5512)

* [Docs] Add ZhenFund as a Sponsor (vllm-project#5548)

* [Doc] Update documentation on Tensorizer (vllm-project#5471)

* [Bugfix] Enable loading FP8 checkpoints for gpt_bigcode models  (vllm-project#5460)

Signed-off-by: Thomas Parnell <[email protected]>

* [Bugfix] Fix typo in Pallas backend (vllm-project#5558)

* [Core][Distributed] improve p2p cache generation (vllm-project#5528)

* Add ccache to amd (vllm-project#5555)

* [Core][Bugfix]: fix prefix caching for blockv2 (vllm-project#5364)

Signed-off-by: Lei Wen <[email protected]>
Co-authored-by: Lei Wen <[email protected]>

* [mypy] Enable type checking for test directory (vllm-project#5017)

* [CI/Build] Test both text and token IDs in batched OpenAI Completions API (vllm-project#5568)

* [misc] Do not allow to use lora with chunked prefill. (vllm-project#5538)

Co-authored-by: Cyrus Leung <[email protected]>

* add gptq_marlin test for bug report vllm-project#5088 (vllm-project#5145)

* [BugFix] Don't start a Ray cluster when not using Ray (vllm-project#5570)

* [Fix] Correct OpenAI batch response format (vllm-project#5554)

* Add basic correctness 2 GPU tests to 4 GPU pipeline (vllm-project#5518)

* [CI][BugFix] Flip is_quant_method_supported condition (vllm-project#5577)

* [build][misc] limit numpy version (vllm-project#5582)

* [Doc] add debugging tips for crash and multi-node debugging (vllm-project#5581)

* Fix w8a8 benchmark and add Llama-3-8B (vllm-project#5562)

* [Model] Rename Phi3 rope scaling type (vllm-project#5595)

* Correct alignment in the seq_len diagram. (vllm-project#5592)

Co-authored-by: Liqian Chen <[email protected]>

* [Kernel] `compressed-tensors` marlin 24 support (vllm-project#5435)

* [Misc] use AutoTokenizer for benchmark serving when vLLM not installed (vllm-project#5588)

* [Hardware][Intel GPU] Add Intel GPU(XPU) inference backend (vllm-project#3814)

Co-authored-by: Jiang Li <[email protected]>
Co-authored-by: Abhilash Majumder <[email protected]>
Co-authored-by: Abhilash Majumder <[email protected]>

* [CI/BUILD] Support non-AVX512 vLLM building and testing (vllm-project#5574)

* [CI] the readability of benchmarking and prepare for dashboard (vllm-project#5571)

[CI] Improve the readability of performance benchmarking results and prepare for upcoming performance dashboard (vllm-project#5571)

* [bugfix][distributed] fix 16 gpus local rank arrangement (vllm-project#5604)

* [Optimization] use a pool to reuse LogicalTokenBlock.token_ids (vllm-project#5584)

* [Bugfix] Fix KV head calculation for MPT models when using GQA (vllm-project#5142)

* [Fix] Use utf-8 encoding in entrypoints/openai/run_batch.py (vllm-project#5606)

* [Speculative Decoding 1/2 ] Add typical acceptance sampling as one of the sampling techniques in the verifier (vllm-project#5131)

* [Model] Initialize Phi-3-vision support (vllm-project#4986)

* [Kernel] Add punica dimensions for Granite 13b (vllm-project#5559)

Signed-off-by: Joe Runde <[email protected]>

* [misc][typo] fix typo (vllm-project#5620)

* [Misc] Fix typo (vllm-project#5618)

* [CI] Avoid naming different metrics with the same name in performance benchmark (vllm-project#5615)

* [bugfix][distributed] improve p2p capability test (vllm-project#5612)

[bugfix][distributed] do not error if two processes do not agree on p2p capability (vllm-project#5612)

* [Misc] Remove import from transformers logging (vllm-project#5625)

* [CI/Build][Misc] Update Pytest Marker for VLMs (vllm-project#5623)

* [ci] Deprecate original CI template (vllm-project#5624)

Signed-off-by: kevin <[email protected]>

* [Misc] Add OpenTelemetry support (vllm-project#4687)

This PR adds basic support for OpenTelemetry distributed tracing.
It includes changes to enable tracing functionality and improve monitoring capabilities.

I've also added a markdown with print-screens to guide users how to use this feature. You can find it here

* [Misc] Add channel-wise quantization support for w8a8 dynamic per token activation quantization (vllm-project#5542)

* [ci] Setup Release pipeline and build release wheels with cache (vllm-project#5610)

Signed-off-by: kevin <[email protected]>

* [Model] LoRA support added for command-r (vllm-project#5178)

* [Bugfix] Fix for inconsistent behaviour related to sampling and repetition penalties  (vllm-project#5639)

Signed-off-by: Thomas Parnell <[email protected]>

* [Doc] Added cerebrium as Integration option (vllm-project#5553)

* [Bugfix] Fix CUDA version check for mma warning suppression (vllm-project#5642)

* [Bugfix] Fix w8a8 benchmarks for int8 case (vllm-project#5643)

* [Bugfix] Fix Phi-3 Long RoPE scaling implementation (vllm-project#5628)

* [Bugfix] Added test for sampling repetition penalty bug. (vllm-project#5659)

Signed-off-by: Thomas Parnell <[email protected]>

* [Bugfix][CI/Build][AMD][ROCm]Fixed the cmake build bug which generate garbage on certain devices (vllm-project#5641)

* [misc][distributed] use 127.0.0.1 for single-node (vllm-project#5619)

* [Model] Add FP8 kv cache for Qwen2 (vllm-project#5656)

* [Bugfix] Fix sampling_params passed incorrectly in Phi3v example (vllm-project#5684)

* [Misc]Add param max-model-len in benchmark_latency.py (vllm-project#5629)

* [CI/Build] Add tqdm to dependencies (vllm-project#5680)

* [ci] Add A100 queue into AWS CI template (vllm-project#5648)

Signed-off-by: kevin <[email protected]>

* [Frontend][Bugfix] Fix preemption_mode -> preemption-mode for CLI arg in arg_utils.py (vllm-project#5688)

* [ci][distributed] add tests for custom allreduce (vllm-project#5689)

* [Bugfix] AsyncLLMEngine hangs with asyncio.run (vllm-project#5654)

* [Doc] Update docker references (vllm-project#5614)

Signed-off-by: Rafael Vasquez <[email protected]>

* [Misc] Add per channel support for static activation quantization; update w8a8 schemes to share base classes (vllm-project#5650)

* [ci] Limit num gpus if specified for A100 (vllm-project#5694)

Signed-off-by: kevin <[email protected]>

* [Misc] Improve conftest (vllm-project#5681)

* [Bugfix][Doc] FIx Duplicate Explicit Target Name Errors (vllm-project#5703)

* [Kernel] Update Cutlass int8 kernel configs for SM90 (vllm-project#5514)

Co-authored-by: Varun Sundar Rabindranath <[email protected]>

* [Model] Port over CLIPVisionModel for VLMs (vllm-project#5591)

* [Kernel] Update Cutlass int8 kernel configs for SM80 (vllm-project#5275)

Co-authored-by: Varun Sundar Rabindranath <[email protected]>

* [Bugfix] Fix the CUDA version check for FP8 support in the CUTLASS kernels (vllm-project#5715)

* [Frontend] Add FlexibleArgumentParser to support both underscore and dash in names (vllm-project#5718)

* [distributed][misc] use fork by default for mp (vllm-project#5669)

* [Model] MLPSpeculator speculative decoding support (vllm-project#4947)

Signed-off-by: Thomas Parnell <[email protected]>

Co-authored-by: Thomas Parnell <[email protected]>
Co-authored-by: Nick Hill <[email protected]>
Co-authored-by: Davis Wertheimer <[email protected]>

* [Kernel] Add punica dimension for Qwen2 LoRA (vllm-project#5441)

* [BugFix] Fix test_phi3v.py (vllm-project#5725)

* [Bugfix] Add  fully sharded layer for QKVParallelLinearWithLora (vllm-project#5665)

Co-authored-by: Antoni Baum <[email protected]>

* [Core][Distributed] add shm broadcast (vllm-project#5399)

Co-authored-by: Cody Yu <[email protected]>

* [Kernel][CPU] Add Quick `gelu` to CPU (vllm-project#5717)

* [Doc] Documentation on supported hardware for quantization methods (vllm-project#5745)

* [BugFix] exclude version 1.15.0 for modelscope (vllm-project#5668)

* [ci][test] fix ca test in main (vllm-project#5746)

* [LoRA] Add support for pinning lora adapters in the LRU cache (vllm-project#5603)

* [CI][Hardware][Intel GPU] add Intel GPU(XPU) ci pipeline (vllm-project#5616)

* [Model] Support Qwen-VL and Qwen-VL-Chat models with text-only inputs (vllm-project#5710)

Co-authored-by: Roger Wang <[email protected]>

* [Misc] Remove vllm-project#4789 workaround left in vllm/entrypoints/openai/run_batch.py (vllm-project#5756)

* [Bugfix] Fix pin_lora error in TPU executor (vllm-project#5760)

* [Docs][TPU] Add installation tip for TPU (vllm-project#5761)

* [core][distributed] improve shared memory broadcast (vllm-project#5754)

* [BugFix] [Kernel] Add Cutlass2x fallback kernels (vllm-project#5744)

Co-authored-by: Varun Sundar Rabindranath <[email protected]>

* [Distributed] Add send and recv helpers (vllm-project#5719)

* [Bugfix] Add phi3v resize for dynamic shape and fix torchvision requirement (vllm-project#5772)

* [doc][faq] add warning to download models for every nodes (vllm-project#5783)

* post-rebase api adjustments

* [Doc] Add "Suggest edit" button to doc pages (vllm-project#5789)

* [Doc] Add Phi-3-medium to list of supported models (vllm-project#5788)

* [Bugfix] Fix FlexibleArgumentParser replaces _ with - for actual args (vllm-project#5795)

* [ci] Remove aws template (vllm-project#5757)

Signed-off-by: kevin <[email protected]>

* [Doc] Add notice about breaking changes to VLMs (vllm-project#5818)

* [Speculative Decoding] Support draft model on different tensor-parallel size than target model (vllm-project#5414)

* add pin_lora to habana components

* add WA for model loader

* fix api mismatches with ray

* tensor parallel fixes

* workers cpu alignment fix

* [Misc] Remove useless code in cpu_worker (vllm-project#5824)

* prefill/decode metadata fixes

* [Core] Add fault tolerance for `RayTokenizerGroupPool` (vllm-project#5748)

* re-enable attn metadata trimming

* worker_use_ray fix

* [doc][distributed] add both gloo and nccl tests (vllm-project#5834)

* [CI/Build] Add unit testing for FlexibleArgumentParser (vllm-project#5798)

* [Misc] Update `w4a16` `compressed-tensors` support to include `w8a16` (vllm-project#5794)

* [Hardware][TPU] Refactor TPU backend (vllm-project#5831)

* [Hardware][AMD][CI/Build][Doc] Upgrade to ROCm 6.1, Dockerfile improvements, test fixes (vllm-project#5422)

* [Hardware][TPU] Raise errors for unsupported sampling params (vllm-project#5850)

* [CI/Build] Add E2E tests for MLPSpeculator (vllm-project#5791)

Signed-off-by: Thomas Parnell <[email protected]>

* [Bugfix] Fix assertion in NeuronExecutor (vllm-project#5841)

* [Core] Refactor Worker and ModelRunner to consolidate control plane communication (vllm-project#5408)

Signed-off-by: Stephanie Wang <[email protected]>
Signed-off-by: Stephanie <[email protected]>
Co-authored-by: Stephanie <[email protected]>

* [Misc][Doc] Add Example of using OpenAI Server with VLM (vllm-project#5832)

* [bugfix][distributed] fix shm broadcast when the queue size is full (vllm-project#5801)

* [Bugfix] Fix embedding to support 2D inputs (vllm-project#5829)

* [Bugfix][TPU] Fix KV cache size calculation (vllm-project#5860)

* [CI/Build] Refactor image test assets (vllm-project#5821)

* [Kernel] Adding bias epilogue support for `cutlass_scaled_mm` (vllm-project#5560)

Co-authored-by: Chih-Chieh-Yang <[email protected]>
Co-authored-by: Lucas Wilkinson <[email protected]>

* [Frontend] Add tokenize/detokenize endpoints (vllm-project#5054)

* [Hardware][TPU] Support parallel sampling & Swapping (vllm-project#5855)

* [Bugfix][TPU] Fix CPU cache allocation (vllm-project#5869)

* Support CPU inference with VSX PowerPC ISA (vllm-project#5652)

* [doc] update usage of env var to avoid conflict (vllm-project#5873)

* [Misc] Add example for LLaVA-NeXT (vllm-project#5879)

* [BugFix] Fix cuda graph for MLPSpeculator (vllm-project#5875)

Co-authored-by: Abhinav Goyal <[email protected]>

* [Doc] Add note about context length in Phi-3-Vision example (vllm-project#5887)

* [VLM][Bugfix] Make sure that `multi_modal_kwargs` is broadcasted properly (vllm-project#5880)

Signed-off-by: Xiaowei Jiang <[email protected]>

* [Model] Add base class for LoRA-supported models (vllm-project#5018)

* [Bugfix] Fix img_sizes Parsing in Phi3-Vision (vllm-project#5888)

* [CI/Build] [1/3] Reorganize entrypoints tests (vllm-project#5526)

* add collective crash WA

* add comment to the weird mark_step

* [Model][Bugfix] Implicit model flags and reenable Phi-3-Vision (vllm-project#5896)

* [doc][misc] add note for Kubernetes users (vllm-project#5916)

* [BugFix] Fix `MLPSpeculator` handling of `num_speculative_tokens` (vllm-project#5876)

* [BugFix] Fix `min_tokens` behaviour for multiple eos tokens (vllm-project#5849)

* [CI/Build] Fix Args for `_get_logits_warper` in Sampler Test (vllm-project#5922)

* [Model] Add Gemma 2 (vllm-project#5908)

* [core][misc] remove logical block (vllm-project#5882)

* [Kernel][ROCm][AMD] fused_moe Triton configs v2 for mi300X (vllm-project#5932)

* [Hardware][TPU] Optimize KV cache swapping (vllm-project#5878)

* [VLM][BugFix] Make sure that `multi_modal_kwargs` can broadcast properly with ring buffer. (vllm-project#5905)

Signed-off-by: Xiaowei Jiang <[email protected]>
Co-authored-by: Roger Wang <[email protected]>

* [Bugfix][Hardware][Intel CPU] Fix unpassed multi_modal_kwargs for CPU runner (vllm-project#5956)

* [Core] Registry for processing model inputs (vllm-project#5214)

Co-authored-by: ywang96 <[email protected]>

* Unmark fused_moe config json file as executable (vllm-project#5960)

* [Hardware][Intel] OpenVINO vLLM backend (vllm-project#5379)

* [Bugfix] Better error message for MLPSpeculator when `num_speculative_tokens` is set too high (vllm-project#5894)

Signed-off-by: Thomas Parnell <[email protected]>

* [CI/Build] [2/3] Reorganize entrypoints tests (vllm-project#5904)

* [Distributed] Make it clear that % should not be in tensor dict keys. (vllm-project#5927)

Signed-off-by: Xiaowei Jiang <[email protected]>

* [Spec Decode] Introduce DraftModelRunner (vllm-project#5799)

* [Bugfix] Fix compute datatype for cutlass 3.x epilogues (vllm-project#5931)

* [ Misc ] Remove `fp8_shard_indexer` from Col/Row Parallel Linear (Simplify Weight Loading) (vllm-project#5928)

Co-authored-by: Robert Shaw <rshaw@neuralmagic>

* [ Bugfix ] Enabling Loading Models With Fused QKV/MLP on Disk with FP8 (vllm-project#5921)

Co-authored-by: Robert Shaw <rshaw@neuralmagic>

* Support Deepseek-V2 (vllm-project#4650)

Co-authored-by: Philipp Moritz <[email protected]>

* [Bugfix] Only add `Attention.kv_scale` if kv cache quantization is enabled (vllm-project#5936)

* Unmark more files as executable (vllm-project#5962)

* [Bugfix] Fix Engine Failing After Invalid Request - AsyncEngineDeadError (vllm-project#5963)

Co-authored-by: Robert Shaw <rshaw@neuralmagic>

* [Kernel] Flashinfer for prefill & decode, with Cudagraph support for decode (vllm-project#4628)

Co-authored-by: LiuXiaoxuanPKU <[email protected]>, bong-furiosa <[email protected]>

* [Bugfix][TPU] Fix TPU sampler output (vllm-project#5978)

* [Bugfix][TPU] Fix pad slot id (vllm-project#5977)

* [Bugfix] fix missing last itl in openai completions benchmark (vllm-project#5926)

* [Misc] Extend vLLM Metrics logging API (vllm-project#5925)

Co-authored-by: Antoni Baum <[email protected]>

* [Kernel] Add punica dimensions for Granite 3b and 8b (vllm-project#5930)

Signed-off-by: Joe Runde <[email protected]>

* [Bugfix] Fix precisions in Gemma 1 (vllm-project#5913)

* [Misc] Update Phi-3-Vision Example (vllm-project#5981)

Co-authored-by: Cyrus Leung <[email protected]>

* [Bugfix] Support `eos_token_id` from `config.json` (vllm-project#5954)

* [Core] Optimize `SequenceStatus.is_finished` by switching to IntEnum (vllm-project#5974)

* [Kernel] Raise an exception in MoE kernel if the batch size is larger then 65k (vllm-project#5939)

* [ CI/Build ] Added E2E Test For Compressed Tensors (vllm-project#5839)

Co-authored-by: Michael Goin <[email protected]>
Co-authored-by: Robert Shaw <rshaw@neuralmagic>

* [CI/Build] Add TP test for vision models (vllm-project#5892)

* [ CI/Build ] LM Eval Harness Based CI Testing (vllm-project#5838)

Co-authored-by: Robert Shaw <rshaw@neuralmagic>

* [Bugfix][CI/Build][Hardware][AMD] Install matching torchvision to fix AMD tests (vllm-project#5949)

* [CI/Build] Temporarily Remove Phi3-Vision from TP Test (vllm-project#5989)

* [CI/Build] Reuse code for checking output consistency (vllm-project#5988)

* [CI/Build] [3/3] Reorganize entrypoints tests (vllm-project#5966)

* [ci][distributed] fix device count call

[ci][distributed] fix some cuda init that makes it necessary to use spawn (vllm-project#5991)

* [Frontend]: Support base64 embedding (vllm-project#5935)

Co-authored-by: Cyrus Leung <[email protected]>

* [Lora] Use safetensor keys instead of adapter_config.json to find unexpected modules.  (vllm-project#5909)

Co-authored-by: sang <[email protected]>

* [ CI ] Temporarily Disable Large LM-Eval Tests (vllm-project#6005)

Co-authored-by: [email protected] <rshaw@neuralmagic>

* [Misc] Fix `get_min_capability` (vllm-project#5971)

* [ Misc ] Refactor w8a8 to use `process_weights_after_load` (Simplify Weight Loading) (vllm-project#5940)

Co-authored-by: Robert Shaw <rshaw@neuralmagic>

* [misc][cuda] use nvml to avoid accidentally cuda initialization (vllm-project#6007)

* [Speculative Decoding 2/2 ] Integrate typical acceptance sampler into Spec Decode Worker (vllm-project#5348)

* Revert test changes

* cleanup

* llm engine cleanup

* utils.py cleanup

* custom ops refactor

* move xops to ops

* remove vllm/hpu/attn_bias.py

* whitespace fix

* revert accidental changes in rmsnorm

* Fix hpugraph hashing

* add trim_attn_metadata comment

* fix prompt bucketing:

* [ CI ] Re-enable Large Model LM Eval (vllm-project#6031)

* [doc][misc] remove deprecated api server in doc (vllm-project#6037)

* [Misc] update benchmark backend for scalellm (vllm-project#6018)

* [doc][misc] further lower visibility of simple api server (vllm-project#6041)

Co-authored-by: Simon Mo <[email protected]>

* [Bugfix] Use RayActorError for older versions of Ray in  RayTokenizerGroupPool (vllm-project#6039)

* [Bugfix] adding chunking mechanism to fused_moe to handle large inputs (vllm-project#6029)

* add FAQ doc under 'serving' (vllm-project#5946)

* [Bugfix][Doc] Fix Doc Formatting (vllm-project#6048)

* [Bugfix] Add explicit `end_forward` calls to flashinfer (vllm-project#6044)

* [BugFix] Ensure worker model loop is always stopped at the right time (vllm-project#5987)

* [Frontend] Relax api url assertion for openai benchmarking (vllm-project#6046)

* [Model] Changes to MLPSpeculator to support tie_weights and input_scale (vllm-project#5965)

Signed-off-by: Thomas Parnell <[email protected]>
Co-authored-by: Joshua Rosenkranz <[email protected]>

* [Core] Optimize block_manager_v2 vs block_manager_v1 (to make V2 default)  (vllm-project#5602)

* [Frontend] Add template related params to request (vllm-project#5709)

* [VLM] Remove `image_input_type` from VLM config (vllm-project#5852)

Signed-off-by: Xiaowei Jiang <[email protected]>
Co-authored-by: Cyrus Leung <[email protected]>
Co-authored-by: Roger Wang <[email protected]>

* [Doc] Reinstate doc dependencies (vllm-project#6061)

* guard model loader wa for hpu

---------

Signed-off-by: Thomas Parnell <[email protected]>
Signed-off-by: Lei Wen <[email protected]>
Signed-off-by: Joe Runde <[email protected]>
Signed-off-by: kevin <[email protected]>
Signed-off-by: Rafael Vasquez <[email protected]>
Signed-off-by: Stephanie Wang <[email protected]>
Signed-off-by: Stephanie <[email protected]>
Signed-off-by: Xiaowei Jiang <[email protected]>
Signed-off-by: Joe Runde <[email protected]>
Co-authored-by: Li, Jiang <[email protected]>
Co-authored-by: Jianan Gu <[email protected]>
Co-authored-by: Woosuk Kwon <[email protected]>
Co-authored-by: Cyrus Leung <[email protected]>
Co-authored-by: Roger Wang <[email protected]>
Co-authored-by: Tyler Michael Smith <[email protected]>
Co-authored-by: Michael Goin <[email protected]>
Co-authored-by: youkaichao <[email protected]>
Co-authored-by: zifeitong <[email protected]>
Co-authored-by: Robert Shaw <[email protected]>
Co-authored-by: Cody Yu <[email protected]>
Co-authored-by: Philipp Moritz <[email protected]>
Co-authored-by: Antoni Baum <[email protected]>
Co-authored-by: Jie Fu (傅杰) <[email protected]>
Co-authored-by: Allen.Dou <[email protected]>
Co-authored-by: Simon Mo <[email protected]>
Co-authored-by: Kuntai Du <[email protected]>
Co-authored-by: Dipika Sikka <[email protected]>
Co-authored-by: Sanger Steel <[email protected]>
Co-authored-by: Thomas Parnell <[email protected]>
Co-authored-by: leiwen83 <[email protected]>
Co-authored-by: Lei Wen <[email protected]>
Co-authored-by: SangBin Cho <[email protected]>
Co-authored-by: Alexander Matveev <[email protected]>
Co-authored-by: Nick Hill <[email protected]>
Co-authored-by: Amit Garg <[email protected]>
Co-authored-by: Charles Riggins <[email protected]>
Co-authored-by: Liqian Chen <[email protected]>
Co-authored-by: zhyncs <[email protected]>
Co-authored-by: Kunshang Ji <[email protected]>
Co-authored-by: Abhilash Majumder <[email protected]>
Co-authored-by: Abhilash Majumder <[email protected]>
Co-authored-by: Bruce Fontaine <[email protected]>
Co-authored-by: zifeitong <[email protected]>
Co-authored-by: sroy745 <[email protected]>
Co-authored-by: Isotr0py <[email protected]>
Co-authored-by: Joe Runde <[email protected]>
Co-authored-by: Chang Su <[email protected]>
Co-authored-by: Roger Wang <[email protected]>
Co-authored-by: Kevin H. Luu <[email protected]>
Co-authored-by: Ronen Schaffer <[email protected]>
Co-authored-by: sergey-tinkoff <[email protected]>
Co-authored-by: milo157 <[email protected]>
Co-authored-by: Shukant Pal <[email protected]>
Co-authored-by: Hongxia Yang <[email protected]>
Co-authored-by: DearPlanet <[email protected]>
Co-authored-by: Rafael Vasquez <[email protected]>
Co-authored-by: Varun Sundar Rabindranath <[email protected]>
Co-authored-by: Varun Sundar Rabindranath <[email protected]>
Co-authored-by: Joshua Rosenkranz <[email protected]>
Co-authored-by: Davis Wertheimer <[email protected]>
Co-authored-by: Jinzhen Lin <[email protected]>
Co-authored-by: Jee Li <[email protected]>
Co-authored-by: rohithkrn <[email protected]>
Co-authored-by: Murali Andoorveedu <[email protected]>
Co-authored-by: Woo-Yeon Lee <[email protected]>
Co-authored-by: Matt Wong <[email protected]>
Co-authored-by: aws-patlange <[email protected]>
Co-authored-by: Stephanie Wang <[email protected]>
Co-authored-by: Stephanie <[email protected]>
Co-authored-by: Luka Govedič <[email protected]>
Co-authored-by: Chih-Chieh-Yang <[email protected]>
Co-authored-by: Lucas Wilkinson <[email protected]>
Co-authored-by: sasha0552 <[email protected]>
Co-authored-by: Chip Kerchner <[email protected]>
Co-authored-by: Abhinav Goyal <[email protected]>
Co-authored-by: xwjiang2010 <[email protected]>
Co-authored-by: Divakar Verma <[email protected]>
Co-authored-by: Ilya Lavrenov <[email protected]>
Co-authored-by: Robert Shaw <rshaw@neuralmagic>
Co-authored-by: wangding zeng <[email protected]>
Co-authored-by: Lily Liu <[email protected]>
Co-authored-by: LiuXiaoxuanPKU <[email protected]>, bong-furiosa <[email protected]>
Co-authored-by: mcalman <[email protected]>
Co-authored-by: William Lin <[email protected]>
Co-authored-by: Cyrus Leung <[email protected]>
Co-authored-by: llmpros <[email protected]>
Co-authored-by: sang <[email protected]>
Co-authored-by: Avshalom Manevich <[email protected]>
Co-authored-by: James Whedbee <[email protected]>
Co-authored-by: Joshua Rosenkranz <[email protected]>
Co-authored-by: danieljannai21 <[email protected]>
xjpang pushed a commit to xjpang/vllm that referenced this pull request Jul 8, 2024
xjpang pushed a commit to xjpang/vllm that referenced this pull request Jul 24, 2024
Alvant pushed a commit to compressa-ai/vllm that referenced this pull request Oct 26, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[Help wanted] [Spec decode]: Increase acceptance rate via Medusa's typical acceptance
3 participants