Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(hugr-py): add values and constants #1203

Merged
merged 7 commits into from
Jun 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions hugr-py/src/hugr/_cfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,16 @@
from ._exceptions import NoSiblingAncestor, NotInSameCfg, MismatchedExit
from ._hugr import Hugr, Node, ParentBuilder, ToNode, Wire
from ._tys import FunctionType, TypeRow, Type
import hugr._val as val


class Block(_DfBase[ops.DataflowBlock]):
def set_block_outputs(self, branching: Wire, *other_outputs: Wire) -> None:
self.set_outputs(branching, *other_outputs)

def set_single_successor_outputs(self, *outputs: Wire) -> None:
# TODO requires constants
raise NotImplementedError
def set_single_succ_outputs(self, *outputs: Wire) -> None:
u = self.add_load_const(val.Unit)
self.set_outputs(u, *outputs)

def _wire_up_port(self, node: Node, offset: int, p: Wire) -> Type:
src = p.out_port()
Expand Down
26 changes: 21 additions & 5 deletions hugr-py/src/hugr/_dfg.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,20 @@
from __future__ import annotations

from dataclasses import dataclass, replace
from typing import (
Iterable,
TYPE_CHECKING,
Iterable,
TypeVar,
)
from ._hugr import Hugr, Node, Wire, OutPort, ParentBuilder

from typing_extensions import Self

import hugr._ops as ops
from hugr._tys import TypeRow
import hugr._val as val
from hugr._tys import Type, TypeRow

from ._exceptions import NoSiblingAncestor
from ._hugr import ToNode
from hugr._tys import Type
from ._hugr import Hugr, Node, OutPort, ParentBuilder, ToNode, Wire

if TYPE_CHECKING:
from ._cfg import Cfg
Expand Down Expand Up @@ -113,6 +114,21 @@ def add_state_order(self, src: Node, dst: Node) -> None:
# adds edge to the right of all existing edges
self.hugr.add_link(src.out(-1), dst.inp(-1))

def add_const(self, val: val.Value) -> Node:
return self.hugr.add_const(val, self.parent_node)

def load_const(self, const_node: ToNode) -> Node:
const_op = self.hugr._get_typed_op(const_node, ops.Const)
load_op = ops.LoadConst(const_op.val.type_())

load = self.add(load_op())
self.hugr.add_link(const_node.out_port(), load.inp(0))

return load

def add_load_const(self, val: val.Value) -> Node:
return self.load_const(self.add_const(val))

def _wire_up(self, node: Node, ports: Iterable[Wire]):
tys = [self._wire_up_port(node, i, p) for i, p in enumerate(ports)]
if isinstance(op := self.hugr[node].op, ops.PartialOp):
Expand Down
6 changes: 5 additions & 1 deletion hugr-py/src/hugr/_hugr.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@

from typing_extensions import Self

