From 1c1fa1b1e356e4f480238942440762d9147e6b95 Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Thu, 6 Jun 2024 17:27:49 +0100 Subject: [PATCH] feat(hugr-py): define type hierarchy separate from serialized Closes #1172 --- hugr-py/src/hugr/_dfg.py | 2 +- hugr-py/src/hugr/_hugr.py | 2 +- hugr-py/src/hugr/_ops.py | 25 +-- hugr-py/src/hugr/_tys.py | 271 ++++++++++++++++++++++++++ hugr-py/src/hugr/serialization/ops.py | 15 +- hugr-py/src/hugr/serialization/tys.py | 161 ++++++++++++--- hugr-py/src/hugr/utils.py | 21 +- hugr-py/tests/test_hugr_build.py | 63 +++--- 8 files changed, 485 insertions(+), 75 deletions(-) create mode 100644 hugr-py/src/hugr/_tys.py diff --git a/hugr-py/src/hugr/_dfg.py b/hugr-py/src/hugr/_dfg.py index f083e8ae0..4c4ab3da7 100644 --- a/hugr-py/src/hugr/_dfg.py +++ b/hugr-py/src/hugr/_dfg.py @@ -5,7 +5,7 @@ from ._ops import Op, Command, Input, Output, DFG from ._exceptions import NoSiblingAncestor -from hugr.serialization.tys import FunctionType, Type +from hugr._tys import FunctionType, Type @dataclass() diff --git a/hugr-py/src/hugr/_hugr.py b/hugr-py/src/hugr/_hugr.py index d42f2edf1..1a6632af6 100644 --- a/hugr-py/src/hugr/_hugr.py +++ b/hugr-py/src/hugr/_hugr.py @@ -19,7 +19,7 @@ from hugr.serialization.serial_hugr import SerialHugr from hugr.serialization.ops import OpType as SerialOp -from hugr.serialization.tys import Type +from hugr._tys import Type from hugr._ops import Op from hugr.utils import BiMap diff --git a/hugr-py/src/hugr/_ops.py b/hugr-py/src/hugr/_ops.py index 693c8ef31..9724162dd 100644 --- a/hugr-py/src/hugr/_ops.py +++ b/hugr-py/src/hugr/_ops.py @@ -4,7 +4,8 @@ from typing import Generic, Protocol, TypeVar, TYPE_CHECKING from hugr.serialization.ops import BaseOp import hugr.serialization.ops as sops -import hugr.serialization.tys as tys +from hugr.utils import ser_it +import hugr._tys as tys if TYPE_CHECKING: from hugr._hugr import Hugr, Node, Wire @@ -43,14 +44,14 @@ def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> T: @dataclass() class Input(Op): - types: list[tys.Type] + types: tys.TypeRow @property def num_out(self) -> int | None: return len(self.types) def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.Input: - return sops.Input(parent=parent.idx, types=self.types) + return sops.Input(parent=parent.idx, types=ser_it(self.types)) def __call__(self) -> Command: return super().__call__() @@ -58,10 +59,10 @@ def __call__(self) -> Command: @dataclass() class Output(Op): - types: list[tys.Type] + types: tys.TypeRow def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.Output: - return sops.Output(parent=parent.idx, types=self.types) + return sops.Output(parent=parent.idx, types=ser_it(self.types)) @dataclass() @@ -81,21 +82,21 @@ def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.CustomOp: parent=parent.idx, extension=self.extension, op_name=self.op_name, - signature=self.signature, + signature=self.signature.to_serial(), description=self.description, - args=self.args, + args=ser_it(self.args), ) @dataclass() class MakeTuple(Op): - types: list[tys.Type] + types: tys.TypeRow num_out: int | None = 1 def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.MakeTuple: return sops.MakeTuple( parent=parent.idx, - tys=self.types, + tys=ser_it(self.types), ) def __call__(self, *elements: Wire) -> Command: @@ -104,7 +105,7 @@ def __call__(self, *elements: Wire) -> Command: @dataclass() class UnpackTuple(Op): - types: list[tys.Type] + types: tys.TypeRow @property def num_out(self) -> int | None: @@ -113,7 +114,7 @@ def num_out(self) -> int | None: def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.UnpackTuple: return sops.UnpackTuple( parent=parent.idx, - tys=self.types, + tys=ser_it(self.types), ) def __call__(self, tuple_: Wire) -> Command: @@ -131,5 +132,5 @@ def num_out(self) -> int | None: def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.DFG: return sops.DFG( parent=parent.idx, - signature=self.signature, + signature=self.signature.to_serial(), ) diff --git a/hugr-py/src/hugr/_tys.py b/hugr-py/src/hugr/_tys.py new file mode 100644 index 000000000..7cbd73f2f --- /dev/null +++ b/hugr-py/src/hugr/_tys.py @@ -0,0 +1,271 @@ +from __future__ import annotations +from dataclasses import dataclass, field +import hugr.serialization.tys as stys +from hugr.utils import ser_it +from typing import Any, Protocol, runtime_checkable + +ExtensionId = stys.ExtensionId +ExtensionSet = stys.ExtensionSet +TypeBound = stys.TypeBound + + +class TypeParam(Protocol): + """A type parameter.""" + + def to_serial(self) -> stys.BaseTypeParam: ... + + def to_serial_root(self) -> stys.TypeParam: + return stys.TypeParam(root=self.to_serial()) # type: ignore[arg-type] + + +class TypeArg(Protocol): + """A type argument.""" + + def to_serial(self) -> stys.BaseTypeArg: ... + + def to_serial_root(self) -> stys.TypeArg: + return stys.TypeArg(root=self.to_serial()) # type: ignore[arg-type] + + +@runtime_checkable +class Type(Protocol): + """A type.""" + + def to_serial(self) -> stys.BaseType: ... + + def to_serial_root(self) -> stys.Type: + return stys.Type(root=self.to_serial()) # type: ignore[arg-type] + + +TypeRow = list[Type] + +# -------------------------------------------- +# --------------- TypeParam ------------------ +# -------------------------------------------- + + +@dataclass(frozen=True) +class TypeTypeParam(TypeParam): + bound: TypeBound + + def to_serial(self) -> stys.TypeTypeParam: + return stys.TypeTypeParam(b=self.bound) + + +@dataclass(frozen=True) +class BoundedNatParam(TypeParam): + upper_bound: int | None + + def to_serial(self) -> stys.BoundedNatParam: + return stys.BoundedNatParam(bound=self.upper_bound) + + +@dataclass(frozen=True) +class OpaqueParam(TypeParam): + ty: Opaque + + def to_serial(self) -> stys.OpaqueParam: + return stys.OpaqueParam(ty=self.ty.to_serial()) + + +@dataclass(frozen=True) +class ListParam(TypeParam): + param: TypeParam + + def to_serial(self) -> stys.ListParam: + return stys.ListParam(param=self.param.to_serial_root()) + + +@dataclass(frozen=True) +class TupleParam(TypeParam): + params: list[TypeParam] + + def to_serial(self) -> stys.TupleParam: + return stys.TupleParam(params=ser_it(self.params)) + + +@dataclass(frozen=True) +class ExtensionsParam(TypeParam): + def to_serial(self) -> stys.ExtensionsParam: + return stys.ExtensionsParam() + + +# ------------------------------------------ +# --------------- TypeArg ------------------ +# ------------------------------------------ + + +@dataclass(frozen=True) +class TypeTypeArg(TypeArg): + ty: Type + + def to_serial(self) -> stys.TypeTypeArg: + return stys.TypeTypeArg(ty=self.ty.to_serial_root()) + + +@dataclass(frozen=True) +class BoundedNatArg(TypeArg): + n: int + + def to_serial(self) -> stys.BoundedNatArg: + return stys.BoundedNatArg(n=self.n) + + +@dataclass(frozen=True) +class OpaqueArg(TypeArg): + ty: Opaque + value: Any + + def to_serial(self) -> stys.OpaqueArg: + return stys.OpaqueArg(typ=self.ty.to_serial(), value=self.value) + + +@dataclass(frozen=True) +class SequenceArg(TypeArg): + elems: list[TypeArg] + + def to_serial(self) -> stys.SequenceArg: + return stys.SequenceArg(elems=ser_it(self.elems)) + + +@dataclass(frozen=True) +class ExtensionsArg(TypeArg): + extensions: ExtensionSet + + def to_serial(self) -> stys.ExtensionsArg: + return stys.ExtensionsArg(es=self.extensions) + + +@dataclass(frozen=True) +class VariableArg(TypeArg): + idx: int + param: TypeParam + + def to_serial(self) -> stys.VariableArg: + return stys.VariableArg(idx=self.idx, cached_decl=self.param.to_serial_root()) + + +# ---------------------------------------------- +# --------------- Type ------------------------- +# ---------------------------------------------- + + +@dataclass(frozen=True) +class Array(Type): + ty: Type + size: int + + def to_serial(self) -> stys.Array: + return stys.Array(ty=self.ty.to_serial_root(), len=self.size) + + +@dataclass(frozen=True) +class UnitSum(Type): + size: int + + def to_serial(self) -> stys.UnitSum: + return stys.UnitSum(size=self.size) + + +@dataclass() +class Sum(Type): + variant_rows: list[TypeRow] + + def to_serial(self) -> stys.GeneralSum: + return stys.GeneralSum(rows=[ser_it(row) for row in self.variant_rows]) + + def as_tuple(self) -> Tuple: + assert ( + len(self.variant_rows) == 1 + ), "Sum type must have exactly one row to be converted to a Tuple" + return Tuple(*self.variant_rows[0]) + + +@dataclass() +class Tuple(Sum): + def __init__(self, *tys: Type): + self.variant_rows = [list(tys)] + + +@dataclass(frozen=True) +class Variable(Type): + idx: int + bound: TypeBound + + def to_serial(self) -> stys.Variable: + return stys.Variable(i=self.idx, b=self.bound) + + +@dataclass(frozen=True) +class RowVariable(Type): + idx: int + bound: TypeBound + + def to_serial(self) -> stys.RowVar: + return stys.RowVar(i=self.idx, b=self.bound) + + +@dataclass(frozen=True) +class USize(Type): + def to_serial(self) -> stys.USize: + return stys.USize() + + +@dataclass(frozen=True) +class Alias(Type): + name: str + bound: TypeBound + + def to_serial(self) -> stys.Alias: + return stys.Alias(name=self.name, bound=self.bound) + + +@dataclass(frozen=True) +class FunctionType(Type): + input: list[Type] + output: list[Type] + extension_reqs: ExtensionSet = field(default_factory=ExtensionSet) + + def to_serial(self) -> stys.FunctionType: + return stys.FunctionType(input=ser_it(self.input), output=ser_it(self.output)) + + @classmethod + def empty(cls) -> FunctionType: + return cls(input=[], output=[]) + + +@dataclass(frozen=True) +class PolyFuncType(Type): + params: list[TypeParam] + body: FunctionType + + def to_serial(self) -> stys.PolyFuncType: + return stys.PolyFuncType( + params=[p.to_serial_root() for p in self.params], body=self.body.to_serial() + ) + + +@dataclass +class Opaque(Type): + id: str + bound: TypeBound + args: list[TypeArg] = field(default_factory=list) + extension: ExtensionId = "" + + def to_serial(self) -> stys.Opaque: + return stys.Opaque( + extension=self.extension, + id=self.id, + args=[arg.to_serial_root() for arg in self.args], + bound=self.bound, + ) + + +@dataclass +class QubitDef(Type): + def to_serial(self) -> stys.Qubit: + return stys.Qubit() + + +Qubit = QubitDef() +Bool = UnitSum(size=2) diff --git a/hugr-py/src/hugr/serialization/ops.py b/hugr-py/src/hugr/serialization/ops.py index ed3b74086..e467c8a45 100644 --- a/hugr-py/src/hugr/serialization/ops.py +++ b/hugr-py/src/hugr/serialization/ops.py @@ -20,6 +20,7 @@ classes as tys_classes, model_rebuild as tys_model_rebuild, ) +from hugr.utils import deser_it NodeID = int @@ -215,7 +216,7 @@ def insert_port_types(self, in_types: TypeRow, out_types: TypeRow) -> None: self.types = list(out_types) def deserialize(self) -> _ops.Input: - return _ops.Input(types=self.types) + return _ops.Input(types=[t.deserialize() for t in self.types]) class Output(DataflowOp): @@ -229,7 +230,7 @@ def insert_port_types(self, in_types: TypeRow, out_types: TypeRow) -> None: self.types = list(in_types) def deserialize(self) -> _ops.Output: - return _ops.Output(types=self.types) + return _ops.Output(types=deser_it(self.types)) class Call(DataflowOp): @@ -304,7 +305,7 @@ def insert_child_dfg_signature(self, inputs: TypeRow, outputs: TypeRow) -> None: ) def deserialize(self) -> _ops.DFG: - return _ops.DFG(self.signature) + return _ops.DFG(self.signature.deserialize()) # ------------------------------------------------ @@ -406,8 +407,8 @@ def deserialize(self) -> _ops.Custom: return _ops.Custom( extension=self.extension, op_name=self.op_name, - signature=self.signature, - args=self.args, + signature=self.signature.deserialize(), + args=deser_it(self.args), ) model_config = ConfigDict( @@ -447,7 +448,7 @@ def insert_port_types(self, in_types: TypeRow, out_types: TypeRow) -> None: self.tys = list(in_types) def deserialize(self) -> _ops.MakeTuple: - return _ops.MakeTuple(self.tys) + return _ops.MakeTuple(deser_it(self.tys)) class UnpackTuple(DataflowOp): @@ -460,7 +461,7 @@ def insert_port_types(self, in_types: TypeRow, out_types: TypeRow) -> None: self.tys = list(out_types) def deserialize(self) -> _ops.UnpackTuple: - return _ops.UnpackTuple(self.tys) + return _ops.UnpackTuple(deser_it(self.tys)) class Tag(DataflowOp): diff --git a/hugr-py/src/hugr/serialization/tys.py b/hugr-py/src/hugr/serialization/tys.py index f19e27340..e29b0b76d 100644 --- a/hugr-py/src/hugr/serialization/tys.py +++ b/hugr-py/src/hugr/serialization/tys.py @@ -1,8 +1,12 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod import inspect import sys from enum import Enum from typing import Annotated, Any, Literal, Union, Mapping +from hugr.utils import deser_it from pydantic import ( BaseModel, Field, @@ -57,34 +61,57 @@ def update_model_config(cls, config: ConfigDict): # -------------------------------------------- -class TypeTypeParam(ConfiguredBaseModel): +class BaseTypeParam(ABC, ConfiguredBaseModel): + @abstractmethod + def deserialize(self) -> _tys.TypeParam: ... + + +class TypeTypeParam(BaseTypeParam): tp: Literal["Type"] = "Type" b: "TypeBound" + def deserialize(self) -> _tys.TypeTypeParam: + return _tys.TypeTypeParam(bound=self.b) -class BoundedNatParam(ConfiguredBaseModel): + +class BoundedNatParam(BaseTypeParam): tp: Literal["BoundedNat"] = "BoundedNat" bound: int | None + def deserialize(self) -> _tys.BoundedNatParam: + return _tys.BoundedNatParam(upper_bound=self.bound) + -class OpaqueParam(ConfiguredBaseModel): +class OpaqueParam(BaseTypeParam): tp: Literal["Opaque"] = "Opaque" ty: "Opaque" + def deserialize(self) -> _tys.OpaqueParam: + return _tys.OpaqueParam(ty=self.ty.deserialize()) -class ListParam(ConfiguredBaseModel): + +class ListParam(BaseTypeParam): tp: Literal["List"] = "List" param: "TypeParam" + def deserialize(self) -> _tys.ListParam: + return _tys.ListParam(param=self.param.deserialize()) + -class TupleParam(ConfiguredBaseModel): +class TupleParam(BaseTypeParam): tp: Literal["Tuple"] = "Tuple" params: list["TypeParam"] + def deserialize(self) -> _tys.TupleParam: + return _tys.TupleParam(params=deser_it(self.params)) -class ExtensionsParam(ConfiguredBaseModel): + +class ExtensionsParam(BaseTypeParam): tp: Literal["Extensions"] = "Extensions" + def deserialize(self) -> _tys.ExtensionsParam: + return _tys.ExtensionsParam() + class TypeParam(RootModel): """A type parameter.""" @@ -101,43 +128,69 @@ class TypeParam(RootModel): model_config = ConfigDict(json_schema_extra={"required": ["tp"]}) + def deserialize(self) -> _tys.TypeParam: + return self.root.deserialize() + # ------------------------------------------ # --------------- TypeArg ------------------ # ------------------------------------------ -class TypeTypeArg(ConfiguredBaseModel): +class BaseTypeArg(ABC, ConfiguredBaseModel): + @abstractmethod + def deserialize(self) -> _tys.TypeArg: ... + + +class TypeTypeArg(BaseTypeArg): tya: Literal["Type"] = "Type" ty: "Type" + def deserialize(self) -> _tys.TypeTypeArg: + return _tys.TypeTypeArg(ty=self.ty.deserialize()) -class BoundedNatArg(ConfiguredBaseModel): + +class BoundedNatArg(BaseTypeArg): tya: Literal["BoundedNat"] = "BoundedNat" n: int + def deserialize(self) -> _tys.BoundedNatArg: + return _tys.BoundedNatArg(n=self.n) + -class OpaqueArg(ConfiguredBaseModel): +class OpaqueArg(BaseTypeArg): tya: Literal["Opaque"] = "Opaque" typ: "Opaque" value: Any + def deserialize(self) -> _tys.OpaqueArg: + return _tys.OpaqueArg(ty=self.typ.deserialize(), value=self.value) -class SequenceArg(ConfiguredBaseModel): + +class SequenceArg(BaseTypeArg): tya: Literal["Sequence"] = "Sequence" elems: list["TypeArg"] + def deserialize(self) -> _tys.SequenceArg: + return _tys.SequenceArg(elems=deser_it(self.elems)) + -class ExtensionsArg(ConfiguredBaseModel): +class ExtensionsArg(BaseTypeArg): tya: Literal["Extensions"] = "Extensions" es: ExtensionSet + def deserialize(self) -> _tys.ExtensionsArg: + return _tys.ExtensionsArg(extensions=self.es) -class VariableArg(BaseModel): + +class VariableArg(BaseTypeArg): tya: Literal["Variable"] = "Variable" idx: int cached_decl: TypeParam + def deserialize(self) -> _tys.VariableArg: + return _tys.VariableArg(idx=self.idx, param=self.cached_decl.deserialize()) + class TypeArg(RootModel): """A type argument.""" @@ -154,13 +207,21 @@ class TypeArg(RootModel): model_config = ConfigDict(json_schema_extra={"required": ["tya"]}) + def deserialize(self) -> _tys.TypeArg: + return self.root.deserialize() + # -------------------------------------------- # --------------- Container ------------------ # -------------------------------------------- -class MultiContainer(ConfiguredBaseModel): +class BaseType(ABC, ConfiguredBaseModel): + @abstractmethod + def deserialize(self) -> _tys.Type: ... + + +class MultiContainer(BaseType): ty: "Type" @@ -170,22 +231,31 @@ class Array(MultiContainer): t: Literal["Array"] = "Array" len: int + def deserialize(self) -> _tys.Array: + return _tys.Array(ty=self.ty.deserialize(), size=self.len) -class UnitSum(ConfiguredBaseModel): + +class UnitSum(BaseType): """Simple sum type where all variants are empty tuples.""" t: Literal["Sum"] = "Sum" s: Literal["Unit"] = "Unit" size: int + def deserialize(self) -> _tys.UnitSum: + return _tys.UnitSum(size=self.size) + -class GeneralSum(ConfiguredBaseModel): +class GeneralSum(BaseType): """General sum type that explicitly stores the types of the variants.""" t: Literal["Sum"] = "Sum" s: Literal["General"] = "General" rows: list["TypeRow"] + def deserialize(self) -> _tys.Sum: + return _tys.Sum(variant_rows=[[t.deserialize() for t in r] for r in self.rows]) + class SumType(RootModel): root: Annotated[Union[UnitSum, GeneralSum], Field(discriminator="s")] @@ -197,21 +267,27 @@ def t(self) -> str: model_config = ConfigDict(json_schema_extra={"required": ["s"]}) + def deserialize(self) -> _tys.Sum | _tys.UnitSum: + return self.root.deserialize() + # ---------------------------------------------- # --------------- ClassicType ------------------ # ---------------------------------------------- -class Variable(ConfiguredBaseModel): +class Variable(BaseType): """A type variable identified by an index into the array of TypeParams.""" t: Literal["V"] = "V" i: int b: "TypeBound" + def deserialize(self) -> _tys.Variable: + return _tys.Variable(idx=self.i, bound=self.b) -class RowVar(ConfiguredBaseModel): + +class RowVar(BaseType): """A variable standing for a row of some (unknown) number of types. May occur only within a row; not a node input/output.""" @@ -219,14 +295,20 @@ class RowVar(ConfiguredBaseModel): i: int b: "TypeBound" + def deserialize(self) -> _tys.RowVariable: + return _tys.RowVariable(idx=self.i, bound=self.b) + -class USize(ConfiguredBaseModel): +class USize(BaseType): """Unsigned integer size type.""" t: Literal["I"] = "I" + def deserialize(self) -> _tys.USize: + return _tys.USize() -class FunctionType(ConfiguredBaseModel): + +class FunctionType(BaseType): """A graph encoded as a value. It contains a concrete signature and a set of required resources.""" @@ -241,6 +323,13 @@ class FunctionType(ConfiguredBaseModel): def empty(cls) -> "FunctionType": return FunctionType(input=[], output=[], extension_reqs=[]) + def deserialize(self) -> _tys.FunctionType: + return _tys.FunctionType( + input=deser_it(self.input), + output=deser_it(self.output), + extension_reqs=self.extension_reqs, + ) + model_config = ConfigDict( # Needed to avoid random '\n's in the pydantic description json_schema_extra={ @@ -252,7 +341,7 @@ def empty(cls) -> "FunctionType": ) -class PolyFuncType(ConfiguredBaseModel): +class PolyFuncType(BaseType): """A polymorphic type scheme, i.e. of a FuncDecl, FuncDefn or OpDef. (Nodes/operations in the Hugr are not polymorphic.)""" @@ -268,6 +357,12 @@ class PolyFuncType(ConfiguredBaseModel): def empty(cls) -> "PolyFuncType": return PolyFuncType(params=[], body=FunctionType.empty()) + def deserialize(self) -> _tys.PolyFuncType: + return _tys.PolyFuncType( + params=deser_it(self.params), + body=self.body.deserialize(), + ) + model_config = ConfigDict( # Needed to avoid random '\n's in the pydantic description json_schema_extra={ @@ -296,7 +391,7 @@ def join(*bs: "TypeBound") -> "TypeBound": return res -class Opaque(ConfiguredBaseModel): +class Opaque(BaseType): """An opaque Type that can be downcasted by the extensions that define it.""" t: Literal["Opaque"] = "Opaque" @@ -305,25 +400,39 @@ class Opaque(ConfiguredBaseModel): args: list[TypeArg] bound: TypeBound + def deserialize(self) -> _tys.Opaque: + return _tys.Opaque( + extension=self.extension, + id=self.id, + args=deser_it(self.args), + bound=self.bound, + ) -class Alias(ConfiguredBaseModel): + +class Alias(BaseType): """An Alias Type""" t: Literal["Alias"] = "Alias" bound: TypeBound name: str + def deserialize(self) -> _tys.Alias: + return _tys.Alias(name=self.name, bound=self.bound) + # ---------------------------------------------- # --------------- LinearType ------------------- # ---------------------------------------------- -class Qubit(ConfiguredBaseModel): +class Qubit(BaseType): """A qubit.""" t: Literal["Q"] = "Q" + def deserialize(self) -> _tys.QubitDef: + return _tys.Qubit + class Type(RootModel): """A HUGR type.""" @@ -344,6 +453,9 @@ class Type(RootModel): model_config = ConfigDict(json_schema_extra={"required": ["t"]}) + def deserialize(self) -> _tys.Type: + return self.root.deserialize() + # ------------------------------------------- # --------------- TypeRow ------------------- @@ -391,3 +503,6 @@ def model_rebuild( model_rebuild(dict(classes)) + + +from hugr import _tys # noqa: E402 # needed to avoid circular imports diff --git a/hugr-py/src/hugr/utils.py b/hugr-py/src/hugr/utils.py index c4bd21124..dd4bdae1a 100644 --- a/hugr-py/src/hugr/utils.py +++ b/hugr-py/src/hugr/utils.py @@ -1,6 +1,6 @@ from collections.abc import Hashable, ItemsView, MutableMapping from dataclasses import dataclass, field -from typing import Generic, TypeVar +from typing import Generic, Iterable, Protocol, TypeVar L = TypeVar("L", bound=Hashable) @@ -54,3 +54,22 @@ def delete_left(self, key: L) -> None: def delete_right(self, key: R) -> None: del self.fwd[self.bck[key]] del self.bck[key] + + +S = TypeVar("S", covariant=True) + + +class SerCollection(Protocol[S]): + def to_serial_root(self) -> S: ... + + +class DeserCollection(Protocol[S]): + def deserialize(self) -> S: ... + + +def ser_it(it: Iterable[SerCollection[S]]) -> list[S]: + return [v.to_serial_root() for v in it] + + +def deser_it(it: Iterable[DeserCollection[S]]) -> list[S]: + return [v.deserialize() for v in it] diff --git a/hugr-py/tests/test_hugr_build.py b/hugr-py/tests/test_hugr_build.py index 52f7a2b07..71bc424ce 100644 --- a/hugr-py/tests/test_hugr_build.py +++ b/hugr-py/tests/test_hugr_build.py @@ -8,26 +8,26 @@ from hugr._ops import Custom, Command import hugr._ops as ops from hugr.serialization import SerialHugr -import hugr.serialization.tys as stys +import hugr._tys as tys import pytest import json -BOOL_T = stys.Type(stys.SumType(stys.UnitSum(size=2))) -QB_T = stys.Type(stys.Qubit()) -ARG_5 = stys.TypeArg(stys.BoundedNatArg(n=5)) -INT_T = stys.Type( - stys.Opaque( + +def int_t(width: int) -> tys.Opaque: + return tys.Opaque( extension="arithmetic.int.types", id="int", - args=[ARG_5], - bound=stys.TypeBound.Eq, + args=[tys.BoundedNatArg(n=width)], + bound=tys.TypeBound.Eq, ) -) + + +INT_T = int_t(5) @dataclass class LogicOps(Custom): - extension: stys.ExtensionId = "logic" + extension: tys.ExtensionId = "logic" # TODO get from YAML @@ -35,8 +35,8 @@ class LogicOps(Custom): class NotDef(LogicOps): num_out: int | None = 1 op_name: str = "Not" - signature: stys.FunctionType = field( - default_factory=lambda: stys.FunctionType(input=[BOOL_T], output=[BOOL_T]) + signature: tys.FunctionType = field( + default_factory=lambda: tys.FunctionType(input=[tys.Bool], output=[tys.Bool]) ) def __call__(self, a: Wire) -> Command: @@ -48,18 +48,21 @@ def __call__(self, a: Wire) -> Command: @dataclass class IntOps(Custom): - extension: stys.ExtensionId = "arithmetic.int" + extension: tys.ExtensionId = "arithmetic.int" + + +ARG_5 = tys.BoundedNatArg(n=5) @dataclass class DivModDef(IntOps): num_out: int | None = 2 - extension: stys.ExtensionId = "arithmetic.int" + extension: tys.ExtensionId = "arithmetic.int" op_name: str = "idivmod_u" - signature: stys.FunctionType = field( - default_factory=lambda: stys.FunctionType(input=[INT_T] * 2, output=[INT_T] * 2) + signature: tys.FunctionType = field( + default_factory=lambda: tys.FunctionType(input=[INT_T] * 2, output=[INT_T] * 2) ) - args: list[stys.TypeArg] = field(default_factory=lambda: [ARG_5, ARG_5]) + args: list[tys.TypeArg] = field(default_factory=lambda: [ARG_5, ARG_5]) DivMod = DivModDef() @@ -115,7 +118,7 @@ def test_stable_indices(): def test_simple_id(): - h = Dfg.endo([QB_T] * 2) + h = Dfg.endo([tys.Qubit] * 2) a, b = h.inputs() h.set_outputs(a, b) @@ -123,7 +126,7 @@ def test_simple_id(): def test_multiport(): - h = Dfg([BOOL_T], [BOOL_T] * 2) + h = Dfg([tys.Bool], [tys.Bool] * 2) (a,) = h.inputs() h.set_outputs(a, a) in_n, ou_n = h.input_node, h.output_node @@ -146,7 +149,7 @@ def test_multiport(): def test_add_op(): - h = Dfg.endo([BOOL_T]) + h = Dfg.endo([tys.Bool]) (a,) = h.inputs() nt = h.add_op(Not, a) h.set_outputs(nt) @@ -155,7 +158,7 @@ def test_add_op(): def test_tuple(): - row = [BOOL_T, QB_T] + row = [tys.Bool, tys.Qubit] h = Dfg.endo(row) a, b = h.inputs() t = h.add(ops.MakeTuple(row)(a, b)) @@ -182,7 +185,7 @@ def test_multi_out(): def test_insert(): - h1 = Dfg.endo([BOOL_T]) + h1 = Dfg.endo([tys.Bool]) (a1,) = h1.inputs() nt = h1.add(Not(a1)) h1.set_outputs(nt) @@ -195,12 +198,12 @@ def test_insert(): def test_insert_nested(): - h1 = Dfg.endo([BOOL_T]) + h1 = Dfg.endo([tys.Bool]) (a1,) = h1.inputs() nt = h1.add(Not(a1)) h1.set_outputs(nt) - h = Dfg.endo([BOOL_T]) + h = Dfg.endo([tys.Bool]) (a,) = h.inputs() nested = h.insert_nested(h1, a) h.set_outputs(nested) @@ -214,9 +217,9 @@ def _nested_nop(dfg: Dfg): nt = dfg.add(Not(a1)) dfg.set_outputs(nt) - h = Dfg.endo([BOOL_T]) + h = Dfg.endo([tys.Bool]) (a,) = h.inputs() - nested = h.add_nested([BOOL_T], [BOOL_T], a) + nested = h.add_nested([tys.Bool], [tys.Bool], a) _nested_nop(nested) @@ -226,9 +229,9 @@ def _nested_nop(dfg: Dfg): def test_build_inter_graph(): - h = Dfg.endo([BOOL_T, BOOL_T]) + h = Dfg.endo([tys.Bool, tys.Bool]) (a, b) = h.inputs() - nested = h.add_nested([], [BOOL_T]) + nested = h.add_nested([], [tys.Bool]) nt = nested.add(Not(a)) nested.set_outputs(nt) @@ -245,9 +248,9 @@ def test_build_inter_graph(): def test_ancestral_sibling(): - h = Dfg.endo([BOOL_T]) + h = Dfg.endo([tys.Bool]) (a,) = h.inputs() - nested = h.add_nested([], [BOOL_T]) + nested = h.add_nested([], [tys.Bool]) nt = nested.add(Not(a))