Skip to content

Commit

Permalink
Address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
dstrain115 committed Sep 9, 2024
1 parent 558c7a8 commit edf87c8
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 40 deletions.
61 changes: 27 additions & 34 deletions unitary/alpha/quokka_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
"""Simulation using a "Quokka" device."""

from typing import Any, Callable, Dict, Optional, Sequence
import warnings

import cirq
import numpy as np
Expand All @@ -30,52 +31,46 @@
_REPETITION_KEY = "count"


class QuokkaPostEndpoint:
def __init__(self, name=_DEFAULT_QUOKKA_NAME):
self._endpoint = _REQUEST_ENDPOINT.format(name)

def __call__(self, json_request: JSON_TYPE) -> JSON_TYPE:
try:
import requests
except ImportError as e:
raise ImportError(
"Please install requests library to use Quokka"
"(e.g. pip install requests)"
) from e
result = requests.post(self._endpoint, json=json_request)
return json.loads(result.content)


class QuokkaSampler(cirq.Sampler):
"""Sampler for querying a Quokka quantum simulation device.
See https://www.quokkacomputing.com/ for more information.a
Args:
name: name of your quokka device
endpoint: HTTP url endpoint to post queries to.
post_function: used only for testing to override default
post: used only for testing to override default
behavior to connect to internet URLs.
"""

def __init__(
self,
name: str = _DEFAULT_QUOKKA_NAME,
endpoint: Optional[str] = None,
post_function: Optional[Callable[[JSON_TYPE], JSON_TYPE]] = None,
post: Optional[Callable[[JSON_TYPE], JSON_TYPE]] = None,
):
self.quokka_name = name
self.endpoint = endpoint
self.post_function = post_function

if self.endpoint is None:
self.endpoint = _REQUEST_ENDPOINT.format(self.quokka_name)
if self.post_function is None:
self.post_function = self._post

def _post(self, json_request: JSON_TYPE) -> JSON_TYPE:
"""Sends POST queries to quokka endpoint."""
try:
import requests
except ImportError as e:
print(
"Please install requests library to use Quokka"
"(e.g. pip install requests)"
)
raise e
result = requests.post(self.endpoint, json=json_request)
return json.loads(result.content)
self._post = post or QuokkaPostEndpoint(name)

def run_sweep(
self,
program: "cirq.AbstractCircuit",
params: "cirq.Sweepable",
program: cirq.AbstractCircuit,
params: cirq.Sweepable,
repetitions: int = 1,
) -> Sequence["cirq.Result"]:
) -> Sequence[cirq.Result]:
"""Samples from the given Circuit.
This allows for sweeping over different parameter values,
Expand Down Expand Up @@ -106,10 +101,10 @@ def run_sweep(
if isinstance(op.gate, cirq.MeasurementGate):
key = cirq.measurement_key_name(op)
if key in measure_keys:
print(
warnings.warn(
"Warning! Keys can only be measured once in Quokka simulator"
f"Key {key} will only contain the last measured value"
)
print("Key {key} will only contain the last measured value")
measure_keys[key] = op.qubits
if cirq.QasmOutput.valid_id_re.match(key):
register_names[key] = f"m_{key}"
Expand All @@ -131,7 +126,7 @@ def run_sweep(

# Send data to quokka endpoint
data = {_SCRIPT_KEY: qasm, _REPETITION_KEY: repetitions}
json_results = self.post_function(data)
json_results = self._post(data)

if _ERROR_CODE_KEY in json_results and json_results[_ERROR_CODE_KEY] != 0:
raise RuntimeError(f"Quokka returned an error: {json_results}")
Expand All @@ -144,9 +139,7 @@ def run_sweep(
for key in measure_keys:
register_name = register_names[key]
if register_name not in json_results[_RESULT_KEY]:
raise RuntimeError(
f"Quokka did not measure key {key}: {json_results}"
)
raise KeyError(f"Quokka did not measure key {key}: {json_results}")
result_measurements[key] = np.asarray(
json_results[_RESULT_KEY][register_name], dtype=np.dtype("int8")
)
Expand Down
12 changes: 6 additions & 6 deletions unitary/alpha/quokka_sampler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,13 @@


class FakeQuokkaEndpoint:
def __init__(self, responses: Iterable[quokka_sampler.JSON_TYPE]):
def __init__(self, *responses: quokka_sampler.JSON_TYPE):
self.responses = list(responses)
self.requests = []

def _post(self, json_request: quokka_sampler.JSON_TYPE) -> quokka_sampler.JSON_TYPE:
def __call__(
self, json_request: quokka_sampler.JSON_TYPE
) -> quokka_sampler.JSON_TYPE:
self.requests.append(json_request)
return self.responses.pop()

Expand Down Expand Up @@ -73,9 +75,8 @@ def test_quokka_deterministic_examples(circuit, json_result):
sim = cirq.Simulator()
expected_results = sim.run(circuit, repetitions=5)
json_response = {"error": "no error", "error_code": 0, "result": json_result}
endpoint = FakeQuokkaEndpoint([json_response])
quokka = quokka_sampler.QuokkaSampler(
name="test_mctesterface", post_function=endpoint._post
name="test_mctesterface", post=FakeQuokkaEndpoint(json_response)
)
quokka_results = quokka.run(circuit, repetitions=5)
assert quokka_results == expected_results
Expand All @@ -100,9 +101,8 @@ def test_quokka_run_sweep():
"error_code": 0,
"result": {"m_m2": [[1, 0], [1, 0], [1, 0], [1, 0], [1, 0]]},
}
endpoint = FakeQuokkaEndpoint([json_response, json_response2])
quokka = quokka_sampler.QuokkaSampler(
name="test_mctesterface", post_function=endpoint._post
name="test_mctesterface", post=FakeQuokkaEndpoint(json_response, json_response2)
)
quokka_results = quokka.run_sweep(circuit, sweep, repetitions=5)
assert quokka_results[0] == expected_results[0]

0 comments on commit edf87c8

Please sign in to comment.