Skip to content

Commit

Permalink
feat: EccRewriter bindings (#251)
Browse files Browse the repository at this point in the history
- Moves `tket2-py/src/pattern/rewrite.rs` into `tket2-py/src/rewrite.rs`
- Adds bindings for `PyECCRewriter`
  • Loading branch information
aborgna-q authored Nov 20, 2023
1 parent 4ff3603 commit 97e2e0a
Show file tree
Hide file tree
Showing 7 changed files with 99 additions and 26 deletions.
2 changes: 1 addition & 1 deletion tket2-py/src/circuit/convert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use tket2::json::TKETDecode;
use tket2::passes::CircuitChunks;
use tket_json_rs::circuit_json::SerialCircuit;

use crate::pattern::rewrite::PyCircuitRewrite;
use crate::rewrite::PyCircuitRewrite;

/// A manager for tket 2 operations on a tket 1 Circuit.
#[pyclass]
Expand Down
2 changes: 2 additions & 0 deletions tket2-py/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ pub mod circuit;
pub mod optimiser;
pub mod passes;
pub mod pattern;
pub mod rewrite;

use pyo3::prelude::*;

Expand All @@ -14,6 +15,7 @@ fn tket2_py(py: Python, m: &PyModule) -> PyResult<()> {
add_submodule(py, m, optimiser::module(py)?)?;
add_submodule(py, m, passes::module(py)?)?;
add_submodule(py, m, pattern::module(py)?)?;
add_submodule(py, m, rewrite::module(py)?)?;
Ok(())
}

Expand Down
5 changes: 1 addition & 4 deletions tket2-py/src/pattern.rs
Original file line number Diff line number Diff line change
@@ -1,20 +1,17 @@
//! Pattern matching on circuits.
pub mod portmatching;
pub mod rewrite;

use crate::circuit::Tk2Circuit;
use crate::rewrite::PyCircuitRewrite;

use hugr::Hugr;
use pyo3::prelude::*;
use tket2::portmatching::{CircuitPattern, PatternMatcher};

use self::rewrite::PyCircuitRewrite;

/// The module definition
pub fn module(py: Python) -> PyResult<&PyModule> {
let m = PyModule::new(py, "_pattern")?;
m.add_class::<self::rewrite::PyCircuitRewrite>()?;
m.add_class::<Rule>()?;
m.add_class::<RuleMatcher>()?;
m.add_class::<self::portmatching::PyCircuitPattern>()?;
Expand Down
19 changes: 0 additions & 19 deletions tket2-py/src/pattern/rewrite.rs

This file was deleted.

75 changes: 75 additions & 0 deletions tket2-py/src/rewrite.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
//! PyO3 wrapper for rewriters.
use derive_more::From;
use itertools::Itertools;
use pyo3::prelude::*;
use std::path::PathBuf;
use tket2::rewrite::{CircuitRewrite, ECCRewriter, Rewriter};

use crate::circuit::Tk2Circuit;

/// The module definition
pub fn module(py: Python) -> PyResult<&PyModule> {
let m = PyModule::new(py, "_rewrite")?;
m.add_class::<PyECCRewriter>()?;
m.add_class::<PyCircuitRewrite>()?;
Ok(m)
}

/// A rewrite rule for circuits.
///
/// Python equivalent of [`CircuitRewrite`].
///
/// [`CircuitRewrite`]: tket2::rewrite::CircuitRewrite
#[pyclass]
#[pyo3(name = "CircuitRewrite")]
#[derive(Debug, Clone, From)]
#[repr(transparent)]
pub struct PyCircuitRewrite {
/// Rust representation of the circuit chunks.
pub rewrite: CircuitRewrite,
}

#[pymethods]
impl PyCircuitRewrite {
/// Number of nodes added or removed by the rewrite.
///
/// The difference between the new number of nodes minus the old. A positive
/// number is an increase in node count, a negative number is a decrease.
pub fn node_count_delta(&self) -> isize {
self.rewrite.node_count_delta()
}

/// The replacement subcircuit.
pub fn replacement(&self) -> Tk2Circuit {
self.rewrite.replacement().clone().into()
}
}

/// A rewriter based on circuit equivalence classes.
///
/// In every equivalence class, one circuit is chosen as the representative.
/// Valid rewrites turn a non-representative circuit into its representative,
/// or a representative circuit into any of the equivalent non-representative
#[pyclass(name = "ECCRewriter")]
pub struct PyECCRewriter(ECCRewriter);

#[pymethods]
impl PyECCRewriter {
/// Load a precompiled ecc rewriter from a file.
#[staticmethod]
pub fn load_precompiled(path: PathBuf) -> PyResult<Self> {
Ok(Self(ECCRewriter::load_binary(path).map_err(|e| {
PyErr::new::<pyo3::exceptions::PyIOError, _>(e.to_string())
})?))
}

/// Returns a list of circuit rewrites that can be applied to the given Tk2Circuit.
pub fn get_rewrites(&self, circ: &Tk2Circuit) -> Vec<PyCircuitRewrite> {
self.0
.get_rewrites(&circ.hugr)
.into_iter()
.map_into()
.collect()
}
}
4 changes: 2 additions & 2 deletions tket2-py/tket2/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from . import passes, circuit, optimiser, pattern
from . import passes, circuit, optimiser, pattern, rewrite

__all__ = [circuit, optimiser, passes, pattern]
__all__ = [circuit, optimiser, passes, pattern, rewrite]
18 changes: 18 additions & 0 deletions tket2-py/tket2/rewrite.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Re-export native bindings
from .tket2._rewrite import * # noqa: F403
from .tket2 import _rewrite

from pathlib import Path
import importlib

__all__ = [
"default_ecc_rewriter",
*_rewrite.__all__,
]


def default_ecc_rewriter() -> _rewrite.ECCRewriter:
"""Load the default ecc rewriter."""
# TODO: Cite, explain what this is
rewriter = Path(importlib.resources.files("tket2").joinpath("data/nam_6_3.rwr"))
return _rewrite.ECCRewriter.load_precompiled(rewriter)

0 comments on commit 97e2e0a

Please sign in to comment.