Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Vectorized variants of to_bits and from_bits #1142

Closed
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
6637af9
Vectorized variants of `to_bits` and `from_bits`
anurudhp Jul 15, 2024
96cf992
move functions from `classical_sim` to `data_types`
anurudhp Jul 15, 2024
d6d7cc3
use new functionality everywhere
anurudhp Jul 15, 2024
685f4bb
fix `QFxp.from_bits_array`
anurudhp Jul 16, 2024
aa32062
split phase gradient tests
anurudhp Jul 16, 2024
539c4b1
remove unused classical sim functions
anurudhp Jul 16, 2024
70ad494
mypy
anurudhp Jul 16, 2024
35c81f1
fix simulation for AddIntoPhaseGrad bloqs
anurudhp Jul 16, 2024
1c9b370
Merge branch 'main' into 2024/07/15-refactor-dtype-classical-sim
anurudhp Jul 23, 2024
8f6c7e9
fix phasegrad classical simulation
anurudhp Jul 23, 2024
063e49b
Merge branch 'main' into 2024/07/15-refactor-dtype-classical-sim
anurudhp Jul 23, 2024
7b79cf1
fix Fxp default config (overflow=wrap, shifting=trunc, op_sizing=same)
anurudhp Jul 23, 2024
73681bb
split unittests
anurudhp Jul 23, 2024
013ff0f
Merge branch 'main' into 2024/07/15-refactor-dtype-classical-sim
anurudhp Jul 23, 2024
9f8dddc
fix cast, assert QFxp classical val (partial)
anurudhp Jul 23, 2024
e1fa90a
fix Fxp behavior (overflow=wrap, shifting=trunc)
anurudhp Jul 23, 2024
0a99fff
more assert_valid_classical_val
anurudhp Jul 24, 2024
88e831a
cleanup types and boilerplate
anurudhp Jul 24, 2024
86d7eee
rename
anurudhp Jul 24, 2024
3aa7e0a
cleanup classical sim
anurudhp Jul 24, 2024
15394e5
rename `gamma_fxp` to `abs_gamma_fxp`
anurudhp Jul 24, 2024
80069f8
cleanup `.apply` and old methods
anurudhp Jul 24, 2024
39f391b
make `QFxp.fxp_dtype_template` public, construct constants using it (…
anurudhp Jul 24, 2024
b86b529
fix Fxp constants in cast test
anurudhp Jul 24, 2024
74dd54c
Merge branch 'main' into 2024/07/15-refactor-dtype-classical-sim
anurudhp Jul 24, 2024
613e153
fix classical values in `PlusEqualsProduct`
anurudhp Jul 24, 2024
a7d82fa
cleanup
anurudhp Jul 24, 2024
6ad1163
mypy
anurudhp Jul 24, 2024
2c741b4
fix QFxp Fxp template
anurudhp Jul 24, 2024
6e94ca8
fix classical call args
anurudhp Jul 24, 2024
01d3511
fix _mul_via_repeated_add (Fxp shifting is buggy)
anurudhp Jul 25, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 71 additions & 11 deletions qualtran/_infra/data_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,10 +77,30 @@ def get_classical_domain(self) -> Iterable[Any]:
def to_bits(self, x) -> List[int]:
"""Yields individual bits corresponding to binary representation of x"""

def to_bits_array(self, x_array: NDArray[Any]) -> NDArray[np.uint8]:
"""Yields an NDArray of bits corresponding to binary representations of the input elements.

Often, converting an array can be performed faster than converting each element individually.
This operation accepts any NDArray of values, and the output array satisfies
`output_shape = input_shape + (self.bitsize,)`.
"""
return np.vectorize(
lambda x: np.asarray(self.to_bits(x), dtype=np.uint8), signature='()->(n)'
)(x_array)
Comment on lines +87 to +89
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm sure you know this, but as far as I understand it np.vectorize will use a python for-loop under-the-hood and you don't get any special performance improvements by using it. You get the correct api and broadcasting behavior, however.

why is the signature argument needed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Without the signature, it tries to pack each output as a single entry in the array, and fails when we return a vector that needs to be treated as an additional dimension


@abc.abstractmethod
def from_bits(self, bits: Sequence[int]):
"""Combine individual bits to form x"""

def from_bits_array(self, bits_array: NDArray[np.uint8]):
"""Combine individual bits to form classical values.

Often, converting an array can be performed faster than converting each element individually.
This operation accepts any NDArray of bits such that the last dimension equals `self.bitsize`,
and the output array satisfies `output_shape = input_shape[:-1]`.
"""
return np.vectorize(self.from_bits, signature='(n)->()')(bits_array)

@abc.abstractmethod
def assert_valid_classical_val(self, val: Any, debug_str: str = 'val'):
"""Raises an exception if `val` is not a valid classical value for this type.
Expand All @@ -90,17 +110,6 @@ def assert_valid_classical_val(self, val: Any, debug_str: str = 'val'):
debug_str: Optional debugging information to use in exception messages.
"""

@abc.abstractmethod
def is_symbolic(self) -> bool:
"""Returns True if this qdtype is parameterized with symbolic objects."""

def iteration_length_or_zero(self) -> SymbolicInt:
"""Safe version of iteration length.

Returns the iteration_length if the type has it or else zero.
"""
return getattr(self, 'iteration_length', 0)

def assert_valid_classical_val_array(self, val_array: NDArray[Any], debug_str: str = 'val'):
"""Raises an exception if `val_array` is not a valid array of classical values
for this type.
Expand All @@ -116,6 +125,17 @@ def assert_valid_classical_val_array(self, val_array: NDArray[Any], debug_str: s
for val in val_array.reshape(-1):
self.assert_valid_classical_val(val)

@abc.abstractmethod
def is_symbolic(self) -> bool:
"""Returns True if this qdtype is parameterized with symbolic objects."""

def iteration_length_or_zero(self) -> SymbolicInt:
"""Safe version of iteration length.

Returns the iteration_length if the type has it or else zero.
"""
return getattr(self, 'iteration_length', 0)

def __str__(self):
return f'{self.__class__.__name__}({self.num_qubits})'

Expand Down Expand Up @@ -324,10 +344,43 @@ def to_bits(self, x: int) -> List[int]:
self.assert_valid_classical_val(x)
return [int(x) for x in f'{int(x):0{self.bitsize}b}']

def to_bits_array(self, x_array: NDArray[np.integer]) -> NDArray[np.uint8]:
"""Returns the big-endian bitstrings specified by the given integers.

Args:
x_array: An integer or array of unsigned integers.
"""
if is_symbolic(self.bitsize):
raise ValueError(f"Cannot compute bits for symbolic {self.bitsize=}")

w = int(self.bitsize)
x = np.atleast_1d(x_array)
if not np.issubdtype(x.dtype, np.uint):
assert np.all(x >= 0)
assert np.iinfo(x.dtype).bits <= 64
x = x.astype(np.uint64)
assert w <= np.iinfo(x.dtype).bits
mask = 2 ** np.arange(w - 1, 0 - 1, -1, dtype=x.dtype).reshape((w, 1))
return (x & mask).astype(bool).astype(np.uint8).T

def from_bits(self, bits: Sequence[int]) -> int:
"""Combine individual bits to form x"""
return int("".join(str(x) for x in bits), 2)

def from_bits_array(self, bits_array: NDArray[np.uint8]) -> NDArray[np.integer]:
"""Returns the integer specified by the given big-endian bitstrings.

Args:
bits_array: A bitstring or array of bitstrings, each of which has the 1s bit (LSB) at the end.
Returns:
An array of integers; one for each bitstring.
"""
bitstrings = np.atleast_2d(bits_array)
if bitstrings.shape[1] > 64:
raise NotImplementedError()
basis = 2 ** np.arange(bitstrings.shape[1] - 1, 0 - 1, -1, dtype=np.uint64)
return np.sum(basis * bitstrings, axis=1, dtype=np.uint64)

def assert_valid_classical_val(self, val: int, debug_str: str = 'val'):
if not isinstance(val, (int, np.integer)):
raise ValueError(f"{debug_str} should be an integer, not {val!r}")
Expand Down Expand Up @@ -528,6 +581,13 @@ def from_bits(self, bits: Sequence[int]) -> Fxp:
fxp_bin = "0b" + bits_bin[: -self.num_frac] + "." + bits_bin[-self.num_frac :]
return Fxp(fxp_bin, dtype=self.fxp_dtype_str)

def from_bits_array(self, bits_array: NDArray[np.uint8]):
assert isinstance(self.bitsize, int), "cannot convert to bits for symbolic bitsize"
# TODO figure out why `np.vectorize` is not working here
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

open an issue and link? Do you have any theories? as I understand it: np.vectorize just does a python for loop under-the-hood

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's something to do with how Fxp interacts with numpy. Fxp has some inbuilt support to operate over NDArrays, so perhaps mixing the order up causes issues. I didn't investigate more though.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

An Fxp object can wrap a numpy array -- so to get a ND collection of Fxp objects, you construct a Fxp(numpy_array_of_int_or_float_values) instead of np.array([Fxp(x) for x in array_of_int_or_float_values])

See https://github.com/francof2a/fxpmath?tab=readme-ov-file#arithmetic for more details

return Fxp(
[self.from_bits(bitstring) for bitstring in bits_array.reshape(-1, self.bitsize)]
)

def to_fixed_width_int(self, x: Union[float, Fxp]) -> int:
"""Returns the interpretation of the binary representation of `x` as an integer. Requires `x` to be nonnegative."""
if x < 0:
Expand Down
85 changes: 77 additions & 8 deletions qualtran/_infra/data_types_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import math
import random
from typing import Any, Sequence, Union

import cirq
import numpy as np
import pytest
import sympy
from numpy.typing import NDArray

from qualtran.symbolics import is_symbolic

Expand Down Expand Up @@ -233,8 +235,20 @@ def test_single_qubit_consistency():
assert check_dtypes_consistent(QFxp(1, 1), QBit())


def test_to_and_from_bits():
# QInt
def assert_to_and_from_bits_array_consistent(qdtype: QDType, values: Union[Sequence[Any], NDArray]):
values = np.asarray(values)
bits_array = qdtype.to_bits_array(values)

# individual values
for val, bits in zip(values.reshape(-1), bits_array.reshape(-1, qdtype.num_qubits)):
assert np.all(bits == qdtype.to_bits(val))

# round trip
values_roundtrip = qdtype.from_bits_array(bits_array)
assert np.all(values_roundtrip == values)


def test_qint_to_and_from_bits():
qint4 = QInt(4)
assert [*qint4.get_classical_domain()] == [*range(-8, 8)]
for x in range(-8, 8):
Expand All @@ -246,7 +260,10 @@ def test_to_and_from_bits():
with pytest.raises(ValueError):
QInt(4).to_bits(10)

# QUInt
assert_to_and_from_bits_array_consistent(qint4, range(-8, 8))


def test_quint_to_and_from_bits():
quint4 = QUInt(4)
assert [*quint4.get_classical_domain()] == [*range(0, 16)]
assert list(quint4.to_bits(10)) == [1, 0, 1, 0]
Expand All @@ -259,23 +276,66 @@ def test_to_and_from_bits():
with pytest.raises(ValueError):
quint4.to_bits(-1)

# BoundedQUInt
assert_to_and_from_bits_array_consistent(quint4, range(0, 16))


def test_bits_to_int():
rs = np.random.RandomState(52)
bitstrings = rs.choice([0, 1], size=(100, 23))

nums = QUInt(23).from_bits_array(bitstrings)
assert nums.shape == (100,)

for num, bs in zip(nums, bitstrings):
ref_num = cirq.big_endian_bits_to_int(bs.tolist())
assert num == ref_num

# check one input bitstring instead of array of input bitstrings.
(num,) = QUInt(23).from_bits_array(np.array([1, 0]))
assert num == 2


def test_int_to_bits():
rs = np.random.RandomState(52)
nums = rs.randint(0, 2**23 - 1, size=(100,), dtype=np.uint64)
bitstrings = QUInt(23).to_bits_array(nums)
assert bitstrings.shape == (100, 23)

for num, bs in zip(nums, bitstrings):
ref_bs = cirq.big_endian_int_to_bits(int(num), bit_count=23)
np.testing.assert_array_equal(ref_bs, bs)

# check bounds
with pytest.raises(AssertionError):
QUInt(8).to_bits_array(np.array([4, -2]))


def test_bounded_quint_to_and_from_bits():
bquint4 = BoundedQUInt(4, 12)
assert [*bquint4.get_classical_domain()] == [*range(0, 12)]
assert list(bquint4.to_bits(10)) == [1, 0, 1, 0]
with pytest.raises(ValueError):
BoundedQUInt(4, 12).to_bits(13)

# QBit
assert_to_and_from_bits_array_consistent(bquint4, range(0, 12))


def test_qbit_to_and_from_bits():
assert list(QBit().to_bits(0)) == [0]
assert list(QBit().to_bits(1)) == [1]
with pytest.raises(ValueError):
QBit().to_bits(2)

# QAny
assert_to_and_from_bits_array_consistent(QBit(), [0, 1])


def test_qany_to_and_from_bits():
assert list(QAny(4).to_bits(10)) == [1, 0, 1, 0]

# QIntOnesComp
assert_to_and_from_bits_array_consistent(QAny(4), range(16))


def test_qintonescomp_to_and_from_bits():
qintones4 = QIntOnesComp(4)
assert list(qintones4.to_bits(-2)) == [1, 1, 0, 1]
assert list(qintones4.to_bits(2)) == [0, 0, 1, 0]
Expand All @@ -287,6 +347,10 @@ def test_to_and_from_bits():
with pytest.raises(ValueError):
qintones4.to_bits(-8)

assert_to_and_from_bits_array_consistent(qintones4, range(-7, 8))


def test_qfxp_to_and_from_bits():
# QFxp: Negative numbers are stored as ones complement
qfxp_4_3 = QFxp(4, 3, True)
assert list(qfxp_4_3.to_bits(0.5)) == [0, 1, 0, 0]
Expand Down Expand Up @@ -321,6 +385,11 @@ def test_to_and_from_bits():
assert list(QFxp(7, 3, True).to_bits(-4.375)) == [1] + [0, 1, 1] + [1, 0, 1]
assert list(QFxp(7, 3, True).to_bits(+4.625)) == [0] + [1, 0, 0] + [1, 0, 1]

assert_to_and_from_bits_array_consistent(QFxp(4, 3, False), [1 / 2, 1 / 4, 3 / 8])
assert_to_and_from_bits_array_consistent(
QFxp(4, 3, True), [1 / 2, -1 / 2, 1 / 4, -1 / 4, -3 / 8, 3 / 8]
)


def test_iter_bits():
assert QUInt(2).to_bits(0) == [0, 0]
Expand Down
14 changes: 9 additions & 5 deletions qualtran/bloqs/basic_gates/z_basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
DecomposeTypeError,
QAny,
QBit,
QDType,
Register,
Side,
Signature,
Expand All @@ -42,7 +43,6 @@
from qualtran.bloqs.bookkeeping import ArbitraryClifford
from qualtran.cirq_interop.t_complexity_protocol import TComplexity
from qualtran.drawing import Circle, directional_text_box, Text, TextBox, WireSymbol
from qualtran.simulation.classical_sim import ints_to_bits

if TYPE_CHECKING:
import cirq
Expand Down Expand Up @@ -384,12 +384,16 @@ def check(self, attribute, val):
if val >= 2**self.bitsize:
raise ValueError(f"`val` is too big for bitsize {self.bitsize}")

@cached_property
def dtype(self) -> QDType:
if self.bitsize == 1:
return QBit()
return QAny(self.bitsize)

@cached_property
def signature(self) -> Signature:
side = Side.RIGHT if self.state else Side.LEFT
if self.bitsize == 1:
return Signature([Register('val', QBit(), side=side)])
return Signature([Register('val', QAny(self.bitsize), side=side)])
return Signature([Register('val', self.dtype, side=side)])

@staticmethod
def _build_composite_state(bb: 'BloqBuilder', bits: NDArray[np.uint8]) -> Dict[str, 'SoquetT']:
Expand All @@ -415,7 +419,7 @@ def _build_composite_effect(
def build_composite_bloq(self, bb: 'BloqBuilder', **val: 'SoquetT') -> Dict[str, 'SoquetT']:
if isinstance(self.bitsize, sympy.Expr):
raise DecomposeTypeError(f'Symbolic bitsize {self.bitsize} not supported')
bits = ints_to_bits(np.array([self.val]), w=self.bitsize)[0]
bits = np.asarray(self.dtype.to_bits(self.val))
if self.state:
assert not val
return self._build_composite_state(bb, bits)
Expand Down
3 changes: 1 addition & 2 deletions qualtran/bloqs/bookkeeping/join.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
)
from qualtran.bloqs.bookkeeping._bookkeeping_bloq import _BookkeepingBloq
from qualtran.drawing import directional_text_box, Text, WireSymbol
from qualtran.simulation.classical_sim import bits_to_ints

if TYPE_CHECKING:
import quimb.tensor as qtn
Expand Down Expand Up @@ -95,7 +94,7 @@ def my_tensors(
]

def on_classical_vals(self, reg: 'NDArray[np.uint]') -> Dict[str, int]:
return {'reg': bits_to_ints(reg)[0]}
return {'reg': self.dtype.from_bits(reg.tolist())}

def wire_symbol(self, reg: Optional[Register], idx: Tuple[int, ...] = tuple()) -> 'WireSymbol':
if reg is None:
Expand Down
Loading
Loading