diff --git a/netket_fidelity/utils/sampling_Ustate.py b/netket_fidelity/utils/sampling_Ustate.py index f9452d4..fb6c86b 100644 --- a/netket_fidelity/utils/sampling_Ustate.py +++ b/netket_fidelity/utils/sampling_Ustate.py @@ -40,7 +40,9 @@ def _logpsi_U_fun(apply_fun, variables, x, *args): variables_applyfun, U = flax.core.pop(variables, "unitary") xp, mels = U.get_conn_padded(x) + xp = xp.reshape(-1, x.shape[-1]) logpsi_xp = apply_fun(variables_applyfun, xp, *args) + logpsi_xp = logpsi_xp.reshape(x.shape[0], -1) return jax.scipy.special.logsumexp(logpsi_xp, axis=-1, b=mels)