Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Align with huggingface Top K sampling #753

Merged
merged 3 commits into from
Aug 15, 2023

Conversation

Abraham-Xu
Copy link
Contributor

main modifications
all modifications are in vllm/model_executor/layers/sampler.py

  1. Reverse the order of softmax and _apply_top_p_top_k function in forward function of class Sampler. So the _apply_top_p_top_k will use logits as input instead of probabilites.
  2. In _apply_top_p_top_k function, it computes the temporary probabilities ( softmax of logits ) firstly for the top_p process to locate the indexes whose cumulative probability exceeds the probability p. ( ATTENTION: the output of _apply_top_p_top_k function is still logits instead of probabilities. This is the main difference from the original code. )
  3. Change the top_p and top_k masked probability value from 0 to -float("Inf").
  4. In _sample_from_prompt function, use torch.multinomial without parameter "replacement=True".

tested result
The input probabilites distribution of torch.multinomial for the first token is the same as for huggingface/transformers under the same weights and input sentence.

test code:
huggingface/transformers

PATH_TO_CONVERTED_WEIGHTS="/data/xutianci/llama_hf/"
PATH_TO_CONVERTED_TOKENIZER="/data/xutianci/vllm/llama-tokenizer/"

from transformers import AutoTokenizer,  AutoModelForCausalLM
import torch
import numpy as np

model = AutoModelForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)

prompt = "Hello, my name is"
inputs = tokenizer(prompt, return_tensors="pt")
print(f'inputs={inputs}')

# Generate
generate_ids = model.generate(inputs.input_ids, do_sample=True, max_new_tokens=1)
print(f'generate_ids={generate_ids}')
output = tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
print(f'output={output}')

vllm

from vllm import LLM, SamplingParams
import torch
import numpy as np

prompts = [
    "Hello, my name is",
]
sampling_params = SamplingParams(top_k=50, max_tokens=1)

llm = LLM(model="/data/xutianci/llama_hf/", tokenizer="/data/xutianci/vllm/llama-tokenizer/")

outputs = llm.generate(prompts, sampling_params)

# Print the outputs.
for output in outputs:
    prompt = output.prompt
    generated_text = output.outputs[0].text
    print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")

related issue
https://github.com/vllm-project/vllm/issues/718

@Abraham-Xu Abraham-Xu changed the title Align with huggingface greedy search Align with huggingface Top K sampling Aug 13, 2023
Copy link
Member

@zhuohan123 zhuohan123 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thanks for the fix! Committed some code to make sure the format is correct and to rename some variables.

@zhuohan123 zhuohan123 merged commit d174437 into vllm-project:main Aug 15, 2023
2 checks passed
@Abraham-Xu
Copy link
Contributor Author

LGTM! Thanks for the fix! Committed some code to make sure the format is correct and to rename some variables.

Glad to be merged. Thanks for checking and revising! The renaming made the algorithm more self-explanatory!

randxie pushed a commit to randxie/vllm that referenced this pull request Aug 29, 2023
hongxiayang pushed a commit to hongxiayang/vllm that referenced this pull request Feb 13, 2024
sjchoi1 pushed a commit to casys-kaist-internal/vllm that referenced this pull request May 7, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants