From 7e4c6752938c0b04a18a4098b3f85fe10f14da67 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Agust=C3=ADn=20Borgna?= Date: Tue, 5 Nov 2024 13:51:13 +0000 Subject: [PATCH] feat: Support classical expressions --- Cargo.toml | 2 + src/circuit_json.rs | 9 + src/clexpr.rs | 86 +++++++++ src/clexpr/op.rs | 73 ++++++++ src/clexpr/operator.rs | 72 ++++++++ src/lib.rs | 1 + src/opbox.rs | 2 + src/optype.rs | 11 ++ tests/data/qasm.json | 367 +++++++++++++++++++++++++++++++++++++++ tests/data/qasm.py | 32 ++++ tests/missing_optypes.rs | 60 +++++-- tests/roundtrip.rs | 2 + 12 files changed, 704 insertions(+), 13 deletions(-) create mode 100644 src/clexpr.rs create mode 100644 src/clexpr/op.rs create mode 100644 src/clexpr/operator.rs create mode 100644 tests/data/qasm.json create mode 100644 tests/data/qasm.py diff --git a/Cargo.toml b/Cargo.toml index 0eea50d..e27bfce 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -25,6 +25,7 @@ pythonize = { workspace = true, optional = true } strum = { workspace = true, features = ["derive"] } [dev-dependencies] +itertools = { workspace = true } pyo3 = { workspace = true } rstest = { workspace = true } assert-json-diff = { workspace = true } @@ -37,6 +38,7 @@ name = "integration" path = "tests/lib.rs" [workspace.dependencies] +itertools = "0.13.0" pyo3 = "0.22.2" pythonize = "0.22.0" rstest = "0.23.0" diff --git a/src/circuit_json.rs b/src/circuit_json.rs index c4d34b2..1a28a91 100644 --- a/src/circuit_json.rs +++ b/src/circuit_json.rs @@ -1,6 +1,7 @@ //! Contains structs for serializing and deserializing TKET circuits to and from //! JSON. +use crate::clexpr::ClExpr; use crate::opbox::OpBox; use crate::optype::OpType; use serde::{Deserialize, Serialize}; @@ -168,6 +169,12 @@ pub struct Operation

{ #[serde(rename = "box")] #[serde(skip_serializing_if = "Option::is_none")] pub op_box: Option, + /// Classical expression. + /// + /// Required if the operation is of type [`OpType::ClExp`]. + #[serde(skip_serializing_if = "Option::is_none")] + #[serde(rename = "expr")] + pub classical_expr: Option, /// The pre-computed signature. #[serde(skip_serializing_if = "Option::is_none")] pub signature: Option>, @@ -233,6 +240,7 @@ impl

Default for Operation

{ data: None, params: None, op_box: None, + classical_expr: None, signature: None, conditional: None, classical: None, @@ -266,6 +274,7 @@ impl

Operation

{ .params .map(|params| params.into_iter().map(f).collect()), op_box: self.op_box, + classical_expr: self.classical_expr, signature: self.signature, conditional: self.conditional, classical: self.classical, diff --git a/src/clexpr.rs b/src/clexpr.rs new file mode 100644 index 0000000..555341e --- /dev/null +++ b/src/clexpr.rs @@ -0,0 +1,86 @@ +//! Classical expressions + +pub mod op; +pub mod operator; + +use operator::ClOperator; +use serde::de::SeqAccess; +use serde::ser::SerializeSeq; +use serde::{Deserialize, Serialize}; + +/// Data encoding a classical expression. +/// +/// A classical expression operates over multi-bit registers, +/// which are identified here by their individual bit positions. +/// +/// This is included in a [`Operation`] when the operation is a [`OpType::ClExpr`]. +/// +/// [`Operation`]: crate::circuit_json::Operation +/// [`OpType::ClExpr`]: crate::optype::OpType::ClExpr +#[derive(Debug, Default, PartialEq, Clone, Serialize, Deserialize)] +#[non_exhaustive] +pub struct ClExpr { + /// TODO: ??? + pub bit_posn: Vec, + /// The encoded expression. + pub expr: ClOperator, + /// The input bits of the expression. + pub reg_posn: Vec, + /// The output bits of the expression. + pub output_posn: ClRegisterBits, +} + +/// An input register for a classical expression. +/// +/// Contains the input index as well as the bits that are part of the register. +/// +/// Serialized as a list with two elements: the index and the bits. +#[derive(Debug, Default, PartialEq, Clone)] +pub struct InputClRegister { + /// The index of the register. + pub index: u32, + /// The individual bit indices that are part of the register. + pub bits: ClRegisterBits, +} + +/// The list of bit indices which are part of a register. +/// +/// Registers are little-endian, so the first bit is the least significant. +#[derive(Debug, Default, PartialEq, Clone, Serialize, Deserialize)] +#[serde(transparent)] +pub struct ClRegisterBits(pub Vec); + +impl Serialize for InputClRegister { + fn serialize(&self, serializer: S) -> Result { + let mut seq = serializer.serialize_seq(Some(2))?; + seq.serialize_element(&self.index)?; + seq.serialize_element(&self.bits)?; + seq.end() + } +} + +impl<'de> Deserialize<'de> for InputClRegister { + fn deserialize>(deserializer: D) -> Result { + struct Visitor; + + impl<'de_vis> serde::de::Visitor<'de_vis> for Visitor { + type Value = InputClRegister; + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + formatter.write_str("a list of two elements: the index and the bits") + } + + fn visit_seq>(self, mut seq: A) -> Result { + let index = seq + .next_element::()? + .ok_or_else(|| serde::de::Error::invalid_length(0, &self))?; + let bits = seq + .next_element::()? + .ok_or_else(|| serde::de::Error::invalid_length(1, &self))?; + Ok(InputClRegister { index, bits }) + } + } + + deserializer.deserialize_seq(Visitor) + } +} diff --git a/src/clexpr/op.rs b/src/clexpr/op.rs new file mode 100644 index 0000000..1843810 --- /dev/null +++ b/src/clexpr/op.rs @@ -0,0 +1,73 @@ +//! Classical expression operations. + +use serde::{Deserialize, Serialize}; +use strum::EnumString; + +/// List of supported classical expressions. +/// +/// Corresponds to `pytket.circuit.ClOp`. +#[derive(Deserialize, Serialize, Clone, Debug, Default, PartialEq, Eq, Hash, EnumString)] +#[non_exhaustive] +pub enum ClOp { + /// Invalid operation + #[default] + INVALID, + + /// Bitwise AND + BitAnd, + /// Bitwise OR + BitOr, + /// Bitwise XOR + BitXor, + /// Bitwise equality + BitEq, + /// Bitwise inequality + BitNeq, + /// Bitwise NOT + BitNot, + /// Constant zero bit + BitZero, + /// Constant one bit + BitOne, + + /// Registerwise AND + RegAnd, + /// Registerwise OR + RegOr, + /// Registerwise XOR + RegXor, + /// Registerwise equality + RegEq, + /// Registerwise inequality + RegNeq, + /// Registerwise NOT + RegNot, + /// Constant all-zeros register + RegZero, + /// Constant all-ones register + RegOne, + /// Integer less-than comparison + RegLt, + /// Integer greater-than comparison + RegGt, + /// Integer less-than-or-equal comparison + RegLeq, + /// Integer greater-than-or-equal comparison + RegGeq, + /// Integer addition + RegAdd, + /// Integer subtraction + RegSub, + /// Integer multiplication + RegMul, + /// Integer division + RegDiv, + /// Integer exponentiation + RegPow, + /// Left shift + RegLsh, + /// Right shift + RegRsh, + /// Integer negation + RegNeg, +} diff --git a/src/clexpr/operator.rs b/src/clexpr/operator.rs new file mode 100644 index 0000000..78fe593 --- /dev/null +++ b/src/clexpr/operator.rs @@ -0,0 +1,72 @@ +//! A tree of operators forming a classical expression. + +use serde::{Deserialize, Serialize}; + +use super::op::ClOp; + +/// A node in a classical expression tree. +#[derive(Debug, Default, PartialEq, Clone, Serialize, Deserialize)] +#[non_exhaustive] +pub struct ClOperator { + /// The operation to be performed. + pub op: ClOp, + /// The arguments to the operation. + pub args: Vec, +} + +/// An argument to a classical expression operation. +#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)] +#[non_exhaustive] +#[serde(tag = "type", content = "input")] +pub enum ClArgument { + /// A terminal argument. + #[serde(rename = "term")] + Terminal(ClTerminal), + /// A sub-expression. + #[serde(rename = "expr")] + Expression(Box), +} + +/// A terminal argument in a classical expression operation. +#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)] +#[non_exhaustive] +#[serde(tag = "type", content = "term")] +pub enum ClTerminal { + /// A terminal argument. + #[serde(rename = "var")] + Variable(ClVariable), + /// A constant integer. + #[serde(rename = "int")] + Int(u32), +} + +/// A variable terminal argument in a classical expression operation. +#[derive(Debug, PartialEq, Clone, Serialize, Deserialize, Hash)] +#[non_exhaustive] +#[serde(tag = "type", content = "var")] +pub enum ClVariable { + /// A register variable. + #[serde(rename = "reg")] + Register { + /// The register index. + index: u32, + }, +} + +impl Default for ClArgument { + fn default() -> Self { + ClArgument::Terminal(ClTerminal::default()) + } +} + +impl Default for ClTerminal { + fn default() -> Self { + ClTerminal::Int(0) + } +} + +impl Default for ClVariable { + fn default() -> Self { + ClVariable::Register { index: 0 } + } +} diff --git a/src/lib.rs b/src/lib.rs index 5410829..3238f1f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -3,6 +3,7 @@ //! [TKET](https://github.com/CQCL/tket) quantum compiler. pub mod circuit_json; +pub mod clexpr; pub mod opbox; pub mod optype; #[cfg(feature = "pyo3")] diff --git a/src/opbox.rs b/src/opbox.rs index 794e5a2..9ad7522 100644 --- a/src/opbox.rs +++ b/src/opbox.rs @@ -147,6 +147,8 @@ pub enum OpBox { control_state: u32, }, /// Holding box for abstract expressions on Bits. + /// + /// Deprecated in favour of [`OpType::ClExpr`]. ClassicalExpBox { id: BoxID, n_i: u32, diff --git a/src/optype.rs b/src/optype.rs index 851c98c..1fe2b49 100644 --- a/src/optype.rs +++ b/src/optype.rs @@ -500,9 +500,20 @@ pub enum OpType { /// See [`ClassicalExpBox`] /// + /// Deprecated. Use [`OpType::ClExpBox`] instead. + /// /// [`ClassicalExpBox`]: crate::opbox::OpBox::ClassicalExpBox ClassicalExpBox, + /// Classical expression. + /// + /// An operation of this type is accompanied by a [`ClExpr`] object. + /// + /// This is a replacement of the deprecated [`ClassicalExpBox`]. + /// + /// [`ClExpBox`]: crate::opbox::OpBox::ClExpBox + ClExpr, + /// See [`MultiplexorBox`] /// /// [`MultiplexorBox`]: crate::opbox::OpBox::MultiplexorBox diff --git a/tests/data/qasm.json b/tests/data/qasm.json new file mode 100644 index 0000000..3ea4f58 --- /dev/null +++ b/tests/data/qasm.json @@ -0,0 +1,367 @@ +{ + "bits": [ + [ + "a", + [ + 0 + ] + ], + [ + "a", + [ + 1 + ] + ], + [ + "a", + [ + 2 + ] + ], + [ + "b", + [ + 0 + ] + ], + [ + "b", + [ + 1 + ] + ], + [ + "b", + [ + 2 + ] + ], + [ + "c", + [ + 0 + ] + ], + [ + "c", + [ + 1 + ] + ], + [ + "c", + [ + 2 + ] + ], + [ + "d", + [ + 0 + ] + ], + [ + "d", + [ + 1 + ] + ], + [ + "d", + [ + 2 + ] + ] + ], + "commands": [ + { + "args": [ + [ + "a", + [ + 0 + ] + ], + [ + "a", + [ + 1 + ] + ], + [ + "a", + [ + 2 + ] + ], + [ + "b", + [ + 0 + ] + ], + [ + "b", + [ + 1 + ] + ], + [ + "b", + [ + 2 + ] + ], + [ + "c", + [ + 0 + ] + ], + [ + "c", + [ + 1 + ] + ], + [ + "c", + [ + 2 + ] + ], + [ + "d", + [ + 0 + ] + ], + [ + "d", + [ + 1 + ] + ], + [ + "d", + [ + 2 + ] + ] + ], + "op": { + "expr": { + "bit_posn": [], + "expr": { + "args": [ + { + "input": { + "args": [ + { + "input": { + "args": [ + { + "input": { + "term": { + "type": "reg", + "var": { + "index": 0 + } + }, + "type": "var" + }, + "type": "term" + }, + { + "input": { + "term": { + "type": "reg", + "var": { + "index": 1 + } + }, + "type": "var" + }, + "type": "term" + } + ], + "op": "RegAdd" + }, + "type": "expr" + }, + { + "input": { + "term": 2, + "type": "int" + }, + "type": "term" + } + ], + "op": "RegDiv" + }, + "type": "expr" + }, + { + "input": { + "term": { + "type": "reg", + "var": { + "index": 2 + } + }, + "type": "var" + }, + "type": "term" + } + ], + "op": "RegSub" + }, + "output_posn": [ + 9, + 10, + 11 + ], + "reg_posn": [ + [ + 0, + [ + 0, + 1, + 2 + ] + ], + [ + 1, + [ + 3, + 4, + 5 + ] + ], + [ + 2, + [ + 6, + 7, + 8 + ] + ] + ] + }, + "type": "ClExpr" + } + }, + { + "args": [ + [ + "q", + [ + 0 + ] + ] + ], + "op": { + "type": "H" + } + }, + { + "args": [ + [ + "q", + [ + 2 + ] + ] + ], + "op": { + "type": "Z" + } + }, + { + "args": [ + [ + "q", + [ + 2 + ] + ], + [ + "q", + [ + 1 + ] + ] + ], + "op": { + "type": "CX" + } + } + ], + "created_qubits": [], + "discarded_qubits": [], + "implicit_permutation": [ + [ + [ + "q", + [ + 0 + ] + ], + [ + "q", + [ + 0 + ] + ] + ], + [ + [ + "q", + [ + 1 + ] + ], + [ + "q", + [ + 1 + ] + ] + ], + [ + [ + "q", + [ + 2 + ] + ], + [ + "q", + [ + 2 + ] + ] + ] + ], + "phase": "0.0", + "qubits": [ + [ + "q", + [ + 0 + ] + ], + [ + "q", + [ + 1 + ] + ], + [ + "q", + [ + 2 + ] + ] + ] +} diff --git a/tests/data/qasm.py b/tests/data/qasm.py new file mode 100644 index 0000000..205a022 --- /dev/null +++ b/tests/data/qasm.py @@ -0,0 +1,32 @@ +# /// script +# requires-python = ">=3.13" +# dependencies = [ +# "pytket>=1.34", +# ] +# /// + +import json + +from pytket import Circuit +from pytket.qasm import circuit_from_qasm_str + + +def qasm_circuit() -> Circuit: + qasm = """OPENQASM 2.0; + include "hqslib1.inc"; + qreg q[3]; + creg a[3]; + creg b[3]; + creg c[3]; + creg d[3]; + d = (((a + b) / 2) - c); + + h q[0]; + z q[2]; + cx q[2], q[1]; + """ + return circuit_from_qasm_str(qasm, use_clexpr=True) + + +if __name__ == "__main__": + print(json.dumps(qasm_circuit().to_dict(), indent=2)) diff --git a/tests/missing_optypes.rs b/tests/missing_optypes.rs index 80a274d..1da1242 100644 --- a/tests/missing_optypes.rs +++ b/tests/missing_optypes.rs @@ -4,10 +4,29 @@ use std::str::FromStr; +use itertools::Itertools; use pyo3::prelude::*; use pyo3::types::PyDict; +use tket_json_rs::clexpr::op::ClOp; use tket_json_rs::OpType; +/// Given a python enum, lists the enum variants that cannot be converted into a `T` using `FromStr`. +fn find_missing_variants<'py, T>(py_enum: &Bound<'py, PyAny>) -> impl Iterator + 'py +where + T: FromStr, +{ + let py_members = py_enum.getattr("__members__").unwrap(); + let py_members = py_members.downcast::().unwrap(); + + py_members.into_iter().filter_map(|(name, _class)| { + let name = name.extract::().unwrap(); + match T::from_str(&name) { + Err(_) => Some(name), + Ok(_) => None, + } + }) +} + #[test] #[ignore = "Requires a python environment with `pytket` installed."] fn missing_optypes() -> PyResult<()> { @@ -19,19 +38,7 @@ fn missing_optypes() -> PyResult<()> { panic!("Failed to import `pytket`. Make sure the python library is installed."); }; let py_enum = pytket.getattr("OpType")?; - let py_members = py_enum.getattr("__members__")?; - let py_members = py_members.downcast::()?; - - let missing: Vec = py_members - .into_iter() - .filter_map(|(name, _class)| { - let name = name.extract::().unwrap(); - match OpType::from_str(&name) { - Err(_) => Some(name), - Ok(_) => None, - } - }) - .collect(); + let missing = find_missing_variants::(&py_enum).collect_vec(); if !missing.is_empty() { let msg = "\nMissing optypes in `tket_json_rs`:\n".to_string(); @@ -46,3 +53,30 @@ fn missing_optypes() -> PyResult<()> { Ok(()) }) } + +#[test] +#[ignore = "Requires a python environment with `pytket` installed."] +fn missing_classical_optypes() -> PyResult<()> { + println!("Checking missing classical ops"); + + pyo3::prepare_freethreaded_python(); + Python::with_gil(|py| { + let Ok(pytket) = PyModule::import_bound(py, "pytket") else { + panic!("Failed to import `pytket`. Make sure the python library is installed."); + }; + let py_enum = pytket.getattr("circuit")?.getattr("ClOp")?; + let missing = find_missing_variants::(&py_enum).collect_vec(); + + if !missing.is_empty() { + let msg = "\nMissing classical ops in `tket_json_rs`:\n".to_string(); + let msg = missing + .into_iter() + .fold(msg, |msg, s| msg + " - " + &s + "\n"); + let msg = + msg + "Please add them to the `ClOp` enum in `tket_json_rs/src/clexpr/op.rs`.\n"; + panic!("{msg}"); + } + + Ok(()) + }) +} diff --git a/tests/roundtrip.rs b/tests/roundtrip.rs index 41aa4ea..9b4dd36 100644 --- a/tests/roundtrip.rs +++ b/tests/roundtrip.rs @@ -7,12 +7,14 @@ use tket_json_rs::SerialCircuit; const SIMPLE: &str = include_str!("data/simple.json"); const CLASSICAL: &str = include_str!("data/classical.json"); const DIAGONAL: &str = include_str!("data/diagonal-box.json"); +const QASM: &str = include_str!("data/qasm.json"); const WASM: &str = include_str!("data/wasm.json"); #[rstest] #[case::simple(SIMPLE, 4)] #[case::classical(CLASSICAL, 3)] #[case::diagonal_box(DIAGONAL, 1)] +#[case::qasm_box(QASM, 4)] #[case::wasm_box(WASM, 1)] fn roundtrip(#[case] json: &str, #[case] num_commands: usize) { let initial_json: Value = serde_json::from_str(json).unwrap();