Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Avoid copying unnecessary buffers between simulation iterations #4789

Merged
merged 26 commits into from
Jan 14, 2022
Merged
Show file tree
Hide file tree
Changes from 25 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
993cf0c
Add a with_buffer parameter to ActOnArgs.copy
yjt98765 Dec 30, 2021
6abb2b9
Fix mypy error
yjt98765 Dec 31, 2021
9fa2bed
Merge branch 'master' into actonarg
yjt98765 Dec 31, 2021
a865b7e
Change copy's parameter to reuse_buffer
yjt98765 Jan 2, 2022
3af3ec4
Change the semantics of reuse_buffer parameter
yjt98765 Jan 3, 2022
2a556e4
Merge branch 'master' into actonarg
yjt98765 Jan 5, 2022
df67918
Add docstring and deprecation warning
yjt98765 Jan 5, 2022
2d53c76
Support default buffer parameters in ActOnArgs
yjt98765 Jan 5, 2022
4f10146
Fix CI errors
yjt98765 Jan 5, 2022
280ad3e
Fix test_state_vector_trial_result_repr
yjt98765 Jan 5, 2022
5f821e4
Add test for deprecation warnings
yjt98765 Jan 5, 2022
5edf97f
Fix CI errors
yjt98765 Jan 5, 2022
7bc4908
Merge branch 'master' into actonarg
yjt98765 Jan 6, 2022
412c1bc
Use assert_deprecated for deprecation test
yjt98765 Jan 6, 2022
0480241
Add a test case for the deprecation warning in _run
yjt98765 Jan 6, 2022
b2fda13
Fix coverage and type errors
yjt98765 Jan 6, 2022
a006a39
Fix a coverage error
yjt98765 Jan 6, 2022
49cb93d
Merge branch 'master' into actonarg
yjt98765 Jan 7, 2022
67141cb
Merge branch 'master' into actonarg
yjt98765 Jan 12, 2022
8e8076b
Raise a ValueError when qid_shape cannot be inferred
yjt98765 Jan 12, 2022
3369294
Fix type hint and deprecation deadline problems
yjt98765 Jan 13, 2022
7f7ff17
Rename reuse_buffer to deep_copy_buffers
yjt98765 Jan 13, 2022
195b802
Merge branch 'master' into actonarg
yjt98765 Jan 14, 2022
62addd1
Add shallow copy logic to copy method
yjt98765 Jan 14, 2022
2b948a9
Merge branch 'master' into actonarg
yjt98765 Jan 14, 2022
5d260b1
Merge branch 'master' into actonarg
CirqBot Jan 14, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cirq-core/cirq/contrib/quimb/mps_simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ def __str__(self) -> str:
def _value_equality_values_(self) -> Any:
return self.qubit_map, self.M, self.simulation_options, self.grouping

def _on_copy(self, target: 'MPSState'):
def _on_copy(self, target: 'MPSState', deep_copy_buffers: bool = True):
target.simulation_options = self.simulation_options
target.grouping = self.grouping
target.M = [x.copy() for x in self.M]
Expand Down
2 changes: 1 addition & 1 deletion cirq-core/cirq/protocols/act_on_protocol_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def __init__(self, fallback_result: Any = NotImplemented, measurements=None):
def _perform_measurement(self, qubits):
return self.measurements # coverage: ignore

def copy(self):
def copy(self, deep_copy_buffers: bool = True):
return DummyActOnArgs(self.fallback_result, self.measurements.copy()) # coverage: ignore

