Skip to content

Commit

Permalink
Merge branch 'main' into feat/num-gates
Browse files Browse the repository at this point in the history
# Conflicts:
#	src/circuit.rs
  • Loading branch information
lmondada committed Aug 30, 2023
2 parents 1dcd79c + b13dcea commit b39a6fb
Show file tree
Hide file tree
Showing 13 changed files with 352 additions and 134 deletions.
4 changes: 3 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ portgraph = { workspace = true }
pyo3 = { workspace = true, optional = true, features = [
"multiple-pymethods",
] }
strum_macros = "0.25.2"
strum = "0.25.0"

[features]
pyo3 = ["dep:pyo3", "tket-json-rs/pyo3", "tket-json-rs/tket2ops", "portgraph/pyo3", "quantinuum-hugr/pyo3"]
Expand All @@ -57,4 +59,4 @@ members = ["pyrs"]

quantinuum-hugr = { git = "https://github.com/CQCL-DEV/hugr", branch = "fix/no-resource-validation" }
portgraph = "0.8"
pyo3 = { version = "0.19" }
pyo3 = { version = "0.19" }
31 changes: 21 additions & 10 deletions src/circuit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ pub mod command;
use self::command::{Command, CommandIterator};

use hugr::extension::prelude::QB_T;
use hugr::hugr::CircuitUnit;
use hugr::hugr::{CircuitUnit, NodeType};
use hugr::ops::OpTrait;
use hugr::HugrView;

Expand All @@ -31,18 +31,18 @@ use petgraph::visit::{GraphBase, IntoNeighborsDirected, IntoNodeIdentifiers};
// - Depth
pub trait Circuit<'circ>: HugrView {
/// An iterator over the commands in the circuit.
type Commands: Iterator<Item = Command<'circ>>;
type Commands: Iterator<Item = Command>;

/// An iterator over the commands applied to an unit.
type UnitCommands: Iterator<Item = Command<'circ>>;
type UnitCommands: Iterator<Item = Command>;

/// Return the name of the circuit
fn name(&self) -> Option<&str>;

/// Get the linear inputs of the circuit and their types.
fn units(&self) -> Vec<(CircuitUnit, Type)>;

/// Returns the ports corresponding to qubits inputs to the circuit.
/// Returns the units corresponding to qubits inputs to the circuit.
#[inline]
fn qubits(&self) -> Vec<CircuitUnit> {
self.units()
Expand All @@ -52,6 +52,12 @@ pub trait Circuit<'circ>: HugrView {
.collect()
}

/// Returns the input node to the circuit.
fn input(&self) -> Node;

/// Returns the output node to the circuit.
fn output(&self) -> Node;

/// Given a linear port in a node, returns the corresponding port on the other side of the node (if any).
fn follow_linear_port(&self, node: Node, port: Port) -> Option<Port>;

Expand All @@ -63,14 +69,19 @@ pub trait Circuit<'circ>: HugrView {
/// Returns all the commands applied to the given unit, in order.
fn unit_commands<'a: 'circ>(&'a self) -> Self::UnitCommands;

/// Returns the input node to the circuit.
fn input(&self) -> Node;
/// Returns the [`NodeType`] of a command.
fn command_nodetype(&self, command: &Command) -> &NodeType {
self.get_nodetype(command.node())
}

/// Returns the output node to the circuit.
fn output(&self) -> Node;
/// Returns the [`OpType`] of a command.
fn command_optype(&self, command: &Command) -> &OpType {
self.get_optype(command.node())
}

/// The number of gates in the circuit.
fn num_gates(&self) -> usize;

}

