Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(hugr-py): define type hierarchy separate from serialized #1176

Merged
merged 2 commits into from
Jun 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion hugr-py/src/hugr/_dfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from ._ops import Op, Command, Input, Output, DFG
from ._exceptions import NoSiblingAncestor
from hugr.serialization.tys import FunctionType, Type
from hugr._tys import FunctionType, Type


@dataclass()
Expand Down
2 changes: 1 addition & 1 deletion hugr-py/src/hugr/_hugr.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@
from typing_extensions import Self

from hugr._ops import Op
from hugr._tys import Type
from hugr.serialization.ops import OpType as SerialOp
from hugr.serialization.serial_hugr import SerialHugr
from hugr.serialization.tys import Type
from hugr.utils import BiMap

from ._exceptions import ParentBeforeChild
Expand Down
25 changes: 13 additions & 12 deletions hugr-py/src/hugr/_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
from typing import Generic, Protocol, TypeVar, TYPE_CHECKING
from hugr.serialization.ops import BaseOp
import hugr.serialization.ops as sops
import hugr.serialization.tys as tys
from hugr.utils import ser_it
import hugr._tys as tys

if TYPE_CHECKING:
from hugr._hugr import Hugr, Node, Wire
Expand Down Expand Up @@ -43,25 +44,25 @@ def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> T:

@dataclass()
class Input(Op):
types: list[tys.Type]
types: tys.TypeRow

@property
def num_out(self) -> int | None:
return len(self.types)

def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.Input:
return sops.Input(parent=parent.idx, types=self.types)
return sops.Input(parent=parent.idx, types=ser_it(self.types))

def __call__(self) -> Command:
return super().__call__()


@dataclass()
class Output(Op):
types: list[tys.Type]
types: tys.TypeRow

def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.Output:
return sops.Output(parent=parent.idx, types=self.types)
return sops.Output(parent=parent.idx, types=ser_it(self.types))


@dataclass()
Expand All @@ -81,21 +82,21 @@ def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.CustomOp:
parent=parent.idx,
extension=self.extension,
op_name=self.op_name,
signature=self.signature,
signature=self.signature.to_serial(),
description=self.description,
args=self.args,
args=ser_it(self.args),
)


@dataclass()
class MakeTuple(Op):
types: list[tys.Type]
types: tys.TypeRow
num_out: int | None = 1

def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.MakeTuple:
return sops.MakeTuple(
parent=parent.idx,
tys=self.types,
tys=ser_it(self.types),
)

def __call__(self, *elements: Wire) -> Command:
Expand All @@ -104,7 +105,7 @@ def __call__(self, *elements: Wire) -> Command:

@dataclass()
class UnpackTuple(Op):
types: list[tys.Type]
types: tys.TypeRow

@property
def num_out(self) -> int | None:
Expand All @@ -113,7 +114,7 @@ def num_out(self) -> int | None:
def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.UnpackTuple:
return sops.UnpackTuple(
parent=parent.idx,
tys=self.types,
tys=ser_it(self.types),
)

def __call__(self, tuple_: Wire) -> Command:
Expand All @@ -131,5 +132,5 @@ def num_out(self) -> int | None:
def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.DFG:
return sops.DFG(
parent=parent.idx,
signature=self.signature,
signature=self.signature.to_serial(),
)
271 changes: 271 additions & 0 deletions hugr-py/src/hugr/_tys.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,271 @@
from __future__ import annotations
from dataclasses import dataclass, field
import hugr.serialization.tys as stys
from hugr.utils import ser_it
from typing import Any, Protocol, runtime_checkable

ExtensionId = stys.ExtensionId
ExtensionSet = stys.ExtensionSet
TypeBound = stys.TypeBound


class TypeParam(Protocol):
"""A type parameter."""

def to_serial(self) -> stys.BaseTypeParam: ...

def to_serial_root(self) -> stys.TypeParam:
return stys.TypeParam(root=self.to_serial()) # type: ignore[arg-type]

Check warning on line 18 in hugr-py/src/hugr/_tys.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/_tys.py#L18

Added line #L18 was not covered by tests


class TypeArg(Protocol):
"""A type argument."""

def to_serial(self) -> stys.BaseTypeArg: ...

def to_serial_root(self) -> stys.TypeArg:
return stys.TypeArg(root=self.to_serial()) # type: ignore[arg-type]


@runtime_checkable
class Type(Protocol):
"""A type."""

def to_serial(self) -> stys.BaseType: ...

def to_serial_root(self) -> stys.Type:
return stys.Type(root=self.to_serial()) # type: ignore[arg-type]


TypeRow = list[Type]

# --------------------------------------------
# --------------- TypeParam ------------------
# --------------------------------------------


@dataclass(frozen=True)
class TypeTypeParam(TypeParam):
bound: TypeBound

def to_serial(self) -> stys.TypeTypeParam:
return stys.TypeTypeParam(b=self.bound)

Check warning on line 52 in hugr-py/src/hugr/_tys.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/_tys.py#L52

Added line #L52 was not covered by tests


@dataclass(frozen=True)
class BoundedNatParam(TypeParam):
upper_bound: int | None

def to_serial(self) -> stys.BoundedNatParam:
return stys.BoundedNatParam(bound=self.upper_bound)

Check warning on line 60 in hugr-py/src/hugr/_tys.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/_tys.py#L60

Added line #L60 was not covered by tests


@dataclass(frozen=True)
class OpaqueParam(TypeParam):
ty: Opaque

def to_serial(self) -> stys.OpaqueParam:
return stys.OpaqueParam(ty=self.ty.to_serial())

Check warning on line 68 in hugr-py/src/hugr/_tys.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/_tys.py#L68

Added line #L68 was not covered by tests


@dataclass(frozen=True)
class ListParam(TypeParam):
param: TypeParam

def to_serial(self) -> stys.ListParam:
return stys.ListParam(param=self.param.to_serial_root())

Check warning on line 76 in hugr-py/src/hugr/_tys.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/_tys.py#L76

Added line #L76 was not covered by tests


@dataclass(frozen=True)
class TupleParam(TypeParam):
params: list[TypeParam]

def to_serial(self) -> stys.TupleParam:
return stys.TupleParam(params=ser_it(self.params))

Check warning on line 84 in hugr-py/src/hugr/_tys.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/_tys.py#L84

Added line #L84 was not covered by tests


@dataclass(frozen=True)
class ExtensionsParam(TypeParam):
def to_serial(self) -> stys.ExtensionsParam:
return stys.ExtensionsParam()

Check warning on line 90 in hugr-py/src/hugr/_tys.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/_tys.py#L90

Added line #L90 was not covered by tests


# ------------------------------------------
# --------------- TypeArg ------------------
# ------------------------------------------


@dataclass(frozen=True)
class TypeTypeArg(TypeArg):
ty: Type

def to_serial(self) -> stys.TypeTypeArg:
return stys.TypeTypeArg(ty=self.ty.to_serial_root())

Check warning on line 103 in hugr-py/src/hugr/_tys.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/_tys.py#L103

Added line #L103 was not covered by tests


@dataclass(frozen=True)
class BoundedNatArg(TypeArg):
n: int

def to_serial(self) -> stys.BoundedNatArg:
return stys.BoundedNatArg(n=self.n)


@dataclass(frozen=True)
class OpaqueArg(TypeArg):
ty: Opaque
value: Any

def to_serial(self) -> stys.OpaqueArg:
return stys.OpaqueArg(typ=self.ty.to_serial(), value=self.value)

Check warning on line 120 in hugr-py/src/hugr/_tys.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/_tys.py#L120

Added line #L120 was not covered by tests


@dataclass(frozen=True)
class SequenceArg(TypeArg):
elems: list[TypeArg]

def to_serial(self) -> stys.SequenceArg:
return stys.SequenceArg(elems=ser_it(self.elems))

Check warning on line 128 in hugr-py/src/hugr/_tys.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/_tys.py#L128

Added line #L128 was not covered by tests


@dataclass(frozen=True)
class ExtensionsArg(TypeArg):
extensions: ExtensionSet

def to_serial(self) -> stys.ExtensionsArg:
return stys.ExtensionsArg(es=self.extensions)

Check warning on line 136 in hugr-py/src/hugr/_tys.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/_tys.py#L136

Added line #L136 was not covered by tests


@dataclass(frozen=True)
class VariableArg(TypeArg):
idx: int
param: TypeParam

def to_serial(self) -> stys.VariableArg:
return stys.VariableArg(idx=self.idx, cached_decl=self.param.to_serial_root())

