Skip to content

Commit

Permalink
feat: Add missing typing hints (#352)
Browse files Browse the repository at this point in the history
Adds typing info for the currently binded classes and methods.
Also binds a couple classes that where missing.

Fixes #342.
  • Loading branch information
aborgna-q authored May 24, 2024
1 parent 4b24533 commit 4990613
Show file tree
Hide file tree
Showing 10 changed files with 331 additions and 49 deletions.
4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -81,3 +81,7 @@ features = ["pyo3/extension-module"]
[tool.pytest.ini_options]
# Lark throws deprecation warnings for `src_parse` and `src_constants`.
filterwarnings = "ignore::DeprecationWarning:lark.*"

[tool.pyright]
# Rust bindings have typing stubs but no python source code.
reportMissingModuleSource = "none"
5 changes: 4 additions & 1 deletion tket2-py/src/circuit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,13 @@ pub fn module(py: Python<'_>) -> PyResult<Bound<'_, PyModule>> {
m.add_class::<Dfg>()?;
m.add_class::<PyNode>()?;
m.add_class::<PyWire>()?;
m.add_class::<WireIter>()?;
m.add_class::<PyCircuitCost>()?;
m.add_class::<Tk2Op>()?;
m.add_class::<PyCustom>()?;
m.add_class::<PyHugrType>()?;
m.add_class::<Pauli>()?;
m.add_class::<PyTypeBound>()?;

m.add_function(wrap_pyfunction!(validate_hugr, &m)?)?;
m.add_function(wrap_pyfunction!(to_hugr_dot, &m)?)?;
Expand Down Expand Up @@ -129,7 +131,8 @@ impl fmt::Debug for PyNode {
}

#[pyclass]
struct WireIter {
/// An iterator over the wires of a node.
pub struct WireIter {
node: PyNode,
current: usize,
}
Expand Down
1 change: 1 addition & 0 deletions tket2-py/src/pattern.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ pub fn module(py: Python<'_>) -> PyResult<Bound<'_, PyModule>> {
m.add_class::<self::portmatching::PyCircuitPattern>()?;
m.add_class::<self::portmatching::PyPatternMatcher>()?;
m.add_class::<self::portmatching::PyPatternMatch>()?;
m.add_class::<self::portmatching::PyPatternID>()?;

m.add(
"InvalidPatternError",
Expand Down
13 changes: 13 additions & 0 deletions tket2-py/src/pattern/portmatching.rs
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,19 @@ pub struct PyPatternID {
pub id: PatternID,
}

#[pymethods]
impl PyPatternID {
/// A string representation of the pattern.
pub fn __repr__(&self) -> String {
format!("{:?}", self.id)
}

/// Cast the pattern ID to an integer.
pub fn __int__(&self) -> usize {
self.id.into()
}
}

impl fmt::Display for PyPatternID {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
self.id.fmt(f)
Expand Down
174 changes: 126 additions & 48 deletions tket2-py/tket2/_tket2/circuit.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -7,25 +7,115 @@ class Tk2Circuit:

def __init__(self, circ: Circuit) -> None:
"""Create a Tk2Circuit from a pytket Circuit."""

def __hash__(self) -> int:
"""Compute the circuit hash by traversal."""

def __copy__(self) -> Tk2Circuit:
"""Create a copy of the circuit."""

def __deepcopy__(self) -> Tk2Circuit:
"""Create a deep copy of the circuit."""

def hash(self) -> int:
"""Compute the circuit hash by traversal."""

def circuit_cost(self, cost_fn: Callable[[Tk2Op], Any]) -> int:
"""Compute the cost of the circuit. Return value must implement __add__."""

def node_op(self, node: Node) -> CustomOp:
"""If the node corresponds to a custom op, return it. Otherwise, raise an error."""

def to_tket1(self) -> Circuit:
"""Convert to pytket Circuit."""

def apply_rewrite(self, rw) -> None:
"""Apply a rewrite to the circuit."""

def node_inputs(self, node: Node) -> list[Wire]:
"""The incoming wires to a node."""

def node_outputs(self, node: Node) -> list[Wire]:
"""The outgoing wires from a node."""

def input_node(self) -> Node:
"""The input node of the circuit."""

def output_node(self) -> Node:
"""The output node of the circuit."""

def to_hugr_json(self) -> str:
"""Encode the circuit as a HUGR json string."""

@staticmethod
def from_hugr_json(json: str) -> Tk2Circuit:
"""Decode a HUGR json string to a Tk2Circuit."""

def to_tket1_json(self) -> str:
"""Encode the circuit as a pytket json string."""

@staticmethod
def from_tket1_json(json: str) -> Tk2Circuit:
"""Decode a pytket json string to a Tk2Circuit."""

class Dfg:
"""A builder for a HUGR dataflow graph."""

def __init__(
self,
input_types: list[HugrType],
output_types: list[HugrType],
) -> None:
"""Begin building a dataflow graph with specified input and output types."""

def inputs(self) -> list[Wire]:
"""The output wires of the input node in the DFG, one for each input type."""

def add_op(self, op: CustomOp, wires: list[Wire]) -> Node:
"""Add a custom operation to the DFG, wiring in input wires."""

def finish(self, outputs: list[Wire]) -> Tk2Circuit:
"""Finish building the DFG by wiring in output wires to the output node
(one per output type) and return the resulting circuit."""

class Node:
"""Handle to node in HUGR."""

def outs(self, n: int) -> list[Wire]:
"""Generate n output wires from this node."""

def __getitem__(self, i: int) -> Wire:
"""Get the i-th output wire from this node."""

def __iter__(self) -> Any:
"""Iterate over the output wires from this node."""

class WireIter:
"""Iterator for wires from a node."""

def __iter__(self) -> WireIter:
"""Get the iterator."""

def __next__(self) -> Wire:
"""Get the next wire from the node."""

class Wire:
"""An outgoing edge from a node in a HUGR, defined by the node and outgoing port."""

def node(self) -> Node:
"""Source node of wire."""

def port(self) -> int:
"""Source port of wire."""

class CircuitCost:
"""A cost function for circuits."""

def __init__(self, cost: Any) -> None:
"""Create a new circuit cost.
The cost object must implement __add__, __sub__, __eq__, and __lt__."""

class Tk2Op(Enum):
"""A Tket2 built-in operation."""

Expand All @@ -51,75 +141,63 @@ class Tk2Op(Enum):
QFree = auto()
Reset = auto()

class TypeBound(Enum):
"""HUGR type bounds."""
class CustomOp:
"""A HUGR custom operation."""

Any = 0 # Any type
Copyable = 1 # Copyable type
Eq = 2 # Equality-comparable type
def __init__(
self,
extension: str,
op_name: str,
input_types: list[HugrType],
output_types: list[HugrType],
) -> None:
"""Create a new custom operation from name and input/output types."""

def to_custom(self) -> CustomOp:
"""Convert to a custom operation. Identity operation."""

def name(self) -> str:
"""Fully qualified (include extension) name of the operation."""

class HugrType:
"""Value types in HUGR."""

def __init__(self, extension: str, type_name: str, bound: TypeBound) -> None:
"""Create a new named Custom type."""

@staticmethod
def qubit() -> HugrType:
"""Qubit type from HUGR prelude."""

@staticmethod
def linear_bit() -> HugrType:
"""Linear bit type from TKET1 extension."""

@staticmethod
def bool() -> HugrType:
"""Boolean type (HUGR 2-ary unit sum)."""

class Node:
"""Handle to node in HUGR."""
def outs(self, n: int) -> list[Wire]:
"""Generate n output wires from this node."""
def __getitem__(self, i: int) -> Wire:
"""Get the i-th output wire from this node."""
class Pauli(Enum):
"""Simple enum representation of Pauli matrices."""

class Wire:
"""An outgoing edge from a node in a HUGR, defined by the node and outgoing port."""
def node(self) -> Node:
"""Source node of wire."""

def port(self) -> int:
"""Source port of wire."""

class CustomOp:
"""A HUGR custom operation."""
def __init__(
self,
extension: str,
op_name: str,
input_types: list[HugrType],
output_types: list[HugrType],
) -> None:
"""Create a new custom operation from name and input/output types."""
I = auto() # noqa: E741
X = auto()
Y = auto()
Z = auto()

def to_custom(self) -> CustomOp:
"""Convert to a custom operation. Identity operation."""
def name(self) -> str:
"""Fully qualified (include extension) name of the operation."""
class TypeBound(Enum):
"""HUGR type bounds."""

class Dfg:
"""A builder for a HUGR dataflow graph."""
def __init__(
self,
input_types: list[HugrType],
output_types: list[HugrType],
) -> None:
"""Begin building a dataflow graph with specified input and output types."""
def inputs(self) -> list[Wire]:
"""The output wires of the input node in the DFG, one for each input type."""
def add_op(self, op: CustomOp, wires: list[Wire]) -> Node:
"""Add a custom operation to the DFG, wiring in input wires."""
def finish(self, outputs: list[Wire]) -> Tk2Circuit:
"""Finish building the DFG by wiring in output wires to the output node
(one per output type) and return the resulting circuit."""
Any = 0 # Any type
Copyable = 1 # Copyable type
Eq = 2 # Equality-comparable type

def to_hugr_dot(hugr: Tk2Circuit | Circuit) -> str: ...
def to_hugr_mermaid(hugr: Tk2Circuit | Circuit) -> str: ...
def validate_hugr(hugr: Tk2Circuit | Circuit) -> None: ...

class HugrError(Exception): ...
class BuildError(Exception): ...
class ValidationError(Exception): ...
class HUGRSerializationError(Exception): ...
class OpConvertError(Exception): ...
34 changes: 34 additions & 0 deletions tket2-py/tket2/_tket2/optimiser.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from .circuit import Tk2Circuit
from pytket._tket.circuit import Circuit

from pathlib import Path

class BadgerOptimiser:
@staticmethod
def load_precompiled(filename: Path) -> BadgerOptimiser:
"""Load a precompiled rewriter from a file."""

@staticmethod
def compile_eccs(filename: Path) -> BadgerOptimiser:
"""Compile a set of ECCs and create a new rewriter ."""

def optimise(
self,
circ: Tk2Circuit | Circuit,
timeout: int | None = None,
progress_timeout: int | None = None,
n_threads: int | None = None,
split_circ: bool = False,
queue_size: int | None = None,
log_progress: Path | None = None,
) -> Tk2Circuit | Circuit:
"""Optimise a circuit.
:param circ: The circuit to optimise.
:param timeout: Maximum time to spend on the optimisation.
:param progress_timeout: Maximum time to wait between new best results.
:param n_threads: Number of threads to use.
:param split_circ: Split the circuit into subcircuits and optimise them separately.
:param queue_size: Maximum number of circuits to keep in the queue of candidates.
:param log_progress: Log progress to a CSV file.
"""
52 changes: 52 additions & 0 deletions tket2-py/tket2/_tket2/passes.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
from pathlib import Path

from .optimiser import BadgerOptimiser
from .circuit import Tk2Circuit
from pytket._tket.circuit import Circuit

class CircuitChunks:
def reassemble(self) -> Circuit | Tk2Circuit:
"""Reassemble the circuit from its chunks."""

def circuits(self) -> list[Circuit | Tk2Circuit]:
"""Returns clones of the split circuits."""

def update_circuit(self, index: int, circ: Circuit | Tk2Circuit) -> None:
"""Replace a circuit chunk with a new version."""

class PullForwardError(Exception):
"""Error from a `PullForward` operation."""

def greedy_depth_reduce(circ: Circuit | Tk2Circuit) -> tuple[Circuit | Tk2Circuit, int]:
"""Greedy depth reduction of a circuit.
Returns the reduced circuit and the depth reduction.
"""

def badger_optimise(
circ: Circuit | Tk2Circuit,
optimiser: BadgerOptimiser,
max_threads: int | None = None,
timeout: int | None = None,
progress_timeout: int | None = None,
log_dir: Path | None = None,
rebase: bool = False,
) -> Circuit | Tk2Circuit:
"""Optimise a circuit using the Badger optimiser.
HyperTKET's best attempt at optimising a circuit using circuit rewriting
and the given Badger optimiser.
By default, the input circuit will be rebased to Nam, i.e. CX + Rz + H before
optimising. This can be deactivated by setting `rebase` to `false`, in which
case the circuit is expected to be in the Nam gate set.
Will use at most `max_threads` threads (plus a constant) and take at most
`timeout` seconds (plus a constant). Default to the number of cpus and
15min respectively.
Log files will be written to the directory `log_dir` if specified.
"""

def chunks(c: Circuit | Tk2Circuit, max_chunk_size: int) -> CircuitChunks:
"""Split a circuit into chunks of at most `max_chunk_size` gates."""
Loading

0 comments on commit 4990613

Please sign in to comment.