impl<'circ, T> Circuit<'circ> for T
Expand All @@ -79,7 +90,7 @@ where
for<'a> &'a T: GraphBase<NodeId = Node> + IntoNeighborsDirected + IntoNodeIdentifiers,
{
type Commands = CommandIterator<'circ, T>;
type UnitCommands = std::iter::Empty<Command<'circ>>;
type UnitCommands = std::iter::Empty<Command>;

#[inline]
fn name(&self) -> Option<&str> {
Expand Down Expand Up @@ -124,7 +135,7 @@ where
fn unit_commands<'a: 'circ>(&'a self) -> Self::UnitCommands {
// TODO Can we associate linear i/o with the corresponding unit without
// doing the full toposort?
todo!()
unimplemented!()
}

#[inline]
Expand Down
33 changes: 24 additions & 9 deletions src/circuit/command.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,31 @@ pub use hugr::types::{EdgeKind, Signature, Type, TypeRow};
pub use hugr::{Node, Port, Wire};

/// An operation applied to specific wires.
pub struct Command<'circ> {
/// The operation.
pub op: &'circ OpType,
#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct Command {
/// The operation node.
pub node: Node,
node: Node,
/// The input units to the operation.
pub inputs: Vec<CircuitUnit>,
inputs: Vec<CircuitUnit>,
/// The output units to the operation.
pub outputs: Vec<CircuitUnit>,
outputs: Vec<CircuitUnit>,
}

impl Command {
/// Returns the node corresponding to this command.
pub fn node(&self) -> Node {
self.node
}

/// Returns the output units of this command.
pub fn outputs(&self) -> &Vec<CircuitUnit> {
&self.outputs
}

/// Returns the output units of this command.
pub fn inputs(&self) -> &Vec<CircuitUnit> {
&self.inputs
}
}

/// An iterator over the commands of a circuit.
Expand Down Expand Up @@ -72,7 +88,7 @@ where

/// Process a new node, updating wires in `unit_wires` and returns the
/// command for the node if it's not an input or output.
fn process_node(&mut self, node: Node) -> Option<Command<'circ>> {
fn process_node(&mut self, node: Node) -> Option<Command> {
let optype = self.circ.get_optype(node);
let sig = optype.signature();

Expand Down Expand Up @@ -120,7 +136,6 @@ where
.collect();

Some(Command {
op: optype,
node,
inputs,
outputs,
Expand All @@ -133,7 +148,7 @@ where
Circ: HierarchyView<'circ>,
for<'a> &'a Circ: GraphBase<NodeId = Node> + IntoNeighborsDirected + IntoNodeIdentifiers,
{
type Item = Command<'circ>;
type Item = Command;

fn next(&mut self) -> Option<Self::Item> {
loop {
Expand Down
3 changes: 2 additions & 1 deletion src/json.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,8 @@ impl TKETDecode for SerialCircuit {
fn encode<'circ>(circ: &'circ impl Circuit<'circ>) -> Result<Self, Self::EncodeError> {
let mut encoder = JsonEncoder::new(circ);
for com in circ.commands() {
encoder.add_command(com)?;
let optype = circ.command_optype(&com);
encoder.add_command(com, optype)?;
}
Ok(encoder.finish())
}
Expand Down
18 changes: 9 additions & 9 deletions src/json/encoder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -88,20 +88,20 @@ impl JsonEncoder {
}

/// Add a circuit command to the serialization.
pub fn add_command(&mut self, command: Command) -> Result<(), OpConvertError> {
pub fn add_command(&mut self, command: Command, optype: &OpType) -> Result<(), OpConvertError> {
// Register any output of the command that can be used as a TKET1 parameter.
self.record_parameters(&command);
self.record_parameters(&command, optype);

let args = command
.inputs
.inputs()
.iter()
.filter_map(|&u| self.unit_to_register(u))
.collect();

// TODO Restore the opgroup (once the decoding supports it)
let opgroup = None;

let op: JsonOp = command.op.try_into()?;
let op: JsonOp = optype.try_into()?;
let op: circuit_json::Operation = op.into_operation();

// TODO: Update op.params. Leave untouched the ones that contain free variables.
Expand All @@ -126,21 +126,21 @@ impl JsonEncoder {
/// Record any output of the command that can be used as a TKET1 parameter.
///
/// Associates the output wires with the parameter expression.
fn record_parameters(&mut self, command: &Command) {
fn record_parameters(&mut self, command: &Command, optype: &OpType) {
// Only consider commands where all inputs are parameters.
let inputs = command
.inputs
.inputs()
.iter()
.filter_map(|unit| match unit {
CircuitUnit::Wire(wire) => self.parameters.get(wire),
CircuitUnit::Linear(_) => None,
})
.collect_vec();
if inputs.len() != command.inputs.len() {
if inputs.len() != command.inputs().len() {
return;
}

let param = match command.op {
let param = match optype {
OpType::Const(const_op) => {
// New constant, register it if it can be interpreted as a parameter.
match const_op.value() {
Expand All @@ -166,7 +166,7 @@ impl JsonEncoder {
}
};

for unit in &command.outputs {
for unit in command.outputs() {
if let CircuitUnit::Wire(wire) = unit {
self.parameters.insert(*wire, param.clone());
}
Expand Down
12 changes: 7 additions & 5 deletions src/json/op.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@
//! circuits by ensuring they always define a signature, and computing the
//! explicit count of qubits and linear bits.

use crate::ops::EXTENSION_ID as QUANTUM_EXTENSION_ID;
use hugr::extension::prelude::QB_T;
use hugr::extension::ExtensionSet;
use hugr::ops::custom::ExternalOp;
use hugr::ops::{LeafOp, OpTrait, OpType};
use hugr::std_extensions::arithmetic::float_types::FLOAT64_TYPE;
use hugr::std_extensions::quantum::EXTENSION_ID as QUANTUM_EXTENSION_ID;
use hugr::types::FunctionType;

use itertools::Itertools;
Expand All @@ -20,7 +20,7 @@ use tket_json_rs::optype::OpType as JsonOpType;

use super::{try_param_to_constant, OpConvertError};
use crate::extension::{try_unwrap_json_op, LINEAR_BIT, TKET1_EXTENSION_ID};
use crate::utils::{cx_gate, h_gate};
use crate::T2Op;

/// A serialized operation, containing the operation type and all its attributes.
///
Expand Down Expand Up @@ -184,8 +184,11 @@ impl From<&JsonOp> for OpType {
fn from(json_op: &JsonOp) -> Self {
match json_op.op.op_type {
// JsonOpType::X => LeafOp::X.into(),
JsonOpType::H => h_gate().into(),
JsonOpType::CX => cx_gate().into(),
JsonOpType::H => T2Op::H.into(),
JsonOpType::CX => T2Op::CX.into(),
JsonOpType::T => T2Op::T.into(),
JsonOpType::Tdg => T2Op::Tdg.into(),
JsonOpType::X => T2Op::X.into(),
JsonOpType::noop => LeafOp::Noop { ty: QB_T }.into(),
// TODO TKET1 measure takes a bit as input, HUGR measure does not
//JsonOpType::Measure => LeafOp::Measure.into(),
Expand Down Expand Up @@ -239,7 +242,6 @@ impl TryFrom<&OpType> for JsonOp {
ext => {
return try_unwrap_json_op(ext).ok_or_else(err);
} // h_gate() => JsonOpType::H,
// cx_gate() => JsonOpType::CX,
// LeafOp::ZZMax => JsonOpType::ZZMax,
// LeafOp::Reset => JsonOpType::Reset,
// //LeafOp::Measure => JsonOpType::Measure,
Expand Down
2 changes: 2 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@
pub mod circuit;
pub mod extension;
pub mod json;
mod ops;
pub mod passes;
pub use ops::{Pauli, T2Op};

#[cfg(feature = "portmatching")]
pub mod portmatching;
Expand Down
Loading

0 comments on commit b39a6fb

Please sign in to comment.