Skip to content

Commit

Permalink
Improve mid-circuit measurement conversion abilities (#417)
Browse files Browse the repository at this point in the history
* add support for conditional ops

* happy `codefactor`

* add logic for `IfElseOp`

* add support of `SwitchCaseOp`

* happy `codefactor`

* fix `control_values`

* minor tweaks

* codefactor?

* codefactor?

* `changelog`

* minor tweaks

* address comments

* minor tweaks

* readying master merging

* apply suggestions

* happy `black`

* minor tweak

* happy `black`

* address comments

* fix tests?

* minor tweak
  • Loading branch information
obliviateandsurrender authored Feb 21, 2024
1 parent 307c6c4 commit dc8ec85
Show file tree
Hide file tree
Showing 4 changed files with 335 additions and 88 deletions.
7 changes: 6 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,17 @@
[(#406)](https://github.com/PennyLaneAI/pennylane-qiskit/pull/406)
[(#428)](https://github.com/PennyLaneAI/pennylane-qiskit/pull/428)

* Measurement operations are now added to the PennyLane template when a `QuantumCircuit`
* Measurement operations are now added to the PennyLane template when a ``QuantumCircuit``
is converted using `load`. Additionally, one can override any existing terminal
measurements by providing a list of PennyLane
`measurements <https://docs.pennylane.ai/en/stable/introduction/measurements.html>`_ themselves.
[(#405)](https://github.com/PennyLaneAI/pennylane-qiskit/pull/405)

* Added support for coverting conditional operations based on mid-circuit measurements and
two of the ``ControlFlowOp`` operations - ``IfElseOp`` and ``SwitchCaseOp`` when converting
a ``QuantumCircuit`` using `load`.
[(#417)](https://github.com/PennyLaneAI/pennylane-qiskit/pull/417)

### Breaking changes 💔

### Deprecations 👋
Expand Down
279 changes: 200 additions & 79 deletions pennylane_qiskit/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,13 @@
"""
from typing import Dict, Any
import warnings
from functools import partial, reduce

import numpy as np
from qiskit import QuantumCircuit
from qiskit.circuit import Parameter, ParameterExpression, ParameterVector, Measure, Barrier
from qiskit.circuit import Parameter, ParameterExpression, ParameterVector
from qiskit.circuit import Measure, Barrier, ControlFlowOp
from qiskit.circuit.controlflow.switch_case import _DefaultCaseType
from qiskit.circuit.library import GlobalPhaseGate
from qiskit.exceptions import QiskitError
from sympy import lambdify
Expand Down Expand Up @@ -54,6 +57,31 @@ def _check_parameter_bound(param: Parameter, unbound_params: Dict[Parameter, Any
raise ValueError(f"The parameter {param} was not bound correctly.".format(param))


def _process_basic_param_args(params, *args, **kwargs):
"""Process the basic conditions for parameter dictionary computation.
Returns:
params (dict): A dictionary mapping ``quantum_circuit.parameters`` to values
flag (bool): Indicating whether the returned ``params`` can be used.
"""

# if no kwargs are passed, and a dictionary has been passed as a single argument, then assume it is params
if params is None and not kwargs and (len(args) == 1 and isinstance(args[0], dict)):
return (args[0], True)

if not args and not kwargs:
return (params, True)

# make params dict if using args and/or kwargs
if params is not None:
raise RuntimeError(
"Cannot define parameters via the params kwarg when passing Parameter values "
"as individual args or kwargs."
)

return ({}, False)


def _expected_parameters(quantum_circuit):
"""Gets the expected parameters and a string of their names from the QuantumCircuit.
Primarily serves to change a list of Parameters and ParameterVectorElements into a list
Expand Down Expand Up @@ -104,22 +132,12 @@ def _format_params_dict(quantum_circuit, params, *args, **kwargs):
params (dict): A dictionary mapping ``quantum_circuit.parameters`` to values
"""

# if no kwargs are passed, and a dictionary has been passed as a single argument, then assume it is params
if params is None and not kwargs and (len(args) == 1 and isinstance(args[0], dict)):
return args[0]
params, flag = _process_basic_param_args(params, *args, **kwargs)

if not args and not kwargs:
if flag:
return params

# make params dict if using args and/or kwargs
if params is not None:
raise RuntimeError(
"Cannot define parameters via the params kwarg when passing Parameter values "
"as individual args or kwargs."
)

expected_params, param_name_string = _expected_parameters(quantum_circuit)
params = {}

# populate it with any parameters defined as kwargs
for k, v in kwargs.items():
Expand Down Expand Up @@ -228,6 +246,40 @@ def _check_circuit_and_assign_parameters(
return quantum_circuit.assign_parameters(params)


def _get_operation_params(instruction, unbound_params) -> list:
"""Extract the bound parameters from the operation.
If the bound parameters are a Qiskit ParameterExpression, then replace it with
the corresponding PennyLane variable from the unbound_params dictionary.
Args:
instruction (qiskit.circuit.Instruction): a qiskit's quantum circuit instruction
unbound_params dict[qiskit.circuit.Parameter, Any]: a dictionary mapping
qiskit parameters to trainable parameter values
Returns:
list: bound parameters of the given instruction
"""
operation_params = []
for p in instruction.params:
_check_parameter_bound(p, unbound_params)

if isinstance(p, ParameterExpression):
if p.parameters: # non-empty set = has unbound parameters
ordered_params = tuple(p.parameters)
f = lambdify(ordered_params, getattr(p, "_symbol_expr"), modules=qml.numpy)
f_args = []
for i_ordered_params in ordered_params:
f_args.append(unbound_params.get(i_ordered_params))
operation_params.append(f(*f_args))
else: # needed for qiskit<0.43.1
operation_params.append(float(p)) # pragma: no cover
else:
operation_params.append(p)

return operation_params


def map_wires(qc_wires: list, wires: list) -> dict:
"""Utility function mapping the wires specified in a quantum circuit with the wires
specified by the user for the template.
Expand All @@ -251,24 +303,7 @@ def map_wires(qc_wires: list, wires: list) -> dict:
)


def execute_supported_operation(operation_name: str, parameters: list, wires: list):
"""Utility function that executes an operation that is natively supported by PennyLane.
Args:
operation_name (str): Name of the PL operator to be executed
parameters (str): parameters of the operation that will be executed
wires (list): wires of the operation
"""
operation = getattr(pennylane_ops, operation_name)

if not parameters:
operation(wires=wires)
elif operation_name in ["QubitStateVector", "StatePrep"]:
operation(np.array(parameters), wires=wires)
else:
operation(*parameters, wires=wires)


# pylint:disable=too-many-statements, too-many-branches
def load(quantum_circuit: QuantumCircuit, measurements=None):
"""Loads a PennyLane template from a Qiskit QuantumCircuit.
Warnings are created for each of the QuantumCircuit instructions that were
Expand All @@ -284,7 +319,7 @@ def load(quantum_circuit: QuantumCircuit, measurements=None):
function: the resulting PennyLane template
"""

# pylint:disable=too-many-branches
# pylint:disable=too-many-branches, fixme, protected-access
def _function(*args, params: dict = None, wires: list = None, **kwargs):
"""Returns a PennyLane quantum function created based on the input QuantumCircuit.
Warnings are created for each of the QuantumCircuit instructions that were
Expand Down Expand Up @@ -355,7 +390,8 @@ def _function(*args, params: dict = None, wires: list = None, **kwargs):
"""

# organize parameters, format trainable parameter values correctly, and then bind the parameters to the circuit
# organize parameters, format trainable parameter values correctly,
# and then bind the parameters to the circuit
params = _format_params_dict(quantum_circuit, params, *args, **kwargs)
unbound_params = _extract_variable_refs(params)
qc = _check_circuit_and_assign_parameters(quantum_circuit, params, unbound_params)
Expand All @@ -366,56 +402,46 @@ def _function(*args, params: dict = None, wires: list = None, **kwargs):
wire_map = map_wires(qc_wires, wires)

# Stores the measurements encountered in the circuit
mid_circ_meas, terminal_meas = [], []
# terminal_meas / mid_circ_meas -> terminal / mid-circuit measurements
# mid_circ_regs -> maps the classical registers to the measurements done
terminal_meas, mid_circ_meas = [], []
mid_circ_regs = {}

# Processing the dictionary of parameters passed
for idx, (op, qargs, _) in enumerate(qc.data):
# the new Singleton classes have different names than the objects they represent, but base_class.__name__ still matches
instruction_name = getattr(op, "base_class", op.__class__).__name__

operation_wires = [wire_map[hash(qubit)] for qubit in qargs]

for idx, circuit_instruction in enumerate(qc.data):
(instruction, qargs, cargs) = circuit_instruction
# the new Singleton classes have different names than the objects they represent,
# but base_class.__name__ still matches
instruction_name = getattr(instruction, "base_class", instruction.__class__).__name__
# New Qiskit gates that are not natively supported by PL (identical
# gates exist with a different name)
# TODO: remove the following when gates have been renamed in PennyLane
instruction_name = "U3Gate" if instruction_name == "UGate" else instruction_name

# pylint:disable=protected-access
if (
instruction_name in inv_map
and inv_map[instruction_name] in pennylane_ops._qubit__ops__
):
# Extract the bound parameters from the operation. If the bound parameters are a
# Qiskit ParameterExpression, then replace it with the corresponding PennyLane
# variable from the unbound_params dictionary.

pl_parameters = []
for p in op.params:
_check_parameter_bound(p, unbound_params)

if isinstance(p, ParameterExpression):
if p.parameters: # non-empty set = has unbound parameters
ordered_params = tuple(p.parameters)

f = lambdify(ordered_params, p._symbol_expr, modules=qml.numpy)
f_args = []
for i_ordered_params in ordered_params:
f_args.append(unbound_params.get(i_ordered_params))
pl_parameters.append(f(*f_args))
else: # needed for qiskit<0.43.1
pl_parameters.append(float(p)) # pragma: no cover
else:
pl_parameters.append(p)

execute_supported_operation(
inv_map[instruction_name], pl_parameters, operation_wires
)
# Define operator builders and helpers
# operation_class -> PennyLane operation class object mapped from the Qiskit operation
# operation_args and operation_kwargs -> Parameters required for the
# instantiation of `operation_class`
operation_class = None
operation_wires = [wire_map[hash(qubit)] for qubit in qargs]
operation_kwargs = {"wires": operation_wires}
operation_args = []

# Extract the bound parameters from the operation. If the bound parameters are a
# Qiskit ParameterExpression, then replace it with the corresponding PennyLane
# variable from the unbound_params dictionary.
operation_params = _get_operation_params(instruction, unbound_params)

if instruction_name in dagger_map:
operation_class = qml.adjoint(dagger_map[instruction_name])

elif instruction_name in dagger_map:
gate = dagger_map[instruction_name]
qml.adjoint(gate)(wires=operation_wires)
elif instruction_name in inv_map:
operation_class = getattr(pennylane_ops, inv_map[instruction_name])
operation_args.extend(operation_params)
if operation_class in (qml.QubitStateVector, qml.StatePrep):
operation_args = [np.array(operation_params)]

elif isinstance(op, Measure):
elif isinstance(instruction, Measure):
# Store the current operation wires
op_wires = set(operation_wires)
# Look-ahead for more gate(s) on its wire(s)
Expand All @@ -429,22 +455,82 @@ def _function(*args, params: dict = None, wires: list = None, **kwargs):
meas_terminal = False
break

# Allows for adding terminal measurements
if meas_terminal:
terminal_meas.extend(operation_wires)

# Allows for queing the mid-circuit measurements
if not meas_terminal:
mid_circ_meas.append(qml.measure(wires=operation_wires))
else:
terminal_meas.extend(operation_wires)
operation_class = qml.measure
mid_circ_meas.append(qml.measure(wires=operation_wires))

# Allows for tracking conditional operations
for carg in cargs:
mid_circ_regs[carg] = mid_circ_meas[-1]

else:

try:
operation_matrix = op.to_matrix()
pennylane_ops.QubitUnitary(operation_matrix, wires=operation_wires)
if not isinstance(instruction, (ControlFlowOp,)):
operation_args = [instruction.to_matrix()]
operation_class = qml.QubitUnitary

except (AttributeError, QiskitError):
warnings.warn(
f"{__name__}: The {instruction_name} instruction is not supported by PennyLane,"
" and has not been added to the template.",
UserWarning,
)

# Check if it is a conditional operation or conditional instruction
instruction_cond = instruction.condition and instruction.condition[0] in mid_circ_regs
if instruction_cond or isinstance(instruction, ControlFlowOp):
# Iteratively recurse over to build different branches
with qml.QueuingManager.stop_recording():
branch_funcs = [
partial(load(branch_inst, measurements=None), params=params, wires=wires)
for branch_inst in operation_params
if isinstance(branch_inst, QuantumCircuit)
]

# Get the functions for handling condition
true_fn, false_fn, elif_fns, cond_op = _conditional_funcs(
instruction, cargs, operation_class, branch_funcs, instruction_name
)
res_reg, res_bit = cond_op

# Check for elif branches (doesn't require qjit)
if elif_fns:
m_val = sum(2**idx * mid_circ_regs[clbit] for idx, clbit in enumerate(res_reg))
for elif_bit, elif_branch in elif_fns:
qml.cond(m_val == elif_bit, elif_branch)(
*operation_args, **operation_kwargs
)

# Check if just conditional requires some extra work
if isinstance(res_bit, str):
# Handles the default case in the SwitchCaseOp
if res_bit == "SwitchDefault":
elif_bits = [elif_bit for (elif_bit, _) in elif_fns]
qml.cond(
reduce(
lambda m0, m1: m0 & m1,
[(m_val != elif_bit) for elif_bit in elif_bits],
),
true_fn,
)(*operation_args, **operation_kwargs)
# Just do the routine conditional
else:
qml.cond(
mid_circ_regs[res_reg] == res_bit,
true_fn,
false_fn,
)(*operation_args, **operation_kwargs)

# Check if it is not a mid-circuit measurement
elif operation_class and not isinstance(instruction, Measure):
operation_class(*operation_args, **operation_kwargs)

# Use the user-provided measurements
if measurements:
if qml.queuing.QueuingManager.active_context():
Expand Down Expand Up @@ -474,3 +560,38 @@ def load_qasm_from_file(file: str):
function: the new PennyLane template
"""
return load(QuantumCircuit.from_qasm_file(file))


# pylint:disable=fixme, protected-access
def _conditional_funcs(ops, cargs, operation_class, branch_funcs, ctrl_flow_type):
"""Builds the conditional functions for Controlled flows
This method returns the arguments to be used by the `qml.cond`
for creating a classically controlled flow.
These are the branches (`true_fn`, `false_fn`, `elif_fns`) and
the qiskit's classical condition, which has to be converted to
the corresponding PennyLane mid-circuit measurement.
"""
true_fn, false_fn, elif_fns = operation_class, None, ()
# Logic for using legacy c_if
if not isinstance(ops, ControlFlowOp):
return true_fn, false_fn, elif_fns, ops.condition

# Logic for handling IfElseOp
if ctrl_flow_type == "IfElseOp":
true_fn = branch_funcs[0]
if len(branch_funcs) == 2:
false_fn = branch_funcs[1]

# Logic for handling SwitchCaseOp
elif ctrl_flow_type == "SwitchCaseOp":
elif_fns = []
for case, res_bit in ops._case_map.items():
if not isinstance(case, _DefaultCaseType):
elif_fns.append((case, branch_funcs[res_bit]))
ops.condition = [tuple(cargs), "SwitchCase"]
if any((isinstance(case, _DefaultCaseType) for case in ops._case_map)):
true_fn = branch_funcs[-1]
ops.condition = [tuple(cargs), "SwitchDefault"]

return true_fn, false_fn, elif_fns, ops.condition
Loading

0 comments on commit dc8ec85

Please sign in to comment.