Skip to content

Commit

Permalink
Use QDType bit interfaces for classical simulation (#1215)
Browse files Browse the repository at this point in the history
* Use QDType bit interfaces for classical simulation

* link to issue

---------

Co-authored-by: Matthew Harrigan <[email protected]>
  • Loading branch information
anurudhp and mpharrigan authored Jul 31, 2024
1 parent 2183db4 commit 0af54fc
Show file tree
Hide file tree
Showing 9 changed files with 44 additions and 116 deletions.
3 changes: 3 additions & 0 deletions qualtran/_infra/data_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,6 +547,9 @@ class QFxp(QDType):
int type is QUInt(6). So a true classical value of `10.0011` will have a raw
integer representation of `100011`.
See https://github.com/quantumlib/Qualtran/issues/1219 for discussion on alternatives
and future upgrades.
Attributes:
bitsize: The total number of qubits used to represent the integer and
Expand Down
8 changes: 1 addition & 7 deletions qualtran/bloqs/bookkeeping/cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
ConnectionT,
DecomposeTypeError,
QDType,
QFxp,
Register,
Side,
Signature,
Expand Down Expand Up @@ -95,12 +94,7 @@ def my_tensors(
]

def on_classical_vals(self, reg: int) -> Dict[str, 'ClassicalValT']:
if isinstance(self.out_dtype, QFxp):
res = reg
elif isinstance(self.inp_dtype, QFxp):
res = int(reg)
else:
res = self.out_dtype.from_bits(self.inp_dtype.to_bits(reg))
res = self.out_dtype.from_bits(self.inp_dtype.to_bits(reg))
return {'reg': res}

def as_cirq_op(self, qubit_manager, reg: 'CirqQuregT') -> Tuple[None, Dict[str, 'CirqQuregT']]:
Expand Down
21 changes: 14 additions & 7 deletions qualtran/bloqs/bookkeeping/cast_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,17 +31,24 @@ def test_cast_tensor_contraction():


def test_cast_classical_sim():
c = Cast(QInt(8), QFxp(8, 8))
qint = QUInt(8)
qfxp = QFxp(8, 8)

c = Cast(qint, qfxp)
(y,) = c.call_classically(reg=7)
assert y == 7
bloq = TestCastToFrom()
(a, b) = bloq.call_classically(a=7, b=2)
assert y == int(y)
assert qfxp.float_from_fixed_width_int(int(y)) == 7 / 2**8

bloq = TestCastToFrom(bitsize=8)
b_float = 2 / 2**8
(a, b) = bloq.call_classically(a=7, b=qfxp.to_fixed_width_int(b_float))
assert a == 7
assert b == 9
assert b == int(b)
assert qfxp.float_from_fixed_width_int(int(b)) == 9 / 2**8

c = Cast(QFxp(8, 8), QUInt(8))
c = Cast(qfxp, qint)
val = 1.2
val_as_int = QFxp(8, 8).to_fixed_width_int(val)
val_as_int = qfxp.to_fixed_width_int(val)
assert c.call_classically(reg=val_as_int) == (val_as_int,) # type: ignore


Expand Down
5 changes: 0 additions & 5 deletions qualtran/bloqs/bookkeeping/join.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,13 @@
DecomposeTypeError,
QBit,
QDType,
QFxp,
QUInt,
Register,
Side,
Signature,
)
from qualtran.bloqs.bookkeeping._bookkeeping_bloq import _BookkeepingBloq
from qualtran.drawing import directional_text_box, Text, WireSymbol
from qualtran.simulation.classical_sim import bits_to_ints

if TYPE_CHECKING:
import quimb.tensor as qtn
Expand Down Expand Up @@ -96,9 +94,6 @@ def my_tensors(
]

def on_classical_vals(self, reg: 'NDArray[np.uint]') -> Dict[str, int]:
if isinstance(self.dtype, QFxp):
# TODO(#1095): support QFxp in classical simulation
return {'reg': bits_to_ints(reg)[0]}
return {'reg': self.dtype.from_bits(reg.tolist())}

def wire_symbol(self, reg: Optional[Register], idx: Tuple[int, ...] = tuple()) -> 'WireSymbol':
Expand Down
40 changes: 20 additions & 20 deletions qualtran/bloqs/bookkeeping/partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import numpy as np
from attrs import evolve, field, frozen, validators
from numpy.typing import NDArray

