diff --git a/hugr-py/src/hugr/std/__init__.py b/hugr-py/src/hugr/std/__init__.py new file mode 100644 index 000000000..f01d3d421 --- /dev/null +++ b/hugr-py/src/hugr/std/__init__.py @@ -0,0 +1 @@ +"""Types and operations from the standard extension set.""" diff --git a/hugr-py/src/hugr/std/float.py b/hugr-py/src/hugr/std/float.py new file mode 100644 index 000000000..92cb5d79a --- /dev/null +++ b/hugr-py/src/hugr/std/float.py @@ -0,0 +1,25 @@ +"""Floating point types and operations.""" + +from __future__ import annotations + +from dataclasses import dataclass + +from hugr import tys, val + +#: HUGR 64-bit IEEE 754-2019 floating point type. +FLOAT_T = tys.Opaque( + extension="arithmetic.float.types", + id="float64", + args=[], + bound=tys.TypeBound.Copyable, +) + + +@dataclass +class FloatVal(val.ExtensionValue): + """Custom value for a floating point number.""" + + v: float + + def to_value(self) -> val.Extension: + return val.Extension("float", FLOAT_T, self.v) diff --git a/hugr-py/src/hugr/std/int.py b/hugr-py/src/hugr/std/int.py new file mode 100644 index 000000000..20f7b91cd --- /dev/null +++ b/hugr-py/src/hugr/std/int.py @@ -0,0 +1,71 @@ +"""HUGR integer types and operations.""" + +from __future__ import annotations + +from dataclasses import dataclass, field + +from hugr import tys, val +from hugr.ops import Custom + + +def int_t(width: int) -> tys.Opaque: + """Create an integer type with a given log bit width. + + + Args: + width: The log bit width of the integer. + + Returns: + The integer type. + + Examples: + >>> int_t(5).id # 32 bit integer + 'int' + """ + return tys.Opaque( + extension="arithmetic.int.types", + id="int", + args=[tys.BoundedNatArg(n=width)], + bound=tys.TypeBound.Eq, + ) + + +#: HUGR 32-bit integer type. +INT_T = int_t(5) + + +@dataclass +class IntVal(val.ExtensionValue): + """Custom value for an integer.""" + + v: int + + def to_value(self) -> val.Extension: + return val.Extension("int", INT_T, self.v) + + +@dataclass(frozen=True) +class IntOps(Custom): + """Base class for integer operations.""" + + extension: tys.ExtensionId = "arithmetic.int" + + +_ARG_I32 = tys.BoundedNatArg(n=5) + + +@dataclass(frozen=True) +class _DivModDef(IntOps): + """DivMod operation, has two inputs and two outputs.""" + + num_out: int = 2 + extension: tys.ExtensionId = "arithmetic.int" + op_name: str = "idivmod_u" + signature: tys.FunctionType = field( + default_factory=lambda: tys.FunctionType(input=[INT_T] * 2, output=[INT_T] * 2) + ) + args: list[tys.TypeArg] = field(default_factory=lambda: [_ARG_I32, _ARG_I32]) + + +#: DivMod operation. +DivMod = _DivModDef() diff --git a/hugr-py/src/hugr/std/logic.py b/hugr-py/src/hugr/std/logic.py new file mode 100644 index 000000000..b2890e6e4 --- /dev/null +++ b/hugr-py/src/hugr/std/logic.py @@ -0,0 +1,38 @@ +"""HUGR logic operations.""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import TYPE_CHECKING + +from hugr import tys +from hugr.ops import Command, Custom + +if TYPE_CHECKING: + from hugr.ops import ComWire + + +@dataclass(frozen=True) +class LogicOps(Custom): + """Base class for logic operations.""" + + extension: tys.ExtensionId = "logic" + + +_NotSig = tys.FunctionType.endo([tys.Bool]) + + +@dataclass(frozen=True) +class _NotDef(LogicOps): + """Not operation.""" + + num_out: int = 1 + op_name: str = "Not" + signature: tys.FunctionType = _NotSig + + def __call__(self, a: ComWire) -> Command: + return super().__call__(a) + + +#: Not operation +Not = _NotDef() diff --git a/hugr-py/tests/conftest.py b/hugr-py/tests/conftest.py index 9f3ed3e62..e2d67b0c2 100644 --- a/hugr-py/tests/conftest.py +++ b/hugr-py/tests/conftest.py @@ -4,76 +4,19 @@ import os import pathlib import subprocess -from dataclasses import dataclass, field +from dataclasses import dataclass from typing import TYPE_CHECKING -from hugr import tys, val +from hugr import tys from hugr.hugr import Hugr from hugr.ops import Command, Custom from hugr.serialization.serial_hugr import SerialHugr +from hugr.std.float import FLOAT_T if TYPE_CHECKING: from hugr.ops import ComWire -def int_t(width: int) -> tys.Opaque: - return tys.Opaque( - extension="arithmetic.int.types", - id="int", - args=[tys.BoundedNatArg(n=width)], - bound=tys.TypeBound.Eq, - ) - - -INT_T = int_t(5) - - -@dataclass -class IntVal(val.ExtensionValue): - v: int - - def to_value(self) -> val.Extension: - return val.Extension("int", INT_T, self.v) - - -FLOAT_T = tys.Opaque( - extension="arithmetic.float.types", - id="float64", - args=[], - bound=tys.TypeBound.Copyable, -) - - -@dataclass -class FloatVal(val.ExtensionValue): - v: float - - def to_value(self) -> val.Extension: - return val.Extension("float", FLOAT_T, self.v) - - -@dataclass(frozen=True) -class LogicOps(Custom): - extension: tys.ExtensionId = "logic" - - -_NotSig = tys.FunctionType.endo([tys.Bool]) - - -# TODO get from YAML -@dataclass(frozen=True) -class NotDef(LogicOps): - num_out: int = 1 - op_name: str = "Not" - signature: tys.FunctionType = _NotSig - - def __call__(self, a: ComWire) -> Command: - return super().__call__(a) - - -Not = NotDef() - - @dataclass(frozen=True) class QuantumOps(Custom): extension: tys.ExtensionId = "tket2.quantum" @@ -141,28 +84,6 @@ def __call__(self, q: ComWire, fl_wire: ComWire) -> Command: Rz = RzDef() -@dataclass(frozen=True) -class IntOps(Custom): - extension: tys.ExtensionId = "arithmetic.int" - - -ARG_5 = tys.BoundedNatArg(n=5) - - -@dataclass(frozen=True) -class DivModDef(IntOps): - num_out: int = 2 - extension: tys.ExtensionId = "arithmetic.int" - op_name: str = "idivmod_u" - signature: tys.FunctionType = field( - default_factory=lambda: tys.FunctionType(input=[INT_T] * 2, output=[INT_T] * 2) - ) - args: list[tys.TypeArg] = field(default_factory=lambda: [ARG_5, ARG_5]) - - -DivMod = DivModDef() - - def validate(h: Hugr, mermaid: bool = False, roundtrip: bool = True): workspace_dir = pathlib.Path(__file__).parent.parent.parent # use the HUGR_BIN environment variable if set, otherwise use the debug build diff --git a/hugr-py/tests/test_cfg.py b/hugr-py/tests/test_cfg.py index cc3df5b13..a9c1e871d 100644 --- a/hugr-py/tests/test_cfg.py +++ b/hugr-py/tests/test_cfg.py @@ -1,8 +1,9 @@ from hugr import ops, tys, val from hugr.cfg import Cfg from hugr.dfg import Dfg +from hugr.std.int import INT_T, DivMod, IntVal -from .conftest import INT_T, DivMod, IntVal, validate +from .conftest import validate def build_basic_cfg(cfg: Cfg) -> None: diff --git a/hugr-py/tests/test_cond_loop.py b/hugr-py/tests/test_cond_loop.py index 69b777c8d..2f67f3563 100644 --- a/hugr-py/tests/test_cond_loop.py +++ b/hugr-py/tests/test_cond_loop.py @@ -3,8 +3,9 @@ from hugr import ops, tys, val from hugr.cond_loop import Conditional, ConditionalError, TailLoop from hugr.dfg import Dfg +from hugr.std.int import INT_T, IntVal -from .conftest import INT_T, H, IntVal, Measure, validate +from .conftest import H, Measure, validate SUM_T = tys.Sum([[tys.Qubit], [tys.Qubit, INT_T]]) diff --git a/hugr-py/tests/test_hugr_build.py b/hugr-py/tests/test_hugr_build.py index f7ee1bef9..bf327c9f6 100644 --- a/hugr-py/tests/test_hugr_build.py +++ b/hugr-py/tests/test_hugr_build.py @@ -8,8 +8,10 @@ from hugr.hugr import Hugr from hugr.node_port import Node, _SubPort from hugr.ops import NoConcreteFunc +from hugr.std.int import INT_T, DivMod, IntVal +from hugr.std.logic import Not -from .conftest import INT_T, DivMod, IntVal, Not, validate +from .conftest import validate def test_stable_indices(): diff --git a/hugr-py/tests/test_tracked_dfg.py b/hugr-py/tests/test_tracked_dfg.py index b5578cc9f..74388946e 100644 --- a/hugr-py/tests/test_tracked_dfg.py +++ b/hugr-py/tests/test_tracked_dfg.py @@ -1,9 +1,11 @@ import pytest from hugr import tys +from hugr.std.float import FLOAT_T, FloatVal +from hugr.std.logic import Not from hugr.tracked_dfg import TrackedDfg -from .conftest import CX, FLOAT_T, FloatVal, H, Measure, Not, Rz, validate +from .conftest import CX, H, Measure, Rz, validate def test_track_wire():