diff --git a/cirq/contrib/svg/svg.py b/cirq/contrib/svg/svg.py index 72a7ecb886a..ee04833a78d 100644 --- a/cirq/contrib/svg/svg.py +++ b/cirq/contrib/svg/svg.py @@ -6,12 +6,25 @@ import cirq QBLUE = '#1967d2' +FONT = "Arial" + + +def fixup_text(text: str): + if '\n' in text: + return '?' + if '[]' in text: + # https://github.com/quantumlib/Cirq/issues/2905 + # TODO: escape angle brackets when you actually want to display tags + return text.replace('[]', '') # coverage: ignore + if '[cirq.VirtualTag()]' in text: + # https://github.com/quantumlib/Cirq/issues/2905 + return text.replace('[cirq.VirtualTag()]', '') + return text def _get_text_width(t: str) -> float: - if '\n' in t: - return _get_text_width('?') - tp = matplotlib.textpath.TextPath((0, 0), t, size=14, prop='Arial') + t = fixup_text(t) + tp = matplotlib.textpath.TextPath((0, 0), t, size=14, prop=FONT) bb = tp.get_extents() return bb.width + 10 @@ -30,7 +43,8 @@ def _rect(x: float, def _text(x: float, y: float, text: str, fontsize: int = 14): """Draw SVG text.""" return f'{text}' + f'text-anchor="middle" font-size="{fontsize}px" ' \ + f'font-family="{FONT}">{text}' def _fit_horizontal(tdd: 'cirq.TextDiagramDrawer', @@ -208,13 +222,10 @@ def tdd_to_svg( if v.text == '×': t += _text(x, y + 3, '×', fontsize=40) continue - if '\n' in v.text: - t += _rect(boxx, boxy, boxwidth, boxheight) - t += _text(x, y, '?', fontsize=18) - continue + v_text = fixup_text(v.text) t += _rect(boxx, boxy, boxwidth, boxheight) - t += _text(x, y, v.text, fontsize=14 if len(v.text) > 1 else 18) + t += _text(x, y, v_text, fontsize=14 if len(v_text) > 1 else 18) t += '' return t diff --git a/cirq/contrib/svg/svg_test.py b/cirq/contrib/svg/svg_test.py index 0cd602fe6d3..f71787b4142 100644 --- a/cirq/contrib/svg/svg_test.py +++ b/cirq/contrib/svg/svg_test.py @@ -22,6 +22,15 @@ def test_svg(): assert '' in svg_text +def test_svg_noise(): + noise_model = cirq.ConstantQubitNoiseModel(cirq.DepolarizingChannel(p=1e-3)) + q = cirq.LineQubit(0) + circuit = cirq.Circuit(cirq.X(q)) + circuit = cirq.Circuit(noise_model.noisy_moments(circuit.moments, [q])) + svg = circuit_to_svg(circuit) + assert '>D(0.001)' in svg + + def test_validation(): with pytest.raises(ValueError): circuit_to_svg(cirq.Circuit())