Skip to content

Commit

Permalink
feat(hugr-py): define type hierarchy separate from serialized
Browse files Browse the repository at this point in the history
Closes #1172
  • Loading branch information
ss2165 committed Jun 6, 2024
1 parent 5da06e1 commit 1c1fa1b
Show file tree
Hide file tree
Showing 8 changed files with 485 additions and 75 deletions.
2 changes: 1 addition & 1 deletion hugr-py/src/hugr/_dfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from ._ops import Op, Command, Input, Output, DFG
from ._exceptions import NoSiblingAncestor
from hugr.serialization.tys import FunctionType, Type
from hugr._tys import FunctionType, Type


@dataclass()
Expand Down
2 changes: 1 addition & 1 deletion hugr-py/src/hugr/_hugr.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from hugr.serialization.serial_hugr import SerialHugr
from hugr.serialization.ops import OpType as SerialOp
from hugr.serialization.tys import Type
from hugr._tys import Type
from hugr._ops import Op
from hugr.utils import BiMap

Expand Down
25 changes: 13 additions & 12 deletions hugr-py/src/hugr/_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
from typing import Generic, Protocol, TypeVar, TYPE_CHECKING
from hugr.serialization.ops import BaseOp
import hugr.serialization.ops as sops
import hugr.serialization.tys as tys
from hugr.utils import ser_it
import hugr._tys as tys

if TYPE_CHECKING:
from hugr._hugr import Hugr, Node, Wire
Expand Down Expand Up @@ -43,25 +44,25 @@ def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> T:

@dataclass()
class Input(Op):
types: list[tys.Type]
types: tys.TypeRow

@property
def num_out(self) -> int | None:
return len(self.types)

def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.Input:
return sops.Input(parent=parent.idx, types=self.types)
return sops.Input(parent=parent.idx, types=ser_it(self.types))

def __call__(self) -> Command:
return super().__call__()


@dataclass()
class Output(Op):
types: list[tys.Type]
types: tys.TypeRow

def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.Output:
return sops.Output(parent=parent.idx, types=self.types)
return sops.Output(parent=parent.idx, types=ser_it(self.types))


@dataclass()
Expand All @@ -81,21 +82,21 @@ def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.CustomOp:
parent=parent.idx,
extension=self.extension,
op_name=self.op_name,
signature=self.signature,
signature=self.signature.to_serial(),
description=self.description,
args=self.args,
args=ser_it(self.args),
)


@dataclass()
class MakeTuple(Op):
types: list[tys.Type]
types: tys.TypeRow
num_out: int | None = 1

def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.MakeTuple:
return sops.MakeTuple(
parent=parent.idx,
tys=self.types,
tys=ser_it(self.types),
)

def __call__(self, *elements: Wire) -> Command:
Expand All @@ -104,7 +105,7 @@ def __call__(self, *elements: Wire) -> Command:

@dataclass()
class UnpackTuple(Op):
types: list[tys.Type]
types: tys.TypeRow

@property
def num_out(self) -> int | None:
Expand All @@ -113,7 +114,7 @@ def num_out(self) -> int | None:
def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.UnpackTuple:
return sops.UnpackTuple(
parent=parent.idx,
tys=self.types,
tys=ser_it(self.types),
)

def __call__(self, tuple_: Wire) -> Command:
Expand All @@ -131,5 +132,5 @@ def num_out(self) -> int | None:
def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.DFG:
return sops.DFG(
parent=parent.idx,
signature=self.signature,
signature=self.signature.to_serial(),
)
271 changes: 271 additions & 0 deletions hugr-py/src/hugr/_tys.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,271 @@
from __future__ import annotations
from dataclasses import dataclass, field
import hugr.serialization.tys as stys
from hugr.utils import ser_it
from typing import Any, Protocol, runtime_checkable

ExtensionId = stys.ExtensionId
ExtensionSet = stys.ExtensionSet
TypeBound = stys.TypeBound


class TypeParam(Protocol):
"""A type parameter."""

def to_serial(self) -> stys.BaseTypeParam: ...

def to_serial_root(self) -> stys.TypeParam:
return stys.TypeParam(root=self.to_serial()) # type: ignore[arg-type]

Check warning on line 18 in hugr-py/src/hugr/_tys.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/_tys.py#L18

Added line #L18 was not covered by tests


class TypeArg(Protocol):
"""A type argument."""

def to_serial(self) -> stys.BaseTypeArg: ...

