Skip to content

Commit

Permalink
jitting fullsumstate with U
Browse files Browse the repository at this point in the history
  • Loading branch information
alleSini99 committed Nov 2, 2023
1 parent 133d3f7 commit 1f4e998
Showing 1 changed file with 14 additions and 7 deletions.
21 changes: 14 additions & 7 deletions netket_fidelity/infidelity/overlap_U/exact.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import jax.numpy as jnp
import jax
from functools import partial

from netket import jax as nkjax
from netket.utils.dispatch import TrueT
Expand All @@ -21,13 +22,15 @@ def infidelity(vstate: FullSumState, op: InfidelityOperatorUPsi):
if not isinstance(op.target, FullSumState):
raise TypeError("Can only compute infidelity of exact states.")

U_sp = sparsify(op._U)
Ustate_t = U_sp @ op.target.to_array(normalize=False)

return infidelity_sampling_FullSumState(
vstate._apply_fun,
vstate.parameters,
vstate.model_state,
vstate._all_states,
op.target.to_array(),
op._U.to_sparse(),
Ustate_t,
return_grad=False,
)

Expand All @@ -45,30 +48,34 @@ def infidelity( # noqa: F811
if not isinstance(op.target, FullSumState):
raise TypeError("Can only compute infidelity of exact states.")

U_sp = sparsify(op._U)
Ustate_t = U_sp @ op.target.to_array(normalize=False)

return infidelity_sampling_FullSumState(
vstate._apply_fun,
vstate.parameters,
vstate.model_state,
vstate._all_states,
op.target.to_array(),
op._U.to_sparse(),
Ustate_t,
return_grad=True,
)


@partial(jax.jit, static_argnames=("afun", "return_grad"))
def infidelity_sampling_FullSumState(
afun,
params,
model_state,
sigma,
state_t,
U_sp,
Ustate_t,
return_grad,
):
def expect_fun(params):
state = jnp.exp(afun({"params": params, **model_state}, sigma))
state = state / jnp.sqrt(jnp.sum(jnp.abs(state) ** 2))
return jnp.abs(state.T.conj() @ (U_sp @ state_t)) ** 2
return jnp.abs(state.T.conj().T @ Ustate_t) ** 2 / (
Ustate_t.conj().T @ Ustate_t
)

if not return_grad:
F = expect_fun(params)
Expand Down

0 comments on commit 1f4e998

Please sign in to comment.