From 558c7a8d6d64297cd619226a33bc6d98ced203ee Mon Sep 17 00:00:00 2001 From: Doug Strain Date: Tue, 3 Sep 2024 12:41:13 -0700 Subject: [PATCH 1/3] Add support for sampling from Quokka devices - This sampler converts circuits to QASM and then sends them to a quokka endpoint for simulation. - Any parameterized circuits are resolved and sent point by point to the device. --- unitary/alpha/quokka_sampler.py | 161 +++++++++++++++++++++++++++ unitary/alpha/quokka_sampler_test.py | 108 ++++++++++++++++++ 2 files changed, 269 insertions(+) create mode 100644 unitary/alpha/quokka_sampler.py create mode 100644 unitary/alpha/quokka_sampler_test.py diff --git a/unitary/alpha/quokka_sampler.py b/unitary/alpha/quokka_sampler.py new file mode 100644 index 00000000..f0fe9aa1 --- /dev/null +++ b/unitary/alpha/quokka_sampler.py @@ -0,0 +1,161 @@ +# Copyright 2024 The Unitary Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Simulation using a "Quokka" device.""" + +from typing import Any, Callable, Dict, Optional, Sequence + +import cirq +import numpy as np +import json + +_REQUEST_ENDPOINT = "http://{}.quokkacomputing.com/qsim/qasm" +_DEFAULT_QUOKKA_NAME = "quokka1" + +JSON_TYPE = Dict[str, Any] +_RESULT_KEY = "result" +_ERROR_CODE_KEY = "error_code" +_RESULT_KEY = "result" +_SCRIPT_KEY = "script" +_REPETITION_KEY = "count" + + +class QuokkaSampler(cirq.Sampler): + """Sampler for querying a Quokka quantum simulation device. + + See https://www.quokkacomputing.com/ for more information.a + + Args: + name: name of your quokka device + endpoint: HTTP url endpoint to post queries to. + post_function: used only for testing to override default + behavior to connect to internet URLs. + """ + + def __init__( + self, + name: str = _DEFAULT_QUOKKA_NAME, + endpoint: Optional[str] = None, + post_function: Optional[Callable[[JSON_TYPE], JSON_TYPE]] = None, + ): + self.quokka_name = name + self.endpoint = endpoint + self.post_function = post_function + + if self.endpoint is None: + self.endpoint = _REQUEST_ENDPOINT.format(self.quokka_name) + if self.post_function is None: + self.post_function = self._post + + def _post(self, json_request: JSON_TYPE) -> JSON_TYPE: + """Sends POST queries to quokka endpoint.""" + try: + import requests + except ImportError as e: + print( + "Please install requests library to use Quokka" + "(e.g. pip install requests)" + ) + raise e + result = requests.post(self.endpoint, json=json_request) + return json.loads(result.content) + + def run_sweep( + self, + program: "cirq.AbstractCircuit", + params: "cirq.Sweepable", + repetitions: int = 1, + ) -> Sequence["cirq.Result"]: + """Samples from the given Circuit. + + This allows for sweeping over different parameter values, + unlike the `run` method. The `params` argument will provide a + mapping from `sympy.Symbol`s used within the circuit to a set of + values. Unlike the `run` method, which specifies a single + mapping from symbol to value, this method allows a "sweep" of + values. This allows a user to specify execution of a family of + related circuits efficiently. + + Args: + program: The circuit to sample from. + params: Parameters to run with the program. + repetitions: The number of times to sample. + + Returns: + Result list for this run; one for each possible parameter resolver. + """ + rtn_results = [] + qubits = sorted(program.all_qubits()) + measure_keys = {} + register_names = {} + meas_i = 0 + + # Find all measurements in the circuit and record keys + # so that we can later translate between circuit and QASM registers. + for op in program.all_operations(): + if isinstance(op.gate, cirq.MeasurementGate): + key = cirq.measurement_key_name(op) + if key in measure_keys: + print( + "Warning! Keys can only be measured once in Quokka simulator" + ) + print("Key {key} will only contain the last measured value") + measure_keys[key] = op.qubits + if cirq.QasmOutput.valid_id_re.match(key): + register_names[key] = f"m_{key}" + else: + register_names[key] = f"m{meas_i}" + meas_i += 1 + + # QASM 2.0 does not support parameter sweeps, + # so resolve any symbolic functions to a concrete circuit. + for param_resolver in cirq.to_resolvers(params): + circuit = cirq.resolve_parameters(program, param_resolver) + qasm = cirq.qasm(circuit) + + # Hack to change sqrt-X gates into rx 0.5 gates: + # Since quokka does not support sx or sxdg gates + qasm = qasm.replace("\nsx ", "\nrx(pi*0.5) ").replace( + "\nsxdg ", "\nrx(pi*-0.5) " + ) + + # Send data to quokka endpoint + data = {_SCRIPT_KEY: qasm, _REPETITION_KEY: repetitions} + json_results = self.post_function(data) + + if _ERROR_CODE_KEY in json_results and json_results[_ERROR_CODE_KEY] != 0: + raise RuntimeError(f"Quokka returned an error: {json_results}") + + if _RESULT_KEY not in json_results: + raise RuntimeError(f"Quokka did not return any results: {json_results}") + + # Associate results from json response to measurement keys. + result_measurements = {} + for key in measure_keys: + register_name = register_names[key] + if register_name not in json_results[_RESULT_KEY]: + raise RuntimeError( + f"Quokka did not measure key {key}: {json_results}" + ) + result_measurements[key] = np.asarray( + json_results[_RESULT_KEY][register_name], dtype=np.dtype("int8") + ) + + # Append measurements to eventual result. + rtn_results.append( + cirq.ResultDict( + params=param_resolver, + measurements=result_measurements, + ) + ) + return rtn_results diff --git a/unitary/alpha/quokka_sampler_test.py b/unitary/alpha/quokka_sampler_test.py new file mode 100644 index 00000000..05674c2d --- /dev/null +++ b/unitary/alpha/quokka_sampler_test.py @@ -0,0 +1,108 @@ +# Copyright 2024 The Unitary Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Iterable +import pytest +import cirq +import sympy + +import unitary.alpha.quokka_sampler as quokka_sampler + +# Qubits for testing +_Q = cirq.LineQubit.range(10) + + +class FakeQuokkaEndpoint: + def __init__(self, responses: Iterable[quokka_sampler.JSON_TYPE]): + self.responses = list(responses) + self.requests = [] + + def _post(self, json_request: quokka_sampler.JSON_TYPE) -> quokka_sampler.JSON_TYPE: + self.requests.append(json_request) + return self.responses.pop() + + +@pytest.mark.parametrize( + "circuit,json_result", + [ + ( + cirq.Circuit(cirq.X(_Q[0]), cirq.measure(_Q[0], key="mmm")), + {"m_mmm": [[1], [1], [1], [1], [1]]}, + ), + ( + cirq.Circuit(cirq.X(_Q[0]), cirq.measure(_Q[0])), + {"m0": [[1], [1], [1], [1], [1]]}, + ), + ( + cirq.Circuit( + cirq.X(_Q[0]), cirq.X(_Q[1]), cirq.measure(_Q[0]), cirq.measure(_Q[1]) + ), + {"m0": [[1], [1], [1], [1], [1]], "m1": [[1], [1], [1], [1], [1]]}, + ), + ( + cirq.Circuit( + cirq.X(_Q[0]), + cirq.CNOT(_Q[0], _Q[1]), + cirq.measure(_Q[0]), + cirq.measure(_Q[1]), + ), + {"m0": [[1], [1], [1], [1], [1]], "m1": [[1], [1], [1], [1], [1]]}, + ), + ( + cirq.Circuit( + cirq.X(_Q[0]), + cirq.CNOT(_Q[0], _Q[1]), + cirq.measure(_Q[0], _Q[1], key="m2"), + ), + {"m_m2": [[1, 1], [1, 1], [1, 1], [1, 1], [1, 1]]}, + ), + ], +) +def test_quokka_deterministic_examples(circuit, json_result): + sim = cirq.Simulator() + expected_results = sim.run(circuit, repetitions=5) + json_response = {"error": "no error", "error_code": 0, "result": json_result} + endpoint = FakeQuokkaEndpoint([json_response]) + quokka = quokka_sampler.QuokkaSampler( + name="test_mctesterface", post_function=endpoint._post + ) + quokka_results = quokka.run(circuit, repetitions=5) + assert quokka_results == expected_results + + +def test_quokka_run_sweep(): + sim = cirq.Simulator() + circuit = cirq.Circuit( + cirq.X(_Q[0]), + cirq.X(_Q[1]) ** sympy.Symbol("X_1"), + cirq.measure(_Q[0], _Q[1], key="m2"), + ) + sweep = cirq.Points("X_1", [0, 1]) + expected_results = sim.run_sweep(circuit, sweep, repetitions=5) + json_response = { + "error": "no error", + "error_code": 0, + "result": {"m_m2": [[1, 1], [1, 1], [1, 1], [1, 1], [1, 1]]}, + } + json_response2 = { + "error": "no error", + "error_code": 0, + "result": {"m_m2": [[1, 0], [1, 0], [1, 0], [1, 0], [1, 0]]}, + } + endpoint = FakeQuokkaEndpoint([json_response, json_response2]) + quokka = quokka_sampler.QuokkaSampler( + name="test_mctesterface", post_function=endpoint._post + ) + quokka_results = quokka.run_sweep(circuit, sweep, repetitions=5) + assert quokka_results[0] == expected_results[0] From edf87c822084e9f5d7885d31d93eabfba1b12916 Mon Sep 17 00:00:00 2001 From: Doug Strain Date: Mon, 9 Sep 2024 12:05:45 -0700 Subject: [PATCH 2/3] Address comments --- unitary/alpha/quokka_sampler.py | 61 ++++++++++++---------------- unitary/alpha/quokka_sampler_test.py | 12 +++--- 2 files changed, 33 insertions(+), 40 deletions(-) diff --git a/unitary/alpha/quokka_sampler.py b/unitary/alpha/quokka_sampler.py index f0fe9aa1..0222e66c 100644 --- a/unitary/alpha/quokka_sampler.py +++ b/unitary/alpha/quokka_sampler.py @@ -14,6 +14,7 @@ """Simulation using a "Quokka" device.""" from typing import Any, Callable, Dict, Optional, Sequence +import warnings import cirq import numpy as np @@ -30,6 +31,22 @@ _REPETITION_KEY = "count" +class QuokkaPostEndpoint: + def __init__(self, name=_DEFAULT_QUOKKA_NAME): + self._endpoint = _REQUEST_ENDPOINT.format(name) + + def __call__(self, json_request: JSON_TYPE) -> JSON_TYPE: + try: + import requests + except ImportError as e: + raise ImportError( + "Please install requests library to use Quokka" + "(e.g. pip install requests)" + ) from e + result = requests.post(self._endpoint, json=json_request) + return json.loads(result.content) + + class QuokkaSampler(cirq.Sampler): """Sampler for querying a Quokka quantum simulation device. @@ -37,45 +54,23 @@ class QuokkaSampler(cirq.Sampler): Args: name: name of your quokka device - endpoint: HTTP url endpoint to post queries to. - post_function: used only for testing to override default + post: used only for testing to override default behavior to connect to internet URLs. """ def __init__( self, name: str = _DEFAULT_QUOKKA_NAME, - endpoint: Optional[str] = None, - post_function: Optional[Callable[[JSON_TYPE], JSON_TYPE]] = None, + post: Optional[Callable[[JSON_TYPE], JSON_TYPE]] = None, ): - self.quokka_name = name - self.endpoint = endpoint - self.post_function = post_function - - if self.endpoint is None: - self.endpoint = _REQUEST_ENDPOINT.format(self.quokka_name) - if self.post_function is None: - self.post_function = self._post - - def _post(self, json_request: JSON_TYPE) -> JSON_TYPE: - """Sends POST queries to quokka endpoint.""" - try: - import requests - except ImportError as e: - print( - "Please install requests library to use Quokka" - "(e.g. pip install requests)" - ) - raise e - result = requests.post(self.endpoint, json=json_request) - return json.loads(result.content) + self._post = post or QuokkaPostEndpoint(name) def run_sweep( self, - program: "cirq.AbstractCircuit", - params: "cirq.Sweepable", + program: cirq.AbstractCircuit, + params: cirq.Sweepable, repetitions: int = 1, - ) -> Sequence["cirq.Result"]: + ) -> Sequence[cirq.Result]: """Samples from the given Circuit. This allows for sweeping over different parameter values, @@ -106,10 +101,10 @@ def run_sweep( if isinstance(op.gate, cirq.MeasurementGate): key = cirq.measurement_key_name(op) if key in measure_keys: - print( + warnings.warn( "Warning! Keys can only be measured once in Quokka simulator" + f"Key {key} will only contain the last measured value" ) - print("Key {key} will only contain the last measured value") measure_keys[key] = op.qubits if cirq.QasmOutput.valid_id_re.match(key): register_names[key] = f"m_{key}" @@ -131,7 +126,7 @@ def run_sweep( # Send data to quokka endpoint data = {_SCRIPT_KEY: qasm, _REPETITION_KEY: repetitions} - json_results = self.post_function(data) + json_results = self._post(data) if _ERROR_CODE_KEY in json_results and json_results[_ERROR_CODE_KEY] != 0: raise RuntimeError(f"Quokka returned an error: {json_results}") @@ -144,9 +139,7 @@ def run_sweep( for key in measure_keys: register_name = register_names[key] if register_name not in json_results[_RESULT_KEY]: - raise RuntimeError( - f"Quokka did not measure key {key}: {json_results}" - ) + raise KeyError(f"Quokka did not measure key {key}: {json_results}") result_measurements[key] = np.asarray( json_results[_RESULT_KEY][register_name], dtype=np.dtype("int8") ) diff --git a/unitary/alpha/quokka_sampler_test.py b/unitary/alpha/quokka_sampler_test.py index 05674c2d..93a6a0a2 100644 --- a/unitary/alpha/quokka_sampler_test.py +++ b/unitary/alpha/quokka_sampler_test.py @@ -24,11 +24,13 @@ class FakeQuokkaEndpoint: - def __init__(self, responses: Iterable[quokka_sampler.JSON_TYPE]): + def __init__(self, *responses: quokka_sampler.JSON_TYPE): self.responses = list(responses) self.requests = [] - def _post(self, json_request: quokka_sampler.JSON_TYPE) -> quokka_sampler.JSON_TYPE: + def __call__( + self, json_request: quokka_sampler.JSON_TYPE + ) -> quokka_sampler.JSON_TYPE: self.requests.append(json_request) return self.responses.pop() @@ -73,9 +75,8 @@ def test_quokka_deterministic_examples(circuit, json_result): sim = cirq.Simulator() expected_results = sim.run(circuit, repetitions=5) json_response = {"error": "no error", "error_code": 0, "result": json_result} - endpoint = FakeQuokkaEndpoint([json_response]) quokka = quokka_sampler.QuokkaSampler( - name="test_mctesterface", post_function=endpoint._post + name="test_mctesterface", post=FakeQuokkaEndpoint(json_response) ) quokka_results = quokka.run(circuit, repetitions=5) assert quokka_results == expected_results @@ -100,9 +101,8 @@ def test_quokka_run_sweep(): "error_code": 0, "result": {"m_m2": [[1, 0], [1, 0], [1, 0], [1, 0], [1, 0]]}, } - endpoint = FakeQuokkaEndpoint([json_response, json_response2]) quokka = quokka_sampler.QuokkaSampler( - name="test_mctesterface", post_function=endpoint._post + name="test_mctesterface", post=FakeQuokkaEndpoint(json_response, json_response2) ) quokka_results = quokka.run_sweep(circuit, sweep, repetitions=5) assert quokka_results[0] == expected_results[0] From f90606ee05b8477b0c28e85b1b056b4984421c1a Mon Sep 17 00:00:00 2001 From: Doug Strain Date: Mon, 9 Sep 2024 13:04:02 -0700 Subject: [PATCH 3/3] Remove unneeded import --- unitary/alpha/quokka_sampler_test.py | 1 - 1 file changed, 1 deletion(-) diff --git a/unitary/alpha/quokka_sampler_test.py b/unitary/alpha/quokka_sampler_test.py index 93a6a0a2..f4f81763 100644 --- a/unitary/alpha/quokka_sampler_test.py +++ b/unitary/alpha/quokka_sampler_test.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Iterable import pytest import cirq import sympy