Skip to content

Commit

Permalink
feat(hugr-py): builder for function definition/declaration and call
Browse files Browse the repository at this point in the history
  • Loading branch information
ss2165 committed Jun 21, 2024
1 parent afadde1 commit 197c52c
Show file tree
Hide file tree
Showing 7 changed files with 306 additions and 16 deletions.
25 changes: 24 additions & 1 deletion hugr-py/src/hugr/_dfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
from hugr._tys import Type, TypeRow, get_first_sum, FunctionType, TypeArg, FunctionKind

from ._exceptions import NoSiblingAncestor
from ._hugr import Hugr, Node, OutPort, ParentBuilder, ToNode, Wire
from ._hugr import Hugr, ParentBuilder
from ._node_port import Node, OutPort, Wire, ToNode

if TYPE_CHECKING:
from ._cfg import Cfg
Expand Down Expand Up @@ -164,6 +165,28 @@ def load(self, const: ToNode | val.Value) -> Node:

return load

def call(
self,
func: ToNode,
*args: Wire,
instantiation: FunctionType | None = None,
type_args: list[TypeArg] | None = None,
) -> Node:
f_op = self.hugr[func]
f_kind = f_op.op.port_kind(func.out(0))
match f_kind:
case FunctionKind(sig):
signature = sig
case _:
raise ValueError(f"Expected function type, got {f_kind}")
call_op = ops.Call(signature, instantiation, type_args)
call_n = self.hugr.add_node(call_op, self.parent_node, call_op.num_out)
self.hugr.add_link(func.out(0), call_n.inp(call_op.function_port_offset()))

self._wire_up(call_n, args)

return call_n

def _wire_up(self, node: Node, ports: Iterable[Wire]) -> TypeRow:
tys = [self._wire_up_port(node, i, p) for i, p in enumerate(ports)]
if isinstance(op := self.hugr[node].op, ops.PartialOp):
Expand Down
49 changes: 49 additions & 0 deletions hugr-py/src/hugr/_function.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from __future__ import annotations

from dataclasses import dataclass

import hugr._ops as ops
import hugr._val as val

from ._dfg import _DfBase
from hugr._node_port import Node
from ._hugr import Hugr
from ._tys import TypeRow, TypeParam, PolyFuncType


@dataclass
class Function(_DfBase[ops.FuncDefn]):
def __init__(
self,
name: str,
input_types: TypeRow,
type_params: list[TypeParam] | None = None,
) -> None:
root_op = ops.FuncDefn(name, input_types, type_params or [])
super().__init__(root_op)


@dataclass
class Module:
hugr: Hugr

def __init__(self) -> None:
self.hugr = Hugr(ops.Module())

def define_function(
self,
name: str,
input_types: TypeRow,
type_params: list[TypeParam] | None = None,
) -> Function:
parent_op = ops.FuncDefn(name, input_types, type_params or [])
return Function.new_nested(parent_op, self.hugr)

def define_main(self, input_types: TypeRow) -> Function:
return self.define_function("main", input_types)

def declare_function(self, name: str, signature: PolyFuncType) -> Node:
return self.hugr.add_node(ops.FuncDecl(name, signature), self.hugr.root)

def add_const(self, value: val.Value) -> Node:
return self.hugr.add_node(ops.Const(value), self.hugr.root)
8 changes: 6 additions & 2 deletions hugr-py/src/hugr/_hugr.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
)


from hugr._ops import Op, DataflowOp, Const
from hugr._tys import Type, Kind
from hugr._ops import Op, DataflowOp, Const, Call
from hugr._tys import Type, Kind, ValueKind
from hugr._val import Value
from hugr._node_port import Direction, InPort, OutPort, ToNode, Node, _SubPort
from hugr.serialization.ops import OpType as SerialOp
Expand Down Expand Up @@ -257,6 +257,10 @@ def port_type(self, port: InPort | OutPort) -> Type | None:
op = self[port.node].op
if isinstance(op, DataflowOp):
return op.port_type(port)
if isinstance(op, Call) and isinstance(port, OutPort):
kind = self.port_kind(port)
assert isinstance(kind, ValueKind)
return kind.ty
return None

def insert_hugr(self, hugr: Hugr, parent: ToNode | None = None) -> dict[Node, Node]:
Expand Down
175 changes: 163 additions & 12 deletions hugr-py/src/hugr/_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,15 @@
from hugr._hugr import Hugr


@dataclass
class InvalidPort(Exception):
port: InPort | OutPort

@property
def msg(self) -> str:
return f"Invalid port {self.port}"


@runtime_checkable
class Op(Protocol):
@property
Expand All @@ -25,6 +34,14 @@ def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> BaseOp: ...
def port_kind(self, port: InPort | OutPort) -> tys.Kind: ...


def _sig_port_type(sig: tys.FunctionType, port: InPort | OutPort) -> tys.Type:
from hugr._hugr import Direction

if port.direction == Direction.INCOMING:
return sig.input[port.offset]
return sig.output[port.offset]


@runtime_checkable
class DataflowOp(Op, Protocol):
def outer_signature(self) -> tys.FunctionType: ...
Expand All @@ -35,12 +52,7 @@ def port_kind(self, port: InPort | OutPort) -> tys.Kind:
return tys.ValueKind(self.port_type(port))

def port_type(self, port: InPort | OutPort) -> tys.Type:
from hugr._hugr import Direction

sig = self.outer_signature()
if port.direction == Direction.INCOMING:
return sig.input[port.offset]
return sig.output[port.offset]
return _sig_port_type(self.outer_signature(), port)

def __call__(self, *args) -> Command:
return Command(self, list(args))
Expand Down Expand Up @@ -359,7 +371,11 @@ def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.Const:
)

