Skip to content

Commit

Permalink
[BUGFIX] Fix math.get_dtype_name for ArrayBox (#4494)
Browse files Browse the repository at this point in the history
* get_dtype_name with array boxes

* Update doc/releases/changelog-dev.md

* Update tests/math/test_functions.py

* try and fix docstring

* Update tests/math/test_functions.py

* just remove uncovered line

---------

Co-authored-by: Matthew Silverman <[email protected]>
Co-authored-by: Mudit Pandey <[email protected]>
  • Loading branch information
3 people authored Aug 18, 2023
1 parent db58c54 commit 6bcfdf9
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 2 deletions.
3 changes: 3 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,9 @@ array([False, False])

<h3>Bug fixes 🐛</h3>

* `qml.math.get_dtype_name` now works with autograd array boxes.
[(#4494)](https://github.com/PennyLaneAI/pennylane/pull/4494)

* `_copy_and_shift_params` does not cast or convert integral types, just relying on `+` and `*`'s casting rules in this case.
[(#4477)](https://github.com/PennyLaneAI/pennylane/pull/4477)

Expand Down
12 changes: 11 additions & 1 deletion pennylane/math/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,16 @@
sum = ar.numpy.sum
toarray = ar.numpy.to_numpy
T = ar.numpy.transpose
get_dtype_name = ar.get_dtype_name


def get_dtype_name(x) -> str:
"""An interface independent way of getting the name of the datatype.
>>> x = tf.Variable(0.1)
>>> qml.math.get_dtype_name(tf.Variable(0.1))
'float32'
"""
return ar.get_dtype_name(x)


class NumpyMimic(ar.autoray.NumpyMimic):
Expand Down Expand Up @@ -142,6 +151,7 @@ def __getattr__(name):
"fidelity",
"fidelity_statevector",
"frobenius_inner_product",
"get_dtype_name",
"get_interface",
"get_trainable_indices",
"in_backprop",
Expand Down
9 changes: 9 additions & 0 deletions pennylane/math/single_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,15 @@ def _cond(pred, true_fn, false_fn, args):
ar.register_function("autograd", "unstack", list)


def autograd_get_dtype_name(x):
"""A autograd version of get_dtype_name that can handle array boxes."""
# this function seems to only get called with x is an arraybox.
return ar.get_dtype_name(x._value)


ar.register_function("autograd", "get_dtype_name", autograd_get_dtype_name)


def _block_diag_autograd(tensors):
"""Autograd implementation of scipy.linalg.block_diag"""
_np = _i("qml").numpy
Expand Down
25 changes: 24 additions & 1 deletion tests/math/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Unit tests for the TensorBox functional API in pennylane.fn.fn
"""Unit tests for pennylane.math.single_dispatch
"""
# pylint: disable=import-outside-toplevel
import itertools
Expand Down Expand Up @@ -1327,6 +1327,29 @@ def test_shape(shape, interface, create_array):
assert fn.shape(t) == shape


@pytest.mark.parametrize(
"x, expected",
(
(1.0, "float64"),
(1, "int64"),
(onp.array(0.5), "float64"),
(onp.array(1.0, dtype="float32"), "float32"),
(ArrayBox(1, "a", "b"), "int64"),
(np.array(0.5), "float64"),
(np.array(0.5, dtype="complex64"), "complex64"),
# skip jax as output is dependent on global configuration
(tf.Variable(0.1, dtype="float32"), "float32"),
(tf.Variable(0.1, dtype="float64"), "float64"),
(torch.tensor(0.1, dtype=torch.float32), "float32"),
(torch.tensor(0.5, dtype=torch.float64), "float64"),
(torch.tensor(0.1, dtype=torch.complex128), "complex128"),
),
)
def test_get_dtype_name(x, expected):
"""Test that get_dtype_name returns the a string for the datatype."""
assert fn.get_dtype_name(x) == expected


@pytest.mark.parametrize("t", test_data)
def test_sqrt(t):
"""Test that the square root function works for a variety
Expand Down

0 comments on commit 6bcfdf9

Please sign in to comment.