Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: supporting encoding const-parametrized ops #80

Merged
merged 6 commits into from
Sep 5, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 17 additions & 4 deletions src/circuit/command.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use std::iter::FusedIterator;
use hugr::hugr::views::HierarchyView;
use hugr::ops::{OpTag, OpTrait};
use petgraph::visit::{GraphBase, IntoNeighborsDirected, IntoNodeIdentifiers};
use portgraph::PortOffset;

use super::Circuit;

Expand Down Expand Up @@ -99,8 +100,14 @@ where

// Get the wire corresponding to each input unit.
// TODO: Add this to HugrView?
let inputs = sig
let inputs: Vec<_> = sig
.input_ports()
.chain(
// add the static input port
optype
.static_input()
.map(|_| PortOffset::new_incoming(sig.input.len()).into()),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally we shouldn't need to access portgraph primitives, nor should we set here what's the offset of the static input.

I can add an issue in hugr to implement an OpType::static_input_port (similar to OpType::other_port).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, sounds good

)
.filter_map(|port| {
let (from, from_port) = self.circ.linked_ports(node, port).next()?;
let wire = Wire::new(from, from_port);
Expand All @@ -116,15 +123,14 @@ where
}
})
.collect();

// The units in `self.wire_units` have been updated.
// Now we can early return if the node should be ignored.
let tag = optype.tag();
if tag == OpTag::Input || tag == OpTag::Output {
return None;
}

let outputs = sig
let mut outputs: Vec<_> = sig
.output_ports()
.map(|port| {
let wire = Wire::new(node, port);
Expand All @@ -134,7 +140,14 @@ where
}
})
.collect();

