diff --git a/unitary/alpha/quokka_sampler.py b/unitary/alpha/quokka_sampler.py index f0fe9aa1..0222e66c 100644 --- a/unitary/alpha/quokka_sampler.py +++ b/unitary/alpha/quokka_sampler.py @@ -14,6 +14,7 @@ """Simulation using a "Quokka" device.""" from typing import Any, Callable, Dict, Optional, Sequence +import warnings import cirq import numpy as np @@ -30,6 +31,22 @@ _REPETITION_KEY = "count" +class QuokkaPostEndpoint: + def __init__(self, name=_DEFAULT_QUOKKA_NAME): + self._endpoint = _REQUEST_ENDPOINT.format(name) + + def __call__(self, json_request: JSON_TYPE) -> JSON_TYPE: + try: + import requests + except ImportError as e: + raise ImportError( + "Please install requests library to use Quokka" + "(e.g. pip install requests)" + ) from e + result = requests.post(self._endpoint, json=json_request) + return json.loads(result.content) + + class QuokkaSampler(cirq.Sampler): """Sampler for querying a Quokka quantum simulation device. @@ -37,45 +54,23 @@ class QuokkaSampler(cirq.Sampler): Args: name: name of your quokka device - endpoint: HTTP url endpoint to post queries to. - post_function: used only for testing to override default + post: used only for testing to override default behavior to connect to internet URLs. """ def __init__( self, name: str = _DEFAULT_QUOKKA_NAME, - endpoint: Optional[str] = None, - post_function: Optional[Callable[[JSON_TYPE], JSON_TYPE]] = None, + post: Optional[Callable[[JSON_TYPE], JSON_TYPE]] = None, ): - self.quokka_name = name - self.endpoint = endpoint - self.post_function = post_function - - if self.endpoint is None: - self.endpoint = _REQUEST_ENDPOINT.format(self.quokka_name) - if self.post_function is None: - self.post_function = self._post - - def _post(self, json_request: JSON_TYPE) -> JSON_TYPE: - """Sends POST queries to quokka endpoint.""" - try: - import requests - except ImportError as e: - print( - "Please install requests library to use Quokka" - "(e.g. pip install requests)" - ) - raise e - result = requests.post(self.endpoint, json=json_request) - return json.loads(result.content) + self._post = post or QuokkaPostEndpoint(name) def run_sweep( self, - program: "cirq.AbstractCircuit", - params: "cirq.Sweepable", + program: cirq.AbstractCircuit, + params: cirq.Sweepable, repetitions: int = 1, - ) -> Sequence["cirq.Result"]: + ) -> Sequence[cirq.Result]: """Samples from the given Circuit. This allows for sweeping over different parameter values, @@ -106,10 +101,10 @@ def run_sweep( if isinstance(op.gate, cirq.MeasurementGate): key = cirq.measurement_key_name(op) if key in measure_keys: - print( + warnings.warn( "Warning! Keys can only be measured once in Quokka simulator" + f"Key {key} will only contain the last measured value" ) - print("Key {key} will only contain the last measured value") measure_keys[key] = op.qubits if cirq.QasmOutput.valid_id_re.match(key): register_names[key] = f"m_{key}" @@ -131,7 +126,7 @@ def run_sweep( # Send data to quokka endpoint data = {_SCRIPT_KEY: qasm, _REPETITION_KEY: repetitions} - json_results = self.post_function(data) + json_results = self._post(data) if _ERROR_CODE_KEY in json_results and json_results[_ERROR_CODE_KEY] != 0: raise RuntimeError(f"Quokka returned an error: {json_results}") @@ -144,9 +139,7 @@ def run_sweep( for key in measure_keys: register_name = register_names[key] if register_name not in json_results[_RESULT_KEY]: - raise RuntimeError( - f"Quokka did not measure key {key}: {json_results}" - ) + raise KeyError(f"Quokka did not measure key {key}: {json_results}") result_measurements[key] = np.asarray( json_results[_RESULT_KEY][register_name], dtype=np.dtype("int8") ) diff --git a/unitary/alpha/quokka_sampler_test.py b/unitary/alpha/quokka_sampler_test.py index 05674c2d..93a6a0a2 100644 --- a/unitary/alpha/quokka_sampler_test.py +++ b/unitary/alpha/quokka_sampler_test.py @@ -24,11 +24,13 @@ class FakeQuokkaEndpoint: - def __init__(self, responses: Iterable[quokka_sampler.JSON_TYPE]): + def __init__(self, *responses: quokka_sampler.JSON_TYPE): self.responses = list(responses) self.requests = [] - def _post(self, json_request: quokka_sampler.JSON_TYPE) -> quokka_sampler.JSON_TYPE: + def __call__( + self, json_request: quokka_sampler.JSON_TYPE + ) -> quokka_sampler.JSON_TYPE: self.requests.append(json_request) return self.responses.pop() @@ -73,9 +75,8 @@ def test_quokka_deterministic_examples(circuit, json_result): sim = cirq.Simulator() expected_results = sim.run(circuit, repetitions=5) json_response = {"error": "no error", "error_code": 0, "result": json_result} - endpoint = FakeQuokkaEndpoint([json_response]) quokka = quokka_sampler.QuokkaSampler( - name="test_mctesterface", post_function=endpoint._post + name="test_mctesterface", post=FakeQuokkaEndpoint(json_response) ) quokka_results = quokka.run(circuit, repetitions=5) assert quokka_results == expected_results @@ -100,9 +101,8 @@ def test_quokka_run_sweep(): "error_code": 0, "result": {"m_m2": [[1, 0], [1, 0], [1, 0], [1, 0], [1, 0]]}, } - endpoint = FakeQuokkaEndpoint([json_response, json_response2]) quokka = quokka_sampler.QuokkaSampler( - name="test_mctesterface", post_function=endpoint._post + name="test_mctesterface", post=FakeQuokkaEndpoint(json_response, json_response2) ) quokka_results = quokka.run_sweep(circuit, sweep, repetitions=5) assert quokka_results[0] == expected_results[0]