-
Notifications
You must be signed in to change notification settings - Fork 1
/
beam_utils.py
148 lines (126 loc) · 4.87 KB
/
beam_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
import asyncio
from enum import Enum
from openai import AsyncOpenAI
from pydantic import BaseModel, ValidationError
from tqdm import tqdm
from transformers import LlamaTokenizer
from typing import Callable
from vllm import LLM, SamplingParams
from hallu_detection_prompt import PROMPT
from utils import find_label, get_inference_funcs
token_to_id = {
"PASS": 51935, # The actual token is " PASS"
"FAIL": 34207, # The actual token is " FAIL"
}
flip_label: Callable = lambda x: "FAIL" if x == "PASS" else "PASS"
class YesNo(Enum):
YES = "Yes"
NO = "No"
class SimilarityResponse(BaseModel):
similar: YesNo
async def mutate(prompt, client: AsyncOpenAI, model):
mutate_prompt = "Your job is to slightly perturb prompts while preserving their original meaning. You can feel free to modify syntactic and grammatical structure, and feel free to use reasonably similar synonyms or phrases to replace existing words. Do not, however, change the original intent of the question."
completion = await client.chat.completions.create(
model=model,
messages=[
{
"role": "system",
"content": mutate_prompt,
},
{"role": "user", "content": f"Perturb the following prompt: {prompt}"},
],
temperature=0.85,
)
return completion.choices[0].message.content
async def similar(prompt, client: AsyncOpenAI, model, og_question):
# Job is to see if questions are semantically similar enough
simlar_prompt = "Your job is to compare two sentences and determine whether or not they are asking the same underlying question."
try:
completion: SimilarityResponse = await client.chat.completions.create(
model=model,
messages=[
{
"role": "system",
"content": simlar_prompt,
},
{
"role": "user",
"content": f"Consider Sentence A: {og_question}.\nAnd now consider Sentence B: {prompt}.\nDo the two Sentences ask very precisely the same question? Answer either Yes or No.",
},
],
max_retries=5,
response_model=SimilarityResponse,
)
return completion.similar
except ValidationError as e:
print(e)
return None
def batch_gpt_call(
prompts: list[str], op: Callable, client: AsyncOpenAI, model="gpt-3.5-turbo"
):
async def handler(prompts):
tasks = [op(prompt, client, model) for prompt in prompts]
responses = await asyncio.gather(*tasks)
return responses
return asyncio.run(handler(prompts))
def remove_indices(lst, indices):
# Sort the indices in descending order to avoid index shifting
for index in sorted(indices, reverse=True):
if 0 <= index < len(lst):
lst.pop(index)
return lst
def get_labels(
target_model: LLM,
model_name: str,
tokenizer: LlamaTokenizer,
questions: list[str],
answer: str,
context: str,
og_label: str,
# TODO: temp closer to 0
sample_params: SamplingParams = SamplingParams(
max_tokens=8000, logprobs=20, temperature=0
),
batch_size: int = 100,
use_logprobs: bool = False,
) -> tuple[list, list, list]:
messages = [
[
{
"role": "user",
"content": PROMPT.format(
question=question, context=context, answer=answer
),
}
]
for question in questions
]
messages, generate, parse_responses = get_inference_funcs(
messages, target_model, tokenizer, model_name, sample_params
)
labels, all_responses, scores = [], [], []
for i in tqdm(range(0, len(messages), batch_size), desc="Processing batches"):
responses = generate(messages[i : i + batch_size])
text_responses = parse_responses(responses)
idx_to_remove = []
for i, resp in enumerate(responses):
try:
if use_logprobs:
scores.append(
resp.outputs[0]
# 3rd to last token is FAIL or PASS
.logprobs[-3][token_to_id[flip_label(og_label)]].logprob
)
else:
# This devolves to random search.
# TODO: use logprobs from OpenAI API.
# Harder in general to know what position the PASS or FAIL token is except for models trained on the same output format.
scores.append(10)
except Exception as e:
print(e)
idx_to_remove.append(i)
continue
text_responses = remove_indices(text_responses, idx_to_remove)
labels.extend(map(find_label, text_responses))
all_responses.extend(text_responses)
return labels, all_responses, scores