diff --git a/cirq-core/cirq/circuits/qasm_output.py b/cirq-core/cirq/circuits/qasm_output.py index 72d5235f759..89cc1d5bf7e 100644 --- a/cirq-core/cirq/circuits/qasm_output.py +++ b/cirq-core/cirq/circuits/qasm_output.py @@ -366,10 +366,10 @@ def on_stuck(bad_op): if should_annotate: output_line_gap(1) if isinstance(main_op, ops.GateOperation): - x = str(main_op.gate).replace('\n', '\n //') + x = str(main_op.gate).replace('\n', '\n// ') output(f'// Gate: {x!s}\n') else: - x = str(main_op).replace('\n', '\n //') + x = str(main_op).replace('\n', '\n// ') output(f'// Operation: {x!s}\n') for qasm in qasms: diff --git a/cirq-core/cirq/contrib/qasm_import/_lexer.py b/cirq-core/cirq/contrib/qasm_import/_lexer.py index c2f3df7ec53..86392aaa7f5 100644 --- a/cirq-core/cirq/contrib/qasm_import/_lexer.py +++ b/cirq-core/cirq/contrib/qasm_import/_lexer.py @@ -33,6 +33,7 @@ def __init__(self): 'creg': 'CREG', 'measure': 'MEASURE', 'reset': 'RESET', + 'gate': 'GATE', 'if': 'IF', '->': 'ARROW', '==': 'EQ', @@ -120,6 +121,10 @@ def t_RESET(self, t): r"""reset""" return t + def t_GATE(self, t): + r"""gate""" + return t + def t_IF(self, t): r"""if""" return t diff --git a/cirq-core/cirq/contrib/qasm_import/_lexer_test.py b/cirq-core/cirq/contrib/qasm_import/_lexer_test.py index ca5cc5803e1..a13aa65f235 100644 --- a/cirq-core/cirq/contrib/qasm_import/_lexer_test.py +++ b/cirq-core/cirq/contrib/qasm_import/_lexer_test.py @@ -159,6 +159,74 @@ def test_creg(): assert token.value == ";" +def test_custom_gate(): + lexer = QasmLexer() + lexer.input('gate name(param1,param2) q1, q2 {X(q1)}') + token = lexer.token() + assert token.type == "GATE" + assert token.value == "gate" + + token = lexer.token() + assert token.type == "ID" + assert token.value == "name" + + token = lexer.token() + assert token.type == "(" + assert token.value == "(" + + token = lexer.token() + assert token.type == "ID" + assert token.value == "param1" + + token = lexer.token() + assert token.type == "," + assert token.value == "," + + token = lexer.token() + assert token.type == "ID" + assert token.value == "param2" + + token = lexer.token() + assert token.type == ")" + assert token.value == ")" + + token = lexer.token() + assert token.type == "ID" + assert token.value == "q1" + + token = lexer.token() + assert token.type == "," + assert token.value == "," + + token = lexer.token() + assert token.type == "ID" + assert token.value == "q2" + + token = lexer.token() + assert token.type == "{" + assert token.value == "{" + + token = lexer.token() + assert token.type == "ID" + assert token.value == "X" + + token = lexer.token() + assert token.type == "(" + assert token.value == "(" + + token = lexer.token() + assert token.type == "ID" + assert token.value == "q1" + + token = lexer.token() + assert token.type == ")" + assert token.value == ")" + + token = lexer.token() + assert token.type == "}" + assert token.value == "}" + + def test_error(): lexer = QasmLexer() lexer.input('θ') diff --git a/cirq-core/cirq/contrib/qasm_import/_parser.py b/cirq-core/cirq/contrib/qasm_import/_parser.py index 3e13f011778..5bac1cdac15 100644 --- a/cirq-core/cirq/contrib/qasm_import/_parser.py +++ b/cirq-core/cirq/contrib/qasm_import/_parser.py @@ -12,15 +12,28 @@ # See the License for the specific language governing permissions and # limitations under the License. +import dataclasses import functools import operator -from typing import Any, Callable, cast, Dict, Iterable, List, Optional, Union, TYPE_CHECKING +from typing import ( + Any, + Callable, + cast, + Dict, + Iterable, + List, + Optional, + Set, + Tuple, + Union, + TYPE_CHECKING, +) import numpy as np import sympy from ply import yacc -from cirq import ops, Circuit, NamedQubit, CX +from cirq import ops, value, Circuit, CircuitOperation, CX, FrozenCircuit, NamedQubit from cirq.circuits.qasm_output import QasmUGate from cirq.contrib.qasm_import._lexer import QasmLexer from cirq.contrib.qasm_import.exception import QasmException @@ -47,6 +60,31 @@ def __init__( self.circuit = c +def _generate_op_qubits(args: List[List[ops.Qid]], lineno: int) -> List[List[ops.Qid]]: + """Generates the Cirq qubits for an operation from the OpenQASM qregs. + + OpenQASM gates can be applied on single qubits and qubit registers. + We represent single qubits as registers of size 1. + Based on the OpenQASM spec (https://arxiv.org/abs/1707.03429), + single qubit arguments can be mixed with qubit registers. + Given quantum registers of length reg_size and single qubits are both + used as arguments, we generate reg_size GateOperations via iterating + through each qubit of the registers 0 to n-1 and use the same one + qubit from the "single-qubit registers" for each operation.""" + reg_sizes = np.unique([len(reg) for reg in args]) + if len(reg_sizes) > 2 or (len(reg_sizes) > 1 and reg_sizes[0] != 1): + raise QasmException( + f"Non matching quantum registers of length {reg_sizes} at line {lineno}" + ) + op_qubits_gen = functools.reduce( + cast(Callable[[List['cirq.Qid'], List['cirq.Qid']], List['cirq.Qid']], np.broadcast), args + ) + op_qubits = [[q] if isinstance(q, ops.Qid) else q for q in op_qubits_gen] + if any(len(set(q)) < len(q) for q in op_qubits): + raise QasmException(f"Overlapping qubits in arguments at line {lineno}") + return op_qubits + + class QasmGateStatement: """Specifies how to convert a call to an OpenQASM gate to a list of `cirq.GateOperation`s. @@ -87,7 +125,7 @@ def _validate_args(self, args: List[List[ops.Qid]], lineno: int): f"got: {len(args)}, at line {lineno}" ) - def _validate_params(self, params: List[float], lineno: int): + def _validate_params(self, params: List[value.TParamVal], lineno: int): if len(params) != self.num_params: raise QasmException( f"{self.qasm_gate} takes {self.num_params} parameter(s), " @@ -95,41 +133,47 @@ def _validate_params(self, params: List[float], lineno: int): ) def on( - self, params: List[float], args: List[List[ops.Qid]], lineno: int + self, params: List[value.TParamVal], args: List[List[ops.Qid]], lineno: int ) -> Iterable[ops.Operation]: self._validate_args(args, lineno) self._validate_params(params, lineno) - reg_sizes = np.unique([len(reg) for reg in args]) - if len(reg_sizes) > 2 or (len(reg_sizes) > 1 and reg_sizes[0] != 1): - raise QasmException( - f"Non matching quantum registers of length {reg_sizes} at line {lineno}" - ) - # the actual gate we'll apply the arguments to might be a parameterized # or non-parameterized gate final_gate: ops.Gate = ( self.cirq_gate if isinstance(self.cirq_gate, ops.Gate) else self.cirq_gate(params) ) - # OpenQASM gates can be applied on single qubits and qubit registers. - # We represent single qubits as registers of size 1. - # Based on the OpenQASM spec (https://arxiv.org/abs/1707.03429), - # single qubit arguments can be mixed with qubit registers. - # Given quantum registers of length reg_size and single qubits are both - # used as arguments, we generate reg_size GateOperations via iterating - # through each qubit of the registers 0 to n-1 and use the same one - # qubit from the "single-qubit registers" for each operation. - op_qubits = functools.reduce( - cast(Callable[[List['cirq.Qid'], List['cirq.Qid']], List['cirq.Qid']], np.broadcast), - args, - ) - for qubits in op_qubits: - if isinstance(qubits, ops.Qid): - yield final_gate.on(qubits) - elif len(np.unique(qubits)) < len(qubits): - raise QasmException(f"Overlapping qubits in arguments at line {lineno}") - else: - yield final_gate.on(*qubits) + for qubits in _generate_op_qubits(args, lineno): + yield final_gate.on(*qubits) + + +@dataclasses.dataclass +class CustomGate: + """Represents an invocation of a user-defined gate. + + The custom gate definition is encoded here as a `FrozenCircuit`, and the + arguments (params and qubits) of the specific invocation of that gate are + stored here too. When `on` is called, we create a CircuitOperation, mapping + the qubits and params to the values provided.""" + + name: str + circuit: FrozenCircuit + params: Tuple[str, ...] + qubits: Tuple[ops.Qid, ...] + + def on( + self, params: List[value.TParamVal], args: List[List[ops.Qid]], lineno: int + ) -> Iterable[ops.Operation]: + if len(params) != len(self.params): + raise QasmException(f"Wrong number of params for '{self.name}' at line {lineno}") + if len(args) != len(self.qubits): + raise QasmException(f"Wrong number of qregs for '{self.name}' at line {lineno}") + for qubits in _generate_op_qubits(args, lineno): + yield CircuitOperation( + self.circuit, + param_resolver={k: v for k, v in zip(self.params, params)}, + qubit_map={k: v for k, v in zip(self.qubits, qubits)}, + ) class QasmParser: @@ -146,6 +190,18 @@ def __init__(self) -> None: self.circuit = Circuit() self.qregs: Dict[str, int] = {} self.cregs: Dict[str, int] = {} + self.gate_set: Dict[str, Union[CustomGate, QasmGateStatement]] = {**self.basic_gates} + """The gates available to use in the circuit, including those from libraries, and + user-defined ones.""" + self.in_custom_gate_scope = False + """This is set to True when the parser is in the middle of parsing a custom gate + definition.""" + self.custom_gate_scoped_params: Set[str] = set() + """The params declared within the current custom gate definition. Empty if not in + custom gate scope.""" + self.custom_gate_scoped_qubits: Dict[str, ops.Qid] = {} + """The qubits declared within the current custom gate definition. Empty if not in + custom gate scope.""" self.qelibinc = False self.lexer = QasmLexer() self.supported_format = False @@ -270,8 +326,6 @@ def __init__(self) -> None: 'tdg': QasmGateStatement(qasm_gate='tdg', num_params=0, num_args=1, cirq_gate=ops.T**-1), } - all_gates = {**basic_gates, **qelib_gates} - tokens = QasmLexer.tokens start = 'start' @@ -296,11 +350,13 @@ def p_qasm_no_format_specified_error(self, p): def p_qasm_include(self, p): """qasm : qasm QELIBINC""" self.qelibinc = True + self.gate_set |= self.qelib_gates p[0] = Qasm(self.supported_format, self.qelibinc, self.qregs, self.cregs, self.circuit) def p_qasm_include_stdgates(self, p): """qasm : qasm STDGATESINC""" self.qelibinc = True + self.gate_set |= self.qelib_gates p[0] = Qasm(self.supported_format, self.qelibinc, self.qregs, self.cregs, self.circuit) def p_qasm_circuit(self, p): @@ -338,6 +394,10 @@ def p_circuit_empty(self, p): """circuit : empty""" p[0] = self.circuit + def p_circuit_gate_def(self, p): + """circuit : gate_def""" + p[0] = self.circuit + # qreg and creg def p_new_reg(self, p): @@ -382,14 +442,13 @@ def p_gate_op_with_params(self, p): self._resolve_gate_operation(args=p[5], gate=p[1], p=p, params=p[3]) def _resolve_gate_operation( - self, args: List[List[ops.Qid]], gate: str, p: Any, params: List[float] + self, args: List[List[ops.Qid]], gate: str, p: Any, params: List[value.TParamVal] ): - gate_set = self.basic_gates if not self.qelibinc else self.all_gates - if gate not in gate_set.keys(): + if gate not in self.gate_set: tip = ", did you forget to include qelib1.inc?" if not self.qelibinc else "" msg = f'Unknown gate "{gate}" at line {p.lineno(1)}{tip}' raise QasmException(msg) - p[0] = gate_set[gate].on(args=args, params=params, lineno=p.lineno(1)) + p[0] = self.gate_set[gate].on(args=args, params=params, lineno=p.lineno(1)) # params : parameter ',' params # | parameter @@ -404,7 +463,8 @@ def p_params_single(self, p): p[0] = [p[1]] # expr : term - # | func '(' expression ')' """ + # | ID + # | func '(' expression ')' # | binary_op # | unary_op @@ -412,6 +472,14 @@ def p_expr_term(self, p): """expr : term""" p[0] = p[1] + def p_expr_identifier(self, p): + """expr : ID""" + if not self.in_custom_gate_scope: + raise QasmException(f"Parameter '{p[1]}' in line {p.lineno(1)} not supported") + if p[1] not in self.custom_gate_scoped_params: + raise QasmException(f"Undefined parameter '{p[1]}' in line {p.lineno(1)}'") + p[0] = sympy.Symbol(p[1]) + def p_expr_parens(self, p): """expr : '(' expr ')'""" p[0] = p[2] @@ -464,6 +532,15 @@ def p_args_single(self, p): def p_quantum_arg_register(self, p): """qarg : ID""" reg = p[1] + if self.in_custom_gate_scope: + if reg not in self.custom_gate_scoped_qubits: + if reg not in self.qregs: + msg = f"Undefined qubit '{reg}'" + else: + msg = f"'{reg}' is a register, not a qubit" + raise QasmException(f"{msg} at line {p.lineno(1)}") + p[0] = [self.custom_gate_scoped_qubits[reg]] + return if reg not in self.qregs.keys(): raise QasmException(f'Undefined quantum register "{reg}" at line {p.lineno(1)}') qubits = [] @@ -492,6 +569,8 @@ def p_quantum_arg_bit(self, p): """qarg : ID '[' NATURAL_NUMBER ']'""" reg = p[1] idx = p[3] + if self.in_custom_gate_scope: + raise QasmException(f"Unsupported indexed qreg '{reg}[{idx}]' at line {p.lineno(1)}") arg_name = self.make_name(idx, reg) if reg not in self.qregs.keys(): raise QasmException(f'Undefined quantum register "{reg}" at line {p.lineno(1)}') @@ -570,6 +649,60 @@ def p_if(self, p): ops.ClassicallyControlledOperation(conditions=conditions, sub_operation=tuple(p[7])[0]) ] + def p_gate_params_multiple(self, p): + """gate_params : ID ',' gate_params""" + self.p_gate_params_single(p) + p[0] += p[3] + + def p_gate_params_single(self, p): + """gate_params : ID""" + self.in_custom_gate_scope = True + self.custom_gate_scoped_params.add(p[1]) + p[0] = [p[1]] + + def p_gate_qubits_multiple(self, p): + """gate_qubits : ID ',' gate_qubits""" + self.p_gate_qubits_single(p) + p[0] += p[3] + + def p_gate_qubits_single(self, p): + """gate_qubits : ID""" + self.in_custom_gate_scope = True + q = NamedQubit(p[1]) + self.custom_gate_scoped_qubits[p[1]] = q + p[0] = [q] + + def p_gate_ops(self, p): + """gate_ops : gate_op gate_ops""" + p[0] = [p[1]] + p[2] + + def p_gate_ops_empty(self, p): + """gate_ops : empty""" + self.in_custom_gate_scope = True + p[0] = [] + + def p_gate_def_parameterized(self, p): + """gate_def : GATE ID '(' gate_params ')' gate_qubits '{' gate_ops '}'""" + self._gate_def(p, has_params=True) + + def p_gate_def(self, p): + """gate_def : GATE ID gate_qubits '{' gate_ops '}'""" + self._gate_def(p, has_params=False) + + def _gate_def(self, p: List[Any], *, has_params: bool): + name = p[2] + gate_params = tuple(p[4]) if has_params else () + offset = 3 if has_params else 0 + gate_qubits = tuple(p[3 + offset]) + gate_ops = p[5 + offset] + circuit = Circuit(gate_ops).freeze() + gate_def = CustomGate(name, circuit, gate_params, gate_qubits) + self.gate_set[name] = gate_def + self.custom_gate_scoped_params.clear() + self.custom_gate_scoped_qubits.clear() + self.in_custom_gate_scope = False + p[0] = gate_def + def p_error(self, p): if p is None: raise QasmException('Unexpected end of file') diff --git a/cirq-core/cirq/contrib/qasm_import/_parser_test.py b/cirq-core/cirq/contrib/qasm_import/_parser_test.py index 3325bcd225a..37d7ce57a88 100644 --- a/cirq-core/cirq/contrib/qasm_import/_parser_test.py +++ b/cirq-core/cirq/contrib/qasm_import/_parser_test.py @@ -12,16 +12,18 @@ # See the License for the specific language governing permissions and # limitations under the License. +import re import textwrap +import warnings from typing import Callable import numpy as np import pytest import sympy - import cirq import cirq.testing as ct +from cirq.testing import consistent_qasm as cq from cirq import Circuit from cirq.circuits.qasm_output import QasmUGate from cirq.contrib.qasm_import import QasmException @@ -186,7 +188,7 @@ def test_CX_gate(): qreg q2[2]; CX q1[0], q1[1]; CX q1, q2[0]; - CX q2, q1; + CX q2, q1; """ parser = QasmParser() @@ -390,7 +392,7 @@ def test_U_angles(): def test_U_gate_zero_params_error(): qasm = """OPENQASM 2.0; - qreg q[2]; + qreg q[2]; U q[1];""" parser = QasmParser() @@ -401,7 +403,7 @@ def test_U_gate_zero_params_error(): def test_U_gate_too_much_params_error(): qasm = """OPENQASM 2.0; - qreg q[2]; + qreg q[2]; U(pi, pi, pi, pi) q[1];""" parser = QasmParser() @@ -520,9 +522,9 @@ def test_rotation_gates(qasm_gate: str, cirq_gate: Callable[[float], cirq.Gate]) def test_rotation_gates_wrong_number_of_args(qasm_gate: str): qasm = f""" OPENQASM 2.0; - include "qelib1.inc"; - qreg q[2]; - {qasm_gate}(pi) q[0], q[1]; + include "qelib1.inc"; + qreg q[2]; + {qasm_gate}(pi) q[0], q[1]; """ parser = QasmParser() @@ -534,9 +536,9 @@ def test_rotation_gates_wrong_number_of_args(qasm_gate: str): @pytest.mark.parametrize('qasm_gate', [g[0] for g in rotation_gates]) def test_rotation_gates_zero_params_error(qasm_gate: str): qasm = f"""OPENQASM 2.0; - include "qelib1.inc"; - qreg q[2]; - {qasm_gate} q[1]; + include "qelib1.inc"; + qreg q[2]; + {qasm_gate} q[1]; """ parser = QasmParser() @@ -584,7 +586,7 @@ def test_measure_individual_bits(): OPENQASM 2.0; include "qelib1.inc"; qreg q1[2]; - creg c1[2]; + creg c1[2]; measure q1[0] -> c1[0]; measure q1[1] -> c1[1]; """ @@ -612,8 +614,8 @@ def test_measure_registers(): qasm = """OPENQASM 2.0; include "qelib1.inc"; qreg q1[3]; - creg c1[3]; - measure q1 -> c1; + creg c1[3]; + measure q1 -> c1; """ parser = QasmParser() @@ -639,10 +641,10 @@ def test_measure_registers(): def test_measure_mismatched_register_size(): qasm = """OPENQASM 2.0; - include "qelib1.inc"; + include "qelib1.inc"; qreg q1[2]; - creg c1[3]; - measure q1 -> c1; + creg c1[3]; + measure q1 -> c1; """ parser = QasmParser() @@ -653,11 +655,11 @@ def test_measure_mismatched_register_size(): def test_measure_to_quantum_register(): qasm = """OPENQASM 2.0; - include "qelib1.inc"; + include "qelib1.inc"; qreg q1[3]; qreg q2[3]; - creg c1[3]; - measure q2 -> q1; + creg c1[3]; + measure q2 -> q1; """ parser = QasmParser() @@ -668,10 +670,10 @@ def test_measure_to_quantum_register(): def test_measure_undefined_classical_bit(): qasm = """OPENQASM 2.0; - include "qelib1.inc"; - qreg q1[3]; - creg c1[3]; - measure q1[1] -> c2[1]; + include "qelib1.inc"; + qreg q1[3]; + creg c1[3]; + measure q1[1] -> c2[1]; """ parser = QasmParser() @@ -682,11 +684,11 @@ def test_measure_undefined_classical_bit(): def test_measure_from_classical_register(): qasm = """OPENQASM 2.0; - include "qelib1.inc"; + include "qelib1.inc"; qreg q1[2]; - creg c1[3]; - creg c2[3]; - measure c1 -> c2; + creg c1[3]; + creg c2[3]; + measure c1 -> c2; """ parser = QasmParser() @@ -698,8 +700,8 @@ def test_measure_from_classical_register(): def test_measurement_bounds(): qasm = """OPENQASM 2.0; qreg q1[3]; - creg c1[3]; - measure q1[0] -> c1[4]; + creg c1[3]; + measure q1[0] -> c1[4]; """ parser = QasmParser() @@ -741,7 +743,7 @@ def test_u1_gate(): OPENQASM 2.0; include "qelib1.inc"; qreg q[1]; - u1(pi / 3.0) q[0]; + u1(pi / 3.0) q[0]; """ parser = QasmParser() @@ -764,7 +766,7 @@ def test_u2_gate(): OPENQASM 2.0; include "qelib1.inc"; qreg q[1]; - u2(2 * pi, pi / 3.0) q[0]; + u2(2 * pi, pi / 3.0) q[0]; """ parser = QasmParser() @@ -787,7 +789,7 @@ def test_id_gate(): OPENQASM 2.0; include "qelib1.inc"; qreg q[2]; - id q; + id q; """ parser = QasmParser() @@ -846,7 +848,7 @@ def test_r_gate(): OPENQASM 2.0; include "qelib1.inc"; qreg q[1]; - r(pi, pi / 2.0) q[0]; + r(pi, pi / 2.0) q[0]; """ parser = QasmParser() @@ -871,9 +873,9 @@ def test_r_gate(): def test_standard_single_qubit_gates_wrong_number_of_args(qasm_gate): qasm = f""" OPENQASM 2.0; - include "qelib1.inc"; - qreg q[2]; - {qasm_gate} q[0], q[1]; + include "qelib1.inc"; + qreg q[2]; + {qasm_gate} q[0], q[1]; """ parser = QasmParser() @@ -889,9 +891,9 @@ def test_standard_single_qubit_gates_wrong_number_of_args(qasm_gate): ) def test_standard_gates_wrong_params_error(qasm_gate: str, num_params: int): qasm = f"""OPENQASM 2.0; - include "qelib1.inc"; - qreg q[2]; - {qasm_gate}(pi, 2*pi, 3*pi, 4*pi, 5*pi) q[1]; + include "qelib1.inc"; + qreg q[2]; + {qasm_gate}(pi, 2*pi, 3*pi, 4*pi, 5*pi) q[1]; """ parser = QasmParser() @@ -903,9 +905,9 @@ def test_standard_gates_wrong_params_error(qasm_gate: str, num_params: int): return qasm = f"""OPENQASM 2.0; - include "qelib1.inc"; - qreg q[2]; - {qasm_gate} q[1]; + include "qelib1.inc"; + qreg q[2]; + {qasm_gate} q[1]; """ parser = QasmParser() @@ -1227,3 +1229,273 @@ def test_openqasm_3_0_scalar_qubit(): ct.assert_same_circuits(parsed_qasm.circuit, expected_circuit) assert parsed_qasm.qregs == {'q': 1} + + +def test_custom_gate(): + qasm = """OPENQASM 2.0; + include "qelib1.inc"; + qreg q[2]; + gate g q0, q1 { + x q0; + y q0; + z q1; + } + g q[0], q[1]; + g q[1], q[0]; + """ + + # The gate definition should translate to this + q0, q1 = cirq.NamedQubit.range(2, prefix='q') + g = cirq.FrozenCircuit(cirq.X(q0), cirq.Y(q0), cirq.Z(q1)) + + # The outer circuit should then translate to this + q_0, q_1 = cirq.NamedQubit.range(2, prefix='q_') # The outer qreg array + expected = cirq.Circuit( + cirq.CircuitOperation(g, qubit_map={q0: q_0, q1: q_1}), + cirq.CircuitOperation(g, qubit_map={q0: q_1, q1: q_0}), + ) + + # Verify + parser = QasmParser() + parsed_qasm = parser.parse(qasm) + assert parsed_qasm.circuit == expected + + # Sanity check that this unrolls to a valid circuit + unrolled_expected = cirq.Circuit( + cirq.X(q_0), cirq.Y(q_0), cirq.Z(q_1), cirq.X(q_1), cirq.Y(q_1), cirq.Z(q_0) + ) + unrolled = cirq.align_left(cirq.unroll_circuit_op(parsed_qasm.circuit, tags_to_check=None)) + assert unrolled == unrolled_expected + + # Sanity check that these have the same unitaries as the QASM. + cq.assert_qiskit_parsed_qasm_consistent_with_unitary(qasm, cirq.unitary(parsed_qasm.circuit)) + cq.assert_qiskit_parsed_qasm_consistent_with_unitary(qasm, cirq.unitary(unrolled)) + + +def test_custom_gate_parameterized(): + qasm = """OPENQASM 2.0; + include "qelib1.inc"; + qreg q[2]; + gate g(p0, p1) q0, q1 { + rx(p0) q0; + ry(p0+p1+3) q0; + rz(p1) q1; + } + g(1,2) q[0], q[1]; + g(0,4) q[1], q[0]; + """ + + # The gate definition should translate to this + p0, p1 = sympy.symbols('p0, p1') + q0, q1 = cirq.NamedQubit.range(2, prefix='q') + g = cirq.FrozenCircuit( + cirq.Rx(rads=p0).on(q0), cirq.Ry(rads=p0 + p1 + 3).on(q0), cirq.Rz(rads=p1).on(q1) + ) + + # The outer circuit should then translate to this + q_0, q_1 = cirq.NamedQubit.range(2, prefix='q_') # The outer qreg array + expected = cirq.Circuit( + cirq.CircuitOperation(g, qubit_map={q0: q_0, q1: q_1}, param_resolver={'p0': 1, 'p1': 2}), + cirq.CircuitOperation(g, qubit_map={q0: q_1, q1: q_0}, param_resolver={'p0': 0, 'p1': 4}), + ) + + # Verify + parser = QasmParser() + parsed_qasm = parser.parse(qasm) + assert parsed_qasm.circuit == expected + + # Sanity check that this unrolls to a valid circuit + unrolled_expected = cirq.Circuit( + cirq.Rx(rads=1).on(q_0), + cirq.Ry(rads=6).on(q_0), + cirq.Rz(rads=2).on(q_1), + cirq.Rx(rads=0).on(q_1), + cirq.Ry(rads=7).on(q_1), + cirq.Rz(rads=4).on(q_0), + ) + unrolled = cirq.align_left(cirq.unroll_circuit_op(parsed_qasm.circuit, tags_to_check=None)) + assert unrolled == unrolled_expected + + # Sanity check that these have the same unitaries as the QASM. + cq.assert_qiskit_parsed_qasm_consistent_with_unitary(qasm, cirq.unitary(parsed_qasm.circuit)) + cq.assert_qiskit_parsed_qasm_consistent_with_unitary(qasm, cirq.unitary(unrolled)) + + +def test_custom_gate_broadcast(): + qasm = """OPENQASM 2.0; + include "qelib1.inc"; + qreg q[3]; + gate g q0 { + x q0; + y q0; + z q0; + } + g q; // broadcast to all qubits in register + """ + + # The gate definition should translate to this + q0 = cirq.NamedQubit('q0') + g = cirq.FrozenCircuit(cirq.X(q0), cirq.Y(q0), cirq.Z(q0)) + + # The outer circuit should then translate to this + q_0, q_1, q_2 = cirq.NamedQubit.range(3, prefix='q_') # The outer qreg array + expected = cirq.Circuit( + # It is broadcast to all qubits in the qreg + cirq.CircuitOperation(g, qubit_map={q0: q_0}), + cirq.CircuitOperation(g, qubit_map={q0: q_1}), + cirq.CircuitOperation(g, qubit_map={q0: q_2}), + ) + + # Verify + parser = QasmParser() + parsed_qasm = parser.parse(qasm) + assert parsed_qasm.circuit == expected + + # Sanity check that this unrolls to a valid circuit + unrolled_expected = cirq.Circuit( + cirq.X(q_0), + cirq.Y(q_0), + cirq.Z(q_0), + cirq.X(q_1), + cirq.Y(q_1), + cirq.Z(q_1), + cirq.X(q_2), + cirq.Y(q_2), + cirq.Z(q_2), + ) + unrolled = cirq.align_left(cirq.unroll_circuit_op(parsed_qasm.circuit, tags_to_check=None)) + assert unrolled == unrolled_expected + + # Sanity check that these have the same unitaries as the QASM. + cq.assert_qiskit_parsed_qasm_consistent_with_unitary(qasm, cirq.unitary(parsed_qasm.circuit)) + cq.assert_qiskit_parsed_qasm_consistent_with_unitary(qasm, cirq.unitary(unrolled)) + + +def test_custom_gate_undefined_qubit_error(): + qasm = """OPENQASM 2.0; + include "qelib1.inc"; + qreg q[1]; + gate g q0 { x q1; } + g q + """ + _test_parse_exception( + qasm, + cirq_err="Undefined qubit 'q1' at line 4", + qiskit_err="4,19: 'q1' is not defined in this scope", + ) + + +def test_custom_gate_qubit_scope_closure_error(): + qasm = """OPENQASM 2.0; + include "qelib1.inc"; + qreg q[1]; + gate g q0 { x q; } + g q + """ + _test_parse_exception( + qasm, + cirq_err="'q' is a register, not a qubit at line 4", + qiskit_err="4,19: 'q' is a quantum register, not a qubit", + ) + + +def test_custom_gate_qubit_index_error(): + qasm = """OPENQASM 2.0; + include "qelib1.inc"; + qreg q[1]; + gate g q0 { x q0[0]; } + g q + """ + _test_parse_exception( + qasm, + cirq_err="Unsupported indexed qreg 'q0[0]' at line 4", + qiskit_err="4,21: needed ';', but instead saw [", + ) + + +def test_custom_gate_qreg_count_error(): + qasm = """OPENQASM 2.0; + include "qelib1.inc"; + qreg q[2]; + gate g q0 { x q0; } + g q[0], q[1]; + """ + _test_parse_exception( + qasm, + cirq_err="Wrong number of qregs for 'g' at line 5", + qiskit_err="5,5: 'g' takes 1 quantum argument, but got 2", + ) + + +def test_custom_gate_missing_param_error(): + qasm = """OPENQASM 2.0; + include "qelib1.inc"; + qreg q[1]; + gate g(p) q0 { rx(p) q0; } + g q; + """ + _test_parse_exception( + qasm, + cirq_err="Wrong number of params for 'g' at line 5", + qiskit_err=None, # Qiskit bug? It's an invalid circuit that won't simulate. + ) + + +def test_custom_gate_extra_param_error(): + qasm = """OPENQASM 2.0; + include "qelib1.inc"; + qreg q[1]; + gate g q0 { x q0; } + g(3) q; + """ + _test_parse_exception( + qasm, + cirq_err="Wrong number of params for 'g' at line 5", + qiskit_err="5,5: 'g' takes 0 parameters, but got 1", + ) + + +def test_custom_gate_undefined_param_error(): + qasm = """OPENQASM 2.0; + include "qelib1.inc"; + qreg q[1]; + gate g q0 { rx(p) q0; } + g q; + """ + _test_parse_exception( + qasm, + cirq_err="Undefined parameter 'p' in line 4", + qiskit_err="4,20: 'p' is not a parameter", + ) + + +def test_top_level_param_error(): + qasm = """OPENQASM 2.0; + include "qelib1.inc"; + qreg q[1]; + rx(p) q; + """ + _test_parse_exception( + qasm, + cirq_err="Parameter 'p' in line 4 not supported", + qiskit_err="4,8: 'p' is not a parameter", + ) + + +def _test_parse_exception(qasm: str, cirq_err: str, qiskit_err: str | None): + parser = QasmParser() + with pytest.raises(QasmException, match=re.escape(cirq_err)): + parser.parse(qasm) + try: + import qiskit + + if qiskit_err is None: + qiskit.QuantumCircuit.from_qasm_str(qasm) + return + with pytest.raises(qiskit.qasm2.exceptions.QASM2ParseError, match=re.escape(qiskit_err)): + qiskit.QuantumCircuit.from_qasm_str(qasm) + except ImportError: # pragma: no cover + warnings.warn( + "Skipped _test_qiskit_parse_exception because " + "qiskit isn't installed to verify against." + )