Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fixing putmask logic for scalar inputs #980

Merged
merged 3 commits into from
Jul 7, 2023

Conversation

ipdemes
Copy link
Contributor

@ipdemes ipdemes commented Jul 1, 2023

No description provided.

@ipdemes ipdemes added the category:bug-fix PR is a bug fix and will be classified as such in release notes label Jul 1, 2023
Copy link
Contributor

@shriram-jagan shriram-jagan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! I can confirm that this fixes my use case. You could also optionally use auto dim = std::max(1, args.in.dim()); like in ScalarUnaryRedDispatch to be consistent.

x = num.random.rand(3, 3)
s = x.sum()
num.putmask(s, True, 1.0)
print("IRINA DEBUG DDD", s)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The print statement somehow made it in....

@@ -4014,7 +4014,7 @@ def find_common_type(*args: Any) -> np.dtype[Any]:
scalar_types.append(array.dtype)
else:
array_types.append(array.dtype)
return np.find_common_type(array_types, scalar_types) # type: ignore
return np.find_common_type(array_types, scalar_types)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd make sure this change doesn't make mypy unhappy.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, mypy was complaining about this type: ignore so I removed it

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ha, mypy in CI complaining now. I will put it back

@@ -105,7 +105,8 @@ static void putmask_template(TaskContext& context)
{
auto& inputs = context.inputs();
PutmaskArgs args{context.outputs()[0], inputs[1], inputs[2]};
double_dispatch(args.input.dim(), args.input.code(), PutmaskImpl<KIND>{}, args);
int dim = args.input.dim() == 0 ? 1 : args.input.dim();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
int dim = args.input.dim() == 0 ? 1 : args.input.dim();
int dim = std::max(1, args.input.dim());

@ipdemes ipdemes merged commit cf0154b into nv-legate:branch-23.07 Jul 7, 2023
@ipdemes ipdemes deleted the putmask_fix2 branch August 2, 2023 21:28
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
category:bug-fix PR is a bug fix and will be classified as such in release notes
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants