-
Notifications
You must be signed in to change notification settings - Fork 106
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Numba implementation of Blockwise #1015
base: main
Are you sure you want to change the base?
Conversation
8f514ac
to
2720d0b
Compare
2014cd9
to
c45aab2
Compare
if nout == 1: | ||
tuple_core_shapes = (to_fixed_tuple(core_shapes[0], core_shape_0),) | ||
elif nout == 2: | ||
tuple_core_shapes = ( | ||
to_fixed_tuple(core_shapes[0], core_shape_0), | ||
to_fixed_tuple(core_shapes[1], core_shape_1), | ||
) | ||
else: | ||
tuple_core_shapes = ( | ||
to_fixed_tuple(core_shapes[0], core_shape_0), | ||
to_fixed_tuple(core_shapes[1], core_shape_1), | ||
to_fixed_tuple(core_shapes[2], core_shape_2), | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If anybody has an idea on how to do this dynamically would be great. Do we have to do string generation 😭?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are you opposed to cheesing it?
tuple(to_fixed_tuple(core_shapes[i], core_shape_lens[i]) for i in range(nout))
(I don't have full context)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
numba doesn't support that in this context
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ewww. Maybe you could try a bunch of eval
statements?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We would need to go down the string generation as we do for some other Ops (like Scan). But I didn't want to :)
This can only be done when the output of infer_shape of the core_op depends only on the input shapes, and not their values.
Restricted to 3 outputs, due to limitations in jitting of Numba functions
c45aab2
to
31cc1e9
Compare
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #1015 +/- ##
==========================================
- Coverage 81.75% 81.74% -0.01%
==========================================
Files 183 185 +2
Lines 47756 47816 +60
Branches 11620 11632 +12
==========================================
+ Hits 39044 39089 +45
- Misses 6519 6529 +10
- Partials 2193 2198 +5
|
Description
Implement Numba blockwise for Ops with up to 3 outputs (due to numba not liking tuple generators in the inner functions...)
It uses the machinery developed for RVs and Elemwise. The hard part has to do with multiple number of inputs and numba fussiness.
It also improves Blockwise shape inference based on the infer_shape of the core ops
The small cholesky benchmark I added here test runs 10x faster after this PR on my local machine.
Related Issue