Skip to content

Commit

Permalink
feat: Improve import system
Browse files Browse the repository at this point in the history
  • Loading branch information
mark-koch committed Feb 2, 2024
1 parent 8c9e4d2 commit 9f972a3
Show file tree
Hide file tree
Showing 26 changed files with 521 additions and 83 deletions.
44 changes: 37 additions & 7 deletions guppylang/checker/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from abc import ABC, abstractmethod
from collections.abc import Iterable, Iterator
from dataclasses import dataclass
from typing import Any, NamedTuple
from typing import TYPE_CHECKING, Any, NamedTuple

from guppylang.ast_util import AstNode, name_nodes_in_ast
from guppylang.gtypes import (
Expand All @@ -19,6 +19,9 @@
TupleType,
)

if TYPE_CHECKING:
from guppylang.module import GuppyModule


@dataclass
class Variable:
Expand All @@ -31,11 +34,28 @@ class Variable:


@dataclass
class CallableVariable(ABC, Variable):
class GlobalVariable(Variable):
"""Class holding data associated with a module-level variable."""

module: "GuppyModule | None"

@property
def qualname(self) -> str:
"""The qualified name of this global variable."""
return f"{self.module.name}.{self.name}" if self.module else self.name


@dataclass
class CallableVariable(ABC, GlobalVariable):
"""Abstract base class for global variables that can be called."""

ty: FunctionType

@property
def is_method(self) -> bool:
"""Returns whether this variable is an instance method."""
return "." in self.name

@abstractmethod
def check_call(
self, args: list[ast.expr], ty: GuppyType, node: AstNode, ctx: "Context"
Expand Down Expand Up @@ -67,7 +87,7 @@ class Globals(NamedTuple):
constants), to types, or to instance functions belonging to types.
"""

values: dict[str, Variable]
values: dict[str, GlobalVariable]
types: dict[str, type[GuppyType]]
type_vars: dict[str, TypeVarDecl]
python_scope: PyScope
Expand All @@ -91,7 +111,7 @@ def get_instance_func(self, ty: GuppyType, name: str) -> CallableVariable | None
Returns `None` if the name doesn't exist or isn't a function.
"""
qualname = qualified_name(ty.__class__, name)
qualname = qualified_instance_name(ty.__class__, name)
if qualname in self.values:
val = self.values[qualname]
if isinstance(val, CallableVariable):
Expand Down Expand Up @@ -203,7 +223,17 @@ def __contains__(self, key: object) -> bool:
return super().__contains__(key)


def qualified_name(ty: type[GuppyType] | str, name: str) -> str:
def qualified_name(module: "GuppyModule | None", name: str) -> str:
"""Returns a name qualified by a module."""
module_name = module.name if module else "builtins"
return f"{module_name}.{name}"


def instance_name(ty: type[GuppyType], name: str) -> str:
"""Returns a name for an instance function on a type."""
return f"{ty.name}.{name}"


def qualified_instance_name(ty: type[GuppyType], name: str) -> str:
"""Returns a qualified name for an instance function on a type."""
ty_name = ty if isinstance(ty, str) else ty.name
return f"{ty_name}.{name}"
return qualified_name(ty.module, instance_name(ty, name))
22 changes: 16 additions & 6 deletions guppylang/checker/func_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,19 @@

import ast
from dataclasses import dataclass
from typing import TYPE_CHECKING

from guppylang.ast_util import AstNode, return_nodes_in_ast, with_loc
from guppylang.cfg.bb import BB
from guppylang.cfg.builder import CFGBuilder
from guppylang.checker.cfg_checker import CheckedCFG, check_cfg
from guppylang.checker.core import CallableVariable, Context, Globals, Variable
from guppylang.checker.core import (
CallableVariable,
Context,
Globals,
GlobalVariable,
Variable,
)
from guppylang.checker.expr_checker import check_call, synthesize_call
from guppylang.error import GuppyError
from guppylang.gtypes import (
Expand All @@ -25,24 +32,27 @@
)
from guppylang.nodes import CheckedNestedFunctionDef, GlobalCall, NestedFunctionDef

if TYPE_CHECKING:
from guppylang.module import GuppyModule


@dataclass
class DefinedFunction(CallableVariable):
class DefinedFunction(CallableVariable, GlobalVariable):
"""A user-defined function"""

ty: FunctionType
defined_at: ast.FunctionDef

@staticmethod
def from_ast(
func_def: ast.FunctionDef, name: str, globals: Globals
func_def: ast.FunctionDef, name: str, module: "GuppyModule", globals: Globals
) -> "DefinedFunction":
ty = check_signature(func_def, globals)
if ty.quantified:
raise GuppyError(
"Generic function definitions are not supported yet", func_def
)
return DefinedFunction(name, ty, func_def, None)
return DefinedFunction(name, ty, func_def, None, module)

