-
Notifications
You must be signed in to change notification settings - Fork 6.8k
make array.reshape compatible with numpy #9790
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -926,12 +926,12 @@ def _at(self, idx): | |
self.handle, mx_uint(idx), ctypes.byref(handle))) | ||
return NDArray(handle=handle, writable=self.writable) | ||
|
||
def reshape(self, shape): | ||
def reshape(self, *shape, **kwargs): | ||
"""Returns a **view** of this array with a new shape without altering any data. | ||
|
||
Parameters | ||
---------- | ||
shape : tuple of int | ||
shape : tuple of int, or n ints | ||
The new shape should not change the array size, namely | ||
``np.prod(new_shape)`` should be equal to ``np.prod(self.shape)``. | ||
|
||
|
@@ -960,6 +960,11 @@ def reshape(self, shape): | |
[ 4., 5.]], dtype=float32) | ||
>>> y = x.reshape((3,-1)) | ||
>>> y.asnumpy() | ||
array([[ 0., 1.], | ||
[ 2., 3.], | ||
[ 4., 5.]], dtype=float32) | ||
>>> y = x.reshape(3,2) | ||
>>> y.asnumpy() | ||
array([[ 0., 1.], | ||
[ 2., 3.], | ||
[ 4., 5.]], dtype=float32) | ||
|
@@ -968,6 +973,14 @@ def reshape(self, shape): | |
array([[-1., -1., -1.], | ||
[-1., -1., -1.]], dtype=float32) | ||
""" | ||
if len(shape) == 1 and isinstance(shape[0], (list, tuple)): | ||
shape = shape[0] | ||
elif not len(shape): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what happens for reshape(1, 2, shape=1)? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. keyword argument is ignored. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. would that be a problem? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. an error should be raised There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. OK There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Numpy throws an error too, though the error is confusing.
|
||
for key, value in kwargs.items(): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. no need to iterate. test for len(kwargs) == 1 and kwargs.get('shape', None) directly There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The loop also helps throw exception with the invalid argument name, which is consistent with numpy's behavior. The check would be lost in your proposal. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is slow. Use test for len(kwargs) == 1, then kwargs.get('shape', None) |
||
if key == 'shape': | ||
shape = value | ||
else: | ||
raise TypeError("'%s' is an invalid keyword argument for this function"%key) | ||
handle = NDArrayHandle() | ||
|
||
# Actual reshape | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(self, *args, shape=None)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This doesn't work in py2