diff --git a/lectures/jax_intro.md b/lectures/jax_intro.md index 571881aa..5a7b2ad6 100644 --- a/lectures/jax_intro.md +++ b/lectures/jax_intro.md @@ -4,7 +4,7 @@ jupytext: extension: .md format_name: myst format_version: 0.13 - jupytext_version: 1.14.7 + jupytext_version: 1.16.1 kernelspec: display_name: Python 3 (ipykernel) language: python @@ -14,20 +14,14 @@ kernelspec: # An Introduction to JAX -```{include} _admonition/gpu.md -``` - This lecture provides a short introduction to [Google JAX](https://github.com/google/jax). -As mentioned above, the lecture was built using a GPU: - +What GPUs do we have access to? ```{code-cell} ipython3 !nvidia-smi ``` - - ## JAX as a NumPy Replacement @@ -66,6 +60,10 @@ print(jnp.mean(a)) print(jnp.dot(a, a)) ``` +```{code-cell} ipython3 +print(a @ a) # Equivalent +``` + However, the array object `a` is not a NumPy array: ```{code-cell} ipython3 @@ -106,7 +104,12 @@ linalg.inv(B) # Inverse of identity is identity ``` ```{code-cell} ipython3 -linalg.eigh(B) # Computes eigenvalues and eigenvectors +result = linalg.eigh(B) # Computes eigenvalues and eigenvectors +result.eigenvalues +``` + +```{code-cell} ipython3 +result.eigenvectors ``` ### Differences @@ -155,21 +158,7 @@ a ```{code-cell} ipython3 :tags: [raises-exception] -a[0] = 1 -``` - -In line with immutability, JAX does not support inplace operations: - -```{code-cell} ipython3 -a = np.array((2, 1)) -a.sort() -a -``` - -```{code-cell} ipython3 -a = jnp.array((2, 1)) -a_new = a.sort() -a, a_new +# a[0] = 1 # uncommenting produces a TypeError ``` The designers of JAX chose to make arrays immutable because JAX uses a @@ -198,7 +187,11 @@ id(a) ## Random Numbers -Random numbers are also a bit different in JAX, relative to NumPy. Typically, in JAX, the state of the random number generator needs to be controlled explicitly. +Random numbers are also a bit different in JAX, relative to NumPy. + +Typically, in JAX, the state of the random number generator needs to be controlled explicitly. + +(This is also related to JAX's functional programming paradigm, discussed below. JAX does not typically work with objects that maintain state, such as the state of a random number generator.) ```{code-cell} ipython3 import jax.random as random @@ -265,11 +258,9 @@ for A in matrices: One point to remember is that JAX expects tuples to describe array shapes, even for flat arrays. Hence, to get a one-dimensional array of normal random draws we use `(len, )` for the shape, as in ```{code-cell} ipython3 -random.normal(key, (5, )) +random.normal(key, (5,)) # not random.normal(key, 5) ``` - - ## JIT compilation The JAX just-in-time (JIT) compiler accelerates logic within functions by fusing linear @@ -299,13 +290,12 @@ How long does the function take to execute? %time f(x).block_until_ready() ``` -```{note} -Here, in order to measure actual speed, we use the `block_until_ready()` method +(In order to measure actual speed, we use `block_until_ready()` method to hold the interpreter until the results of the computation are returned from the device. This is necessary because JAX uses asynchronous dispatch, which -allows the Python interpreter to run ahead of GPU computations. +allows the Python interpreter to run ahead of GPU computations.) + -``` The code doesn't run as fast as we might hope, given that it's running on a GPU. @@ -315,22 +305,25 @@ But if we run it a second time it becomes much faster: %time f(x).block_until_ready() ``` +```{code-cell} ipython3 +%timeit f(x).block_until_ready() +``` + This is because the built in functions like `jnp.cos` are JIT compiled and the first run includes compile time. -Why would JAX want to JIT-compile built in functions like `jnp.cos` instead of -just providing pre-compiled versions, like NumPy? -The reason is that the JIT compiler can specialize on the *size* of the array -being used, which is helpful for parallelization. -For example, in running the code above, the JIT compiler produced a version of `jnp.cos` that is -specialized to floating point arrays of size `n = 50_000_000`. +### When does JAX recompile? + +You might remember that Numba recompiles if we change the types of variables in a function call. + +JAX recompiles more often --- in particular, it recompiles every time we change array sizes. -We can check this by calling `f` with a new array of different size. +For example, let's try ```{code-cell} ipython3 -m = 50_000_001 +m = n + 1 y = jnp.ones(m) ``` @@ -349,6 +342,10 @@ get faster execution. %time f(y).block_until_ready() ``` +Why does JAX generate fresh machine code every time we change the array size??? + + + The compiled versions for the previous array size are still available in memory too, and the following call is dispatched to the correct compiled code. @@ -356,43 +353,78 @@ too, and the following call is dispatched to the correct compiled code. %time f(x).block_until_ready() ``` +### Compiling user-built functions + +We can instruct JAX to compile entire functions that we build. + +For example, consider + +```{code-cell} ipython3 +def g(x): + y = jnp.zeros_like(x) + for i in range(10): + y += x**i + return y +``` + +```{code-cell} ipython3 +n = 1_000_000 +x = jnp.ones(n) +``` +Let's time it. -### Compiling the outer function +```{code-cell} ipython3 +%time g(x) +``` -We can do even better if we manually JIT-compile the outer function. +```{code-cell} ipython3 +%time g(x) +``` ```{code-cell} ipython3 -f_jit = jax.jit(f) # target for JIT compilation +g_jit = jax.jit(g) # target for JIT compilation ``` Let's run once to compile it: ```{code-cell} ipython3 -f_jit(x) +g_jit(x) ``` And now let's time it. ```{code-cell} ipython3 -%time f_jit(x).block_until_ready() +%time g_jit(x).block_until_ready() ``` Note the speed gain. -This is because the array operations are fused and no intermediate arrays are created. +This is because + +1. the loop is compiled and +2. the array operations are fused and no intermediate arrays are created. Incidentally, a more common syntax when targetting a function for the JIT -compiler is +compiler is ```{code-cell} ipython3 @jax.jit -def f(x): - a = 3*x + jnp.sin(x) + jnp.cos(x**2) - jnp.cos(2*x) - x**2 * 0.4 * x**1.5 - return jnp.sum(a) +def g_jit_2(x): + y = jnp.zeros_like(x) + for i in range(10): + y += x**i + return y ``` +```{code-cell} ipython3 +%time g_jit_2(x).block_until_ready() +``` + +```{code-cell} ipython3 +%time g_jit_2(x).block_until_ready() +``` ## Functional Programming @@ -413,7 +445,80 @@ In particular, a pure function has * no side effects -JAX will not usually throw errors when compiling impure functions but execution becomes unpredictable. + +### Example: Python/NumPy/Numba style code is not pure + + + +Here's an example to show that NumPy functions are not pure: + +```{code-cell} ipython3 +np.random.randn() +``` + +```{code-cell} ipython3 +np.random.randn() +``` + +This fails the test: a function returns the same result when called on the same inputs. + +The issue is that the function maintains internal state between function calls --- the state of the random number generator. + + + +Here's a function that fails to be pure because it modifies external state. + +```{code-cell} ipython3 +def double_input(x): # Not pure -- side effects + x[:] = 2 * x + return None + +x = np.ones(5) +x +``` + +```{code-cell} ipython3 +double_input(x) +x +``` + +Here's a pure version: + +```{code-cell} ipython3 +def double_input(x): + y = 2 * x + return y +``` + +The following function is also not pure, since it modifies a global variable (similar to the last example). + +```{code-cell} ipython3 +a = 1 +def f(): + global a + a += 1 + return None +``` + +```{code-cell} ipython3 +a +``` + +```{code-cell} ipython3 +f() +``` + +```{code-cell} ipython3 +a +``` + +### Compiling impure functions + +JAX does not insist on pure functions. + +For example, JAX will not usually throw errors when compiling impure functions + +However, execution becomes unpredictable! Here's an illustration of this fact, using global variables: @@ -429,6 +534,10 @@ def f(x): x = jnp.ones(2) ``` +```{code-cell} ipython3 +x +``` + ```{code-cell} ipython3 f(x) ``` @@ -445,7 +554,7 @@ a = 42 f(x) ``` -Changing the dimension of the input triggers a fresh compilation of the function, at which time the change in the value of `a` takes effect: +Notice that the change in the value of `a` takes effect in the code below: ```{code-cell} ipython3 x = jnp.ones(3) @@ -455,8 +564,9 @@ x = jnp.ones(3) f(x) ``` -Moral of the story: write pure functions when using JAX! +Can you explain why? +Moral of the story: write pure functions when using JAX! ## Gradients @@ -502,9 +612,9 @@ Writing fast JAX code requires shifting repetitive tasks from loops to array pro This procedure is called **vectorization** or **array programming**, and will be familiar to anyone who has used NumPy or MATLAB. -In most ways, vectorization is the same in JAX as it is in NumPy. +In some ways, vectorization is the same in JAX as it is in NumPy. -But there are also some differences, which we highlight here. +But there are also major differences, which we highlight here. As a running example, consider the function @@ -512,9 +622,12 @@ $$ f(x,y) = \frac{\cos(x^2 + y^2)}{1 + x^2 + y^2} $$ -Suppose that we want to evaluate this function on a square grid of $x$ and $y$ points and then plot it. +Suppose that we want to evaluate this function on a square grid of $x$ and $y$ points. + + +### A slow version with loops -To clarify, here is the slow `for` loop version. +To clarify, here is the slow `for` loop version, which we run in a setting where `len(x) = len(y)` is very small. ```{code-cell} ipython3 @jax.jit @@ -539,7 +652,7 @@ Even for this very small grid, the run time is extremely slow. (Notice that we used a NumPy array for `z_loops` because we wanted to write to it.) -+++ + OK, so how can we do the same operation in vectorized form? @@ -561,7 +674,10 @@ Here is what we actually wanted: z_loops.shape ``` -To get the right shape and the correct nested for loop calculation, we can use a `meshgrid` operation designed for this purpose: +### Vectorization attempt 1 + + +To get the right shape and the correct nested for loop calculation, we can use a `meshgrid` operation that originated in MATLAB and was replicated in NumPy and then JAX: ```{code-cell} ipython3 x_mesh, y_mesh = jnp.meshgrid(x, y) @@ -607,22 +723,59 @@ z_mesh = f(x_mesh, y_mesh) But there is one problem here: the mesh grids use a lot of memory. ```{code-cell} ipython3 -x_mesh.nbytes + y_mesh.nbytes +(x_mesh.nbytes + y_mesh.nbytes) / 1_000_000 # MB of memory ``` By comparison, the flat array `x` is just ```{code-cell} ipython3 -x.nbytes # and y is just a pointer to x +x.nbytes / 1_000_000 # and y is just a pointer to x ``` This extra memory usage can be a big problem in actual research calculations. -So let's try a different approach using [jax.vmap](https://jax.readthedocs.io/en/latest/_autosummary/jax.vmap.html) +### Vectorization attempt 2 + + +We can achieve a similar effect through NumPy style broadcasting rules. + +```{code-cell} ipython3 +x_reshaped = jnp.reshape(x, (n, 1)) # Give x another dimension (column) +y_reshaped = jnp.reshape(y, (1, n)) # Give y another dimension (row) +``` + +When we evaluate $f$ on these reshaped arrays, we replicate the nested for loops in the original version. + +```{code-cell} ipython3 +%time z_reshaped = f(x_reshaped, y_reshaped) +``` + +Let's check that we got the same result -+++ +```{code-cell} ipython3 +jnp.allclose(z_reshaped, z_mesh) +``` -First we vectorize `f` in `y`. +The memory usage for the inputs is much more moderate. + +```{code-cell} ipython3 +(x_reshaped.nbytes + y_reshaped.nbytes) / 1_000_000 +``` + +### Vectorization attempt 3 + + +There's another approach to vectorization we can pursue, using [jax.vmap](https://jax.readthedocs.io/en/latest/_autosummary/jax.vmap.html) + + + +It runs out that, when we are working with complex functions and operations, this `vmap` approach can be the easiest to implement. + +It's also very memory parsimonious. + + + +The first step is to vectorize the function `f` in `y`. ```{code-cell} ipython3 f_vec_y = jax.vmap(f, in_axes=(None, 0)) @@ -636,6 +789,12 @@ Next, we vectorize in the first argument, which is `x`. f_vec = jax.vmap(f_vec_y, in_axes=(0, None)) ``` +Finally, we JIT-compile the result: + +```{code-cell} ipython3 +f_vec = jax.jit(f_vec) +``` + With this construction, we can now call the function $f$ on flat (low memory) arrays. ```{code-cell} ipython3 @@ -648,26 +807,64 @@ z_vmap = f_vec(x, y) z_vmap = f_vec(x, y) ``` -The execution time is essentially the same as the mesh operation but we are using much less memory. - -And we produce the correct answer: +Let's check we produce the correct answer: ```{code-cell} ipython3 jnp.allclose(z_vmap, z_mesh) ``` +**Exercise** +In a previous notebook we used Monte Carlo to price a European call option and +constructed a solution using Numba. -## Exercises +The code looked like this: +```{code-cell} ipython3 +import numba +from numpy.random import randn +M = 10_000_000 + +n, β, K = 20, 0.99, 100 +μ, ρ, ν, S0, h0 = 0.0001, 0.1, 0.001, 10, 0 + +@numba.jit(parallel=True) +def compute_call_price_parallel(β=β, + μ=μ, + S0=S0, + h0=h0, + K=K, + n=n, + ρ=ρ, + ν=ν, + M=M): + current_sum = 0.0 + # For each sample path + for m in numba.prange(M): + s = np.log(S0) + h = h0 + # Simulate forward in time + for t in range(n): + s = s + μ + np.exp(h) * randn() + h = ρ * h + ν * randn() + # And add the value max{S_n - K, 0} to current_sum + current_sum += np.maximum(np.exp(s) - K, 0) + + return β**n * current_sum / M +``` + +Let's run it once to compile it: -```{exercise-start} -:label: jax_intro_ex2 +```{code-cell} ipython3 +compute_call_price_parallel() ``` -In the Exercise section of [a lecture on Numba and parallelization](https://python-programming.quantecon.org/parallelization.html), we used Monte Carlo to price a European call option. +And now let's time it: -The code was accelerated by Numba-based multithreading. +```{code-cell} ipython3 +%%time +compute_call_price_parallel() +``` Try writing a version of this operation for JAX, using all the same parameters. @@ -675,14 +872,13 @@ parameters. If you are running your code on a GPU, you should be able to achieve significantly faster execution. - -```{exercise-end} +```{code-cell} ipython3 +for i in range(12): + print("Solution below.") ``` +**Solution** -```{solution-start} jax_intro_ex2 -:class: dropdown -``` Here is one solution: ```{code-cell} ipython3 @@ -727,6 +923,3 @@ And now let's time it: %%time compute_call_price_jax().block_until_ready() ``` - -```{solution-end} -```