Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make maximum classical register width a parameter of QASM converters #1083

Merged
merged 8 commits into from
Oct 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions pytket/docs/changelog.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,14 @@
Changelog
=========

Unreleased
----------

Minor new features:

* Add optional parameter to QASM conversion methods to set the maximum allowed
width of classical registers (default 32).

1.21.0 (October 2023)
---------------------

Expand Down
170 changes: 133 additions & 37 deletions pytket/pytket/qasm/qasm.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,14 +290,15 @@ def __iter__(self) -> Iterable[str]:


class CircuitTransformer(Transformer):
def __init__(self, return_gate_dict: bool = False) -> None:
def __init__(self, return_gate_dict: bool = False, maxwidth: int = 32) -> None:
super().__init__()
self.q_registers: Dict[str, int] = {}
self.c_registers: Dict[str, int] = {}
self.gate_dict: Dict[str, Dict] = {}
self.wasm: Optional[WasmFileHandler] = None
self.include = ""
self.return_gate_dict = return_gate_dict
self.maxwidth = maxwidth

def _fresh_temp_bit(self) -> List:
if _TEMP_BIT_NAME in self.c_registers:
Expand Down Expand Up @@ -352,6 +353,12 @@ def args(self, tree: Iterable[Token]) -> Iterator[List]:

def creg(self, tree: List[Token]) -> None:
name, size = _extract_reg(tree[0])
if size > self.maxwidth:
raise QASMUnsupportedError(
f"Circuit contains classical register {name} of size {size} > "
f"{self.maxwidth}: try setting the `maxwidth` parameter to a larger "
"value."
)
self.c_registers[Reg(name)] = size

def qreg(self, tree: List[Token]) -> None:
Expand Down Expand Up @@ -612,7 +619,7 @@ def ifc(self, tree: Sequence) -> Iterable[CommandDict]:
else:
pred_val = cast(int, val)
minval = 0
maxval = (1 << 32) - 1
maxval = (1 << self.maxwidth) - 1
if condition.op == RegWiseOp.LT:
maxval = pred_val - 1
elif condition.op == RegWiseOp.GT:
Expand Down Expand Up @@ -834,7 +841,7 @@ def gdef(self, tree: List) -> None:
symbol_map = {sym: sym * pi for sym in map(Symbol, symbols)}
rename_map = {Qubit.from_list(qb): Qubit("q", i) for i, qb in enumerate(args)}

new = CircuitTransformer()
new = CircuitTransformer(maxwidth=self.maxwidth)
circ_dict = new.prog(child_iter)

