From 4765969af7406163b236360a4258967230db8d06 Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Thu, 25 Jan 2024 15:05:01 +0000 Subject: [PATCH] fix: Stop exiting interpreter on error --- guppylang/error.py | 70 ++++++++++++++++++++++++++++++--------------- tests/error/util.py | 6 ++-- 2 files changed, 49 insertions(+), 27 deletions(-) diff --git a/guppylang/error.py b/guppylang/error.py index 8e8ae4bc..5f7c4e7c 100644 --- a/guppylang/error.py +++ b/guppylang/error.py @@ -2,17 +2,16 @@ import functools import sys import textwrap -from collections.abc import Callable, Sequence +from collections.abc import Callable, Iterator, Sequence +from contextlib import contextmanager from dataclasses import dataclass, field +from types import TracebackType from typing import Any, TypeVar, cast from guppylang.ast_util import AstNode, get_file, get_line_offset, get_source from guppylang.gtypes import BoundTypeVar, ExistentialTypeVar, FunctionType, GuppyType from guppylang.hugr.hugr import Node, OutPortV -# Whether the interpreter should exit when a Guppy error occurs -EXIT_ON_ERROR: bool = True - @dataclass(frozen=True) class SourceLoc: @@ -136,6 +135,18 @@ def unsolved_vars(self) -> set[ExistentialTypeVar]: return set() +ExceptHook = Callable[[type[BaseException], BaseException, TracebackType | None], Any] + + +@contextmanager +def exception_hook(hook: ExceptHook) -> Iterator[None]: + """Sets a custom `excepthook` for the scope of a 'with' block.""" + old_hook = sys.excepthook + sys.excepthook = hook + yield + sys.excepthook = old_hook + + def format_source_location( loc: ast.AST, num_lines: int = 3, @@ -169,27 +180,40 @@ def format_source_location( def pretty_errors(f: FuncT) -> FuncT: """Decorator to print custom error banners when a `GuppyError` occurs.""" + def hook( + excty: type[BaseException], err: BaseException, traceback: TracebackType | None + ) -> None: + """Custom `excepthook` that intercepts `GuppyExceptions` for pretty printing.""" + # Fall back to default hook if it's not a GuppyException or we're missing an + # error location + if not isinstance(err, GuppyError) or err.location is None: + sys.__excepthook__(excty, err, traceback) + return + + loc = err.location + file, line_offset = get_file(loc), get_line_offset(loc) + assert file is not None + assert line_offset is not None + line = line_offset + loc.lineno - 1 + sys.stderr.write( + f"Guppy compilation failed. Error in file {file}:{line}\n\n" + f"{format_source_location(loc)}\n" + f"{err.__class__.__name__}: {err.get_msg()}\n", + ) + @functools.wraps(f) def pretty_errors_wrapped(*args: Any, **kwargs: Any) -> Any: - try: - return f(*args, **kwargs) - except GuppyError as err: - # Reraise if we're missing a location - if not err.location: + with exception_hook(hook): + try: + return f(*args, **kwargs) + except GuppyError as err: + # For normal usage, this `try` block is not necessary since the + # excepthook is automatically invoked when the exception (which is being + # reraised below) is not handled. However, when running tests, we have + # to manually invoke the hook to print the error message, since the + # tests always have to capture exceptions. + if "pytest" in sys.modules: + hook(type(err), err, err.__traceback__) raise - loc = err.location - file, line_offset = get_file(loc), get_line_offset(loc) - assert file is not None - assert line_offset is not None - line = line_offset + loc.lineno - 1 - print( # noqa: T201 - f"Guppy compilation failed. Error in file {file}:{line}\n\n" - f"{format_source_location(loc)}\n" - f"{err.__class__.__name__}: {err.get_msg()}", - file=sys.stderr, - ) - if EXIT_ON_ERROR: - sys.exit(1) - return None return cast(FuncT, pretty_errors_wrapped) diff --git a/tests/error/util.py b/tests/error/util.py index 821aa2f4..c77141e4 100644 --- a/tests/error/util.py +++ b/tests/error/util.py @@ -2,9 +2,7 @@ import pathlib import pytest -from typing import Any -from collections.abc import Callable - +from guppylang.error import GuppyError from guppylang.hugr import tys from guppylang.hugr.tys import TypeBound from guppylang.module import GuppyModule @@ -17,7 +15,7 @@ def run_error_test(file, capsys): spec = importlib.util.spec_from_file_location("test_module", file) py_module = importlib.util.module_from_spec(spec) - with pytest.raises(SystemExit): + with pytest.raises(GuppyError): spec.loader.exec_module(py_module) err = capsys.readouterr().err