diff --git a/crates/accelerate/src/filter_op_nodes.rs b/crates/accelerate/src/filter_op_nodes.rs new file mode 100644 index 000000000000..7c41391f3788 --- /dev/null +++ b/crates/accelerate/src/filter_op_nodes.rs @@ -0,0 +1,63 @@ +// This code is part of Qiskit. +// +// (C) Copyright IBM 2024 +// +// This code is licensed under the Apache License, Version 2.0. You may +// obtain a copy of this license in the LICENSE.txt file in the root directory +// of this source tree or at http://www.apache.org/licenses/LICENSE-2.0. +// +// Any modifications or derivative works of this code must retain this +// copyright notice, and modified files need to carry a notice indicating +// that they have been altered from the originals. + +use pyo3::prelude::*; +use pyo3::wrap_pyfunction; + +use qiskit_circuit::dag_circuit::DAGCircuit; +use qiskit_circuit::packed_instruction::PackedInstruction; +use rustworkx_core::petgraph::stable_graph::NodeIndex; + +#[pyfunction] +#[pyo3(name = "filter_op_nodes")] +pub fn py_filter_op_nodes( + py: Python, + dag: &mut DAGCircuit, + predicate: &Bound, +) -> PyResult<()> { + let callable = |node: NodeIndex| -> PyResult { + let dag_op_node = dag.get_node(py, node)?; + predicate.call1((dag_op_node,))?.extract() + }; + let mut remove_nodes: Vec = Vec::new(); + for node in dag.op_nodes(true) { + if !callable(node)? { + remove_nodes.push(node); + } + } + for node in remove_nodes { + dag.remove_op_node(node); + } + Ok(()) +} + +/// Remove any nodes that have the provided label set +/// +/// Args: +/// dag (DAGCircuit): The dag circuit to filter the ops from +/// label (str): The label to filter nodes on +#[pyfunction] +pub fn filter_labeled_op(dag: &mut DAGCircuit, label: String) { + let predicate = |node: &PackedInstruction| -> bool { + match node.label() { + Some(inst_label) => inst_label != label, + None => false, + } + }; + dag.filter_op_nodes(predicate); +} + +pub fn filter_op_nodes_mod(m: &Bound) -> PyResult<()> { + m.add_wrapped(wrap_pyfunction!(py_filter_op_nodes))?; + m.add_wrapped(wrap_pyfunction!(filter_labeled_op))?; + Ok(()) +} diff --git a/crates/accelerate/src/lib.rs b/crates/accelerate/src/lib.rs index e8760ee2c616..78eea97faad0 100644 --- a/crates/accelerate/src/lib.rs +++ b/crates/accelerate/src/lib.rs @@ -22,6 +22,7 @@ pub mod dense_layout; pub mod edge_collections; pub mod error_map; pub mod euler_one_qubit_decomposer; +pub mod filter_op_nodes; pub mod isometry; pub mod nlayout; pub mod optimize_1q_gates; diff --git a/crates/accelerate/src/remove_diagonal_gates_before_measure.rs b/crates/accelerate/src/remove_diagonal_gates_before_measure.rs index 10916a77fca8..cf2c738f131a 100644 --- a/crates/accelerate/src/remove_diagonal_gates_before_measure.rs +++ b/crates/accelerate/src/remove_diagonal_gates_before_measure.rs @@ -49,7 +49,9 @@ fn run_remove_diagonal_before_measure(dag: &mut DAGCircuit) -> PyResult<()> { let mut nodes_to_remove = Vec::new(); for index in dag.op_nodes(true) { let node = &dag.dag[index]; - let NodeType::Operation(inst) = node else {panic!()}; + let NodeType::Operation(inst) = node else { + panic!() + }; if inst.op.name() == "measure" { let predecessor = (dag.quantum_predecessors(index)) diff --git a/crates/circuit/src/dag_circuit.rs b/crates/circuit/src/dag_circuit.rs index 381ef25b7a7e..32b3a77ed24c 100644 --- a/crates/circuit/src/dag_circuit.rs +++ b/crates/circuit/src/dag_circuit.rs @@ -5794,6 +5794,25 @@ impl DAGCircuit { } } + // Filter any nodes that don't match a given predicate function + pub fn filter_op_nodes(&mut self, mut predicate: F) + where + F: FnMut(&PackedInstruction) -> bool, + { + let mut remove_nodes: Vec = Vec::new(); + for node in self.op_nodes(true) { + let NodeType::Operation(op) = &self.dag[node] else { + unreachable!() + }; + if !predicate(op) { + remove_nodes.push(node); + } + } + for node in remove_nodes { + self.remove_op_node(node); + } + } + pub fn op_nodes_by_py_type<'a>( &'a self, op: &'a Bound, diff --git a/crates/circuit/src/packed_instruction.rs b/crates/circuit/src/packed_instruction.rs index 77ca0c6c02dd..df8f9801314a 100644 --- a/crates/circuit/src/packed_instruction.rs +++ b/crates/circuit/src/packed_instruction.rs @@ -553,6 +553,13 @@ impl PackedInstruction { .and_then(|extra| extra.condition.as_ref()) } + #[inline] + pub fn label(&self) -> Option<&str> { + self.extra_attrs + .as_ref() + .and_then(|extra| extra.label.as_deref()) + } + /// Build a reference to the Python-space operation object (the `Gate`, etc) packed into this /// instruction. This may construct the reference if the `PackedInstruction` is a standard /// gate with no already stored operation. diff --git a/crates/pyext/src/lib.rs b/crates/pyext/src/lib.rs index 1478fb367a13..49e44bffa2ec 100644 --- a/crates/pyext/src/lib.rs +++ b/crates/pyext/src/lib.rs @@ -16,8 +16,9 @@ use qiskit_accelerate::{ circuit_library::circuit_library, commutation_analysis::commutation_analysis, commutation_checker::commutation_checker, convert_2q_block_matrix::convert_2q_block_matrix, dense_layout::dense_layout, error_map::error_map, - euler_one_qubit_decomposer::euler_one_qubit_decomposer, isometry::isometry, nlayout::nlayout, - optimize_1q_gates::optimize_1q_gates, pauli_exp_val::pauli_expval, + euler_one_qubit_decomposer::euler_one_qubit_decomposer, filter_op_nodes::filter_op_nodes_mod, + isometry::isometry, nlayout::nlayout, optimize_1q_gates::optimize_1q_gates, + pauli_exp_val::pauli_expval, remove_diagonal_gates_before_measure::remove_diagonal_gates_before_measure, results::results, sabre::sabre, sampled_exp_val::sampled_exp_val, sparse_pauli_op::sparse_pauli_op, star_prerouting::star_prerouting, stochastic_swap::stochastic_swap, synthesis::synthesis, @@ -46,6 +47,7 @@ fn _accelerate(m: &Bound) -> PyResult<()> { add_submodule(m, dense_layout, "dense_layout")?; add_submodule(m, error_map, "error_map")?; add_submodule(m, euler_one_qubit_decomposer, "euler_one_qubit_decomposer")?; + add_submodule(m, filter_op_nodes_mod, "filter_op_nodes")?; add_submodule(m, isometry, "isometry")?; add_submodule(m, nlayout, "nlayout")?; add_submodule(m, optimize_1q_gates, "optimize_1q_gates")?; diff --git a/qiskit/__init__.py b/qiskit/__init__.py index d9979c9d4d92..3cc10bf96a31 100644 --- a/qiskit/__init__.py +++ b/qiskit/__init__.py @@ -92,6 +92,7 @@ sys.modules["qiskit._accelerate.commutation_checker"] = _accelerate.commutation_checker sys.modules["qiskit._accelerate.commutation_analysis"] = _accelerate.commutation_analysis sys.modules["qiskit._accelerate.synthesis.linear_phase"] = _accelerate.synthesis.linear_phase +sys.modules["qiskit._accelerate.filter_op_nodes"] = _accelerate.filter_op_nodes from qiskit.exceptions import QiskitError, MissingOptionalLibraryError diff --git a/qiskit/transpiler/passes/utils/filter_op_nodes.py b/qiskit/transpiler/passes/utils/filter_op_nodes.py index 344d2280e3f4..75b824332aee 100644 --- a/qiskit/transpiler/passes/utils/filter_op_nodes.py +++ b/qiskit/transpiler/passes/utils/filter_op_nodes.py @@ -18,6 +18,8 @@ from qiskit.transpiler.basepasses import TransformationPass from qiskit.transpiler.passes.utils import control_flow +from qiskit._accelerate.filter_op_nodes import filter_op_nodes + class FilterOpNodes(TransformationPass): """Remove all operations that match a filter function @@ -59,7 +61,5 @@ def __init__(self, predicate: Callable[[DAGOpNode], bool]): @control_flow.trivial_recurse def run(self, dag: DAGCircuit) -> DAGCircuit: """Run the RemoveBarriers pass on `dag`.""" - for node in dag.op_nodes(): - if not self.predicate(node): - dag.remove_op_node(node) + filter_op_nodes(dag, self.predicate) return dag