Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
* impl - FFI for np_argmax and np_argmin
Browse files Browse the repository at this point in the history
* impl - FFI for np_indices

* fix - use MXNetTypeWithBool2String
  • Loading branch information
Ubuntu committed Mar 18, 2020
1 parent 2fae7e4 commit 4ed11d1
Show file tree
Hide file tree
Showing 6 changed files with 154 additions and 4 deletions.
3 changes: 3 additions & 0 deletions benchmark/python/ffi/benchmark_ffi.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,9 @@ def prepare_workloads():
OpArgMngr.add_workload("random.uniform", low=0, high=1, size=1)
OpArgMngr.add_workload("where", pool['2x3'], pool['2x3'], pool['2x1'])
OpArgMngr.add_workload("may_share_memory", pool['2x3'][:0], pool['2x3'][:1])
OpArgMngr.add_workload("argmax", pool['3x2'], axis=-1)
OpArgMngr.add_workload("argmin", pool['3x2'], axis=-1)
OpArgMngr.add_workload("indices", dimensions=(1, 2, 3))


def benchmark_helper(f, *args, **kwargs):
Expand Down
10 changes: 6 additions & 4 deletions python/mxnet/ndarray/numpy/_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -4529,7 +4529,7 @@ def argmax(a, axis=None, out=None):
>>> b
array([2., 2.])
"""
return _npi.argmax(a, axis=axis, keepdims=False, out=out)
return _api_internal.argmax(a, axis, False, out)


@set_module('mxnet.ndarray.numpy')
Expand Down Expand Up @@ -4597,7 +4597,7 @@ def argmin(a, axis=None, out=None):
>>> b
array([0., 0.])
"""
return _npi.argmin(a, axis=axis, keepdims=False, out=out)
return _api_internal.argmin(a, axis, False, out)


