Skip to content

Commit

Permalink
Merge branch 'pdb'
Browse files Browse the repository at this point in the history
  • Loading branch information
pldi21 committed Mar 5, 2024
2 parents 8224a5c + 44d966a commit d40780e
Show file tree
Hide file tree
Showing 5 changed files with 87 additions and 49 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ dependencies = [
"litellm>=1.26.6",
"PyYAML>=6.0.1",
"ipyflow>=0.0.130",
"numpy>=1.26.3"
]
description = "AI-assisted debugging. Uses AI to answer 'why'."
readme = "README.md"
Expand Down
58 changes: 15 additions & 43 deletions src/chatdbg/chatdbg_pdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from .ipdb_util.logging import ChatDBGLog, CopyingTextIOWrapper
from .ipdb_util.prompts import pdb_instructions
from .ipdb_util.text import *
from .ipdb_util.locals import *

_valid_models = [
"gpt-4-turbo-preview",
Expand Down Expand Up @@ -194,7 +195,7 @@ def onecmd(self, line: str) -> bool:
output = strip_color(hist_file.getvalue())
if line not in [ 'quit', 'EOF']:
self._log.user_command(line, output)
if line not in [ 'hist', 'test_prompt' ] and not self.was_chat:
if line not in [ 'hist', 'test_prompt', 'c', 'continue' ] and not self.was_chat:
self._history += [ (line, output) ]

def message(self, msg) -> None:
Expand Down Expand Up @@ -389,56 +390,27 @@ def print_stack_trace(self, context=None, locals=None):
pass


def _get_defined_locals_and_params(self, frame):

class SymbolFinder(ast.NodeVisitor):
def __init__(self):
self.defined_symbols = set()

def visit_Assign(self, node):
for target in node.targets:
if isinstance(target, ast.Name):
self.defined_symbols.add(target.id)
self.generic_visit(node)

def visit_For(self, node):
if isinstance(node.target, ast.Name):
self.defined_symbols.add(node.target.id)
self.generic_visit(node)

def visit_comprehension(self, node):
if isinstance(node.target, ast.Name):
self.defined_symbols.add(node.target.id)
self.generic_visit(node)


try:
source = textwrap.dedent(inspect.getsource(frame))
tree = ast.parse(source)

finder = SymbolFinder()
finder.visit(tree)

args, varargs, keywords, locals = inspect.getargvalues(frame)
parameter_symbols = set(args + [ varargs, keywords ])
parameter_symbols.discard(None)

return (finder.defined_symbols | parameter_symbols) & locals.keys()
except OSError as e:
# yipes -silent fail if getsource fails
return set()

def _print_locals(self, frame):
locals = frame.f_locals
defined_locals = self._get_defined_locals_and_params(frame)
in_global_scope = locals is frame.f_globals
defined_locals = extract_locals(frame)
# if in_global_scope and "In" in locals: # in notebook
# defined_locals = defined_locals | extract_nb_globals(locals)
if len(defined_locals) > 0:
if locals is frame.f_globals:
if in_global_scope:
print(f' Global variables:', file=self.stdout)
else:
print(f' Variables in this frame:', file=self.stdout)
for name in sorted(defined_locals):
value = locals[name]
print(f" {name}= {format_limited(value, limit=20)}", file=self.stdout)
prefix = f' {name}= '
rep = format_limited(value, limit=20).split('\n')
if len(rep) > 1:
rep = prefix + rep[0] + '\n' + textwrap.indent('\n'.join(rep[1:]),
prefix = ' ' * len(prefix))
else:
rep = prefix + rep[0]
print(rep, file=self.stdout)
print(file=self.stdout)

def _stack_prompt(self):
Expand Down
52 changes: 52 additions & 0 deletions src/chatdbg/ipdb_util/locals.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import ast
import inspect
import textwrap

class SymbolFinder(ast.NodeVisitor):
def __init__(self):
self.defined_symbols = set()

def visit_Assign(self, node):
for target in node.targets:
if isinstance(target, ast.Name):
self.defined_symbols.add(target.id)
self.generic_visit(node)

def visit_For(self, node):
if isinstance(node.target, ast.Name):
self.defined_symbols.add(node.target.id)
self.generic_visit(node)

def visit_comprehension(self, node):
if isinstance(node.target, ast.Name):
self.defined_symbols.add(node.target.id)
self.generic_visit(node)

def extract_locals(frame):
try:
source = textwrap.dedent(inspect.getsource(frame))
tree = ast.parse(source)

finder = SymbolFinder()
finder.visit(tree)

args, varargs, keywords, locals = inspect.getargvalues(frame)
parameter_symbols = set(args + [ varargs, keywords ])
parameter_symbols.discard(None)

return (finder.defined_symbols | parameter_symbols) & locals.keys()
except:
# ipes
return set()

def extract_nb_globals(globals):
result = set()
for source in globals["In"]:
try:
tree = ast.parse(source)
finder = SymbolFinder()
finder.visit(tree)
result = result | (finder.defined_symbols & globals.keys())
except Exception as e:
pass
return result
5 changes: 2 additions & 3 deletions src/chatdbg/ipdb_util/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,11 @@
"""

_general_instructions=f"""\
The root cause of any error is likely due to a problem in the source code within
the {os.getcwd()} directory.
The root cause of any error is likely due to a problem in the source code from the user.
Explain why each variable contributing to the error has been set to the value that it has.
Keep your answers under 10 paragraphs.
Continue with your explanations until you reach the root cause of the error. Your answer may be as long as necessary.
End your answer with a section titled "##### Recommendation\\n" that contains one of:
* a fix if you have identified the root cause
Expand Down
20 changes: 17 additions & 3 deletions src/chatdbg/ipdb_util/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import itertools
import inspect
import numbers
import numpy as np

def make_arrow(pad):
"""generate the leading arrow in front of traceback or debugger"""
Expand All @@ -22,6 +23,14 @@ def _is_iterable(obj):
except TypeError:
return False


def _repr_if_defined(obj):
if obj.__class__ in [ np.ndarray, dict, list, tuple ]:
# handle these at iterables to truncate reasonably
return False
result = "__repr__" in dir(obj.__class__) and obj.__class__.__repr__ is not object.__repr__
return result

def format_limited(value, limit=10, depth=3):

def format_tuple(t, depth):
Expand Down Expand Up @@ -50,7 +59,7 @@ def helper(value, depth):
return format_dict(value.items(), depth-1)
elif isinstance(value, (str,bytes)):
if len(value) > 254:
value = value[0:253] + "..."
value = str(value)[0:253] + "..."
return value
elif isinstance(value, tuple):
if len(value) > limit:
Expand All @@ -59,6 +68,11 @@ def helper(value, depth):
return format_tuple(value, depth-1)
elif value is None or isinstance(value, (int, float, bool, type, numbers.Number)):
return value
elif isinstance(value, np.ndarray):
with np.printoptions(threshold=limit):
return np.array_repr(value)
elif inspect.isclass(type(value)) and _repr_if_defined(value):
return repr(value)
elif _is_iterable(value):
value = list(itertools.islice(value, 0, limit + 1))
if len(value) > limit:
Expand All @@ -71,8 +85,8 @@ def helper(value, depth):
return value

result = str(helper(value, depth=3)).replace("Ellipsis", "...")
if len(result) > 1024:
result = result[:1024-3] + '...'
if len(result) > 1024 * 2:
result = result[:1024 * 2 - 3] + '...'
if type(value) == str:
return "'" + result + "'"
else:
Expand Down

0 comments on commit d40780e

Please sign in to comment.