diff --git a/netket_fidelity/infidelity/overlap/expect.py b/netket_fidelity/infidelity/overlap/expect.py index a4bbe61..8d342ce 100644 --- a/netket_fidelity/infidelity/overlap/expect.py +++ b/netket_fidelity/infidelity/overlap/expect.py @@ -7,7 +7,6 @@ from netket import jax as nkjax from netket.utils import mpi -from netket_fidelity.utils import expect_2distr from .operator import InfidelityOperatorStandard @@ -76,7 +75,10 @@ def infidelity_sampling_MCState( σ_t = sigma_t.reshape(-1, N) def expect_kernel(params): - def kernel_fun(params, params_t, σ, σ_t): + def kernel_fun(params_all, samples_all): + params, params_t = params_all + σ, σ_t = samples_all + W = {"params": params, **model_state} W_t = {"params": params_t, **model_state_t} @@ -91,14 +93,24 @@ def kernel_fun(params, params_t, σ, σ_t): lambda params, σ: 2 * afun_t({"params": params, **model_state_t}, σ).real ) - return expect_2distr( - log_pdf, - log_pdf_t, + def log_pdf_joint(params_all, samples_all): + params, params_t = params_all + σ, σ_t = samples_all + log_pdf_vals = log_pdf(params, σ) + log_pdf_t_vals = log_pdf_t(params_t, σ_t) + return log_pdf_vals + log_pdf_t_vals + + return nkjax.expect( + log_pdf_joint, kernel_fun, - params, - params_t, - σ, - σ_t, + ( + params, + params_t, + ), + ( + σ, + σ_t, + ), n_chains=n_chains_t, ) diff --git a/netket_fidelity/infidelity/overlap_U/exact.py b/netket_fidelity/infidelity/overlap_U/exact.py index 88d0471..2819ea5 100644 --- a/netket_fidelity/infidelity/overlap_U/exact.py +++ b/netket_fidelity/infidelity/overlap_U/exact.py @@ -82,8 +82,8 @@ def expect_fun(params): F, F_vjp_fun = nkjax.vjp(expect_fun, params, conjugate=True) F_grad = F_vjp_fun(jnp.ones_like(F))[0] - F_grad = jax.tree_map(lambda x: mpi.mpi_mean_jax(x)[0], F_grad) - I_grad = jax.tree_map(lambda x: -x, F_grad) + F_grad = jax.tree_util.tree_map(lambda x: mpi.mpi_mean_jax(x)[0], F_grad) + I_grad = jax.tree_util.tree_map(lambda x: -x, F_grad) I_stats = Stats(mean=1 - F, error_of_mean=0.0, variance=0.0) return I_stats, I_grad diff --git a/netket_fidelity/infidelity/overlap_U/expect.py b/netket_fidelity/infidelity/overlap_U/expect.py index f0f9e12..89623ca 100644 --- a/netket_fidelity/infidelity/overlap_U/expect.py +++ b/netket_fidelity/infidelity/overlap_U/expect.py @@ -9,7 +9,6 @@ from netket.vqs import MCState, expect, expect_and_grad, get_local_kernel_arguments from netket.utils import mpi -from netket_fidelity.utils import expect_2distr from .operator import InfidelityOperatorUPsi @@ -113,7 +112,10 @@ def infidelity_sampling_MCState( xp_t_ravel = jnp.vstack(xp_t_splitted) def expect_kernel(params): - def kernel_fun(params, params_t, σ, σ_t): + def kernel_fun(params_all, samples_all): + params, params_t = params_all + σ, σ_t = samples_all + W = {"params": params, **model_state} W_t = {"params": params_t, **model_state_t} @@ -139,14 +141,24 @@ def kernel_fun(params, params_t, σ, σ_t): lambda params, σ: 2 * afun_t({"params": params, **model_state_t}, σ).real ) - return expect_2distr( - log_pdf, - log_pdf_t, + def log_pdf_joint(params_all, samples_all): + params, params_t = params_all + σ, σ_t = samples_all + log_pdf_vals = log_pdf(params, σ) + log_pdf_t_vals = log_pdf_t(params_t, σ_t) + return log_pdf_vals + log_pdf_t_vals + + return nkjax.expect( + log_pdf_joint, kernel_fun, - params, - params_t, - σ, - σ_t, + ( + params, + params_t, + ), + ( + σ, + σ_t, + ), n_chains=n_chains_t, ) diff --git a/netket_fidelity/utils/__init__.py b/netket_fidelity/utils/__init__.py index ac817ba..fcd0305 100644 --- a/netket_fidelity/utils/__init__.py +++ b/netket_fidelity/utils/__init__.py @@ -1,4 +1,3 @@ -from .expect import expect_2distr from .sampling_Ustate import make_logpsi_U_afun, _logpsi_U_fun from netket.utils import _hide_submodules diff --git a/netket_fidelity/utils/expect.py b/netket_fidelity/utils/expect.py deleted file mode 100644 index 05e4bc5..0000000 --- a/netket_fidelity/utils/expect.py +++ /dev/null @@ -1,182 +0,0 @@ -from typing import Callable, Tuple -from functools import partial -import jax.numpy as jnp -import jax -from netket.utils.types import PyTree -from netket.jax import vjp as nkvjp -from netket.stats import statistics as mpi_statistics, Stats - - -def expect_2distr( - log_pdf_new: Callable[[PyTree, jnp.ndarray], jnp.ndarray], - log_pdf_old: Callable[[PyTree, jnp.ndarray], jnp.ndarray], - expected_fun: Callable[[PyTree, jnp.ndarray], jnp.ndarray], - pars_new: PyTree, - pars_old: PyTree, - σ_new: jnp.ndarray, - σ_old: jnp.ndarray, - *expected_fun_args, - n_chains: int = None, -) -> Tuple[jnp.ndarray, Stats]: - """ - Computes the expectation value over a log-pdf. - - Args: - log_pdf: - expected_ffun - """ - - return _expect_2distr( - n_chains, - log_pdf_new, - log_pdf_old, - expected_fun, - pars_new, - pars_old, - σ_new, - σ_old, - *expected_fun_args, - ) - - -@partial(jax.custom_vjp, nondiff_argnums=(0, 1, 2, 3)) -def _expect_2distr( - n_chains, - log_pdf_new, - log_pdf_old, - expected_fun, - pars_new, - pars_old, - σ_new, - σ_old, - *expected_fun_args, -): - L_σ = expected_fun(pars_new, pars_old, σ_new, σ_old, *expected_fun_args) - if n_chains is not None: - L_σ = L_σ.reshape((n_chains, -1)) - - L̄_σ = mpi_statistics(L_σ.T) - - return L̄_σ.mean, L̄_σ - - -def _expect_fwd_fid( - n_chains, - log_pdf_new, - log_pdf_old, - expected_fun, - pars_new, - pars_old, - σ_new, - σ_old, - *expected_fun_args, -): - L_σ = expected_fun(pars_new, pars_old, σ_new, σ_old, *expected_fun_args) - if n_chains is not None: - L_σ_r = L_σ.reshape((n_chains, -1)) - else: - L_σ_r = L_σ - - L̄_stat = mpi_statistics(L_σ_r.T) - - L̄_σ = L̄_stat.mean - - # Use the baseline trick to reduce the variance - ΔL_σ = L_σ - L̄_σ - - return (L̄_σ, L̄_stat), (pars_new, pars_old, σ_new, σ_old, expected_fun_args, ΔL_σ) - - -def _expect_bwd_fid(n_chains, log_pdf_new, log_pdf_old, expected_fun, residuals, dout): - pars_new, pars_old, σ_new, σ_old, cost_args, ΔL_σ = residuals - dL̄, dL̄_stats = dout - log_p_old = log_pdf_old(pars_old, σ_old) - - def f(pars_new, pars_old, σ_new, σ_old, *cost_args): - log_p = log_pdf_new(pars_new, σ_new) + log_p_old - term1 = jax.vmap(jnp.multiply)(ΔL_σ, log_p) - term2 = expected_fun(pars_new, pars_old, σ_new, σ_old, *cost_args) - out = term1 + term2 - out = out.mean() - return out - - _, pb = nkvjp(f, pars_new, pars_old, σ_new, σ_old, *cost_args) - - grad_f = pb(dL̄) - - return grad_f - - -_expect_2distr.defvjp(_expect_fwd_fid, _expect_bwd_fid) - - -def expect_onedistr( - log_pdf: Callable[[PyTree, jnp.ndarray], jnp.ndarray], - expected_fun: Callable[[PyTree, jnp.ndarray], jnp.ndarray], - pars: PyTree, - σ: jnp.ndarray, - *expected_fun_args, - n_chains: int = None, -) -> Tuple[jnp.ndarray, Stats]: - """ - Computes the expectation value over a log-pdf. - - Args: - log_pdf: - expected_ffun - """ - return _expect_onedistr( - n_chains, log_pdf, expected_fun, pars, σ, *expected_fun_args - ) - - -@partial(jax.custom_vjp, nondiff_argnums=(0, 1, 2)) -def _expect_onedistr(n_chains, log_pdf, expected_fun, pars, σ, *expected_fun_args): - L_σ = expected_fun(pars, σ, *expected_fun_args) - if n_chains is not None: - L_σ = L_σ.reshape((n_chains, -1)) - - L̄_σ = mpi_statistics(L_σ.T) - # L̄_σ = L_σ.mean(axis=0) - - return L̄_σ.mean, L̄_σ - - -def _expect_onedistr_fwd(n_chains, log_pdf, expected_fun, pars, σ, *expected_fun_args): - L_σ = expected_fun(pars, σ, *expected_fun_args) - if n_chains is not None: - L_σ_r = L_σ.reshape((n_chains, -1)) - else: - L_σ_r = L_σ - - L̄_stat = mpi_statistics(L_σ_r.T) - - L̄_σ = L̄_stat.mean - # L̄_σ = L_σ.mean(axis=0) - - # Use the baseline trick to reduce the variance - ΔL_σ = L_σ - L̄_σ - - return (L̄_σ, L̄_stat), (pars, σ, expected_fun_args, ΔL_σ) - - -def _expect_onedistr_bwd(n_chains, log_pdf, expected_fun, residuals, dout): - pars, σ, cost_args, ΔL_σ = residuals - dL̄, dL̄_stats = dout - - def f(pars, σ, *cost_args): - log_p = log_pdf(pars, σ) - term1 = jax.vmap(jnp.multiply)(ΔL_σ, log_p) - term2 = expected_fun(pars, σ, *cost_args) - out = term1 + term2 - out = out.mean() - return out - - _, pb = nkvjp(f, pars, σ, *cost_args) - - grad_f = pb(dL̄) - - return grad_f - - -_expect_onedistr.defvjp(_expect_onedistr_fwd, _expect_onedistr_bwd)