From 73e29f25ec90b8dfcc6517b961d6d1d13f694cb6 Mon Sep 17 00:00:00 2001 From: Mark Koch <48097969+mark-koch@users.noreply.github.com> Date: Tue, 19 Mar 2024 17:26:09 +0000 Subject: [PATCH] feat: New type representation with parameters (#174) * Prepare support for types and functions that are generic over constant values (e.g. bounded nats): - Generic types and functions are now defined in terms of `Parameter`s and `Argument`s that can be either types or constants (see `tys/param.py` and `tys/arg.py`). - Implementations for `ConstParam` and `ConstArg` will follow in the future * Improved pretty printing of types (see `tys/printing.py`) * Add a notion of `TypeDefinition` (see `tys/definition.py`) that replaces the ad-hoc creation of Python classes to define new types * `BoolType` is no longer a `SumType`. This was a hugr implementation detail and not relevant for the Guppy type system. Drive by renaming: * Rename function type members from `args/return` to `inputs/output` since `args` is already used now * Rename `GuppyType` to `Type` --- guppylang/ast_util.py | 12 +- guppylang/cfg/builder.py | 4 +- guppylang/checker/cfg_checker.py | 12 +- guppylang/checker/core.py | 80 ++- guppylang/checker/expr_checker.py | 272 +++++---- guppylang/checker/func_checker.py | 63 +- guppylang/checker/stmt_checker.py | 14 +- guppylang/compiler/cfg_compiler.py | 9 +- guppylang/compiler/core.py | 3 +- guppylang/compiler/expr_compiler.py | 39 +- guppylang/compiler/func_compiler.py | 19 +- guppylang/compiler/stmt_compiler.py | 2 +- guppylang/custom.py | 30 +- guppylang/declared.py | 9 +- guppylang/decorator.py | 57 +- guppylang/error.py | 54 -- guppylang/gtypes.py | 654 --------------------- guppylang/hugr/hugr.py | 46 +- guppylang/hugr/ops.py | 1 - guppylang/module.py | 37 +- guppylang/nodes.py | 9 +- guppylang/prelude/_internal.py | 41 +- guppylang/prelude/builtins.py | 8 +- guppylang/tys/__init__.py | 0 guppylang/tys/arg.py | 87 +++ guppylang/tys/common.py | 60 ++ guppylang/tys/const.py | 62 ++ guppylang/tys/definition.py | 210 +++++++ guppylang/tys/param.py | 171 ++++++ guppylang/tys/parsing.py | 116 ++++ guppylang/tys/printing.py | 126 ++++ guppylang/tys/subst.py | 64 ++ guppylang/tys/ty.py | 516 ++++++++++++++++ guppylang/tys/var.py | 49 ++ tests/error/poly_errors/non_linear2.err | 2 +- tests/error/poly_errors/pass_poly_free.err | 2 +- tests/hugr/test_dummy_nodes.py | 17 +- tests/hugr/test_ports.py | 14 - 38 files changed, 1863 insertions(+), 1108 deletions(-) delete mode 100644 guppylang/gtypes.py create mode 100644 guppylang/tys/__init__.py create mode 100644 guppylang/tys/arg.py create mode 100644 guppylang/tys/common.py create mode 100644 guppylang/tys/const.py create mode 100644 guppylang/tys/definition.py create mode 100644 guppylang/tys/param.py create mode 100644 guppylang/tys/parsing.py create mode 100644 guppylang/tys/printing.py create mode 100644 guppylang/tys/subst.py create mode 100644 guppylang/tys/ty.py create mode 100644 guppylang/tys/var.py delete mode 100644 tests/hugr/test_ports.py diff --git a/guppylang/ast_util.py b/guppylang/ast_util.py index 6c435f8e..67dd9d40 100644 --- a/guppylang/ast_util.py +++ b/guppylang/ast_util.py @@ -5,7 +5,7 @@ from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar, cast if TYPE_CHECKING: - from guppylang.gtypes import GuppyType + from guppylang.tys.ty import Type AstNode = ( ast.AST @@ -286,24 +286,24 @@ def with_loc(loc: ast.AST, node: A) -> A: return node -def with_type(ty: "GuppyType", node: A) -> A: +def with_type(ty: "Type", node: A) -> A: """Annotates an AST node with a type.""" node.type = ty # type: ignore[attr-defined] return node -def get_type_opt(node: AstNode) -> Optional["GuppyType"]: +def get_type_opt(node: AstNode) -> Optional["Type"]: """Tries to retrieve a type annotation from an AST node.""" - from guppylang.gtypes import GuppyType + from guppylang.tys.ty import Type, TypeBase try: ty = node.type # type: ignore[union-attr] - return ty if isinstance(ty, GuppyType) else None + return cast(Type, ty) if isinstance(ty, TypeBase) else None except AttributeError: return None -def get_type(node: AstNode) -> "GuppyType": +def get_type(node: AstNode) -> "Type": """Retrieve a type annotation from an AST node. Fails if the node is not annotated. diff --git a/guppylang/cfg/builder.py b/guppylang/cfg/builder.py index 1a2490da..60fc13c1 100644 --- a/guppylang/cfg/builder.py +++ b/guppylang/cfg/builder.py @@ -15,7 +15,6 @@ from guppylang.cfg.cfg import CFG from guppylang.checker.core import Globals from guppylang.error import GuppyError, InternalGuppyError -from guppylang.gtypes import NoneType from guppylang.nodes import ( DesugaredGenerator, DesugaredListComp, @@ -26,6 +25,7 @@ NestedFunctionDef, PyExpr, ) +from guppylang.tys.ty import NoneType # In order to build expressions, need an endless stream of unique temporary variables # to store intermediate results @@ -213,7 +213,7 @@ def visit_FunctionDef( from guppylang.checker.func_checker import check_signature func_ty = check_signature(node, self.globals) - returns_none = isinstance(func_ty.returns, NoneType) + returns_none = isinstance(func_ty.output, NoneType) cfg = CFGBuilder().build(node.body, returns_none, self.globals) new_node = NestedFunctionDef( diff --git a/guppylang/checker/cfg_checker.py b/guppylang/checker/cfg_checker.py index d4f81084..ab3b7a3b 100644 --- a/guppylang/checker/cfg_checker.py +++ b/guppylang/checker/cfg_checker.py @@ -16,7 +16,7 @@ from guppylang.checker.expr_checker import ExprSynthesizer, to_bool from guppylang.checker.stmt_checker import StmtChecker from guppylang.error import GuppyError -from guppylang.gtypes import GuppyType +from guppylang.tys.ty import Type VarRow = Sequence[Variable] @@ -44,17 +44,17 @@ class CheckedBB(BB): class CheckedCFG(BaseCFG[CheckedBB]): - input_tys: list[GuppyType] - output_ty: GuppyType + input_tys: list[Type] + output_ty: Type - def __init__(self, input_tys: list[GuppyType], output_ty: GuppyType) -> None: + def __init__(self, input_tys: list[Type], output_ty: Type) -> None: super().__init__([]) self.input_tys = input_tys self.output_ty = output_ty def check_cfg( - cfg: CFG, inputs: VarRow, return_ty: GuppyType, globals: Globals + cfg: CFG, inputs: VarRow, return_ty: Type, globals: Globals ) -> CheckedCFG: """Type checks a control-flow graph. @@ -121,7 +121,7 @@ def check_bb( bb: BB, checked_cfg: CheckedCFG, inputs: VarRow, - return_ty: GuppyType, + return_ty: Type, globals: Globals, ) -> CheckedBB: cfg = bb.containing_cfg diff --git a/guppylang/checker/core.py b/guppylang/checker/core.py index ac5a4288..76349d10 100644 --- a/guppylang/checker/core.py +++ b/guppylang/checker/core.py @@ -6,17 +6,29 @@ from dataclasses import dataclass from typing import Any, NamedTuple +from typing_extensions import assert_never + from guppylang.ast_util import AstNode, name_nodes_in_ast -from guppylang.gtypes import ( - BoolType, +from guppylang.tys.definition import ( + TypeDef, + bool_type_def, + callable_type_def, + linst_type_def, + list_type_def, + none_type_def, + tuple_type_def, +) +from guppylang.tys.param import Parameter +from guppylang.tys.subst import Subst +from guppylang.tys.ty import ( + BoundTypeVar, + ExistentialTypeVar, FunctionType, - GuppyType, - LinstType, - ListType, NoneType, - Subst, + OpaqueType, SumType, TupleType, + Type, ) @@ -25,7 +37,7 @@ class Variable: """Class holding data associated with a variable.""" name: str - ty: GuppyType + ty: Type defined_at: AstNode | None used: AstNode | None @@ -38,14 +50,14 @@ class CallableVariable(ABC, Variable): @abstractmethod def check_call( - self, args: list[ast.expr], ty: GuppyType, node: AstNode, ctx: "Context" + self, args: list[ast.expr], ty: Type, node: AstNode, ctx: "Context" ) -> tuple[ast.expr, Subst]: """Checks the return type of a function call against a given type.""" @abstractmethod def synthesize_call( self, args: list[ast.expr], node: AstNode, ctx: "Context" - ) -> tuple[ast.expr, GuppyType]: + ) -> tuple[ast.expr, Type]: """Synthesizes the return type of a function call.""" @@ -68,30 +80,44 @@ class Globals(NamedTuple): """ values: dict[str, Variable] - types: dict[str, type[GuppyType]] - type_vars: dict[str, TypeVarDecl] + type_defs: dict[str, TypeDef] + param_vars: dict[str, Parameter] python_scope: PyScope @staticmethod def default() -> "Globals": """Generates a `Globals` instance that is populated with all core types""" - tys: dict[str, type[GuppyType]] = { - FunctionType.name: FunctionType, - TupleType.name: TupleType, - SumType.name: SumType, - NoneType.name: NoneType, - BoolType.name: BoolType, - ListType.name: ListType, - LinstType.name: LinstType, + type_defs = { + "Callable": callable_type_def, + "tuple": tuple_type_def, + "None": none_type_def, + "bool": bool_type_def, + "list": list_type_def, + "linst": linst_type_def, } - return Globals({}, tys, {}, {}) + return Globals({}, type_defs, {}, {}) - def get_instance_func(self, ty: GuppyType, name: str) -> CallableVariable | None: + def get_instance_func(self, ty: Type, name: str) -> CallableVariable | None: """Looks up an instance function with a given name for a type. Returns `None` if the name doesn't exist or isn't a function. """ - qualname = qualified_name(ty.__class__, name) + defn: TypeDef + match ty: + case BoundTypeVar() | ExistentialTypeVar() | SumType(): + return None + case FunctionType(): + defn = callable_type_def + case OpaqueType() as ty: + defn = ty.defn + case TupleType(): + defn = tuple_type_def + case NoneType(): + defn = none_type_def + case _: + assert_never(ty) + + qualname = qualified_name(defn.name, name) if qualname in self.values: val = self.values[qualname] if isinstance(val, CallableVariable): @@ -101,15 +127,15 @@ def get_instance_func(self, ty: GuppyType, name: str) -> CallableVariable | None def __or__(self, other: "Globals") -> "Globals": return Globals( self.values | other.values, - self.types | other.types, - self.type_vars | other.type_vars, + self.type_defs | other.type_defs, + self.param_vars | other.param_vars, self.python_scope | other.python_scope, ) def __ior__(self, other: "Globals") -> "Globals": # noqa: PYI034 self.values.update(other.values) - self.types.update(other.types) - self.type_vars.update(other.type_vars) + self.type_defs.update(other.type_defs) + self.param_vars.update(other.param_vars) return self @@ -203,7 +229,7 @@ def __contains__(self, key: object) -> bool: return super().__contains__(key) -def qualified_name(ty: type[GuppyType] | str, name: str) -> str: +def qualified_name(ty: TypeDef | str, name: str) -> str: """Returns a qualified name for an instance function on a type.""" ty_name = ty if isinstance(ty, str) else ty.name return f"{ty_name}.{name}" diff --git a/guppylang/checker/expr_checker.py b/guppylang/checker/expr_checker.py index d5cc069d..5bce35a8 100644 --- a/guppylang/checker/expr_checker.py +++ b/guppylang/checker/expr_checker.py @@ -49,20 +49,6 @@ GuppyTypeInferenceError, InternalGuppyError, ) -from guppylang.gtypes import ( - BoolType, - ExistentialTypeVar, - FunctionType, - GuppyType, - Inst, - LinstType, - ListType, - NoneType, - Subst, - TupleType, - row_to_type, - unify, -) from guppylang.nodes import ( DesugaredGenerator, DesugaredListComp, @@ -76,6 +62,29 @@ PyExpr, TypeApply, ) +from guppylang.tys.arg import TypeArg +from guppylang.tys.definition import ( + bool_type, + get_element_type, + is_bool_type, + is_linst_type, + is_list_type, + linst_type, + list_type, +) +from guppylang.tys.param import TypeParam +from guppylang.tys.subst import Inst, Subst +from guppylang.tys.ty import ( + ExistentialTypeVar, + FunctionType, + NoneType, + OpaqueType, + TupleType, + Type, + TypeBase, + row_to_type, + unify, +) # Mapping from unary AST op to dunder method and display name unary_table: dict[type[ast.unaryop], tuple[str, str]] = { @@ -128,12 +137,12 @@ def __init__(self, ctx: Context) -> None: def _fail( self, - expected: GuppyType, - actual: ast.expr | GuppyType, + expected: Type, + actual: ast.expr | Type, loc: AstNode | None = None, ) -> NoReturn: """Raises a type error indicating that the type doesn't match.""" - if not isinstance(actual, GuppyType): + if not isinstance(actual, TypeBase): loc = loc or actual _, actual = self._synthesize(actual, allow_free_vars=True) if loc is None: @@ -143,7 +152,7 @@ def _fail( ) def check( - self, expr: ast.expr, ty: GuppyType, kind: str = "expression" + self, expr: ast.expr, ty: Type, kind: str = "expression" ) -> tuple[ast.expr, Subst]: """Checks an expression against a type. @@ -173,11 +182,11 @@ def check( def _synthesize( self, node: ast.expr, allow_free_vars: bool - ) -> tuple[ast.expr, GuppyType]: + ) -> tuple[ast.expr, Type]: """Invokes the type synthesiser""" return ExprSynthesizer(self.ctx).synthesize(node, allow_free_vars) - def visit_Tuple(self, node: ast.Tuple, ty: GuppyType) -> tuple[ast.expr, Subst]: + def visit_Tuple(self, node: ast.Tuple, ty: Type) -> tuple[ast.expr, Subst]: if not isinstance(ty, TupleType) or len(ty.element_types) != len(node.elts): return self._fail(ty, node) subst: Subst = {} @@ -186,28 +195,29 @@ def visit_Tuple(self, node: ast.Tuple, ty: GuppyType) -> tuple[ast.expr, Subst]: subst |= s return node, subst - def visit_List(self, node: ast.List, ty: GuppyType) -> tuple[ast.expr, Subst]: - if not isinstance(ty, ListType | LinstType): + def visit_List(self, node: ast.List, ty: Type) -> tuple[ast.expr, Subst]: + if not is_list_type(ty) and not is_linst_type(ty): return self._fail(ty, node) + el_ty = get_element_type(ty) subst: Subst = {} for i, el in enumerate(node.elts): - node.elts[i], s = self.check(el, ty.element_type.substitute(subst)) + node.elts[i], s = self.check(el, el_ty.substitute(subst)) subst |= s return node, subst def visit_DesugaredListComp( - self, node: DesugaredListComp, ty: GuppyType + self, node: DesugaredListComp, ty: Type ) -> tuple[ast.expr, Subst]: - if not isinstance(ty, ListType | LinstType): + if not is_list_type(ty) and not is_linst_type(ty): return self._fail(ty, node) node, elt_ty = synthesize_comprehension(node, node.generators, self.ctx) - subst = unify(ty.element_type, elt_ty, {}) + subst = unify(get_element_type(ty), elt_ty, {}) if subst is None: - actual = LinstType(elt_ty) if elt_ty.linear else ListType(elt_ty) + actual = linst_type(elt_ty) if elt_ty.linear else list_type(elt_ty) return self._fail(ty, actual, node) return node, subst - def visit_Call(self, node: ast.Call, ty: GuppyType) -> tuple[ast.expr, Subst]: + def visit_Call(self, node: ast.Call, ty: Type) -> tuple[ast.expr, Subst]: if len(node.keywords) > 0: raise GuppyError( "Argument passing by keyword is not supported", node.keywords[0] @@ -231,7 +241,7 @@ def visit_Call(self, node: ast.Call, ty: GuppyType) -> tuple[ast.expr, Subst]: else: raise GuppyTypeError(f"Expected function type, got `{func_ty}`", node.func) - def visit_PyExpr(self, node: PyExpr, ty: GuppyType) -> tuple[ast.expr, Subst]: + def visit_PyExpr(self, node: PyExpr, ty: Type) -> tuple[ast.expr, Subst]: python_val = eval_py_expr(node, self.ctx) if act := python_value_to_guppy_type(python_val, node, self.ctx.globals): subst = unify(ty, act, {}) @@ -246,19 +256,19 @@ def visit_PyExpr(self, node: PyExpr, ty: GuppyType) -> tuple[ast.expr, Subst]: node, ) - def generic_visit(self, node: ast.expr, ty: GuppyType) -> tuple[ast.expr, Subst]: + def generic_visit(self, node: ast.expr, ty: Type) -> tuple[ast.expr, Subst]: # Try to synthesize and then check if we can unify it with the given type node, synth = self._synthesize(node, allow_free_vars=False) subst, inst = check_type_against(synth, ty, node, self._kind) # Apply instantiation of quantified type variables if inst: - node = with_loc(node, TypeApply(value=node, tys=inst)) + node = with_loc(node, TypeApply(value=node, inst=inst)) return node, subst -class ExprSynthesizer(AstVisitor[tuple[ast.expr, GuppyType]]): +class ExprSynthesizer(AstVisitor[tuple[ast.expr, Type]]): ctx: Context def __init__(self, ctx: Context) -> None: @@ -266,7 +276,7 @@ def __init__(self, ctx: Context) -> None: def synthesize( self, node: ast.expr, allow_free_vars: bool = False - ) -> tuple[ast.expr, GuppyType]: + ) -> tuple[ast.expr, Type]: """Tries to synthesise a type for the given expression. Also returns a new desugared expression with type annotations. @@ -281,18 +291,18 @@ def synthesize( return with_type(ty, node), ty def _check( - self, expr: ast.expr, ty: GuppyType, kind: str = "expression" + self, expr: ast.expr, ty: Type, kind: str = "expression" ) -> tuple[ast.expr, Subst]: """Checks an expression against a given type""" return ExprChecker(self.ctx).check(expr, ty, kind) - def visit_Constant(self, node: ast.Constant) -> tuple[ast.expr, GuppyType]: + def visit_Constant(self, node: ast.Constant) -> tuple[ast.expr, Type]: ty = python_value_to_guppy_type(node.value, node, self.ctx.globals) if ty is None: raise GuppyError("Unsupported constant", node) return node, ty - def visit_Name(self, node: ast.Name) -> tuple[ast.Name, GuppyType]: + def visit_Name(self, node: ast.Name) -> tuple[ast.Name, Type]: x = node.id if x in self.ctx.locals: var = self.ctx.locals[x] @@ -314,28 +324,26 @@ def visit_Name(self, node: ast.Name) -> tuple[ast.Name, GuppyType]: f"been caught by program analysis!" ) - def visit_Tuple(self, node: ast.Tuple) -> tuple[ast.expr, GuppyType]: + def visit_Tuple(self, node: ast.Tuple) -> tuple[ast.expr, Type]: elems = [self.synthesize(elem) for elem in node.elts] node.elts = [n for n, _ in elems] return node, TupleType([ty for _, ty in elems]) - def visit_List(self, node: ast.List) -> tuple[ast.expr, GuppyType]: + def visit_List(self, node: ast.List) -> tuple[ast.expr, Type]: if len(node.elts) == 0: raise GuppyTypeInferenceError( "Cannot infer type variable in expression of type `list[?T]`", node ) node.elts[0], el_ty = self.synthesize(node.elts[0]) node.elts[1:] = [self._check(el, el_ty)[0] for el in node.elts[1:]] - return node, LinstType(el_ty) if el_ty.linear else ListType(el_ty) + return node, linst_type(el_ty) if el_ty.linear else list_type(el_ty) - def visit_DesugaredListComp( - self, node: DesugaredListComp - ) -> tuple[ast.expr, GuppyType]: + def visit_DesugaredListComp(self, node: DesugaredListComp) -> tuple[ast.expr, Type]: node, elt_ty = synthesize_comprehension(node, node.generators, self.ctx) - result_ty = LinstType(elt_ty) if elt_ty.linear else ListType(elt_ty) + result_ty = linst_type(elt_ty) if elt_ty.linear else list_type(elt_ty) return node, result_ty - def visit_UnaryOp(self, node: ast.UnaryOp) -> tuple[ast.expr, GuppyType]: + def visit_UnaryOp(self, node: ast.UnaryOp) -> tuple[ast.expr, Type]: # We need to synthesise the argument type, so we can look up dunder methods node.operand, op_ty = self.synthesize(node.operand) @@ -358,7 +366,7 @@ def visit_UnaryOp(self, node: ast.UnaryOp) -> tuple[ast.expr, GuppyType]: def _synthesize_binary( self, left_expr: ast.expr, right_expr: ast.expr, op: AstOp, node: ast.expr - ) -> tuple[ast.expr, GuppyType]: + ) -> tuple[ast.expr, Type]: """Helper method to compile binary operators by calling out to dunder methods. For example, first try calling `__add__` on the left operand. If that fails, try @@ -392,7 +400,7 @@ def _synthesize_instance_func( err: str, exp_sig: FunctionType | None = None, give_reason: bool = False, - ) -> tuple[ast.expr, GuppyType]: + ) -> tuple[ast.expr, Type]: """Helper method for expressions that are implemented via instance methods. Raises a `GuppyTypeError` if the given instance method is not defined. The error @@ -412,16 +420,16 @@ def _synthesize_instance_func( ) if exp_sig and unify(exp_sig, func.ty.unquantified()[0], {}) is None: raise GuppyError( - f"Method `{ty.name}.{func_name}` has signature `{func.ty}`, but " + f"Method `{ty}.{func_name}` has signature `{func.ty}`, but " f"expected `{exp_sig}`", node, ) return func.synthesize_call([node, *args], node, self.ctx) - def visit_BinOp(self, node: ast.BinOp) -> tuple[ast.expr, GuppyType]: + def visit_BinOp(self, node: ast.BinOp) -> tuple[ast.expr, Type]: return self._synthesize_binary(node.left, node.right, node.op, node) - def visit_Compare(self, node: ast.Compare) -> tuple[ast.expr, GuppyType]: + def visit_Compare(self, node: ast.Compare) -> tuple[ast.expr, Type]: if len(node.comparators) != 1 or len(node.ops) != 1: raise InternalGuppyError( "BB contains chained comparison. Should have been removed during CFG " @@ -430,17 +438,17 @@ def visit_Compare(self, node: ast.Compare) -> tuple[ast.expr, GuppyType]: left_expr, [op], [right_expr] = node.left, node.ops, node.comparators return self._synthesize_binary(left_expr, right_expr, op, node) - def visit_Subscript(self, node: ast.Subscript) -> tuple[ast.expr, GuppyType]: + def visit_Subscript(self, node: ast.Subscript) -> tuple[ast.expr, Type]: node.value, ty = self.synthesize(node.value) exp_sig = FunctionType( - [ty, ExistentialTypeVar.new("Key", False)], - ExistentialTypeVar.new("Val", False), + [ty, ExistentialTypeVar.fresh("Key", False)], + ExistentialTypeVar.fresh("Val", False), ) return self._synthesize_instance_func( node.value, [node.slice], "__getitem__", "not subscriptable", exp_sig ) - def visit_Call(self, node: ast.Call) -> tuple[ast.expr, GuppyType]: + def visit_Call(self, node: ast.Call) -> tuple[ast.expr, Type]: if len(node.keywords) > 0: raise GuppyError("Keyword arguments are not supported", node.keywords[0]) node.func, ty = self.synthesize(node.func) @@ -461,9 +469,9 @@ def visit_Call(self, node: ast.Call) -> tuple[ast.expr, GuppyType]: else: raise GuppyTypeError(f"Expected function type, got `{ty}`", node.func) - def visit_MakeIter(self, node: MakeIter) -> tuple[ast.expr, GuppyType]: + def visit_MakeIter(self, node: MakeIter) -> tuple[ast.expr, Type]: node.value, ty = self.synthesize(node.value) - exp_sig = FunctionType([ty], ExistentialTypeVar.new("Iter", False)) + exp_sig = FunctionType([ty], ExistentialTypeVar.fresh("Iter", False)) expr, ty = self._synthesize_instance_func( node.value, [], "__iter__", "not iterable", exp_sig ) @@ -483,36 +491,36 @@ def visit_MakeIter(self, node: MakeIter) -> tuple[ast.expr, GuppyType]: ) return expr, ty - def visit_IterHasNext(self, node: IterHasNext) -> tuple[ast.expr, GuppyType]: + def visit_IterHasNext(self, node: IterHasNext) -> tuple[ast.expr, Type]: node.value, ty = self.synthesize(node.value) - exp_sig = FunctionType([ty], TupleType([BoolType(), ty])) + exp_sig = FunctionType([ty], TupleType([bool_type(), ty])) return self._synthesize_instance_func( node.value, [], "__hasnext__", "not an iterator", exp_sig, True ) - def visit_IterNext(self, node: IterNext) -> tuple[ast.expr, GuppyType]: + def visit_IterNext(self, node: IterNext) -> tuple[ast.expr, Type]: node.value, ty = self.synthesize(node.value) exp_sig = FunctionType( - [ty], TupleType([ExistentialTypeVar.new("T", False), ty]) + [ty], TupleType([ExistentialTypeVar.fresh("T", False), ty]) ) return self._synthesize_instance_func( node.value, [], "__next__", "not an iterator", exp_sig, True ) - def visit_IterEnd(self, node: IterEnd) -> tuple[ast.expr, GuppyType]: + def visit_IterEnd(self, node: IterEnd) -> tuple[ast.expr, Type]: node.value, ty = self.synthesize(node.value) exp_sig = FunctionType([ty], NoneType()) return self._synthesize_instance_func( node.value, [], "__end__", "not an iterator", exp_sig, True ) - def visit_ListComp(self, node: ast.ListComp) -> tuple[ast.expr, GuppyType]: + def visit_ListComp(self, node: ast.ListComp) -> tuple[ast.expr, Type]: raise InternalGuppyError( "BB contains `ListComp`. Should have been removed during CFG" f"construction: `{ast.unparse(node)}`" ) - def visit_PyExpr(self, node: PyExpr) -> tuple[ast.expr, GuppyType]: + def visit_PyExpr(self, node: PyExpr) -> tuple[ast.expr, Type]: python_val = eval_py_expr(node, self.ctx) if ty := python_value_to_guppy_type(python_val, node, self.ctx.globals): return with_loc(node, ast.Constant(value=python_val)), ty @@ -522,19 +530,19 @@ def visit_PyExpr(self, node: PyExpr) -> tuple[ast.expr, GuppyType]: node, ) - def visit_NamedExpr(self, node: ast.NamedExpr) -> tuple[ast.expr, GuppyType]: + def visit_NamedExpr(self, node: ast.NamedExpr) -> tuple[ast.expr, Type]: raise InternalGuppyError( "BB contains `NamedExpr`. Should have been removed during CFG" f"construction: `{ast.unparse(node)}`" ) - def visit_BoolOp(self, node: ast.BoolOp) -> tuple[ast.expr, GuppyType]: + def visit_BoolOp(self, node: ast.BoolOp) -> tuple[ast.expr, Type]: raise InternalGuppyError( "BB contains `BoolOp`. Should have been removed during CFG construction: " f"`{ast.unparse(node)}`" ) - def visit_IfExp(self, node: ast.IfExp) -> tuple[ast.expr, GuppyType]: + def visit_IfExp(self, node: ast.IfExp) -> tuple[ast.expr, Type]: raise InternalGuppyError( "BB contains `IfExp`. Should have been removed during CFG construction: " f"`{ast.unparse(node)}`" @@ -542,42 +550,42 @@ def visit_IfExp(self, node: ast.IfExp) -> tuple[ast.expr, GuppyType]: def check_type_against( - act: GuppyType, exp: GuppyType, node: AstNode, kind: str = "expression" + act: Type, exp: Type, node: AstNode, kind: str = "expression" ) -> tuple[Subst, Inst]: """Checks a type against another type. Returns a substitution for the free variables the expected type and an instantiation - for the quantified variables in the actual type. Note that the expected type may not - be quantified and the actual type may not contain free unification variables. + for the parameters in the actual type. Note that the expected type may not be + parametrised and the actual type may not contain free unification variables. """ - assert not isinstance(exp, FunctionType) or not exp.quantified + assert not isinstance(exp, FunctionType) or not exp.parametrized assert not act.unsolved_vars - # The actual type may be quantified. In that case, we have to find an instantiation - # to avoid higher-rank types. + # The actual type may be parametrised. In that case, we have to find an + # instantiation to avoid higher-rank types. subst: Subst | None - if isinstance(act, FunctionType) and act.quantified: + if isinstance(act, FunctionType) and act.parametrized: unquantified, free_vars = act.unquantified() subst = unify(exp, unquantified, {}) if subst is None: raise GuppyTypeError(f"Expected {kind} of type `{exp}`, got `{act}`", node) - # Check that we have found a valid instantiation for all quantified vars + # Check that we have found a valid instantiation for all params for i, v in enumerate(free_vars): if v not in subst: raise GuppyTypeInferenceError( f"Expected {kind} of type `{exp}`, got `{act}`. Couldn't infer an " - f"instantiation for type variable `{act.quantified[i]}` (higher-" - "rank polymorphic types are not supported)", + f"instantiation for parameter `{act.params[i].name}` (higher-rank " + "polymorphic types are not supported)", node, ) if subst[v].unsolved_vars: raise GuppyTypeError( f"Expected {kind} of type `{exp}`, got `{act}`. Can't instantiate " - f"type variable `{act.quantified[i]}` with type `{subst[v]}` " - "containing free variables", + f"parameter `{act.params[i]}` with type `{subst[v]}` containing " + "free variables", node, ) - inst = [subst[v] for v in free_vars] + inst = [TypeArg(subst[v]) for v in free_vars] subst = {v: t for v, t in subst.items() if v in exp.unsolved_vars} # Finally, check that the instantiation respects the linearity requirements @@ -608,7 +616,7 @@ def check_num_args(exp: int, act: int, node: AstNode) -> None: def type_check_args( - args: list[ast.expr], + inputs: list[ast.expr], func_ty: FunctionType, subst: Subst, ctx: Context, @@ -616,27 +624,27 @@ def type_check_args( ) -> tuple[list[ast.expr], Subst]: """Checks the arguments of a function call and infers free type variables. - We expect that quantified variables have been replaced with free unification - variables. Checks that all unification variables can be inferred. + We expect that parameters have been replaced with free unification variables. + Checks that all unification variables can be inferred. """ - assert not func_ty.quantified - check_num_args(len(func_ty.args), len(args), node) + assert not func_ty.parametrized + check_num_args(len(func_ty.inputs), len(inputs), node) new_args: list[ast.expr] = [] - for arg, ty in zip(args, func_ty.args): - a, s = ExprChecker(ctx).check(arg, ty.substitute(subst), "argument") + for inp, ty in zip(inputs, func_ty.inputs): + a, s = ExprChecker(ctx).check(inp, ty.substitute(subst), "argument") new_args.append(a) subst |= s # If the argument check succeeded, this means that we must have found instantiations - # for all unification variables occurring in the argument types - assert all(set.issubset(arg.unsolved_vars, subst.keys()) for arg in func_ty.args) + # for all unification variables occurring in the input types + assert all(set.issubset(inp.unsolved_vars, subst.keys()) for inp in func_ty.inputs) # We also have to check that we found instantiations for all vars in the return type - if not set.issubset(func_ty.returns.unsolved_vars, subst.keys()): + if not set.issubset(func_ty.output.unsolved_vars, subst.keys()): raise GuppyTypeInferenceError( f"Cannot infer type variable in expression of type " - f"`{func_ty.returns.substitute(subst)}`", + f"`{func_ty.output.substitute(subst)}`", node, ) @@ -645,14 +653,14 @@ def type_check_args( def synthesize_call( func_ty: FunctionType, args: list[ast.expr], node: AstNode, ctx: Context -) -> tuple[list[ast.expr], GuppyType, Inst]: +) -> tuple[list[ast.expr], Type, Inst]: """Synthesizes the return type of a function call. Returns an annotated argument list, the synthesized return type, and an instantiation for the quantifiers in the function type. """ assert not func_ty.unsolved_vars - check_num_args(len(func_ty.args), len(args), node) + check_num_args(len(func_ty.inputs), len(args), node) # Replace quantified variables with free unification variables and try to infer an # instantiation by checking the arguments @@ -661,18 +669,18 @@ def synthesize_call( # Success implies that the substitution is closed assert all(not t.unsolved_vars for t in subst.values()) - inst = [subst[v] for v in free_vars] + inst = [TypeArg(subst[v]) for v in free_vars] # Finally, check that the instantiation respects the linearity requirements check_inst(func_ty, inst, node) - return args, unquantified.returns.substitute(subst), inst + return args, unquantified.output.substitute(subst), inst def check_call( func_ty: FunctionType, - args: list[ast.expr], - ty: GuppyType, + inputs: list[ast.expr], + ty: Type, node: AstNode, ctx: Context, kind: str = "expression", @@ -683,7 +691,7 @@ def check_call( expected type, and an instantiation for the quantifiers in the function type. """ assert not func_ty.unsolved_vars - check_num_args(len(func_ty.args), len(args), node) + check_num_args(len(func_ty.inputs), len(inputs), node) # When checking, we can use the information from the expected return type to infer # some type arguments. However, this pushes errors inwards. For example, given a @@ -705,9 +713,9 @@ def check_call( # in practice. Can we do better than that? # First, try to synthesize - res: tuple[GuppyType, Inst] | None = None + res: tuple[Type, Inst] | None = None try: - args, synth, inst = synthesize_call(func_ty, args, node, ctx) + inputs, synth, inst = synthesize_call(func_ty, inputs, node, ctx) res = synth, inst except GuppyTypeInferenceError: pass @@ -716,38 +724,38 @@ def check_call( subst = unify(ty, synth, {}) if subst is None: raise GuppyTypeError(f"Expected {kind} of type `{ty}`, got `{synth}`", node) - return args, subst, inst + return inputs, subst, inst # If synthesis fails, we try again, this time also using information from the # expected return type unquantified, free_vars = func_ty.unquantified() - subst = unify(ty, unquantified.returns, {}) + subst = unify(ty, unquantified.output, {}) if subst is None: raise GuppyTypeError( - f"Expected {kind} of type `{ty}`, got `{unquantified.returns}`", node + f"Expected {kind} of type `{ty}`, got `{unquantified.output}`", node ) # Try to infer more by checking against the arguments - args, subst = type_check_args(args, unquantified, subst, ctx, node) + inputs, subst = type_check_args(inputs, unquantified, subst, ctx, node) # Also make sure we found an instantiation for all free vars in the type we're # checking against if not set.issubset(ty.unsolved_vars, subst.keys()): raise GuppyTypeInferenceError( f"Expected expression of type `{ty}`, got " - f"`{func_ty.returns.substitute(subst)}`. Couldn't infer type variables", + f"`{func_ty.output.substitute(subst)}`. Couldn't infer type variables", node, ) # Success implies that the substitution is closed assert all(not t.unsolved_vars for t in subst.values()) - inst = [subst[v] for v in free_vars] + inst = [TypeArg(subst[v]) for v in free_vars] subst = {v: t for v, t in subst.items() if v in ty.unsolved_vars} # Finally, check that the instantiation respects the linearity requirements check_inst(func_ty, inst, node) - return args, subst, inst + return inputs, subst, inst def check_inst(func_ty: FunctionType, inst: Inst, node: AstNode) -> None: @@ -755,29 +763,35 @@ def check_inst(func_ty: FunctionType, inst: Inst, node: AstNode) -> None: Makes sure that the linearity requirements are satisfied. """ - for var, ty in zip(func_ty.quantified, inst): - if not var.linear and ty.linear: + for param, arg in zip(func_ty.params, inst, strict=True): + # Give a more informative error message for linearity issues + if ( + isinstance(param, TypeParam) + and isinstance(arg, TypeArg) + and arg.ty.linear + and not param.can_be_linear + ): raise GuppyTypeError( - f"Cannot instantiate non-linear type variable `{var}` in type " - f"`{func_ty}` with linear type `{ty}`", + f"Cannot instantiate non-linear type variable `{param.name}` in type " + f"`{func_ty}` with linear type `{arg.ty}`", node, ) + # For everything else, we fall back to the default checking implementation + param.check_arg(arg, node) def instantiate_poly(node: ast.expr, ty: FunctionType, inst: Inst) -> ast.expr: """Instantiates quantified type arguments in a function.""" - assert len(ty.quantified) == len(inst) + assert len(ty.params) == len(inst) if len(inst) > 0: - node = with_loc(node, TypeApply(value=with_type(ty, node), tys=inst)) + node = with_loc(node, TypeApply(value=with_type(ty, node), inst=inst)) return with_type(ty.instantiate(inst), node) return node -def to_bool( - node: ast.expr, node_ty: GuppyType, ctx: Context -) -> tuple[ast.expr, GuppyType]: +def to_bool(node: ast.expr, node_ty: Type, ctx: Context) -> tuple[ast.expr, Type]: """Tries to turn a node into a bool""" - if isinstance(node_ty, BoolType): + if is_bool_type(node_ty): return node, node_ty func = ctx.globals.get_instance_func(node_ty, "__bool__") @@ -790,7 +804,7 @@ def to_bool( # We could check the return type against bool, but we can give a better error # message if we synthesise and compare to bool by hand call, return_ty = func.synthesize_call([node], node, ctx) - if not isinstance(return_ty, BoolType): + if not is_bool_type(return_ty): raise GuppyTypeError( f"`__bool__` on type `{node_ty}` returns `{return_ty}` instead of `bool`", node, @@ -800,7 +814,7 @@ def to_bool( def synthesize_comprehension( node: DesugaredListComp, gens: list[DesugaredGenerator], ctx: Context -) -> tuple[DesugaredListComp, GuppyType]: +) -> tuple[DesugaredListComp, Type]: """Helper function to synthesise the element type of a list comprehension.""" from guppylang.checker.stmt_checker import StmtChecker @@ -919,25 +933,23 @@ def eval_py_expr(node: PyExpr, ctx: Context) -> Any: return python_val -def python_value_to_guppy_type( - v: Any, node: ast.expr, globals: Globals -) -> GuppyType | None: +def python_value_to_guppy_type(v: Any, node: ast.expr, globals: Globals) -> Type | None: """Turns a primitive Python value into a Guppy type. Returns `None` if the Python value cannot be represented in Guppy. """ match v: case bool(): - return globals.types["bool"].build(node=node) + return globals.type_defs["bool"].check_instantiate([]) case int(): - return globals.types["int"].build(node=node) + return globals.type_defs["int"].check_instantiate([]) case float(): - return globals.types["float"].build(node=node) + return globals.type_defs["float"].check_instantiate([]) case tuple(elts): tys = [python_value_to_guppy_type(elt, node, globals) for elt in elts] if any(ty is None for ty in tys): return None - return TupleType(cast(list[GuppyType], tys)) + return TupleType(cast(list[Type], tys)) case list(): return _python_list_to_guppy_type(v, node, globals) case _: @@ -950,10 +962,12 @@ def python_value_to_guppy_type( try: import tket2 # type: ignore[import-untyped, import-not-found, unused-ignore] # noqa: F401 - qubit = globals.types["qubit"].build() + qubit = globals.type_defs["qubit"].check_instantiate([]) return FunctionType( [qubit] * v.n_qubits, - row_to_type([qubit] * v.n_qubits + [BoolType()] * v.n_bits), + row_to_type( + [qubit] * v.n_qubits + [bool_type()] * v.n_bits + ), ) except ImportError: raise GuppyError( @@ -968,14 +982,14 @@ def python_value_to_guppy_type( def _python_list_to_guppy_type( vs: list[Any], node: ast.expr, globals: Globals -) -> ListType | None: +) -> OpaqueType | None: """Turns a Python list into a Guppy type. Returns `None` if the list contains different types or types that are not representable in Guppy. """ if len(vs) == 0: - return ListType(ExistentialTypeVar.new("T", False)) + return list_type(ExistentialTypeVar.fresh("T", False)) # All the list elements must have a unifiable types v, *rest = vs @@ -989,4 +1003,4 @@ def _python_list_to_guppy_type( if (subst := unify(ty, el_ty, {})) is None: raise GuppyError("Python list contains elements with different types", node) el_ty = el_ty.substitute(subst) - return ListType(el_ty) + return list_type(el_ty) diff --git a/guppylang/checker/func_checker.py b/guppylang/checker/func_checker.py index f224bd99..246d0627 100644 --- a/guppylang/checker/func_checker.py +++ b/guppylang/checker/func_checker.py @@ -7,6 +7,7 @@ import ast from dataclasses import dataclass +from typing import TYPE_CHECKING from guppylang.ast_util import AstNode, return_nodes_in_ast, with_loc from guppylang.cfg.bb import BB @@ -15,15 +16,13 @@ from guppylang.checker.core import CallableVariable, Context, Globals, Variable from guppylang.checker.expr_checker import check_call, synthesize_call from guppylang.error import GuppyError -from guppylang.gtypes import ( - BoundTypeVar, - FunctionType, - GuppyType, - NoneType, - Subst, - type_from_ast, -) from guppylang.nodes import CheckedNestedFunctionDef, GlobalCall, NestedFunctionDef +from guppylang.tys.parsing import type_from_ast +from guppylang.tys.subst import Subst +from guppylang.tys.ty import FunctionType, NoneType, Type + +if TYPE_CHECKING: + from guppylang.tys.param import Parameter @dataclass @@ -38,14 +37,14 @@ def from_ast( func_def: ast.FunctionDef, name: str, globals: Globals ) -> "DefinedFunction": ty = check_signature(func_def, globals) - if ty.quantified: + if ty.parametrized: raise GuppyError( "Generic function definitions are not supported yet", func_def ) return DefinedFunction(name, ty, func_def, None) def check_call( - self, args: list[ast.expr], ty: GuppyType, node: AstNode, ctx: Context + self, args: list[ast.expr], ty: Type, node: AstNode, ctx: Context ) -> tuple[ast.expr, Subst]: # Use default implementation from the expression checker args, subst, inst = check_call(self.ty, args, ty, node, ctx) @@ -53,7 +52,7 @@ def check_call( def synthesize_call( self, args: list[ast.expr], node: AstNode, ctx: Context - ) -> tuple[GlobalCall, GuppyType]: + ) -> tuple[GlobalCall, Type]: # Use default implementation from the expression checker args, ty, inst = synthesize_call(self.ty, args, node, ctx) return with_loc(node, GlobalCall(func=self, args=args, type_args=inst)), ty @@ -70,15 +69,15 @@ def check_global_func_def(func: DefinedFunction, globals: Globals) -> CheckedFun """Type checks a top-level function definition.""" func_def = func.defined_at args = func_def.args.args - returns_none = isinstance(func.ty.returns, NoneType) - assert func.ty.arg_names is not None + returns_none = isinstance(func.ty.output, NoneType) + assert func.ty.input_names is not None cfg = CFGBuilder().build(func_def.body, returns_none, globals) inputs = [ Variable(x, ty, loc, None) - for x, ty, loc in zip(func.ty.arg_names, func.ty.args, args) + for x, ty, loc in zip(func.ty.input_names, func.ty.inputs, args) ] - cfg = check_cfg(cfg, inputs, func.ty.returns, globals) + cfg = check_cfg(cfg, inputs, func.ty.output, globals) return CheckedFunction(func_def.name, func.ty, func_def, None, cfg) @@ -87,7 +86,7 @@ def check_nested_func_def( ) -> CheckedNestedFunctionDef: """Type checks a local (nested) function definition.""" func_ty = check_signature(func_def, ctx.globals) - assert func_ty.arg_names is not None + assert func_ty.input_names is not None # We've already built the CFG for this function while building the CFG of the # enclosing function @@ -95,13 +94,13 @@ def check_nested_func_def( # Find captured variables parent_cfg = bb.containing_cfg - def_ass_before = set(func_ty.arg_names) | ctx.locals.keys() + def_ass_before = set(func_ty.input_names) | ctx.locals.keys() maybe_ass_before = def_ass_before | parent_cfg.maybe_ass_before[bb] cfg.analyze(def_ass_before, maybe_ass_before) captured = { x: ctx.locals[x] for x in cfg.live_before[cfg.entry_bb] - if x not in func_ty.arg_names and x in ctx.locals + if x not in func_ty.input_names and x in ctx.locals } # Captured variables may not be linear @@ -131,7 +130,7 @@ def check_nested_func_def( # Construct inputs for checking the body CFG inputs = list(captured.values()) + [ Variable(x, ty, func_def.args.args[i], None) - for i, (x, ty) in enumerate(zip(func_ty.arg_names, func_ty.args)) + for i, (x, ty) in enumerate(zip(func_ty.input_names, func_ty.inputs)) ] globals = ctx.globals @@ -148,7 +147,7 @@ def check_nested_func_def( # Otherwise, we treat it like a local name inputs.append(Variable(func_def.name, func_def.ty, func_def, None)) - checked_cfg = check_cfg(cfg, inputs, func_ty.returns, globals) + checked_cfg = check_cfg(cfg, inputs, func_ty.output, globals) checked_def = CheckedNestedFunctionDef( checked_cfg, func_ty, @@ -188,20 +187,20 @@ def check_signature(func_def: ast.FunctionDef, globals: Globals) -> FunctionType raise GuppyError("Return type must be annotated", func_def) # TODO: Prepopulate mapping when using Python 3.12 style generic functions - type_var_mapping: dict[str, BoundTypeVar] = {} - arg_tys = [] - arg_names = [] - for _i, arg in enumerate(func_def.args.args): - if arg.annotation is None: - raise GuppyError("Argument type must be annotated", arg) - ty = type_from_ast(arg.annotation, globals, type_var_mapping) - arg_tys.append(ty) - arg_names.append(arg.arg) - + type_var_mapping: dict[str, "Parameter"] = {} + input_tys = [] + input_names = [] + for inp in func_def.args.args: + if inp.annotation is None: + raise GuppyError("Argument type must be annotated", inp) + ty = type_from_ast(inp.annotation, globals, type_var_mapping) + input_tys.append(ty) + input_names.append(inp.arg) ret_type = type_from_ast(func_def.returns, globals, type_var_mapping) + return FunctionType( - arg_tys, + input_tys, ret_type, - arg_names, + input_names, sorted(type_var_mapping.values(), key=lambda v: v.idx), ) diff --git a/guppylang/checker/stmt_checker.py b/guppylang/checker/stmt_checker.py index 686003fb..272814c5 100644 --- a/guppylang/checker/stmt_checker.py +++ b/guppylang/checker/stmt_checker.py @@ -16,17 +16,19 @@ from guppylang.checker.core import Context, Variable from guppylang.checker.expr_checker import ExprChecker, ExprSynthesizer from guppylang.error import GuppyError, GuppyTypeError, InternalGuppyError -from guppylang.gtypes import GuppyType, NoneType, Subst, TupleType, type_from_ast from guppylang.nodes import NestedFunctionDef +from guppylang.tys.parsing import type_from_ast +from guppylang.tys.subst import Subst +from guppylang.tys.ty import NoneType, TupleType, Type class StmtChecker(AstVisitor[BBStatement]): ctx: Context bb: BB | None - return_ty: GuppyType | None + return_ty: Type | None def __init__( - self, ctx: Context, bb: BB | None = None, return_ty: GuppyType | None = None + self, ctx: Context, bb: BB | None = None, return_ty: Type | None = None ) -> None: assert not return_ty or not return_ty.unsolved_vars self.ctx = ctx @@ -36,15 +38,15 @@ def __init__( def check_stmts(self, stmts: Sequence[BBStatement]) -> list[BBStatement]: return [self.visit(s) for s in stmts] - def _synth_expr(self, node: ast.expr) -> tuple[ast.expr, GuppyType]: + def _synth_expr(self, node: ast.expr) -> tuple[ast.expr, Type]: return ExprSynthesizer(self.ctx).synthesize(node) def _check_expr( - self, node: ast.expr, ty: GuppyType, kind: str = "expression" + self, node: ast.expr, ty: Type, kind: str = "expression" ) -> tuple[ast.expr, Subst]: return ExprChecker(self.ctx).check(node, ty, kind) - def _check_assign(self, lhs: ast.expr, ty: GuppyType, node: ast.stmt) -> None: + def _check_assign(self, lhs: ast.expr, ty: Type, node: ast.stmt) -> None: """Helper function to check assignments with patterns.""" match lhs: # Easiest case is if the LHS pattern is a single variable. diff --git a/guppylang/compiler/cfg_compiler.py b/guppylang/compiler/cfg_compiler.py index 2af09078..4ef9a170 100644 --- a/guppylang/compiler/cfg_compiler.py +++ b/guppylang/compiler/cfg_compiler.py @@ -12,8 +12,9 @@ ) from guppylang.compiler.expr_compiler import ExprCompiler from guppylang.compiler.stmt_compiler import StmtCompiler -from guppylang.gtypes import SumType, TupleType, type_to_row from guppylang.hugr.hugr import CFNode, Hugr, Node, OutPortV +from guppylang.tys.definition import is_bool_type +from guppylang.tys.ty import SumType, TupleType, type_to_row if TYPE_CHECKING: from collections.abc import Sequence @@ -135,8 +136,10 @@ def choose_vars_for_tuple_sum( Given `unit_sum: Sum((), (), ...)` and output variable sets `#s1, #s2, ...`, constructs a TupleSum value of type `Sum(Tuple(#s1), Tuple(#s2), ...)`. """ - assert isinstance(unit_sum.ty, SumType) - assert len(unit_sum.ty.element_types) == len(output_vars) + assert isinstance(unit_sum.ty, SumType) or is_bool_type(unit_sum.ty) + assert len(output_vars) == ( + len(unit_sum.ty.element_types) if isinstance(unit_sum.ty, SumType) else 2 + ) tuples = [ graph.add_make_tuple( inputs=[dfg[v.name].port for v in sort_vars(vs) if v.name in dfg], diff --git a/guppylang/compiler/core.py b/guppylang/compiler/core.py index 9e9026a8..f1fc942a 100644 --- a/guppylang/compiler/core.py +++ b/guppylang/compiler/core.py @@ -4,8 +4,9 @@ from guppylang.ast_util import AstNode from guppylang.checker.core import CallableVariable, Variable -from guppylang.gtypes import FunctionType, Inst from guppylang.hugr.hugr import DFContainingNode, Hugr, OutPortV +from guppylang.tys.subst import Inst +from guppylang.tys.ty import FunctionType @dataclass diff --git a/guppylang/compiler/expr_compiler.py b/guppylang/compiler/expr_compiler.py index 8d05fbed..8c8f64e0 100644 --- a/guppylang/compiler/expr_compiler.py +++ b/guppylang/compiler/expr_compiler.py @@ -13,17 +13,6 @@ PortVariable, ) from guppylang.error import GuppyError, InternalGuppyError -from guppylang.gtypes import ( - BoolType, - BoundTypeVar, - FunctionType, - GuppyType, - Inst, - ListType, - NoneType, - TupleType, - type_to_row, -) from guppylang.hugr import ops, val from guppylang.hugr.hugr import DFContainingNode, OutPortV, VNode from guppylang.nodes import ( @@ -35,6 +24,16 @@ LocalName, TypeApply, ) +from guppylang.tys.definition import bool_type, get_element_type, is_list_type +from guppylang.tys.subst import Inst +from guppylang.tys.ty import ( + BoundTypeVar, + FunctionType, + NoneType, + TupleType, + Type, + type_to_row, +) class ExprCompiler(CompilerBase, AstVisitor[OutPortV]): @@ -175,7 +174,7 @@ def visit_LocalCall(self, node: LocalCall) -> OutPortV: args = [self.visit(arg) for arg in node.args] call = self.graph.add_indirect_call(func, args) - rets = [call.out_port(i) for i in range(len(type_to_row(func.ty.returns)))] + rets = [call.out_port(i) for i in range(len(type_to_row(func.ty.output)))] return self._pack_returns(rets) def visit_GlobalCall(self, node: GlobalCall) -> OutPortV: @@ -194,7 +193,7 @@ def visit_Call(self, node: ast.Call) -> OutPortV: def visit_TypeApply(self, node: TypeApply) -> OutPortV: func = self.visit(node.value) assert isinstance(func.ty, FunctionType) - ta = self.graph.add_type_apply(func, node.tys, self.dfg.node).out_port(0) + ta = self.graph.add_type_apply(func, node.inst, self.dfg.node).out_port(0) # We have to be very careful here: If we instantiate `foo: forall T. T -> T` # with a tuple type `tuple[A, B]`, we get the type `tuple[A, B] -> tuple[A, B]`. @@ -203,7 +202,7 @@ def visit_TypeApply(self, node: TypeApply) -> OutPortV: # function with a single output port typed `tuple[A, B]`. # TODO: We would need to do manual monomorphisation in that case to obtain a # function that returns two ports as expected - if instantiation_needs_unpacking(func.ty, node.tys): + if instantiation_needs_unpacking(func.ty, node.inst): raise GuppyError( "Generic function instantiations returning rows are not supported yet", node, @@ -218,7 +217,7 @@ def visit_UnaryOp(self, node: ast.UnaryOp) -> OutPortV: arg = self.visit(node.operand) return self.graph.add_node( ops.CustomOp(extension="logic", op_name="Not", args=[]), inputs=[arg] - ).add_out_port(BoolType()) + ).add_out_port(bool_type()) raise InternalGuppyError("Node should have been removed during type checking.") @@ -291,13 +290,13 @@ def expr_to_row(expr: ast.expr) -> list[ast.expr]: def instantiation_needs_unpacking(func_ty: FunctionType, inst: Inst) -> bool: """Checks if instantiating a polymorphic makes it return a row.""" - if isinstance(func_ty.returns, BoundTypeVar): - return_ty = inst[func_ty.returns.idx] + if isinstance(func_ty.output, BoundTypeVar): + return_ty = inst[func_ty.output.idx] return isinstance(return_ty, TupleType | NoneType) return False -def python_value_to_hugr(v: Any, exp_ty: GuppyType) -> val.Value | None: +def python_value_to_hugr(v: Any, exp_ty: Type) -> val.Value | None: """Turns a Python value into a Hugr value. Returns None if the Python value cannot be represented in Guppy. @@ -326,9 +325,9 @@ def python_value_to_hugr(v: Any, exp_ty: GuppyType) -> val.Value | None: return None return val.Tuple(vs=vs) case list(elts): - assert isinstance(exp_ty, ListType) + assert is_list_type(exp_ty) return list_value( - [python_value_to_hugr(elt, exp_ty.element_type) for elt in elts] + [python_value_to_hugr(elt, get_element_type(exp_ty)) for elt in elts] ) case _: # Pytket conversion is an optional feature diff --git a/guppylang/compiler/func_compiler.py b/guppylang/compiler/func_compiler.py index e67b900b..6dfe53ad 100644 --- a/guppylang/compiler/func_compiler.py +++ b/guppylang/compiler/func_compiler.py @@ -9,9 +9,10 @@ DFContainer, PortVariable, ) -from guppylang.gtypes import FunctionType, Inst, type_to_row from guppylang.hugr.hugr import DFContainingVNode, Hugr, OutPortV from guppylang.nodes import CheckedNestedFunctionDef +from guppylang.tys.subst import Inst +from guppylang.tys.ty import FunctionType, type_to_row @dataclass @@ -40,7 +41,7 @@ def compile_call( call = graph.add_indirect_call(func.out_port(0), args, dfg.node) else: call = graph.add_call(self.node.out_port(0), args, dfg.node) - return [call.out_port(i) for i in range(len(type_to_row(self.ty.returns)))] + return [call.out_port(i) for i in range(len(type_to_row(self.ty.output)))] def compile_global_func_def( @@ -50,7 +51,7 @@ def compile_global_func_def( globals: CompiledGlobals, ) -> CompiledFunctionDef: """Compiles a top-level function definition to Hugr.""" - _, ports = graph.add_input_with_ports(list(func.ty.args), def_node) + _, ports = graph.add_input_with_ports(list(func.ty.inputs), def_node) cfg_node = graph.add_cfg(def_node, ports) compile_cfg(func.cfg, graph, cfg_node, globals) @@ -70,20 +71,22 @@ def compile_local_func_def( globals: CompiledGlobals, ) -> PortVariable: """Compiles a local (nested) function definition to Hugr.""" - assert func.ty.arg_names is not None + assert func.ty.input_names is not None # Pick an order for the captured variables captured = list(func.captured.values()) # Prepend captured variables to the function arguments closure_ty = FunctionType( - [v.ty for v in captured] + list(func.ty.args), - func.ty.returns, - [v.name for v in captured] + list(func.ty.arg_names), + [v.ty for v in captured] + list(func.ty.inputs), + func.ty.output, + [v.name for v in captured] + list(func.ty.input_names), ) def_node = graph.add_def(closure_ty, dfg.node, func.name) - def_input, input_ports = graph.add_input_with_ports(list(closure_ty.args), def_node) + def_input, input_ports = graph.add_input_with_ports( + list(closure_ty.inputs), def_node + ) # If we have captured variables and the body contains a recursive occurrence of # the function itself, then we provide the partially applied function as a local diff --git a/guppylang/compiler/stmt_compiler.py b/guppylang/compiler/stmt_compiler.py index 85e2ae53..f5890746 100644 --- a/guppylang/compiler/stmt_compiler.py +++ b/guppylang/compiler/stmt_compiler.py @@ -11,9 +11,9 @@ ) from guppylang.compiler.expr_compiler import ExprCompiler from guppylang.error import InternalGuppyError -from guppylang.gtypes import TupleType from guppylang.hugr.hugr import Hugr, OutPortV from guppylang.nodes import CheckedNestedFunctionDef +from guppylang.tys.ty import TupleType class StmtCompiler(CompilerBase, AstVisitor[None]): diff --git a/guppylang/custom.py b/guppylang/custom.py index e1142448..c7085d63 100644 --- a/guppylang/custom.py +++ b/guppylang/custom.py @@ -6,15 +6,12 @@ from guppylang.checker.expr_checker import check_call, synthesize_call from guppylang.checker.func_checker import check_signature from guppylang.compiler.core import CompiledFunction, CompiledGlobals, DFContainer -from guppylang.error import ( - GuppyError, - InternalGuppyError, - UnknownFunctionType, -) -from guppylang.gtypes import FunctionType, GuppyType, Inst, Subst, type_to_row +from guppylang.error import GuppyError, InternalGuppyError from guppylang.hugr import ops from guppylang.hugr.hugr import DFContainingVNode, Hugr, Node, OutPortV from guppylang.nodes import GlobalCall +from guppylang.tys.subst import Inst, Subst +from guppylang.tys.ty import FunctionType, NoneType, Type, type_to_row class CustomFunction(CompiledFunction): @@ -53,7 +50,10 @@ def __init__( @property # type: ignore[override] def ty(self) -> FunctionType: if self._ty is None: - return UnknownFunctionType() + # If we don't have a specified type, then the extension writer has to + # provide their own type-checking code. Therefore, it doesn't matter which + # type we return here since it will never be inspected. + return FunctionType([], NoneType()) return self._ty @ty.setter @@ -83,7 +83,7 @@ def check_type(self, globals: Globals) -> None: raise def check_call( - self, args: list[ast.expr], ty: GuppyType, node: AstNode, ctx: Context + self, args: list[ast.expr], ty: Type, node: AstNode, ctx: Context ) -> tuple[ast.expr, Subst]: self.call_checker._setup(ctx, node, self) new_node, subst = self.call_checker.check(args, ty) @@ -91,7 +91,7 @@ def check_call( def synthesize_call( self, args: list[ast.expr], node: AstNode, ctx: "Context" - ) -> tuple[ast.expr, GuppyType]: + ) -> tuple[ast.expr, Type]: self.call_checker._setup(ctx, node, self) new_node, ty = self.call_checker.synthesize(args) return with_type(ty, with_loc(node, new_node)), ty @@ -123,7 +123,7 @@ def load( node, ) - if self._ty.quantified: + if self._ty.parametrized: raise InternalGuppyError( "Can't yet generate higher-order versions of custom functions. This " "requires generic function *definitions*" @@ -143,7 +143,7 @@ def load( # to the function, and returns the results. if module not in self._defined: def_node = graph.add_def(self.ty, module, self.name) - _, inp_ports = graph.add_input_with_ports(list(self.ty.args), def_node) + _, inp_ports = graph.add_input_with_ports(list(self.ty.inputs), def_node) returns = self.compile_call( inp_ports, [], DFContainer(def_node, {}), graph, globals, node ) @@ -169,14 +169,14 @@ def _setup(self, ctx: Context, node: AstNode, func: CustomFunction) -> None: self.func = func @abstractmethod - def check(self, args: list[ast.expr], ty: GuppyType) -> tuple[ast.expr, Subst]: + def check(self, args: list[ast.expr], ty: Type) -> tuple[ast.expr, Subst]: """Checks the return value against a given type. Returns a (possibly) transformed and annotated AST node for the call. """ @abstractmethod - def synthesize(self, args: list[ast.expr]) -> tuple[ast.expr, GuppyType]: + def synthesize(self, args: list[ast.expr]) -> tuple[ast.expr, Type]: """Synthesizes a type for the return value of a call. Also returns a (possibly) transformed and annotated argument list. @@ -214,12 +214,12 @@ def compile(self, args: list[OutPortV]) -> list[OutPortV]: class DefaultCallChecker(CustomCallChecker): """Checks function calls by comparing to a type signature.""" - def check(self, args: list[ast.expr], ty: GuppyType) -> tuple[ast.expr, Subst]: + def check(self, args: list[ast.expr], ty: Type) -> tuple[ast.expr, Subst]: # Use default implementation from the expression checker args, subst, inst = check_call(self.func.ty, args, ty, self.node, self.ctx) return GlobalCall(func=self.func, args=args, type_args=inst), subst - def synthesize(self, args: list[ast.expr]) -> tuple[ast.expr, GuppyType]: + def synthesize(self, args: list[ast.expr]) -> tuple[ast.expr, Type]: # Use default implementation from the expression checker args, ty, inst = synthesize_call(self.func.ty, args, self.node, self.ctx) return GlobalCall(func=self.func, args=args, type_args=inst), ty diff --git a/guppylang/declared.py b/guppylang/declared.py index 7fc5923e..5a8baf3a 100644 --- a/guppylang/declared.py +++ b/guppylang/declared.py @@ -7,9 +7,10 @@ from guppylang.checker.func_checker import check_signature from guppylang.compiler.core import CompiledFunction, CompiledGlobals, DFContainer from guppylang.error import GuppyError -from guppylang.gtypes import GuppyType, Inst, Subst, type_to_row from guppylang.hugr.hugr import Hugr, Node, OutPortV, VNode from guppylang.nodes import GlobalCall +from guppylang.tys.subst import Inst, Subst +from guppylang.tys.ty import Type, type_to_row @dataclass @@ -30,7 +31,7 @@ def from_ast( return DeclaredFunction(name, ty, func_def, None) def check_call( - self, args: list[ast.expr], ty: GuppyType, node: AstNode, ctx: Context + self, args: list[ast.expr], ty: Type, node: AstNode, ctx: Context ) -> tuple[ast.expr, Subst]: # Use default implementation from the expression checker args, subst, inst = check_call(self.ty, args, ty, node, ctx) @@ -38,7 +39,7 @@ def check_call( def synthesize_call( self, args: list[ast.expr], node: AstNode, ctx: Context - ) -> tuple[GlobalCall, GuppyType]: + ) -> tuple[GlobalCall, Type]: # Use default implementation from the expression checker args, ty, inst = synthesize_call(self.ty, args, node, ctx) return with_loc(node, GlobalCall(func=self, args=args, type_args=inst)), ty @@ -70,4 +71,4 @@ def compile_call( call = graph.add_indirect_call(func.out_port(0), args, dfg.node) else: call = graph.add_call(self.node.out_port(0), args, dfg.node) - return [call.out_port(i) for i in range(len(type_to_row(self.ty.returns)))] + return [call.out_port(i) for i in range(len(type_to_row(self.ty.output)))] diff --git a/guppylang/decorator.py b/guppylang/decorator.py index cb1836a6..3d37a4c6 100644 --- a/guppylang/decorator.py +++ b/guppylang/decorator.py @@ -1,12 +1,12 @@ import functools import inspect -from collections.abc import Callable, Iterator, Sequence +from collections.abc import Callable from dataclasses import dataclass from pathlib import Path from types import ModuleType -from typing import Any, ClassVar, TypeVar +from typing import Any, TypeVar -from guppylang.ast_util import AstNode, has_empty_body +from guppylang.ast_util import has_empty_body from guppylang.custom import ( CustomCallChecker, CustomCallCompiler, @@ -16,10 +16,10 @@ OpCompiler, ) from guppylang.error import GuppyError, MissingModuleError, pretty_errors -from guppylang.gtypes import GuppyType, TypeTransformer from guppylang.hugr import ops, tys from guppylang.hugr.hugr import Hugr from guppylang.module import GuppyModule, PyFunc, parse_py_func +from guppylang.tys.definition import OpaqueTypeDef, TypeDef FuncDecorator = Callable[[PyFunc], PyFunc | Hugr] CustomFuncDecorator = Callable[[PyFunc], CustomFunction] @@ -113,12 +113,12 @@ def _get_python_caller(self, fn: PyFunc | None = None) -> ModuleIdentifier: return ModuleIdentifier(Path(filename), module) @pretty_errors - def extend_type(self, module: GuppyModule, ty: type[GuppyType]) -> ClassDecorator: + def extend_type(self, module: GuppyModule, defn: TypeDef) -> ClassDecorator: """Decorator to add new instance functions to a type.""" module._instance_func_buffer = {} def dec(c: type) -> type: - module._register_buffered_instance_funcs(ty) + module._register_buffered_instance_funcs(defn) return c return dec @@ -142,48 +142,9 @@ def type( def dec(c: type) -> type: _name = name or c.__name__ - - @dataclass(frozen=True) - class NewType(GuppyType): - args: Sequence[GuppyType] - name: ClassVar[str] = _name - - @staticmethod - def build(*args: GuppyType, node: AstNode | None = None) -> "GuppyType": - # At the moment, custom types don't support type arguments. - if len(args) > 0: - raise GuppyError( - f"Type `{_name}` does not accept type parameters.", node - ) - return NewType([]) - - @property - def type_args(self) -> Iterator[GuppyType]: - return iter(self.args) - - @property - def linear(self) -> bool: - return linear - - def to_hugr(self) -> tys.Type: - return hugr_ty - - def hugr_bound(self) -> tys.TypeBound: - return bound or super().hugr_bound() - - def transform(self, transformer: TypeTransformer) -> GuppyType: - return transformer.transform(self) or NewType( - [ty.transform(transformer) for ty in self.args] - ) - - def __str__(self) -> str: - return _name - - NewType.__name__ = name - NewType.__qualname__ = _name - module.register_type(_name, NewType) - module._register_buffered_instance_funcs(NewType) - c._guppy_type = NewType # type: ignore[attr-defined] + defn = OpaqueTypeDef(_name, [], linear, lambda _: hugr_ty, bound) + module.register_type(_name, defn) + module._register_buffered_instance_funcs(defn) return c return dec diff --git a/guppylang/error.py b/guppylang/error.py index 51ad044b..7bb1b455 100644 --- a/guppylang/error.py +++ b/guppylang/error.py @@ -10,8 +10,6 @@ from typing import Any, TypeVar, cast from guppylang.ast_util import AstNode, get_file, get_line_offset, get_source -from guppylang.gtypes import BoundTypeVar, ExistentialTypeVar, FunctionType, GuppyType -from guppylang.hugr.hugr import Node, OutPortV @dataclass(frozen=True) @@ -84,58 +82,6 @@ class InternalGuppyError(Exception): """Exception for internal problems during compilation.""" -class UndefinedPort(OutPortV): - """Dummy port for undefined variables. - - Raises an `InternalGuppyError` if one tries to access one of its properties. - """ - - def __init__(self, ty: GuppyType): - self._ty = ty - - @property - def ty(self) -> GuppyType: - return self._ty - - @property - def node(self) -> Node: - raise InternalGuppyError("Tried to access undefined Port") - - @property - def offset(self) -> int: - raise InternalGuppyError("Tried to access undefined Port") - - -class UnknownFunctionType(FunctionType): - """Dummy function type for custom functions without an expressible type. - - Raises an `InternalGuppyError` if one tries to access one of its members. - """ - - def __init__(self) -> None: - pass - - @property - def args(self) -> Sequence[GuppyType]: - raise InternalGuppyError("Tried to access unknown function type") - - @property - def returns(self) -> GuppyType: - raise InternalGuppyError("Tried to access unknown function type") - - @property - def args_names(self) -> Sequence[str] | None: - raise InternalGuppyError("Tried to access unknown function type") - - @property - def quantified(self) -> Sequence[BoundTypeVar]: - raise InternalGuppyError("Tried to access unknown function type") - - @property - def unsolved_vars(self) -> set[ExistentialTypeVar]: - return set() - - ExceptHook = Callable[[type[BaseException], BaseException, TracebackType | None], Any] diff --git a/guppylang/gtypes.py b/guppylang/gtypes.py deleted file mode 100644 index 04d1c873..00000000 --- a/guppylang/gtypes.py +++ /dev/null @@ -1,654 +0,0 @@ -import ast -import itertools -from abc import ABC, abstractmethod -from collections.abc import Iterator, Sequence -from dataclasses import dataclass, field -from typing import ( - TYPE_CHECKING, - ClassVar, - Literal, -) - -import guppylang.hugr.tys as tys -from guppylang.ast_util import AstNode - -if TYPE_CHECKING: - from guppylang.checker.core import Globals - - -Subst = dict["ExistentialTypeVar", "GuppyType"] -Inst = Sequence["GuppyType"] - - -@dataclass(frozen=True) -class GuppyType(ABC): - """Base class for all Guppy types. - - Note that all instances of `GuppyType` subclasses are expected to be immutable. - """ - - name: ClassVar[str] - - # Cache for free variables - _unsolved_vars: set["ExistentialTypeVar"] = field(init=False, repr=False) - - def __post_init__(self) -> None: - # Make sure that we don't have higher-rank polymorphic types - for arg in self.type_args: - if isinstance(arg, FunctionType) and arg.quantified: - from guppylang.error import InternalGuppyError - - raise InternalGuppyError( - "Tried to construct a higher-rank polymorphic type!" - ) - - # Compute free variables - if isinstance(self, ExistentialTypeVar): - vs = {self} - else: - vs = set() - for arg in self.type_args: - vs |= arg.unsolved_vars - object.__setattr__(self, "_unsolved_vars", vs) - - @staticmethod - @abstractmethod - def build(*args: "GuppyType", node: AstNode | None = None) -> "GuppyType": - pass - - @property - @abstractmethod - def type_args(self) -> Iterator["GuppyType"]: - pass - - @property - @abstractmethod - def linear(self) -> bool: - pass - - @abstractmethod - def to_hugr(self) -> tys.Type: - pass - - @abstractmethod - def transform(self, transformer: "TypeTransformer") -> "GuppyType": - pass - - def hugr_bound(self) -> tys.TypeBound: - if self.linear: - return tys.TypeBound.Any - return tys.TypeBound.join(*(ty.hugr_bound() for ty in self.type_args)) - - @property - def unsolved_vars(self) -> set["ExistentialTypeVar"]: - return self._unsolved_vars - - def substitute(self, s: Subst) -> "GuppyType": - return self.transform(Substituter(s)) - - -@dataclass(frozen=True) -class BoundTypeVar(GuppyType): - """Bound type variable, identified with a de Bruijn index.""" - - idx: int - display_name: str - linear: bool = False - - name: ClassVar[Literal["BoundTypeVar"]] = "BoundTypeVar" - - @staticmethod - def build(*rgs: GuppyType, node: AstNode | None = None) -> GuppyType: - raise NotImplementedError - - @property - def type_args(self) -> Iterator["GuppyType"]: - return iter(()) - - def hugr_bound(self) -> tys.TypeBound: - # We shouldn't make variables equatable, since we also want to substitute types - # like `float` - return tys.TypeBound.Any if self.linear else tys.TypeBound.Copyable - - def transform(self, transformer: "TypeTransformer") -> GuppyType: - return transformer.transform(self) or self - - def __str__(self) -> str: - return self.display_name - - def to_hugr(self) -> tys.Type: - return tys.Variable(i=self.idx, b=self.hugr_bound()) - - -@dataclass(frozen=True) -class ExistentialTypeVar(GuppyType): - """Existential type variable, identified with a globally unique id. - - Is solved during type checking. - """ - - id: int - display_name: str - linear: bool = False - - name: ClassVar[Literal["ExistentialTypeVar"]] = "ExistentialTypeVar" - - _id_generator: ClassVar[Iterator[int]] = itertools.count() - - @classmethod - def new(cls, display_name: str, linear: bool) -> "ExistentialTypeVar": - return ExistentialTypeVar(next(cls._id_generator), display_name, linear) - - @staticmethod - def build(*rgs: GuppyType, node: AstNode | None = None) -> GuppyType: - raise NotImplementedError - - @property - def type_args(self) -> Iterator["GuppyType"]: - return iter(()) - - def transform(self, transformer: "TypeTransformer") -> GuppyType: - return transformer.transform(self) or self - - def __str__(self) -> str: - return "?" + self.display_name - - def __hash__(self) -> int: - return self.id - - def to_hugr(self) -> tys.Type: - from guppylang.error import InternalGuppyError - - raise InternalGuppyError("Tried to convert free type variable to Hugr") - - -@dataclass(frozen=True) -class FunctionType(GuppyType): - args: Sequence[GuppyType] - returns: GuppyType - arg_names: Sequence[str] | None = field( - default=None, - compare=False, # Argument names are not taken into account for type equality - ) - quantified: Sequence[BoundTypeVar] = field(default_factory=list) - - name: ClassVar[Literal["%function"]] = "%function" - linear = False - - def __str__(self) -> str: - prefix = ( - "forall " + ", ".join(str(v) for v in self.quantified) + ". " - if self.quantified - else "" - ) - if len(self.args) == 1: - [arg] = self.args - return prefix + f"{arg} -> {self.returns}" - else: - return ( - prefix + f"({', '.join(str(a) for a in self.args)}) -> {self.returns}" - ) - - @staticmethod - def build(*args: GuppyType, node: AstNode | None = None) -> GuppyType: - # Function types cannot be constructed using `build`. The type parsing code - # has a special case for function types. - raise NotImplementedError - - @property - def type_args(self) -> Iterator[GuppyType]: - return itertools.chain(iter(self.args), iter((self.returns,))) - - def to_hugr(self) -> tys.PolyFuncType: - ins = [t.to_hugr() for t in self.args] - outs = [t.to_hugr() for t in type_to_row(self.returns)] - func_ty = tys.FunctionType(input=ins, output=outs, extension_reqs=[]) - return tys.PolyFuncType( - params=[tys.TypeTypeParam(b=v.hugr_bound()) for v in self.quantified], - body=func_ty, - ) - - def hugr_bound(self) -> tys.TypeBound: - # Functions are not equatable, only copyable - return tys.TypeBound.Copyable - - def transform(self, transformer: "TypeTransformer") -> GuppyType: - return transformer.transform(self) or FunctionType( - [ty.transform(transformer) for ty in self.args], - self.returns.transform(transformer), - self.arg_names, - ) - - def instantiate(self, tys: Sequence[GuppyType]) -> "FunctionType": - """Instantiates quantified type variables.""" - assert len(tys) == len(self.quantified) - - # Set the `preserve` flag for instantiated tuples and None - preserved_tys: list[GuppyType] = [] - for ty in tys: - if isinstance(ty, TupleType): - ty = TupleType(ty.element_types, preserve=True) - elif isinstance(ty, NoneType): - ty = NoneType(preserve=True) - preserved_tys.append(ty) - - inst = Instantiator(preserved_tys) - return FunctionType( - [ty.transform(inst) for ty in self.args], - self.returns.transform(inst), - self.arg_names, - ) - - def unquantified(self) -> tuple["FunctionType", Sequence[ExistentialTypeVar]]: - """Replaces all quantified variables with free type variables.""" - inst = [ - ExistentialTypeVar.new(v.display_name, v.linear) for v in self.quantified - ] - return self.instantiate(inst), inst - - -@dataclass(frozen=True) -class TupleType(GuppyType): - element_types: Sequence[GuppyType] - - # Flag to avoid turning the tuple into row when calling `type_to_row()`. This is - # used to make sure that type vars instantiated to tuples are not broken up into - # rows when generating a Hugr - preserve: bool = field(default=False, compare=False) - - name: ClassVar[Literal["tuple"]] = "tuple" - - @staticmethod - def build(*args: GuppyType, node: AstNode | None = None) -> GuppyType: - from guppylang.error import GuppyError - - # TODO: Parse empty tuples via `tuple[()]` - if len(args) == 0: - raise GuppyError("Tuple type requires generic type arguments", node) - return TupleType(list(args)) - - def __str__(self) -> str: - return f"({', '.join(str(e) for e in self.element_types)})" - - @property - def linear(self) -> bool: - return any(t.linear for t in self.element_types) - - @property - def type_args(self) -> Iterator[GuppyType]: - return iter(self.element_types) - - def to_hugr(self) -> tys.Type: - ts = [t.to_hugr() for t in self.element_types] - return tys.TupleType(inner=ts) - - def transform(self, transformer: "TypeTransformer") -> GuppyType: - return transformer.transform(self) or TupleType( - [ty.transform(transformer) for ty in self.element_types] - ) - - -@dataclass(frozen=True) -class SumType(GuppyType): - element_types: Sequence[GuppyType] - - name: ClassVar[str] = "%tuple" - - @staticmethod - def build(*args: GuppyType, node: AstNode | None = None) -> GuppyType: - # Sum types cannot be parsed and constructed using `build` since they cannot be - # written by the user - raise NotImplementedError - - def __str__(self) -> str: - return f"Sum({', '.join(str(e) for e in self.element_types)})" - - @property - def linear(self) -> bool: - return any(t.linear for t in self.element_types) - - @property - def type_args(self) -> Iterator[GuppyType]: - return iter(self.element_types) - - def to_hugr(self) -> tys.Type: - if all( - isinstance(e, TupleType) and len(e.element_types) == 0 - for e in self.element_types - ): - return tys.UnitSum(size=len(self.element_types)) - return tys.GeneralSum(row=[t.to_hugr() for t in self.element_types]) - - def transform(self, transformer: "TypeTransformer") -> GuppyType: - return transformer.transform(self) or SumType( - [ty.transform(transformer) for ty in self.element_types] - ) - - -@dataclass(frozen=True) -class ListType(GuppyType): - element_type: GuppyType - - name: ClassVar[Literal["list"]] = "list" - linear: bool = field(default=False, init=False) - - @staticmethod - def build(*args: GuppyType, node: AstNode | None = None) -> GuppyType: - from guppylang.error import GuppyError - - if len(args) == 0: - raise GuppyError("Missing type parameter for generic type `list`", node) - if len(args) > 1: - raise GuppyError("Too many type arguments for generic type `list`", node) - (arg,) = args - if arg.linear: - raise GuppyError( - "Type `list` cannot store linear data, use `linst` instead", node - ) - return ListType(arg) - - def __str__(self) -> str: - return f"list[{self.element_type}]" - - @property - def type_args(self) -> Iterator[GuppyType]: - return iter((self.element_type,)) - - def to_hugr(self) -> tys.Type: - return tys.Opaque( - extension="Collections", - id="List", - args=[tys.TypeTypeArg(ty=self.element_type.to_hugr())], - bound=self.hugr_bound(), - ) - - def transform(self, transformer: "TypeTransformer") -> GuppyType: - return transformer.transform(self) or ListType( - self.element_type.transform(transformer) - ) - - -@dataclass(frozen=True) -class LinstType(GuppyType): - element_type: GuppyType - - name: ClassVar[Literal["linst"]] = "linst" - - @staticmethod - def build(*args: GuppyType, node: AstNode | None = None) -> GuppyType: - from guppylang.error import GuppyError - - if len(args) == 0: - raise GuppyError("Missing type parameter for generic type `linst`", node) - if len(args) > 1: - raise GuppyError("Too many type arguments for generic type `linst`", node) - return LinstType(args[0]) - - def __str__(self) -> str: - return f"linst[{self.element_type}]" - - @property - def linear(self) -> bool: - return self.element_type.linear - - @property - def type_args(self) -> Iterator[GuppyType]: - return iter((self.element_type,)) - - def to_hugr(self) -> tys.Type: - return tys.Opaque( - extension="Collections", - id="List", - args=[tys.TypeTypeArg(ty=self.element_type.to_hugr())], - bound=self.hugr_bound(), - ) - - def transform(self, transformer: "TypeTransformer") -> GuppyType: - return transformer.transform(self) or LinstType( - self.element_type.transform(transformer) - ) - - -@dataclass(frozen=True) -class NoneType(GuppyType): - name: ClassVar[Literal["None"]] = "None" - linear: bool = False - - # Flag to avoid turning the type into a row when calling `type_to_row()`. This is - # used to make sure that type vars instantiated to Nones are not broken up into - # empty rows when generating a Hugr - preserve: bool = field(default=False, compare=False) - - @staticmethod - def build(*args: GuppyType, node: AstNode | None = None) -> GuppyType: - if len(args) > 0: - from guppylang.error import GuppyError - - raise GuppyError("Type `None` is not generic", node) - return NoneType() - - @property - def type_args(self) -> Iterator[GuppyType]: - return iter(()) - - def substitute(self, s: Subst) -> GuppyType: - return self - - def __str__(self) -> str: - return "None" - - def to_hugr(self) -> tys.Type: - return tys.TupleType(inner=[]) - - def transform(self, transformer: "TypeTransformer") -> GuppyType: - return transformer.transform(self) or self - - -@dataclass(frozen=True) -class BoolType(SumType): - """The type of booleans.""" - - linear: bool = False - name: ClassVar[Literal["bool"]] = "bool" - - def __init__(self) -> None: - # Hugr bools are encoded as Sum((), ()) - super().__init__([TupleType([]), TupleType([])]) - - @staticmethod - def build(*args: GuppyType, node: AstNode | None = None) -> GuppyType: - if len(args) > 0: - from guppylang.error import GuppyError - - raise GuppyError("Type `bool` is not generic", node) - return BoolType() - - def __str__(self) -> str: - return "bool" - - def transform(self, transformer: "TypeTransformer") -> GuppyType: - return transformer.transform(self) or self - - -class TypeTransformer(ABC): - """Abstract base class for a type visitor that transforms types.""" - - @abstractmethod - def transform(self, ty: GuppyType) -> GuppyType | None: - """This method is called for each visited type. - - Return a transformed type or `None` to continue the recursive visit. - """ - - -class Substituter(TypeTransformer): - """Type transformer that substitutes free type variables.""" - - subst: Subst - - def __init__(self, subst: Subst) -> None: - self.subst = subst - - def transform(self, ty: GuppyType) -> GuppyType | None: - if isinstance(ty, ExistentialTypeVar): - return self.subst.get(ty, None) - return None - - -class Instantiator(TypeTransformer): - """Type transformer that instantiates bound type variables.""" - - tys: Sequence[GuppyType] - - def __init__(self, tys: Sequence[GuppyType]) -> None: - self.tys = tys - - def transform(self, ty: GuppyType) -> GuppyType | None: - if isinstance(ty, BoundTypeVar): - # Instantiate if type for the index is available - if ty.idx < len(self.tys): - return self.tys[ty.idx] - - # Otherwise, lower the de Bruijn index - return BoundTypeVar(ty.idx - len(self.tys), ty.display_name, ty.linear) - return None - - -def unify(s: GuppyType, t: GuppyType, subst: Subst | None) -> Subst | None: - """Computes a most general unifier for two types. - - Return a substitutions `subst` such that `s[subst] == t[subst]` or `None` if this - not possible. - """ - if subst is None: - return None - if s == t: - return subst - if isinstance(s, ExistentialTypeVar): - return _unify_var(s, t, subst) - if isinstance(t, ExistentialTypeVar): - return _unify_var(t, s, subst) - if type(s) == type(t): - sargs, targs = list(s.type_args), list(t.type_args) - if len(sargs) == len(targs): - for sa, ta in zip(sargs, targs): - subst = unify(sa, ta, subst) - return subst - return None - - -def _unify_var(var: ExistentialTypeVar, t: GuppyType, subst: Subst) -> Subst | None: - """Helper function for unification of type variables.""" - if var in subst: - return unify(subst[var], t, subst) - if isinstance(t, ExistentialTypeVar) and t in subst: - return unify(var, subst[t], subst) - if var in t.unsolved_vars: - return None - return {var: t, **subst} - - -def type_from_ast( - node: AstNode, - globals: "Globals", - type_var_mapping: dict[str, BoundTypeVar] | None = None, -) -> GuppyType: - """Turns an AST expression into a Guppy type.""" - from guppylang.error import GuppyError - - if isinstance(node, ast.Name): - x = node.id - if x in globals.types: - return globals.types[x].build(node=node) - if x in globals.type_vars: - if type_var_mapping is None: - raise GuppyError( - "Free type variable. Only function types can be generic", node - ) - var_decl = globals.type_vars[x] - if var_decl.name not in type_var_mapping: - type_var_mapping[var_decl.name] = BoundTypeVar( - len(type_var_mapping), var_decl.name, var_decl.linear - ) - return type_var_mapping[var_decl.name] - raise GuppyError("Unknown type", node) - - if isinstance(node, ast.Constant): - v = node.value - if v is None: - return NoneType() - if isinstance(v, str): - try: - [stmt] = ast.parse(v).body - if not isinstance(stmt, ast.Expr): - raise GuppyError("Invalid Guppy type", node) - return type_from_ast(stmt.value, globals, type_var_mapping) - except (SyntaxError, ValueError): - raise GuppyError("Invalid Guppy type", node) from None - raise GuppyError(f"Constant `{v}` is not a valid type", node) - - if isinstance(node, ast.Tuple): - return TupleType( - [type_from_ast(el, globals, type_var_mapping) for el in node.elts] - ) - - if ( - isinstance(node, ast.Subscript) - and isinstance(node.value, ast.Name) - and node.value.id == "Callable" - and isinstance(node.slice, ast.Tuple) - and len(node.slice.elts) == 2 - ): - # TODO: Do we want to allow polymorphic Callable types? - [func_args, ret] = node.slice.elts - if isinstance(func_args, ast.List): - return FunctionType( - [type_from_ast(a, globals, type_var_mapping) for a in func_args.elts], - type_from_ast(ret, globals, type_var_mapping), - ) - - if isinstance(node, ast.Subscript) and isinstance(node.value, ast.Name): - x = node.value.id - if x in globals.types: - args = ( - node.slice.elts if isinstance(node.slice, ast.Tuple) else [node.slice] - ) - return globals.types[x].build( - *(type_from_ast(a, globals, type_var_mapping) for a in args), node=node - ) - - raise GuppyError("Not a valid Guppy type", node) - - -def type_row_from_ast(node: ast.expr, globals: "Globals") -> Sequence[GuppyType]: - """Turns an AST expression into a Guppy type row. - - This is needed to interpret the return type annotation of functions. - """ - # The return type `-> None` is represented in the ast as `ast.Constant(value=None)` - if isinstance(node, ast.Constant) and node.value is None: - return [] - ty = type_from_ast(node, globals) - if isinstance(ty, TupleType): - return ty.element_types - else: - return [ty] - - -def row_to_type(row: Sequence[GuppyType]) -> GuppyType: - """Turns a row of types into a single type by packing into a tuple.""" - if len(row) == 0: - return NoneType() - elif len(row) == 1: - return row[0] - else: - return TupleType(row) - - -def type_to_row(ty: GuppyType) -> Sequence[GuppyType]: - """Turns a type into a row of types by unpacking top-level tuples.""" - if isinstance(ty, NoneType) and not ty.preserve: - return [] - if isinstance(ty, TupleType) and not ty.preserve: - return ty.element_types - return [ty] diff --git a/guppylang/hugr/hugr.py b/guppylang/hugr/hugr.py index 01a6cd30..dcd23c7c 100644 --- a/guppylang/hugr/hugr.py +++ b/guppylang/hugr/hugr.py @@ -9,16 +9,16 @@ import guppylang.hugr.ops as ops import guppylang.hugr.raw as raw -from guppylang.gtypes import ( +from guppylang.hugr import val +from guppylang.tys.subst import Inst +from guppylang.tys.ty import ( FunctionType, - GuppyType, - Inst, SumType, TupleType, + Type, row_to_type, type_to_row, ) -from guppylang.hugr import tys, val NodeIdx = int PortOffset = int @@ -44,7 +44,7 @@ class OutPort(Port, ABC): class InPortV(InPort): """A typed value input port.""" - ty: GuppyType + ty: Type offset: PortOffset @@ -52,7 +52,7 @@ class InPortV(InPort): class OutPortV(OutPort): """A typed value output port.""" - ty: GuppyType + ty: Type offset: PortOffset @@ -70,7 +70,7 @@ class OutPortCF(OutPort): Edge = tuple[OutPort, InPort] -TypeList = list[GuppyType] +TypeList = list[Type] @dataclass @@ -138,13 +138,13 @@ def num_out_ports(self) -> int: """The number of output ports on this node.""" return len(self.out_port_types) - def add_in_port(self, ty: GuppyType) -> InPortV: + def add_in_port(self, ty: Type) -> InPortV: """Adds an input port at the end of the node and returns the port.""" p = InPortV(self, self.num_in_ports, ty) self.in_port_types.append(ty) return p - def add_out_port(self, ty: GuppyType) -> OutPortV: + def add_out_port(self, ty: Type) -> OutPortV: """Adds an output port at the end of the node and returns the port.""" p = OutPortV(self, self.num_out_ports, ty) self.out_port_types.append(ty) @@ -354,7 +354,7 @@ def set_root_name(self, name: str) -> VNode: return self.root def add_constant( - self, value: val.Value, ty: GuppyType, parent: Node | None = None + self, value: val.Value, ty: Type, parent: Node | None = None ) -> VNode: """Adds a constant node holding a given value to the graph.""" return self.add_node( @@ -372,7 +372,7 @@ def add_input( return node def add_input_with_ports( - self, output_tys: Sequence[GuppyType], parent: Node | None = None + self, output_tys: Sequence[Type], parent: Node | None = None ) -> tuple[VNode, list[OutPortV]]: """Adds an `Input` node to the graph.""" node = self.add_input(None, parent) @@ -495,7 +495,7 @@ def add_call( return self.add_node( ops.Call(), None, - list(type_to_row(def_port.ty.returns)), + list(type_to_row(def_port.ty.output)), parent, [*args, def_port], ) @@ -508,27 +508,27 @@ def add_indirect_call( return self.add_node( ops.CallIndirect(), None, - list(type_to_row(fun_port.ty.returns)), + list(type_to_row(fun_port.ty.output)), parent, [fun_port, *args], ) def add_partial( - self, def_port: OutPortV, args: list[OutPortV], parent: Node | None = None + self, def_port: OutPortV, inputs: list[OutPortV], parent: Node | None = None ) -> VNode: """Adds a `Partial` evaluation node to the graph.""" assert isinstance(def_port.ty, FunctionType) - assert len(def_port.ty.args) >= len(args) - assert [p.ty for p in args] == def_port.ty.args[: len(args)] + assert len(def_port.ty.inputs) >= len(inputs) + assert [p.ty for p in inputs] == def_port.ty.inputs[: len(inputs)] new_ty = FunctionType( - def_port.ty.args[len(args) :], - def_port.ty.returns, - def_port.ty.arg_names[len(args) :] - if def_port.ty.arg_names is not None + def_port.ty.inputs[len(inputs) :], + def_port.ty.output, + def_port.ty.input_names[len(inputs) :] + if def_port.ty.input_names is not None else None, ) return self.add_node( - ops.DummyOp(name="partial"), None, [new_ty], parent, [*args, def_port] + ops.DummyOp(name="partial"), None, [new_ty], parent, [*inputs, def_port] ) def add_type_apply( @@ -536,11 +536,11 @@ def add_type_apply( ) -> VNode: """Adds a `TypeApply` node to the graph.""" assert isinstance(func_port.ty, FunctionType) - assert len(func_port.ty.quantified) == len(args) + assert len(func_port.ty.params) == len(args) result_ty = func_port.ty.instantiate(args) ta = ops.TypeApplication( input=func_port.ty.to_hugr(), - args=[tys.TypeTypeArg(ty=ty.to_hugr()) for ty in args], + args=[arg.to_hugr() for arg in args], output=result_ty.to_hugr(), ) return self.add_node( diff --git a/guppylang/hugr/ops.py b/guppylang/hugr/ops.py index 6bb0d067..a3ae730e 100644 --- a/guppylang/hugr/ops.py +++ b/guppylang/hugr/ops.py @@ -422,7 +422,6 @@ class TypeApplication(BaseModel): ( Module | Case - | Module | FuncDefn | FuncDecl | Const diff --git a/guppylang/module.py b/guppylang/module.py index d24d1cd7..b46f7b5f 100644 --- a/guppylang/module.py +++ b/guppylang/module.py @@ -7,7 +7,7 @@ from typing import Any, Union from guppylang.ast_util import AstNode, annotate_location -from guppylang.checker.core import Globals, PyScope, TypeVarDecl, qualified_name +from guppylang.checker.core import Globals, PyScope, qualified_name from guppylang.checker.func_checker import DefinedFunction, check_global_func_def from guppylang.compiler.core import CompiledGlobals from guppylang.compiler.func_compiler import ( @@ -17,8 +17,9 @@ from guppylang.custom import CustomFunction from guppylang.declared import DeclaredFunction from guppylang.error import GuppyError, pretty_errors -from guppylang.gtypes import GuppyType from guppylang.hugr.hugr import Hugr +from guppylang.tys.definition import TypeDef +from guppylang.tys.param import TypeParam PyFunc = Callable[..., Any] PyFuncDefOrDecl = tuple[bool, PyFunc] @@ -92,9 +93,7 @@ def load(self, m: Union[ModuleType, "GuppyModule"]) -> None: if isinstance(val, GuppyModule): self.load(val) - def register_func_def( - self, f: PyFunc, instance: type[GuppyType] | None = None - ) -> None: + def register_func_def(self, f: PyFunc, instance: TypeDef | None = None) -> None: """Registers a Python function definition as belonging to this Guppy module.""" self._check_not_yet_compiled() func_ast = parse_py_func(f) @@ -107,9 +106,7 @@ def register_func_def( self._check_name_available(name, func_ast) self._func_defs[name] = func_ast, get_py_scope(f) - def register_func_decl( - self, f: PyFunc, instance: type[GuppyType] | None = None - ) -> None: + def register_func_decl(self, f: PyFunc, instance: TypeDef | None = None) -> None: """Registers a Python function declaration as belonging to this Guppy module.""" self._check_not_yet_compiled() func_ast = parse_py_func(f) @@ -123,7 +120,7 @@ def register_func_decl( self._func_decls[name] = func_ast def register_custom_func( - self, func: CustomFunction, instance: type[GuppyType] | None = None + self, func: CustomFunction, instance: TypeDef | None = None ) -> None: """Registers a custom function as belonging to this Guppy module.""" self._check_not_yet_compiled() @@ -135,31 +132,33 @@ def register_custom_func( self._check_name_available(func.name, func.defined_at) self._custom_funcs[func.name] = func - def register_type(self, name: str, ty: type[GuppyType]) -> None: + def register_type(self, name: str, defn: TypeDef) -> None: """Registers an existing Guppy type as belonging to this Guppy module.""" self._check_not_yet_compiled() self._check_type_name_available(name, None) - self._globals.types[name] = ty + self._globals.type_defs[name] = defn def register_type_var(self, name: str, linear: bool) -> None: """Registers a new type variable""" self._check_not_yet_compiled() self._check_type_name_available(name, None) - self._globals.type_vars[name] = TypeVarDecl(name, linear) + self._globals.param_vars[name] = TypeParam( + len(self._globals.param_vars), name, linear + ) - def _register_buffered_instance_funcs(self, instance: type[GuppyType]) -> None: + def _register_buffered_instance_funcs(self, defn: TypeDef) -> None: assert self._instance_func_buffer is not None buffer = self._instance_func_buffer self._instance_func_buffer = None for f in buffer.values(): if isinstance(f, CustomFunction): - self.register_custom_func(f, instance) + self.register_custom_func(f, defn) else: is_def, pyfunc = f if is_def: - self.register_func_def(pyfunc, instance) + self.register_func_def(pyfunc, defn) else: - self.register_func_decl(pyfunc, instance) + self.register_func_decl(pyfunc, defn) @property def compiled(self) -> bool: @@ -232,7 +231,7 @@ def contains_function(self, name: str) -> bool: def contains_type(self, name: str) -> bool: """Returns 'True' if the module contains a type with the given name.""" - return name in self._globals.types or name in self._globals.type_vars + return name in self._globals.type_defs or name in self._globals.param_vars def _check_not_yet_compiled(self) -> None: if self._compiled: @@ -246,13 +245,13 @@ def _check_name_available(self, name: str, node: AstNode | None) -> None: ) def _check_type_name_available(self, name: str, node: AstNode | None) -> None: - if name in self._globals.types: + if name in self._globals.type_defs: raise GuppyError( f"Module `{self.name}` already contains a type `{name}`", node, ) - if name in self._globals.type_vars: + if name in self._globals.param_vars: raise GuppyError( f"Module `{self.name}` already contains a type variable `{name}`", node, diff --git a/guppylang/nodes.py b/guppylang/nodes.py index 908fe55f..14b39b52 100644 --- a/guppylang/nodes.py +++ b/guppylang/nodes.py @@ -1,10 +1,11 @@ """Custom AST nodes used by Guppy""" import ast -from collections.abc import Mapping, Sequence +from collections.abc import Mapping from typing import TYPE_CHECKING, Any -from guppylang.gtypes import FunctionType, GuppyType, Inst +from guppylang.tys.subst import Inst +from guppylang.tys.ty import FunctionType if TYPE_CHECKING: from guppylang.cfg.cfg import CFG @@ -52,11 +53,11 @@ class GlobalCall(ast.expr): class TypeApply(ast.expr): value: ast.expr - tys: Sequence[GuppyType] + inst: Inst _fields = ( "value", - "tys", + "inst", ) diff --git a/guppylang/prelude/_internal.py b/guppylang/prelude/_internal.py index 3da1fef3..84804722 100644 --- a/guppylang/prelude/_internal.py +++ b/guppylang/prelude/_internal.py @@ -13,10 +13,12 @@ DefaultCallChecker, ) from guppylang.error import GuppyError, GuppyTypeError -from guppylang.gtypes import BoolType, FunctionType, GuppyType, Subst, unify from guppylang.hugr import ops, tys, val from guppylang.hugr.hugr import OutPortV from guppylang.nodes import GlobalCall +from guppylang.tys.definition import bool_type +from guppylang.tys.subst import Subst +from guppylang.tys.ty import FunctionType, OpaqueType, Type, unify INT_WIDTH = 6 # 2^6 = 64 bit @@ -103,17 +105,22 @@ def float_op(op_name: str, ext: str = "arithmetic.float") -> ops.OpType: class CoercingChecker(DefaultCallChecker): """Function call type checker that automatically coerces arguments to float.""" - def synthesize(self, args: list[ast.expr]) -> tuple[ast.expr, GuppyType]: + def synthesize(self, args: list[ast.expr]) -> tuple[ast.expr, Type]: from .builtins import Int for i in range(len(args)): args[i], ty = ExprSynthesizer(self.ctx).synthesize(args[i]) - if isinstance(ty, self.ctx.globals.types["int"]): + if ( + isinstance(ty, OpaqueType) + and ty.defn == self.ctx.globals.type_defs["int"] + ): call = with_loc( self.node, GlobalCall(func=Int.__float__, args=[args[i]], type_args=[]), ) - args[i] = with_type(self.ctx.globals.types["float"].build(), call) + args[i] = with_type( + self.ctx.globals.type_defs["float"].check_instantiate([]), call + ) return super().synthesize(args) @@ -129,13 +136,13 @@ def _setup(self, ctx: Context, node: AstNode, func: CustomFunction) -> None: super()._setup(ctx, node, func) self.base_checker._setup(ctx, node, func) - def check(self, args: list[ast.expr], ty: GuppyType) -> tuple[ast.expr, Subst]: + def check(self, args: list[ast.expr], ty: Type) -> tuple[ast.expr, Subst]: expr, subst = self.base_checker.check(args, ty) if isinstance(expr, GlobalCall): expr.args = list(reversed(args)) return expr, subst - def synthesize(self, args: list[ast.expr]) -> tuple[ast.expr, GuppyType]: + def synthesize(self, args: list[ast.expr]) -> tuple[ast.expr, Type]: expr, ty = self.base_checker.synthesize(args) if isinstance(expr, GlobalCall): expr.args = list(reversed(args)) @@ -151,10 +158,10 @@ class FailingChecker(CustomCallChecker): def __init__(self, msg: str) -> None: self.msg = msg - def synthesize(self, args: list[ast.expr]) -> tuple[ast.expr, GuppyType]: + def synthesize(self, args: list[ast.expr]) -> tuple[ast.expr, Type]: raise GuppyError(self.msg, self.node) - def check(self, args: list[ast.expr], ty: GuppyType) -> tuple[ast.expr, Subst]: + def check(self, args: list[ast.expr], ty: Type) -> tuple[ast.expr, Subst]: raise GuppyError(self.msg, self.node) @@ -164,12 +171,12 @@ class UnsupportedChecker(CustomCallChecker): Gives the uses a nicer error message when they try to use an unsupported feature. """ - def synthesize(self, args: list[ast.expr]) -> tuple[ast.expr, GuppyType]: + def synthesize(self, args: list[ast.expr]) -> tuple[ast.expr, Type]: raise GuppyError( f"Builtin method `{self.func.name}` is not supported by Guppy", self.node ) - def check(self, args: list[ast.expr], ty: GuppyType) -> tuple[ast.expr, Subst]: + def check(self, args: list[ast.expr], ty: Type) -> tuple[ast.expr, Subst]: raise GuppyError( f"Builtin method `{self.func.name}` is not supported by Guppy", self.node ) @@ -201,11 +208,11 @@ def _get_func( ) return [fst, *rest], func - def synthesize(self, args: list[ast.expr]) -> tuple[ast.expr, GuppyType]: + def synthesize(self, args: list[ast.expr]) -> tuple[ast.expr, Type]: args, func = self._get_func(args) return func.synthesize_call(args, self.node, self.ctx) - def check(self, args: list[ast.expr], ty: GuppyType) -> tuple[ast.expr, Subst]: + def check(self, args: list[ast.expr], ty: Type) -> tuple[ast.expr, Subst]: args, func = self._get_func(args) return func.check_call(args, ty, self.node, self.ctx) @@ -213,7 +220,7 @@ def check(self, args: list[ast.expr], ty: GuppyType) -> tuple[ast.expr, Subst]: class CallableChecker(CustomCallChecker): """Call checker for the builtin `callable` function""" - def synthesize(self, args: list[ast.expr]) -> tuple[ast.expr, GuppyType]: + def synthesize(self, args: list[ast.expr]) -> tuple[ast.expr, Type]: check_num_args(1, len(args), self.node) [arg] = args arg, ty = ExprSynthesizer(self.ctx).synthesize(arg) @@ -222,11 +229,11 @@ def synthesize(self, args: list[ast.expr]) -> tuple[ast.expr, GuppyType]: or self.ctx.globals.get_instance_func(ty, "__call__") is not None ) const = with_loc(self.node, ast.Constant(value=is_callable)) - return const, BoolType() + return const, bool_type() - def check(self, args: list[ast.expr], ty: GuppyType) -> tuple[ast.expr, Subst]: + def check(self, args: list[ast.expr], ty: Type) -> tuple[ast.expr, Subst]: args, _ = self.synthesize(args) - subst = unify(ty, BoolType(), {}) + subst = unify(ty, bool_type(), {}) if subst is None: raise GuppyTypeError( f"Expected expression of type `{ty}`, got `bool`", self.node @@ -336,4 +343,4 @@ def compile(self, args: list[OutPortV]) -> list[OutPortV]: self.graph.add_node( quantum_op("QFree"), inputs=[measure.add_out_port(qubit.ty)] ) - return [measure.add_out_port(BoolType())] + return [measure.add_out_port(bool_type())] diff --git a/guppylang/prelude/builtins.py b/guppylang/prelude/builtins.py index 194e0000..23a16843 100644 --- a/guppylang/prelude/builtins.py +++ b/guppylang/prelude/builtins.py @@ -4,7 +4,6 @@ from guppylang.custom import DefaultCallChecker, NoopCompiler from guppylang.decorator import guppy -from guppylang.gtypes import BoolType, LinstType, ListType from guppylang.hugr import ops, tys from guppylang.module import GuppyModule from guppylang.prelude._internal import ( @@ -25,6 +24,7 @@ int_op, logic_op, ) +from guppylang.tys.definition import bool_type_def, linst_type_def, list_type_def builtins = GuppyModule("builtins", import_builtins=False) @@ -32,7 +32,7 @@ L = guppy.type_var(builtins, "L", linear=True) -@guppy.extend_type(builtins, BoolType) +@guppy.extend_type(builtins, bool_type_def) class Bool: @guppy.hugr_op(builtins, logic_op("And", [tys.BoundedNatArg(n=2)])) def __and__(self: bool, other: bool) -> bool: ... @@ -290,7 +290,7 @@ def __truediv__(self: float, other: float) -> float: ... def __trunc__(self: float) -> float: ... -@guppy.extend_type(builtins, ListType) +@guppy.extend_type(builtins, list_type_def) class List: @guppy.hugr_op(builtins, ops.DummyOp(name="Concat")) def __add__(self: list[T], other: list[T]) -> list[T]: ... @@ -365,7 +365,7 @@ def sort(self: list[T]) -> None: ... linst = list -@guppy.extend_type(builtins, LinstType) +@guppy.extend_type(builtins, linst_type_def) class Linst: @guppy.hugr_op(builtins, ops.DummyOp(name="Append")) def __add__(self: linst[L], other: linst[L]) -> linst[L]: ... diff --git a/guppylang/tys/__init__.py b/guppylang/tys/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/guppylang/tys/arg.py b/guppylang/tys/arg.py new file mode 100644 index 00000000..ddcd7a71 --- /dev/null +++ b/guppylang/tys/arg.py @@ -0,0 +1,87 @@ +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import TYPE_CHECKING, TypeAlias + +from guppylang.hugr import tys +from guppylang.tys.common import ToHugr, Transformable, Transformer, Visitor +from guppylang.tys.const import Const +from guppylang.tys.var import ExistentialVar + +if TYPE_CHECKING: + from guppylang.tys.ty import Type + + +# We define the `Argument` type as a union of all `ArgumentBase` subclasses defined +# below. This models an algebraic data type and enables exhaustiveness checking in +# pattern matches etc. +# Note that this might become obsolete in case the `@sealed` decorator is added: +# * https://peps.python.org/pep-0622/#sealed-classes-as-algebraic-data-types +# * https://github.com/johnthagen/sealed-typing-pep +Argument: TypeAlias = "TypeArg | ConstArg" + + +@dataclass(frozen=True) +class ArgumentBase(ToHugr[tys.TypeArg], Transformable["Argument"], ABC): + """Abstract base class for arguments of parametrized types. + + For example, in the type `array[int, 42]` we have two arguments `int` and `42`. + """ + + @property + @abstractmethod + def unsolved_vars(self) -> set[ExistentialVar]: + """The existential type variables contained in this argument.""" + + +@dataclass(frozen=True) +class TypeArg(ArgumentBase): + """Argument that can be instantiated for a `TypeParameter`.""" + + # The type to instantiate + ty: "Type" + + @property + def unsolved_vars(self) -> set[ExistentialVar]: + """The existential type variables contained in this argument.""" + return self.ty.unsolved_vars + + def to_hugr(self) -> tys.TypeArg: + """Computes the Hugr representation of the argument.""" + return tys.TypeTypeArg(ty=self.ty.to_hugr()) + + def visit(self, visitor: Visitor) -> None: + """Accepts a visitor on this argument.""" + if not visitor.visit(self): + self.ty.visit(visitor) + + def transform(self, transformer: Transformer) -> Argument: + """Accepts a transformer on this argument.""" + return transformer.transform(self) or TypeArg(self.ty.transform(transformer)) + + +@dataclass(frozen=True) +class ConstArg(ArgumentBase): + """Argument that can be substituted for a `ConstParameter`. + + Note that support for this kind is not implemented yet. + """ + + # Hugr value to instantiate + const: Const + + @property + def unsolved_vars(self) -> set[ExistentialVar]: + """The existential type variables contained in this argument.""" + raise NotImplementedError + + def to_hugr(self) -> tys.TypeArg: + """Computes the Hugr representation of the argument.""" + raise NotImplementedError + + def visit(self, visitor: Visitor) -> None: + """Accepts a visitor on this argument.""" + raise NotImplementedError + + def transform(self, transformer: Transformer) -> Argument: + """Accepts a transformer on this argument.""" + raise NotImplementedError diff --git a/guppylang/tys/common.py b/guppylang/tys/common.py new file mode 100644 index 00000000..0058f272 --- /dev/null +++ b/guppylang/tys/common.py @@ -0,0 +1,60 @@ +from abc import ABC, abstractmethod +from typing import Any, Generic, TypeVar + +T = TypeVar("T") + + +class ToHugr(ABC, Generic[T]): + """Abstract base class for objects that have a Hugr representation.""" + + @abstractmethod + def to_hugr(self) -> T: + """Computes the Hugr representation of the object.""" + + +class Visitor(ABC): + """Abstract base class for a type visitor that transforms types.""" + + @abstractmethod + def visit(self, arg: Any, /) -> bool: + """This method is called for each visited type. + + Return `False` to continue the recursive descent. + """ + + +class Transformer(ABC): + """Abstract base class for a type visitor that transforms types.""" + + @abstractmethod + def transform(self, arg: Any, /) -> Any | None: + """Transforms the object. + + Return a transformed type or `None` to continue the recursive visit. + """ + + +class Visitable(ABC): + """Abstract base class for objects that can be recursively visited.""" + + @abstractmethod + def visit(self, visitor: Visitor, /) -> None: + """Accepts a visitor on the object. + + Implementors of this method should first invoke the visitor on the object + itself. If the visitor doesn't handle the object, the visitor should be passed + to all relevant members of the object. + """ + + +class Transformable(Visitable, ABC, Generic[T]): + """Abstract base class for objects that can be recursively transformed.""" + + @abstractmethod + def transform(self, transformer: Transformer, /) -> T: + """Accepts a transformer on the object. + + Implementors of this method should first invoke the transformer on the object + itself. If the visitor doesn't handle the object, the visitor should be used to + transform all relevant members of the object. + """ diff --git a/guppylang/tys/const.py b/guppylang/tys/const.py new file mode 100644 index 00000000..3d2feacb --- /dev/null +++ b/guppylang/tys/const.py @@ -0,0 +1,62 @@ +from abc import ABC +from dataclasses import dataclass +from typing import TYPE_CHECKING + +from guppylang.error import InternalGuppyError +from guppylang.hugr import val +from guppylang.tys.var import BoundVar, ExistentialVar + +if TYPE_CHECKING: + from guppylang.tys.ty import Type + + +@dataclass(frozen=True) +class Const(ABC): + """Abstract base class for constants arguments in the type system. + + In principle, we can allow constants of any type representable in the type system. + For now, we will only support basic consts like `int` or `bool`, but in the future + we could have struct constants etc. + """ + + ty: "Type" + + def __post_init__(self) -> None: + if self.ty.unsolved_vars: + raise InternalGuppyError("Attempted to create constant with unsolved type") + + +@dataclass(frozen=True) +class ConstValue(Const): + """A constant value in the type system. + + For example, in the type `array[int, 5]` the second argument is a `ConstArg` that + contains a `ConstValue(5)`. + """ + + # Hugr encoding of the value + # TODO: We might need a Guppy representation of this... + value: val.Value + + +@dataclass(frozen=True) +class BoundConstVar(BoundVar, Const): + """Bound variable referencing a `ConstParam`. + + For example, in the function type `forall n: int. array[float, n] -> array[int, n]`, + we represent the int argument to `array` as a `ConstArg` containing a + `BoundConstVar(idx=0)`. + """ + + +@dataclass(frozen=True) +class ExistentialConstVar(ExistentialVar, Const): + """Existential constant variable. + + During type checking we try to solve all existential constant variables and + substitute them with concrete constants. + """ + + @classmethod + def fresh(cls, display_name: str, ty: "Type") -> "ExistentialConstVar": + return ExistentialConstVar(ty, display_name, next(cls._fresh_id)) diff --git a/guppylang/tys/definition.py b/guppylang/tys/definition.py new file mode 100644 index 00000000..9d144409 --- /dev/null +++ b/guppylang/tys/definition.py @@ -0,0 +1,210 @@ +from abc import ABC, abstractmethod +from collections.abc import Callable, Sequence +from dataclasses import dataclass, field +from typing import Literal + +from guppylang.ast_util import AstNode +from guppylang.error import GuppyError +from guppylang.hugr import tys +from guppylang.tys.arg import Argument, TypeArg +from guppylang.tys.param import Parameter, TypeParam +from guppylang.tys.ty import FunctionType, NoneType, OpaqueType, TupleType, Type + + +@dataclass(frozen=True) +class TypeDef(ABC): + """Abstract base class for type definitions.""" + + name: str + + @abstractmethod + def check_instantiate( + self, args: Sequence[Argument], loc: AstNode | None = None + ) -> Type: + """Checks if the type definition can be instantiated with the given arguments. + + Returns the resulting concrete type or raises a user error if the arguments are + invalid. + """ + + +@dataclass(frozen=True) +class OpaqueTypeDef(TypeDef): + """An opaque type definition that is backed by some Hugr type.""" + + params: Sequence[Parameter] + always_linear: bool + to_hugr: Callable[[Sequence[Argument]], tys.Type] + bound: tys.TypeBound | None = None + + def check_instantiate( + self, args: Sequence[Argument], loc: AstNode | None = None + ) -> OpaqueType: + """Checks if the type definition can be instantiated with the given arguments. + + Returns the resulting concrete type or raises a user error if the arguments are + invalid. + """ + exp, act = len(self.params), len(args) + if exp > act: + raise GuppyError(f"Missing parameter for type `{self.name}`", loc) + elif 0 == exp < act: + raise GuppyError(f"Type `{self.name}` is not parameterized", loc) + elif 0 < exp < act: + raise GuppyError(f"Too many parameters for type `{self.name}`", loc) + + # Now check that the kinds match up + for param, arg in zip(self.params, args, strict=True): + # TODO: The error location is bad. We want the location of `arg`, not of the + # whole thing. + param.check_arg(arg, loc) + return OpaqueType(args, self) + + +@dataclass(frozen=True) +class _CallableTypeDef(TypeDef): + """Type definition associated with the builtin `Callable` type. + + Any impls on functions can be registered with this definition. + """ + + name: Literal["Callable"] = field(default="Callable", init=False) + + def check_instantiate( + self, args: Sequence[Argument], loc: AstNode | None = None + ) -> FunctionType: + # We get the inputs/output as a flattened list: `args = [*inputs, output]`. + if not args: + raise GuppyError(f"Missing parameter for type `{self.name}`", loc) + args = [ + # TODO: Better error location + TypeParam(0, f"T{i}", can_be_linear=True).check_arg(arg, loc).ty + for i, arg in enumerate(args) + ] + *inputs, output = args + return FunctionType(inputs, output) + + +@dataclass(frozen=True) +class _TupleTypeDef(TypeDef): + """Type definition associated with the builtin `tuple` type. + + Any impls on tuples can be registered with this definition. + """ + + name: Literal["tuple"] = field(default="tuple", init=False) + + def check_instantiate( + self, args: Sequence[Argument], loc: AstNode | None = None + ) -> TupleType: + # We accept any number of arguments. If users just write `tuple`, we give them + # the empty tuple type. We just have to make sure that the args are of kind type + args = [ + # TODO: Better error location + TypeParam(0, f"T{i}", can_be_linear=True).check_arg(arg, loc).ty + for i, arg in enumerate(args) + ] + return TupleType(args) + + +@dataclass(frozen=True) +class _NoneTypeDef(TypeDef): + """Type definition associated with the builtin `None` type. + + Any impls on None can be registered with this definition. + """ + + name: Literal["None"] = field(default="None", init=False) + + def check_instantiate( + self, args: Sequence[Argument], loc: AstNode | None = None + ) -> NoneType: + if args: + raise GuppyError("Type `None` is not parameterized", loc) + return NoneType() + + +@dataclass(frozen=True) +class _ListTypeDef(OpaqueTypeDef): + """Type definition associated with the builtin `list` type. + + We have a custom definition to give a nicer error message if the user tries to put + linear data into a regular list. + """ + + def check_instantiate( + self, args: Sequence[Argument], loc: AstNode | None = None + ) -> OpaqueType: + if len(args) == 1: + [arg] = args + if isinstance(arg, TypeArg) and arg.ty.linear: + raise GuppyError( + "Type `list` cannot store linear data, use `linst` instead", loc + ) + return super().check_instantiate(args, loc) + + +def _list_to_hugr(args: Sequence[Argument]) -> tys.Opaque: + return tys.Opaque( + extension="Collections", + id="List", + args=[arg.to_hugr() for arg in args], + bound=tys.TypeBound.join( + *(arg.ty.hugr_bound for arg in args if isinstance(arg, TypeArg)) + ), + ) + + +callable_type_def = _CallableTypeDef() +tuple_type_def = _TupleTypeDef() +none_type_def = _NoneTypeDef() +bool_type_def = OpaqueTypeDef( + name="bool", + params=[], + always_linear=False, + to_hugr=lambda _: tys.UnitSum(size=2), +) +linst_type_def = OpaqueTypeDef( + name="linst", + params=[TypeParam(0, "T", can_be_linear=True)], + always_linear=False, + to_hugr=_list_to_hugr, +) +list_type_def = _ListTypeDef( + name="list", + params=[TypeParam(0, "T", can_be_linear=False)], + always_linear=False, + to_hugr=_list_to_hugr, +) + + +def bool_type() -> OpaqueType: + return OpaqueType([], bool_type_def) + + +def list_type(element_ty: Type) -> OpaqueType: + return OpaqueType([TypeArg(element_ty)], list_type_def) + + +def linst_type(element_ty: Type) -> OpaqueType: + return OpaqueType([TypeArg(element_ty)], linst_type_def) + + +def is_bool_type(ty: Type) -> bool: + return isinstance(ty, OpaqueType) and ty.defn == bool_type_def + + +def is_list_type(ty: Type) -> bool: + return isinstance(ty, OpaqueType) and ty.defn == list_type_def + + +def is_linst_type(ty: Type) -> bool: + return isinstance(ty, OpaqueType) and ty.defn == linst_type_def + + +def get_element_type(ty: Type) -> Type: + assert isinstance(ty, OpaqueType) + assert ty.defn in (list_type_def, linst_type_def) + (arg,) = ty.args + assert isinstance(arg, TypeArg) + return arg.ty diff --git a/guppylang/tys/param.py b/guppylang/tys/param.py new file mode 100644 index 00000000..e398992c --- /dev/null +++ b/guppylang/tys/param.py @@ -0,0 +1,171 @@ +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import TYPE_CHECKING, TypeAlias + +from typing_extensions import Self + +from guppylang.ast_util import AstNode +from guppylang.error import GuppyTypeError, InternalGuppyError +from guppylang.hugr import tys +from guppylang.hugr.tys import TypeBound +from guppylang.tys.arg import Argument, ConstArg, TypeArg +from guppylang.tys.common import ToHugr +from guppylang.tys.var import ExistentialVar + +if TYPE_CHECKING: + from guppylang.tys.ty import Type + + +# We define the `Parameter` type as a union of all `ParameterBase` subclasses defined +# below. This models an algebraic data type and enables exhaustiveness checking in +# pattern matches etc. +# Note that this might become obsolete in case the `@sealed` decorator is added: +# * https://peps.python.org/pep-0622/#sealed-classes-as-algebraic-data-types +# * https://github.com/johnthagen/sealed-typing-pep +Parameter: TypeAlias = "TypeParam | ConstParam" + + +@dataclass(frozen=True) +class ParameterBase(ToHugr[tys.TypeParam], ABC): + """Abstract base class for parameters used in function and type definitions. + + For example, when defining a struct type + + ``` + @guppy.struct + class Foo[T, n: int]: + ... + ``` + + we generate an `StructDef` that depends on the parameters `T` and `n`. From + this, we obtain a proper `StructType` by providing arguments that are substituted + for the parameters (for example `Foo[int, 42]`). + """ + + idx: int + name: str + + @abstractmethod + def with_idx(self, idx: int) -> Self: + """Returns a copy of the parameter with a new index.""" + + @abstractmethod + def check_arg(self, arg: Argument, loc: AstNode | None = None) -> Argument: + """Checks that this parameter can be instantiated with a given argument. + + Raises a user error if the argument is not valid. + """ + + @abstractmethod + def to_existential(self) -> tuple[Argument, ExistentialVar]: + """Creates a fresh existential variable that can be instantiated for this + parameter. + + Returns both the argument and the created variable. + """ + + @abstractmethod + def to_bound(self, idx: int | None = None) -> Argument: + """Creates a bound variable with a given index that can be instantiated for this + parameter. + """ + + +@dataclass(frozen=True) +class TypeParam(ParameterBase): + """A parameter of kind type. Used to define generic functions and types.""" + + can_be_linear: bool + + def with_idx(self, idx: int) -> "TypeParam": + """Returns a copy of the parameter with a new index.""" + return TypeParam(idx, self.name, self.can_be_linear) + + def check_arg(self, arg: Argument, loc: AstNode | None = None) -> TypeArg: + """Checks that this parameter can be instantiated with a given argument. + + Raises a user error if the argument is not valid. + """ + match arg: + case ConstArg(const): + raise GuppyTypeError( + f"Expected a type, got value of type {const.ty}", loc + ) + case TypeArg(ty): + if not self.can_be_linear and ty.linear: + raise GuppyTypeError( + f"Expected a non-linear type, got value of type {ty}", loc + ) + return arg + + def to_existential(self) -> tuple[Argument, ExistentialVar]: + """Creates a fresh existential variable that can be instantiated for this + parameter. + + Returns both the argument and the created variable. + """ + from guppylang.tys.ty import ExistentialTypeVar + + var = ExistentialTypeVar.fresh(self.name, self.can_be_linear) + return TypeArg(var), var + + def to_bound(self, idx: int | None = None) -> Argument: + """Creates a bound variable with a given index that can be instantiated for this + parameter. + """ + from guppylang.tys.ty import BoundTypeVar + + if idx is None: + idx = self.idx + return TypeArg(BoundTypeVar(self.name, idx, self.can_be_linear)) + + def to_hugr(self) -> tys.TypeParam: + """Computes the Hugr representation of the parameter.""" + return tys.TypeTypeParam( + b=tys.TypeBound.Any if self.can_be_linear else TypeBound.Copyable + ) + + +@dataclass(frozen=True) +class ConstParam(ParameterBase): + """A parameter of kind constant. Used to define fixed-size arrays etc. + + Note that support for this kind is not implemented yet. + """ + + ty: "Type" + + def __post_init__(self) -> None: + if self.ty.unsolved_vars: + raise InternalGuppyError( + "Attempted to create constant param with unsolved type" + ) + + def with_idx(self, idx: int) -> "ConstParam": + """Returns a copy of the parameter with a new index.""" + return ConstParam(idx, self.name, self.ty) + + def check_arg(self, arg: Argument, loc: AstNode | None = None) -> ConstArg: + """Checks that this parameter can be instantiated with a given argument. + + Raises a user error if the argument is not valid. + """ + raise NotImplementedError + + def to_existential(self) -> tuple[Argument, ExistentialVar]: + """Creates a fresh existential variable that can be instantiated for this + parameter. + + Returns both the argument and the created variable. + """ + raise NotImplementedError + + def to_bound(self, idx: int | None = None) -> Argument: + """Creates a bound variable with a given index that can be instantiated for this + parameter. + """ + raise NotImplementedError + + def to_hugr(self) -> tys.TypeParam: + """Computes the Hugr representation of the parameter.""" + raise NotImplementedError diff --git a/guppylang/tys/parsing.py b/guppylang/tys/parsing.py new file mode 100644 index 00000000..a587ccd3 --- /dev/null +++ b/guppylang/tys/parsing.py @@ -0,0 +1,116 @@ +import ast +from collections.abc import Sequence + +from guppylang.ast_util import AstNode +from guppylang.checker.core import Globals +from guppylang.error import GuppyError +from guppylang.tys.arg import Argument, TypeArg +from guppylang.tys.param import Parameter, TypeParam +from guppylang.tys.ty import NoneType, TupleType, Type + + +def arg_from_ast( + node: AstNode, + globals: Globals, + param_var_mapping: dict[str, Parameter] | None = None, +) -> Argument: + """Turns an AST expression into an argument.""" + # A single identifier + if isinstance(node, ast.Name): + x = node.id + # Either a defined type (e.g. `int`, `bool`, ...) + if x in globals.type_defs: + ty = globals.type_defs[x].check_instantiate([], node) + return TypeArg(ty) + # Or a parameter (e.g. `T`, `n`, ...) + if x in globals.param_vars: + if param_var_mapping is None: + raise GuppyError( + "Free type variable. Only function types can be generic", node + ) + var = globals.param_vars[x] + if var.name not in param_var_mapping: + param_var_mapping[var.name] = var.with_idx(len(param_var_mapping)) + return param_var_mapping[var.name].to_bound() + raise GuppyError("Unknown identifier", node) + + # A parametrised type, e.g. `list[??]` + if isinstance(node, ast.Subscript) and isinstance(node.value, ast.Name): + x = node.value.id + if x in globals.type_defs: + arg_nodes = ( + node.slice.elts if isinstance(node.slice, ast.Tuple) else [node.slice] + ) + # Hack: Flatten argument lists to support the `Callable` type. For example, + # we turn `Callable[[int, int], bool]` into `Callable[int, int, bool]`. + # TODO: We can get rid of this once we added support for variadic params + arg_nodes = [ + n + for arg in arg_nodes + for n in (arg.elts if isinstance(arg, ast.List) else (arg,)) + ] + args = [ + arg_from_ast(arg_node, globals, param_var_mapping) + for arg_node in arg_nodes + ] + ty = globals.type_defs[x].check_instantiate(args, node) + return TypeArg(ty) + # We don't allow parametrised variables like `T[int]` + if x in globals.param_vars: + raise GuppyError( + f"Variable `{x}` is not parameterized. Higher-kinded types are not " + f"supported", + node, + ) + + # We allow tuple types to be written as `(int, bool)` + if isinstance(node, ast.Tuple): + ty = TupleType( + [type_from_ast(el, globals, param_var_mapping) for el in node.elts] + ) + return TypeArg(ty) + + # `None` is represented as a `ast.Constant` node with value `None` + if isinstance(node, ast.Constant) and node.value is None: + return TypeArg(NoneType()) + + # Finally, we also support delayed annotations in strings + if isinstance(node, ast.Constant) and isinstance(node.value, str): + try: + [stmt] = ast.parse(node.value).body + if not isinstance(stmt, ast.Expr): + raise GuppyError("Invalid Guppy type", node) + return arg_from_ast(stmt.value, globals, param_var_mapping) + except (SyntaxError, ValueError): + raise GuppyError("Invalid Guppy type", node) from None + + raise GuppyError("Not a valid type argument", node) + + +_type_param = TypeParam(0, "T", True) + + +def type_from_ast( + node: AstNode, + globals: Globals, + param_var_mapping: dict[str, Parameter] | None = None, +) -> Type: + """Turns an AST expression into a Guppy type.""" + # Parse an argument and check that it's valid for a `TypeParam` + arg = arg_from_ast(node, globals, param_var_mapping) + return _type_param.check_arg(arg, node).ty + + +def type_row_from_ast(node: ast.expr, globals: "Globals") -> Sequence[Type]: + """Turns an AST expression into a Guppy type row. + + This is needed to interpret the return type annotation of functions. + """ + # The return type `-> None` is represented in the ast as `ast.Constant(value=None)` + if isinstance(node, ast.Constant) and node.value is None: + return [] + ty = type_from_ast(node, globals) + if isinstance(ty, TupleType): + return ty.element_types + else: + return [ty] diff --git a/guppylang/tys/printing.py b/guppylang/tys/printing.py new file mode 100644 index 00000000..252efb54 --- /dev/null +++ b/guppylang/tys/printing.py @@ -0,0 +1,126 @@ +from functools import singledispatchmethod + +from guppylang.error import InternalGuppyError +from guppylang.tys.arg import ConstArg, TypeArg +from guppylang.tys.param import ConstParam, TypeParam +from guppylang.tys.ty import ( + FunctionType, + NoneType, + OpaqueType, + SumType, + TupleType, + Type, +) +from guppylang.tys.var import BoundVar, ExistentialVar, UniqueId + + +class TypePrinter: + """Visitor that pretty prints types. + + Takes care of inserting minimal parentheses and renaming variables to make them + unique. + """ + + # Store how often each user-picked display name is used to stand for different + # variables + used: dict[str, int] + + # Already chosen names for bound and existential variables + bound_names: list[str] + existential_names: dict[UniqueId, str] + + # Count how often the user has picked the same name to stand for different variables + counter: dict[str, int] + + def __init__(self) -> None: + self.used = {} + self.bound_names = [] + self.existential_names = {} + self.counter = {} + + def _fresh_name(self, display_name: str) -> str: + if display_name not in self.counter: + self.counter[display_name] = 1 + return display_name + + # If the display name `T` has already been used, we start adding indices: `T`, + # `T'1`, `T'2`, ... + indexed = f"{display_name}'{self.counter[display_name]}" + self.counter[display_name] += 1 + return indexed + + def visit(self, ty: Type) -> str: + return self._visit(ty, False) + + @singledispatchmethod + def _visit(self, ty: Type, inside_row: bool) -> str: + raise InternalGuppyError(f"Tried to pretty-print unknown type: {ty!r}") + + @_visit.register + def _visit_BoundVar(self, var: BoundVar, inside_row: bool) -> str: + return self.bound_names[var.idx] + + @_visit.register + def _visit_ExistentialVar(self, var: ExistentialVar, inside_row: bool) -> str: + if var.id not in self.existential_names: + self.existential_names[var.id] = self._fresh_name(var.display_name) + return f"?{self.existential_names[var.id]}" + + @_visit.register + def _visit_FunctionType(self, ty: FunctionType, inside_row: bool) -> str: + if ty.parametrized: + for p in ty.params: + self.bound_names.append(self._fresh_name(p.name)) + inputs = ", ".join([self._visit(inp, True) for inp in ty.inputs]) + if len(ty.inputs) != 1: + inputs = f"({inputs})" + output = self._visit(ty.output, True) + if ty.parametrized: + quantified = ", ".join([self._visit(param, False) for param in ty.params]) + del self.bound_names[: -len(ty.params)] + return _wrap(f"forall {quantified}. {inputs} -> {output}", inside_row) + return _wrap(f"{inputs} -> {output}", inside_row) + + @_visit.register + def _visit_OpaqueType(self, ty: OpaqueType, inside_row: bool) -> str: + if ty.args: + args = ", ".join(self._visit(arg, True) for arg in ty.args) + return f"{ty.defn.name}[{args}]" + return ty.defn.name + + @_visit.register + def _visit_TupleType(self, ty: TupleType, inside_row: bool) -> str: + args = ", ".join(self._visit(arg, True) for arg in ty.args) + return f"({args})" + + @_visit.register + def _visit_SumType(self, ty: SumType, inside_row: bool) -> str: + args = ", ".join(self._visit(arg, True) for arg in ty.args) + return f"Sum[{args}]" + + @_visit.register + def _visit_NoneType(self, ty: NoneType, inside_row: bool) -> str: + return "None" + + @_visit.register + def _visit_TypeParam(self, param: TypeParam, inside_row: bool) -> str: + # TODO: Print linearity? + return self.bound_names[-param.idx - 1] + + @_visit.register + def _visit_ConstParam(self, param: ConstParam, inside_row: bool) -> str: + kind = self._visit(param.ty, True) + name = self.bound_names[-param.idx - 1] + return f"{name}: {kind}" + + @_visit.register + def _visit_TypeArg(self, arg: TypeArg, inside_row: bool) -> str: + return self._visit(arg.ty, inside_row) + + @_visit.register + def _visit_ConstArg(self, arg: ConstArg, inside_row: bool) -> str: + return self._visit(arg.const, inside_row) + + +def _wrap(s: str, inside_row: bool) -> str: + return f"({s})" if inside_row else s diff --git a/guppylang/tys/subst.py b/guppylang/tys/subst.py new file mode 100644 index 00000000..118b3f50 --- /dev/null +++ b/guppylang/tys/subst.py @@ -0,0 +1,64 @@ +import functools +from collections.abc import Sequence +from typing import Any + +from guppylang.error import InternalGuppyError +from guppylang.tys.arg import Argument, TypeArg +from guppylang.tys.common import Transformer +from guppylang.tys.const import BoundConstVar, ExistentialConstVar +from guppylang.tys.ty import BoundTypeVar, ExistentialTypeVar, FunctionType, Type +from guppylang.tys.var import ExistentialVar + +Subst = dict[ExistentialVar, Type] # TODO: `GuppyType | Const` or `Argument` ?? +Inst = Sequence[Argument] + + +class Substituter(Transformer): + """Type transformer that applies a substitution of existential variables.""" + + def __init__(self, subst: Subst) -> None: + self.subst = subst + + @functools.singledispatchmethod + def transform(self, ty: Any) -> Any | None: # type: ignore[override] + return None + + @transform.register + def _transform_ExistentialTypeVar(self, ty: ExistentialTypeVar) -> Type | None: + return self.subst.get(ty, None) + + @transform.register + def _transform_ExistentialConstVar(self, ty: ExistentialConstVar) -> Type | None: + raise NotImplementedError + + +class Instantiator(Transformer): + """Type transformer that instantiates bound variables.""" + + def __init__(self, inst: Inst) -> None: + self.inst = inst + + @functools.singledispatchmethod + def transform(self, ty: Any) -> Any | None: # type: ignore[override] + return None + + @transform.register + def _transform_BoundTypeVar(self, ty: BoundTypeVar) -> Type | None: + # Instantiate if type for the index is available + if ty.idx < len(self.inst): + arg = self.inst[ty.idx] + assert isinstance(arg, TypeArg) + return arg.ty + + # Otherwise, lower the de Bruijn index + return BoundTypeVar(ty.display_name, ty.idx - len(self.inst), ty.linear) + + @transform.register + def _transform_BoundConstVar(self, ty: BoundConstVar) -> Type | None: + raise NotImplementedError + + @transform.register + def _transform_FunctionType(self, ty: FunctionType) -> Type | None: + if ty.parametrized: + raise InternalGuppyError("Tried to instantiate under binder") + return None diff --git a/guppylang/tys/ty.py b/guppylang/tys/ty.py new file mode 100644 index 00000000..973dcef2 --- /dev/null +++ b/guppylang/tys/ty.py @@ -0,0 +1,516 @@ +from abc import ABC, abstractmethod +from collections.abc import Sequence +from dataclasses import dataclass, field +from functools import cached_property +from typing import TYPE_CHECKING, TypeAlias, cast + +from guppylang.error import InternalGuppyError +from guppylang.hugr import tys +from guppylang.hugr.tys import TypeBound +from guppylang.tys.arg import Argument, ConstArg, TypeArg +from guppylang.tys.common import ToHugr, Transformable, Transformer, Visitor +from guppylang.tys.const import ExistentialConstVar +from guppylang.tys.param import Parameter +from guppylang.tys.var import BoundVar, ExistentialVar + +if TYPE_CHECKING: + from guppylang.tys.definition import OpaqueTypeDef + from guppylang.tys.subst import Inst, Subst + + +@dataclass(frozen=True) +class TypeBase(ToHugr[tys.Type], Transformable["Type"], ABC): + """Abstract base class for all Guppy types. + + Note that all subclasses are expected to be immutable. + """ + + @cached_property + @abstractmethod + def linear(self) -> bool: + """Whether this type should be treated linearly.""" + + @cached_property + @abstractmethod + def hugr_bound(self) -> tys.TypeBound: + """The Hugr bound of this type, i.e. `Any`, `Copyable`, or `Equatable`. + + This needs to be specified explicitly, since opaque nonlinear types in a Hugr + extension could be either declared as copyable or equatable. If we don't get the + bound exactly right during serialisation, the Hugr validator will complain. + """ + + @cached_property + def unsolved_vars(self) -> set[ExistentialVar]: + """The existential type variables contained in this type.""" + return set() + + def substitute(self, subst: "Subst") -> "Type": + """Substitutes existential variables in this type.""" + from guppylang.tys.subst import Substituter + + return self.transform(Substituter(subst)) + + def __str__(self) -> str: + """Returns a human-readable representation of the type.""" + from guppylang.tys.printing import TypePrinter + + # We use a custom printer that takes care of inserting parentheses and choosing + # unique names + return TypePrinter().visit(cast(Type, self)) + + +@dataclass(frozen=True) +class ParametrizedTypeBase(TypeBase, ABC): + """Abstract base class for types that depend on parameters. + + For example, `list`, `tuple`, etc. require arguments in order to be turned into a + proper type. + + Note that all subclasses are expected to be immutable. + """ + + args: Sequence[Argument] + + def __post_init__(self) -> None: + # Make sure that we don't have nested generic functions + for arg in self.args: + match arg: + case TypeArg(ty=FunctionType(parametrized=True)): + raise InternalGuppyError( + "Tried to construct a higher-rank polymorphic type!" + ) + + @property + @abstractmethod + def intrinsically_linear(self) -> bool: + """Whether this type is linear, independent of the arguments. + + For example, a parametrized struct containing a qubit is linear, no matter what + the arguments are. + """ + + @cached_property + def linear(self) -> bool: + """Whether this type should be treated linearly.""" + return self.intrinsically_linear or any( + isinstance(arg, TypeArg) and arg.ty.linear for arg in self.args + ) + + @cached_property + def unsolved_vars(self) -> set[ExistentialVar]: + """The existential type variables contained in this type.""" + unsolved = set() + for arg in self.args: + match arg: + case TypeArg(ty): + unsolved |= ty.unsolved_vars + case ConstArg(c) if isinstance(c, ExistentialConstVar): + unsolved.add(c) + return unsolved + + @cached_property + def hugr_bound(self) -> tys.TypeBound: + """The Hugr bound of this type, i.e. `Any`, `Copyable`, or `Equatable`.""" + if self.linear: + return tys.TypeBound.Any + return tys.TypeBound.join( + *(arg.ty.hugr_bound for arg in self.args if isinstance(arg, TypeArg)) + ) + + def visit(self, visitor: Visitor) -> None: + """Accepts a visitor on this type.""" + if not visitor.visit(self): + for arg in self.args: + visitor.visit(arg) + + +@dataclass(frozen=True) +class BoundTypeVar(TypeBase, BoundVar): + """Bound type variable, referencing a parameter of kind `Type`. + + For example, in the function type `forall T. list[T] -> T` we represent `T` as a + `BoundTypeVar(idx=0)`. + + A bound type variables can be instantiated with a `TypeArg` argument. + """ + + linear: bool + + @cached_property + def hugr_bound(self) -> tys.TypeBound: + """The Hugr bound of this type, i.e. `Any`, `Copyable`, or `Equatable`.""" + if self.linear: + return TypeBound.Any + # We're conservative and don't require equatability for non-linear variables. + # This is fine since Guppy doesn't use the equatable feature anyways. + return TypeBound.Copyable + + def to_hugr(self) -> tys.Type: + """Computes the Hugr representation of the type.""" + return tys.Variable(i=self.idx, b=self.hugr_bound) + + def visit(self, visitor: Visitor) -> None: + """Accepts a visitor on this type.""" + visitor.visit(self) + + def transform(self, transformer: Transformer) -> "Type": + """Accepts a transformer on this type.""" + return transformer.transform(self) or self + + def __str__(self) -> str: + """Returns a human-readable representation of the type.""" + return self.display_name + + +@dataclass(frozen=True) +class ExistentialTypeVar(ExistentialVar, TypeBase): + """Existential type variable. + + For example, the empty list literal `[]` is typed as `list[?T]` where `?T` stands + for an existential type variable. + + During type checking we try to solve all existential type variables and substitute + them with concrete types. + """ + + linear: bool + + @classmethod + def fresh(cls, display_name: str, linear: bool) -> "ExistentialTypeVar": + return ExistentialTypeVar(display_name, next(cls._fresh_id), linear) + + @cached_property + def unsolved_vars(self) -> set[ExistentialVar]: + """The existential type variables contained in this type.""" + return {self} + + @cached_property + def hugr_bound(self) -> tys.TypeBound: + """The Hugr bound of this type, i.e. `Any`, `Copyable`, or `Equatable`.""" + raise InternalGuppyError( + "Tried to compute bound of unsolved existential type variable" + ) + + def to_hugr(self) -> tys.Type: + """Computes the Hugr representation of the type.""" + raise InternalGuppyError( + "Tried to convert unsolved existential type variable to Hugr" + ) + + def visit(self, visitor: Visitor) -> None: + """Accepts a visitor on this type.""" + visitor.visit(self) + + def transform(self, transformer: Transformer) -> "Type": + """Accepts a transformer on this type.""" + return transformer.transform(self) or self + + +@dataclass(frozen=True) +class NoneType(TypeBase): + """Type of tuples.""" + + linear: bool = field(default=False, init=False) + hugr_bound: tys.TypeBound = field(default=tys.TypeBound.Eq, init=False) + + # Flag to avoid turning the type into a row when calling `type_to_row()`. This is + # used to make sure that type vars instantiated to Nones are not broken up into + # empty rows when generating a Hugr + preserve: bool = field(default=False, compare=False) + + def to_hugr(self) -> tys.Type: + """Computes the Hugr representation of the type.""" + return tys.TupleType(inner=[]) + + def visit(self, visitor: Visitor) -> None: + """Accepts a visitor on this type.""" + visitor.visit(self) + + def transform(self, transformer: Transformer) -> "Type": + """Accepts a transformer on this type.""" + return transformer.transform(self) or self + + +@dataclass(frozen=True, init=False) +class FunctionType(ParametrizedTypeBase): + """Type of (potentially generic) functions.""" + + inputs: Sequence["Type"] + output: "Type" + params: Sequence[Parameter] + input_names: Sequence[str] | None + + args: Sequence[Argument] = field(init=False) + linear: bool = field(default=False, init=False) + intrinsically_linear: bool = field(default=False, init=False) + hugr_bound: tys.TypeBound = field(default=TypeBound.Copyable, init=False) + + def __init__( + self, + inputs: Sequence["Type"], + output: "Type", + input_names: Sequence[str] | None = None, + params: Sequence[Parameter] | None = None, + ) -> None: + # We need a custom __init__ to set the args + args = [TypeArg(ty) for ty in inputs] + args.append(TypeArg(output)) + object.__setattr__(self, "args", args) + object.__setattr__(self, "inputs", inputs) + object.__setattr__(self, "output", output) + object.__setattr__(self, "input_names", input_names or []) + object.__setattr__(self, "params", params or []) + + @property + def parametrized(self) -> bool: + """Whether the function is parametrized.""" + return len(self.params) > 0 + + def to_hugr(self) -> tys.Type: + """Computes the Hugr representation of the type.""" + ins = [t.to_hugr() for t in self.inputs] + outs = [t.to_hugr() for t in type_to_row(self.output)] + func_ty = tys.FunctionType(input=ins, output=outs, extension_reqs=[]) + return tys.PolyFuncType(params=[p.to_hugr() for p in self.params], body=func_ty) + + def visit(self, visitor: Visitor) -> None: + """Accepts a visitor on this type.""" + if not visitor.visit(self): + for inp in self.inputs: + visitor.visit(inp) + visitor.visit(self.output) + for param in self.params: + visitor.visit(param) + + def transform(self, transformer: Transformer) -> "Type": + """Accepts a transformer on this type.""" + return transformer.transform(self) or FunctionType( + [inp.transform(transformer) for inp in self.inputs], + self.output.transform(transformer), + self.input_names, + self.params, + ) + + def instantiate(self, args: "Inst") -> "FunctionType": + """Instantiates all function parameters with concrete types.""" + from guppylang.tys.subst import Instantiator + + assert len(args) == len(self.params) + + # Set the `preserve` flag for instantiated tuples and None + preserved_args: list[Argument] = [] + for arg in args: + if isinstance(arg, TypeArg): + if isinstance(arg.ty, TupleType): + arg = TypeArg(TupleType(arg.ty.element_types, preserve=True)) + elif isinstance(arg.ty, NoneType): + arg = TypeArg(NoneType(preserve=True)) + preserved_args.append(arg) + + inst = Instantiator(preserved_args) + return FunctionType( + [ty.transform(inst) for ty in self.inputs], + self.output.transform(inst), + self.input_names, + ) + + def unquantified(self) -> tuple["FunctionType", Sequence[ExistentialVar]]: + """Instantiates all parameters with existential variables.""" + exs = [param.to_existential() for param in self.params] + return self.instantiate([arg for arg, _ in exs]), [var for _, var in exs] + + +@dataclass(frozen=True, init=False) +class TupleType(ParametrizedTypeBase): + """Type of tuples.""" + + element_types: Sequence["Type"] + + # Flag to avoid turning the tuple into a row when calling `type_to_row()`. This is + # used to make sure that type vars instantiated to tuples are not broken up into + # rows when generating a Hugr + preserve: bool = field(default=False, compare=False) + + def __init__(self, element_types: Sequence["Type"], preserve: bool = False) -> None: + # We need a custom __init__ to set the args + args = [TypeArg(ty) for ty in element_types] + object.__setattr__(self, "args", args) + object.__setattr__(self, "element_types", element_types) + object.__setattr__(self, "preserve", preserve) + + @property + def intrinsically_linear(self) -> bool: + return False + + def to_hugr(self) -> tys.Type: + """Computes the Hugr representation of the type.""" + return tys.TupleType(inner=[ty.to_hugr() for ty in self.element_types]) + + def transform(self, transformer: Transformer) -> "Type": + """Accepts a transformer on this type.""" + return transformer.transform(self) or TupleType( + [ty.transform(transformer) for ty in self.element_types], self.preserve + ) + + +@dataclass(frozen=True, init=False) +class SumType(ParametrizedTypeBase): + """Type of sums. + + Note that this type is only used internally when constructing the Hugr. Users cannot + write down this type. + """ + + element_types: Sequence["Type"] + + def __init__(self, element_types: Sequence["Type"]) -> None: + # We need a custom __init__ to set the args + args = [TypeArg(ty) for ty in element_types] + object.__setattr__(self, "args", args) + object.__setattr__(self, "element_types", element_types) + + @property + def intrinsically_linear(self) -> bool: + return False + + def to_hugr(self) -> tys.Type: + """Computes the Hugr representation of the type.""" + if all( + isinstance(e, TupleType) and len(e.element_types) == 0 + for e in self.element_types + ): + return tys.UnitSum(size=len(self.element_types)) + return tys.GeneralSum(row=[t.to_hugr() for t in self.element_types]) + + def transform(self, transformer: Transformer) -> "Type": + """Accepts a transformer on this type.""" + return transformer.transform(self) or SumType( + [ty.transform(transformer) for ty in self.element_types] + ) + + +@dataclass(frozen=True) +class OpaqueType(ParametrizedTypeBase): + """Type that is directly backed by a Hugr opaque type. + + For example, many builtin types like `int`, `float`, `list` etc. are directly backed + by a Hugr extension. + """ + + defn: "OpaqueTypeDef" + + @property + def intrinsically_linear(self) -> bool: + """Whether this type is linear, independent of the arguments.""" + return self.defn.always_linear + + @property + def hugr_bound(self) -> tys.TypeBound: + """The Hugr bound of this type, i.e. `Any`, `Copyable`, or `Equatable`.""" + if self.defn.bound is not None: + return self.defn.bound + return super().hugr_bound + + def to_hugr(self) -> tys.Type: + """Computes the Hugr representation of the type.""" + return self.defn.to_hugr(self.args) + + def transform(self, transformer: Transformer) -> "Type": + """Accepts a transformer on this type.""" + return transformer.transform(self) or OpaqueType( + [arg.transform(transformer) for arg in self.args], self.defn + ) + + +# We define the `Type` type as a union of all `TypeBase` subclasses defined above. This +# models an algebraic data type and enables exhaustiveness checking in pattern matches +# etc. +# Note that this might become obsolete in case the `@sealed` decorator is added: +# * https://peps.python.org/pep-0622/#sealed-classes-as-algebraic-data-types +# * https://github.com/johnthagen/sealed-typing-pep +ParametrizedType: TypeAlias = FunctionType | TupleType | SumType | OpaqueType +Type: TypeAlias = BoundTypeVar | ExistentialTypeVar | NoneType | ParametrizedType + +TypeRow: TypeAlias = Sequence[Type] + + +def row_to_type(row: TypeRow) -> Type: + """Turns a row of types into a single type by packing into a tuple.""" + if len(row) == 0: + return NoneType() + elif len(row) == 1: + return row[0] + else: + return TupleType(row) + + +def type_to_row(ty: Type) -> TypeRow: + """Turns a type into a row of types by unpacking top-level tuples.""" + if isinstance(ty, NoneType) and not ty.preserve: + return [] + if isinstance(ty, TupleType) and not ty.preserve: + return ty.element_types + return [ty] + + +def unify(s: Type, t: Type, subst: "Subst | None") -> "Subst | None": + """Computes a most general unifier for two types. + + Return a substitutions `subst` such that `s[subst] == t[subst]` or `None` if this + not possible. + """ + if subst is None: + return None + match s, t: + case ExistentialTypeVar(id=s_id), ExistentialTypeVar(id=t_id) if s_id == t_id: + return subst + case ExistentialTypeVar() as s, t: + return _unify_var(s, t, subst) + case s, ExistentialTypeVar() as t: + return _unify_var(t, s, subst) + case BoundTypeVar(idx=s_idx), BoundTypeVar(idx=t_idx) if s_idx == t_idx: + return subst + case NoneType(), NoneType(): + return subst + case FunctionType() as s, FunctionType() as t if s.params == t.params: + return _unify_args(s, t, subst) + case TupleType() as s, TupleType() as t: + return _unify_args(s, t, subst) + case SumType() as s, SumType() as t: + return _unify_args(s, t, subst) + case OpaqueType() as s, OpaqueType() as t if s.defn == t.defn: + return _unify_args(s, t, subst) + case _: + return None + + +def _unify_var(var: ExistentialTypeVar, t: Type, subst: "Subst") -> "Subst | None": + """Helper function for unification of type variables.""" + if var in subst: + return unify(subst[var], t, subst) + if isinstance(t, ExistentialTypeVar) and t in subst: + return unify(var, subst[t], subst) + if var in t.unsolved_vars: + return None + return {var: t, **subst} + + +def _unify_args( + s: ParametrizedType, t: ParametrizedType, subst: "Subst" +) -> "Subst | None": + """Helper function for unification of type arguments of parametrised types.""" + if len(s.args) != len(t.args): + return None + for sa, ta in zip(s.args, t.args, strict=True): + match sa, ta: + case TypeArg(ty=sa_ty), TypeArg(ty=ta_ty): + res = unify(sa_ty, ta_ty, subst) + if res is None: + return None + subst = res + case ConstArg(), ConstArg(): + raise NotImplementedError + case _: + return None + return subst diff --git a/guppylang/tys/var.py b/guppylang/tys/var.py new file mode 100644 index 00000000..43555d6f --- /dev/null +++ b/guppylang/tys/var.py @@ -0,0 +1,49 @@ +import itertools +from abc import ABC +from collections.abc import Iterator +from dataclasses import dataclass +from typing import ClassVar + +# Type of de Bruijn indicies +DeBruijn = int + +# Type of unique variable identifiers +UniqueId = int + + +@dataclass(frozen=True) +class Var(ABC): + """Abstract base class for variables that occur in types. + + A variable can either occur as a type itself (see subclasses `BoundTypeVar` and + `ExistentialTypeVar`) or as an argument to a parametrised type. + """ + + # Name that is used when showing the variable to the user + display_name: str + + +@dataclass(frozen=True) +class BoundVar(Var, ABC): + """Variable that is bound to a parameter of kind `Const`. + + Identified by a de Bruijn index. + """ + + idx: DeBruijn + + +@dataclass(frozen=True) +class ExistentialVar(Var, ABC): + """Existential variable, referencing a parameter of kind `Const`. + + Identified by a globally unique id. + + During type checking we try to solve all existential variables and substitute + them with concrete consts. + """ + + id: UniqueId + + # Generator of fresh unique ids + _fresh_id: ClassVar[Iterator[UniqueId]] = itertools.count() diff --git a/tests/error/poly_errors/non_linear2.err b/tests/error/poly_errors/non_linear2.err index f45e934f..5c6827f2 100644 --- a/tests/error/poly_errors/non_linear2.err +++ b/tests/error/poly_errors/non_linear2.err @@ -4,4 +4,4 @@ Guppy compilation failed. Error in file $FILE:23 22: def main() -> None: 23: foo(h) ^^^^^^ -GuppyTypeError: Cannot instantiate non-linear type variable `T` in type `forall T. T -> T -> None` with linear type `qubit` +GuppyTypeError: Cannot instantiate non-linear type variable `T` in type `forall T. (T -> T) -> None` with linear type `qubit` diff --git a/tests/error/poly_errors/pass_poly_free.err b/tests/error/poly_errors/pass_poly_free.err index 5154e2df..945cafb0 100644 --- a/tests/error/poly_errors/pass_poly_free.err +++ b/tests/error/poly_errors/pass_poly_free.err @@ -4,4 +4,4 @@ Guppy compilation failed. Error in file $FILE:24 23: def main() -> None: 24: foo(bar) ^^^ -GuppyTypeInferenceError: Expected argument of type `?T -> ?T`, got `forall T. T -> T`. Couldn't infer an instantiation for type variable `T` (higher-rank polymorphic types are not supported) +GuppyTypeInferenceError: Expected argument of type `?T -> ?T`, got `forall T. T -> T`. Couldn't infer an instantiation for parameter `T` (higher-rank polymorphic types are not supported) diff --git a/tests/hugr/test_dummy_nodes.py b/tests/hugr/test_dummy_nodes.py index c7e4e266..009c1a4e 100644 --- a/tests/hugr/test_dummy_nodes.py +++ b/tests/hugr/test_dummy_nodes.py @@ -1,15 +1,16 @@ -from guppylang.gtypes import FunctionType, BoolType, TupleType +from guppylang.tys.definition import bool_type +from guppylang.tys.ty import FunctionType, TupleType from guppylang.hugr import ops from guppylang.hugr.hugr import Hugr def test_single_dummy(): g = Hugr() - defn = g.add_def(FunctionType([BoolType()], BoolType()), g.root, "test") + defn = g.add_def(FunctionType([bool_type()], bool_type()), g.root, "test") dfg = g.add_dfg(defn) - inp = g.add_input([BoolType()], dfg).out_port(0) + inp = g.add_input([bool_type()], dfg).out_port(0) dummy = g.add_node( - ops.DummyOp(name="dummy"), inputs=[inp], output_types=[BoolType()], parent=dfg + ops.DummyOp(name="dummy"), inputs=[inp], output_types=[bool_type()], parent=dfg ) g.add_output([dummy.out_port(0)], parent=dfg) @@ -21,15 +22,15 @@ def test_single_dummy(): def test_unique_names(): g = Hugr() defn = g.add_def( - FunctionType([BoolType()], TupleType([BoolType(), BoolType()])), g.root, "test" + FunctionType([bool_type()], TupleType([bool_type(), bool_type()])), g.root, "test" ) dfg = g.add_dfg(defn) - inp = g.add_input([BoolType()], dfg).out_port(0) + inp = g.add_input([bool_type()], dfg).out_port(0) dummy1 = g.add_node( - ops.DummyOp(name="dummy"), inputs=[inp], output_types=[BoolType()], parent=dfg + ops.DummyOp(name="dummy"), inputs=[inp], output_types=[bool_type()], parent=dfg ) dummy2 = g.add_node( - ops.DummyOp(name="dummy"), inputs=[inp], output_types=[BoolType()], parent=dfg + ops.DummyOp(name="dummy"), inputs=[inp], output_types=[bool_type()], parent=dfg ) g.add_output([dummy1.out_port(0), dummy2.out_port(0)], parent=dfg) diff --git a/tests/hugr/test_ports.py b/tests/hugr/test_ports.py deleted file mode 100644 index 5d04671a..00000000 --- a/tests/hugr/test_ports.py +++ /dev/null @@ -1,14 +0,0 @@ -import pytest - -from guppylang.error import UndefinedPort, InternalGuppyError -from guppylang.gtypes import BoolType - - -def test_undefined_port(): - ty = BoolType() - p = UndefinedPort(ty) - assert p.ty == ty - with pytest.raises(InternalGuppyError, match="Tried to access undefined Port"): - p.node - with pytest.raises(InternalGuppyError, match="Tried to access undefined Port"): - p.offset