-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathagent.py
53 lines (44 loc) · 1.68 KB
/
agent.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
from transformers import ReactCodeAgent, HfApiEngine
from prompts import *
from tools.squad_tools import SquadRetrieverTool, SquadQueryTool
from transformers.agents.llm_engine import MessageRole, get_clean_message_list
from openai import OpenAI
from prompts import FOCUSED_SQUAD_REACT_CODE_SYSTEM_PROMPT
DEFAULT_TASK_SOLVING_TOOLBOX = [SquadRetrieverTool()] # , SquadQueryTool()
openai_role_conversions = {
MessageRole.TOOL_RESPONSE: MessageRole.USER,
}
class OpenAIModel:
def __init__(self, model_name="gpt-4o-mini-2024-07-18"):
self.model_name = model_name
self.client = OpenAI(
api_key=os.getenv("OPENAI_API_KEY"),
)
def __call__(self, messages, stop_sequences=[]):
messages = get_clean_message_list(messages, role_conversions=openai_role_conversions)
response = self.client.chat.completions.create(
model=self.model_name,
messages=messages,
stop=stop_sequences,
temperature=0.5
)
return response.choices[0].message.content
def get_agent(
model_name=None,
system_prompt=FOCUSED_SQUAD_REACT_CODE_SYSTEM_PROMPT,
toolbox=DEFAULT_TASK_SOLVING_TOOLBOX,
use_openai=True,
openai_model_name="gpt-4o-mini-2024-07-18",
):
DEFAULT_MODEL_NAME = "http://localhost:1234/v1"
if model_name is None:
model_name = DEFAULT_MODEL_NAME
llm_engine = HfApiEngine(model_name) if not use_openai else OpenAIModel(openai_model_name)
# Initialize the agent with both tools
agent = ReactCodeAgent(
tools=toolbox,
llm_engine=llm_engine,
system_prompt=system_prompt,
additional_authorized_imports=["PIL"],
)
return agent