-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Moves `tket2-py/src/pattern/rewrite.rs` into `tket2-py/src/rewrite.rs` - Adds bindings for `PyECCRewriter`
- Loading branch information
Showing
7 changed files
with
99 additions
and
26 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
//! PyO3 wrapper for rewriters. | ||
use derive_more::From; | ||
use itertools::Itertools; | ||
use pyo3::prelude::*; | ||
use std::path::PathBuf; | ||
use tket2::rewrite::{CircuitRewrite, ECCRewriter, Rewriter}; | ||
|
||
use crate::circuit::Tk2Circuit; | ||
|
||
/// The module definition | ||
pub fn module(py: Python) -> PyResult<&PyModule> { | ||
let m = PyModule::new(py, "_rewrite")?; | ||
m.add_class::<PyECCRewriter>()?; | ||
m.add_class::<PyCircuitRewrite>()?; | ||
Ok(m) | ||
} | ||
|
||
/// A rewrite rule for circuits. | ||
/// | ||
/// Python equivalent of [`CircuitRewrite`]. | ||
/// | ||
/// [`CircuitRewrite`]: tket2::rewrite::CircuitRewrite | ||
#[pyclass] | ||
#[pyo3(name = "CircuitRewrite")] | ||
#[derive(Debug, Clone, From)] | ||
#[repr(transparent)] | ||
pub struct PyCircuitRewrite { | ||
/// Rust representation of the circuit chunks. | ||
pub rewrite: CircuitRewrite, | ||
} | ||
|
||
#[pymethods] | ||
impl PyCircuitRewrite { | ||
/// Number of nodes added or removed by the rewrite. | ||
/// | ||
/// The difference between the new number of nodes minus the old. A positive | ||
/// number is an increase in node count, a negative number is a decrease. | ||
pub fn node_count_delta(&self) -> isize { | ||
self.rewrite.node_count_delta() | ||
} | ||
|
||
/// The replacement subcircuit. | ||
pub fn replacement(&self) -> Tk2Circuit { | ||
self.rewrite.replacement().clone().into() | ||
} | ||
} | ||
|
||
/// A rewriter based on circuit equivalence classes. | ||
/// | ||
/// In every equivalence class, one circuit is chosen as the representative. | ||
/// Valid rewrites turn a non-representative circuit into its representative, | ||
/// or a representative circuit into any of the equivalent non-representative | ||
#[pyclass(name = "ECCRewriter")] | ||
pub struct PyECCRewriter(ECCRewriter); | ||
|
||
#[pymethods] | ||
impl PyECCRewriter { | ||
/// Load a precompiled ecc rewriter from a file. | ||
#[staticmethod] | ||
pub fn load_precompiled(path: PathBuf) -> PyResult<Self> { | ||
Ok(Self(ECCRewriter::load_binary(path).map_err(|e| { | ||
PyErr::new::<pyo3::exceptions::PyIOError, _>(e.to_string()) | ||
})?)) | ||
} | ||
|
||
/// Returns a list of circuit rewrites that can be applied to the given Tk2Circuit. | ||
pub fn get_rewrites(&self, circ: &Tk2Circuit) -> Vec<PyCircuitRewrite> { | ||
self.0 | ||
.get_rewrites(&circ.hugr) | ||
.into_iter() | ||
.map_into() | ||
.collect() | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,3 @@ | ||
from . import passes, circuit, optimiser, pattern | ||
from . import passes, circuit, optimiser, pattern, rewrite | ||
|
||
__all__ = [circuit, optimiser, passes, pattern] | ||
__all__ = [circuit, optimiser, passes, pattern, rewrite] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
# Re-export native bindings | ||
from .tket2._rewrite import * # noqa: F403 | ||
from .tket2 import _rewrite | ||
|
||
from pathlib import Path | ||
import importlib | ||
|
||
__all__ = [ | ||
"default_ecc_rewriter", | ||
*_rewrite.__all__, | ||
] | ||
|
||
|
||
def default_ecc_rewriter() -> _rewrite.ECCRewriter: | ||
"""Load the default ecc rewriter.""" | ||
# TODO: Cite, explain what this is | ||
rewriter = Path(importlib.resources.files("tket2").joinpath("data/nam_6_3.rwr")) | ||
return _rewrite.ECCRewriter.load_precompiled(rewriter) |