VIDEO: https://youtu.be/Uwn3ngzXD0Y
"JAX is a Python library for accelerator-oriented array computation and program transformation, designed for high-performance numerical computing and large-scale machine learning. Basically, JAX is NumPy on the CPU, GPU, and TPU, with great automatic differentiation for high-performance machine learning research."
This video shows how to use JAX (https://jax.readthedocs.io/en/latest/) on CoCalc with or without a GPU.