Skip to content

Commit

Permalink
[Numpy] FFI: sort, argsort, vstack etc (apache#17857)
Browse files Browse the repository at this point in the history
* * sort FFI

* * argsort FFI

* * vstack, row_stack FFI

* * greater FFI

* * inner FFI

* multinomial FFI

* rand FFI

* randn FFI

* * Fix input out of index and rscalar of greater

* * Fix ndarray situation

* * Fix sanity

* fix lint

* fix bugs

* * Remove duplicate operator (greater)

* * Fix Tuple downcast Error (Only Integer)
* Fix segmentation fault(pointer)

Co-authored-by: Sheng Zha <[email protected]>
  • Loading branch information
hanke580 and szha authored Aug 10, 2020
1 parent 5c50475 commit d0e17e5
Show file tree
Hide file tree
Showing 9 changed files with 232 additions and 17 deletions.
11 changes: 7 additions & 4 deletions benchmark/python/ffi/benchmark_ffi.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,6 @@ def prepare_workloads():
OpArgMngr.add_workload("nan_to_num", pool['2x2'])
OpArgMngr.add_workload("tri", 2, 3, 4)
OpArgMngr.add_workload("tensordot", pool['2x2'], pool['2x2'], ((1, 0), (0, 1)))
OpArgMngr.add_workload("kron", pool['2x2'], pool['2x2'])
OpArgMngr.add_workload("cumsum", pool['3x2'], axis=0, out=pool['3x2'])
OpArgMngr.add_workload("random.shuffle", pool['3'])
OpArgMngr.add_workload("equal", pool['2x2'], pool['2x2'])
Expand Down Expand Up @@ -100,11 +99,14 @@ def prepare_workloads():
OpArgMngr.add_workload("trace", pool['2x2'])
OpArgMngr.add_workload("transpose", pool['2x2'])
OpArgMngr.add_workload("split", pool['3x3'], (0, 1, 2), axis=1)
OpArgMngr.add_workload("vstack", (pool['3x3'], pool['3x3'], pool['3x3']))
OpArgMngr.add_workload("argmax", pool['3x2'], axis=-1)
OpArgMngr.add_workload("argmin", pool['3x2'], axis=-1)
OpArgMngr.add_workload("atleast_1d", pool['2'], pool['2x2'])
OpArgMngr.add_workload("atleast_2d", pool['2'], pool['2x2'])
OpArgMngr.add_workload("atleast_3d", pool['2'], pool['2x2'])
OpArgMngr.add_workload("argsort", pool['3x2'], axis=-1)
OpArgMngr.add_workload("sort", pool['3x2'], axis=-1)
OpArgMngr.add_workload("indices", dimensions=(1, 2, 3))
OpArgMngr.add_workload("subtract", pool['2x2'], pool['2x2'])
OpArgMngr.add_workload("multiply", pool['2x2'], pool['2x2'])
Expand All @@ -115,6 +117,10 @@ def prepare_workloads():
OpArgMngr.add_workload("power", pool['2x2'], pool['2x2'])
OpArgMngr.add_workload("lcm", pool['2x2'].astype('int32'), pool['2x2'].astype('int32'))
OpArgMngr.add_workload("diff", pool['2x2'], n=1, axis=-1)
OpArgMngr.add_workload("inner", pool['2x2'], pool['2x2'])
OpArgMngr.add_workload("random.multinomial", n=2, pvals=[1/6.]*6, size=(2,2))
OpArgMngr.add_workload("random.rand", 3, 2)
OpArgMngr.add_workload("random.randn", 2, 2)
OpArgMngr.add_workload("nonzero", pool['2x2'])
OpArgMngr.add_workload("tril", pool['2x2'], k=0)
OpArgMngr.add_workload("random.choice", pool['2'], size=(2, 2))
Expand Down Expand Up @@ -144,9 +150,6 @@ def prepare_workloads():
OpArgMngr.add_workload("random.logistic", loc=2, scale=2, size=(2,2))
OpArgMngr.add_workload("random.gumbel", loc=2, scale=2, size=(2,2))
OpArgMngr.add_workload("where", pool['2x3'], pool['2x3'], pool['2x1'])
OpArgMngr.add_workload("fmax", pool['2x2'], pool['2x2'])
OpArgMngr.add_workload("fmin", pool['2x2'], pool['2x2'])
OpArgMngr.add_workload("fmod", pool['2x2'], pool['2x2'])
OpArgMngr.add_workload("may_share_memory", pool['2x3'][:0], pool['2x3'][:1])
OpArgMngr.add_workload('squeeze', pool['2x2'], axis=None)
OpArgMngr.add_workload("pad", pool['2x2'], pad_width=((1,2),(1,2)), mode="constant")
Expand Down
8 changes: 4 additions & 4 deletions python/mxnet/ndarray/numpy/_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -1620,7 +1620,7 @@ def argsort(a, axis=-1, kind=None, order=None):
if order is not None:
raise NotImplementedError("order not supported here")

return _npi.argsort(data=a, axis=axis, is_ascend=True, dtype='int64')
return _api_internal.argsort(a, axis, True, 'int64')


@set_module('mxnet.ndarray.numpy')
Expand Down Expand Up @@ -1664,7 +1664,7 @@ def sort(a, axis=-1, kind=None, order=None):
"""
if order is not None:
raise NotImplementedError("order not supported here")
return _npi.sort(data=a, axis=axis, is_ascend=True)
return _api_internal.sort(a, axis, True)


@set_module('mxnet.ndarray.numpy')
Expand Down Expand Up @@ -4581,7 +4581,7 @@ def get_list(arrays):
return [arr for arr in arrays]

arrays = get_list(arrays)
return _npi.vstack(*arrays)
return _api_internal.vstack(*arrays)


@set_module('mxnet.ndarray.numpy')
Expand Down Expand Up @@ -4626,7 +4626,7 @@ def get_list(arrays):
return [arr for arr in arrays]

arrays = get_list(arrays)
return _npi.vstack(*arrays)
return _api_internal.vstack(*arrays)


@set_module('mxnet.ndarray.numpy')
Expand Down
14 changes: 5 additions & 9 deletions python/mxnet/ndarray/numpy/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
from ...context import current_context
from . import _internal as _npi
from . import _api_internal
from ..ndarray import NDArray


__all__ = ['randint', 'uniform', 'normal', "choice", "rand", "multinomial", "multivariate_normal",
Expand Down Expand Up @@ -331,14 +330,11 @@ def multinomial(n, pvals, size=None):
>>> np.random.multinomial(100, [1.0 / 3, 2.0 / 3])
array([32, 68])
"""
if isinstance(pvals, NDArray):
return _npi.multinomial(pvals, pvals=None, n=n, size=size)
else:
if isinstance(pvals, np.ndarray):
raise ValueError('numpy ndarray is not supported!')
if any(isinstance(i, list) for i in pvals):
raise ValueError('object too deep for desired array')
return _npi.multinomial(n=n, pvals=pvals, size=size)
if isinstance(pvals, np.ndarray):
raise ValueError('numpy ndarray is not supported!')
if any(isinstance(i, list) for i in pvals):
raise ValueError('object too deep for desired array')
return _api_internal.multinomial(n, pvals, size)


def rayleigh(scale=1.0, size=None, ctx=None, out=None):
Expand Down
20 changes: 20 additions & 0 deletions src/api/operator/numpy/np_matrix_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -615,4 +615,24 @@ MXNET_REGISTER_API("_npi.tril_indices")
*ret = ADT(0, ndarray_handles.begin(), ndarray_handles.end());
});

MXNET_REGISTER_API("_npi.vstack")
.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
using namespace runtime;
const nnvm::Op* op = Op::Get("_npi_vstack");
nnvm::NodeAttrs attrs;
op::NumpyVstackParam param;
param.num_args = args.size();

attrs.parsed = param;
attrs.op = op;
SetAttrDict<op::NumpyVstackParam>(&attrs);
int num_outputs = 0;
std::vector<NDArray*> inputs_vec(args.size(), nullptr);
for (int i = 0; i < args.size(); ++i) {
inputs_vec[i] = args[i].operator mxnet::NDArray*();
}
NDArray** inputs = inputs_vec.data();
auto ndoutputs = Invoke(op, &attrs, param.num_args, inputs, &num_outputs, nullptr);
*ret = ndoutputs[0];
});
} // namespace mxnet
88 changes: 88 additions & 0 deletions src/api/operator/numpy/np_ordering_op.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file np_ordering_op.cc
* \brief Implementation of the API of functions in src/operator/tensor/ordering_op.cc
*/
#include <mxnet/api_registry.h>
#include <mxnet/runtime/packed_func.h>
#include "../utils.h"
#include "../../../operator/tensor/ordering_op-inl.h"

namespace mxnet {

MXNET_REGISTER_API("_npi.sort")
.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
using namespace runtime;
const nnvm::Op* op = Op::Get("_npi_sort");
nnvm::NodeAttrs attrs;
op::SortParam param;

if (args[1].type_code() == kNull) {
param.axis = dmlc::nullopt;
} else {
param.axis = args[1].operator int();
}
param.is_ascend = true;

attrs.parsed = std::move(param);
attrs.op = op;

int num_inputs = 1;
NDArray* inputs[] = {args[0].operator mxnet::NDArray*()};

int num_outputs = 0;
SetAttrDict<op::SortParam>(&attrs);
auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr);
*ret = reinterpret_cast<mxnet::NDArray*>(ndoutputs[0]);
});

MXNET_REGISTER_API("_npi.argsort")
.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
using namespace runtime;
const nnvm::Op* op = Op::Get("_npi_argsort");
nnvm::NodeAttrs attrs;
op::ArgSortParam param;

if (args[1].type_code() == kNull) {
param.axis = dmlc::nullopt;
} else {
param.axis = args[1].operator int();
}
param.is_ascend = true;
if (args[3].type_code() == kNull) {
param.dtype = mshadow::kFloat32;
} else {
param.dtype = String2MXNetTypeWithBool(args[3].operator std::string());
}

attrs.parsed = std::move(param);
attrs.op = op;

int num_inputs = 1;
NDArray* inputs[] = {args[0].operator mxnet::NDArray*()};

int num_outputs = 0;
SetAttrDict<op::ArgSortParam>(&attrs);
auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr);
*ret = reinterpret_cast<mxnet::NDArray*>(ndoutputs[0]);
});

} // namespace mxnet
75 changes: 75 additions & 0 deletions src/api/operator/numpy/random/np_multinomial_op.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file np_multinomial_op.cc
* \brief Implementation of the API of functions in src/operator/numpy/random/np_multinomial_op.cc
*/
#include <mxnet/api_registry.h>
#include <mxnet/runtime/packed_func.h>
#include <vector>
#include "../../utils.h"
#include "../../../../operator/numpy/random/np_multinomial_op.h"

namespace mxnet {

MXNET_REGISTER_API("_npi.multinomial")
.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
using namespace runtime;
const nnvm::Op* op = Op::Get("_npi_multinomial");
nnvm::NodeAttrs attrs;
op::NumpyMultinomialParam param;
NDArray** inputs = new NDArray*[1]();
int num_inputs = 0;

// parse int
param.n = args[0].operator int();

// parse pvals
if (args[1].type_code() == kNull) {
param.pvals = dmlc::nullopt;
} else if (args[1].type_code() == kNDArrayHandle) {
param.pvals = dmlc::nullopt;
inputs[0] = args[1].operator mxnet::NDArray*();
num_inputs = 1;
} else {
param.pvals = Obj2Tuple<double, Float>(args[1].operator ObjectRef());
}

// parse size
if (args[2].type_code() == kNull) {
param.size = dmlc::nullopt;
} else {
if (args[2].type_code() == kDLInt) {
param.size = mxnet::Tuple<int>(1, args[2].operator int64_t());
} else {
param.size = mxnet::Tuple<int>(args[2].operator ObjectRef());
}
}

attrs.parsed = std::move(param);
attrs.op = op;
SetAttrDict<op::NumpyMultinomialParam>(&attrs);
inputs = num_inputs == 0 ? nullptr : inputs;
int num_outputs = 0;
auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr);
*ret = ndoutputs[0];
});

} // namespace mxnet
5 changes: 5 additions & 0 deletions src/operator/numpy/np_matrix_op-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,11 @@ struct NumpyVstackParam : public dmlc::Parameter<NumpyVstackParam> {
DMLC_DECLARE_FIELD(num_args).set_lower_bound(1)
.describe("Number of inputs to be vstacked.");
}
void SetAttrDict(std::unordered_map<std::string, std::string>* dict) {
std::ostringstream num_args_s;
num_args_s << num_args;
(*dict)["num_args"] = num_args_s.str();
}
};

struct NumpyColumnStackParam : public dmlc::Parameter<NumpyColumnStackParam> {
Expand Down
10 changes: 10 additions & 0 deletions src/operator/numpy/random/np_multinomial_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

#include <mxnet/operator_util.h>
#include <vector>
#include <string>
#include "../../mshadow_op.h"
#include "../../mxnet_op.h"
#include "../../operator_common.h"
Expand Down Expand Up @@ -55,6 +56,15 @@ struct NumpyMultinomialParam : public dmlc::Parameter<NumpyMultinomialParam> {
"e.g., (m, n, k), then m * n * k samples are drawn. "
"Default is None, in which case a single value is returned.");
}
void SetAttrDict(std::unordered_map<std::string, std::string>* dict) {
std::ostringstream n_s, pvals_s, size_s;
n_s << n;
pvals_s << pvals;
size_s << size;
(*dict)["n"] = n_s.str();
(*dict)["pvals"] = pvals_s.str();
(*dict)["size"] = size_s.str();
}
};

inline bool NumpyMultinomialOpShape(const nnvm::NodeAttrs& attrs,
Expand Down
18 changes: 18 additions & 0 deletions src/operator/tensor/ordering_op-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,13 @@
#include <mshadow/tensor.h>
#include <algorithm>
#include <vector>
#include <string>
#include <type_traits>
#include "../mshadow_op.h"
#include "../elemwise_op_common.h"
#include "./sort_op.h"
#include "./indexing_op.h"
#include "../../api/operator/op_utils.h"

namespace mshadow {
template<typename xpu, int src_dim, typename DType, int dst_dim>
Expand Down Expand Up @@ -105,6 +107,13 @@ struct SortParam : public dmlc::Parameter<SortParam> {
DMLC_DECLARE_FIELD(is_ascend).set_default(true)
.describe("Whether to sort in ascending or descending order.");
}
void SetAttrDict(std::unordered_map<std::string, std::string>* dict) {
std::ostringstream axis_s, is_ascend_s;
axis_s << axis;
is_ascend_s << is_ascend;
(*dict)["axis"] = axis_s.str();
(*dict)["is_ascend_s"] = is_ascend_s.str();
}
};

struct ArgSortParam : public dmlc::Parameter<ArgSortParam> {
Expand All @@ -130,6 +139,15 @@ struct ArgSortParam : public dmlc::Parameter<ArgSortParam> {
" \"both\". An error will be raised if the selected data type cannot precisely "
"represent the indices.");
}
void SetAttrDict(std::unordered_map<std::string, std::string>* dict) {
std::ostringstream axis_s, is_ascend_s, dtype_s;
axis_s << axis;
is_ascend_s << is_ascend;
dtype_s << dtype;
(*dict)["axis"] = axis_s.str();
(*dict)["is_ascend_s"] = is_ascend_s.str();
(*dict)["dtype"] = MXNetTypeWithBool2String(dtype);
}
};

inline void ParseTopKParam(const TShape& src_shape,
Expand Down

0 comments on commit d0e17e5

Please sign in to comment.