-
Notifications
You must be signed in to change notification settings - Fork 1
/
beam.py
171 lines (155 loc) · 4.86 KB
/
beam.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
"""
Sorry guys...
But also you're welcome for free adversarial data :)
"""
import instructor
import json
import numpy as np
import random
from anthropic import AsyncAnthropic
from datasets import load_dataset
from dotenv import load_dotenv
from functools import partial
from itertools import chain
from openai import AsyncOpenAI
from rich.console import Console
from transformers import LlamaTokenizer
from typing import Union
from vllm import LLM
from beam_utils import YesNo, batch_gpt_call, get_labels, mutate, similar
from utils import find_label, setup_target_llm
load_dotenv()
console = Console()
def beam_search(
init: str,
answer: str,
context: str,
og_label: str,
client: Union[AsyncAnthropic, AsyncOpenAI],
instructor_client: Union[AsyncAnthropic, AsyncOpenAI],
llm_under_haize: LLM,
model_name: str,
tokenizer: LlamaTokenizer,
beam_size: int = 10,
explore_size: int = 10,
max_iters: int = 10,
desired_screwups: int = 20,
):
beam: list[str] = [init]
flipped_screwups, flipped_strings = (
[],
set(),
) # Candidates that flip the boundary. Oopsies!
i = 0
while i < max_iters and len(flipped_screwups) < desired_screwups:
console.print(
f"\n\n---------------------- Iter {i} -----------------------", style="blue"
)
i += 1
mutant_prompts = []
while not mutant_prompts:
# Searching all nodes from beam
mutant_prompts = batch_gpt_call(
list(chain.from_iterable([node] * explore_size for node in beam)),
mutate,
client,
"gpt-3.5-turbo",
)
# Ensure semantic similarity of question
semantic_sims = batch_gpt_call(
mutant_prompts,
partial(similar, og_question=init),
instructor_client,
"gpt-4o",
)
mutant_prompts = [
m
for i, m in enumerate(mutant_prompts)
if semantic_sims[i]
and semantic_sims[i] == YesNo.YES
and m not in flipped_strings
]
beam.extend(mutant_prompts)
# Get hallucination detection model's predicted labels
labels, raw_responses, scores = get_labels(
llm_under_haize,
model_name,
tokenizer,
mutant_prompts,
answer,
context,
og_label,
)
new_flipped = []
for mutant, label, resp in zip(mutant_prompts, labels, raw_responses):
# Check if hallucination detection model is being silly
if label != og_label:
new_flipped.append(
{
"haize_variant": mutant,
"variant_response": resp,
"variant_label": find_label(resp),
}
)
flipped_strings.add(mutant)
flipped_screwups.extend(new_flipped)
console.print(f"\nTOTAL FAILURES ==> {len(flipped_screwups)}", style="red")
console.print("== NET NEW FAILURES ==", style="red")
console.print(new_flipped, style="red")
# Beam me up scotty.
idx = np.argsort(-np.array(scores))
beam = list(np.array(beam)[idx][:beam_size])
console.print(
f"\n-------------------------------------------------------", style="blue"
)
return list(flipped_screwups)
if __name__ == "__main__":
# Example testing...
dataset = load_dataset("PatronusAI/HaluBench")["test"]
example = random.choice(dataset)
question, answer, context, label = (
example["question"],
example["answer"],
example["passage"],
example["label"],
)
console.print(
"------ Original Question, Answer, Context, Label ------", style="green"
)
print("\n<<QUESTION>>")
print(question)
print("\n<<ANSWER>>")
print(answer)
print("\n<<CONTEXT>>")
print(context)
print("\n<<LABEL>>")
print(label)
console.print(
"\n-------------------------------------------------------\n\n", style="green"
)
client = AsyncOpenAI()
model_name = "PatronusAI/Llama-3-Patronus-Lynx-8B-Instruct"
instructor_client = instructor.patch(AsyncOpenAI())
tokenizer, llm_under_haize, _ = setup_target_llm()
flipped_screwups = beam_search(
question,
answer,
context,
label,
client,
instructor_client,
llm_under_haize,
model_name,
tokenizer,
beam_size=15,
explore_size=20,
max_iters=30,
)
final = {
"question": question,
"context": context,
"answer": answer,
"og_label": label,
"screwups": flipped_screwups,
}
json.dump(final, open("beam_results.json", "w"), indent=4)