Skip to content

Commit

Permalink
Add RichExceptionChainRepr for richer exception highlighting (#8)
Browse files Browse the repository at this point in the history
  • Loading branch information
joshuadavidthomas authored Feb 23, 2022
1 parent 6aeb9b1 commit 810befd
Show file tree
Hide file tree
Showing 2 changed files with 237 additions and 5 deletions.
Binary file modified assets/screenshot.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
242 changes: 237 additions & 5 deletions pytest_rich.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,37 @@

import attr
import pytest
from _pytest._code.code import ExceptionChainRepr
from _pytest._code.code import ExceptionRepr
from _pytest._code.code import ReprFuncArgs
from pygments.token import Comment
from pygments.token import Keyword
from pygments.token import Name
from pygments.token import Number
from pygments.token import Operator
from pygments.token import String
from pygments.token import Text as TextToken
from pygments.token import Token
from rich._loop import loop_last
from rich.columns import Columns
from rich.console import Console
from rich.console import ConsoleOptions
from rich.console import ConsoleRenderable
from rich.console import Group
from rich.console import group
from rich.console import RenderResult
from rich.highlighter import ReprHighlighter
from rich.live import Live
from rich.markdown import Markdown
from rich.panel import Panel
from rich.progress import Progress
from rich.progress import SpinnerColumn
from rich.progress import TaskID
from rich.rule import Rule
from rich.style import Style
from rich.syntax import Syntax
from rich.text import Text
from rich.theme import Theme
from rich.traceback import PathHighlighter

if sys.version_info < (3, 8):
from typing_extensions import Literal
Expand Down Expand Up @@ -215,10 +235,11 @@ def pytest_sessionfinish(
self.runtest_progress.stop()
self.runtest_progress = None
self.runtest_tasks_per_file.clear()
for nodeid, report in self.failed_reports.items():
m = Markdown(f"```python-traceback\n{report.longrepr}\n```")
self.console.print(Rule(f"[magenta]{nodeid}[/magenta]", style="red"))
self.console.print(m)
if self.failed_reports:
self.console.print(Rule("FAILURES", style="red"))
for nodeid, report in self.failed_reports.items():
tb = RichExceptionChainRepr(nodeid, report.longrepr)
self.console.print(tb)

def pytest_keyboard_interrupt(
self, excinfo: pytest.ExceptionInfo[BaseException]
Expand All @@ -227,3 +248,214 @@ def pytest_keyboard_interrupt(

def pytest_unconfigure(self) -> None:
...


@attr.s(auto_attribs=True)
class RichExceptionChainRepr:
"""
A rich representation of an ExceptionChainRepr produced by pytest.
This is needed because pytest does not provide the actual traceback
object, which Rich's `Traceback` class requires.
"""

nodeid: str
chain: ExceptionChainRepr
extra_lines: int = 3
theme: Optional[str] = "ansi_dark"
word_wrap: bool = True
indent_guides: bool = True

def __attrs_post_init__(self):
self.theme = Syntax.get_theme(self.theme)

def __rich_console__(
self, console: Console, options: ConsoleOptions
) -> RenderResult:
theme = self.theme
background_style = theme.get_background_style()
token_style = theme.get_style_for_token

traceback_theme = Theme(
{
"pretty": token_style(TextToken),
"pygments.text": token_style(Token),
"pygments.string": token_style(String),
"pygments.function": token_style(Name.Function),
"pygments.number": token_style(Number),
"repr.indent": token_style(Comment) + Style(dim=True),
"repr.str": token_style(String),
"repr.brace": token_style(TextToken) + Style(bold=True),
"repr.number": token_style(Number),
"repr.bool_true": token_style(Keyword.Constant),
"repr.bool_false": token_style(Keyword.Constant),
"repr.none": token_style(Keyword.Constant),
"scope.border": token_style(String.Delimiter),
"scope.equals": token_style(Operator),
"scope.key": token_style(Name),
"scope.key.special": token_style(Name.Constant) + Style(dim=True),
},
inherit=False,
)

stack_renderable: ConsoleRenderable = Panel(
self._render_chain(self.chain, options),
title=f"[magenta]{self.nodeid}[/magenta]",
style=background_style,
border_style="traceback.border",
expand=True,
padding=(0, 1),
)
with console.use_theme(traceback_theme):
yield stack_renderable

path_highlighter = PathHighlighter()
for entry in self.chain.reprtraceback.reprentries:
if entry.reprfileloc.message:
yield Text.assemble(
path_highlighter(
Text(entry.reprfileloc.path, style="pygments.string")
),
(":", "pygments.text"),
(str(entry.reprfileloc.lineno), "pygments.number"),
(": ", "pygments.text"),
Text(entry.reprfileloc.message, style="pygments.string"),
style="pygments.text",
)
yield ""

@group()
def _render_chain(
self, chain: ExceptionChainRepr, options: ConsoleOptions
) -> RenderResult:
path_highlighter = PathHighlighter()
repr_highlighter = ReprHighlighter()
theme = self.theme
code_cache: Dict[str, str] = {}

def read_code(filename: str) -> str:
"""
Read files and cache results on filename.
Args:
filename (str): Filename to read
Returns:
str: Contents of file
"""
code = code_cache.get(filename)
if not code:
with open(
filename, "rt", encoding="utf-8", errors="replace"
) as code_file:
code = code_file.read()
code_cache[filename] = code
return code

def guess_funcname(lineno: int, filename: str) -> str:
"""
Get the nearest function name
Args:
lineno (int): Line number to start searching from
filename (str): Filename to read
Returns:
str: Function name
"""
code = read_code(filename)
lines = code.splitlines()
while True:
line = lines[lineno - 1]
if line.startswith("def "):
return line.split("def ")[1].split("(")[0]
lineno -= 1
if lineno == 0:
return "???"

def get_args(reprfuncargs: ReprFuncArgs) -> str:
args = Text("")
for arg in reprfuncargs.args:
args.append(
Text.assemble(
(arg[0], "name.variable"),
(" = ", "repr.equals"),
(arg[1], "token"),
)
)
if reprfuncargs.args[-1] != arg:
args.append(Text(", "))
return args

def get_error_source(lines: List[str]) -> str:
for line in lines:
if line.startswith(">"):
return line.split(">")[1].strip()

def get_err_msgs(lines: List[str]) -> str:
err_lines = []
for line in lines:
if line.startswith("E"):
err_lines.append(line[1:].strip())
return err_lines

for last, entry in loop_last(chain.reprtraceback.reprentries):
filename = entry.reprfileloc.path
lineno = entry.reprfileloc.lineno
funcname = guess_funcname(lineno, filename)
message = entry.reprfileloc.message

text = Text.assemble(
path_highlighter(Text(filename, style="pygments.string")),
(":", "pygments.text"),
(str(lineno), "pygments.number"),
" in ",
(funcname, "pygments.function"),
style="pygments.text",
)
yield text

args = get_args(entry.reprfuncargs)
if args:
yield args

code = read_code(filename)
syntax = Syntax(
code,
"python",
theme=theme,
line_numbers=True,
line_range=(
lineno - self.extra_lines,
lineno + self.extra_lines,
),
highlight_lines={lineno},
word_wrap=self.word_wrap,
code_width=88,
indent_guides=self.indent_guides,
dedent=False,
)
yield ""
yield syntax

if message:
line_pointer = "> " if options.legacy_windows else "❱ "
yield ""
yield Text.assemble(
(str(lineno), "pygments.number"),
": ",
(message, "traceback.exc_type"),
)
yield Text.assemble(
(line_pointer, Style(color="red")),
repr_highlighter(get_error_source(entry.lines)),
)
for err_msg in get_err_msgs(entry.lines):
yield Text.assemble(
("E ", Style(color="red")),
repr_highlighter(err_msg),
)

if not last:
yield ""
yield Rule(style=Style(color="red", dim=True))

0 comments on commit 810befd

Please sign in to comment.