diff --git a/cirq/ops/raw_types.py b/cirq/ops/raw_types.py index e5600fec6c6..17af8bf991f 100644 --- a/cirq/ops/raw_types.py +++ b/cirq/ops/raw_types.py @@ -825,8 +825,8 @@ def _trace_distance_bound_(self) -> float: def _phase_by_(self, phase_turns: float, qubit_index: int) -> 'cirq.Operation': return protocols.phase_by(self.sub_operation, phase_turns, qubit_index) - def __pow__(self, exponent: Any) -> 'cirq.Operation': - return self.sub_operation ** exponent + def __pow__(self, exponent: Any) -> 'cirq.TaggedOperation': + return TaggedOperation(self.sub_operation ** exponent, *self.tags) def __mul__(self, other: Any) -> Any: return self.sub_operation * other diff --git a/cirq/ops/raw_types_test.py b/cirq/ops/raw_types_test.py index d2be7c029e1..866bcf4f26e 100644 --- a/cirq/ops/raw_types_test.py +++ b/cirq/ops/raw_types_test.py @@ -422,6 +422,7 @@ def test_tagged_operation(): op = cirq.X(q1).with_tags('tag1') op_repr = "cirq.X(cirq.GridQubit(1, 1))" assert repr(op) == f"cirq.TaggedOperation({op_repr}, 'tag1')" + assert op == op ** 1 assert op.qubits == (q1,) assert op.tags == ('tag1',) @@ -616,7 +617,7 @@ def test_tagged_operation_forwards_protocols(): y = cirq.Y(q1) tagged_y = cirq.Y(q1).with_tags(tag) - assert tagged_y ** 0.5 == cirq.YPowGate(exponent=0.5)(q1) + assert tagged_y ** 0.5 == cirq.YPowGate(exponent=0.5)(q1).with_tags(tag) assert tagged_y * 2 == (y * 2) assert 3 * tagged_y == (3 * y) assert cirq.phase_by(y, 0.125, 0) == cirq.phase_by(tagged_y, 0.125, 0)