-
Notifications
You must be signed in to change notification settings - Fork 6.8k
make array.reshape compatible with numpy #9790
Conversation
@@ -926,12 +926,12 @@ def _at(self, idx): | |||
self.handle, mx_uint(idx), ctypes.byref(handle))) | |||
return NDArray(handle=handle, writable=self.writable) | |||
|
|||
def reshape(self, shape): | |||
def reshape(self, *shape, **kwargs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(self, *args, shape=None)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This doesn't work in py2
python/mxnet/ndarray/ndarray.py
Outdated
@@ -968,6 +973,14 @@ def reshape(self, shape): | |||
array([[-1., -1., -1.], | |||
[-1., -1., -1.]], dtype=float32) | |||
""" | |||
if len(shape) == 1 and isinstance(shape[0], (list, tuple)): | |||
shape = shape[0] | |||
elif not len(shape): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what happens for reshape(1, 2, shape=1)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
keyword argument is ignored.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
would that be a problem?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
an error should be raised
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Numpy throws an error too, though the error is confusing.
In [1]: import numpy as np
In [2]: a = np.ones((3,5))
In [3]: a.reshape(3,5,shape=(7,))
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
<ipython-input-3-cac6c36bb3d2> in <module>()
----> 1 a.reshape(3,5,shape=(7,))
TypeError: 'shape' is an invalid keyword argument for this function
python/mxnet/ndarray/ndarray.py
Outdated
if len(shape) == 1 and isinstance(shape[0], (list, tuple)): | ||
shape = shape[0] | ||
elif not len(shape): | ||
for key, value in kwargs.items(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no need to iterate. test for len(kwargs) == 1 and kwargs.get('shape', None) directly
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The loop also helps throw exception with the invalid argument name, which is consistent with numpy's behavior. The check would be lost in your proposal.
python/mxnet/ndarray/ndarray.py
Outdated
if len(shape) == 1 and isinstance(shape[0], (list, tuple)): | ||
shape = shape[0] | ||
elif not shape: | ||
for key, value in kwargs.items(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is slow. Use test for len(kwargs) == 1, then kwargs.get('shape', None)
* make array.reshape compatible with numpy * update * add exception when both *args and **kwargs are specified * update
* make array.reshape compatible with numpy * update * add exception when both *args and **kwargs are specified * update
Description
Allow n ints as input to array.reshape.
Checklist
Essentials
make lint
)Changes