circ_dict["qubits"] = args
Expand All @@ -850,7 +857,9 @@ def gdef(self, tree: List) -> None:
comparison_circ = _get_gate_circuit(
NOPARAM_EXTRA_COMMANDS[gate], qubit_args
)
if circuit_to_qasm_str(comparison_circ) == circuit_to_qasm_str(gate_circ):
if circuit_to_qasm_str(
comparison_circ, maxwidth=self.maxwidth
) == circuit_to_qasm_str(gate_circ, maxwidth=self.maxwidth):
existing_op = True
elif gate in PARAM_EXTRA_COMMANDS:
qubit_args = [
Expand Down Expand Up @@ -911,61 +920,110 @@ def prog(self, tree: Iterable) -> Dict[str, Any]:
return outdict


parser = Lark(
grammar,
start="prog",
debug=False,
parser="lalr",
cache=True,
transformer=CircuitTransformer(),
)
def parser(maxwidth: int) -> Lark:
return Lark(
grammar,
start="prog",
debug=False,
parser="lalr",
cache=True,
transformer=CircuitTransformer(maxwidth=maxwidth),
)


g_parser = None
g_maxwidth = 32


def set_parser(maxwidth: int) -> None:
global g_parser, g_maxwidth
if (g_parser is None) or (g_maxwidth != maxwidth): # type: ignore
g_parser = parser(maxwidth=maxwidth)
g_maxwidth = maxwidth


def circuit_from_qasm(
input_file: Union[str, "os.PathLike[Any]"], encoding: str = "utf-8"
input_file: Union[str, "os.PathLike[Any]"],
encoding: str = "utf-8",
maxwidth: int = 32,
) -> Circuit:
"""A method to generate a tket Circuit from a qasm file"""
"""A method to generate a tket Circuit from a qasm file.

:param input_file: path to qasm file; filename must have ``.qasm`` extension
:param encoding: file encoding (default utf-8)
:param maxwidth: maximum allowed width of classical registers (default 32)
:return: pytket circuit
"""
ext = os.path.splitext(input_file)[-1]
if ext != ".qasm":
raise TypeError("Can only convert .qasm files")
with open(input_file, "r", encoding=encoding) as f:
try:
circ = circuit_from_qasm_io(f)
circ = circuit_from_qasm_io(f, maxwidth=maxwidth)
except QASMParseError as e:
raise QASMParseError(e.msg, e.line, str(input_file))
return circ


def circuit_from_qasm_str(qasm_str: str) -> Circuit:
"""A method to generate a tket Circuit from a qasm str"""
cast(CircuitTransformer, parser.options.transformer)._reset_context(
def circuit_from_qasm_str(qasm_str: str, maxwidth: int = 32) -> Circuit:
"""A method to generate a tket Circuit from a qasm string.

:param qasm_str: qasm string
:param maxwidth: maximum allowed width of classical registers (default 32)
:return: pytket circuit
"""
global g_parser
set_parser(maxwidth=maxwidth)
assert g_parser is not None
cast(CircuitTransformer, g_parser.options.transformer)._reset_context(
reset_wasm=False
)
return Circuit.from_dict(parser.parse(qasm_str)) # type: ignore[arg-type]
return Circuit.from_dict(g_parser.parse(qasm_str)) # type: ignore[arg-type]


def circuit_from_qasm_io(stream_in: TextIO) -> Circuit:
def circuit_from_qasm_io(stream_in: TextIO, maxwidth: int = 32) -> Circuit:
"""A method to generate a tket Circuit from a qasm text stream"""
return circuit_from_qasm_str(stream_in.read())
return circuit_from_qasm_str(stream_in.read(), maxwidth=maxwidth)


def circuit_from_qasm_wasm(
input_file: Union[str, "os.PathLike[Any]"],
wasm_file: Union[str, "os.PathLike[Any]"],
encoding: str = "utf-8",
maxwidth: int = 32,
) -> Circuit:
"""A method to generate a tket Circuit from a qasm str and external WASM module."""
"""A method to generate a tket Circuit from a qasm string and external WASM module.

:param input_file: path to qasm file; filename must have ``.qasm`` extension
:param wasm_file: path to WASM file containing functions used in qasm
:param encoding: encoding of qasm file (default utf-8)
:param maxwidth: maximum allowed width of classical registers (default 32)
:return: pytket circuit
"""
global g_parser
wasm_module = WasmFileHandler(str(wasm_file))
cast(CircuitTransformer, parser.options.transformer).wasm = wasm_module
return circuit_from_qasm(input_file, encoding=encoding)
set_parser(maxwidth=maxwidth)
assert g_parser is not None
cast(CircuitTransformer, g_parser.options.transformer).wasm = wasm_module
return circuit_from_qasm(input_file, encoding=encoding, maxwidth=maxwidth)


def circuit_to_qasm(circ: Circuit, output_file: str, header: str = "qelib1") -> None:
def circuit_to_qasm(
circ: Circuit, output_file: str, header: str = "qelib1", maxwidth: int = 32
) -> None:
"""Convert a Circuit to QASM and write it to a file.

Note that this will not account for implicit qubit permutations in the Circuit."""
Classical bits in the pytket circuit must be singly-indexed.

Note that this will not account for implicit qubit permutations in the Circuit.

:param circ: pytket circuit
:param output_file: path to output qasm file
:param header: qasm header (default "qelib1")
:param maxwidth: maximum allowed width of classical registers (default 32)
"""
with open(output_file, "w") as out:
circuit_to_qasm_io(circ, out, header=header)
circuit_to_qasm_io(circ, out, header=header, maxwidth=maxwidth)


def _filtered_qasm_str(qasm: str) -> str:
Expand Down Expand Up @@ -993,11 +1051,24 @@ def _filtered_qasm_str(qasm: str) -> str:


def circuit_to_qasm_str(
circ: Circuit, header: str = "qelib1", include_gate_defs: Optional[Set[str]] = None
circ: Circuit,
header: str = "qelib1",
include_gate_defs: Optional[Set[str]] = None,
maxwidth: int = 32,
) -> str:
"""Convert a Circuit to QASM and return the string.

Note that this will not account for implicit qubit permutations in the Circuit."""
Classical bits in the pytket circuit must be singly-indexed.

Note that this will not account for implicit qubit permutations in the Circuit.

:param circ: pytket circuit
:param header: qasm header (default "qelib1")
:param output_file: path to output qasm file
:param include_gate_defs: optional set of gates to include
:param maxwidth: maximum allowed width of classical registers (default 32)
:return: qasm string
"""
if any(
circ.n_gates_of_type(typ)
for typ in (
Expand All @@ -1014,7 +1085,14 @@ def circuit_to_qasm_str(
"Complex classical gates not supported with qelib1: try converting with "
"`header=hqslib1`"
)
qasm_writer = QasmWriter(circ.qubits, circ.bits, header, include_gate_defs)
if any(bit.index[0] >= maxwidth for bit in circ.bits):
raise QASMUnsupportedError(
f"Circuit contains a classical register larger than {maxwidth}: try "
"setting the `maxwidth` parameter to a higher value."
)
qasm_writer = QasmWriter(
circ.qubits, circ.bits, header, include_gate_defs, maxwidth
)
for command in circ:
assert isinstance(command, Command)
qasm_writer.add_op(command.op, command.args)
Expand All @@ -1036,8 +1114,8 @@ def _retrieve_registers(
}


def _parse_range(minval: int, maxval: int) -> Tuple[str, int]:
REGMAX = (1 << 32) - 1
def _parse_range(minval: int, maxval: int, maxwidth: int) -> Tuple[str, int]:
REGMAX = (1 << maxwidth) - 1
if minval == maxval:
return ("==", minval)
elif minval == 0:
Expand Down Expand Up @@ -1140,8 +1218,10 @@ def __init__(
bits: List[Bit],
header: str = "qelib1",
include_gate_defs: Optional[Set[str]] = None,
maxwidth: int = 32,
):
self.header = header
self.maxwidth = maxwidth
self.added_gate_definitions: Set[str] = set()
self.include_module_gates = {"measure", "reset", "barrier"}
self.include_module_gates.update(
Expand Down Expand Up @@ -1243,7 +1323,9 @@ def write_gate_definition(
gate_circ = _get_gate_circuit(optype, qubit_args, symbols)
# write circuit to qasm
self.strings.add_string(
circuit_to_qasm_str(gate_circ, self.header, self.include_gate_defs)
circuit_to_qasm_str(
gate_circ, self.header, self.include_gate_defs, self.maxwidth
)
)
self.strings.add_string("}\n")

Expand All @@ -1261,7 +1343,7 @@ def mark_as_written(self, written_variable: str) -> None:
self.range_preds.remove(hit)

def add_range_predicate(self, op: RangePredicateOp, args: List[Bit]) -> None:
comparator, value = _parse_range(op.lower, op.upper)
comparator, value = _parse_range(op.lower, op.upper, self.maxwidth)
if (not hqs_header(self.header)) and comparator != "==":
raise QASMUnsupportedError(
"OpenQASM conditions must be on a register's fixed value."
Expand Down Expand Up @@ -1440,7 +1522,9 @@ def add_custom_gate(self, op: CustomGate, args: List[UnitID]) -> None:
gate_circ.rename_units(dict(zip(gate_circ.qubits, args)))
gate_circ.symbol_substitution(dict(zip(op.gate.args, op.params)))
self.strings.add_string(
circuit_to_qasm_str(gate_circ, self.header, self.include_gate_defs)
circuit_to_qasm_str(
gate_circ, self.header, self.include_gate_defs, self.maxwidth
)
)
else:
opstr = op.gate.name
Expand Down Expand Up @@ -1575,10 +1659,22 @@ def circuit_to_qasm_io(
stream_out: TextIO,
header: str = "qelib1",
include_gate_defs: Optional[Set[str]] = None,
maxwidth: int = 32,
) -> None:
"""Convert a Circuit to QASM and write to a text stream.

Note that this will not account for implicit qubit permutations in the Circuit."""
Classical bits in the pytket circuit must be singly-indexed.

Note that this will not account for implicit qubit permutations in the Circuit.

:param circ: pytket circuit
:param stream_out: text stream to be written to
:param header: qasm header (default "qelib1")
:param include_gate_defs: optional set of gates to include
:param maxwidth: maximum allowed width of classical registers (default 32)
"""
stream_out.write(
circuit_to_qasm_str(circ, header=header, include_gate_defs=include_gate_defs)
circuit_to_qasm_str(
circ, header=header, include_gate_defs=include_gate_defs, maxwidth=maxwidth
)
)
23 changes: 21 additions & 2 deletions pytket/tests/qasm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -605,8 +605,8 @@ def test_scratch_bits_filtering() -> None:
creg a[1];
creg b[1];
creg d[2];
creg {_TEMP_BIT_NAME}[100];
creg {_TEMP_BIT_NAME}_1[100];
creg {_TEMP_BIT_NAME}[32];
creg {_TEMP_BIT_NAME}_1[32];
{_TEMP_BIT_NAME}[0] = (a[0] ^ b[0]);
if({_TEMP_BIT_NAME}[0]==1) x q[0];
"""
Expand Down Expand Up @@ -818,6 +818,25 @@ def test_classical_assignment_order_1() -> None:
assert qasm == correct_qasm


def test_max_reg_width() -> None:
circ_in = Circuit(1, 33)
circ_in.H(0).Measure(0, 32)
with pytest.raises(QASMUnsupportedError):
circuit_to_qasm_str(circ_in)
qasm_out = circuit_to_qasm_str(circ_in, maxwidth=64)
assert "measure q[0] -> c[32];" in qasm_out
qasm_in = """OPENQASM 2.0;
include "qelib1.inc";
qreg q[1];
creg c[33];
h q[0];
measure q[0] -> c[32];"""
with pytest.raises(QASMUnsupportedError):
circuit_from_qasm_str(qasm_in)
circ_out = circuit_from_qasm_str(qasm_in, maxwidth=64)
assert len(circ_out.bits) == 33


if __name__ == "__main__":
test_qasm_correct()
test_qasm_qubit()
Expand Down