Skip to content

Commit

Permalink
coverage
Browse files Browse the repository at this point in the history
  • Loading branch information
renatomello committed Oct 15, 2024
1 parent 238fa6d commit e813f56
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 13 deletions.
18 changes: 9 additions & 9 deletions src/qibo/quantum_info/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -891,26 +891,26 @@ def quantum_fisher_information_matrix(

if parameters is None:
parameters = circuit.get_parameters()
parameters = backend.cast(parameters, dtype=float)
parameters = backend.cast(parameters, dtype=float).flatten()

jacobian = backend.calculate_jacobian_matrix(
circuit, parameters, initial_state, return_complex
)
jacobian = (
jacobian[0] + 1j * jacobian[1]
if return_complex
else backend.cast(jacobian, dtype=np.complex128)
)

if return_complex:
jacobian = jacobian[0] + 1j * jacobian[1]

jacobian = backend.cast(jacobian, dtype=np.complex128)

copied = circuit.copy(deep=True)
copied.set_parameters(parameters)

state = backend.execute_circuit(copied, initial_state=initial_state).state()

overlaps = backend.np.conj(state.T) @ jacobian
print(state.dtype, jacobian.dtype)
overlaps = jacobian.T @ state

qfim = jacobian.T @ jacobian
qfim = qfim - backend.np.outer(backend.np.conj(overlaps.T), overlaps)
qfim = qfim - backend.np.outer(overlaps, backend.np.conj(overlaps.T))

return 4 * backend.np.real(qfim)

Expand Down
14 changes: 10 additions & 4 deletions tests/test_quantum_info_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,9 +390,10 @@ def test_frame_potential(backend, nqubits, power_t, samples):
backend.assert_allclose(potential, potential_haar, rtol=1e-2, atol=1e-2)


@pytest.mark.parametrize("params_flag", [None, True])
@pytest.mark.parametrize("return_complex", [False, True])
@pytest.mark.parametrize("nqubits", [4, 8])
def test_qfim(backend, nqubits, return_complex):
def test_qfim(backend, nqubits, return_complex, params_flag):
if backend.name not in ["pytorch", "tensorflow"]:
circuit = Circuit(nqubits)
params = np.random.rand(3)
Expand All @@ -411,14 +412,19 @@ def test_qfim(backend, nqubits, return_complex):
for param in params[:-1]:
elem = float(target[-1] * backend.np.sin(param) ** 2)
target.append(elem)
target = 4 * backend.np.diag(backend.cast(target))
target = 4 * backend.np.diag(backend.cast(target, dtype=np.float64))

# numerical qfim from quantum_info
circuit = unary_encoder(data, "diagonal")
circuit.set_parameters(params)

if params_flag is not None:
circuit.set_parameters(params)
else:
params = params_flag

qfim = quantum_fisher_information_matrix(
circuit, params, return_complex=return_complex, backend=backend
)
# qfim = backend.cast(qfim, dtype=np.float64)

backend.assert_allclose(qfim, target, atol=1e-10)
backend.assert_allclose(qfim, target, atol=1e-6)

0 comments on commit e813f56

Please sign in to comment.