Skip to content

Commit

Permalink
Implemented feedback on conns test:
Browse files Browse the repository at this point in the history
-better naming convention
-np.testing.assert_allclose instead of assert

Addition
-matrix element check.
  • Loading branch information
Mohammed Boky committed Dec 5, 2023
1 parent 14d0c5e commit 8d0044c
Showing 1 changed file with 62 additions and 20 deletions.
82 changes: 62 additions & 20 deletions test/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import numpy as np
from netket_fidelity.operator import singlequbit_gates as sg
from jax import numpy as jnp

import netket_fidelity as nkf


Expand All @@ -29,40 +28,83 @@ def test_operator_dense_and_conversion(operator):
assert operator.hilbert == operator.to_local_operator().hilbert


def test_get_conns():
def test_get_conns_and_mels():
hi_spin = nk.hilbert.Spin(s=0.5, N=3)
hi_qubit = nk.hilbert.Qubit(N=3)

local_state_spin = hi_spin.local_states
local_state_qubit = hi_qubit.local_states

sigma_4_qubit = hi_qubit.numbers_to_states(2)
sigma_2_qubit = hi_qubit.numbers_to_states(2)
sigma_7_qubit = hi_qubit.numbers_to_states(7)
sigma_4_spin = hi_spin.numbers_to_states(2)
sigma_2_spin = hi_spin.numbers_to_states(2)
sigma_7_spin = hi_spin.numbers_to_states(7)

sigma_qubit = jnp.array([sigma_4_qubit, sigma_7_qubit])
sigma_spin = jnp.array([sigma_4_spin, sigma_7_spin])
sigma_qubit = jnp.array([sigma_2_qubit, sigma_7_qubit])
sigma_spin = jnp.array([sigma_2_spin, sigma_7_spin])

conns_rx_qubit, _ = sg.get_conns_and_mels_Rx(sigma_qubit, 0, 0, local_state_qubit)
conns_ry_qubit, _ = sg.get_conns_and_mels_Ry(sigma_qubit, 0, 0, local_state_qubit)
conns_h_qubit, _ = sg.get_conns_and_mels_Hadamard(sigma_qubit, 0, local_state_qubit)
conns_rx_qubit, mels_rx_qubit = sg.get_conns_and_mels_Rx(
sigma_qubit, 0, np.pi / 2, local_state_qubit
)
conns_ry_qubit, mels_ry_qubit = sg.get_conns_and_mels_Ry(
sigma_qubit, 0, np.pi / 2, local_state_qubit
)
conns_h_qubit, mels_h_qubit = sg.get_conns_and_mels_Hadamard(
sigma_qubit, 0, local_state_qubit
)

conns_rx_spin, _ = sg.get_conns_and_mels_Rx(sigma_spin, 0, 0, local_state_spin)
conns_ry_spin, _ = sg.get_conns_and_mels_Ry(sigma_spin, 0, 0, local_state_spin)
conns_h_spin, _ = sg.get_conns_and_mels_Hadamard(sigma_spin, 0, local_state_spin)
conns_rx_spin, mels_rx_spin = sg.get_conns_and_mels_Rx(
sigma_spin, 0, np.pi / 2, local_state_spin
)
conns_ry_spin, mels_ry_spin = sg.get_conns_and_mels_Ry(
sigma_spin, 0, np.pi / 2, local_state_spin
)
conns_h_spin, mels_h_spin = sg.get_conns_and_mels_Hadamard(
sigma_spin, 0, local_state_spin
)

values_check_qubit = jnp.array(
conns_check_qubit = jnp.array(
[[[0.0, 1.0, 0.0], [1.0, 1.0, 0.0]], [[1.0, 1.0, 1.0], [0.0, 1.0, 1.0]]]
)

values_check_spin = jnp.array(
conns_check_spin = jnp.array(
[[[-1.0, 1.0, -1.0], [1.0, 1.0, -1.0]], [[1.0, 1.0, 1.0], [-1.0, 1.0, 1.0]]]
)

assert (conns_rx_qubit == values_check_qubit).all()
assert (conns_ry_qubit == values_check_qubit).all()
assert (conns_h_qubit == values_check_qubit).all()
assert (conns_rx_spin == values_check_spin).all()
assert (conns_ry_spin == values_check_spin).all()
assert (conns_h_spin == values_check_spin).all()
mels_check_qubit_rx = jnp.array(
[[0.70710678 + 0.0j, 0.0 - 0.70710678j], [0.70710678 + 0.0j, 0.0 - 0.70710678j]]
)
mels_check_qubit_ry = jnp.array(
[
[0.70710678 + 0.0j, 0.70710678 + 0.0j],
[0.70710678 + 0.0j, -0.70710678 + 0.0j],
]
)
mels_check_qubit_h = jnp.array(
[[0.70710678, 0.70710678], [-0.70710678, 0.70710678]]
)

mels_check_spin_rx = jnp.array(
[[0.70710678 + 0.0j, 0.0 - 0.70710678j], [0.70710678 + 0.0j, 0.0 - 0.70710678j]]
)
mels_check_spin_ry = jnp.array(
[
[0.70710678 + 0.0j, 0.70710678 + 0.0j],
[0.70710678 + 0.0j, -0.70710678 + 0.0j],
]
)
mels_check_spin_h = jnp.array([[0.70710678, 0.70710678], [-0.70710678, 0.70710678]])

np.testing.assert_allclose(conns_rx_qubit, conns_check_qubit)
np.testing.assert_allclose(conns_ry_qubit, conns_check_qubit)
np.testing.assert_allclose(conns_h_qubit, conns_check_qubit)
np.testing.assert_allclose(conns_rx_spin, conns_check_spin)
np.testing.assert_allclose(conns_ry_spin, conns_check_spin)
np.testing.assert_allclose(conns_h_spin, conns_check_spin)

np.testing.assert_allclose(mels_rx_qubit, mels_check_qubit_rx)
np.testing.assert_allclose(mels_ry_qubit, mels_check_qubit_ry)
np.testing.assert_allclose(mels_h_qubit, mels_check_qubit_h)
np.testing.assert_allclose(mels_rx_spin, mels_check_spin_rx)
np.testing.assert_allclose(mels_ry_spin, mels_check_spin_ry)
np.testing.assert_allclose(mels_h_spin, mels_check_spin_h)

0 comments on commit 8d0044c

Please sign in to comment.