-
Notifications
You must be signed in to change notification settings - Fork 3
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversion time for sparse matrix #19
Comments
This sounds very good and perfectly aligns with our view on the tool. As we have presumably unlimited time on precomputations and need to generate as efficient function as possible, it makes sense to apply optimizations which make code more performant. I suggest you to just create a pull request with your idea and contribute, we will be happy to accept this when we get a little bit more free time. Otherwise we will probably implement this ourselves, just a little bit later. |
@mattephi Thanks for your answer, I will make the pull request when I have some time 👍 It helped a bit to reduce the number of zeros but actually I am still struggling with the conversion times. This is unrelated to jaxadi though. The jax |
I think it is the main issue with jax as well as with other tools such as mjx. Compile times are horrendous and many people complain about them. There are some approaches to the problem like applying |
Hello again!
Thank you for your quick fixes for #16 and #17 !
This is more of a discussion rather than an issue, so feel free to close it at some point.
As I have to convert big sparse matrices, I am facing long compilation time using
convert(foo, compile=True)
as it does not ignore the zeros and assign them to the output, meaning there is a ton of assignment operations even though there are few non-zeros coefficients.For instance the following code takes 10 seconds to run for N = 100 or 76 seconds for N = 200 (and I will likely have to convert bigger matrices). I guess the compiler struggles with the high number of operations and does not realize it can ignore most of them?
Note that this is just an example, of course for this diagonal matrix I could just directly code it in jax.
As a quick roundabout, I slightly modified the codegen function to ignore the operations associated with a zero value:
The tests seem to be all passing, and it brought down the convert time from 76 to less than 2 seconds for N = 200.
I'm sure there is a better way to do that, but I just wanted to share the code in case other people are facing the same struggles with spare matrices.
For completeness, the translation (for N = 4 just to illustrate) goes from
to
We can see it just discards useless assignments of
work[4][0]
tooutputs[0]
without affecting the rest.Best,
The text was updated successfully, but these errors were encountered: