Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: conventional results post processing #593

Merged
merged 12 commits into from
Oct 29, 2024
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
206 changes: 206 additions & 0 deletions guppylang/hresult.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,206 @@
r"""
Quantinuum system results and utilities.

Includes conversions to traditional distributions over bitstrings if a tagging
convention is used, including conversion to a pytket BackendResult.

Under this convention, tags are assumed to be a name of a bit register unless they fit
the regex pattern `^([a-z][\w_]*)\[(\d+)\]$` (like `my_Reg[12]`) in which case they
are assumed to refer to the nth element of a bit register.

For results of the form ``` result("<register>", value) ``` `value` can be `{0, 1}`,
wherein the register is assumed to be length 1, or lists over those values,
wherein the list is taken to be the value of the entire register.

For results of the form ``` result("<register>[n]", value) ``` `value` can only be
`{0,1}`.
The register is assumed to be at least `n+1` in size and unset
elements are assumed to be `0`.

Subsequent writes to the same register/element in the same shot will overwrite.

To convert to a `BackendResult` all registers must be present in all shots, and register
sizes cannot change between shots.

"""

from __future__ import annotations

import re
from collections import Counter, defaultdict
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Literal

if TYPE_CHECKING:
from collections.abc import Iterable

from pytket.backends.backendresult import BackendResult

#: Primitive data types that can be returned by a result
DataPrimitive = int | float | bool
Copy link
Member

@qartik qartik Oct 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am confused by the spec description above:

result("<register>", value) value can be {0, 1}

and the code here.

Are floats supported or not? I'd imagine, if guppy says result("tag_for_pi", pi), we can return an approx value of pi as part of the results.

Or is this PR limited to bitstrings? If so, why include float?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

float is allowed as a result type, it is not supported under the "register bitstring convention" adopted as part of this PR. The HResult and Shots classes faithfully capture all types, it is only the post-processing/conversions that are limited.

#: Data value that can be returned by a result: a primitive or a list of primitives
qartik marked this conversation as resolved.
Show resolved Hide resolved
DataValue = DataPrimitive | list[DataPrimitive]
TaggedResult = tuple[str, DataValue]
# Pattern to match register index in tag, e.g. "reg[0]"
REG_INDEX_PATTERN = re.compile(r"^([a-z][\w_]*)\[(\d+)\]$")

BitChar = Literal["0", "1"]


@dataclass
class HResult:
"""Results from a single shot execution."""

entries: list[TaggedResult] = field(default_factory=list)

def __init__(self, entries: Iterable[TaggedResult] | None = None):
self.entries = list(entries or [])

def append(self, tag: str, data: DataValue) -> None:
self.entries.append((tag, data))

def as_dict(self) -> dict[str, DataValue]:
"""Convert results to a dictionary.

For duplicate tags, the last value is used.

Returns:
dict: A dictionary where the keys are the tags and the
values are the data.

Example:
>>> results = Results()
>>> results.append("tag1", 1)
>>> results.append("tag2", 2)
>>> results.append("tag2", 3)
>>> results.as_dict()
{"tag1": 1, "tag2": 3}
"""
return dict(self.entries)

def to_register_bits(self) -> dict[str, str]:
"""Convert results to a dictionary of register bit values."""
Comment on lines +81 to +82
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's say I do

results = Results()
results.append("qs[0]", 1)
results.append("qs", 0)
results.append("qs[1]", 1)
results.to_register_bits()

The expected output should be {"qs": "01"}, right? I think this works since dicts are iterated in insertion order and HResult.as_dict keeps the order of entries. You could consider using OrderedDict to make this invariant explicit?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes it relies on insertion order preservation guaranteed from python 3.7 onwards.
OrderedDict is more useful for things that require reordering these days so I will
just add a comment.

reg_bits: dict[str, list[BitChar]] = {}

res_dict = self.as_dict()
for tag, data in res_dict.items():
match = re.match(REG_INDEX_PATTERN, tag)
if match is not None:
reg_name, reg_index_str = match.groups()
reg_index = int(reg_index_str)

