Skip to content

Commit

Permalink
Fix timing computations on C/GPU for entropy
Browse files Browse the repository at this point in the history
  • Loading branch information
EtienneCmb committed Oct 31, 2024
1 parent d59fab5 commit fdb1301
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 25 deletions.
Binary file modified docs/_static/jax_cgpu_entropy.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
59 changes: 34 additions & 25 deletions docs/jax.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ In the first cell, install hoi and import some modules:
import numpy as np
import jax
import jax.numpy as jnp
import timeit
from time import time
from hoi.metrics import Oinfo
Expand All @@ -35,36 +36,44 @@ In a new cell, past the following code. This code compute the Gaussian Copula en

.. code-block:: shell
def compute_timings(n=15):
n_samples = np.linspace(10, 10e2, n).astype(int)
n_features = np.linspace(1, 10, n).astype(int)
n_variables = np.linspace(1, 10e2, n).astype(int)
# number of repetition
n_repeat= 5
entropy = jax.vmap(get_entropy(method="gc"), in_axes=(0,))
# get the entropy function
entropy = jax.jit(jax.vmap(get_entropy(method="gc"), in_axes=(0,)))
# dry run
entropy(np.random.rand(2, 2, 10))
# dry run
entropy(np.random.rand(2, 2, 10))
timings_cpu = []
data_size = []
for n_s, n_f, n_v in zip(n_samples, n_features, n_variables):
# generate random data
x = np.random.rand(n_v, n_f, n_s)
x = jnp.asarray(x)
# define the number of samples, features and variables
n_samples = np.linspace(10, 10e2, 5).astype(int)
n_features = np.linspace(1, 10, 5).astype(int)
n_variables = np.linspace(1, 10e2, 5).astype(int)
# compute entropy
start = time()
entropy(x)
timings_cpu.append(time() - start)
data_size.append(n_s * n_f * n_v)
data_size, timings_gpu, timings_cpu = [], [], []
for n_s, n_f, n_v in zip(n_samples, n_features, n_variables):
x = np.random.rand(n_v, n_f, n_s)
x = jnp.asarray(x)
return data_size, timings_cpu
# compute the entropy on cpu
with jax.default_device(jax.devices("cpu")[0]):
result_cpu = timeit.timeit(
'entropy(x).block_until_ready()',
number=n_repeat,
globals=globals()
)
timings_cpu.append(result_cpu / n_repeat)
with jax.default_device(jax.devices("gpu")[0]):
data_size, timings_gpu = compute_timings()
# compute the entropy on gpu
with jax.default_device(jax.devices("gpu")[0]):
result_gpu = timeit.timeit(
'entropy(x).block_until_ready()',
number=n_repeat,
globals=globals()
)
timings_gpu.append(result_gpu / n_repeat)
with jax.default_device(jax.devices("cpu")[0]):
data_size, timings_cpu = compute_timings()
data_size.append(n_s * n_f * n_v)
Finally, plot the timing comparison :
Expand All @@ -81,7 +90,7 @@ Finally, plot the timing comparison :
.. image:: _static/jax_cgpu_entropy.png
On CPU, the computing time increase linearly as the array gets larger. However, on GPU, it doesn't scale as fast.
As the data size increases, computations on CPU (in red) increase linearly while they remain relatively stable on GPU (in blue).
Computing Higher-Order Interactions on large multiplets
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Expand All @@ -105,7 +114,7 @@ In the next example, we are going to compute Higher-Order Interactions on a larg
start = time()
model.fit(minsize=3, maxsize=o)
timings.append(time() - start)
return order, timings
with jax.default_device(jax.devices("gpu")[0]):
Expand Down

0 comments on commit fdb1301

Please sign in to comment.