Skip to content

Commit

Permalink
Ray's review comments
Browse files Browse the repository at this point in the history
Co-authored-by: Raynel Sanchez <[email protected]>
  • Loading branch information
Cryoris and raynelfss committed Sep 2, 2024
1 parent 71ac2fd commit 12d0a05
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 55 deletions.
77 changes: 42 additions & 35 deletions crates/accelerate/src/commutation_checker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ use numpy::PyReadonlyArray2;
use pyo3::exceptions::PyRuntimeError;
use pyo3::intern;
use pyo3::prelude::*;
use pyo3::types::{PyBool, PyDict, PySequence, PyTuple};
use pyo3::types::{IntoPyDict, PyBool, PyDict, PySequence, PyTuple};

use qiskit_circuit::bit_data::BitData;
use qiskit_circuit::circuit_instruction::{ExtraInstructionAttributes, OperationFromPython};
Expand Down Expand Up @@ -55,9 +55,9 @@ where
{
let mut bitdata: BitData<T> = BitData::new(py, "bits".to_string());

bits1.iter().chain(bits2.iter()).for_each(|bit| {
bitdata.add(py, &bit, false).unwrap();
});
for bit in bits1.iter().chain(bits2.iter()) {
bitdata.add(py, &bit, false)?;
}

Ok((
bitdata.map_bits(bits1)?.collect(),
Expand Down Expand Up @@ -191,7 +191,7 @@ impl CommutationChecker {
out_dict.set_item("current_cache_entries", self.current_cache_entries)?;
let cache_dict = PyDict::new_bound(py);
for (key, value) in &self.cache {
cache_dict.set_item(key, commutation_entry_to_pydict(py, value))?;
cache_dict.set_item(key, commutation_entry_to_pydict(py, value)?)?;
}
out_dict.set_item("cache", cache_dict)?;
out_dict.set_item("library", self.library.library.to_object(py))?;
Expand Down Expand Up @@ -397,7 +397,7 @@ impl CommutationChecker {
}

let first_qarg: Vec<Qubit> = Vec::from_iter((0..first_qargs.len() as u32).map(Qubit));
let second_qarg: Vec<Qubit> = second_qargs.iter().map(|q| *qarg.get(q).unwrap()).collect();
let second_qarg: Vec<Qubit> = second_qargs.iter().map(|q| qarg[q]).collect();

if first_qarg.len() > second_qarg.len() {
return Err(QiskitError::new_err(
Expand Down Expand Up @@ -461,14 +461,24 @@ impl CommutationChecker {
));
};

let op12 = unitary_compose::compose(
let op12 = match unitary_compose::compose(
&first_mat.view(),
&second_mat.view(),
&second_qarg,
false,
);
let op21 =
unitary_compose::compose(&first_mat.view(), &second_mat.view(), &second_qarg, true);
) {
Ok(matrix) => matrix,
Err(e) => return Err(PyRuntimeError::new_err(e)),
};
let op21 = match unitary_compose::compose(
&first_mat.view(),
&second_mat.view(),
&second_qarg,
true,
) {
Ok(matrix) => matrix,
Err(e) => return Err(PyRuntimeError::new_err(e)),
};
Ok(abs_diff_eq!(op12, op21, epsilon = 1e-8))
}
}
Expand Down Expand Up @@ -609,8 +619,8 @@ impl CommutationLibrary {
match py_any {
Some(pyob) => CommutationLibrary {
library: pyob
.extract::<Option<HashMap<(String, String), CommutationLibraryEntry>>>()
.unwrap(),
.extract::<HashMap<(String, String), CommutationLibraryEntry>>()
.ok(),
},
None => CommutationLibrary {
library: Some(HashMap::new()),
Expand Down Expand Up @@ -646,20 +656,17 @@ impl ToPyObject for CommutationLibraryEntry {
fn to_object(&self, py: Python) -> PyObject {
match self {
CommutationLibraryEntry::Commutes(b) => b.into_py(py),
CommutationLibraryEntry::QubitMapping(qm) => {
let out_dict = PyDict::new_bound(py);

qm.iter().for_each(|(k, v)| {
out_dict
.set_item(
PyTuple::new_bound(py, k.iter().map(|q| q.map(|t| t.0))),
PyBool::new_bound(py, *v),
)
.ok()
.unwrap()
});
out_dict.unbind().into_any()
}
CommutationLibraryEntry::QubitMapping(qm) => qm
.iter()
.map(|(k, v)| {
(
PyTuple::new_bound(py, k.iter().map(|q| q.map(|t| t.0))),
PyBool::new_bound(py, *v),
)
})
.into_py_dict_bound(py)
.unbind()
.into(),
}
}
}
Expand All @@ -671,20 +678,18 @@ type CacheKey = (

type CommutationCacheEntry = HashMap<CacheKey, bool>;

fn commutation_entry_to_pydict(py: Python, entry: &CommutationCacheEntry) -> Py<PyDict> {
fn commutation_entry_to_pydict(py: Python, entry: &CommutationCacheEntry) -> PyResult<Py<PyDict>> {
let out_dict = PyDict::new_bound(py);
for (k, v) in entry.iter() {
let qubits = PyTuple::new_bound(py, k.0.iter().map(|q| q.map(|t| t.0)));
let params0 = PyTuple::new_bound(py, k.1 .0.iter().map(|pk| pk.0));
let params1 = PyTuple::new_bound(py, k.1 .1.iter().map(|pk| pk.0));
out_dict
.set_item(
PyTuple::new_bound(py, [qubits, PyTuple::new_bound(py, [params0, params1])]),
PyBool::new_bound(py, *v),
)
.expect("Failed to construct commutation cache for serialization");
out_dict.set_item(
PyTuple::new_bound(py, [qubits, PyTuple::new_bound(py, [params0, params1])]),
PyBool::new_bound(py, *v),
)?;
}
out_dict.unbind()
Ok(out_dict.unbind())
}

fn commutation_cache_entry_from_pydict(dict: &Bound<PyDict>) -> PyResult<CommutationCacheEntry> {
Expand Down Expand Up @@ -756,7 +761,9 @@ fn hashable_params(params: &[Param]) -> PyResult<SmallVec<[ParameterKey; 3]>> {
Ok(ParameterKey(*x))
}
} else {
panic!("Unable to hash a non-float instruction parameter.")
Err(QiskitError::new_err(
"Unable to hash a non-float instruction parameter.",
))
}
})
.collect()
Expand Down
41 changes: 21 additions & 20 deletions crates/accelerate/src/unitary_compose.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,15 +33,15 @@ pub fn compose(
overall_unitary: &ArrayView2<Complex64>,
qubits: &[Qubit],
front: bool,
) -> Array2<Complex64> {
) -> Result<Array2<Complex64>, &'static str> {
let gate_qubits = gate_unitary.shape()[0].ilog2() as usize;

// Full composition of operators
if qubits.is_empty() {
if front {
return gate_unitary.dot(overall_unitary);
return Ok(gate_unitary.dot(overall_unitary));
} else {
return overall_unitary.dot(gate_unitary);
return Ok(overall_unitary.dot(gate_unitary));
}
}
// Compose with other on subsystem
Expand All @@ -61,14 +61,14 @@ pub fn compose(
.collect::<Vec<usize>>();
let num_rows = usize::pow(2, num_indices as u32);

let res = _einsum_matmul(&tensor, &mat, &indices, shift, right_mul)
let res = _einsum_matmul(&tensor, &mat, &indices, shift, right_mul)?
.as_standard_layout()
.into_shape((num_rows, num_rows))
.unwrap()
.into_dimensionality::<ndarray::Ix2>()
.unwrap()
.to_owned();
res
Ok(res)
}

// Reshape an input matrix to (2, 2, ..., 2) depending on its dimensionality
Expand All @@ -86,11 +86,11 @@ fn _einsum_matmul(
indices: &[usize],
shift: usize,
right_mul: bool,
) -> Array<Complex64, IxDyn> {
) -> Result<Array<Complex64, IxDyn>, &'static str> {
let rank = tensor.ndim();
let rank_mat = mat.ndim();
if rank_mat % 2 != 0 {
panic!("Contracted matrix must have an even number of indices.");
return Err("Contracted matrix must have an even number of indices.");
}
// Get einsum indices for tensor
let mut indices_tensor = (0..rank).collect::<Vec<usize>>();
Expand All @@ -110,16 +110,16 @@ fn _einsum_matmul(
[mat_free, mat_contract].concat()
};

let tensor_einsum = String::from_utf8(indices_tensor.iter().map(|c| LOWERCASE[*c]).collect())
.expect("Failed building tensor string.");
let mat_einsum = String::from_utf8(indices_mat.iter().map(|c| LOWERCASE[*c]).collect())
.expect("Failed building matrix string.");
let tensor_einsum = unsafe {
String::from_utf8_unchecked(indices_tensor.iter().map(|c| LOWERCASE[*c]).collect())
};
let mat_einsum =
unsafe { String::from_utf8_unchecked(indices_mat.iter().map(|c| LOWERCASE[*c]).collect()) };

einsum(
format!("{},{}", tensor_einsum, mat_einsum).as_str(),
&[tensor, mat],
)
.unwrap()
}

fn _einsum_matmul_helper(qubits: &[u32], num_qubits: usize) -> [String; 4] {
Expand All @@ -132,19 +132,20 @@ fn _einsum_matmul_helper(qubits: &[u32], num_qubits: usize) -> [String; 4] {
mat_l.push(LOWERCASE[25 - pos]);
tens_out[num_qubits - 1 - *idx as usize] = LOWERCASE[25 - pos];
});
[
String::from_utf8(mat_l).expect("Failed building string."),
String::from_utf8(mat_r).expect("Failed building string."),
String::from_utf8(tens_in).expect("Failed building string."),
String::from_utf8(tens_out).expect("Failed building string."),
]
unsafe {
[
String::from_utf8_unchecked(mat_l),
String::from_utf8_unchecked(mat_r),
String::from_utf8_unchecked(tens_in),
String::from_utf8_unchecked(tens_out),
]
}
}

fn _einsum_matmul_index(qubits: &[u32], num_qubits: usize) -> String {
assert!(num_qubits > 26, "Can't compute unitary of > 26 qubits");

let tens_r =
String::from_utf8(_UPPERCASE[..num_qubits].to_vec()).expect("Failed building string.");
let tens_r = unsafe { String::from_utf8_unchecked(_UPPERCASE[..num_qubits].to_vec()) };
let [mat_l, mat_r, tens_lin, tens_lout] = _einsum_matmul_helper(qubits, num_qubits);
format!(
"{}{}, {}{}->{}{}",
Expand Down

0 comments on commit 12d0a05

Please sign in to comment.