diff --git a/pyqtorch/differentiation/gpsr.py b/pyqtorch/differentiation/gpsr.py index 8009401d..0ecf038d 100644 --- a/pyqtorch/differentiation/gpsr.py +++ b/pyqtorch/differentiation/gpsr.py @@ -131,7 +131,6 @@ def backward(ctx: Any, grad_out: Tensor) -> Tuple[None, ...]: pass shift_pi2 = torch.tensor(torch.pi, dtype=dtype_values) / 2.0 - shift_multi = 0.5 def expectation_fn(values: dict[str, Tensor]) -> Tensor: """Use the PSRExpectation for nested grad calls. @@ -156,7 +155,7 @@ def single_gap_shift( param_name: str, values: dict[str, Tensor], spectral_gap: Tensor, - shift: Tensor = shift_pi2, + shift_prefac: float = 1.0, ) -> Tensor: """Implements single gap PSR rule. @@ -164,7 +163,8 @@ def single_gap_shift( param_name: Name of the parameter to apply PSR. values: Dictionary with parameter values. spectral_gap: Spectral gap value for PSR. - shift: Shift value. Defaults to torch.tensor(torch.pi)/2.0. + shift_prefac: Shift prefactor value to multiply pi/2. + Defaults to 1. Returns: Gradient evaluation for param_name. @@ -172,7 +172,7 @@ def single_gap_shift( # device conversions spectral_gap = spectral_gap.to(device=device) - shift = shift.to(device=device) + shift = shift_pi2.to(device=device) * shift_prefac # apply shift rule shifted_values = values.copy() @@ -190,7 +190,7 @@ def multi_gap_shift( param_name: str, values: dict[str, Tensor], spectral_gaps: Tensor, - shift_prefac: float = shift_multi, + shift_prefac: float = 0.5, ) -> Tensor: """Implement multi gap PSR rule. @@ -248,7 +248,10 @@ def multi_gap_shift( return dfdx def vjp( - param_name: str, spectral_gap: Tensor, values: dict[str, Tensor] + param_name: str, + spectral_gap: Tensor, + values: dict[str, Tensor], + shift_prefac: float, ) -> Tensor: """Vector-jacobian product between `grad_out` and jacobians of parameters. @@ -256,6 +259,7 @@ def vjp( param_name: Parameter name to compute gradient over. spectral_gap: Spectral gap of the corresponding operation. values: Dictionary with parameter values. + shift_prefac: Shift prefactor value for PSR shifts. Returns: Updated jacobian by PSR. @@ -266,16 +270,21 @@ def vjp( param_name, # type: ignore values, spectral_gap, + shift_prefac=shift_prefac, ) grads = {p: None for p in ctx.param_names} - def update_gradient(param_name: str, spectral_gap: Tensor): + def update_gradient(param_name: str, spectral_gap: Tensor, shift_prefac: float): if values[param_name].requires_grad: if grads[param_name] is not None: - grads[param_name] += vjp(param_name, spectral_gap, values) + grads[param_name] += vjp( + param_name, spectral_gap, values, shift_prefac + ) else: - grads[param_name] = vjp(param_name, spectral_gap, values) + grads[param_name] = vjp( + param_name, spectral_gap, values, shift_prefac + ) for op in ctx.circuit.flatten(): @@ -283,7 +292,10 @@ def update_gradient(param_name: str, spectral_gap: Tensor): op.param_name, str ): factor = 1.0 if isinstance(op, Parametric) else 2.0 - update_gradient(op.param_name, factor * op.spectral_gap) + shift_prefac = 1.0 / factor + if len(op.spectral_gap) > 1: + shift_prefac = 0.5 + update_gradient(op.param_name, factor * op.spectral_gap, shift_prefac) return ( None,