Skip to content

Commit

Permalink
fix: Make native py modules behave like python's (#212)
Browse files Browse the repository at this point in the history
Cleans up the `tket2-py` structure and makes the modules behave like
normal python ones.

With this we can do
```python
from tket2.circuit import T2Circuit
```

Minor changes:

- Moves some code from `tket2-py/src/lib.rs` to a new `pattern.rs`.
- Deletes some commented-out code.

Closes #209
  • Loading branch information
aborgna-q authored Nov 3, 2023
1 parent 4fe9e38 commit 4220038
Show file tree
Hide file tree
Showing 14 changed files with 197 additions and 389 deletions.
61 changes: 31 additions & 30 deletions tket2-py/src/circuit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,37 @@ use tket2::passes::CircuitChunks;
use tket2::rewrite::CircuitRewrite;
use tket_json_rs::circuit_json::SerialCircuit;

/// The module definition
pub fn module(py: Python) -> PyResult<&PyModule> {
let m = PyModule::new(py, "_circuit")?;
m.add_class::<T2Circuit>()?;
m.add_class::<tket2::T2Op>()?;
m.add_class::<tket2::Pauli>()?;
m.add_class::<tket2::passes::CircuitChunks>()?;

m.add_function(wrap_pyfunction!(validate_hugr, m)?)?;
m.add_function(wrap_pyfunction!(to_hugr_dot, m)?)?;
m.add_function(wrap_pyfunction!(to_hugr, m)?)?;
m.add_function(wrap_pyfunction!(chunks, m)?)?;

m.add("HugrError", py.get_type::<hugr::hugr::PyHugrError>())?;
m.add("BuildError", py.get_type::<hugr::builder::PyBuildError>())?;
m.add(
"ValidationError",
py.get_type::<hugr::hugr::validate::PyValidationError>(),
)?;
m.add(
"HUGRSerializationError",
py.get_type::<hugr::hugr::serialize::PyHUGRSerializationError>(),
)?;
m.add(
"OpConvertError",
py.get_type::<tket2::json::PyOpConvertError>(),
)?;

Ok(m)
}

/// Apply a fallible function expecting a hugr on a pytket circuit.
pub fn try_with_hugr<T, E, F>(circ: Py<PyAny>, f: F) -> PyResult<T>
where
Expand Down Expand Up @@ -79,33 +110,3 @@ impl T2Circuit {
rw.apply(&mut self.0).expect("Apply error.");
}
}
/// circuit module
pub fn add_circuit_module(py: Python, parent: &PyModule) -> PyResult<()> {
let m = PyModule::new(py, "circuit")?;
m.add_class::<T2Circuit>()?;
m.add_class::<tket2::T2Op>()?;
m.add_class::<tket2::Pauli>()?;
m.add_class::<tket2::passes::CircuitChunks>()?;

m.add_function(wrap_pyfunction!(validate_hugr, m)?)?;
m.add_function(wrap_pyfunction!(to_hugr_dot, m)?)?;
m.add_function(wrap_pyfunction!(to_hugr, m)?)?;
m.add_function(wrap_pyfunction!(chunks, m)?)?;

m.add("HugrError", py.get_type::<hugr::hugr::PyHugrError>())?;
m.add("BuildError", py.get_type::<hugr::builder::PyBuildError>())?;
m.add(
"ValidationError",
py.get_type::<hugr::hugr::validate::PyValidationError>(),
)?;
m.add(
"HUGRSerializationError",
py.get_type::<hugr::hugr::serialize::PyHUGRSerializationError>(),
)?;
m.add(
"OpConvertError",
py.get_type::<tket2::json::PyOpConvertError>(),
)?;

parent.add_submodule(m)
}
98 changes: 18 additions & 80 deletions tket2-py/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,94 +3,32 @@

mod circuit;
mod optimiser;
mod pass;
mod passes;
mod pattern;

use circuit::{add_circuit_module, to_hugr, T2Circuit};
use optimiser::add_optimiser_module;
use pass::add_pass_module;

use hugr::Hugr;
use pyo3::prelude::*;
use tket2::portmatching::pyo3::PyPatternMatch;
use tket2::portmatching::{CircuitPattern, PatternMatcher};
use tket2::rewrite::CircuitRewrite;

