Skip to content

Commit

Permalink
feat: Replace Circuit::num_gates with num_operations (#384)
Browse files Browse the repository at this point in the history
Closes #105. Closes #108.

`num_gates` used to count every node in the top-level region, giving
unexpected results on results with constants, control flow, or anything
other than simple gates.

`num_operations` now only counts `CustomOp`s, traversing containers as
needed.

I also improved the circuit unit tests, to include circuits in modules
and circuits in `FuncDefn`s (instead of `DFG`s).

Some notes:
- Part of the tests testing parametric operations is commented out until
we solve CQCL/hugr#1166.
- Although the test circuits have function names, `Circuit::name`
returns `None`. I'll address that in another PR.
  • Loading branch information
aborgna-q authored Jun 6, 2024
1 parent 70d18ae commit 093e650
Show file tree
Hide file tree
Showing 6 changed files with 91 additions and 25 deletions.
97 changes: 80 additions & 17 deletions tket2/src/circuit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -147,14 +147,31 @@ impl<T: HugrView> Circuit<T> {
.expect("Circuit has no I/O nodes")
}

/// The number of quantum gates in the circuit.
/// The number of operations in the circuit.
///
/// This includes [`Tk2Op`]s, pytket ops, and any other custom operations.
///
/// Nested circuits are traversed to count their operations.
///
/// [`Tk2Op`]: crate::Tk2Op
#[inline]
pub fn num_gates(&self) -> usize
pub fn num_operations(&self) -> usize
where
Self: Sized,
{
// TODO: Discern quantum gates in the commands iterator.
self.hugr().children(self.parent).count() - 2
let mut count = 0;
let mut roots = vec![self.parent];
while let Some(node) = roots.pop() {
for child in self.hugr().children(node) {
let optype = self.hugr().get_optype(child);
if optype.is_custom_op() {
count += 1;
} else if OpTag::DataflowParent.is_superset(optype.tag()) {
roots.push(child);
}
}
}
count
}

/// Count the number of qubits in the circuit.
Expand Down Expand Up @@ -471,6 +488,7 @@ fn update_signature(
#[cfg(test)]
mod tests {
use cool_asserts::assert_matches;
use rstest::{fixture, rstest};

use hugr::types::FunctionType;
use hugr::{
Expand All @@ -479,38 +497,83 @@ mod tests {
};

use super::*;
use crate::utils::build_module_with_circuit;
use crate::{json::load_tk1_json_str, utils::build_simple_circuit, Tk2Op};

fn test_circuit() -> Circuit {
#[fixture]
fn tk1_circuit() -> Circuit {
load_tk1_json_str(
r#"{ "phase": "0",
"bits": [["c", [0]]],
"qubits": [["q", [0]], ["q", [1]]],
"commands": [
{"args": [["q", [0]]], "op": {"type": "H"}},
{"args": [["q", [0]], ["q", [1]]], "op": {"type": "CX"}},
{"args": [["q", [1]]], "op": {"type": "X"}}
{"args": [["q", [1]]], "op": {"params": ["0.25"], "type": "Rz"}}
],
"implicit_permutation": [[["q", [0]], ["q", [0]]], [["q", [1]], ["q", [1]]]]
}"#,
)
.unwrap()
}

#[test]
fn test_circuit_properties() {
let circ = test_circuit();
/// 2-qubit circuit with a Hadamard, a CNOT, and a X gate.
#[fixture]
fn simple_circuit() -> Circuit {
build_simple_circuit(2, |circ| {
circ.append(Tk2Op::H, [0])?;
circ.append(Tk2Op::CX, [0, 1])?;
circ.append(Tk2Op::X, [1])?;

assert_eq!(circ.name(), None);
assert_eq!(circ.circuit_signature().body().input_count(), 3);
assert_eq!(circ.circuit_signature().body().output_count(), 3);
assert_eq!(circ.qubit_count(), 2);
assert_eq!(circ.num_gates(), 3);
// TODO: Replace the `X` with the following once Hugr adds `CircuitBuilder::add_constant`.
// See https://github.com/CQCL/hugr/pull/1168

//let angle = circ.add_constant(ConstF64::new(0.5));
//circ.append_and_consume(
// Tk2Op::RzF64,
// [CircuitUnit::Linear(1), CircuitUnit::Wire(angle)],
//)?;
Ok(())
})
.unwrap()
}

/// 2-qubit circuit with a Hadamard, a CNOT, and a X gate,
/// defined inside a module.
#[fixture]
fn simple_module() -> Circuit {
build_module_with_circuit(2, |circ| {
circ.append(Tk2Op::H, [0])?;
circ.append(Tk2Op::CX, [0, 1])?;
circ.append(Tk2Op::X, [1])?;
Ok(())
})
.unwrap()
}

#[rstest]
#[case::simple(simple_circuit(), 2, 0, None)]
#[case::module(simple_module(), 2, 0, None)]
#[case::tk1(tk1_circuit(), 2, 1, None)]
fn test_circuit_properties(
#[case] circ: Circuit,
#[case] qubits: usize,
#[case] bits: usize,
#[case] name: Option<&str>,
) {
assert_eq!(circ.name(), name);
assert_eq!(circ.circuit_signature().body().input_count(), qubits + bits);
assert_eq!(
circ.circuit_signature().body().output_count(),
qubits + bits
);
assert_eq!(circ.qubit_count(), qubits);
assert_eq!(circ.num_operations(), 3);

assert_eq!(circ.units().count(), 3);
assert_eq!(circ.units().count(), qubits + bits);
assert_eq!(circ.nonlinear_units().count(), 0);
assert_eq!(circ.linear_units().count(), 3);
assert_eq!(circ.qubits().count(), 2);
assert_eq!(circ.linear_units().count(), qubits + bits);
assert_eq!(circ.qubits().count(), qubits);
}

#[test]
Expand Down
2 changes: 1 addition & 1 deletion tket2/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
//! let mut circ: Circuit = tket2::json::load_tk1_json_file("../test_files/barenco_tof_5.json").unwrap();
//!
//! assert_eq!(circ.qubit_count(), 9);
//! assert_eq!(circ.num_gates(), 170);
//! assert_eq!(circ.num_operations(), 170);
//!
//! // Traverse the circuit and print the gates.
//! for command in circ.commands() {
Expand Down
5 changes: 4 additions & 1 deletion tket2/src/optimiser/badger/eq_circ_class.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,10 @@ impl EqCircClass {
};

// Find the index for the smallest circuit
let min_index = circs.iter().position_min_by_key(|c| c.num_gates()).unwrap();
let min_index = circs
.iter()
.position_min_by_key(|c| c.num_operations())
.unwrap();
let representative = circs.swap_remove(min_index);
Ok(Self::new(representative, circs))
}
Expand Down
2 changes: 1 addition & 1 deletion tket2/src/portmatching/pattern.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ impl CircuitPattern {
/// Construct a pattern from a circuit.
pub fn try_from_circuit(circuit: &Circuit) -> Result<Self, InvalidPattern> {
let hugr = circuit.hugr();
if circuit.num_gates() == 0 {
if circuit.num_operations() == 0 {
return Err(InvalidPattern::EmptyCircuit);
}
let mut pattern = Pattern::new();
Expand Down
2 changes: 1 addition & 1 deletion tket2/src/rewrite.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ impl CircuitRewrite {
/// The difference between the new number of nodes minus the old. A positive
/// number is an increase in node count, a negative number is a decrease.
pub fn node_count_delta(&self) -> isize {
let new_count = self.replacement().num_gates() as isize;
let new_count = self.replacement().num_operations() as isize;
let old_count = self.subcircuit().node_count() as isize;
new_count - old_count
}
Expand Down
8 changes: 4 additions & 4 deletions tket2/src/rewrite/strategy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ impl RewriteStrategy for GreedyRewriteStrategy {
}

fn circuit_cost(&self, circ: &Circuit<impl HugrView>) -> Self::Cost {
circ.num_gates()
circ.num_operations()
}

fn op_cost(&self, _op: &OpType) -> Self::Cost {
Expand Down Expand Up @@ -488,7 +488,7 @@ mod tests {
let strategy = GreedyRewriteStrategy;
let rewritten = strategy.apply_rewrites(rws, &circ).collect_vec();
assert_eq!(rewritten.len(), 1);
assert_eq!(rewritten[0].circ.num_gates(), 5);
assert_eq!(rewritten[0].circ.num_operations(), 5);

if REWRITE_TRACING_ENABLED {
assert_eq!(rewritten[0].circ.rewrite_trace().unwrap().len(), 3);
Expand All @@ -511,7 +511,7 @@ mod tests {
let strategy = LexicographicCostFunction::default_cx();
let rewritten = strategy.apply_rewrites(rws, &circ).collect_vec();
let exp_circ_lens = HashSet::from_iter([3, 7, 9]);
let circ_lens: HashSet<_> = rewritten.iter().map(|r| r.circ.num_gates()).collect();
let circ_lens: HashSet<_> = rewritten.iter().map(|r| r.circ.num_operations()).collect();
assert_eq!(circ_lens, exp_circ_lens);

if REWRITE_TRACING_ENABLED {
Expand Down Expand Up @@ -547,7 +547,7 @@ mod tests {
let strategy = GammaStrategyCost::exhaustive_cx_with_gamma(10.);
let rewritten = strategy.apply_rewrites(rws, &circ);
let exp_circ_lens = HashSet::from_iter([8, 17, 6, 9]);
let circ_lens: HashSet<_> = rewritten.map(|r| r.circ.num_gates()).collect();
let circ_lens: HashSet<_> = rewritten.map(|r| r.circ.num_operations()).collect();
assert_eq!(circ_lens, exp_circ_lens);
}

Expand Down

0 comments on commit 093e650

Please sign in to comment.