Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[FFI] Randint (#20083)
Browse files Browse the repository at this point in the history
  • Loading branch information
barry-jin authored Mar 26, 2021
1 parent c7a8ccc commit 03e7cc2
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 3 deletions.
10 changes: 7 additions & 3 deletions python/mxnet/ndarray/numpy/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,15 +79,19 @@ def randint(low, high=None, size=None, dtype=None, ctx=None, out=None):
[3, 2, 2, 0]])
"""
if dtype is None:
dtype = 'int'
dtype = 'int64'
elif not isinstance(dtype, str):
dtype = np.dtype(dtype).name
if ctx is None:
ctx = current_context()
ctx = str(current_context())
else:
ctx = str(ctx)
if size is None:
size = ()
if high is None:
high = low
low = 0
return _npi.random_randint(low, high, shape=size, dtype=dtype, ctx=ctx, out=out)
return _api_internal.randint(low, high, size, dtype, ctx, out)


def uniform(low=0.0, high=1.0, size=None, dtype=None, ctx=None, out=None):
Expand Down
68 changes: 68 additions & 0 deletions src/api/operator/random/np_randint_op.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file np_randint_op.cc
* \brief Implementation of the API of functions in src/operator/numpy/random/np_randint_op.cc
*/
#include <mxnet/api_registry.h>
#include <mxnet/runtime/packed_func.h>
#include <vector>
#include "../utils.h"
#include "../../../operator/random/sample_op.h"

namespace mxnet {

MXNET_REGISTER_API("_npi.randint")
.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
using namespace runtime;
const nnvm::Op* op = Op::Get("_npi_random_randint");
nnvm::NodeAttrs attrs;
op::SampleRandIntParam param;
int num_inputs = 0;
param.low = args[0].operator int();
param.high = args[1].operator int();
if (args[2].type_code() == kDLInt) {
param.shape = TShape(1, args[2].operator int64_t());
} else {
param.shape = TShape(args[2].operator ObjectRef());
}
if (args[3].type_code() == kNull) {
param.dtype = mxnet::common::GetDefaultDtype();
} else {
param.dtype = String2MXNetTypeWithBool(args[3].operator std::string());
}
attrs.parsed = param;
attrs.op = op;
if (args[4].type_code() != kNull) {
attrs.dict["ctx"] = args[4].operator std::string();
}
NDArray* out = args[5].operator mxnet::NDArray*();
NDArray** outputs = out == nullptr ? nullptr : &out;
int num_outputs = out != nullptr;
SetAttrDict<op::SampleRandIntParam>(&attrs);
auto ndoutputs = Invoke(op, &attrs, num_inputs, nullptr, &num_outputs, outputs);
if (out) {
*ret = PythonArg(5);
} else {
*ret = reinterpret_cast<mxnet::NDArray*>(ndoutputs[0]);
}
});

} // namespace mxnet
11 changes: 11 additions & 0 deletions src/operator/random/sample_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,17 @@ struct SampleRandIntParam : public dmlc::Parameter<SampleRandIntParam>,
.describe("DType of the output in case this can't be inferred. "
"Defaults to int32 if not defined (dtype=None).");
}
void SetAttrDict(std::unordered_map<std::string, std::string>* dict) {
std::ostringstream low_s, high_s, dtype_s, shape_s;
low_s << low;
high_s << high;
dtype_s << dtype;
shape_s << shape;
(*dict)["low"] = low_s.str();
(*dict)["high"] = high_s.str();
(*dict)["dtype"] = MXNetTypeWithBool2String(dtype);
(*dict)["shape"] = shape_s.str();
}
};

struct SampleUniformLikeParam : public dmlc::Parameter<SampleUniformLikeParam>,
Expand Down

0 comments on commit 03e7cc2

Please sign in to comment.