Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

eye for dense and sparse #8225

Closed
wants to merge 27 commits into from
Closed

eye for dense and sparse #8225

wants to merge 27 commits into from

Conversation

ZiyueHuang
Copy link
Member

@ZiyueHuang ZiyueHuang commented Oct 11, 2017

Description

eye for storage_type of default, csr, row_sparse(fallback to default).
As a feature requested in #8168.

cc @eric-haibin-lin for review.

Checklist

Essentials

  • Passed code style checking (make lint)
  • Changes are complete (i.e. I finished coding on this PR)
  • All changes have test coverage
  • For user-facing API changes, API doc string has been updated.
  • To my best knowledge, examples are either not affected by this change, or have been fixed to be compatible with this change

Changes

  • eye for default storage type, add unitest
  • eye for csr storage type, add unitest
  • eye for row_sparse storage type, add unitest

Comments

  • If this change is a backward incompatible change, why must this change be made.
  • Intersting edge cases to note here


template<typename xpu>
void EyeFillEx(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please fix indentation

}

template<int req>
struct eye_fill_impl {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: maybe eye_dns_fill is a better name?

--------
>>> mx.nd.eye(1, 2)

[[ 1. 0.]]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let put a simple example and a complex one with k!= 0

An optional list of types of the aux data for RowSparseNDArray or CSRNDArray
(default values depends on the storage type)

Returns
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Returns
-------
NDArray  
       A created array

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

if ctx is None:
ctx = Context.default_ctx
dtype = mx_real_t if dtype is None else dtype
# pylint: disable= no-member, protected-access
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are these pylint disable necessary?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(default values depends on the storage type)

Returns
-------
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lets put csrndarray as the return type here

Copy link
Member Author

@ZiyueHuang ZiyueHuang Oct 14, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

NDArray or CSRNDArray
    A created array

.add_enum("float64", mshadow::kFloat64)
.add_enum("float16", mshadow::kFloat16)
.add_enum("uint8", mshadow::kUint8)
.add_enum("int32", mshadow::kInt32)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we support int64, too?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.


template<typename ParamType>
inline bool InitEyeShape(const nnvm::NodeAttrs& attrs,
std::vector<TShape> *in_attrs,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pls fix indentation

CHECK_EQ(req[0], kWriteTo) << "kWriteTo is expected for EyeFillEx";
if (stype == kCSRStorage) {
NDArray nd(outputs[0]);
EyeFillCsr<xpu>(s, nd, param.N, param.M, param.k);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can directly pass outputs[0] to EyeFillCsr

void EyeFillCsr(mshadow::Stream<xpu> *s, const NDArray& out,
const nnvm::dim_t N, const nnvm::dim_t M, const nnvm::dim_t k) {
const nnvm::dim_t num_cols = M > 0 ? M : N;
const nnvm::dim_t nnz = k > 0 ? std::min(std::max(num_cols - std::abs(k), (nnvm::dim_t)0), N) :
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this line is too long..

Copy link
Member

@eric-haibin-lin eric-haibin-lin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good. Just some final comments..

An optional device context (default is the current default context)
dtype : str or numpy.dtype, optional
An optional value type (default is `float32`)
aux_types: list of numpy.dtype, optional
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

aux_types will be removed by #8269
let's also not expose aux_types to users here, either. just using the default one is fine

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

An optional device context (default is the current default context)
dtype: str or numpy.dtype, optional
An optional value type (default is `float32`)
aux_types: list of numpy.dtype, optional
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here for aux-types

@@ -730,6 +730,13 @@ def test_output():
assert_almost_equal(out.asnumpy(), zeros.asnumpy())
mx.nd.full(shape, 2, out=out)
assert_almost_equal(out.asnumpy(), ones.asnumpy() * 2)
import random
Copy link
Member

@eric-haibin-lin eric-haibin-lin Oct 14, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: any reason why np.random.randint is not used here?

.add_enum("float16", mshadow::kFloat16)
.add_enum("uint8", mshadow::kUint8)
.add_enum("int32", mshadow::kInt32)
.add_enum("int8", mshadow::kInt8)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@reminisce will int8 cause any issue if we enable it?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be fine I think.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think frontend support int8

@ZiyueHuang
Copy link
Member Author

Why build fails here? Do you have any idea? @eric-haibin-lin @piiswrong

These codes can successfully compile and pass the unittest on my machine.

@ZiyueHuang
Copy link
Member Author

Seems that tests for eye passed, CI fails due to test_maximum_minimum_scalar.

Copy link
Member

@eric-haibin-lin eric-haibin-lin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some final minor comments. Should be good to merge once addressed (and if CI passes)

});
}

struct eye_csr_indptr_fill {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's add more documentation for the new kernels

}
};

struct eye_csr_data_fill {
Copy link
Member

@eric-haibin-lin eric-haibin-lin Oct 30, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@eric-haibin-lin eric-haibin-lin self-assigned this Nov 8, 2017
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants