Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add VJP/JVP capabilities to DefaultQubit2 #4374

Merged
merged 22 commits into from
Jul 28, 2023
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
285 changes: 267 additions & 18 deletions pennylane/devices/experimental/default_qubit_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
"""

from functools import partial
from numbers import Number
from typing import Union, Callable, Tuple, Optional, Sequence
import concurrent.futures
import os
Expand All @@ -29,9 +30,9 @@

from . import Device
from .execution_config import ExecutionConfig, DefaultExecutionConfig
from ..qubit.simulate import simulate
from ..qubit.simulate import simulate, get_final_state, measure_final_state
from ..qubit.preprocess import preprocess, validate_and_expand_adjoint
from ..qubit.adjoint_jacobian import adjoint_jacobian
from ..qubit.adjoint_jacobian import adjoint_jacobian, adjoint_vjp, adjoint_jvp

Result_or_ResultBatch = Union[Result, ResultBatch]
QuantumTapeBatch = Sequence[QuantumTape]
Expand Down Expand Up @@ -169,9 +170,13 @@ def supports_derivatives(
):
return True

if execution_config.gradient_method == "adjoint":
if (
execution_config.gradient_method == "adjoint"
and execution_config.use_device_gradient in [None, True]
):
if circuit is None:
return True

return isinstance(validate_and_expand_adjoint(circuit), QuantumScript)

return False
Expand Down Expand Up @@ -264,24 +269,247 @@ def compute_derivatives(
self.tracker.update(derivative_batches=1, derivatives=len(circuits))
self.tracker.record()

if execution_config.gradient_method == "adjoint":
max_workers = self._get_max_workers(execution_config)
if max_workers is None:
res = tuple(adjoint_jacobian(circuit) for circuit in circuits)
else:
vanilla_circuits = [convert_to_numpy_parameters(c) for c in circuits]
with concurrent.futures.ProcessPoolExecutor(max_workers=max_workers) as executor:
exec_map = executor.map(adjoint_jacobian, vanilla_circuits)
res = tuple(circuit for circuit in exec_map)
max_workers = self._get_max_workers(execution_config)
if max_workers is None:
res = tuple(adjoint_jacobian(circuit) for circuit in circuits)
else:
vanilla_circuits = [convert_to_numpy_parameters(c) for c in circuits]
with concurrent.futures.ProcessPoolExecutor(max_workers=max_workers) as executor:
exec_map = executor.map(adjoint_jacobian, vanilla_circuits)
res = tuple(circuit for circuit in exec_map)

# reset _rng to mimic serial behavior
self._rng = np.random.default_rng(self._rng.integers(2**31 - 1))
# reset _rng to mimic serial behavior
self._rng = np.random.default_rng(self._rng.integers(2**31 - 1))

return res[0] if is_single_circuit else res
return res[0] if is_single_circuit else res

raise NotImplementedError(
f"{self.name} cannot compute derivatives via {execution_config.gradient_method}"
)
def execute_and_compute_derivatives(
self,
circuits: QuantumTape_or_Batch,
execution_config: ExecutionConfig = DefaultExecutionConfig,
):
is_single_circuit = False
if isinstance(circuits, QuantumScript):
is_single_circuit = True
circuits = [circuits]

if self.tracker.active:
for c in circuits:
self.tracker.update(resources=c.specs["resources"])
self.tracker.update(
execute_and_derivative_batches=1,
executions=len(circuits),
derivatives=len(circuits),
)
self.tracker.record()

max_workers = self._get_max_workers(execution_config)
if max_workers is None:
results = tuple(
_adjoint_jac_wrapper(c, rng=self._rng, debugger=self._debugger) for c in circuits
)
results, jacs = tuple(zip(*results))
else:
self._validate_multiprocessing_circuits(circuits)

vanilla_circuits = [convert_to_numpy_parameters(c) for c in circuits]
seeds = self._rng.integers(2**31 - 1, size=len(vanilla_circuits))

with concurrent.futures.ProcessPoolExecutor(max_workers=max_workers) as executor:
results = tuple(executor.map(_adjoint_jac_wrapper, vanilla_circuits, seeds))

results, jacs = tuple(zip(*results))

# reset _rng to mimic serial behavior
self._rng = np.random.default_rng(self._rng.integers(2**31 - 1))

return (results[0], jacs[0]) if is_single_circuit else (results, jacs)

def supports_jvp(
self,
execution_config: Optional[ExecutionConfig] = None,
circuit: Optional[QuantumTape] = None,
) -> bool:
"""Whether or not this device defines a custom jacobian vector product.