@set_module('mxnet.ndarray.numpy')
Expand Down Expand Up @@ -4945,8 +4945,10 @@ def indices(dimensions, dtype=_np.int32, ctx=None):
"""
if isinstance(dimensions, (tuple, list)):
if ctx is None:
ctx = current_context()
return _npi.indices(dimensions=dimensions, dtype=dtype, ctx=ctx)
ctx = str(current_context())
else:
ctx = str(ctx)
return _api_internal.indices(dimensions, dtype, ctx)
else:
raise ValueError("The dimensions must be sequence of ints")
# pylint: enable=redefined-outer-name
Expand Down
98 changes: 98 additions & 0 deletions src/api/operator/numpy/np_broadcast_reduce_op_index.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file np_broadcast_reduce_op_index.cc
* \brief Implementation of the API of functions in
src/operator/numpy/np_broadcast_reduce_op_index.cc
*/
#include <mxnet/api_registry.h>
#include <mxnet/runtime/packed_func.h>
#include "../utils.h"
#include "../../../operator/tensor/broadcast_reduce_op.h"

namespace mxnet {

MXNET_REGISTER_API("_npi.argmax")
.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
using namespace runtime;
const nnvm::Op* op = Op::Get("_npi_argmax");
nnvm::NodeAttrs attrs;
op::ReduceAxisParam param;
// param.axis
if (args[1].type_code() == kNull) {
param.axis = dmlc::nullopt;
} else {
param.axis = args[1].operator int();
}
// param.keepdims
param.keepdims = args[2].operator bool();

attrs.parsed = std::move(param);
attrs.op = op;
SetAttrDict<op::ReduceAxisParam>(&attrs);
// inputs
NDArray* inputs[] = {args[0].operator mxnet::NDArray*()};
int num_inputs = 1;
// outputs
NDArray* out = args[3].operator mxnet::NDArray*();
NDArray** outputs = out == nullptr ? nullptr : &out;
int num_outputs = out != nullptr;
auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, outputs);
if (out) {
*ret = PythonArg(3);
} else {
*ret = reinterpret_cast<mxnet::NDArray*>(ndoutputs[0]);
}
});

MXNET_REGISTER_API("_npi.argmin")
.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
using namespace runtime;
const nnvm::Op* op = Op::Get("_npi_argmin");
nnvm::NodeAttrs attrs;
op::ReduceAxisParam param;
// param.axis
if (args[1].type_code() == kNull) {
param.axis = dmlc::nullopt;
} else {
param.axis = args[1].operator int();
}
// param.keepdims
param.keepdims = args[2].operator bool();

attrs.parsed = std::move(param);
attrs.op = op;
SetAttrDict<op::ReduceAxisParam>(&attrs);
// inputs
NDArray* inputs[] = {args[0].operator mxnet::NDArray*()};
int num_inputs = 1;
// outputs
NDArray* out = args[3].operator mxnet::NDArray*();
NDArray** outputs = out == nullptr ? nullptr : &out;
int num_outputs = out != nullptr;
auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, outputs);
if (out) {
*ret = PythonArg(3);
} else {
*ret = reinterpret_cast<mxnet::NDArray*>(ndoutputs[0]);
}
});

} // namespace mxnet
32 changes: 32 additions & 0 deletions src/api/operator/numpy/np_init_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include <mxnet/runtime/packed_func.h>
#include "../utils.h"
#include "../../../operator/tensor/init_op.h"
#include "../../../operator/numpy/np_init_op.h"

namespace mxnet {

Expand Down Expand Up @@ -88,4 +89,35 @@ MXNET_REGISTER_API("_npi.full_like")
*ret = ndoutputs[0];
});

MXNET_REGISTER_API("_npi.indices")
.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
using namespace runtime;
const nnvm::Op* op = Op::Get("_npi_indices");
nnvm::NodeAttrs attrs;
op::IndicesOpParam param;
// param.dimensions
if (args[0].type_code() == kDLInt) {
param.dimensions = TShape(1, args[0].operator int64_t());
} else {
param.dimensions = TShape(args[0].operator ObjectRef());
}
// param.dtype
if (args[1].type_code() == kNull) {
param.dtype = mshadow::kInt32;
} else {
param.dtype = String2MXNetTypeWithBool(args[1].operator std::string());
}
attrs.parsed = std::move(param);
attrs.op = op;
SetAttrDict<op::IndicesOpParam>(&attrs);
// param.ctx
if (args[2].type_code() != kNull) {
attrs.dict["ctx"] = args[2].operator std::string();
}
int num_inputs = 0;
int num_outputs = 0;
auto ndoutputs = Invoke(op, &attrs, num_inputs, nullptr, &num_outputs, nullptr);
*ret = ndoutputs[0];
});

} // namespace mxnet
8 changes: 8 additions & 0 deletions src/operator/numpy/np_init_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#include <string>
#include "../tensor/init_op.h"
#include "../tensor/elemwise_unary_op.h"
#include "../../api/operator/op_utils.h"


namespace mxnet {
Expand Down Expand Up @@ -79,6 +80,13 @@ struct IndicesOpParam : public dmlc::Parameter<IndicesOpParam> {
.describe("Context of output, in format [cpu|gpu|cpu_pinned](n)."
"Only used for imperative calls.");
}
void SetAttrDict(std::unordered_map<std::string, std::string>* dict) {
std::ostringstream dimensions_s, dtype_s;
dimensions_s << dimensions;
dtype_s << dtype;
(*dict)["dimensions"] = dimensions_s.str();
(*dict)["dtype"] = MXNetTypeWithBool2String(dtype);
}
};

inline bool NumpyRangeShape(const nnvm::NodeAttrs& attrs,
Expand Down
7 changes: 7 additions & 0 deletions src/operator/tensor/broadcast_reduce_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,13 @@ struct ReduceAxisParam : public dmlc::Parameter<ReduceAxisParam> {
.describe("If this is set to `True`, the reduced axis is left "
"in the result as dimension with size one.");
}
void SetAttrDict(std::unordered_map<std::string, std::string>* dict) {
std::ostringstream axis_s, keepdims_s;
axis_s << axis;
keepdims_s << keepdims;
(*dict)["axis"] = axis_s.str();
(*dict)["keepdims"] = keepdims_s.str();
}
};

enum PickOpMode {kWrap, kClip};
Expand Down

0 comments on commit 4ed11d1

Please sign in to comment.