-
Notifications
You must be signed in to change notification settings - Fork 38
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
QuTiPv5 Paper Notebook: JAX #116
base: main
Are you sure you want to change the base?
Conversation
- added mesolve example - added count stat. example - added gradient example - added basic structure for notebook with introduction and explanations
- added qutip5 paper and qutip-jax as reference - added basics tests for current and shot noise example - small code readbility improvement
Although it is ready for review, we should still wait until an official arXiv link is available to the QuTiP 5 paper as this reference link is blank at the moment.
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe good to tag @rochisha0 as a co-author, she provided the final mcsolve example
plt.legend() | ||
plt.show() | ||
``` | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could add some funky animation plot!
%matplotlib notebook
from IPython.display import HTML
ax, ani = qt.anim_matrix_histogram(result_ising)
HTML(ani.to_jshtml(fps=15))
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it normal that this takes >20min to compute (on a normal-ish laptop)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah ouch. i tried a different example the other day and it was fairly slow too. we could reduce the number of points in tlist, but perhaps just better not to add here!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
From my experience, a long tlist
usually causes a long compilation time (not computation time!), because we have a for loop in the qutip base solver. This for loop will be flattened by jax compilation. So each step in the for loop is compiled separately. This can be speeded up by rewriting it with a JAX loop. But I don't know how hard it is because in each step we add something to the result class.
Seems like JAX 0.4.36 breaks something in qutip-jax, so tests will fail for a while |
This adds the JAX examples from the paper to the tutorials. I combined the mesolve, count stat. and gradient calculation into one notebook as they are closely related.
TODO: