Skip to content

Commit

Permalink
Handle qubits in the __str__ of StateVectorTrialResult (#6180)
Browse files Browse the repository at this point in the history
  • Loading branch information
vtomole authored Jul 4, 2023
1 parent 9849695 commit ad6e649
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 2 deletions.
5 changes: 3 additions & 2 deletions cirq-core/cirq/sim/state_vector_simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

from cirq import _compat, ops, value, qis
from cirq.sim import simulator, state_vector, simulator_base
from cirq.protocols import qid_shape

if TYPE_CHECKING:
import cirq
Expand All @@ -31,7 +32,7 @@
class SimulatesIntermediateStateVector(
Generic[TStateVectorStepResult],
simulator_base.SimulatorBase[
TStateVectorStepResult, 'cirq.StateVectorTrialResult', 'cirq.StateVectorSimulationState',
TStateVectorStepResult, 'cirq.StateVectorTrialResult', 'cirq.StateVectorSimulationState'
],
simulator.SimulatesAmplitudes,
metaclass=abc.ABCMeta,
Expand Down Expand Up @@ -172,7 +173,7 @@ def __str__(self) -> str:
size = np.prod(shape, dtype=np.int64)
final = final.reshape(size)
if len([1 for e in final if abs(e) > 0.001]) < 16:
state_vector = qis.dirac_notation(final, 3)
state_vector = qis.dirac_notation(final, 3, qid_shape(substate.qubits))
else:
state_vector = str(final)
label = f'qubits: {substate.qubits}' if substate.qubits else 'phase:'
Expand Down
22 changes: 22 additions & 0 deletions cirq-core/cirq/sim/state_vector_simulator_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,28 @@ def test_str_big():
assert 'output vector: [0.03125+0.j 0.03125+0.j 0.03125+0.j ..' in str(result)


def test_str_qudit():
qutrit = cirq.LineQid(0, dimension=3)
final_simulator_state = cirq.StateVectorSimulationState(
prng=np.random.RandomState(0),
qubits=[qutrit],
initial_state=np.array([0, 0, 1]),
dtype=np.complex64,
)
result = cirq.StateVectorTrialResult(cirq.ParamResolver(), {}, final_simulator_state)
assert "|2⟩" in str(result)

ququart = cirq.LineQid(0, dimension=4)
final_simulator_state = cirq.StateVectorSimulationState(
prng=np.random.RandomState(0),
qubits=[ququart],
initial_state=np.array([0, 1, 0, 0]),
dtype=np.complex64,
)
result = cirq.StateVectorTrialResult(cirq.ParamResolver(), {}, final_simulator_state)
assert "|1⟩" in str(result)


def test_pretty_print():
final_simulator_state = cirq.StateVectorSimulationState(
available_buffer=np.array([1]),
Expand Down

0 comments on commit ad6e649

Please sign in to comment.