diff --git a/docs/_static/jax_cgpu_entropy.png b/docs/_static/jax_cgpu_entropy.png index 5c808114..1e5ea9fd 100644 Binary files a/docs/_static/jax_cgpu_entropy.png and b/docs/_static/jax_cgpu_entropy.png differ diff --git a/docs/jax.rst b/docs/jax.rst index db79b081..9a7b42ae 100644 --- a/docs/jax.rst +++ b/docs/jax.rst @@ -22,6 +22,7 @@ In the first cell, install hoi and import some modules: import numpy as np import jax import jax.numpy as jnp + import timeit from time import time from hoi.metrics import Oinfo @@ -35,36 +36,44 @@ In a new cell, past the following code. This code compute the Gaussian Copula en .. code-block:: shell - def compute_timings(n=15): - n_samples = np.linspace(10, 10e2, n).astype(int) - n_features = np.linspace(1, 10, n).astype(int) - n_variables = np.linspace(1, 10e2, n).astype(int) + # number of repetition + n_repeat= 5 - entropy = jax.vmap(get_entropy(method="gc"), in_axes=(0,)) + # get the entropy function + entropy = jax.jit(jax.vmap(get_entropy(method="gc"), in_axes=(0,))) - # dry run - entropy(np.random.rand(2, 2, 10)) + # dry run + entropy(np.random.rand(2, 2, 10)) - timings_cpu = [] - data_size = [] - for n_s, n_f, n_v in zip(n_samples, n_features, n_variables): - # generate random data - x = np.random.rand(n_v, n_f, n_s) - x = jnp.asarray(x) + # define the number of samples, features and variables + n_samples = np.linspace(10, 10e2, 5).astype(int) + n_features = np.linspace(1, 10, 5).astype(int) + n_variables = np.linspace(1, 10e2, 5).astype(int) - # compute entropy - start = time() - entropy(x) - timings_cpu.append(time() - start) - data_size.append(n_s * n_f * n_v) + data_size, timings_gpu, timings_cpu = [], [], [] + for n_s, n_f, n_v in zip(n_samples, n_features, n_variables): + x = np.random.rand(n_v, n_f, n_s) + x = jnp.asarray(x) - return data_size, timings_cpu + # compute the entropy on cpu + with jax.default_device(jax.devices("cpu")[0]): + result_cpu = timeit.timeit( + 'entropy(x).block_until_ready()', + number=n_repeat, + globals=globals() + ) + timings_cpu.append(result_cpu / n_repeat) - with jax.default_device(jax.devices("gpu")[0]): - data_size, timings_gpu = compute_timings() + # compute the entropy on gpu + with jax.default_device(jax.devices("gpu")[0]): + result_gpu = timeit.timeit( + 'entropy(x).block_until_ready()', + number=n_repeat, + globals=globals() + ) + timings_gpu.append(result_gpu / n_repeat) - with jax.default_device(jax.devices("cpu")[0]): - data_size, timings_cpu = compute_timings() + data_size.append(n_s * n_f * n_v) Finally, plot the timing comparison : @@ -81,7 +90,7 @@ Finally, plot the timing comparison : .. image:: _static/jax_cgpu_entropy.png -On CPU, the computing time increase linearly as the array gets larger. However, on GPU, it doesn't scale as fast. +As the data size increases, computations on CPU (in red) increase linearly while they remain relatively stable on GPU (in blue). Computing Higher-Order Interactions on large multiplets ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ @@ -105,7 +114,7 @@ In the next example, we are going to compute Higher-Order Interactions on a larg start = time() model.fit(minsize=3, maxsize=o) timings.append(time() - start) - + return order, timings with jax.default_device(jax.devices("gpu")[0]):