Skip to content

Commit

Permalink
feat: Hugrs as circuits, no more lifetimes
Browse files Browse the repository at this point in the history
  • Loading branch information
aborgna-q committed Sep 7, 2023
1 parent a62ce43 commit 9fc92da
Show file tree
Hide file tree
Showing 14 changed files with 77 additions and 128 deletions.
4 changes: 2 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ members = ["pyrs", "compile-matcher"]

[workspace.dependencies]

quantinuum-hugr = { git = "https://github.com/CQCL-DEV/hugr", rev = "5a97a635" }
portgraph = { version = "0.8", features = ["serde"] }
quantinuum-hugr = { git = "https://github.com/CQCL-DEV/hugr", rev = "e23323d" }
portgraph = { version = "0.9", features = ["serde"] }
pyo3 = { version = "0.19" }
itertools = { version = "0.11.0" }
4 changes: 2 additions & 2 deletions benches/benchmarks/hash.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use criterion::{black_box, criterion_group, AxisScale, BenchmarkId, Criterion, PlotConfiguration};
use hugr::hugr::views::SiblingGraph;
use hugr::hugr::views::{HierarchyView, SiblingGraph};
use hugr::HugrView;
use tket2::circuit::{CircuitHash, HierarchyView};
use tket2::circuit::CircuitHash;

use super::generators::make_cnot_layers;

Expand Down
2 changes: 1 addition & 1 deletion compile-matcher/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use std::fs;
use std::path::Path;

use clap::Parser;
use hugr::hugr::views::{HierarchyView, SiblingGraph};
use hugr::hugr::views::SiblingGraph;
use hugr::ops::handle::DfgID;
use hugr::HugrView;
use itertools::Itertools;
Expand Down
64 changes: 26 additions & 38 deletions src/circuit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,10 @@ use hugr::hugr::{CircuitUnit, NodeType};
use hugr::ops::OpTrait;
use hugr::HugrView;

pub use hugr::hugr::views::HierarchyView;
pub use hugr::ops::OpType;
use hugr::types::TypeBound;
pub use hugr::types::{EdgeKind, Signature, Type, TypeRow};
pub use hugr::{Node, Port, Wire};
use petgraph::visit::{GraphBase, IntoNeighborsDirected, IntoNodeIdentifiers};

