From a167b6defa66b8a399ad8761e671a0131839a160 Mon Sep 17 00:00:00 2001 From: Filippo Vicentini Date: Fri, 17 May 2024 19:54:47 +0200 Subject: [PATCH] deprecation fixes --- .../S2_dynamics_with_measurements.py | 6 +++--- netket_fidelity/driver/ptvmc.py | 4 ++-- netket_fidelity/infidelity/logic.py | 10 +++++----- netket_fidelity/infidelity/overlap/exact.py | 4 ++-- netket_fidelity/infidelity/overlap/expect.py | 4 ++-- netket_fidelity/infidelity/overlap_U/expect.py | 4 ++-- test/_infidelity_exact.py | 6 +++--- 7 files changed, 19 insertions(+), 19 deletions(-) diff --git a/examples/dynamics_with_measurements/S2_dynamics_with_measurements.py b/examples/dynamics_with_measurements/S2_dynamics_with_measurements.py index ff9f656..ed544b8 100644 --- a/examples/dynamics_with_measurements/S2_dynamics_with_measurements.py +++ b/examples/dynamics_with_measurements/S2_dynamics_with_measurements.py @@ -96,7 +96,7 @@ def projective_measurement(phi, psi, p, key_meas, key_spin): key_spin, subkey_spin = jax.random.split(key_spin) params = flax.core.unfreeze(psi.parameters) - params = jax.tree_map(lambda x: jnp.array(x), params) + params = jax.tree_util.tree_map(lambda x: jnp.array(x), params) if jax.random.uniform(subkey_spin) < prob_up.real: params["orbital_down"] = params["orbital_down"].at[i].set(1e-12) else: @@ -128,7 +128,7 @@ def dynamics_with_measurements( # ZZ diagonal term params = flax.core.unfreeze(psi.parameters) - params = jax.tree_map(lambda x: jnp.array(x), params) + params = jax.tree_util.tree_map(lambda x: jnp.array(x), params) for l, m in g.edges(): params["theta_zz"] = ( params["theta_zz"] @@ -154,7 +154,7 @@ def dynamics_with_measurements( # ZZ diagonal term params = flax.core.unfreeze(psi.parameters) - params = jax.tree_map(lambda x: jnp.array(x), params) + params = jax.tree_util.tree_map(lambda x: jnp.array(x), params) for l, m in g.edges(): params["theta_zz"] = ( params["theta_zz"] diff --git a/netket_fidelity/driver/ptvmc.py b/netket_fidelity/driver/ptvmc.py index b8a6ee4..6d8c5a5 100644 --- a/netket_fidelity/driver/ptvmc.py +++ b/netket_fidelity/driver/ptvmc.py @@ -22,7 +22,7 @@ def __init__( U_dagger=None, preconditioner: PreconditionerT = identity_preconditioner, is_unitary=False, - sample_Upsi=False, + sample_Upsi=False, cv_coeff=None, ): self._dt = dt @@ -41,7 +41,7 @@ def __init__( U_dagger=U_dagger, preconditioner=preconditioner, is_unitary=is_unitary, - sample_Upsi=sample_Upsi, + sample_Upsi=sample_Upsi, cv_coeff=cv_coeff, ) diff --git a/netket_fidelity/infidelity/logic.py b/netket_fidelity/infidelity/logic.py index 08bc819..aa0da51 100644 --- a/netket_fidelity/infidelity/logic.py +++ b/netket_fidelity/infidelity/logic.py @@ -59,7 +59,7 @@ def InfidelityOperator( the function :class:`netket_fidelity.infidelity.InfidelityUPsi` . This works only with the operators provdided in the package. - We remark that sampling from :math:`U|\phi\rangle` requires to compute connected elements of + We remark that sampling from :math:`U|\phi\rangle` requires to compute connected elements of :math:`U` and so is more expensive than sampling from an autonomous state. The choice of this estimator is specified by passing :code:`sample_Upsi=True`, while the flag argument :code:`is_unitary` indicates whether :math:`U` is unitary or not. @@ -82,7 +82,7 @@ def InfidelityOperator( This estimator is more efficient since it does not require to sample from :math:`U|\phi\rangle`, but only from :math:`|\phi\rangle`. This choice of the estimator is the default and it works only - with `is_unitary==True` (besides :code:`sample_Upsi=False` ). + with `is_unitary==True` (besides :code:`sample_Upsi=False` ). When :math:`|\Phi⟩ = |\phi⟩` the two estimators coincides. To reduce the variance of the estimator, the Control Variates (CV) method can be applied. This consists @@ -100,8 +100,8 @@ def InfidelityOperator( c* = \frac{\rm{Cov}_{χ}\left[ |1-I_{loc}|^2, \rm{Re}\left[1-I_{loc}\right]\right]}{ \rm{Var}_{χ}\left[ |1-I_{loc}|^2\right] }, - where :math:`\rm{Cov}\left\cdot, \cdot\right]` indicates the covariance and :math:`\rm{Var}\left[\cdot\right]` the variance. - In the relevant limit :math:`|\Psi⟩ \rightarrow|\Phi⟩`, we have :math:`c^\star \rightarrow -1/2`. The value :math:`-1/2` is + where :math:`\rm{Cov}\left\cdot, \cdot\right]` indicates the covariance and :math:`\rm{Var}\left[\cdot\right]` the variance. + In the relevant limit :math:`|\Psi⟩ \rightarrow|\Phi⟩`, we have :math:`c^\star \rightarrow -1/2`. The value :math:`-1/2` is adopted as default value for c in the infidelity estimator. To not apply CV, set c=0. @@ -110,7 +110,7 @@ def InfidelityOperator( U: operator :math:`\hat{U}`. U_dagger: dagger operator :math:`\hat{U^\dagger}`. cv_coeff: Control Variates coefficient c. - is_unitary: flag specifiying the unitarity of :math:`\hat{U}`. If True with + is_unitary: flag specifiying the unitarity of :math:`\hat{U}`. If True with :code:`sample_Upsi=False`, the second estimator is used. dtype: The dtype of the output of expectation value and gradient. sample_Upsi: flag specifiying whether to sample from |ϕ⟩ or from U|ϕ⟩. If False with `is_unitary=False` , an error occurs. diff --git a/netket_fidelity/infidelity/overlap/exact.py b/netket_fidelity/infidelity/overlap/exact.py index 0eb14d7..aefe9bc 100644 --- a/netket_fidelity/infidelity/overlap/exact.py +++ b/netket_fidelity/infidelity/overlap/exact.py @@ -71,8 +71,8 @@ def expect_fun(params): F, F_vjp_fun = nkjax.vjp(expect_fun, params, conjugate=True) F_grad = F_vjp_fun(jnp.ones_like(F))[0] - F_grad = jax.tree_map(lambda x: mpi.mpi_mean_jax(x)[0], F_grad) - I_grad = jax.tree_map(lambda x: -x, F_grad) + F_grad = jax.tree_util.tree_map(lambda x: mpi.mpi_mean_jax(x)[0], F_grad) + I_grad = jax.tree_util.tree_map(lambda x: -x, F_grad) I_stats = Stats(mean=1 - F, error_of_mean=0.0, variance=0.0) return I_stats, I_grad diff --git a/netket_fidelity/infidelity/overlap/expect.py b/netket_fidelity/infidelity/overlap/expect.py index 576214a..a4bbe61 100644 --- a/netket_fidelity/infidelity/overlap/expect.py +++ b/netket_fidelity/infidelity/overlap/expect.py @@ -111,8 +111,8 @@ def kernel_fun(params, params_t, σ, σ_t): ) F_grad = F_vjp_fun(jnp.ones_like(F))[0] - F_grad = jax.tree_map(lambda x: mpi.mpi_mean_jax(x)[0], F_grad) - I_grad = jax.tree_map(lambda x: -x, F_grad) + F_grad = jax.tree_util.tree_map(lambda x: mpi.mpi_mean_jax(x)[0], F_grad) + I_grad = jax.tree_util.tree_map(lambda x: -x, F_grad) I_stats = F_stats.replace(mean=1 - F) return I_stats, I_grad diff --git a/netket_fidelity/infidelity/overlap_U/expect.py b/netket_fidelity/infidelity/overlap_U/expect.py index d8c07af..f0f9e12 100644 --- a/netket_fidelity/infidelity/overlap_U/expect.py +++ b/netket_fidelity/infidelity/overlap_U/expect.py @@ -159,8 +159,8 @@ def kernel_fun(params, params_t, σ, σ_t): ) F_grad = F_vjp_fun(jnp.ones_like(F))[0] - F_grad = jax.tree_map(lambda x: mpi.mpi_mean_jax(x)[0], F_grad) - I_grad = jax.tree_map(lambda x: -x, F_grad) + F_grad = jax.tree_util.tree_map(lambda x: mpi.mpi_mean_jax(x)[0], F_grad) + I_grad = jax.tree_util.tree_map(lambda x: -x, F_grad) I_stats = F_stats.replace(mean=1 - F) return I_stats, I_grad diff --git a/test/_infidelity_exact.py b/test/_infidelity_exact.py index 643e358..f6dec2b 100644 --- a/test/_infidelity_exact.py +++ b/test/_infidelity_exact.py @@ -15,6 +15,6 @@ def _infidelity_exact(params_new, vstate, U): ) else: - return 1 - jnp.absolute(state_new.conj().T @ U.to_sparse() @ state_old) ** 2 / ( - (state_new.conj().T @ state_new) * (state_old.conj().T @ state_old) - ) + return 1 - jnp.absolute( + state_new.conj().T @ (U.to_sparse() @ state_old) + ) ** 2 / ((state_new.conj().T @ state_new) * (state_old.conj().T @ state_old))