Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(hugr-py): builder for function definition/declaration and call #1212

Merged
merged 3 commits into from
Jun 21, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion hugr-py/src/hugr/_cfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@

from ._dfg import _DfBase
from ._exceptions import NoSiblingAncestor, NotInSameCfg, MismatchedExit
from ._hugr import Hugr, Node, ParentBuilder, ToNode, Wire
from ._hugr import Hugr, ParentBuilder
from ._node_port import Node, Wire, ToNode
from ._tys import TypeRow, Type
import hugr._val as val

Expand Down
4 changes: 3 additions & 1 deletion hugr-py/src/hugr/_cond_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
import hugr._ops as ops

from ._dfg import _DfBase
from ._hugr import Hugr, Node, ParentBuilder, ToNode, Wire
from ._hugr import Hugr, ParentBuilder
from ._node_port import Node, Wire, ToNode

from ._tys import Sum, TypeRow


Expand Down
27 changes: 25 additions & 2 deletions hugr-py/src/hugr/_dfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,11 @@

import hugr._ops as ops
import hugr._val as val
from hugr._tys import Type, TypeRow, get_first_sum
from hugr._tys import Type, TypeRow, get_first_sum, FunctionType, TypeArg, FunctionKind

from ._exceptions import NoSiblingAncestor
from ._hugr import Hugr, Node, OutPort, ParentBuilder, ToNode, Wire
from ._hugr import Hugr, ParentBuilder
from ._node_port import Node, OutPort, Wire, ToNode

if TYPE_CHECKING:
from ._cfg import Cfg
Expand Down Expand Up @@ -164,6 +165,28 @@

return load

def call(
self,
func: ToNode,
*args: Wire,
instantiation: FunctionType | None = None,
type_args: list[TypeArg] | None = None,
) -> Node:
f_op = self.hugr[func]
f_kind = f_op.op.port_kind(func.out(0))
match f_kind:
case FunctionKind(sig):
signature = sig
case _:
raise ValueError(f"Expected function type, got {f_kind}")

Check warning on line 181 in hugr-py/src/hugr/_dfg.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/_dfg.py#L180-L181

Added lines #L180 - L181 were not covered by tests
ss2165 marked this conversation as resolved.
Show resolved Hide resolved
call_op = ops.Call(signature, instantiation, type_args)
call_n = self.hugr.add_node(call_op, self.parent_node, call_op.num_out)
self.hugr.add_link(func.out(0), call_n.inp(call_op.function_port_offset()))

self._wire_up(call_n, args)

return call_n

def _wire_up(self, node: Node, ports: Iterable[Wire]) -> TypeRow:
tys = [self._wire_up_port(node, i, p) for i, p in enumerate(ports)]
if isinstance(op := self.hugr[node].op, ops.PartialOp):
Expand Down
49 changes: 49 additions & 0 deletions hugr-py/src/hugr/_function.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from __future__ import annotations

from dataclasses import dataclass

import hugr._ops as ops
import hugr._val as val

from ._dfg import _DfBase
from hugr._node_port import Node
from ._hugr import Hugr
from ._tys import TypeRow, TypeParam, PolyFuncType


@dataclass
class Function(_DfBase[ops.FuncDefn]):
def __init__(
self,
name: str,
input_types: TypeRow,
type_params: list[TypeParam] | None = None,
) -> None:
root_op = ops.FuncDefn(name, input_types, type_params or [])
super().__init__(root_op)

Check warning on line 23 in hugr-py/src/hugr/_function.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/_function.py#L22-L23

Added lines #L22 - L23 were not covered by tests


@dataclass
class Module:
hugr: Hugr

def __init__(self) -> None:
self.hugr = Hugr(ops.Module())

def define_function(
self,
name: str,
input_types: TypeRow,
type_params: list[TypeParam] | None = None,
) -> Function:
parent_op = ops.FuncDefn(name, input_types, type_params or [])
return Function.new_nested(parent_op, self.hugr)

def define_main(self, input_types: TypeRow) -> Function:
return self.define_function("main", input_types)

