Skip to content

Commit

Permalink
simplify code a little
Browse files Browse the repository at this point in the history
  • Loading branch information
Cryoris committed Jul 2, 2024
1 parent e4b95f8 commit 46a37aa
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 26 deletions.
5 changes: 3 additions & 2 deletions crates/accelerate/src/synthesis/permutation/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,10 @@ pub fn _synth_permutation_basic(py: Python, pattern: PyArrayLike1<i64>) -> PyRes
#[pyfunction]
#[pyo3(signature = (pattern))]
fn _synth_permutation_acg(py: Python, pattern: PyArrayLike1<i64>) -> PyResult<CircuitData> {
let view = pattern.as_array();
let inverted = utils::invert(&pattern.as_array());
let view = inverted.view();
let num_qubits = view.len();
let cycles = utils::pattern_to_cycles(&view, &true);
let cycles = utils::pattern_to_cycles(&view);
let swaps = utils::decompose_cycles(&cycles);

CircuitData::from_standard_gates(
Expand Down
50 changes: 26 additions & 24 deletions crates/accelerate/src/synthesis/permutation/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ use pyo3::exceptions::PyValueError;
use pyo3::prelude::*;
use std::vec::Vec;

use qiskit_circuit::slice::{PySequenceIndex, PySequenceIndexError, SequenceIndex};

pub fn validate_permutation(pattern: &ArrayView1<i64>) -> PyResult<()> {
let n = pattern.len();
let mut seen: Vec<bool> = vec![false; n];
Expand Down Expand Up @@ -89,29 +91,22 @@ pub fn get_ordered_swap(pattern: &ArrayView1<i64>) -> Vec<(usize, usize)> {
/// example: let a pattern be [1, 2, 3, 0, 4, 6, 5], then it contains the two
/// cycles [1, 2, 3, 0] and [6, 5]. The index [4] does not perform a permutation and does
/// therefore not create a cycle.
pub fn pattern_to_cycles(pattern: &ArrayView1<i64>, invert_order: &bool) -> Vec<Vec<usize>> {
pub fn pattern_to_cycles(pattern: &ArrayView1<usize>) -> Vec<Vec<usize>> {
// vector keeping track of which elements in the permutation pattern have been visited
let mut explored: Vec<bool> = vec![false; pattern.len()];

// vector to store the cycles
let mut cycles: Vec<Vec<usize>> = Vec::new();

// cast the input pattern in terms of integers to usize, such that it can be used as index
// also invert the bit ordering if ``invert_order`` is true
let permutation: Array1<usize> = if *invert_order {
invert(pattern) // implies cast to usize
} else {
pattern.mapv(|x| x as usize)
};

for mut i in permutation.clone() {
for pos in pattern {
let mut cycle: Vec<usize> = Vec::new();

// follow the cycle until we reached an entry we saw before
let mut i = *pos;
while !explored[i] {
cycle.push(i);
explored[i] = true;
i = permutation[i];
i = pattern[i];
}
// cycles must have more than 1 element
if cycle.len() > 1 {
Expand All @@ -122,6 +117,14 @@ pub fn pattern_to_cycles(pattern: &ArrayView1<i64>, invert_order: &bool) -> Vec<
cycles
}

/// Periodic (or Python-like) access to a vector.
/// Util used below in ``decompose_cycles``.
#[inline]
fn pget(vec: &Vec<usize>, index: isize) -> Result<usize, PySequenceIndexError> {
let SequenceIndex::Int(wrapped) = PySequenceIndex::Int(index).with_len(vec.len())? else {unreachable!()};
Ok(vec[wrapped])
}

/// Given a disjoint cycle decomposition of a permutation pattern (see the function
/// ``pattern_to_cycles``), decomposes every cycle into a series of SWAPs to implement it.
/// In combination with ``pattern_to_cycle``, this function allows to implement a
Expand All @@ -130,20 +133,19 @@ pub fn decompose_cycles(cycles: &Vec<Vec<usize>>) -> Vec<(usize, usize)> {
let mut swaps: Vec<(usize, usize)> = Vec::new();

for cycle in cycles {
let length = cycle.len();

if length > 2 {
// handle first element separately, which accesses the last element
swaps.push((cycle[length - 1], cycle[length - 3]));
for i in 1..(length - 1) / 2 {
swaps.push((cycle[i - 1], cycle[length - 3 - i]));
}
}
let length = cycle.len() as isize;

// no size check needed, cycles always have at least 2 elements
swaps.push((cycle[length - 1], cycle[length - 2]));
for i in 1..length / 2 {
swaps.push((cycle[i - 1], cycle[length - 2 - i]));
for idx in 0..(length - 1) / 2 {
swaps.push((
pget(cycle, idx - 1).unwrap(),
pget(cycle, length - 3 - idx).unwrap(),
));
}
for idx in 0..length / 2 {
swaps.push((
pget(cycle, idx - 1).unwrap(),
pget(cycle, length - 2 - idx).unwrap(),
));
}
}

Expand Down

0 comments on commit 46a37aa

Please sign in to comment.