Check warning on line 145 in hugr-py/src/hugr/_tys.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/_tys.py#L145

Added line #L145 was not covered by tests


# ----------------------------------------------
# --------------- Type -------------------------
# ----------------------------------------------


@dataclass(frozen=True)
class Array(Type):
ty: Type
size: int

def to_serial(self) -> stys.Array:
return stys.Array(ty=self.ty.to_serial_root(), len=self.size)

Check warning on line 159 in hugr-py/src/hugr/_tys.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/_tys.py#L159

Added line #L159 was not covered by tests


@dataclass(frozen=True)
class UnitSum(Type):
size: int

def to_serial(self) -> stys.UnitSum:
return stys.UnitSum(size=self.size)


@dataclass()
class Sum(Type):
variant_rows: list[TypeRow]

def to_serial(self) -> stys.GeneralSum:
return stys.GeneralSum(rows=[ser_it(row) for row in self.variant_rows])

Check warning on line 175 in hugr-py/src/hugr/_tys.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/_tys.py#L175

Added line #L175 was not covered by tests

def as_tuple(self) -> Tuple:
assert (

Check warning on line 178 in hugr-py/src/hugr/_tys.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/_tys.py#L178

Added line #L178 was not covered by tests
len(self.variant_rows) == 1
), "Sum type must have exactly one row to be converted to a Tuple"
return Tuple(*self.variant_rows[0])

Check warning on line 181 in hugr-py/src/hugr/_tys.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/_tys.py#L181

Added line #L181 was not covered by tests


@dataclass()
class Tuple(Sum):
def __init__(self, *tys: Type):
self.variant_rows = [list(tys)]

Check warning on line 187 in hugr-py/src/hugr/_tys.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/_tys.py#L187

Added line #L187 was not covered by tests


@dataclass(frozen=True)
class Variable(Type):
idx: int
bound: TypeBound

def to_serial(self) -> stys.Variable:
return stys.Variable(i=self.idx, b=self.bound)

Check warning on line 196 in hugr-py/src/hugr/_tys.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/_tys.py#L196

Added line #L196 was not covered by tests


@dataclass(frozen=True)
class RowVariable(Type):
idx: int
bound: TypeBound

def to_serial(self) -> stys.RowVar:
return stys.RowVar(i=self.idx, b=self.bound)

Check warning on line 205 in hugr-py/src/hugr/_tys.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/_tys.py#L205

Added line #L205 was not covered by tests


@dataclass(frozen=True)
class USize(Type):
def to_serial(self) -> stys.USize:
return stys.USize()

Check warning on line 211 in hugr-py/src/hugr/_tys.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/_tys.py#L211

Added line #L211 was not covered by tests


@dataclass(frozen=True)
class Alias(Type):
name: str
bound: TypeBound

def to_serial(self) -> stys.Alias:
return stys.Alias(name=self.name, bound=self.bound)

Check warning on line 220 in hugr-py/src/hugr/_tys.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/_tys.py#L220

Added line #L220 was not covered by tests


@dataclass(frozen=True)
class FunctionType(Type):
input: list[Type]
output: list[Type]
extension_reqs: ExtensionSet = field(default_factory=ExtensionSet)

def to_serial(self) -> stys.FunctionType:
return stys.FunctionType(input=ser_it(self.input), output=ser_it(self.output))

@classmethod
def empty(cls) -> FunctionType:
return cls(input=[], output=[])


@dataclass(frozen=True)
class PolyFuncType(Type):
params: list[TypeParam]
body: FunctionType

def to_serial(self) -> stys.PolyFuncType:
return stys.PolyFuncType(

Check warning on line 243 in hugr-py/src/hugr/_tys.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/_tys.py#L243

Added line #L243 was not covered by tests
params=[p.to_serial_root() for p in self.params], body=self.body.to_serial()
)


@dataclass
class Opaque(Type):
id: str
bound: TypeBound
args: list[TypeArg] = field(default_factory=list)
extension: ExtensionId = ""

def to_serial(self) -> stys.Opaque:
return stys.Opaque(
extension=self.extension,
id=self.id,
args=[arg.to_serial_root() for arg in self.args],
bound=self.bound,
)


@dataclass
class QubitDef(Type):
def to_serial(self) -> stys.Qubit:
return stys.Qubit()


Qubit = QubitDef()
Bool = UnitSum(size=2)
Loading
Loading