def port_kind(self, port: InPort | OutPort) -> tys.Kind:
return tys.ConstKind(self.val.type_())
match port:
case OutPort(_, 0):
return tys.ConstKind(self.val.type_())
case _:
raise InvalidPort(port)


@dataclass
Expand All @@ -378,6 +394,15 @@ def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.LoadConstant:
def outer_signature(self) -> tys.FunctionType:
return tys.FunctionType(input=[], output=[self.type_()])

def port_kind(self, port: InPort | OutPort) -> tys.Kind:
match port:
case InPort(_, 0):
return tys.ConstKind(self.type_())
case OutPort(_, 0):
return tys.ValueKind(self.type_())
case _:
raise InvalidPort(port)


@dataclass()
class Conditional(DataflowOp):
Expand Down Expand Up @@ -417,15 +442,12 @@ def nth_inputs(self, n: int) -> tys.TypeRow:
class Case(DfParentOp):
inputs: tys.TypeRow
_outputs: tys.TypeRow | None = None
num_out: int | None = 0

@property
def outputs(self) -> tys.TypeRow:
return _check_complete(self._outputs)

@property
def num_out(self) -> int | None:
return 0

def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.Case:
return sops.Case(
parent=parent.idx, signature=self.inner_signature().to_serial()
Expand All @@ -435,7 +457,7 @@ def inner_signature(self) -> tys.FunctionType:
return tys.FunctionType(self.inputs, self.outputs)

def port_kind(self, port: InPort | OutPort) -> tys.Kind:
raise NotImplementedError("Case nodes have no external ports.")
raise InvalidPort(port)

def _set_out_types(self, types: tys.TypeRow) -> None:
self._outputs = types
Expand Down Expand Up @@ -486,3 +508,132 @@ def _set_out_types(self, types: tys.TypeRow) -> None:

def _inputs(self) -> tys.TypeRow:
return self.just_inputs + self.rest


@dataclass
class FuncDefn(DfParentOp):
name: str
inputs: tys.TypeRow
params: list[tys.TypeParam] = field(default_factory=list)
_outputs: tys.TypeRow | None = None
num_out: int | None = 1

@property
def outputs(self) -> tys.TypeRow:
return _check_complete(self._outputs)

@property
def signature(self) -> tys.PolyFuncType:
return tys.PolyFuncType(
self.params, tys.FunctionType(self.inputs, self.outputs)
)

def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.FuncDefn:
return sops.FuncDefn(
parent=parent.idx,
name=self.name,
signature=self.signature.to_serial(),
)

def inner_signature(self) -> tys.FunctionType:
return self.signature.body

def _set_out_types(self, types: tys.TypeRow) -> None:
self._outputs = types

def _inputs(self) -> tys.TypeRow:
return self.inputs

def port_kind(self, port: InPort | OutPort) -> tys.Kind:
match port:
case OutPort(_, 0):
return tys.FunctionKind(self.signature)
case _:
raise InvalidPort(port)


@dataclass
class FuncDecl(Op):
name: str
signature: tys.PolyFuncType
num_out: int | None = 0

def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.FuncDecl:
return sops.FuncDecl(
parent=parent.idx,
name=self.name,
signature=self.signature.to_serial(),
)

def port_kind(self, port: InPort | OutPort) -> tys.Kind:
match port:
case OutPort(_, 0):
return tys.FunctionKind(self.signature)
case _:
raise InvalidPort(port)


@dataclass
class Module(Op):
num_out: int | None = 0

def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.Module:
return sops.Module(parent=parent.idx)

def port_kind(self, port: InPort | OutPort) -> tys.Kind:
raise InvalidPort(port)


class NoConcreteFunc(Exception):
pass


@dataclass
class Call(Op):
signature: tys.PolyFuncType
instantiation: tys.FunctionType
type_args: list[tys.TypeArg]

def __init__(
self,
signature: tys.PolyFuncType,
instantiation: tys.FunctionType | None = None,
type_args: list[tys.TypeArg] | None = None,
) -> None:
self.signature = signature
if len(signature.params) == 0:
self.instantiation = signature.body
self.type_args = []

else:
# TODO substitute type args into signature to get instantiation
if instantiation is None:
raise NoConcreteFunc("Missing instantiation for polymorphic function.")
type_args = type_args or []

if len(signature.params) != len(type_args):
raise NoConcreteFunc("Mismatched number of type arguments.")
self.instantiation = instantiation
self.type_args = type_args

def to_serial(self, node: Node, parent: Node, hugr: Hugr) -> sops.Call:
return sops.Call(
parent=parent.idx,
func_sig=self.signature.to_serial(),
type_args=ser_it(self.type_args),
instantiation=self.instantiation.to_serial(),
)

@property
def num_out(self) -> int | None:
return len(self.signature.body.output)

def function_port_offset(self) -> int:
return len(self.signature.body.input)

def port_kind(self, port: InPort | OutPort) -> tys.Kind:
match port:
case InPort(_, offset) if offset == self.function_port_offset():
return tys.FunctionKind(self.signature)
case _:
return tys.ValueKind(_sig_port_type(self.instantiation, port))
3 changes: 3 additions & 0 deletions hugr-py/src/hugr/_tys.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ def to_serial(self) -> stys.BaseType: ...
def to_serial_root(self) -> stys.Type:
return stys.Type(root=self.to_serial()) # type: ignore[arg-type]

def type_arg(self) -> TypeTypeArg:
return TypeTypeArg(self)


TypeRow = list[Type]

Expand Down
Loading

0 comments on commit 197c52c

Please sign in to comment.