Skip to content

Commit

Permalink
QDTypes learn to_bits and from_bits (#808)
Browse files Browse the repository at this point in the history
* QDTypes learn to_bits and from_bits

* Fix pylint

* Address comments

* More improvements and fix failing tests
  • Loading branch information
tanujkhattar authored Mar 21, 2024
1 parent f73458c commit b8a00f5
Show file tree
Hide file tree
Showing 2 changed files with 205 additions and 12 deletions.
132 changes: 120 additions & 12 deletions qualtran/_infra/data_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,12 @@
"""

import abc
from typing import Any, Iterable, Union
from typing import Any, Iterable, List, Sequence, Union

import attrs
import numpy as np
import sympy
from fxpmath import Fxp
from numpy.typing import NDArray


Expand All @@ -70,6 +71,14 @@ def get_classical_domain(self) -> Iterable[Any]:
"""Yields all possible classical (computational basis state) values representable
by this type."""

@abc.abstractmethod
def to_bits(self, x) -> List[int]:
"""Yields individual bits corresponding to binary representation of x"""

@abc.abstractmethod
def from_bits(self, bits: Sequence[int]):
"""Combine individual bits to form x"""

@abc.abstractmethod
def assert_valid_classical_val(self, val: Any, debug_str: str = 'val'):
"""Raises an exception if `val` is not a valid classical value for this type.
Expand Down Expand Up @@ -110,6 +119,16 @@ def assert_valid_classical_val(self, val: int, debug_str: str = 'val'):
if not (val == 0 or val == 1):
raise ValueError(f"Bad {self} value {val} in {debug_str}")

def to_bits(self, x) -> List[int]:
"""Yields individual bits corresponding to binary representation of x"""
self.assert_valid_classical_val(x)
return [int(x)]

def from_bits(self, bits: Sequence[int]) -> int:
"""Combine individual bits to form x"""
assert len(bits) == 1
return bits[0]

def assert_valid_classical_val_array(self, val_array: NDArray[int], debug_str: str = 'val'):
if not np.all((val_array == 0) | (val_array == 1)):
raise ValueError(f"Bad {self} value array in {debug_str}")
Expand All @@ -128,6 +147,14 @@ def num_qubits(self):
def get_classical_domain(self) -> Iterable[Any]:
raise TypeError(f"Ambiguous domain for {self}. Please use a more specific type.")

def to_bits(self, x) -> List[int]:
# TODO: Raise an error once usage of `QAny` is minimized across the library
return QUInt(self.bitsize).to_bits(x)

def from_bits(self, bits: Sequence[int]) -> int:
# TODO: Raise an error once usage of `QAny` is minimized across the library
return QUInt(self.bitsize).from_bits(bits)

def assert_valid_classical_val(self, val, debug_str: str = 'val'):
pass

Expand All @@ -152,7 +179,20 @@ def num_qubits(self):
return self.bitsize

def get_classical_domain(self) -> Iterable[int]:
return range(-(2 ** (self.bitsize - 1)), 2 ** (self.bitsize - 1))
max_val = 1 << (self.bitsize - 1)
return range(-max_val, max_val)

def to_bits(self, x: int) -> List[int]:
"""Yields individual bits corresponding to binary representation of x"""
self.assert_valid_classical_val(x)
mask = (1 << self.bitsize) - 1
return QUInt(self.bitsize).to_bits(int(x) & mask)

def from_bits(self, bits: Sequence[int]) -> int:
"""Combine individual bits to form x"""
sign = bits[0]
x = QUInt(self.bitsize - 1).from_bits([1 - x if sign else x for x in bits[1:]])
return ~x if sign else x

def assert_valid_classical_val(self, val: int, debug_str: str = 'val'):
if not isinstance(val, (int, np.integer)):
Expand Down Expand Up @@ -190,11 +230,28 @@ def __attrs_post_init__(self):
def num_qubits(self):
return self.bitsize

def get_classical_domain(self) -> Iterable[Any]:
raise NotImplementedError()
def to_bits(self, x: int) -> List[int]:
"""Yields individual bits corresponding to binary representation of x"""
self.assert_valid_classical_val(x)
return [int(x < 0)] + [y ^ int(x < 0) for y in QUInt(self.bitsize - 1).to_bits(abs(x))]

def from_bits(self, bits: Sequence[int]) -> int:
"""Combine individual bits to form x"""
x = QUInt(self.bitsize).from_bits([b ^ bits[0] for b in bits[1:]])
return (-1) ** bits[0] * x

def get_classical_domain(self) -> Iterable[int]:
max_val = 1 << (self.bitsize - 1)
return range(-max_val + 1, max_val)

def assert_valid_classical_val(self, val, debug_str: str = 'val'):
pass # TODO: implement
if not isinstance(val, (int, np.integer)):
raise ValueError(f"{debug_str} should be an integer, not {val!r}")
max_val = 1 << (self.bitsize - 1)
if not -max_val <= val <= max_val:
raise ValueError(
f"Classical value {val} must be in range [-{max_val}, +{max_val}] in {debug_str}"
)


@attrs.frozen
Expand All @@ -215,7 +272,16 @@ def num_qubits(self):
return self.bitsize

def get_classical_domain(self) -> Iterable[Any]:
return range(2 ** (self.bitsize))
return range(2**self.bitsize)

def to_bits(self, x: int) -> List[int]:
"""Yields individual bits corresponding to binary representation of x"""
self.assert_valid_classical_val(x)
return [int(x) for x in f'{int(x):0{self.bitsize}b}']

def from_bits(self, bits: Sequence[int]) -> int:
"""Combine individual bits to form x"""
return int("".join(str(x) for x in bits), 2)

def assert_valid_classical_val(self, val: int, debug_str: str = 'val'):
if not isinstance(val, (int, np.integer)):
Expand Down Expand Up @@ -309,6 +375,15 @@ def assert_valid_classical_val(self, val: int, debug_str: str = 'val'):
if val >= self.iteration_length:
raise ValueError(f"Too-large classical value encountered in {debug_str}")

def to_bits(self, x: int) -> List[int]:
"""Yields individual bits corresponding to binary representation of x"""
self.assert_valid_classical_val(x, debug_str='val')
return QUInt(self.bitsize).to_bits(x)

def from_bits(self, bits: Sequence[int]) -> int:
"""Combine individual bits to form x"""
return QUInt(self.bitsize).from_bits(bits)

def assert_valid_classical_val_array(self, val_array: NDArray[int], debug_str: str = 'val'):
if np.any(val_array < 0):
raise ValueError(f"Negative classical values encountered in {debug_str}")
Expand Down Expand Up @@ -354,6 +429,22 @@ def num_int(self) -> Union[int, sympy.Expr]:
def fxp_dtype_str(self) -> str:
return f'fxp-{"us"[self.signed]}{self.bitsize}/{self.num_frac}'

@property
def _fxp_dtype(self) -> Fxp:
return Fxp(None, dtype=self.fxp_dtype_str)

def to_bits(self, x: Union[float, Fxp]) -> List[int]:
"""Yields individual bits corresponding to binary representation of x"""
self._assert_valid_classical_val(x)
fxp = x if isinstance(x, Fxp) else Fxp(x)
return [int(x) for x in fxp.like(self._fxp_dtype).bin()]

def from_bits(self, bits: Sequence[int]) -> Fxp:
"""Combine individual bits to form x"""
bits_bin = "".join(str(x) for x in bits[:])
fxp_bin = "0b" + bits_bin[: -self.num_frac] + "." + bits_bin[-self.num_frac :]
return Fxp(fxp_bin, dtype=self.fxp_dtype_str)

def __attrs_post_init__(self):
if isinstance(self.num_qubits, int):
if self.num_qubits == 1 and self.signed:
Expand All @@ -363,11 +454,22 @@ def __attrs_post_init__(self):
if self.bitsize < self.num_frac:
raise ValueError("bitsize must be >= num_frac.")

def get_classical_domain(self) -> Iterable[Any]:
raise NotImplementedError()

def assert_valid_classical_val(self, val, debug_str: str = 'val'):
pass # TODO: implement
def get_classical_domain(self) -> Iterable[Fxp]:
qint = QIntOnesComp(self.bitsize) if self.signed else QUInt(self.bitsize)
for x in qint.get_classical_domain():
yield Fxp(x / 2**self.num_frac, dtype=self.fxp_dtype_str)

def _assert_valid_classical_val(self, val: Union[float, Fxp], debug_str: str = 'val'):
fxp_val = val if isinstance(val, Fxp) else Fxp(val)
if fxp_val.get_val() != fxp_val.like(self._fxp_dtype).get_val():
raise ValueError(
f"{debug_str}={val} cannot be accurately represented using Fxp {fxp_val}"
)

def assert_valid_classical_val(self, val: Union[float, Fxp], debug_str: str = 'val'):
# TODO: Asserting a valid value here opens a can of worms because classical data, except integers,
# is currently not propagated correctly through Bloqs
pass


@attrs.frozen
Expand Down Expand Up @@ -405,7 +507,13 @@ def num_qubits(self):
return self.bitsize

def get_classical_domain(self) -> Iterable[Any]:
return range(2 ** (self.bitsize))
return range(2**self.bitsize)

def to_bits(self, x: int) -> List[int]:
raise NotImplementedError(f"to_bits not implemented for {self}")

def from_bits(self, bits: Sequence[int]) -> int:
raise NotImplementedError(f"from_bits not implemented for {self}")

def assert_valid_classical_val(self, val: int, debug_str: str = 'val'):
if not isinstance(val, (int, np.integer)):
Expand Down
85 changes: 85 additions & 0 deletions qualtran/_infra/data_types_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,3 +207,88 @@ def test_single_qubit_consistency():
assert check_dtypes_consistent(QAny(1), QBit())
assert check_dtypes_consistent(BoundedQUInt(1), QBit())
assert check_dtypes_consistent(QFxp(1, 1), QBit())


def test_to_and_from_bits():
# QInt
qint4 = QInt(4)
assert [*qint4.get_classical_domain()] == [*range(-8, 8)]
for x in range(-8, 8):
assert qint4.from_bits(qint4.to_bits(x)) == x
assert list(qint4.to_bits(-2)) == [1, 1, 1, 0]
assert list(QInt(4).to_bits(2)) == [0, 0, 1, 0]
assert qint4.from_bits(qint4.to_bits(-2)) == -2
assert qint4.from_bits(qint4.to_bits(2)) == 2
with pytest.raises(ValueError):
QInt(4).to_bits(10)

# QUInt
quint4 = QUInt(4)
assert [*quint4.get_classical_domain()] == [*range(0, 16)]
assert list(quint4.to_bits(10)) == [1, 0, 1, 0]
assert quint4.from_bits(quint4.to_bits(10)) == 10
for x in range(16):
assert quint4.from_bits(quint4.to_bits(x)) == x
with pytest.raises(ValueError):
quint4.to_bits(16)

with pytest.raises(ValueError):
quint4.to_bits(-1)

# BoundedQUInt
bquint4 = BoundedQUInt(4, 12)
assert [*bquint4.get_classical_domain()] == [*range(0, 12)]
assert list(bquint4.to_bits(10)) == [1, 0, 1, 0]
with pytest.raises(ValueError):
BoundedQUInt(4, 12).to_bits(13)

# QBit
assert list(QBit().to_bits(0)) == [0]
assert list(QBit().to_bits(1)) == [1]
with pytest.raises(ValueError):
QBit().to_bits(2)

# QAny
assert list(QAny(4).to_bits(10)) == [1, 0, 1, 0]

# QIntOnesComp
qintones4 = QIntOnesComp(4)
assert list(qintones4.to_bits(-2)) == [1, 1, 0, 1]
assert list(qintones4.to_bits(2)) == [0, 0, 1, 0]
assert [*qintones4.get_classical_domain()] == [*range(-7, 8)]
for x in range(-7, 8):
assert qintones4.from_bits(qintones4.to_bits(x)) == x
with pytest.raises(ValueError):
qintones4.to_bits(8)
with pytest.raises(ValueError):
qintones4.to_bits(-8)

# QFxp: Negative numbers are stored as ones complement
qfxp_4_3 = QFxp(4, 3, True)
assert list(qfxp_4_3.to_bits(0.5)) == [0, 1, 0, 0]
assert qfxp_4_3.from_bits(qfxp_4_3.to_bits(0.5)).get_val() == 0.5
assert list(qfxp_4_3.to_bits(-0.5)) == [1, 1, 0, 0]
assert qfxp_4_3.from_bits(qfxp_4_3.to_bits(-0.5)).get_val() == -0.5
assert list(qfxp_4_3.to_bits(0.625)) == [0, 1, 0, 1]
assert qfxp_4_3.from_bits(qfxp_4_3.to_bits(+0.625)).get_val() == +0.625
assert qfxp_4_3.from_bits(qfxp_4_3.to_bits(-0.625)).get_val() == -0.625
assert list(QFxp(4, 3, True).to_bits(-(1 - 0.625))) == [1, 1, 0, 1]
assert qfxp_4_3.from_bits(qfxp_4_3.to_bits(0.375)).get_val() == 0.375
assert qfxp_4_3.from_bits(qfxp_4_3.to_bits(-0.375)).get_val() == -0.375
with pytest.raises(ValueError):
_ = qfxp_4_3.to_bits(0.1)

with pytest.raises(ValueError):
_ = qfxp_4_3.to_bits(1.5)

assert qfxp_4_3.from_bits(qfxp_4_3.to_bits(1 / 2 + 1 / 4 + 1 / 8)) == 1 / 2 + 1 / 4 + 1 / 8
assert qfxp_4_3.from_bits(qfxp_4_3.to_bits(-1 / 2 - 1 / 4 - 1 / 8)) == -1 / 2 - 1 / 4 - 1 / 8
with pytest.raises(ValueError):
_ = qfxp_4_3.to_bits(1 / 2 + 1 / 4 + 1 / 8 + 1 / 16)

for qfxp in [QFxp(4, 3, True), QFxp(3, 3, False), QFxp(7, 3, False), QFxp(7, 3, True)]:
for x in qfxp.get_classical_domain():
assert qfxp.from_bits(qfxp.to_bits(x)) == x

assert list(QFxp(7, 3, True).to_bits(-4.375)) == [1] + [0, 1, 1] + [1, 0, 1]
assert list(QFxp(7, 3, True).to_bits(+4.625)) == [0] + [1, 0, 0] + [1, 0, 1]

0 comments on commit b8a00f5

Please sign in to comment.