Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify: remove expect_2distr #15

Merged
merged 2 commits into from
May 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 21 additions & 9 deletions netket_fidelity/infidelity/overlap/expect.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from netket import jax as nkjax
from netket.utils import mpi

from netket_fidelity.utils import expect_2distr

from .operator import InfidelityOperatorStandard

Expand Down Expand Up @@ -76,7 +75,10 @@ def infidelity_sampling_MCState(
σ_t = sigma_t.reshape(-1, N)

def expect_kernel(params):
def kernel_fun(params, params_t, σ, σ_t):
def kernel_fun(params_all, samples_all):
params, params_t = params_all
σ, σ_t = samples_all

W = {"params": params, **model_state}
W_t = {"params": params_t, **model_state_t}

Expand All @@ -91,14 +93,24 @@ def kernel_fun(params, params_t, σ, σ_t):
lambda params, σ: 2 * afun_t({"params": params, **model_state_t}, σ).real
)

return expect_2distr(
log_pdf,
log_pdf_t,
def log_pdf_joint(params_all, samples_all):
params, params_t = params_all
σ, σ_t = samples_all
log_pdf_vals = log_pdf(params, σ)
log_pdf_t_vals = log_pdf_t(params_t, σ_t)
return log_pdf_vals + log_pdf_t_vals

return nkjax.expect(
log_pdf_joint,
kernel_fun,
params,
params_t,
σ,
σ_t,
(
params,
params_t,
),
(
σ,
σ_t,
),
n_chains=n_chains_t,
)

Expand Down
4 changes: 2 additions & 2 deletions netket_fidelity/infidelity/overlap_U/exact.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,8 @@ def expect_fun(params):
F, F_vjp_fun = nkjax.vjp(expect_fun, params, conjugate=True)

F_grad = F_vjp_fun(jnp.ones_like(F))[0]
F_grad = jax.tree_map(lambda x: mpi.mpi_mean_jax(x)[0], F_grad)
I_grad = jax.tree_map(lambda x: -x, F_grad)
F_grad = jax.tree_util.tree_map(lambda x: mpi.mpi_mean_jax(x)[0], F_grad)
I_grad = jax.tree_util.tree_map(lambda x: -x, F_grad)
I_stats = Stats(mean=1 - F, error_of_mean=0.0, variance=0.0)

return I_stats, I_grad
30 changes: 21 additions & 9 deletions netket_fidelity/infidelity/overlap_U/expect.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from netket.vqs import MCState, expect, expect_and_grad, get_local_kernel_arguments
from netket.utils import mpi

from netket_fidelity.utils import expect_2distr

from .operator import InfidelityOperatorUPsi

Expand Down Expand Up @@ -113,7 +112,10 @@ def infidelity_sampling_MCState(
xp_t_ravel = jnp.vstack(xp_t_splitted)

def expect_kernel(params):
def kernel_fun(params, params_t, σ, σ_t):
def kernel_fun(params_all, samples_all):
params, params_t = params_all
σ, σ_t = samples_all

W = {"params": params, **model_state}
W_t = {"params": params_t, **model_state_t}

Expand All @@ -139,14 +141,24 @@ def kernel_fun(params, params_t, σ, σ_t):
lambda params, σ: 2 * afun_t({"params": params, **model_state_t}, σ).real
)

return expect_2distr(
log_pdf,
log_pdf_t,
def log_pdf_joint(params_all, samples_all):
params, params_t = params_all
σ, σ_t = samples_all
log_pdf_vals = log_pdf(params, σ)
log_pdf_t_vals = log_pdf_t(params_t, σ_t)
return log_pdf_vals + log_pdf_t_vals

return nkjax.expect(
log_pdf_joint,
kernel_fun,
params,
params_t,
σ,
σ_t,
(
params,
params_t,
),
(
σ,
σ_t,
),
n_chains=n_chains_t,
)

Expand Down
1 change: 0 additions & 1 deletion netket_fidelity/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from .expect import expect_2distr
from .sampling_Ustate import make_logpsi_U_afun, _logpsi_U_fun

from netket.utils import _hide_submodules
Expand Down
182 changes: 0 additions & 182 deletions netket_fidelity/utils/expect.py

This file was deleted.

Loading