Skip to content

Commit

Permalink
Handle atol/rtol, more error propagation
Browse files Browse the repository at this point in the history
  • Loading branch information
Cryoris committed Sep 2, 2024
1 parent 12d0a05 commit 9235004
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 45 deletions.
68 changes: 38 additions & 30 deletions crates/accelerate/src/commutation_checker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
// copyright notice, and modified files need to carry a notice indicating
// that they have been altered from the originals.

use approx::abs_diff_eq;
use hashbrown::{HashMap, HashSet};
use ndarray::linalg::kron;
use ndarray::Array2;
Expand Down Expand Up @@ -404,34 +403,38 @@ impl CommutationChecker {
"first instructions must have at most as many qubits as the second instruction",
));
};
let first_mat = match get_matrix(py, first_op, first_params) {
let first_mat = match get_matrix(py, first_op, first_params)? {
Some(matrix) => matrix,
None => return Ok(false),
};

let second_mat = match get_matrix(py, second_op, second_params) {
let second_mat = match get_matrix(py, second_op, second_params)? {
Some(matrix) => matrix,
None => return Ok(false),
};

let tol = 1e-8;
let rtol = 1e-5;
let atol = 1e-8;
if first_qarg == second_qarg {
match first_qarg.len() {
1 => Ok(unitary_compose::commute_1q(
&first_mat.view(),
&second_mat.view(),
tol,
rtol,
atol,
)),
2 => Ok(unitary_compose::commute_2q(
&first_mat.view(),
&second_mat.view(),
&[Qubit(0), Qubit(1)],
tol,
rtol,
atol,
)),
_ => Ok(abs_diff_eq!(
second_mat.dot(&first_mat),
first_mat.dot(&second_mat),
epsilon = 1e-8
_ => Ok(unitary_compose::allclose(
&second_mat.dot(&first_mat).view(),
&first_mat.dot(&second_mat).view(),
rtol,
atol,
)),
}
} else {
Expand All @@ -457,7 +460,8 @@ impl CommutationChecker {
&first_mat.view(),
&second_mat.view(),
&second_qarg,
tol,
rtol,
atol,
));
};

Expand All @@ -479,7 +483,12 @@ impl CommutationChecker {
Ok(matrix) => matrix,
Err(e) => return Err(PyRuntimeError::new_err(e)),
};
Ok(abs_diff_eq!(op12, op21, epsilon = 1e-8))
Ok(unitary_compose::allclose(
&op12.view(),
&op21.view(),
rtol,
atol,
))
}
}

Expand Down Expand Up @@ -531,30 +540,29 @@ fn commutation_precheck(
None
}

fn get_matrix(py: Python, operation: &OperationRef, params: &[Param]) -> Option<Array2<Complex64>> {
fn get_matrix(
py: Python,
operation: &OperationRef,
params: &[Param],
) -> PyResult<Option<Array2<Complex64>>> {
match operation.matrix(params) {
Some(matrix) => Some(matrix),
Some(matrix) => Ok(Some(matrix)),
None => match operation {
PyGateType(gate) => matrix_via_operator(py, &gate.gate),
PyOperationType(op) => matrix_via_operator(py, &op.operation),
_ => None,
PyGateType(gate) => Ok(Some(matrix_via_operator(py, &gate.gate)?)),
PyOperationType(op) => Ok(Some(matrix_via_operator(py, &op.operation)?)),
_ => Ok(None),
},
}
}

fn matrix_via_operator(py: Python, py_obj: &PyObject) -> Option<Array2<Complex64>> {
Some(
QI_OPERATOR
.get_bound(py)
.call1((py_obj,))
.ok()?
.getattr(intern!(py, "data"))
.ok()?
.extract::<PyReadonlyArray2<Complex64>>()
.ok()?
.as_array()
.to_owned(),
)
fn matrix_via_operator(py: Python, py_obj: &PyObject) -> PyResult<Array2<Complex64>> {
Ok(QI_OPERATOR
.get_bound(py)
.call1((py_obj,))?
.getattr(intern!(py, "data"))?
.extract::<PyReadonlyArray2<Complex64>>()?
.as_array()
.to_owned())
}

fn is_commutation_skipped<T>(op: &T, params: &[Param]) -> bool
Expand Down
72 changes: 57 additions & 15 deletions crates/accelerate/src/unitary_compose.rs
Original file line number Diff line number Diff line change
Expand Up @@ -153,33 +153,53 @@ fn _einsum_matmul_index(qubits: &[u32], num_qubits: usize) -> String {
)
}

pub fn commute_1q(left: &ArrayView2<Complex64>, right: &ArrayView2<Complex64>, tol: f64) -> bool {
let values: [Complex64; 4] = [
left[[0, 1]] * right[[1, 0]] - right[[0, 1]] * left[[1, 0]], // top left
(left[[0, 0]] - left[[1, 1]]) * right[[0, 1]]
+ left[[0, 1]] * (right[[1, 1]] - right[[0, 0]]), // top right
left[[1, 0]] * (right[[0, 0]] - right[[1, 1]])
+ (left[[1, 1]] - left[[0, 0]]) * right[[1, 0]], // bottom left
left[[1, 0]] * right[[0, 1]] - right[[1, 0]] * left[[0, 1]], // bottom right
];
!values.iter().any(|value| value.abs() > tol)
pub fn commute_1q(
left: &ArrayView2<Complex64>,
right: &ArrayView2<Complex64>,
rtol: f64,
atol: f64,
) -> bool {
// This could allow for explicit hardcoded formulas, using less FLOPS, if we only
// consider an absolute tolerance. But for backward compatibility we now implement the full
// formula including relative tolerance handling.
for i in 0..2usize {
for j in 0..2usize {
let mut ab = Complex64::zero();
let mut ba = Complex64::zero();
for k in 0..2usize {
ab += left[[i, k]] * right[[k, j]];
ba += right[[i, k]] * left[[k, j]];
}
let sum = ab - ba;
if sum.abs() > atol + ba.abs() * rtol {
return false;
}
}
}
true
}

pub fn commute_2q(
left: &ArrayView2<Complex64>,
right: &ArrayView2<Complex64>,
qargs: &[Qubit],
tol: f64,
rtol: f64,
atol: f64,
) -> bool {
let rev = qargs[0].0 == 1;
for i in 0..4usize {
for j in 0..4usize {
let mut sum = Complex64::zero();
// We compute AB and BA separately, to enable checking the relative difference
// (AB - BA)_ij > atol + rtol * BA_ij. This is due to backward compatibility and could
// maybe be changed in the future to save one complex number allocation.
let mut ab = Complex64::zero();
let mut ba = Complex64::zero();
for k in 0..4usize {
sum += left[[_ind(i, rev), _ind(k, rev)]] * right[[k, j]]
- right[[i, k]] * left[[_ind(k, rev), _ind(j, rev)]];
ab += left[[_ind(i, rev), _ind(k, rev)]] * right[[k, j]];
ba += right[[i, k]] * left[[_ind(k, rev), _ind(j, rev)]];
}
if sum.abs() > tol {
let sum = ab - ba;
if sum.abs() > atol + ba.abs() * rtol {
return false;
}
}
Expand All @@ -196,3 +216,25 @@ fn _ind(i: usize, reversed: bool) -> usize {
i
}
}

/// For equally sized matrices, ``left`` and ``right``, check whether all entries are close
/// by the criterion
///
/// |left_ij - right_ij| <= atol + rtol * right_ij
///
/// This is analogous to NumPy's ``allclose`` function.
pub fn allclose(
left: &ArrayView2<Complex64>,
right: &ArrayView2<Complex64>,
rtol: f64,
atol: f64,
) -> bool {
for i in 0..left.nrows() {
for j in 0..left.ncols() {
if (left[(i, j)] - right[(i, j)]).abs() > atol + rtol * right[(i, j)].abs() {
return false;
}
}
}
true
}

0 comments on commit 9235004

Please sign in to comment.