-
Notifications
You must be signed in to change notification settings - Fork 70
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix concatenate and *stack APIs to support scalars(#818, #839) #866
Fix concatenate and *stack APIs to support scalars(#818, #839) #866
Conversation
cunumeric/module.py
Outdated
@@ -1805,7 +1805,11 @@ def concatenate( | |||
|
|||
# flatten arrays if axis == None and concatenate arrays on the first axis | |||
if axis is None: | |||
inputs = list(inp.ravel() for inp in inputs) | |||
# Reshape arrays in the `array_list` to handle scalars | |||
reshaped = _atleast_nd(1, tuple(inputs)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think you need to call tuple(inputs)
, you could just do _atleast_nd(1, inputs)
, and similar in other places in this PR where you call _atleast_nd(1, tuple(tup))
. You may need to extend the type signature of _atleast_nd
to take ndim: int, arys: Sequence[ndarray, ...]
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed. Thanks.
cunumeric/module.py
Outdated
# handle scalar inputs | ||
if type(common_info.ndim) is not int: | ||
common_info.ndim = 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think is necessary. The check_shape_dtype_without_axis
always returns an int for ndim
as far as I can tell, and if I comment out this code and run with scalars, this works fine:
arrays = (0, 4)
print(cn.stack(arrays))
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree. Fixed. Thanks.
normalize_axis_index
for scalar inputs