Skip to content

Commit

Permalink
fix: Stop exiting interpreter on error
Browse files Browse the repository at this point in the history
  • Loading branch information
mark-koch committed Jan 25, 2024
1 parent 8221385 commit 4765969
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 27 deletions.
70 changes: 47 additions & 23 deletions guppylang/error.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,16 @@
import functools
import sys
import textwrap
from collections.abc import Callable, Sequence
from collections.abc import Callable, Iterator, Sequence
from contextlib import contextmanager
from dataclasses import dataclass, field
from types import TracebackType
from typing import Any, TypeVar, cast

from guppylang.ast_util import AstNode, get_file, get_line_offset, get_source
from guppylang.gtypes import BoundTypeVar, ExistentialTypeVar, FunctionType, GuppyType
from guppylang.hugr.hugr import Node, OutPortV

# Whether the interpreter should exit when a Guppy error occurs
EXIT_ON_ERROR: bool = True


@dataclass(frozen=True)
class SourceLoc:
Expand Down Expand Up @@ -136,6 +135,18 @@ def unsolved_vars(self) -> set[ExistentialTypeVar]:
return set()


ExceptHook = Callable[[type[BaseException], BaseException, TracebackType | None], Any]


@contextmanager
def exception_hook(hook: ExceptHook) -> Iterator[None]:
"""Sets a custom `excepthook` for the scope of a 'with' block."""
old_hook = sys.excepthook
sys.excepthook = hook
yield
sys.excepthook = old_hook


def format_source_location(
loc: ast.AST,
num_lines: int = 3,
Expand Down Expand Up @@ -169,27 +180,40 @@ def format_source_location(
def pretty_errors(f: FuncT) -> FuncT:
"""Decorator to print custom error banners when a `GuppyError` occurs."""

def hook(
excty: type[BaseException], err: BaseException, traceback: TracebackType | None
) -> None:
"""Custom `excepthook` that intercepts `GuppyExceptions` for pretty printing."""
# Fall back to default hook if it's not a GuppyException or we're missing an
# error location
if not isinstance(err, GuppyError) or err.location is None:
sys.__excepthook__(excty, err, traceback)
return

loc = err.location
file, line_offset = get_file(loc), get_line_offset(loc)
assert file is not None
assert line_offset is not None
line = line_offset + loc.lineno - 1
sys.stderr.write(
f"Guppy compilation failed. Error in file {file}:{line}\n\n"
f"{format_source_location(loc)}\n"
f"{err.__class__.__name__}: {err.get_msg()}\n",
)

@functools.wraps(f)
def pretty_errors_wrapped(*args: Any, **kwargs: Any) -> Any:
try:
return f(*args, **kwargs)
except GuppyError as err:
# Reraise if we're missing a location
if not err.location:
with exception_hook(hook):
try:
return f(*args, **kwargs)
except GuppyError as err:
# For normal usage, this `try` block is not necessary since the
# excepthook is automatically invoked when the exception (which is being
# reraised below) is not handled. However, when running tests, we have
# to manually invoke the hook to print the error message, since the
# tests always have to capture exceptions.
if "pytest" in sys.modules:
hook(type(err), err, err.__traceback__)
raise
loc = err.location
file, line_offset = get_file(loc), get_line_offset(loc)
assert file is not None
assert line_offset is not None
line = line_offset + loc.lineno - 1
print( # noqa: T201
f"Guppy compilation failed. Error in file {file}:{line}\n\n"
f"{format_source_location(loc)}\n"
f"{err.__class__.__name__}: {err.get_msg()}",
file=sys.stderr,
)
if EXIT_ON_ERROR:
sys.exit(1)
return None

return cast(FuncT, pretty_errors_wrapped)
6 changes: 2 additions & 4 deletions tests/error/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,7 @@
import pathlib
import pytest

from typing import Any
from collections.abc import Callable

from guppylang.error import GuppyError
from guppylang.hugr import tys
from guppylang.hugr.tys import TypeBound
from guppylang.module import GuppyModule
Expand All @@ -17,7 +15,7 @@ def run_error_test(file, capsys):
spec = importlib.util.spec_from_file_location("test_module", file)
py_module = importlib.util.module_from_spec(spec)

with pytest.raises(SystemExit):
with pytest.raises(GuppyError):
spec.loader.exec_module(py_module)

err = capsys.readouterr().err
Expand Down

0 comments on commit 4765969

Please sign in to comment.