Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

simplify logic of logpsi_U wrapper #11

Merged
merged 1 commit into from
Nov 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 3 additions & 5 deletions netket_fidelity/infidelity/overlap/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,13 @@

import jax.numpy as jnp

import flax

from netket import jax as nkjax
from netket.operator import AbstractOperator, DiscreteJaxOperator
from netket.utils.types import DType
from netket.utils.numbers import is_scalar
from netket.vqs import VariationalState, MCState, FullSumState

from netket_fidelity.utils.sampling_Ustate import _logpsi_U
from netket_fidelity.utils.sampling_Ustate import make_logpsi_U_afun


class InfidelityOperatorStandard(AbstractOperator):
Expand Down Expand Up @@ -72,12 +70,12 @@ def InfidelityUPsi(
"an instance of DiscreteJaxOperator."
)

logpsiU = nkjax.HashablePartial(_logpsi_U, state._apply_fun)
logpsiU, variables_U = make_logpsi_U_afun(state._apply_fun, U, state.variables)
target = MCState(
sampler=state.sampler,
apply_fun=logpsiU,
n_samples=state.n_samples,
variables=flax.core.copy(state.variables, {"unitary": U}),
variables=variables_U,
)

return InfidelityOperatorStandard(target, cv_coeff=cv_coeff, dtype=dtype)
2 changes: 1 addition & 1 deletion netket_fidelity/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .expect import expect_2distr
from .sampling_Ustate import _logpsi_U
from .sampling_Ustate import make_logpsi_U_afun, _logpsi_U_fun

from netket.utils import _hide_submodules

Expand Down
32 changes: 30 additions & 2 deletions netket_fidelity/utils/sampling_Ustate.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,37 @@
import jax

import flax

from netket import jax as nkjax


def make_logpsi_U_afun(logpsi_fun, U, variables):
"""Wraps an apply_fun into another one that multiplies it by an
Unitary transformation U.

This wrapper is made such that the Unitary is passed as the model_state
of the new wrapped function, and therefore changes to the angles/coefficients
of the Unitary should not trigger recompilation.

Args:
logpsi_fun: a function that takes as input variables and samples
U: a {class}`nk.operator.JaxDiscreteOperator`
variables: The variables used to call *logpsi_fun*

Returns:
A tuple, where the first element is a new function with the same signature as
the original **logpsi_fun** and a set of new variables to be used to call it.
"""
# wrap apply_fun into logpsi logpsi_U
logpsiU_fun = nkjax.HashablePartial(_logpsi_U_fun, logpsi_fun)

# Insert a new 'model_state' key to store the Unitary. This only works
# if U is a pytree that can be flattened/unflattened.
new_variables = flax.core.copy(variables, {"unitary": U})

return logpsiU_fun, new_variables


def _logpsi_U(apply_fun, variables, x, *args):
def _logpsi_U_fun(apply_fun, variables, x, *args):
"""
This should be used as a wrapper to the original apply function, adding
to the `variables` dictionary (in model_state) a new key `unitary` with
Expand Down
Loading