diff --git a/src/qibo/quantum_info/metrics.py b/src/qibo/quantum_info/metrics.py index c320d99ffb..6a4cd9b970 100644 --- a/src/qibo/quantum_info/metrics.py +++ b/src/qibo/quantum_info/metrics.py @@ -891,26 +891,26 @@ def quantum_fisher_information_matrix( if parameters is None: parameters = circuit.get_parameters() - parameters = backend.cast(parameters, dtype=float) + parameters = backend.cast(parameters, dtype=float).flatten() jacobian = backend.calculate_jacobian_matrix( circuit, parameters, initial_state, return_complex ) - jacobian = ( - jacobian[0] + 1j * jacobian[1] - if return_complex - else backend.cast(jacobian, dtype=np.complex128) - ) + + if return_complex: + jacobian = jacobian[0] + 1j * jacobian[1] + + jacobian = backend.cast(jacobian, dtype=np.complex128) copied = circuit.copy(deep=True) copied.set_parameters(parameters) state = backend.execute_circuit(copied, initial_state=initial_state).state() - - overlaps = backend.np.conj(state.T) @ jacobian + print(state.dtype, jacobian.dtype) + overlaps = jacobian.T @ state qfim = jacobian.T @ jacobian - qfim = qfim - backend.np.outer(backend.np.conj(overlaps.T), overlaps) + qfim = qfim - backend.np.outer(overlaps, backend.np.conj(overlaps.T)) return 4 * backend.np.real(qfim) diff --git a/tests/test_quantum_info_metrics.py b/tests/test_quantum_info_metrics.py index 93d5027cbe..3187f1ea6f 100644 --- a/tests/test_quantum_info_metrics.py +++ b/tests/test_quantum_info_metrics.py @@ -390,9 +390,10 @@ def test_frame_potential(backend, nqubits, power_t, samples): backend.assert_allclose(potential, potential_haar, rtol=1e-2, atol=1e-2) +@pytest.mark.parametrize("params_flag", [None, True]) @pytest.mark.parametrize("return_complex", [False, True]) @pytest.mark.parametrize("nqubits", [4, 8]) -def test_qfim(backend, nqubits, return_complex): +def test_qfim(backend, nqubits, return_complex, params_flag): if backend.name not in ["pytorch", "tensorflow"]: circuit = Circuit(nqubits) params = np.random.rand(3) @@ -411,14 +412,19 @@ def test_qfim(backend, nqubits, return_complex): for param in params[:-1]: elem = float(target[-1] * backend.np.sin(param) ** 2) target.append(elem) - target = 4 * backend.np.diag(backend.cast(target)) + target = 4 * backend.np.diag(backend.cast(target, dtype=np.float64)) # numerical qfim from quantum_info circuit = unary_encoder(data, "diagonal") - circuit.set_parameters(params) + + if params_flag is not None: + circuit.set_parameters(params) + else: + params = params_flag qfim = quantum_fisher_information_matrix( circuit, params, return_complex=return_complex, backend=backend ) + # qfim = backend.cast(qfim, dtype=np.float64) - backend.assert_allclose(qfim, target, atol=1e-10) + backend.assert_allclose(qfim, target, atol=1e-6)