Skip to content

Commit

Permalink
feat: Add Option type to standard library (#696)
Browse files Browse the repository at this point in the history
Closes #667.

Also generalises the `@guppy.type` decorator to handle generic types.
  • Loading branch information
mark-koch authored Dec 10, 2024
1 parent 78e366b commit 45ea6b7
Show file tree
Hide file tree
Showing 5 changed files with 177 additions and 5 deletions.
17 changes: 13 additions & 4 deletions guppylang/decorator.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import ast
import inspect
from collections.abc import Callable, KeysView
from collections.abc import Callable, KeysView, Sequence
from dataclasses import dataclass, field
from pathlib import Path
from types import ModuleType
Expand Down Expand Up @@ -48,6 +48,8 @@
sphinx_running,
)
from guppylang.span import SourceMap
from guppylang.tys.arg import Argument
from guppylang.tys.param import Parameter
from guppylang.tys.subst import Inst
from guppylang.tys.ty import NumericType

Expand Down Expand Up @@ -197,29 +199,36 @@ def dec(c: type) -> type:
@pretty_errors
def type(
self,
hugr_ty: ht.Type,
hugr_ty: ht.Type | Callable[[Sequence[Argument]], ht.Type],
name: str = "",
linear: bool = False,
bound: ht.TypeBound | None = None,
params: Sequence[Parameter] | None = None,
module: GuppyModule | None = None,
) -> OpaqueTypeDecorator:
"""Decorator to annotate a class definitions as Guppy types.
Requires the static Hugr translation of the type. Additionally, the type can be
marked as linear. All `@guppy` annotated functions on the class are turned into
instance functions.
For non-generic types, the Hugr representation can be passed as a static value.
For generic types, a callable may be passed that takes the type arguments of a
concrete instantiation.
"""
mod = module or self.get_module()
mod._instance_func_buffer = {}

mk_hugr_ty = (lambda _: hugr_ty) if isinstance(hugr_ty, ht.Type) else hugr_ty

def dec(c: type) -> OpaqueTypeDef:
defn = OpaqueTypeDef(
DefId.fresh(mod),
name or c.__name__,
None,
[],
params or [],
linear,
lambda _: hugr_ty,
mk_hugr_ty,
bound,
)
mod.register_def(defn)
Expand Down
8 changes: 7 additions & 1 deletion guppylang/definition/ty.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from abc import abstractmethod
from collections.abc import Callable, Sequence
from dataclasses import dataclass, field
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any

from hugr import tys

Expand Down Expand Up @@ -41,6 +41,12 @@ class OpaqueTypeDef(TypeDef, CompiledDef):
to_hugr: Callable[[Sequence[Argument]], tys.Type]
bound: tys.TypeBound | None = None

def __getitem__(self, item: Any) -> "OpaqueTypeDef":
"""Dummy implementation to allow generic instantiations in type signatures that
are evaluated by the Python interpreter.
"""
return self

def check_instantiate(
self, args: Sequence[Argument], globals: "Globals", loc: AstNode | None = None
) -> OpaqueType:
Expand Down
60 changes: 60 additions & 0 deletions guppylang/std/_internal/compiler/option.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from abc import ABC

from hugr import Wire, ops
from hugr import tys as ht
from hugr import val as hv

from guppylang.definition.custom import CustomCallCompiler, CustomInoutCallCompiler
from guppylang.definition.value import CallReturnWires
from guppylang.error import InternalGuppyError
from guppylang.std._internal.compiler.prelude import build_unwrap
from guppylang.tys.arg import TypeArg


class OptionCompiler(CustomInoutCallCompiler, ABC):
"""Abstract base class for compilers for `Option` methods."""

@property
def option_ty(self) -> ht.Option:
match self.type_args:
case [TypeArg(ty)]:
return ht.Option(ty.to_hugr())
case _:
raise InternalGuppyError("Invalid type args for Option op")


class OptionConstructor(OptionCompiler, CustomCallCompiler):
"""Compiler for the `Option` constructors `nothing` and `some`."""

def __init__(self, tag: int):
self.tag = tag

