Skip to content

Commit

Permalink
feat: Update remaining code to use new diagnostics (#606)
Browse files Browse the repository at this point in the history
  • Loading branch information
mark-koch authored Nov 11, 2024
1 parent c4a2aca commit 130282d
Show file tree
Hide file tree
Showing 25 changed files with 240 additions and 170 deletions.
38 changes: 27 additions & 11 deletions guppylang/cfg/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
import copy
import itertools
from collections.abc import Iterator
from typing import NamedTuple
from dataclasses import dataclass
from typing import ClassVar, NamedTuple

from guppylang.ast_util import (
AstVisitor,
Expand All @@ -15,6 +16,8 @@
from guppylang.cfg.bb import BB, BBStatement
from guppylang.cfg.cfg import CFG
from guppylang.checker.core import Globals
from guppylang.checker.errors.generic import ExpectedError, UnsupportedError
from guppylang.diagnostic import Error
from guppylang.error import GuppyError, InternalGuppyError
from guppylang.experimental import check_lists_enabled
from guppylang.nodes import (
Expand Down Expand Up @@ -47,6 +50,12 @@ class Jumps(NamedTuple):
break_bb: BB | None


@dataclass(frozen=True)
class UnreachableError(Error):
title: ClassVar[str] = "Unreachable"
span_label: ClassVar[str] = "This code is not reachable"


class CFGBuilder(AstVisitor[BB | None]):
"""Constructs a CFG from ast nodes."""

Expand All @@ -71,7 +80,7 @@ def build(self, nodes: list[ast.stmt], returns_none: bool, globals: Globals) ->
# an implicit void return
if final_bb is not None:
if not returns_none:
raise GuppyError("Expected return statement", nodes[-1])
raise GuppyError(ExpectedError(nodes[-1], "return statement"))
self.cfg.link(final_bb, self.cfg.exit_bb)

return self.cfg
Expand All @@ -81,7 +90,7 @@ def visit_stmts(self, nodes: list[ast.stmt], bb: BB, jumps: Jumps) -> BB | None:
next_functional = False
for node in nodes:
if bb_opt is None:
raise GuppyError("Unreachable code", node)
raise GuppyError(UnreachableError(node))
if is_functional_annotation(node):
next_functional = True
continue
Expand Down Expand Up @@ -241,7 +250,7 @@ def visit_FunctionDef(
def generic_visit(self, node: ast.AST, bb: BB, jumps: Jumps) -> BB | None:
# When adding support for new statements, we have to remember to use the
# ExprBuilder to transform all included expressions!
raise GuppyError("Statement is not supported", node)
raise GuppyError(UnsupportedError(node, "This statement", singular=True))


class ExprBuilder(ast.NodeTransformer):
Expand Down Expand Up @@ -309,16 +318,20 @@ def visit_ListComp(self, node: ast.ListComp) -> ast.AST:
# Check for illegal expressions
illegals = find_nodes(is_illegal_in_list_comp, node)
if illegals:
raise GuppyError(
"Expression is not supported inside a list comprehension", illegals[0]
err = UnsupportedError(
illegals[0],
"This expression",
singular=True,
unsupported_in="a list comprehension",
)
raise GuppyError(err)

# Desugar into statements that create the iterator, check for a next element,
# get the next element, and finalise the iterator.
gens = []
for g in node.generators:
if g.is_async:
raise GuppyError("Async generators are not supported", g)
raise GuppyError(UnsupportedError(g, "Async generators"))
g.iter = self.visit(g.iter)
it = make_var(next(tmp_vars), g.iter)
hasnext = make_var(next(tmp_vars), g.iter)
Expand Down Expand Up @@ -479,6 +492,12 @@ def is_functional_annotation(stmt: ast.stmt) -> bool:
return False


@dataclass(frozen=True)
class EmptyPyExprError(Error):
title: ClassVar[str] = "Invalid Python expression"
span_label: ClassVar[str] = "Compile-time `py(...)` expression requires an argument"


def is_py_expression(node: ast.AST) -> PyExpr | None:
"""Checks if the given node is a compile-time `py(...)` expression and turns it into
a `PyExpr` AST node.
Expand All @@ -492,10 +511,7 @@ def is_py_expression(node: ast.AST) -> PyExpr | None:
):
match node.args:
case []:
raise GuppyError(
"Compile-time `py(...)` expression requires an argument",
node,
)
raise GuppyError(EmptyPyExprError(node))
case [arg]:
pass
case args:
Expand Down
15 changes: 8 additions & 7 deletions guppylang/compiler/expr_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from guppylang.ast_util import AstVisitor, get_type, with_loc, with_type
from guppylang.cfg.builder import tmp_vars
from guppylang.checker.core import Variable
from guppylang.checker.errors.generic import UnsupportedError
from guppylang.checker.linearity_checker import contains_subscript
from guppylang.compiler.core import CompilerBase, DFContainer
from guppylang.compiler.hugr_extension import PartialOp
Expand Down Expand Up @@ -188,11 +189,11 @@ def visit_GlobalName(self, node: GlobalName) -> Wire:
defn = self.globals[node.def_id]
assert isinstance(defn, CompiledValueDef)
if isinstance(defn, CompiledCallableDef) and defn.ty.parametrized:
raise GuppyError(
"Usage of polymorphic functions as dynamic higher-order values is not "
"supported yet",
node,
# TODO: This should be caught during checking
err = UnsupportedError(
node, "Polymorphic functions as dynamic higher-order values"
)
raise GuppyError(err)
return defn.load(self.dfg, self.globals, node)

def visit_Name(self, node: ast.Name) -> Wire:
Expand Down Expand Up @@ -379,10 +380,10 @@ def visit_TypeApply(self, node: TypeApply) -> Wire:
# TODO: We would need to do manual monomorphisation in that case to obtain a
# function that returns two ports as expected
if instantiation_needs_unpacking(defn.ty, node.inst):
raise GuppyError(
"Generic function instantiations returning rows are not supported yet",
node,
err = UnsupportedError(
node, "Generic function instantiations returning rows"
)
raise GuppyError(err)

return defn.load_with_args(node.inst, self.dfg, self.globals, node)

Expand Down
12 changes: 6 additions & 6 deletions guppylang/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from guppylang.definition.parameter import ConstVarDef, TypeVarDef
from guppylang.definition.struct import RawStructDef
from guppylang.definition.ty import OpaqueTypeDef, TypeDef
from guppylang.error import GuppyError, MissingModuleError, pretty_errors
from guppylang.error import MissingModuleError, pretty_errors
from guppylang.ipython_inspect import get_ipython_globals, is_running_ipython
from guppylang.module import (
GuppyModule,
Expand Down Expand Up @@ -148,7 +148,7 @@ def _get_python_caller(self, fn: PyFunc | None = None) -> ModuleIdentifier:
break
frame = frame.f_back
else:
raise GuppyError("Could not find a caller for the `@guppy` decorator")
raise RuntimeError("Could not find a caller for the `@guppy` decorator")

# Jupyter notebook cells all get different dummy filenames. However,
# we want the whole notebook to correspond to a single implicit
Expand All @@ -172,7 +172,7 @@ def init_module(self, import_builtins: bool = True) -> None:
module_id = self._get_python_caller()
if module_id in self._modules:
msg = f"Module {module_id.name} is already initialised"
raise GuppyError(msg)
raise ValueError(msg)
self._modules[module_id] = GuppyModule(module_id.name, import_builtins)

@pretty_errors
Expand Down Expand Up @@ -426,7 +426,7 @@ def get_module(
other_module = find_guppy_module_in_py_module(value)
if other_module and other_module != module:
defs[x] = value
except GuppyError:
except ValueError:
pass
module.load(**defs)
return module
Expand All @@ -447,7 +447,7 @@ def compile_function(self, f_def: RawFunctionDef) -> FuncDefnPointer:
"""Compiles a single function definition."""
module = f_def.id.module
if not module:
raise GuppyError("Function definition must belong to a module")
raise ValueError("Function definition must belong to a module")
compiled_module = module.compile()
assert module._compiled is not None, "Module should be compiled"
globs = module._compiled.globs
Expand Down Expand Up @@ -476,7 +476,7 @@ def _parse_expr_string(ty_str: str, parse_err: str, sources: SourceMap) -> ast.e
try:
expr_ast = ast.parse(ty_str, mode="eval").body
except SyntaxError:
raise GuppyError(parse_err) from None
raise SyntaxError(parse_err) from None

# Try to annotate the type AST with source information. This requires us to
# inspect the stack frame of the caller
Expand Down
7 changes: 4 additions & 3 deletions guppylang/definition/struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,9 @@ class Suggestion(Help):
"Add a `@guppy` annotation to turn `{method_name}` into a Guppy method"
)

def __post_init__(self) -> None:
self.add_sub_diagnostic(NonGuppyMethodError.Suggestion(None))


@dataclass(frozen=True)
class RawStructDef(TypeDef, ParsableDef):
Expand Down Expand Up @@ -132,9 +135,7 @@ def parse(self, globals: Globals, sources: SourceMap) -> "ParsedStructDef":
case _, ast.FunctionDef(name=name) as node:
v = getattr(self.python_class, name)
if not isinstance(v, Definition):
err = NonGuppyMethodError(node, self.name, name)
err.add_sub_diagnostic(NonGuppyMethodError.Suggestion(None))
raise GuppyError(err)
raise GuppyError(NonGuppyMethodError(node, self.name, name))
used_func_names[name] = node
if name in used_field_names:
raise GuppyError(DuplicateFieldError(node, self.name, name))
Expand Down
33 changes: 23 additions & 10 deletions guppylang/experimental.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from ast import expr
from dataclasses import dataclass
from types import TracebackType
from typing import ClassVar

from guppylang.ast_util import AstNode
from guppylang.diagnostic import Error, Help
from guppylang.error import GuppyError

EXPERIMENTAL_FEATURES_ENABLED = False
Expand Down Expand Up @@ -55,19 +58,29 @@ def __exit__(
EXPERIMENTAL_FEATURES_ENABLED = self.original


@dataclass(frozen=True)
class ExperimentalFeatureError(Error):
title: ClassVar[str] = "Experimental feature"
span_label: ClassVar[str] = "{things} are an experimental feature"
things: str

@dataclass(frozen=True)
class Suggestion(Help):
message: ClassVar[str] = (
"Experimental features are currently disabled. You can enable them by "
"calling `guppylang.enable_experimental_features()`, however note that "
"these features are unstable and might break in the future."
)

def __post_init__(self) -> None:
self.add_sub_diagnostic(ExperimentalFeatureError.Suggestion(None))


def check_function_tensors_enabled(node: expr | None = None) -> None:
if not EXPERIMENTAL_FEATURES_ENABLED:
raise GuppyError(
"Function tensors are an experimental feature. Use "
"`guppylang.enable_experimental_features()` to enable them.",
node,
)
raise GuppyError(ExperimentalFeatureError(node, "Function tensors"))


def check_lists_enabled(loc: AstNode | None = None) -> None:
if not EXPERIMENTAL_FEATURES_ENABLED:
raise GuppyError(
"Lists are an experimental feature and not fully supported yet. Use "
"`guppylang.enable_experimental_features()` to enable them.",
loc,
)
raise GuppyError(ExperimentalFeatureError(loc, "Lists"))
10 changes: 5 additions & 5 deletions guppylang/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from guppylang.definition.parameter import ParamDef
from guppylang.definition.struct import CheckedStructDef
from guppylang.definition.ty import TypeDef
from guppylang.error import GuppyError, pretty_errors
from guppylang.error import pretty_errors
from guppylang.experimental import enable_experimental_features

if TYPE_CHECKING:
Expand Down Expand Up @@ -143,7 +143,7 @@ def load(
imports.append((alias, mod))
else:
msg = f"Only Guppy definitions or modules can be imported. Got `{imp}`"
raise GuppyError(msg)
raise TypeError(msg)

# Also include any impls that are defined by the imported modules
impls: dict[DefId, dict[str, DefId]] = {}
Expand Down Expand Up @@ -180,7 +180,7 @@ def load_all(self, mod: GuppyModule | ModuleType) -> None:
self.load_all(find_guppy_module_in_py_module(mod))
else:
msg = f"Only Guppy definitions or modules can be imported. Got `{mod}`"
raise GuppyError(msg)
raise TypeError(msg)

def register_def(self, defn: RawDef, instance: TypeDef | None = None) -> None:
"""Registers a definition with this module.
Expand Down Expand Up @@ -416,13 +416,13 @@ def find_guppy_module_in_py_module(module: ModuleType) -> GuppyModule:

if not mods:
msg = f"No Guppy modules found in `{module.__name__}`"
raise GuppyError(msg)
raise ValueError(msg)
if len(mods) > 1:
msg = (
f"Python module `{module.__name__}` contains multiple Guppy modules. "
"Cannot decide which one to import."
)
raise GuppyError(msg)
raise ValueError(msg)
return mods[0]


Expand Down
Loading

0 comments on commit 130282d

Please sign in to comment.