Skip to content

Commit

Permalink
Fix jax.ad deprecation. (#4403)
Browse files Browse the repository at this point in the history
* Fix jax.ad deprecation.

* Update changelog.

* Print name of requested device in DeviceError.

* Force push.

* Fix pytest.raises in test_device.py
  • Loading branch information
vincentmr authored Jul 31, 2023
1 parent 7d64bdc commit 81c8477
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 15 deletions.
5 changes: 4 additions & 1 deletion doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -182,8 +182,11 @@

<h3>Bug fixes 🐛</h3>

* Replace deprecated `jax.ad` by `jax.interpreters.ad`.
[(#4403)](https://github.com/PennyLaneAI/pennylane/pull/4403)

* Stop `metric_tensor` from accidentally catching errors that stem from
flawed wires assignments in the original circuit, leading to recursion errors
flawed wires assignments in the original circuit, leading to recursion errors.
[(#4328)](https://github.com/PennyLaneAI/pennylane/pull/4328)

* Raise a warning if control indicators are hidden when calling `qml.draw_mpl`
Expand Down
2 changes: 1 addition & 1 deletion pennylane/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,7 @@ def run_cnot():

return dev

raise DeviceError("Device does not exist. Make sure the required plugin is installed.")
raise DeviceError(f"Device {name} does not exist. Make sure the required plugin is installed.")


def version():
Expand Down
8 changes: 5 additions & 3 deletions pennylane/math/multi_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -932,11 +932,13 @@ def jax_argnums_to_tape_trainable(qnode, argnums, expand_fn, args, kwargs):
"""
import jax

with jax.core.new_main(jax.ad.JVPTrace) as main:
trace = jax.ad.JVPTrace(main, 0)
with jax.core.new_main(jax.interpreters.ad.JVPTrace) as main:
trace = jax.interpreters.ad.JVPTrace(main, 0)

args_jvp = [
jax.ad.JVPTracer(trace, arg, jax.numpy.zeros(arg.shape)) if i in argnums else arg
jax.interpreters.ad.JVPTracer(trace, arg, jax.numpy.zeros(arg.shape))
if i in argnums
else arg
for i, arg in enumerate(args)
]

Expand Down
28 changes: 18 additions & 10 deletions tests/test_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

mock_device_paulis = ["PauliX", "PauliY", "PauliZ"]

# pylint: disable=abstract-class-instantiated, no-self-use, redefined-outer-name, invalid-name
# pylint: disable=abstract-class-instantiated, no-self-use, redefined-outer-name, invalid-name, missing-function-docstring


@pytest.fixture(scope="function")
Expand Down Expand Up @@ -253,16 +253,16 @@ def test_supports_observable_exception(self, mock_device):
class TestInternalFunctions:
"""Test the internal functions of the abstract Device class"""

# pylint: disable=unnecessary-dunder-call
def test_repr(self, mock_device_with_operations):
"""Tests the __repr__ function"""
dev = mock_device_with_operations()
repr_string = dev.__repr__()
assert "<Device device (wires=1, shots=1000) at " in repr_string
assert "<Device device (wires=1, shots=1000) at " in dev.__repr__()

def test_str(self, mock_device_with_operations):
"""Tests the __str__ function"""
dev = mock_device_with_operations()
string = dev.__str__()
string = str(dev)
assert "Short name: MockDevice" in string
assert "Package: pennylane" in string
assert "Plugin version: None" in string
Expand Down Expand Up @@ -810,7 +810,7 @@ class TestDeviceInit:
def test_no_device(self):
"""Test that an exception is raised for a device that doesn't exist"""

with pytest.raises(DeviceError, match="Device does not exist"):
with pytest.raises(DeviceError, match="Device None does not exist"):
qml.device("None", wires=0)

def test_outdated_API(self, monkeypatch):
Expand Down Expand Up @@ -963,7 +963,7 @@ def test_result_empty_tape(self, mock_device_with_paulis_and_methods, tol):
class TestGrouping:
"""Tests for the use_grouping option for devices."""

# pylint: disable=too-few-public-methods
# pylint: disable=too-few-public-methods, unused-argument, missing-function-docstring, missing-class-docstring
class SomeDevice(qml.Device):
name = ""
short_name = ""
Expand All @@ -972,10 +972,18 @@ class SomeDevice(qml.Device):
author = ""
operations = ""
observables = ""
apply = lambda *args, **kwargs: 0
expval = lambda *args, **kwargs: 0
reset = lambda *args, **kwargs: 0
supports_observable = lambda *args, **kwargs: True

def apply(self, *args, **kwargs):
return 0

def expval(self, *args, **kwargs):
return 0

def reset(self, *args, **kwargs):
return 0

def supports_observable(self, *args, **kwargs):
return True

# pylint: disable=attribute-defined-outside-init
@pytest.mark.parametrize("use_grouping", (True, False))
Expand Down

0 comments on commit 81c8477

Please sign in to comment.