def declare_function(self, name: str, signature: PolyFuncType) -> Node:
return self.hugr.add_node(ops.FuncDecl(name, signature), self.hugr.root)

def add_const(self, value: val.Value) -> Node:
return self.hugr.add_node(ops.Const(value), self.hugr.root)

Check warning on line 49 in hugr-py/src/hugr/_function.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/_function.py#L49

Added line #L49 was not covered by tests
119 changes: 10 additions & 109 deletions hugr-py/src/hugr/_hugr.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,121 +2,28 @@

from collections.abc import Mapping
from dataclasses import dataclass, field, replace
from enum import Enum
from typing import (
ClassVar,
Generic,
Iterable,
Iterator,
Protocol,
TypeVar,
cast,
overload,
Type as PyType,
)

from typing_extensions import Self

from hugr._ops import Op, DataflowOp, Const
from hugr._tys import Type, Kind
from hugr._ops import Op, DataflowOp, Const, Call
from hugr._tys import Type, Kind, ValueKind
from hugr._val import Value
from hugr._node_port import Direction, InPort, OutPort, ToNode, Node, _SubPort
from hugr.serialization.ops import OpType as SerialOp
from hugr.serialization.serial_hugr import SerialHugr
from hugr.utils import BiMap

from ._exceptions import ParentBeforeChild


class Direction(Enum):
INCOMING = 0
OUTGOING = 1


@dataclass(frozen=True, eq=True, order=True)
class _Port:
node: Node
offset: int
direction: ClassVar[Direction]


@dataclass(frozen=True, eq=True, order=True)
class InPort(_Port):
direction: ClassVar[Direction] = Direction.INCOMING


class Wire(Protocol):
def out_port(self) -> OutPort: ...


@dataclass(frozen=True, eq=True, order=True)
class OutPort(_Port, Wire):
direction: ClassVar[Direction] = Direction.OUTGOING

def out_port(self) -> OutPort:
return self


class ToNode(Wire, Protocol):
def to_node(self) -> Node: ...

@overload
def __getitem__(self, index: int) -> OutPort: ...
@overload
def __getitem__(self, index: slice) -> Iterator[OutPort]: ...
@overload
def __getitem__(self, index: tuple[int, ...]) -> Iterator[OutPort]: ...

def __getitem__(
self, index: int | slice | tuple[int, ...]
) -> OutPort | Iterator[OutPort]:
return self.to_node()._index(index)

def out_port(self) -> "OutPort":
return OutPort(self.to_node(), 0)

def inp(self, offset: int) -> InPort:
return InPort(self.to_node(), offset)

def out(self, offset: int) -> OutPort:
return OutPort(self.to_node(), offset)

def port(self, offset: int, direction: Direction) -> InPort | OutPort:
if direction == Direction.INCOMING:
return self.inp(offset)
else:
return self.out(offset)


@dataclass(frozen=True, eq=True, order=True)
class Node(ToNode):
idx: int
_num_out_ports: int | None = field(default=None, compare=False)

def _index(
self, index: int | slice | tuple[int, ...]
) -> OutPort | Iterator[OutPort]:
match index:
case int(index):
if self._num_out_ports is not None:
if index >= self._num_out_ports:
raise IndexError("Index out of range")
return self.out(index)
case slice():
start = index.start or 0
stop = index.stop or self._num_out_ports
if stop is None:
raise ValueError(
"Stop must be specified when number of outputs unknown"
)
step = index.step or 1
return (self[i] for i in range(start, stop, step))
case tuple(xs):
return (self[i] for i in xs)

def to_node(self) -> Node:
return self


@dataclass()
class NodeData:
op: Op
Expand All @@ -131,25 +38,15 @@ def to_serial(self, node: Node, hugr: Hugr) -> SerialOp:
return SerialOp(root=o) # type: ignore[arg-type]


_SO = _SubPort[OutPort]
_SI = _SubPort[InPort]

P = TypeVar("P", InPort, OutPort)
K = TypeVar("K", InPort, OutPort)
OpVar = TypeVar("OpVar", bound=Op)
OpVar2 = TypeVar("OpVar2", bound=Op)


