diff --git a/pytket/binders/circuit/clexpr.cpp b/pytket/binders/circuit/clexpr.cpp index fa2f73c227..a0589cb393 100644 --- a/pytket/binders/circuit/clexpr.cpp +++ b/pytket/binders/circuit/clexpr.cpp @@ -34,7 +34,7 @@ namespace tket { static std::string qasm_bit_repr( const ClExprTerm &term, const std::map &input_bits) { - if (const int *n = std::get_if(&term)) { + if (const uint64_t *n = std::get_if(&term)) { switch (*n) { case 0: return "0"; @@ -56,7 +56,7 @@ static std::string qasm_bit_repr( static std::string qasm_reg_repr( const ClExprTerm &term, const std::map &input_regs) { - if (const int *n = std::get_if(&term)) { + if (const uint64_t *n = std::get_if(&term)) { std::stringstream ss; ss << *n; return ss.str(); diff --git a/pytket/conanfile.py b/pytket/conanfile.py index cf6b449cfb..1d2f0718d5 100644 --- a/pytket/conanfile.py +++ b/pytket/conanfile.py @@ -38,7 +38,7 @@ def requirements(self): self.requires("pybind11_json/0.2.14") self.requires("symengine/0.12.0") self.requires("tkassert/0.3.4@tket/stable") - self.requires("tket/1.3.37@tket/stable") + self.requires("tket/1.3.38@tket/stable") self.requires("tklog/0.3.3@tket/stable") self.requires("tkrng/0.3.3@tket/stable") self.requires("tktokenswap/0.3.9@tket/stable") diff --git a/pytket/docs/changelog.rst b/pytket/docs/changelog.rst index 1dcaf780f2..1fe3063f00 100644 --- a/pytket/docs/changelog.rst +++ b/pytket/docs/changelog.rst @@ -8,6 +8,8 @@ Features: * Add `clexpr.check_register_alignments()` method to check register alignments in `ClExprOp`. +* Use `ClExprOp` instead of `ClassicalExpBox` when deconstructing complex + conditions. Fixes: diff --git a/pytket/pytket/circuit/add_condition.py b/pytket/pytket/circuit/add_condition.py index 79abac5209..63d2c21b84 100644 --- a/pytket/pytket/circuit/add_condition.py +++ b/pytket/pytket/circuit/add_condition.py @@ -16,11 +16,8 @@ from typing import Tuple, Union from pytket.circuit import Bit, Circuit, BitRegister -from pytket._tket.unit_id import ( - _TEMP_REG_SIZE, - _TEMP_BIT_NAME, - _TEMP_BIT_REG_BASE, -) +from pytket._tket.unit_id import _TEMP_BIT_NAME, _TEMP_BIT_REG_BASE +from pytket.circuit.clexpr import wired_clexpr_from_logic_exp from pytket.circuit.logic_exp import ( BitLogicExp, Constant, @@ -79,7 +76,8 @@ def _add_condition( circ.add_bit(condition_bit) if isinstance(pred_exp, BitLogicExp): - circ.add_classicalexpbox_bit(pred_exp, [condition_bit]) + wexpr, args = wired_clexpr_from_logic_exp(pred_exp, [condition_bit]) + circ.add_clexpr(wexpr, args) return condition_bit, bool(pred_val) assert isinstance(pred_exp, (RegLogicExp, BitRegister)) @@ -99,10 +97,11 @@ def _add_condition( int(r_name.split("_")[-1]) for r_name in existing_reg_names ) next_index = max(existing_reg_indices, default=-1) + 1 - temp_reg = BitRegister(f"{_TEMP_BIT_REG_BASE}_{next_index}", _TEMP_REG_SIZE) + temp_reg = BitRegister(f"{_TEMP_BIT_REG_BASE}_{next_index}", min_reg_size) circ.add_c_register(temp_reg) - target_bits = temp_reg.to_list()[:min_reg_size] - circ.add_classicalexpbox_register(pred_exp, target_bits) + target_bits = temp_reg.to_list() + wexpr, args = wired_clexpr_from_logic_exp(pred_exp, target_bits) + circ.add_clexpr(wexpr, args) elif isinstance(pred_exp, BitRegister): target_bits = pred_exp.to_list() diff --git a/pytket/pytket/circuit/decompose_classical.py b/pytket/pytket/circuit/decompose_classical.py index 1b525bc4f1..ab1fe8a487 100644 --- a/pytket/pytket/circuit/decompose_classical.py +++ b/pytket/pytket/circuit/decompose_classical.py @@ -358,5 +358,12 @@ def _decompose_expressions(circ: Circuit) -> Tuple[Circuit, bool]: # add_gate doesn't work for metaops newcirc.add_barrier(args) else: + for arg in args: + if ( + isinstance(arg, Bit) + and arg.reg_name != "_w" # workaround: this shouldn't be type Bit + and arg not in newcirc.bits + ): + newcirc.add_bit(arg) newcirc.add_gate(op, args, **kwargs) return newcirc, modified diff --git a/pytket/tests/classical_test.py b/pytket/tests/classical_test.py index e365074270..0139ad8f4e 100644 --- a/pytket/tests/classical_test.py +++ b/pytket/tests/classical_test.py @@ -1037,7 +1037,7 @@ def test_regpredicate(condition: PredicateExp) -> None: circ.add_bit(inp, reject_dups=False) circ.X(qb, condition=condition) - assert circ.n_gates_of_type(OpType.ClassicalExpBox) == 1 + assert circ.n_gates_of_type(OpType.ClassicalExpBox) == 0 newcirc = circ.copy() DecomposeClassicalExp().apply(newcirc) @@ -1115,178 +1115,6 @@ def check_serialization_roundtrip(circ: Circuit) -> None: assert circ_from_dict.to_dict() == circ_dict -def test_decomposition_known() -> None: - bits = [Bit(i) for i in range(10)] - registers = [BitRegister(c, 3) for c in "abdefghijk"] - - qreg = QubitRegister("q_", 10) - circ = Circuit() - conditioned_circ = Circuit() - decomposed_circ = Circuit() - - for c in (circ, conditioned_circ, decomposed_circ): - for b in bits: - c.add_bit(b) - for br in registers: - for b in br.to_list(): - c.add_bit(b, reject_dups=False) - c.add_q_register(qreg.name, qreg.size) - - circ.H(qreg[0], condition=bits[0]) - circ.X(qreg[0], condition=if_bit(bits[1])) - circ.S(qreg[0]) - circ.T(qreg[1], condition=if_not_bit(bits[2])) - circ.Z(qreg[0], condition=(bits[2] & bits[3])) - circ.Z(qreg[1], condition=if_not_bit(bits[3] & bits[4])) - big_exp = bits[4] | bits[5] ^ bits[6] | bits[7] & bits[8] - # ^ no need for parantheses as python operator precedence - # will enforce correct precedence in LogicExp - circ.CX(qreg[0], qreg[1]) - circ.CX(qreg[1], qreg[2], condition=big_exp) - - circ.add_barrier(qreg.to_list()) - - circ.H(qreg[2], condition=reg_eq(registers[0], 3)) - circ.X(qreg[3], condition=reg_lt(registers[1], 6)) - circ.Y(qreg[4], condition=reg_neq(registers[2], 5)) - circ.Z(qreg[5], condition=reg_gt(registers[3], 3)) - circ.S(qreg[6], condition=reg_leq(registers[4], 6)) - circ.T(qreg[7], condition=reg_geq(registers[5], 3)) - big_reg_exp = registers[4] & registers[3] | registers[6] ^ registers[7] - circ.CX(qreg[3], qreg[4], condition=reg_eq(big_reg_exp, 3)) - - circ.add_classicalexpbox_bit( - bits[4] | bits[5] & bits[3], [bits[0]], condition=bits[1] - ) - check_serialization_roundtrip(circ) - - temp_bits = BitRegister(_TEMP_BIT_NAME, 64) - - def temp_reg(i: int) -> BitRegister: - return BitRegister(f"{_TEMP_BIT_REG_BASE}_{i}", 64) - - for b in (temp_bits[i] for i in range(0, 10)): - conditioned_circ.add_bit(b) - - for t_r in (temp_reg(i) for i in range(0, 1)): - conditioned_circ.add_c_register(t_r.name, t_r.size) - - # relies on existing interface for adding conditionals - # may need a more low level interface for that if we decide to get rid of it - conditioned_circ.H(qreg[0], condition_bits=[bits[0]], condition_value=1) - conditioned_circ.X(qreg[0], condition_bits=[bits[1]], condition_value=1) - conditioned_circ.S(qreg[0]) - conditioned_circ.T(qreg[1], condition_bits=[bits[2]], condition_value=0) - - conditioned_circ.add_classicalexpbox_bit((bits[2] & bits[3]), [temp_bits[0]]) - conditioned_circ.Z(qreg[0], condition_bits=[temp_bits[0]], condition_value=1) - conditioned_circ.add_classicalexpbox_bit((bits[3] & bits[4]), [temp_bits[1]]) - conditioned_circ.Z(qreg[1], condition_bits=[temp_bits[1]], condition_value=0) - conditioned_circ.CX(qreg[0], qreg[1]) - conditioned_circ.add_classicalexpbox_bit(big_exp, [temp_bits[2]]) - conditioned_circ.CX( - qreg[1], qreg[2], condition_bits=[temp_bits[2]], condition_value=1 - ) - - conditioned_circ.add_barrier(qreg.to_list()) - - registers_lists = [reg.to_list() for reg in registers] - - conditioned_circ.add_c_range_predicate(3, 3, registers_lists[0], temp_bits[3]) - conditioned_circ.H(qreg[2], condition_bits=[temp_bits[3]], condition_value=1) - conditioned_circ.add_c_range_predicate(0, 5, registers_lists[1], temp_bits[4]) - conditioned_circ.X(qreg[3], condition_bits=[temp_bits[4]], condition_value=1) - conditioned_circ.add_c_range_predicate(5, 5, registers_lists[2], temp_bits[5]) - conditioned_circ.Y(qreg[4], condition_bits=[temp_bits[5]], condition_value=0) - conditioned_circ.add_c_range_predicate( - 4, 18446744073709551615, registers_lists[3], temp_bits[6] - ) - conditioned_circ.Z(qreg[5], condition_bits=[temp_bits[6]], condition_value=1) - conditioned_circ.add_c_range_predicate(0, 6, registers_lists[4], temp_bits[7]) - conditioned_circ.S(qreg[6], condition_bits=[temp_bits[7]], condition_value=1) - conditioned_circ.add_c_range_predicate( - 3, 18446744073709551615, registers_lists[5], temp_bits[8] - ) - conditioned_circ.T(qreg[7], condition_bits=[temp_bits[8]], condition_value=1) - - temp_reg_bits = [temp_reg(0)[i] for i in range(3)] - conditioned_circ.add_classicalexpbox_register(big_reg_exp, temp_reg_bits) - conditioned_circ.add_c_range_predicate(3, 3, temp_reg_bits, temp_bits[9]) - conditioned_circ.CX( - qreg[3], qreg[4], condition_bits=[temp_bits[9]], condition_value=1 - ) - conditioned_circ.add_classicalexpbox_bit( - bits[4] | bits[5] & bits[3], [bits[0]], condition=bits[1] - ) - - assert compare_commands_box(circ, conditioned_circ) - - for b in (temp_bits[i] for i in range(0, 12)): - decomposed_circ.add_bit(b) - - decomposed_circ.add_c_register(BitRegister(f"{_TEMP_BIT_REG_BASE}_0", 3)) - decomposed_circ.add_c_register(BitRegister(f"{_TEMP_BIT_REG_BASE}_1", 64)) - decomposed_circ.add_c_register(BitRegister(f"{_TEMP_BIT_REG_BASE}_2", 64)) - - decomposed_circ.H(qreg[0], condition_bits=[bits[0]], condition_value=1) - decomposed_circ.X(qreg[0], condition_bits=[bits[1]], condition_value=1) - decomposed_circ.S(qreg[0]) - decomposed_circ.T(qreg[1], condition_bits=[bits[2]], condition_value=0) - decomposed_circ.add_c_and(bits[2], bits[3], temp_bits[0]) - decomposed_circ.Z(qreg[0], condition_bits=[temp_bits[0]], condition_value=1) - decomposed_circ.add_c_and(bits[3], bits[4], temp_bits[1]) - decomposed_circ.Z(qreg[1], condition_bits=[temp_bits[1]], condition_value=0) - decomposed_circ.CX(qreg[0], qreg[1]) - decomposed_circ.add_c_range_predicate(3, 3, registers_lists[0], temp_bits[3]) - decomposed_circ.add_c_range_predicate(0, 5, registers_lists[1], temp_bits[4]) - decomposed_circ.add_c_range_predicate(5, 5, registers_lists[2], temp_bits[5]) - decomposed_circ.add_c_range_predicate( - 4, 18446744073709551615, registers_lists[3], temp_bits[6] - ) - decomposed_circ.add_c_range_predicate(0, 6, registers_lists[4], temp_bits[7]) - decomposed_circ.add_c_range_predicate( - 3, 18446744073709551615, registers_lists[5], temp_bits[8] - ) - - decomposed_circ.add_c_xor(bits[5], bits[6], temp_bits[10]) - decomposed_circ.add_c_and(bits[7], bits[8], temp_bits[11]) - decomposed_circ.add_c_or(bits[4], temp_bits[10], temp_bits[10]) - decomposed_circ.add_c_or(temp_bits[10], temp_bits[11], temp_bits[2]) - decomposed_circ.CX( - qreg[1], qreg[2], condition_bits=[temp_bits[2]], condition_value=1 - ) - - decomposed_circ.add_barrier(qreg.to_list()) - - decomposed_circ.H(qreg[2], condition_bits=[temp_bits[3]], condition_value=1) - decomposed_circ.X(qreg[3], condition_bits=[temp_bits[4]], condition_value=1) - decomposed_circ.Y(qreg[4], condition_bits=[temp_bits[5]], condition_value=0) - decomposed_circ.Z(qreg[5], condition_bits=[temp_bits[6]], condition_value=1) - decomposed_circ.S(qreg[6], condition_bits=[temp_bits[7]], condition_value=1) - decomposed_circ.T(qreg[7], condition_bits=[temp_bits[8]], condition_value=1) - - decomposed_circ.add_c_and_to_registers(registers[4], registers[3], temp_reg(1)) - decomposed_circ.add_c_xor_to_registers(registers[6], registers[7], temp_reg(2)) - decomposed_circ.add_c_or_to_registers( - temp_reg(1), BitRegister(temp_reg(2).name, 3), temp_reg(0) - ) - decomposed_circ.add_c_range_predicate(3, 3, temp_reg(0).to_list()[:3], temp_bits[9]) - decomposed_circ.CX( - qreg[3], qreg[4], condition_bits=[temp_bits[9]], condition_value=1 - ) - decomposed_circ.add_c_and( - bits[5], bits[3], temp_bits[10], condition_bits=[bits[1]], condition_value=1 - ) - decomposed_circ.add_c_or( - bits[4], temp_bits[10], bits[0], condition_bits=[bits[1]], condition_value=1 - ) - check_serialization_roundtrip(decomposed_circ) - circ_copy = circ.copy() - - DecomposeClassicalExp().apply(circ_copy) - assert circ_copy == decomposed_circ - - def test_conditional() -> None: c = Circuit(1, 2) c.H(0, condition_bits=[0, 1], condition_value=3) diff --git a/pytket/tests/compilation_test.py b/pytket/tests/compilation_test.py index fd87e95f2b..0da482dc17 100644 --- a/pytket/tests/compilation_test.py +++ b/pytket/tests/compilation_test.py @@ -309,7 +309,7 @@ def test_resize_scratch_registers() -> None: reg_a = circ.add_c_register("a", 1) reg_b = circ.add_c_register("b", 1) circ.X(0, condition=reg_eq(reg_a ^ reg_b, 1)) - assert circ.get_c_register(f"{_TEMP_BIT_REG_BASE}_0").size == 64 + assert circ.get_c_register(f"{_TEMP_BIT_REG_BASE}_0").size == 1 c_compiled = circ.copy() scratch_reg_resize_pass(10).apply(c_compiled) assert circ == c_compiled diff --git a/tket/conanfile.py b/tket/conanfile.py index 31917b2906..a258c5c5ee 100644 --- a/tket/conanfile.py +++ b/tket/conanfile.py @@ -23,7 +23,7 @@ class TketConan(ConanFile): name = "tket" - version = "1.3.37" + version = "1.3.38" package_type = "library" license = "Apache 2" homepage = "https://github.com/CQCL/tket" diff --git a/tket/include/tket/Ops/ClExpr.hpp b/tket/include/tket/Ops/ClExpr.hpp index 1adf73d696..712aad2af6 100644 --- a/tket/include/tket/Ops/ClExpr.hpp +++ b/tket/include/tket/Ops/ClExpr.hpp @@ -19,6 +19,7 @@ * @brief Classical expressions involving bits and registers */ +#include #include #include #include @@ -124,7 +125,7 @@ void from_json(const nlohmann::json& j, ClExprVar& var); /** * A term in a classical expression (either a constant or a variable) */ -typedef std::variant ClExprTerm; +typedef std::variant ClExprTerm; std::ostream& operator<<(std::ostream& os, const ClExprTerm& term); diff --git a/tket/src/Ops/ClExpr.cpp b/tket/src/Ops/ClExpr.cpp index 0e0f1fdd5a..880bbd986d 100644 --- a/tket/src/Ops/ClExpr.cpp +++ b/tket/src/Ops/ClExpr.cpp @@ -15,6 +15,7 @@ #include "tket/Ops/ClExpr.hpp" #include +#include #include #include #include @@ -131,7 +132,7 @@ void from_json(const nlohmann::json& j, ClExprVar& var) { } std::ostream& operator<<(std::ostream& os, const ClExprTerm& term) { - if (const int* n = std::get_if(&term)) { + if (const uint64_t* n = std::get_if(&term)) { return os << *n; } else { ClExprVar var = std::get(term); @@ -141,7 +142,7 @@ std::ostream& operator<<(std::ostream& os, const ClExprTerm& term) { void to_json(nlohmann::json& j, const ClExprTerm& term) { nlohmann::json inner_j; - if (const int* n = std::get_if(&term)) { + if (const uint64_t* n = std::get_if(&term)) { j["type"] = "int"; inner_j = *n; } else { @@ -155,7 +156,7 @@ void to_json(nlohmann::json& j, const ClExprTerm& term) { void from_json(const nlohmann::json& j, ClExprTerm& term) { const std::string termtype = j.at("type").get(); if (termtype == "int") { - term = j.at("term").get(); + term = j.at("term").get(); } else { TKET_ASSERT(termtype == "var"); term = j.at("term").get(); diff --git a/tket/test/src/test_ClExpr.cpp b/tket/test/src/test_ClExpr.cpp index a49888ce5e..248a8cfbbe 100644 --- a/tket/test/src/test_ClExpr.cpp +++ b/tket/test/src/test_ClExpr.cpp @@ -13,6 +13,7 @@ // limitations under the License. #include +#include #include #include #include @@ -127,7 +128,7 @@ SCENARIO("Serialization and stringification") { REQUIRE(var_reg1 == var_reg); } GIVEN("ClExprTerm") { - ClExprTerm term_int = 7; + ClExprTerm term_int = uint64_t{7}; ClExprTerm term_var = ClRegVar{5}; std::stringstream ss; ss << term_int << ", " << term_var; @@ -140,14 +141,14 @@ SCENARIO("Serialization and stringification") { REQUIRE(term_var1 == term_var); } GIVEN("Vector of ClExprArg (1)") { - std::vector args{ClRegVar{2}, int{3}}; + std::vector args{ClRegVar{2}, uint64_t{3}}; nlohmann::json j = args; std::vector args1 = j.get>(); REQUIRE(args == args1); } GIVEN("ClExpr (1)") { // r0 + 7 - ClExpr expr(ClOp::RegAdd, {ClRegVar{0}, int{7}}); + ClExpr expr(ClOp::RegAdd, {ClRegVar{0}, uint64_t{7}}); std::stringstream ss; ss << expr; REQUIRE(ss.str() == "add(r0, 7)"); @@ -156,7 +157,7 @@ SCENARIO("Serialization and stringification") { REQUIRE(expr1 == expr); } GIVEN("Vector of ClExprArg (2)") { - ClExpr expr(ClOp::RegAdd, {ClRegVar{0}, int{8}}); + ClExpr expr(ClOp::RegAdd, {ClRegVar{0}, uint64_t{8}}); std::vector args{expr}; nlohmann::json j = args; std::vector args1 = j.get>(); @@ -165,7 +166,7 @@ SCENARIO("Serialization and stringification") { GIVEN("ClExpr (2)") { // (r0 + r1) / (r2 * 3) ClExpr numer(ClOp::RegAdd, {ClRegVar{0}, ClRegVar{1}}); - ClExpr denom(ClOp::RegMul, {ClRegVar{2}, int{3}}); + ClExpr denom(ClOp::RegMul, {ClRegVar{2}, uint64_t{3}}); ClExpr expr(ClOp::RegDiv, {numer, denom}); std::stringstream ss; ss << expr;