From 54c25823bda4a4422d8e33f918038292db2327b6 Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Mon, 23 Dec 2024 11:23:56 +0000 Subject: [PATCH] preserve register correlations --- guppylang/hresult.py | 21 +++++++++------------ tests/test_hresult.py | 12 +++++------- 2 files changed, 14 insertions(+), 19 deletions(-) diff --git a/guppylang/hresult.py b/guppylang/hresult.py index 3d2d29c3..62334b5a 100644 --- a/guppylang/hresult.py +++ b/guppylang/hresult.py @@ -226,7 +226,7 @@ def collated_shots(self) -> list[dict[str, list[DataValue]]]: """For each shot generate a dictionary of tags to collated data.""" return list(self._collated_shots_iter()) - def collated_counts(self) -> dict[str, Counter[str]]: + def collated_counts(self) -> Counter[tuple[tuple[str, str], ...]]: """Calculate counts of bit strings for each tag by collating across shots using `HShots.tag_collated_shots`. Each `result` entry per shot is seen to be appending to the bitstring for that tag. @@ -235,23 +235,20 @@ def collated_counts(self) -> dict[str, Counter[str]]: Example: >>> res = HShots([HResult([("a", 1), ("a", 0)]), HResult([("a", [0, 1])])]) - >>> res.tag_collated_counts() - {'a': Counter({'10': 1, '01': 1})} + >>> res.collated_counts() + Counter({(("a", "10"),): 1, (("a", "01"),): 1}) Raises: ValueError: If any value is a float. """ - counts: dict[str, Counter[str]] = defaultdict(Counter) + return Counter( + tuple((tag, _flat_bitstring(data)) for tag, data in d.items()) + for d in self._collated_shots_iter() + ) - for d in self._collated_shots_iter(): - bit_chars = { - tag: "".join(_cast_primitive_bit(prim) for prim in _flatten(data)) - for tag, data in d.items() - } - for tag, bit_st in bit_chars.items(): - counts[tag][bit_st] += 1 - return dict(counts) +def _flat_bitstring(data: Iterable[DataValue]) -> str: + return "".join(_cast_primitive_bit(prim) for prim in _flatten(data)) def _flatten(itr: Iterable[DataValue]) -> Iterable[DataPrimitive]: diff --git a/tests/test_hresult.py b/tests/test_hresult.py index 4b81c5fc..b104a64b 100644 --- a/tests/test_hresult.py +++ b/tests/test_hresult.py @@ -135,13 +135,11 @@ def test_collate_tag(): shots = HShots([*shotlist, weird_shot, lst_shot]) counter = shots.collated_counts() - - assert counter == { - "c": Counter({"111": 10, "1": 1}), - "d": Counter({"11111": 10, "10": 1}), - "e": Counter({"1": 1}), - "lst": Counter({"101101": 1}), - } + assert counter == Counter({ + (("c", "111"), ("d", "11111")): 10, + (("c", "1"), ("d", "10"), ("e", "1")): 1, + (("lst", "101101"),): 1, + }) float_shots = HShots( [HResult([("f", 1.0), ("f", 0.1)]), HResult([("f", [2.0]), ("g", 2.0)])]