Skip to content

Commit

Permalink
[Numpy] FFI: tril_indices (apache#18546)
Browse files Browse the repository at this point in the history
* add numpy tril_indices ffi

* Update src/api/operator/numpy/np_matrix_op.cc

Co-authored-by: Haozheng Fan <[email protected]>

Co-authored-by: Haozheng Fan <[email protected]>
  • Loading branch information
2 people authored and AntiZpvoh committed Jul 6, 2020
1 parent 4ffe1d9 commit cd4b222
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 3 deletions.
2 changes: 1 addition & 1 deletion python/mxnet/ndarray/numpy/_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -5231,7 +5231,7 @@ def tril_indices(n, k=0, m=None):
"""
if m is None:
m = n
return tuple(_npi.tril_indices(n, k, m))
return tuple(_api_internal.tril_indices(n, k, m))


@set_module('mxnet.ndarray.numpy')
Expand Down
2 changes: 1 addition & 1 deletion python/mxnet/numpy/multiarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -6002,7 +6002,7 @@ def tril_indices(n, k=0, m=None):
"""
if m is None:
m = n
return tuple(_mx_nd_np.tril_indices(n, k, m))
return _mx_nd_np.tril_indices(n, k, m)


# pylint: disable=redefined-outer-name
Expand Down
24 changes: 24 additions & 0 deletions src/api/operator/numpy/np_matrix_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -510,4 +510,28 @@ MXNET_REGISTER_API("_npi.squeeze")
*ret = ndoutputs[0];
});

MXNET_REGISTER_API("_npi.tril_indices")
.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
using namespace runtime;
const nnvm::Op* op = Op::Get("_npi_tril_indices");
nnvm::NodeAttrs attrs;
op::NumpyTrilindicesParam param;
param.n = args[0].operator int();
param.k = args[1].operator int();
param.m = args[2].operator int();

attrs.parsed = param;
attrs.op = op;
SetAttrDict<op::NumpyTrilindicesParam>(&attrs);

int num_outputs = 0;
auto ndoutputs = Invoke(op, &attrs, 0, nullptr, &num_outputs, nullptr);
std::vector<NDArrayHandle> ndarray_handles;
ndarray_handles.reserve(num_outputs);
for (int i = 0; i < num_outputs; ++i) {
ndarray_handles.emplace_back(ndoutputs[i]);
}
*ret = ADT(0, ndarray_handles.begin(), ndarray_handles.end());
});

} // namespace mxnet
2 changes: 1 addition & 1 deletion tests/python/unittest/test_numpy_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -7842,7 +7842,7 @@ def hybrid_forward(self, F, x, *args, **kwargs):
for hybridize in [True, False]:
# dummy nparray for hybridize
x = np.ones((1,1))
test_trilindices = TestTrilindices(n, k, m)
test_trilindices = TestTrilindices(int(n), int(k), int(m))
if hybridize:
test_trilindices.hybridize()
mx_out = test_trilindices(x)[1]
Expand Down

0 comments on commit cd4b222

Please sign in to comment.