From 45ea6b7086f75f017eb4830f55dca9c87d9f599b Mon Sep 17 00:00:00 2001 From: Mark Koch <48097969+mark-koch@users.noreply.github.com> Date: Tue, 10 Dec 2024 09:33:38 +0000 Subject: [PATCH] feat: Add Option type to standard library (#696) Closes #667. Also generalises the `@guppy.type` decorator to handle generic types. --- guppylang/decorator.py | 17 ++++-- guppylang/definition/ty.py | 8 ++- guppylang/std/_internal/compiler/option.py | 60 +++++++++++++++++++++ guppylang/std/option.py | 61 ++++++++++++++++++++++ tests/integration/test_option.py | 36 +++++++++++++ 5 files changed, 177 insertions(+), 5 deletions(-) create mode 100644 guppylang/std/_internal/compiler/option.py create mode 100644 guppylang/std/option.py create mode 100644 tests/integration/test_option.py diff --git a/guppylang/decorator.py b/guppylang/decorator.py index b0ea43fe..5b7d6361 100644 --- a/guppylang/decorator.py +++ b/guppylang/decorator.py @@ -1,6 +1,6 @@ import ast import inspect -from collections.abc import Callable, KeysView +from collections.abc import Callable, KeysView, Sequence from dataclasses import dataclass, field from pathlib import Path from types import ModuleType @@ -48,6 +48,8 @@ sphinx_running, ) from guppylang.span import SourceMap +from guppylang.tys.arg import Argument +from guppylang.tys.param import Parameter from guppylang.tys.subst import Inst from guppylang.tys.ty import NumericType @@ -197,10 +199,11 @@ def dec(c: type) -> type: @pretty_errors def type( self, - hugr_ty: ht.Type, + hugr_ty: ht.Type | Callable[[Sequence[Argument]], ht.Type], name: str = "", linear: bool = False, bound: ht.TypeBound | None = None, + params: Sequence[Parameter] | None = None, module: GuppyModule | None = None, ) -> OpaqueTypeDecorator: """Decorator to annotate a class definitions as Guppy types. @@ -208,18 +211,24 @@ def type( Requires the static Hugr translation of the type. Additionally, the type can be marked as linear. All `@guppy` annotated functions on the class are turned into instance functions. + + For non-generic types, the Hugr representation can be passed as a static value. + For generic types, a callable may be passed that takes the type arguments of a + concrete instantiation. """ mod = module or self.get_module() mod._instance_func_buffer = {} + mk_hugr_ty = (lambda _: hugr_ty) if isinstance(hugr_ty, ht.Type) else hugr_ty + def dec(c: type) -> OpaqueTypeDef: defn = OpaqueTypeDef( DefId.fresh(mod), name or c.__name__, None, - [], + params or [], linear, - lambda _: hugr_ty, + mk_hugr_ty, bound, ) mod.register_def(defn) diff --git a/guppylang/definition/ty.py b/guppylang/definition/ty.py index e50fe4fd..123b99c6 100644 --- a/guppylang/definition/ty.py +++ b/guppylang/definition/ty.py @@ -1,7 +1,7 @@ from abc import abstractmethod from collections.abc import Callable, Sequence from dataclasses import dataclass, field -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any from hugr import tys @@ -41,6 +41,12 @@ class OpaqueTypeDef(TypeDef, CompiledDef): to_hugr: Callable[[Sequence[Argument]], tys.Type] bound: tys.TypeBound | None = None + def __getitem__(self, item: Any) -> "OpaqueTypeDef": + """Dummy implementation to allow generic instantiations in type signatures that + are evaluated by the Python interpreter. + """ + return self + def check_instantiate( self, args: Sequence[Argument], globals: "Globals", loc: AstNode | None = None ) -> OpaqueType: diff --git a/guppylang/std/_internal/compiler/option.py b/guppylang/std/_internal/compiler/option.py new file mode 100644 index 00000000..3ec5c118 --- /dev/null +++ b/guppylang/std/_internal/compiler/option.py @@ -0,0 +1,60 @@ +from abc import ABC + +from hugr import Wire, ops +from hugr import tys as ht +from hugr import val as hv + +from guppylang.definition.custom import CustomCallCompiler, CustomInoutCallCompiler +from guppylang.definition.value import CallReturnWires +from guppylang.error import InternalGuppyError +from guppylang.std._internal.compiler.prelude import build_unwrap +from guppylang.tys.arg import TypeArg + + +class OptionCompiler(CustomInoutCallCompiler, ABC): + """Abstract base class for compilers for `Option` methods.""" + + @property + def option_ty(self) -> ht.Option: + match self.type_args: + case [TypeArg(ty)]: + return ht.Option(ty.to_hugr()) + case _: + raise InternalGuppyError("Invalid type args for Option op") + + +class OptionConstructor(OptionCompiler, CustomCallCompiler): + """Compiler for the `Option` constructors `nothing` and `some`.""" + + def __init__(self, tag: int): + self.tag = tag + + def compile(self, args: list[Wire]) -> list[Wire]: + return [self.builder.add_op(ops.Tag(self.tag, self.option_ty), *args)] + + +class OptionTestCompiler(OptionCompiler): + """Compiler for the `Option.is_nothing` and `Option.is_some` methods.""" + + def __init__(self, tag: int): + self.tag = tag + + def compile_with_inouts(self, args: list[Wire]) -> CallReturnWires: + [opt] = args + cond = self.builder.add_conditional(opt) + for i in [0, 1]: + with cond.add_case(i) as case: + val = hv.TRUE if i == self.tag else hv.FALSE + opt = case.add_op(ops.Tag(i, self.option_ty), *case.inputs()) + case.set_outputs(case.load(val), opt) + [res, opt] = cond.outputs() + return CallReturnWires(regular_returns=[res], inout_returns=[opt]) + + +class OptionUnwrapCompiler(OptionCompiler, CustomCallCompiler): + """Compiler for the `Option.unwrap` method.""" + + def compile(self, args: list[Wire]) -> list[Wire]: + [opt] = args + err = "Option.unwrap: value is `Nothing`" + return list(build_unwrap(self.builder, opt, err).outputs()) diff --git a/guppylang/std/option.py b/guppylang/std/option.py new file mode 100644 index 00000000..9359e168 --- /dev/null +++ b/guppylang/std/option.py @@ -0,0 +1,61 @@ +from collections.abc import Sequence +from typing import Generic, no_type_check + +import hugr.tys as ht + +from guppylang.decorator import guppy +from guppylang.error import InternalGuppyError +from guppylang.std._internal.compiler.option import ( + OptionConstructor, + OptionTestCompiler, + OptionUnwrapCompiler, +) +from guppylang.std.builtins import owned +from guppylang.tys.arg import Argument, TypeArg +from guppylang.tys.param import TypeParam + + +def _option_to_hugr(args: Sequence[Argument]) -> ht.Type: + match args: + case [TypeArg(ty)]: + return ht.Option(ty.to_hugr()) + case _: + raise InternalGuppyError("Invalid type args for Option") + + +T = guppy.type_var("T", linear=True) + + +@guppy.type(_option_to_hugr, params=[TypeParam(0, "T", can_be_linear=True)]) +class Option(Generic[T]): # type: ignore[misc] + """Represents an optional value.""" + + @guppy.custom(OptionTestCompiler(0)) + @no_type_check + def is_nothing(self: "Option[T]") -> bool: + """Returns `True` if the option is a `nothing` value.""" + + @guppy.custom(OptionTestCompiler(1)) + @no_type_check + def is_some(self: "Option[T]") -> bool: + """Returns `True` if the option is a `some` value.""" + + @guppy.custom(OptionUnwrapCompiler()) + @no_type_check + def unwrap(self: "Option[T]" @ owned) -> T: + """Returns the contained `some` value, consuming `self`. + + Panics if the option is a `nothing` value. + """ + + +@guppy.custom(OptionConstructor(0)) +@no_type_check +def nothing() -> Option[T]: + """Constructs a `nothing` optional value.""" + + +@guppy.custom(OptionConstructor(1)) +@no_type_check +def some(value: T @ owned) -> Option[T]: + """Constructs a `some` optional value.""" diff --git a/tests/integration/test_option.py b/tests/integration/test_option.py new file mode 100644 index 00000000..db491643 --- /dev/null +++ b/tests/integration/test_option.py @@ -0,0 +1,36 @@ +from guppylang.decorator import guppy +from guppylang.module import GuppyModule +from guppylang.std.option import Option, nothing, some + + +def test_none(validate, run_int_fn): + module = GuppyModule("test_range") + module.load(Option, nothing) + + @guppy(module) + def main() -> int: + x: Option[int] = nothing() + is_none = 10 if x.is_nothing() else 0 + is_some = 1 if x.is_some() else 0 + return is_none + is_some + + compiled = module.compile() + validate(compiled) + run_int_fn(compiled, expected=10) + + +def test_some_unwrap(validate, run_int_fn): + module = GuppyModule("test_range") + module.load(Option, some) + + @guppy(module) + def main() -> int: + x: Option[int] = some(42) + is_none = 1 if x.is_nothing() else 0 + is_some = x.unwrap() if x.is_some() else 0 + return is_none + is_some + + compiled = module.compile() + validate(compiled) + run_int_fn(compiled, expected=42) +