Skip to content

Commit

Permalink
implement dpnp.prod and dpnp.nanprod (#1613)
Browse files Browse the repository at this point in the history
* implement dpnp.prod and dpnp.nanprod

* address comments

* updates for nanprod input array

* allow fall back on numpy - needed for Win tests
  • Loading branch information
vtavana authored Nov 7, 2023
1 parent de562ee commit db127d4
Show file tree
Hide file tree
Showing 12 changed files with 279 additions and 263 deletions.
4 changes: 1 addition & 3 deletions dpnp/backend/include/dpnp_iface_fptr.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -244,9 +244,7 @@ enum class DPNPFuncName : size_t
DPNP_FN_PLACE, /**< Used in numpy.place() impl */
DPNP_FN_POWER, /**< Used in numpy.power() impl */
DPNP_FN_PROD, /**< Used in numpy.prod() impl */
DPNP_FN_PROD_EXT, /**< Used in numpy.prod() impl, requires extra parameters
*/
DPNP_FN_PTP, /**< Used in numpy.ptp() impl */
DPNP_FN_PTP, /**< Used in numpy.ptp() impl */
DPNP_FN_PTP_EXT, /**< Used in numpy.ptp() impl, requires extra parameters */
DPNP_FN_PUT, /**< Used in numpy.put() impl */
DPNP_FN_PUT_ALONG_AXIS, /**< Used in numpy.put_along_axis() impl */
Expand Down
49 changes: 0 additions & 49 deletions dpnp/backend/kernels/dpnp_krnl_reduction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -294,19 +294,6 @@ void (*dpnp_prod_default_c)(void *,
const long *) =
dpnp_prod_c<_DataType_output, _DataType_input>;

template <typename _DataType_output, typename _DataType_input>
DPCTLSyclEventRef (*dpnp_prod_ext_c)(DPCTLSyclQueueRef,
void *,
const void *,
const shape_elem_type *,
const size_t,
const shape_elem_type *,
const size_t,
const void *,
const long *,
const DPCTLEventVectorRef) =
dpnp_prod_c<_DataType_output, _DataType_input>;

void func_map_init_reduction(func_map_t &fmap)
{
// WARNING. The meaning of the fmap is changed. Second argument represents
Expand Down Expand Up @@ -349,42 +336,6 @@ void func_map_init_reduction(func_map_t &fmap)
fmap[DPNPFuncName::DPNP_FN_PROD][eft_DBL][eft_DBL] = {
eft_DBL, (void *)dpnp_prod_default_c<double, double>};

fmap[DPNPFuncName::DPNP_FN_PROD_EXT][eft_INT][eft_INT] = {
eft_LNG, (void *)dpnp_prod_ext_c<int32_t, int32_t>};
fmap[DPNPFuncName::DPNP_FN_PROD_EXT][eft_INT][eft_LNG] = {
eft_LNG, (void *)dpnp_prod_ext_c<int64_t, int32_t>};
fmap[DPNPFuncName::DPNP_FN_PROD_EXT][eft_INT][eft_FLT] = {
eft_FLT, (void *)dpnp_prod_ext_c<float, int32_t>};
fmap[DPNPFuncName::DPNP_FN_PROD_EXT][eft_INT][eft_DBL] = {
eft_DBL, (void *)dpnp_prod_ext_c<double, int32_t>};

fmap[DPNPFuncName::DPNP_FN_PROD_EXT][eft_LNG][eft_INT] = {
eft_INT, (void *)dpnp_prod_ext_c<int32_t, int64_t>};
fmap[DPNPFuncName::DPNP_FN_PROD_EXT][eft_LNG][eft_LNG] = {
eft_LNG, (void *)dpnp_prod_ext_c<int64_t, int64_t>};
fmap[DPNPFuncName::DPNP_FN_PROD_EXT][eft_LNG][eft_FLT] = {
eft_FLT, (void *)dpnp_prod_ext_c<float, int64_t>};
fmap[DPNPFuncName::DPNP_FN_PROD_EXT][eft_LNG][eft_DBL] = {
eft_DBL, (void *)dpnp_prod_ext_c<double, int64_t>};

fmap[DPNPFuncName::DPNP_FN_PROD_EXT][eft_FLT][eft_INT] = {
eft_INT, (void *)dpnp_prod_ext_c<int32_t, float>};
fmap[DPNPFuncName::DPNP_FN_PROD_EXT][eft_FLT][eft_LNG] = {
eft_LNG, (void *)dpnp_prod_ext_c<int64_t, float>};
fmap[DPNPFuncName::DPNP_FN_PROD_EXT][eft_FLT][eft_FLT] = {
eft_FLT, (void *)dpnp_prod_ext_c<float, float>};
fmap[DPNPFuncName::DPNP_FN_PROD_EXT][eft_FLT][eft_DBL] = {
eft_DBL, (void *)dpnp_prod_ext_c<double, float>};

fmap[DPNPFuncName::DPNP_FN_PROD_EXT][eft_DBL][eft_INT] = {
eft_INT, (void *)dpnp_prod_ext_c<int32_t, double>};
fmap[DPNPFuncName::DPNP_FN_PROD_EXT][eft_DBL][eft_LNG] = {
eft_LNG, (void *)dpnp_prod_ext_c<int64_t, double>};
fmap[DPNPFuncName::DPNP_FN_PROD_EXT][eft_DBL][eft_FLT] = {
eft_FLT, (void *)dpnp_prod_ext_c<float, double>};
fmap[DPNPFuncName::DPNP_FN_PROD_EXT][eft_DBL][eft_DBL] = {
eft_DBL, (void *)dpnp_prod_ext_c<double, double>};

fmap[DPNPFuncName::DPNP_FN_SUM][eft_INT][eft_INT] = {
eft_LNG, (void *)dpnp_sum_default_c<int32_t, int32_t>};
fmap[DPNPFuncName::DPNP_FN_SUM][eft_INT][eft_LNG] = {
Expand Down
2 changes: 0 additions & 2 deletions dpnp/dpnp_algo/dpnp_algo.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -126,8 +126,6 @@ cdef extern from "dpnp_iface_fptr.hpp" namespace "DPNPFuncName": # need this na
DPNP_FN_PARTITION
DPNP_FN_PARTITION_EXT
DPNP_FN_PLACE
DPNP_FN_PROD
DPNP_FN_PROD_EXT
DPNP_FN_PTP
DPNP_FN_PTP_EXT
DPNP_FN_QR
Expand Down
84 changes: 0 additions & 84 deletions dpnp/dpnp_algo/dpnp_algo_mathematical.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,7 @@ __all__ += [
"dpnp_modf",
"dpnp_nancumprod",
"dpnp_nancumsum",
"dpnp_nanprod",
"dpnp_nansum",
"dpnp_prod",
"dpnp_sum",
"dpnp_trapz",
]
Expand Down Expand Up @@ -319,26 +317,6 @@ cpdef utils.dpnp_descriptor dpnp_nancumsum(utils.dpnp_descriptor x1):
return dpnp_cumsum(x1_desc)


cpdef utils.dpnp_descriptor dpnp_nanprod(utils.dpnp_descriptor x1):
x1_obj = x1.get_array()
cdef utils.dpnp_descriptor result = utils_py.create_output_descriptor_py(x1.shape,
x1.dtype,
None,
device=x1_obj.sycl_device,
usm_type=x1_obj.usm_type,
sycl_queue=x1_obj.sycl_queue)

for i in range(result.size):
input_elem = x1.get_pyobj().flat[i]

if dpnp.isnan(input_elem):
result.get_pyobj().flat[i] = 1
else:
result.get_pyobj().flat[i] = input_elem

return dpnp_prod(result)


cpdef utils.dpnp_descriptor dpnp_nansum(utils.dpnp_descriptor x1):
x1_obj = x1.get_array()
cdef utils.dpnp_descriptor result = utils_py.create_output_descriptor_py(x1.shape,
Expand All @@ -359,68 +337,6 @@ cpdef utils.dpnp_descriptor dpnp_nansum(utils.dpnp_descriptor x1):
return dpnp_sum(result)


cpdef utils.dpnp_descriptor dpnp_prod(utils.dpnp_descriptor x1,
object axis=None,
object dtype=None,
utils.dpnp_descriptor out=None,
cpp_bool keepdims=False,
object initial=None,
object where=True):
"""
input:float64 : output:float64 : name:prod
input:float32 : output:float32 : name:prod
input:int64 : output:int64 : name:prod
input:int32 : output:int64 : name:prod
input:bool : output:int64 : name:prod
input:complex64 : output:complex64 : name:prod
input:complex128: output:complex128: name:prod
"""

cdef shape_type_c x1_shape = x1.shape
cdef DPNPFuncType x1_c_type = dpnp_dtype_to_DPNPFuncType(x1.dtype)

cdef shape_type_c axis_shape = utils._object_to_tuple(axis)

cdef shape_type_c result_shape = utils.get_reduction_output_shape(x1_shape, axis, keepdims)
cdef DPNPFuncType result_c_type = utils.get_output_c_type(DPNP_FN_PROD_EXT, x1_c_type, out, dtype)

""" select kernel """
cdef DPNPFuncData kernel_data = get_dpnp_function_ptr(DPNP_FN_PROD_EXT, x1_c_type, result_c_type)

x1_obj = x1.get_array()

""" Create result array """
cdef utils.dpnp_descriptor result = utils.create_output_descriptor(result_shape,
result_c_type,
out,
device=x1_obj.sycl_device,
usm_type=x1_obj.usm_type,
sycl_queue=x1_obj.sycl_queue)
cdef dpnp_reduction_c_t func = <dpnp_reduction_c_t > kernel_data.ptr

result_sycl_queue = result.get_array().sycl_queue

cdef c_dpctl.SyclQueue q = <c_dpctl.SyclQueue> result_sycl_queue
cdef c_dpctl.DPCTLSyclQueueRef q_ref = q.get_queue_ref()

""" Call FPTR interface function """
cdef c_dpctl.DPCTLSyclEventRef event_ref = func(q_ref,
result.get_data(),
x1.get_data(),
x1_shape.data(),
x1_shape.size(),
axis_shape.data(),
axis_shape.size(),
NULL,
NULL,
NULL) # dep_events_ref

with nogil: c_dpctl.DPCTLEvent_WaitAndThrow(event_ref)
c_dpctl.DPCTLEvent_Delete(event_ref)

return result


cpdef utils.dpnp_descriptor dpnp_sum(utils.dpnp_descriptor x1,
object axis=None,
object dtype=None,
Expand Down
4 changes: 1 addition & 3 deletions dpnp/dpnp_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -1026,9 +1026,7 @@ def prod(
"""
Returns the prod along a given axis.
.. seealso::
:obj:`dpnp.prod` for full documentation,
:meth:`dpnp.dparray.sum`
For full documentation refer to :obj:`dpnp.prod`.
"""

Expand Down
Loading

0 comments on commit db127d4

Please sign in to comment.