def _act_on_fallback_(
Expand Down
29 changes: 25 additions & 4 deletions cirq-core/cirq/sim/act_on_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
"""Objects and methods for acting efficiently on a state tensor."""
import abc
import copy
import inspect
from typing import (
Any,
Dict,
Expand All @@ -26,6 +27,7 @@
Optional,
Iterator,
)
import warnings

import numpy as np

Expand Down Expand Up @@ -113,14 +115,33 @@ def _perform_measurement(self, qubits: Sequence['cirq.Qid']) -> List[int]:
"""Child classes that perform measurements should implement this with
the implementation."""

def copy(self: TSelf) -> TSelf:
"""Creates a copy of the object."""
def copy(self: TSelf, deep_copy_buffers: bool = True) -> TSelf:
"""Creates a copy of the object.

Args:
deep_copy_buffers: If True, buffers will also be deep-copied.
Otherwise the copy will share a reference to the original object's
buffers.

Returns:
A copied instance.
"""
args = copy.copy(self)
self._on_copy(args)
if 'deep_copy_buffers' in inspect.signature(self._on_copy).parameters:
self._on_copy(args, deep_copy_buffers)
else:
warnings.warn(
(
'A new parameter deep_copy_buffers has been added to ActOnArgs._on_copy(). '
'The classes that inherit from ActOnArgs should support it before Cirq 0.15.'
),
DeprecationWarning,
)
self._on_copy(args)
args._log_of_measurement_results = self.log_of_measurement_results.copy()
return args

def _on_copy(self: TSelf, args: TSelf):
def _on_copy(self: TSelf, args: TSelf, deep_copy_buffers: bool = True):
"""Subclasses should implement this with any additional state copy
functionality."""

Expand Down
18 changes: 16 additions & 2 deletions cirq-core/cirq/sim/act_on_args_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

from collections import abc
import inspect
from typing import (
Dict,
TYPE_CHECKING,
Expand All @@ -25,6 +26,7 @@
List,
Union,
)
import warnings

import numpy as np

Expand Down Expand Up @@ -131,9 +133,21 @@ def _act_on_fallback_(
self.args[q] = op_args
return True

def copy(self) -> 'cirq.ActOnArgsContainer[TActOnArgs]':
def copy(self, deep_copy_buffers: bool = True) -> 'cirq.ActOnArgsContainer[TActOnArgs]':
logs = self.log_of_measurement_results.copy()
copies = {a: a.copy() for a in set(self.args.values())}
copies = {}
for act_on_args in set(self.args.values()):
if 'deep_copy_buffers' in inspect.signature(act_on_args.copy).parameters:
copies[act_on_args] = act_on_args.copy(deep_copy_buffers)
else:
warnings.warn(
(
'A new parameter deep_copy_buffers has been added to ActOnArgs.copy(). The '
'classes that inherit from ActOnArgs should support it before Cirq 0.15.'
),
DeprecationWarning,
)
copies[act_on_args] = act_on_args.copy()
for copy in copies.values():
copy._log_of_measurement_results = logs
args = {q: copies[a] for q, a in self.args.items()}
Expand Down
9 changes: 8 additions & 1 deletion cirq-core/cirq/sim/act_on_args_container_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ def __init__(self, qubits, logs):
def _perform_measurement(self, qubits: Sequence[cirq.Qid]) -> List[int]:
return [0] * len(qubits)

def copy(self) -> 'EmptyActOnArgs':
def copy(self) -> 'EmptyActOnArgs': # type: ignore
"""The deep_copy_buffers parameter is omitted to trigger a deprecation warning test."""
return EmptyActOnArgs(
qubits=self.qubits,
logs=self.log_of_measurement_results.copy(),
Expand Down Expand Up @@ -226,6 +227,12 @@ def test_copy_succeeds():
assert copied.qubits == (q0, q1)


def test_copy_deprecation_warning():
args = create_container(qs2, False)
with cirq.testing.assert_deprecated('deep_copy_buffers', deadline='0.15'):
args.copy(False)


def test_merge_succeeds():
args = create_container(qs2, False)
merged = args.create_merged_state()
Expand Down
9 changes: 9 additions & 0 deletions cirq-core/cirq/sim/act_on_args_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ def _act_on_fallback_(
) -> bool:
return True

def _on_copy(self, args):
return super()._on_copy(args)


def test_measurements():
args = DummyArgs()
Expand Down Expand Up @@ -89,3 +92,9 @@ def test_transpose_qubits():
args.transpose_to_qubit_order((q0, q2))
with pytest.raises(ValueError, match='Qubits do not match'):
args.transpose_to_qubit_order((q0, q1, q1))


def test_on_copy_has_no_param():
args = DummyArgs()
with cirq.testing.assert_deprecated('deep_copy_buffers', deadline='0.15'):
args.copy(False)
39 changes: 29 additions & 10 deletions cirq-core/cirq/sim/act_on_density_matrix_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.
"""Objects and methods for acting efficiently on a density matrix."""

from typing import Any, Dict, List, Tuple, TYPE_CHECKING, Sequence, Union
from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING, Sequence, Union

import numpy as np

Expand All @@ -36,11 +36,11 @@ class ActOnDensityMatrixArgs(ActOnArgs):
def __init__(
self,
target_tensor: np.ndarray,
available_buffer: List[np.ndarray],
qid_shape: Tuple[int, ...],
prng: np.random.RandomState = None,
log_of_measurement_results: Dict[str, Any] = None,
qubits: Sequence['cirq.Qid'] = None,
available_buffer: Optional[List[np.ndarray]] = None,
qid_shape: Optional[Tuple[int, ...]] = None,
prng: Optional[np.random.RandomState] = None,
log_of_measurement_results: Optional[Dict[str, Any]] = None,
qubits: Optional[Sequence['cirq.Qid']] = None,
ignore_measurement_results: bool = False,
):
"""Inits ActOnDensityMatrixArgs.
Expand All @@ -65,11 +65,27 @@ def __init__(
will treat measurement as dephasing instead of collapsing
process. This is only applicable to simulators that can
model dephasing.

Raises:
ValueError: The dimension of `target_tensor` is not divisible by 2
and `qid_shape` is not provided.
"""
super().__init__(prng, qubits, log_of_measurement_results, ignore_measurement_results)
self.target_tensor = target_tensor
self.available_buffer = available_buffer
self.qid_shape = qid_shape
if available_buffer is None:
self.available_buffer = [np.empty_like(target_tensor) for _ in range(3)]
else:
self.available_buffer = available_buffer
if qid_shape is None:
target_shape = target_tensor.shape
if len(target_shape) % 2 != 0:
raise ValueError(
'The dimension of target_tensor is not divisible by 2.'
' Require explicit qid_shape.'
)
self.qid_shape = target_shape[: len(target_shape) // 2]
else:
self.qid_shape = qid_shape

def _act_on_fallback_(
self,
Expand Down Expand Up @@ -108,9 +124,12 @@ def _perform_measurement(self, qubits: Sequence['cirq.Qid']) -> List[int]:
)
return bits

def _on_copy(self, target: 'cirq.ActOnDensityMatrixArgs'):
def _on_copy(self, target: 'cirq.ActOnDensityMatrixArgs', deep_copy_buffers: bool = True):
target.target_tensor = self.target_tensor.copy()
target.available_buffer = [b.copy() for b in self.available_buffer]
if deep_copy_buffers:
target.available_buffer = [b.copy() for b in self.available_buffer]
else:
target.available_buffer = self.available_buffer

def _on_kronecker_product(
self, other: 'cirq.ActOnDensityMatrixArgs', target: 'cirq.ActOnDensityMatrixArgs'
Expand Down
29 changes: 29 additions & 0 deletions cirq-core/cirq/sim/act_on_density_matrix_args_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,35 @@
import cirq


def test_default_parameter():
qid_shape = (2,)
tensor = cirq.to_valid_density_matrix(
0, len(qid_shape), qid_shape=qid_shape, dtype=np.complex64
)
args = cirq.ActOnDensityMatrixArgs(target_tensor=tensor)
assert len(args.available_buffer) == 3
for buffer in args.available_buffer:
assert buffer.shape == tensor.shape
assert buffer.dtype == tensor.dtype
assert args.qid_shape == qid_shape


def test_shallow_copy_buffers():
qid_shape = (2,)
tensor = cirq.to_valid_density_matrix(
0, len(qid_shape), qid_shape=qid_shape, dtype=np.complex64
)
args = cirq.ActOnDensityMatrixArgs(target_tensor=tensor)
copy = args.copy(deep_copy_buffers=False)
assert copy.available_buffer is args.available_buffer


def test_default_parameter_error():
tensor = np.ndarray(shape=(2,))
with pytest.raises(ValueError, match='The dimension of target_tensor is not divisible by 2'):
cirq.ActOnDensityMatrixArgs(target_tensor=tensor)


def test_decomposed_fallback():
class Composite(cirq.Gate):
def num_qubits(self) -> int:
Expand Down
22 changes: 14 additions & 8 deletions cirq-core/cirq/sim/act_on_state_vector_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.
"""Objects and methods for acting efficiently on a state vector."""

from typing import Any, Tuple, TYPE_CHECKING, Union, Dict, List, Sequence
from typing import Any, Optional, Tuple, TYPE_CHECKING, Union, Dict, List, Sequence

import numpy as np

Expand All @@ -40,10 +40,10 @@ class ActOnStateVectorArgs(ActOnArgs):
def __init__(
self,
target_tensor: np.ndarray,
available_buffer: np.ndarray,
prng: np.random.RandomState = None,
log_of_measurement_results: Dict[str, Any] = None,
qubits: Sequence['cirq.Qid'] = None,
available_buffer: Optional[np.ndarray] = None,
prng: Optional[np.random.RandomState] = None,
log_of_measurement_results: Optional[Dict[str, Any]] = None,
qubits: Optional[Sequence['cirq.Qid']] = None,
):
"""Inits ActOnStateVectorArgs.

Expand All @@ -66,7 +66,10 @@ def __init__(
"""
super().__init__(prng, qubits, log_of_measurement_results)
self.target_tensor = target_tensor
self.available_buffer = available_buffer
if available_buffer is None:
self.available_buffer = np.empty_like(target_tensor)
else:
self.available_buffer = available_buffer

def swap_target_tensor_for(self, new_target_tensor: np.ndarray):
"""Gives a new state vector for the system.
Expand Down Expand Up @@ -174,9 +177,12 @@ def _perform_measurement(self, qubits: Sequence['cirq.Qid']) -> List[int]:
)
return bits

def _on_copy(self, target: 'cirq.ActOnStateVectorArgs'):
def _on_copy(self, target: 'cirq.ActOnStateVectorArgs', deep_copy_buffers: bool = True):
target.target_tensor = self.target_tensor.copy()
target.available_buffer = self.available_buffer.copy()
if deep_copy_buffers:
target.available_buffer = self.available_buffer.copy()
else:
target.available_buffer = self.available_buffer

def _on_kronecker_product(
self, other: 'cirq.ActOnStateVectorArgs', target: 'cirq.ActOnStateVectorArgs'
Expand Down
14 changes: 14 additions & 0 deletions cirq-core/cirq/sim/act_on_state_vector_args_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,20 @@
import cirq


def test_default_parameter():
target_tensor = cirq.one_hot(shape=(2, 2, 2), dtype=np.complex64)
args = cirq.ActOnStateVectorArgs(target_tensor)
assert args.available_buffer.shape == target_tensor.shape
assert args.available_buffer.dtype == target_tensor.dtype


def test_shallow_copy_buffers():
target_tensor = cirq.one_hot(shape=(2, 2, 2), dtype=np.complex64)
args = cirq.ActOnStateVectorArgs(target_tensor)
copy = args.copy(deep_copy_buffers=False)
assert copy.available_buffer is args.available_buffer


def test_decomposed_fallback():
class Composite(cirq.Gate):
def num_qubits(self) -> int:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def _perform_measurement(self, qubits: Sequence['cirq.Qid']) -> List[int]:
"""Returns the measurement from the tableau."""
return [self.tableau._measure(self.qubit_map[q], self.prng) for q in qubits]

def _on_copy(self, target: 'ActOnCliffordTableauArgs'):
def _on_copy(self, target: 'ActOnCliffordTableauArgs', deep_copy_buffers: bool = True):
target.tableau = self.tableau.copy()

def sample(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def _perform_measurement(self, qubits: Sequence['cirq.Qid']) -> List[int]:
"""Returns the measurement from the stabilizer state form."""
return [self.state._measure(self.qubit_map[q], self.prng) for q in qubits]

def _on_copy(self, target: 'ActOnStabilizerCHFormArgs'):
def _on_copy(self, target: 'ActOnStabilizerCHFormArgs', deep_copy_buffers: bool = True):
target.state = self.state.copy()

def sample(
Expand Down
13 changes: 11 additions & 2 deletions cirq-core/cirq/sim/operation_target.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,17 @@ def apply_operation(self, op: 'cirq.Operation'):
protocols.act_on(op, self)

@abc.abstractmethod
def copy(self: TSelfTarget) -> TSelfTarget:
"""Copies the object."""
def copy(self: TSelfTarget, deep_copy_buffers: bool = True) -> TSelfTarget:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Update the docstring here too since it's the base interface.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated.

"""Creates a copy of the object.

Args:
deep_copy_buffers: If True, buffers will also be deep-copied.
Otherwise the copy will share a reference to the original object's
buffers.

Returns:
A copied instance.
"""

@property
@abc.abstractmethod
Expand Down
Loading