Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support repetition_penalty #1424

Merged
merged 2 commits into from
Oct 29, 2023
Merged

Support repetition_penalty #1424

merged 2 commits into from
Oct 29, 2023

Conversation

beginlner
Copy link
Contributor

It has the same behavior as this.

@WoosukKwon
Copy link
Collaborator

Hi @beginlner, Could you tell us how it is different from #1392?

@beginlner
Copy link
Contributor Author

beginlner commented Oct 23, 2023

Hi @beginlner, Could you tell us how it is different from #1392?

Hi, I think we implemented completely identical functions.

Copy link
Member

@zhuohan123 zhuohan123 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thank you for your contribution! Added a small style fix.

@zhuohan123 zhuohan123 merged commit 69be658 into vllm-project:main Oct 29, 2023
2 checks passed
@resorcap
Copy link

resorcap commented Oct 30, 2023

In huggingface, input_token_ids contains prompt tokens.
But this pr only penalty for generate_tokens. This behavior is not the same.
@zhuohan123 @beginlner

@zwj536
Copy link

zwj536 commented Oct 31, 2023

I'm getting inconsistent results between HF and vllm with llama-7b @beginlner @WoosukKwon

## hf
import torch
from transformers import (
    AutoTokenizer, 
    AutoModelForCausalLM
)

MODEL_NAME = "huggyllama/llama-7b"
#MODEL_NAME = "huggyllama/llama-13b"

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, torch_dtype=torch.float16).cuda()  

prompt = [
    "Hello, what is apple?  ",
]
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.cuda()
generated_ids = model.generate(
    input_ids, 
    do_sample=False, 
    repetition_penalty=1.2, 
    max_new_tokens=64,
)

texts = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
""" OUTPUT
Hello, what is orange?   \nHello, what are you doing here?\n— _The Wizard of Oz_ , L. Frank Baum (1900)\n#  **What Is a Computer?**\nA computer is an electronic device that can store and process information in the'
"""

## vllm
from vllm import LLM, SamplingParams

# Sample prompts.
prompts = [
    "Hello, what is apple?  ",
]
sampling_params = SamplingParams(
    temperature=0,
    max_tokens=64, 
    #frequency_penalty=1.2,
    #presence_penalty=1.2,
    repetition_penalty=1.2,
)

MODEL_NAME = "huggyllama/llama-7b"
#MODEL_NAME = "huggyllama/llama-13b"

llm = LLM(
    model=MODEL_NAME,
    trust_remote_code=True,
)
outputs = llm.generate(prompts, sampling_params)
for output in outputs:
    prompt = output.prompt
    generated_text = output.outputs[0].text
    print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")

""" OUTPUT
Hello, what is apple?  \nHello, what is apple?  \nHello, what is apple?  \nHello, what is apple?  \nWhat\'s that you say?  What\'s that you say?"  And so on. The child will repeat the question until he gets an answer
"""

@resorcap
Copy link

resorcap commented Nov 1, 2023

I'm getting inconsistent results between HF and vllm with llama-7b @beginlner @WoosukKwon

## hf
import torch
from transformers import (
    AutoTokenizer, 
    AutoModelForCausalLM
)

MODEL_NAME = "huggyllama/llama-7b"
#MODEL_NAME = "huggyllama/llama-13b"

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, torch_dtype=torch.float16).cuda()  

prompt = [
    "Hello, what is apple?  ",
]
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.cuda()
generated_ids = model.generate(
    input_ids, 
    do_sample=False, 
    repetition_penalty=1.2, 
    max_new_tokens=64,
)

texts = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
""" OUTPUT
Hello, what is orange?   \nHello, what are you doing here?\n— _The Wizard of Oz_ , L. Frank Baum (1900)\n#  **What Is a Computer?**\nA computer is an electronic device that can store and process information in the'
"""

## vllm
from vllm import LLM, SamplingParams