/// An object behaving like a quantum circuit.
//
Expand All @@ -32,9 +30,11 @@ use petgraph::visit::{GraphBase, IntoNeighborsDirected, IntoNodeIdentifiers};
// - Vertical slice iterator
// - Gate count map
// - Depth
pub trait Circuit<'circ>: HugrView {
pub trait Circuit: HugrView {
/// An iterator over the commands in the circuit.
type Commands: Iterator<Item = Command>;
type Commands<'a>: Iterator<Item = Command>
where
Self: 'a;

/// An iterator over the commands applied to an unit.
type UnitCommands: Iterator<Item = Command>;
Expand Down Expand Up @@ -67,10 +67,10 @@ pub trait Circuit<'circ>: HugrView {
/// Returns all the commands in the circuit, in some topological order.
///
/// Ignores the Input and Output nodes.
fn commands(&'circ self) -> Self::Commands;
fn commands(&self) -> Self::Commands<'_>;

/// Returns all the commands applied to the given unit, in order.
fn unit_commands(&'circ self) -> Self::UnitCommands;
fn unit_commands(&self) -> Self::UnitCommands;

/// Returns the [`NodeType`] of a command.
fn command_nodetype(&self, command: &Command) -> &NodeType {
Expand All @@ -86,12 +86,11 @@ pub trait Circuit<'circ>: HugrView {
fn num_gates(&self) -> usize;
}

impl<'circ, T> Circuit<'circ> for T
impl<T> Circuit for T
where
T: 'circ + HierarchyView<'circ>,
for<'a> &'a T: GraphBase<NodeId = Node> + IntoNeighborsDirected + IntoNodeIdentifiers,
T: HugrView,
{
type Commands = CommandIterator<'circ, T>;
type Commands<'a> = CommandIterator<'a, T> where Self: 'a;
type UnitCommands = std::iter::Empty<Command>;

#[inline]
Expand Down Expand Up @@ -129,12 +128,12 @@ where
}
}

fn commands(&'circ self) -> Self::Commands {
fn commands(&self) -> Self::Commands<'_> {
// Traverse the circuit in topological order.
CommandIterator::new(self)
}

fn unit_commands(&'circ self) -> Self::UnitCommands {
fn unit_commands(&self) -> Self::UnitCommands {
// TODO Can we associate linear i/o with the corresponding unit without
// doing the full toposort?
unimplemented!()
Expand All @@ -158,35 +157,24 @@ where

#[cfg(test)]
mod tests {
use std::sync::OnceLock;

use hugr::{
hugr::views::{DescendantsGraph, HierarchyView},
ops::handle::DfgID,
Hugr, HugrView,
};
use hugr::Hugr;

use crate::{circuit::Circuit, json::load_tk1_json_str};

static CIRC: OnceLock<Hugr> = OnceLock::new();

fn test_circuit() -> DescendantsGraph<'static, DfgID> {
let hugr = CIRC.get_or_init(|| {
load_tk1_json_str(
r#"{
"phase": "0",
"bits": [],
"qubits": [["q", [0]], ["q", [1]]],
"commands": [
{"args": [["q", [0]]], "op": {"type": "H"}},
{"args": [["q", [0]], ["q", [1]]], "op": {"type": "CX"}}
],
"implicit_permutation": [[["q", [0]], ["q", [0]]], [["q", [1]], ["q", [1]]]]
}"#,
)
.unwrap()
});
DescendantsGraph::new(hugr, hugr.root())
fn test_circuit() -> Hugr {
load_tk1_json_str(
r#"{
"phase": "0",
"bits": [],
"qubits": [["q", [0]], ["q", [1]]],
"commands": [
{"args": [["q", [0]]], "op": {"type": "H"}},
{"args": [["q", [0]], ["q", [1]]], "op": {"type": "CX"}}
],
"implicit_permutation": [[["q", [0]], ["q", [0]]], [["q", [1]], ["q", [1]]]]
}"#,
)
.unwrap()
}

#[test]
Expand Down
17 changes: 4 additions & 13 deletions src/circuit/command.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,7 @@
use std::collections::HashMap;
use std::iter::FusedIterator;

use hugr::hugr::views::HierarchyView;
use hugr::ops::{OpTag, OpTrait};
use petgraph::visit::{GraphBase, IntoNeighborsDirected, IntoNodeIdentifiers};

use super::Circuit;

Expand Down Expand Up @@ -59,8 +57,7 @@ pub struct CommandIterator<'circ, Circ> {

impl<'circ, Circ> CommandIterator<'circ, Circ>
where
Circ: HierarchyView<'circ>,
for<'a> &'a Circ: GraphBase<NodeId = Node> + IntoNeighborsDirected + IntoNodeIdentifiers,
Circ: Circuit,
{
/// Create a new iterator over the commands of a circuit.
pub(super) fn new(circ: &'circ Circ) -> Self {
Expand All @@ -77,7 +74,7 @@ where
})
.collect();

let nodes = petgraph::algo::toposort(circ, None).unwrap();
let nodes = petgraph::algo::toposort(&circ.as_petgraph(), None).unwrap();
Self {
circ,
nodes,
Expand Down Expand Up @@ -157,8 +154,7 @@ where

impl<'circ, Circ> Iterator for CommandIterator<'circ, Circ>
where
Circ: HierarchyView<'circ>,
for<'a> &'a Circ: GraphBase<NodeId = Node> + IntoNeighborsDirected + IntoNodeIdentifiers,
Circ: Circuit,
{
type Item = Command;

Expand All @@ -182,12 +178,7 @@ where
}
}

impl<'circ, Circ> FusedIterator for CommandIterator<'circ, Circ>
where
Circ: HierarchyView<'circ>,
for<'a> &'a Circ: GraphBase<NodeId = Node> + IntoNeighborsDirected + IntoNodeIdentifiers,
{
}
impl<'circ, Circ> FusedIterator for CommandIterator<'circ, Circ> where Circ: Circuit {}

#[cfg(test)]
mod test {
Expand Down
10 changes: 5 additions & 5 deletions src/circuit/hash.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ use core::panic;
use std::hash::{Hash, Hasher};

use fxhash::{FxHashMap, FxHasher64};
use hugr::hugr::views::HierarchyView;
use hugr::ops::{LeafOp, OpName, OpTag, OpTrait, OpType};
use hugr::types::TypeBound;
use hugr::{HugrView, Node, Port};
Expand All @@ -29,14 +28,15 @@ pub trait CircuitHash<'circ>: HugrView {

impl<'circ, T> CircuitHash<'circ> for T
where
T: HugrView + HierarchyView<'circ>,
for<'a> &'a T:
pg::GraphBase<NodeId = Node> + pg::IntoNeighborsDirected + pg::IntoNodeIdentifiers,
T: HugrView,
{
fn circuit_hash(&'circ self) -> u64 {
let mut hash_state = HashState::default();

for node in pg::Topo::new(self).iter(self).filter(|&n| n != self.root()) {
for node in pg::Topo::new(&self.as_petgraph())
.iter(&self.as_petgraph())
.filter(|&n| n != self.root())
{
let hash = hash_node(self, node, &mut hash_state);
hash_state.set_node(self, node, hash);
}
Expand Down
4 changes: 2 additions & 2 deletions src/json.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ pub trait TKETDecode: Sized {
/// Convert the serialized circuit to a [`Hugr`].
fn decode(self) -> Result<Hugr, Self::DecodeError>;
/// Convert a [`Hugr`] to a new serialized circuit.
fn encode<'circ>(circuit: &'circ impl Circuit<'circ>) -> Result<Self, Self::EncodeError>;
fn encode(circuit: &impl Circuit) -> Result<Self, Self::EncodeError>;
}

impl TKETDecode for SerialCircuit {
Expand All @@ -60,7 +60,7 @@ impl TKETDecode for SerialCircuit {
Ok(decoder.finish())
}

fn encode<'circ>(circ: &'circ impl Circuit<'circ>) -> Result<Self, Self::EncodeError> {
fn encode(circ: &impl Circuit) -> Result<Self, Self::EncodeError> {
let mut encoder = JsonEncoder::new(circ);
for com in circ.commands() {
let optype = circ.command_optype(&com);
Expand Down
2 changes: 1 addition & 1 deletion src/json/encoder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ pub(super) struct JsonEncoder {

impl JsonEncoder {
/// Create a new [`JsonEncoder`] from a [`Circuit`].
pub fn new<'circ>(circ: &impl Circuit<'circ>) -> Self {
pub fn new(circ: &impl Circuit) -> Self {
let name = circ.name().map(str::to_string);

// Compute the linear qubit and bit registers. Each one have independent
Expand Down
8 changes: 2 additions & 6 deletions src/ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -289,12 +289,8 @@ pub(crate) mod test {

use std::sync::Arc;

use hugr::{
extension::OpDef,
hugr::views::{HierarchyView, SiblingGraph},
ops::handle::DfgID,
Hugr, HugrView,
};
use hugr::hugr::views::HierarchyView;
use hugr::{extension::OpDef, hugr::views::SiblingGraph, ops::handle::DfgID, Hugr, HugrView};
use rstest::{fixture, rstest};

use crate::{circuit::Circuit, ops::SimpleOpEnum, utils::build_simple_circuit};
Expand Down
62 changes: 23 additions & 39 deletions src/portmatching/matcher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ pub struct PatternMatch<'a, C> {
pub(super) root: Node,
}

impl<'a, C: Circuit<'a> + Clone> PatternMatch<'a, C> {
impl<'a, C: Circuit + Clone> PatternMatch<'a, C> {
/// The matcher's pattern ID of the match.
pub fn pattern_id(&self) -> PatternID {
self.pattern
Expand Down Expand Up @@ -187,7 +187,7 @@ impl<'a, C: Circuit<'a> + Clone> PatternMatch<'a, C> {
}
}

impl<'a, C: Circuit<'a>> Debug for PatternMatch<'a, C> {
impl<'a, C: Circuit> Debug for PatternMatch<'a, C> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("PatternMatch")
.field("root", &self.root)
Expand Down Expand Up @@ -237,10 +237,7 @@ impl PatternMatcher {
}

/// Find all convex pattern matches in a circuit.
pub fn find_matches<'a, C: Circuit<'a> + Clone>(
&self,
circuit: &'a C,
) -> Vec<PatternMatch<'a, C>> {
pub fn find_matches<'a, C: Circuit + Clone>(&self, circuit: &'a C) -> Vec<PatternMatch<'a, C>> {
let mut checker = ConvexChecker::new(circuit);
circuit
.commands()
Expand All @@ -249,7 +246,7 @@ impl PatternMatcher {
}

/// Find all convex pattern matches in a circuit rooted at a given node.
fn find_rooted_matches<'a, C: Circuit<'a> + Clone>(
fn find_rooted_matches<'a, C: Circuit + Clone>(
&self,
circ: &'a C,
root: Node,
Expand Down Expand Up @@ -385,8 +382,8 @@ fn compatible_offsets((_, pout): &(Port, Port), (pin, _): &(Port, Port)) -> bool
}

/// Check if an edge `e` is valid in a portgraph `g` without weights.
pub(crate) fn validate_unweighted_edge<'circ>(
circ: &impl Circuit<'circ>,
pub(crate) fn validate_unweighted_edge(
circ: &impl Circuit,
) -> impl for<'a> Fn(Node, &'a PEdge) -> Option<Node> + '_ {
move |src, &(src_port, tgt_port)| {
let (next_node, _) = circ
Expand All @@ -397,8 +394,8 @@ pub(crate) fn validate_unweighted_edge<'circ>(
}

/// Check if a node `n` is valid in a weighted portgraph `g`.
pub(crate) fn validate_weighted_node<'circ>(
circ: &impl Circuit<'circ>,
pub(crate) fn validate_weighted_node(
circ: &impl Circuit,
) -> impl for<'a> Fn(Node, &PNode) -> bool + '_ {
move |v, prop| {
let v_weight = MatchOp::try_from(circ.get_optype(v).clone());
Expand All @@ -425,43 +422,30 @@ fn handle_match_error<T>(match_res: Result<T, InvalidPatternMatch>, root: Node)

#[cfg(test)]
mod tests {
use std::sync::OnceLock;

use hugr::hugr::views::{DescendantsGraph, HierarchyView};
use hugr::ops::handle::DfgID;
use hugr::{Hugr, HugrView};
use hugr::Hugr;
use itertools::Itertools;

use crate::utils::build_simple_circuit;
use crate::T2Op;

use super::{CircuitPattern, PatternMatcher};

static H_CX: OnceLock<Hugr> = OnceLock::new();
static CX_CX: OnceLock<Hugr> = OnceLock::new();

fn h_cx<'a>() -> DescendantsGraph<'a, DfgID> {
let circ = H_CX.get_or_init(|| {
build_simple_circuit(2, |circ| {
circ.append(T2Op::CX, [0, 1]).unwrap();
circ.append(T2Op::H, [0]).unwrap();
Ok(())
})
.unwrap()
});
DescendantsGraph::new(circ, circ.root())
fn h_cx() -> Hugr {
build_simple_circuit(2, |circ| {
circ.append(T2Op::CX, [0, 1]).unwrap();
circ.append(T2Op::H, [0]).unwrap();
Ok(())
})
.unwrap()
}

fn cx_xc<'a>() -> DescendantsGraph<'a, DfgID> {
let circ = CX_CX.get_or_init(|| {
build_simple_circuit(2, |circ| {
circ.append(T2Op::CX, [0, 1]).unwrap();
circ.append(T2Op::CX, [1, 0]).unwrap();
Ok(())
})
.unwrap()
});
DescendantsGraph::new(circ, circ.root())
fn cx_xc() -> Hugr {
build_simple_circuit(2, |circ| {
circ.append(T2Op::CX, [0, 1]).unwrap();
circ.append(T2Op::CX, [1, 0]).unwrap();
Ok(())
})
.unwrap()
}

#[test]
Expand Down
Loading

0 comments on commit 9fc92da

Please sign in to comment.