From 4e573f36febd01e9147b037f2cb485c232febae7 Mon Sep 17 00:00:00 2001 From: Raynel Sanchez <87539502+raynelfss@users.noreply.github.com> Date: Mon, 7 Oct 2024 18:55:18 -0400 Subject: [PATCH] [Oxidize BasisTranslator]: Move `basis_search` and `BasisSearchVisitor` to rust. (#12811) * Add: Basis search function - Add rust counterpart for `basis_search`. - Consolidated the `BasisSearchVisitor` into the function due to differences in rust behavior. * Fix: Wrong return value for `basis_search` * Fix: Remove `IndexMap` and duplicate declarations. * Fix: Adapt to #12730 * Remove: unused imports * Docs: Edit docstring for rust native `basis_search` * Fix: Use owned Strings. - Due to the nature of `hashbrown` we must use owned Strings instead of `&str`. * Add: mutable graph view that the `BasisTranslator` can access in Rust. - Remove import of `random` in `basis_translator`. * Fix: Review comments - Rename `EquivalenceLibrary`'s `mut_graph` method to `graph_mut` to keep consistent with rust naming conventions. - Use `&HashSet` instead of `HashSet<&str>` to avoid extra conversion. - Use `u32::MAX` as num_qubits for dummy node. - Use for loop instead of foreachj to add edges to dummy node. - Add comment explaining usage of flatten in `initialize_num_gates_remain_for_rule`. - Remove stale comments. * Update crates/accelerate/src/basis/basis_translator/basis_search.rs --------- Co-authored-by: Matthew Treinish --- .../basis/basis_translator/basis_search.rs | 217 ++++++++++++++++++ .../src/basis/basis_translator/mod.rs | 2 + crates/accelerate/src/equivalence.rs | 5 + qiskit/__init__.py | 2 + .../passes/basis/basis_translator.py | 166 +------------- 5 files changed, 229 insertions(+), 163 deletions(-) create mode 100644 crates/accelerate/src/basis/basis_translator/basis_search.rs diff --git a/crates/accelerate/src/basis/basis_translator/basis_search.rs b/crates/accelerate/src/basis/basis_translator/basis_search.rs new file mode 100644 index 000000000000..2810765db741 --- /dev/null +++ b/crates/accelerate/src/basis/basis_translator/basis_search.rs @@ -0,0 +1,217 @@ +// This code is part of Qiskit. +// +// (C) Copyright IBM 2024 +// +// This code is licensed under the Apache License, Version 2.0. You may +// obtain a copy of this license in the LICENSE.txt file in the root directory +// of this source tree or at http://www.apache.org/licenses/LICENSE-2.0. +// +// Any modifications or derivative works of this code must retain this +// copyright notice, and modified files need to carry a notice indicating +// that they have been altered from the originals. + +use std::cell::RefCell; + +use hashbrown::{HashMap, HashSet}; +use pyo3::prelude::*; + +use crate::equivalence::{EdgeData, Equivalence, EquivalenceLibrary, Key, NodeData}; +use qiskit_circuit::operations::Operation; +use rustworkx_core::petgraph::stable_graph::{EdgeReference, NodeIndex, StableDiGraph}; +use rustworkx_core::petgraph::visit::Control; +use rustworkx_core::traversal::{dijkstra_search, DijkstraEvent}; + +use super::compose_transforms::{BasisTransformIn, GateIdentifier}; + +/// Search for a set of transformations from source_basis to target_basis. +/// Args: +/// equiv_lib (EquivalenceLibrary): Source of valid translations +/// source_basis (Set[Tuple[gate_name: str, gate_num_qubits: int]]): Starting basis. +/// target_basis (Set[gate_name: str]): Target basis. +/// +/// Returns: +/// Optional[List[Tuple[gate, equiv_params, equiv_circuit]]]: List of (gate, +/// equiv_params, equiv_circuit) tuples tuples which, if applied in order +/// will map from source_basis to target_basis. Returns None if no path +/// was found. +#[pyfunction] +#[pyo3(name = "basis_search")] +pub(crate) fn py_basis_search( + py: Python, + equiv_lib: &mut EquivalenceLibrary, + source_basis: HashSet, + target_basis: HashSet, +) -> PyObject { + basis_search(equiv_lib, &source_basis, &target_basis).into_py(py) +} + +type BasisTransforms = Vec<(GateIdentifier, BasisTransformIn)>; +/// Search for a set of transformations from source_basis to target_basis. +/// +/// Performs a Dijkstra search algorithm on the `EquivalenceLibrary`'s core graph +/// to rate and classify different possible equivalent circuits to the provided gates. +/// +/// This is done by connecting all the nodes represented in the `target_basis` to a dummy +/// node, and then traversing the graph until all the nodes described in the `source +/// basis` are reached. +pub(crate) fn basis_search( + equiv_lib: &mut EquivalenceLibrary, + source_basis: &HashSet, + target_basis: &HashSet, +) -> Option { + // Build the visitor attributes: + let mut num_gates_remaining_for_rule: HashMap = HashMap::default(); + let predecessors: RefCell> = + RefCell::new(HashMap::default()); + let opt_cost_map: RefCell> = RefCell::new(HashMap::default()); + let mut basis_transforms: Vec<(GateIdentifier, BasisTransformIn)> = vec![]; + + // Initialize visitor attributes: + initialize_num_gates_remain_for_rule(equiv_lib.graph(), &mut num_gates_remaining_for_rule); + + let mut source_basis_remain: HashSet = source_basis + .iter() + .filter_map(|(gate_name, gate_num_qubits)| { + if !target_basis.contains(gate_name) { + Some(Key { + name: gate_name.to_string(), + num_qubits: *gate_num_qubits, + }) + } else { + None + } + }) + .collect(); + + // If source_basis is empty, no work needs to be done. + if source_basis_remain.is_empty() { + return Some(vec![]); + } + + // This is only necessary since gates in target basis are currently reported by + // their names and we need to have in addition the number of qubits they act on. + let target_basis_keys: Vec = equiv_lib + .keys() + .filter(|&key| target_basis.contains(key.name.as_str())) + .cloned() + .collect(); + + // Dummy node is inserted in the graph. Which is where the search will start + let dummy: NodeIndex = equiv_lib.graph_mut().add_node(NodeData { + equivs: vec![], + key: Key { + name: "key".to_string(), + num_qubits: u32::MAX, + }, + }); + + // Extract indices for the target_basis gates, to avoid borrowing from graph. + let target_basis_indices: Vec = target_basis_keys + .iter() + .map(|key| equiv_lib.node_index(key)) + .collect(); + + // Connect each edge in the target_basis to the dummy node. + for node in target_basis_indices { + equiv_lib.graph_mut().add_edge(dummy, node, None); + } + + // Edge cost function for Visitor + let edge_weight = |edge: EdgeReference>| -> Result { + if edge.weight().is_none() { + return Ok(1); + } + let edge_data = edge.weight().as_ref().unwrap(); + let mut cost_tot = 0; + let borrowed_cost = opt_cost_map.borrow(); + for instruction in edge_data.rule.circuit.0.iter() { + let instruction_op = instruction.op.view(); + cost_tot += borrowed_cost[&( + instruction_op.name().to_string(), + instruction_op.num_qubits(), + )]; + } + Ok(cost_tot + - borrowed_cost[&( + edge_data.source.name.to_string(), + edge_data.source.num_qubits, + )]) + }; + + let basis_transforms = match dijkstra_search( + &equiv_lib.graph(), + [dummy], + edge_weight, + |event: DijkstraEvent, u32>| { + match event { + DijkstraEvent::Discover(n, score) => { + let gate_key = &equiv_lib.graph()[n].key; + let gate = (gate_key.name.to_string(), gate_key.num_qubits); + source_basis_remain.remove(gate_key); + let mut borrowed_cost_map = opt_cost_map.borrow_mut(); + if let Some(entry) = borrowed_cost_map.get_mut(&gate) { + *entry = score; + } else { + borrowed_cost_map.insert(gate.clone(), score); + } + if let Some(rule) = predecessors.borrow().get(&gate) { + basis_transforms.push(( + (gate_key.name.to_string(), gate_key.num_qubits), + (rule.params.clone(), rule.circuit.clone()), + )); + } + + if source_basis_remain.is_empty() { + basis_transforms.reverse(); + return Control::Break(()); + } + } + DijkstraEvent::EdgeRelaxed(_, target, Some(edata)) => { + let gate = &equiv_lib.graph()[target].key; + predecessors + .borrow_mut() + .entry((gate.name.to_string(), gate.num_qubits)) + .and_modify(|value| *value = edata.rule.clone()) + .or_insert(edata.rule.clone()); + } + DijkstraEvent::ExamineEdge(_, target, Some(edata)) => { + num_gates_remaining_for_rule + .entry(edata.index) + .and_modify(|val| *val -= 1) + .or_insert(0); + let target = &equiv_lib.graph()[target].key; + + // If there are gates in this `rule` that we have not yet generated, we can't apply + // this `rule`. if `target` is already in basis, it's not beneficial to use this rule. + if num_gates_remaining_for_rule[&edata.index] > 0 + || target_basis_keys.contains(target) + { + return Control::Prune; + } + } + _ => {} + }; + Control::Continue + }, + ) { + Ok(Control::Break(_)) => Some(basis_transforms), + _ => None, + }; + equiv_lib.graph_mut().remove_node(dummy); + basis_transforms +} + +fn initialize_num_gates_remain_for_rule( + graph: &StableDiGraph>, + source: &mut HashMap, +) { + let mut save_index = usize::MAX; + // When iterating over the edges, ignore any none-valued ones by calling `flatten` + for edge_data in graph.edge_weights().flatten() { + if save_index == edge_data.index { + continue; + } + source.insert(edge_data.index, edge_data.num_gates); + save_index = edge_data.index; + } +} diff --git a/crates/accelerate/src/basis/basis_translator/mod.rs b/crates/accelerate/src/basis/basis_translator/mod.rs index b97f4e37c4b5..18970065267c 100644 --- a/crates/accelerate/src/basis/basis_translator/mod.rs +++ b/crates/accelerate/src/basis/basis_translator/mod.rs @@ -12,10 +12,12 @@ use pyo3::prelude::*; +pub mod basis_search; mod compose_transforms; #[pymodule] pub fn basis_translator(m: &Bound) -> PyResult<()> { + m.add_wrapped(wrap_pyfunction!(basis_search::py_basis_search))?; m.add_wrapped(wrap_pyfunction!(compose_transforms::py_compose_transforms))?; Ok(()) } diff --git a/crates/accelerate/src/equivalence.rs b/crates/accelerate/src/equivalence.rs index 55e1b0336ae4..7ea9161a2a1f 100644 --- a/crates/accelerate/src/equivalence.rs +++ b/crates/accelerate/src/equivalence.rs @@ -697,6 +697,11 @@ impl EquivalenceLibrary { pub fn graph(&self) -> &GraphType { &self.graph } + + /// Expose a mutable view of the inner graph. + pub(crate) fn graph_mut(&mut self) -> &mut GraphType { + &mut self.graph + } } fn raise_if_param_mismatch( diff --git a/qiskit/__init__.py b/qiskit/__init__.py index e20e5a4284b4..196b10da3183 100644 --- a/qiskit/__init__.py +++ b/qiskit/__init__.py @@ -56,6 +56,8 @@ sys.modules["qiskit._accelerate.basis"] = _accelerate.basis sys.modules["qiskit._accelerate.basis.basis_translator"] = _accelerate.basis.basis_translator sys.modules["qiskit._accelerate.converters"] = _accelerate.converters +sys.modules["qiskit._accelerate.basis"] = _accelerate.basis +sys.modules["qiskit._accelerate.basis.basis_translator"] = _accelerate.basis.basis_translator sys.modules["qiskit._accelerate.convert_2q_block_matrix"] = _accelerate.convert_2q_block_matrix sys.modules["qiskit._accelerate.dense_layout"] = _accelerate.dense_layout sys.modules["qiskit._accelerate.equivalence"] = _accelerate.equivalence diff --git a/qiskit/transpiler/passes/basis/basis_translator.py b/qiskit/transpiler/passes/basis/basis_translator.py index 01cbd18c0ee7..a1d3e7f0d39c 100644 --- a/qiskit/transpiler/passes/basis/basis_translator.py +++ b/qiskit/transpiler/passes/basis/basis_translator.py @@ -13,15 +13,12 @@ """Translates gates to a target basis using a given equivalence library.""" -import random import time import logging from functools import singledispatchmethod from collections import defaultdict -import rustworkx - from qiskit.circuit import ( ControlFlowOp, QuantumCircuit, @@ -29,11 +26,10 @@ ) from qiskit.dagcircuit import DAGCircuit, DAGOpNode from qiskit.converters import circuit_to_dag, dag_to_circuit -from qiskit.circuit.equivalence import Key, NodeData, Equivalence from qiskit.transpiler.basepasses import TransformationPass from qiskit.transpiler.exceptions import TranspilerError from qiskit.circuit.controlflow import CONTROL_FLOW_OP_NAMES -from qiskit._accelerate.basis.basis_translator import compose_transforms +from qiskit._accelerate.basis.basis_translator import basis_search, compose_transforms logger = logging.getLogger(__name__) @@ -172,7 +168,7 @@ def run(self, dag): # Search for a path from source to target basis. search_start_time = time.time() - basis_transforms = _basis_search(self._equiv_lib, source_basis, target_basis) + basis_transforms = basis_search(self._equiv_lib, source_basis, target_basis) qarg_local_basis_transforms = {} for qarg, local_source_basis in qargs_local_source_basis.items(): @@ -195,7 +191,7 @@ def run(self, dag): expanded_target, qarg, ) - local_basis_transforms = _basis_search( + local_basis_transforms = basis_search( self._equiv_lib, local_source_basis, expanded_target ) @@ -446,159 +442,3 @@ def _extract_basis_target( qargs_local_source_basis=qargs_local_source_basis, ) return source_basis, qargs_local_source_basis - - -class StopIfBasisRewritable(Exception): - """Custom exception that signals `rustworkx.dijkstra_search` to stop.""" - - -class BasisSearchVisitor(rustworkx.visit.DijkstraVisitor): - """Handles events emitted during `rustworkx.dijkstra_search`.""" - - def __init__(self, graph, source_basis, target_basis): - self.graph = graph - self.target_basis = set(target_basis) - self._source_gates_remain = set(source_basis) - self._num_gates_remain_for_rule = {} - save_index = -1 - for edata in self.graph.edges(): - if save_index == edata.index: - continue - self._num_gates_remain_for_rule[edata.index] = edata.num_gates - save_index = edata.index - - self._basis_transforms = [] - self._predecessors = {} - self._opt_cost_map = {} - - def discover_vertex(self, v, score): - gate = self.graph[v].key - self._source_gates_remain.discard(gate) - self._opt_cost_map[gate] = score - rule = self._predecessors.get(gate, None) - if rule is not None: - logger.debug( - "Gate %s generated using rule \n%s\n with total cost of %s.", - gate.name, - rule.circuit, - score, - ) - self._basis_transforms.append( - ((gate.name, gate.num_qubits), (rule.params, rule.circuit)) - ) - # we can stop the search if we have found all gates in the original circuit. - if not self._source_gates_remain: - # if we start from source gates and apply `basis_transforms` in reverse order, we'll end - # up with gates in the target basis. Note though that `basis_transforms` may include - # additional transformations that are not required to map our source gates to the given - # target basis. - self._basis_transforms.reverse() - raise StopIfBasisRewritable - - def examine_edge(self, edge): - _, target, edata = edge - if edata is None: - return - - self._num_gates_remain_for_rule[edata.index] -= 1 - - target = self.graph[target].key - # if there are gates in this `rule` that we have not yet generated, we can't apply - # this `rule`. if `target` is already in basis, it's not beneficial to use this rule. - if self._num_gates_remain_for_rule[edata.index] > 0 or target in self.target_basis: - raise rustworkx.visit.PruneSearch - - def edge_relaxed(self, edge): - _, target, edata = edge - if edata is not None: - gate = self.graph[target].key - self._predecessors[gate] = edata.rule - - def edge_cost(self, edge_data): - """Returns the cost of an edge. - - This function computes the cost of this edge rule by summing - the costs of all gates in the rule equivalence circuit. In the - end, we need to subtract the cost of the source since `dijkstra` - will later add it. - """ - - if edge_data is None: - # the target of the edge is a gate in the target basis, - # so we return a default value of 1. - return 1 - - cost_tot = 0 - for instruction in edge_data.rule.circuit: - key = Key(name=instruction.name, num_qubits=len(instruction.qubits)) - cost_tot += self._opt_cost_map[key] - - return cost_tot - self._opt_cost_map[edge_data.source] - - @property - def basis_transforms(self): - """Returns the gate basis transforms.""" - return self._basis_transforms - - -def _basis_search(equiv_lib, source_basis, target_basis): - """Search for a set of transformations from source_basis to target_basis. - - Args: - equiv_lib (EquivalenceLibrary): Source of valid translations - source_basis (Set[Tuple[gate_name: str, gate_num_qubits: int]]): Starting basis. - target_basis (Set[gate_name: str]): Target basis. - - Returns: - Optional[List[Tuple[gate, equiv_params, equiv_circuit]]]: List of (gate, - equiv_params, equiv_circuit) tuples tuples which, if applied in order - will map from source_basis to target_basis. Returns None if no path - was found. - """ - - logger.debug("Begining basis search from %s to %s.", source_basis, target_basis) - - source_basis = { - Key(gate_name, gate_num_qubits) - for gate_name, gate_num_qubits in source_basis - if gate_name not in target_basis - } - - # if source basis is empty, no work to be done. - if not source_basis: - return [] - - # This is only necessary since gates in target basis are currently reported by - # their names and we need to have in addition the number of qubits they act on. - target_basis_keys = [key for key in equiv_lib.keys() if key.name in target_basis] - - graph = equiv_lib.graph - vis = BasisSearchVisitor(graph, source_basis, target_basis_keys) - - # we add a dummy node and connect it with gates in the target basis. - # we'll start the search from this dummy node. - dummy = graph.add_node( - NodeData( - key=Key("".join(chr(random.randint(0, 26) + 97) for _ in range(10)), 0), - equivs=[Equivalence([], QuantumCircuit(0, name="dummy starting node"))], - ) - ) - - try: - graph.add_edges_from_no_data( - [(dummy, equiv_lib.node_index(key)) for key in target_basis_keys] - ) - rtn = None - try: - rustworkx.digraph_dijkstra_search(graph, [dummy], vis.edge_cost, vis) - except StopIfBasisRewritable: - rtn = vis.basis_transforms - - logger.debug("Transformation path:") - for (gate_name, gate_num_qubits), (params, equiv) in rtn: - logger.debug("%s/%s => %s\n%s", gate_name, gate_num_qubits, params, equiv) - finally: - # Remove dummy node in order to return graph to original state - graph.remove_node(dummy) - - return rtn