if reg_name not in reg_bits:
# Initialize register counts to False
reg_bits[reg_name] = ["0"] * (reg_index + 1)
bitlst = reg_bits[reg_name]
if reg_index >= len(bitlst):
# Extend register counts with "0"
bitlst += ["0"] * (reg_index - len(bitlst) + 1)

bitlst[reg_index] = _cast_primitive_bit(data)
continue
match data:
case list(vs):
reg_bits[tag] = [_cast_primitive_bit(v) for v in vs]
case _:
reg_bits[tag] = [_cast_primitive_bit(data)]

return {reg: "".join(bits) for reg, bits in reg_bits.items()}


def _cast_primitive_bit(data: DataValue) -> BitChar:
if isinstance(data, int) and data in {0, 1}:
return str(data) # type: ignore[return-value]
raise ValueError(f"Expected bit data for register value found {data}")


@dataclass
class HShots:
"""Results accumulated over multiple shots."""

results: list[HResult] = field(default_factory=list)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need for the default_factory since you have a custom __init__


def __init__(
self, results: Iterable[HResult | Iterable[TaggedResult]] | None = None
):
self.results = [
res if isinstance(res, HResult) else HResult(res) for res in results or []
]

def register_counts(
self, strict_names: bool = False, strict_lengths: bool = False
) -> dict[str, Counter[str]]:
"""Convert results to a dictionary of register counts.

Returns:
dict: A dictionary where the keys are the register names
and the values are the counts of the register bitstrings.
"""
return {
reg: Counter(bitstrs)
for reg, bitstrs in self.register_bitstrings(
strict_lengths=strict_lengths, strict_names=strict_names
).items()
}

def register_bitstrings(
self, strict_names: bool = False, strict_lengths: bool = False
) -> dict[str, list[str]]:
"""Convert results to a dictionary from register name to list of bitstrings over
the shots.

Args:
strict_names: Whether to enforce that all shots have the same
registers.
strict_lengths: Whether to enforce that all register bitstrings have
the same length.

"""

shot_dct: dict[str, list[str]] = defaultdict(list)
for shot in self.results:
bitstrs = shot.to_register_bits()
for reg, bitstr in bitstrs.items():
if (
strict_lengths
and reg in shot_dct
and len(shot_dct[reg][0]) != len(bitstr)
):
raise ValueError(
"All register bitstrings must have the same length."
)
shot_dct[reg].append(bitstr)
if strict_names and not bitstrs.keys() == shot_dct.keys():
raise ValueError("All shots must have the same registers.")
return shot_dct

