diff --git a/src/chatdbg/__main__.py b/src/chatdbg/__main__.py index 349c5fa..8ff3a9f 100644 --- a/src/chatdbg/__main__.py +++ b/src/chatdbg/__main__.py @@ -1,6 +1,7 @@ import os import pathlib import sys + the_path = pathlib.Path(__file__).parent.resolve() sys.path.insert(0, os.path.abspath(the_path)) @@ -8,4 +9,3 @@ from . import chatdbg chatdbg.main() - diff --git a/src/chatdbg/chatdbg.py b/src/chatdbg/chatdbg.py index 0a2357d..573519c 100644 --- a/src/chatdbg/chatdbg.py +++ b/src/chatdbg/chatdbg.py @@ -15,10 +15,12 @@ import chatdbg_pdb import chatdbg_why + class ChatDBG(chatdbg_pdb.Pdb): def do_why(self, arg): asyncio.run(chatdbg_why.why(self, arg)) + import importlib.metadata _usage = f"""\ diff --git a/src/chatdbg/chatdbg_gdb.py b/src/chatdbg/chatdbg_gdb.py index 06b5d50..83f2da0 100644 --- a/src/chatdbg/chatdbg_gdb.py +++ b/src/chatdbg/chatdbg_gdb.py @@ -9,6 +9,7 @@ import textwrap import pathlib + the_path = pathlib.Path(__file__).parent.resolve() sys.path.append(os.path.abspath(the_path)) @@ -18,6 +19,7 @@ import chatdbg_utils + def read_lines_list(file_path: str, start_line: int, end_line: int) -> [str]: """ Read lines from a file and return a list containing the lines between start_line and end_line. @@ -32,7 +34,7 @@ def read_lines_list(file_path: str, start_line: int, end_line: int) -> [str]: """ # open the file for reading - with open(file_path, 'r') as f: + with open(file_path, "r") as f: # read all the lines from the file lines = f.readlines() # remove trailing newline characters @@ -44,30 +46,34 @@ def read_lines_list(file_path: str, start_line: int, end_line: int) -> [str]: # return the requested lines as a list return lines[start_line:end_line] + # Set the prompt to gdb-ChatDBG gdb.prompt_hook = lambda x: "(gdb-ChatDBG) " last_error_type = "" + def stop_handler(event): """Sets last error type so we can report it later.""" # Check if the event is a stop event global last_error_type - if not hasattr(event, 'stop_signal'): - last_error_type = "" # Not a real error (e.g., a breakpoint) + if not hasattr(event, "stop_signal"): + last_error_type = "" # Not a real error (e.g., a breakpoint) return if event.stop_signal is not None: last_error_type = event.stop_signal + gdb.events.stop.connect(stop_handler) # Implement the command `why` class Why(gdb.Command): """Provides root cause analysis for a failure.""" + def __init__(self): gdb.Command.__init__(self, "why", gdb.COMMAND_USER) - def invoke(self, arg, from_tty, really_run = True): + def invoke(self, arg, from_tty, really_run=True): try: frame = gdb.selected_frame() except: @@ -77,18 +83,24 @@ def invoke(self, arg, from_tty, really_run = True): if not last_error_type: # Assume we are running from a core dump, # which _probably_ means a SEGV. - last_error_type = 'SIGSEGV' + last_error_type = "SIGSEGV" the_prompt = buildPrompt() if the_prompt: # Call `explain` function with pieces of the_prompt as arguments. - asyncio.run(chatdbg_utils.explain(the_prompt[0], the_prompt[1], the_prompt[2], really_run)) - + asyncio.run( + chatdbg_utils.explain( + the_prompt[0], the_prompt[1], the_prompt[2], really_run + ) + ) + + Why() + def buildPrompt() -> str: thread = gdb.selected_thread() if not thread: - return '' + return "" stack_trace = "" source_code = "" @@ -99,7 +111,7 @@ def buildPrompt() -> str: # magic number - don't bother walking up more than this many frames. # This is just to prevent overwhelming OpenAI (or to cope with a stack overflow!). max_frames = 10 - + # Walk the stack and build up the frames list. while frame is not None and max_frames > 0: func_name = frame.name() @@ -118,7 +130,9 @@ def buildPrompt() -> str: try: block = frame.block() except RuntimeError: - print('Your program must be compiled with debug information (`-g`) to use `why`.') + print( + "Your program must be compiled with debug information (`-g`) to use `why`." + ) return "" for symbol in block: if symbol.is_argument: @@ -136,15 +150,17 @@ def buildPrompt() -> str: line_num = frame_info[3] arg_list = [] for arg in frame_info[2]: - arg_list.append(str(arg[1])) # Note: arg[0] is the name of the argument - stack_trace += f'frame {i}: {func_name}({",".join(arg_list)}) at {file_name}:{line_num}\n' + arg_list.append(str(arg[1])) # Note: arg[0] is the name of the argument + stack_trace += ( + f'frame {i}: {func_name}({",".join(arg_list)}) at {file_name}:{line_num}\n' + ) try: - source_code += f'/* frame {i} */\n' + source_code += f"/* frame {i} */\n" lines = read_lines_list(file_name, line_num - 10, line_num) - source_code += '\n'.join(lines) + '\n' + source_code += "\n".join(lines) + "\n" # Get the spaces before the last line. num_spaces = len(lines[-1]) - len(lines[-1].lstrip()) - source_code += ' ' * num_spaces + '^' + '-' * (79 - num_spaces) + '\n' + source_code += " " * num_spaces + "^" + "-" * (79 - num_spaces) + "\n" except: # Couldn't find source for some reason. Skip file. pass @@ -157,7 +173,5 @@ def buildPrompt() -> str: last_error_type = panic_log + "\n" + last_error_type except: pass - - return (source_code, stack_trace, last_error_type) - + return (source_code, stack_trace, last_error_type) diff --git a/src/chatdbg/chatdbg_lldb.py b/src/chatdbg/chatdbg_lldb.py index 28e8750..c9db2b4 100644 --- a/src/chatdbg/chatdbg_lldb.py +++ b/src/chatdbg/chatdbg_lldb.py @@ -6,6 +6,7 @@ import sys import pathlib + the_path = pathlib.Path(__file__).parent.resolve() # The file produced by the panic handler if the Rust program is using the chatdbg crate. @@ -17,10 +18,12 @@ from typing import Tuple, Union + def __lldb_init_module(debugger: lldb.SBDebugger, internal_dict: dict) -> None: # Update the prompt. debugger.HandleCommand("settings set prompt '(ChatDBG lldb) '") + def is_debug_build(debugger, command, result, internal_dict) -> bool: """Returns False if not compiled with debug information.""" target = debugger.GetSelectedTarget() @@ -36,6 +39,7 @@ def is_debug_build(debugger, command, result, internal_dict) -> bool: break return has_debug_symbols + def is_debug_build_prev(debugger, command, result, internal_dict) -> bool: target = debugger.GetSelectedTarget() if target: @@ -46,9 +50,10 @@ def is_debug_build_prev(debugger, command, result, internal_dict) -> bool: return True return False - + # From lldbinit + def get_process() -> Union[None, lldb.SBProcess]: """ Get the process that the current target owns. @@ -56,6 +61,7 @@ def get_process() -> Union[None, lldb.SBProcess]: """ return get_target().process + def get_frame() -> lldb.SBFrame: """ Get the current frame of the running process. @@ -66,7 +72,10 @@ def get_frame() -> lldb.SBFrame: frame = None for thread in get_process(): # Loop through the threads in the process - if thread.GetStopReason() != lldb.eStopReasonNone and thread.GetStopReason() != lldb.eStopReasonInvalid: + if ( + thread.GetStopReason() != lldb.eStopReasonNone + and thread.GetStopReason() != lldb.eStopReasonInvalid + ): # If the stop reason is not "none" or "invalid", get the frame at index 0 and break the loop. frame = thread.GetFrameAtIndex(0) break @@ -78,6 +87,7 @@ def get_frame() -> lldb.SBFrame: # Return the current frame. return frame + def get_thread() -> lldb.SBThread: """ Returns the currently stopped thread in the debugged process. @@ -87,39 +97,46 @@ def get_thread() -> lldb.SBThread: # Iterate over threads in the process for _thread in get_process(): # Check if thread is stopped for a valid reason - if _thread.GetStopReason() != lldb.eStopReasonNone and _thread.GetStopReason() != lldb.eStopReasonInvalid: + if ( + _thread.GetStopReason() != lldb.eStopReasonNone + and _thread.GetStopReason() != lldb.eStopReasonInvalid + ): thread = _thread if not thread: # No stopped thread was found pass return thread + def get_target() -> lldb.SBTarget: target = lldb.debugger.GetSelectedTarget() if not target: return None return target + def truncate_string(string, n): if len(string) <= n: return string else: return string[:n] + "..." + def buildPrompt(debugger: any) -> Tuple[str, str, str]: import os + target = get_target() if not target: - return '' + return "" thread = get_thread() if not thread: - return '' + return "" if thread.GetStopReason() == lldb.eStopReasonBreakpoint: - return '' + return "" frame = thread.GetFrameAtIndex(0) - stack_trace = '' - source_code = '' - + stack_trace = "" + source_code = "" + # magic number - don't bother walking up more than this many frames. # This is just to prevent overwhelming OpenAI (or to cope with a stack overflow!). max_frames = 10 @@ -132,7 +149,7 @@ def buildPrompt(debugger: any) -> Tuple[str, str, str]: if not function: continue full_func_name = frame.GetFunctionName() - func_name = full_func_name.split('(')[0] + func_name = full_func_name.split("(")[0] arg_list = [] type_list = [] @@ -141,14 +158,14 @@ def buildPrompt(debugger: any) -> Tuple[str, str, str]: arg = frame.FindVariable(frame.GetFunction().GetArgumentName(i)) if not arg: continue - arg_name = str(arg).split('=')[0].strip() - arg_val = str(arg).split('=')[1].strip() + arg_name = str(arg).split("=")[0].strip() + arg_val = str(arg).split("=")[1].strip() arg_list.append(f"{arg_name} = {arg_val}") - + # Get the frame variables variables = frame.GetVariables(True, True, True, True) var_list = [] - + for var in variables: name = var.GetName() value = var.GetValue() @@ -158,10 +175,12 @@ def buildPrompt(debugger: any) -> Tuple[str, str, str]: # Attempt to dereference the pointer try: deref_value = var.Dereference().GetValue() - var_list.append(f"{type} {name} = {value} (*{name} = {deref_value})") + var_list.append( + f"{type} {name} = {value} (*{name} = {deref_value})" + ) except: var_list.append(f"{type} {name} = {value}") - + line_entry = frame.GetLineEntry() file_spec = line_entry.GetFileSpec() file_name = file_spec.GetFilename() @@ -171,15 +190,27 @@ def buildPrompt(debugger: any) -> Tuple[str, str, str]: col_num = line_entry.GetColumn() max_line_length = 100 - + try: lines = chatdbg_utils.read_lines(full_file_name, line_num - 10, line_num) - stack_trace += truncate_string(f'frame {index}: {func_name}({",".join(arg_list)}) at {file_name}:{line_num}:{col_num}\n', max_line_length - 3) + '\n' # 3 accounts for ellipsis + stack_trace += ( + truncate_string( + f'frame {index}: {func_name}({",".join(arg_list)}) at {file_name}:{line_num}:{col_num}\n', + max_line_length - 3, + ) + + "\n" + ) # 3 accounts for ellipsis if len(var_list) > 0: - stack_trace += "Local variables: " + truncate_string(','.join(var_list), max_line_length) + '\n' - source_code += f'/* frame {index} in {file_name} */\n' - source_code += lines + '\n' - source_code += '-' * (chatdbg_utils.read_lines_width() + col_num - 1) + '^' + '\n\n' + stack_trace += ( + "Local variables: " + + truncate_string(",".join(var_list), max_line_length) + + "\n" + ) + source_code += f"/* frame {index} in {file_name} */\n" + source_code += lines + "\n" + source_code += ( + "-" * (chatdbg_utils.read_lines_width() + col_num - 1) + "^" + "\n\n" + ) except: # Couldn't find the source for some reason. Skip the file. continue @@ -194,33 +225,46 @@ def buildPrompt(debugger: any) -> Tuple[str, str, str]: pass return (source_code, stack_trace, error_reason) + @lldb.command("why") -def why(debugger: lldb.SBDebugger, command: str, result: str, internal_dict: dict, really_run = True) -> None: +def why( + debugger: lldb.SBDebugger, + command: str, + result: str, + internal_dict: dict, + really_run=True, +) -> None: """ Root cause analysis for an error. """ # Check if there is debug info. if not is_debug_build(debugger, command, result, internal_dict): - print('Your program must be compiled with debug information (`-g`) to use `why`.') + print( + "Your program must be compiled with debug information (`-g`) to use `why`." + ) return # Check if program is attached to a debugger. if not get_target(): - print('Must be attached to a program to ask `why`.') + print("Must be attached to a program to ask `why`.") return # Check if code has been run before executing the `why` command. thread = get_thread() if not thread: - print('Must run the code first to ask `why`.') + print("Must run the code first to ask `why`.") return # Check why code stopped running. if thread.GetStopReason() == lldb.eStopReasonBreakpoint: # Check if execution stopped at a breakpoint or an error. - print('Execution stopped at a breakpoint, not an error.') + print("Execution stopped at a breakpoint, not an error.") return the_prompt = buildPrompt(debugger) - asyncio.run(chatdbg_utils.explain(the_prompt[0], the_prompt[1], the_prompt[2], really_run)) + asyncio.run( + chatdbg_utils.explain(the_prompt[0], the_prompt[1], the_prompt[2], really_run) + ) + @lldb.command("why_prompt") -def why_prompt(debugger: lldb.SBDebugger, command: str, result: str, internal_dict: dict) -> None: +def why_prompt( + debugger: lldb.SBDebugger, command: str, result: str, internal_dict: dict +) -> None: why(debugger, command, result, internal_dict, really_run=False) - diff --git a/src/chatdbg/chatdbg_pdb.py b/src/chatdbg/chatdbg_pdb.py index 6674fb8..0a7e61e 100644 --- a/src/chatdbg/chatdbg_pdb.py +++ b/src/chatdbg/chatdbg_pdb.py @@ -89,13 +89,25 @@ class Restart(Exception): """Causes a debugger to be restarted for the debugged python program.""" + pass -__all__ = ["run", "pm", "Pdb", "runeval", "runctx", "runcall", "set_trace", - "post_mortem", "help"] + +__all__ = [ + "run", + "pm", + "Pdb", + "runeval", + "runctx", + "runcall", + "set_trace", + "post_mortem", + "help", +] + def find_function(funcname, filename): - cre = re.compile(r'def\s+%s\s*[(]' % re.escape(funcname)) + cre = re.compile(r"def\s+%s\s*[(]" % re.escape(funcname)) try: fp = tokenize.open(filename) except OSError: @@ -107,6 +119,7 @@ def find_function(funcname, filename): return funcname, filename, lineno return None + def lasti2lineno(code, lasti): linestarts = list(dis.findlinestarts(code)) linestarts.reverse() @@ -118,6 +131,7 @@ def lasti2lineno(code, lasti): class _rstr(str): """String that doesn't quote its repr.""" + def __repr__(self): return self @@ -134,7 +148,7 @@ def __new__(cls, val): def check(self): if not os.path.exists(self): - print('Error:', self.orig, 'does not exist') + print("Error:", self.orig, "does not exist") sys.exit(1) # Replace pdb's dir with script's dir in front of module search path. @@ -147,7 +161,7 @@ def filename(self): @property def namespace(self): return dict( - __name__='__main__', + __name__="__main__", __file__=self, __builtins__=__builtins__, ) @@ -169,6 +183,7 @@ def check(self): @functools.cached_property def _details(self): import runpy + return runpy._get_module_details(self) @property @@ -188,7 +203,7 @@ def _spec(self): @property def namespace(self): return dict( - __name__='__main__', + __name__="__main__", __file__=os.path.normcase(os.path.abspath(self.filename)), __package__=self._spec.parent, __loader__=self._spec.loader, @@ -202,30 +217,39 @@ def namespace(self): # be to your liking. You can set it once pdb is imported using the # command "pdb.line_prefix = '\n% '". # line_prefix = ': ' # Use this to get the old situation back -line_prefix = '\n-> ' # Probably a better default +line_prefix = "\n-> " # Probably a better default + class Pdb(bdb.Bdb, cmd.Cmd): _previous_sigint_handler = None - def __init__(self, completekey='tab', stdin=None, stdout=None, skip=None, - nosigint=False, readrc=True): + def __init__( + self, + completekey="tab", + stdin=None, + stdout=None, + skip=None, + nosigint=False, + readrc=True, + ): bdb.Bdb.__init__(self, skip=skip) cmd.Cmd.__init__(self, completekey, stdin, stdout) sys.audit("pdb.Pdb") if stdout: self.use_rawinput = 0 - self.prompt = '(ChatDBG Pdb) ' + self.prompt = "(ChatDBG Pdb) " self.aliases = {} self.displaying = {} - self.mainpyfile = '' + self.mainpyfile = "" self._wait_for_mainpyfile = False self.tb_lineno = {} # Try to load readline if it exists try: import readline + # remove some common file name delimiters - readline.set_completer_delims(' \t\n`@#$%^&*()=+[{]}\\|;:\'",<>?') + readline.set_completer_delims(" \t\n`@#$%^&*()=+[{]}\\|;:'\",<>?") except ImportError: pass self.allow_kbdint = False @@ -235,25 +259,25 @@ def __init__(self, completekey='tab', stdin=None, stdout=None, skip=None, self.rcLines = [] if readrc: try: - with open(os.path.expanduser('~/.pdbrc'), encoding='utf-8') as rcFile: + with open(os.path.expanduser("~/.pdbrc"), encoding="utf-8") as rcFile: self.rcLines.extend(rcFile) except OSError: pass try: - with open(".pdbrc", encoding='utf-8') as rcFile: + with open(".pdbrc", encoding="utf-8") as rcFile: self.rcLines.extend(rcFile) except OSError: pass - self.commands = {} # associates a command list to breakpoint numbers - self.commands_doprompt = {} # for each bp num, tells if the prompt - # must be disp. after execing the cmd list - self.commands_silent = {} # for each bp num, tells if the stack trace - # must be disp. after execing the cmd list - self.commands_defining = False # True while in the process of defining - # a command list - self.commands_bnum = None # The breakpoint number for which we are - # defining a list + self.commands = {} # associates a command list to breakpoint numbers + self.commands_doprompt = {} # for each bp num, tells if the prompt + # must be disp. after execing the cmd list + self.commands_silent = {} # for each bp num, tells if the stack trace + # must be disp. after execing the cmd list + self.commands_defining = False # True while in the process of defining + # a command list + self.commands_bnum = None # The breakpoint number for which we are + # defining a list def sigint_handler(self, signum, frame): if self.allow_kbdint: @@ -301,7 +325,7 @@ def execRcLines(self): self.rcLines = [] while rcLines: line = rcLines.pop().strip() - if line and line[0] != '#': + if line and line[0] != "#": if self.onecmd(line): # if onecmd returns True, the command wants to exit # from the interaction, save leftover rc lines @@ -317,14 +341,16 @@ def user_call(self, frame, argument_list): if self._wait_for_mainpyfile: return if self.stop_here(frame): - self.message('--Call--') + self.message("--Call--") self.interaction(frame, None) def user_line(self, frame): """This function is called when we stop or break at this line.""" if self._wait_for_mainpyfile: - if (self.mainpyfile != self.canonic(frame.f_code.co_filename) - or frame.f_lineno <= 0): + if ( + self.mainpyfile != self.canonic(frame.f_code.co_filename) + or frame.f_lineno <= 0 + ): return self._wait_for_mainpyfile = False if self.bp_commands(frame): @@ -337,8 +363,7 @@ def bp_commands(self, frame): Returns True if the normal interaction function must be called, False otherwise.""" # self.currentbp is set in bdb in Bdb.break_here if a breakpoint was hit - if getattr(self, "currentbp", False) and \ - self.currentbp in self.commands: + if getattr(self, "currentbp", False) and self.currentbp in self.commands: currentbp = self.currentbp self.currentbp = 0 lastcmd_back = self.lastcmd @@ -358,8 +383,8 @@ def user_return(self, frame, return_value): """This function is called when a return trap is set here.""" if self._wait_for_mainpyfile: return - frame.f_locals['__return__'] = return_value - self.message('--Return--') + frame.f_locals["__return__"] = return_value + self.message("--Return--") self.interaction(frame, None) def user_exception(self, frame, exc_info): @@ -368,17 +393,20 @@ def user_exception(self, frame, exc_info): if self._wait_for_mainpyfile: return exc_type, exc_value, exc_traceback = exc_info - frame.f_locals['__exception__'] = exc_type, exc_value + frame.f_locals["__exception__"] = exc_type, exc_value # An 'Internal StopIteration' exception is an exception debug event # issued by the interpreter when handling a subgenerator run with # 'yield from' or a generator controlled by a for loop. No exception has # actually occurred in this case. The debugger uses this debug event to # stop when the debuggee is returning from such generators. - prefix = 'Internal ' if (not exc_traceback - and exc_type is StopIteration) else '' - self.message('%s%s' % (prefix, - traceback.format_exception_only(exc_type, exc_value)[-1].strip())) + prefix = ( + "Internal " if (not exc_traceback and exc_type is StopIteration) else "" + ) + self.message( + "%s%s" + % (prefix, traceback.format_exception_only(exc_type, exc_value)[-1].strip()) + ) self.interaction(frame, exc_traceback) # General interaction function @@ -392,7 +420,7 @@ def _cmdloop(self): self.allow_kbdint = False break except KeyboardInterrupt: - self.message('--KeyboardInterrupt--') + self.message("--KeyboardInterrupt--") # Called before loop, handles display expressions def preloop(self): @@ -405,8 +433,9 @@ def preloop(self): # fields are changed to be displayed if newvalue is not oldvalue and newvalue != oldvalue: displaying[expr] = newvalue - self.message('display %s: %r [old: %r]' % - (expr, newvalue, oldvalue)) + self.message( + "display %s: %r [old: %r]" % (expr, newvalue, oldvalue) + ) def interaction(self, frame, traceback): # Restore the previous signal handler at the Pdb prompt. @@ -435,11 +464,12 @@ def displayhook(self, obj): self.message(repr(obj)) def default(self, line): - if line[:1] == '!': line = line[1:] + if line[:1] == "!": + line = line[1:] locals = self.curframe_locals globals = self.curframe.f_globals try: - code = compile(line + '\n', '', 'single') + code = compile(line + "\n", "", "single") save_stdout = sys.stdout save_stdin = sys.stdin save_displayhook = sys.displayhook @@ -464,18 +494,17 @@ def precmd(self, line): line = self.aliases[args[0]] ii = 1 for tmpArg in args[1:]: - line = line.replace("%" + str(ii), - tmpArg) + line = line.replace("%" + str(ii), tmpArg) ii += 1 - line = line.replace("%*", ' '.join(args[1:])) + line = line.replace("%*", " ".join(args[1:])) args = line.split() # split into ';;' separated commands # unless it's an alias command - if args[0] != 'alias': - marker = line.find(';;') + if args[0] != "alias": + marker = line.find(";;") if marker >= 0: # queue up everything after marker - next = line[marker+2:].lstrip() + next = line[marker + 2 :].lstrip() self.cmdqueue.append(next) line = line[:marker].rstrip() return line @@ -497,20 +526,20 @@ def handle_command_def(self, line): cmd, arg, line = self.parseline(line) if not cmd: return - if cmd == 'silent': + if cmd == "silent": self.commands_silent[self.commands_bnum] = True - return # continue to handle other cmd def in the cmd list - elif cmd == 'end': + return # continue to handle other cmd def in the cmd list + elif cmd == "end": self.cmdqueue = [] - return 1 # end of cmd list + return 1 # end of cmd list cmdlist = self.commands[self.commands_bnum] if arg: - cmdlist.append(cmd+' '+arg) + cmdlist.append(cmd + " " + arg) else: cmdlist.append(cmd) # Determine if we must stop try: - func = getattr(self, 'do_' + cmd) + func = getattr(self, "do_" + cmd) except AttributeError: func = self.default # one of the resuming commands @@ -526,14 +555,14 @@ def message(self, msg): print(msg, file=self.stdout) def error(self, msg): - print('***', msg, file=self.stdout) + print("***", msg, file=self.stdout) # Generic completion functions. Individual complete_foo methods can be # assigned below to one of these functions. def _complete_location(self, text, line, begidx, endidx): # Complete a file/module/function location for break/tbreak/clear. - if line.strip().endswith((':', ',')): + if line.strip().endswith((":", ",")): # Here comes a line number or a condition which we can't complete. return [] # First, try to find matching functions (i.e. expressions). @@ -542,20 +571,23 @@ def _complete_location(self, text, line, begidx, endidx): except Exception: ret = [] # Then, try to complete file names as well. - globs = glob.glob(glob.escape(text) + '*') + globs = glob.glob(glob.escape(text) + "*") for fn in globs: if os.path.isdir(fn): - ret.append(fn + '/') - elif os.path.isfile(fn) and fn.lower().endswith(('.py', '.pyw')): - ret.append(fn + ':') + ret.append(fn + "/") + elif os.path.isfile(fn) and fn.lower().endswith((".py", ".pyw")): + ret.append(fn + ":") return ret def _complete_bpnumber(self, text, line, begidx, endidx): # Complete a breakpoint number. (This would be more helpful if we could # display additional info along with the completions, such as file/line # of the breakpoint.) - return [str(i) for i, bp in enumerate(bdb.Breakpoint.bpbynumber) - if bp is not None and str(i).startswith(text)] + return [ + str(i) + for i, bp in enumerate(bdb.Breakpoint.bpbynumber) + if bp is not None and str(i).startswith(text) + ] def _complete_expression(self, text, line, begidx, endidx): # Complete an arbitrary expression. @@ -565,18 +597,18 @@ def _complete_expression(self, text, line, begidx, endidx): # complete builtins, and they clutter the namespace quite heavily, so we # leave them out. ns = {**self.curframe.f_globals, **self.curframe_locals} - if '.' in text: + if "." in text: # Walk an attribute chain up to the last part, similar to what # rlcompleter does. This will bail if any of the parts are not # simple attribute access, which is what we want. - dotted = text.split('.') + dotted = text.split(".") try: obj = ns[dotted[0]] for part in dotted[1:-1]: obj = getattr(obj, part) except (KeyError, AttributeError): return [] - prefix = '.'.join(dotted[:-1]) + '.' + prefix = ".".join(dotted[:-1]) + "." return [prefix + n for n in dir(obj) if n.startswith(dotted[-1])] else: # Complete a simple name. @@ -634,15 +666,17 @@ def do_commands(self, arg): try: self.get_bpbynumber(bnum) except ValueError as err: - self.error('cannot set commands: %s' % err) + self.error("cannot set commands: %s" % err) return self.commands_bnum = bnum # Save old definitions for the case of a keyboard interrupt. if bnum in self.commands: - old_command_defs = (self.commands[bnum], - self.commands_doprompt[bnum], - self.commands_silent[bnum]) + old_command_defs = ( + self.commands[bnum], + self.commands_doprompt[bnum], + self.commands_silent[bnum], + ) else: old_command_defs = None self.commands[bnum] = [] @@ -650,7 +684,7 @@ def do_commands(self, arg): self.commands_silent[bnum] = False prompt_back = self.prompt - self.prompt = '(com) ' + self.prompt = "(com) " self.commands_defining = True try: self.cmdloop() @@ -664,14 +698,14 @@ def do_commands(self, arg): del self.commands[bnum] del self.commands_doprompt[bnum] del self.commands_silent[bnum] - self.error('command definition aborted, old commands restored') + self.error("command definition aborted, old commands restored") finally: self.commands_defining = False self.prompt = prompt_back complete_commands = _complete_bpnumber - def do_break(self, arg, temporary = 0): + def do_break(self, arg, temporary=0): """b(reak) [ ([filename:]lineno | function) [, condition] ] Without argument, list all breaks. @@ -698,27 +732,27 @@ def do_break(self, arg, temporary = 0): filename = None lineno = None cond = None - comma = arg.find(',') + comma = arg.find(",") if comma > 0: # parse stuff after comma: "condition" - cond = arg[comma+1:].lstrip() + cond = arg[comma + 1 :].lstrip() arg = arg[:comma].rstrip() # parse stuff before comma: [filename:]lineno | function - colon = arg.rfind(':') + colon = arg.rfind(":") funcname = None if colon >= 0: filename = arg[:colon].rstrip() f = self.lookupmodule(filename) if not f: - self.error('%r not found from sys.path' % filename) + self.error("%r not found from sys.path" % filename) return else: filename = f - arg = arg[colon+1:].lstrip() + arg = arg[colon + 1 :].lstrip() try: lineno = int(arg) except ValueError: - self.error('Bad lineno: %s' % arg) + self.error("Bad lineno: %s" % arg) return else: # no colon; can be lineno or function @@ -726,17 +760,15 @@ def do_break(self, arg, temporary = 0): lineno = int(arg) except ValueError: try: - func = eval(arg, - self.curframe.f_globals, - self.curframe_locals) + func = eval(arg, self.curframe.f_globals, self.curframe_locals) except: func = arg try: - if hasattr(func, '__func__'): + if hasattr(func, "__func__"): func = func.__func__ code = func.__code__ - #use co_name to identify the bkpt (function names - #could be aliased, but co_name is invariant) + # use co_name to identify the bkpt (function names + # could be aliased, but co_name is invariant) funcname = code.co_name lineno = code.co_firstlineno filename = code.co_filename @@ -744,10 +776,12 @@ def do_break(self, arg, temporary = 0): # last thing to try (ok, filename, ln) = self.lineinfo(arg) if not ok: - self.error('The specified object %r is not a function ' - 'or was not found along sys.path.' % arg) + self.error( + "The specified object %r is not a function " + "or was not found along sys.path." % arg + ) return - funcname = ok # ok contains a function name + funcname = ok # ok contains a function name lineno = int(ln) if not filename: filename = self.defaultFile() @@ -760,14 +794,13 @@ def do_break(self, arg, temporary = 0): self.error(err) else: bp = self.get_breaks(filename, line)[-1] - self.message("Breakpoint %d at %s:%d" % - (bp.number, bp.file, bp.line)) + self.message("Breakpoint %d at %s:%d" % (bp.number, bp.file, bp.line)) # To be overridden in derived debuggers def defaultFile(self): """Produce a reasonable default.""" filename = self.curframe.f_code.co_filename - if filename == '' and self.mainpyfile: + if filename == "" and self.mainpyfile: filename = self.mainpyfile return filename @@ -797,10 +830,11 @@ def lineinfo(self, identifier): id = idstring[1].strip() else: return failed - if id == '': return failed - parts = id.split('.') + if id == "": + return failed + parts = id.split(".") # Protection for derived debuggers - if parts[0] == 'self': + if parts[0] == "self": del parts[0] if len(parts) == 0: return failed @@ -826,17 +860,16 @@ def checkline(self, filename, lineno): """ # this method should be callable before starting debugging, so default # to "no globals" if there is no current frame - frame = getattr(self, 'curframe', None) + frame = getattr(self, "curframe", None) globs = frame.f_globals if frame else None line = linecache.getline(filename, lineno, globs) if not line: - self.message('End of file') + self.message("End of file") return 0 line = line.strip() # Don't allow setting breakpoint at a blank line - if (not line or (line[0] == '#') or - (line[:3] == '"""') or line[:3] == "'''"): - self.error('Blank or comment') + if not line or (line[0] == "#") or (line[:3] == '"""') or line[:3] == "'''": + self.error("Blank or comment") return 0 return lineno @@ -853,7 +886,7 @@ def do_enable(self, arg): self.error(err) else: bp.enable() - self.message('Enabled %s' % bp) + self.message("Enabled %s" % bp) complete_enable = _complete_bpnumber @@ -873,7 +906,7 @@ def do_disable(self, arg): self.error(err) else: bp.disable() - self.message('Disabled %s' % bp) + self.message("Disabled %s" % bp) complete_disable = _complete_bpnumber @@ -884,7 +917,7 @@ def do_condition(self, arg): condition is absent, any existing condition is removed; i.e., the breakpoint is made unconditional. """ - args = arg.split(' ', 1) + args = arg.split(" ", 1) try: cond = args[1] except IndexError: @@ -892,15 +925,15 @@ def do_condition(self, arg): try: bp = self.get_bpbynumber(args[0].strip()) except IndexError: - self.error('Breakpoint number expected') + self.error("Breakpoint number expected") except ValueError as err: self.error(err) else: bp.cond = cond if not cond: - self.message('Breakpoint %d is now unconditional.' % bp.number) + self.message("Breakpoint %d is now unconditional." % bp.number) else: - self.message('New condition set for breakpoint %d.' % bp.number) + self.message("New condition set for breakpoint %d." % bp.number) complete_condition = _complete_bpnumber @@ -921,21 +954,23 @@ def do_ignore(self, arg): try: bp = self.get_bpbynumber(args[0].strip()) except IndexError: - self.error('Breakpoint number expected') + self.error("Breakpoint number expected") except ValueError as err: self.error(err) else: bp.ignore = count if count > 0: if count > 1: - countstr = '%d crossings' % count + countstr = "%d crossings" % count else: - countstr = '1 crossing' - self.message('Will ignore next %s of breakpoint %d.' % - (countstr, bp.number)) + countstr = "1 crossing" + self.message( + "Will ignore next %s of breakpoint %d." % (countstr, bp.number) + ) else: - self.message('Will stop next time breakpoint %d is reached.' - % bp.number) + self.message( + "Will stop next time breakpoint %d is reached." % bp.number + ) complete_ignore = _complete_bpnumber @@ -948,21 +983,21 @@ def do_clear(self, arg): """ if not arg: try: - reply = input('Clear all breaks? ') + reply = input("Clear all breaks? ") except EOFError: - reply = 'no' + reply = "no" reply = reply.strip().lower() - if reply in ('y', 'yes'): + if reply in ("y", "yes"): bplist = [bp for bp in bdb.Breakpoint.bpbynumber if bp] self.clear_all_breaks() for bp in bplist: - self.message('Deleted %s' % bp) + self.message("Deleted %s" % bp) return - if ':' in arg: + if ":" in arg: # Make sure it works for "clear C:\foo\bar.py:12" - i = arg.rfind(':') + i = arg.rfind(":") filename = arg[:i] - arg = arg[i+1:] + arg = arg[i + 1 :] try: lineno = int(arg) except ValueError: @@ -974,7 +1009,7 @@ def do_clear(self, arg): self.error(err) else: for bp in bplist: - self.message('Deleted %s' % bp) + self.message("Deleted %s" % bp) return numberlist = arg.split() for i in numberlist: @@ -984,8 +1019,9 @@ def do_clear(self, arg): self.error(err) else: self.clear_bpbynumber(i) - self.message('Deleted %s' % bp) - do_cl = do_clear # 'c' is already an abbreviation for 'continue' + self.message("Deleted %s" % bp) + + do_cl = do_clear # 'c' is already an abbreviation for 'continue' complete_clear = _complete_location complete_cl = _complete_location @@ -997,6 +1033,7 @@ def do_where(self, arg): context of most commands. 'bt' is an alias for this command. """ self.print_stack_trace() + do_w = do_where do_bt = do_where @@ -1014,18 +1051,19 @@ def do_up(self, arg): stack trace (to an older frame). """ if self.curindex == 0: - self.error('Oldest frame') + self.error("Oldest frame") return try: count = int(arg or 1) except ValueError: - self.error('Invalid frame count (%s)' % arg) + self.error("Invalid frame count (%s)" % arg) return if count < 0: newframe = 0 else: newframe = max(0, self.curindex - count) self._select_frame(newframe) + do_u = do_up def do_down(self, arg): @@ -1034,18 +1072,19 @@ def do_down(self, arg): stack trace (to a newer frame). """ if self.curindex + 1 == len(self.stack): - self.error('Newest frame') + self.error("Newest frame") return try: count = int(arg or 1) except ValueError: - self.error('Invalid frame count (%s)' % arg) + self.error("Invalid frame count (%s)" % arg) return if count < 0: newframe = len(self.stack) - 1 else: newframe = min(len(self.stack) - 1, self.curindex + count) self._select_frame(newframe) + do_d = do_down def do_until(self, arg): @@ -1060,16 +1099,16 @@ def do_until(self, arg): try: lineno = int(arg) except ValueError: - self.error('Error in argument: %r' % arg) + self.error("Error in argument: %r" % arg) return if lineno <= self.curframe.f_lineno: - self.error('"until" line number is smaller than current ' - 'line number') + self.error('"until" line number is smaller than current ' "line number") return else: lineno = None self.set_until(self.curframe, lineno) return 1 + do_unt = do_until def do_step(self, arg): @@ -1080,6 +1119,7 @@ def do_step(self, arg): """ self.set_step() return 1 + do_s = do_step def do_next(self, arg): @@ -1089,6 +1129,7 @@ def do_next(self, arg): """ self.set_next(self.curframe) return 1 + do_n = do_next def do_run(self, arg): @@ -1100,11 +1141,12 @@ def do_run(self, arg): """ if arg: import shlex + argv0 = sys.argv[0:1] try: sys.argv = shlex.split(arg) except ValueError as e: - self.error('Cannot run %s: %s' % (arg, e)) + self.error("Cannot run %s: %s" % (arg, e)) return sys.argv[:0] = argv0 # this is caught in the main debugger loop @@ -1118,6 +1160,7 @@ def do_return(self, arg): """ self.set_return(self.curframe) return 1 + do_r = do_return def do_continue(self, arg): @@ -1126,8 +1169,9 @@ def do_continue(self, arg): """ if not self.nosigint: try: - Pdb._previous_sigint_handler = \ - signal.signal(signal.SIGINT, self.sigint_handler) + Pdb._previous_sigint_handler = signal.signal( + signal.SIGINT, self.sigint_handler + ) except ValueError: # ValueError happens when do_continue() is invoked from # a non-main thread in which case we just continue without @@ -1136,6 +1180,7 @@ def do_continue(self, arg): pass self.set_continue() return 1 + do_c = do_cont = do_continue def do_jump(self, arg): @@ -1150,7 +1195,7 @@ def do_jump(self, arg): for loop or out of a finally clause. """ if self.curindex + 1 != len(self.stack): - self.error('You can only jump within the bottom frame') + self.error("You can only jump within the bottom frame") return try: arg = int(arg) @@ -1164,7 +1209,8 @@ def do_jump(self, arg): self.stack[self.curindex] = self.stack[self.curindex][0], arg self.print_stack_entry(self.stack[self.curindex]) except ValueError as e: - self.error('Jump failed: %s' % e) + self.error("Jump failed: %s" % e) + do_j = do_jump def do_debug(self, arg): @@ -1204,7 +1250,7 @@ def do_EOF(self, arg): """EOF Handles the receipt of EOF as a command. """ - self.message('') + self.message("") self._user_requested_quit = True self.set_quit() return 1 @@ -1216,24 +1262,28 @@ def do_args(self, arg): co = self.curframe.f_code dict = self.curframe_locals n = co.co_argcount + co.co_kwonlyargcount - if co.co_flags & inspect.CO_VARARGS: n = n+1 - if co.co_flags & inspect.CO_VARKEYWORDS: n = n+1 + if co.co_flags & inspect.CO_VARARGS: + n = n + 1 + if co.co_flags & inspect.CO_VARKEYWORDS: + n = n + 1 for i in range(n): name = co.co_varnames[i] if name in dict: - self.message('%s = %r' % (name, dict[name])) + self.message("%s = %r" % (name, dict[name])) else: - self.message('%s = *** undefined ***' % (name,)) + self.message("%s = *** undefined ***" % (name,)) + do_a = do_args def do_retval(self, arg): """retval Print the return value for the last return of a function. """ - if '__return__' in self.curframe_locals: - self.message(repr(self.curframe_locals['__return__'])) + if "__return__" in self.curframe_locals: + self.message(repr(self.curframe_locals["__return__"])) else: - self.error('Not yet returned!') + self.error("Not yet returned!") + do_rv = do_retval def _getval(self, arg): @@ -1252,7 +1302,7 @@ def _getval_except(self, arg, frame=None): except: exc_info = sys.exc_info()[:2] err = traceback.format_exception_only(*exc_info)[-1].strip() - return _rstr('** raised %s **' % err) + return _rstr("** raised %s **" % err) def _error_exc(self): exc_info = sys.exc_info()[:2] @@ -1299,12 +1349,12 @@ def do_list(self, arg): exception was originally raised or propagated is indicated by ">>", if it differs from the current line. """ - self.lastcmd = 'list' + self.lastcmd = "list" last = None - if arg and arg != '.': + if arg and arg != ".": try: - if ',' in arg: - first, last = arg.split(',') + if "," in arg: + first, last = arg.split(",") first = int(first.strip()) last = int(last.strip()) if last < first: @@ -1314,9 +1364,9 @@ def do_list(self, arg): first = int(arg.strip()) first = max(1, first - 5) except ValueError: - self.error('Error in argument: %r' % arg) + self.error("Error in argument: %r" % arg) return - elif self.lineno is None or arg == '.': + elif self.lineno is None or arg == ".": first = max(1, self.curframe.f_lineno - 5) else: first = self.lineno + 1 @@ -1332,13 +1382,13 @@ def do_list(self, arg): breaklist = self.get_file_breaks(filename) try: lines = linecache.getlines(filename, self.curframe.f_globals) - self._print_lines(lines[first-1:last], first, breaklist, - self.curframe) + self._print_lines(lines[first - 1 : last], first, breaklist, self.curframe) self.lineno = min(last, len(lines)) if len(lines) < last: - self.message('[EOF]') + self.message("[EOF]") except KeyboardInterrupt: pass + do_l = do_list def do_longlist(self, arg): @@ -1353,6 +1403,7 @@ def do_longlist(self, arg): self.error(err) return self._print_lines(lines, lineno, breaklist, self.curframe) + do_ll = do_longlist def do_source(self, arg): @@ -1382,16 +1433,16 @@ def _print_lines(self, lines, start, breaks=(), frame=None): for lineno, line in enumerate(lines, start): s = str(lineno).rjust(3) if len(s) < 4: - s += ' ' + s += " " if lineno in breaks: - s += 'B' + s += "B" else: - s += ' ' + s += " " if lineno == current_lineno: - s += '->' + s += "->" elif lineno == exc_lineno: - s += '>>' - self.message(s + '\t' + line.rstrip()) + s += ">>" + self.message(s + "\t" + line.rstrip()) def do_whatis(self, arg): """whatis arg @@ -1409,7 +1460,7 @@ def do_whatis(self, arg): except Exception: pass if code: - self.message('Method %s' % code.co_name) + self.message("Method %s" % code.co_name) return # Is it a function? try: @@ -1417,11 +1468,11 @@ def do_whatis(self, arg): except Exception: pass if code: - self.message('Function %s' % code.co_name) + self.message("Function %s" % code.co_name) return # Is it a class? if value.__class__ is type: - self.message('Class %s.%s' % (value.__module__, value.__qualname__)) + self.message("Class %s.%s" % (value.__module__, value.__qualname__)) return # None of the above... self.message(type(value)) @@ -1437,13 +1488,13 @@ def do_display(self, arg): Without expression, list all display expressions for the current frame. """ if not arg: - self.message('Currently displaying:') + self.message("Currently displaying:") for item in self.displaying.get(self.curframe, {}).items(): - self.message('%s: %r' % item) + self.message("%s: %r" % item) else: val = self._getval_except(arg) self.displaying.setdefault(self.curframe, {})[arg] = val - self.message('display %s: %r' % (arg, val)) + self.message("display %s: %r" % (arg, val)) complete_display = _complete_expression @@ -1458,13 +1509,12 @@ def do_undisplay(self, arg): try: del self.displaying.get(self.curframe, {})[arg] except KeyError: - self.error('not displaying %s' % arg) + self.error("not displaying %s" % arg) else: self.displaying.pop(self.curframe, None) def complete_undisplay(self, text, line, begidx, endidx): - return [e for e in self.displaying.get(self.curframe, {}) - if e.startswith(text)] + return [e for e in self.displaying.get(self.curframe, {}) if e.startswith(text)] def do_interact(self, arg): """interact @@ -1508,14 +1558,15 @@ def do_alias(self, arg): if args[0] in self.aliases and len(args) == 1: self.message("%s = %s" % (args[0], self.aliases[args[0]])) else: - self.aliases[args[0]] = ' '.join(args[1:]) + self.aliases[args[0]] = " ".join(args[1:]) def do_unalias(self, arg): """unalias name Delete the specified alias. """ args = arg.split() - if len(args) == 0: return + if len(args) == 0: + return if args[0] in self.aliases: del self.aliases[args[0]] @@ -1523,8 +1574,14 @@ def complete_unalias(self, text, line, begidx, endidx): return [a for a in self.aliases if a.startswith(text)] # List of all the commands making the program resume execution. - commands_resuming = ['do_continue', 'do_step', 'do_next', 'do_return', - 'do_quit', 'do_jump'] + commands_resuming = [ + "do_continue", + "do_step", + "do_next", + "do_return", + "do_quit", + "do_jump", + ] # Print a traceback starting at the top stack frame. # The most recently entered frame is printed last; @@ -1544,11 +1601,10 @@ def print_stack_trace(self): def print_stack_entry(self, frame_lineno, prompt_prefix=line_prefix): frame, lineno = frame_lineno if frame is self.curframe: - prefix = '> ' + prefix = "> " else: - prefix = ' ' - self.message(prefix + - self.format_stack_entry(frame_lineno, prompt_prefix)) + prefix = " " + self.message(prefix + self.format_stack_entry(frame_lineno, prompt_prefix)) # Provide help @@ -1563,19 +1619,21 @@ def do_help(self, arg): return cmd.Cmd.do_help(self, arg) try: try: - topic = getattr(self, 'help_' + arg) + topic = getattr(self, "help_" + arg) return topic() except AttributeError: - command = getattr(self, 'do_' + arg) + command = getattr(self, "do_" + arg) except AttributeError: - self.error('No help for %r' % arg) + self.error("No help for %r" % arg) else: if sys.flags.optimize >= 2: - self.error('No help for %r; please do not run Python with -OO ' - 'if you need command help' % arg) + self.error( + "No help for %r; please do not run Python with -OO " + "if you need command help" % arg + ) return if command.__doc__ is None: - self.error('No help for %r; __doc__ string missing' % arg) + self.error("No help for %r; __doc__ string missing" % arg) return self.message(command.__doc__.rstrip()) @@ -1591,7 +1649,7 @@ def help_exec(self): (Pdb) global list_options; list_options = ['-l'] (Pdb) """ - self.message((self.help_exec.__doc__ or '').strip()) + self.message((self.help_exec.__doc__ or "").strip()) def help_pdb(self): help() @@ -1604,14 +1662,14 @@ def lookupmodule(self, filename): lookupmodule() translates (possibly incomplete) file or module name into an absolute file name. """ - if os.path.isabs(filename) and os.path.exists(filename): + if os.path.isabs(filename) and os.path.exists(filename): return filename f = os.path.join(sys.path[0], filename) - if os.path.exists(f) and self.canonic(f) == self.mainpyfile: + if os.path.exists(f) and self.canonic(f) == self.mainpyfile: return f root, ext = os.path.splitext(filename) - if ext == '': - filename = filename + '.py' + if ext == "": + filename = filename + ".py" if os.path.isabs(filename): return filename for dirname in sys.path: @@ -1637,6 +1695,7 @@ def _run(self, target: Union[_ModuleTarget, _ScriptTarget]): # __main__ will break). Clear __main__ and replace with # the target namespace. import __main__ + __main__.__dict__.clear() __main__.__dict__.update(target.namespace) @@ -1648,15 +1707,44 @@ def _run(self, target: Union[_ModuleTarget, _ScriptTarget]): if __doc__ is not None: # unfortunately we can't guess this order from the class definition _help_order = [ - 'help', 'where', 'down', 'up', 'break', 'tbreak', 'clear', 'disable', - 'enable', 'ignore', 'condition', 'commands', 'step', 'next', 'until', - 'jump', 'return', 'retval', 'run', 'continue', 'list', 'longlist', - 'args', 'p', 'pp', 'whatis', 'source', 'display', 'undisplay', - 'interact', 'alias', 'unalias', 'debug', 'quit', + "help", + "where", + "down", + "up", + "break", + "tbreak", + "clear", + "disable", + "enable", + "ignore", + "condition", + "commands", + "step", + "next", + "until", + "jump", + "return", + "retval", + "run", + "continue", + "list", + "longlist", + "args", + "p", + "pp", + "whatis", + "source", + "display", + "undisplay", + "interact", + "alias", + "unalias", + "debug", + "quit", ] for _command in _help_order: - __doc__ += getattr(Pdb, 'do_' + _command).__doc__.strip() + '\n\n' + __doc__ += getattr(Pdb, "do_" + _command).__doc__.strip() + "\n\n" __doc__ += Pdb.help_exec.__doc__ del _help_order, _command @@ -1664,27 +1752,34 @@ def _run(self, target: Union[_ModuleTarget, _ScriptTarget]): # Simplified interface + def run(statement, globals=None, locals=None): Pdb().run(statement, globals, locals) + def runeval(expression, globals=None, locals=None): return Pdb().runeval(expression, globals, locals) + def runctx(statement, globals, locals): # B/W compatibility run(statement, globals, locals) + def runcall(*args, **kwds): return Pdb().runcall(*args, **kwds) + def set_trace(*, header=None): pdb = Pdb() if header is not None: pdb.message(header) pdb.set_trace(sys._getframe().f_back) + # Post-Mortem interface + def post_mortem(t=None): # handling the default if t is None: @@ -1692,29 +1787,35 @@ def post_mortem(t=None): # being handled, otherwise it returns None t = sys.exc_info()[2] if t is None: - raise ValueError("A valid traceback must be passed if no " - "exception is being handled") + raise ValueError( + "A valid traceback must be passed if no " "exception is being handled" + ) p = Pdb() p.reset() p.interaction(None, t) + def pm(): post_mortem(sys.last_traceback) # Main program for testing -TESTCMD = 'import x; x.main()' +TESTCMD = "import x; x.main()" + def test(): run(TESTCMD) + # print help def help(): import pydoc + pydoc.pager(__doc__) + _usage = """\ usage: pdb.py [-c command] ... [-m module | pyfile] [arg] ... @@ -1734,25 +1835,25 @@ def help(): def main(): import getopt - opts, args = getopt.getopt(sys.argv[1:], 'mhc:', ['help', 'command=']) + opts, args = getopt.getopt(sys.argv[1:], "mhc:", ["help", "command="]) if not args: print(_usage) sys.exit(2) - if any(opt in ['-h', '--help'] for opt, optarg in opts): + if any(opt in ["-h", "--help"] for opt, optarg in opts): print(_usage) sys.exit() - commands = [optarg for opt, optarg in opts if opt in ['-c', '--command']] + commands = [optarg for opt, optarg in opts if opt in ["-c", "--command"]] - module_indicated = any(opt in ['-m'] for opt, optarg in opts) + module_indicated = any(opt in ["-m"] for opt, optarg in opts) cls = _ModuleTarget if module_indicated else _ScriptTarget target = cls(args[0]) target.check() - sys.argv[:] = args # Hide "pdb.py" and pdb options from argument list + sys.argv[:] = args # Hide "pdb.py" and pdb options from argument list # Note on saving/restoring sys.argv: it's a good idea when sys.argv was # modified by the script being debugged. It's a bad idea when it was @@ -1771,7 +1872,7 @@ def main(): print("\t" + " ".join(sys.argv[1:])) except SystemExit: # In most cases SystemExit does not warrant a post-mortem session. - print("The program exited via sys.exit(). Exit status:", end=' ') + print("The program exited via sys.exit(). Exit status:", end=" ") print(sys.exc_info()[1]) except SyntaxError: traceback.print_exc() @@ -1782,11 +1883,11 @@ def main(): print("Running 'cont' or 'step' will restart the program") t = sys.exc_info()[2] pdb.interaction(None, t) - print("Post mortem debugger finished. The " + target + - " will be restarted") + print("Post mortem debugger finished. The " + target + " will be restarted") # When invoked as main program, invoke the debugger on a script -if __name__ == '__main__': +if __name__ == "__main__": import pdb + pdb.main() diff --git a/src/chatdbg/chatdbg_utils.py b/src/chatdbg/chatdbg_utils.py index 79fcc98..4a0aa0d 100644 --- a/src/chatdbg/chatdbg_utils.py +++ b/src/chatdbg/chatdbg_utils.py @@ -4,18 +4,21 @@ import sys import textwrap + def get_model() -> str: - all_models = ['gpt-4', 'gpt-3.5-turbo'] - - if not 'OPENAI_API_MODEL' in os.environ: - model = 'gpt-4' + all_models = ["gpt-4", "gpt-3.5-turbo"] + + if not "OPENAI_API_MODEL" in os.environ: + model = "gpt-4" else: - model = os.environ['OPENAI_API_MODEL'] + model = os.environ["OPENAI_API_MODEL"] if model not in all_models: - print(f'The environment variable OPENAI_API_MODEL is currently set to "{model}".') - print(f'The only valid values are {all_models}.') + print( + f'The environment variable OPENAI_API_MODEL is currently set to "{model}".' + ) + print(f"The only valid values are {all_models}.") return "" - + return model @@ -65,6 +68,7 @@ def word_wrap_except_code_blocks(text: str) -> str: wrapped_text = "\n\n".join(wrapped_paragraphs) return wrapped_text + def word_wrap_except_code_blocks_previous(text: str) -> str: """Wraps text except for code blocks. @@ -79,23 +83,23 @@ def word_wrap_except_code_blocks_previous(text: str) -> str: The wrapped text. """ # Split text into paragraphs - paragraphs = text.split('\n\n') + paragraphs = text.split("\n\n") wrapped_paragraphs = [] # Check if currently in a code block. in_code_block = False # Loop through each paragraph and apply appropriate wrapping. for paragraph in paragraphs: # If this paragraph starts and ends with a code block, add it as is. - if paragraph.startswith('```') and paragraph.endswith('```'): + if paragraph.startswith("```") and paragraph.endswith("```"): wrapped_paragraphs.append(paragraph) continue # If this is the beginning of a code block add it as is. - if paragraph.startswith('```'): + if paragraph.startswith("```"): in_code_block = True wrapped_paragraphs.append(paragraph) continue # If this is the end of a code block stop skipping text. - if paragraph.endswith('```'): + if paragraph.endswith("```"): in_code_block = False wrapped_paragraphs.append(paragraph) continue @@ -107,12 +111,14 @@ def word_wrap_except_code_blocks_previous(text: str) -> str: wrapped_paragraph = textwrap.fill(paragraph) wrapped_paragraphs.append(wrapped_paragraph) # Join all paragraphs into a single string - wrapped_text = '\n\n'.join(wrapped_paragraphs) + wrapped_text = "\n\n".join(wrapped_paragraphs) return wrapped_text + def read_lines_width() -> int: return 10 + def read_lines(file_path: str, start_line: int, end_line: int) -> str: """ Read lines from a file and return a string containing the lines between start_line and end_line. @@ -127,7 +133,7 @@ def read_lines(file_path: str, start_line: int, end_line: int) -> str: """ # open the file for reading - with open(file_path, 'r') as f: + with open(file_path, "r") as f: # read all the lines from the file lines = f.readlines() # remove trailing newline characters @@ -139,44 +145,59 @@ def read_lines(file_path: str, start_line: int, end_line: int) -> str: # ensure end_line is within range end_line = min(len(lines), end_line) # return the requested lines as a string - return '\n'.join(lines[start_line:end_line]) + return "\n".join(lines[start_line:end_line]) + -async def explain(source_code: str, traceback: str, exception: str, really_run = True) -> None: +async def explain( + source_code: str, traceback: str, exception: str, really_run=True +) -> None: import httpx + user_prompt = "Explain what the root cause of this error is, given the following source code context for each stack frame and a traceback, and propose a fix. In your response, never refer to the frames given below (as in, 'frame 0'). Instead, always refer only to specific lines and filenames of source code.\n" - user_prompt += '\n' - user_prompt += 'Source code for each stack frame:\n```\n' - user_prompt += source_code + '\n```\n' - user_prompt += traceback + '\n\n' - user_prompt += 'stop reason = ' + exception + '\n' - text = '' - + user_prompt += "\n" + user_prompt += "Source code for each stack frame:\n```\n" + user_prompt += source_code + "\n```\n" + user_prompt += traceback + "\n\n" + user_prompt += "stop reason = " + exception + "\n" + text = "" + if not really_run: print(user_prompt) return - if not 'OPENAI_API_KEY' in os.environ: - print('You need a valid OpenAI key to use ChatDBG. You can get a key here: https://openai.com/api/') - print('Set the environment variable OPENAI_API_KEY to your key value.') + if not "OPENAI_API_KEY" in os.environ: + print( + "You need a valid OpenAI key to use ChatDBG. You can get a key here: https://openai.com/api/" + ) + print("Set the environment variable OPENAI_API_KEY to your key value.") return model = get_model() if not model: return - + try: - completion = await openai_async.chat_complete(openai.api_key, timeout=30, payload={'model': f'{model}', 'messages': [{'role': 'user', 'content': user_prompt}]}) + completion = await openai_async.chat_complete( + openai.api_key, + timeout=30, + payload={ + "model": f"{model}", + "messages": [{"role": "user", "content": user_prompt}], + }, + ) json_payload = completion.json() - text = json_payload['choices'][0]['message']['content'] + text = json_payload["choices"][0]["message"]["content"] except (openai.error.AuthenticationError, httpx.LocalProtocolError, KeyError): # Something went wrong. print() - print('You need a valid OpenAI key to use ChatDBG. You can get a key here: https://openai.com/api/') - print('Set the environment variable OPENAI_API_KEY to your key value.') + print( + "You need a valid OpenAI key to use ChatDBG. You can get a key here: https://openai.com/api/" + ) + print("Set the environment variable OPENAI_API_KEY to your key value.") import sys + sys.exit(1) except Exception as e: - print(f'EXCEPTION {e}, {type(e)}') + print(f"EXCEPTION {e}, {type(e)}") pass print(word_wrap_except_code_blocks(text)) - diff --git a/src/chatdbg/chatdbg_why.py b/src/chatdbg/chatdbg_why.py index 02904e2..3aa371c 100644 --- a/src/chatdbg/chatdbg_why.py +++ b/src/chatdbg/chatdbg_why.py @@ -6,6 +6,7 @@ import chatdbg_utils + async def why(self, arg): user_prompt = "Explain what the root cause of this error is, given the following source code and traceback, and generate code that fixes the error." user_prompt += "\n" @@ -52,7 +53,7 @@ async def why(self, arg): + " " * (positions.col_offset - leading_spaces) + "^" * (positions.end_col_offset - positions.col_offset) + "\n" - ) + ) if index >= lineno: break except: @@ -63,15 +64,15 @@ async def why(self, arg): user_prompt += f"```\n{stack_trace}```\n" user_prompt += f"Exception: {exception_name} ({exception_value})\n" - #print(user_prompt) - #return - + # print(user_prompt) + # return + import httpx model = chatdbg_utils.get_model() if not model: return - + text = "" try: completion = await openai_async.chat_complete( @@ -83,7 +84,7 @@ async def why(self, arg): }, ) json_payload = completion.json() - if not 'choices' in json_payload: + if not "choices" in json_payload: raise openai.error.AuthenticationError text = json_payload["choices"][0]["message"]["content"] except (openai.error.AuthenticationError, httpx.LocalProtocolError):