From 6bcfdf9cd24444f7dc0c9216a7b1a673a08b1f88 Mon Sep 17 00:00:00 2001 From: Christina Lee Date: Fri, 18 Aug 2023 19:40:21 -0400 Subject: [PATCH] [BUGFIX] Fix `math.get_dtype_name` for `ArrayBox` (#4494) * get_dtype_name with array boxes * Update doc/releases/changelog-dev.md * Update tests/math/test_functions.py * try and fix docstring * Update tests/math/test_functions.py * just remove uncovered line --------- Co-authored-by: Matthew Silverman Co-authored-by: Mudit Pandey --- doc/releases/changelog-dev.md | 3 +++ pennylane/math/__init__.py | 12 +++++++++++- pennylane/math/single_dispatch.py | 9 +++++++++ tests/math/test_functions.py | 25 ++++++++++++++++++++++++- 4 files changed, 47 insertions(+), 2 deletions(-) diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index ef7b0df39a1..729812f61fa 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -373,6 +373,9 @@ array([False, False])

Bug fixes 🐛

+* `qml.math.get_dtype_name` now works with autograd array boxes. + [(#4494)](https://github.com/PennyLaneAI/pennylane/pull/4494) + * `_copy_and_shift_params` does not cast or convert integral types, just relying on `+` and `*`'s casting rules in this case. [(#4477)](https://github.com/PennyLaneAI/pennylane/pull/4477) diff --git a/pennylane/math/__init__.py b/pennylane/math/__init__.py index cc3fe0ec0af..e6ea3aec4f7 100644 --- a/pennylane/math/__init__.py +++ b/pennylane/math/__init__.py @@ -95,7 +95,16 @@ sum = ar.numpy.sum toarray = ar.numpy.to_numpy T = ar.numpy.transpose -get_dtype_name = ar.get_dtype_name + + +def get_dtype_name(x) -> str: + """An interface independent way of getting the name of the datatype. + + >>> x = tf.Variable(0.1) + >>> qml.math.get_dtype_name(tf.Variable(0.1)) + 'float32' + """ + return ar.get_dtype_name(x) class NumpyMimic(ar.autoray.NumpyMimic): @@ -142,6 +151,7 @@ def __getattr__(name): "fidelity", "fidelity_statevector", "frobenius_inner_product", + "get_dtype_name", "get_interface", "get_trainable_indices", "in_backprop", diff --git a/pennylane/math/single_dispatch.py b/pennylane/math/single_dispatch.py index 6fe4fc8e684..1ca2a96dfde 100644 --- a/pennylane/math/single_dispatch.py +++ b/pennylane/math/single_dispatch.py @@ -116,6 +116,15 @@ def _cond(pred, true_fn, false_fn, args): ar.register_function("autograd", "unstack", list) +def autograd_get_dtype_name(x): + """A autograd version of get_dtype_name that can handle array boxes.""" + # this function seems to only get called with x is an arraybox. + return ar.get_dtype_name(x._value) + + +ar.register_function("autograd", "get_dtype_name", autograd_get_dtype_name) + + def _block_diag_autograd(tensors): """Autograd implementation of scipy.linalg.block_diag""" _np = _i("qml").numpy diff --git a/tests/math/test_functions.py b/tests/math/test_functions.py index 90713998423..913ad85a407 100644 --- a/tests/math/test_functions.py +++ b/tests/math/test_functions.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Unit tests for the TensorBox functional API in pennylane.fn.fn +"""Unit tests for pennylane.math.single_dispatch """ # pylint: disable=import-outside-toplevel import itertools @@ -1327,6 +1327,29 @@ def test_shape(shape, interface, create_array): assert fn.shape(t) == shape +@pytest.mark.parametrize( + "x, expected", + ( + (1.0, "float64"), + (1, "int64"), + (onp.array(0.5), "float64"), + (onp.array(1.0, dtype="float32"), "float32"), + (ArrayBox(1, "a", "b"), "int64"), + (np.array(0.5), "float64"), + (np.array(0.5, dtype="complex64"), "complex64"), + # skip jax as output is dependent on global configuration + (tf.Variable(0.1, dtype="float32"), "float32"), + (tf.Variable(0.1, dtype="float64"), "float64"), + (torch.tensor(0.1, dtype=torch.float32), "float32"), + (torch.tensor(0.5, dtype=torch.float64), "float64"), + (torch.tensor(0.1, dtype=torch.complex128), "complex128"), + ), +) +def test_get_dtype_name(x, expected): + """Test that get_dtype_name returns the a string for the datatype.""" + assert fn.get_dtype_name(x) == expected + + @pytest.mark.parametrize("t", test_data) def test_sqrt(t): """Test that the square root function works for a variety