def to_serial_root(self) -> stys.TypeArg:
return stys.TypeArg(root=self.to_serial()) # type: ignore[arg-type]


@runtime_checkable
class Type(Protocol):
"""A type."""

def to_serial(self) -> stys.BaseType: ...

def to_serial_root(self) -> stys.Type:
return stys.Type(root=self.to_serial()) # type: ignore[arg-type]


TypeRow = list[Type]

# --------------------------------------------
# --------------- TypeParam ------------------
# --------------------------------------------


@dataclass(frozen=True)
class TypeTypeParam(TypeParam):
bound: TypeBound

def to_serial(self) -> stys.TypeTypeParam:
return stys.TypeTypeParam(b=self.bound)

Check warning on line 52 in hugr-py/src/hugr/_tys.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/_tys.py#L52

Added line #L52 was not covered by tests


@dataclass(frozen=True)
class BoundedNatParam(TypeParam):
upper_bound: int | None

def to_serial(self) -> stys.BoundedNatParam:
return stys.BoundedNatParam(bound=self.upper_bound)

Check warning on line 60 in hugr-py/src/hugr/_tys.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/_tys.py#L60

Added line #L60 was not covered by tests


@dataclass(frozen=True)
class OpaqueParam(TypeParam):
ty: Opaque

def to_serial(self) -> stys.OpaqueParam:
return stys.OpaqueParam(ty=self.ty.to_serial())

Check warning on line 68 in hugr-py/src/hugr/_tys.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/_tys.py#L68

Added line #L68 was not covered by tests


@dataclass(frozen=True)
class ListParam(TypeParam):
param: TypeParam

def to_serial(self) -> stys.ListParam:
return stys.ListParam(param=self.param.to_serial_root())

Check warning on line 76 in hugr-py/src/hugr/_tys.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/_tys.py#L76

Added line #L76 was not covered by tests


@dataclass(frozen=True)
class TupleParam(TypeParam):
params: list[TypeParam]

def to_serial(self) -> stys.TupleParam:
return stys.TupleParam(params=ser_it(self.params))

Check warning on line 84 in hugr-py/src/hugr/_tys.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/_tys.py#L84

Added line #L84 was not covered by tests


@dataclass(frozen=True)
class ExtensionsParam(TypeParam):
def to_serial(self) -> stys.ExtensionsParam:
return stys.ExtensionsParam()

Check warning on line 90 in hugr-py/src/hugr/_tys.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/_tys.py#L90

Added line #L90 was not covered by tests


# ------------------------------------------
# --------------- TypeArg ------------------
# ------------------------------------------


@dataclass(frozen=True)
class TypeTypeArg(TypeArg):
ty: Type

def to_serial(self) -> stys.TypeTypeArg:
return stys.TypeTypeArg(ty=self.ty.to_serial_root())

Check warning on line 103 in hugr-py/src/hugr/_tys.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/_tys.py#L103

Added line #L103 was not covered by tests


@dataclass(frozen=True)
class BoundedNatArg(TypeArg):
n: int

def to_serial(self) -> stys.BoundedNatArg:
return stys.BoundedNatArg(n=self.n)


@dataclass(frozen=True)
class OpaqueArg(TypeArg):
ty: Opaque
value: Any

def to_serial(self) -> stys.OpaqueArg:
return stys.OpaqueArg(typ=self.ty.to_serial(), value=self.value)

Check warning on line 120 in hugr-py/src/hugr/_tys.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/_tys.py#L120

Added line #L120 was not covered by tests


@dataclass(frozen=True)
class SequenceArg(TypeArg):
elems: list[TypeArg]

def to_serial(self) -> stys.SequenceArg:
return stys.SequenceArg(elems=ser_it(self.elems))

Check warning on line 128 in hugr-py/src/hugr/_tys.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/_tys.py#L128

Added line #L128 was not covered by tests


@dataclass(frozen=True)
class ExtensionsArg(TypeArg):
extensions: ExtensionSet

def to_serial(self) -> stys.ExtensionsArg:
return stys.ExtensionsArg(es=self.extensions)

Check warning on line 136 in hugr-py/src/hugr/_tys.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/_tys.py#L136

Added line #L136 was not covered by tests


@dataclass(frozen=True)
class VariableArg(TypeArg):
idx: int
param: TypeParam

def to_serial(self) -> stys.VariableArg:
return stys.VariableArg(idx=self.idx, cached_decl=self.param.to_serial_root())

