diff --git a/cirq-core/cirq/study/result.py b/cirq-core/cirq/study/result.py index 646202e2f65..3eacc6c07be 100644 --- a/cirq-core/cirq/study/result.py +++ b/cirq-core/cirq/study/result.py @@ -97,7 +97,8 @@ def __init__( self, *, # Forces keyword args. params: resolver.ParamResolver, - measurements: Dict[str, np.ndarray], + measurements: Optional[Dict[str, np.ndarray]] = None, + measurements2: Optional[Dict[str, np.ndarray]] = None, ) -> None: """Inits Result. @@ -108,10 +109,53 @@ def __init__( with the first index running over the repetitions, and the second index running over the qubits for the corresponding measurements. + measurements2: A dictionary from measurement gate key to measurement + results. The value for each key is a 3D array of booleans, + with the first index running over the repetitions, the second + index running over "instances" of that key in the circuit, and + the last index running over the qubits for the corresponding + measurements. """ + if measurements is None and measurements2 is None: + measurements2 = {} # for backwards compatibility, allow constructing with None. self.params = params - self._data: Optional[pd.DataFrame] = None self._measurements = measurements + self._measurements2 = measurements2 + self._data: Optional[pd.DataFrame] = None + + @property + def measurements(self) -> Dict[str, np.ndarray]: + if self._measurements is None: + self._measurements = {} + for key, data in self._measurements2.items(): + reps, instances, qubits = data.shape + if instances != 1: + raise ValueError('Cannot extract 2D measurements for repeated keys') + self._measurements[key] = data.reshape((reps, qubits)) + return self._measurements + + @property + def measurements2(self) -> Dict[str, np.ndarray]: + """Returns mapping from measurement key to 3D data array.""" + if self._measurements2 is None: + self._measurements2 = { + key: data[:, np.newaxis, :] for key, data in self._measurements.items() + } + return self._measurements2 + + @property + def repetitions(self) -> int: + if self._measurements2 is not None: + if not self._measurements2: + return 0 + # Get the length quickly from one of the keyed results. + return len(next(iter(self._measurements2.values()))) + else: + if not self._measurements: + return 0 + # Get the length quickly from one of the keyed results. + return len(next(iter(self._measurements.values()))) + @property def data(self) -> pd.DataFrame: @@ -119,7 +163,7 @@ def data(self) -> pd.DataFrame: # Convert to a DataFrame with columns as measurement keys, rows as # repetitions and a big endian integer for individual measurements. converted_dict = {} - for key, val in self._measurements.items(): + for key, val in self.measurements.items(): converted_dict[key] = [value.big_endian_bits_to_int(m_vals) for m_vals in val] # Note that when a numpy array is produced from this data frame, # Pandas will try to use np.int64 as dtype, but will upgrade to @@ -145,17 +189,6 @@ def from_single_parameter_set( """ return Result(params=params, measurements=measurements) - @property - def measurements(self) -> Dict[str, np.ndarray]: - return self._measurements - - @property - def repetitions(self) -> int: - if not self.measurements: - return 0 - # Get the length quickly from one of the keyed results. - return len(next(iter(self.measurements.values()))) - # Reason for 'type: ignore': https://github.com/python/mypy/issues/5273 def multi_measurement_histogram( # type: ignore self, @@ -208,12 +241,12 @@ def multi_measurement_histogram( # type: ignore results. """ fixed_keys = tuple(_key_to_str(key) for key in keys) - samples = zip( + samples: Iterable[Any] = zip( *(self.measurements[sub_key] for sub_key in fixed_keys) - ) # type: Iterable[Any] + ) if len(fixed_keys) == 0: samples = [()] * self.repetitions - c = collections.Counter() # type: collections.Counter + c = collections.Counter() for sample in samples: c[fold_func(sample)] += 1 return c @@ -350,7 +383,7 @@ def _pack_digits(digits: np.ndarray, pack_bits: str = 'auto') -> Tuple[str, bool if pack_bits == 'force': return _pack_bits(digits), True if pack_bits not in ['auto', 'never']: - raise ValueError("Please set `pack_bits` to 'auto', " "'force', or 'never'.") + raise ValueError("Please set `pack_bits` to 'auto', 'force', or 'never'.") # Do error checking here, otherwise the following logic will work # for both "auto" and "never".