#[derive(Clone)]
#[pyclass]
/// A rewrite rule defined by a left hand side and right hand side of an equation.
pub struct Rule(pub [Hugr; 2]);

#[pymethods]
impl Rule {
#[new]
fn new_rule(l: PyObject, r: PyObject) -> PyResult<Rule> {
let l = to_hugr(l)?;
let r = to_hugr(r)?;

Ok(Rule([l, r]))
}
}
#[pyclass]
struct RuleMatcher {
matcher: PatternMatcher,
rights: Vec<Hugr>,
}

#[pymethods]
impl RuleMatcher {
#[new]
pub fn from_rules(rules: Vec<Rule>) -> PyResult<Self> {
let (lefts, rights): (Vec<_>, Vec<_>) =
rules.into_iter().map(|Rule([l, r])| (l, r)).unzip();
let patterns: Result<Vec<CircuitPattern>, _> =
lefts.iter().map(CircuitPattern::try_from_circuit).collect();
let matcher = PatternMatcher::from_patterns(patterns?);

Ok(Self { matcher, rights })
}

pub fn find_match(&self, target: &T2Circuit) -> PyResult<Option<CircuitRewrite>> {
let h = &target.0;
let p_match = self.matcher.find_matches_iter(h).next();
if let Some(m) = p_match {
let py_match = PyPatternMatch::try_from_rust(m, h, &self.matcher)?;
let r = self.rights.get(py_match.pattern_id).unwrap().clone();
let rw = py_match.to_rewrite(h, r)?;
Ok(Some(rw))
} else {
Ok(None)
}
}
}

/// The Python bindings to TKET2.
#[pymodule]
#[pyo3(name = "tket2")]
fn tket2_py(py: Python, m: &PyModule) -> PyResult<()> {
add_circuit_module(py, m)?;
add_pattern_module(py, m)?;
add_pass_module(py, m)?;
add_optimiser_module(py, m)?;
add_submodule(py, m, circuit::module(py)?)?;
add_submodule(py, m, optimiser::module(py)?)?;
add_submodule(py, m, passes::module(py)?)?;
add_submodule(py, m, pattern::module(py)?)?;
Ok(())
}