if let OpType::Const(_) = optype {
// add the static output port from a const.
let offset = outputs.len();
outputs.push(CircuitUnit::Wire(Wire::new(
node,
PortOffset::new_outgoing(offset).into(),
ss2165 marked this conversation as resolved.
Show resolved Hide resolved
)))
}
Some(Command {
node,
inputs,
Expand Down
43 changes: 29 additions & 14 deletions src/json/encoder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,13 @@

use std::collections::HashMap;

use downcast_rs::Downcast;
use hugr::extension::prelude::QB_T;
use hugr::ops::OpType;
use hugr::std_extensions::arithmetic::float_types::ConstF64;
use hugr::values::{PrimValue, Value};
use hugr::Wire;
use itertools::Itertools;
use tket_json_rs::circuit_json::{self, Permutation, SerialCircuit};
use itertools::{Either, Itertools};
use tket_json_rs::circuit_json::{self, Permutation, Register, SerialCircuit};

use crate::circuit::command::{CircuitUnit, Command};
use crate::circuit::Circuit;
Expand Down Expand Up @@ -92,19 +91,35 @@ impl JsonEncoder {
// Register any output of the command that can be used as a TKET1 parameter.
self.record_parameters(&command, optype);

let args = command
.inputs()
.iter()
.filter_map(|&u| self.unit_to_register(u))
.collect();
if let OpType::Const(_) | OpType::LoadConstant(_) = optype {
return Ok(());
}
let (args, params): (Vec<Register>, Vec<Wire>) =
command
.inputs()
.iter()
.partition_map(|&u| match self.unit_to_register(u) {
Some(r) => Either::Left(r),
None => match u {
CircuitUnit::Wire(w) => Either::Right(w),
CircuitUnit::Linear(_) => panic!("Should have been a register."),
ss2165 marked this conversation as resolved.
Show resolved Hide resolved
},
});

// TODO Restore the opgroup (once the decoding supports it)
let opgroup = None;

let op: JsonOp = optype.try_into()?;
let op: circuit_json::Operation = op.into_operation();

// TODO: Update op.params. Leave untouched the ones that contain free variables.
let mut op: circuit_json::Operation = op.into_operation();
if !params.is_empty() {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With this we recover all parameters that have been encoded as inputs. It should be easy to support transparent pass-through of parameters we don't understand yet (e.g. x+1) and retrieve them from the metadata here.

This is OK for now, but I'll add a followup issue.

op.params = Some(
params
.into_iter()
.filter_map(|w| self.parameters.get(&w))
.cloned()
.collect(),
)
}
// TODO: ops that contain free variables.
// (update decoder to ignore them too, but store them in the wrapped op)

let command = circuit_json::Command { op, args, opgroup };
Expand Down Expand Up @@ -144,8 +159,8 @@ impl JsonEncoder {
OpType::Const(const_op) => {
// New constant, register it if it can be interpreted as a parameter.
match const_op.value() {
Value::Prim(PrimValue::Extension(v)) => {
if let Some(f) = v.as_any().downcast_ref::<ConstF64>() {
Value::Prim(PrimValue::Extension((v,))) => {
if let Some(f) = v.downcast_ref::<ConstF64>() {
f.to_string()
} else {
return;
Expand Down
9 changes: 3 additions & 6 deletions src/json/op.rs
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,7 @@ impl From<&JsonOp> for OpType {
JsonOpType::Tdg => T2Op::Tdg.into(),
JsonOpType::X => T2Op::X.into(),
JsonOpType::Rz => T2Op::RzF64.into(),
JsonOpType::TK1 => T2Op::TK1.into(),
JsonOpType::noop => LeafOp::Noop { ty: QB_T }.into(),
_ => LeafOp::CustomOp(Box::new(json_op.as_opaque_op())).into(),
}
Expand All @@ -214,7 +215,8 @@ impl TryFrom<&OpType> for JsonOp {
T2Op::CX => JsonOpType::CX,
T2Op::H => JsonOpType::H,
T2Op::Measure => JsonOpType::Measure,
T2Op::RzF64 => JsonOpType::RzF64,
T2Op::RzF64 => JsonOpType::Rz,
T2Op::TK1 => JsonOpType::TK1,
_ => return Err(err()),
}
} else if let LeafOp::CustomOp(b) = leaf {
Expand All @@ -237,11 +239,6 @@ impl TryFrom<&OpType> for JsonOp {
}
}

if num_params > 0 {
unimplemented!("Native parametric operation encoding is not supported yet.")
// TODO: Gather parameter values from the `OpType` to encode in the `JsonOpType`.
}

Ok(JsonOp::new_with_counts(
json_optype,
num_qubits,
Expand Down
62 changes: 36 additions & 26 deletions src/json/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,13 @@ use std::collections::HashSet;
use hugr::hugr::views::{HierarchyView, SiblingGraph};
use hugr::ops::handle::DfgID;
use hugr::{Hugr, HugrView};
use rstest::rstest;
use tket_json_rs::circuit_json::{self, SerialCircuit};

use crate::circuit::Circuit;
use crate::json::TKETDecode;

#[test]
fn read_json_simple() {
let circ_s = r#"{
const SIMPLE_JSON: &str = r#"{
"phase": "0",
"bits": [],
"qubits": [["q", [0]], ["q", [1]]],
Expand All @@ -23,44 +22,46 @@ fn read_json_simple() {
"implicit_permutation": [[["q", [0]], ["q", [0]]], [["q", [1]], ["q", [1]]]]
}"#;

let ser: circuit_json::SerialCircuit = serde_json::from_str(circ_s).unwrap();
assert_eq!(ser.commands.len(), 2);

let hugr: Hugr = ser.clone().decode().unwrap();
let circ: SiblingGraph<'_, DfgID> = SiblingGraph::new(&hugr, hugr.root());

assert_eq!(circ.qubits().len(), 2);

let reser: SerialCircuit = SerialCircuit::encode(&circ).unwrap();
compare_serial_circs(&ser, &reser);
}

#[test]
fn read_json_unknown_op() {
// test ops that are not native to tket-2 are correctly captured as
// custom and output

let circ_s = r#"{
const UNKNOWN_OP: &str = r#"{
"phase": "1/2",
"bits": [["c", [0]], ["c", [1]]],
"qubits": [["q", [0]], ["q", [1]], ["q", [2]]],
"commands": [
{"args": [["q", [0]], ["q", [1]], ["q", [2]]], "op": {"type": "CSWAP"}},
{"args": [["q", [1]], ["c", [1]]], "op": {"type": "Measure"}},
{"args": [["q", [2]], ["c", [0]]], "op": {"type": "Measure"}}
{"args": [["q", [1]], ["c", [1]]], "op": {"type": "Measure"}}
],
"created_qubits": [],
"discarded_qubits": [],
"implicit_permutation": [[["q", [0]], ["q", [0]]], [["q", [1]], ["q", [1]]], [["q", [2]], ["q", [2]]]]
}"#;

let ser: SerialCircuit = serde_json::from_str(circ_s).unwrap();
assert_eq!(ser.commands.len(), 3);
const PARAMETRIZED: &str = r#"{
"phase": "0.0",
"bits": [],
"qubits": [["q", [0]], ["q", [1]]],
"commands": [
{"args":[["q",[0]]],"op":{"type":"H"}},
{"args":[["q",[1]],["q",[0]]],"op":{"type":"CX"}},
{"args":[["q",[0]]],"op":{"params":["0.1"],"type":"Rz"}},
{"args": [["q", [0]]], "op": {"params": ["0.1", "0.2", "0.3"], "type": "TK1"}}
],
"created_qubits": [],
"discarded_qubits": [],
"implicit_permutation": [[["q", [0]], ["q", [0]]], [["q", [1]], ["q", [1]]]]
}"#;

#[rstest]
#[case::simple(SIMPLE_JSON, 2, 2)]
#[case::unknown_op(UNKNOWN_OP, 2, 3)]
#[case::parametrized(PARAMETRIZED, 4, 2)]
fn json_roundtrip(#[case] circ_s: &str, #[case] num_commands: usize, #[case] num_qubits: usize) {
let ser: circuit_json::SerialCircuit = serde_json::from_str(circ_s).unwrap();
assert_eq!(ser.commands.len(), num_commands);

let hugr: Hugr = ser.clone().decode().unwrap();
let circ: SiblingGraph<'_, DfgID> = SiblingGraph::new(&hugr, hugr.root());

assert_eq!(circ.qubits().len(), 3);
assert_eq!(circ.qubits().len(), num_qubits);

let reser: SerialCircuit = SerialCircuit::encode(&circ).unwrap();
compare_serial_circs(&ser, &reser);
Expand All @@ -80,5 +81,14 @@ fn compare_serial_circs(a: &SerialCircuit, b: &SerialCircuit) {

assert_eq!(a.implicit_permutation, b.implicit_permutation);

assert_eq!(a.commands.len(), b.commands.len());

// the below only works if both serial circuits share a topological ordering
// of commands.
for (a, b) in a.commands.iter().zip(b.commands.iter()) {
assert_eq!(a.op.op_type, b.op.op_type);
assert_eq!(a.args, b.args);
assert_eq!(a.op.params, b.op.params);
}
// TODO: Check commands equality (they only implement PartialEq)
}
5 changes: 5 additions & 0 deletions src/ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ pub enum T2Op {
ZZMax,
Measure,
RzF64,
TK1,
}
#[derive(Clone, Copy, Debug, Serialize, Deserialize, EnumIter, Display, PartialEq, PartialOrd)]
#[allow(missing_docs)]
Expand Down Expand Up @@ -112,6 +113,10 @@ impl SimpleOpEnum for T2Op {
CX | ZZMax => FunctionType::new(two_qb_row.clone(), two_qb_row),
Measure => FunctionType::new(one_qb_row, type_row![QB_T, BOOL_T]),
RzF64 => FunctionType::new(type_row![QB_T, FLOAT64_TYPE], one_qb_row),
TK1 => FunctionType::new(
type_row![QB_T, FLOAT64_TYPE, FLOAT64_TYPE, FLOAT64_TYPE],
one_qb_row,
),
}
}

Expand Down