diff --git a/pyproject.toml b/pyproject.toml index 598eecc..081feee 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,7 +12,7 @@ authors = [ { name="Stephen Freund", email="sfreund@williams.edu" }, ] dependencies = [ - "llm-utils>=0.2.6", + "llm-utils>=0.2.8", "openai>=1.6.1", "rich>=13.7.0", "ansicolors>=1.1.8", @@ -22,6 +22,7 @@ dependencies = [ "litellm>=1.26.6", "PyYAML>=6.0.1", "ipyflow>=0.0.130", + "numpy>=1.26.3" ] description = "AI-assisted debugging. Uses AI to answer 'why'." readme = "README.md" diff --git a/src/chatdbg/assistant/assistant.py b/src/chatdbg/assistant/assistant.py index 213be0d..c3043cf 100644 --- a/src/chatdbg/assistant/assistant.py +++ b/src/chatdbg/assistant/assistant.py @@ -213,7 +213,7 @@ def run(self, prompt, client_print=print): ) client_print() client_print(f"[Cost: ~${cost:.2f} USD]") - return run.usage.total_tokens, cost, elapsed_time + return run.usage.total_tokens,run.usage.prompt_tokens, run.usage.completion_tokens, cost, elapsed_time except OpenAIError as e: client_print(f"*** OpenAI Error: {e}") sys.exit(-1) diff --git a/src/chatdbg/chatdbg_lldb.py b/src/chatdbg/chatdbg_lldb.py index 4005c03..6badd58 100644 --- a/src/chatdbg/chatdbg_lldb.py +++ b/src/chatdbg/chatdbg_lldb.py @@ -8,7 +8,6 @@ import json import llm_utils -import openai from assistant.lite_assistant import LiteAssistant import chatdbg_utils @@ -234,7 +233,7 @@ def why( sys.exit(1) the_prompt = buildPrompt(debugger) - args, _ = chatdbg_utils.parse_known_args(command) + args, _ = chatdbg_utils.parse_known_args(command.split()) chatdbg_utils.explain(the_prompt[0], the_prompt[1], the_prompt[2], args) @@ -389,7 +388,12 @@ def _instructions(): You are an assistant debugger. The user is having an issue with their code, and you are trying to help them find the root cause. They will provide a short summary of the issue and a question to be answered. + Call the `lldb` function to run lldb debugger commands on the stopped program. + Call the `get_code_surrounding` function to retrieve user code and give more context back to the user on their problem. + Call the `find_definition` function to retrieve the definition of a particular symbol. + You should call `find_definition` on every symbol that could be linked to the issue. + Don't hesitate to use as many function calls as needed to give the best possible answer. Once you have identified the root cause of the problem, explain it and provide a way to fix the issue if you can. """ @@ -440,52 +444,6 @@ def get_code_surrounding(filename: str, lineno: int) -> str: (lines, first) = llm_utils.read_lines(filename, lineno - 7, lineno + 3) return llm_utils.number_group_of_lines(lines, first) - clangd = clangd_lsp_integration.clangd() - - def find_definition(filename: str, lineno: int, character: int) -> str: - """ - { - "name": "find_definition", - "description": "Returns the definition for the symbol at the given source location.", - "parameters": { - "type": "object", - "properties": { - "filename": { - "type": "string", - "description": "The filename the code location is from." - }, - "lineno": { - "type": "integer", - "description": "The line number where the symbol is present." - }, - "character": { - "type": "integer", - "description": "The column number where the symbol is present." - } - }, - "required": [ "filename", "lineno", "character" ] - } - } - """ - clangd.didOpen(filename, "c" if filename.endswith(".c") else "cpp") - definition = clangd.definition(filename, lineno, character) - clangd.didClose(filename) - - if "result" not in definition or not definition["result"]: - return "No definition found." - - path = clangd_lsp_integration.uri_to_path(definition["result"][0]["uri"]) - start_lineno = definition["result"][0]["range"]["start"]["line"] + 1 - end_lineno = definition["result"][0]["range"]["end"]["line"] + 1 - (lines, first) = llm_utils.read_lines(path, start_lineno - 5, end_lineno + 5) - content = llm_utils.number_group_of_lines(lines, first) - line_string = ( - f"line {start_lineno}" - if start_lineno == end_lineno - else f"lines {start_lineno}-{end_lineno}" - ) - return f"""File '{path}' at {line_string}:\n```\n{content}\n```""" - assistant = LiteAssistant( _instructions(), model=args.llm, @@ -500,6 +458,62 @@ def find_definition(filename: str, lineno: int, character: int) -> str: print("[WARNING] clangd is not available.") print("[WARNING] The `find_definition` function will not be made available.") else: + clangd = clangd_lsp_integration.clangd() + + def find_definition(filename: str, lineno: int, symbol: str) -> str: + """ + { + "name": "find_definition", + "description": "Returns the definition for the given symbol at the given source line number.", + "parameters": { + "type": "object", + "properties": { + "filename": { + "type": "string", + "description": "The filename the symbol is from." + }, + "lineno": { + "type": "integer", + "description": "The line number where the symbol is present." + }, + "symbol": { + "type": "string", + "description": "The symbol to lookup." + } + }, + "required": [ "filename", "lineno", "symbol" ] + } + } + """ + # We just return the first match here. Maybe we should find all definitions. + with open(filename, "r") as file: + lines = file.readlines() + if lineno - 1 >= len(lines): + return "Symbol not found at that location!" + character = lines[lineno - 1].find(symbol) + if character == -1: + return "Symbol not found at that location!" + clangd.didOpen(filename, "c" if filename.endswith(".c") else "cpp") + definition = clangd.definition(filename, lineno, character + 1) + clangd.didClose(filename) + + if "result" not in definition or not definition["result"]: + return "No definition found." + + path = clangd_lsp_integration.uri_to_path(definition["result"][0]["uri"]) + start_lineno = definition["result"][0]["range"]["start"]["line"] + 1 + end_lineno = definition["result"][0]["range"]["end"]["line"] + 1 + (lines, first) = llm_utils.read_lines( + path, start_lineno - 5, end_lineno + 5 + ) + content = llm_utils.number_group_of_lines(lines, first) + line_string = ( + f"line {start_lineno}" + if start_lineno == end_lineno + else f"lines {start_lineno}-{end_lineno}" + ) + return f"""File '{path}' at {line_string}:\n```\n{content}\n```""" + assistant.add_function(find_definition) return assistant @@ -517,6 +531,8 @@ def get_frame_summary() -> str: summaries = [] for i, frame in enumerate(thread): + if not frame.GetDisplayFunctionName(): + continue name = frame.GetDisplayFunctionName().split("(")[0] arguments = [] for j in range( diff --git a/src/chatdbg/chatdbg_pdb.py b/src/chatdbg/chatdbg_pdb.py index 5e4feef..e3b8d3e 100644 --- a/src/chatdbg/chatdbg_pdb.py +++ b/src/chatdbg/chatdbg_pdb.py @@ -21,6 +21,7 @@ from .ipdb_util.logging import ChatDBGLog, CopyingTextIOWrapper from .ipdb_util.prompts import pdb_instructions from .ipdb_util.text import * +from .ipdb_util.locals import * _valid_models = [ "gpt-4-turbo-preview", @@ -194,7 +195,7 @@ def onecmd(self, line: str) -> bool: output = strip_color(hist_file.getvalue()) if line not in [ 'quit', 'EOF']: self._log.user_command(line, output) - if line not in [ 'hist', 'test_prompt' ] and not self.was_chat: + if line not in [ 'hist', 'test_prompt', 'c', 'continue' ] and not self.was_chat: self._history += [ (line, output) ] def message(self, msg) -> None: @@ -389,56 +390,27 @@ def print_stack_trace(self, context=None, locals=None): pass - def _get_defined_locals_and_params(self, frame): - - class SymbolFinder(ast.NodeVisitor): - def __init__(self): - self.defined_symbols = set() - - def visit_Assign(self, node): - for target in node.targets: - if isinstance(target, ast.Name): - self.defined_symbols.add(target.id) - self.generic_visit(node) - - def visit_For(self, node): - if isinstance(node.target, ast.Name): - self.defined_symbols.add(node.target.id) - self.generic_visit(node) - - def visit_comprehension(self, node): - if isinstance(node.target, ast.Name): - self.defined_symbols.add(node.target.id) - self.generic_visit(node) - - - try: - source = textwrap.dedent(inspect.getsource(frame)) - tree = ast.parse(source) - - finder = SymbolFinder() - finder.visit(tree) - - args, varargs, keywords, locals = inspect.getargvalues(frame) - parameter_symbols = set(args + [ varargs, keywords ]) - parameter_symbols.discard(None) - - return (finder.defined_symbols | parameter_symbols) & locals.keys() - except OSError as e: - # yipes -silent fail if getsource fails - return set() - def _print_locals(self, frame): locals = frame.f_locals - defined_locals = self._get_defined_locals_and_params(frame) + in_global_scope = locals is frame.f_globals + defined_locals = extract_locals(frame) + # if in_global_scope and "In" in locals: # in notebook + # defined_locals = defined_locals | extract_nb_globals(locals) if len(defined_locals) > 0: - if locals is frame.f_globals: + if in_global_scope: print(f' Global variables:', file=self.stdout) else: print(f' Variables in this frame:', file=self.stdout) for name in sorted(defined_locals): value = locals[name] - print(f" {name}= {format_limited(value, limit=20)}", file=self.stdout) + prefix = f' {name}= ' + rep = format_limited(value, limit=20).split('\n') + if len(rep) > 1: + rep = prefix + rep[0] + '\n' + textwrap.indent('\n'.join(rep[1:]), + prefix = ' ' * len(prefix)) + else: + rep = prefix + rep[0] + print(rep, file=self.stdout) print(file=self.stdout) def _stack_prompt(self): @@ -499,8 +471,8 @@ def client_print(line=""): full_prompt = truncate_proportionally(full_prompt) self._log.push_chat(arg, full_prompt) - tokens, cost, time = self._assistant.run(full_prompt, client_print) - self._log.pop_chat(tokens, cost, time) + total_tokens, prompt_tokens, completion_tokens, cost, time = self._assistant.run(full_prompt, client_print) + self._log.pop_chat(total_tokens, prompt_tokens, completion_tokens, cost, time) def do_mark(self, arg): marks = [ 'Full', 'Partial', 'Wrong', 'None', '?' ] diff --git a/src/chatdbg/chatdbg_utils.py b/src/chatdbg/chatdbg_utils.py index cbab44e..59f8e74 100644 --- a/src/chatdbg/chatdbg_utils.py +++ b/src/chatdbg/chatdbg_utils.py @@ -1,5 +1,4 @@ import argparse -import os import textwrap from typing import Any, List, Optional, Tuple diff --git a/src/chatdbg/clangd_lsp_integration.py b/src/chatdbg/clangd_lsp_integration.py index edaa49c..a4d7532 100644 --- a/src/chatdbg/clangd_lsp_integration.py +++ b/src/chatdbg/clangd_lsp_integration.py @@ -61,13 +61,12 @@ def uri_to_path(uri): return urllib.parse.unquote(path) # clangd seems to escape paths. -def is_available(): +def is_available(executable="clangd"): try: - clangd = subprocess.Popen( - ["clangd", "--version"], + clangd = subprocess.run( + [executable, "--version"], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, - text=True, ) return clangd.returncode == 0 except FileNotFoundError: diff --git a/src/chatdbg/ipdb_util/locals.py b/src/chatdbg/ipdb_util/locals.py new file mode 100644 index 0000000..9c704ef --- /dev/null +++ b/src/chatdbg/ipdb_util/locals.py @@ -0,0 +1,52 @@ +import ast +import inspect +import textwrap + +class SymbolFinder(ast.NodeVisitor): + def __init__(self): + self.defined_symbols = set() + + def visit_Assign(self, node): + for target in node.targets: + if isinstance(target, ast.Name): + self.defined_symbols.add(target.id) + self.generic_visit(node) + + def visit_For(self, node): + if isinstance(node.target, ast.Name): + self.defined_symbols.add(node.target.id) + self.generic_visit(node) + + def visit_comprehension(self, node): + if isinstance(node.target, ast.Name): + self.defined_symbols.add(node.target.id) + self.generic_visit(node) + +def extract_locals(frame): + try: + source = textwrap.dedent(inspect.getsource(frame)) + tree = ast.parse(source) + + finder = SymbolFinder() + finder.visit(tree) + + args, varargs, keywords, locals = inspect.getargvalues(frame) + parameter_symbols = set(args + [ varargs, keywords ]) + parameter_symbols.discard(None) + + return (finder.defined_symbols | parameter_symbols) & locals.keys() + except: + # ipes + return set() + +def extract_nb_globals(globals): + result = set() + for source in globals["In"]: + try: + tree = ast.parse(source) + finder = SymbolFinder() + finder.visit(tree) + result = result | (finder.defined_symbols & globals.keys()) + except Exception as e: + pass + return result \ No newline at end of file diff --git a/src/chatdbg/ipdb_util/logging.py b/src/chatdbg/ipdb_util/logging.py index 739fa25..51dd97b 100644 --- a/src/chatdbg/ipdb_util/logging.py +++ b/src/chatdbg/ipdb_util/logging.py @@ -106,9 +106,11 @@ def push_chat(self, line, full_prompt): } } - def pop_chat(self, tokens, cost, time): + def pop_chat(self, total_tokens, prompt_tokens, completion_tokens, cost, time): self.chat_step['stats'] = { - 'tokens' : tokens, + 'tokens' : total_tokens, + 'prompt' : prompt_tokens, + 'completion' : completion_tokens, 'cost' : cost, 'time' : time } diff --git a/src/chatdbg/ipdb_util/prompts.py b/src/chatdbg/ipdb_util/prompts.py index 669fab8..e355f99 100644 --- a/src/chatdbg/ipdb_util/prompts.py +++ b/src/chatdbg/ipdb_util/prompts.py @@ -34,12 +34,11 @@ """ _general_instructions=f"""\ -The root cause of any error is likely due to a problem in the source code within -the {os.getcwd()} directory. +The root cause of any error is likely due to a problem in the source code from the user. Explain why each variable contributing to the error has been set to the value that it has. -Keep your answers under 10 paragraphs. +Continue with your explanations until you reach the root cause of the error. Your answer may be as long as necessary. End your answer with a section titled "##### Recommendation\\n" that contains one of: * a fix if you have identified the root cause diff --git a/src/chatdbg/ipdb_util/text.py b/src/chatdbg/ipdb_util/text.py index fb9dbe3..349df11 100644 --- a/src/chatdbg/ipdb_util/text.py +++ b/src/chatdbg/ipdb_util/text.py @@ -2,6 +2,7 @@ import itertools import inspect import numbers +import numpy as np def make_arrow(pad): """generate the leading arrow in front of traceback or debugger""" @@ -22,6 +23,14 @@ def _is_iterable(obj): except TypeError: return False + +def _repr_if_defined(obj): + if obj.__class__ in [ np.ndarray, dict, list, tuple ]: + # handle these at iterables to truncate reasonably + return False + result = "__repr__" in dir(obj.__class__) and obj.__class__.__repr__ is not object.__repr__ + return result + def format_limited(value, limit=10, depth=3): def format_tuple(t, depth): @@ -50,7 +59,7 @@ def helper(value, depth): return format_dict(value.items(), depth-1) elif isinstance(value, (str,bytes)): if len(value) > 254: - value = value[0:253] + "..." + value = str(value)[0:253] + "..." return value elif isinstance(value, tuple): if len(value) > limit: @@ -59,6 +68,11 @@ def helper(value, depth): return format_tuple(value, depth-1) elif value is None or isinstance(value, (int, float, bool, type, numbers.Number)): return value + elif isinstance(value, np.ndarray): + with np.printoptions(threshold=limit): + return np.array_repr(value) + elif inspect.isclass(type(value)) and _repr_if_defined(value): + return repr(value) elif _is_iterable(value): value = list(itertools.islice(value, 0, limit + 1)) if len(value) > limit: @@ -71,8 +85,8 @@ def helper(value, depth): return value result = str(helper(value, depth=3)).replace("Ellipsis", "...") - if len(result) > 1024: - result = result[:1024-3] + '...' + if len(result) > 1024 * 2: + result = result[:1024 * 2 - 3] + '...' if type(value) == str: return "'" + result + "'" else: