Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix concatenate and *stack APIs to support scalars(#818, #839) #866

Merged
merged 2 commits into from
Mar 31, 2023

Conversation

robinwnv
Copy link
Contributor

  1. In module.py, reshape inputs to at least 1-D array for concatenate(axis=None), hstack and column_stack
  2. In module.py, in stack, set ndim to 0 for scalar inputs
  3. In module.py, call normalize_axis_index for scalar inputs
  4. In test code, add scalar tests for each API
  5. In test code, remove xfails

@robinwnv robinwnv added the category:bug-fix PR is a bug fix and will be classified as such in release notes label Mar 28, 2023
@robinwnv robinwnv changed the title Fix concatenate and *stack APIs to support scalars Fix concatenate and *stack APIs to support scalars(#818, #839) Mar 30, 2023
@@ -1805,7 +1805,11 @@ def concatenate(

# flatten arrays if axis == None and concatenate arrays on the first axis
if axis is None:
inputs = list(inp.ravel() for inp in inputs)
# Reshape arrays in the `array_list` to handle scalars
reshaped = _atleast_nd(1, tuple(inputs))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think you need to call tuple(inputs), you could just do _atleast_nd(1, inputs), and similar in other places in this PR where you call _atleast_nd(1, tuple(tup)). You may need to extend the type signature of _atleast_nd to take ndim: int, arys: Sequence[ndarray, ...].

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed. Thanks.

Comment on lines 1878 to 1880
# handle scalar inputs
if type(common_info.ndim) is not int:
common_info.ndim = 0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think is necessary. The check_shape_dtype_without_axis always returns an int for ndim as far as I can tell, and if I comment out this code and run with scalars, this works fine:

arrays = (0, 4)
print(cn.stack(arrays))

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree. Fixed. Thanks.

@robinwnv robinwnv merged commit ee63f1b into nv-legate:branch-23.05 Mar 31, 2023
@robinwnv robinwnv deleted the fix_concatenate_stack2 branch March 31, 2023 01:07
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
category:bug-fix PR is a bug fix and will be classified as such in release notes
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants