Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

implement dpnp.tensordot #1699

Merged
merged 10 commits into from
Feb 9, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
145 changes: 128 additions & 17 deletions dpnp/dpnp_iface_linearalgebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@


import numpy
from numpy.core.numeric import normalize_axis_tuple

import dpnp
from dpnp.dpnp_algo import *
Expand Down Expand Up @@ -404,42 +405,152 @@ def outer(x1, x2, out=None):
return call_origin(numpy.outer, x1, x2, out=out)


def tensordot(x1, x2, axes=2):
"""
def tensordot(a, b, axes=2):
r"""
Compute tensor dot product along specified axes.

For full documentation refer to :obj:`numpy.tensordot`.

Limitations
-----------
Parameters `x1` and `x2` are supported as :obj:`dpnp.ndarray`.
Keyword argument `kwargs` is currently unsupported.
Parameter `axes` is supported only with value ``1``.
Otherwise the functions will be executed sequentially on CPU.
Input array data types are limited by supported DPNP :ref:`Data types`.
Parameters
----------
a : {dpnp_array, usm_ndarray, scalar}
First input array. Both inputs `a` and `b` can not be scalars at the same time.
b : {dpnp_array, usm_ndarray, scalar}
Second input array. Both inputs `a` and `b` can not be scalars at the same time.
axes : int or (2,) array_like
* integer_like
If an int `N`, sum over the last `N` axes of `a` and the first `N` axes
of `b` in order. The sizes of the corresponding axes must match.
* (2,) array_like
Or, a list of axes to be summed over, first sequence applying to `a`,
second to `b`. Both elements array_like must be of the same length.

Returns
-------
out : dpnp.ndarray
Returns the tensordot product of `a` and `b`.

See Also
--------
:obj:`dpnp.dot` : Returns the dot product.
:obj:`dpnp.einsum` : Evaluates the Einstein summation convention on the operands.

Notes
-----
Three common use cases are:
* ``axes = 0`` : tensor product :math:`a\\otimes b`
* ``axes = 1`` : tensor dot product :math:`a\\cdot b`
vtavana marked this conversation as resolved.
Show resolved Hide resolved
* ``axes = 2`` : (default) tensor double contraction :math:`a:b`

When `axes` is integer, the sequence for evaluation will be: first
the -Nth axis in `a` and 0th axis in `b`, and the -1th axis in `a` and
Nth axis in `b` last.

When there is more than one axis to sum over - and they are not the last
(first) axes of `a` (`b`) - the argument `axes` should consist of
two sequences of the same length, with the first axis to sum over given
first in both sequences, the second axis second, and so forth.

The shape of the result consists of the non-contracted axes of the
first tensor, followed by the non-contracted axes of the second.

Examples
--------
>>> import dpnp as np
>>> a = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
>>> b = np.array([1, 2, 3])
>>> result = np.tensordot(a, b, 1)
>>> [x for x in result]
[14, 32, 50]
>>> np.tensordot(a, b, 1)
array([14, 32, 50])

>>> a = np.arange(60.).reshape(3,4,5)
>>> b = np.arange(24.).reshape(4,3,2)
>>> c = np.tensordot(a,b, axes=([1,0],[0,1]))
>>> c.shape
(5, 2)
>>> c
array([[4400., 4730.],
[4532., 4874.],
[4664., 5018.],
[4796., 5162.],
[4928., 5306.]])

A slower but equivalent way of computing the same...

>>> d = np.zeros((5,2))
>>> for i in range(5):
... for j in range(2):
... for k in range(3):
... for n in range(4):
... d[i,j] += a[k,n,i] * b[n,k,j]
>>> c == d
array([[ True, True],
[ True, True],
[ True, True],
[ True, True],
[ True, True]])

"""

x1_desc = dpnp.get_dpnp_descriptor(x1, copy_when_nondefault_queue=False)
x2_desc = dpnp.get_dpnp_descriptor(x2, copy_when_nondefault_queue=False)
if x1_desc and x2_desc and (axes == 1):
return dpnp_tensordot_not_implemented(x1_desc, x2_desc) # dpnp_matmul
dpnp.check_supported_arrays_type(a, b, scalar_type=True)

return call_origin(numpy.tensordot, x1, x2, axes)
if dpnp.isscalar(a):
a = dpnp.array(a, sycl_queue=b.sycl_queue, usm_type=b.usm_type)
elif dpnp.isscalar(b):
b = dpnp.array(b, sycl_queue=a.sycl_queue, usm_type=a.usm_type)

try:
iter(axes)
except Exception:
if not isinstance(axes, int):
raise ValueError("Axes must be an integer.")
vtavana marked this conversation as resolved.
Show resolved Hide resolved
axes_a = tuple(range(-axes, 0))
axes_b = tuple(range(0, axes))
else:
if len(axes) != 2:
raise ValueError("Axes must consist of two sequences.")

axes_a, axes_b = axes
axes_a = (axes_a,) if dpnp.isscalar(axes_a) else axes_a
axes_b = (axes_b,) if dpnp.isscalar(axes_b) else axes_b

if len(axes_a) != len(axes_b):
raise ValueError("Axes length mismatch.")

a_shape = a.shape
b_shape = b.shape
for axis_a, axis_b in zip(axes_a, axes_b):
if a_shape[axis_a] != b_shape[axis_b]:
raise ValueError(
"shape of input arrays is not similar at requested axes."
)

# Make the axes non-negative
a_ndim = a.ndim
b_ndim = b.ndim
axes_a = normalize_axis_tuple(axes_a, a_ndim, "axis")
axes_b = normalize_axis_tuple(axes_b, b_ndim, "axis")

# Move the axes to sum over, to the end of "a"
notin = tuple(k for k in range(a_ndim) if k not in axes_a)
newaxes_a = notin + axes_a
N1 = int(numpy.prod([a_shape[ax] for ax in notin]))
N2 = int(numpy.prod([a_shape[ax] for ax in axes_a]))
newshape_a = (N1, N2)
olda = [a_shape[axis] for axis in notin]

# Move the axes to sum over, to the front of "b"
notin = tuple(k for k in range(b_ndim) if k not in axes_b)
newaxes_b = tuple(axes_b + notin)
N1 = int(numpy.prod([b_shape[ax] for ax in axes_b]))
N2 = int(numpy.prod([b_shape[ax] for ax in notin]))
newshape_b = (N1, N2)
oldb = [b_shape[axis] for axis in notin]

at = a.transpose(newaxes_a).reshape(newshape_a)
bt = b.transpose(newaxes_b).reshape(newshape_b)
res = dpnp.matmul(at, bt)

return res.reshape(olda + oldb)


def vdot(a, b):
Expand Down
2 changes: 0 additions & 2 deletions dpnp/dpnp_iface_sorting.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
# cython: language_level=3
# distutils: language = c++
# -*- coding: utf-8 -*-
# *****************************************************************************
# Copyright (c) 2016-2024, Intel Corporation
Expand Down
4 changes: 0 additions & 4 deletions tests/skipped_tests.tbl
Original file line number Diff line number Diff line change
Expand Up @@ -335,10 +335,6 @@ tests/third_party/cupy/linalg_tests/test_einsum.py::TestListArgEinSumError::test
tests/third_party/cupy/linalg_tests/test_product.py::TestMatrixPower::test_matrix_power_invlarge
tests/third_party/cupy/linalg_tests/test_product.py::TestMatrixPower::test_matrix_power_large
tests/third_party/cupy/linalg_tests/test_product.py::TestMatrixPower::test_matrix_power_of_two
tests/third_party/cupy/linalg_tests/test_product.py::TestProduct::test_tensordot_zero_dim
tests/third_party/cupy/linalg_tests/test_product.py::TestProduct::test_transposed_tensordot
tests/third_party/cupy/linalg_tests/test_product.py::TestProduct::test_transposed_tensordot_with_int_axes
tests/third_party/cupy/linalg_tests/test_product.py::TestProduct::test_transposed_tensordot_with_list_axes

tests/third_party/cupy/logic_tests/test_comparison.py::TestArrayEqual::test_array_equal_broadcast_not_allowed
tests/third_party/cupy/logic_tests/test_comparison.py::TestArrayEqual::test_array_equal_diff_dtypes_is_equal
Expand Down
4 changes: 0 additions & 4 deletions tests/skipped_tests_gpu.tbl
Original file line number Diff line number Diff line change
Expand Up @@ -437,10 +437,6 @@ tests/third_party/cupy/linalg_tests/test_einsum.py::TestListArgEinSumError::test
tests/third_party/cupy/linalg_tests/test_product.py::TestMatrixPower::test_matrix_power_invlarge
tests/third_party/cupy/linalg_tests/test_product.py::TestMatrixPower::test_matrix_power_large
tests/third_party/cupy/linalg_tests/test_product.py::TestMatrixPower::test_matrix_power_of_two
tests/third_party/cupy/linalg_tests/test_product.py::TestProduct::test_transposed_tensordot
tests/third_party/cupy/linalg_tests/test_product.py::TestProduct::test_transposed_tensordot_with_int_axes
tests/third_party/cupy/linalg_tests/test_product.py::TestProduct::test_transposed_tensordot_with_list_axes
tests/third_party/cupy/linalg_tests/test_product.py::TestProduct::test_tensordot_zero_dim

tests/third_party/cupy/logic_tests/test_comparison.py::TestArrayEqual::test_array_equal_broadcast_not_allowed
tests/third_party/cupy/logic_tests/test_comparison.py::TestArrayEqual::test_array_equal_diff_dtypes_is_equal
Expand Down
Loading
Loading