diff --git a/plugin.json b/plugin.json index a4dceb2..fedbfc5 100644 --- a/plugin.json +++ b/plugin.json @@ -27,6 +27,6 @@ "openai" ] }, - "version": "2.0.1", + "version": "3.0.1", "minimumbinaryninjaversion": 3200 } \ No newline at end of file diff --git a/src/agent.py b/src/agent.py index 7007c8a..d4bbb29 100644 --- a/src/agent.py +++ b/src/agent.py @@ -5,8 +5,7 @@ from pathlib import Path import openai -from openai.api_resources.model import Model -from openai.error import APIError +from openai import APIError from binaryninja.function import Function from binaryninja.lowlevelil import LowLevelILFunction @@ -18,6 +17,7 @@ from . query import Query from . c import Pseudo_C +from . exceptions import NoAPIKeyException class Agent: @@ -45,7 +45,7 @@ def __init__(self, path_to_api_key: Optional[Path]=None) -> None: # Read the API key from the environment variable. - openai.api_key = self.read_api_key(path_to_api_key) + self.client = openai.OpenAI(api_key=self.read_api_key(filename=path_to_api_key)) assert bv is not None, 'BinaryView is None. Check how you called this function.' # Set instance attributes. @@ -87,12 +87,12 @@ def read_api_key(self, filename: Optional[Path]=None) -> str: except FileNotFoundError: log.log_error(f'Could not find API key file at {filename}.') - raise APIError('No API key found. Refer to the documentation to add the ' + raise NoAPIKeyException('No API key found. Refer to the documentation to add the ' 'API key.') def is_valid_model(self, model: str) -> bool: '''Checks if the model is valid by querying the OpenAI API.''' - models: list[Model] = openai.Model.list().data + models: list = self.client.models.list().data return model in [m.id for m in models] def get_model(self) -> str: @@ -206,7 +206,8 @@ def rename_variable(self, response: str) -> None: def send_query(self, query: str, callback: Optional[Callable]=None) -> None: '''Sends a query to the engine and prints the response.''' - query = Query(query_string=query, + query = Query(client=self.client, + query_string=query, model=self.model, max_token_count=self.get_token_count(), callback_function=callback) diff --git a/src/exceptions.py b/src/exceptions.py index 432c98e..98b8366 100644 --- a/src/exceptions.py +++ b/src/exceptions.py @@ -1,3 +1,6 @@ +class NoAPIKeyException(Exception): + pass + class RegisterSettingsGroupException(Exception): pass diff --git a/src/query.py b/src/query.py index abf43f2..04d3e4a 100644 --- a/src/query.py +++ b/src/query.py @@ -1,17 +1,18 @@ from __future__ import annotations from collections.abc import Callable from typing import Optional -import openai +from openai import Client from binaryninja.plugin import BackgroundTaskThread from binaryninja.log import log_debug, log_info class Query(BackgroundTaskThread): - def __init__(self, query_string: str, model: str, + def __init__(self, client: Client, query_string: str, model: str, max_token_count: int, callback_function: Optional[Callable]=None) -> None: BackgroundTaskThread.__init__(self, initial_progress_text="", can_cancel=False) + self.client: Client = client self.query_string: str = query_string self.model: str = model self.max_token_count: int = max_token_count @@ -23,7 +24,7 @@ def run(self) -> None: log_debug(f'Sending query: {self.query_string}') if self.model in ["gpt-3.5-turbo","gpt-4","gpt-4-32k"]: - response = openai.ChatCompletion.create( + response = self.client.chat.completions.create( model=self.model, messages=[{"role":"user","content":self.query_string}], max_tokens=self.max_token_count, @@ -31,7 +32,7 @@ def run(self) -> None: # Get the response text. result: str = response.choices[0].message.content else: - response = openai.Completion.create( + response = self.client.chat.completions.create( model=self.model, prompt=self.query_string, max_tokens=self.max_token_count,