Check warning on line 145 in hugr-py/src/hugr/_tys.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/_tys.py#L145

Added line #L145 was not covered by tests


# ----------------------------------------------
# --------------- Type -------------------------
# ----------------------------------------------


@dataclass(frozen=True)
class Array(Type):
ty: Type
size: int

def to_serial(self) -> stys.Array:
return stys.Array(ty=self.ty.to_serial_root(), len=self.size)

Check warning on line 159 in hugr-py/src/hugr/_tys.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/_tys.py#L159

Added line #L159 was not covered by tests


@dataclass(frozen=True)
class UnitSum(Type):
size: int

def to_serial(self) -> stys.UnitSum:
return stys.UnitSum(size=self.size)


@dataclass()
class Sum(Type):
variant_rows: list[TypeRow]

def to_serial(self) -> stys.GeneralSum:
return stys.GeneralSum(rows=[ser_it(row) for row in self.variant_rows])

Check warning on line 175 in hugr-py/src/hugr/_tys.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/_tys.py#L175

Added line #L175 was not covered by tests

def as_tuple(self) -> Tuple:
assert (

Check warning on line 178 in hugr-py/src/hugr/_tys.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/_tys.py#L178

Added line #L178 was not covered by tests
len(self.variant_rows) == 1
), "Sum type must have exactly one row to be converted to a Tuple"
return Tuple(*self.variant_rows[0])

Check warning on line 181 in hugr-py/src/hugr/_tys.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/_tys.py#L181

Added line #L181 was not covered by tests


@dataclass()
class Tuple(Sum):
def __init__(self, *tys: Type):
self.variant_rows = [list(tys)]

Check warning on line 187 in hugr-py/src/hugr/_tys.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/_tys.py#L187

Added line #L187 was not covered by tests


@dataclass(frozen=True)
class Variable(Type):
idx: int
bound: TypeBound

def to_serial(self) -> stys.Variable:
return stys.Variable(i=self.idx, b=self.bound)

Check warning on line 196 in hugr-py/src/hugr/_tys.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/_tys.py#L196

Added line #L196 was not covered by tests


@dataclass(frozen=True)
class RowVariable(Type):
idx: int
bound: TypeBound

def to_serial(self) -> stys.RowVar:
return stys.RowVar(i=self.idx, b=self.bound)

Check warning on line 205 in hugr-py/src/hugr/_tys.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/_tys.py#L205

Added line #L205 was not covered by tests


@dataclass(frozen=True)
class USize(Type):
def to_serial(self) -> stys.USize:
return stys.USize()

Check warning on line 211 in hugr-py/src/hugr/_tys.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/_tys.py#L211

Added line #L211 was not covered by tests


@dataclass(frozen=True)
class Alias(Type):
name: str
bound: TypeBound

def to_serial(self) -> stys.Alias:
return stys.Alias(name=self.name, bound=self.bound)

Check warning on line 220 in hugr-py/src/hugr/_tys.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/_tys.py#L220

Added line #L220 was not covered by tests


@dataclass(frozen=True)
class FunctionType(Type):
input: list[Type]
output: list[Type]
extension_reqs: ExtensionSet = field(default_factory=ExtensionSet)

def to_serial(self) -> stys.FunctionType:
return stys.FunctionType(input=ser_it(self.input), output=ser_it(self.output))

@classmethod
def empty(cls) -> FunctionType:
return cls(input=[], output=[])


@dataclass(frozen=True)
class PolyFuncType(Type):
params: list[TypeParam]
body: FunctionType

def to_serial(self) -> stys.PolyFuncType:
return stys.PolyFuncType(

Check warning on line 243 in hugr-py/src/hugr/_tys.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/_tys.py#L243

Added line #L243 was not covered by tests
params=[p.to_serial_root() for p in self.params], body=self.body.to_serial()
)


@dataclass
class Opaque(Type):
id: str
bound: TypeBound
args: list[TypeArg] = field(default_factory=list)
extension: ExtensionId = ""

def to_serial(self) -> stys.Opaque:
return stys.Opaque(
extension=self.extension,
id=self.id,
args=[arg.to_serial_root() for arg in self.args],
bound=self.bound,
)


@dataclass
class QubitDef(Type):
def to_serial(self) -> stys.Qubit:
return stys.Qubit()


Qubit = QubitDef()
Bool = UnitSum(size=2)
Loading

0 comments on commit 1c1fa1b

Please sign in to comment.