Skip to content

Commit

Permalink
Avoid Python op creation in commutative cancellation (#12701)
Browse files Browse the repository at this point in the history
* Avoid Python op creation in commutative cancellation

This commit updates the commutative cancellation and commutation
analysis transpiler pass. It builds off of #12692 to adjust access
patterns in the python transpiler path to avoid eagerly creating a
Python space operation object. The goal of this PR is to mitigate the
performance regression on these passes introduced by the extra
conversion cost of #12459.

* Remove stray print

* Don't add __array__ to DAGOpNode or CircuitInstruction
  • Loading branch information
mtreinish authored Jul 3, 2024
1 parent 419f40e commit 9571ea1
Show file tree
Hide file tree
Showing 6 changed files with 111 additions and 17 deletions.
6 changes: 6 additions & 0 deletions crates/circuit/src/circuit_instruction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -460,6 +460,12 @@ impl CircuitInstruction {
.and_then(|attrs| attrs.unit.as_deref())
}

pub fn is_parameterized(&self) -> bool {
self.params
.iter()
.any(|x| matches!(x, Param::ParameterExpression(_)))
}

/// Creates a shallow copy with the given fields replaced.
///
/// Returns:
Expand Down
30 changes: 30 additions & 0 deletions crates/circuit/src/dag_node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ use crate::circuit_instruction::{
convert_py_to_operation_type, operation_type_to_py, CircuitInstruction,
ExtraInstructionAttributes,
};
use crate::imports::QUANTUM_CIRCUIT;
use crate::operations::Operation;
use numpy::IntoPyArray;
use pyo3::prelude::*;
Expand Down Expand Up @@ -228,6 +229,16 @@ impl DAGOpNode {
Ok(())
}

#[getter]
fn num_qubits(&self) -> u32 {
self.instruction.operation.num_qubits()
}

#[getter]
fn num_clbits(&self) -> u32 {
self.instruction.operation.num_clbits()
}

#[getter]
fn get_qargs(&self, py: Python) -> Py<PyTuple> {
self.instruction.qubits.clone_ref(py)
Expand Down Expand Up @@ -259,6 +270,10 @@ impl DAGOpNode {
self.instruction.params.to_object(py)
}

pub fn is_parameterized(&self) -> bool {
self.instruction.is_parameterized()
}

#[getter]
fn matrix(&self, py: Python) -> Option<PyObject> {
let matrix = self.instruction.operation.matrix(&self.instruction.params);
Expand Down Expand Up @@ -325,6 +340,21 @@ impl DAGOpNode {
}
}

#[getter]
fn definition<'py>(&self, py: Python<'py>) -> PyResult<Option<Bound<'py, PyAny>>> {
let definition = self
.instruction
.operation
.definition(&self.instruction.params);
definition
.map(|data| {
QUANTUM_CIRCUIT
.get_bound(py)
.call_method1(intern!(py, "_from_circuit_data"), (data,))
})
.transpose()
}