from hugr._ops import Op, DataflowOp
from hugr._ops import Op, DataflowOp, Const
from hugr._tys import Type, Kind
from hugr._val import Value
from hugr.serialization.ops import OpType as SerialOp
from hugr.serialization.serial_hugr import SerialHugr
from hugr.utils import BiMap
Expand Down Expand Up @@ -228,6 +229,9 @@ def add_node(
parent = parent or self.root
return self._add_node(op, parent, num_outs)

def add_const(self, value: Value, parent: ToNode | None = None) -> Node:
return self.add_node(Const(value), parent)

def delete_node(self, node: ToNode) -> NodeData | None:
node = node.to_node()
parent = self[node].parent
Expand Down
56 changes: 35 additions & 21 deletions hugr-py/src/hugr/_ops.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from __future__ import annotations

from dataclasses import dataclass, field
from typing import Generic, Protocol, TypeVar, TYPE_CHECKING, runtime_checkable
from typing import Protocol, TYPE_CHECKING, runtime_checkable, TypeVar
from hugr.serialization.ops import BaseOp
import hugr.serialization.ops as sops
from hugr.utils import ser_it
import hugr._tys as tys
import hugr._val as val
from ._exceptions import IncompleteOp

if TYPE_CHECKING:
Expand Down Expand Up @@ -55,23 +56,6 @@
incoming: list[Wire]


T = TypeVar("T", bound=BaseOp)


@dataclass()
class SerWrap(Op, Generic[T]):
# catch all for serial ops that don't have a corresponding Op class
_serial_op: T

def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> T:
root = self._serial_op.model_copy()
root.parent = parent.idx
return root

def port_kind(self, port: InPort | OutPort) -> tys.Kind:
raise NotImplementedError


@dataclass()
class Input(DataflowOp):
types: tys.TypeRow
Expand Down Expand Up @@ -304,9 +288,7 @@

@property
def other_outputs(self) -> tys.TypeRow:
if self._other_outputs is None:
raise IncompleteOp()
return self._other_outputs
return _check_complete(self._other_outputs)

@property
def num_out(self) -> int | None:
Expand Down Expand Up @@ -359,3 +341,35 @@

def port_kind(self, port: InPort | OutPort) -> tys.Kind:
return tys.CFKind()


@dataclass
class Const(Op):
val: val.Value
num_out: int | None = 1

def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.Const:
return sops.Const(
parent=parent.idx,
v=self.val.to_serial_root(),
)

def port_kind(self, port: InPort | OutPort) -> tys.Kind:
return tys.ConstKind(self.val.type_())

Check warning on line 358 in hugr-py/src/hugr/_ops.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/_ops.py#L358

Added line #L358 was not covered by tests


@dataclass
class LoadConst(DataflowOp):
typ: tys.Type | None = None

def type_(self) -> tys.Type:
return _check_complete(self.typ)

def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.LoadConstant:
return sops.LoadConstant(
parent=parent.idx,
datatype=self.type_().to_serial_root(),
)

def outer_signature(self) -> tys.FunctionType:
return tys.FunctionType(input=[], output=[self.type_()])
103 changes: 103 additions & 0 deletions hugr-py/src/hugr/_val.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Any, Protocol, runtime_checkable, TYPE_CHECKING
import hugr.serialization.ops as sops
import hugr.serialization.tys as stys
import hugr._tys as tys
from hugr.utils import ser_it

if TYPE_CHECKING:
from hugr._hugr import Hugr

Check warning on line 10 in hugr-py/src/hugr/_val.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/_val.py#L10

Added line #L10 was not covered by tests


@runtime_checkable
class Value(Protocol):
def to_serial(self) -> sops.BaseValue: ...
def to_serial_root(self) -> sops.Value:
return sops.Value(root=self.to_serial()) # type: ignore[arg-type]

def type_(self) -> tys.Type: ...


@dataclass
class Sum(Value):
tag: int
typ: tys.Sum
vals: list[Value]

def type_(self) -> tys.Sum:
return self.typ

def to_serial(self) -> sops.SumValue:
return sops.SumValue(
tag=self.tag,
typ=stys.SumType(root=self.type_().to_serial()),
vs=ser_it(self.vals),
)


def bool_value(b: bool) -> Sum:
return Sum(
tag=int(b),
typ=tys.Bool,
vals=[],
)


Unit = Sum(0, tys.Unit, [])
TRUE = bool_value(True)
FALSE = bool_value(False)


@dataclass
class Tuple(Value):
vals: list[Value]

def type_(self) -> tys.Tuple:
return tys.Tuple(*(v.type_() for v in self.vals))

def to_serial(self) -> sops.TupleValue:
return sops.TupleValue(
vs=ser_it(self.vals),
)


@dataclass
class Function(Value):
body: Hugr

def type_(self) -> tys.FunctionType:
return self.body.root_op().inner_signature()

def to_serial(self) -> sops.FunctionValue:
return sops.FunctionValue(
hugr=self.body.to_serial(),
)


@dataclass
class Extension(Value):
name: str
typ: tys.Type
val: Any
extensions: tys.ExtensionSet = field(default_factory=tys.ExtensionSet)

def type_(self) -> tys.Type:
return self.typ

def to_serial(self) -> sops.ExtensionValue:
return sops.ExtensionValue(
typ=self.typ.to_serial_root(),
value=sops.CustomConst(c=self.name, v=self.val),
extensions=self.extensions,
)


class ExtensionValue(Value, Protocol):
def to_value(self) -> Extension: ...

def type_(self) -> tys.Type:
return self.to_value().type_()

def to_serial(self) -> sops.ExtensionValue:
return self.to_value().to_serial()
Loading
Loading