-
Notifications
You must be signed in to change notification settings - Fork 70
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fixing putmask logic for scalar inputs #980
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! I can confirm that this fixes my use case. You could also optionally use auto dim = std::max(1, args.in.dim()); like in ScalarUnaryRedDispatch to be consistent.
tests/integration/test_putmask.py
Outdated
x = num.random.rand(3, 3) | ||
s = x.sum() | ||
num.putmask(s, True, 1.0) | ||
print("IRINA DEBUG DDD", s) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The print statement somehow made it in....
cunumeric/array.py
Outdated
@@ -4014,7 +4014,7 @@ def find_common_type(*args: Any) -> np.dtype[Any]: | |||
scalar_types.append(array.dtype) | |||
else: | |||
array_types.append(array.dtype) | |||
return np.find_common_type(array_types, scalar_types) # type: ignore | |||
return np.find_common_type(array_types, scalar_types) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd make sure this change doesn't make mypy unhappy.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, mypy was complaining about this type: ignore
so I removed it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ha, mypy in CI complaining now. I will put it back
@@ -105,7 +105,8 @@ static void putmask_template(TaskContext& context) | |||
{ | |||
auto& inputs = context.inputs(); | |||
PutmaskArgs args{context.outputs()[0], inputs[1], inputs[2]}; | |||
double_dispatch(args.input.dim(), args.input.code(), PutmaskImpl<KIND>{}, args); | |||
int dim = args.input.dim() == 0 ? 1 : args.input.dim(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
int dim = args.input.dim() == 0 ? 1 : args.input.dim(); | |
int dim = std::max(1, args.input.dim()); |
No description provided.