/// Sets the Instruction name corresponding to the op for this node
#[setter]
fn set_name(&mut self, py: Python, new_name: PyObject) -> PyResult<()> {
Expand Down
34 changes: 32 additions & 2 deletions crates/circuit/src/operations.rs
Original file line number Diff line number Diff line change
Expand Up @@ -715,8 +715,38 @@ impl Operation for StandardGate {
.expect("Unexpected Qiskit python bug"),
)
}),
Self::RXGate => todo!("Add when we have R"),
Self::RYGate => todo!("Add when we have R"),
Self::RXGate => Python::with_gil(|py| -> Option<CircuitData> {
let theta = &params[0];
Some(
CircuitData::from_standard_gates(
py,
1,
[(
Self::RGate,
smallvec![theta.clone(), FLOAT_ZERO],
smallvec![Qubit(0)],
)],
FLOAT_ZERO,
)
.expect("Unexpected Qiskit python bug"),
)
}),
Self::RYGate => Python::with_gil(|py| -> Option<CircuitData> {
let theta = &params[0];
Some(
CircuitData::from_standard_gates(
py,
1,
[(
Self::RGate,
smallvec![theta.clone(), Param::Float(PI / 2.0)],
smallvec![Qubit(0)],
)],
FLOAT_ZERO,
)
.expect("Unexpected Qiskit python bug"),
)
}),
Self::RZGate => Python::with_gil(|py| -> Option<CircuitData> {
let theta = &params[0];
Some(
Expand Down
35 changes: 34 additions & 1 deletion qiskit/circuit/commutation_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from qiskit.circuit.operation import Operation
from qiskit.circuit.controlflow import CONTROL_FLOW_OP_NAMES
from qiskit.quantum_info.operators import Operator
from qiskit._accelerate.circuit import StandardGate

_skipped_op_names = {"measure", "reset", "delay", "initialize"}
_no_cache_op_names = {"annotated"}
Expand Down Expand Up @@ -57,6 +58,23 @@ def __init__(self, standard_gate_commutations: dict = None, cache_max_entries: i
self._cache_miss = 0
self._cache_hit = 0

def commute_nodes(
self,
op1,
op2,
max_num_qubits: int = 3,
) -> bool:
"""Checks if two DAGOpNodes commute."""
qargs1 = op1.qargs
cargs1 = op2.cargs
if not isinstance(op1._raw_op, StandardGate):
op1 = op1.op
qargs2 = op2.qargs
cargs2 = op2.cargs
if not isinstance(op2._raw_op, StandardGate):
op2 = op2.op
return self.commute(op1, qargs1, cargs1, op2, qargs2, cargs2, max_num_qubits)

def commute(
self,
op1: Operation,
Expand Down Expand Up @@ -255,9 +273,15 @@ def is_commutation_skipped(op, qargs, max_num_qubits):
if getattr(op, "is_parameterized", False) and op.is_parameterized():
return True

from qiskit.dagcircuit.dagnode import DAGOpNode

# we can proceed if op has defined: to_operator, to_matrix and __array__, or if its definition can be
# recursively resolved by operations that have a matrix. We check this by constructing an Operator.
if (hasattr(op, "to_matrix") and hasattr(op, "__array__")) or hasattr(op, "to_operator"):
if (
isinstance(op, DAGOpNode)
or (hasattr(op, "to_matrix") and hasattr(op, "__array__"))
or hasattr(op, "to_operator")
):
return False

return False
Expand Down Expand Up @@ -409,6 +433,15 @@ def _commute_matmul(
first_qarg = tuple(qarg[q] for q in first_qargs)
second_qarg = tuple(qarg[q] for q in second_qargs)

from qiskit.dagcircuit.dagnode import DAGOpNode

# If we have a DAGOpNode here we've received a StandardGate definition from
# rust and we can manually pull the matrix to use for the Operators
if isinstance(first_ops, DAGOpNode):
first_ops = first_ops.matrix
if isinstance(second_op, DAGOpNode):
second_op = second_op.matrix

# try to generate an Operator out of op, if this succeeds we can determine commutativity, otherwise
# return false
try:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,14 +72,7 @@ def run(self, dag):
does_commute = (
isinstance(current_gate, DAGOpNode)
and isinstance(prev_gate, DAGOpNode)
and self.comm_checker.commute(
current_gate.op,
current_gate.qargs,
current_gate.cargs,
prev_gate.op,
prev_gate.qargs,
prev_gate.cargs,
)
and self.comm_checker.commute_nodes(current_gate, prev_gate)
)
if not does_commute:
break
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from qiskit.circuit.library.standard_gates.rx import RXGate
from qiskit.circuit.library.standard_gates.p import PhaseGate
from qiskit.circuit.library.standard_gates.rz import RZGate
from qiskit.circuit import ControlFlowOp
from qiskit.circuit.controlflow import CONTROL_FLOW_OP_NAMES


_CUTOFF_PRECISION = 1e-5
Expand Down Expand Up @@ -138,14 +138,14 @@ def run(self, dag):
total_phase = 0.0
for current_node in run:
if (
getattr(current_node.op, "condition", None) is not None
current_node.condition is not None
or len(current_node.qargs) != 1
or current_node.qargs[0] != run_qarg
):
raise RuntimeError("internal error")

if current_node.name in ["p", "u1", "rz", "rx"]:
current_angle = float(current_node.op.params[0])
current_angle = float(current_node.params[0])
elif current_node.name in ["z", "x"]:
current_angle = np.pi
elif current_node.name == "t":
Expand All @@ -159,8 +159,8 @@ def run(self, dag):

# Compose gates
total_angle = current_angle + total_angle
if current_node.op.definition:
total_phase += current_node.op.definition.global_phase
if current_node.definition:
total_phase += current_node.definition.global_phase

# Replace the data of the first node in the run
if cancel_set_key[0] == "z_rotation":
Expand Down Expand Up @@ -200,7 +200,9 @@ def _handle_control_flow_ops(self, dag):
"""

pass_manager = PassManager([CommutationAnalysis(), self])
for node in dag.op_nodes(ControlFlowOp):
for node in dag.op_nodes():
if node.name not in CONTROL_FLOW_OP_NAMES:
continue
mapped_blocks = []
for block in node.op.blocks:
new_circ = pass_manager.run(block)
Expand Down

0 comments on commit 9571ea1

Please sign in to comment.