Skip to content

Commit

Permalink
Refactor phase gradient bloq dtypes (#1191)
Browse files Browse the repository at this point in the history
* add dtype properties, cleanup signatures

* split unit tests

* dedup tests

* split more tests

* use `signature.n_qubits`

---------

Co-authored-by: Matthew Harrigan <[email protected]>
  • Loading branch information
anurudhp and mpharrigan authored Jul 26, 2024
1 parent 8b94c3f commit 16b5378
Show file tree
Hide file tree
Showing 2 changed files with 102 additions and 88 deletions.
53 changes: 20 additions & 33 deletions qualtran/bloqs/rotations/phase_gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@
if TYPE_CHECKING:
import quimb.tensor as qtn

from qualtran import Bloq
from qualtran.resource_counting import BloqCountT, SympySymbolAllocator
from qualtran.simulation.classical_sim import ClassicalValT
from qualtran.symbolics import SymbolicFloat, SymbolicInt
Expand Down Expand Up @@ -99,11 +98,15 @@ class PhaseGradientUnitary(GateWithRegisters):
@cached_property
def signature(self) -> 'Signature':
return (
Signature.build_from_dtypes(ctrl=QBit(), phase_grad=QFxp(self.bitsize, self.bitsize))
Signature.build_from_dtypes(ctrl=QBit(), phase_grad=self.phase_dtype)
if self.is_controlled
else Signature.build_from_dtypes(phase_grad=QFxp(self.bitsize, self.bitsize))
else Signature.build_from_dtypes(phase_grad=self.phase_dtype)
)

@property
def phase_dtype(self) -> QFxp:
return QFxp(self.bitsize, self.bitsize)

def decompose_from_registers(
self, *, context: cirq.DecompositionContext, **quregs: NDArray[cirq.Qid] # type: ignore[type-var]
) -> Iterator[cirq.OP_TREE]:
Expand All @@ -123,9 +126,7 @@ def _circuit_diagram_info_(self, args: cirq.CircuitDiagramInfoArgs) -> cirq.Circ
def __pow__(self, power):
if power == 1:
return self
return PhaseGradientUnitary(
self.bitsize, self.exponent * power, self.is_controlled, self.eps
)
return attrs.evolve(self, exponent=self.exponent * power)

def build_call_graph(self, ssa: SympySymbolAllocator) -> Set['BloqCountT']:
gate = CZPowGate if self.is_controlled else ZPowGate
Expand Down Expand Up @@ -191,9 +192,11 @@ class PhaseGradientState(GateWithRegisters):

@cached_property
def signature(self) -> 'Signature':
return Signature(
[Register('phase_grad', QFxp(self.bitsize, self.bitsize), side=Side.RIGHT)]
)
return Signature([Register('phase_grad', self.phase_dtype, side=Side.RIGHT)])

@property
def phase_dtype(self) -> QFxp:
return QFxp(self.bitsize, self.bitsize)

def decompose_from_registers(
self, *, context: cirq.DecompositionContext, **quregs: NDArray[cirq.Qid]
Expand Down Expand Up @@ -264,18 +267,15 @@ def pretty_name(self) -> str:
@cached_property
def signature(self) -> 'Signature':
return (
Signature.build_from_dtypes(
ctrl=QBit(),
x=QFxp(self.x_bitsize, self.x_bitsize, signed=False),
phase_grad=QFxp(self.phase_bitsize, self.phase_bitsize, signed=False),
)
Signature.build_from_dtypes(ctrl=QBit(), x=self.x_dtype, phase_grad=self.phase_dtype)
if self.controlled_by is not None
else Signature.build_from_dtypes(
x=QFxp(self.x_bitsize, self.x_bitsize, signed=False),
phase_grad=QFxp(self.phase_bitsize, self.phase_bitsize, signed=False),
)
else Signature.build_from_dtypes(x=self.x_dtype, phase_grad=self.phase_dtype)
)

@cached_property
def x_dtype(self) -> QFxp:
return QFxp(self.x_bitsize, self.x_bitsize, signed=False)

@cached_property
def phase_dtype(self) -> QFxp:
return QFxp(self.phase_bitsize, self.phase_bitsize, signed=False)
Expand Down Expand Up @@ -331,21 +331,8 @@ def build_call_graph(self, ssa: 'SympySymbolAllocator') -> Set['BloqCountT']:

return {(Toffoli(), num_toffoli)}

def adjoint(self) -> 'Bloq':
return AddIntoPhaseGrad(
self.x_bitsize,
self.phase_bitsize,
self.right_shift,
sign=-1 * self.sign,
controlled_by=self.controlled_by,
)

def __pow__(self, power):
if power == 1:
return self
if power == -1:
return self.adjoint()
raise NotImplementedError("AddIntoPhaseGrad.__pow__ defined only for powers +1/-1.")
def adjoint(self) -> 'AddIntoPhaseGrad':
return attrs.evolve(self, sign=-self.sign)

def my_tensors(
self, incoming: Dict[str, 'ConnectionT'], outgoing: Dict[str, 'ConnectionT']
Expand Down
137 changes: 82 additions & 55 deletions qualtran/bloqs/rotations/phase_gradient_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict
from typing import Optional

import cirq
import numpy as np
Expand Down Expand Up @@ -77,95 +77,122 @@ def test_phase_gradient_gate(n: int, exponent, controlled):
assert np.allclose(cirq.unitary(bloq), cirq.unitary(cirq_gate), atol=eps)


def test_add_into_phase_grad():
@pytest.mark.parametrize("controlled_by", [None, 0, 1])
def test_add_into_phase_grad_classical_sim(controlled_by: Optional[int]):
from qualtran.bloqs.rotations.phase_gradient import _fxp

x_bit, phase_bit = 4, 7
bloq = AddIntoPhaseGrad(x_bit, phase_bit)
basis_map: Dict[int, int] = {}
bloq = AddIntoPhaseGrad(x_bit, phase_bit, controlled_by=controlled_by)

for x in range(2**x_bit):
for phase_grad in range(2**phase_bit):
phase_fxp = _fxp(phase_grad / 2**phase_bit, phase_bit)
x_fxp = _fxp(x / 2**x_bit, x_bit).like(phase_fxp)
phase_grad_out = int((phase_fxp + x_fxp).astype(float) * 2**phase_bit)
# Test Bloq style classical simulation.
assert bloq.call_classically(x=x, phase_grad=phase_grad) == (x, phase_grad_out)
# Prepare basis states mapping for cirq-style simulation.
input_state = int(f'{x:0{x_bit}b}' + f'{phase_grad:0{phase_bit}b}', 2)
output_state = int(f'{x:0{x_bit}b}' + f'{phase_grad_out:0{phase_bit}b}', 2)
basis_map[input_state] = output_state
# Test cirq style simulation.
num_bits = x_bit + phase_bit
assert len(basis_map) == len(set(basis_map.values()))
circuit = cirq.Circuit(bloq.on(*cirq.LineQubit.range(num_bits)))
cirq.testing.assert_equivalent_computational_basis_map(basis_map, circuit)
((toffoli, n),) = bloq.bloq_counts().items()
assert bloq.t_complexity() == n * toffoli.t_complexity()


@pytest.mark.parametrize('controlled', [0, 1])
def test_add_into_phase_grad_controlled(controlled: int):
if controlled_by is None:
assert bloq.call_classically(x=x, phase_grad=phase_grad) == (x, phase_grad_out)
else:
for control in [0, 1]:
phase_grad_out_ctrld = (
phase_grad_out if controlled_by == control else phase_grad
)
assert bloq.call_classically(ctrl=control, x=x, phase_grad=phase_grad) == (
control,
x,
phase_grad_out_ctrld,
)


@pytest.mark.parametrize("controlled_by", [None, 0, 1])
def test_add_into_phase_grad_unitary(controlled_by: Optional[int]):
from qualtran.bloqs.rotations.phase_gradient import _fxp

x_bit, phase_bit = 4, 7
bloq = AddIntoPhaseGrad(x_bit, phase_bit, controlled_by=controlled)
basis_map: Dict[int, int] = {}
num_bits = 1 + x_bit + phase_bit
bloq = AddIntoPhaseGrad(x_bit, phase_bit, controlled_by=controlled_by)

# compute expected unitary manually
num_bits = x_bit + phase_bit + (1 if controlled_by is not None else 0)
expected_unitary = np.zeros((2**num_bits, 2**num_bits))
for control in range(2):
for x in range(2**x_bit):
for phase_grad in range(2**phase_bit):
phase_fxp = _fxp(phase_grad / 2**phase_bit, phase_bit)
x_fxp = _fxp(x / 2**x_bit, x_bit).like(phase_fxp)
if control == controlled:
phase_grad_out = int((phase_fxp + x_fxp).astype(float) * 2**phase_bit)
else:
phase_grad_out = phase_grad
# Test Bloq style classical simulation.
assert bloq.call_classically(ctrl=control, x=x, phase_grad=phase_grad) == (
control,
x,
phase_grad_out,
)
# Prepare basis states mapping for cirq-style simulation.
input_state = int(
f'{control}' + f'{x:0{x_bit}b}' + f'{phase_grad:0{phase_bit}b}', 2
)
output_state = int(
f'{control}' + f'{x:0{x_bit}b}' + f'{phase_grad_out:0{phase_bit}b}', 2
)
basis_map[input_state] = output_state

for x in range(2**x_bit):
for phase_grad in range(2**phase_bit):
phase_fxp = _fxp(phase_grad / 2**phase_bit, phase_bit)
x_fxp = _fxp(x / 2**x_bit, x_bit).like(phase_fxp)
phase_grad_out = int((phase_fxp + x_fxp).astype(float) * 2**phase_bit)

if controlled_by is None:
input_state = int(f'{x:0{x_bit}b}' + f'{phase_grad:0{phase_bit}b}', 2)
output_state = int(f'{x:0{x_bit}b}' + f'{phase_grad_out:0{phase_bit}b}', 2)
expected_unitary[output_state, input_state] = 1
# Test cirq style simulation.
assert len(basis_map) == len(set(basis_map.values()))
else:
for control in [0, 1]:
phase_grad_out_ctrld = (
phase_grad_out if controlled_by == control else phase_grad
)
input_state = int(
f'{control}' + f'{x:0{x_bit}b}' + f'{phase_grad:0{phase_bit}b}', 2
)
output_state = int(
f'{control}' + f'{x:0{x_bit}b}' + f'{phase_grad_out_ctrld:0{phase_bit}b}', 2
)
expected_unitary[output_state, input_state] = 1

# verify bloq unitary
circuit = cirq.Circuit(bloq.on(*cirq.LineQubit.range(num_bits)))
np.testing.assert_allclose(circuit.unitary(), expected_unitary, atol=1e-8)


@pytest.mark.slow
def test_add_into_phase_grad_t_complexity():
x_bit, phase_bit = 4, 7
bloq = AddIntoPhaseGrad(x_bit, phase_bit)
((toffoli, n),) = bloq.bloq_counts().items()
assert bloq.t_complexity() == n * toffoli.t_complexity()


_ADD_SCALED_VAL_INTO_PHASE_REG_EXAMPLES: list[AddScaledValIntoPhaseReg] = [
AddScaledValIntoPhaseReg.from_bitsize(4, 7, 0.123, 6),
AddScaledValIntoPhaseReg.from_bitsize(2, 8, 1.3868682, 8),
AddScaledValIntoPhaseReg.from_bitsize(4, 9, -9.0949456, 5),
AddScaledValIntoPhaseReg.from_bitsize(6, 4, 2.5, 2),
AddScaledValIntoPhaseReg(QFxp(4, 0, signed=False), 4, 1.3868682, QFxp(8, 7, signed=False)),
]


@pytest.mark.parametrize(
'bloq',
[
AddScaledValIntoPhaseReg.from_bitsize(4, 7, 0.123, 6),
AddScaledValIntoPhaseReg.from_bitsize(2, 8, 1.3868682, 8),
AddScaledValIntoPhaseReg.from_bitsize(4, 9, -19.0949456, 5),
AddScaledValIntoPhaseReg.from_bitsize(6, 4, 2.5, 2),
AddScaledValIntoPhaseReg(QFxp(4, 0, signed=False), 4, 1.3868682, QFxp(8, 7, signed=False)),
pytest.param(bloq, marks=pytest.mark.slow if bloq.signature.n_qubits() > 10 else ())
for bloq in _ADD_SCALED_VAL_INTO_PHASE_REG_EXAMPLES
],
)
def test_add_scaled_val_into_phase_reg(bloq):
def test_add_scaled_val_into_phase_reg_classical_sim(bloq: AddScaledValIntoPhaseReg):
cbloq = bloq.decompose_bloq()
for x in range(2**bloq.x_dtype.bitsize):
for phase_grad in range(2**bloq.phase_bitsize):
d = {'x': x, 'phase_grad': phase_grad}
c1 = bloq.on_classical_vals(**d)
c2 = cbloq.on_classical_vals(**d)
assert c1 == c2, f'{d=}, {c1=}, {c2=}'


@pytest.mark.parametrize(
'bloq',
[
pytest.param(bloq, marks=pytest.mark.slow if bloq.signature.n_qubits() > 12 else ())
for bloq in _ADD_SCALED_VAL_INTO_PHASE_REG_EXAMPLES
],
)
def test_add_scaled_val_into_phase_reg_unitary(bloq: AddScaledValIntoPhaseReg):
bloq_unitary = cirq.unitary(bloq)
op = GateHelper(bloq).operation
circuit = cirq.Circuit(cirq.I.on_each(*op.qubits), cirq.decompose_once(op))
decomposed_unitary = circuit.unitary(qubit_order=op.qubits)
np.testing.assert_allclose(bloq_unitary, decomposed_unitary)


@pytest.mark.parametrize('bloq', _ADD_SCALED_VAL_INTO_PHASE_REG_EXAMPLES)
def test_add_scaled_val_into_phase_reg_t_complexity(bloq):
((add_into_phase, n),) = bloq.bloq_counts().items()
assert bloq.t_complexity() == n * add_into_phase.t_complexity()

Expand Down

0 comments on commit 16b5378

Please sign in to comment.