Skip to content

Commit

Permalink
feat(hugr-py): move std extension types/ops in to std module (#1288)
Browse files Browse the repository at this point in the history
Not complete, see #1287

Quantum operations left in tests as they are not part of std extensions
set (maybe they should be...?)
  • Loading branch information
ss2165 authored Jul 10, 2024
1 parent af38154 commit 7d82245
Show file tree
Hide file tree
Showing 9 changed files with 148 additions and 86 deletions.
1 change: 1 addition & 0 deletions hugr-py/src/hugr/std/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Types and operations from the standard extension set."""
25 changes: 25 additions & 0 deletions hugr-py/src/hugr/std/float.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
"""Floating point types and operations."""

from __future__ import annotations

from dataclasses import dataclass

from hugr import tys, val

#: HUGR 64-bit IEEE 754-2019 floating point type.
FLOAT_T = tys.Opaque(
extension="arithmetic.float.types",
id="float64",
args=[],
bound=tys.TypeBound.Copyable,
)


@dataclass
class FloatVal(val.ExtensionValue):
"""Custom value for a floating point number."""

v: float

def to_value(self) -> val.Extension:
return val.Extension("float", FLOAT_T, self.v)
71 changes: 71 additions & 0 deletions hugr-py/src/hugr/std/int.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
"""HUGR integer types and operations."""

from __future__ import annotations

from dataclasses import dataclass, field

from hugr import tys, val
from hugr.ops import Custom


def int_t(width: int) -> tys.Opaque:
"""Create an integer type with a given log bit width.
Args:
width: The log bit width of the integer.
Returns:
The integer type.
Examples:
>>> int_t(5).id # 32 bit integer
'int'
"""
return tys.Opaque(
extension="arithmetic.int.types",
id="int",
args=[tys.BoundedNatArg(n=width)],
bound=tys.TypeBound.Eq,
)


#: HUGR 32-bit integer type.
INT_T = int_t(5)


@dataclass
class IntVal(val.ExtensionValue):
"""Custom value for an integer."""

v: int

def to_value(self) -> val.Extension:
return val.Extension("int", INT_T, self.v)


@dataclass(frozen=True)
class IntOps(Custom):
"""Base class for integer operations."""

extension: tys.ExtensionId = "arithmetic.int"


_ARG_I32 = tys.BoundedNatArg(n=5)


@dataclass(frozen=True)
class _DivModDef(IntOps):
"""DivMod operation, has two inputs and two outputs."""

num_out: int = 2
extension: tys.ExtensionId = "arithmetic.int"
op_name: str = "idivmod_u"
signature: tys.FunctionType = field(
default_factory=lambda: tys.FunctionType(input=[INT_T] * 2, output=[INT_T] * 2)
)
args: list[tys.TypeArg] = field(default_factory=lambda: [_ARG_I32, _ARG_I32])


#: DivMod operation.
DivMod = _DivModDef()
38 changes: 38 additions & 0 deletions hugr-py/src/hugr/std/logic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
"""HUGR logic operations."""

from __future__ import annotations

from dataclasses import dataclass
from typing import TYPE_CHECKING

from hugr import tys
from hugr.ops import Command, Custom

if TYPE_CHECKING:
from hugr.ops import ComWire


@dataclass(frozen=True)
class LogicOps(Custom):
"""Base class for logic operations."""

extension: tys.ExtensionId = "logic"


_NotSig = tys.FunctionType.endo([tys.Bool])


@dataclass(frozen=True)
class _NotDef(LogicOps):
"""Not operation."""

num_out: int = 1
op_name: str = "Not"
signature: tys.FunctionType = _NotSig

def __call__(self, a: ComWire) -> Command:
return super().__call__(a)


#: Not operation
Not = _NotDef()
85 changes: 3 additions & 82 deletions hugr-py/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,76 +4,19 @@
import os
import pathlib
import subprocess
from dataclasses import dataclass, field
from dataclasses import dataclass
from typing import TYPE_CHECKING

from hugr import tys, val
from hugr import tys
from hugr.hugr import Hugr
from hugr.ops import Command, Custom
from hugr.serialization.serial_hugr import SerialHugr
from hugr.std.float import FLOAT_T

if TYPE_CHECKING:
from hugr.ops import ComWire


def int_t(width: int) -> tys.Opaque:
return tys.Opaque(
extension="arithmetic.int.types",
id="int",
args=[tys.BoundedNatArg(n=width)],
bound=tys.TypeBound.Eq,
)


INT_T = int_t(5)


@dataclass
class IntVal(val.ExtensionValue):
v: int

def to_value(self) -> val.Extension:
return val.Extension("int", INT_T, self.v)


FLOAT_T = tys.Opaque(
extension="arithmetic.float.types",
id="float64",
args=[],
bound=tys.TypeBound.Copyable,
)


@dataclass
class FloatVal(val.ExtensionValue):
v: float

def to_value(self) -> val.Extension:
return val.Extension("float", FLOAT_T, self.v)


@dataclass(frozen=True)
class LogicOps(Custom):
extension: tys.ExtensionId = "logic"


_NotSig = tys.FunctionType.endo([tys.Bool])


# TODO get from YAML
@dataclass(frozen=True)
class NotDef(LogicOps):
num_out: int = 1
op_name: str = "Not"
signature: tys.FunctionType = _NotSig

def __call__(self, a: ComWire) -> Command:
return super().__call__(a)


Not = NotDef()


@dataclass(frozen=True)
class QuantumOps(Custom):
extension: tys.ExtensionId = "tket2.quantum"
Expand Down Expand Up @@ -141,28 +84,6 @@ def __call__(self, q: ComWire, fl_wire: ComWire) -> Command:
Rz = RzDef()


@dataclass(frozen=True)
class IntOps(Custom):
extension: tys.ExtensionId = "arithmetic.int"


ARG_5 = tys.BoundedNatArg(n=5)


@dataclass(frozen=True)
class DivModDef(IntOps):
num_out: int = 2
extension: tys.ExtensionId = "arithmetic.int"
op_name: str = "idivmod_u"
signature: tys.FunctionType = field(
default_factory=lambda: tys.FunctionType(input=[INT_T] * 2, output=[INT_T] * 2)
)
args: list[tys.TypeArg] = field(default_factory=lambda: [ARG_5, ARG_5])


DivMod = DivModDef()


def validate(h: Hugr, mermaid: bool = False, roundtrip: bool = True):
workspace_dir = pathlib.Path(__file__).parent.parent.parent
# use the HUGR_BIN environment variable if set, otherwise use the debug build
Expand Down
3 changes: 2 additions & 1 deletion hugr-py/tests/test_cfg.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from hugr import ops, tys, val
from hugr.cfg import Cfg
from hugr.dfg import Dfg
from hugr.std.int import INT_T, DivMod, IntVal

from .conftest import INT_T, DivMod, IntVal, validate
from .conftest import validate


def build_basic_cfg(cfg: Cfg) -> None:
Expand Down
3 changes: 2 additions & 1 deletion hugr-py/tests/test_cond_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
from hugr import ops, tys, val
from hugr.cond_loop import Conditional, ConditionalError, TailLoop
from hugr.dfg import Dfg
from hugr.std.int import INT_T, IntVal

from .conftest import INT_T, H, IntVal, Measure, validate
from .conftest import H, Measure, validate

SUM_T = tys.Sum([[tys.Qubit], [tys.Qubit, INT_T]])

Expand Down
4 changes: 3 additions & 1 deletion hugr-py/tests/test_hugr_build.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@
from hugr.hugr import Hugr
from hugr.node_port import Node, _SubPort
from hugr.ops import NoConcreteFunc
from hugr.std.int import INT_T, DivMod, IntVal
from hugr.std.logic import Not

from .conftest import INT_T, DivMod, IntVal, Not, validate
from .conftest import validate


def test_stable_indices():
Expand Down
4 changes: 3 additions & 1 deletion hugr-py/tests/test_tracked_dfg.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import pytest

from hugr import tys
from hugr.std.float import FLOAT_T, FloatVal
from hugr.std.logic import Not
from hugr.tracked_dfg import TrackedDfg

from .conftest import CX, FLOAT_T, FloatVal, H, Measure, Not, Rz, validate
from .conftest import CX, H, Measure, Rz, validate


def test_track_wire():
Expand Down

0 comments on commit 7d82245

Please sign in to comment.