from qualtran import (
bloq_example,
Expand All @@ -24,13 +25,13 @@
ConnectionT,
DecomposeTypeError,
QAny,
QDType,
Register,
Side,
Signature,
)
from qualtran.bloqs.bookkeeping._bookkeeping_bloq import _BookkeepingBloq
from qualtran.drawing import directional_text_box, Text, WireSymbol
from qualtran.simulation.classical_sim import bits_to_ints, ints_to_bits

if TYPE_CHECKING:
import quimb.tensor as qtn
Expand Down Expand Up @@ -65,13 +66,17 @@ def __attrs_post_init__(self):
if len(set(r.name for r in self.regs)) != len(self.regs):
raise ValueError("Duplicate register names")

@cached_property
def lumped_dtype(self) -> QDType:
return QAny(bitsize=self.n)

@cached_property
def signature(self) -> 'Signature':
lumped = Side.LEFT if self.partition else Side.RIGHT
partitioned = Side.RIGHT if self.partition else Side.LEFT

return Signature(
[Register('x', QAny(bitsize=self.n), side=lumped)]
[Register('x', self.lumped_dtype, side=lumped)]
+ [evolve(reg, side=partitioned) for reg in self.regs]
)

Expand Down Expand Up @@ -119,40 +124,35 @@ def my_tensors(

def _classical_partition(self, x: 'ClassicalValT') -> Dict[str, 'ClassicalValT']:
out_vals = {}
xbits = ints_to_bits(x, self.n)[0]
xbits = self.lumped_dtype.to_bits(x)
start = 0
for reg in self.regs:
size = int(np.prod(reg.shape + (reg.bitsize,)))
bits_reg = xbits[start : start + size]
if reg.shape == ():
out_vals[reg.name] = bits_to_ints(bits_reg)[0]
out_vals[reg.name] = reg.dtype.from_bits(bits_reg)
else:
ints_reg = bits_to_ints(
[
bits_reg[i * reg.bitsize : (i + 1) * reg.bitsize]
for i in range(np.prod(reg.shape))
]
out_vals[reg.name] = reg.dtype.from_bits_array(
np.asarray(bits_reg).reshape(reg.shape + (reg.bitsize,))
)
out_vals[reg.name] = np.array(ints_reg).reshape(reg.shape)
start += size
return out_vals

def _classical_unpartition(self, **vals: 'ClassicalValT'):
out_vals = []
def _classical_unpartition_to_bits(self, **vals: 'ClassicalValT') -> NDArray[np.uint8]:
out_vals: list[NDArray[np.uint8]] = []
for reg in self.regs:
reg_val = vals[reg.name]
if isinstance(reg_val, np.ndarray):
out_vals.append(ints_to_bits(reg_val.ravel(), reg.bitsize).ravel())
else:
out_vals.append(ints_to_bits(reg_val, reg.bitsize)[0])
big_int = np.concatenate(out_vals)
return {'x': bits_to_ints(big_int)[0]}
reg_val = np.asarray(vals[reg.name])
bitstrings = reg.dtype.to_bits_array(reg_val.ravel())
out_vals.append(bitstrings.ravel())
return np.concatenate(out_vals)

def on_classical_vals(self, **vals: 'ClassicalValT') -> Dict[str, 'ClassicalValT']:
if self.partition:
return self._classical_partition(vals['x'])
else:
return self._classical_unpartition(**vals)
big_int_bits = self._classical_unpartition_to_bits(**vals)
big_int = self.lumped_dtype.from_bits(big_int_bits.tolist())
return {'x': big_int}

def wire_symbol(self, reg: Register, idx: Tuple[int, ...] = tuple()) -> 'WireSymbol':
if reg is None:
Expand Down
3 changes: 1 addition & 2 deletions qualtran/bloqs/bookkeeping/split.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
)
from qualtran.bloqs.bookkeeping._bookkeeping_bloq import _BookkeepingBloq
from qualtran.drawing import directional_text_box, Text, WireSymbol
from qualtran.simulation.classical_sim import ints_to_bits

if TYPE_CHECKING:
import quimb.tensor as qtn
Expand Down Expand Up @@ -88,7 +87,7 @@ def as_cirq_op(self, qubit_manager, reg: 'CirqQuregT') -> Tuple[None, Dict[str,
return None, {'reg': reg.reshape((self.dtype.num_qubits, 1))}

def on_classical_vals(self, reg: int) -> Dict[str, 'ClassicalValT']:
return {'reg': ints_to_bits(np.array([reg]), self.dtype.num_qubits)[0]}
return {'reg': np.asarray(self.dtype.to_bits(reg))}

def my_tensors(
self, incoming: Dict[str, 'ConnectionT'], outgoing: Dict[str, 'ConnectionT']
Expand Down
6 changes: 5 additions & 1 deletion qualtran/bloqs/for_testing/casting.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,13 @@

@frozen
class TestCastToFrom(Bloq):
bitsize: int = 4

@cached_property
def signature(self) -> Signature:
return Signature([Register('a', QUInt(4)), Register('b', QFxp(4, 4))])
return Signature(
[Register('a', QUInt(self.bitsize)), Register('b', QFxp(self.bitsize, self.bitsize))]
)

def build_composite_bloq(
self, bb: 'BloqBuilder', *, a: 'Soquet', b: 'Soquet'
Expand Down
32 changes: 0 additions & 32 deletions qualtran/simulation/classical_sim.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,38 +37,6 @@
ClassicalValT = Union[int, np.integer, NDArray[np.integer]]


def bits_to_ints(bitstrings: Union[Sequence[int], NDArray[np.uint]]) -> NDArray[np.integer]:
"""Returns the integer specified by the given big-endian bitstrings.
Args:
bitstrings: A bitstring or array of bitstrings, each of which has the 1s bit (LSB) at the end.
Returns:
An array of integers; one for each bitstring.
"""
from qualtran import QUInt

bitstrings = np.atleast_2d(bitstrings)
return QUInt(bitstrings.shape[1]).from_bits_array(bitstrings)


def ints_to_bits(
x: Union[int, np.integer, Sequence[int], NDArray[np.integer]], w: int
) -> NDArray[np.uint8]:
"""Returns the big-endian bitstrings specified by the given integers.
Args:
x: An integer or array of unsigned integers.
w: The bit width of the returned bitstrings.
"""
from qualtran import QInt, QUInt

x = np.atleast_1d(x)
if np.all(x >= 0):
return QUInt(w).to_bits_array(x)
else:
return QInt(w).to_bits_array(x)


def _get_in_vals(
binst: Union[DanglingT, BloqInstance], reg: Register, soq_assign: Dict[Soquet, ClassicalValT]
) -> ClassicalValT:
Expand Down
42 changes: 0 additions & 42 deletions qualtran/simulation/classical_sim_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import itertools
from typing import Dict

import cirq
import numpy as np
import pytest
from attrs import frozen
Expand All @@ -26,53 +25,12 @@
from qualtran.simulation.classical_sim import (
_update_assign_from_vals,
add_ints,
bits_to_ints,
call_cbloq_classically,
ClassicalValT,
ints_to_bits,
)
from qualtran.testing import execute_notebook


def test_bits_to_int():
rs = np.random.RandomState(52)
bitstrings = rs.choice([0, 1], size=(100, 23))

nums = bits_to_ints(bitstrings)
assert nums.dtype == np.uint64
assert nums.shape == (100,)

for num, bs in zip(nums, bitstrings):
ref_num = cirq.big_endian_bits_to_int(bs.tolist())
assert num == ref_num

# check one input bitstring instead of array of input bitstrings.
(num,) = bits_to_ints([1, 0])
assert num == 2


def test_int_to_bits():
rs = np.random.RandomState(52)
nums = rs.randint(0, 2**23 - 1, size=(100,), dtype=np.uint64)
bitstrings = ints_to_bits(nums, w=23)
assert bitstrings.shape == (100, 23)

nums = rs.randint(-(2**22), 2**22, size=(100,), dtype=np.int64)
bitstrings = ints_to_bits(nums, w=23)
assert bitstrings.shape == (100, 23)

for num, bs in zip(nums, bitstrings):
ref_bs = cirq.big_endian_int_to_bits(int(num), bit_count=23)
np.testing.assert_array_equal(ref_bs, bs)

# check one input int
(bitstring,) = ints_to_bits(2, w=8)
assert bitstring.tolist() == [0, 0, 0, 0, 0, 0, 1, 0]

bitstring = ints_to_bits([31, -1], w=6)
assert bitstring.tolist() == [[0, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1]]


def test_dtype_validation():
# set up mocks for `_update_assign_from_vals`
soq_assign: Dict[Soquet, ClassicalValT] = {} # gets assigned to; we discard in this test.
Expand Down

0 comments on commit 0af54fc

Please sign in to comment.