``DefaultQubit2`` supports backpropagation derivatives with analytic results, as well as
adjoint differentiation.

Args:
execution_config (ExecutionConfig): The configuration of the desired derivative calculation
circuit (QuantumTape): An optional circuit to check derivatives support for.

Returns:
bool: Whether or not a derivative can be calculated provided the given information
"""
return self.supports_derivatives(execution_config, circuit)

def compute_jvp(
self,
circuits: QuantumTape_or_Batch,
tangents: Tuple[Number],
execution_config: ExecutionConfig = DefaultExecutionConfig,
):
is_single_circuit = False
if isinstance(circuits, QuantumScript):
is_single_circuit = True
circuits = [circuits]
tangents = [tangents]

if self.tracker.active:
self.tracker.update(jvp_batches=1, jvps=len(circuits))
self.tracker.record()

max_workers = self._get_max_workers(execution_config)
if max_workers is None:
res = tuple(adjoint_jvp(circuit, tans) for circuit, tans in zip(circuits, tangents))
else:
vanilla_circuits = [convert_to_numpy_parameters(c) for c in circuits]
with concurrent.futures.ProcessPoolExecutor(max_workers=max_workers) as executor:
res = tuple(executor.map(adjoint_jvp, vanilla_circuits, tangents))

# reset _rng to mimic serial behavior
self._rng = np.random.default_rng(self._rng.integers(2**31 - 1))

return res[0] if is_single_circuit else res

def execute_and_compute_jvp(
self,
circuits: QuantumTape_or_Batch,
tangents: Tuple[Number],
execution_config: ExecutionConfig = DefaultExecutionConfig,
):
is_single_circuit = False
if isinstance(circuits, QuantumScript):
is_single_circuit = True
circuits = [circuits]
tangents = [tangents]

if self.tracker.active:
for c in circuits:
self.tracker.update(resources=c.specs["resources"])
self.tracker.update(
execute_and_jvp_batches=1, executions=len(circuits), jvps=len(circuits)
)
self.tracker.record()

max_workers = self._get_max_workers(execution_config)
if max_workers is None:
results = tuple(
_adjoint_jvp_wrapper(c, t, rng=self._rng, debugger=self._debugger)
for c, t in zip(circuits, tangents)
)
results, jvps = tuple(zip(*results))
else:
self._validate_multiprocessing_circuits(circuits)

vanilla_circuits = [convert_to_numpy_parameters(c) for c in circuits]
seeds = self._rng.integers(2**31 - 1, size=len(vanilla_circuits))

with concurrent.futures.ProcessPoolExecutor(max_workers=max_workers) as executor:
results = tuple(
executor.map(_adjoint_jvp_wrapper, vanilla_circuits, tangents, seeds)
)

results, jvps = tuple(zip(*results))

# reset _rng to mimic serial behavior
self._rng = np.random.default_rng(self._rng.integers(2**31 - 1))

return (results[0], jvps[0]) if is_single_circuit else (results, jvps)

def supports_vjp(
self,
execution_config: Optional[ExecutionConfig] = None,
circuit: Optional[QuantumTape] = None,
) -> bool:
"""Whether or not this device defines a custom vector jacobian product.

``DefaultQubit2`` supports backpropagation derivatives with analytic results, as well as
adjoint differentiation.

Args:
execution_config (ExecutionConfig): A description of the hyperparameters for the desired computation.
circuit (None, QuantumTape): A specific circuit to check differentation for.