def compile(self, args: list[Wire]) -> list[Wire]:
return [self.builder.add_op(ops.Tag(self.tag, self.option_ty), *args)]


class OptionTestCompiler(OptionCompiler):
"""Compiler for the `Option.is_nothing` and `Option.is_some` methods."""

def __init__(self, tag: int):
self.tag = tag

def compile_with_inouts(self, args: list[Wire]) -> CallReturnWires:
[opt] = args
cond = self.builder.add_conditional(opt)
for i in [0, 1]:
with cond.add_case(i) as case:
val = hv.TRUE if i == self.tag else hv.FALSE
opt = case.add_op(ops.Tag(i, self.option_ty), *case.inputs())
case.set_outputs(case.load(val), opt)
[res, opt] = cond.outputs()
return CallReturnWires(regular_returns=[res], inout_returns=[opt])


class OptionUnwrapCompiler(OptionCompiler, CustomCallCompiler):
"""Compiler for the `Option.unwrap` method."""

def compile(self, args: list[Wire]) -> list[Wire]:
[opt] = args
err = "Option.unwrap: value is `Nothing`"
return list(build_unwrap(self.builder, opt, err).outputs())
61 changes: 61 additions & 0 deletions guppylang/std/option.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
from collections.abc import Sequence
from typing import Generic, no_type_check

import hugr.tys as ht

from guppylang.decorator import guppy
from guppylang.error import InternalGuppyError
from guppylang.std._internal.compiler.option import (
OptionConstructor,
OptionTestCompiler,
OptionUnwrapCompiler,
)
from guppylang.std.builtins import owned
from guppylang.tys.arg import Argument, TypeArg
from guppylang.tys.param import TypeParam


def _option_to_hugr(args: Sequence[Argument]) -> ht.Type:
match args:
case [TypeArg(ty)]:
return ht.Option(ty.to_hugr())
case _:
raise InternalGuppyError("Invalid type args for Option")


T = guppy.type_var("T", linear=True)


@guppy.type(_option_to_hugr, params=[TypeParam(0, "T", can_be_linear=True)])
class Option(Generic[T]): # type: ignore[misc]
"""Represents an optional value."""

@guppy.custom(OptionTestCompiler(0))
@no_type_check
def is_nothing(self: "Option[T]") -> bool:
"""Returns `True` if the option is a `nothing` value."""

@guppy.custom(OptionTestCompiler(1))
@no_type_check
def is_some(self: "Option[T]") -> bool:
"""Returns `True` if the option is a `some` value."""

@guppy.custom(OptionUnwrapCompiler())
@no_type_check
def unwrap(self: "Option[T]" @ owned) -> T:
"""Returns the contained `some` value, consuming `self`.
Panics if the option is a `nothing` value.
"""


@guppy.custom(OptionConstructor(0))
@no_type_check
def nothing() -> Option[T]:
"""Constructs a `nothing` optional value."""


@guppy.custom(OptionConstructor(1))
@no_type_check
def some(value: T @ owned) -> Option[T]:
"""Constructs a `some` optional value."""
36 changes: 36 additions & 0 deletions tests/integration/test_option.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from guppylang.decorator import guppy
from guppylang.module import GuppyModule
from guppylang.std.option import Option, nothing, some


def test_none(validate, run_int_fn):
module = GuppyModule("test_range")
module.load(Option, nothing)

@guppy(module)
def main() -> int:
x: Option[int] = nothing()
is_none = 10 if x.is_nothing() else 0
is_some = 1 if x.is_some() else 0
return is_none + is_some

compiled = module.compile()
validate(compiled)
run_int_fn(compiled, expected=10)


def test_some_unwrap(validate, run_int_fn):
module = GuppyModule("test_range")
module.load(Option, some)

@guppy(module)
def main() -> int:
x: Option[int] = some(42)
is_none = 1 if x.is_nothing() else 0
is_some = x.unwrap() if x.is_some() else 0
return is_none + is_some

compiled = module.compile()
validate(compiled)
run_int_fn(compiled, expected=42)

0 comments on commit 45ea6b7

Please sign in to comment.