diff --git a/cirq/circuits/circuit_operation.py b/cirq/circuits/circuit_operation.py index 25b252344b4..c6cfca27cb0 100644 --- a/cirq/circuits/circuit_operation.py +++ b/cirq/circuits/circuit_operation.py @@ -148,6 +148,12 @@ def __post_init__(self): 'in a CircuitOperation. Consider remapping the key using ' '`measurement_key_map` in the CircuitOperation constructor.' ) + + # Disallow qid mapping dimension conflicts. + for q, q_new in self.qubit_map.items(): + if q_new.dimension != q.dimension: + raise ValueError(f'Qid dimension conflict.\nFrom qid: {q}\nTo qid: {q_new}') + # Ensure that param_resolver is converted to an actual ParamResolver. object.__setattr__(self, 'param_resolver', study.ParamResolver(self.param_resolver)) @@ -445,6 +451,8 @@ def with_qubit_mapping( for q in self.circuit.all_qubits(): q_new = transform(self.qubit_map.get(q, q)) if q_new != q: + if q_new.dimension != q.dimension: + raise ValueError(f'Qid dimension conflict.\nFrom qid: {q}\nTo qid: {q_new}') new_map[q] = q_new new_op = self.replace(qubit_map=new_map) if len(set(new_op.qubits)) != len(set(self.qubits)): diff --git a/cirq/circuits/circuit_operation_test.py b/cirq/circuits/circuit_operation_test.py index 136c691b0f5..4cc868b165c 100644 --- a/cirq/circuits/circuit_operation_test.py +++ b/cirq/circuits/circuit_operation_test.py @@ -102,6 +102,24 @@ def test_invalid_measurement_keys(): _ = cirq.CircuitOperation(circuit, measurement_key_map={'m:a': 'ma'}) +def test_invalid_qubit_mapping(): + q = cirq.LineQubit(0) + q3 = cirq.LineQid(1, dimension=3) + + # Invalid qid remapping dict in constructor + with pytest.raises(ValueError, match='Qid dimension conflict'): + _ = cirq.CircuitOperation(cirq.FrozenCircuit(), qubit_map={q: q3}) + + # Invalid qid remapping dict in with_qubit_mapping call + c_op = cirq.CircuitOperation(cirq.FrozenCircuit(cirq.X(q))) + with pytest.raises(ValueError, match='Qid dimension conflict'): + _ = c_op.with_qubit_mapping({q: q3}) + + # Invalid qid remapping function in with_qubit_mapping call + with pytest.raises(ValueError, match='Qid dimension conflict'): + _ = c_op.with_qubit_mapping(lambda q: q3) + + def test_circuit_sharing(): a, b, c = cirq.LineQubit.range(3) circuit = cirq.FrozenCircuit(