/// portmatching module
fn add_pattern_module(py: Python, parent: &PyModule) -> PyResult<()> {
let m = PyModule::new(py, "pattern")?;
m.add_class::<tket2::portmatching::CircuitPattern>()?;
m.add_class::<tket2::portmatching::PatternMatcher>()?;
m.add_class::<CircuitRewrite>()?;
m.add_class::<Rule>()?;
m.add_class::<RuleMatcher>()?;

m.add(
"InvalidPatternError",
py.get_type::<tket2::portmatching::pattern::PyInvalidPatternError>(),
)?;
m.add(
"InvalidReplacementError",
py.get_type::<hugr::hugr::views::sibling_subgraph::PyInvalidReplacementError>(),
)?;

parent.add_submodule(m)
fn add_submodule(py: Python, parent: &PyModule, submodule: &PyModule) -> PyResult<()> {
parent.add_submodule(submodule)?;

// Add submodule to sys.modules.
// This is required to be able to do `from parent.submodule import ...`.
//
// See [https://github.com/PyO3/pyo3/issues/759]
let parent_name = parent.name()?;
let submodule_name = submodule.name()?;
let modules = py.import("sys")?.getattr("modules")?;
modules.set_item(format!("{parent_name}.{submodule_name}"), submodule)?;
Ok(())
}
9 changes: 4 additions & 5 deletions tket2-py/src/optimiser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,11 @@ use tket2::optimiser::{DefaultTasoOptimiser, TasoLogger};

use crate::circuit::update_hugr;

/// The circuit optimisation module.
pub fn add_optimiser_module(py: Python, parent: &PyModule) -> PyResult<()> {
let m = PyModule::new(py, "optimiser")?;
/// The module definition
pub fn module(py: Python) -> PyResult<&PyModule> {
let m = PyModule::new(py, "_optimiser")?;
m.add_class::<PyTasoOptimiser>()?;

parent.add_submodule(m)
Ok(m)
}

/// Wrapped [`DefaultTasoOptimiser`].
Expand Down
28 changes: 15 additions & 13 deletions tket2-py/src/pass.rs → tket2-py/src/passes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,21 @@ use crate::{
optimiser::PyTasoOptimiser,
};

/// The module definition
///
/// This module is re-exported from the python module with the same name.
pub fn module(py: Python) -> PyResult<&PyModule> {
let m = PyModule::new(py, "_passes")?;
m.add_function(wrap_pyfunction!(greedy_depth_reduce, m)?)?;
m.add_function(wrap_pyfunction!(taso_optimise, m)?)?;
m.add_class::<tket2::T2Op>()?;
m.add(
"PullForwardError",
py.get_type::<tket2::passes::PyPullForwardError>(),
)?;
Ok(m)
}

#[pyfunction]
fn greedy_depth_reduce(py_c: PyObject) -> PyResult<(PyObject, u32)> {
try_with_hugr(py_c, |mut h| {
Expand Down Expand Up @@ -119,16 +134,3 @@ fn taso_optimise(
PyResult::Ok(circ)
})
}

pub(crate) fn add_pass_module(py: Python, parent: &PyModule) -> PyResult<()> {
let m = PyModule::new(py, "passes")?;
m.add_function(wrap_pyfunction!(greedy_depth_reduce, m)?)?;
m.add_function(wrap_pyfunction!(taso_optimise, m)?)?;
m.add_class::<tket2::T2Op>()?;
m.add(
"PullForwardError",
py.get_type::<tket2::passes::PyPullForwardError>(),
)?;
parent.add_submodule(m)?;
Ok(())
}
78 changes: 78 additions & 0 deletions tket2-py/src/pattern.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
//!
use crate::circuit::{to_hugr, T2Circuit};

use hugr::Hugr;
use pyo3::prelude::*;
use tket2::portmatching::pyo3::PyPatternMatch;
use tket2::portmatching::{CircuitPattern, PatternMatcher};
use tket2::rewrite::CircuitRewrite;

/// The module definition
pub fn module(py: Python) -> PyResult<&PyModule> {
let m = PyModule::new(py, "_pattern")?;
m.add_class::<tket2::portmatching::CircuitPattern>()?;
m.add_class::<tket2::portmatching::PatternMatcher>()?;
m.add_class::<CircuitRewrite>()?;
m.add_class::<Rule>()?;
m.add_class::<RuleMatcher>()?;

m.add(
"InvalidPatternError",
py.get_type::<tket2::portmatching::pattern::PyInvalidPatternError>(),
)?;
m.add(
"InvalidReplacementError",
py.get_type::<hugr::hugr::views::sibling_subgraph::PyInvalidReplacementError>(),
)?;

Ok(m)
}

#[derive(Clone)]
#[pyclass]
/// A rewrite rule defined by a left hand side and right hand side of an equation.
pub struct Rule(pub [Hugr; 2]);

#[pymethods]
impl Rule {
#[new]
fn new_rule(l: PyObject, r: PyObject) -> PyResult<Rule> {
let l = to_hugr(l)?;
let r = to_hugr(r)?;

Ok(Rule([l, r]))
}
}
#[pyclass]
struct RuleMatcher {
matcher: PatternMatcher,
rights: Vec<Hugr>,
}

#[pymethods]
impl RuleMatcher {
#[new]
pub fn from_rules(rules: Vec<Rule>) -> PyResult<Self> {
let (lefts, rights): (Vec<_>, Vec<_>) =
rules.into_iter().map(|Rule([l, r])| (l, r)).unzip();
let patterns: Result<Vec<CircuitPattern>, _> =
lefts.iter().map(CircuitPattern::try_from_circuit).collect();
let matcher = PatternMatcher::from_patterns(patterns?);

Ok(Self { matcher, rights })
}

pub fn find_match(&self, target: &T2Circuit) -> PyResult<Option<CircuitRewrite>> {
let h = &target.0;
let p_match = self.matcher.find_matches_iter(h).next();
if let Some(m) = p_match {
let py_match = PyPatternMatch::try_from_rust(m, h, &self.matcher)?;
let r = self.rights.get(py_match.pattern_id).unwrap().clone();
let rw = py_match.to_rewrite(h, r)?;
Ok(Some(rw))
} else {
Ok(None)
}
}
}
Loading

0 comments on commit 4220038

Please sign in to comment.