Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New Pbd code, start of lldb w/ same way to do conversations. #34

Merged
merged 11 commits into from
Feb 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
__pycache__/
*.py[cod]
*$py.class
a.out
nicovank marked this conversation as resolved.
Show resolved Hide resolved

# C extensions
*.so
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ classifiers = [

[project.scripts]
chatdbg = "chatdbg.__main__:main"
ichatdbg = "chatdbg.__imain__:main"

[project.urls]
"Homepage" = "https://github.com/plasma-umass/ChatDBG"
Expand Down
16 changes: 16 additions & 0 deletions src/chatdbg/__imain__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import os
import pathlib
import sys

the_path = pathlib.Path(__file__).parent.resolve()

sys.path.insert(0, os.path.abspath(the_path))

from .chatdbg_ipdb import *

import ipdb

ipdb.__main__._get_debugger_cls = lambda : ChatDBG

def main():
ipdb.__main__.main()
6 changes: 4 additions & 2 deletions src/chatdbg/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

sys.path.insert(0, os.path.abspath(the_path))

from . import chatdbg
from .chatdbg import *
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am really not a fan of import *. It can shadow existing imports/variables silently and heavily messes with any static analysis tool (not that we have any currently). I'll just fix it after this merges.


chatdbg.main()
# from . import chatdbg

# chatdbg.main()
235 changes: 235 additions & 0 deletions src/chatdbg/assistant/assistant.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,235 @@
import atexit
import inspect
import json
import textwrap
import time
import sys

import llm_utils
from openai import *
from pydantic import BaseModel
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's add pydantic to the pyproject.toml:

dependencies = ["llm_utils>=0.2.6", "openai>=1.6.1", "rich>=13.7.0"]

Copy link
Collaborator

@nicovank nicovank Feb 14, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unless it's not used, which may be the case here from a quick search, then just remove import 👍 .

EDIT: It is used.



class Assistant:
"""
An Assistant is a wrapper around OpenAI's assistant API. Example usage:

assistant = Assistant("Assistant Name", instructions,
model='gpt-4-1106-preview', debug=True)
assistant.add_function(my_func)
response = assistant.run(user_prompt)

Name can be any name you want.

If debug is True, it will create a log of all messages and JSON responses in
json.txt.
"""

def __init__(self, name, instructions, model="gpt-3.5-turbo-1106", debug=True):

if debug:
self.json = open(f'json.txt', 'w')
else:
self.json = None

try:
self.client = OpenAI(timeout=30)
except OpenAIError:
print(textwrap.dedent("""\
You need an OpenAI key to use this tool.
You can get a key here: https://platform.openai.com/api-keys
Set the environment variable OPENAI_API_KEY to your key value.
"""))
sys.exit(0)


self.assistants = self.client.beta.assistants
self.threads = self.client.beta.threads
self.functions = dict()

self.assistant = self.assistants.create(name=name,
instructions=instructions,
model=model)

self._log(self.assistant)

atexit.register(self._delete_assistant)

self.thread = self.threads.create()
self._log(self.thread)

def _delete_assistant(self):
if self.assistant != None:
try:
id = self.assistant.id
response = self.assistants.delete(id)
self._log(response)
assert response.deleted
except Exception as e:
print(f'Assistant {id} was not deleted ({e}).\nYou can do so at https://platform.openai.com/assistants.')

def add_function(self, function):
"""
Add a new function to the list of function tools for the assistant.
The function should have the necessary json spec as is pydoc string.
"""
function_json = json.loads(function.__doc__)
assert 'name' in function_json, "Bad JSON in pydoc for function tool."
try:
name = function_json['name']
self.functions[name] = function

tools = [
{
"type": "function",
"function": json.loads(function.__doc__)
} for function in self.functions.values()
]

assistant = self.assistants.update(self.assistant.id,
tools=tools)
self._log(assistant)
except OpenAIError as e:
print(f"*** OpenAI Error: {e}")


def _make_call(self, tool_call):
name = tool_call.function.name
args = tool_call.function.arguments

# There is a sketchy case that happens occasionally because
# the API produces a bad call...
try:
args = json.loads(args)
function = self.functions[name]
result = function(**args)
except Exception as e:
result = f"Ill-formed function call ({e})\n"

return result

def _print_messages(self, messages, client_print):
client_print()
for i,m in enumerate(messages):
message_text = m.content[0].text.value
if i == 0:
message_text = '(Message) ' + message_text
client_print(message_text)



def _wait_on_run(self, run, thread, client_print):
try:
while run.status == "queued" or run.status == "in_progress":
run = self.threads.runs.retrieve(
thread_id=thread.id,
run_id=run.id,
)
time.sleep(0.5)
return run
finally:
if run.status == 'in_progress':
client_print("Cancelling message that's in progress.")
self.threads.runs.cancel(thread_id=thread.id, run_id=run.id)

def run(self, prompt, client_print = print):
"""
Give the prompt to the assistant and get the response, which may included
intermediate function calls.
All output is printed to the given file.
"""
start_time = time.perf_counter()

try:
if self.assistant == None:
return 0,0,0

assert len(prompt) <= 32768

message = self.threads.messages.create(thread_id=self.thread.id,
role="user",
content=prompt)
self._log(message)

last_printed_message_id = message.id


run = self.threads.runs.create(thread_id=self.thread.id,
assistant_id=self.assistant.id)
self._log(run)

run = self._wait_on_run(run, self.thread, client_print)
self._log(run)

while run.status == "requires_action":

messages = self.threads.messages.list(thread_id=self.thread.id,
after=last_printed_message_id,
order='asc')

mlist = list(messages)
if len(mlist) > 0:
self._print_messages(mlist, client_print)
last_printed_message_id = mlist[-1].id
client_print()


outputs = []
for tool_call in run.required_action.submit_tool_outputs.tool_calls:
output = self._make_call(tool_call)
self._log(output)
outputs += [ { 'tool_call_id' : tool_call.id, 'output' : output } ]

try:
run = self.threads.runs.submit_tool_outputs(thread_id=self.thread.id,
run_id=run.id,
tool_outputs=outputs)
self._log(run)
except Exception as e:
self._log(run, f'FAILED to submit tool call results: {e}')

run = self._wait_on_run(run, self.thread, client_print)
self._log(run)

if run.status == 'failed':
message = f"\n**Internal Failure ({run.last_error.code}):** {run.last_error.message}"
client_print(message)
return 0,0,0

messages = self.threads.messages.list(thread_id=self.thread.id,
after=last_printed_message_id,
order='asc')
self._print_messages(messages, client_print)

end_time = time.perf_counter()
elapsed_time = end_time - start_time

cost = llm_utils.calculate_cost(run.usage.prompt_tokens,
run.usage.completion_tokens,
self.assistant.model)
client_print()
client_print(f'[Cost: ~${cost:.2f} USD]')
return run.usage.total_tokens, cost, elapsed_time
except OpenAIError as e:
client_print(f"*** OpenAI Error: {e}")
return 0,0,0




def _log(self, obj, title=''):
if self.json != None:
stack = inspect.stack()
caller_frame_record = stack[1]
lineno, function = caller_frame_record[2:4]
loc = f'{function}:{lineno}'

print('-' * 70, file=self.json)
print(f'{loc} {title}', file=self.json)
if isinstance(obj, BaseModel):
json_obj = json.loads(obj.model_dump_json())
else:
json_obj = obj
print(f'\n{json.dumps(json_obj, indent=2)}\n', file=self.json)
self.json.flush()
return obj
Loading