diff --git a/unitary/alpha/quantum_world.py b/unitary/alpha/quantum_world.py index 6235b908..631efa94 100644 --- a/unitary/alpha/quantum_world.py +++ b/unitary/alpha/quantum_world.py @@ -20,7 +20,12 @@ from unitary.alpha.sparse_vector_simulator import PostSelectOperation, SparseSimulator from unitary.alpha.qudit_state_transform import qudit_to_qubit_unitary, num_bits import numpy as np -import itertools +from itertools import combinations +import pandas as pd +# from tabulate import tabulate +# from texttable import Texttable +# from prettytable import PrettyTable +# from terminaltables import AsciiTable class QuantumWorld: @@ -689,25 +694,37 @@ def density_matrix( 2**num_shown_qubits, 2**num_shown_qubits ) - def measure_entanglement(self, obj1: QuantumObject, obj2: QuantumObject) -> float: - """Measures the entanglement (i.e. quantum mutual information) of the two given objects. + def measure_entanglement(self, objects: Optional[Sequence[QuantumObject]] = None) -> float: + """Measures the entanglement (i.e. quantum mutual information) of the given objects. See https://en.wikipedia.org/wiki/Quantum_mutual_information for the formula. Parameters: - obj1, obj2: two quantum objects (currently only qubits are supported) + objects: quantum objects among which the entanglement will be calculated + (currently only qubits are supported). If not specified, all current + quantum objects will be used. If specified, at least two quantum + objects are expected. Returns: - The quantum mutual information defined as S_1 + S_2 - S_12, where S denotes (reduced) - von Neumann entropy. + The quantum mutual information. For 2 qubits it's defined as S_1 + S_2 - S_12, + where S denotes (reduced) von Neumann entropy. """ - density_matrix_12 = self.density_matrix([obj1, obj2]).reshape(2, 2, 2, 2) - density_matrix_1 = cirq.partial_trace(density_matrix_12, [0]) - density_matrix_2 = cirq.partial_trace(density_matrix_12, [1]) - return ( - cirq.von_neumann_entropy(density_matrix_1, validate=False) - + cirq.von_neumann_entropy(density_matrix_2, validate=False) - - cirq.von_neumann_entropy(density_matrix_12.reshape(4, 4), validate=False) - ) + num_involved_objects = len(objects) if objects is not None else len(self.object_name_dict.values()) + + if num_involved_objects < 2: + raise ValueError(f"Could not calculate entanglement for {num_involved_objects} qubit. " + "At least 2 qubits are required.") + + involved_objects = objects if objects is not None else list(self.object_name_dict.values()) + + density_matrix = self.density_matrix(involved_objects) + reshaped_density_matrix = density_matrix.reshape(tuple([2, 2] * num_involved_objects)) + result = 0.0 + for comb in combinations(range(num_involved_objects), num_involved_objects - 1): + reshaped_partial_density_matrix = cirq.partial_trace(reshaped_density_matrix, list(comb)) + partial_density_matrix = reshaped_partial_density_matrix.reshape(2 ** (num_involved_objects - 1), 2 ** (num_involved_objects - 1)) + result += cirq.von_neumann_entropy(partial_density_matrix, validate=False) + result -= cirq.von_neumann_entropy(density_matrix, validate=False) + return result def __getitem__(self, name: str) -> QuantumObject: quantum_object = self.object_name_dict.get(name, None) diff --git a/unitary/alpha/quantum_world_test.py b/unitary/alpha/quantum_world_test.py index f8880c9d..441efe51 100644 --- a/unitary/alpha/quantum_world_test.py +++ b/unitary/alpha/quantum_world_test.py @@ -206,7 +206,6 @@ def test_unhook(simulator, compile_to_qubits): alpha.Split()(light, light2, light3) board.unhook(light2) results = board.peek([light2, light3], count=200, convert_to_enum=False) - print(results) assert all(result[0] == 0 for result in results) assert not all(result[1] == 0 for result in results) assert not all(result[1] == 1 for result in results) @@ -662,8 +661,6 @@ def test_combine_worlds(simulator, compile_to_qubits): results = world2.peek(count=100) expected = [StopLight.YELLOW] + result - print(results) - print(expected) assert all(actual == expected for actual in results) @@ -958,9 +955,13 @@ def test_measure_entanglement(simulator, compile_to_qubits): ) # S_1 + S_2 - S_12 = 0 + 0 - 0 = 0 for all three cases. - assert round(board.measure_entanglement(light1, light2)) == 0.0 - assert round(board.measure_entanglement(light1, light3)) == 0.0 - assert round(board.measure_entanglement(light2, light3)) == 0.0 + assert round(board.measure_entanglement([light1, light2]), 1) == 0.0 + assert round(board.measure_entanglement([light1, light3]), 1) == 0.0 + assert round(board.measure_entanglement([light2, light3]), 1) == 0.0 + # S_12 + S_13 + S_23 - S_123 = 0 + 0 + 0 - 0 = 0 + assert round(board.measure_entanglement([light1, light2, light3]), 1) == 0.0 + # Test with objects=None. + assert round(board.measure_entanglement(), 1) == 0.0 alpha.Superposition()(light2) alpha.quantum_if(light2).apply(alpha.Flip())(light3) @@ -968,8 +969,15 @@ def test_measure_entanglement(simulator, compile_to_qubits): assert not all(result[0] == 0 for result in results) assert (result[0] == result[1] for result in results) # S_1 + S_2 - S_12 = 0 + 1 - 1 = 0 - assert round(board.measure_entanglement(light1, light2), 1) == 0.0 - # S_1 + S_2 - S_12 = 0 + 1 - 1 = 0 - assert round(board.measure_entanglement(light1, light3), 1) == 0.0 - # S_1 + S_2 - S_12 = 1 + 1 - 0 = 2 - assert round(board.measure_entanglement(light2, light3), 1) == 2.0 + assert round(board.measure_entanglement([light1, light2]), 1) == 0.0 + # S_1 + S_3 - S_13 = 0 + 1 - 1 = 0 + assert round(board.measure_entanglement([light1, light3]), 1) == 0.0 + # S_2 + S_3 - S_23 = 1 + 1 - 0 = 2 + assert round(board.measure_entanglement([light2, light3]), 1) == 2.0 + # S_12 + S_13 + S_23 - S_123 = 1 + 1 + 0 - 0 + assert round(board.measure_entanglement([light1, light2, light3]), 1) == 2.0 + # Test with objects=None. + assert round(board.measure_entanglement(), 1) == 2.0 + # Supplying one object would return a value error. + with pytest.raises(ValueError, match="Could not calculate entanglement for 1 qubit."): + board.measure_entanglement([light1])