-
Notifications
You must be signed in to change notification settings - Fork 2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: conventional results post processing #593
Changes from 11 commits
864f61d
4c0bf2e
329ac82
6000741
c4fe5ff
dcda153
6741c76
22469b5
33fa156
667f711
f306a74
a1ee324
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,206 @@ | ||
r""" | ||
Quantinuum system results and utilities. | ||
|
||
Includes conversions to traditional distributions over bitstrings if a tagging | ||
convention is used, including conversion to a pytket BackendResult. | ||
|
||
Under this convention, tags are assumed to be a name of a bit register unless they fit | ||
the regex pattern `^([a-z][\w_]*)\[(\d+)\]$` (like `my_Reg[12]`) in which case they | ||
are assumed to refer to the nth element of a bit register. | ||
|
||
For results of the form ``` result("<register>", value) ``` `value` can be `{0, 1}`, | ||
wherein the register is assumed to be length 1, or lists over those values, | ||
wherein the list is taken to be the value of the entire register. | ||
|
||
For results of the form ``` result("<register>[n]", value) ``` `value` can only be | ||
`{0,1}`. | ||
The register is assumed to be at least `n+1` in size and unset | ||
elements are assumed to be `0`. | ||
|
||
Subsequent writes to the same register/element in the same shot will overwrite. | ||
|
||
To convert to a `BackendResult` all registers must be present in all shots, and register | ||
sizes cannot change between shots. | ||
|
||
""" | ||
|
||
from __future__ import annotations | ||
|
||
import re | ||
from collections import Counter, defaultdict | ||
from dataclasses import dataclass, field | ||
from typing import TYPE_CHECKING, Literal | ||
|
||
if TYPE_CHECKING: | ||
from collections.abc import Iterable | ||
|
||
from pytket.backends.backendresult import BackendResult | ||
|
||
#: Primitive data types that can be returned by a result | ||
DataPrimitive = int | float | bool | ||
#: Data value that can be returned by a result: a primitive or a list of primitives | ||
qartik marked this conversation as resolved.
Show resolved
Hide resolved
|
||
DataValue = DataPrimitive | list[DataPrimitive] | ||
TaggedResult = tuple[str, DataValue] | ||
# Pattern to match register index in tag, e.g. "reg[0]" | ||
REG_INDEX_PATTERN = re.compile(r"^([a-z][\w_]*)\[(\d+)\]$") | ||
|
||
BitChar = Literal["0", "1"] | ||
|
||
|
||
@dataclass | ||
class HResult: | ||
"""Results from a single shot execution.""" | ||
|
||
entries: list[TaggedResult] = field(default_factory=list) | ||
|
||
def __init__(self, entries: Iterable[TaggedResult] | None = None): | ||
self.entries = list(entries or []) | ||
|
||
def append(self, tag: str, data: DataValue) -> None: | ||
self.entries.append((tag, data)) | ||
|
||
def as_dict(self) -> dict[str, DataValue]: | ||
"""Convert results to a dictionary. | ||
|
||
For duplicate tags, the last value is used. | ||
|
||
Returns: | ||
dict: A dictionary where the keys are the tags and the | ||
values are the data. | ||
|
||
Example: | ||
>>> results = Results() | ||
>>> results.append("tag1", 1) | ||
>>> results.append("tag2", 2) | ||
>>> results.append("tag2", 3) | ||
>>> results.as_dict() | ||
{"tag1": 1, "tag2": 3} | ||
""" | ||
return dict(self.entries) | ||
|
||
def to_register_bits(self) -> dict[str, str]: | ||
"""Convert results to a dictionary of register bit values.""" | ||
Comment on lines
+81
to
+82
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's say I do results = Results()
results.append("qs[0]", 1)
results.append("qs", 0)
results.append("qs[1]", 1)
results.to_register_bits() The expected output should be There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes it relies on insertion order preservation guaranteed from python 3.7 onwards. |
||
reg_bits: dict[str, list[BitChar]] = {} | ||
|
||
res_dict = self.as_dict() | ||
for tag, data in res_dict.items(): | ||
match = re.match(REG_INDEX_PATTERN, tag) | ||
if match is not None: | ||
reg_name, reg_index_str = match.groups() | ||
reg_index = int(reg_index_str) | ||
|
||
if reg_name not in reg_bits: | ||
# Initialize register counts to False | ||
reg_bits[reg_name] = ["0"] * (reg_index + 1) | ||
bitlst = reg_bits[reg_name] | ||
if reg_index >= len(bitlst): | ||
# Extend register counts with "0" | ||
bitlst += ["0"] * (reg_index - len(bitlst) + 1) | ||
|
||
bitlst[reg_index] = _cast_primitive_bit(data) | ||
continue | ||
match data: | ||
case list(vs): | ||
reg_bits[tag] = [_cast_primitive_bit(v) for v in vs] | ||
case _: | ||
reg_bits[tag] = [_cast_primitive_bit(data)] | ||
|
||
return {reg: "".join(bits) for reg, bits in reg_bits.items()} | ||
|
||
|
||
def _cast_primitive_bit(data: DataValue) -> BitChar: | ||
if isinstance(data, int) and data in {0, 1}: | ||
return str(data) # type: ignore[return-value] | ||
raise ValueError(f"Expected bit data for register value found {data}") | ||
|
||
|
||
@dataclass | ||
class HShots: | ||
"""Results accumulated over multiple shots.""" | ||
|
||
results: list[HResult] = field(default_factory=list) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No need for the |
||
|
||
def __init__( | ||
self, results: Iterable[HResult | Iterable[TaggedResult]] | None = None | ||
): | ||
self.results = [ | ||
res if isinstance(res, HResult) else HResult(res) for res in results or [] | ||
] | ||
|
||
def register_counts( | ||
self, strict_names: bool = False, strict_lengths: bool = False | ||
) -> dict[str, Counter[str]]: | ||
"""Convert results to a dictionary of register counts. | ||
|
||
Returns: | ||
dict: A dictionary where the keys are the register names | ||
and the values are the counts of the register bitstrings. | ||
""" | ||
return { | ||
reg: Counter(bitstrs) | ||
for reg, bitstrs in self.register_bitstrings( | ||
strict_lengths=strict_lengths, strict_names=strict_names | ||
).items() | ||
} | ||
|
||
def register_bitstrings( | ||
self, strict_names: bool = False, strict_lengths: bool = False | ||
) -> dict[str, list[str]]: | ||
"""Convert results to a dictionary from register name to list of bitstrings over | ||
the shots. | ||
|
||
Args: | ||
strict_names: Whether to enforce that all shots have the same | ||
registers. | ||
strict_lengths: Whether to enforce that all register bitstrings have | ||
the same length. | ||
|
||
""" | ||
|
||
shot_dct: dict[str, list[str]] = defaultdict(list) | ||
for shot in self.results: | ||
bitstrs = shot.to_register_bits() | ||
for reg, bitstr in bitstrs.items(): | ||
if ( | ||
strict_lengths | ||
and reg in shot_dct | ||
and len(shot_dct[reg][0]) != len(bitstr) | ||
): | ||
raise ValueError( | ||
"All register bitstrings must have the same length." | ||
) | ||
shot_dct[reg].append(bitstr) | ||
if strict_names and not bitstrs.keys() == shot_dct.keys(): | ||
raise ValueError("All shots must have the same registers.") | ||
return shot_dct | ||
|
||
def to_pytket(self) -> BackendResult: | ||
"""Convert results to a pytket BackendResult. | ||
|
||
Returns: | ||
BackendResult: A BackendResult object with the shots. | ||
|
||
Raises: | ||
ImportError: If pytket is not installed. | ||
ValueError: If a register's bitstrings have different lengths or not all | ||
registers are present in all shots. | ||
""" | ||
try: | ||
from pytket._tket.unit_id import Bit | ||
from pytket.backends.backendresult import BackendResult | ||
from pytket.utils.outcomearray import OutcomeArray | ||
except ImportError as e: | ||
raise ImportError( | ||
"Pytket is an optional dependency, install with the `pytket` extra" | ||
) from e | ||
counts = self.register_bitstrings(strict_lengths=True, strict_names=True) | ||
reg_sizes: dict[str, int] = { | ||
reg: len(next(iter(counts[reg]), "")) for reg in counts | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why is the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes, it is just there to avoid having to handle unexpected stopiteration errors induced by changes to other code. It is safe to assume unpopulated registers are size 0 |
||
} | ||
registers = list(counts.keys()) | ||
bits = [Bit(reg, i) for reg in registers for i in range(reg_sizes[reg])] | ||
int_shots = [ | ||
[ord(bitval) - 48 for reg in registers for bitval in counts[reg][i]] | ||
for i in range(len(self.results)) | ||
] | ||
return BackendResult(shots=OutcomeArray.from_readouts(int_shots), c_bits=bits) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,113 @@ | ||
import re | ||
from collections import Counter | ||
|
||
import pytest | ||
|
||
from guppylang.hresult import REG_INDEX_PATTERN, HResult, HShots | ||
|
||
|
||
@pytest.mark.parametrize( | ||
("identifier", "match"), | ||
[ | ||
("sadfj", None), | ||
("asdf_sdf", None), | ||
("asdf3h32", None), | ||
("dsf[3]asdf", None), | ||
("_s34fd_fd[12]", None), | ||
("afsd3[34]sdf", None), | ||
("asdf[2]", ("asdf", 2)), | ||
("as3df[21234]", ("as3df", 21234)), | ||
("as3ABdfAB[2]", ("as3ABdfAB", 2)), | ||
], | ||
) | ||
def test_reg_index_pattern_match(identifier, match: tuple[str, int] | None): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It wasn't immediately obvious to me what the purpose of this test was, best to add a docstring.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. AFAIK this pattern matches the QASM2 allowed names |
||
"""Test regex pattern matches tags indexing in to registers.""" | ||
mtch = re.match(REG_INDEX_PATTERN, identifier) | ||
if mtch is None: | ||
assert match is None | ||
return | ||
parsed = (mtch.group(1), int(mtch.group(2))) | ||
assert parsed == match | ||
|
||
|
||
def test_as_dict(): | ||
results = HResult() | ||
results.append("tag1", 1) | ||
results.append("tag2", 2) | ||
results.append("tag2", 3) | ||
assert results.as_dict() == {"tag1": 1, "tag2": 3} | ||
|
||
|
||
def test_to_register_bits(): | ||
results = HResult() | ||
results.append("c[0]", 1) | ||
results.append("c[1]", 0) | ||
results.append("c[3]", 1) | ||
results.append("d", [1, 0, 1, 0]) | ||
results.append("x[5]", 1) | ||
results.append("x", 0) | ||
|
||
assert results.to_register_bits() == {"c": "1001", "d": "1010", "x": "0"} | ||
|
||
shots = HShots([results, results]) | ||
assert shots.register_counts() == { | ||
"c": Counter({"1001": 2}), | ||
"d": Counter({"1010": 2}), | ||
"x": Counter({"0": 2}), | ||
} | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"results", | ||
[ | ||
HResult([("t", 1.0)]), | ||
HResult([("t[1]", 1.0)]), | ||
HResult([("t", [1.0])]), | ||
HResult([("t[0]", [0])]), | ||
HResult([("t[0]", 3)]), | ||
], | ||
) | ||
def test_to_register_bits_bad(results: HResult): | ||
with pytest.raises(ValueError, match="Expected bit"): | ||
_ = results.to_register_bits() | ||
|
||
|
||
def test_counter(): | ||
shot1 = HResult() | ||
shot1.append("c", [1, 0, 1, 0]) | ||
shot1.append("d", [1, 0, 1]) | ||
|
||
shot2 = HResult() | ||
shot2.append("c", [1, 0, 1]) | ||
|
||
shots = HShots([shot1, shot2]) | ||
assert shots.register_counts() == { | ||
"c": Counter({"1010": 1, "101": 1}), | ||
"d": Counter({"101": 1}), | ||
} | ||
with pytest.raises(ValueError, match="same length"): | ||
_ = shots.register_counts(strict_lengths=True) | ||
|
||
with pytest.raises(ValueError, match="All shots must have the same registers"): | ||
_ = shots.register_counts(strict_names=True) | ||
|
||
|
||
def test_pytket(): | ||
qartik marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"""Test that results observing strict tagging conventions can be converted to pytket | ||
shot results.""" | ||
pytest.importorskip("pytket", reason="pytket not installed") | ||
|
||
hsim_shots = HShots(([("c", [1, 0]), ("d", [1, 0])], [("c", [0, 0]), ("d", [1, 0])])) | ||
|
||
pytket_result = hsim_shots.to_pytket() | ||
|
||
from pytket._tket.unit_id import Bit | ||
from pytket.backends.backendresult import BackendResult | ||
from pytket.utils.outcomearray import OutcomeArray | ||
|
||
bits = [Bit("c", 0), Bit("c", 1), Bit("d", 0), Bit("d", 1)] | ||
expected = BackendResult( | ||
c_bits=bits, shots=OutcomeArray.from_readouts([[1, 0, 1, 0], [0, 0, 1, 0]]) | ||
) | ||
|
||
assert pytket_result == expected |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am confused by the spec description above:
and the code here.
Are floats supported or not? I'd imagine, if guppy says
result("tag_for_pi", pi)
, we can return an approx value of pi as part of the results.Or is this PR limited to bitstrings? If so, why include
float
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
float is allowed as a result type, it is not supported under the "register bitstring convention" adopted as part of this PR. The
HResult
andShots
classes faithfully capture all types, it is only the post-processing/conversions that are limited.