Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tutorial for Batch Decoding and Obtaining Log Probs #81

Closed
aflah02 opened this issue Jan 23, 2024 · 24 comments
Closed

Tutorial for Batch Decoding and Obtaining Log Probs #81

aflah02 opened this issue Jan 23, 2024 · 24 comments

Comments

@aflah02
Copy link

aflah02 commented Jan 23, 2024

Hi
Thanks for the great library
I have a usecase which I think will benefit a lot from Radix Attention. I need to obtain log probs for around a 100K sequences which can be binned into groups of 100 having a similar prefix like 'Wikipedia originated in' and having 100 different suffixes. I do not need to generate anything and I only need the log probs for the input. Is there a tutorial for such a usecase?

@merrymercy
Copy link
Contributor

merrymercy commented Jan 23, 2024

Yes, RadixAttention can help your case a lot. We do not have this interface/tutorial ready, but I can easily make one for you.

What do you need specifically?

Given an input token sequence [a, b, c, d]

  1. the probably of all tokens at each location, which is a tensor of the shape [4, 32000]
  2. the logprob for each selected token, which is a tensor of the shape [4, 1]
  3. the logprob for the whole sequence (sum), which is a scalar.
  4. the logprob for the whole sequence (mean), which is a scalar.

@aflah02
Copy link
Author

aflah02 commented Jan 23, 2024

@merrymercy Thanks for the quick response
I just need the log probs of the selected tokens in the sequence given at each position (Option 2). Here is what I do currently with vLLM -

sampling_params = SamplingParams(max_tokens=1, prompt_logprobs=1)
llm = LLM(model= model_path, tokenizer= model_path)
outputs = llm.generate(all_input_texts, sampling_params=sampling_params)
# to get log probs for ith sample
outputs[i].prompt_logprobs

@merrymercy
Copy link
Contributor

merrymercy commented Jan 23, 2024

Do you also need the logprob for the shared prefix?
Unfortunately, we do not store the logprob for the prefix. We only store the KV cache, so it is not very easy to also return the logprob for the shared prefix.

@aflah02
Copy link
Author

aflah02 commented Jan 23, 2024

@merrymercy Nope I don't need it for the shared prefix, only for the non-shared portions

Like for example if the sentences are "Wikipedia originated in India", "Wikipedia originated in U.S.A", etc. I need it only for "India", "U.S.A" etc.

@merrymercy
Copy link
Contributor

Great! This is easier. What do you do with the logprob? Do you compute the normalized logprob for selecting purposes?
Actually, the choices in sglang is implemented by comparing the logprob of these choices. (

s += "I need to use a " + sgl.gen("tool", choices=["calculator", "web browser"]) + ". "
)

@aflah02
Copy link
Author

aflah02 commented Jan 23, 2024

Yeah I use the normalized logprobs and store them for later analysis. This example looks very relevant. If I understand correctly, something like this would populate the most likely choice from the options, right?

@sgl.function
def tool_use(s, question):
    s += question
    s += sgl.gen("tool", choices=["U.S.A", "India"])

runtime = Runtime(model_path='Model_Saves/teknium--OpenHermes-2.5-Mistral-7B')
set_default_backend(runtime)

driver_tool_use()

How can I now access the logprobs as well?

@merrymercy
Copy link
Contributor

merrymercy commented Jan 23, 2024

Yes, it will populate the most likely choices from the options based on the normalized logprobs (sum of the logprobs divided by the number of tokens)

I am working on some examples and interface updates for you to easily get the logprobs. I will upload them very soon!

@aflah02
Copy link
Author

aflah02 commented Jan 23, 2024

Thank you for taking out the time! That would be really helpful!

@merrymercy
Copy link
Contributor

@aflah02 Could you try this with the main branch? Does it meet your needs?

https://github.com/sgl-project/sglang/blob/main/examples/usage/choices_logprob.py

Output:

questions: What is 5 + 5?
choice: calculator
logprobs of choice 1 [-4.4240264892578125, -0.0002205615019192919]
logprobs of choice 2 [-12.680765151977539, -0.08715292066335678]
--------------------------------------------------
questions: What is 5 + 6?
choice: calculator
logprobs of choice 1 [-5.266744136810303, -0.00022354240354616195]
logprobs of choice 2 [-12.893030166625977, -0.09100916236639023]
--------------------------------------------------
questions: Who is Michael Jordan?
choice: search engine
logprobs of choice 1 [-10.858648300170898, -0.002947198925539851]
logprobs of choice 2 [-6.427036762237549, -0.00434991717338562]
--------------------------------------------------

