From 7e6175f5b1c51e691be394b72720ced3630be3b6 Mon Sep 17 00:00:00 2001 From: femto Date: Sun, 3 Nov 2024 16:07:50 +0800 Subject: [PATCH] git --- optillm/bon.py | 34 +++++++++++++++++++++++----------- 1 file changed, 23 insertions(+), 11 deletions(-) diff --git a/optillm/bon.py b/optillm/bon.py index 8ee752a..7f38ed9 100644 --- a/optillm/bon.py +++ b/optillm/bon.py @@ -9,18 +9,30 @@ def best_of_n_sampling(system_prompt: str, initial_query: str, client, model: st {"role": "user", "content": initial_query}] completions = [] - - response = client.chat.completions.create( - model=model, - messages=messages, - max_tokens=4096, - n=n, - temperature=1 - ) - completions = [choice.message.content for choice in response.choices] + + try: + response = client.chat.completions.create( + model=model, + messages=messages, + max_tokens=4096, + n=n, + temperature=1 + ) + completions = [choice.message.content for choice in response.choices] + bon_completion_tokens += response.usage.completion_tokens + except: + for _ in range(n): + response = client.chat.completions.create( + model=model, + messages=messages, + max_tokens=4096, + n=1, + temperature=1 + ) + completions.extend([choice.message.content for choice in response.choices]) + bon_completion_tokens += response.usage.completion_tokens logger.info(f"Generated {len(completions)} initial completions. Tokens used: {response.usage.completion_tokens}") - bon_completion_tokens += response.usage.completion_tokens - + # Rate the completions rating_messages = messages.copy() rating_messages.append({"role": "system", "content": "Rate the following responses on a scale from 0 to 10, where 0 is poor and 10 is excellent. Consider factors such as relevance, coherence, and helpfulness. Respond with only a number."})