Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deconstruct classical conditions in terms of ClExprOp instead of ClassicalExpBox #1657

Merged
merged 12 commits into from
Nov 6, 2024
4 changes: 2 additions & 2 deletions pytket/binders/circuit/clexpr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ namespace tket {

static std::string qasm_bit_repr(
const ClExprTerm &term, const std::map<int, Bit> &input_bits) {
if (const int *n = std::get_if<int>(&term)) {
if (const uint64_t *n = std::get_if<uint64_t>(&term)) {
switch (*n) {
case 0:
return "0";
Expand All @@ -56,7 +56,7 @@ static std::string qasm_bit_repr(

static std::string qasm_reg_repr(
const ClExprTerm &term, const std::map<int, BitRegister> &input_regs) {
if (const int *n = std::get_if<int>(&term)) {
if (const uint64_t *n = std::get_if<uint64_t>(&term)) {
std::stringstream ss;
ss << *n;
return ss.str();
Expand Down
2 changes: 1 addition & 1 deletion pytket/conanfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def requirements(self):
self.requires("pybind11_json/0.2.14")
self.requires("symengine/0.12.0")
self.requires("tkassert/0.3.4@tket/stable")
self.requires("tket/1.3.37@tket/stable")
self.requires("tket/1.3.38@tket/stable")
self.requires("tklog/0.3.3@tket/stable")
self.requires("tkrng/0.3.3@tket/stable")
self.requires("tktokenswap/0.3.9@tket/stable")
Expand Down
2 changes: 2 additions & 0 deletions pytket/docs/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ Features:

* Add `clexpr.check_register_alignments()` method to check register alignments
in `ClExprOp`.
* Use `ClExprOp` instead of `ClassicalExpBox` when deconstructing complex
conditions.

Fixes:

Expand Down
17 changes: 8 additions & 9 deletions pytket/pytket/circuit/add_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,8 @@
from typing import Tuple, Union

from pytket.circuit import Bit, Circuit, BitRegister
from pytket._tket.unit_id import (
_TEMP_REG_SIZE,
_TEMP_BIT_NAME,
_TEMP_BIT_REG_BASE,
)
from pytket._tket.unit_id import _TEMP_BIT_NAME, _TEMP_BIT_REG_BASE
from pytket.circuit.clexpr import wired_clexpr_from_logic_exp
from pytket.circuit.logic_exp import (
BitLogicExp,
Constant,
Expand Down Expand Up @@ -79,7 +76,8 @@ def _add_condition(
circ.add_bit(condition_bit)

if isinstance(pred_exp, BitLogicExp):
circ.add_classicalexpbox_bit(pred_exp, [condition_bit])
wexpr, args = wired_clexpr_from_logic_exp(pred_exp, [condition_bit])
circ.add_clexpr(wexpr, args)
return condition_bit, bool(pred_val)

assert isinstance(pred_exp, (RegLogicExp, BitRegister))
Expand All @@ -99,10 +97,11 @@ def _add_condition(
int(r_name.split("_")[-1]) for r_name in existing_reg_names
)
next_index = max(existing_reg_indices, default=-1) + 1
temp_reg = BitRegister(f"{_TEMP_BIT_REG_BASE}_{next_index}", _TEMP_REG_SIZE)
temp_reg = BitRegister(f"{_TEMP_BIT_REG_BASE}_{next_index}", min_reg_size)
circ.add_c_register(temp_reg)
target_bits = temp_reg.to_list()[:min_reg_size]
circ.add_classicalexpbox_register(pred_exp, target_bits)
target_bits = temp_reg.to_list()
wexpr, args = wired_clexpr_from_logic_exp(pred_exp, target_bits)
circ.add_clexpr(wexpr, args)
elif isinstance(pred_exp, BitRegister):
target_bits = pred_exp.to_list()

Expand Down
7 changes: 7 additions & 0 deletions pytket/pytket/circuit/decompose_classical.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,5 +358,12 @@ def _decompose_expressions(circ: Circuit) -> Tuple[Circuit, bool]:
# add_gate doesn't work for metaops
newcirc.add_barrier(args)
else:
for arg in args:
if (
isinstance(arg, Bit)
and arg.reg_name != "_w" # workaround: this shouldn't be type Bit
and arg not in newcirc.bits
):
newcirc.add_bit(arg)
newcirc.add_gate(op, args, **kwargs)
return newcirc, modified
174 changes: 1 addition & 173 deletions pytket/tests/classical_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1037,7 +1037,7 @@ def test_regpredicate(condition: PredicateExp) -> None:
circ.add_bit(inp, reject_dups=False)

circ.X(qb, condition=condition)
assert circ.n_gates_of_type(OpType.ClassicalExpBox) == 1
assert circ.n_gates_of_type(OpType.ClassicalExpBox) == 0
newcirc = circ.copy()
DecomposeClassicalExp().apply(newcirc)

Expand Down Expand Up @@ -1115,178 +1115,6 @@ def check_serialization_roundtrip(circ: Circuit) -> None:
assert circ_from_dict.to_dict() == circ_dict


def test_decomposition_known() -> None:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Next step after this will be to write a new decomposition pass, at which point I'll add a corresponding test to replace this one.

bits = [Bit(i) for i in range(10)]
registers = [BitRegister(c, 3) for c in "abdefghijk"]

qreg = QubitRegister("q_", 10)
circ = Circuit()
conditioned_circ = Circuit()
decomposed_circ = Circuit()

for c in (circ, conditioned_circ, decomposed_circ):
for b in bits:
c.add_bit(b)
for br in registers:
for b in br.to_list():
c.add_bit(b, reject_dups=False)
c.add_q_register(qreg.name, qreg.size)

circ.H(qreg[0], condition=bits[0])
circ.X(qreg[0], condition=if_bit(bits[1]))
circ.S(qreg[0])
circ.T(qreg[1], condition=if_not_bit(bits[2]))
circ.Z(qreg[0], condition=(bits[2] & bits[3]))
circ.Z(qreg[1], condition=if_not_bit(bits[3] & bits[4]))
big_exp = bits[4] | bits[5] ^ bits[6] | bits[7] & bits[8]
# ^ no need for parantheses as python operator precedence
# will enforce correct precedence in LogicExp
circ.CX(qreg[0], qreg[1])
circ.CX(qreg[1], qreg[2], condition=big_exp)

circ.add_barrier(qreg.to_list())

circ.H(qreg[2], condition=reg_eq(registers[0], 3))
circ.X(qreg[3], condition=reg_lt(registers[1], 6))
circ.Y(qreg[4], condition=reg_neq(registers[2], 5))
circ.Z(qreg[5], condition=reg_gt(registers[3], 3))
circ.S(qreg[6], condition=reg_leq(registers[4], 6))
circ.T(qreg[7], condition=reg_geq(registers[5], 3))
big_reg_exp = registers[4] & registers[3] | registers[6] ^ registers[7]
circ.CX(qreg[3], qreg[4], condition=reg_eq(big_reg_exp, 3))

circ.add_classicalexpbox_bit(
bits[4] | bits[5] & bits[3], [bits[0]], condition=bits[1]
)
check_serialization_roundtrip(circ)

temp_bits = BitRegister(_TEMP_BIT_NAME, 64)

def temp_reg(i: int) -> BitRegister:
return BitRegister(f"{_TEMP_BIT_REG_BASE}_{i}", 64)

for b in (temp_bits[i] for i in range(0, 10)):
conditioned_circ.add_bit(b)

for t_r in (temp_reg(i) for i in range(0, 1)):
conditioned_circ.add_c_register(t_r.name, t_r.size)

# relies on existing interface for adding conditionals
# may need a more low level interface for that if we decide to get rid of it
conditioned_circ.H(qreg[0], condition_bits=[bits[0]], condition_value=1)
conditioned_circ.X(qreg[0], condition_bits=[bits[1]], condition_value=1)
conditioned_circ.S(qreg[0])
conditioned_circ.T(qreg[1], condition_bits=[bits[2]], condition_value=0)

conditioned_circ.add_classicalexpbox_bit((bits[2] & bits[3]), [temp_bits[0]])
conditioned_circ.Z(qreg[0], condition_bits=[temp_bits[0]], condition_value=1)
conditioned_circ.add_classicalexpbox_bit((bits[3] & bits[4]), [temp_bits[1]])
conditioned_circ.Z(qreg[1], condition_bits=[temp_bits[1]], condition_value=0)
conditioned_circ.CX(qreg[0], qreg[1])
conditioned_circ.add_classicalexpbox_bit(big_exp, [temp_bits[2]])
conditioned_circ.CX(
qreg[1], qreg[2], condition_bits=[temp_bits[2]], condition_value=1
)

conditioned_circ.add_barrier(qreg.to_list())

registers_lists = [reg.to_list() for reg in registers]

conditioned_circ.add_c_range_predicate(3, 3, registers_lists[0], temp_bits[3])
conditioned_circ.H(qreg[2], condition_bits=[temp_bits[3]], condition_value=1)
conditioned_circ.add_c_range_predicate(0, 5, registers_lists[1], temp_bits[4])
conditioned_circ.X(qreg[3], condition_bits=[temp_bits[4]], condition_value=1)
conditioned_circ.add_c_range_predicate(5, 5, registers_lists[2], temp_bits[5])
conditioned_circ.Y(qreg[4], condition_bits=[temp_bits[5]], condition_value=0)
conditioned_circ.add_c_range_predicate(
4, 18446744073709551615, registers_lists[3], temp_bits[6]
)
conditioned_circ.Z(qreg[5], condition_bits=[temp_bits[6]], condition_value=1)
conditioned_circ.add_c_range_predicate(0, 6, registers_lists[4], temp_bits[7])
conditioned_circ.S(qreg[6], condition_bits=[temp_bits[7]], condition_value=1)
conditioned_circ.add_c_range_predicate(
3, 18446744073709551615, registers_lists[5], temp_bits[8]
)
conditioned_circ.T(qreg[7], condition_bits=[temp_bits[8]], condition_value=1)

temp_reg_bits = [temp_reg(0)[i] for i in range(3)]
conditioned_circ.add_classicalexpbox_register(big_reg_exp, temp_reg_bits)
conditioned_circ.add_c_range_predicate(3, 3, temp_reg_bits, temp_bits[9])
conditioned_circ.CX(
qreg[3], qreg[4], condition_bits=[temp_bits[9]], condition_value=1
)
conditioned_circ.add_classicalexpbox_bit(
bits[4] | bits[5] & bits[3], [bits[0]], condition=bits[1]
)

assert compare_commands_box(circ, conditioned_circ)

for b in (temp_bits[i] for i in range(0, 12)):
decomposed_circ.add_bit(b)

decomposed_circ.add_c_register(BitRegister(f"{_TEMP_BIT_REG_BASE}_0", 3))
decomposed_circ.add_c_register(BitRegister(f"{_TEMP_BIT_REG_BASE}_1", 64))
decomposed_circ.add_c_register(BitRegister(f"{_TEMP_BIT_REG_BASE}_2", 64))

decomposed_circ.H(qreg[0], condition_bits=[bits[0]], condition_value=1)
decomposed_circ.X(qreg[0], condition_bits=[bits[1]], condition_value=1)
decomposed_circ.S(qreg[0])
decomposed_circ.T(qreg[1], condition_bits=[bits[2]], condition_value=0)
decomposed_circ.add_c_and(bits[2], bits[3], temp_bits[0])
decomposed_circ.Z(qreg[0], condition_bits=[temp_bits[0]], condition_value=1)
decomposed_circ.add_c_and(bits[3], bits[4], temp_bits[1])
decomposed_circ.Z(qreg[1], condition_bits=[temp_bits[1]], condition_value=0)
decomposed_circ.CX(qreg[0], qreg[1])
decomposed_circ.add_c_range_predicate(3, 3, registers_lists[0], temp_bits[3])
decomposed_circ.add_c_range_predicate(0, 5, registers_lists[1], temp_bits[4])
decomposed_circ.add_c_range_predicate(5, 5, registers_lists[2], temp_bits[5])
decomposed_circ.add_c_range_predicate(
4, 18446744073709551615, registers_lists[3], temp_bits[6]
)
decomposed_circ.add_c_range_predicate(0, 6, registers_lists[4], temp_bits[7])
decomposed_circ.add_c_range_predicate(
3, 18446744073709551615, registers_lists[5], temp_bits[8]
)

decomposed_circ.add_c_xor(bits[5], bits[6], temp_bits[10])
decomposed_circ.add_c_and(bits[7], bits[8], temp_bits[11])
decomposed_circ.add_c_or(bits[4], temp_bits[10], temp_bits[10])
decomposed_circ.add_c_or(temp_bits[10], temp_bits[11], temp_bits[2])
decomposed_circ.CX(
qreg[1], qreg[2], condition_bits=[temp_bits[2]], condition_value=1
)

decomposed_circ.add_barrier(qreg.to_list())

decomposed_circ.H(qreg[2], condition_bits=[temp_bits[3]], condition_value=1)
decomposed_circ.X(qreg[3], condition_bits=[temp_bits[4]], condition_value=1)
decomposed_circ.Y(qreg[4], condition_bits=[temp_bits[5]], condition_value=0)
decomposed_circ.Z(qreg[5], condition_bits=[temp_bits[6]], condition_value=1)
decomposed_circ.S(qreg[6], condition_bits=[temp_bits[7]], condition_value=1)
decomposed_circ.T(qreg[7], condition_bits=[temp_bits[8]], condition_value=1)

decomposed_circ.add_c_and_to_registers(registers[4], registers[3], temp_reg(1))
decomposed_circ.add_c_xor_to_registers(registers[6], registers[7], temp_reg(2))
decomposed_circ.add_c_or_to_registers(
temp_reg(1), BitRegister(temp_reg(2).name, 3), temp_reg(0)
)
decomposed_circ.add_c_range_predicate(3, 3, temp_reg(0).to_list()[:3], temp_bits[9])
decomposed_circ.CX(
qreg[3], qreg[4], condition_bits=[temp_bits[9]], condition_value=1
)
decomposed_circ.add_c_and(
bits[5], bits[3], temp_bits[10], condition_bits=[bits[1]], condition_value=1
)
decomposed_circ.add_c_or(
bits[4], temp_bits[10], bits[0], condition_bits=[bits[1]], condition_value=1
)
check_serialization_roundtrip(decomposed_circ)
circ_copy = circ.copy()

DecomposeClassicalExp().apply(circ_copy)
assert circ_copy == decomposed_circ


def test_conditional() -> None:
c = Circuit(1, 2)
c.H(0, condition_bits=[0, 1], condition_value=3)
Expand Down
2 changes: 1 addition & 1 deletion pytket/tests/compilation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,7 @@ def test_resize_scratch_registers() -> None:
reg_a = circ.add_c_register("a", 1)
reg_b = circ.add_c_register("b", 1)
circ.X(0, condition=reg_eq(reg_a ^ reg_b, 1))
assert circ.get_c_register(f"{_TEMP_BIT_REG_BASE}_0").size == 64
assert circ.get_c_register(f"{_TEMP_BIT_REG_BASE}_0").size == 1
c_compiled = circ.copy()
scratch_reg_resize_pass(10).apply(c_compiled)
assert circ == c_compiled
2 changes: 1 addition & 1 deletion tket/conanfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

class TketConan(ConanFile):
name = "tket"
version = "1.3.37"
version = "1.3.38"
package_type = "library"
license = "Apache 2"
homepage = "https://github.com/CQCL/tket"
Expand Down
3 changes: 2 additions & 1 deletion tket/include/tket/Ops/ClExpr.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
* @brief Classical expressions involving bits and registers
*/

#include <cstdint>
#include <map>
#include <nlohmann/detail/macro_scope.hpp>
#include <ostream>
Expand Down Expand Up @@ -124,7 +125,7 @@ void from_json(const nlohmann::json& j, ClExprVar& var);
/**
* A term in a classical expression (either a constant or a variable)
*/
typedef std::variant<int, ClExprVar> ClExprTerm;
typedef std::variant<uint64_t, ClExprVar> ClExprTerm;

std::ostream& operator<<(std::ostream& os, const ClExprTerm& term);

Expand Down
7 changes: 4 additions & 3 deletions tket/src/Ops/ClExpr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "tket/Ops/ClExpr.hpp"

#include <algorithm>
#include <cstdint>
#include <set>
#include <stdexcept>
#include <string>
Expand Down Expand Up @@ -131,7 +132,7 @@ void from_json(const nlohmann::json& j, ClExprVar& var) {
}

std::ostream& operator<<(std::ostream& os, const ClExprTerm& term) {
if (const int* n = std::get_if<int>(&term)) {
if (const uint64_t* n = std::get_if<uint64_t>(&term)) {
return os << *n;
} else {
ClExprVar var = std::get<ClExprVar>(term);
Expand All @@ -141,7 +142,7 @@ std::ostream& operator<<(std::ostream& os, const ClExprTerm& term) {

void to_json(nlohmann::json& j, const ClExprTerm& term) {
nlohmann::json inner_j;
if (const int* n = std::get_if<int>(&term)) {
if (const uint64_t* n = std::get_if<uint64_t>(&term)) {
j["type"] = "int";
inner_j = *n;
} else {
Expand All @@ -155,7 +156,7 @@ void to_json(nlohmann::json& j, const ClExprTerm& term) {
void from_json(const nlohmann::json& j, ClExprTerm& term) {
const std::string termtype = j.at("type").get<std::string>();
if (termtype == "int") {
term = j.at("term").get<int>();
term = j.at("term").get<uint64_t>();
} else {
TKET_ASSERT(termtype == "var");
term = j.at("term").get<ClExprVar>();
Expand Down
11 changes: 6 additions & 5 deletions tket/test/src/test_ClExpr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
// limitations under the License.

#include <catch2/catch_test_macros.hpp>
#include <cstdint>
#include <memory>
#include <nlohmann/json_fwd.hpp>
#include <sstream>
Expand Down Expand Up @@ -127,7 +128,7 @@ SCENARIO("Serialization and stringification") {
REQUIRE(var_reg1 == var_reg);
}
GIVEN("ClExprTerm") {
ClExprTerm term_int = 7;
ClExprTerm term_int = uint64_t{7};
ClExprTerm term_var = ClRegVar{5};
std::stringstream ss;
ss << term_int << ", " << term_var;
Expand All @@ -140,14 +141,14 @@ SCENARIO("Serialization and stringification") {
REQUIRE(term_var1 == term_var);
}
GIVEN("Vector of ClExprArg (1)") {
std::vector<ClExprArg> args{ClRegVar{2}, int{3}};
std::vector<ClExprArg> args{ClRegVar{2}, uint64_t{3}};
nlohmann::json j = args;
std::vector<ClExprArg> args1 = j.get<std::vector<ClExprArg>>();
REQUIRE(args == args1);
}
GIVEN("ClExpr (1)") {
// r0 + 7
ClExpr expr(ClOp::RegAdd, {ClRegVar{0}, int{7}});
ClExpr expr(ClOp::RegAdd, {ClRegVar{0}, uint64_t{7}});
std::stringstream ss;
ss << expr;
REQUIRE(ss.str() == "add(r0, 7)");
Expand All @@ -156,7 +157,7 @@ SCENARIO("Serialization and stringification") {
REQUIRE(expr1 == expr);
}
GIVEN("Vector of ClExprArg (2)") {
ClExpr expr(ClOp::RegAdd, {ClRegVar{0}, int{8}});
ClExpr expr(ClOp::RegAdd, {ClRegVar{0}, uint64_t{8}});
std::vector<ClExprArg> args{expr};
nlohmann::json j = args;
std::vector<ClExprArg> args1 = j.get<std::vector<ClExprArg>>();
Expand All @@ -165,7 +166,7 @@ SCENARIO("Serialization and stringification") {
GIVEN("ClExpr (2)") {
// (r0 + r1) / (r2 * 3)
ClExpr numer(ClOp::RegAdd, {ClRegVar{0}, ClRegVar{1}});
ClExpr denom(ClOp::RegMul, {ClRegVar{2}, int{3}});
ClExpr denom(ClOp::RegMul, {ClRegVar{2}, uint64_t{3}});
ClExpr expr(ClOp::RegDiv, {numer, denom});
std::stringstream ss;
ss << expr;
Expand Down
Loading