def check_call(
self, args: list[ast.expr], ty: GuppyType, node: AstNode, ctx: Context
Expand Down Expand Up @@ -79,7 +89,7 @@ def check_global_func_def(func: DefinedFunction, globals: Globals) -> CheckedFun
for x, ty, loc in zip(func.ty.arg_names, func.ty.args, args)
]
cfg = check_cfg(cfg, inputs, func.ty.returns, globals)
return CheckedFunction(func_def.name, func.ty, func_def, None, cfg)
return CheckedFunction(func.name, func.ty, func_def, None, func.module, cfg)


def check_nested_func_def(
Expand Down Expand Up @@ -141,7 +151,7 @@ def check_nested_func_def(
if func_def.name in cfg.live_before[cfg.entry_bb]:
if not captured:
# If there are no captured vars, we treat the function like a global name
func = DefinedFunction(func_def.name, func_ty, func_def, None)
func = DefinedFunction(func_def.name, func_ty, func_def, None, None)
globals = ctx.globals | Globals({func_def.name: func}, {}, {}, {})

else:
Expand Down
4 changes: 2 additions & 2 deletions guppylang/compiler/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from dataclasses import dataclass

from guppylang.ast_util import AstNode
from guppylang.checker.core import CallableVariable, Variable
from guppylang.checker.core import CallableVariable, GlobalVariable, Variable
from guppylang.gtypes import FunctionType, Inst
from guppylang.hugr.hugr import DFContainingNode, Hugr, OutPortV

Expand All @@ -28,7 +28,7 @@ def __init__(
object.__setattr__(self, "port", port)


class CompiledVariable(ABC, Variable):
class CompiledVariable(ABC, GlobalVariable):
"""Abstract base class for compiled global module-level variables."""

@abstractmethod
Expand Down
5 changes: 3 additions & 2 deletions guppylang/compiler/expr_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,8 @@ def visit_LocalName(self, node: LocalName) -> OutPortV:
return self.dfg[node.id].port

def visit_GlobalName(self, node: GlobalName) -> OutPortV:
return self.globals[node.id].load(self.dfg, self.graph, self.globals, node)
name = node.value.qualname
return self.globals[name].load(self.dfg, self.graph, self.globals, node)

def visit_Name(self, node: ast.Name) -> OutPortV:
raise InternalGuppyError("Node should have been removed during type checking.")
Expand Down Expand Up @@ -179,7 +180,7 @@ def visit_LocalCall(self, node: LocalCall) -> OutPortV:
return self._pack_returns(rets)

def visit_GlobalCall(self, node: GlobalCall) -> OutPortV:
func = self.globals[node.func.name]
func = self.globals[node.func.qualname]
assert isinstance(func, CompiledFunction)

args = [self.visit(arg) for arg in node.args]
Expand Down
8 changes: 6 additions & 2 deletions guppylang/compiler/func_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,9 @@ def compile_global_func_def(
parent=def_node,
)

return CompiledFunctionDef(func.name, func.ty, func.defined_at, None, def_node)
return CompiledFunctionDef(
func.name, func.ty, func.defined_at, None, func.module, def_node
)


def compile_local_func_def(
Expand Down Expand Up @@ -98,7 +100,9 @@ def compile_local_func_def(
else:
# Otherwise, we treat the function like a normal global variable
globals = globals | {
func.name: CompiledFunctionDef(func.name, func.ty, func, None, def_node)
func.name: CompiledFunctionDef(
func.name, func.ty, func, None, None, def_node
)
}

# Compile the CFG
Expand Down
6 changes: 6 additions & 0 deletions guppylang/custom.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import ast
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING

from guppylang.ast_util import AstNode, get_type, with_loc, with_type
from guppylang.checker.core import Context, Globals
Expand All @@ -16,6 +17,9 @@
from guppylang.hugr.hugr import DFContainingVNode, Hugr, Node, OutPortV
from guppylang.nodes import GlobalCall

if TYPE_CHECKING:
from guppylang.module import GuppyModule


class CustomFunction(CompiledFunction):
"""A function whose type checking and compilation behaviour can be customised."""
Expand All @@ -35,13 +39,15 @@ class CustomFunction(CompiledFunction):
def __init__(
self,
name: str,
module: "GuppyModule",
defined_at: ast.FunctionDef | None,
compiler: "CustomCallCompiler",
checker: "CustomCallChecker",
higher_order_value: bool = True,
ty: FunctionType | None = None,
):
self.name = name
self.module = module
self.defined_at = defined_at
self.higher_order_value = higher_order_value
self.call_compiler = compiler
Expand Down
8 changes: 6 additions & 2 deletions guppylang/declared.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import ast
from dataclasses import dataclass
from typing import TYPE_CHECKING

from guppylang.ast_util import AstNode, has_empty_body, with_loc
from guppylang.checker.core import Context, Globals
Expand All @@ -11,6 +12,9 @@
from guppylang.hugr.hugr import Hugr, Node, OutPortV, VNode
from guppylang.nodes import GlobalCall

if TYPE_CHECKING:
from guppylang.module import GuppyModule


@dataclass
class DeclaredFunction(CompiledFunction):
Expand All @@ -20,14 +24,14 @@ class DeclaredFunction(CompiledFunction):

@staticmethod
def from_ast(
func_def: ast.FunctionDef, name: str, globals: Globals
func_def: ast.FunctionDef, name: str, module: "GuppyModule", globals: Globals
) -> "DeclaredFunction":
ty = check_signature(func_def, globals)
if not has_empty_body(func_def):
raise GuppyError(
"Body of function declaration must be empty", func_def.body[0]
)
return DeclaredFunction(name, ty, func_def, None)
return DeclaredFunction(name, ty, func_def, None, module)

def check_call(
self, args: list[ast.expr], ty: GuppyType, node: AstNode, ctx: Context
Expand Down
20 changes: 17 additions & 3 deletions guppylang/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from dataclasses import dataclass
from pathlib import Path
from types import ModuleType
from typing import Any, ClassVar, TypeVar
from typing import Any, ClassVar, Protocol, TypeVar

from guppylang.ast_util import AstNode, has_empty_body
from guppylang.custom import (
Expand All @@ -26,6 +26,12 @@
ClassDecorator = Callable[[type], type]


class ClassWithGuppyType(Protocol):
"""Mypy protocol for a class that is annotated with a Guppy type."""

_guppy_type: type[GuppyType]


@dataclass(frozen=True)
class ModuleIdentifier:
"""Identifier for the Python file/module that called the decorator."""
Expand Down Expand Up @@ -113,12 +119,17 @@ def _get_python_caller(self, fn: PyFunc | None = None) -> ModuleIdentifier:
return ModuleIdentifier(Path(filename), module)

@pretty_errors
def extend_type(self, module: GuppyModule, ty: type[GuppyType]) -> ClassDecorator:
def extend_type(
self, module: GuppyModule, ty: type[GuppyType] | ClassWithGuppyType
) -> ClassDecorator:
"""Decorator to add new instance functions to a type."""
module._instance_func_buffer = {}
guppy_ty = (
ty if isinstance(ty, type) and issubclass(ty, GuppyType) else ty._guppy_type
)

def dec(c: type) -> type:
module._register_buffered_instance_funcs(ty)
module._register_buffered_instance_funcs(guppy_ty)
return c

return dec
Expand All @@ -142,11 +153,13 @@ def type(

def dec(c: type) -> type:
_name = name or c.__name__
_module = module

@dataclass(frozen=True)
class NewType(GuppyType):
args: Sequence[GuppyType]
name: ClassVar[str] = _name
module: ClassVar[GuppyModule | None] = _module

@staticmethod
def build(*args: GuppyType, node: AstNode | None = None) -> "GuppyType":
Expand Down Expand Up @@ -220,6 +233,7 @@ def dec(f: PyFunc) -> CustomFunction:
call_checker = checker or DefaultCallChecker()
func = CustomFunction(
name or func_ast.name,
module,
func_ast,
compiler or DefaultCallCompiler(),
call_checker,
Expand Down
6 changes: 6 additions & 0 deletions guppylang/gtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

if TYPE_CHECKING:
from guppylang.checker.core import Globals
from guppylang.module import GuppyModule


Subst = dict["ExistentialTypeVar", "GuppyType"]
Expand All @@ -28,6 +29,7 @@ class GuppyType(ABC):
"""

name: ClassVar[str]
module: ClassVar["GuppyModule | None"] = None

# Cache for free variables
_unsolved_vars: set["ExistentialTypeVar"] = field(init=False, repr=False)
Expand All @@ -51,6 +53,10 @@ def __post_init__(self) -> None:
vs |= arg.unsolved_vars
object.__setattr__(self, "_unsolved_vars", vs)

@classmethod
def qualname(cls) -> str:
return f"{cls.module.name}.{cls.name}" if cls.module else cls.name

@staticmethod
@abstractmethod
def build(*args: "GuppyType", node: AstNode | None = None) -> "GuppyType":
Expand Down
Loading

0 comments on commit 9f972a3

Please sign in to comment.