diff --git a/assets/screenshot.png b/assets/screenshot.png index cd12313..d8c3a91 100644 Binary files a/assets/screenshot.png and b/assets/screenshot.png differ diff --git a/pytest_rich.py b/pytest_rich.py index a84e687..092157f 100644 --- a/pytest_rich.py +++ b/pytest_rich.py @@ -13,17 +13,37 @@ import attr import pytest +from _pytest._code.code import ExceptionChainRepr from _pytest._code.code import ExceptionRepr +from _pytest._code.code import ReprFuncArgs +from pygments.token import Comment +from pygments.token import Keyword +from pygments.token import Name +from pygments.token import Number +from pygments.token import Operator +from pygments.token import String +from pygments.token import Text as TextToken +from pygments.token import Token +from rich._loop import loop_last from rich.columns import Columns from rich.console import Console +from rich.console import ConsoleOptions +from rich.console import ConsoleRenderable from rich.console import Group +from rich.console import group +from rich.console import RenderResult +from rich.highlighter import ReprHighlighter from rich.live import Live -from rich.markdown import Markdown from rich.panel import Panel from rich.progress import Progress from rich.progress import SpinnerColumn from rich.progress import TaskID from rich.rule import Rule +from rich.style import Style +from rich.syntax import Syntax +from rich.text import Text +from rich.theme import Theme +from rich.traceback import PathHighlighter if sys.version_info < (3, 8): from typing_extensions import Literal @@ -215,10 +235,11 @@ def pytest_sessionfinish( self.runtest_progress.stop() self.runtest_progress = None self.runtest_tasks_per_file.clear() - for nodeid, report in self.failed_reports.items(): - m = Markdown(f"```python-traceback\n{report.longrepr}\n```") - self.console.print(Rule(f"[magenta]{nodeid}[/magenta]", style="red")) - self.console.print(m) + if self.failed_reports: + self.console.print(Rule("FAILURES", style="red")) + for nodeid, report in self.failed_reports.items(): + tb = RichExceptionChainRepr(nodeid, report.longrepr) + self.console.print(tb) def pytest_keyboard_interrupt( self, excinfo: pytest.ExceptionInfo[BaseException] @@ -227,3 +248,214 @@ def pytest_keyboard_interrupt( def pytest_unconfigure(self) -> None: ... + + +@attr.s(auto_attribs=True) +class RichExceptionChainRepr: + """ + A rich representation of an ExceptionChainRepr produced by pytest. + + This is needed because pytest does not provide the actual traceback + object, which Rich's `Traceback` class requires. + """ + + nodeid: str + chain: ExceptionChainRepr + extra_lines: int = 3 + theme: Optional[str] = "ansi_dark" + word_wrap: bool = True + indent_guides: bool = True + + def __attrs_post_init__(self): + self.theme = Syntax.get_theme(self.theme) + + def __rich_console__( + self, console: Console, options: ConsoleOptions + ) -> RenderResult: + theme = self.theme + background_style = theme.get_background_style() + token_style = theme.get_style_for_token + + traceback_theme = Theme( + { + "pretty": token_style(TextToken), + "pygments.text": token_style(Token), + "pygments.string": token_style(String), + "pygments.function": token_style(Name.Function), + "pygments.number": token_style(Number), + "repr.indent": token_style(Comment) + Style(dim=True), + "repr.str": token_style(String), + "repr.brace": token_style(TextToken) + Style(bold=True), + "repr.number": token_style(Number), + "repr.bool_true": token_style(Keyword.Constant), + "repr.bool_false": token_style(Keyword.Constant), + "repr.none": token_style(Keyword.Constant), + "scope.border": token_style(String.Delimiter), + "scope.equals": token_style(Operator), + "scope.key": token_style(Name), + "scope.key.special": token_style(Name.Constant) + Style(dim=True), + }, + inherit=False, + ) + + stack_renderable: ConsoleRenderable = Panel( + self._render_chain(self.chain, options), + title=f"[magenta]{self.nodeid}[/magenta]", + style=background_style, + border_style="traceback.border", + expand=True, + padding=(0, 1), + ) + with console.use_theme(traceback_theme): + yield stack_renderable + + path_highlighter = PathHighlighter() + for entry in self.chain.reprtraceback.reprentries: + if entry.reprfileloc.message: + yield Text.assemble( + path_highlighter( + Text(entry.reprfileloc.path, style="pygments.string") + ), + (":", "pygments.text"), + (str(entry.reprfileloc.lineno), "pygments.number"), + (": ", "pygments.text"), + Text(entry.reprfileloc.message, style="pygments.string"), + style="pygments.text", + ) + yield "" + + @group() + def _render_chain( + self, chain: ExceptionChainRepr, options: ConsoleOptions + ) -> RenderResult: + path_highlighter = PathHighlighter() + repr_highlighter = ReprHighlighter() + theme = self.theme + code_cache: Dict[str, str] = {} + + def read_code(filename: str) -> str: + """ + Read files and cache results on filename. + + Args: + filename (str): Filename to read + + Returns: + str: Contents of file + """ + code = code_cache.get(filename) + if not code: + with open( + filename, "rt", encoding="utf-8", errors="replace" + ) as code_file: + code = code_file.read() + code_cache[filename] = code + return code + + def guess_funcname(lineno: int, filename: str) -> str: + """ + Get the nearest function name + + Args: + lineno (int): Line number to start searching from + filename (str): Filename to read + + Returns: + str: Function name + """ + code = read_code(filename) + lines = code.splitlines() + while True: + line = lines[lineno - 1] + if line.startswith("def "): + return line.split("def ")[1].split("(")[0] + lineno -= 1 + if lineno == 0: + return "???" + + def get_args(reprfuncargs: ReprFuncArgs) -> str: + args = Text("") + for arg in reprfuncargs.args: + args.append( + Text.assemble( + (arg[0], "name.variable"), + (" = ", "repr.equals"), + (arg[1], "token"), + ) + ) + if reprfuncargs.args[-1] != arg: + args.append(Text(", ")) + return args + + def get_error_source(lines: List[str]) -> str: + for line in lines: + if line.startswith(">"): + return line.split(">")[1].strip() + + def get_err_msgs(lines: List[str]) -> str: + err_lines = [] + for line in lines: + if line.startswith("E"): + err_lines.append(line[1:].strip()) + return err_lines + + for last, entry in loop_last(chain.reprtraceback.reprentries): + filename = entry.reprfileloc.path + lineno = entry.reprfileloc.lineno + funcname = guess_funcname(lineno, filename) + message = entry.reprfileloc.message + + text = Text.assemble( + path_highlighter(Text(filename, style="pygments.string")), + (":", "pygments.text"), + (str(lineno), "pygments.number"), + " in ", + (funcname, "pygments.function"), + style="pygments.text", + ) + yield text + + args = get_args(entry.reprfuncargs) + if args: + yield args + + code = read_code(filename) + syntax = Syntax( + code, + "python", + theme=theme, + line_numbers=True, + line_range=( + lineno - self.extra_lines, + lineno + self.extra_lines, + ), + highlight_lines={lineno}, + word_wrap=self.word_wrap, + code_width=88, + indent_guides=self.indent_guides, + dedent=False, + ) + yield "" + yield syntax + + if message: + line_pointer = "> " if options.legacy_windows else "❱ " + yield "" + yield Text.assemble( + (str(lineno), "pygments.number"), + ": ", + (message, "traceback.exc_type"), + ) + yield Text.assemble( + (line_pointer, Style(color="red")), + repr_highlighter(get_error_source(entry.lines)), + ) + for err_msg in get_err_msgs(entry.lines): + yield Text.assemble( + ("E ", Style(color="red")), + repr_highlighter(err_msg), + ) + + if not last: + yield "" + yield Rule(style=Style(color="red", dim=True))