@dataclass(frozen=True, eq=True, order=True)
class _SubPort(Generic[P]):
port: P
sub_offset: int = 0

def next_sub_offset(self) -> Self:
return replace(self, sub_offset=self.sub_offset + 1)


_SO = _SubPort[OutPort]
_SI = _SubPort[InPort]


class ParentBuilder(ToNode, Protocol[OpVar]):
hugr: Hugr[OpVar]
parent_node: Node
Expand Down Expand Up @@ -360,6 +257,10 @@ def port_type(self, port: InPort | OutPort) -> Type | None:
op = self[port.node].op
if isinstance(op, DataflowOp):
return op.port_type(port)
if isinstance(op, Call) and isinstance(port, OutPort):
kind = self.port_kind(port)
if isinstance(kind, ValueKind):
return kind.ty
return None

def insert_hugr(self, hugr: Hugr, parent: ToNode | None = None) -> dict[Node, Node]:
Expand Down
115 changes: 115 additions & 0 deletions hugr-py/src/hugr/_node_port.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
from __future__ import annotations

from dataclasses import dataclass, field, replace
from enum import Enum
from typing import (
ClassVar,
Iterator,
Protocol,
overload,
TypeVar,
Generic,
)
from typing_extensions import Self


class Direction(Enum):
INCOMING = 0
OUTGOING = 1


@dataclass(frozen=True, eq=True, order=True)
class _Port:
node: Node
offset: int
direction: ClassVar[Direction]


@dataclass(frozen=True, eq=True, order=True)
class InPort(_Port):
direction: ClassVar[Direction] = Direction.INCOMING


class Wire(Protocol):
def out_port(self) -> OutPort: ...


@dataclass(frozen=True, eq=True, order=True)
class OutPort(_Port, Wire):
direction: ClassVar[Direction] = Direction.OUTGOING

def out_port(self) -> OutPort:
return self


class ToNode(Wire, Protocol):
def to_node(self) -> Node: ...

@overload
def __getitem__(self, index: int) -> OutPort: ...
@overload
def __getitem__(self, index: slice) -> Iterator[OutPort]: ...
@overload
def __getitem__(self, index: tuple[int, ...]) -> Iterator[OutPort]: ...

def __getitem__(
self, index: int | slice | tuple[int, ...]
) -> OutPort | Iterator[OutPort]:
return self.to_node()._index(index)

def out_port(self) -> "OutPort":
return OutPort(self.to_node(), 0)

def inp(self, offset: int) -> InPort:
return InPort(self.to_node(), offset)

def out(self, offset: int) -> OutPort:
return OutPort(self.to_node(), offset)

def port(self, offset: int, direction: Direction) -> InPort | OutPort:
if direction == Direction.INCOMING:
return self.inp(offset)
else:
return self.out(offset)


@dataclass(frozen=True, eq=True, order=True)
class Node(ToNode):
idx: int
_num_out_ports: int | None = field(default=None, compare=False)

def _index(
self, index: int | slice | tuple[int, ...]
) -> OutPort | Iterator[OutPort]:
match index:
case int(index):
if self._num_out_ports is not None:
if index >= self._num_out_ports:
raise IndexError("Index out of range")
return self.out(index)
case slice():
start = index.start or 0
stop = index.stop or self._num_out_ports
if stop is None:
raise ValueError(

Check warning on line 94 in hugr-py/src/hugr/_node_port.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/_node_port.py#L94

Added line #L94 was not covered by tests
"Stop must be specified when number of outputs unknown"
)
step = index.step or 1
return (self[i] for i in range(start, stop, step))
case tuple(xs):
return (self[i] for i in xs)

def to_node(self) -> Node:
return self


P = TypeVar("P", InPort, OutPort)


@dataclass(frozen=True, eq=True, order=True)
class _SubPort(Generic[P]):
port: P
sub_offset: int = 0

def next_sub_offset(self) -> Self:
return replace(self, sub_offset=self.sub_offset + 1)
Loading
Loading