From 9f972a3cd3fab77ac369eedcfb586e574740e56a Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Fri, 2 Feb 2024 09:05:37 +0000 Subject: [PATCH] feat: Improve import system --- guppylang/checker/core.py | 44 ++++- guppylang/checker/func_checker.py | 22 ++- guppylang/compiler/core.py | 4 +- guppylang/compiler/expr_compiler.py | 5 +- guppylang/compiler/func_compiler.py | 8 +- guppylang/custom.py | 6 + guppylang/declared.py | 8 +- guppylang/decorator.py | 20 ++- guppylang/gtypes.py | 6 + guppylang/module.py | 173 +++++++++++++++---- guppylang/nodes.py | 4 +- tests/error/import_errors/__init__.py | 0 tests/error/import_errors/alias_different.py | 15 ++ tests/error/import_errors/alias_methods.py | 15 ++ tests/error/linear_errors/branch_use.err | 8 +- tests/error/linear_errors/branch_use.py | 5 - tests/error/linear_errors/break_unused.py | 4 +- tests/error/linear_errors/continue_unused.py | 4 +- tests/error/test_import_errors.py | 80 +++++++++ tests/integration/modules/__init__.py | 0 tests/integration/modules/mod_a.py | 23 +++ tests/integration/modules/mod_b.py | 24 +++ tests/integration/modules/mod_c.py | 19 ++ tests/integration/test_comprehension.py | 4 - tests/integration/test_import.py | 99 +++++++++++ tests/integration/test_linear.py | 4 - 26 files changed, 521 insertions(+), 83 deletions(-) create mode 100644 tests/error/import_errors/__init__.py create mode 100644 tests/error/import_errors/alias_different.py create mode 100644 tests/error/import_errors/alias_methods.py create mode 100644 tests/error/test_import_errors.py create mode 100644 tests/integration/modules/__init__.py create mode 100644 tests/integration/modules/mod_a.py create mode 100644 tests/integration/modules/mod_b.py create mode 100644 tests/integration/modules/mod_c.py create mode 100644 tests/integration/test_import.py diff --git a/guppylang/checker/core.py b/guppylang/checker/core.py index ac5a4288..8744d3ea 100644 --- a/guppylang/checker/core.py +++ b/guppylang/checker/core.py @@ -4,7 +4,7 @@ from abc import ABC, abstractmethod from collections.abc import Iterable, Iterator from dataclasses import dataclass -from typing import Any, NamedTuple +from typing import TYPE_CHECKING, Any, NamedTuple from guppylang.ast_util import AstNode, name_nodes_in_ast from guppylang.gtypes import ( @@ -19,6 +19,9 @@ TupleType, ) +if TYPE_CHECKING: + from guppylang.module import GuppyModule + @dataclass class Variable: @@ -31,11 +34,28 @@ class Variable: @dataclass -class CallableVariable(ABC, Variable): +class GlobalVariable(Variable): + """Class holding data associated with a module-level variable.""" + + module: "GuppyModule | None" + + @property + def qualname(self) -> str: + """The qualified name of this global variable.""" + return f"{self.module.name}.{self.name}" if self.module else self.name + + +@dataclass +class CallableVariable(ABC, GlobalVariable): """Abstract base class for global variables that can be called.""" ty: FunctionType + @property + def is_method(self) -> bool: + """Returns whether this variable is an instance method.""" + return "." in self.name + @abstractmethod def check_call( self, args: list[ast.expr], ty: GuppyType, node: AstNode, ctx: "Context" @@ -67,7 +87,7 @@ class Globals(NamedTuple): constants), to types, or to instance functions belonging to types. """ - values: dict[str, Variable] + values: dict[str, GlobalVariable] types: dict[str, type[GuppyType]] type_vars: dict[str, TypeVarDecl] python_scope: PyScope @@ -91,7 +111,7 @@ def get_instance_func(self, ty: GuppyType, name: str) -> CallableVariable | None Returns `None` if the name doesn't exist or isn't a function. """ - qualname = qualified_name(ty.__class__, name) + qualname = qualified_instance_name(ty.__class__, name) if qualname in self.values: val = self.values[qualname] if isinstance(val, CallableVariable): @@ -203,7 +223,17 @@ def __contains__(self, key: object) -> bool: return super().__contains__(key) -def qualified_name(ty: type[GuppyType] | str, name: str) -> str: +def qualified_name(module: "GuppyModule | None", name: str) -> str: + """Returns a name qualified by a module.""" + module_name = module.name if module else "builtins" + return f"{module_name}.{name}" + + +def instance_name(ty: type[GuppyType], name: str) -> str: + """Returns a name for an instance function on a type.""" + return f"{ty.name}.{name}" + + +def qualified_instance_name(ty: type[GuppyType], name: str) -> str: """Returns a qualified name for an instance function on a type.""" - ty_name = ty if isinstance(ty, str) else ty.name - return f"{ty_name}.{name}" + return qualified_name(ty.module, instance_name(ty, name)) diff --git a/guppylang/checker/func_checker.py b/guppylang/checker/func_checker.py index f224bd99..f7410eac 100644 --- a/guppylang/checker/func_checker.py +++ b/guppylang/checker/func_checker.py @@ -7,12 +7,19 @@ import ast from dataclasses import dataclass +from typing import TYPE_CHECKING from guppylang.ast_util import AstNode, return_nodes_in_ast, with_loc from guppylang.cfg.bb import BB from guppylang.cfg.builder import CFGBuilder from guppylang.checker.cfg_checker import CheckedCFG, check_cfg -from guppylang.checker.core import CallableVariable, Context, Globals, Variable +from guppylang.checker.core import ( + CallableVariable, + Context, + Globals, + GlobalVariable, + Variable, +) from guppylang.checker.expr_checker import check_call, synthesize_call from guppylang.error import GuppyError from guppylang.gtypes import ( @@ -25,9 +32,12 @@ ) from guppylang.nodes import CheckedNestedFunctionDef, GlobalCall, NestedFunctionDef +if TYPE_CHECKING: + from guppylang.module import GuppyModule + @dataclass -class DefinedFunction(CallableVariable): +class DefinedFunction(CallableVariable, GlobalVariable): """A user-defined function""" ty: FunctionType @@ -35,14 +45,14 @@ class DefinedFunction(CallableVariable): @staticmethod def from_ast( - func_def: ast.FunctionDef, name: str, globals: Globals + func_def: ast.FunctionDef, name: str, module: "GuppyModule", globals: Globals ) -> "DefinedFunction": ty = check_signature(func_def, globals) if ty.quantified: raise GuppyError( "Generic function definitions are not supported yet", func_def ) - return DefinedFunction(name, ty, func_def, None) + return DefinedFunction(name, ty, func_def, None, module) def check_call( self, args: list[ast.expr], ty: GuppyType, node: AstNode, ctx: Context @@ -79,7 +89,7 @@ def check_global_func_def(func: DefinedFunction, globals: Globals) -> CheckedFun for x, ty, loc in zip(func.ty.arg_names, func.ty.args, args) ] cfg = check_cfg(cfg, inputs, func.ty.returns, globals) - return CheckedFunction(func_def.name, func.ty, func_def, None, cfg) + return CheckedFunction(func.name, func.ty, func_def, None, func.module, cfg) def check_nested_func_def( @@ -141,7 +151,7 @@ def check_nested_func_def( if func_def.name in cfg.live_before[cfg.entry_bb]: if not captured: # If there are no captured vars, we treat the function like a global name - func = DefinedFunction(func_def.name, func_ty, func_def, None) + func = DefinedFunction(func_def.name, func_ty, func_def, None, None) globals = ctx.globals | Globals({func_def.name: func}, {}, {}, {}) else: diff --git a/guppylang/compiler/core.py b/guppylang/compiler/core.py index 9e9026a8..d9328d9e 100644 --- a/guppylang/compiler/core.py +++ b/guppylang/compiler/core.py @@ -3,7 +3,7 @@ from dataclasses import dataclass from guppylang.ast_util import AstNode -from guppylang.checker.core import CallableVariable, Variable +from guppylang.checker.core import CallableVariable, GlobalVariable, Variable from guppylang.gtypes import FunctionType, Inst from guppylang.hugr.hugr import DFContainingNode, Hugr, OutPortV @@ -28,7 +28,7 @@ def __init__( object.__setattr__(self, "port", port) -class CompiledVariable(ABC, Variable): +class CompiledVariable(ABC, GlobalVariable): """Abstract base class for compiled global module-level variables.""" @abstractmethod diff --git a/guppylang/compiler/expr_compiler.py b/guppylang/compiler/expr_compiler.py index 8d05fbed..369d28de 100644 --- a/guppylang/compiler/expr_compiler.py +++ b/guppylang/compiler/expr_compiler.py @@ -147,7 +147,8 @@ def visit_LocalName(self, node: LocalName) -> OutPortV: return self.dfg[node.id].port def visit_GlobalName(self, node: GlobalName) -> OutPortV: - return self.globals[node.id].load(self.dfg, self.graph, self.globals, node) + name = node.value.qualname + return self.globals[name].load(self.dfg, self.graph, self.globals, node) def visit_Name(self, node: ast.Name) -> OutPortV: raise InternalGuppyError("Node should have been removed during type checking.") @@ -179,7 +180,7 @@ def visit_LocalCall(self, node: LocalCall) -> OutPortV: return self._pack_returns(rets) def visit_GlobalCall(self, node: GlobalCall) -> OutPortV: - func = self.globals[node.func.name] + func = self.globals[node.func.qualname] assert isinstance(func, CompiledFunction) args = [self.visit(arg) for arg in node.args] diff --git a/guppylang/compiler/func_compiler.py b/guppylang/compiler/func_compiler.py index e67b900b..699a5c0c 100644 --- a/guppylang/compiler/func_compiler.py +++ b/guppylang/compiler/func_compiler.py @@ -60,7 +60,9 @@ def compile_global_func_def( parent=def_node, ) - return CompiledFunctionDef(func.name, func.ty, func.defined_at, None, def_node) + return CompiledFunctionDef( + func.name, func.ty, func.defined_at, None, func.module, def_node + ) def compile_local_func_def( @@ -98,7 +100,9 @@ def compile_local_func_def( else: # Otherwise, we treat the function like a normal global variable globals = globals | { - func.name: CompiledFunctionDef(func.name, func.ty, func, None, def_node) + func.name: CompiledFunctionDef( + func.name, func.ty, func, None, None, def_node + ) } # Compile the CFG diff --git a/guppylang/custom.py b/guppylang/custom.py index e1142448..ac07d39b 100644 --- a/guppylang/custom.py +++ b/guppylang/custom.py @@ -1,5 +1,6 @@ import ast from abc import ABC, abstractmethod +from typing import TYPE_CHECKING from guppylang.ast_util import AstNode, get_type, with_loc, with_type from guppylang.checker.core import Context, Globals @@ -16,6 +17,9 @@ from guppylang.hugr.hugr import DFContainingVNode, Hugr, Node, OutPortV from guppylang.nodes import GlobalCall +if TYPE_CHECKING: + from guppylang.module import GuppyModule + class CustomFunction(CompiledFunction): """A function whose type checking and compilation behaviour can be customised.""" @@ -35,6 +39,7 @@ class CustomFunction(CompiledFunction): def __init__( self, name: str, + module: "GuppyModule", defined_at: ast.FunctionDef | None, compiler: "CustomCallCompiler", checker: "CustomCallChecker", @@ -42,6 +47,7 @@ def __init__( ty: FunctionType | None = None, ): self.name = name + self.module = module self.defined_at = defined_at self.higher_order_value = higher_order_value self.call_compiler = compiler diff --git a/guppylang/declared.py b/guppylang/declared.py index 7fc5923e..fc6a3356 100644 --- a/guppylang/declared.py +++ b/guppylang/declared.py @@ -1,5 +1,6 @@ import ast from dataclasses import dataclass +from typing import TYPE_CHECKING from guppylang.ast_util import AstNode, has_empty_body, with_loc from guppylang.checker.core import Context, Globals @@ -11,6 +12,9 @@ from guppylang.hugr.hugr import Hugr, Node, OutPortV, VNode from guppylang.nodes import GlobalCall +if TYPE_CHECKING: + from guppylang.module import GuppyModule + @dataclass class DeclaredFunction(CompiledFunction): @@ -20,14 +24,14 @@ class DeclaredFunction(CompiledFunction): @staticmethod def from_ast( - func_def: ast.FunctionDef, name: str, globals: Globals + func_def: ast.FunctionDef, name: str, module: "GuppyModule", globals: Globals ) -> "DeclaredFunction": ty = check_signature(func_def, globals) if not has_empty_body(func_def): raise GuppyError( "Body of function declaration must be empty", func_def.body[0] ) - return DeclaredFunction(name, ty, func_def, None) + return DeclaredFunction(name, ty, func_def, None, module) def check_call( self, args: list[ast.expr], ty: GuppyType, node: AstNode, ctx: Context diff --git a/guppylang/decorator.py b/guppylang/decorator.py index cb1836a6..96f68031 100644 --- a/guppylang/decorator.py +++ b/guppylang/decorator.py @@ -4,7 +4,7 @@ from dataclasses import dataclass from pathlib import Path from types import ModuleType -from typing import Any, ClassVar, TypeVar +from typing import Any, ClassVar, Protocol, TypeVar from guppylang.ast_util import AstNode, has_empty_body from guppylang.custom import ( @@ -26,6 +26,12 @@ ClassDecorator = Callable[[type], type] +class ClassWithGuppyType(Protocol): + """Mypy protocol for a class that is annotated with a Guppy type.""" + + _guppy_type: type[GuppyType] + + @dataclass(frozen=True) class ModuleIdentifier: """Identifier for the Python file/module that called the decorator.""" @@ -113,12 +119,17 @@ def _get_python_caller(self, fn: PyFunc | None = None) -> ModuleIdentifier: return ModuleIdentifier(Path(filename), module) @pretty_errors - def extend_type(self, module: GuppyModule, ty: type[GuppyType]) -> ClassDecorator: + def extend_type( + self, module: GuppyModule, ty: type[GuppyType] | ClassWithGuppyType + ) -> ClassDecorator: """Decorator to add new instance functions to a type.""" module._instance_func_buffer = {} + guppy_ty = ( + ty if isinstance(ty, type) and issubclass(ty, GuppyType) else ty._guppy_type + ) def dec(c: type) -> type: - module._register_buffered_instance_funcs(ty) + module._register_buffered_instance_funcs(guppy_ty) return c return dec @@ -142,11 +153,13 @@ def type( def dec(c: type) -> type: _name = name or c.__name__ + _module = module @dataclass(frozen=True) class NewType(GuppyType): args: Sequence[GuppyType] name: ClassVar[str] = _name + module: ClassVar[GuppyModule | None] = _module @staticmethod def build(*args: GuppyType, node: AstNode | None = None) -> "GuppyType": @@ -220,6 +233,7 @@ def dec(f: PyFunc) -> CustomFunction: call_checker = checker or DefaultCallChecker() func = CustomFunction( name or func_ast.name, + module, func_ast, compiler or DefaultCallCompiler(), call_checker, diff --git a/guppylang/gtypes.py b/guppylang/gtypes.py index 04d1c873..2c4e09df 100644 --- a/guppylang/gtypes.py +++ b/guppylang/gtypes.py @@ -14,6 +14,7 @@ if TYPE_CHECKING: from guppylang.checker.core import Globals + from guppylang.module import GuppyModule Subst = dict["ExistentialTypeVar", "GuppyType"] @@ -28,6 +29,7 @@ class GuppyType(ABC): """ name: ClassVar[str] + module: ClassVar["GuppyModule | None"] = None # Cache for free variables _unsolved_vars: set["ExistentialTypeVar"] = field(init=False, repr=False) @@ -51,6 +53,10 @@ def __post_init__(self) -> None: vs |= arg.unsolved_vars object.__setattr__(self, "_unsolved_vars", vs) + @classmethod + def qualname(cls) -> str: + return f"{cls.module.name}.{cls.name}" if cls.module else cls.name + @staticmethod @abstractmethod def build(*args: "GuppyType", node: AstNode | None = None) -> "GuppyType": diff --git a/guppylang/module.py b/guppylang/module.py index d24d1cd7..1c961cb3 100644 --- a/guppylang/module.py +++ b/guppylang/module.py @@ -3,11 +3,19 @@ import sys import textwrap from collections.abc import Callable -from types import ModuleType -from typing import Any, Union +from typing import TYPE_CHECKING, Any from guppylang.ast_util import AstNode, annotate_location -from guppylang.checker.core import Globals, PyScope, TypeVarDecl, qualified_name +from guppylang.checker.core import ( + CallableVariable, + Globals, + GlobalVariable, + PyScope, + TypeVarDecl, + instance_name, + qualified_instance_name, + qualified_name, +) from guppylang.checker.func_checker import DefinedFunction, check_global_func_def from guppylang.compiler.core import CompiledGlobals from guppylang.compiler.func_compiler import ( @@ -16,10 +24,17 @@ ) from guppylang.custom import CustomFunction from guppylang.declared import DeclaredFunction -from guppylang.error import GuppyError, pretty_errors +from guppylang.error import ( + GuppyError, + pretty_errors, +) from guppylang.gtypes import GuppyType from guppylang.hugr.hugr import Hugr +if TYPE_CHECKING: + from types import ModuleType + + PyFunc = Callable[..., Any] PyFuncDefOrDecl = tuple[bool, PyFunc] @@ -69,29 +84,63 @@ def __init__(self, name: str, import_builtins: bool = True): self.load(builtins) - def load(self, m: Union[ModuleType, "GuppyModule"]) -> None: - """Imports another Guppy module.""" + def load(self, module: "GuppyModule | ModuleType") -> None: + """Imports every value and type from another module.""" self._check_not_yet_compiled() - if isinstance(m, GuppyModule): - # Compile module if it isn't compiled yet - if not m.compiled: - m.compile() - - # For now, we can only import custom functions - if any( - not isinstance(v, CustomFunction) for v in m._compiled_globals.values() - ): - raise GuppyError( - "Importing modules with defined functions is not supported yet" - ) - - self._imported_globals |= m._globals - self._imported_compiled_globals |= m._compiled_globals + if isinstance(module, GuppyModule): + if not module.compiled: + module.compile() + for ty in module._globals.types.values(): + self._import_type(ty) + for val in module._globals.values.values(): + if not (isinstance(val, CallableVariable) and val.is_method): + self._import_value(val) + self._import_instance_methods(module) else: - for val in m.__dict__.values(): + for val in module.__dict__.values(): if isinstance(val, GuppyModule): self.load(val) + def import_(self, module: "GuppyModule", name: str, alias: str = "") -> None: + """Imports a Guppy value or type from another module by name.""" + self._check_not_yet_compiled() + if not module.compiled: + module.compile() + if name in module._globals.types: + self._import_type(module._globals.types[name], alias) + elif name in module._globals.values: + self._import_value(module._globals.values[name], alias) + else: + raise GuppyError(f"Could not find `{name}` in module `{module.name}`") + self._import_instance_methods(module) + + def _import_type(self, ty: type[GuppyType], alias: str = "") -> None: + """Imports a Guppy type from a different module.""" + name = alias or ty.name + self._check_type_name_available(name, None) + self._imported_globals.types[name] = ty + + def _import_value(self, v: GlobalVariable, alias: str = "") -> None: + """Imports a Guppy value from a different module.""" + assert v.module is not None + name = alias or v.name + self._check_name_available(name, None) + self._imported_globals.values[name] = v + self._imported_compiled_globals[name] = v.module._compiled_globals[v.name] + + def _import_instance_methods(self, module: "GuppyModule") -> None: + """Transitively imports all instance methods from a module.""" + for x, v in (module._globals.values | module._imported_globals.values).items(): + if ( + isinstance(v, CallableVariable) + and v.is_method + and x not in self._imported_compiled_globals + ): + assert v.module is not None + self._check_name_available(v.qualname, None) + self._imported_globals.values[x] = v + self._imported_compiled_globals[x] = v.module._compiled_globals[x] + def register_func_def( self, f: PyFunc, instance: type[GuppyType] | None = None ) -> None: @@ -101,9 +150,14 @@ def register_func_def( if self._instance_func_buffer is not None: self._instance_func_buffer[func_ast.name] = (True, f) else: - name = ( - qualified_name(instance, func_ast.name) if instance else func_ast.name - ) + if instance: + # Instance methods are called `Type.method_name`, so that's what we + # update the name to. However, to disambiguate we actually use the + # qualified name `module.Type.method_name` as the dict key. + name = qualified_instance_name(instance, func_ast.name) + func_ast.name = instance_name(instance, func_ast.name) + else: + name = func_ast.name self._check_name_available(name, func_ast) self._func_defs[name] = func_ast, get_py_scope(f) @@ -116,9 +170,14 @@ def register_func_decl( if self._instance_func_buffer is not None: self._instance_func_buffer[func_ast.name] = (False, f) else: - name = ( - qualified_name(instance, func_ast.name) if instance else func_ast.name - ) + if instance: + # Instance methods are called `Type.method_name`, so that's what we + # update the name to. However, to disambiguate we actually use the + # qualified name `module.Type.method_name` as the dict key. + func_ast.name = instance_name(instance, func_ast.name) + name = qualified_name(instance.module, func_ast.name) + else: + name = func_ast.name self._check_name_available(name, func_ast) self._func_decls[name] = func_ast @@ -131,9 +190,15 @@ def register_custom_func( self._instance_func_buffer[func.name] = func else: if instance: - func.name = qualified_name(instance, func.name) - self._check_name_available(func.name, func.defined_at) - self._custom_funcs[func.name] = func + # Instance methods are called `Type.method_name`, so that's what we + # update `func.name` to. However, to disambiguate we actually use the + # qualified name `module.Type.method_name` as the dict key. + name = qualified_instance_name(instance, func.name) + func.name = instance_name(instance, func.name) + else: + name = func.name + self._check_name_available(name, func.defined_at) + self._custom_funcs[name] = func def register_type(self, name: str, ty: type[GuppyType]) -> None: """Registers an existing Guppy type as belonging to this Guppy module.""" @@ -175,11 +240,15 @@ def compile(self) -> Hugr | None: for func in self._custom_funcs.values(): func.check_type(self._imported_globals | self._globals) defined_funcs = { - x: DefinedFunction.from_ast(f, x, self._imported_globals | self._globals) + x: DefinedFunction.from_ast( + f, f.name, self, self._imported_globals | self._globals + ) for x, (f, _) in self._func_defs.items() } declared_funcs = { - x: DeclaredFunction.from_ast(f, x, self._imported_globals | self._globals) + x: DeclaredFunction.from_ast( + f, f.name, self, self._imported_globals | self._globals + ) for x, f in self._func_decls.items() } self._globals.values.update(self._custom_funcs) @@ -205,22 +274,43 @@ def compile(self) -> Hugr | None: # Prepare `FunctionDef` nodes for all function definitions def_nodes = {x: graph.add_def(f.ty, module_node, x) for x, f in checked.items()} + + # Store the compiled functions, so they can be imported by other modules self._compiled_globals |= ( self._custom_funcs | declared_funcs | { - x: CompiledFunctionDef(x, f.ty, f.defined_at, None, def_nodes[x]) + x: CompiledFunctionDef( + f.name, f.ty, f.defined_at, None, self, def_nodes[x] + ) for x, f in checked.items() } ) + # Construct a mapping from fully qualified names to compiled functions. This + # will be used to compile `GlobalName` and `GlobalCall` nodes. + compiled_globals: CompiledGlobals = {} + for x, v in self._compiled_globals.items(): + # Note that for instance methods, the dict key `x` is already fully + # qualified + if isinstance(v, CallableVariable) and v.is_method: + compiled_globals[x] = v + else: + compiled_globals[v.qualname] = v + for v in self._imported_compiled_globals.values(): + # Prepare `FuncDecl` nodes for all imported functions + if isinstance(v, DeclaredFunction | DefinedFunction): + v = DeclaredFunction(v.name, v.ty, v.defined_at, None, v.module) + v.add_to_graph(graph, module_node) + compiled_globals[v.qualname] = v + # Compile function definitions to Hugr for x, f in checked.items(): compile_global_func_def( f, def_nodes[x], graph, - self._imported_compiled_globals | self._compiled_globals, + compiled_globals, ) self._compiled = True @@ -244,11 +334,22 @@ def _check_name_available(self, name: str, node: AstNode | None) -> None: f"Module `{self.name}` already contains a function named `{name}`", node, ) + if name in self._imported_globals.values: + raise GuppyError( + f"A function named `{name}` has already been imported", + node, + ) def _check_type_name_available(self, name: str, node: AstNode | None) -> None: if name in self._globals.types: raise GuppyError( - f"Module `{self.name}` already contains a type `{name}`", + f"Module `{self.name}` already contains a type named `{name}`", + node, + ) + + if name in self._imported_globals.types: + raise GuppyError( + f"A type named `{name}` has already been imported", node, ) diff --git a/guppylang/nodes.py b/guppylang/nodes.py index 908fe55f..047a5e2e 100644 --- a/guppylang/nodes.py +++ b/guppylang/nodes.py @@ -9,7 +9,7 @@ if TYPE_CHECKING: from guppylang.cfg.cfg import CFG from guppylang.checker.cfg_checker import CheckedCFG - from guppylang.checker.core import CallableVariable, Variable + from guppylang.checker.core import CallableVariable, GlobalVariable, Variable class LocalName(ast.Name): @@ -20,7 +20,7 @@ class LocalName(ast.Name): class GlobalName(ast.Name): id: str - value: "Variable" + value: "GlobalVariable" _fields = ( "id", diff --git a/tests/error/import_errors/__init__.py b/tests/error/import_errors/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/error/import_errors/alias_different.py b/tests/error/import_errors/alias_different.py new file mode 100644 index 00000000..5a2334e3 --- /dev/null +++ b/tests/error/import_errors/alias_different.py @@ -0,0 +1,15 @@ +from guppylang import GuppyModule, guppy +from tests.integration.modules.mod_a import mod_a, MyType +from tests.integration.modules.mod_b import mod_b + +module = GuppyModule("test") +module.import_(mod_a, "MyType") +module.import_(mod_b, "MyType", alias="MyType2") + + +@guppy(module) +def foo(x: MyType) -> "MyType2": + return x + + +module.compile() diff --git a/tests/error/import_errors/alias_methods.py b/tests/error/import_errors/alias_methods.py new file mode 100644 index 00000000..cf007694 --- /dev/null +++ b/tests/error/import_errors/alias_methods.py @@ -0,0 +1,15 @@ +from guppylang import GuppyModule, guppy +from tests.integration.modules.mod_a import mod_a, MyType +from tests.integration.modules.mod_b import mod_b + +module = GuppyModule("test") +module.import_(mod_a, "MyType") +module.import_(mod_b, "MyType", alias="MyType2") + + +@guppy(module) +def foo(x: MyType) -> MyType: + return +x + + +module.compile() diff --git a/tests/error/linear_errors/branch_use.err b/tests/error/linear_errors/branch_use.err index c096aa16..9b8f5d21 100644 --- a/tests/error/linear_errors/branch_use.err +++ b/tests/error/linear_errors/branch_use.err @@ -1,7 +1,7 @@ -Guppy compilation failed. Error in file $FILE:23 +Guppy compilation failed. Error in file $FILE:18 -21: @guppy(module) -22: def foo(b: bool) -> bool: -23: q = new_qubit() +16: @guppy(module) +17: def foo(b: bool) -> bool: +18: q = new_qubit() ^ GuppyError: Variable `q` with linear type `Qubit` is not used on all control-flow paths diff --git a/tests/error/linear_errors/branch_use.py b/tests/error/linear_errors/branch_use.py index de233c7d..fd5f7c08 100644 --- a/tests/error/linear_errors/branch_use.py +++ b/tests/error/linear_errors/branch_use.py @@ -13,11 +13,6 @@ def new_qubit() -> Qubit: ... -@guppy.declare(module) -def measure(q: Qubit) -> bool: - ... - - @guppy(module) def foo(b: bool) -> bool: q = new_qubit() diff --git a/tests/error/linear_errors/break_unused.py b/tests/error/linear_errors/break_unused.py index 3c7f0060..c043bbb7 100644 --- a/tests/error/linear_errors/break_unused.py +++ b/tests/error/linear_errors/break_unused.py @@ -14,7 +14,7 @@ def new_qubit() -> Qubit: @guppy.declare(module) -def measure() -> bool: +def measure_() -> bool: ... @@ -26,7 +26,7 @@ def foo(i: int) -> bool: if i == 0: break i -= 1 - b ^= measure(q) + b ^= measure_(q) return b diff --git a/tests/error/linear_errors/continue_unused.py b/tests/error/linear_errors/continue_unused.py index e31dabff..52704632 100644 --- a/tests/error/linear_errors/continue_unused.py +++ b/tests/error/linear_errors/continue_unused.py @@ -14,7 +14,7 @@ def new_qubit() -> Qubit: @guppy.declare(module) -def measure() -> bool: +def measure_() -> bool: ... @@ -26,7 +26,7 @@ def foo(i: int) -> bool: if i % 10 == 0: break i -= 1 - b ^= measure(q) + b ^= measure_(q) return b diff --git a/tests/error/test_import_errors.py b/tests/error/test_import_errors.py new file mode 100644 index 00000000..ec298b2a --- /dev/null +++ b/tests/error/test_import_errors.py @@ -0,0 +1,80 @@ +import pathlib +import pytest + +from guppylang import GuppyModule, guppy +from guppylang.error import GuppyError +from guppylang.gtypes import BoolType +from tests.integration.modules.mod_a import mod_a +from tests.integration.modules.mod_b import mod_b + +from tests.error.util import run_error_test + + +def test_doesnt_exist(): + module = GuppyModule("test") + + with pytest.raises(GuppyError, match="Could not find `h` in module `mod_a`"): + module.import_(mod_a, "h") + + +def test_func_already_defined(): + module = GuppyModule("test") + + @guppy(module) + def f() -> None: + return + + with pytest.raises(GuppyError, match="Module `test` already contains a function named `f`"): + module.import_(mod_a, "f") + + +def test_type_already_defined(): + module = GuppyModule("test") + + @guppy.type(module, BoolType().to_hugr()) + class MyType: + pass + + with pytest.raises(GuppyError, match="Module `test` already contains a type named `MyType`"): + module.import_(mod_a, "MyType") + + +def test_func_already_imported(): + module = GuppyModule("test") + module.import_(mod_a, "f") + + with pytest.raises(GuppyError, match="A function named `f` has already been imported"): + module.import_(mod_b, "f") + + +def test_type_already_imported(): + module = GuppyModule("test") + module.import_(mod_a, "MyType") + + with pytest.raises(GuppyError, match="A type named `MyType` has already been imported"): + module.import_(mod_b, "MyType") + + +def test_already_imported_alias(): + module = GuppyModule("test") + module.import_(mod_a, "f", alias="h") + + with pytest.raises(GuppyError, match="A function named `h` has already been imported"): + module.import_(mod_b, "h") + + +path = pathlib.Path(__file__).parent.resolve() / "import_errors" +files = [ + x + for x in path.iterdir() + if x.is_file() + if x.suffix == ".py" and x.name != "__init__.py" +] + +# Turn paths into strings, otherwise pytest doesn't display the names +files = [str(f) for f in files] + + +@pytest.mark.parametrize("file", files) +def test_import_errors(file, capsys): + run_error_test(file, capsys) diff --git a/tests/integration/modules/__init__.py b/tests/integration/modules/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/integration/modules/mod_a.py b/tests/integration/modules/mod_a.py new file mode 100644 index 00000000..b9abfe75 --- /dev/null +++ b/tests/integration/modules/mod_a.py @@ -0,0 +1,23 @@ +from guppylang import GuppyModule, guppy +from guppylang.gtypes import BoolType + +mod_a = GuppyModule("mod_a") + + +@guppy(mod_a) +def f(x: int) -> int: + return x + + +@guppy.declare(mod_a) +def g() -> int: + ... + + +@guppy.type(mod_a, BoolType().to_hugr()) +class MyType: + + @guppy.declare(mod_a) + def __neg__(self: "MyType") -> "MyType": + ... + diff --git a/tests/integration/modules/mod_b.py b/tests/integration/modules/mod_b.py new file mode 100644 index 00000000..aee314cc --- /dev/null +++ b/tests/integration/modules/mod_b.py @@ -0,0 +1,24 @@ +from guppylang import GuppyModule, guppy +from guppylang.gtypes import BoolType +from guppylang.hugr import ops + +mod_b = GuppyModule("mod_b") + + +@guppy.declare(mod_b) +def f(x: bool) -> bool: + ... + + +@guppy.hugr_op(mod_b, ops.DummyOp(name="dummy")) +def h() -> int: + ... + + +@guppy.type(mod_b, BoolType().to_hugr()) +class MyType: + + @guppy.declare(mod_b) + def __pos__(self: "MyType") -> "MyType": + ... + diff --git a/tests/integration/modules/mod_c.py b/tests/integration/modules/mod_c.py new file mode 100644 index 00000000..4b70b87c --- /dev/null +++ b/tests/integration/modules/mod_c.py @@ -0,0 +1,19 @@ +from guppylang import GuppyModule, guppy +from tests.integration.modules.mod_a import mod_a, MyType + +mod_c = GuppyModule("mod_c") +mod_c.import_(mod_a, "MyType") + + +@guppy.declare(mod_c) +def g() -> MyType: + ... + + +@guppy.extend_type(mod_c, MyType) +class _: + + @guppy(mod_c) + def __int__(self: "MyType") -> int: + return 0 + diff --git a/tests/integration/test_comprehension.py b/tests/integration/test_comprehension.py index 76da7749..c5f1cce9 100644 --- a/tests/integration/test_comprehension.py +++ b/tests/integration/test_comprehension.py @@ -144,10 +144,6 @@ def test_linear_discard(validate): module = GuppyModule("test") module.load(quantum) - @guppy.declare(module) - def discard(q: Qubit) -> None: - ... - @guppy(module) def test(qs: linst[Qubit]) -> list[None]: return [discard(q) for q in qs] diff --git a/tests/integration/test_import.py b/tests/integration/test_import.py new file mode 100644 index 00000000..a2881356 --- /dev/null +++ b/tests/integration/test_import.py @@ -0,0 +1,99 @@ +from guppylang.decorator import guppy +from guppylang.module import GuppyModule +from tests.integration.modules.mod_a import mod_a, MyType, f, g +from tests.integration.modules.mod_b import mod_b +from tests.integration.modules.mod_c import mod_c + + +def test_import_func(validate): + module = GuppyModule("test") + module.import_(mod_a, "f") + module.import_(mod_a, "g") + + @guppy(module) + def test(x: int) -> int: + return f(x) + g() + + validate(module.compile()) + + +def test_import_type(validate): + module = GuppyModule("test") + module.import_(mod_a, "MyType") + + @guppy(module) + def test(x: MyType) -> MyType: + return -x + + validate(module.compile()) + + +def test_func_alias(validate): + module = GuppyModule("test") + module.import_(mod_a, "f", alias="g") + + @guppy(module) + def test(x: int) -> int: + return g(x) + + validate(module.compile()) + + +def test_type_alias(validate): + module = GuppyModule("test") + module.import_(mod_a, "MyType", alias="MyType_Alias") + + @guppy(module) + def test(x: "MyType_Alias") -> "MyType_Alias": + return -x + + validate(module.compile()) + + +def test_conflict_alias(validate): + module = GuppyModule("test") + module.import_(mod_a, "f") + module.import_(mod_b, "f", alias="f_b") + + @guppy(module) + def test(x: int, y: bool) -> tuple[int, bool]: + return f(x), f_b(y) + + validate(module.compile()) + + +def test_conflict_alias_type(validate): + module = GuppyModule("test") + module.import_(mod_a, "MyType") + module.import_(mod_b, "MyType", alias="MyTypeB") + + @guppy(module) + def test(x: MyType, y: "MyTypeB") -> tuple[MyType, "MyTypeB"]: + return -x, +y + + validate(module.compile()) + + +def test_type_transitive(validate): + module = GuppyModule("test") + module.import_(mod_c, "g") # `g` returns a type that was defined in `mod_a` + + @guppy(module) + def test() -> int: + x = g() + return int(-x) # Use instance method that was defined in `mod_a` + + validate(module.compile()) + + +def test_type_transitive_conflict(validate): + module = GuppyModule("test") + module.import_(mod_b, "MyType") + module.import_(mod_c, "g") + + @guppy(module) + def test(ty_b: MyType) -> MyType: + ty_a = g() + return +ty_b + + validate(module.compile()) diff --git a/tests/integration/test_linear.py b/tests/integration/test_linear.py index 197257b7..1302853a 100644 --- a/tests/integration/test_linear.py +++ b/tests/integration/test_linear.py @@ -262,10 +262,6 @@ class MyType: def __iter__(self: "MyType") -> MyIter: ... - @guppy.declare(module) - def measure(q: Qubit) -> bool: - ... - @guppy(module) def test(mt: MyType, xs: list[int]) -> None: # We can break, since `mt` itself is not linear