diff --git a/pyqtorch/differentiation/gpsr.py b/pyqtorch/differentiation/gpsr.py index 0ecf038d..13c76488 100644 --- a/pyqtorch/differentiation/gpsr.py +++ b/pyqtorch/differentiation/gpsr.py @@ -273,18 +273,18 @@ def vjp( shift_prefac=shift_prefac, ) - grads = {p: None for p in ctx.param_names} + grads = {p: torch.zeros_like(v) for p, v in values.items()} def update_gradient(param_name: str, spectral_gap: Tensor, shift_prefac: float): + """Update gradient of a parameter using PSR. + + Args: + param_name (str): Parameter name to compute gradient over. + spectral_gap (Tensor): Spectral gap of the corresponding operation. + shift_prefac (float): Shift prefactor value for PSR shifts. + """ if values[param_name].requires_grad: - if grads[param_name] is not None: - grads[param_name] += vjp( - param_name, spectral_gap, values, shift_prefac - ) - else: - grads[param_name] = vjp( - param_name, spectral_gap, values, shift_prefac - ) + grads[param_name] = vjp(param_name, spectral_gap, values, shift_prefac) for op in ctx.circuit.flatten(): @@ -292,10 +292,21 @@ def update_gradient(param_name: str, spectral_gap: Tensor, shift_prefac: float): op.param_name, str ): factor = 1.0 if isinstance(op, Parametric) else 2.0 - shift_prefac = 1.0 / factor if len(op.spectral_gap) > 1: - shift_prefac = 0.5 - update_gradient(op.param_name, factor * op.spectral_gap, shift_prefac) + update_gradient(op.param_name, factor * op.spectral_gap, 0.5) + else: + shift_factor = 1.0 + # note the spectral gap can be empty + # this is handled in single-gap PSR + if isinstance(op, HamiltonianEvolution): + shift_factor = ( + 1.0 / (op.spectral_gap.item() * factor) + if len(op.spectral_gap) == 1 + else 1.0 + ) + update_gradient( + op.param_name, factor * op.spectral_gap, shift_factor + ) return ( None,