Skip to content

Commit

Permalink
feat(hugr-py): define type hierarchy separate from serialized (#1176)
Browse files Browse the repository at this point in the history
Closes #1172

Changes in first commit.
Note second commit is just running the schema update, which appears to
be trivial and non-breaking.

Coverage not great because this hierarchy is comprehensive (unlike the
Ops) but not all are tested. Deemed not worth adding a load of boiler
plate tests for now.
  • Loading branch information
ss2165 authored Jun 14, 2024
1 parent cbe31be commit 10f4c42
Show file tree
Hide file tree
Showing 13 changed files with 648 additions and 253 deletions.
2 changes: 1 addition & 1 deletion hugr-py/src/hugr/_dfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from ._ops import Op, Command, Input, Output, DFG
from ._exceptions import NoSiblingAncestor
from hugr.serialization.tys import FunctionType, Type
from hugr._tys import FunctionType, Type


@dataclass()
Expand Down
2 changes: 1 addition & 1 deletion hugr-py/src/hugr/_hugr.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@
from typing_extensions import Self

from hugr._ops import Op
from hugr._tys import Type
from hugr.serialization.ops import OpType as SerialOp
from hugr.serialization.serial_hugr import SerialHugr
from hugr.serialization.tys import Type
from hugr.utils import BiMap

from ._exceptions import ParentBeforeChild
Expand Down
25 changes: 13 additions & 12 deletions hugr-py/src/hugr/_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
from typing import Generic, Protocol, TypeVar, TYPE_CHECKING
from hugr.serialization.ops import BaseOp
import hugr.serialization.ops as sops
import hugr.serialization.tys as tys
from hugr.utils import ser_it
import hugr._tys as tys

if TYPE_CHECKING:
from hugr._hugr import Hugr, Node, Wire
Expand Down Expand Up @@ -43,25 +44,25 @@ def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> T:

@dataclass()
class Input(Op):
types: list[tys.Type]
types: tys.TypeRow

@property
def num_out(self) -> int | None:
return len(self.types)

def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.Input:
return sops.Input(parent=parent.idx, types=self.types)
return sops.Input(parent=parent.idx, types=ser_it(self.types))

def __call__(self) -> Command:
return super().__call__()


@dataclass()
class Output(Op):
types: list[tys.Type]
types: tys.TypeRow

def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.Output:
return sops.Output(parent=parent.idx, types=self.types)
return sops.Output(parent=parent.idx, types=ser_it(self.types))


@dataclass()
Expand All @@ -81,21 +82,21 @@ def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.CustomOp:
parent=parent.idx,
extension=self.extension,
op_name=self.op_name,
signature=self.signature,
signature=self.signature.to_serial(),
description=self.description,
args=self.args,
args=ser_it(self.args),
)


@dataclass()
class MakeTuple(Op):
types: list[tys.Type]
types: tys.TypeRow
num_out: int | None = 1

def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.MakeTuple:
return sops.MakeTuple(
parent=parent.idx,
tys=self.types,
tys=ser_it(self.types),
)

def __call__(self, *elements: Wire) -> Command:
Expand All @@ -104,7 +105,7 @@ def __call__(self, *elements: Wire) -> Command:

@dataclass()
class UnpackTuple(Op):
types: list[tys.Type]
types: tys.TypeRow

@property
def num_out(self) -> int | None:
Expand All @@ -113,7 +114,7 @@ def num_out(self) -> int | None:
def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.UnpackTuple:
return sops.UnpackTuple(
parent=parent.idx,
tys=self.types,
tys=ser_it(self.types),
)

def __call__(self, tuple_: Wire) -> Command:
Expand All @@ -131,5 +132,5 @@ def num_out(self) -> int | None:
def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.DFG:
return sops.DFG(
parent=parent.idx,
signature=self.signature,
signature=self.signature.to_serial(),
)
271 changes: 271 additions & 0 deletions hugr-py/src/hugr/_tys.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,271 @@
from __future__ import annotations
from dataclasses import dataclass, field
import hugr.serialization.tys as stys
from hugr.utils import ser_it
from typing import Any, Protocol, runtime_checkable

ExtensionId = stys.ExtensionId
ExtensionSet = stys.ExtensionSet
TypeBound = stys.TypeBound


class TypeParam(Protocol):
"""A type parameter."""

def to_serial(self) -> stys.BaseTypeParam: ...

def to_serial_root(self) -> stys.TypeParam:
return stys.TypeParam(root=self.to_serial()) # type: ignore[arg-type]


class TypeArg(Protocol):
"""A type argument."""

def to_serial(self) -> stys.BaseTypeArg: ...

def to_serial_root(self) -> stys.TypeArg:
return stys.TypeArg(root=self.to_serial()) # type: ignore[arg-type]


@runtime_checkable
class Type(Protocol):
"""A type."""

def to_serial(self) -> stys.BaseType: ...

def to_serial_root(self) -> stys.Type:
return stys.Type(root=self.to_serial()) # type: ignore[arg-type]


TypeRow = list[Type]

# --------------------------------------------
# --------------- TypeParam ------------------
# --------------------------------------------


@dataclass(frozen=True)
class TypeTypeParam(TypeParam):
bound: TypeBound

def to_serial(self) -> stys.TypeTypeParam:
return stys.TypeTypeParam(b=self.bound)


@dataclass(frozen=True)
class BoundedNatParam(TypeParam):
upper_bound: int | None

def to_serial(self) -> stys.BoundedNatParam:
return stys.BoundedNatParam(bound=self.upper_bound)


@dataclass(frozen=True)
class OpaqueParam(TypeParam):
ty: Opaque

def to_serial(self) -> stys.OpaqueParam:
return stys.OpaqueParam(ty=self.ty.to_serial())


@dataclass(frozen=True)
class ListParam(TypeParam):
param: TypeParam

def to_serial(self) -> stys.ListParam:
return stys.ListParam(param=self.param.to_serial_root())


@dataclass(frozen=True)
class TupleParam(TypeParam):
params: list[TypeParam]

def to_serial(self) -> stys.TupleParam:
return stys.TupleParam(params=ser_it(self.params))


@dataclass(frozen=True)
class ExtensionsParam(TypeParam):
def to_serial(self) -> stys.ExtensionsParam:
return stys.ExtensionsParam()


# ------------------------------------------
# --------------- TypeArg ------------------
# ------------------------------------------


@dataclass(frozen=True)
class TypeTypeArg(TypeArg):
ty: Type

def to_serial(self) -> stys.TypeTypeArg:
return stys.TypeTypeArg(ty=self.ty.to_serial_root())


@dataclass(frozen=True)
class BoundedNatArg(TypeArg):
n: int

def to_serial(self) -> stys.BoundedNatArg:
return stys.BoundedNatArg(n=self.n)


@dataclass(frozen=True)
class OpaqueArg(TypeArg):
ty: Opaque
value: Any

def to_serial(self) -> stys.OpaqueArg:
return stys.OpaqueArg(typ=self.ty.to_serial(), value=self.value)


@dataclass(frozen=True)
class SequenceArg(TypeArg):
elems: list[TypeArg]

def to_serial(self) -> stys.SequenceArg:
return stys.SequenceArg(elems=ser_it(self.elems))


@dataclass(frozen=True)
class ExtensionsArg(TypeArg):
extensions: ExtensionSet

def to_serial(self) -> stys.ExtensionsArg:
return stys.ExtensionsArg(es=self.extensions)


@dataclass(frozen=True)
class VariableArg(TypeArg):
idx: int
param: TypeParam

def to_serial(self) -> stys.VariableArg:
return stys.VariableArg(idx=self.idx, cached_decl=self.param.to_serial_root())


# ----------------------------------------------
# --------------- Type -------------------------
# ----------------------------------------------


@dataclass(frozen=True)
class Array(Type):
ty: Type
size: int

def to_serial(self) -> stys.Array:
return stys.Array(ty=self.ty.to_serial_root(), len=self.size)


@dataclass(frozen=True)
class UnitSum(Type):
size: int

def to_serial(self) -> stys.UnitSum:
return stys.UnitSum(size=self.size)


@dataclass()
class Sum(Type):
variant_rows: list[TypeRow]

def to_serial(self) -> stys.GeneralSum:
return stys.GeneralSum(rows=[ser_it(row) for row in self.variant_rows])

def as_tuple(self) -> Tuple:
assert (
len(self.variant_rows) == 1
), "Sum type must have exactly one row to be converted to a Tuple"
return Tuple(*self.variant_rows[0])


@dataclass()
class Tuple(Sum):
def __init__(self, *tys: Type):
self.variant_rows = [list(tys)]


@dataclass(frozen=True)
class Variable(Type):
idx: int
bound: TypeBound

def to_serial(self) -> stys.Variable:
return stys.Variable(i=self.idx, b=self.bound)


@dataclass(frozen=True)
class RowVariable(Type):
idx: int
bound: TypeBound

def to_serial(self) -> stys.RowVar:
return stys.RowVar(i=self.idx, b=self.bound)


@dataclass(frozen=True)
class USize(Type):
def to_serial(self) -> stys.USize:
return stys.USize()


@dataclass(frozen=True)
class Alias(Type):
name: str
bound: TypeBound

def to_serial(self) -> stys.Alias:
return stys.Alias(name=self.name, bound=self.bound)


@dataclass(frozen=True)
class FunctionType(Type):
input: list[Type]
output: list[Type]
extension_reqs: ExtensionSet = field(default_factory=ExtensionSet)

def to_serial(self) -> stys.FunctionType:
return stys.FunctionType(input=ser_it(self.input), output=ser_it(self.output))

@classmethod
def empty(cls) -> FunctionType:
return cls(input=[], output=[])


@dataclass(frozen=True)
class PolyFuncType(Type):
params: list[TypeParam]
body: FunctionType

def to_serial(self) -> stys.PolyFuncType:
return stys.PolyFuncType(
params=[p.to_serial_root() for p in self.params], body=self.body.to_serial()
)


@dataclass
class Opaque(Type):
id: str
bound: TypeBound
args: list[TypeArg] = field(default_factory=list)
extension: ExtensionId = ""

def to_serial(self) -> stys.Opaque:
return stys.Opaque(
extension=self.extension,
id=self.id,
args=[arg.to_serial_root() for arg in self.args],
bound=self.bound,
)


@dataclass
class QubitDef(Type):
def to_serial(self) -> stys.Qubit:
return stys.Qubit()


Qubit = QubitDef()
Bool = UnitSum(size=2)
Loading

0 comments on commit 10f4c42

Please sign in to comment.