Skip to content

Commit

Permalink
Initialize DefaultQubit2 local rng from global rng if no seed provi…
Browse files Browse the repository at this point in the history
…ded (#4394)

* initialize dev rng from global state

* change default value to global str

* Apply suggestions from code review

Co-authored-by: Frederik Wilde <[email protected]>

* Update tests/devices/experimental/test_default_qubit_2.py

* fix indentation

---------

Co-authored-by: Frederik Wilde <[email protected]>
  • Loading branch information
2 people authored and mudit2812 committed Aug 2, 2023
1 parent d72c2e2 commit 8265e8b
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 9 deletions.
4 changes: 4 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,10 @@
* When given a callable, `qml.ctrl` now does its custom pre-processing on all queued operators from the callable.
[(#4370)](https://github.com/PennyLaneAI/pennylane/pull/4370)

* If no seed is specified on initialization with `DefaultQubit2`, the local random number generator will be
seeded from on the NumPy's global random number generator.
[(#4394)](https://github.com/PennyLaneAI/pennylane/pull/4394)

* The experimental `DefaultQubit2` device now supports computing VJPs and JVPs using the adjoint method.
[(#4374)](https://github.com/PennyLaneAI/pennylane/pull/4374)

Expand Down
17 changes: 10 additions & 7 deletions pennylane/devices/experimental/default_qubit_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,14 @@
class DefaultQubit2(Device):
"""A PennyLane device written in Python and capable of backpropagation derivatives.
Args:
seed (Union[None, int, array_like[int], SeedSequence, BitGenerator, Generator]): A
seed-like parameter matching that of ``seed`` for ``numpy.random.default_rng``.
If no value is provided, a default RNG will be used.
max_workers (int): A ``ProcessPoolExecutor`` executes tapes asynchronously
Keyword Args:
seed="global" (Union[str, None, int, array_like[int], SeedSequence, BitGenerator, Generator]): A
seed-like parameter matching that of ``seed`` for ``numpy.random.default_rng`` or
a request to seed from numpy's global random number generator.
The default, ``seed="global"`` pulls a seed from NumPy's global generator. ``seed=None``
will pull a seed from the OS entropy.
max_workers=None (int): A ``ProcessPoolExecutor`` executes tapes asynchronously
using a pool of at most ``max_workers`` processes. If ``max_workers`` is ``None``,
only the current process executes tapes. If you experience any
issue, say using JAX, TensorFlow, Torch, try setting ``max_workers`` to ``None``.
Expand Down Expand Up @@ -135,9 +137,10 @@ def name(self):
"""The name of the device."""
return "default.qubit.2"

def __init__(self, seed=None, max_workers=None) -> None:
def __init__(self, seed="global", max_workers=None) -> None:
super().__init__()
self._max_workers = max_workers
seed = np.random.randint(0, high=10000000) if seed == "global" else seed
self._rng = np.random.default_rng(seed)
self._debugger = None

Expand Down
45 changes: 43 additions & 2 deletions tests/devices/experimental/test_default_qubit_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for default qubit 2."""
# pylint: disable=import-outside-toplevel
# pylint: disable=import-outside-toplevel, no-member

import pytest

Expand Down Expand Up @@ -1456,7 +1456,7 @@ def test_different_executions(self, measurements, max_workers):
[qml.sample(wires=0), qml.expval(qml.PauliZ(0)), qml.probs(wires=0)],
],
)
def test_global_seed(self, measurements, max_workers):
def test_global_seed_and_device_seed(self, measurements, max_workers):
"""Test that a global seed does not affect the result of devices
provided with a seed"""
qs = qml.tape.QuantumScript([qml.Hadamard(0)], measurements, shots=1000)
Expand All @@ -1477,6 +1477,47 @@ def test_global_seed(self, measurements, max_workers):

assert all(np.all(res1 == res2) for res1, res2 in zip(result1, result2))

def test_global_seed_no_device_seed_by_default(self):
"""Test that the global numpy seed initializes the rng if device seed is none."""
np.random.seed(42)
dev = DefaultQubit2()
first_num = dev._rng.random() # pylint: disable=protected-access

np.random.seed(42)
dev2 = DefaultQubit2()
second_num = dev2._rng.random() # pylint: disable=protected-access

assert qml.math.allclose(first_num, second_num)

np.random.seed(42)
dev2 = DefaultQubit2(seed="global")
third_num = dev2._rng.random() # pylint: disable=protected-access

assert qml.math.allclose(third_num, first_num)

def test_None_seed_not_using_global_rng(self):
"""Test that if the seed is None, it is uncorrelated with the global rng."""
np.random.seed(42)
dev = DefaultQubit2(seed=None)
first_nums = dev._rng.random(10) # pylint: disable=protected-access

np.random.seed(42)
dev2 = DefaultQubit2(seed=None)
second_nums = dev2._rng.random(10) # pylint: disable=protected-access

assert not qml.math.allclose(first_nums, second_nums)

def test_rng_as_seed(self):
"""Test that a PRNG can be passed as a seed."""
rng1 = np.random.default_rng(42)
first_num = rng1.random()

rng = np.random.default_rng(42)
dev = DefaultQubit2(seed=rng)
second_num = dev._rng.random() # pylint: disable=protected-access

assert qml.math.allclose(first_num, second_num)


class TestHamiltonianSamples:
"""Test that the measure_with_samples function works as expected for
Expand Down

0 comments on commit 8265e8b

Please sign in to comment.