From c2dd14a46cc9ce784efd8644634a2ff4e42588b3 Mon Sep 17 00:00:00 2001 From: PabloAndresCQ Date: Mon, 11 Nov 2024 19:10:49 +0000 Subject: [PATCH 1/5] Added support for ClExpr --- .../cutensornet/structured_state/classical.py | 89 ++++++++ tests/test_structured_state_conditionals.py | 193 +++++++++++++++++- 2 files changed, 281 insertions(+), 1 deletion(-) diff --git a/pytket/extensions/cutensornet/structured_state/classical.py b/pytket/extensions/cutensornet/structured_state/classical.py index d49ac4d9..047787eb 100644 --- a/pytket/extensions/cutensornet/structured_state/classical.py +++ b/pytket/extensions/cutensornet/structured_state/classical.py @@ -22,11 +22,13 @@ SetBitsOp, CopyBitsOp, RangePredicateOp, + ClExprOp, ClassicalExpBox, LogicExp, BitWiseOp, RegWiseOp, ) +from pytket._tket.circuit import ClExpr, ClOp, ClBitVar, ClRegVar ExtendedLogicExp = Union[LogicExp, Bit, BitRegister, int] @@ -56,6 +58,20 @@ def apply_classical_command( # Check that the value is in the range bits_dict[res_bit] = val >= op.lower and val <= op.upper + elif isinstance(op, ClExprOp): + # Convert bit_posn to dictionary of `ClBitVar` index to its value + bitvar_val = {var_id: int(bits_dict[args[bit_pos]]) for var_id, bit_pos in op.expr.bit_posn.items()} + # Convert reg_posn to dictionary of `ClRegVar` index to its value + regvar_val = {var_id: from_little_endian([bits_dict[args[bit_pos]] for bit_pos in reg_pos_list]) for var_id, reg_pos_list in op.expr.reg_posn.items()} + result = evaluate_clexpr(op.expr.expr, bitvar_val, regvar_val) + + # The result is an int in little-endian encoding. We update the + # output register accordingly. + for bit_pos in op.expr.output_posn: + bits_dict[args[bit_pos]] = (result % 2) == 1 + result = result >> 1 + assert result == 0 # All bits consumed + elif isinstance(op, ClassicalExpBox): the_exp = op.get_exp() result = evaluate_logic_exp(the_exp, bits_dict) @@ -74,6 +90,79 @@ def apply_classical_command( raise NotImplementedError(f"Commands of type {op.type} are not supported.") +def evaluate_clexpr(expr: ClExpr, bitvar_val: dict[int, int], regvar_val: dict[int, int]) -> int: + """Recursive evaluation of a ClExpr.""" + + # Evaluate arguments to operation + args_val = [] + for arg in expr.args: + if isinstance(arg, int): + value = arg + elif isinstance(arg, ClBitVar): + value = bitvar_val[arg.index] + elif isinstance(arg, ClRegVar): + value = regvar_val[arg.index] + elif isinstance(arg, ClExpr): + value = evaluate_clexpr(arg, bitvar_val, regvar_val) + else: + raise Exception(f"Unrecognised argument type of ClExpr: {type(arg)}.") + + args_val.append(value) + + # Apply the operation at the root of this ClExpr + if expr.op in [ClOp.BitAnd, ClOp.RegAnd]: + result = args_val[0] & args_val[1] + elif expr.op in [ClOp.BitOr, ClOp.RegOr]: + result = args_val[0] | args_val[1] + elif expr.op in [ClOp.BitXor, ClOp.RegXor]: + result = args_val[0] ^ args_val[1] + elif expr.op in [ClOp.BitEq, ClOp.RegEq]: + result = int(args_val[0] == args_val[1]) + elif expr.op in [ClOp.BitNeq, ClOp.RegNeq]: + result = int(args_val[0] != args_val[1]) + elif expr.op == ClOp.RegGeq: + result = int(args_val[0] >= args_val[1]) + elif expr.op == ClOp.RegGt: + result = int(args_val[0] > args_val[1]) + elif expr.op == ClOp.RegLeq: + result = int(args_val[0] <= args_val[1]) + elif expr.op == ClOp.RegLt: + result = int(args_val[0] < args_val[1]) + elif expr.op == ClOp.BitNot: + result = 1 - args_val[0] + # elif expr.op == ClOp.RegNot: + # result = int(args_val[0] == 0) + elif expr.op in [ClOp.BitZero, ClOp.RegZero]: + result = 0 + elif expr.op in [ClOp.BitOne, ClOp.RegOne]: + result = 1 + # elif expr.op == ClOp.RegAdd: + # result = args_val[0] + args_val[1] + # elif expr.op == ClOp.RegSub: + # result = args_val[0] - args_val[1] + # elif expr.op == ClOp.RegMul: + # result = args_val[0] * args_val[1] + # elif expr.op == ClOp.RegPow: + # result = int(args_val[0] ** args_val[1]) + elif expr.op == ClOp.RegRsh: + result = args_val[0] >> args_val[1] + # elif expr.op == ClOp.RegNeg: + # result = -args_val[0] + else: + # TODO: Currently not supporting ClOp's RegDiv since it does not return int, + # so I am unsure what the semantic is meant to be. + # TODO: I don't now what to do with RegNot, since input + # is not guaranteed to be 0 or 1. + # TODO: It is not clear what to do with overflow of ADD, etc. + # so I have decided to not support them for now. + raise NotImplementedError( + f"Evaluation of {expr.op} not supported in ClExpr ", + "by pytket-cutensornet.", + ) + + return result + + def evaluate_logic_exp(exp: ExtendedLogicExp, bits_dict: dict[Bit, bool]) -> int: """Recursive evaluation of a LogicExp.""" diff --git a/tests/test_structured_state_conditionals.py b/tests/test_structured_state_conditionals.py index c5b6a116..8de0b1d3 100644 --- a/tests/test_structured_state_conditionals.py +++ b/tests/test_structured_state_conditionals.py @@ -12,6 +12,7 @@ reg_eq, ) from pytket.circuit.logic_exp import BitWiseOp, create_bit_logic_exp +from pytket.circuit.clexpr import wired_clexpr_from_logic_exp from pytket.extensions.cutensornet.structured_state import ( CuTensorNetHandle, @@ -25,6 +26,35 @@ # (see https://github.com/CQCL/pytket-qir/blob/main/tests/conditional_test.py) # Further down, there are tests to check that the simulation works correctly. +def test_circuit_with_clexpr_i() -> None: + # test conditional handling + + circ = Circuit(3) + a = circ.add_c_register("a", 5) + b = circ.add_c_register("b", 5) + c = circ.add_c_register("c", 5) + d = circ.add_c_register("d", 5) + circ.H(0) + wexpr, args = wired_clexpr_from_logic_exp(a | b, c) + circ.add_clexpr(wexpr, args) + wexpr, args = wired_clexpr_from_logic_exp(c | b, d) + circ.add_clexpr(wexpr, args) + wexpr, args = wired_clexpr_from_logic_exp(c | b, d) + circ.add_clexpr(wexpr, args, condition=a[4]) + circ.H(0) + circ.Measure(Qubit(0), d[4]) + circ.H(1) + circ.Measure(Qubit(1), d[3]) + circ.H(2) + circ.Measure(Qubit(2), d[2]) + + with CuTensorNetHandle() as libhandle: + cfg = Config() + state = simulate(libhandle, circ, SimulationAlgorithm.MPSxGate, cfg) + assert state.is_valid() + assert np.isclose(state.vdot(state), 1.0, atol=cfg._atol) + assert state.get_fidelity() == 1.0 + def test_circuit_with_classicalexpbox_i() -> None: # test conditional handling @@ -52,6 +82,35 @@ def test_circuit_with_classicalexpbox_i() -> None: assert np.isclose(state.vdot(state), 1.0, atol=cfg._atol) assert state.get_fidelity() == 1.0 +def test_circuit_with_clexpr_ii() -> None: + # test conditional handling with else case + + circ = Circuit(3) + a = circ.add_c_register("a", 5) + b = circ.add_c_register("b", 5) + c = circ.add_c_register("c", 5) + d = circ.add_c_register("d", 5) + circ.H(0) + wexpr, args = wired_clexpr_from_logic_exp(a | b, c) + circ.add_clexpr(wexpr, args) + wexpr, args = wired_clexpr_from_logic_exp(c | b, d) + circ.add_clexpr(wexpr, args) + wexpr, args = wired_clexpr_from_logic_exp(c | b, d) + circ.add_clexpr(wexpr, args, condition=if_not_bit(a[4])) + circ.H(0) + circ.Measure(Qubit(0), d[4]) + circ.H(1) + circ.Measure(Qubit(1), d[3]) + circ.H(2) + circ.Measure(Qubit(2), d[2]) + + with CuTensorNetHandle() as libhandle: + cfg = Config() + state = simulate(libhandle, circ, SimulationAlgorithm.MPSxGate, cfg) + assert state.is_valid() + assert np.isclose(state.vdot(state), 1.0, atol=cfg._atol) + assert state.get_fidelity() == 1.0 + def test_circuit_with_classicalexpbox_ii() -> None: # test conditional handling with else case @@ -81,6 +140,35 @@ def test_circuit_with_classicalexpbox_ii() -> None: assert np.isclose(state.vdot(state), 1.0, atol=cfg._atol) assert state.get_fidelity() == 1.0 +@pytest.mark.skip(reason="Currently not supporting arithmetic operations in ClExpr") +def test_circuit_with_clexpr_iii() -> None: + # test complicated conditions and recursive classical op + + circ = Circuit(2) + + a = circ.add_c_register("a", 15) + b = circ.add_c_register("b", 15) + c = circ.add_c_register("c", 15) + d = circ.add_c_register("d", 15) + e = circ.add_c_register("e", 15) + + circ.H(0) + bits = [Bit(i) for i in range(10)] + big_exp = bits[4] | bits[5] ^ bits[6] | bits[7] & bits[8] + circ.H(0, condition=big_exp) + + wexpr, args = wired_clexpr_from_logic_exp(a + b - d, c) + circ.add_clexpr(wexpr, args) + wexpr, args = wired_clexpr_from_logic_exp(a * b * d * c, e) + circ.add_clexpr(wexpr, args) + + with CuTensorNetHandle() as libhandle: + cfg = Config() + state = simulate(libhandle, circ, SimulationAlgorithm.MPSxGate, cfg) + assert state.is_valid() + assert np.isclose(state.vdot(state), 1.0, atol=cfg._atol) + assert state.get_fidelity() == 1.0 + @pytest.mark.skip(reason="Currently not supporting arithmetic operations in LogicExp") def test_circuit_with_classicalexpbox_iii() -> None: @@ -239,6 +327,32 @@ def test_pytket_qir_conditional_10() -> None: assert state.get_fidelity() == 1.0 +def test_pytket_qir_conditional_11() -> None: + box_circ = Circuit(4) + box_circ.X(0) + box_circ.Y(1) + box_circ.Z(2) + box_circ.H(3) + box_c = box_circ.add_c_register("c", 5) + + box_circ.H(0) + + wexpr, args = wired_clexpr_from_logic_exp(box_c | box_c, box_c) + box_circ.add_clexpr(wexpr, args) + + cbox = CircBox(box_circ) + d = Circuit(4, 5) + a = d.add_c_register("a", 4) + d.add_circbox(cbox, [0, 2, 1, 3, 0, 1, 2, 3, 4], condition=a[0]) + + with CuTensorNetHandle() as libhandle: + cfg = Config() + state = simulate(libhandle, d, SimulationAlgorithm.MPSxGate, cfg) + assert state.is_valid() + assert np.isclose(state.vdot(state), 1.0, atol=cfg._atol) + assert state.get_fidelity() == 1.0 + + def test_circuit_with_conditional_gate_v() -> None: # test conditional with no register @@ -430,7 +544,84 @@ def test_repeat_until_success_i() -> None: assert np.allclose(target_state, output_state) -def test_repeat_until_success_ii() -> None: +def test_repeat_until_success_ii_clexpr() -> None: + # From Figure 1(c) of https://arxiv.org/pdf/1311.1074 + + attempts = 100 + + circ = Circuit() + qin = circ.add_q_register("qin", 1) + qaux = circ.add_q_register("aux", 2) + flag = circ.add_c_register("flag", 3) + circ.add_c_setbits([True, True], [flag[0], flag[1]]) # Set flag bits to 11 + circ.H(qin[0]) # Use to convert gate to sqrt(1/5)*I + i*sqrt(4/5)*X (i.e. Z -> X) + + for _ in range(attempts): + wexpr, args = wired_clexpr_from_logic_exp( + flag[0] | flag[1], [flag[2]] # Success if both are zero + ) + circ.add_clexpr(wexpr, args) + + circ.add_gate( + OpType.Reset, [qaux[0]], condition_bits=[flag[2]], condition_value=1 + ) + circ.add_gate( + OpType.Reset, [qaux[1]], condition_bits=[flag[2]], condition_value=1 + ) + circ.add_gate(OpType.H, [qaux[0]], condition_bits=[flag[2]], condition_value=1) + circ.add_gate(OpType.H, [qaux[1]], condition_bits=[flag[2]], condition_value=1) + + circ.add_gate(OpType.T, [qin[0]], condition_bits=[flag[2]], condition_value=1) + circ.add_gate(OpType.Z, [qin[0]], condition_bits=[flag[2]], condition_value=1) + circ.add_gate( + OpType.Tdg, [qaux[0]], condition_bits=[flag[2]], condition_value=1 + ) + circ.add_gate( + OpType.CX, [qaux[1], qaux[0]], condition_bits=[flag[2]], condition_value=1 + ) + circ.add_gate(OpType.T, [qaux[0]], condition_bits=[flag[2]], condition_value=1) + circ.add_gate( + OpType.CX, [qin[0], qaux[1]], condition_bits=[flag[2]], condition_value=1 + ) + circ.add_gate(OpType.T, [qaux[1]], condition_bits=[flag[2]], condition_value=1) + + circ.add_gate(OpType.H, [qaux[0]], condition_bits=[flag[2]], condition_value=1) + circ.add_gate(OpType.H, [qaux[1]], condition_bits=[flag[2]], condition_value=1) + circ.Measure(qaux[0], flag[0], condition_bits=[flag[2]], condition_value=1) + circ.Measure(qaux[1], flag[1], condition_bits=[flag[2]], condition_value=1) + + # From chat with Silas and exploring the RUS as a block matrix, we have noticed + # that the circuit is missing an X correction when this condition is satisfied + wexpr, args = wired_clexpr_from_logic_exp(flag[0] ^ flag[1], [flag[2]]) + circ.add_clexpr(wexpr, args) + circ.add_gate(OpType.Z, [qin[0]], condition_bits=[flag[2]], condition_value=1) + + circ.H(qin[0]) # Use to convert gate to sqrt(1/5)*I + i*sqrt(4/5)*X (i.e. Z -> X) + + with CuTensorNetHandle() as libhandle: + cfg = Config() + + state = simulate(libhandle, circ, SimulationAlgorithm.MPSxGate, cfg) + assert state.is_valid() + assert np.isclose(state.vdot(state), 1.0, atol=cfg._atol) + assert state.get_fidelity() == 1.0 + + # All of the flag bits should have turned False + assert all(not state.get_bits()[bit] for bit in flag) + # The auxiliary qubits should be in state |0> + prob = state.postselect({qaux[0]: 0, qaux[1]: 0}) + assert np.isclose(prob, 1.0) + + target_state = [np.sqrt(1 / 5), np.sqrt(4 / 5) * 1j] + output_state = state.get_statevector() + # As indicated in the paper, the gate is implemented up to global phase + global_phase = target_state[0] / output_state[0] + assert np.isclose(abs(global_phase), 1.0) + output_state *= global_phase + assert np.allclose(target_state, output_state) + + +def test_repeat_until_success_ii_classicalexpblox() -> None: # From Figure 1(c) of https://arxiv.org/pdf/1311.1074 attempts = 100 From 67f5378fc80f11a93b0f9be6847f3be17ada4412 Mon Sep 17 00:00:00 2001 From: PabloAndresCQ Date: Mon, 11 Nov 2024 19:22:19 +0000 Subject: [PATCH 2/5] Linting --- .../cutensornet/structured_state/classical.py | 16 +++++++++++--- tests/test_structured_state_conditionals.py | 21 +++++++++++-------- 2 files changed, 25 insertions(+), 12 deletions(-) diff --git a/pytket/extensions/cutensornet/structured_state/classical.py b/pytket/extensions/cutensornet/structured_state/classical.py index 047787eb..d7bc9c99 100644 --- a/pytket/extensions/cutensornet/structured_state/classical.py +++ b/pytket/extensions/cutensornet/structured_state/classical.py @@ -60,9 +60,17 @@ def apply_classical_command( elif isinstance(op, ClExprOp): # Convert bit_posn to dictionary of `ClBitVar` index to its value - bitvar_val = {var_id: int(bits_dict[args[bit_pos]]) for var_id, bit_pos in op.expr.bit_posn.items()} + bitvar_val = { + var_id: int(bits_dict[args[bit_pos]]) + for var_id, bit_pos in op.expr.bit_posn.items() + } # Convert reg_posn to dictionary of `ClRegVar` index to its value - regvar_val = {var_id: from_little_endian([bits_dict[args[bit_pos]] for bit_pos in reg_pos_list]) for var_id, reg_pos_list in op.expr.reg_posn.items()} + regvar_val = { + var_id: from_little_endian( + [bits_dict[args[bit_pos]] for bit_pos in reg_pos_list] + ) + for var_id, reg_pos_list in op.expr.reg_posn.items() + } result = evaluate_clexpr(op.expr.expr, bitvar_val, regvar_val) # The result is an int in little-endian encoding. We update the @@ -90,7 +98,9 @@ def apply_classical_command( raise NotImplementedError(f"Commands of type {op.type} are not supported.") -def evaluate_clexpr(expr: ClExpr, bitvar_val: dict[int, int], regvar_val: dict[int, int]) -> int: +def evaluate_clexpr( + expr: ClExpr, bitvar_val: dict[int, int], regvar_val: dict[int, int] +) -> int: """Recursive evaluation of a ClExpr.""" # Evaluate arguments to operation diff --git a/tests/test_structured_state_conditionals.py b/tests/test_structured_state_conditionals.py index 8de0b1d3..53e90f00 100644 --- a/tests/test_structured_state_conditionals.py +++ b/tests/test_structured_state_conditionals.py @@ -26,6 +26,7 @@ # (see https://github.com/CQCL/pytket-qir/blob/main/tests/conditional_test.py) # Further down, there are tests to check that the simulation works correctly. + def test_circuit_with_clexpr_i() -> None: # test conditional handling @@ -35,11 +36,11 @@ def test_circuit_with_clexpr_i() -> None: c = circ.add_c_register("c", 5) d = circ.add_c_register("d", 5) circ.H(0) - wexpr, args = wired_clexpr_from_logic_exp(a | b, c) + wexpr, args = wired_clexpr_from_logic_exp(a | b, c) # type: ignore circ.add_clexpr(wexpr, args) - wexpr, args = wired_clexpr_from_logic_exp(c | b, d) + wexpr, args = wired_clexpr_from_logic_exp(c | b, d) # type: ignore circ.add_clexpr(wexpr, args) - wexpr, args = wired_clexpr_from_logic_exp(c | b, d) + wexpr, args = wired_clexpr_from_logic_exp(c | b, d) # type: ignore circ.add_clexpr(wexpr, args, condition=a[4]) circ.H(0) circ.Measure(Qubit(0), d[4]) @@ -82,6 +83,7 @@ def test_circuit_with_classicalexpbox_i() -> None: assert np.isclose(state.vdot(state), 1.0, atol=cfg._atol) assert state.get_fidelity() == 1.0 + def test_circuit_with_clexpr_ii() -> None: # test conditional handling with else case @@ -91,11 +93,11 @@ def test_circuit_with_clexpr_ii() -> None: c = circ.add_c_register("c", 5) d = circ.add_c_register("d", 5) circ.H(0) - wexpr, args = wired_clexpr_from_logic_exp(a | b, c) + wexpr, args = wired_clexpr_from_logic_exp(a | b, c) # type: ignore circ.add_clexpr(wexpr, args) - wexpr, args = wired_clexpr_from_logic_exp(c | b, d) + wexpr, args = wired_clexpr_from_logic_exp(c | b, d) # type: ignore circ.add_clexpr(wexpr, args) - wexpr, args = wired_clexpr_from_logic_exp(c | b, d) + wexpr, args = wired_clexpr_from_logic_exp(c | b, d) # type: ignore circ.add_clexpr(wexpr, args, condition=if_not_bit(a[4])) circ.H(0) circ.Measure(Qubit(0), d[4]) @@ -140,6 +142,7 @@ def test_circuit_with_classicalexpbox_ii() -> None: assert np.isclose(state.vdot(state), 1.0, atol=cfg._atol) assert state.get_fidelity() == 1.0 + @pytest.mark.skip(reason="Currently not supporting arithmetic operations in ClExpr") def test_circuit_with_clexpr_iii() -> None: # test complicated conditions and recursive classical op @@ -157,9 +160,9 @@ def test_circuit_with_clexpr_iii() -> None: big_exp = bits[4] | bits[5] ^ bits[6] | bits[7] & bits[8] circ.H(0, condition=big_exp) - wexpr, args = wired_clexpr_from_logic_exp(a + b - d, c) + wexpr, args = wired_clexpr_from_logic_exp(a + b - d, c) # type: ignore circ.add_clexpr(wexpr, args) - wexpr, args = wired_clexpr_from_logic_exp(a * b * d * c, e) + wexpr, args = wired_clexpr_from_logic_exp(a * b * d * c, e) # type: ignore circ.add_clexpr(wexpr, args) with CuTensorNetHandle() as libhandle: @@ -337,7 +340,7 @@ def test_pytket_qir_conditional_11() -> None: box_circ.H(0) - wexpr, args = wired_clexpr_from_logic_exp(box_c | box_c, box_c) + wexpr, args = wired_clexpr_from_logic_exp(box_c | box_c, box_c) # type: ignore box_circ.add_clexpr(wexpr, args) cbox = CircBox(box_circ) From 1353e880715d1aa1120cb1faa5ef499c97f77b58 Mon Sep 17 00:00:00 2001 From: PabloAndresCQ Date: Tue, 12 Nov 2024 11:14:34 +0000 Subject: [PATCH 3/5] Bumped pytket version dependency --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 23e033b2..0f4a3927 100644 --- a/setup.py +++ b/setup.py @@ -42,7 +42,7 @@ license="Apache 2", packages=find_namespace_packages(include=["pytket.*"]), include_package_data=True, - install_requires=["pytket >= 1.33.0", "networkx >= 2.8.8"], + install_requires=["pytket >= 1.34.0", "networkx >= 2.8.8"], classifiers=[ "Environment :: Console", "Programming Language :: Python :: 3.10", From d66eba35e97e32602e19e4eab364ff852e318f36 Mon Sep 17 00:00:00 2001 From: PabloAndresCQ Date: Tue, 12 Nov 2024 11:14:54 +0000 Subject: [PATCH 4/5] Updated changelog --- docs/changelog.rst | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/docs/changelog.rst b/docs/changelog.rst index db8f665a..fee08d40 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -3,6 +3,12 @@ Changelog ~~~~~~~~~ +Unreleased +---------- + +* Updated pytket version requirement to 1.34. +* Now supporting ``ClExpr`` operations (the new version of tket's ``ClassicalExpBox``). + 0.10.0 (October 2024) --------------------- From 5da5d1943e98ff9fa3fd614a2d46943cfa0639cd Mon Sep 17 00:00:00 2001 From: PabloAndresCQ Date: Tue, 12 Nov 2024 16:26:40 +0000 Subject: [PATCH 5/5] Added suggestions from code review --- .../cutensornet/structured_state/classical.py | 69 ++++++++++----- tests/test_structured_state_conditionals.py | 87 ++++++++++++++----- 2 files changed, 111 insertions(+), 45 deletions(-) diff --git a/pytket/extensions/cutensornet/structured_state/classical.py b/pytket/extensions/cutensornet/structured_state/classical.py index d7bc9c99..474529ab 100644 --- a/pytket/extensions/cutensornet/structured_state/classical.py +++ b/pytket/extensions/cutensornet/structured_state/classical.py @@ -71,14 +71,26 @@ def apply_classical_command( ) for var_id, reg_pos_list in op.expr.reg_posn.items() } - result = evaluate_clexpr(op.expr.expr, bitvar_val, regvar_val) + # Identify number of bits on each register + regvar_size = { + var_id: len(reg_pos_list) + for var_id, reg_pos_list in op.expr.reg_posn.items() + } + # Identify number of bits in output register + output_size = len(op.expr.output_posn) + result = evaluate_clexpr( + op.expr.expr, bitvar_val, regvar_val, regvar_size, output_size + ) # The result is an int in little-endian encoding. We update the # output register accordingly. for bit_pos in op.expr.output_posn: bits_dict[args[bit_pos]] = (result % 2) == 1 result = result >> 1 - assert result == 0 # All bits consumed + # If there has been overflow in the operations, error out. + # This can be detected if `result != 0` + if result != 0: + raise ValueError("Evaluation of the ClExpr resulted in overflow.") elif isinstance(op, ClassicalExpBox): the_exp = op.get_exp() @@ -99,7 +111,11 @@ def apply_classical_command( def evaluate_clexpr( - expr: ClExpr, bitvar_val: dict[int, int], regvar_val: dict[int, int] + expr: ClExpr, + bitvar_val: dict[int, int], + regvar_val: dict[int, int], + regvar_size: dict[int, int], + output_size: int, ) -> int: """Recursive evaluation of a ClExpr.""" @@ -113,7 +129,9 @@ def evaluate_clexpr( elif isinstance(arg, ClRegVar): value = regvar_val[arg.index] elif isinstance(arg, ClExpr): - value = evaluate_clexpr(arg, bitvar_val, regvar_val) + value = evaluate_clexpr( + arg, bitvar_val, regvar_val, regvar_size, output_size + ) else: raise Exception(f"Unrecognised argument type of ClExpr: {type(arg)}.") @@ -140,31 +158,39 @@ def evaluate_clexpr( result = int(args_val[0] < args_val[1]) elif expr.op == ClOp.BitNot: result = 1 - args_val[0] - # elif expr.op == ClOp.RegNot: - # result = int(args_val[0] == 0) + elif expr.op == ClOp.RegNot: # Bit-wise NOT (flip all bits) + n_bits = regvar_size[expr.args[0].index] # type: ignore + result = (2**n_bits - 1) ^ args_val[0] # XOR with all 1s bitstring elif expr.op in [ClOp.BitZero, ClOp.RegZero]: result = 0 - elif expr.op in [ClOp.BitOne, ClOp.RegOne]: + elif expr.op == ClOp.BitOne: result = 1 - # elif expr.op == ClOp.RegAdd: - # result = args_val[0] + args_val[1] - # elif expr.op == ClOp.RegSub: - # result = args_val[0] - args_val[1] - # elif expr.op == ClOp.RegMul: - # result = args_val[0] * args_val[1] - # elif expr.op == ClOp.RegPow: - # result = int(args_val[0] ** args_val[1]) + elif expr.op == ClOp.RegOne: # All 1s bitstring + n_bits = output_size + result = 2**n_bits - 1 + elif expr.op == ClOp.RegAdd: + result = args_val[0] + args_val[1] + elif expr.op == ClOp.RegSub: + if args_val[0] < args_val[1]: + raise NotImplementedError( + "Currently not supporting ClOp.RegSub where the outcome is negative." + ) + result = args_val[0] - args_val[1] + elif expr.op == ClOp.RegMul: + result = args_val[0] * args_val[1] + elif expr.op == ClOp.RegDiv: # floor(a / b) + result = args_val[0] // args_val[1] + elif expr.op == ClOp.RegPow: + result = int(args_val[0] ** args_val[1]) + elif expr.op == ClOp.RegLsh: + result = args_val[0] << args_val[1] elif expr.op == ClOp.RegRsh: result = args_val[0] >> args_val[1] # elif expr.op == ClOp.RegNeg: # result = -args_val[0] else: - # TODO: Currently not supporting ClOp's RegDiv since it does not return int, - # so I am unsure what the semantic is meant to be. - # TODO: I don't now what to do with RegNot, since input - # is not guaranteed to be 0 or 1. - # TODO: It is not clear what to do with overflow of ADD, etc. - # so I have decided to not support them for now. + # TODO: Not supporting RegNeg because I do not know if we have agreed how to + # specify signed ints. raise NotImplementedError( f"Evaluation of {expr.op} not supported in ClExpr ", "by pytket-cutensornet.", @@ -231,4 +257,5 @@ def evaluate_logic_exp(exp: ExtendedLogicExp, bits_dict: dict[Bit, bool]) -> int def from_little_endian(bitstring: list[bool]) -> int: """Obtain the integer from the little-endian encoded bitstring (i.e. bitstring [False, True] is interpreted as the integer 2).""" + # TODO: Assumes unisigned integer. What are the specs for signed integers? return sum(1 << i for i, b in enumerate(bitstring) if b) diff --git a/tests/test_structured_state_conditionals.py b/tests/test_structured_state_conditionals.py index 53e90f00..3fdaa177 100644 --- a/tests/test_structured_state_conditionals.py +++ b/tests/test_structured_state_conditionals.py @@ -10,6 +10,9 @@ Bit, if_not_bit, reg_eq, + WiredClExpr, + ClExpr, + ClOp, ) from pytket.circuit.logic_exp import BitWiseOp, create_bit_logic_exp from pytket.circuit.clexpr import wired_clexpr_from_logic_exp @@ -36,11 +39,11 @@ def test_circuit_with_clexpr_i() -> None: c = circ.add_c_register("c", 5) d = circ.add_c_register("d", 5) circ.H(0) - wexpr, args = wired_clexpr_from_logic_exp(a | b, c) # type: ignore + wexpr, args = wired_clexpr_from_logic_exp(a | b, c.to_list()) circ.add_clexpr(wexpr, args) - wexpr, args = wired_clexpr_from_logic_exp(c | b, d) # type: ignore + wexpr, args = wired_clexpr_from_logic_exp(c | b, d.to_list()) circ.add_clexpr(wexpr, args) - wexpr, args = wired_clexpr_from_logic_exp(c | b, d) # type: ignore + wexpr, args = wired_clexpr_from_logic_exp(c | b, d.to_list()) circ.add_clexpr(wexpr, args, condition=a[4]) circ.H(0) circ.Measure(Qubit(0), d[4]) @@ -66,9 +69,9 @@ def test_circuit_with_classicalexpbox_i() -> None: c = circ.add_c_register("c", 5) d = circ.add_c_register("d", 5) circ.H(0) - circ.add_classicalexpbox_register(a | b, c) # type: ignore - circ.add_classicalexpbox_register(c | b, d) # type: ignore - circ.add_classicalexpbox_register(c | b, d, condition=a[4]) # type: ignore + circ.add_classicalexpbox_register(a | b, c.to_list()) + circ.add_classicalexpbox_register(c | b, d.to_list()) + circ.add_classicalexpbox_register(c | b, d.to_list(), condition=a[4]) circ.H(0) circ.Measure(Qubit(0), d[4]) circ.H(1) @@ -93,11 +96,11 @@ def test_circuit_with_clexpr_ii() -> None: c = circ.add_c_register("c", 5) d = circ.add_c_register("d", 5) circ.H(0) - wexpr, args = wired_clexpr_from_logic_exp(a | b, c) # type: ignore + wexpr, args = wired_clexpr_from_logic_exp(a | b, c.to_list()) circ.add_clexpr(wexpr, args) - wexpr, args = wired_clexpr_from_logic_exp(c | b, d) # type: ignore + wexpr, args = wired_clexpr_from_logic_exp(c | b, d.to_list()) circ.add_clexpr(wexpr, args) - wexpr, args = wired_clexpr_from_logic_exp(c | b, d) # type: ignore + wexpr, args = wired_clexpr_from_logic_exp(c | b, d.to_list()) circ.add_clexpr(wexpr, args, condition=if_not_bit(a[4])) circ.H(0) circ.Measure(Qubit(0), d[4]) @@ -123,11 +126,9 @@ def test_circuit_with_classicalexpbox_ii() -> None: c = circ.add_c_register("c", 5) d = circ.add_c_register("d", 5) circ.H(0) - circ.add_classicalexpbox_register(a | b, c) # type: ignore - circ.add_classicalexpbox_register(c | b, d) # type: ignore - circ.add_classicalexpbox_register( - c | b, d, condition=if_not_bit(a[4]) # type: ignore - ) + circ.add_classicalexpbox_register(a | b, c.to_list()) + circ.add_classicalexpbox_register(c | b, d.to_list()) + circ.add_classicalexpbox_register(c | b, d.to_list(), condition=if_not_bit(a[4])) circ.H(0) circ.Measure(Qubit(0), d[4]) circ.H(1) @@ -160,9 +161,9 @@ def test_circuit_with_clexpr_iii() -> None: big_exp = bits[4] | bits[5] ^ bits[6] | bits[7] & bits[8] circ.H(0, condition=big_exp) - wexpr, args = wired_clexpr_from_logic_exp(a + b - d, c) # type: ignore + wexpr, args = wired_clexpr_from_logic_exp(a + b - d, c.to_list()) circ.add_clexpr(wexpr, args) - wexpr, args = wired_clexpr_from_logic_exp(a * b * d * c, e) # type: ignore + wexpr, args = wired_clexpr_from_logic_exp(a * b * d * c, e.to_list()) circ.add_clexpr(wexpr, args) with CuTensorNetHandle() as libhandle: @@ -190,8 +191,8 @@ def test_circuit_with_classicalexpbox_iii() -> None: big_exp = bits[4] | bits[5] ^ bits[6] | bits[7] & bits[8] circ.H(0, condition=big_exp) - circ.add_classicalexpbox_register(a + b - d, c) # type: ignore - circ.add_classicalexpbox_register(a * b * d * c, e) # type: ignore + circ.add_classicalexpbox_register(a + b - d, c.to_list()) + circ.add_classicalexpbox_register(a * b * d * c, e.to_list()) with CuTensorNetHandle() as libhandle: cfg = Config() @@ -268,7 +269,7 @@ def test_circuit_with_conditional_gate_iv() -> None: assert state.get_fidelity() == 1.0 -def test_pytket_qir_conditional_8() -> None: +def test_pytket_basic_conditional_i() -> None: c = Circuit(4) c.H(0) c.H(1) @@ -287,7 +288,7 @@ def test_pytket_qir_conditional_8() -> None: assert state.get_fidelity() == 1.0 -def test_pytket_qir_conditional_9() -> None: +def test_pytket_basic_conditional_ii() -> None: c = Circuit(4) c.X(0) c.Y(1) @@ -306,7 +307,7 @@ def test_pytket_qir_conditional_9() -> None: assert state.get_fidelity() == 1.0 -def test_pytket_qir_conditional_10() -> None: +def test_pytket_basic_conditional_iii_classicalexpbox() -> None: box_circ = Circuit(4) box_circ.X(0) box_circ.Y(1) @@ -315,7 +316,7 @@ def test_pytket_qir_conditional_10() -> None: box_c = box_circ.add_c_register("c", 5) box_circ.H(0) - box_circ.add_classicalexpbox_register(box_c | box_c, box_c) # type: ignore + box_circ.add_classicalexpbox_register(box_c | box_c, box_c.to_list()) cbox = CircBox(box_circ) d = Circuit(4, 5) @@ -330,7 +331,7 @@ def test_pytket_qir_conditional_10() -> None: assert state.get_fidelity() == 1.0 -def test_pytket_qir_conditional_11() -> None: +def test_pytket_basic_conditional_iii_clexpr() -> None: box_circ = Circuit(4) box_circ.X(0) box_circ.Y(1) @@ -340,7 +341,7 @@ def test_pytket_qir_conditional_11() -> None: box_circ.H(0) - wexpr, args = wired_clexpr_from_logic_exp(box_c | box_c, box_c) # type: ignore + wexpr, args = wired_clexpr_from_logic_exp(box_c | box_c, box_c.to_list()) box_circ.add_clexpr(wexpr, args) cbox = CircBox(box_circ) @@ -697,3 +698,41 @@ def test_repeat_until_success_ii_classicalexpblox() -> None: assert np.isclose(abs(global_phase), 1.0) output_state *= global_phase assert np.allclose(target_state, output_state) + + +def test_clexpr_on_regs() -> None: + """Non-exhaustive test on some ClOp on registers.""" + circ = Circuit(2) + a = circ.add_c_register("a", 5) + b = circ.add_c_register("b", 5) + c = circ.add_c_register("c", 5) + d = circ.add_c_register("d", 5) + e = circ.add_c_register("e", 5) + + w_expr_regone = WiredClExpr(ClExpr(ClOp.RegOne, []), output_posn=list(range(5))) + circ.add_clexpr(w_expr_regone, a.to_list()) # a = 0b11111 = 31 + circ.add_c_setbits([True, True, False, False, False], b.to_list()) # b = 3 + circ.add_c_setbits([False, True, False, True, False], c.to_list()) # c = 10 + circ.add_clexpr(*wired_clexpr_from_logic_exp(b | c, d.to_list())) # d = 11 + circ.add_clexpr(*wired_clexpr_from_logic_exp(a - d, e.to_list())) # e = 20 + + with CuTensorNetHandle() as libhandle: + cfg = Config() + + state = simulate(libhandle, circ, SimulationAlgorithm.MPSxGate, cfg) + assert state.is_valid() + assert np.isclose(state.vdot(state), 1.0, atol=cfg._atol) + assert state.get_fidelity() == 1.0 + + # Check the bits + bits_dict = state.get_bits() + a_bitstring = list(bits_dict[bit] for bit in a) + assert all(a_bitstring) # a = 0b11111 + b_bitstring = list(bits_dict[bit] for bit in b) + assert b_bitstring == [True, True, False, False, False] # b = 0b11000 + c_bitstring = list(bits_dict[bit] for bit in c) + assert c_bitstring == [False, True, False, True, False] # c = 0b01010 + d_bitstring = list(bits_dict[bit] for bit in d) + assert d_bitstring == [True, True, False, True, False] # d = 0b11010 + e_bitstring = list(bits_dict[bit] for bit in e) + assert e_bitstring == [False, False, True, False, True] # e = 0b00101