Skip to content

Commit

Permalink
feat: result appending post-processing (#713)
Browse files Browse the repository at this point in the history
To enable e.g. all the elements of an array to be collated in to the
same bitstrings, see tests

Flattens lists so can be used once entire arrays can be printed too.

So some guppy that looks like for a given shot:

```python
vals = (1, 0, 1, 0)
for v in vals:
    result("c", v)
```


results in post-processed output, over 100 shots
```python
Counter({(("c", 1010"),): 100})
```
  • Loading branch information
ss2165 authored Jan 3, 2025
1 parent d1ade94 commit d0ec1ce
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 1 deletion.
51 changes: 50 additions & 1 deletion guppylang/hresult.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,15 @@ def to_register_bits(self) -> dict[str, str]:

return {reg: "".join(bits) for reg, bits in reg_bits.items()}

def collate_tags(self) -> dict[str, list[DataValue]]:
"""Collate all the entries with the same tag in to a dictionary with a list
containing all the data for that tag."""

tags: dict[str, list[DataValue]] = defaultdict(list)
for tag, data in self.entries:
tags[tag].append(data)
return dict(tags)


def _cast_primitive_bit(data: DataValue) -> BitChar:
if isinstance(data, int) and data in {0, 1}:
Expand Down Expand Up @@ -173,7 +182,7 @@ def register_bitstrings(
shot_dct[reg].append(bitstr)
if strict_names and not bitstrs.keys() == shot_dct.keys():
raise ValueError("All shots must have the same registers.")
return shot_dct
return dict(shot_dct)

def to_pytket(self) -> BackendResult:
"""Convert results to a pytket BackendResult.
Expand Down Expand Up @@ -208,3 +217,43 @@ def to_pytket(self) -> BackendResult:
return BackendResult(
shots=OutcomeArray.from_ints(int_shots, width=len(bits)), c_bits=bits
)

def _collated_shots_iter(self) -> Iterable[dict[str, list[DataValue]]]:
for shot in self.results:
yield shot.collate_tags()

def collated_shots(self) -> list[dict[str, list[DataValue]]]:
"""For each shot generate a dictionary of tags to collated data."""
return list(self._collated_shots_iter())

def collated_counts(self) -> Counter[tuple[tuple[str, str], ...]]:
"""Calculate counts of bit strings for each tag by collating across shots using
`HShots.tag_collated_shots`. Each `result` entry per shot is seen to be
appending to the bitstring for that tag.
If the result value is a list, it is flattened and appended to the bitstring.
Example:
>>> res = HShots([HResult([("a", 1), ("a", 0)]), HResult([("a", [0, 1])])])
>>> res.collated_counts()
Counter({(("a", "10"),): 1, (("a", "01"),): 1})
Raises:
ValueError: If any value is a float.
"""
return Counter(
tuple((tag, _flat_bitstring(data)) for tag, data in d.items())
for d in self._collated_shots_iter()
)


def _flat_bitstring(data: Iterable[DataValue]) -> str:
return "".join(_cast_primitive_bit(prim) for prim in _flatten(data))


def _flatten(itr: Iterable[DataValue]) -> Iterable[DataPrimitive]:
for i in itr:
if isinstance(i, list):
yield from _flatten(i)
else:
yield i
36 changes: 36 additions & 0 deletions tests/test_hresult.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,3 +113,39 @@ def test_pytket():
)

assert pytket_result == expected


def test_collate_tag():
# test use of same tag for all entries of array

shotlist = []
for _ in range(10):
shot = HResult()
_ = [
shot.append(reg, 1)
for reg, size in (("c", 3), ("d", 5))
for _ in range(size)
]
shotlist.append(shot)

weird_shot = HResult((("c", 1), ("d", 1), ("d", 0), ("e", 1)))
assert weird_shot.collate_tags() == {"c": [1], "d": [1, 0], "e": [1]}

lst_shot = HResult([("lst", [1, 0, 1]), ("lst", [1, 0, 1])])
shots = HShots([*shotlist, weird_shot, lst_shot])

counter = shots.collated_counts()
assert counter == Counter({
(("c", "111"), ("d", "11111")): 10,
(("c", "1"), ("d", "10"), ("e", "1")): 1,
(("lst", "101101"),): 1,
})

float_shots = HShots(
[HResult([("f", 1.0), ("f", 0.1)]), HResult([("f", [2.0]), ("g", 2.0)])]
)

assert float_shots.collated_shots() == [
{"f": [1.0, 0.1]},
{"f": [[2.0]], "g": [2.0]},
]

0 comments on commit d0ec1ce

Please sign in to comment.