diff --git a/netket_fidelity/infidelity/overlap_U/exact.py b/netket_fidelity/infidelity/overlap_U/exact.py index 4a20216..8ddd23b 100644 --- a/netket_fidelity/infidelity/overlap_U/exact.py +++ b/netket_fidelity/infidelity/overlap_U/exact.py @@ -1,5 +1,6 @@ import jax.numpy as jnp import jax +from functools import partial from netket import jax as nkjax from netket.utils.dispatch import TrueT @@ -21,13 +22,15 @@ def infidelity(vstate: FullSumState, op: InfidelityOperatorUPsi): if not isinstance(op.target, FullSumState): raise TypeError("Can only compute infidelity of exact states.") + U_sp = sparsify(op._U) + Ustate_t = U_sp @ op.target.to_array(normalize=False) + return infidelity_sampling_FullSumState( vstate._apply_fun, vstate.parameters, vstate.model_state, vstate._all_states, - op.target.to_array(), - op._U.to_sparse(), + Ustate_t, return_grad=False, ) @@ -45,30 +48,34 @@ def infidelity( # noqa: F811 if not isinstance(op.target, FullSumState): raise TypeError("Can only compute infidelity of exact states.") + U_sp = sparsify(op._U) + Ustate_t = U_sp @ op.target.to_array(normalize=False) + return infidelity_sampling_FullSumState( vstate._apply_fun, vstate.parameters, vstate.model_state, vstate._all_states, - op.target.to_array(), - op._U.to_sparse(), + Ustate_t, return_grad=True, ) +@partial(jax.jit, static_argnames=("afun", "return_grad")) def infidelity_sampling_FullSumState( afun, params, model_state, sigma, - state_t, - U_sp, + Ustate_t, return_grad, ): def expect_fun(params): state = jnp.exp(afun({"params": params, **model_state}, sigma)) state = state / jnp.sqrt(jnp.sum(jnp.abs(state) ** 2)) - return jnp.abs(state.T.conj() @ (U_sp @ state_t)) ** 2 + return jnp.abs(state.T.conj().T @ Ustate_t) ** 2 / ( + Ustate_t.conj().T @ Ustate_t + ) if not return_grad: F = expect_fun(params)