-
Notifications
You must be signed in to change notification settings - Fork 70
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix concatenate and *stack APIs to support scalars(#818, #839) #866
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1479,8 +1479,8 @@ def check_shape_with_axis( | |
ndim = inputs[0].ndim | ||
shape = inputs[0].shape | ||
|
||
axis = normalize_axis_index(axis, ndim) | ||
if ndim >= 1: | ||
axis = normalize_axis_index(axis, ndim) | ||
if _builtin_any( | ||
shape[:axis] != inp.shape[:axis] | ||
or shape[axis + 1 :] != inp.shape[axis + 1 :] | ||
|
@@ -1805,7 +1805,11 @@ def concatenate( | |
|
||
# flatten arrays if axis == None and concatenate arrays on the first axis | ||
if axis is None: | ||
inputs = list(inp.ravel() for inp in inputs) | ||
# Reshape arrays in the `array_list` to handle scalars | ||
reshaped = _atleast_nd(1, tuple(inputs)) | ||
if not isinstance(reshaped, list): | ||
reshaped = [reshaped] | ||
inputs = list(inp.ravel() for inp in reshaped) | ||
axis = 0 | ||
|
||
# Check to see if we can build a new tuple of cuNumeric arrays | ||
|
@@ -1871,6 +1875,9 @@ def stack( | |
if len(shapes) != 1: | ||
raise ValueError("all input arrays must have the same shape for stack") | ||
|
||
# handle scalar inputs | ||
if type(common_info.ndim) is not int: | ||
common_info.ndim = 0 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think is necessary. The
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree. Fixed. Thanks. |
||
axis = normalize_axis_index(axis, common_info.ndim + 1) | ||
shape = common_info.shape[:axis] + (1,) + common_info.shape[axis:] | ||
arrays = [arr.reshape(shape) for arr in arrays] | ||
|
@@ -1960,7 +1967,14 @@ def hstack(tup: Sequence[ndarray]) -> ndarray: | |
-------- | ||
Multiple GPUs, Multiple CPUs | ||
""" | ||
tup, common_info = check_shape_dtype_without_axis(tup, hstack.__name__) | ||
# Reshape arrays in the `array_list` to handle scalars | ||
reshaped = _atleast_nd(1, tuple(tup)) | ||
if not isinstance(reshaped, list): | ||
reshaped = [reshaped] | ||
|
||
tup, common_info = check_shape_dtype_without_axis( | ||
reshaped, hstack.__name__ | ||
) | ||
check_shape_with_axis( | ||
tup, hstack.__name__, axis=(0 if common_info.ndim == 1 else 1) | ||
) | ||
|
@@ -2052,14 +2066,19 @@ def column_stack(tup: Sequence[ndarray]) -> ndarray: | |
-------- | ||
Multiple GPUs, Multiple CPUs | ||
""" | ||
# Reshape arrays in the `array_list` to handle scalars | ||
reshaped = _atleast_nd(1, tuple(tup)) | ||
if not isinstance(reshaped, list): | ||
reshaped = [reshaped] | ||
|
||
tup, common_info = check_shape_dtype_without_axis( | ||
tup, column_stack.__name__ | ||
reshaped, column_stack.__name__ | ||
) | ||
# When ndim == 1, hstack concatenates arrays along the first axis | ||
|
||
if common_info.ndim == 1: | ||
tup = list(inp.reshape((inp.shape[0], 1)) for inp in tup) | ||
common_info.shape = tup[0].shape | ||
check_shape_with_axis(tup, dstack.__name__, 1) | ||
check_shape_with_axis(tup, column_stack.__name__, 1) | ||
return _concatenate( | ||
tup, | ||
common_info, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think you need to call
tuple(inputs)
, you could just do_atleast_nd(1, inputs)
, and similar in other places in this PR where you call_atleast_nd(1, tuple(tup))
. You may need to extend the type signature of_atleast_nd
to takendim: int, arys: Sequence[ndarray, ...]
.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed. Thanks.