-
Notifications
You must be signed in to change notification settings - Fork 6.8k
Conversation
src/operator/tensor/init_op.h
Outdated
|
||
template<typename xpu> | ||
void EyeFillEx(const nnvm::NodeAttrs& attrs, | ||
const OpContext& ctx, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please fix indentation
src/operator/tensor/init_op.h
Outdated
} | ||
|
||
template<int req> | ||
struct eye_fill_impl { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: maybe eye_dns_fill
is a better name?
python/mxnet/ndarray/ndarray.py
Outdated
-------- | ||
>>> mx.nd.eye(1, 2) | ||
|
||
[[ 1. 0.]] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let put a simple example and a complex one with k!= 0
python/mxnet/ndarray/ndarray.py
Outdated
An optional list of types of the aux data for RowSparseNDArray or CSRNDArray | ||
(default values depends on the storage type) | ||
|
||
Returns |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Returns
-------
NDArray
A created array
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
if ctx is None: | ||
ctx = Context.default_ctx | ||
dtype = mx_real_t if dtype is None else dtype | ||
# pylint: disable= no-member, protected-access |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
are these pylint disable necessary?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I just follow the way in zeros
. https://github.com/apache/incubator-mxnet/blob/master/python/mxnet/ndarray/ndarray.py#L2928
python/mxnet/ndarray/utils.py
Outdated
(default values depends on the storage type) | ||
|
||
Returns | ||
------- |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lets put csrndarray as the return type here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
NDArray or CSRNDArray
A created array
.add_enum("float64", mshadow::kFloat64) | ||
.add_enum("float16", mshadow::kFloat16) | ||
.add_enum("uint8", mshadow::kUint8) | ||
.add_enum("int32", mshadow::kInt32) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we support int64, too?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
src/operator/tensor/init_op.h
Outdated
|
||
template<typename ParamType> | ||
inline bool InitEyeShape(const nnvm::NodeAttrs& attrs, | ||
std::vector<TShape> *in_attrs, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
pls fix indentation
src/operator/tensor/init_op.h
Outdated
CHECK_EQ(req[0], kWriteTo) << "kWriteTo is expected for EyeFillEx"; | ||
if (stype == kCSRStorage) { | ||
NDArray nd(outputs[0]); | ||
EyeFillCsr<xpu>(s, nd, param.N, param.M, param.k); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you can directly pass outputs[0] to EyeFillCsr
src/operator/tensor/init_op.h
Outdated
void EyeFillCsr(mshadow::Stream<xpu> *s, const NDArray& out, | ||
const nnvm::dim_t N, const nnvm::dim_t M, const nnvm::dim_t k) { | ||
const nnvm::dim_t num_cols = M > 0 ? M : N; | ||
const nnvm::dim_t nnz = k > 0 ? std::min(std::max(num_cols - std::abs(k), (nnvm::dim_t)0), N) : |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this line is too long..
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good. Just some final comments..
python/mxnet/ndarray/sparse.py
Outdated
An optional device context (default is the current default context) | ||
dtype : str or numpy.dtype, optional | ||
An optional value type (default is `float32`) | ||
aux_types: list of numpy.dtype, optional |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
aux_types will be removed by #8269
let's also not expose aux_types to users here, either. just using the default one is fine
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
python/mxnet/ndarray/utils.py
Outdated
An optional device context (default is the current default context) | ||
dtype: str or numpy.dtype, optional | ||
An optional value type (default is `float32`) | ||
aux_types: list of numpy.dtype, optional |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same here for aux-types
@@ -730,6 +730,13 @@ def test_output(): | |||
assert_almost_equal(out.asnumpy(), zeros.asnumpy()) | |||
mx.nd.full(shape, 2, out=out) | |||
assert_almost_equal(out.asnumpy(), ones.asnumpy() * 2) | |||
import random |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: any reason why np.random.randint is not used here?
src/operator/tensor/init_op.h
Outdated
.add_enum("float16", mshadow::kFloat16) | ||
.add_enum("uint8", mshadow::kUint8) | ||
.add_enum("int32", mshadow::kInt32) | ||
.add_enum("int8", mshadow::kInt8) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@reminisce will int8 cause any issue if we enable it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should be fine I think.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think frontend support int8
Why build fails here? Do you have any idea? @eric-haibin-lin @piiswrong These codes can successfully compile and pass the unittest on my machine. |
Seems that tests for |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some final minor comments. Should be good to merge once addressed (and if CI passes)
src/operator/tensor/init_op.h
Outdated
}); | ||
} | ||
|
||
struct eye_csr_indptr_fill { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's add more documentation for the new kernels
src/operator/tensor/init_op.h
Outdated
} | ||
}; | ||
|
||
struct eye_csr_data_fill { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we reuse this kernel instead?
https://github.com/apache/incubator-mxnet/blob/3552b958682d8ab9c3ea8e7d3f3f8855d29fbbe5/src/operator/mxnet_op.h#L295
Description
eye for storage_type of default, csr, row_sparse(fallback to default).
As a feature requested in #8168.
cc @eric-haibin-lin for review.
Checklist
Essentials
make lint
)Changes
Comments