forked from Qiskit/qiskit
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* add SamplerPubResult * add join_data * Update qiskit/primitives/containers/sampler_pub_result.py Co-authored-by: Ian Hincks <[email protected]> * adding tests * add join_data tests * add reno * fix linting --------- Co-authored-by: Ian Hincks <[email protected]> Co-authored-by: Ian Hincks <[email protected]>
- Loading branch information
1 parent
43c065f
commit 849fa00
Showing
8 changed files
with
215 additions
and
19 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
# This code is part of Qiskit. | ||
# | ||
# (C) Copyright IBM 2024. | ||
# | ||
# This code is licensed under the Apache License, Version 2.0. You may | ||
# obtain a copy of this license in the LICENSE.txt file in the root directory | ||
# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0. | ||
# | ||
# Any modifications or derivative works of this code must retain this | ||
# copyright notice, and modified files need to carry a notice indicating | ||
# that they have been altered from the originals. | ||
|
||
""" | ||
Sampler Pub result class | ||
""" | ||
|
||
from __future__ import annotations | ||
|
||
from typing import Iterable | ||
|
||
import numpy as np | ||
|
||
from .bit_array import BitArray | ||
from .pub_result import PubResult | ||
|
||
|
||
class SamplerPubResult(PubResult): | ||
"""Result of Sampler Pub.""" | ||
|
||
def join_data(self, names: Iterable[str] | None = None) -> BitArray | np.ndarray: | ||
"""Join data from many registers into one data container. | ||
Data is joined along the bits axis. For example, for :class:`~.BitArray` data, this corresponds | ||
to bitstring concatenation. | ||
Args: | ||
names: Which registers to join. Their order is maintained, for example, given | ||
``["alpha", "beta"]``, the data from register ``alpha`` is placed to the left of the | ||
data from register ``beta``. When ``None`` is given, this value is set to the | ||
ordered list of register names, which will have been preserved from the input circuit | ||
order. | ||
Returns: | ||
Joint data. | ||
Raises: | ||
ValueError: If specified names are empty. | ||
ValueError: If specified name does not exist. | ||
TypeError: If specified data comes from incompatible types. | ||
""" | ||
if names is None: | ||
names = list(self.data) | ||
if not names: | ||
raise ValueError("No entry exists in the data bin.") | ||
else: | ||
names = list(names) | ||
if not names: | ||
raise ValueError("An empty name list is given.") | ||
for name in names: | ||
if name not in self.data: | ||
raise ValueError(f"Name '{name}' does not exist.") | ||
|
||
data = [self.data[name] for name in names] | ||
if isinstance(data[0], BitArray): | ||
if not all(isinstance(datum, BitArray) for datum in data): | ||
raise TypeError("Data comes from incompatible types.") | ||
joint_data = BitArray.concatenate_bits(data) | ||
elif isinstance(data[0], np.ndarray): | ||
if not all(isinstance(datum, np.ndarray) for datum in data): | ||
raise TypeError("Data comes from incompatible types.") | ||
joint_data = np.concatenate(data, axis=-1) | ||
else: | ||
raise TypeError("Data comes from incompatible types.") | ||
return joint_data |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
17 changes: 17 additions & 0 deletions
17
releasenotes/notes/sampler-pub-result-e64e7de1bae2d35e.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
--- | ||
features_primitives: | ||
- | | ||
The subclass :class:`~.SamplerPubResult` of :class:`~.PubResult` was added, | ||
which :class:`~.BaseSamplerV2` implementations can return. The main feature | ||
added in this new subclass is :meth:`~.SamplerPubResult.join_data`, which | ||
joins together (a subset of) the contents of :attr:`~.PubResult.data` into | ||
a single object. This enables the following patterns: | ||
.. code:: python | ||
job_result = sampler.run([pub1, pub2, pub3], shots=123).result() | ||
# assuming all returned data entries are BitArrays | ||
counts1 = job_result[0].join_data().get_counts() | ||
bistrings2 = job_result[1].join_data().get_bitstrings() | ||
array3 = job_result[2].join_data().array |
104 changes: 104 additions & 0 deletions
104
test/python/primitives/containers/test_sampler_pub_result.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,104 @@ | ||
# This code is part of Qiskit. | ||
# | ||
# (C) Copyright IBM 2024. | ||
# | ||
# This code is licensed under the Apache License, Version 2.0. You may | ||
# obtain a copy of this license in the LICENSE.txt file in the root directory | ||
# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0. | ||
# | ||
# Any modifications or derivative works of this code must retain this | ||
# copyright notice, and modified files need to carry a notice indicating | ||
# that they have been altered from the originals. | ||
|
||
|
||
"""Unit tests for SamplerPubResult.""" | ||
|
||
from test import QiskitTestCase | ||
|
||
import numpy as np | ||
|
||
from qiskit.primitives.containers import BitArray, DataBin, SamplerPubResult | ||
|
||
|
||
class SamplerPubResultCase(QiskitTestCase): | ||
"""Test the SamplerPubResult class.""" | ||
|
||
def test_construction(self): | ||
"""Test that the constructor works.""" | ||
ba = BitArray.from_samples(["00", "11"], 2) | ||
counts = {"00": 1, "11": 1} | ||
data_bin = DataBin(a=ba, b=ba) | ||
pub_result = SamplerPubResult(data_bin) | ||
self.assertEqual(pub_result.data.a.get_counts(), counts) | ||
self.assertEqual(pub_result.data.b.get_counts(), counts) | ||
self.assertEqual(pub_result.metadata, {}) | ||
|
||
pub_result = SamplerPubResult(data_bin, {"x": 1}) | ||
self.assertEqual(pub_result.data.a.get_counts(), counts) | ||
self.assertEqual(pub_result.data.b.get_counts(), counts) | ||
self.assertEqual(pub_result.metadata, {"x": 1}) | ||
|
||
def test_repr(self): | ||
"""Test that the repr doesn't fail""" | ||
# we are primarily interested in making sure some future change doesn't cause the repr to | ||
# raise an error. it is more sensible for humans to detect a deficiency in the formatting | ||
# itself, should one be uncovered | ||
ba = BitArray.from_samples(["00", "11"], 2) | ||
data_bin = DataBin(a=ba, b=ba) | ||
self.assertTrue(repr(SamplerPubResult(data_bin)).startswith("SamplerPubResult")) | ||
self.assertTrue(repr(SamplerPubResult(data_bin, {"x": 1})).startswith("SamplerPubResult")) | ||
|
||
def test_join_data_failures(self): | ||
"""Test the join_data() failure mechanisms work.""" | ||
|
||
result = SamplerPubResult(DataBin()) | ||
with self.assertRaisesRegex(ValueError, "No entry exists in the data bin"): | ||
result.join_data() | ||
|
||
alpha = BitArray.from_samples(["00", "11"], 2) | ||
beta = BitArray.from_samples(["010", "101"], 3) | ||
result = SamplerPubResult(DataBin(alpha=alpha, beta=beta)) | ||
with self.assertRaisesRegex(ValueError, "An empty name list is given"): | ||
result.join_data([]) | ||
|
||
alpha = BitArray.from_samples(["00", "11"], 2) | ||
beta = BitArray.from_samples(["010", "101"], 3) | ||
result = SamplerPubResult(DataBin(alpha=alpha, beta=beta)) | ||
with self.assertRaisesRegex(ValueError, "Name 'foo' does not exist"): | ||
result.join_data(["alpha", "foo"]) | ||
|
||
alpha = BitArray.from_samples(["00", "11"], 2) | ||
beta = np.empty((2,)) | ||
result = SamplerPubResult(DataBin(alpha=alpha, beta=beta)) | ||
with self.assertRaisesRegex(TypeError, "Data comes from incompatible types"): | ||
result.join_data() | ||
|
||
alpha = np.empty((2,)) | ||
beta = BitArray.from_samples(["00", "11"], 2) | ||
result = SamplerPubResult(DataBin(alpha=alpha, beta=beta)) | ||
with self.assertRaisesRegex(TypeError, "Data comes from incompatible types"): | ||
result.join_data() | ||
|
||
result = SamplerPubResult(DataBin(alpha=1, beta={})) | ||
with self.assertRaisesRegex(TypeError, "Data comes from incompatible types"): | ||
result.join_data() | ||
|
||
def test_join_data_bit_array_default(self): | ||
"""Test the join_data() method with no arguments and bit arrays.""" | ||
alpha = BitArray.from_samples(["00", "11"], 2) | ||
beta = BitArray.from_samples(["010", "101"], 3) | ||
data_bin = DataBin(alpha=alpha, beta=beta) | ||
result = SamplerPubResult(data_bin) | ||
|
||
gamma = result.join_data() | ||
self.assertEqual(list(gamma.get_bitstrings()), ["01000", "10111"]) | ||
|
||
def test_join_data_ndarray_default(self): | ||
"""Test the join_data() method with no arguments and ndarrays.""" | ||
alpha = np.linspace(0, 1, 30).reshape((2, 3, 5)) | ||
beta = np.linspace(0, 1, 12).reshape((2, 3, 2)) | ||
data_bin = DataBin(alpha=alpha, beta=beta, shape=(2, 3)) | ||
result = SamplerPubResult(data_bin) | ||
|
||
gamma = result.join_data() | ||
np.testing.assert_allclose(gamma, np.concatenate([alpha, beta], axis=2)) |