Skip to content

Commit

Permalink
up
Browse files Browse the repository at this point in the history
  • Loading branch information
sbrandhsn committed Sep 5, 2024
1 parent b3038a6 commit f3224b3
Show file tree
Hide file tree
Showing 9 changed files with 311 additions and 152 deletions.
2 changes: 1 addition & 1 deletion crates/accelerate/src/commutation_analysis.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ const MAX_NUM_QUBITS: u32 = 3;
/// commutation_set = {0: [[0], [2, 3], [4], [1]]}
/// node_indices = {(0, 0): 0, (1, 0): 3, (2, 0): 1, (3, 0): 1, (4, 0): 2}
///
fn analyze_commutations_inner(
pub fn analyze_commutations_inner(
py: Python,
dag: &mut DAGCircuit,
commutation_checker: &mut CommutationChecker,
Expand Down
275 changes: 275 additions & 0 deletions crates/accelerate/src/commutation_cancellation.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,275 @@
use crate::commutation_analysis::analyze_commutations_inner;
use crate::commutation_checker::CommutationChecker;
use crate::target_transpiler::Target;
use hashbrown::{HashMap, HashSet};
use pyo3::prelude::*;
use pyo3::{pyfunction, pymodule, wrap_pyfunction, Bound, PyResult, Python};
use qiskit_circuit::dag_circuit::{DAGCircuit, NodeType, Wire};
use qiskit_circuit::operations::StandardGate::{PhaseGate, RXGate, RZGate, U1Gate};
use qiskit_circuit::operations::{
Operation, Param, StandardGate,
};
use qiskit_circuit::Qubit;
use rustworkx_core::petgraph::stable_graph::NodeIndex;
use std::f64::consts::PI;

const _CUTOFF_PRECISION: f64 = 1e-5;
static ROTATION_GATES: [&str; 4] = ["p", "u1", "rz", "rx"];
static HALF_TURNS: [&str; 2] = ["z", "x"];
static QUARTER_TURNS: [&str; 1] = ["s"];
static EIGHTH_TURNS: [&str; 1] = ["t"];

const Z_ROTATION: &str = "z_rotation";
const X_ROTATION: &str = "x_rotation";

#[pyfunction]
#[pyo3(signature = (dag, commutation_checker, basis_gates=None, target=None))]
pub(crate) fn cancel_commutations(
py: Python,
dag: &mut DAGCircuit,
commutation_checker: &mut CommutationChecker,
basis_gates: Option<HashSet<String>>,
target: Option<&Target>,
) -> PyResult<()> {
let basis: HashSet<String> = if let Some(tar) = target {
HashSet::from_iter(tar.operation_names().map(String::from))
} else if let Some(basis) = basis_gates {
basis
} else {
HashSet::new()
};

let _var_z_map: HashMap<&str, StandardGate> =
HashMap::from([("rz", RZGate), ("p", PhaseGate), ("u1", U1Gate)]);

let _z_rotations: HashSet<&str> = HashSet::from(["p", "z", "u1", "rz", "t", "s"]);
let _x_rotations: HashSet<&str> = HashSet::from(["x", "rx"]);
let _gates: HashSet<&str> = HashSet::from(["cx", "cy", "cz", "h", "y"]);

let z_var_gate = dag
.op_names
.keys()
.find(|g| _var_z_map.contains_key(g.as_str()))
// Fallback to the first matching key from basis if there is no match in dag.op_names
.or_else(|| basis.iter().find(|g| _var_z_map.contains_key(g.as_str())))
// get the StandardGate associated with that string
.and_then(|key| _var_z_map.get(key.as_str()));

// Gate sets to be cancelled
/* Traverse each qubit to generate the cancel dictionaries
Cancel dictionaries:
- For 1-qubit gates the key is (gate_type, qubit_id, commutation_set_id),
the value is the list of gates that share the same gate type, qubit, commutation set.
- For 2qbit gates the key: (gate_type, first_qbit, sec_qbit, first commutation_set_id,
sec_commutation_set_id), the value is the list gates that share the same gate type,
qubits and commutation sets.
*/
let (commutation_set, node_indices) = analyze_commutations_inner(py, dag, commutation_checker)?;
let mut single_q_cancellation_sets: HashMap<(String, Qubit, usize), Vec<NodeIndex>> =
HashMap::new();
let mut two_q_cancellation_sets: HashMap<(String, Qubit, Qubit, usize, usize), Vec<NodeIndex>> =
HashMap::new();

(0..dag.num_qubits() as u32).for_each(|qubit| {
let wire = Qubit(qubit);
if let Some(wire_commutation_set) = commutation_set.get(&Wire::Qubit(wire)) {
wire_commutation_set
.iter()
.enumerate()
.for_each(|(com_set_idx, com_set)| {
// This ensures that we only have DAGOPNodes in the current com_set, yuck...
if let NodeType::Operation(_node0) = &dag.dag[*com_set.first().unwrap()] {
com_set.iter().for_each(|node| {
let op = match &dag.dag[*node] {
NodeType::Operation(instr) => instr,
_ => panic!("Unexpected type in commutation set."),
};
let num_qargs = dag.get_qargs(op.qubits).len();
// no support for cancellation of parameterized gates
if op
.params_view()
.iter()
.all(|p| !matches!(p, Param::ParameterExpression(_)))
{
let op_name = op.op.name().to_string();
if num_qargs == 1usize && _gates.contains(op_name.as_str()) {
single_q_cancellation_sets
.entry((op_name.clone(), wire, com_set_idx))
.or_insert_with(Vec::new)
.push(*node);
}

if num_qargs == 1usize && _z_rotations.contains(op_name.as_str()) {
single_q_cancellation_sets
.entry((Z_ROTATION.to_string(), wire, com_set_idx))
.or_insert_with(Vec::new)
.push(*node);
}
if num_qargs == 1usize && _x_rotations.contains(op_name.as_str()) {
single_q_cancellation_sets
.entry((X_ROTATION.to_string(), wire, com_set_idx))
.or_insert_with(Vec::new)
.push(*node);
}
// Don't deal with Y rotation, because Y rotation doesn't commute with
// CNOT, so it should be dealt with by optimized1qgate pass
if num_qargs == 2usize
&& dag.get_qargs(op.qubits).first().unwrap() == &wire
{
let second_qarg = dag.get_qargs(op.qubits)[1];
let q2_key = (
op_name,
wire,
second_qarg,
com_set_idx,
*node_indices
.get(&(*node, Wire::Qubit(second_qarg)))
.unwrap(),
);
two_q_cancellation_sets
.entry(q2_key)
.or_insert_with(Vec::new)
.push(*node);
}
}
})
}
})
}
});

for (cancel_key, cancel_set) in &two_q_cancellation_sets {
if cancel_set.len() > 1 && _gates.contains(cancel_key.0.as_str()) {
for &c_node in &cancel_set[0..(cancel_set.len() / 2) * 2] {
dag.remove_op_node(c_node);
}
}
}

for (cancel_key, cancel_set) in &single_q_cancellation_sets {
if cancel_key.0 == Z_ROTATION && z_var_gate.is_none() {
continue;
}
if cancel_set.len() > 1 && _gates.contains(cancel_key.0.as_str()) {
for &c_node in &cancel_set[0..(cancel_set.len() / 2) * 2] {
dag.remove_op_node(c_node);
}
} else if cancel_set.len() > 1 && (cancel_key.0 == Z_ROTATION || cancel_key.0 == X_ROTATION)
{
let run_op = match &dag.dag[*cancel_set.first().unwrap()] {
NodeType::Operation(instr) => instr,
_ => panic!("Unexpected type in commutation set run."),
};

let run_qarg = dag.get_qargs(run_op.qubits).first().unwrap();
let mut total_angle: f64 = 0.0f64;
let mut total_phase: f64 = 0.0f64;
for current_node in cancel_set {
let node_op = match &dag.dag[*current_node] {
NodeType::Operation(instr) => instr,
_ => panic!("Unexpected type in commutation set run."),
};
let node_op_name = node_op.op.name();

let node_qargs = dag.get_qargs(node_op.qubits);
if node_op
.extra_attrs
.as_deref()
.is_some_and(|attr| attr.condition.is_some())
|| node_qargs.len() > 1
|| &node_qargs[0] != run_qarg
{
panic!("internal error");
}

let node_angle = if ROTATION_GATES.contains(&node_op_name) {
match node_op.params_view().first() {
Some(Param::Float(f)) => *f,
_ => panic!(
"Rotational gate with parameter expression encoutned in cancellation"
),
}
} else if HALF_TURNS.contains(&node_op_name) {
PI
} else if QUARTER_TURNS.contains(&node_op_name) {
PI / 2.0
} else if EIGHTH_TURNS.contains(&node_op_name) {
PI / 4.0
} else {
panic!("Angle for operation {node_op_name} is not defined")
};
total_angle += node_angle;

if let Some(definition) = node_op.op.definition(node_op.params_view()) {
//TODO check for PyNone global phase?
//total_phase += match definition.global_phase {Param::Float(f) => f, Param::Obj(pyop) => , Param::ParameterExpression(_) => panic!("PackedInstruction with definition has global phase set as parameter expression")};
total_phase += match definition.global_phase {Param::Float(f) => f, _ => panic!("PackedInstruction with definition has no global phase set as floating point number")};
}
}

let new_op = if cancel_key.0 == Z_ROTATION {
z_var_gate.unwrap()
} else if cancel_key.0 == X_ROTATION {
&RXGate
} else {
panic!("impossible case!");
};

let gate_angle = mod_2pi(total_angle, 0.);

let new_op_phase = if gate_angle.abs() > _CUTOFF_PRECISION {
let new_index = dag.insert_1q_on_incoming_qubit(
(*new_op, &[total_angle]),
*cancel_set.first().unwrap(),
);
let new_node = match &dag.dag[new_index] {
NodeType::Operation(instr) => instr,
_ => panic!("Unexpected type in commutation set run."),
};

if let Some(definition) = new_node.op.definition(new_node.params_view()) {
//TODO check for PyNone global phase?
match definition.global_phase {Param::Float(f) => f, _ => panic!("PackedInstruction with definition has no global phase set as floating point number")}
} else {
0.0f64
}
} else {
0.0f64
};

dag.add_global_phase(py, &Param::Float(total_phase - new_op_phase))?;

for node in cancel_set {
dag.remove_op_node(*node);
}

//TODO do we need this due to numerical instability?
/*
if np.mod(total_angle, (2 * np.pi)) < _CUTOFF_PRECISION:
dag.remove_op_node(run[0])
*/
}
}

Ok(())
}

/// Wrap angle into interval [-π,π). If within atol of the endpoint, clamp to -π
#[inline]
fn mod_2pi(angle: f64, atol: f64) -> f64 {
// f64::rem_euclid() isn't exactly the same as Python's % operator, but because
// the RHS here is a constant and positive it is effectively equivalent for
// this case
let wrapped = (angle + PI).rem_euclid(2. * PI) - PI;
if (wrapped - PI).abs() < atol {
-PI
} else {
wrapped
}
}

#[pymodule]
pub fn commutation_cancellation(m: &Bound<PyModule>) -> PyResult<()> {
m.add_wrapped(wrap_pyfunction!(cancel_commutations))?;
Ok(())
}
1 change: 1 addition & 0 deletions crates/accelerate/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ use pyo3::import_exception;

pub mod circuit_library;
pub mod commutation_analysis;
pub mod commutation_cancellation;
pub mod commutation_checker;
pub mod convert_2q_block_matrix;
pub mod dense_layout;
Expand Down
2 changes: 1 addition & 1 deletion crates/circuit/src/circuit_data.rs
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ pub struct CircuitData {
clbits: BitData<Clbit>,
param_table: ParameterTable,
#[pyo3(get)]
global_phase: Param,
pub global_phase: Param,
}

impl CircuitData {
Expand Down
7 changes: 4 additions & 3 deletions crates/circuit/src/dag_circuit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ pub struct DAGCircuit {
var_output_map: _VarIndexMap,

/// Operation kind to count
op_names: IndexMap<String, usize, RandomState>,
pub op_names: IndexMap<String, usize, RandomState>,

// Python modules we need to frequently access (for now).
control_flow_module: PyControlFlowModule,
Expand Down Expand Up @@ -3396,7 +3396,7 @@ def _format(operand):
/// DAGCircuitError: If replacement operation was incompatible with
/// location of target node.
#[pyo3(signature = (node, op, inplace=false, propagate_condition=true))]
fn substitute_node(
pub fn substitute_node(
&mut self,
node: &Bound<PyAny>,
op: &Bound<PyAny>,
Expand Down Expand Up @@ -6241,7 +6241,7 @@ impl DAGCircuit {
&mut self,
new_gate: (StandardGate, &[f64]),
old_index: NodeIndex,
) {
) -> NodeIndex {
self.increment_op(new_gate.0.name());
let old_node = &self.dag[old_index];
let inst = if let NodeType::Operation(old_node) = old_node {
Expand All @@ -6268,6 +6268,7 @@ impl DAGCircuit {
self.dag.add_edge(parent_index, new_index, weight.clone());
self.dag.add_edge(new_index, old_index, weight);
self.dag.remove_edge(edge_index);
new_index
}

/// Remove a sequence of 1 qubit nodes from the dag
Expand Down
18 changes: 10 additions & 8 deletions crates/pyext/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,15 @@ use pyo3::prelude::*;

use qiskit_accelerate::{
circuit_library::circuit_library, commutation_analysis::commutation_analysis,
commutation_checker::commutation_checker, convert_2q_block_matrix::convert_2q_block_matrix,
dense_layout::dense_layout, error_map::error_map,
euler_one_qubit_decomposer::euler_one_qubit_decomposer, isometry::isometry, nlayout::nlayout,
optimize_1q_gates::optimize_1q_gates, pauli_exp_val::pauli_expval, results::results,
sabre::sabre, sampled_exp_val::sampled_exp_val, sparse_pauli_op::sparse_pauli_op,
star_prerouting::star_prerouting, stochastic_swap::stochastic_swap, synthesis::synthesis,
target_transpiler::target, two_qubit_decompose::two_qubit_decompose, uc_gate::uc_gate,
utils::utils, vf2_layout::vf2_layout,
commutation_cancellation::commutation_cancellation, commutation_checker::commutation_checker,
convert_2q_block_matrix::convert_2q_block_matrix, dense_layout::dense_layout,
error_map::error_map, euler_one_qubit_decomposer::euler_one_qubit_decomposer,
isometry::isometry, nlayout::nlayout, optimize_1q_gates::optimize_1q_gates,
pauli_exp_val::pauli_expval, results::results, sabre::sabre, sampled_exp_val::sampled_exp_val,
sparse_pauli_op::sparse_pauli_op, star_prerouting::star_prerouting,
stochastic_swap::stochastic_swap, synthesis::synthesis, target_transpiler::target,
two_qubit_decompose::two_qubit_decompose, uc_gate::uc_gate, utils::utils,
vf2_layout::vf2_layout,
};

#[inline(always)]
Expand Down Expand Up @@ -63,5 +64,6 @@ fn _accelerate(m: &Bound<PyModule>) -> PyResult<()> {
add_submodule(m, vf2_layout, "vf2_layout")?;
add_submodule(m, commutation_checker, "commutation_checker")?;
add_submodule(m, commutation_analysis, "commutation_analysis")?;
add_submodule(m, commutation_cancellation, "commutation_cancellation")?;
Ok(())
}
1 change: 1 addition & 0 deletions qiskit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@
sys.modules["qiskit._accelerate.synthesis.clifford"] = _accelerate.synthesis.clifford
sys.modules["qiskit._accelerate.commutation_checker"] = _accelerate.commutation_checker
sys.modules["qiskit._accelerate.commutation_analysis"] = _accelerate.commutation_analysis
sys.modules["qiskit._accelerate.commutation_cancellation"] = _accelerate.commutation_cancellation
sys.modules["qiskit._accelerate.synthesis.linear_phase"] = _accelerate.synthesis.linear_phase

from qiskit.exceptions import QiskitError, MissingOptionalLibraryError
Expand Down
Loading

0 comments on commit f3224b3

Please sign in to comment.