Skip to content

Commit

Permalink
Change TaggedOperation's __pow__ from returning Operation to re…
Browse files Browse the repository at this point in the history
…turning `TaggedOperation` (#4916)

Fixes: #4914
  • Loading branch information
vtomole authored Jan 29, 2022
1 parent 275372e commit d7a3906
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 3 deletions.
4 changes: 2 additions & 2 deletions cirq-core/cirq/ops/raw_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -825,8 +825,8 @@ def _trace_distance_bound_(self) -> float:
def _phase_by_(self, phase_turns: float, qubit_index: int) -> 'cirq.Operation':
return protocols.phase_by(self.sub_operation, phase_turns, qubit_index)

def __pow__(self, exponent: Any) -> 'cirq.Operation':
return self.sub_operation ** exponent
def __pow__(self, exponent: Any) -> 'cirq.TaggedOperation':
return TaggedOperation(self.sub_operation ** exponent, *self.tags)

def __mul__(self, other: Any) -> Any:
return self.sub_operation * other
Expand Down
3 changes: 2 additions & 1 deletion cirq-core/cirq/ops/raw_types_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,6 +422,7 @@ def test_tagged_operation():
op = cirq.X(q1).with_tags('tag1')
op_repr = "cirq.X(cirq.GridQubit(1, 1))"
assert repr(op) == f"cirq.TaggedOperation({op_repr}, 'tag1')"
assert op == op ** 1

assert op.qubits == (q1,)
assert op.tags == ('tag1',)
Expand Down Expand Up @@ -616,7 +617,7 @@ def test_tagged_operation_forwards_protocols():

y = cirq.Y(q1)
tagged_y = cirq.Y(q1).with_tags(tag)
assert tagged_y ** 0.5 == cirq.YPowGate(exponent=0.5)(q1)
assert tagged_y ** 0.5 == cirq.YPowGate(exponent=0.5)(q1).with_tags(tag)
assert tagged_y * 2 == (y * 2)
assert 3 * tagged_y == (3 * y)
assert cirq.phase_by(y, 0.125, 0) == cirq.phase_by(tagged_y, 0.125, 0)
Expand Down

0 comments on commit d7a3906

Please sign in to comment.