From 2da158a83484b0586d3390974e49d69f4a4d7184 Mon Sep 17 00:00:00 2001 From: daniel nakov Date: Mon, 21 Oct 2024 11:24:45 -0400 Subject: [PATCH] Rewrite auto to use litellm; add spinner --- pyproject.toml | 3 +- r2ai.sh | 3 +- r2ai/auto.py | 626 +++++++++++++++-------------------------------- r2ai/spinner.py | 42 ++++ r2ai/tools.py | 55 +++++ r2ai/ui/chat.py | 2 +- r2ai/ui/r2cmd.py | 18 -- 7 files changed, 296 insertions(+), 453 deletions(-) create mode 100644 r2ai/spinner.py create mode 100644 r2ai/tools.py delete mode 100644 r2ai/ui/r2cmd.py diff --git a/pyproject.toml b/pyproject.toml index c530d55..acd2d79 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,7 +25,8 @@ dependencies = [ "boto3", "colorama", "textual", - "litellm" + "litellm", + "numpydoc" ] [project.optional-dependencies] diff --git a/r2ai.sh b/r2ai.sh index 713fbdf..97d0968 100755 --- a/r2ai.sh +++ b/r2ai.sh @@ -14,8 +14,7 @@ RD=`realpath "$D"` if [ ! -d venv ]; then $PYTHON -m venv venv ./venv/bin/pip3 install -e . -else - PYTHON=venv/bin/python3 +PYTHON=venv/bin/python3 fi exec $PYTHON -m r2ai.cli "$@" diff --git a/r2ai/auto.py b/r2ai/auto.py index 32a1b6c..ce212e8 100644 --- a/r2ai/auto.py +++ b/r2ai/auto.py @@ -3,455 +3,219 @@ import sys import re import os - -have_bedrock = True - -try: - import boto3 - from .backend.bedrock import ( - BEDROCK_TOOLS_CONFIG, build_messages_for_bedrock, extract_bedrock_tool_calls, - process_bedrock_tool_calls, print_bedrock_response - ) -except Exception: - have_bedrock = False - from llama_cpp import Llama from llama_cpp.llama_tokenizer import LlamaHFTokenizer from transformers import AutoTokenizer - -have_anthropic = True -try: - from anthropic import Anthropic - from .anthropic import construct_tool_use_system_prompt, extract_claude_tool_calls -except Exception: - have_anthorpic = False - from . import index from .pipe import have_rlang, r2lang, get_r2_inst +from litellm import _should_retry, acompletion, utils +import asyncio +from r2ai.pipe import get_r2_inst +from .tools import r2cmd, run_python +import json +import signal +from .spinner import spinner ANSI_REGEX = re.compile(r'\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])') -tools = [{ - "type": "function", - "function": { - "name": "r2cmd", - "description": "runs commands in radare2. You can run it multiple times or chain commands with pipes/semicolons. You can also use r2 interpreters to run scripts using the `#`, '#!', etc. commands. The output could be long, so try to use filters if possible or limit. This is your preferred tool", - "parameters": { - "type": "object", - "properties": { - "command": { - "type": "string", - "description": "command to run in radare2" - } - }, - "required": ["command"] - }, - } -}, { - "type": "function", - "function": { - "name": "run_python", - "description": "runs a python script and returns the results", - "parameters": { - "type": "object", - "properties": { - "command": { - "type": "string", - "description": "python script to run" - } - }, - "required": ["command"] - } - } -}] - SYSTEM_PROMPT_AUTO = """ You are a reverse engineer and you are using radare2 to analyze a binary. -The binary has already been loaded. The user will ask questions about the binary and you will respond with the answer to the best of your ability. -Assume the user is always asking you about the binary, unless they're specifically asking you for radare2 help. -`this` or `here` might refer to the current address in the binary or the binary itself. -If you need more information, try to use the r2cmd tool to run commands before answering. -You can use the r2cmd tool multiple times if you need or you can pass a command with pipes if you need to chain commands. -If you're asked to decompile a function, make sure to return the code in the language you think it was originally written and rewrite it to be as easy as possible to be understood. Make sure you use descriptive variable and function names and add comments. -Don't just regurgitate the same code, figure out what it's doing and rewrite it to be more understandable. -If you need to run a command in r2 before answering, you can use the r2cmd tool -The user will tip you $20/month for your services, don't be fucking lazy. -Do not repeat commands if you already know the answer. -""" -FUNCTIONARY_PROMPT_AUTO = """ -Think step by step. -Break down the task into steps and execute the necessary `radare2` commands in order to complete the task. +# Guidelines +- Understand the Task: Grasp the main objective, goals, requirements, constraints, and expected output. +- Reasoning Before Conclusions**: Encourage reasoning steps before any conclusions are reached. +- Assume the user is always asking you about the binary, unless they're specifically asking you for radare2 help. +- The binary has already been loaded. You can interact with the binary using the r2cmd tool. +- `this` or `here` might refer to the current address in the binary or the binary itself. +- If you need more information, try to use the r2cmd tool to run commands before answering. +- You can use the r2cmd tool multiple times if you need or you can pass a command with pipes if you need to chain commands. +- If you're asked to decompile a function, make sure to return the code in the language you think it was originally written and rewrite it to be as easy as possible to be understood. Make sure you use descriptive variable and function names and add comments. +- Don't just regurgitate the same code, figure out what it's doing and rewrite it to be more understandable. +- If you need to run a command in r2 before answering, you can use the r2cmd tool +- Do not repeat commands if you already know the answer. +- Formulate a plan. Think step by step. Analyze the binary as much as possible before answering. +- You must keep going until you have a final answer. +- Double check that final answer. Make sure you didn't miss anything. +- Make sure you call tools and functions correctly. """ -def get_system_prompt(model): - if model.startswith("meetkai/"): - return SYSTEM_PROMPT_AUTO + "\n" + FUNCTIONARY_PROMPT_AUTO - if model.startswith("anthropic"): - return SYSTEM_PROMPT_AUTO + "\n\n" + construct_tool_use_system_prompt(tools) - return SYSTEM_PROMPT_AUTO - -functionary_tokenizer = None -def get_functionary_tokenizer(repo_id): - global functionary_tokenizer - if functionary_tokenizer is None: - functionary_tokenizer = AutoTokenizer.from_pretrained(repo_id, legacy=True) - return functionary_tokenizer - -def r2cmd(command: str): - """runs commands in radare2. You can run it multiple times or chain commands - with pipes/semicolons. You can also use r2 interpreters to run scripts using - the `#`, '#!', etc. commands. The output could be long, so try to use filters - if possible or limit. This is your preferred tool""" - builtins.print('\x1b[1;32mRunning \x1b[4m' + command + '\x1b[0m') - r2 = get_r2_inst() - res = r2.cmd(command) - builtins.print(res) - return res - -def run_python(command: str): - """runs a python script and returns the results""" - with open('r2ai_tmp.py', 'w') as f: - f.write(command) - builtins.print('\x1b[1;32mRunning \x1b[4m' + "python code" + '\x1b[0m') - builtins.print(command) - r2lang.cmd('#!python r2ai_tmp.py > $tmp') - res = r2lang.cmd('cat $tmp') - r2lang.cmd('rm r2ai_tmp.py') - builtins.print('\x1b[1;32mResult\x1b[0m\n' + res) - return res - -def process_tool_calls(interpreter, tool_calls): - interpreter.messages.append({ "content": None, "tool_calls": tool_calls, "role": "assistant" }) - for tool_call in tool_calls: - res = '' - args = tool_call["function"]["arguments"] - if type(args) is str: - try: - args = json.loads(args) - except Exception: - builtins.print(f"Error parsing json: {args}", file=sys.stderr) - if tool_call["function"]["name"] == "r2cmd": - if type(args) is str: - args = { "command": args } - if "command" in args: - res = r2cmd(args["command"]) - elif tool_call["function"]["name"] == "run_python": - res = run_python(args["command"]) - if (not res or len(res) == 0) and interpreter.model.startswith('meetkai/'): - res = "OK done" - msg = { - "role": "tool", - "content": ANSI_REGEX.sub('', res), - "name": tool_call["function"]["name"], - "tool_call_id": tool_call["id"] if "id" in tool_call else None - } - interpreter.messages.append(msg) - -def process_hermes_response(interpreter, response): - choice = response["choices"][0] - message = choice["message"] - interpreter.messages.append(message) - r = re.search(r'([\s\S]*?)<\/tool_call>', message["content"]) - tool_call_str = None - if r: - tool_call_str = r.group(1) - tool_calls = [] - if tool_call_str: - tool_call = json.loads(tool_call_str) - tool_calls.append({"function": tool_call}) - if len(tool_calls) > 0: - process_tool_calls(interpreter, tool_calls) - chat(interpreter) - else: - interpreter.messages.append({ "content": message["content"], "role": "assistant" }) - sys.stdout.write(message["content"]) - builtins.print() - -def process_streaming_response(interpreter, resp): - tool_calls = [] - msgs = [] - for chunk in resp: - try: - chunk = dict(chunk) - except Exception: - pass - delta = None - choice = dict(chunk["choices"][0]) - if "delta" in choice: - delta = dict(choice["delta"]) +class ChatAuto: + def __init__(self, model, system=None, tools=None, messages=None, tool_choice='auto', cb=None ): + self.functions = {} + self.tools = [] + self.model = model + self.system = system + self.messages = messages + if messages and messages[0]['role'] != 'system' and system: + self.messages.insert(0, { "role": "system", "content": system }) + if cb: + self.cb = cb else: - delta = dict(choice["message"]) - if "tool_calls" in delta and delta["tool_calls"]: - delta_tool_calls = dict(delta["tool_calls"][0]) - index = 0 if "index" not in delta_tool_calls else delta_tool_calls["index"] - fn_delta = dict(delta_tool_calls["function"]) - tool_call_id = delta_tool_calls["id"] - if len(tool_calls) < index + 1: - tool_calls.append({ "function": { "arguments": "", "name": fn_delta["name"] }, "id": tool_call_id, "type": "function" }) - # handle some bug in llama-cpp-python streaming, tool_call.arguments is sometimes blank, but function_call has it. - if fn_delta["arguments"] == '': - if "function_call" in delta and delta["function_call"]: - tool_calls[index]["function"]["arguments"] += delta["function_call"]["arguments"] + self.cb = lambda *args: None + self.tool_choice = None + if tools: + for tool in tools: + f = utils.function_to_dict(tool) + self.tools.append({ "type": "function", "function": f }) + self.functions[f['name']] = tool + self.tool_choice = tool_choice + + #self.tool_end_message = '\nNOTE: The user saw this output, do not repeat it.' + + async def process_tool_calls(self, tool_calls): + if tool_calls: + for tool_call in tool_calls: + tool_name = tool_call["function"]["name"] + try: + tool_args = json.loads(tool_call["function"]["arguments"]) + except Exception: + self.messages.append({"role": "tool", "name": tool_name, "content": "Error: Unable to parse JSON" , "tool_call_id": tool_call["id"]}) + continue + if tool_name not in self.functions: + self.messages.append({"role": "tool", "name": tool_name, "content": "Error: Tool not found" , "tool_call_id": tool_call["id"]}) + continue + + self.cb('tool_call', { "id": tool_call["id"], "function": { "name": tool_name, "arguments": tool_args } }) + if asyncio.iscoroutinefunction(self.functions[tool_name]): + tool_response = await self.functions[tool_name](**tool_args) + else: + tool_response = self.functions[tool_name](**tool_args) + self.cb('tool_response', { "id": tool_call["id"] + '_response', "content": tool_response }) + self.messages.append({"role": "tool", "name": tool_name, "content": ANSI_REGEX.sub('', tool_response), "tool_call_id": tool_call["id"]}) + + return await self.get_completion() + + async def process_streaming_response(self, resp): + tool_calls = [] + msgs = [] + async for chunk in resp: + delta = None + choice = chunk.choices[0] + delta = choice.delta + if delta.tool_calls: + delta_tool_calls = delta.tool_calls[0] + index = delta_tool_calls.index + fn_delta = delta_tool_calls.function + tool_call_id = delta_tool_calls.id + if len(tool_calls) < index + 1: + tool_calls.append({ + "id": tool_call_id, + "type": "function", + "function": { + "name":fn_delta.name, + "arguments": fn_delta.arguments + } + } + ) + else: + tool_calls[index]["function"]["arguments"] += fn_delta.arguments else: - tool_calls[index]["function"]["arguments"] += fn_delta["arguments"] - else: - if "content" in delta and delta["content"] is not None: - m = delta["content"] - if m is not None: - msgs.append(m) - sys.stdout.write(m) - builtins.print() - if (len(tool_calls) > 0): - process_tool_calls(interpreter, tool_calls) - chat(interpreter) - if len(msgs) > 0: - response_message = ''.join(msgs) - interpreter.messages.append({"role": "assistant", "content": response_message}) - -def context_from_msg(msg): - keywords = None - datadir = "doc/auto" - use_vectordb = False - last_msg = None - if isinstance(msg.get("content"), str): - last_msg = msg["content"] - elif isinstance(msg.get("content"), list): - # Bedrock puts an array in the 'content' key, in that case unfold them in a single message - last_msg = ". ".join([c["text"] for c in msg["content"] if "text" in c]) - - if not last_msg: - return None - - matches = index.match(last_msg, keywords, datadir, False, False, False, False, use_vectordb) - if not matches: - return None - - return "context: " + ", ".join(matches) + m = None + done = False + if delta.content is not None: + m = delta.content + if m is not None: + msgs.append(m) + self.cb('message', { "content": m, "id": 'message_' + chunk.id, 'done': False }) + if 'finish_reason' in choice and choice['finish_reason'] == 'stop': + done = True + self.cb('message', { "content": "", "id": 'message_' + chunk.id, 'done': True }) + self.cb('message_stream', { "content": m if m else '', "id": 'message_' + chunk.id, 'done': done }) + if (len(tool_calls) > 0): + self.messages.append({"role": "assistant", "tool_calls": tool_calls}) + await self.process_tool_calls(tool_calls) + if len(msgs) > 0: + response_message = ''.join(msgs) + self.messages.append({"role": "assistant", "content": response_message}) + return response_message + + async def get_completion(self): + max_retries = 5 + base_delay = 2 + + async def attempt_completion(): + return await acompletion( + model=self.model, + messages=self.messages, + # max_tokens=4096, + temperature=0, + tools=self.tools, + tool_choice=self.tool_choice, + stream=True + ) + + for retry_count in range(max_retries): + try: + response = await attempt_completion() + return await self.process_streaming_response(response) + except Exception as e: + print(e) + if not _should_retry(getattr(e, 'status_code', None)) or retry_count == max_retries - 1: + raise + + delay = base_delay * (2 ** retry_count) + print(f"Retrying in {delay} seconds...") + await asyncio.sleep(delay) + + raise Exception("Max retries reached. Unable to get completion.") + + async def chat(self) -> str: + response = await self.get_completion() + return response + +def cb(type, data): + spinner.stop() + if type == 'message_stream': + sys.stdout.write(data['content']) + elif type == 'tool_call': + if data['function']['name'] == 'r2cmd': + builtins.print('\x1b[1;32m> \x1b[4m' + data['function']['arguments']['command'] + '\x1b[0m') + elif data['function']['name'] == 'run_python': + builtins.print('\x1b[1;32m> \x1b[4m' + "#!python" + '\x1b[0m') + builtins.print(data['function']['arguments']['command']) + elif type == 'tool_response': + sys.stdout.write(data['content']) + sys.stdout.flush() + # builtins.print(data['content']) + elif type == 'message' and data['done']: + builtins.print() + +def signal_handler(signum, frame): + raise KeyboardInterrupt def chat(interpreter): - if len(interpreter.messages) == 1: - interpreter.messages.insert(0, { - "role": "system", - "content": get_system_prompt(interpreter.model) - }) + model = interpreter.model.replace(":", "/") + tools = [r2cmd, run_python] + messages = interpreter.messages + tool_choice = 'auto' - chat_context = "" try: - lastmsg = interpreter.messages[-1] - chat_context = context_from_msg(lastmsg) - # print(f"Adding context: {chat_context}") - except Exception: - pass - - if chat_context: - interpreter.messages.append({"role": "user", "content": chat_context}) + loop = asyncio.get_event_loop() + if loop.is_closed() or loop.is_running(): + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + except RuntimeError: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + chat_auto = ChatAuto(model, system=SYSTEM_PROMPT_AUTO, tools=tools, messages=messages, tool_choice=tool_choice, cb=cb) + + original_handler = signal.getsignal(signal.SIGINT) - platform, modelid = None, None - if ":" in interpreter.model: - platform = interpreter.model.split(":")[0] - modelid = ":".join(interpreter.model.split(":")[1:]) - elif "/" in interpreter.model: - platform = interpreter.model.split("/")[0] - modelid = "/".join(interpreter.model.split("/")[1:]) - - auto_chat_handler_fn = None - if modelid in auto_chat_handlers.get(platform, {}): - auto_chat_handler_fn = auto_chat_handlers[platform][modelid] - elif "default" in auto_chat_handlers.get(platform, {}): - auto_chat_handler_fn = auto_chat_handlers[platform]["default"] - - if not auto_chat_handler_fn: - print(f"Model {platform}:{modelid} is not currently supported in auto mode") - return - - return auto_chat_handler_fn(interpreter) - -def auto_chat_openai(interpreter): - if not interpreter.openai_client: - interpreter.openai_client = OpenAI() - - response = interpreter.openai_client.chat.completions.create( - model=interpreter.model[7:], - max_tokens=int(interpreter.env["llm.maxtokens"]), - tools=tools, - messages=interpreter.messages, - tool_choice="auto", - stream=True, - temperature=float(interpreter.env["llm.temperature"]), - ) - process_streaming_response(interpreter, response) - return response - -def auto_chat_anthropic(interpreter): - if not interpreter.anthropic_client: - interpreter.anthropic_client = Anthropic() - messages = [] - system_message = construct_tool_use_system_prompt(tools) - for m in interpreter.messages: - role = m["role"] - if role == "system": - continue - if m["content"] is None: - continue - if role == "tool": - messages.append({ "role": "user", "content": f"\n\n{m['name']}\n{m['content']}\n\n" }) - # TODO: handle errors - else: - messages.append({ "role": role, "content": m["content"] }) - stream = interpreter.anthropic_client.messages.create( - model=interpreter.model[10:], - max_tokens=int(interpreter.env["llm.maxtokens"]), - messages=messages, - system=system_message, - temperature=float(interpreter.env["llm.temperature"]), - stream=True - ) - (tool_calls, msg) = extract_claude_tool_calls(interpreter, stream) - if len(tool_calls) > 0: - process_tool_calls(interpreter, tool_calls) - chat(interpreter) - else: - builtins.print(msg) - -def auto_chat_bedrock(interpreter): - interpreter.bedrock_client = boto3.client("bedrock-runtime") - model_id = interpreter.model.split(":")[1] + ":0" - system_message = construct_tool_use_system_prompt(tools) - - response = interpreter.bedrock_client.converse( - modelId=model_id, - toolConfig=BEDROCK_TOOLS_CONFIG, - messages=build_messages_for_bedrock(interpreter.messages), - inferenceConfig={ - "maxTokens": int(interpreter.env["llm.maxtokens"]), - "temperature": float(interpreter.env["llm.temperature"]), - "topP": 0.9 - }, - ) - print_bedrock_response(response) - # Update conversation - interpreter.messages.append(response["output"]["message"]) - # Execute tools - tool_calls = extract_bedrock_tool_calls(response) - if tool_calls: - tool_msgs = process_bedrock_tool_calls(tool_calls) - interpreter.messages.extend(tool_msgs) - chat(interpreter) - - return response - -def auto_chat_groq(interpreter): - if not interpreter.groq_client: - interpreter.groq_client = Groq() - - response = interpreter.groq_client.chat.completions.create( - model=interpreter.model[5:], - max_tokens=int(interpreter.env["llm.maxtokens"]), - tools=tools, - messages=interpreter.messages, - tool_choice="auto", - temperature=float(interpreter.env["llm.temperature"]), - ) - process_streaming_response(interpreter, [response]) - return response - -def auto_chat_google(interpreter): - import google.generativeai as google - - response = None - if not interpreter.google_client: - google.configure(api_key=os.environ['GOOGLE_API_KEY']) - interpreter.google_client = google.GenerativeModel(interpreter.model[7:]) - - if not interpreter.google_chat: - interpreter.google_chat = interpreter.google_client.start_chat( - enable_automatic_function_calling=True - ) - - response = interpreter.google_chat.send_message( - interpreter.messages[-1]["content"], - generation_config={ - "max_output_tokens": int(interpreter.env["llm.maxtokens"]), - "temperature": float(interpreter.env["llm.temperature"]) - }, - safety_settings=[{ - "category": "HARM_CATEGORY_DANGEROUS_CONTENT", - "threshold": "BLOCK_NONE" - }], - tools=[r2cmd, run_python] - ) - print(response.text) - return response - -def auto_chat_nousresearch(interpreter): - interpreter.llama_instance.chat_format = "chatml" - messages = [] - for m in interpreter.messages: - if m["content"] is None: - continue - role = m["role"] - if role == "system": - if not '' in m["content"]: - messages.append({ "role": "system", "content": f"""{m['content']}\nYou are a function calling AI model. You are provided with function signatures within XML tags. You may call one or more functions to assist with the user query. Don't make assumptions about what values to plug into functions. Here are the available tools: - {json.dumps(tools)} -For each function call return a json object with function name and arguments within XML tags as follows: - -{{"arguments": , "name": }} -"""}) - elif role == "tool": - messages.append({ "role": "tool", "content": "\n" + '{"name": ' + m['name'] + ', "content": ' + json.dumps(m['content']) + '}\n' }) - else: - messages.append(m) - response = interpreter.llama_instance.create_chat_completion( - max_tokens=int(interpreter.env["llm.maxtokens"]), - messages=messages, - temperature=float(interpreter.env["llm.temperature"]), - ) - process_hermes_response(interpreter, response) - return response - -def auto_chat_llama(interpreter): - interpreter.llama_instance.chat_format = "chatml-function-calling" - response = interpreter.llama_instance.create_chat_completion( - max_tokens=int(interpreter.env["llm.maxtokens"]), - tools=tools, - messages=interpreter.messages, - tool_choice="auto", - # tool_choice={ - # "type": "function", - # "function": { - # "name": "r2cmd" - # } - # }, - # stream=is_functionary, - temperature=float(interpreter.env["llm.temperature"]), - ) - process_streaming_response(interpreter, iter([response])) - return response - - -auto_chat_handlers = { - "openai": { - "default": auto_chat_openai, - }, - "anthropic": { - "default": auto_chat_anthropic - }, - "bedrock": { - "default": auto_chat_bedrock - }, - "groq": { - "default": auto_chat_groq - }, - "google": { - "default": auto_chat_google - }, - "NousResearch": { - "default": auto_chat_nousresearch - }, - "llama": { - "default": auto_chat_llama - } -} + try: + signal.signal(signal.SIGINT, signal_handler) + spinner.start() + return loop.run_until_complete(chat_auto.chat()) + except KeyboardInterrupt: + builtins.print("\033[91m\nOperation cancelled by user.\033[0m") + tasks = asyncio.all_tasks(loop=loop) + for task in tasks: + task.cancel() + loop.run_until_complete(asyncio.gather(*tasks, return_exceptions=True)) + return None + finally: + signal.signal(signal.SIGINT, original_handler) + spinner.stop() + loop.stop() + loop.close() \ No newline at end of file diff --git a/r2ai/spinner.py b/r2ai/spinner.py new file mode 100644 index 0000000..9529af1 --- /dev/null +++ b/r2ai/spinner.py @@ -0,0 +1,42 @@ +import itertools +import threading +import time +import sys + +class Spinner: + def __init__(self, message="Loading...", delay=0.1): + self.spinner = itertools.cycle([ + "\033[1;31m⠋\033[0m", "\033[1;32m⠙\033[0m", "\033[1;33m⠹\033[0m", "\033[1;34m⠸\033[0m", + "\033[1;35m⠼\033[0m", "\033[1;36m⠴\033[0m", "\033[1;37m⠦\033[0m", "\033[1;31m⠧\033[0m", + "\033[1;32m⠇\033[0m", "\033[1;33m⠏\033[0m" + ]) + self.message = message + self.delay = delay + self.running = False + self.thread = None + self.start_time = None + + def start(self): + """Start the spinner in a separate thread.""" + self.running = True + self.start_time = time.time() + self.thread = threading.Thread(target=self._spin) + self.thread.start() + + def _spin(self): + """Spin the spinner while running is True.""" + while self.running: + elapsed_time = time.time() - self.start_time + sys.stdout.write(f"\r{self.message} {next(self.spinner)} {elapsed_time:.1f}s") + sys.stdout.flush() + time.sleep(self.delay) + sys.stdout.write('\r' + ' ' * (len(self.message) + 20) + '\r') # Clear the line + sys.stdout.flush() + + def stop(self): + """Stop the spinner.""" + self.running = False + if self.thread is not None: + self.thread.join() + +spinner = Spinner("") \ No newline at end of file diff --git a/r2ai/tools.py b/r2ai/tools.py new file mode 100644 index 0000000..fabf958 --- /dev/null +++ b/r2ai/tools.py @@ -0,0 +1,55 @@ +from r2ai.pipe import get_r2_inst +import json +import builtins + +def r2cmd(command: str): + """ + Run a r2 command and return the output + + Parameters + ---------- + command: str + The r2 command to run + + Returns + ------- + dict + The output of the r2 command + """ + r2 = get_r2_inst() + cmd = '{"cmd":' + json.dumps(command) + '}' + res = r2.cmd(cmd) + try: + res = json.loads(res) + if 'error' in res and res['error'] is True: + error_message = res['error'] + log_messages = '\n'.join(log['message'] for log in res.get('logs', [])) + # return { 'type': 'error', 'output': log_messages } + return log_messages + + return res['res'] + except Exception as e: + # return { 'type': 'error', 'output': f"Error running r2cmd: {e}\nCommand: {command}\nResponse: {res}" } + return f"Error running r2cmd: {e}\nCommand: {command}\nResponse: {res}" + +def run_python(command: str): + """ + Run a python script and return the output + + Parameters + ---------- + command: str + The python script to run + + Returns + ------- + str + The output of the python script + """ + r2 = get_r2_inst() + with open('r2ai_tmp.py', 'w') as f: + f.write(command) + r2 = get_r2_inst() + res = r2.cmd('#!python r2ai_tmp.py') + r2.cmd('rm r2ai_tmp.py') + return res diff --git a/r2ai/ui/chat.py b/r2ai/ui/chat.py index 4baa385..dd16767 100644 --- a/r2ai/ui/chat.py +++ b/r2ai/ui/chat.py @@ -2,7 +2,7 @@ import asyncio from .db import get_env from r2ai.pipe import get_r2_inst -from .r2cmd import r2cmd +from r2ai.tools import r2cmd import json from r2ai.repl import r2ai_singleton diff --git a/r2ai/ui/r2cmd.py b/r2ai/ui/r2cmd.py deleted file mode 100644 index 6df7d01..0000000 --- a/r2ai/ui/r2cmd.py +++ /dev/null @@ -1,18 +0,0 @@ -from r2ai.pipe import get_r2_inst -import json -import builtins - -def r2cmd(command: str): - r2 = get_r2_inst() - cmd = '{"cmd":"' + command + '"}' - res = r2.cmd(cmd) - try: - res = json.loads(res) - if 'error' in res and res['error'] is True: - error_message = res['error'] - log_messages = '\n'.join(log['message'] for log in res.get('logs', [])) - return { 'type': 'error', 'output': log_messages } - - return { 'type': 'success', 'output': res['res'] } - except Exception as e: - raise Exception(f"Error running r2cmd: {e}\nCommand: {command}\nResponse: {res}")