Skip to content

Commit

Permalink
Merge branch 'main' into rot-hs-qb
Browse files Browse the repository at this point in the history
  • Loading branch information
PhilipVinc authored Nov 8, 2023
2 parents f32dc13 + 6277e57 commit c2567eb
Show file tree
Hide file tree
Showing 12 changed files with 31 additions and 30 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ jobs:
- name: Run tests
run: |
export NETKET_EXPERIMENTAL=1
pytest --cov=netket --cov-append test
pytest --cov=netket_fidelity --cov-append test
- name: Run docstring tests
if: ${{ matrix.doctest }}
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/formatting_check.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,6 @@ jobs:
- name: Set up Python 3.10
uses: chartboost/ruff-action@v1
with:
version: 0.1.3
version: 0.1.4
args: --config pyproject.toml
src: netket_fidelity test examples
src: netket_fidelity examples test
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import jax

from RBM_Jastrow_measurement import RBMJasMeas
from netket_fidelity.renyi2 import Renyi2EntanglementEntropy

# Set the parameters
L = 2
Expand Down Expand Up @@ -64,7 +65,7 @@

# Instantiate the Renyi2 entropy to monitor
subsys = [x for x in range(N // 2)]
S2op = nkf.Renyi2EntanglementEntropy(hi, subsys)
S2op = Renyi2EntanglementEntropy(hi, subsys)


# Compute the probabilities for the measurement outcomes of a spin
Expand Down
3 changes: 1 addition & 2 deletions netket_fidelity/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@
from . import driver

from .infidelity import InfidelityOperator
from .renyi2 import Renyi2EntanglementEntropy

from netket.utils import _hide_submodules

_hide_submodules(__name__, hide_folder=["renyi2", "infidelity"])
_hide_submodules(__name__, hide_folder=["infidelity"])
2 changes: 0 additions & 2 deletions netket_fidelity/infidelity/overlap/exact.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import jax

from netket import jax as nkjax
from netket.utils.dispatch import TrueT
from netket.vqs import FullSumState, expect, expect_and_grad
from netket.utils import mpi
from netket.stats import Stats
Expand Down Expand Up @@ -33,7 +32,6 @@ def infidelity(vstate: FullSumState, op: InfidelityOperatorStandard):
def infidelity( # noqa: F811
vstate: FullSumState,
op: InfidelityOperatorStandard,
use_covariance: TrueT,
*,
mutable,
):
Expand Down
5 changes: 2 additions & 3 deletions netket_fidelity/infidelity/overlap/expect.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import jax.numpy as jnp
import jax

from netket.utils.dispatch import TrueT
from netket.vqs import MCState, expect, expect_and_grad
from netket import jax as nkjax
from netket.utils import mpi
Expand All @@ -14,7 +13,7 @@


@expect.dispatch
def infidelity(vstate: MCState, op: InfidelityOperatorStandard):
def infidelity(vstate: MCState, op: InfidelityOperatorStandard, chunk_size: None):
if op.hilbert != vstate.hilbert:
raise TypeError("Hilbert spaces should match")

Expand All @@ -36,7 +35,7 @@ def infidelity(vstate: MCState, op: InfidelityOperatorStandard):
def infidelity( # noqa: F811
vstate: MCState,
op: InfidelityOperatorStandard,
use_covariance: TrueT,
chunk_size: None,
*,
mutable,
):
Expand Down
23 changes: 14 additions & 9 deletions netket_fidelity/infidelity/overlap_U/exact.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import jax.numpy as jnp
import jax
from functools import partial

from netket import jax as nkjax
from netket.utils.dispatch import TrueT
from netket.vqs import FullSumState, expect, expect_and_grad
from netket.utils import mpi
from netket.stats import Stats
Expand All @@ -21,13 +21,15 @@ def infidelity(vstate: FullSumState, op: InfidelityOperatorUPsi):
if not isinstance(op.target, FullSumState):
raise TypeError("Can only compute infidelity of exact states.")

U_sp = sparsify(op._U)
Ustate_t = U_sp @ op.target.to_array(normalize=False)

return infidelity_sampling_FullSumState(
vstate._apply_fun,
vstate.parameters,
vstate.model_state,
vstate._all_states,
op.target.to_array(),
op._U.to_sparse(),
Ustate_t,
return_grad=False,
)

Expand All @@ -36,7 +38,6 @@ def infidelity(vstate: FullSumState, op: InfidelityOperatorUPsi):
def infidelity( # noqa: F811
vstate: FullSumState,
op: InfidelityOperatorUPsi,
use_covariance: TrueT,
*,
mutable,
):
Expand All @@ -45,30 +46,34 @@ def infidelity( # noqa: F811
if not isinstance(op.target, FullSumState):
raise TypeError("Can only compute infidelity of exact states.")

U_sp = sparsify(op._U)
Ustate_t = U_sp @ op.target.to_array(normalize=False)

return infidelity_sampling_FullSumState(
vstate._apply_fun,
vstate.parameters,
vstate.model_state,
vstate._all_states,
op.target.to_array(),
op._U.to_sparse(),
Ustate_t,
return_grad=True,
)


@partial(jax.jit, static_argnames=("afun", "return_grad"))
def infidelity_sampling_FullSumState(
afun,
params,
model_state,
sigma,
state_t,
U_sp,
Ustate_t,
return_grad,
):
def expect_fun(params):
state = jnp.exp(afun({"params": params, **model_state}, sigma))
state = state / jnp.sqrt(jnp.sum(jnp.abs(state) ** 2))
return jnp.abs(state.T.conj() @ (U_sp @ state_t)) ** 2
return jnp.abs(state.T.conj().T @ Ustate_t) ** 2 / (
Ustate_t.conj().T @ Ustate_t
)

if not return_grad:
F = expect_fun(params)
Expand Down
5 changes: 2 additions & 3 deletions netket_fidelity/infidelity/overlap_U/expect.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,14 @@
from netket.operator import DiscreteJaxOperator
from netket.vqs import MCState, expect, expect_and_grad, get_local_kernel_arguments
from netket.utils import mpi
from netket.utils.dispatch import TrueT

from netket_fidelity.utils import expect_2distr

from .operator import InfidelityOperatorUPsi


@expect.dispatch
def infidelity(vstate: MCState, op: InfidelityOperatorUPsi):
def infidelity(vstate: MCState, op: InfidelityOperatorUPsi, chunk_size: None):
if op.hilbert != vstate.hilbert:
raise TypeError("Hilbert spaces should match")

Expand All @@ -43,7 +42,7 @@ def infidelity(vstate: MCState, op: InfidelityOperatorUPsi):
def infidelity( # noqa: F811
vstate: MCState,
op: InfidelityOperatorUPsi,
use_covariance: TrueT,
chunk_size: None,
*,
mutable,
):
Expand Down
2 changes: 1 addition & 1 deletion netket_fidelity/renyi2/expect.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@


@expect.dispatch
def Renyi2(vstate: MCState, op: Renyi2EntanglementEntropy):
def Renyi2(vstate: MCState, op: Renyi2EntanglementEntropy, chunk_size: None):
if op.hilbert != vstate.hilbert:
raise TypeError("Hilbert spaces should match")

Expand Down
6 changes: 3 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ description="Infidelity operator for NetKet."
readme = "README.md"
requires-python = ">=3.9"
dependencies = [
"netket~=3.9",
"netket~=3.10",
]
dynamic = ["version"]

Expand All @@ -21,7 +21,7 @@ dev = [
"coverage>=5",
"pre-commit>=2.7",
"black==23.10.1",
"ruff==0.1.3",
"ruff==0.1.4",
"wheel",
"build",
"qutip",
Expand Down Expand Up @@ -91,4 +91,4 @@ exclude = ["Examples/Legacy"]
[tool.ruff.per-file-ignores]
"__init__.py" = ["E402","F401"]
"netket/nn/activation.py" = ["F401"]
"Examples/" = ["F401"]
"Examples/" = ["F401"]
4 changes: 2 additions & 2 deletions test/test_Renyi2.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import netket as nk
import numpy as np

import netket_fidelity as nkf

from ._Renyi2_exact import _Renyi2_exact
from netket_fidelity.renyi2 import Renyi2EntanglementEntropy

N = 3
hi = nk.hilbert.Spin(0.5, N)
Expand All @@ -29,7 +29,7 @@ def _setup():
)

subsys = [0, 1]
S2 = nkf.Renyi2EntanglementEntropy(hi, subsys)
S2 = Renyi2EntanglementEntropy(hi, subsys)

return vs, vs_exact, S2, subsys

Expand Down
2 changes: 1 addition & 1 deletion test/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

operators["Rx"] = nkf.operator.Rx(hi, 1, 0.23)
operators["Ry"] = nkf.operator.Ry(hi, 1, 0.43)
operators["Ry"] = nkf.operator.Hadamard(hi, 0)
operators["Hadamard"] = nkf.operator.Hadamard(hi, 0)


@pytest.mark.parametrize(
Expand Down

0 comments on commit c2567eb

Please sign in to comment.