From 3f9d5e1d98f1b46779c048543777ff9789373f54 Mon Sep 17 00:00:00 2001 From: Anurudh Peduri Date: Fri, 19 Jul 2024 03:26:58 -0400 Subject: [PATCH] Add `SignExtend` for two's complement sign extension --- dev_tools/autogenerate-bloqs-notebooks-v2.py | 5 + docs/bloqs/index.rst | 1 + qualtran/bloqs/arithmetic/__init__.py | 1 + .../bloqs/arithmetic/sign_extension.ipynb | 165 ++++++++++++++++++ qualtran/bloqs/arithmetic/sign_extension.py | 123 +++++++++++++ .../bloqs/arithmetic/sign_extension_test.py | 42 +++++ qualtran/serialization/resolver_dict.py | 2 + 7 files changed, 339 insertions(+) create mode 100644 qualtran/bloqs/arithmetic/sign_extension.ipynb create mode 100644 qualtran/bloqs/arithmetic/sign_extension.py create mode 100644 qualtran/bloqs/arithmetic/sign_extension_test.py diff --git a/dev_tools/autogenerate-bloqs-notebooks-v2.py b/dev_tools/autogenerate-bloqs-notebooks-v2.py index f436cb4479..8f230550bc 100644 --- a/dev_tools/autogenerate-bloqs-notebooks-v2.py +++ b/dev_tools/autogenerate-bloqs-notebooks-v2.py @@ -374,6 +374,11 @@ qualtran.bloqs.arithmetic.comparison._LEQ_DOC, ], ), + NotebookSpecV2( + title='Sign Extension', + module=qualtran.bloqs.arithmetic.sign_extension, + bloq_specs=[qualtran.bloqs.arithmetic.sign_extension._SIGN_EXTEND_DOC], + ), NotebookSpecV2( title='Sorting', module=qualtran.bloqs.arithmetic.sorting, diff --git a/docs/bloqs/index.rst b/docs/bloqs/index.rst index 2a5af961cb..30e606253f 100644 --- a/docs/bloqs/index.rst +++ b/docs/bloqs/index.rst @@ -64,6 +64,7 @@ Bloqs Library arithmetic/subtraction.ipynb arithmetic/multiplication.ipynb arithmetic/comparison.ipynb + arithmetic/sign_extension.ipynb arithmetic/sorting.ipynb arithmetic/conversions.ipynb arithmetic/permutation.ipynb diff --git a/qualtran/bloqs/arithmetic/__init__.py b/qualtran/bloqs/arithmetic/__init__.py index 00b016ed55..c5fdd8da44 100644 --- a/qualtran/bloqs/arithmetic/__init__.py +++ b/qualtran/bloqs/arithmetic/__init__.py @@ -35,6 +35,7 @@ SumOfSquares, ) from qualtran.bloqs.arithmetic.negate import Negate +from qualtran.bloqs.arithmetic.sign_extension import SignExtend from qualtran.bloqs.arithmetic.sorting import BitonicSort, Comparator from qualtran.bloqs.arithmetic.subtraction import Subtract diff --git a/qualtran/bloqs/arithmetic/sign_extension.ipynb b/qualtran/bloqs/arithmetic/sign_extension.ipynb new file mode 100644 index 0000000000..ddb41ef397 --- /dev/null +++ b/qualtran/bloqs/arithmetic/sign_extension.ipynb @@ -0,0 +1,165 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "8a2ef164", + "metadata": { + "cq.autogen": "title_cell" + }, + "source": [ + "# Sign Extension" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "76be8def", + "metadata": { + "cq.autogen": "top_imports" + }, + "outputs": [], + "source": [ + "from qualtran import Bloq, CompositeBloq, BloqBuilder, Signature, Register\n", + "from qualtran import QBit, QInt, QUInt, QAny\n", + "from qualtran.drawing import show_bloq, show_call_graph, show_counts_sigma\n", + "from typing import *\n", + "import numpy as np\n", + "import sympy\n", + "import cirq" + ] + }, + { + "cell_type": "markdown", + "id": "283c1b7e", + "metadata": { + "cq.autogen": "SignExtend.bloq_doc.md" + }, + "source": [ + "## `SignExtend`\n", + "Sign-Extend a value to a value of larger bitsize.\n", + "\n", + "Useful to implement arithmetic operations with differing operand bitsizes.\n", + "A sign extension copies the MSB into the new bits of the wider value. For\n", + "example: a 4-bit to 6-bit sign-extension of `1010` gives `111010`.\n", + "\n", + "\n", + "#### Parameters\n", + " - `inp_dtype`: input data type.\n", + " - `out_dtype`: output data type. must be same class as `inp_dtype`, and have larger bitsize. \n", + "\n", + "#### Registers\n", + " - `x`: the input register of type `inp_dtype`\n", + " - `y`: the output register of type `out_dtype`\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6504d380", + "metadata": { + "cq.autogen": "SignExtend.bloq_doc.py" + }, + "outputs": [], + "source": [ + "from qualtran.bloqs.arithmetic import SignExtend" + ] + }, + { + "cell_type": "markdown", + "id": "9354fcdc", + "metadata": { + "cq.autogen": "SignExtend.example_instances.md" + }, + "source": [ + "### Example Instances" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "30a40ff3", + "metadata": { + "cq.autogen": "SignExtend.sign_extend" + }, + "outputs": [], + "source": [ + "from qualtran import QInt\n", + "\n", + "sign_extend = SignExtend(QInt(8), QInt(16))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9e43e994", + "metadata": { + "cq.autogen": "SignExtend.sign_extend_fxp" + }, + "outputs": [], + "source": [ + "sign_extend = SignExtend(QFxp(8, 4, signed=True), QFxp(16, 4, signed=True))" + ] + }, + { + "cell_type": "markdown", + "id": "ae58afdf", + "metadata": { + "cq.autogen": "SignExtend.graphical_signature.md" + }, + "source": [ + "#### Graphical Signature" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2c96ebb1", + "metadata": { + "cq.autogen": "SignExtend.graphical_signature.py" + }, + "outputs": [], + "source": [ + "from qualtran.drawing import show_bloqs\n", + "show_bloqs([sign_extend, sign_extend_fxp],\n", + " ['`sign_extend`', '`sign_extend_fxp`'])" + ] + }, + { + "cell_type": "markdown", + "id": "38f234ca", + "metadata": { + "cq.autogen": "SignExtend.call_graph.md" + }, + "source": [ + "### Call Graph" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "094ebe4f", + "metadata": { + "cq.autogen": "SignExtend.call_graph.py" + }, + "outputs": [], + "source": [ + "from qualtran.resource_counting.generalizers import ignore_split_join\n", + "sign_extend_g, sign_extend_sigma = sign_extend.call_graph(max_depth=1, generalizer=ignore_split_join)\n", + "show_call_graph(sign_extend_g)\n", + "show_counts_sigma(sign_extend_sigma)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/qualtran/bloqs/arithmetic/sign_extension.py b/qualtran/bloqs/arithmetic/sign_extension.py new file mode 100644 index 0000000000..4819f74af5 --- /dev/null +++ b/qualtran/bloqs/arithmetic/sign_extension.py @@ -0,0 +1,123 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from functools import cached_property + +import numpy as np +from attrs import frozen + +from qualtran import ( + Bloq, + bloq_example, + BloqBuilder, + BloqDocSpec, + QDType, + QFxp, + Register, + Side, + Signature, + Soquet, + SoquetT, +) +from qualtran.bloqs.mcmt import MultiTargetCNOT +from qualtran.resource_counting import BloqCountT, SympySymbolAllocator +from qualtran.symbolics import is_symbolic + + +@frozen +class SignExtend(Bloq): + """Sign-Extend a value to a value of larger bitsize. + + Useful to implement arithmetic operations with differing operand bitsizes. + A sign extension copies the MSB into the new bits of the wider value. For + example: a 4-bit to 6-bit sign-extension of `1010` gives `111010`. + + + Args: + inp_dtype: input data type. + out_dtype: output data type. must be same class as `inp_dtype`, + and have larger bitsize. + + Registers: + x (LEFT): the input register of type `inp_dtype` + y (RIGHT): the output register of type `out_dtype` + """ + + inp_dtype: QDType + out_dtype: QDType + + def __attrs_post_init__(self): + if not isinstance(self.inp_dtype, type(self.out_dtype)): + raise ValueError( + f"Expected same input and output base types, got: {self.inp_dtype}, {self.out_dtype}" + ) + + if isinstance(self.out_dtype, QFxp): + assert isinstance(self.inp_dtype, QFxp) # checked above, but mypy does not realize + + if self.out_dtype.num_frac != self.inp_dtype.num_frac: + raise ValueError( + f"Expected same fractional sizes for QFxp, got: {self.inp_dtype.num_frac}, {self.out_dtype.num_frac}" + ) + + if not is_symbolic(self.extend_bitsize) and self.extend_bitsize <= 0: + raise ValueError( + f"input bitsize {self.inp_dtype.num_qubits} must be smaller than " + f"output bitsize {self.out_dtype.num_qubits}" + ) + + @cached_property + def signature(self) -> 'Signature': + return Signature( + [ + Register('x', self.inp_dtype, side=Side.LEFT), + Register('y', self.out_dtype, side=Side.RIGHT), + ] + ) + + @cached_property + def extend_bitsize(self): + return self.out_dtype.num_qubits - self.inp_dtype.num_qubits + + def build_composite_bloq(self, bb: 'BloqBuilder', x: 'Soquet') -> dict[str, 'SoquetT']: + extend_ys = bb.allocate(self.extend_bitsize) + xs = bb.split(x) + + xs[0], extend_ys = bb.add( + MultiTargetCNOT(self.extend_bitsize), control=xs[0], targets=extend_ys + ) + + extend_ys = bb.split(extend_ys) + y = bb.join(np.concatenate([extend_ys, xs]), dtype=self.out_dtype) + + return {'y': y} + + def build_call_graph(self, ssa: 'SympySymbolAllocator') -> set['BloqCountT']: + return {(MultiTargetCNOT(self.extend_bitsize), 1)} + + +@bloq_example +def _sign_extend() -> SignExtend: + from qualtran import QInt + + sign_extend = SignExtend(QInt(8), QInt(16)) + return sign_extend + + +@bloq_example +def _sign_extend_fxp() -> SignExtend: + sign_extend = SignExtend(QFxp(8, 4, signed=True), QFxp(16, 4, signed=True)) + return sign_extend + + +_SIGN_EXTEND_DOC = BloqDocSpec(bloq_cls=SignExtend, examples=[_sign_extend, _sign_extend_fxp]) diff --git a/qualtran/bloqs/arithmetic/sign_extension_test.py b/qualtran/bloqs/arithmetic/sign_extension_test.py new file mode 100644 index 0000000000..09124da4f3 --- /dev/null +++ b/qualtran/bloqs/arithmetic/sign_extension_test.py @@ -0,0 +1,42 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import numpy as np +import pytest + +from qualtran import BloqBuilder, QInt, QUInt +from qualtran.bloqs.arithmetic.sign_extension import _sign_extend, _sign_extend_fxp, SignExtend +from qualtran.bloqs.basic_gates import IntEffect, IntState + + +def test_examples(bloq_autotester): + bloq_autotester(_sign_extend) + bloq_autotester(_sign_extend_fxp) + + +@pytest.mark.parametrize("l, r", [(2, 4)]) +def test_sign_extend_tensor(l: int, r: int): + bloq = SignExtend(QInt(l), QInt(r)) + + def _as_unsigned(num: int, bitsize: int): + # TODO remove this once IntState supports signed values + return QUInt(bitsize).from_bits(QInt(bitsize).to_bits(num)) + + for x in QInt(l).get_classical_domain(): + bb = BloqBuilder() + qx = bb.add(IntState(_as_unsigned(x, l), l)) + qx = bb.add(bloq, x=qx) + bb.add(IntEffect(_as_unsigned(x, r), r), val=qx) + cbloq = bb.finalize() + + np.testing.assert_allclose(cbloq.tensor_contract(), 1) diff --git a/qualtran/serialization/resolver_dict.py b/qualtran/serialization/resolver_dict.py index b7829e3c60..a4d31490f2 100644 --- a/qualtran/serialization/resolver_dict.py +++ b/qualtran/serialization/resolver_dict.py @@ -22,6 +22,7 @@ import qualtran.bloqs.arithmetic.multiplication import qualtran.bloqs.arithmetic.negate import qualtran.bloqs.arithmetic.permutation +import qualtran.bloqs.arithmetic.sign_extension import qualtran.bloqs.arithmetic.sorting import qualtran.bloqs.basic_gates.cnot import qualtran.bloqs.basic_gates.hadamard @@ -170,6 +171,7 @@ "qualtran.bloqs.arithmetic.negate.Negate": qualtran.bloqs.arithmetic.negate.Negate, "qualtran.bloqs.arithmetic.permutation.Permutation": qualtran.bloqs.arithmetic.permutation.Permutation, "qualtran.bloqs.arithmetic.permutation.PermutationCycle": qualtran.bloqs.arithmetic.permutation.PermutationCycle, + "qualtran.bloqs.arithmetic.sign_extension.SignExtend": qualtran.bloqs.arithmetic.sign_extension.SignExtend, "qualtran.bloqs.arithmetic.sorting.BitonicMerge": qualtran.bloqs.arithmetic.sorting.BitonicMerge, "qualtran.bloqs.arithmetic.sorting.BitonicSort": qualtran.bloqs.arithmetic.sorting.BitonicSort, "qualtran.bloqs.arithmetic.sorting.Comparator": qualtran.bloqs.arithmetic.sorting.Comparator,