Skip to content

Commit

Permalink
change: Return observable target if absent for RT (#1026)
Browse files Browse the repository at this point in the history
  • Loading branch information
speller26 authored Sep 3, 2024
1 parent cbe7fc9 commit 2d57c5f
Show file tree
Hide file tree
Showing 6 changed files with 28 additions and 20 deletions.
4 changes: 2 additions & 2 deletions src/braket/circuits/circuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,8 +319,8 @@ def add_result_type(
observable = Circuit._extract_observable(result_type_to_add)
# We can skip this for now for AdjointGradient (the only subtype of this
# type) because AdjointGradient can only be used when `shots=0`, and the
# qubit_observable_mapping is used to generate basis rotation instrunctions
# and make sure the observables are simultaneously commuting for `shots>0` mode.
# qubit_observable_mapping is used to generate basis rotation instructions
# and make sure the observables mutually commute for `shots>0` mode.
supports_basis_rotation_instructions = not isinstance(
result_type_to_add, ObservableParameterResultType
)
Expand Down
12 changes: 5 additions & 7 deletions src/braket/circuits/observable.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,20 +40,18 @@ def __init__(
self, qubit_count: int, ascii_symbols: Sequence[str], targets: QubitSetInput | None = None
):
super().__init__(qubit_count=qubit_count, ascii_symbols=ascii_symbols)
if targets is not None:
targets = QubitSet(targets)
targets = QubitSet(targets)
if targets:
if (num_targets := len(targets)) != qubit_count:
raise ValueError(
f"Length of target {num_targets} does not match qubit count {qubit_count}"
)
self._targets = targets
else:
self._targets = None
self._targets = targets
self._coef = 1

def _unscaled(self) -> Observable:
return Observable(
qubit_count=self.qubit_count, ascii_symbols=self.ascii_symbols, targets=self.targets
qubit_count=self.qubit_count, ascii_symbols=self.ascii_symbols, targets=self._targets
)

def to_ir(
Expand Down Expand Up @@ -207,7 +205,7 @@ def __sub__(self, other: Observable):
def __repr__(self) -> str:
return (
f"{self.name}('qubit_count': {self._qubit_count})"
if self._targets is None
if not self._targets
else f"{self.name}('qubit_count': {self._qubit_count}, 'target': {self._targets})"
)

Expand Down
12 changes: 6 additions & 6 deletions src/braket/circuits/observables.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,9 +322,9 @@ def __init__(self, observables: list[Observable]):
f"{'@'.join([obs.ascii_symbols[0] for obs in unscaled_factors])}"
)
all_targets = [factor.targets for factor in unscaled_factors]
if all(targets is None for targets in all_targets):
merged_targets = None
elif all(targets is not None for targets in all_targets):
if not any(all_targets):
merged_targets = QubitSet()
elif all(all_targets):
flat_targets = [qubit for target in all_targets for qubit in target]
merged_targets = QubitSet(flat_targets)
if len(merged_targets) != len(flat_targets):
Expand Down Expand Up @@ -508,9 +508,9 @@ def __init__(self, observables: list[Observable], display_name: str = "Hamiltoni
self._summands = tuple(flattened_observables)
qubit_count = max(flattened_observables, key=lambda obs: obs.qubit_count).qubit_count
all_targets = [observable.targets for observable in flattened_observables]
if all(targets is None for targets in all_targets):
targets = None
elif all(targets is not None for targets in all_targets):
if not any(all_targets):
targets = QubitSet()
elif all(all_targets):
targets = all_targets
else:
raise ValueError("Cannot mix terms with and without targets")
Expand Down
2 changes: 1 addition & 1 deletion src/braket/circuits/result_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ def observable(self) -> Observable:

@property
def target(self) -> QubitSet:
return self._target
return self._target or self._observable.targets

@target.setter
def target(self, target: QubitSetInput) -> None:
Expand Down
8 changes: 4 additions & 4 deletions src/braket/circuits/result_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ def __init__(

def _to_openqasm(self, serialization_properties: OpenQASMSerializationProperties) -> str:
observable_ir = self.observable.to_ir(
target=self.target,
target=self._target,
ir_type=IRType.OPENQASM,
serialization_properties=serialization_properties,
)
Expand Down Expand Up @@ -477,7 +477,7 @@ def _to_jaqcd(self) -> ir.Expectation:

def _to_openqasm(self, serialization_properties: OpenQASMSerializationProperties) -> str:
observable_ir = self.observable.to_ir(
target=self.target,
target=self._target,
ir_type=IRType.OPENQASM,
serialization_properties=serialization_properties,
)
Expand Down Expand Up @@ -552,7 +552,7 @@ def _to_jaqcd(self) -> ir.Sample:

def _to_openqasm(self, serialization_properties: OpenQASMSerializationProperties) -> str:
observable_ir = self.observable.to_ir(
target=self.target,
target=self._target,
ir_type=IRType.OPENQASM,
serialization_properties=serialization_properties,
)
Expand Down Expand Up @@ -632,7 +632,7 @@ def _to_jaqcd(self) -> ir.Variance:

def _to_openqasm(self, serialization_properties: OpenQASMSerializationProperties) -> str:
observable_ir = self.observable.to_ir(
target=self.target,
target=self._target,
ir_type=IRType.OPENQASM,
serialization_properties=serialization_properties,
)
Expand Down
10 changes: 10 additions & 0 deletions test/unit_tests/braket/circuits/test_result_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from braket.circuits.free_parameter import FreeParameter
from braket.circuits.result_type import ObservableParameterResultType
from braket.circuits.serialization import IRType
from braket.registers import QubitSet


@pytest.fixture
Expand Down Expand Up @@ -168,6 +169,15 @@ def test_obs_rt_repr():
)


def test_obs_rt_target():
assert ObservableResultType(
ascii_symbols=["Obs"], observable=Observable.X(), target=1
).target == QubitSet(1)
assert ObservableResultType(
ascii_symbols=["Obs"], observable=Observable.X(1)
).target == QubitSet(1)


@pytest.mark.parametrize(
"ir_type, serialization_properties, expected_exception, expected_message",
[
Expand Down

0 comments on commit 2d57c5f

Please sign in to comment.