Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Track input linear units in Command #310

Merged
merged 5 commits into from
Apr 16, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
241 changes: 166 additions & 75 deletions tket2/src/circuit/command.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ use hugr::hugr::NodeType;
use hugr::ops::{OpTag, OpTrait};
use hugr::{IncomingPort, OutgoingPort};
use itertools::Either::{self, Left, Right};
use itertools::{EitherOrBoth, Itertools};
use petgraph::visit as pv;

use super::units::{filter, DefaultUnitLabeller, LinearUnit, UnitLabeller, Units};
Expand All @@ -25,11 +26,10 @@ pub struct Command<'circ, Circ> {
circ: &'circ Circ,
/// The operation node.
node: Node,
/// An assignment of linear units to the node's ports.
//
// We'll need something more complex if `follow_linear_port` stops being a
// direct map from input to output.
linear_units: Vec<LinearUnit>,
/// An assignment of linear units to the node's input ports.
input_linear_units: Vec<LinearUnit>,
/// An assignment of linear units to the node's output ports.
output_linear_units: Vec<LinearUnit>,
}

impl<'circ, Circ: Circuit> Command<'circ, Circ> {
Expand Down Expand Up @@ -165,7 +165,11 @@ impl<'circ, Circ: Circuit> Command<'circ, Circ> {
impl<'a, 'circ, Circ: Circuit> UnitLabeller for &'a Command<'circ, Circ> {
#[inline]
fn assign_linear(&self, _: Node, port: Port, _linear_count: usize) -> LinearUnit {
*self.linear_units.get(port.index()).unwrap_or_else(|| {
let units = match port.direction() {
Direction::Incoming => &self.input_linear_units,
Direction::Outgoing => &self.output_linear_units,
};
*units.get(port.index()).unwrap_or_else(|| {
panic!(
"Could not assign a linear unit to port {port:?} of node {:?}",
self.node
Expand All @@ -190,14 +194,17 @@ impl<'circ, Circ: Circuit> std::fmt::Debug for Command<'circ, Circ> {
f.debug_struct("Command")
.field("circuit name", &self.circ.name())
.field("node", &self.node)
.field("linear_units", &self.linear_units)
.field("input_linear_units", &self.input_linear_units)
.field("output_linear_units", &self.output_linear_units)
.finish()
}
}

impl<'circ, Circ> PartialEq for Command<'circ, Circ> {
fn eq(&self, other: &Self) -> bool {
self.node == other.node && self.linear_units == other.linear_units
self.node == other.node
&& self.input_linear_units == other.input_linear_units
&& self.output_linear_units == other.output_linear_units
}
}

Expand All @@ -208,29 +215,17 @@ impl<'circ, Circ> Clone for Command<'circ, Circ> {
Self {
circ: self.circ,
node: self.node,
linear_units: self.linear_units.clone(),
input_linear_units: self.input_linear_units.clone(),
output_linear_units: self.output_linear_units.clone(),
}
}
}

impl<'circ, Circ> std::hash::Hash for Command<'circ, Circ> {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.node.hash(state);
self.linear_units.hash(state);
}
}

impl<'circ, Circ> PartialOrd for Command<'circ, Circ> {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}

impl<'circ, Circ> Ord for Command<'circ, Circ> {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
self.node
.cmp(&other.node)
.then(self.linear_units.cmp(&other.linear_units))
self.input_linear_units.hash(state);
self.output_linear_units.hash(state);
}
}

Expand Down Expand Up @@ -356,12 +351,13 @@ where

/// Process a new node, updating wires in `unit_wires`.
///
/// Returns the an option with the `linear_units` used to construct a
/// [`Command`], if the node is not an input or output.
/// Returns the an option with the `input_linear_units` and
/// `output_linear_units` needed to construct a [`Command`], if the node is
/// not an input or output.
///
/// We don't return the command directly to avoid lifetime issues due to the
/// mutable borrow here.
fn process_node(&mut self, node: Node) -> Option<Vec<LinearUnit>> {
fn process_node(&mut self, node: Node) -> Option<(Vec<LinearUnit>, Vec<LinearUnit>)> {
// The root node is ignored.
if node == self.circ.root() {
return None;
Expand All @@ -373,56 +369,59 @@ where
return None;
}

// Collect the linear units passing through this command into the map
// Collect the linear units passing through this command into the maps
// required to construct a `Command`.
//
// Linear input ports are matched sequentially against the linear output
// ports, ignoring any non-linear ports when assigning unit ids. That
// is, the nth linear input is matched against the nth linear output,
// independently of whether there are any other ports mixed in.
//
// Updates the map tracking the last wire of linear units.
let linear_units: Vec<_> = Units::new_outgoing(self.circ, node, DefaultUnitLabeller)
.filter_map(filter::filter_linear)
.map(|(_, port, _)| {
// Find the linear unit id for this port.
let linear_id = self
.follow_linear_port(node, port)
.and_then(|input_port| {
let input_port = input_port.as_incoming().unwrap();
self.circ.linked_outputs(node, input_port).next()
})
.and_then(|(from, from_port)| {
// Remove the old wire from the map (if there was one)
self.wire_unit.remove(&Wire::new(from, from_port))
})
.unwrap_or({
// New linear unit found. Assign it a new id.
self.wire_unit.len()
});
// Update the map tracking the linear units
let new_wire = Wire::new(node, port);
self.wire_unit.insert(new_wire, linear_id);
LinearUnit::new(linear_id)
})
.collect();

Some(linear_units)
}

/// Returns the linear port on the node that corresponds to the same linear unit.
///
/// We assume the linear data uses the same port offsets on both sides of the node.
/// In the future we may want to have a more general mechanism to handle this.
//
// Note that `Command::linear_units` assumes this behaviour.
fn follow_linear_port(&self, node: Node, port: impl Into<Port>) -> Option<Port> {
let port = port.into();
let optype = self.circ.get_optype(node);
if !optype.port_kind(port)?.is_linear() {
return None;
}
let other_port = Port::new(port.direction().reverse(), port.index());
if optype.port_kind(other_port) == optype.port_kind(port) {
Some(other_port)
} else {
None
let mut input_linear_units = Vec::new();
let mut output_linear_units = Vec::new();

let input_units = Units::new_incoming(self.circ, node, DefaultUnitLabeller)
.filter_map(filter::filter_linear);
let output_units = Units::new_outgoing(self.circ, node, DefaultUnitLabeller)
.filter_map(filter::filter_linear);
for ports in input_units.zip_longest(output_units) {
// Terminate the input linear unit.
// Returns the linear id of the terminated unit.
let mut terminate_input =
|port: IncomingPort, wire_unit: &mut HashMap<Wire, usize>| -> Option<usize> {
let linear_id = self.circ.single_linked_output(node, port).and_then(
|(wire_node, wire_port)| wire_unit.remove(&Wire::new(wire_node, wire_port)),
)?;
input_linear_units.push(LinearUnit::new(linear_id));
Some(linear_id)
};

// Add a new linear unit for this output port.
let mut register_output =
|unit: usize, port: OutgoingPort, wire_unit: &mut HashMap<Wire, usize>| {
let wire = Wire::new(node, port);
wire_unit.insert(wire, unit);
output_linear_units.push(LinearUnit::new(unit));
};

match ports {
EitherOrBoth::Right((_, out_port, _)) => {
let new_id = self.wire_unit.len();
register_output(new_id, out_port, &mut self.wire_unit);
}
EitherOrBoth::Left((_, in_port, _)) => {
terminate_input(in_port, &mut self.wire_unit);
}
EitherOrBoth::Both((_, in_port, _), (_, out_port, _)) => {
if let Some(linear_id) = terminate_input(in_port, &mut self.wire_unit) {
register_output(linear_id, out_port, &mut self.wire_unit);
}
}
}
}

Some((input_linear_units, output_linear_units))
}
}

Expand All @@ -437,12 +436,13 @@ where
loop {
let node = self.next_node()?;
// Process the node, returning a command if it's not an input or output.
if let Some(linear_units) = self.process_node(node) {
if let Some((input_linear_units, output_linear_units)) = self.process_node(node) {
self.remaining -= 1;
return Some(Command {
circ: self.circ,
node,
linear_units,
input_linear_units,
output_linear_units,
});
}
}
Expand Down Expand Up @@ -476,7 +476,10 @@ mod test {
use hugr::std_extensions::arithmetic::float_types::ConstF64;
use hugr::types::FunctionType;
use itertools::Itertools;
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};

use crate::extension::REGISTRY;
use crate::utils::build_simple_circuit;
use crate::Tk2Op;

Expand Down Expand Up @@ -602,4 +605,92 @@ mod test {
[CircuitUnit::Linear(0)],
);
}

/// Commands that allocate and free linear units.
///
/// Creates the following circuit:
/// ```plaintext
/// -------------[ ]---[QFree]
/// [CX]
/// [QAlloc]---[ ]-------------
/// ```
/// and checks that every command is correctly generated, and correctly
/// computes input/output units.
#[test]
fn alloc_free() -> Result<(), Box<dyn std::error::Error>> {
let qb_row = vec![QB_T; 1];
let mut h = DFGBuilder::new(FunctionType::new(qb_row.clone(), qb_row))?;

let [q_in] = h.input_wires_arr();

let alloc = h.add_dataflow_op(Tk2Op::QAlloc, [])?;
let [q_new] = alloc.outputs_arr();

let cx = h.add_dataflow_op(Tk2Op::CX, [q_in, q_new])?;
let [q_in, q_new] = cx.outputs_arr();

let free = h.add_dataflow_op(Tk2Op::QFree, [q_in])?;

let circ = h.finish_hugr_with_outputs([q_new], &REGISTRY)?;

let mut cmds = circ.commands();

let alloc_cmd = cmds.next().unwrap();
assert_eq!(alloc_cmd.node(), alloc.node());
assert_eq!(
alloc_cmd.inputs().map(|(unit, _, _)| unit).collect_vec(),
[]
);
assert_eq!(
alloc_cmd.outputs().map(|(unit, _, _)| unit).collect_vec(),
[CircuitUnit::Linear(1)]
);

let cx_cmd = cmds.next().unwrap();
assert_eq!(cx_cmd.node(), cx.node());
assert_eq!(
cx_cmd.inputs().map(|(unit, _, _)| unit).collect_vec(),
[CircuitUnit::Linear(0), CircuitUnit::Linear(1)]
);
assert_eq!(
cx_cmd.outputs().map(|(unit, _, _)| unit).collect_vec(),
[CircuitUnit::Linear(0), CircuitUnit::Linear(1)]
);

let free_cmd = cmds.next().unwrap();
assert_eq!(free_cmd.node(), free.node());
assert_eq!(
free_cmd.inputs().map(|(unit, _, _)| unit).collect_vec(),
[CircuitUnit::Linear(0)]
);
assert_eq!(
free_cmd.outputs().map(|(unit, _, _)| unit).collect_vec(),
[]
);

Ok(())
}

/// Test the manual trait implementations of `Command`.
#[test]
fn test_impls() -> Result<(), Box<dyn std::error::Error>> {
let qb_row = vec![QB_T; 1];
let mut h = DFGBuilder::new(FunctionType::new(qb_row.clone(), vec![]))?;
let [q_in] = h.input_wires_arr();
h.add_dataflow_op(Tk2Op::QFree, [q_in])?;
let circ = h.finish_hugr_with_outputs([], &REGISTRY)?;

let cmd1 = circ.commands().next().unwrap();
let cmd2 = circ.commands().next().unwrap();

assert_eq!(cmd1, cmd2);

let mut hasher1 = DefaultHasher::new();
cmd1.hash(&mut hasher1);
let mut hasher2 = DefaultHasher::new();
cmd2.hash(&mut hasher2);
assert_eq!(hasher1.finish(), hasher2.finish());

Ok(())
}
}
Loading