Skip to content

Commit

Permalink
[Bugfix] Adjust shift with hamevo via a prefactor (#259)
Browse files Browse the repository at this point in the history
Fixes #258 where not only one should multiply by 2 the spectral gap, but
also adjust the shift accordingly for single gap cases.
  • Loading branch information
chMoussa authored Aug 9, 2024
1 parent 8ec5df1 commit b356f4a
Showing 1 changed file with 22 additions and 10 deletions.
32 changes: 22 additions & 10 deletions pyqtorch/differentiation/gpsr.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,6 @@ def backward(ctx: Any, grad_out: Tensor) -> Tuple[None, ...]:
pass

shift_pi2 = torch.tensor(torch.pi, dtype=dtype_values) / 2.0
shift_multi = 0.5

def expectation_fn(values: dict[str, Tensor]) -> Tensor:
"""Use the PSRExpectation for nested grad calls.
Expand All @@ -156,23 +155,24 @@ def single_gap_shift(
param_name: str,
values: dict[str, Tensor],
spectral_gap: Tensor,
shift: Tensor = shift_pi2,
shift_prefac: float = 1.0,
) -> Tensor:
"""Implements single gap PSR rule.
Args:
param_name: Name of the parameter to apply PSR.
values: Dictionary with parameter values.
spectral_gap: Spectral gap value for PSR.
shift: Shift value. Defaults to torch.tensor(torch.pi)/2.0.
shift_prefac: Shift prefactor value to multiply pi/2.
Defaults to 1.
Returns:
Gradient evaluation for param_name.
"""

# device conversions
spectral_gap = spectral_gap.to(device=device)
shift = shift.to(device=device)
shift = shift_pi2.to(device=device) * shift_prefac

# apply shift rule
shifted_values = values.copy()
Expand All @@ -190,7 +190,7 @@ def multi_gap_shift(
param_name: str,
values: dict[str, Tensor],
spectral_gaps: Tensor,
shift_prefac: float = shift_multi,
shift_prefac: float = 0.5,
) -> Tensor:
"""Implement multi gap PSR rule.
Expand Down Expand Up @@ -248,14 +248,18 @@ def multi_gap_shift(
return dfdx

def vjp(
param_name: str, spectral_gap: Tensor, values: dict[str, Tensor]
param_name: str,
spectral_gap: Tensor,
values: dict[str, Tensor],
shift_prefac: float,
) -> Tensor:
"""Vector-jacobian product between `grad_out` and jacobians of parameters.
Args:
param_name: Parameter name to compute gradient over.
spectral_gap: Spectral gap of the corresponding operation.
values: Dictionary with parameter values.
shift_prefac: Shift prefactor value for PSR shifts.
Returns:
Updated jacobian by PSR.
Expand All @@ -266,24 +270,32 @@ def vjp(
param_name, # type: ignore
values,
spectral_gap,
shift_prefac=shift_prefac,
)

grads = {p: None for p in ctx.param_names}

def update_gradient(param_name: str, spectral_gap: Tensor):
def update_gradient(param_name: str, spectral_gap: Tensor, shift_prefac: float):
if values[param_name].requires_grad:
if grads[param_name] is not None:
grads[param_name] += vjp(param_name, spectral_gap, values)
grads[param_name] += vjp(
param_name, spectral_gap, values, shift_prefac
)
else:
grads[param_name] = vjp(param_name, spectral_gap, values)
grads[param_name] = vjp(
param_name, spectral_gap, values, shift_prefac
)

for op in ctx.circuit.flatten():

if isinstance(op, (Parametric, HamiltonianEvolution)) and isinstance(
op.param_name, str
):
factor = 1.0 if isinstance(op, Parametric) else 2.0
update_gradient(op.param_name, factor * op.spectral_gap)
shift_prefac = 1.0 / factor
if len(op.spectral_gap) > 1:
shift_prefac = 0.5
update_gradient(op.param_name, factor * op.spectral_gap, shift_prefac)

return (
None,
Expand Down

0 comments on commit b356f4a

Please sign in to comment.