Skip to content

Commit

Permalink
feat(hugr-py): add values ans constants
Browse files Browse the repository at this point in the history
  • Loading branch information
ss2165 committed Jun 18, 2024
1 parent 2bb079f commit 8d4de2d
Show file tree
Hide file tree
Showing 8 changed files with 223 additions and 31 deletions.
7 changes: 4 additions & 3 deletions hugr-py/src/hugr/_cfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,16 @@
from ._exceptions import NoSiblingAncestor, NotInSameCfg, MismatchedExit
from ._hugr import Hugr, Node, ParentBuilder, ToNode, Wire
from ._tys import FunctionType, TypeRow, Type
import hugr._val as val


class Block(_DfBase[ops.DataflowBlock]):
def set_block_outputs(self, branching: Wire, *other_outputs: Wire) -> None:
self.set_outputs(branching, *other_outputs)

def set_single_successor_outputs(self, *outputs: Wire) -> None:
# TODO requires constants
raise NotImplementedError
def set_single_succ_outputs(self, *outputs: Wire) -> None:
u = self.add_load_const(val.Unit)
self.set_outputs(u, *outputs)

def _wire_up_port(self, node: Node, offset: int, p: Wire) -> Type:
src = p.out_port()
Expand Down
26 changes: 21 additions & 5 deletions hugr-py/src/hugr/_dfg.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,20 @@
from __future__ import annotations

from dataclasses import dataclass, replace
from typing import (
Iterable,
TYPE_CHECKING,
Iterable,
TypeVar,
)
from ._hugr import Hugr, Node, Wire, OutPort, ParentBuilder

from typing_extensions import Self

import hugr._ops as ops
from hugr._tys import TypeRow
import hugr._val as val
from hugr._tys import Type, TypeRow

from ._exceptions import NoSiblingAncestor
from ._hugr import ToNode
from hugr._tys import Type
from ._hugr import Hugr, Node, OutPort, ParentBuilder, ToNode, Wire

if TYPE_CHECKING:
from ._cfg import Cfg
Expand Down Expand Up @@ -113,6 +114,21 @@ def add_state_order(self, src: Node, dst: Node) -> None:
# adds edge to the right of all existing edges
self.hugr.add_link(src.out(-1), dst.inp(-1))

def add_const(self, val: val.Value) -> Node:
return self.hugr.add_const(val, self.parent_node)

def load_const(self, const_node: ToNode) -> Node:
const_op = self.hugr._get_typed_op(const_node, ops.Const)
load_op = ops.LoadConst(const_op.val.type_())

load = self.add(load_op())
self.hugr.add_link(const_node.out_port(), load.inp(0))

return load

def add_load_const(self, val: val.Value) -> Node:
return self.load_const(self.add_const(val))

def _wire_up(self, node: Node, ports: Iterable[Wire]):
tys = [self._wire_up_port(node, i, p) for i, p in enumerate(ports)]
if isinstance(op := self.hugr[node].op, ops.PartialOp):
Expand Down
6 changes: 5 additions & 1 deletion hugr-py/src/hugr/_hugr.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@

from typing_extensions import Self