def to_pytket(self) -> BackendResult:
"""Convert results to a pytket BackendResult.

Returns:
BackendResult: A BackendResult object with the shots.

Raises:
ImportError: If pytket is not installed.
ValueError: If a register's bitstrings have different lengths or not all
registers are present in all shots.
"""
try:
from pytket._tket.unit_id import Bit
from pytket.backends.backendresult import BackendResult
from pytket.utils.outcomearray import OutcomeArray
except ImportError as e:
raise ImportError(
"Pytket is an optional dependency, install with the `pytket` extra"
) from e
counts = self.register_bitstrings(strict_lengths=True, strict_names=True)
reg_sizes: dict[str, int] = {
reg: len(next(iter(counts[reg]), "")) for reg in counts
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is the "" default in next needed? Shouldn't all registers be populated with at least one value?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, it is just there to avoid having to handle unexpected stopiteration errors induced by changes to other code. It is safe to assume unpopulated registers are size 0

}
registers = list(counts.keys())
bits = [Bit(reg, i) for reg in registers for i in range(reg_sizes[reg])]
int_shots = [
[ord(bitval) - 48 for reg in registers for bitval in counts[reg][i]]
for i in range(len(self.results))
]
return BackendResult(shots=OutcomeArray.from_readouts(int_shots), c_bits=bits)
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ dependencies = [
# pytket = ["pytket >=1.30.0,<2", "tket2 >=0.4.1,<0.5"]
docs = ["sphinx >=7.2.6,<9", "sphinx-book-theme >=1.1.2,<2"]
execution = ["execute-llvm"]
pytket = ["pytket>=1.34"]

[project.urls]
homepage = "https://github.com/CQCL/guppylang"
Expand Down
113 changes: 113 additions & 0 deletions tests/test_hresult.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
import re
from collections import Counter

import pytest

from guppylang.hresult import REG_INDEX_PATTERN, HResult, HShots


@pytest.mark.parametrize(
("identifier", "match"),
[
("sadfj", None),
("asdf_sdf", None),
("asdf3h32", None),
("dsf[3]asdf", None),
("_s34fd_fd[12]", None),
("afsd3[34]sdf", None),
("asdf[2]", ("asdf", 2)),
("as3df[21234]", ("as3df", 21234)),
("as3ABdfAB[2]", ("as3ABdfAB", 2)),
],
)
def test_reg_index_pattern_match(identifier, match: tuple[str, int] | None):
Copy link
Member

@qartik qartik Oct 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It wasn't immediately obvious to me what the purpose of this test was, best to add a docstring.

_s34fd_fd[12] should perhaps be accepted? Do we explicitly forbid register names starting with an underscore?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AFAIK this pattern matches the QASM2 allowed names

"""Test regex pattern matches tags indexing in to registers."""
mtch = re.match(REG_INDEX_PATTERN, identifier)
if mtch is None:
assert match is None
return
parsed = (mtch.group(1), int(mtch.group(2)))
assert parsed == match


def test_as_dict():
results = HResult()
results.append("tag1", 1)
results.append("tag2", 2)
results.append("tag2", 3)
assert results.as_dict() == {"tag1": 1, "tag2": 3}


def test_to_register_bits():
results = HResult()
results.append("c[0]", 1)
results.append("c[1]", 0)
results.append("c[3]", 1)
results.append("d", [1, 0, 1, 0])
results.append("x[5]", 1)
results.append("x", 0)

assert results.to_register_bits() == {"c": "1001", "d": "1010", "x": "0"}

shots = HShots([results, results])
assert shots.register_counts() == {
"c": Counter({"1001": 2}),
"d": Counter({"1010": 2}),
"x": Counter({"0": 2}),
}


@pytest.mark.parametrize(
"results",
[
HResult([("t", 1.0)]),
HResult([("t[1]", 1.0)]),
HResult([("t", [1.0])]),
HResult([("t[0]", [0])]),
HResult([("t[0]", 3)]),
],
)
def test_to_register_bits_bad(results: HResult):
with pytest.raises(ValueError, match="Expected bit"):
_ = results.to_register_bits()


def test_counter():
shot1 = HResult()
shot1.append("c", [1, 0, 1, 0])
shot1.append("d", [1, 0, 1])

shot2 = HResult()
shot2.append("c", [1, 0, 1])

shots = HShots([shot1, shot2])
assert shots.register_counts() == {
"c": Counter({"1010": 1, "101": 1}),
"d": Counter({"101": 1}),
}
with pytest.raises(ValueError, match="same length"):
_ = shots.register_counts(strict_lengths=True)

with pytest.raises(ValueError, match="All shots must have the same registers"):
_ = shots.register_counts(strict_names=True)


def test_pytket():
qartik marked this conversation as resolved.
Show resolved Hide resolved
"""Test that results observing strict tagging conventions can be converted to pytket
shot results."""
pytest.importorskip("pytket", reason="pytket not installed")

hsim_shots = HShots(([("c", [1, 0]), ("d", [1, 0])], [("c", [0, 0]), ("d", [1, 0])]))

pytket_result = hsim_shots.to_pytket()

from pytket._tket.unit_id import Bit
from pytket.backends.backendresult import BackendResult
from pytket.utils.outcomearray import OutcomeArray

bits = [Bit("c", 0), Bit("c", 1), Bit("d", 0), Bit("d", 1)]
expected = BackendResult(
c_bits=bits, shots=OutcomeArray.from_readouts([[1, 0, 1, 0], [0, 0, 1, 0]])
)

assert pytket_result == expected
Loading
Loading