Returns:
bool: Whether or not a derivative can be calculated provided the given information
"""
return self.supports_derivatives(execution_config, circuit)

def compute_vjp(
self,
circuits: QuantumTape_or_Batch,
cotangents: Tuple[Number],
execution_config: ExecutionConfig = DefaultExecutionConfig,
):
is_single_circuit = False
if isinstance(circuits, QuantumScript):
is_single_circuit = True
circuits = [circuits]
cotangents = [cotangents]

if self.tracker.active:
self.tracker.update(vjp_batches=1, vjps=len(circuits))
self.tracker.record()

max_workers = self._get_max_workers(execution_config)
if max_workers is None:
res = tuple(adjoint_vjp(circuit, cots) for circuit, cots in zip(circuits, cotangents))
else:
vanilla_circuits = [convert_to_numpy_parameters(c) for c in circuits]
with concurrent.futures.ProcessPoolExecutor(max_workers=max_workers) as executor:
res = tuple(executor.map(adjoint_vjp, vanilla_circuits, cotangents))

# reset _rng to mimic serial behavior
self._rng = np.random.default_rng(self._rng.integers(2**31 - 1))

return res[0] if is_single_circuit else res

def execute_and_compute_vjp(
self,
circuits: QuantumTape_or_Batch,
cotangents: Tuple[Number],
execution_config: ExecutionConfig = DefaultExecutionConfig,
):
is_single_circuit = False
if isinstance(circuits, QuantumScript):
is_single_circuit = True
circuits = [circuits]
cotangents = [cotangents]

if self.tracker.active:
for c in circuits:
self.tracker.update(resources=c.specs["resources"])
self.tracker.update(
execute_and_vjp_batches=1, executions=len(circuits), vjps=len(circuits)
)
self.tracker.record()

max_workers = self._get_max_workers(execution_config)
if max_workers is None:
results = tuple(
_adjoint_vjp_wrapper(c, t, rng=self._rng, debugger=self._debugger)
for c, t in zip(circuits, cotangents)
)
results, vjps = tuple(zip(*results))
else:
self._validate_multiprocessing_circuits(circuits)

vanilla_circuits = [convert_to_numpy_parameters(c) for c in circuits]
seeds = self._rng.integers(2**31 - 1, size=len(vanilla_circuits))

with concurrent.futures.ProcessPoolExecutor(max_workers=max_workers) as executor:
results = tuple(
executor.map(_adjoint_vjp_wrapper, vanilla_circuits, cotangents, seeds)
)

results, vjps = tuple(zip(*results))

# reset _rng to mimic serial behavior
self._rng = np.random.default_rng(self._rng.integers(2**31 - 1))

return (results[0], vjps[0]) if is_single_circuit else (results, vjps)

# pylint: disable=missing-function-docstring
def _get_max_workers(self, execution_config=None):
Expand Down Expand Up @@ -352,3 +580,24 @@ def _validate_multiprocessing_workers(max_workers):
environment variable `{varname}={num_threads_suggest}`.""",
UserWarning,
)


def _adjoint_jac_wrapper(c, rng=None, debugger=None):
state, is_state_batched = get_final_state(c, debugger=debugger)
jac = adjoint_jacobian(c, state=state)
res = measure_final_state(c, state, is_state_batched, rng=rng)
return res, jac


def _adjoint_jvp_wrapper(c, t, rng=None, debugger=None):
state, is_state_batched = get_final_state(c, debugger=debugger)
jvp = adjoint_jvp(c, t, state=state)
res = measure_final_state(c, state, is_state_batched, rng=rng)
return res, jvp


def _adjoint_vjp_wrapper(c, t, rng=None, debugger=None):
state, is_state_batched = get_final_state(c, debugger=debugger)
vjp = adjoint_vjp(c, t, state=state)
res = measure_final_state(c, state, is_state_batched, rng=rng)
return res, vjp
2 changes: 1 addition & 1 deletion pennylane/devices/qubit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,4 +35,4 @@
from .measure import measure
from .preprocess import preprocess
from .sampling import sample_state, measure_with_samples
from .simulate import simulate
from .simulate import simulate, get_final_state, measure_final_state
Loading
Loading