from hugr._ops import Op, DataflowOp
from hugr._ops import Op, DataflowOp, Const
from hugr._tys import Type, Kind
from hugr._val import Value
from hugr.serialization.ops import OpType as SerialOp
from hugr.serialization.serial_hugr import SerialHugr
from hugr.utils import BiMap
Expand Down Expand Up @@ -228,6 +229,9 @@ def add_node(
parent = parent or self.root
return self._add_node(op, parent, num_outs)

def add_const(self, value: Value, parent: ToNode | None = None) -> Node:
return self.add_node(Const(value), parent)

def delete_node(self, node: ToNode) -> NodeData | None:
node = node.to_node()
parent = self[node].parent
Expand Down
35 changes: 35 additions & 0 deletions hugr-py/src/hugr/_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import hugr.serialization.ops as sops
from hugr.utils import ser_it
import hugr._tys as tys
import hugr._val as val
from ._exceptions import IncompleteOp

if TYPE_CHECKING:
Expand Down Expand Up @@ -360,3 +361,37 @@ def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.ExitBlock:

def port_kind(self, port: InPort | OutPort) -> tys.Kind:
return tys.CFKind()


@dataclass
class Const(Op):
val: val.Value
num_out: int | None = 1

def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.Const:
return sops.Const(
parent=parent.idx,
v=self.val.to_serial_root(),
)

def port_kind(self, port: InPort | OutPort) -> tys.Kind:
return tys.ConstKind(self.val.type_())

Check warning on line 378 in hugr-py/src/hugr/_ops.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/_ops.py#L378

Added line #L378 was not covered by tests


@dataclass
class LoadConst(DataflowOp):
typ: tys.Type | None = None

def type_(self) -> tys.Type:
if self.typ is None:
raise IncompleteOp()

Check warning on line 387 in hugr-py/src/hugr/_ops.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/_ops.py#L387

Added line #L387 was not covered by tests
return self.typ

def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.LoadConstant:
return sops.LoadConstant(
parent=parent.idx,
datatype=self.type_().to_serial_root(),
)

def outer_signature(self) -> tys.FunctionType:
return tys.FunctionType(input=[], output=[self.type_()])
103 changes: 103 additions & 0 deletions hugr-py/src/hugr/_val.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Any, Protocol, runtime_checkable, TYPE_CHECKING
import hugr.serialization.ops as sops
import hugr.serialization.tys as stys
import hugr._tys as tys
from hugr.utils import ser_it

if TYPE_CHECKING:
from hugr._hugr import Hugr

Check warning on line 10 in hugr-py/src/hugr/_val.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/_val.py#L10

Added line #L10 was not covered by tests


@runtime_checkable
class Value(Protocol):
def to_serial(self) -> sops.BaseValue: ...
def to_serial_root(self) -> sops.Value:
return sops.Value(root=self.to_serial()) # type: ignore[arg-type]

def type_(self) -> tys.Type: ...


@dataclass
class Sum(Value):
tag: int
typ: tys.Sum
vals: list[Value]

def type_(self) -> tys.Sum:
return self.typ

def to_serial(self) -> sops.SumValue:
return sops.SumValue(
tag=self.tag,
typ=stys.SumType(root=self.type_().to_serial()),
vs=ser_it(self.vals),
)


def bool_value(b: bool) -> Sum:
return Sum(
tag=int(b),
typ=tys.Bool,
vals=[],
)


Unit = Sum(0, tys.Unit, [])
TRUE = bool_value(True)
FALSE = bool_value(False)


@dataclass
class Tuple(Value):
vals: list[Value]

def type_(self) -> tys.Tuple:
return tys.Tuple(*(v.type_() for v in self.vals))

Check warning on line 57 in hugr-py/src/hugr/_val.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/_val.py#L57

Added line #L57 was not covered by tests

def to_serial(self) -> sops.TupleValue:
return sops.TupleValue(

Check warning on line 60 in hugr-py/src/hugr/_val.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/_val.py#L60

Added line #L60 was not covered by tests
vs=ser_it(self.vals),
)


@dataclass
class Function(Value):
body: Hugr

def type_(self) -> tys.FunctionType:
return self.body.root_op().inner_signature()

Check warning on line 70 in hugr-py/src/hugr/_val.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/_val.py#L70

Added line #L70 was not covered by tests

def to_serial(self) -> sops.FunctionValue:
return sops.FunctionValue(

Check warning on line 73 in hugr-py/src/hugr/_val.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/_val.py#L73

Added line #L73 was not covered by tests
hugr=self.body.to_serial(),
)


@dataclass
class Extension(Value):
name: str
typ: tys.Type
val: Any
extensions: tys.ExtensionSet = field(default_factory=tys.ExtensionSet)

def type_(self) -> tys.Type:
return self.typ

def to_serial(self) -> sops.ExtensionValue:
return sops.ExtensionValue(
typ=self.typ.to_serial_root(),
value=sops.CustomConst(c=self.name, v=self.val),
extensions=self.extensions,
)


class ExtensionValue(Value, Protocol):
def to_value(self) -> Extension: ...

def type_(self) -> tys.Type:
return self.to_value().type_()

def to_serial(self) -> sops.ExtensionValue:
return self.to_value().to_serial()
41 changes: 32 additions & 9 deletions hugr-py/src/hugr/serialization/ops.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations
import inspect
import sys
from abc import ABC
from abc import ABC, abstractmethod
from typing import Any, Literal

from pydantic import Field, RootModel, ConfigDict
Expand Down Expand Up @@ -80,36 +80,50 @@ class CustomConst(ConfiguredBaseModel):
v: Any


class ExtensionValue(ConfiguredBaseModel):
class BaseValue(ABC, ConfiguredBaseModel):
@abstractmethod
def deserialize(self) -> _val.Value: ...


class ExtensionValue(BaseValue):
"""An extension constant value, that can check it is of a given [CustomType]."""

v: Literal["Extension"] = Field("Extension", title="ValueTag")
v: Literal["Extension"] = Field(default="Extension", title="ValueTag")
extensions: ExtensionSet
typ: Type
value: CustomConst

def deserialize(self) -> _val.Value:
return _val.Extension(self.value.c, self.typ.deserialize(), self.value)

Check warning on line 97 in hugr-py/src/hugr/serialization/ops.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/serialization/ops.py#L97

Added line #L97 was not covered by tests

class FunctionValue(ConfiguredBaseModel):

class FunctionValue(BaseValue):
"""A higher-order function value."""

v: Literal["Function"] = Field("Function", title="ValueTag")
v: Literal["Function"] = Field(default="Function", title="ValueTag")
hugr: Any # TODO

def deserialize(self) -> _val.Value:
return _val.Function(self.hugr)

Check warning on line 107 in hugr-py/src/hugr/serialization/ops.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/serialization/ops.py#L107

Added line #L107 was not covered by tests


class TupleValue(ConfiguredBaseModel):
class TupleValue(BaseValue):
"""A constant tuple value."""

v: Literal["Tuple"] = Field("Tuple", title="ValueTag")
v: Literal["Tuple"] = Field(default="Tuple", title="ValueTag")
vs: list["Value"]

def deserialize(self) -> _val.Value:
return _val.Tuple(deser_it((v.root for v in self.vs)))

Check warning on line 117 in hugr-py/src/hugr/serialization/ops.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/serialization/ops.py#L117

Added line #L117 was not covered by tests

class SumValue(ConfiguredBaseModel):

class SumValue(BaseValue):
"""A Sum variant
For any Sum type where this value meets the type of the variant indicated by the tag
"""

v: Literal["Sum"] = Field("Sum", title="ValueTag")
v: Literal["Sum"] = Field(default="Sum", title="ValueTag")
tag: int
typ: SumType
vs: list["Value"]
Expand All @@ -122,6 +136,11 @@ class SumValue(ConfiguredBaseModel):
}
)

def deserialize(self) -> _val.Value:
return _val.Sum(

Check warning on line 140 in hugr-py/src/hugr/serialization/ops.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/serialization/ops.py#L140

Added line #L140 was not covered by tests
self.tag, self.typ.deserialize(), deser_it((v.root for v in self.vs))
)


class Value(RootModel):
"""A constant Value."""
Expand Down Expand Up @@ -282,6 +301,9 @@ class LoadConstant(DataflowOp):
op: Literal["LoadConstant"] = "LoadConstant"
datatype: Type

def deserialize(self) -> _ops.LoadConst:
return _ops.LoadConst(self.datatype.deserialize())


class LoadFunction(DataflowOp):
"""Load a static function in to the local dataflow graph."""
Expand Down Expand Up @@ -560,3 +582,4 @@ class OpDef(ConfiguredBaseModel, populate_by_name=True):
tys_model_rebuild(dict(classes))

from hugr import _ops # noqa: E402 # needed to avoid circular imports
from hugr import _val # noqa: E402 # needed to avoid circular imports
Loading

0 comments on commit 8d4de2d

Please sign in to comment.