Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix dtypes in IntVector and PlusEqualsProduct #1197

Merged
merged 6 commits into from
Jul 27, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion qualtran/bloqs/arithmetic/multiplication.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,18 @@
},
"source": [
"## `PlusEqualProduct`\n",
"Performs result += a * b"
"Performs result += a * b.\n",
"\n",
"#### Parameters\n",
" - `a_bitsize`: bitsize of input `a`.\n",
" - `b_bitsize`: bitsize of input `b`.\n",
" - `result_bitsize`: bitsize of the output register.\n",
" - `is_adjoint`: If true, performs `result -= a * b` instead. Defaults to False. \n",
"\n",
"#### Registers\n",
" - `a`: QUInt of `a_bitsize` bits.\n",
" - `b`: QUInt of `b_bitsize` bits.\n",
" - `result`: QUInt of `result_bitsize` bits.\n"
]
},
{
Expand Down
63 changes: 41 additions & 22 deletions qualtran/bloqs/arithmetic/multiplication.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from qualtran.bloqs.arithmetic.subtraction import Subtract
from qualtran.bloqs.basic_gates import CNOT, TGate, Toffoli, XGate
from qualtran.bloqs.mcmt import MultiControlPauli
from qualtran.symbolics import HasLength, smax, SymbolicInt
from qualtran.symbolics import HasLength, is_symbolic, smax, SymbolicInt

if TYPE_CHECKING:
import quimb.tensor as qtn
Expand All @@ -44,30 +44,58 @@

@frozen
class PlusEqualProduct(GateWithRegisters, cirq.ArithmeticGate): # type: ignore[misc]
"""Performs result += a * b"""
"""Performs result += a * b.

Args:
a_bitsize: bitsize of input `a`.
b_bitsize: bitsize of input `b`.
result_bitsize: bitsize of the output register.
is_adjoint: If true, performs `result -= a * b` instead. Defaults to False.

Registers:
a: QUInt of `a_bitsize` bits.
b: QUInt of `b_bitsize` bits.
result: QUInt of `result_bitsize` bits.
"""

a_bitsize: SymbolicInt
b_bitsize: SymbolicInt
result_bitsize: SymbolicInt
is_adjoint: bool = False

def __attrs_post_init__(self):
res_has_enough = self.a_bitsize + self.b_bitsize <= self.result_bitsize
if not is_symbolic(res_has_enough) and not res_has_enough:
raise ValueError(
f"{self.result_bitsize=} must be at least the sum of input "
f"bitsizes {self.a_bitsize} + {self.b_bitsize}"
)

def pretty_name(self) -> str:
return "result -= a*b" if self.is_adjoint else "result += a*b"

@property
def signature(self) -> 'Signature':
return Signature.build_from_dtypes(
a=QUInt(self.a_bitsize),
b=QUInt(self.b_bitsize),
result=QFxp(self.result_bitsize, self.result_bitsize),
)
return Signature.build_from_dtypes(a=self.a_dtype, b=self.b_dtype, result=self.result_dtype)

@property
def a_dtype(self):
return QUInt(self.a_bitsize)

@property
def b_dtype(self):
return QUInt(self.b_bitsize)

@property
def result_dtype(self):
return QUInt(self.result_bitsize)

def registers(self) -> Sequence[Union[int, Sequence[int]]]:
if not isinstance(self.a_bitsize, int):
if is_symbolic(self.a_bitsize):
raise ValueError(f'Symbolic bitsize {self.a_bitsize} not supported')
if not isinstance(self.b_bitsize, int):
if is_symbolic(self.b_bitsize):
raise ValueError(f'Symbolic bitsize {self.b_bitsize} not supported')
if not isinstance(self.result_bitsize, int):
if is_symbolic(self.result_bitsize):
raise ValueError(f'Symbolic bitsize {self.result_bitsize} not supported')
return [2] * self.a_bitsize, [2] * self.b_bitsize, [2] * self.result_bitsize

Expand All @@ -85,25 +113,16 @@ def on_classical_vals(self, a: int, b: int, result: int) -> Dict[str, 'Classical
return {'a': a, 'b': b, 'result': result_out}

def _circuit_diagram_info_(self, args: cirq.CircuitDiagramInfoArgs) -> cirq.CircuitDiagramInfo:
if not isinstance(self.a_bitsize, int):
if is_symbolic(self.a_bitsize):
raise ValueError(f'Symbolic bitsize {self.a_bitsize} not supported')
if not isinstance(self.b_bitsize, int):
if is_symbolic(self.b_bitsize):
raise ValueError(f'Symbolic bitsize {self.b_bitsize} not supported')
if not isinstance(self.result_bitsize, int):
if is_symbolic(self.result_bitsize):
raise ValueError(f'Symbolic bitsize {self.result_bitsize} not supported')
wire_symbols = ['a'] * self.a_bitsize + ['b'] * self.b_bitsize
wire_symbols += ['c-=a*b' if self.is_adjoint else 'c+=a*b'] * self.result_bitsize
return cirq.CircuitDiagramInfo(wire_symbols=wire_symbols)

def __pow__(self, power):
if power == 1:
return self
if power == -1:
return PlusEqualProduct(
self.a_bitsize, self.b_bitsize, self.result_bitsize, not self.is_adjoint
)
raise NotImplementedError("PlusEqualProduct.__pow__ defined only for powers +1/-1.")

def my_tensors(
self, incoming: Dict[str, 'ConnectionT'], outgoing: Dict[str, 'ConnectionT']
) -> List['qtn.Tensor']:
Expand Down
14 changes: 9 additions & 5 deletions qualtran/bloqs/basic_gates/z_basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
DecomposeTypeError,
QAny,
QBit,
QDType,
Register,
Side,
Signature,
Expand All @@ -42,7 +43,6 @@
from qualtran.bloqs.bookkeeping import ArbitraryClifford
from qualtran.cirq_interop.t_complexity_protocol import TComplexity
from qualtran.drawing import Circle, directional_text_box, Text, TextBox, WireSymbol
from qualtran.simulation.classical_sim import ints_to_bits

if TYPE_CHECKING:
import cirq
Expand Down Expand Up @@ -384,12 +384,16 @@ def check(self, attribute, val):
if val >= 2**self.bitsize:
raise ValueError(f"`val` is too big for bitsize {self.bitsize}")

@cached_property
def dtype(self) -> QDType:
if self.bitsize == 1:
return QBit()
return QAny(self.bitsize)
anurudhp marked this conversation as resolved.
Show resolved Hide resolved

@cached_property
def signature(self) -> Signature:
side = Side.RIGHT if self.state else Side.LEFT
if self.bitsize == 1:
return Signature([Register('val', QBit(), side=side)])
return Signature([Register('val', QAny(self.bitsize), side=side)])
return Signature([Register('val', self.dtype, side=side)])

@staticmethod
def _build_composite_state(bb: 'BloqBuilder', bits: NDArray[np.uint8]) -> Dict[str, 'SoquetT']:
Expand All @@ -415,7 +419,7 @@ def _build_composite_effect(
def build_composite_bloq(self, bb: 'BloqBuilder', **val: 'SoquetT') -> Dict[str, 'SoquetT']:
if isinstance(self.bitsize, sympy.Expr):
raise DecomposeTypeError(f'Symbolic bitsize {self.bitsize} not supported')
bits = ints_to_bits(np.array([self.val]), w=self.bitsize)[0]
bits = np.asarray(self.dtype.to_bits(self.val))
if self.state:
assert not val
return self._build_composite_state(bb, bits)
Expand Down
Loading