Skip to content

Commit

Permalink
preserve register correlations
Browse files Browse the repository at this point in the history
  • Loading branch information
ss2165 committed Dec 23, 2024
1 parent 9d4fc25 commit 54c2582
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 19 deletions.
21 changes: 9 additions & 12 deletions guppylang/hresult.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ def collated_shots(self) -> list[dict[str, list[DataValue]]]:
"""For each shot generate a dictionary of tags to collated data."""
return list(self._collated_shots_iter())

def collated_counts(self) -> dict[str, Counter[str]]:
def collated_counts(self) -> Counter[tuple[tuple[str, str], ...]]:
"""Calculate counts of bit strings for each tag by collating across shots using
`HShots.tag_collated_shots`. Each `result` entry per shot is seen to be
appending to the bitstring for that tag.
Expand All @@ -235,23 +235,20 @@ def collated_counts(self) -> dict[str, Counter[str]]:
Example:
>>> res = HShots([HResult([("a", 1), ("a", 0)]), HResult([("a", [0, 1])])])
>>> res.tag_collated_counts()
{'a': Counter({'10': 1, '01': 1})}
>>> res.collated_counts()
Counter({(("a", "10"),): 1, (("a", "01"),): 1})
Raises:
ValueError: If any value is a float.
"""
counts: dict[str, Counter[str]] = defaultdict(Counter)
return Counter(
tuple((tag, _flat_bitstring(data)) for tag, data in d.items())
for d in self._collated_shots_iter()
)

for d in self._collated_shots_iter():
bit_chars = {
tag: "".join(_cast_primitive_bit(prim) for prim in _flatten(data))
for tag, data in d.items()
}
for tag, bit_st in bit_chars.items():
counts[tag][bit_st] += 1

return dict(counts)
def _flat_bitstring(data: Iterable[DataValue]) -> str:
return "".join(_cast_primitive_bit(prim) for prim in _flatten(data))


def _flatten(itr: Iterable[DataValue]) -> Iterable[DataPrimitive]:
Expand Down
12 changes: 5 additions & 7 deletions tests/test_hresult.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,13 +135,11 @@ def test_collate_tag():
shots = HShots([*shotlist, weird_shot, lst_shot])

counter = shots.collated_counts()

assert counter == {
"c": Counter({"111": 10, "1": 1}),
"d": Counter({"11111": 10, "10": 1}),
"e": Counter({"1": 1}),
"lst": Counter({"101101": 1}),
}
assert counter == Counter({
(("c", "111"), ("d", "11111")): 10,
(("c", "1"), ("d", "10"), ("e", "1")): 1,
(("lst", "101101"),): 1,
})

float_shots = HShots(
[HResult([("f", 1.0), ("f", 0.1)]), HResult([("f", [2.0]), ("g", 2.0)])]
Expand Down

0 comments on commit 54c2582

Please sign in to comment.