From f7adb85bfbc7498047471cdf6b232c6b5056e19e Mon Sep 17 00:00:00 2001 From: Mark Koch <48097969+mark-koch@users.noreply.github.com> Date: Wed, 15 May 2024 14:47:35 +0100 Subject: [PATCH] feat: Add struct types (#207) --- guppylang/ast_util.py | 19 ++ guppylang/checker/core.py | 3 + guppylang/checker/func_checker.py | 8 +- guppylang/decorator.py | 15 + guppylang/definition/common.py | 2 +- guppylang/definition/struct.py | 269 ++++++++++++++++++ guppylang/definition/ty.py | 17 +- guppylang/module.py | 2 +- guppylang/tys/param.py | 29 +- guppylang/tys/parsing.py | 23 +- guppylang/tys/printing.py | 8 +- guppylang/tys/ty.py | 39 ++- tests/error/struct_errors/__init__.py | 0 tests/error/struct_errors/default.err | 7 + tests/error/struct_errors/default.py | 13 + tests/error/struct_errors/duplicate_field.err | 7 + tests/error/struct_errors/duplicate_field.py | 14 + .../struct_errors/func_overrides_field1.err | 7 + .../struct_errors/func_overrides_field1.py | 17 ++ .../struct_errors/func_overrides_field2.err | 7 + .../struct_errors/func_overrides_field2.py | 17 ++ tests/error/struct_errors/inheritance.err | 6 + tests/error/struct_errors/inheritance.py | 13 + tests/error/struct_errors/invalid_generic.err | 6 + tests/error/struct_errors/invalid_generic.py | 18 ++ .../struct_errors/invalid_instantiate1.err | 6 + .../struct_errors/invalid_instantiate1.py | 21 ++ .../struct_errors/invalid_instantiate2.err | 6 + .../struct_errors/invalid_instantiate2.py | 21 ++ tests/error/struct_errors/keywords.err | 6 + tests/error/struct_errors/keywords.py | 13 + .../error/struct_errors/mutual_recursive.err | 7 + tests/error/struct_errors/mutual_recursive.py | 18 ++ tests/error/struct_errors/non_guppy_func.err | 7 + tests/error/struct_errors/non_guppy_func.py | 16 ++ tests/error/struct_errors/recursive.err | 7 + tests/error/struct_errors/recursive.py | 13 + tests/error/struct_errors/stray_docstring.err | 7 + tests/error/struct_errors/stray_docstring.py | 15 + tests/error/struct_errors/type_missing1.err | 7 + tests/error/struct_errors/type_missing1.py | 16 ++ tests/error/struct_errors/type_missing2.err | 7 + tests/error/struct_errors/type_missing2.py | 13 + tests/error/test_struct_errors.py | 20 ++ tests/error/util.py | 4 +- tests/integration/test_struct.py | 97 +++++++ 46 files changed, 858 insertions(+), 35 deletions(-) create mode 100644 guppylang/definition/struct.py create mode 100644 tests/error/struct_errors/__init__.py create mode 100644 tests/error/struct_errors/default.err create mode 100644 tests/error/struct_errors/default.py create mode 100644 tests/error/struct_errors/duplicate_field.err create mode 100644 tests/error/struct_errors/duplicate_field.py create mode 100644 tests/error/struct_errors/func_overrides_field1.err create mode 100644 tests/error/struct_errors/func_overrides_field1.py create mode 100644 tests/error/struct_errors/func_overrides_field2.err create mode 100644 tests/error/struct_errors/func_overrides_field2.py create mode 100644 tests/error/struct_errors/inheritance.err create mode 100644 tests/error/struct_errors/inheritance.py create mode 100644 tests/error/struct_errors/invalid_generic.err create mode 100644 tests/error/struct_errors/invalid_generic.py create mode 100644 tests/error/struct_errors/invalid_instantiate1.err create mode 100644 tests/error/struct_errors/invalid_instantiate1.py create mode 100644 tests/error/struct_errors/invalid_instantiate2.err create mode 100644 tests/error/struct_errors/invalid_instantiate2.py create mode 100644 tests/error/struct_errors/keywords.err create mode 100644 tests/error/struct_errors/keywords.py create mode 100644 tests/error/struct_errors/mutual_recursive.err create mode 100644 tests/error/struct_errors/mutual_recursive.py create mode 100644 tests/error/struct_errors/non_guppy_func.err create mode 100644 tests/error/struct_errors/non_guppy_func.py create mode 100644 tests/error/struct_errors/recursive.err create mode 100644 tests/error/struct_errors/recursive.py create mode 100644 tests/error/struct_errors/stray_docstring.err create mode 100644 tests/error/struct_errors/stray_docstring.py create mode 100644 tests/error/struct_errors/type_missing1.err create mode 100644 tests/error/struct_errors/type_missing1.py create mode 100644 tests/error/struct_errors/type_missing2.err create mode 100644 tests/error/struct_errors/type_missing2.py create mode 100644 tests/error/test_struct_errors.py create mode 100644 tests/integration/test_struct.py diff --git a/guppylang/ast_util.py b/guppylang/ast_util.py index 67dd9d40..309496bf 100644 --- a/guppylang/ast_util.py +++ b/guppylang/ast_util.py @@ -250,6 +250,25 @@ def annotate_location( annotate_location(value, source, file, line_offset, recurse) +def shift_loc(node: ast.AST, delta_lineno: int, delta_col_offset: int) -> None: + """Shifts all line and column number in the AST node by the given amount.""" + if hasattr(node, "lineno"): + node.lineno += delta_lineno + if hasattr(node, "end_lineno") and node.end_lineno is not None: + node.end_lineno += delta_lineno + if hasattr(node, "col_offset"): + node.col_offset += delta_col_offset + if hasattr(node, "end_col_offset") and node.end_col_offset is not None: + node.end_col_offset += delta_col_offset + for _, value in ast.iter_fields(node): + if isinstance(value, list): + for item in value: + if isinstance(item, ast.AST): + shift_loc(item, delta_lineno, delta_col_offset) + elif isinstance(value, ast.AST): + shift_loc(value, delta_lineno, delta_col_offset) + + def get_file(node: AstNode) -> str | None: """Tries to retrieve a file annotation from an AST node.""" try: diff --git a/guppylang/checker/core.py b/guppylang/checker/core.py index 8f653ff8..cba8930f 100644 --- a/guppylang/checker/core.py +++ b/guppylang/checker/core.py @@ -25,6 +25,7 @@ FunctionType, NoneType, OpaqueType, + StructType, SumType, TupleType, Type, @@ -88,6 +89,8 @@ def get_instance_func(self, ty: Type | TypeDef, name: str) -> CallableDef | None type_defn = callable_type_def case OpaqueType() as ty: type_defn = ty.defn + case StructType() as ty: + type_defn = ty.defn case TupleType(): type_defn = tuple_type_def case NoneType(): diff --git a/guppylang/checker/func_checker.py b/guppylang/checker/func_checker.py index 9538723c..072fb5ef 100644 --- a/guppylang/checker/func_checker.py +++ b/guppylang/checker/func_checker.py @@ -152,20 +152,20 @@ def check_signature(func_def: ast.FunctionDef, globals: Globals) -> FunctionType raise GuppyError("Return type must be annotated", func_def) # TODO: Prepopulate mapping when using Python 3.12 style generic functions - type_var_mapping: dict[DefId, "Parameter"] = {} + param_var_mapping: dict[str, "Parameter"] = {} input_tys = [] input_names = [] for inp in func_def.args.args: if inp.annotation is None: raise GuppyError("Argument type must be annotated", inp) - ty = type_from_ast(inp.annotation, globals, type_var_mapping) + ty = type_from_ast(inp.annotation, globals, param_var_mapping) input_tys.append(ty) input_names.append(inp.arg) - ret_type = type_from_ast(func_def.returns, globals, type_var_mapping) + ret_type = type_from_ast(func_def.returns, globals, param_var_mapping) return FunctionType( input_tys, ret_type, input_names, - sorted(type_var_mapping.values(), key=lambda v: v.idx), + sorted(param_var_mapping.values(), key=lambda v: v.idx), ) diff --git a/guppylang/decorator.py b/guppylang/decorator.py index a2be4918..4339f787 100644 --- a/guppylang/decorator.py +++ b/guppylang/decorator.py @@ -18,6 +18,7 @@ from guppylang.definition.declaration import RawFunctionDecl from guppylang.definition.function import RawFunctionDef, parse_py_func from guppylang.definition.parameter import TypeVarDef +from guppylang.definition.struct import RawStructDef from guppylang.definition.ty import OpaqueTypeDef, TypeDef from guppylang.error import GuppyError, MissingModuleError, pretty_errors from guppylang.hugr import ops, tys @@ -28,6 +29,7 @@ FuncDeclDecorator = Callable[[PyFunc], RawFunctionDecl] CustomFuncDecorator = Callable[[PyFunc], RawCustomFunctionDef] ClassDecorator = Callable[[type], type] +StructDecorator = Callable[[type], RawStructDef] @dataclass(frozen=True) @@ -145,6 +147,19 @@ def dec(c: type) -> type: return dec + @pretty_errors + def struct(self, module: GuppyModule) -> StructDecorator: + """Decorator to define a new struct.""" + module._instance_func_buffer = {} + + def dec(cls: type) -> RawStructDef: + defn = RawStructDef(DefId.fresh(module), cls.__name__, None, cls) + module.register_def(defn) + module._register_buffered_instance_funcs(defn) + return defn + + return dec + @pretty_errors def type_var(self, module: GuppyModule, name: str, linear: bool = False) -> TypeVar: """Creates a new type variable in a module.""" diff --git a/guppylang/definition/common.py b/guppylang/definition/common.py index de79a424..f6e3988f 100644 --- a/guppylang/definition/common.py +++ b/guppylang/definition/common.py @@ -47,7 +47,7 @@ class Definition(ABC): id: DefId name: str - defined_at: ast.FunctionDef | None + defined_at: ast.AST | None @property @abstractmethod diff --git a/guppylang/definition/struct.py b/guppylang/definition/struct.py new file mode 100644 index 00000000..508bc212 --- /dev/null +++ b/guppylang/definition/struct.py @@ -0,0 +1,269 @@ +import ast +import inspect +import textwrap +from collections.abc import Sequence +from dataclasses import dataclass +from typing import Any + +from guppylang.ast_util import AstNode, annotate_location +from guppylang.checker.core import Globals +from guppylang.definition.common import ( + CheckableDef, + CompiledDef, + DefId, + Definition, + ParsableDef, +) +from guppylang.definition.parameter import ParamDef +from guppylang.definition.ty import TypeDef +from guppylang.error import GuppyError, InternalGuppyError +from guppylang.tys.arg import Argument +from guppylang.tys.param import Parameter, check_all_args +from guppylang.tys.parsing import type_from_ast +from guppylang.tys.ty import StructType, Type + + +@dataclass(frozen=True) +class UncheckedStructField: + """A single field on a struct whose type has not been checked yet.""" + + name: str + type_ast: ast.expr + + +@dataclass(frozen=True) +class StructField: + """A single field on a struct.""" + + name: str + ty: Type + + +@dataclass(frozen=True) +class RawStructDef(TypeDef, ParsableDef): + """A raw struct type definition that has not been parsed yet.""" + + python_class: type + + def __getitem__(self, item: Any) -> "RawStructDef": + """Dummy implementation to enable subscripting in the Python runtime. + + For example, if users write `MyStruct[int]` in a function signature, the + interpreter will try to execute the expression which would fail if this function + weren't implemented. + """ + return self + + def parse(self, globals: Globals) -> "ParsedStructDef": + """Parses the raw class object into an AST and checks that it is well-formed.""" + cls_def = parse_py_class(self.python_class) + if cls_def.keywords: + raise GuppyError("Unexpected keyword", cls_def.keywords[0]) + + # The only base we allow is `Generic[...]` to specify generic parameters + # TODO: This will become obsolete once we have Python 3.12 style generic classes + params: list[Parameter] + match cls_def.bases: + case []: + params = [] + case [base] if elems := try_parse_generic_base(base): + params = params_from_ast(elems, globals) + case bases: + raise GuppyError("Struct inheritance is not supported", bases[0]) + + fields: list[UncheckedStructField] = [] + used_field_names: set[str] = set() + used_func_names: dict[str, ast.FunctionDef] = {} + for i, node in enumerate(cls_def.body): + match i, node: + # We allow `pass` statements to define empty structs + case _, ast.Pass(): + pass + # Docstrings are also fine if they occur at the start + case 0, ast.Expr(value=ast.Constant(value=v)) if isinstance(v, str): + pass + # Ensure that all function definitions are Guppy functions + case _, ast.FunctionDef(name=name) as node: + v = getattr(self.python_class, name) + if not isinstance(v, Definition): + raise GuppyError( + "Add a `@guppy` decorator to this function to add it to " + f"the struct `{self.name}`", + node, + ) + used_func_names[name] = node + if name in used_field_names: + raise GuppyError( + f"Struct `{self.name}` already contains a field named " + f"`{name}`", + node, + ) + # Struct fields are declared via annotated assignments without value + case _, ast.AnnAssign(target=ast.Name(id=field_name)) as node: + if node.value: + raise GuppyError( + "Default struct values are not supported", node.value + ) + if field_name in used_field_names: + raise GuppyError( + f"Struct `{self.name}` already contains a field named " + f"`{field_name}`", + node.target, + ) + fields.append(UncheckedStructField(field_name, node.annotation)) + used_field_names.add(field_name) + case _, node: + raise GuppyError("Unexpected statement in struct", node) + + # Ensure that functions don't override struct fields + if overriden := used_field_names.intersection(used_func_names.keys()): + x = overriden.pop() + raise GuppyError( + f"Struct `{self.name}` already contains a field named `{x}`", + used_func_names[x], + ) + + return ParsedStructDef(self.id, self.name, cls_def, params, fields) + + def check_instantiate( + self, args: Sequence[Argument], globals: "Globals", loc: AstNode | None = None + ) -> Type: + raise InternalGuppyError("Tried to instantiate raw struct definition") + + +@dataclass(frozen=True) +class ParsedStructDef(TypeDef, CheckableDef): + """A struct definition whose fields have not been checked yet.""" + + defined_at: ast.ClassDef + params: Sequence[Parameter] + fields: Sequence[UncheckedStructField] + + def check(self, globals: Globals) -> "CheckedStructDef": + """Checks that all struct fields have valid types.""" + # Before checking the fields, make sure that this definition is not recursive, + # otherwise the code below would not terminate. + # TODO: This is not ideal (see todo in `check_instantiate`) + check_not_recursive(self, globals) + + param_var_mapping = {p.name: p for p in self.params} + fields = [ + StructField(f.name, type_from_ast(f.type_ast, globals, param_var_mapping)) + for f in self.fields + ] + return CheckedStructDef( + self.id, self.name, self.defined_at, self.params, fields + ) + + def check_instantiate( + self, args: Sequence[Argument], globals: "Globals", loc: AstNode | None = None + ) -> Type: + """Checks if the struct can be instantiated with the given arguments.""" + check_all_args(self.params, args, self.name, loc) + # Obtain a checked version of this struct definition so we can construct a + # `StructType` instance + # TODO: This is quite bad: If we have a cyclic definition this will not + # terminate, so we have to check for cycles in every call to `check`. The + # proper way to deal with this is changing `StructType` such that it only + # takes a `DefId` instead of a `CheckedStructDef`. But this will be a bigger + # refactor... + checked_def = self.check(globals) + return StructType(args, checked_def) + + +@dataclass(frozen=True) +class CheckedStructDef(TypeDef, CompiledDef): + """A struct definition that has been fully checked.""" + + defined_at: ast.ClassDef + params: Sequence[Parameter] + fields: Sequence[StructField] + + def check_instantiate( + self, args: Sequence[Argument], globals: "Globals", loc: AstNode | None = None + ) -> Type: + """Checks if the struct can be instantiated with the given arguments.""" + check_all_args(self.params, args, self.name, loc) + return StructType(args, self) + + +def parse_py_class(cls: type) -> ast.ClassDef: + """Parses a Python class object into an AST.""" + source_lines, line_offset = inspect.getsourcelines(cls) + source = "".join(source_lines) # Lines already have trailing \n's + source = textwrap.dedent(source) + cls_ast = ast.parse(source).body[0] + file = inspect.getsourcefile(cls) + if file is None: + raise GuppyError("Couldn't determine source file for class") + annotate_location(cls_ast, source, file, line_offset) + if not isinstance(cls_ast, ast.ClassDef): + raise GuppyError("Expected a class definition", cls_ast) + return cls_ast + + +def try_parse_generic_base(node: ast.expr) -> list[ast.expr] | None: + """Checks if an AST node corresponds to a `Generic[T1, ..., Tn]` base class. + + Returns the generic parameters or `None` if the AST has a different shape + """ + match node: + case ast.Subscript(value=ast.Name(id="Generic"), slice=elem): + return elem.elts if isinstance(elem, ast.Tuple) else [elem] + case _: + return None + + +def params_from_ast(nodes: Sequence[ast.expr], globals: Globals) -> list[Parameter]: + """Parses a list of AST nodes into unique type parameters. + + Raises user errors if the AST nodes don't correspond to parameters or parameters + occur multiple times. + """ + params: list[Parameter] = [] + params_set: set[DefId] = set() + for node in nodes: + if isinstance(node, ast.Name) and node.id in globals: + defn = globals[node.id] + if isinstance(defn, ParamDef): + if defn.id in params_set: + raise GuppyError( + f"Parameter `{node.id}` cannot be used multiple times", node + ) + params.append(defn.to_param(len(params))) + params_set.add(defn.id) + continue + raise GuppyError("Not a parameter", node) + return params + + +def check_not_recursive(defn: ParsedStructDef, globals: Globals) -> None: + """Throws a user error if the given struct definition is recursive.""" + + # TODO: The implementation below hijacks the type parsing logic to detect recursive + # structs. This is not great since it repeats the work done during checking. We can + # get rid of this after resolving the todo in `ParsedStructDef.check_instantiate()` + + @dataclass(frozen=True) + class DummyStructDef(TypeDef): + """Dummy definition that throws an error when trying to instantiate it. + + By replacing the defn with this, we can detect recursive occurrences during + type parsing. + """ + + def check_instantiate( + self, + args: Sequence[Argument], + globals: "Globals", + loc: AstNode | None = None, + ) -> Type: + raise GuppyError("Recursive structs are not supported", loc) + + dummy_defs = { + **globals.defs, + defn.id: DummyStructDef(defn.id, defn.name, defn.defined_at), + } + dummy_globals = globals.update_defs(dummy_defs) + for field in defn.fields: + type_from_ast(field.type_ast, dummy_globals, {}) diff --git a/guppylang/definition/ty.py b/guppylang/definition/ty.py index 36fc4c35..4d1c6026 100644 --- a/guppylang/definition/ty.py +++ b/guppylang/definition/ty.py @@ -5,11 +5,10 @@ from guppylang.ast_util import AstNode from guppylang.definition.common import CompiledDef, Definition -from guppylang.error import GuppyError from guppylang.hugr import tys from guppylang.hugr.tys import Type from guppylang.tys.arg import Argument -from guppylang.tys.param import Parameter +from guppylang.tys.param import Parameter, check_all_args from guppylang.tys.ty import OpaqueType if TYPE_CHECKING: @@ -50,17 +49,5 @@ def check_instantiate( Returns the resulting concrete type or raises a user error if the arguments are invalid. """ - exp, act = len(self.params), len(args) - if exp > act: - raise GuppyError(f"Missing parameter for type `{self.name}`", loc) - elif 0 == exp < act: - raise GuppyError(f"Type `{self.name}` is not parameterized", loc) - elif 0 < exp < act: - raise GuppyError(f"Too many parameters for type `{self.name}`", loc) - - # Now check that the kinds match up - for param, arg in zip(self.params, args, strict=True): - # TODO: The error location is bad. We want the location of `arg`, not of the - # whole thing. - param.check_arg(arg, loc) + check_all_args(self.params, args, self.name, loc) return OpaqueType(args, self) diff --git a/guppylang/module.py b/guppylang/module.py index 4ab1eec6..dacc89d6 100644 --- a/guppylang/module.py +++ b/guppylang/module.py @@ -106,7 +106,7 @@ def register_def(self, defn: RawDef, instance: TypeDef | None = None) -> None: self._instance_func_buffer[defn.name] = defn else: self._check_name_available(defn.name, defn.defined_at) - if isinstance(defn, TypeDef): + if isinstance(defn, TypeDef | TypeVarDef): self._raw_type_defs[defn.id] = defn else: self._raw_defs[defn.id] = defn diff --git a/guppylang/tys/param.py b/guppylang/tys/param.py index e398992c..8b53bfee 100644 --- a/guppylang/tys/param.py +++ b/guppylang/tys/param.py @@ -1,11 +1,12 @@ from abc import ABC, abstractmethod +from collections.abc import Sequence from dataclasses import dataclass from typing import TYPE_CHECKING, TypeAlias from typing_extensions import Self from guppylang.ast_util import AstNode -from guppylang.error import GuppyTypeError, InternalGuppyError +from guppylang.error import GuppyError, GuppyTypeError, InternalGuppyError from guppylang.hugr import tys from guppylang.hugr.tys import TypeBound from guppylang.tys.arg import Argument, ConstArg, TypeArg @@ -169,3 +170,29 @@ def to_bound(self, idx: int | None = None) -> Argument: def to_hugr(self) -> tys.TypeParam: """Computes the Hugr representation of the parameter.""" raise NotImplementedError + + +def check_all_args( + params: Sequence[Parameter], + args: Sequence[Argument], + type_name: str, + loc: AstNode | None = None, +) -> None: + """Checks a list of arguments against the given parameters. + + Raises a user error if number of arguments doesn't match or one of the argument is + invalid. + """ + exp, act = len(params), len(args) + if exp > act: + raise GuppyError(f"Missing parameter for type `{type_name}`", loc) + elif 0 == exp < act: + raise GuppyError(f"Type `{type_name}` is not parameterized", loc) + elif 0 < exp < act: + raise GuppyError(f"Too many parameters for type `{type_name}`", loc) + + # Now check that the kinds match up + for param, arg in zip(params, args, strict=True): + # TODO: The error location is bad. We want the location of `arg`, not of the + # whole thing. + param.check_arg(arg, loc) diff --git a/guppylang/tys/parsing.py b/guppylang/tys/parsing.py index 4124d6b9..4154505c 100644 --- a/guppylang/tys/parsing.py +++ b/guppylang/tys/parsing.py @@ -1,9 +1,12 @@ import ast from collections.abc import Sequence -from guppylang.ast_util import AstNode +from guppylang.ast_util import ( + AstNode, + set_location_from, + shift_loc, +) from guppylang.checker.core import Globals -from guppylang.definition.common import DefId from guppylang.definition.parameter import ParamDef from guppylang.definition.ty import TypeDef from guppylang.error import GuppyError @@ -15,7 +18,7 @@ def arg_from_ast( node: AstNode, globals: Globals, - param_var_mapping: dict[DefId, Parameter] | None = None, + param_var_mapping: dict[str, Parameter] | None = None, ) -> Argument: """Turns an AST expression into an argument.""" # A single identifier @@ -33,9 +36,9 @@ def arg_from_ast( raise GuppyError( "Free type variable. Only function types can be generic", node ) - if defn.id not in param_var_mapping: - param_var_mapping[defn.id] = defn.to_param(len(param_var_mapping)) - return param_var_mapping[defn.id].to_bound() + if x not in param_var_mapping: + param_var_mapping[x] = defn.to_param(len(param_var_mapping)) + return param_var_mapping[x].to_bound() case defn: raise GuppyError( f"Expected a type, got {defn.description} `{defn.name}`", node @@ -92,6 +95,12 @@ def arg_from_ast( [stmt] = ast.parse(node.value).body if not isinstance(stmt, ast.Expr): raise GuppyError("Invalid Guppy type", node) + set_location_from(stmt, loc=node) + shift_loc( + stmt, + delta_lineno=node.lineno - 1, # -1 since lines start at 1 + delta_col_offset=node.col_offset + 1, # +1 to remove the `"` + ) return arg_from_ast(stmt.value, globals, param_var_mapping) except (SyntaxError, ValueError): raise GuppyError("Invalid Guppy type", node) from None @@ -105,7 +114,7 @@ def arg_from_ast( def type_from_ast( node: AstNode, globals: Globals, - param_var_mapping: dict[DefId, Parameter] | None = None, + param_var_mapping: dict[str, Parameter] | None = None, ) -> Type: """Turns an AST expression into a Guppy type.""" # Parse an argument and check that it's valid for a `TypeParam` diff --git a/guppylang/tys/printing.py b/guppylang/tys/printing.py index 252efb54..8ac82be6 100644 --- a/guppylang/tys/printing.py +++ b/guppylang/tys/printing.py @@ -7,6 +7,7 @@ FunctionType, NoneType, OpaqueType, + StructType, SumType, TupleType, Type, @@ -81,8 +82,11 @@ def _visit_FunctionType(self, ty: FunctionType, inside_row: bool) -> str: return _wrap(f"forall {quantified}. {inputs} -> {output}", inside_row) return _wrap(f"{inputs} -> {output}", inside_row) - @_visit.register - def _visit_OpaqueType(self, ty: OpaqueType, inside_row: bool) -> str: + @_visit.register(OpaqueType) + @_visit.register(StructType) + def _visit_OpaqueType_StructType( + self, ty: OpaqueType | StructType, inside_row: bool + ) -> str: if ty.args: args = ", ".join(self._visit(arg, True) for arg in ty.args) return f"{ty.defn.name}[{args}]" diff --git a/guppylang/tys/ty.py b/guppylang/tys/ty.py index 4e114663..c4d4bfd9 100644 --- a/guppylang/tys/ty.py +++ b/guppylang/tys/ty.py @@ -14,6 +14,7 @@ from guppylang.tys.var import BoundVar, ExistentialVar if TYPE_CHECKING: + from guppylang.definition.struct import CheckedStructDef, StructField from guppylang.definition.ty import OpaqueTypeDef from guppylang.tys.subst import Inst, Subst @@ -423,8 +424,42 @@ def transform(self, transformer: Transformer) -> "Type": ) +@dataclass(frozen=True) +class StructType(ParametrizedTypeBase): + """A struct type.""" + + defn: "CheckedStructDef" + + @cached_property + def fields(self) -> list["StructField"]: + """The fields of this struct type.""" + from guppylang.definition.struct import StructField + from guppylang.tys.subst import Instantiator + + inst = Instantiator(self.args) + return [StructField(f.name, f.ty.transform(inst)) for f in self.defn.fields] + + @cached_property + def intrinsically_linear(self) -> bool: + """Whether this type is linear, independent of the arguments.""" + return any(f.ty.linear for f in self.defn.fields) + + def to_hugr(self) -> tys.Type: + """Computes the Hugr representation of the type.""" + + return tys.TupleType(inner=[f.ty.to_hugr() for f in self.fields]) + + def transform(self, transformer: Transformer) -> "Type": + """Accepts a transformer on this type.""" + return transformer.transform(self) or StructType( + [arg.transform(transformer) for arg in self.args], self.defn + ) + + #: The type of parametrized Guppy types. -ParametrizedType: TypeAlias = FunctionType | TupleType | SumType | OpaqueType +ParametrizedType: TypeAlias = ( + FunctionType | TupleType | SumType | OpaqueType | StructType +) #: The type of Guppy types. #: @@ -487,6 +522,8 @@ def unify(s: Type, t: Type, subst: "Subst | None") -> "Subst | None": return _unify_args(s, t, subst) case OpaqueType() as s, OpaqueType() as t if s.defn == t.defn: return _unify_args(s, t, subst) + case StructType() as s, StructType() as t if s.defn == t.defn: + return _unify_args(s, t, subst) case _: return None diff --git a/tests/error/struct_errors/__init__.py b/tests/error/struct_errors/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/error/struct_errors/default.err b/tests/error/struct_errors/default.err new file mode 100644 index 00000000..62d6dcf0 --- /dev/null +++ b/tests/error/struct_errors/default.err @@ -0,0 +1,7 @@ +Guppy compilation failed. Error in file $FILE:10 + +8: @guppy.struct(module) +9: class MyStruct: +10: x: int = 42 + ^^ +GuppyError: Default struct values are not supported diff --git a/tests/error/struct_errors/default.py b/tests/error/struct_errors/default.py new file mode 100644 index 00000000..f8526bbf --- /dev/null +++ b/tests/error/struct_errors/default.py @@ -0,0 +1,13 @@ +from guppylang.decorator import guppy +from guppylang.module import GuppyModule + + +module = GuppyModule("test") + + +@guppy.struct(module) +class MyStruct: + x: int = 42 + + +module.compile() diff --git a/tests/error/struct_errors/duplicate_field.err b/tests/error/struct_errors/duplicate_field.err new file mode 100644 index 00000000..7058083c --- /dev/null +++ b/tests/error/struct_errors/duplicate_field.err @@ -0,0 +1,7 @@ +Guppy compilation failed. Error in file $FILE:11 + +9: class MyStruct: +10: x: int +11: x: bool + ^ +GuppyError: Struct `MyStruct` already contains a field named `x` diff --git a/tests/error/struct_errors/duplicate_field.py b/tests/error/struct_errors/duplicate_field.py new file mode 100644 index 00000000..56533d54 --- /dev/null +++ b/tests/error/struct_errors/duplicate_field.py @@ -0,0 +1,14 @@ +from guppylang.decorator import guppy +from guppylang.module import GuppyModule + + +module = GuppyModule("test") + + +@guppy.struct(module) +class MyStruct: + x: int + x: bool + + +module.compile() diff --git a/tests/error/struct_errors/func_overrides_field1.err b/tests/error/struct_errors/func_overrides_field1.err new file mode 100644 index 00000000..9dd60e4c --- /dev/null +++ b/tests/error/struct_errors/func_overrides_field1.err @@ -0,0 +1,7 @@ +Guppy compilation failed. Error in file $FILE:13 + +11: +12: @guppy(module) +13: def x(self: "MyStruct") -> int: + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +GuppyError: Struct `MyStruct` already contains a field named `x` diff --git a/tests/error/struct_errors/func_overrides_field1.py b/tests/error/struct_errors/func_overrides_field1.py new file mode 100644 index 00000000..35679d5c --- /dev/null +++ b/tests/error/struct_errors/func_overrides_field1.py @@ -0,0 +1,17 @@ +from guppylang.decorator import guppy +from guppylang.module import GuppyModule + + +module = GuppyModule("test") + + +@guppy.struct(module) +class MyStruct: + x: int + + @guppy(module) + def x(self: "MyStruct") -> int: + return 0 + + +module.compile() diff --git a/tests/error/struct_errors/func_overrides_field2.err b/tests/error/struct_errors/func_overrides_field2.err new file mode 100644 index 00000000..bada6583 --- /dev/null +++ b/tests/error/struct_errors/func_overrides_field2.err @@ -0,0 +1,7 @@ +Guppy compilation failed. Error in file $FILE:11 + +9: class MyStruct: +10: @guppy(module) +11: def x(self: "MyStruct") -> int: + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +GuppyError: Struct `MyStruct` already contains a field named `x` diff --git a/tests/error/struct_errors/func_overrides_field2.py b/tests/error/struct_errors/func_overrides_field2.py new file mode 100644 index 00000000..fb63e6fb --- /dev/null +++ b/tests/error/struct_errors/func_overrides_field2.py @@ -0,0 +1,17 @@ +from guppylang.decorator import guppy +from guppylang.module import GuppyModule + + +module = GuppyModule("test") + + +@guppy.struct(module) +class MyStruct: + @guppy(module) + def x(self: "MyStruct") -> int: + return 0 + + x: int + + +module.compile() diff --git a/tests/error/struct_errors/inheritance.err b/tests/error/struct_errors/inheritance.err new file mode 100644 index 00000000..5f76b29b --- /dev/null +++ b/tests/error/struct_errors/inheritance.err @@ -0,0 +1,6 @@ +Guppy compilation failed. Error in file $FILE:9 + +7: @guppy.struct(module) +8: class MyStruct(int): + ^^^ +GuppyError: Struct inheritance is not supported diff --git a/tests/error/struct_errors/inheritance.py b/tests/error/struct_errors/inheritance.py new file mode 100644 index 00000000..5d02cd08 --- /dev/null +++ b/tests/error/struct_errors/inheritance.py @@ -0,0 +1,13 @@ +from guppylang.decorator import guppy +from guppylang.module import GuppyModule + + +module = GuppyModule("test") + + +@guppy.struct(module) +class MyStruct(int): + x: bool + + +module.compile() diff --git a/tests/error/struct_errors/invalid_generic.err b/tests/error/struct_errors/invalid_generic.err new file mode 100644 index 00000000..975a726e --- /dev/null +++ b/tests/error/struct_errors/invalid_generic.err @@ -0,0 +1,6 @@ +Guppy compilation failed. Error in file $FILE:14 + +12: @guppy.struct(module) +13: class MyStruct(Generic[X]): + ^ +GuppyError: Not a parameter diff --git a/tests/error/struct_errors/invalid_generic.py b/tests/error/struct_errors/invalid_generic.py new file mode 100644 index 00000000..dbe3de12 --- /dev/null +++ b/tests/error/struct_errors/invalid_generic.py @@ -0,0 +1,18 @@ +from typing import Generic, TypeVar + +from guppylang.decorator import guppy +from guppylang.module import GuppyModule + + +module = GuppyModule("test") + + +X = TypeVar("X") # This is a Python type variable, not a Guppy one! + + +@guppy.struct(module) +class MyStruct(Generic[X]): + x: int + + +module.compile() diff --git a/tests/error/struct_errors/invalid_instantiate1.err b/tests/error/struct_errors/invalid_instantiate1.err new file mode 100644 index 00000000..0d78a3d4 --- /dev/null +++ b/tests/error/struct_errors/invalid_instantiate1.err @@ -0,0 +1,6 @@ +Guppy compilation failed. Error in file $FILE:17 + +15: @guppy(module) +16: def foo(s: MyStruct) -> None: + ^^^^^^^^ +GuppyError: Missing parameter for type `MyStruct` diff --git a/tests/error/struct_errors/invalid_instantiate1.py b/tests/error/struct_errors/invalid_instantiate1.py new file mode 100644 index 00000000..7ef70786 --- /dev/null +++ b/tests/error/struct_errors/invalid_instantiate1.py @@ -0,0 +1,21 @@ +from typing import Generic, TypeVar + +from guppylang.decorator import guppy +from guppylang.module import GuppyModule + + +module = GuppyModule("test") +T = guppy.type_var(module, "T") + + +@guppy.struct(module) +class MyStruct(Generic[T]): + x: list[T] + + +@guppy(module) +def foo(s: MyStruct) -> None: + pass + + +module.compile() diff --git a/tests/error/struct_errors/invalid_instantiate2.err b/tests/error/struct_errors/invalid_instantiate2.err new file mode 100644 index 00000000..61c04f56 --- /dev/null +++ b/tests/error/struct_errors/invalid_instantiate2.err @@ -0,0 +1,6 @@ +Guppy compilation failed. Error in file $FILE:17 + +15: @guppy(module) +16: def foo(s: MyStruct[int, bool]) -> None: + ^^^^^^^^^^^^^^^^^^^ +GuppyError: Too many parameters for type `MyStruct` diff --git a/tests/error/struct_errors/invalid_instantiate2.py b/tests/error/struct_errors/invalid_instantiate2.py new file mode 100644 index 00000000..326e0e20 --- /dev/null +++ b/tests/error/struct_errors/invalid_instantiate2.py @@ -0,0 +1,21 @@ +from typing import Generic, TypeVar + +from guppylang.decorator import guppy +from guppylang.module import GuppyModule + + +module = GuppyModule("test") +T = guppy.type_var(module, "T") + + +@guppy.struct(module) +class MyStruct(Generic[T]): + x: list[T] + + +@guppy(module) +def foo(s: MyStruct[int, bool]) -> None: + pass + + +module.compile() diff --git a/tests/error/struct_errors/keywords.err b/tests/error/struct_errors/keywords.err new file mode 100644 index 00000000..2abaeb4e --- /dev/null +++ b/tests/error/struct_errors/keywords.err @@ -0,0 +1,6 @@ +Guppy compilation failed. Error in file $FILE:9 + +7: @guppy.struct(module) +8: class MyStruct(metaclass=type): + ^^^^^^^^^^^^^^ +GuppyError: Unexpected keyword diff --git a/tests/error/struct_errors/keywords.py b/tests/error/struct_errors/keywords.py new file mode 100644 index 00000000..b6800e93 --- /dev/null +++ b/tests/error/struct_errors/keywords.py @@ -0,0 +1,13 @@ +from guppylang.decorator import guppy +from guppylang.module import GuppyModule + + +module = GuppyModule("test") + + +@guppy.struct(module) +class MyStruct(metaclass=type): + x: int + + +module.compile() diff --git a/tests/error/struct_errors/mutual_recursive.err b/tests/error/struct_errors/mutual_recursive.err new file mode 100644 index 00000000..dccc76ef --- /dev/null +++ b/tests/error/struct_errors/mutual_recursive.err @@ -0,0 +1,7 @@ +Guppy compilation failed. Error in file $FILE:15 + +13: @guppy.struct(module) +14: class StructB: +15: y: "StructA" + ^^^^^^^ +GuppyError: Recursive structs are not supported diff --git a/tests/error/struct_errors/mutual_recursive.py b/tests/error/struct_errors/mutual_recursive.py new file mode 100644 index 00000000..3958457a --- /dev/null +++ b/tests/error/struct_errors/mutual_recursive.py @@ -0,0 +1,18 @@ +from guppylang.decorator import guppy +from guppylang.module import GuppyModule + + +module = GuppyModule("test") + + +@guppy.struct(module) +class StructA: + x: "list[StructB]" + + +@guppy.struct(module) +class StructB: + y: "StructA" + + +module.compile() diff --git a/tests/error/struct_errors/non_guppy_func.err b/tests/error/struct_errors/non_guppy_func.err new file mode 100644 index 00000000..65770dce --- /dev/null +++ b/tests/error/struct_errors/non_guppy_func.err @@ -0,0 +1,7 @@ +Guppy compilation failed. Error in file $FILE:12 + +10: x: int +11: +12: def f(self: "MyStruct") -> None: + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +GuppyError: Add a `@guppy` decorator to this function to add it to the struct `MyStruct` diff --git a/tests/error/struct_errors/non_guppy_func.py b/tests/error/struct_errors/non_guppy_func.py new file mode 100644 index 00000000..7c44c232 --- /dev/null +++ b/tests/error/struct_errors/non_guppy_func.py @@ -0,0 +1,16 @@ +from guppylang.decorator import guppy +from guppylang.module import GuppyModule + + +module = GuppyModule("test") + + +@guppy.struct(module) +class MyStruct: + x: int + + def f(self: "MyStruct") -> None: + pass + + +module.compile() diff --git a/tests/error/struct_errors/recursive.err b/tests/error/struct_errors/recursive.err new file mode 100644 index 00000000..1be8dfe0 --- /dev/null +++ b/tests/error/struct_errors/recursive.err @@ -0,0 +1,7 @@ +Guppy compilation failed. Error in file $FILE:10 + +8: @guppy.struct(module) +9: class MyStruct: +10: x: "tuple[MyStruct, int]" + ^^^^^^^^ +GuppyError: Recursive structs are not supported diff --git a/tests/error/struct_errors/recursive.py b/tests/error/struct_errors/recursive.py new file mode 100644 index 00000000..49d6151d --- /dev/null +++ b/tests/error/struct_errors/recursive.py @@ -0,0 +1,13 @@ +from guppylang.decorator import guppy +from guppylang.module import GuppyModule + + +module = GuppyModule("test") + + +@guppy.struct(module) +class MyStruct: + x: "tuple[MyStruct, int]" + + +module.compile() diff --git a/tests/error/struct_errors/stray_docstring.err b/tests/error/struct_errors/stray_docstring.err new file mode 100644 index 00000000..368aae82 --- /dev/null +++ b/tests/error/struct_errors/stray_docstring.err @@ -0,0 +1,7 @@ +Guppy compilation failed. Error in file $FILE:11 + +9: class MyStruct: +10: x: int +11: """Docstring in wrong position""" + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +GuppyError: Unexpected statement in struct diff --git a/tests/error/struct_errors/stray_docstring.py b/tests/error/struct_errors/stray_docstring.py new file mode 100644 index 00000000..0c173d74 --- /dev/null +++ b/tests/error/struct_errors/stray_docstring.py @@ -0,0 +1,15 @@ +from guppylang.decorator import guppy +from guppylang.module import GuppyModule + + +module = GuppyModule("test") + + +@guppy.struct(module) +class MyStruct: + x: int + """Docstring in wrong position""" + y: bool + + +module.compile() diff --git a/tests/error/struct_errors/type_missing1.err b/tests/error/struct_errors/type_missing1.err new file mode 100644 index 00000000..afff66af --- /dev/null +++ b/tests/error/struct_errors/type_missing1.err @@ -0,0 +1,7 @@ +Guppy compilation failed. Error in file $FILE:13 + +11: @guppy.struct(module) +12: class MyStruct: +13: x + ^ +GuppyError: Unexpected statement in struct diff --git a/tests/error/struct_errors/type_missing1.py b/tests/error/struct_errors/type_missing1.py new file mode 100644 index 00000000..7513b948 --- /dev/null +++ b/tests/error/struct_errors/type_missing1.py @@ -0,0 +1,16 @@ +from guppylang.decorator import guppy +from guppylang.module import GuppyModule + + +module = GuppyModule("test") + + +x = 42 + + +@guppy.struct(module) +class MyStruct: + x + + +module.compile() diff --git a/tests/error/struct_errors/type_missing2.err b/tests/error/struct_errors/type_missing2.err new file mode 100644 index 00000000..f70c330b --- /dev/null +++ b/tests/error/struct_errors/type_missing2.err @@ -0,0 +1,7 @@ +Guppy compilation failed. Error in file $FILE:10 + +8: @guppy.struct(module) +9: class MyStruct: +10: x = 42 + ^^^^^^ +GuppyError: Unexpected statement in struct diff --git a/tests/error/struct_errors/type_missing2.py b/tests/error/struct_errors/type_missing2.py new file mode 100644 index 00000000..f748bab1 --- /dev/null +++ b/tests/error/struct_errors/type_missing2.py @@ -0,0 +1,13 @@ +from guppylang.decorator import guppy +from guppylang.module import GuppyModule + + +module = GuppyModule("test") + + +@guppy.struct(module) +class MyStruct: + x = 42 + + +module.compile() diff --git a/tests/error/test_struct_errors.py b/tests/error/test_struct_errors.py new file mode 100644 index 00000000..3a13c416 --- /dev/null +++ b/tests/error/test_struct_errors.py @@ -0,0 +1,20 @@ +import pathlib +import pytest + +from tests.error.util import run_error_test + +path = pathlib.Path(__file__).parent.resolve() / "struct_errors" +files = [ + x + for x in path.iterdir() + if x.is_file() + if x.suffix == ".py" and x.name != "__init__.py" +] + +# Turn paths into strings, otherwise pytest doesn't display the names +files = [str(f) for f in files] + + +@pytest.mark.parametrize("file", files) +def test_struct_errors(file, capsys): + run_error_test(file, capsys) diff --git a/tests/error/util.py b/tests/error/util.py index c77141e4..c8d1090b 100644 --- a/tests/error/util.py +++ b/tests/error/util.py @@ -12,11 +12,9 @@ def run_error_test(file, capsys): file = pathlib.Path(file) - spec = importlib.util.spec_from_file_location("test_module", file) - py_module = importlib.util.module_from_spec(spec) with pytest.raises(GuppyError): - spec.loader.exec_module(py_module) + importlib.import_module(f"tests.error.{file.parent.name}.{file.name}") err = capsys.readouterr().err diff --git a/tests/integration/test_struct.py b/tests/integration/test_struct.py new file mode 100644 index 00000000..a37c12bc --- /dev/null +++ b/tests/integration/test_struct.py @@ -0,0 +1,97 @@ +from typing import Generic + +from guppylang.decorator import guppy +from guppylang.module import GuppyModule + + +def test_basic_defs(validate): + module = GuppyModule("module") + + @guppy.struct(module) + class EmptyStruct: + pass + + @guppy.struct(module) + class OneMemberStruct: + x: int + + @guppy.struct(module) + class TwoMemberStruct: + x: tuple[bool, int] + y: float + + @guppy.struct(module) + class DocstringStruct: + """This is struct with a docstring!""" + + x: int + + @guppy(module) + def main(a: EmptyStruct, b: OneMemberStruct, c: TwoMemberStruct, d: DocstringStruct) -> None: + pass + + validate(module.compile()) + + +def test_backward_ref(validate): + module = GuppyModule("module") + + @guppy.struct(module) + class StructA: + x: int + + @guppy.struct(module) + class StructB: + y: StructA + + @guppy(module) + def main(a: StructA, b: StructB) -> None: + pass + + validate(module.compile()) + + +def test_forward_ref(validate): + module = GuppyModule("module") + + @guppy.struct(module) + class StructA: + x: "StructB" + + @guppy.struct(module) + class StructB: + y: int + + @guppy(module) + def main(a: StructA, b: StructB) -> None: + pass + + validate(module.compile()) + + +def test_generic(validate): + module = GuppyModule("module") + S = guppy.type_var(module, "S") + T = guppy.type_var(module, "T") + + @guppy.struct(module) + class StructA(Generic[T]): + x: tuple[int, T] + + @guppy.struct(module) + class StructC: + a: StructA[int] + b: StructA[list[bool]] + c: "StructB[float, StructB[bool, int]]" + + @guppy.struct(module) + class StructB(Generic[S, T]): + x: S + y: StructA[T] + + @guppy(module) + def main(a: StructA[StructA[float]], b: StructB[int, bool], c: StructC) -> None: + pass + + validate(module.compile()) +