@aflah02
Copy link
Author

aflah02 commented Jan 23, 2024

Thanks a lot for sharing this! I need to install from source and then try this right?
I'll do this by tomorrow and let you know!

@merrymercy
Copy link
Contributor

Yes

@Ja1Zhou
Copy link
Contributor

Ja1Zhou commented Jan 25, 2024

Hi, I see that there is a parameter that can be passed by requests here to return logprobs

req.return_logprob = recv_req.return_logprob
req.logprob_start_len = recv_req.logprob_start_len

Is there a way that we could specify this from the python end with sgl.gen or SglFunction.run?

@merrymercy
Copy link
Contributor

merrymercy commented Jan 25, 2024

@Ja1Zhou It is possible. I can work on an interface for this later.
What kind of logprob do you need?

Do you need the logprob of prompts, the logprob of generation, the logprob of selected tokens, or the logprob of top-5 tokens?

@Ja1Zhou
Copy link
Contributor

Ja1Zhou commented Jan 25, 2024

Many thanks! Currently I would need logprobs of top-5 (or top-n passed as parameter) tokens for each generated token. The scenario is essentially the same as passing the top_logprobs=n parameter to openai api.

An example would be the top_logprobs field in this discussion.

One related question would be if the regex constraint is going to affect the top_logprobs returned?

Thanks again for the swift reply. I would also love to look into supporting this logprobs feature!

@merrymercy
Copy link
Contributor

Great! If you are interested, please go ahead. Our bandwidth is limited so your help would be great.

You can start from

response = requests.post(
url + "/generate",
json={
"text": "The capital of France is",
"sampling_params": {
"temperature": 0,
"max_new_tokens": 32,
},
# "return_logprob": True,
# "logprob_start_len": 0,
},
)
print(response.json())

@aflah02
Copy link
Author

aflah02 commented Jan 26, 2024

Sorry for the delay @merrymercy
Thanks a lot!
This works really well.
Leaving the issue open though as it seems there's another ongoing discussion, but my original issues have been resolved.

@mlinegar
Copy link

@merrymercy in your example, it doesn't seem like the sum exp of the log probs sums to one. I've been running this locally with Mistral 7B:

# launch server
# python -m sglang.launch_server --model-path /user/models/Mistral-7B-Instruct-v0.2-AWQ --port 30000
import sglang as sgl
set_default_backend(RuntimeEndpoint("http://localhost:30000"))
@sgl.function
def tool_use(s, question):
    s += "To answer this question: " + question + ", "
    s += "I need to use a " + sgl.gen("tool", choices=["calculator", "search engine"])
# Run one case
question = "What is 5 + 5?"
state = tool_use.run(question)
print("questions:", question)
print("choice:", state["tool"])
meta_info = state.get_meta_info("tool")
print("logprobs of choice 1", meta_info["prompt_logprob"][0])
print("logprobs of choice 2", meta_info["prompt_logprob"][1])
print("probs of choice 1", np.exp(meta_info["prompt_logprob"][0]))
print("probs of choice 2", np.exp(meta_info["prompt_logprob"][1]))
print("prob sum", np.exp(meta_info["prompt_logprob"][0][0]) + np.exp(meta_info["prompt_logprob"][0][1]))
print("prob sum", np.exp(meta_info["prompt_logprob"][1][0]) + np.exp(meta_info["prompt_logprob"][1][1]))
print('-' * 50)
# Run a batch
questions = [
    "What is 5 + 6?",
    "Who is Michael Jordan?",
]
states = tool_use.run_batch([{"question": q} for q in questions])
for question, state in zip(questions, states):
    print("questions:", question)
    print("choice:", state["tool"])
    meta_info = state.get_meta_info("tool")
    print("logprobs of choice 1", meta_info["prompt_logprob"][0])
    print("logprobs of choice 2", meta_info["prompt_logprob"][1])
    print("probs of choice 1", np.exp(meta_info["prompt_logprob"][0]))
    print("probs of choice 2", np.exp(meta_info["prompt_logprob"][1]))
    print("prob sum", np.exp(meta_info["prompt_logprob"][0][0]) + np.exp(meta_info["prompt_logprob"][0][1]))
    print("prob sum", np.exp(meta_info["prompt_logprob"][1][0]) + np.exp(meta_info["prompt_logprob"][1][1]))    
    print('-' * 50)

With output:

questions: What is 5 + 5?
choice: calculator
logprobs of choice 1 [-8.053388595581055, -0.011829511262476444]
logprobs of choice 2 [-12.069820404052734, -0.0010686860186979175]
probs of choice 1 [3.18022445e-04 9.88240182e-01]
probs of choice 2 [5.72985459e-06 9.98931885e-01]
prob sum 0.9885582047666096
prob sum 0.9989376146774299
--------------------------------------------------
questions: What is 5 + 6?
choice: calculator
logprobs of choice 1 [-6.829063892364502, -0.009234977886080742]
logprobs of choice 2 [-11.620172500610352, -0.0014222837053239346]
probs of choice 1 [0.00108187 0.99080753]
probs of choice 2 [8.98303733e-06 9.98578727e-01]
prob sum 0.9918894039483331
prob sum 0.9985877102981204
--------------------------------------------------
questions: Who is Michael Jordan?
choice: search engine
logprobs of choice 1 [-11.84224796295166, -0.2913426160812378]
logprobs of choice 2 [-9.289620399475098, -0.0018001894932240248]
probs of choice 1 [7.19410971e-06 7.47259611e-01]
probs of choice 2 [9.23781204e-05 9.98201430e-01]
prob sum 0.7472668051043772
prob sum 0.9982938079964204

Am I doing anything wrong? Ideally, I think the exp sum of binary log probs should sum to one.

@merrymercy
Copy link
Contributor

@mlinegar They are not binary log probs. It is the log prob over the whole vocab set. The meaning of this log prob is the same as the log prob defined in the OpenAI API.

For any new questions. Please open a new issue.

@merrymercy
Copy link
Contributor

@aflah02 Did you notice any performance improvement vs vllm or other libraries?

@aflah02
Copy link
Author

aflah02 commented Jan 30, 2024

@merrymercy Yep it's a very significant speed up over vllm for my usecase :)
Thanks for this library. My only pain point is some of the models I'm using are not supported under sglang for now, so I need to use vllm for them. It would be great to have support for Pythia, OPT and Falcon models.
I also fail in loading Mixtral but haven't had the time to open an issue yet. Let me do that

@merrymercy
Copy link
Contributor

We currently do not have the bandwidth to add these models. If you are interested, you can help us contribute them.

Adding a new model is very easy. We use an architecture very similar to vLLM. Here are the steps to add a new model

  1. Compare these two files (https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/models/llama2.py, https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/llama.py). You can learn how to convert a model implementation from vLLM to SGLang. We need to replace PagedAttention with RadixAttention. The other parts are almost the same.
  2. Convert models like OPT, Falcon, Pythia from vLLM to SGLang.

@aflah02
Copy link
Author

aflah02 commented Jan 30, 2024

Thanks! I'll take a look at this

@aliencaocao
Copy link

Hi how do i get the last_logits out? i dont need logprob for every token, but just the last one.
I am using llava1.6 mistral 7b and it has a bug where i get some image encoding tensor dim error whenever i use the choices arg, so i cannot use. But if i dont use, how do i get the logits? I checked the source code and there seem to be a part where i can get the last logit instead of logprob for every token. How do I achieve that?

@FredericOdermatt
Copy link
Contributor

FredericOdermatt commented Oct 1, 2024

As a note for anyone coming to this issue: #1495 merged two days ago adds return_text_in_logprobs: True to the select call. The tutorial now works nicely, returning both token probabilities and tokens in text space:

sglang/examples/frontend_language/usage/choices_logprob.py with Mixtral22B:

questions: What is 5 + 5?
choice: calculator
logprobs of choice 1 [[-1.6109180450439453, 5668, 'calcul'], [-0.0017866615671664476, 1796, 'ator']]
logprobs of choice 2 [[-7.110918045043945, 4240, 'search'], [-0.05527007207274437, 5224, 'engine']]
--------------------------------------------------
questions: What is 5 + 6?
choice: calculator
logprobs of choice 1 [[-1.3703675270080566, 5668, 'calcul'], [-0.0008510305196978152, 1796, 'ator']]
logprobs of choice 2 [[-7.495367527008057, 4240, 'search'], [-0.05314142629504204, 5224, 'engine']]
--------------------------------------------------
questions: Who is Michael Jordan?
choice: search engine
logprobs of choice 1 [[-9.60036849975586, 5668, 'calcul'], [-0.0083265770226717, 1796, 'ator']]
logprobs of choice 2 [[-1.8503687381744385, 4240, 'search'], [-0.021910740062594414, 5224, 'engine']]

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

6 participants