diff --git a/dpnp/dpnp_iface_linearalgebra.py b/dpnp/dpnp_iface_linearalgebra.py index bffe881b626..7baca14c93b 100644 --- a/dpnp/dpnp_iface_linearalgebra.py +++ b/dpnp/dpnp_iface_linearalgebra.py @@ -39,6 +39,7 @@ import numpy +from numpy.core.numeric import normalize_axis_tuple import dpnp from dpnp.dpnp_algo import * @@ -66,9 +67,9 @@ def dot(a, b, out=None): Parameters ---------- - a : {dpnp_array, usm_ndarray, scalar} + a : {dpnp.ndarray, usm_ndarray, scalar} First input array. Both inputs `a` and `b` can not be scalars at the same time. - b : {dpnp_array, usm_ndarray, scalar} + b : {dpnp.ndarray, usm_ndarray, scalar} Second input array. Both inputs `a` and `b` can not be scalars at the same time. out : {dpnp.ndarray, usm_ndarray}, optional Alternative output array in which to place the result. It must have @@ -404,42 +405,152 @@ def outer(x1, x2, out=None): return call_origin(numpy.outer, x1, x2, out=out) -def tensordot(x1, x2, axes=2): - """ +def tensordot(a, b, axes=2): + r""" Compute tensor dot product along specified axes. For full documentation refer to :obj:`numpy.tensordot`. - Limitations - ----------- - Parameters `x1` and `x2` are supported as :obj:`dpnp.ndarray`. - Keyword argument `kwargs` is currently unsupported. - Parameter `axes` is supported only with value ``1``. - Otherwise the functions will be executed sequentially on CPU. - Input array data types are limited by supported DPNP :ref:`Data types`. + Parameters + ---------- + a : {dpnp.ndarray, usm_ndarray, scalar} + First input array. Both inputs `a` and `b` can not be scalars at the same time. + b : {dpnp.ndarray, usm_ndarray, scalar} + Second input array. Both inputs `a` and `b` can not be scalars at the same time. + axes : int or (2,) array_like + * integer_like + If an int `N`, sum over the last `N` axes of `a` and the first `N` axes + of `b` in order. The sizes of the corresponding axes must match. + * (2,) array_like + Or, a list of axes to be summed over, first sequence applying to `a`, + second to `b`. Both elements array_like must be of the same length. + + Returns + ------- + out : dpnp.ndarray + Returns the tensordot product of `a` and `b`. See Also -------- :obj:`dpnp.dot` : Returns the dot product. :obj:`dpnp.einsum` : Evaluates the Einstein summation convention on the operands. + Notes + ----- + Three common use cases are: + * ``axes = 0`` : tensor product :math:`a \otimes b` + * ``axes = 1`` : tensor dot product :math:`a \cdot b` + * ``axes = 2`` : (default) tensor double contraction :math:`a:b` + + When `axes` is integer, the sequence for evaluation will be: first + the -Nth axis in `a` and 0th axis in `b`, and the -1th axis in `a` and + Nth axis in `b` last. + + When there is more than one axis to sum over - and they are not the last + (first) axes of `a` (`b`) - the argument `axes` should consist of + two sequences of the same length, with the first axis to sum over given + first in both sequences, the second axis second, and so forth. + + The shape of the result consists of the non-contracted axes of the + first tensor, followed by the non-contracted axes of the second. + Examples -------- >>> import dpnp as np >>> a = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) >>> b = np.array([1, 2, 3]) - >>> result = np.tensordot(a, b, 1) - >>> [x for x in result] - [14, 32, 50] + >>> np.tensordot(a, b, 1) + array([14, 32, 50]) + + >>> a = np.arange(60.).reshape(3,4,5) + >>> b = np.arange(24.).reshape(4,3,2) + >>> c = np.tensordot(a,b, axes=([1,0],[0,1])) + >>> c.shape + (5, 2) + >>> c + array([[4400., 4730.], + [4532., 4874.], + [4664., 5018.], + [4796., 5162.], + [4928., 5306.]]) + + A slower but equivalent way of computing the same... + + >>> d = np.zeros((5,2)) + >>> for i in range(5): + ... for j in range(2): + ... for k in range(3): + ... for n in range(4): + ... d[i,j] += a[k,n,i] * b[n,k,j] + >>> c == d + array([[ True, True], + [ True, True], + [ True, True], + [ True, True], + [ True, True]]) """ - x1_desc = dpnp.get_dpnp_descriptor(x1, copy_when_nondefault_queue=False) - x2_desc = dpnp.get_dpnp_descriptor(x2, copy_when_nondefault_queue=False) - if x1_desc and x2_desc and (axes == 1): - return dpnp_tensordot_not_implemented(x1_desc, x2_desc) # dpnp_matmul + dpnp.check_supported_arrays_type(a, b, scalar_type=True) - return call_origin(numpy.tensordot, x1, x2, axes) + if dpnp.isscalar(a): + a = dpnp.array(a, sycl_queue=b.sycl_queue, usm_type=b.usm_type) + elif dpnp.isscalar(b): + b = dpnp.array(b, sycl_queue=a.sycl_queue, usm_type=a.usm_type) + + try: + iter(axes) + except Exception: + if not isinstance(axes, int): + raise TypeError("Axes must be an integer.") + axes_a = tuple(range(-axes, 0)) + axes_b = tuple(range(0, axes)) + else: + if len(axes) != 2: + raise ValueError("Axes must consist of two sequences.") + + axes_a, axes_b = axes + axes_a = (axes_a,) if dpnp.isscalar(axes_a) else axes_a + axes_b = (axes_b,) if dpnp.isscalar(axes_b) else axes_b + + if len(axes_a) != len(axes_b): + raise ValueError("Axes length mismatch.") + + a_shape = a.shape + b_shape = b.shape + for axis_a, axis_b in zip(axes_a, axes_b): + if a_shape[axis_a] != b_shape[axis_b]: + raise ValueError( + "shape of input arrays is not similar at requested axes." + ) + + # Make the axes non-negative + a_ndim = a.ndim + b_ndim = b.ndim + axes_a = normalize_axis_tuple(axes_a, a_ndim, "axis") + axes_b = normalize_axis_tuple(axes_b, b_ndim, "axis") + + # Move the axes to sum over, to the end of "a" + notin = tuple(k for k in range(a_ndim) if k not in axes_a) + newaxes_a = notin + axes_a + N1 = int(numpy.prod([a_shape[ax] for ax in notin])) + N2 = int(numpy.prod([a_shape[ax] for ax in axes_a])) + newshape_a = (N1, N2) + olda = [a_shape[axis] for axis in notin] + + # Move the axes to sum over, to the front of "b" + notin = tuple(k for k in range(b_ndim) if k not in axes_b) + newaxes_b = tuple(axes_b + notin) + N1 = int(numpy.prod([b_shape[ax] for ax in axes_b])) + N2 = int(numpy.prod([b_shape[ax] for ax in notin])) + newshape_b = (N1, N2) + oldb = [b_shape[axis] for axis in notin] + + at = a.transpose(newaxes_a).reshape(newshape_a) + bt = b.transpose(newaxes_b).reshape(newshape_b) + res = dpnp.matmul(at, bt) + + return res.reshape(olda + oldb) def vdot(a, b): @@ -450,11 +561,11 @@ def vdot(a, b): Parameters ---------- - a : {dpnp_array, usm_ndarray, scalar} + a : {dpnp.ndarray, usm_ndarray, scalar} First input array. Both inputs `a` and `b` can not be scalars at the same time. If `a` is complex, the complex conjugate is taken before the calculation of the dot product. - b : {dpnp_array, usm_ndarray, scalar} + b : {dpnp.ndarray, usm_ndarray, scalar} Second input array. Both inputs `a` and `b` can not be scalars at the same time. diff --git a/dpnp/dpnp_iface_sorting.py b/dpnp/dpnp_iface_sorting.py index 6a3db20e74c..93e8db2172b 100644 --- a/dpnp/dpnp_iface_sorting.py +++ b/dpnp/dpnp_iface_sorting.py @@ -1,5 +1,3 @@ -# cython: language_level=3 -# distutils: language = c++ # -*- coding: utf-8 -*- # ***************************************************************************** # Copyright (c) 2016-2024, Intel Corporation diff --git a/tests/helper.py b/tests/helper.py index aac6b51a1c6..2a2873afdce 100644 --- a/tests/helper.py +++ b/tests/helper.py @@ -8,7 +8,11 @@ def assert_dtype_allclose( - dpnp_arr, numpy_arr, check_type=True, check_only_type_kind=False + dpnp_arr, + numpy_arr, + check_type=True, + check_only_type_kind=False, + factor=8, ): """ Assert DPNP and NumPy array based on maximum dtype resolution of input arrays @@ -28,6 +32,7 @@ def assert_dtype_allclose( The 'check_only_type_kind' parameter (False by default) asserts only equal type kinds for all data types supported by DPNP when set to True. It is effective only when 'check_type' is also set to True. + The parameter `factor` scales the resolution used for comparing the arrays. """ @@ -44,7 +49,7 @@ def assert_dtype_allclose( if is_inexact(numpy_arr) else -dpnp.inf ) - tol = 8 * max(tol_dpnp, tol_numpy) + tol = factor * max(tol_dpnp, tol_numpy) assert_allclose(dpnp_arr.asnumpy(), numpy_arr, atol=tol, rtol=tol) if check_type: numpy_arr_dtype = numpy_arr.dtype diff --git a/tests/skipped_tests.tbl b/tests/skipped_tests.tbl index a38624e3757..182eaf8877a 100644 --- a/tests/skipped_tests.tbl +++ b/tests/skipped_tests.tbl @@ -335,10 +335,6 @@ tests/third_party/cupy/linalg_tests/test_einsum.py::TestListArgEinSumError::test tests/third_party/cupy/linalg_tests/test_product.py::TestMatrixPower::test_matrix_power_invlarge tests/third_party/cupy/linalg_tests/test_product.py::TestMatrixPower::test_matrix_power_large tests/third_party/cupy/linalg_tests/test_product.py::TestMatrixPower::test_matrix_power_of_two -tests/third_party/cupy/linalg_tests/test_product.py::TestProduct::test_tensordot_zero_dim -tests/third_party/cupy/linalg_tests/test_product.py::TestProduct::test_transposed_tensordot -tests/third_party/cupy/linalg_tests/test_product.py::TestProduct::test_transposed_tensordot_with_int_axes -tests/third_party/cupy/linalg_tests/test_product.py::TestProduct::test_transposed_tensordot_with_list_axes tests/third_party/cupy/logic_tests/test_comparison.py::TestArrayEqual::test_array_equal_broadcast_not_allowed tests/third_party/cupy/logic_tests/test_comparison.py::TestArrayEqual::test_array_equal_diff_dtypes_is_equal diff --git a/tests/skipped_tests_gpu.tbl b/tests/skipped_tests_gpu.tbl index ce6f6aef984..d6fd43e1887 100644 --- a/tests/skipped_tests_gpu.tbl +++ b/tests/skipped_tests_gpu.tbl @@ -437,10 +437,6 @@ tests/third_party/cupy/linalg_tests/test_einsum.py::TestListArgEinSumError::test tests/third_party/cupy/linalg_tests/test_product.py::TestMatrixPower::test_matrix_power_invlarge tests/third_party/cupy/linalg_tests/test_product.py::TestMatrixPower::test_matrix_power_large tests/third_party/cupy/linalg_tests/test_product.py::TestMatrixPower::test_matrix_power_of_two -tests/third_party/cupy/linalg_tests/test_product.py::TestProduct::test_transposed_tensordot -tests/third_party/cupy/linalg_tests/test_product.py::TestProduct::test_transposed_tensordot_with_int_axes -tests/third_party/cupy/linalg_tests/test_product.py::TestProduct::test_transposed_tensordot_with_list_axes -tests/third_party/cupy/linalg_tests/test_product.py::TestProduct::test_tensordot_zero_dim tests/third_party/cupy/logic_tests/test_comparison.py::TestArrayEqual::test_array_equal_broadcast_not_allowed tests/third_party/cupy/logic_tests/test_comparison.py::TestArrayEqual::test_array_equal_diff_dtypes_is_equal diff --git a/tests/test_dot.py b/tests/test_dot.py index 42478db9634..03045f002a8 100644 --- a/tests/test_dot.py +++ b/tests/test_dot.py @@ -44,9 +44,6 @@ def test_dot_scalar(self, dtype): expected = numpy.dot(a, b) assert_allclose(result, expected) - # TODO: get rid of falls back on NumPy when tensordot - # is implemented using OneMKL - @pytest.mark.usefixtures("allow_fall_back_on_numpy") @pytest.mark.parametrize("dtype", get_all_dtypes(no_complex=True)) @pytest.mark.parametrize( "array_info", @@ -88,9 +85,6 @@ def test_dot(self, dtype, array_info): expected = numpy.dot(a, b) assert_dtype_allclose(result, expected) - # TODO: get rid of falls back on NumPy when tensordot - # is implemented using OneMKL - @pytest.mark.usefixtures("allow_fall_back_on_numpy") @pytest.mark.parametrize("dtype", get_complex_dtypes()) @pytest.mark.parametrize( "array_info", @@ -132,9 +126,6 @@ def test_dot_complex(self, dtype, array_info): expected = numpy.dot(a, b) assert_dtype_allclose(result, expected) - # TODO: get rid of falls back on NumPy when tensordot - # is implemented using OneMKL - @pytest.mark.usefixtures("allow_fall_back_on_numpy") @pytest.mark.parametrize("dtype", get_all_dtypes()) @pytest.mark.parametrize( "array_info", @@ -214,9 +205,6 @@ def test_dot_out_scalar(self, dtype): assert result is dp_out assert_allclose(result, expected) - # TODO: get rid of falls back on NumPy when tensordot - # is implemented using OneMKL - @pytest.mark.usefixtures("allow_fall_back_on_numpy") @pytest.mark.parametrize("dtype", get_all_dtypes()) @pytest.mark.parametrize( "array_info", @@ -294,21 +282,14 @@ def test_dot_out_error_scalar(self, ia): # output data type is incorrect dp_out = dpnp.empty((10,), dtype=dpnp.int64) - # TODO: change it to ValueError, when updated - # dpctl is being used in internal CI - with pytest.raises((ValueError, TypeError)): + with pytest.raises(ValueError): dpnp.dot(ia, ib, out=dp_out) # output shape is incorrect dp_out = dpnp.empty((2,), dtype=dpnp.int32) - # TODO: change it to ValueError, when updated - # dpctl is being used in internal CI - with pytest.raises((ValueError, TypeError)): + with pytest.raises(ValueError): dpnp.dot(ia, ib, out=dp_out) - # TODO: get rid of falls back on NumPy when tensordot - # is implemented using OneMKL - @pytest.mark.usefixtures("allow_fall_back_on_numpy") @pytest.mark.parametrize( "shape_pair", [ @@ -373,6 +354,151 @@ def test_multi_dot(type): assert_array_equal(expected, result) +class TestTensordot: + @pytest.mark.parametrize("dtype", get_all_dtypes()) + def test_tensordot_scalar(self, dtype): + a = 2 + b = numpy.array(numpy.random.uniform(-5, 5, 10), dtype=dtype) + ib = dpnp.array(b) + + result = dpnp.tensordot(a, ib, axes=0) + expected = numpy.tensordot(a, b, axes=0) + assert_allclose(result, expected) + + result = dpnp.tensordot(ib, a, axes=0) + expected = numpy.tensordot(b, a, axes=0) + assert_allclose(result, expected) + + @pytest.mark.parametrize("dtype", get_all_dtypes(no_complex=True)) + @pytest.mark.parametrize("axes", [-3, -2, -1, 0, 1, 2]) + def test_tensordot(self, dtype, axes): + a = numpy.array(numpy.random.uniform(-10, 10, 64), dtype=dtype).reshape( + 4, 4, 4 + ) + b = numpy.array(numpy.random.uniform(-10, 10, 64), dtype=dtype).reshape( + 4, 4, 4 + ) + ia = dpnp.array(a) + ib = dpnp.array(b) + + result = dpnp.tensordot(ia, ib, axes=axes) + expected = numpy.tensordot(a, b, axes=axes) + # TODO: investigate the effect of factor, see SAT-6700 + assert_dtype_allclose(result, expected, factor=24) + + @pytest.mark.parametrize("dtype", get_complex_dtypes()) + @pytest.mark.parametrize("axes", [-3, -2, -1, 0, 1, 2]) + def test_tensordot_complex(self, dtype, axes): + x11 = numpy.random.uniform(-10, 10, 64) + x12 = numpy.random.uniform(-10, 10, 64) + x21 = numpy.random.uniform(-10, 10, 64) + x22 = numpy.random.uniform(-10, 10, 64) + a = numpy.array(x11 + 1j * x12, dtype=dtype).reshape(4, 4, 4) + b = numpy.array(x21 + 1j * x22, dtype=dtype).reshape(4, 4, 4) + ia = dpnp.array(a) + ib = dpnp.array(b) + + result = dpnp.tensordot(ia, ib, axes=axes) + expected = numpy.tensordot(a, b, axes=axes) + # TODO: investigate the effect of factor, see SAT-6700 + assert_dtype_allclose(result, expected, factor=24) + + @pytest.mark.parametrize("dtype", get_all_dtypes(no_bool=True)) + @pytest.mark.parametrize( + "axes", + [ + ([0, 1]), + ([0, 1], [1, 2]), + (2, 3), + ([-2, -3], [3, 2]), + ((3, 1), (0, 2)), + ], + ) + def test_tensordot_axes(self, dtype, axes): + a = numpy.array( + numpy.random.uniform(-10, 10, 120), dtype=dtype + ).reshape(2, 5, 3, 4) + b = numpy.array( + numpy.random.uniform(-10, 10, 120), dtype=dtype + ).reshape(4, 2, 5, 3) + ia = dpnp.array(a) + ib = dpnp.array(b) + + result = dpnp.tensordot(ia, ib, axes=axes) + expected = numpy.tensordot(a, b, axes=axes) + # TODO: investigate the effect of factor, see SAT-6700 + assert_dtype_allclose(result, expected, factor=24) + + @pytest.mark.parametrize("dtype1", get_all_dtypes()) + @pytest.mark.parametrize("dtype2", get_all_dtypes()) + def test_tensordot_input_dtype_matrix(self, dtype1, dtype2): + a = numpy.array( + numpy.random.uniform(-10, 10, 60), dtype=dtype1 + ).reshape(3, 4, 5) + b = numpy.array( + numpy.random.uniform(-10, 10, 40), dtype=dtype2 + ).reshape(4, 5, 2) + ia = dpnp.array(a) + ib = dpnp.array(b) + + result = dpnp.tensordot(ia, ib) + expected = numpy.tensordot(a, b) + # TODO: investigate the effect of factor, see SAT-6700 + assert_dtype_allclose(result, expected, factor=24) + + def test_tensordot_strided(self): + for dim in [1, 2, 3, 4]: + axes = 1 if dim == 1 else 2 + A = numpy.random.rand(*([10] * dim)) + B = dpnp.asarray(A) + # positive stride + slices = tuple(slice(None, None, 2) for _ in range(dim)) + a = A[slices] + b = B[slices] + + result = dpnp.tensordot(b, b, axes=axes) + expected = numpy.tensordot(a, a, axes=axes) + assert_dtype_allclose(result, expected) + + # negative stride + slices = tuple(slice(None, None, -2) for _ in range(dim)) + a = A[slices] + b = B[slices] + + result = dpnp.tensordot(b, b, axes=axes) + expected = numpy.tensordot(a, a, axes=axes) + assert_dtype_allclose(result, expected) + + def test_tensordot_error(self): + a = 5 + b = 2 + # both inputs are scalar + with pytest.raises(TypeError): + dpnp.tensordot(a, b, axes=0) + + a = dpnp.arange(24).reshape(2, 3, 4) + b = dpnp.arange(24).reshape(3, 4, 2) + # axes should be an integer + with pytest.raises(TypeError): + dpnp.tensordot(a, b, axes=2.0) + + # Axes must consist of two sequences + with pytest.raises(ValueError): + dpnp.tensordot(a, b, axes=([0, 2],)) + + # Axes length mismatch + with pytest.raises(ValueError): + dpnp.tensordot(a, b, axes=([0, 2], [2])) + + # shape of input arrays is not similar at requested axes + with pytest.raises(ValueError): + dpnp.tensordot(a, b, axes=([0, 2], [2, 0])) + + # out of range index + with pytest.raises(IndexError): + dpnp.tensordot(a, b, axes=([0, 3], [2, 0])) + + class TestVdot: @pytest.mark.parametrize("dtype", get_all_dtypes()) def test_vdot_scalar(self, dtype): diff --git a/tests/test_mathematical.py b/tests/test_mathematical.py index 80fe09c61b8..12115b5256c 100644 --- a/tests/test_mathematical.py +++ b/tests/test_mathematical.py @@ -2726,7 +2726,7 @@ def test_matmul_strided(self): for dim in [1, 2, 3, 4]: A = numpy.random.rand(*([20] * dim)) B = dpnp.asarray(A) - # positive strides + # positive stride slices = tuple(slice(None, None, 2) for _ in range(dim)) a = A[slices] b = B[slices] @@ -2735,7 +2735,7 @@ def test_matmul_strided(self): expected = numpy.matmul(a, a) assert_dtype_allclose(result, expected) - # negative strides + # negative stride slices = tuple(slice(None, None, -2) for _ in range(dim)) a = A[slices] b = B[slices] diff --git a/tests/test_sycl_queue.py b/tests/test_sycl_queue.py index 6bc24af6c7d..479e96e0229 100644 --- a/tests/test_sycl_queue.py +++ b/tests/test_sycl_queue.py @@ -579,6 +579,11 @@ def test_reduce_hypot(device): [0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], [0.0, 1.0, 2.0, 0.0, 1.0, 2.0, 0.0, 1.0, 2.0], ), + pytest.param( + "tensordot", + [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]], + [[4.0, 4.0, 4.0], [4.0, 4.0, 4.0]], + ), # dpnp.vdot has 3 different implementations based on input arrays dtype # checking all of them pytest.param("vdot", [3.0, 4.0, 5.0], [1.0, 2.0, 3.0]), diff --git a/tests/test_usm_type.py b/tests/test_usm_type.py index e188cdb1c47..21dfb3cde67 100644 --- a/tests/test_usm_type.py +++ b/tests/test_usm_type.py @@ -505,6 +505,11 @@ def test_1in_1out(func, data, usm_type): pytest.param("logaddexp", [[-1, 2, 5, 9]], [[4, -3, 2, -8]]), pytest.param("maximum", [[0.0, 1.0, 2.0]], [[3.0, 4.0, 5.0]]), pytest.param("minimum", [[0.0, 1.0, 2.0]], [[3.0, 4.0, 5.0]]), + pytest.param( + "tensordot", + [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]], + [[4.0, 4.0, 4.0], [4.0, 4.0, 4.0]], + ), # dpnp.vdot has 3 different implementations based on input arrays dtype # checking all of them pytest.param("vdot", [3.0, 4.0, 5.0], [1.0, 2.0, 3.0]), diff --git a/tests/third_party/cupy/linalg_tests/test_product.py b/tests/third_party/cupy/linalg_tests/test_product.py index 1fd048356b4..e59b30dcd6e 100644 --- a/tests/third_party/cupy/linalg_tests/test_product.py +++ b/tests/third_party/cupy/linalg_tests/test_product.py @@ -36,9 +36,6 @@ } ) ) -# TODO: get rid of falls back on NumPy when tensordot -# is implemented using OneMKL -@pytest.mark.usefixtures("allow_fall_back_on_numpy") class TestDot(unittest.TestCase): @testing.for_all_dtypes_combination(["dtype_a", "dtype_b"]) @testing.numpy_cupy_allclose(type_check=has_support_aspect64()) @@ -161,9 +158,6 @@ def test_dot_vec1(self, xp, dtype): b = testing.shaped_arange((2,), xp, dtype) return xp.dot(a, b) - # TODO: get rid of falls back on NumPy when tensordot - # is implemented using OneMKL - @pytest.mark.usefixtures("allow_fall_back_on_numpy") @testing.for_all_dtypes() @testing.numpy_cupy_allclose() def test_dot_vec2(self, xp, dtype): @@ -178,9 +172,6 @@ def test_dot_vec3(self, xp, dtype): b = testing.shaped_arange((2,), xp, dtype) return xp.dot(a, b) - # TODO: get rid of falls back on NumPy when tensordot - # is implemented using OneMKL - @pytest.mark.usefixtures("allow_fall_back_on_numpy") @testing.for_all_dtypes() @testing.numpy_cupy_allclose() def test_transposed_dot(self, xp, dtype): @@ -188,9 +179,6 @@ def test_transposed_dot(self, xp, dtype): b = testing.shaped_arange((2, 3, 4), xp, dtype).transpose(0, 2, 1) return xp.dot(a, b) - # TODO: get rid of falls back on NumPy when tensordot - # is implemented using OneMKL - @pytest.mark.usefixtures("allow_fall_back_on_numpy") @testing.for_all_dtypes() @testing.numpy_cupy_allclose() def test_transposed_dot_with_out(self, xp, dtype): @@ -200,9 +188,6 @@ def test_transposed_dot_with_out(self, xp, dtype): xp.dot(a, b, out=c) return c - # TODO: get rid of falls back on NumPy when tensordot - # is implemented using OneMKL - @pytest.mark.usefixtures("allow_fall_back_on_numpy") @testing.for_all_dtypes() def test_transposed_dot_with_out_f_contiguous(self, dtype): for xp in (numpy, cupy): @@ -307,7 +292,6 @@ def test_multidim_outer(self, xp, dtype): b = testing.shaped_arange((4, 5), xp, dtype) return xp.outer(a, b) - @pytest.mark.usefixtures("allow_fall_back_on_numpy") @testing.for_all_dtypes() @testing.numpy_cupy_allclose() def test_tensordot(self, xp, dtype): @@ -322,7 +306,6 @@ def test_transposed_tensordot(self, xp, dtype): b = testing.shaped_arange((4, 3, 2), xp, dtype).transpose(2, 0, 1) return xp.tensordot(a, b) - @pytest.mark.usefixtures("allow_fall_back_on_numpy") @testing.for_all_dtypes() @testing.numpy_cupy_allclose() def test_tensordot_with_int_axes(self, xp, dtype): @@ -352,7 +335,6 @@ def test_transposed_tensordot_with_int_axes(self, xp, dtype): ) return xp.tensordot(a, b, axes=3) - @pytest.mark.usefixtures("allow_fall_back_on_numpy") @testing.for_all_dtypes() @testing.numpy_cupy_allclose() def test_tensordot_with_list_axes(self, xp, dtype): @@ -433,8 +415,6 @@ def test_zerodim_kron(self, xp, dtype): } ) ) -@pytest.mark.usefixtures("allow_fall_back_on_numpy") -@testing.gpu class TestProductZeroLength(unittest.TestCase): @testing.for_all_dtypes() @testing.numpy_cupy_allclose()