Skip to content

Commit

Permalink
Fix dtypes in IntVector and PlusEqualsProduct (#1197)
Browse files Browse the repository at this point in the history
* fix dtypes in signatures

* code cleanup

* use `is_symbolic`

* assert result reg is large enough

* docstring
  • Loading branch information
anurudhp authored Jul 27, 2024
1 parent 174b549 commit 70b6dff
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 28 deletions.
13 changes: 12 additions & 1 deletion qualtran/bloqs/arithmetic/multiplication.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,18 @@
},
"source": [
"## `PlusEqualProduct`\n",
"Performs result += a * b"
"Performs result += a * b.\n",
"\n",
"#### Parameters\n",
" - `a_bitsize`: bitsize of input `a`.\n",
" - `b_bitsize`: bitsize of input `b`.\n",
" - `result_bitsize`: bitsize of the output register.\n",
" - `is_adjoint`: If true, performs `result -= a * b` instead. Defaults to False. \n",
"\n",
"#### Registers\n",
" - `a`: QUInt of `a_bitsize` bits.\n",
" - `b`: QUInt of `b_bitsize` bits.\n",
" - `result`: QUInt of `result_bitsize` bits.\n"
]
},
{
Expand Down
63 changes: 41 additions & 22 deletions qualtran/bloqs/arithmetic/multiplication.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from qualtran.bloqs.arithmetic.subtraction import Subtract
from qualtran.bloqs.basic_gates import CNOT, TGate, Toffoli, XGate
from qualtran.bloqs.mcmt import MultiControlPauli
from qualtran.symbolics import HasLength, smax, SymbolicInt
from qualtran.symbolics import HasLength, is_symbolic, smax, SymbolicInt

if TYPE_CHECKING:
import quimb.tensor as qtn
Expand All @@ -44,30 +44,58 @@

@frozen
class PlusEqualProduct(GateWithRegisters, cirq.ArithmeticGate): # type: ignore[misc]
"""Performs result += a * b"""
"""Performs result += a * b.
Args:
a_bitsize: bitsize of input `a`.
b_bitsize: bitsize of input `b`.
result_bitsize: bitsize of the output register.
is_adjoint: If true, performs `result -= a * b` instead. Defaults to False.
Registers:
a: QUInt of `a_bitsize` bits.
b: QUInt of `b_bitsize` bits.
result: QUInt of `result_bitsize` bits.
"""

a_bitsize: SymbolicInt
b_bitsize: SymbolicInt
result_bitsize: SymbolicInt
is_adjoint: bool = False

def __attrs_post_init__(self):
res_has_enough = self.a_bitsize + self.b_bitsize <= self.result_bitsize
if not is_symbolic(res_has_enough) and not res_has_enough:
raise ValueError(
f"{self.result_bitsize=} must be at least the sum of input "
f"bitsizes {self.a_bitsize} + {self.b_bitsize}"
)

def pretty_name(self) -> str:
return "result -= a*b" if self.is_adjoint else "result += a*b"

@property
def signature(self) -> 'Signature':
return Signature.build_from_dtypes(
a=QUInt(self.a_bitsize),
b=QUInt(self.b_bitsize),
result=QFxp(self.result_bitsize, self.result_bitsize),
)
return Signature.build_from_dtypes(a=self.a_dtype, b=self.b_dtype, result=self.result_dtype)

@property
def a_dtype(self):
return QUInt(self.a_bitsize)

@property
def b_dtype(self):
return QUInt(self.b_bitsize)

@property
def result_dtype(self):
return QUInt(self.result_bitsize)

def registers(self) -> Sequence[Union[int, Sequence[int]]]:
if not isinstance(self.a_bitsize, int):
if is_symbolic(self.a_bitsize):
raise ValueError(f'Symbolic bitsize {self.a_bitsize} not supported')
if not isinstance(self.b_bitsize, int):
if is_symbolic(self.b_bitsize):
raise ValueError(f'Symbolic bitsize {self.b_bitsize} not supported')
if not isinstance(self.result_bitsize, int):
if is_symbolic(self.result_bitsize):
raise ValueError(f'Symbolic bitsize {self.result_bitsize} not supported')
return [2] * self.a_bitsize, [2] * self.b_bitsize, [2] * self.result_bitsize

Expand All @@ -85,25 +113,16 @@ def on_classical_vals(self, a: int, b: int, result: int) -> Dict[str, 'Classical
return {'a': a, 'b': b, 'result': result_out}

def _circuit_diagram_info_(self, args: cirq.CircuitDiagramInfoArgs) -> cirq.CircuitDiagramInfo:
if not isinstance(self.a_bitsize, int):
if is_symbolic(self.a_bitsize):
raise ValueError(f'Symbolic bitsize {self.a_bitsize} not supported')
if not isinstance(self.b_bitsize, int):
if is_symbolic(self.b_bitsize):
raise ValueError(f'Symbolic bitsize {self.b_bitsize} not supported')
if not isinstance(self.result_bitsize, int):
if is_symbolic(self.result_bitsize):
raise ValueError(f'Symbolic bitsize {self.result_bitsize} not supported')
wire_symbols = ['a'] * self.a_bitsize + ['b'] * self.b_bitsize
wire_symbols += ['c-=a*b' if self.is_adjoint else 'c+=a*b'] * self.result_bitsize
return cirq.CircuitDiagramInfo(wire_symbols=wire_symbols)

def __pow__(self, power):
if power == 1:
return self
if power == -1:
return PlusEqualProduct(
self.a_bitsize, self.b_bitsize, self.result_bitsize, not self.is_adjoint
)
raise NotImplementedError("PlusEqualProduct.__pow__ defined only for powers +1/-1.")

def my_tensors(
self, incoming: Dict[str, 'ConnectionT'], outgoing: Dict[str, 'ConnectionT']
) -> List['qtn.Tensor']:
Expand Down
14 changes: 9 additions & 5 deletions qualtran/bloqs/basic_gates/z_basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
DecomposeTypeError,
QAny,
QBit,
QDType,
Register,
Side,
Signature,
Expand All @@ -42,7 +43,6 @@
from qualtran.bloqs.bookkeeping import ArbitraryClifford
from qualtran.cirq_interop.t_complexity_protocol import TComplexity
from qualtran.drawing import Circle, directional_text_box, Text, TextBox, WireSymbol
from qualtran.simulation.classical_sim import ints_to_bits

if TYPE_CHECKING:
import cirq
Expand Down Expand Up @@ -384,12 +384,16 @@ def check(self, attribute, val):
if val >= 2**self.bitsize:
raise ValueError(f"`val` is too big for bitsize {self.bitsize}")

@cached_property
def dtype(self) -> QDType:
if self.bitsize == 1:
return QBit()
return QAny(self.bitsize)

@cached_property
def signature(self) -> Signature:
side = Side.RIGHT if self.state else Side.LEFT
if self.bitsize == 1:
return Signature([Register('val', QBit(), side=side)])
return Signature([Register('val', QAny(self.bitsize), side=side)])
return Signature([Register('val', self.dtype, side=side)])

@staticmethod
def _build_composite_state(bb: 'BloqBuilder', bits: NDArray[np.uint8]) -> Dict[str, 'SoquetT']:
Expand All @@ -415,7 +419,7 @@ def _build_composite_effect(
def build_composite_bloq(self, bb: 'BloqBuilder', **val: 'SoquetT') -> Dict[str, 'SoquetT']:
if isinstance(self.bitsize, sympy.Expr):
raise DecomposeTypeError(f'Symbolic bitsize {self.bitsize} not supported')
bits = ints_to_bits(np.array([self.val]), w=self.bitsize)[0]
bits = np.asarray(self.dtype.to_bits(self.val))
if self.state:
assert not val
return self._build_composite_state(bb, bits)
Expand Down

0 comments on commit 70b6dff

Please sign in to comment.