Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

QuTiPv5 Paper Notebook: JAX #116

Open
wants to merge 5 commits into
base: main
Choose a base branch
from

Conversation

Langhaarzombie
Copy link
Contributor

@Langhaarzombie Langhaarzombie commented Nov 27, 2024

This adds the JAX examples from the paper to the tutorials. I combined the mesolve, count stat. and gradient calculation into one notebook as they are closely related.

TODO:

  • References
  • Tests
  • Maybe some nice plots rather than just numbers? I am open for ideas

- added mesolve example
- added count stat. example
- added gradient example
- added basic structure for notebook with introduction and explanations
- added qutip5 paper and qutip-jax as reference
- added basics tests for current and shot noise example
- small code readbility improvement
@Langhaarzombie Langhaarzombie marked this pull request as ready for review November 28, 2024 07:21
@Langhaarzombie
Copy link
Contributor Author

Langhaarzombie commented Nov 28, 2024

Although it is ready for review, we should still wait until an official arXiv link is available to the QuTiP 5 paper as this reference link is blank at the moment.

  • Include official url to QuTiP 5 paper

Copy link
Member

@nwlambert nwlambert left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe good to tag @rochisha0 as a co-author, she provided the final mcsolve example

tutorials-v5/miscellaneous/v5_paper-jax.md Outdated Show resolved Hide resolved
tutorials-v5/miscellaneous/v5_paper-jax.md Outdated Show resolved Hide resolved
tutorials-v5/miscellaneous/v5_paper-jax.md Outdated Show resolved Hide resolved
tutorials-v5/miscellaneous/v5_paper-jax.md Outdated Show resolved Hide resolved
tutorials-v5/miscellaneous/v5_paper-jax.md Outdated Show resolved Hide resolved
tutorials-v5/miscellaneous/v5_paper-jax.md Show resolved Hide resolved
plt.legend()
plt.show()
```

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could add some funky animation plot!

%matplotlib notebook
from IPython.display import HTML
ax, ani = qt.anim_matrix_histogram(result_ising)
HTML(ani.to_jshtml(fps=15))

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it normal that this takes >20min to compute (on a normal-ish laptop)?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah ouch. i tried a different example the other day and it was fairly slow too. we could reduce the number of points in tlist, but perhaps just better not to add here!

Copy link
Member

@BoxiLi BoxiLi Dec 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From my experience, a long tlist usually causes a long compilation time (not computation time!), because we have a for loop in the qutip base solver. This for loop will be flattened by jax compilation. So each step in the for loop is compiled separately. This can be speeded up by rewriting it with a JAX loop. But I don't know how hard it is because in each step we add something to the result class.

@nwlambert
Copy link
Member

Seems like JAX 0.4.36 breaks something in qutip-jax, so tests will fail for a while

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants