Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[Numpy] FFI: random.choice, take and clip (#17854)
Browse files Browse the repository at this point in the history
* change the header file of np.random.choice

* add np_choice_op.cc file

* add including header file

* implement the basic function of random.choice

* try to use take op in backend

* try to use take op in backend

* add take invoking function

* fix some syntax problems

* fix some problems

* complete numpy.random.choice ffi

* first commit of ffi indexing_op.cc

* add random.choice ffi benchmark

* complete take ffi

* change the implementation of random.choice

* add take op benchmark

* complete clip op ffi and fix a problem

* add clip op benchmark

* fix some sanity problems

* add space before ( and fix reimport

* fix a typo

* remove dead code and remove new operator

Co-authored-by: Ubuntu <[email protected]>
  • Loading branch information
AntiZpvoh and Ubuntu authored Apr 13, 2020
1 parent 7a59239 commit e3d7866
Show file tree
Hide file tree
Showing 10 changed files with 292 additions and 23 deletions.
3 changes: 3 additions & 0 deletions benchmark/python/ffi/benchmark_ffi.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,9 @@ def prepare_workloads():
OpArgMngr.add_workload("diff", pool['2x2'], n=1, axis=-1)
OpArgMngr.add_workload("nonzero", pool['2x2'])
OpArgMngr.add_workload("tril", pool['2x2'], k=0)
OpArgMngr.add_workload("random.choice", pool['2'], size=(2, 2))
OpArgMngr.add_workload("take", pool['2'], dnp.array([1,0], dtype='int64'))
OpArgMngr.add_workload("clip", pool['2x2'], 0, 1)
OpArgMngr.add_workload("expand_dims", pool['2x2'], axis=0)
OpArgMngr.add_workload("broadcast_to", pool['2x2'], (2, 2, 2))
OpArgMngr.add_workload("full_like", pool['2x2'], 2)
Expand Down
10 changes: 3 additions & 7 deletions python/mxnet/ndarray/numpy/_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -690,9 +690,9 @@ def take(a, indices, axis=None, mode='raise', out=None):
raise NotImplementedError(
"function take does not support mode '{}'".format(mode))
if axis is None:
return _npi.take(_npi.reshape(a, -1), indices, 0, mode, out)
return _api_internal.take(_npi.reshape(a, -1), indices, 0, mode, out)
else:
return _npi.take(a, indices, axis, mode, out)
return _api_internal.take(a, indices, axis, mode, out)
# pylint: enable=redefined-outer-name


Expand Down Expand Up @@ -4551,11 +4551,7 @@ def clip(a, a_min, a_max, out=None):
"""
if a_min is None and a_max is None:
raise ValueError('array_clip: must set either max or min')
if a_min is None:
a_min = float('-inf')
if a_max is None:
a_max = float('inf')
return _npi.clip(a, a_min, a_max, out=out)
return _api_internal.clip(a, a_min, a_max, out)


@set_module('mxnet.ndarray.numpy')
Expand Down
22 changes: 7 additions & 15 deletions python/mxnet/ndarray/numpy/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -535,24 +535,16 @@ def choice(a, size=None, replace=True, p=None, ctx=None, out=None):
"""
from ...numpy import ndarray as np_ndarray
if ctx is None:
ctx = current_context()
ctx = str(current_context())
else:
ctx = str(ctx)
if size == ():
size = None
if isinstance(a, np_ndarray):
ctx = None
if p is None:
indices = _npi.choice(a, a=None, size=size,
replace=replace, ctx=ctx, weighted=False)
return _npi.take(a, indices)
else:
indices = _npi.choice(a, p, a=None, size=size,
replace=replace, ctx=ctx, weighted=True)
return _npi.take(a, indices)
indices = _api_internal.choice(a, size, replace, p, ctx, out)
return _api_internal.take(a, indices, 0, 'raise', out)
else:
if p is None:
return _npi.choice(a=a, size=size, replace=replace, ctx=ctx, weighted=False, out=out)
else:
return _npi.choice(p, a=a, size=size, replace=replace, ctx=ctx, weighted=True, out=out)
return _api_internal.choice(a, size, replace, p, ctx, out)


def exponential(scale=1.0, size=None, ctx=None, out=None):
Expand Down Expand Up @@ -834,7 +826,7 @@ def beta(a, b, size=None, dtype=None, ctx=None):
# use fp64 to prevent precision loss
X = gamma(a, 1, size=size, dtype='float64', ctx=ctx)
Y = gamma(b, 1, size=size, dtype='float64', ctx=ctx)
out = X/(X + Y)
out = X / (X + Y)
return out.astype(dtype)


Expand Down
89 changes: 89 additions & 0 deletions src/api/operator/numpy/random/np_choice_op.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file np_choice_op.cc
* \brief Implementation of the API of functions in src/operator/numpy/np_choice_op.cc
*/
#include <mxnet/api_registry.h>
#include <mxnet/runtime/packed_func.h>
#include "../../utils.h"
#include "../../../../operator/numpy/random/np_choice_op.h"

namespace mxnet {

MXNET_REGISTER_API("_npi.choice")
.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
using namespace runtime;
const nnvm::Op* op = Op::Get("_npi_choice");
nnvm::NodeAttrs attrs;
op::NumpyChoiceParam param;

NDArray* inputs[2];
int num_inputs = 0;

if (args[0].type_code() == kDLInt) {
param.a = args[0].operator int();
} else if (args[0].type_code() == kNDArrayHandle) {
param.a = dmlc::nullopt;
inputs[num_inputs] = args[0].operator mxnet::NDArray*();
num_inputs++;
}

if (args[1].type_code() == kNull) {
param.size = dmlc::nullopt;
} else {
if (args[1].type_code() == kDLInt) {
param.size = mxnet::Tuple<int64_t>(1, args[1].operator int64_t());
} else {
param.size = mxnet::Tuple<int64_t>(args[1].operator ObjectRef());
}
}

if (args[2].type_code() == kNull) {
param.replace = true;
} else {
param.replace = args[2].operator bool();
}

if (args[3].type_code() == kNull) {
param.weighted = false;
} else if (args[0].type_code() == kNDArrayHandle) {
param.weighted = true;
inputs[num_inputs] = args[3].operator mxnet::NDArray*();
num_inputs++;
}

attrs.parsed = std::move(param);
attrs.op = op;
if (args[4].type_code() != kNull) {
attrs.dict["ctx"] = args[4].operator std::string();
}
NDArray* out = args[5].operator mxnet::NDArray*();
NDArray** outputs = out == nullptr ? nullptr : &out;
int num_outputs = out != nullptr;
auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, outputs);
if (out) {
*ret = PythonArg(5);
} else {
*ret = ndoutputs[0];
}
});

} // namespace mxnet
2 changes: 1 addition & 1 deletion src/api/operator/numpy/random/np_laplace_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

/*!
* \file np_laplace_op.cc
* \brief Implementation of the API of functions in src/operator/numpy/np_laplace_op.cc
* \brief Implementation of the API of functions in src/operator/numpy/random/np_laplace_op.cc
*/
#include <mxnet/api_registry.h>
#include <mxnet/runtime/packed_func.h>
Expand Down
78 changes: 78 additions & 0 deletions src/api/operator/tensor/indexing_op.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file indexing_op.cc
* \brief Implementation of the API of functions in src/operator/tensor/indexing_op.cc
*/
#include <mxnet/api_registry.h>
#include <mxnet/runtime/packed_func.h>
#include "../utils.h"
#include "../../../operator/tensor/indexing_op.h"

namespace mxnet {

MXNET_REGISTER_API("_npi.take")
.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
using namespace runtime;
const nnvm::Op* op = Op::Get("_npi_take");
nnvm::NodeAttrs attrs;
op::TakeParam param;
NDArray* inputs[2];

if (args[0].type_code() != kNull) {
inputs[0] = args[0].operator mxnet::NDArray *();
}

if (args[1].type_code() != kNull) {
inputs[1] = args[1].operator mxnet::NDArray *();
}

if (args[2].type_code() == kDLInt) {
param.axis = args[2].operator int();
}

if (args[3].type_code() != kNull) {
std::string mode = args[3].operator std::string();
if (mode == "raise") {
param.mode = op::take_::kRaise;
} else if (mode == "clip") {
param.mode = op::take_::kClip;
} else if (mode == "wrap") {
param.mode = op::take_::kWrap;
}
}

attrs.parsed = param;
attrs.op = op;
SetAttrDict<op::TakeParam>(&attrs);

NDArray* out = args[4].operator mxnet::NDArray*();
NDArray** outputs = out == nullptr ? nullptr : &out;
// set the number of outputs provided by the `out` arugment
int num_outputs = out != nullptr;
auto ndoutputs = Invoke(op, &attrs, 2, inputs, &num_outputs, outputs);
if (out) {
*ret = PythonArg(4);
} else {
*ret = ndoutputs[0];
}
});

} // namespace mxnet
71 changes: 71 additions & 0 deletions src/api/operator/tensor/matrix_op.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file matrix_op.cc
* \brief Implementation of the API of functions in src/operator/tensor/matrix_op.cc
*/
#include <mxnet/api_registry.h>
#include <mxnet/runtime/packed_func.h>
#include "../utils.h"
#include "../../../operator/tensor/matrix_op-inl.h"

namespace mxnet {

MXNET_REGISTER_API("_npi.clip")
.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
using namespace runtime;
const nnvm::Op* op = Op::Get("_npi_clip");
nnvm::NodeAttrs attrs;
op::ClipParam param;
NDArray* inputs[1];

if (args[0].type_code() != kNull) {
inputs[0] = args[0].operator mxnet::NDArray *();
}

if (args[1].type_code() != kNull) {
param.a_min = args[1].operator double();
} else {
param.a_min = -INFINITY;
}

if (args[2].type_code() != kNull) {
param.a_max = args[2].operator double();
} else {
param.a_max = INFINITY;
}

attrs.parsed = param;
attrs.op = op;
SetAttrDict<op::ClipParam>(&attrs);

NDArray* out = args[3].operator mxnet::NDArray*();
NDArray** outputs = out == nullptr ? nullptr : &out;
// set the number of outputs provided by the `out` arugment
int num_outputs = out != nullptr;
auto ndoutputs = Invoke(op, &attrs, 1, inputs, &num_outputs, outputs);
if (out) {
*ret = PythonArg(3);
} else {
*ret = ndoutputs[0];
}
});

} // namespace mxnet
11 changes: 11 additions & 0 deletions src/operator/numpy/random/np_choice_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,17 @@ struct NumpyChoiceParam : public dmlc::Parameter<NumpyChoiceParam> {
DMLC_DECLARE_FIELD(replace).set_default(true);
DMLC_DECLARE_FIELD(weighted).set_default(false);
}
void SetAttrDict(std::unordered_map<std::string, std::string>* dict) {
std::ostringstream a_s, size_s, replace_s, weighted_s;
a_s << a;
size_s << size;
replace_s << replace;
weighted_s << weighted;
(*dict)["a"] = a_s.str();
(*dict)["size"] = size_s.str();
(*dict)["replace"] = replace_s.str();
(*dict)["weighted"] = weighted_s.str();
}
};

inline bool NumpyChoiceOpType(const nnvm::NodeAttrs &attrs,
Expand Down
21 changes: 21 additions & 0 deletions src/operator/tensor/indexing_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -680,6 +680,27 @@ struct TakeParam: public dmlc::Parameter<TakeParam> {
" \"wrap\" means to wrap around."
" \"raise\" means to raise an error when index out of range.");
}

void SetAttrDict(std::unordered_map<std::string, std::string>* dict) {
std::ostringstream axis_s, mode_s;
axis_s << axis;
mode_s << mode;
(*dict)["axis"] = axis_s.str();
(*dict)["mode"] = mode_s.str();
switch (mode) {
case take_::kRaise:
(*dict)["mode"] = "raise";
break;
case take_::kClip:
(*dict)["mode"] = "clip";
break;
case take_::kWrap:
(*dict)["mode"] = "wrap";
break;
default:
(*dict)["mode"] = mode_s.str();
}
}
};

inline bool TakeOpShape(const nnvm::NodeAttrs& attrs,
Expand Down
8 changes: 8 additions & 0 deletions src/operator/tensor/matrix_op-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -1605,6 +1605,14 @@ struct ClipParam : public dmlc::Parameter<ClipParam> {
DMLC_DECLARE_FIELD(a_max)
.describe("Maximum value");
}

void SetAttrDict(std::unordered_map<std::string, std::string>* dict) {
std::ostringstream a_min_s, a_max_s;
a_min_s << a_min;
a_max_s << a_max;
(*dict)["a_min"] = a_min_s.str();
(*dict)["a_max"] = a_max_s.str();
}
};


Expand Down

0 comments on commit e3d7866

Please sign in to comment.