# Sample prompts.
prompts = [
    "Hello, what is apple?  ",
]
sampling_params = SamplingParams(
    temperature=0,
    max_tokens=64, 
    frequency_penalty=1.2
)

MODEL_NAME = "huggyllama/llama-7b"
#MODEL_NAME = "huggyllama/llama-13b"

llm = LLM(
    model=MODEL_NAME,
    trust_remote_code=True,
)
outputs = llm.generate(prompts, sampling_params)
for output in outputs:
    prompt = output.prompt
    generated_text = output.outputs[0].text
    print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")

""" OUTPUT
Hello, what is apple?  \nHello, what is apple?  \nHello, what is apple?  \nHello, what is apple?  \nWhat\'s that you say?  What\'s that you say?"  And so on. The child will repeat the question until he gets an answer
"""

Use repetition_penalty instead of frequency_penalty in vllm. And another defect is input_ids inconsistent.

@zwj536
Copy link

zwj536 commented Nov 1, 2023

I'm getting inconsistent results between HF and vllm with llama-7b @beginlner @WoosukKwon

## hf
import torch
from transformers import (
    AutoTokenizer, 
    AutoModelForCausalLM
)

MODEL_NAME = "huggyllama/llama-7b"
#MODEL_NAME = "huggyllama/llama-13b"

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, torch_dtype=torch.float16).cuda()  

prompt = [
    "Hello, what is apple?  ",
]
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.cuda()
generated_ids = model.generate(
    input_ids, 
    do_sample=False, 
    repetition_penalty=1.2, 
    max_new_tokens=64,
)

texts = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
""" OUTPUT
Hello, what is orange?   \nHello, what are you doing here?\n— _The Wizard of Oz_ , L. Frank Baum (1900)\n#  **What Is a Computer?**\nA computer is an electronic device that can store and process information in the'
"""

## vllm
from vllm import LLM, SamplingParams

# Sample prompts.
prompts = [
    "Hello, what is apple?  ",
]
sampling_params = SamplingParams(
    temperature=0,
    max_tokens=64, 
    frequency_penalty=1.2
)

MODEL_NAME = "huggyllama/llama-7b"
#MODEL_NAME = "huggyllama/llama-13b"

llm = LLM(
    model=MODEL_NAME,
    trust_remote_code=True,
)
outputs = llm.generate(prompts, sampling_params)
for output in outputs:
    prompt = output.prompt
    generated_text = output.outputs[0].text
    print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")

""" OUTPUT
Hello, what is apple?  \nHello, what is apple?  \nHello, what is apple?  \nHello, what is apple?  \nWhat\'s that you say?  What\'s that you say?"  And so on. The child will repeat the question until he gets an answer
"""

Use repetition_penalty instead of frequency_penalty in vllm. And another defect is input_ids inconsistent.

@resorcap Hi, the result of repetition_penalty is still inconsistent

"""
# hf
# repetition_penalty=1.2
Hello, what is orange?   \nHello, what are you doing here?\n— _The Wizard of Oz_ , L. Frank Baum (1900)\n#  **What Is a Computer?**\nA computer is an electronic device that can store and process information in the

# vllm
# frequency_penalty=1.2
Hello, what is apple?  \nHello, what is apple?  \nHello, what is apple?  \nHello, what is apple?  \nWhat\'s that you say?  What\'s that you say?"  And so on. The child will repeat the question until he gets an answer

# presence_penalty=1.2
Hello, what is apple?  \nHello, what is apple?  \nHello, what is apple?  \nHello, what is apple?  \nHello, what is apple?  \nHello, what is apple?  \nHello, what is apple?  \nHello, what is apple? 

# repetition_penalty=1.2,
Hello, what is apple?  \nWhat's that you say?  It's an orange.   \nNo it isn't! No it isn't! I know a banana when I see one and this ain't no banana. This here is an apple all right but not
"""

@beginlner
Copy link
Contributor Author

Hi @resorcap @zwj536, thank you for the correction, I have fixed it in #1577.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants