Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update dpnp.kron implementation #1732

Merged
merged 2 commits into from
Feb 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions dpnp/backend/include/dpnp_iface_fptr.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -174,8 +174,6 @@ enum class DPNPFuncName : size_t
DPNP_FN_INV, /**< Used in numpy.linalg.inv() impl */
DPNP_FN_INVERT, /**< Used in numpy.invert() impl */
DPNP_FN_KRON, /**< Used in numpy.kron() impl */
DPNP_FN_KRON_EXT, /**< Used in numpy.kron() impl, requires extra parameters
*/
DPNP_FN_LEFT_SHIFT, /**< Used in numpy.left_shift() impl */
DPNP_FN_LOG, /**< Used in numpy.log() impl */
DPNP_FN_LOG10, /**< Used in numpy.log10() impl */
Expand Down
73 changes: 0 additions & 73 deletions dpnp/backend/kernels/dpnp_krnl_linalg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -499,18 +499,6 @@ void (*dpnp_kron_default_c)(void *,
size_t) =
dpnp_kron_c<_DataType1, _DataType2, _ResultType>;

template <typename _DataType1, typename _DataType2, typename _ResultType>
DPCTLSyclEventRef (*dpnp_kron_ext_c)(DPCTLSyclQueueRef,
void *,
void *,
void *,
shape_elem_type *,
shape_elem_type *,
shape_elem_type *,
size_t,
const DPCTLEventVectorRef) =
dpnp_kron_c<_DataType1, _DataType2, _ResultType>;

template <typename _DataType>
DPCTLSyclEventRef
dpnp_matrix_rank_c(DPCTLSyclQueueRef q_ref,
Expand Down Expand Up @@ -890,67 +878,6 @@ void func_map_init_linalg_func(func_map_t &fmap)
(void *)dpnp_kron_default_c<std::complex<double>, std::complex<double>,
std::complex<double>>};

fmap[DPNPFuncName::DPNP_FN_KRON_EXT][eft_INT][eft_INT] = {
eft_INT, (void *)dpnp_kron_ext_c<int32_t, int32_t, int32_t>};
fmap[DPNPFuncName::DPNP_FN_KRON_EXT][eft_INT][eft_LNG] = {
eft_LNG, (void *)dpnp_kron_ext_c<int32_t, int64_t, int64_t>};
fmap[DPNPFuncName::DPNP_FN_KRON_EXT][eft_INT][eft_FLT] = {
eft_FLT, (void *)dpnp_kron_ext_c<int32_t, float, float>};
fmap[DPNPFuncName::DPNP_FN_KRON_EXT][eft_INT][eft_DBL] = {
eft_DBL, (void *)dpnp_kron_ext_c<int32_t, double, double>};
// fmap[DPNPFuncName::DPNP_FN_KRON_EXT][eft_INT][eft_C128] = {
// eft_C128, (void*)dpnp_kron_ext_c<int32_t, std::complex<double>,
// std::complex<double>>};
fmap[DPNPFuncName::DPNP_FN_KRON_EXT][eft_LNG][eft_INT] = {
eft_LNG, (void *)dpnp_kron_ext_c<int64_t, int32_t, int64_t>};
fmap[DPNPFuncName::DPNP_FN_KRON_EXT][eft_LNG][eft_LNG] = {
eft_LNG, (void *)dpnp_kron_ext_c<int64_t, int64_t, int64_t>};
fmap[DPNPFuncName::DPNP_FN_KRON_EXT][eft_LNG][eft_FLT] = {
eft_FLT, (void *)dpnp_kron_ext_c<int64_t, float, float>};
fmap[DPNPFuncName::DPNP_FN_KRON_EXT][eft_LNG][eft_DBL] = {
eft_DBL, (void *)dpnp_kron_ext_c<int64_t, double, double>};
// fmap[DPNPFuncName::DPNP_FN_KRON_EXT][eft_LNG][eft_C128] = {
// eft_C128, (void*)dpnp_kron_ext_c<int64_t, std::complex<double>,
// std::complex<double>>};
fmap[DPNPFuncName::DPNP_FN_KRON_EXT][eft_FLT][eft_INT] = {
eft_FLT, (void *)dpnp_kron_ext_c<float, int32_t, float>};
fmap[DPNPFuncName::DPNP_FN_KRON_EXT][eft_FLT][eft_LNG] = {
eft_FLT, (void *)dpnp_kron_ext_c<float, int64_t, float>};
fmap[DPNPFuncName::DPNP_FN_KRON_EXT][eft_FLT][eft_FLT] = {
eft_FLT, (void *)dpnp_kron_ext_c<float, float, float>};
fmap[DPNPFuncName::DPNP_FN_KRON_EXT][eft_FLT][eft_DBL] = {
eft_DBL, (void *)dpnp_kron_ext_c<float, double, double>};
// fmap[DPNPFuncName::DPNP_FN_KRON_EXT][eft_FLT][eft_C128] = {
// eft_C128, (void*)dpnp_kron_ext_c<float, std::complex<double>,
// std::complex<double>>};
fmap[DPNPFuncName::DPNP_FN_KRON_EXT][eft_DBL][eft_INT] = {
eft_DBL, (void *)dpnp_kron_ext_c<double, int32_t, double>};
fmap[DPNPFuncName::DPNP_FN_KRON_EXT][eft_DBL][eft_LNG] = {
eft_DBL, (void *)dpnp_kron_ext_c<double, int64_t, double>};
fmap[DPNPFuncName::DPNP_FN_KRON_EXT][eft_DBL][eft_FLT] = {
eft_DBL, (void *)dpnp_kron_ext_c<double, float, double>};
fmap[DPNPFuncName::DPNP_FN_KRON_EXT][eft_DBL][eft_DBL] = {
eft_DBL, (void *)dpnp_kron_ext_c<double, double, double>};
fmap[DPNPFuncName::DPNP_FN_KRON_EXT][eft_DBL][eft_C128] = {
eft_C128, (void *)dpnp_kron_ext_c<double, std::complex<double>,
std::complex<double>>};
// fmap[DPNPFuncName::DPNP_FN_KRON_EXT][eft_C128][eft_INT] = {
// eft_C128, (void*)dpnp_kron_ext_c<std::complex<double>, int32_t,
// std::complex<double>>};
// fmap[DPNPFuncName::DPNP_FN_KRON_EXT][eft_C128][eft_LNG] = {
// eft_C128, (void*)dpnp_kron_ext_c<std::complex<double>, int64_t,
// std::complex<double>>};
// fmap[DPNPFuncName::DPNP_FN_KRON_EXT][eft_C128][eft_FLT] = {
// eft_C128, (void*)dpnp_kron_ext_c<std::complex<double>, float,
// std::complex<double>>};
fmap[DPNPFuncName::DPNP_FN_KRON_EXT][eft_C128][eft_DBL] = {
eft_C128, (void *)dpnp_kron_ext_c<std::complex<double>, double,
std::complex<double>>};
fmap[DPNPFuncName::DPNP_FN_KRON_EXT][eft_C128][eft_C128] = {
eft_C128,
(void *)dpnp_kron_ext_c<std::complex<double>, std::complex<double>,
std::complex<double>>};

fmap[DPNPFuncName::DPNP_FN_MATRIX_RANK][eft_INT][eft_INT] = {
eft_INT, (void *)dpnp_matrix_rank_default_c<int32_t>};
fmap[DPNPFuncName::DPNP_FN_MATRIX_RANK][eft_LNG][eft_LNG] = {
Expand Down
1 change: 0 additions & 1 deletion dpnp/dpnp_algo/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@

set(dpnp_algo_pyx_deps
${CMAKE_CURRENT_SOURCE_DIR}/dpnp_algo_linearalgebra.pxi
${CMAKE_CURRENT_SOURCE_DIR}/dpnp_algo_statistics.pxi
${CMAKE_CURRENT_SOURCE_DIR}/dpnp_algo_trigonometric.pxi
${CMAKE_CURRENT_SOURCE_DIR}/dpnp_algo_sorting.pxi
Expand Down
2 changes: 0 additions & 2 deletions dpnp/dpnp_algo/dpnp_algo.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,6 @@ cdef extern from "dpnp_iface_fptr.hpp" namespace "DPNPFuncName": # need this na
DPNP_FN_FMOD_EXT
DPNP_FN_FULL
DPNP_FN_FULL_LIKE
DPNP_FN_KRON
DPNP_FN_KRON_EXT
DPNP_FN_MAXIMUM
DPNP_FN_MAXIMUM_EXT
DPNP_FN_MEDIAN
Expand Down
1 change: 0 additions & 1 deletion dpnp/dpnp_algo/dpnp_algo.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,6 @@ __all__ = [

include "dpnp_algo_arraycreation.pxi"
include "dpnp_algo_indexing.pxi"
include "dpnp_algo_linearalgebra.pxi"
include "dpnp_algo_logic.pxi"
include "dpnp_algo_mathematical.pxi"
include "dpnp_algo_sorting.pxi"
Expand Down
106 changes: 0 additions & 106 deletions dpnp/dpnp_algo/dpnp_algo_linearalgebra.pxi

This file was deleted.

68 changes: 58 additions & 10 deletions dpnp/dpnp_iface_linearalgebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,14 +44,12 @@
import dpnp

# pylint: disable=no-name-in-module
from .dpnp_algo import (
dpnp_kron,
)
from .dpnp_utils import (
call_origin,
)
from .dpnp_utils.dpnp_utils_linearalgebra import (
dpnp_dot,
dpnp_kron,
dpnp_matmul,
)

Expand Down Expand Up @@ -305,22 +303,72 @@ def inner(a, b):
return dpnp.tensordot(a, b, axes=(-1, -1))


def kron(x1, x2):
def kron(a, b):
"""
Returns the kronecker product of two arrays.

For full documentation refer to :obj:`numpy.kron`.

.. seealso:: :obj:`dpnp.outer` returns the outer product of two arrays.
Parameters
----------
a : {dpnp.ndarray, usm_ndarray, scalar}
First input array. Both inputs `a` and `b` can not be scalars
at the same time.
b : {dpnp.ndarray, usm_ndarray, scalar}
Second input array. Both inputs `a` and `b` can not be scalars
at the same time.

Returns
-------
out : dpnp.ndarray
Returns the Kronecker product.

See Also
--------
:obj:`dpnp.outer` : Returns the outer product of two arrays.

Examples
--------
>>> import dpnp as np
>>> a = np.array([1, 10, 100])
>>> b = np.array([5, 6, 7])
>>> np.kron(a, b)
array([ 5, 6, 7, ..., 500, 600, 700])
>>> np.kron(b, a)
array([ 5, 50, 500, ..., 7, 70, 700])

>>> np.kron(np.eye(2), np.ones((2,2)))
array([[1., 1., 0., 0.],
[1., 1., 0., 0.],
[0., 0., 1., 1.],
[0., 0., 1., 1.]])

>>> a = np.arange(100).reshape((2,5,2,5))
>>> b = np.arange(24).reshape((2,3,4))
>>> c = np.kron(a,b)
>>> c.shape
(2, 10, 6, 20)
>>> I = (1,3,0,2)
>>> J = (0,2,1)
>>> J1 = (0,) + J # extend to ndim=4
>>> S1 = (1,) + b.shape
>>> K = tuple(np.array(I) * np.array(S1) + np.array(J1))
>>> c[K] == a[I]*b[J]
array(True)

"""

x1_desc = dpnp.get_dpnp_descriptor(x1, copy_when_nondefault_queue=False)
x2_desc = dpnp.get_dpnp_descriptor(x2, copy_when_nondefault_queue=False)
if x1_desc and x2_desc:
return dpnp_kron(x1_desc, x2_desc).get_pyobj()
dpnp.check_supported_arrays_type(a, b, scalar_type=True)

if dpnp.isscalar(a) or dpnp.isscalar(b):
return dpnp.multiply(a, b)

a_ndim = a.ndim
b_ndim = b.ndim
if a_ndim == 0 or b_ndim == 0:
return dpnp.multiply(a, b)

return call_origin(numpy.kron, x1, x2)
return dpnp_kron(a, b, a_ndim, b_ndim)


def matmul(
Expand Down
30 changes: 29 additions & 1 deletion dpnp/dpnp_utils/dpnp_utils_linearalgebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
from dpnp.dpnp_array import dpnp_array
from dpnp.dpnp_utils import get_usm_allocations

__all__ = ["dpnp_cross", "dpnp_dot", "dpnp_matmul"]
__all__ = ["dpnp_cross", "dpnp_dot", "dpnp_kron", "dpnp_matmul"]


def _create_result_array(x1, x2, out, shape, dtype, usm_type, sycl_queue):
Expand Down Expand Up @@ -476,6 +476,34 @@ def dpnp_cross(a, b, cp, exec_q):
return cp

vtavana marked this conversation as resolved.
Show resolved Hide resolved

def dpnp_kron(a, b, a_ndim, b_ndim):
"""Returns the kronecker product of two arrays."""

a_shape = a.shape
b_shape = b.shape
if not a.flags.contiguous:
a = dpnp.reshape(a, a_shape)
if not b.flags.contiguous:
b = dpnp.reshape(b, b_shape)

# Equalise the shapes by prepending smaller one with 1s
a_shape = (1,) * max(0, b_ndim - a_ndim) + a_shape
b_shape = (1,) * max(0, a_ndim - b_ndim) + b_shape

# Insert empty dimensions
a_arr = dpnp.expand_dims(a, axis=tuple(range(b_ndim - a_ndim)))
b_arr = dpnp.expand_dims(b, axis=tuple(range(a_ndim - b_ndim)))

# Compute the product
ndim = max(b_ndim, a_ndim)
a_arr = dpnp.expand_dims(a_arr, axis=tuple(range(1, 2 * ndim, 2)))
b_arr = dpnp.expand_dims(b_arr, axis=tuple(range(0, 2 * ndim, 2)))
result = dpnp.multiply(a_arr, b_arr)

# Reshape back
return result.reshape(tuple(numpy.multiply(a_shape, b_shape)))


def dpnp_dot(a, b, /, out=None, *, conjugate=False):
"""
Return the dot product of two arrays.
Expand Down
2 changes: 1 addition & 1 deletion dpnp/linalg/dpnp_utils_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,7 +401,7 @@ def _multi_dot(arrays, order, i, j, out=None):

def _multi_dot_matrix_chain_order(n, arrays, return_costs=False):
"""
Return a dpnp.ndarray that encodes the optimal order of mutiplications.
Return a dpnp.ndarray that encodes the optimal order of multiplications.

The optimal order array is then used by `_multi_dot()` to do the
multiplication.
Expand Down
Loading
Loading