Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update Density Matrix and State Vector Simulators to work when an operation allocates new qubits as part of its decomposition #6108

Merged
merged 25 commits into from
Jun 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
c2ff8bd
WIP add factoring and kron methods to sim state for adding and removi…
senecameeks May 24, 2023
e673105
add test cases
senecameeks May 26, 2023
80fbbbc
add delegating gate test case
senecameeks May 26, 2023
5f15d97
update test
senecameeks May 26, 2023
cd7a573
all tests pass
senecameeks May 26, 2023
1cd81dd
add test case for unitary Y
senecameeks May 26, 2023
d9a46cc
nit
senecameeks May 26, 2023
908ee77
addresses PR comments by adding empty checks. Applys formatter. Subse…
senecameeks Jun 1, 2023
c903d32
nit formatting changes, add docustring with input/output for remove_q…
senecameeks Jun 1, 2023
b92d6d8
Merge branch 'master' of https://github.com/quantumlib/cirq
senecameeks Jun 2, 2023
70162d7
Merge branch 'master' into master
tanujkhattar Jun 5, 2023
530a69e
merge this branch and tanujkhattar@ccde689
senecameeks Jun 6, 2023
1be80c8
merging branches, adding test coverage in next push
senecameeks Jun 6, 2023
2b06182
Merge branch 'master' of https://github.com/quantumlib/cirq
senecameeks Jun 6, 2023
4adf75b
Merge branch 'master' of github.com:senecameeks/Cirq
senecameeks Jun 6, 2023
f70753b
format files
senecameeks Jun 6, 2023
f74f760
add coverage tests
senecameeks Jun 6, 2023
5d31ce3
change assert
senecameeks Jun 7, 2023
096fc14
coverage and type check tests should pass
senecameeks Jun 7, 2023
964dc69
incorporate tanujkhattar@1db8ac5
senecameeks Jun 7, 2023
40d5b33
nit
senecameeks Jun 7, 2023
0ab07c2
Merge branch 'master' into master
tanujkhattar Jun 7, 2023
40369ee
remove block comment
senecameeks Jun 7, 2023
ddd6fd9
Merge branch 'master' of github.com:senecameeks/Cirq
senecameeks Jun 7, 2023
774d715
add coverage
senecameeks Jun 7, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cirq-core/cirq/protocols/act_on_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def act_on(

arg_fallback = getattr(sim_state, '_act_on_fallback_', None)
if arg_fallback is not None:
qubits = action.qubits if isinstance(action, ops.Operation) else qubits
qubits = action.qubits if is_op else qubits
result = arg_fallback(action, qubits=qubits, allow_decompose=allow_decompose)
if result is True:
return
Expand Down
16 changes: 16 additions & 0 deletions cirq-core/cirq/sim/density_matrix_simulation_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,22 @@ def __init__(
)
super().__init__(state=state, prng=prng, qubits=qubits, classical_data=classical_data)

def add_qubits(self, qubits: Sequence['cirq.Qid']):
ret = super().add_qubits(qubits)
return (
self.kronecker_product(type(self)(qubits=qubits), inplace=True)
if ret is NotImplemented
else ret
)

def remove_qubits(self, qubits: Sequence['cirq.Qid']):
ret = super().remove_qubits(qubits)
if ret is not NotImplemented:
return ret
extracted, remainder = self.factor(qubits)
remainder._state._density_matrix *= extracted._state._density_matrix.reshape(-1)[0]
return remainder

def _act_on_fallback_(
self, action: Any, qubits: Sequence['cirq.Qid'], allow_decompose: bool = True
) -> bool:
Expand Down
12 changes: 12 additions & 0 deletions cirq-core/cirq/sim/density_matrix_simulation_state_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,3 +123,15 @@ def test_initial_state_bad_shape():
cirq.DensityMatrixSimulationState(
qubits=qubits, initial_state=np.full((2, 2, 2, 2), 1 / 4), dtype=np.complex64
)


def test_remove_qubits():
"""Test the remove_qubits method."""
q1 = cirq.LineQubit(0)
q2 = cirq.LineQubit(1)
state = cirq.DensityMatrixSimulationState(qubits=[q1, q2])

new_state = state.remove_qubits([q1])

assert len(new_state.qubits) == 1
assert q1 not in new_state.qubits
47 changes: 45 additions & 2 deletions cirq-core/cirq/sim/simulation_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,38 @@ def create_merged_state(self) -> Self:
"""Creates a final merged state."""
return self

def add_qubits(self: Self, qubits: Sequence['cirq.Qid']):
"""Add qubits to a new state space and take the kron product.

Note that only Density Matrix and State Vector simulators
override this function.

Args:
qubits: Sequence of qubits to be added.

Returns:
NotImplemented: If the subclass does not implement this method.

Raises:
ValueError: If a qubit being added is already tracked.
"""
if any(q in self.qubits for q in qubits):
raise ValueError(f"Qubit to add {qubits} should not already be tracked.")
return NotImplemented

def remove_qubits(self: Self, qubits: Sequence['cirq.Qid']) -> Self:
"""Remove qubits from the state space.

Args:
qubits: Sequence of qubits to be added.

Returns:
A new Simulation State with qubits removed. Or
`self` if there are no qubits to remove."""
if qubits is None or not qubits:
return self
return NotImplemented

def kronecker_product(self, other: Self, *, inplace=False) -> Self:
"""Joins two state spaces together."""
args = self if inplace else copy.copy(self)
Expand Down Expand Up @@ -294,13 +326,24 @@ def strat_act_on_from_apply_decompose(
val: Any, args: 'cirq.SimulationState', qubits: Sequence['cirq.Qid']
) -> bool:
operations, qubits1, _ = _try_decompose_into_operations_and_qubits(val)
assert len(qubits1) == len(qubits)
qubit_map = {q: qubits[i] for i, q in enumerate(qubits1)}
if operations is None:
return NotImplemented
assert len(qubits1) == len(qubits)
all_qubits = frozenset([q for op in operations for q in op.qubits])
qubit_map = dict(zip(all_qubits, all_qubits))
qubit_map.update(dict(zip(qubits1, qubits)))
new_ancilla = tuple(q for q in sorted(all_qubits.difference(qubits)) if q not in args.qubits)
args = args.add_qubits(new_ancilla)
if args is NotImplemented:
return NotImplemented
for operation in operations:
operation = operation.with_qubits(*[qubit_map[q] for q in operation.qubits])
protocols.act_on(operation, args)
args = args.remove_qubits(new_ancilla)
if args is NotImplemented: # coverage: ignore
raise TypeError( # coverage: ignore
f"{type(args)} implements `add_qubits` but not `remove_qubits`." # coverage: ignore
) # coverage: ignore
return True


Expand Down
176 changes: 165 additions & 11 deletions cirq-core/cirq/sim/simulation_state_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import cirq
from cirq.sim import simulation_state
from cirq.testing import PhaseUsingCleanAncilla, PhaseUsingDirtyAncilla


class DummyQuantumState(cirq.QuantumStateRepresentation):
Expand All @@ -33,32 +34,90 @@ def reindex(self, axes):


class DummySimulationState(cirq.SimulationState):
def __init__(self):
super().__init__(state=DummyQuantumState(), qubits=cirq.LineQubit.range(2))
def __init__(self, qubits=cirq.LineQubit.range(2)):
super().__init__(state=DummyQuantumState(), qubits=qubits)

def _act_on_fallback_(
self, action: Any, qubits: Sequence['cirq.Qid'], allow_decompose: bool = True
) -> bool:
return True


class AncillaZ(cirq.Gate):
def __init__(self, exponent=1):
self._exponent = exponent

def num_qubits(self) -> int:
return 1

def _decompose_(self, qubits):
ancilla = cirq.NamedQubit('Ancilla')
yield cirq.CX(qubits[0], ancilla)
yield cirq.Z(ancilla) ** self._exponent
yield cirq.CX(qubits[0], ancilla)


class AncillaH(cirq.Gate):
def __init__(self, exponent=1):
self._exponent = exponent

def num_qubits(self) -> int:
return 1

def _decompose_(self, qubits):
ancilla = cirq.NamedQubit('Ancilla')
yield cirq.H(ancilla) ** self._exponent
yield cirq.CX(ancilla, qubits[0])
yield cirq.H(ancilla) ** self._exponent


class AncillaY(cirq.Gate):
def __init__(self, exponent=1):
self._exponent = exponent

def num_qubits(self) -> int:
return 1

def _decompose_(self, qubits):
ancilla = cirq.NamedQubit('Ancilla')
yield cirq.Y(ancilla) ** self._exponent
yield cirq.CX(ancilla, qubits[0])
yield cirq.Y(ancilla) ** self._exponent


class DelegatingAncillaZ(cirq.Gate):
def __init__(self, exponent=1):
self._exponent = exponent

def num_qubits(self) -> int:
return 1

def _decompose_(self, qubits):
a = cirq.NamedQubit('a')
yield cirq.CX(qubits[0], a)
yield AncillaZ(self._exponent).on(a)
yield cirq.CX(qubits[0], a)


class Composite(cirq.Gate):
def num_qubits(self) -> int:
return 1

def _decompose_(self, qubits):
yield cirq.X(*qubits)


def test_measurements():
args = DummySimulationState()
args.measure([cirq.LineQubit(0)], "test", [False], {})
assert args.log_of_measurement_results["test"] == [5]


def test_decompose():
class Composite(cirq.Gate):
def num_qubits(self) -> int:
return 1

def _decompose_(self, qubits):
yield cirq.X(*qubits)

args = DummySimulationState()
assert simulation_state.strat_act_on_from_apply_decompose(
Composite(), args, [cirq.LineQubit(0)]
assert (
simulation_state.strat_act_on_from_apply_decompose(Composite(), args, [cirq.LineQubit(0)])
is NotImplemented
)


Expand Down Expand Up @@ -101,3 +160,98 @@ def test_field_getters():
args = DummySimulationState()
assert args.prng is np.random
assert args.qubit_map == {q: i for i, q in enumerate(cirq.LineQubit.range(2))}


@pytest.mark.parametrize('exp', [-3, -2, -1, 0, 1, 2, 3])
def test_ancilla_z(exp):
q = cirq.LineQubit(0)
test_circuit = cirq.Circuit(AncillaZ(exp).on(q))

control_circuit = cirq.Circuit(cirq.ZPowGate(exponent=exp).on(q))

assert_test_circuit_for_sv_dm_simulators(test_circuit, control_circuit)


@pytest.mark.parametrize('exp', [-3, -2, -1, 0, 1, 2, 3])
def test_ancilla_y(exp):
q = cirq.LineQubit(0)
test_circuit = cirq.Circuit(AncillaY(exp).on(q))

control_circuit = cirq.Circuit(cirq.Y(q))
control_circuit.append(cirq.Y(q))
control_circuit.append(cirq.XPowGate(exponent=exp).on(q))

assert_test_circuit_for_sv_dm_simulators(test_circuit, control_circuit)


@pytest.mark.parametrize('exp', [-3, -2, -1, 0, 1, 2, 3])
def test_borrowable_qubit(exp):
q = cirq.LineQubit(0)
test_circuit = cirq.Circuit()
test_circuit.append(cirq.H(q))
test_circuit.append(cirq.X(q))
test_circuit.append(AncillaH(exp).on(q))

control_circuit = cirq.Circuit(cirq.H(q))

assert_test_circuit_for_sv_dm_simulators(test_circuit, control_circuit)


@pytest.mark.parametrize('exp', [-3, -2, -1, 0, 1, 2, 3])
def test_delegating_gate_qubit(exp):
q = cirq.LineQubit(0)

test_circuit = cirq.Circuit()
test_circuit.append(cirq.H(q))
test_circuit.append(DelegatingAncillaZ(exp).on(q))

control_circuit = cirq.Circuit(cirq.H(q))
control_circuit.append(cirq.ZPowGate(exponent=exp).on(q))

assert_test_circuit_for_sv_dm_simulators(test_circuit, control_circuit)


@pytest.mark.parametrize('num_ancilla', [1, 2, 3])
def test_phase_using_dirty_ancilla(num_ancilla: int):
q = cirq.LineQubit(0)
anc = cirq.NamedQubit.range(num_ancilla, prefix='anc')

u = cirq.MatrixGate(cirq.testing.random_unitary(2 ** (num_ancilla + 1)))
test_circuit = cirq.Circuit(
u.on(q, *anc), PhaseUsingDirtyAncilla(ancilla_bitsize=num_ancilla).on(q)
)
control_circuit = cirq.Circuit(u.on(q, *anc), cirq.Z(q))
assert_test_circuit_for_sv_dm_simulators(test_circuit, control_circuit)


@pytest.mark.parametrize('num_ancilla', [1, 2, 3])
@pytest.mark.parametrize('theta', np.linspace(0, 2 * np.pi, 10))
def test_phase_using_clean_ancilla(num_ancilla: int, theta: float):
q = cirq.LineQubit(0)
u = cirq.MatrixGate(cirq.testing.random_unitary(2))
test_circuit = cirq.Circuit(
u.on(q), PhaseUsingCleanAncilla(theta=theta, ancilla_bitsize=num_ancilla).on(q)
)
control_circuit = cirq.Circuit(u.on(q), cirq.ZPowGate(exponent=theta).on(q))
assert_test_circuit_for_sv_dm_simulators(test_circuit, control_circuit)


def test_add_qubits_raise_value_error(num_ancilla=1):
q = cirq.LineQubit(0)
args = cirq.StateVectorSimulationState(qubits=[q])

with pytest.raises(ValueError, match='should not already be tracked.'):
args.add_qubits([q])


def test_remove_qubits_not_implemented(num_ancilla=1):
args = DummySimulationState()

assert args.remove_qubits([cirq.LineQubit(0)]) is NotImplemented


def assert_test_circuit_for_sv_dm_simulators(test_circuit, control_circuit) -> None:
for test_simulator in ['cirq.final_state_vector', 'cirq.final_density_matrix']:
test_sim = eval(test_simulator)(test_circuit)
control_sim = eval(test_simulator)(control_circuit)
assert np.allclose(test_sim, control_sim)
16 changes: 16 additions & 0 deletions cirq-core/cirq/sim/state_vector_simulation_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,22 @@ def __init__(
)
super().__init__(state=state, prng=prng, qubits=qubits, classical_data=classical_data)

def add_qubits(self, qubits: Sequence['cirq.Qid']):
ret = super().add_qubits(qubits)
return (
self.kronecker_product(type(self)(qubits=qubits), inplace=True)
if ret is NotImplemented
else ret
)

def remove_qubits(self, qubits: Sequence['cirq.Qid']):
ret = super().remove_qubits(qubits)
if ret is not NotImplemented:
return ret
extracted, remainder = self.factor(qubits, inplace=True)
remainder._state._state_vector *= extracted._state._state_vector.reshape((-1,))[0]
return remainder

def _act_on_fallback_(
self, action: Any, qubits: Sequence['cirq.Qid'], allow_decompose: bool = True
) -> bool:
Expand Down