Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added api docs gen (copy from TFQ). #3089

Merged
merged 11 commits into from
Jun 12, 2020
29 changes: 20 additions & 9 deletions cirq/contrib/svg/svg.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,25 @@
import cirq

QBLUE = '#1967d2'
FONT = "Arial"


def fixup_text(text: str):
if '\n' in text:
return '?'
if '[<virtual>]' in text:
# https://github.com/quantumlib/Cirq/issues/2905
# TODO: escape angle brackets when you actually want to display tags
return text.replace('[<virtual>]', '') # coverage: ignore
if '[cirq.VirtualTag()]' in text:
# https://github.com/quantumlib/Cirq/issues/2905
return text.replace('[cirq.VirtualTag()]', '')
return text


def _get_text_width(t: str) -> float:
if '\n' in t:
return _get_text_width('?')
tp = matplotlib.textpath.TextPath((0, 0), t, size=14, prop='Arial')
t = fixup_text(t)
tp = matplotlib.textpath.TextPath((0, 0), t, size=14, prop=FONT)
bb = tp.get_extents()
return bb.width + 10

Expand All @@ -30,7 +43,8 @@ def _rect(x: float,
def _text(x: float, y: float, text: str, fontsize: int = 14):
"""Draw SVG <text> text."""
return f'<text x="{x}" y="{y}" dominant-baseline="middle" ' \
f'text-anchor="middle" font-size="{fontsize}px">{text}</text>'
f'text-anchor="middle" font-size="{fontsize}px" ' \
f'font-family="{FONT}">{text}</text>'


def _fit_horizontal(tdd: 'cirq.TextDiagramDrawer',
Expand Down Expand Up @@ -208,13 +222,10 @@ def tdd_to_svg(
if v.text == '×':
t += _text(x, y + 3, '×', fontsize=40)
continue
if '\n' in v.text:
t += _rect(boxx, boxy, boxwidth, boxheight)
t += _text(x, y, '?', fontsize=18)
continue

v_text = fixup_text(v.text)
t += _rect(boxx, boxy, boxwidth, boxheight)
t += _text(x, y, v.text, fontsize=14 if len(v.text) > 1 else 18)
t += _text(x, y, v_text, fontsize=14 if len(v_text) > 1 else 18)

t += '</svg>'
return t
Expand Down
9 changes: 9 additions & 0 deletions cirq/contrib/svg/svg_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,15 @@ def test_svg():
assert '</svg>' in svg_text


def test_svg_noise():
noise_model = cirq.ConstantQubitNoiseModel(cirq.DepolarizingChannel(p=1e-3))
q = cirq.LineQubit(0)
circuit = cirq.Circuit(cirq.X(q))
circuit = cirq.Circuit(noise_model.noisy_moments(circuit.moments, [q]))
svg = circuit_to_svg(circuit)
assert '>D(0.001)</text>' in svg


def test_validation():
with pytest.raises(ValueError):
circuit_to_svg(cirq.Circuit())
Expand Down
156 changes: 156 additions & 0 deletions cirq/google/api/v2/batch_pb2.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

82 changes: 82 additions & 0 deletions cirq/google/api/v2/batch_pb2.pyi

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 2 additions & 6 deletions cirq/study/visualize.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,6 @@ def plot_state_histogram(result: trial_result.TrialResult) -> np.ndarray:

States is a bitstring representation of all the qubit states in a single
result.
Currently this function assumes each measurement gate applies to only
a single qubit.

Args:
result: The trial results to plot.
Expand All @@ -39,17 +37,15 @@ def plot_state_histogram(result: trial_result.TrialResult) -> np.ndarray:
# This allows cirq to be usable without python3-tk.
import matplotlib.pyplot as plt

num_qubits = len(result.measurements.keys())
num_qubits = sum([value.shape[1] for value in result.measurements.values()])
states = 2**num_qubits
values = np.zeros(states)

# measurements is a dict of {measurement gate key:
# array(repetitions, boolean result)}
# Convert this to an array of repetitions, each with an array of booleans.
# e.g. {q1: array([[True, True]]), q2: array([[False, False]])}
# --> array([[True, False], [True, False]])
measurement_by_result = np.array([
v.transpose()[0] for k, v in result.measurements.items()]).transpose()
measurement_by_result = np.hstack(list(result.measurements.values()))

for meas in measurement_by_result:
# Convert each array of booleans to a string representation.
Expand Down
27 changes: 27 additions & 0 deletions cirq/study/visualize_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,30 @@ def test_plot_state_histogram():
expected_values = [0., 0., 0., 5.]

np.testing.assert_equal(values_plotted, expected_values)


def test_plot_state_histogram_multi_1():
pl.switch_backend('PDF')
qubits = cirq.LineQubit.range(4)
c = cirq.Circuit(
cirq.X.on_each(*qubits[1:]),
cirq.measure(*qubits), # One multi-qubit measurement
)
r = cirq.sample(c, repetitions=5)
values_plotted = visualize.plot_state_histogram(r)
expected_values = [0, 0, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 0]
np.testing.assert_equal(values_plotted, expected_values)


def test_plot_state_histogram_multi_2():
pl.switch_backend('PDF')
qubits = cirq.LineQubit.range(4)
c = cirq.Circuit(
cirq.X.on_each(*qubits[1:]),
cirq.measure(*qubits[:2]), # One multi-qubit measurement
cirq.measure_each(*qubits[2:]), # Multiple single-qubit measurement
)
r = cirq.sample(c, repetitions=5)
values_plotted = visualize.plot_state_histogram(r)
expected_values = [0, 0, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 0]
np.testing.assert_equal(values_plotted, expected_values)
4 changes: 2 additions & 2 deletions cirq/testing/random_circuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,8 @@ def random_circuit(qubits: Union[Sequence[ops.Qid], int],
raise ValueError('At least one qubit must be specified.')
gate_domain = {k: v for k, v in gate_domain.items() if v <= n_qubits}
if not gate_domain:
raise ValueError(f'After removing gates that act on less that '
'{n_qubits}, gate_domain had no gates.')
raise ValueError(f'After removing gates that act on less than '
f'{n_qubits} qubits, gate_domain had no gates.')
max_arity = max(gate_domain.values())

prng = value.parse_random_state(random_state)
Expand Down
Loading