Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support ClExpr #176

Merged
merged 5 commits into from
Nov 12, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 99 additions & 0 deletions pytket/extensions/cutensornet/structured_state/classical.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,13 @@
SetBitsOp,
CopyBitsOp,
RangePredicateOp,
ClExprOp,
ClassicalExpBox,
LogicExp,
BitWiseOp,
RegWiseOp,
)
from pytket._tket.circuit import ClExpr, ClOp, ClBitVar, ClRegVar


ExtendedLogicExp = Union[LogicExp, Bit, BitRegister, int]
Expand Down Expand Up @@ -56,6 +58,28 @@ def apply_classical_command(
# Check that the value is in the range
bits_dict[res_bit] = val >= op.lower and val <= op.upper

elif isinstance(op, ClExprOp):
# Convert bit_posn to dictionary of `ClBitVar` index to its value
bitvar_val = {
var_id: int(bits_dict[args[bit_pos]])
for var_id, bit_pos in op.expr.bit_posn.items()
}
# Convert reg_posn to dictionary of `ClRegVar` index to its value
regvar_val = {
var_id: from_little_endian(
[bits_dict[args[bit_pos]] for bit_pos in reg_pos_list]
)
for var_id, reg_pos_list in op.expr.reg_posn.items()
}
result = evaluate_clexpr(op.expr.expr, bitvar_val, regvar_val)

# The result is an int in little-endian encoding. We update the
# output register accordingly.
for bit_pos in op.expr.output_posn:
bits_dict[args[bit_pos]] = (result % 2) == 1
result = result >> 1
assert result == 0 # All bits consumed

elif isinstance(op, ClassicalExpBox):
the_exp = op.get_exp()
result = evaluate_logic_exp(the_exp, bits_dict)
Expand All @@ -74,6 +98,81 @@ def apply_classical_command(
raise NotImplementedError(f"Commands of type {op.type} are not supported.")


def evaluate_clexpr(
expr: ClExpr, bitvar_val: dict[int, int], regvar_val: dict[int, int]
) -> int:
"""Recursive evaluation of a ClExpr."""

# Evaluate arguments to operation
args_val = []
for arg in expr.args:
if isinstance(arg, int):
value = arg
elif isinstance(arg, ClBitVar):
value = bitvar_val[arg.index]
elif isinstance(arg, ClRegVar):
value = regvar_val[arg.index]
elif isinstance(arg, ClExpr):
value = evaluate_clexpr(arg, bitvar_val, regvar_val)
else:
raise Exception(f"Unrecognised argument type of ClExpr: {type(arg)}.")

args_val.append(value)

# Apply the operation at the root of this ClExpr
if expr.op in [ClOp.BitAnd, ClOp.RegAnd]:
result = args_val[0] & args_val[1]
elif expr.op in [ClOp.BitOr, ClOp.RegOr]:
result = args_val[0] | args_val[1]
elif expr.op in [ClOp.BitXor, ClOp.RegXor]:
result = args_val[0] ^ args_val[1]
elif expr.op in [ClOp.BitEq, ClOp.RegEq]:
result = int(args_val[0] == args_val[1])
elif expr.op in [ClOp.BitNeq, ClOp.RegNeq]:
result = int(args_val[0] != args_val[1])
elif expr.op == ClOp.RegGeq:
result = int(args_val[0] >= args_val[1])
elif expr.op == ClOp.RegGt:
result = int(args_val[0] > args_val[1])
elif expr.op == ClOp.RegLeq:
result = int(args_val[0] <= args_val[1])
elif expr.op == ClOp.RegLt:
result = int(args_val[0] < args_val[1])
elif expr.op == ClOp.BitNot:
result = 1 - args_val[0]
# elif expr.op == ClOp.RegNot:
# result = int(args_val[0] == 0)
elif expr.op in [ClOp.BitZero, ClOp.RegZero]:
result = 0
elif expr.op in [ClOp.BitOne, ClOp.RegOne]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The intended semantics of RegOne is that every bit is set to 1.

result = 1
# elif expr.op == ClOp.RegAdd:
# result = args_val[0] + args_val[1]
# elif expr.op == ClOp.RegSub:
# result = args_val[0] - args_val[1]
# elif expr.op == ClOp.RegMul:
# result = args_val[0] * args_val[1]
# elif expr.op == ClOp.RegPow:
# result = int(args_val[0] ** args_val[1])
elif expr.op == ClOp.RegRsh:
result = args_val[0] >> args_val[1]
# elif expr.op == ClOp.RegNeg:
# result = -args_val[0]
else:
# TODO: Currently not supporting ClOp's RegDiv since it does not return int,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it should return the integer quotient, i.e. floor(a/b) where a and b are unsigned integers. (If this doesn't fit in the result register, perhaps error is the kindest response.)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But if we are leaving ADD etc unsupported for now I don't see a need to support DIV.

# so I am unsure what the semantic is meant to be.
# TODO: I don't now what to do with RegNot, since input
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The intended semantics of RegNot is bitwise-NOT, i.e. every bit of the register is flipped.

# is not guaranteed to be 0 or 1.
# TODO: It is not clear what to do with overflow of ADD, etc.
# so I have decided to not support them for now.
raise NotImplementedError(
f"Evaluation of {expr.op} not supported in ClExpr ",
"by pytket-cutensornet.",
)

return result


def evaluate_logic_exp(exp: ExtendedLogicExp, bits_dict: dict[Bit, bool]) -> int:
"""Recursive evaluation of a LogicExp."""

Expand Down
196 changes: 195 additions & 1 deletion tests/test_structured_state_conditionals.py
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All of the new tests in this file are just copy-pastes of existing tests using ClassicalExpBox, but now using ClExpr. It would be nice to add more tests, because these are minimal and only one of them is checking for correctness (and only on bit operations, rather than register ones).

Do you know of any tests from other repositories that could be used here? Any suggestions of where to get circuits to test on, for which the intended behaviour is known (and, hence, can be checked against)?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not really. There's a test here for the quantinuum local emulator, but it uses ClassicalExpBox with operations that are not supported here, so isn't much use.

Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
reg_eq,
)
from pytket.circuit.logic_exp import BitWiseOp, create_bit_logic_exp
from pytket.circuit.clexpr import wired_clexpr_from_logic_exp

from pytket.extensions.cutensornet.structured_state import (
CuTensorNetHandle,
Expand All @@ -26,6 +27,36 @@
# Further down, there are tests to check that the simulation works correctly.


def test_circuit_with_clexpr_i() -> None:
# test conditional handling

circ = Circuit(3)
a = circ.add_c_register("a", 5)
b = circ.add_c_register("b", 5)
c = circ.add_c_register("c", 5)
d = circ.add_c_register("d", 5)
circ.H(0)
wexpr, args = wired_clexpr_from_logic_exp(a | b, c) # type: ignore
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had to add # type: ignore here because it complains that the second argument c is a BitRegister, rather than list[Bit]. Is there a way around it? Is this a valid implicit cast, or should it be avoided?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could write c.to_list().

circ.add_clexpr(wexpr, args)
wexpr, args = wired_clexpr_from_logic_exp(c | b, d) # type: ignore
circ.add_clexpr(wexpr, args)
wexpr, args = wired_clexpr_from_logic_exp(c | b, d) # type: ignore
circ.add_clexpr(wexpr, args, condition=a[4])
circ.H(0)
circ.Measure(Qubit(0), d[4])
circ.H(1)
circ.Measure(Qubit(1), d[3])
circ.H(2)
circ.Measure(Qubit(2), d[2])

with CuTensorNetHandle() as libhandle:
cfg = Config()
state = simulate(libhandle, circ, SimulationAlgorithm.MPSxGate, cfg)
assert state.is_valid()
assert np.isclose(state.vdot(state), 1.0, atol=cfg._atol)
assert state.get_fidelity() == 1.0


def test_circuit_with_classicalexpbox_i() -> None:
# test conditional handling

Expand Down Expand Up @@ -53,6 +84,36 @@ def test_circuit_with_classicalexpbox_i() -> None:
assert state.get_fidelity() == 1.0


def test_circuit_with_clexpr_ii() -> None:
# test conditional handling with else case

circ = Circuit(3)
a = circ.add_c_register("a", 5)
b = circ.add_c_register("b", 5)
c = circ.add_c_register("c", 5)
d = circ.add_c_register("d", 5)
circ.H(0)
wexpr, args = wired_clexpr_from_logic_exp(a | b, c) # type: ignore
circ.add_clexpr(wexpr, args)
wexpr, args = wired_clexpr_from_logic_exp(c | b, d) # type: ignore
circ.add_clexpr(wexpr, args)
wexpr, args = wired_clexpr_from_logic_exp(c | b, d) # type: ignore
circ.add_clexpr(wexpr, args, condition=if_not_bit(a[4]))
circ.H(0)
circ.Measure(Qubit(0), d[4])
circ.H(1)
circ.Measure(Qubit(1), d[3])
circ.H(2)
circ.Measure(Qubit(2), d[2])

with CuTensorNetHandle() as libhandle:
cfg = Config()
state = simulate(libhandle, circ, SimulationAlgorithm.MPSxGate, cfg)
assert state.is_valid()
assert np.isclose(state.vdot(state), 1.0, atol=cfg._atol)
assert state.get_fidelity() == 1.0


def test_circuit_with_classicalexpbox_ii() -> None:
# test conditional handling with else case

Expand Down Expand Up @@ -82,6 +143,36 @@ def test_circuit_with_classicalexpbox_ii() -> None:
assert state.get_fidelity() == 1.0


@pytest.mark.skip(reason="Currently not supporting arithmetic operations in ClExpr")
def test_circuit_with_clexpr_iii() -> None:
# test complicated conditions and recursive classical op

circ = Circuit(2)

a = circ.add_c_register("a", 15)
b = circ.add_c_register("b", 15)
c = circ.add_c_register("c", 15)
d = circ.add_c_register("d", 15)
e = circ.add_c_register("e", 15)

circ.H(0)
bits = [Bit(i) for i in range(10)]
big_exp = bits[4] | bits[5] ^ bits[6] | bits[7] & bits[8]
circ.H(0, condition=big_exp)

wexpr, args = wired_clexpr_from_logic_exp(a + b - d, c) # type: ignore
circ.add_clexpr(wexpr, args)
wexpr, args = wired_clexpr_from_logic_exp(a * b * d * c, e) # type: ignore
circ.add_clexpr(wexpr, args)

with CuTensorNetHandle() as libhandle:
cfg = Config()
state = simulate(libhandle, circ, SimulationAlgorithm.MPSxGate, cfg)
assert state.is_valid()
assert np.isclose(state.vdot(state), 1.0, atol=cfg._atol)
assert state.get_fidelity() == 1.0


@pytest.mark.skip(reason="Currently not supporting arithmetic operations in LogicExp")
def test_circuit_with_classicalexpbox_iii() -> None:
# test complicated conditions and recursive classical op
Expand Down Expand Up @@ -239,6 +330,32 @@ def test_pytket_qir_conditional_10() -> None:
assert state.get_fidelity() == 1.0


def test_pytket_qir_conditional_11() -> None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why this name?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I grabbed these tests from the pytket_qir repository (here) and did not notice that the qir in the name was no longer applicable here. Happy to rename it them test_basic_conditional_x.

box_circ = Circuit(4)
box_circ.X(0)
box_circ.Y(1)
box_circ.Z(2)
box_circ.H(3)
box_c = box_circ.add_c_register("c", 5)

box_circ.H(0)

wexpr, args = wired_clexpr_from_logic_exp(box_c | box_c, box_c) # type: ignore
box_circ.add_clexpr(wexpr, args)

cbox = CircBox(box_circ)
d = Circuit(4, 5)
a = d.add_c_register("a", 4)
d.add_circbox(cbox, [0, 2, 1, 3, 0, 1, 2, 3, 4], condition=a[0])

with CuTensorNetHandle() as libhandle:
cfg = Config()
state = simulate(libhandle, d, SimulationAlgorithm.MPSxGate, cfg)
assert state.is_valid()
assert np.isclose(state.vdot(state), 1.0, atol=cfg._atol)
assert state.get_fidelity() == 1.0


def test_circuit_with_conditional_gate_v() -> None:
# test conditional with no register

Expand Down Expand Up @@ -430,7 +547,84 @@ def test_repeat_until_success_i() -> None:
assert np.allclose(target_state, output_state)


def test_repeat_until_success_ii() -> None:
def test_repeat_until_success_ii_clexpr() -> None:
# From Figure 1(c) of https://arxiv.org/pdf/1311.1074

attempts = 100

circ = Circuit()
qin = circ.add_q_register("qin", 1)
qaux = circ.add_q_register("aux", 2)
flag = circ.add_c_register("flag", 3)
circ.add_c_setbits([True, True], [flag[0], flag[1]]) # Set flag bits to 11
circ.H(qin[0]) # Use to convert gate to sqrt(1/5)*I + i*sqrt(4/5)*X (i.e. Z -> X)

for _ in range(attempts):
wexpr, args = wired_clexpr_from_logic_exp(
flag[0] | flag[1], [flag[2]] # Success if both are zero
)
circ.add_clexpr(wexpr, args)

circ.add_gate(
OpType.Reset, [qaux[0]], condition_bits=[flag[2]], condition_value=1
)
circ.add_gate(
OpType.Reset, [qaux[1]], condition_bits=[flag[2]], condition_value=1
)
circ.add_gate(OpType.H, [qaux[0]], condition_bits=[flag[2]], condition_value=1)
circ.add_gate(OpType.H, [qaux[1]], condition_bits=[flag[2]], condition_value=1)

circ.add_gate(OpType.T, [qin[0]], condition_bits=[flag[2]], condition_value=1)
circ.add_gate(OpType.Z, [qin[0]], condition_bits=[flag[2]], condition_value=1)
circ.add_gate(
OpType.Tdg, [qaux[0]], condition_bits=[flag[2]], condition_value=1
)
circ.add_gate(
OpType.CX, [qaux[1], qaux[0]], condition_bits=[flag[2]], condition_value=1
)
circ.add_gate(OpType.T, [qaux[0]], condition_bits=[flag[2]], condition_value=1)
circ.add_gate(
OpType.CX, [qin[0], qaux[1]], condition_bits=[flag[2]], condition_value=1
)
circ.add_gate(OpType.T, [qaux[1]], condition_bits=[flag[2]], condition_value=1)

circ.add_gate(OpType.H, [qaux[0]], condition_bits=[flag[2]], condition_value=1)
circ.add_gate(OpType.H, [qaux[1]], condition_bits=[flag[2]], condition_value=1)
circ.Measure(qaux[0], flag[0], condition_bits=[flag[2]], condition_value=1)
circ.Measure(qaux[1], flag[1], condition_bits=[flag[2]], condition_value=1)

# From chat with Silas and exploring the RUS as a block matrix, we have noticed
# that the circuit is missing an X correction when this condition is satisfied
Comment on lines +597 to +598
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this issue tracked somewhere?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This RUS circuit comes from a paper where they acknowledge that a "recovery operation" (i.e. correction) is required in some cases for RUS. As far as I can tell, they don't explicitly indicate what is the recovery operation required for this particular circuit (appearing in Fig 1c), but we figured it was an X correction.

AFAIK this is not tracked anywhere, but is known by people with experience on RUS.

wexpr, args = wired_clexpr_from_logic_exp(flag[0] ^ flag[1], [flag[2]])
circ.add_clexpr(wexpr, args)
Comment on lines +599 to +600
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reading these tests makes me think we should add a method Circuit.add_clexpr_from_logic_exp() to combine these two steps.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, could be a handy addition, but not strictly necessary.

circ.add_gate(OpType.Z, [qin[0]], condition_bits=[flag[2]], condition_value=1)

circ.H(qin[0]) # Use to convert gate to sqrt(1/5)*I + i*sqrt(4/5)*X (i.e. Z -> X)

with CuTensorNetHandle() as libhandle:
cfg = Config()

state = simulate(libhandle, circ, SimulationAlgorithm.MPSxGate, cfg)
assert state.is_valid()
assert np.isclose(state.vdot(state), 1.0, atol=cfg._atol)
assert state.get_fidelity() == 1.0

# All of the flag bits should have turned False
assert all(not state.get_bits()[bit] for bit in flag)
# The auxiliary qubits should be in state |0>
prob = state.postselect({qaux[0]: 0, qaux[1]: 0})
assert np.isclose(prob, 1.0)

target_state = [np.sqrt(1 / 5), np.sqrt(4 / 5) * 1j]
output_state = state.get_statevector()
# As indicated in the paper, the gate is implemented up to global phase
global_phase = target_state[0] / output_state[0]
assert np.isclose(abs(global_phase), 1.0)
output_state *= global_phase
assert np.allclose(target_state, output_state)


def test_repeat_until_success_ii_classicalexpblox() -> None:
# From Figure 1(c) of https://arxiv.org/pdf/1311.1074

attempts = 100
Expand Down
Loading