Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

JAX backend fails for a simple pymc linear regression model #157

Open
trendelkampschroer opened this issue Oct 21, 2024 · 10 comments
Open

Comments

@trendelkampschroer
Copy link

trendelkampschroer commented Oct 21, 2024

Minimal example

import time

import arviz as az
import numpy as np
import nutpie
import pandas as pd
import pymc as pm

BETA = [1.0, -1.0, 2.0, -2.0]
SIGMA = 10.0


def generate_data(num_samples: int = 1000) -> pd.DataFrame:
    rng = np.random.default_rng(42)
    dims = len(BETA)
    X = rng.normal(size=(num_samples, dims))
    y = X.dot(BETA) + SIGMA * rng.normal(size=num_samples)
    frame = pd.DataFrame(data=X, columns=[f"x_{i+1}" for i in range(dims)])
    frame["y"] = y
    return frame


def make_model(frame: pd.DataFrame) -> pm.Model:
    predictors = [col for col in frame.columns if col.startswith("x")]
    observation_idx = [i for i in range(len(frame))]
    coords = {"observation_idx": observation_idx, "predictors": predictors}

    with pm.Model(coords=coords) as model:
        # Data
        x = pm.Data("x", frame[predictors], dims=["observation_idx", "predictor"])
        y = pm.Data("y", frame["y"], dims="observation_idx")

        # Population level
        beta = pm.Normal("beta", mu=0.0, sigma=1.0, dims="predictor")
        sigma = pm.HalfNormal("sigma", sigma=10.0)

        # Linear model
        mu = (beta * x).sum(axis=-1)

        # Likelihood
        pm.Normal("y_obs", mu=mu, sigma=sigma, observed=y, shape=mu.shape)

    return model


if __name__ == "__main__":
    frame = generate_data(num_samples=10_000)
    model = make_model(frame)

    kwargs = dict(backend="jax", gradient_backend="jax")
    t0 = time.time()
    trace = nutpie.sample(nutpie.compile_pymc_model(model, **kwargs))
    t = time.time() - t0
    print(f"Time for nutpie (compiled, {kwargs=}) sampling is {t=:.3f}s.")
    summary = az.summary(trace, var_names=["beta", "sigma"], round_to=4).loc[:, ["mean", "hdi_3%", "hdi_97%", "ess_bulk", "ess_tail", "r_hat"]]
    print(summary)

Error message

thread 'nutpie-worker-3' panicked at /Users/runner/miniforge3/conda-bld/nutpie_1722020254444/_build_env/.cargo/registry/src/index.crates.io-6f17d22bba15001f/nuts-rs-0.12.1/src/sampler.rs:576:18:
Could not send sampling results to main thread.: SendError { .. }
note: run with `RUST_BACKTRACE=1` environment variable to display a backtrace
Traceback (most recent call last):
  File ".../linear_regression.py", line 52, in <module>
thread 'nutpie-worker-0' panicked at /Users/runner/miniforge3/conda-bld/nutpie_1722020254444/_build_env/.cargo/registry/src/index.crates.io-6f17d22bba15001f/nuts-rs-0.12.1/src/sampler.rs:576:18:
Could not send sampling results to main thread.: SendError { .. }
    trace = nutpie.sample(nutpie.compile_pymc_model(model, **kwargs))
            ^^thread 'nutpie-worker-6' panicked at /Users/runner/miniforge3/conda-bld/nutpie_1722020254444/_build_env/.cargo/registry/src/index.crates.io-6f17d22bba15001f/nuts-rs-0.12.1/src/sampler.rs:576:18:
Could not send sampling results to main thread.: SendError { .. }
^^thread 'nutpie-worker-1' panicked at /Users/runner/miniforge3/conda-bld/nutpie_1722020254444/_build_env/.cargo/registry/src/index.crates.io-6f17d22bba15001f/nuts-rs-0.12.1/src/sampler.rs:576:18:
Could not send sampling results to main thread.: SendError { .. }
^^^^^^^^^^thread 'nutpie-worker-2' panicked at /Users/runner/miniforge3/conda-bld/nutpie_1722020254444/_build_env/.cargo/registry/src/index.crates.io-6f17d22bba15001f/nuts-rs-0.12.1/src/sampler.rs:576:18:
Could not send sampling results to main thread.: SendError { .. }
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../lib/python3.12/site-packages/nutpie/sample.py", line 636, in sample
    result = sampler.wait()
             ^^^^^^^^^^^^^^
  File ".../lib/python3.12/site-packages/nutpie/sample.py", line 388, in wait
    self._sampler.wait(timeout)
RuntimeError: All initialization points failed

Caused by:
    Logp function returned error: Python error: TypeError: _compile_pymc_model_jax.<locals>.make_logp_func.<locals>.logp() got multiple values for argument 'x'

Sampling with backend="numba" and gradient_backend="pytensor" runs successfully.

Version

# packages in environment at .../.miniconda3/envs/pymc:
#
# Name                    Version                   Build  Channel
absl-py                   2.1.0              pyhd8ed1ab_0    conda-forge
accelerate                1.0.0              pyhd8ed1ab_0    conda-forge
arviz                     0.20.0             pyhd8ed1ab_0    conda-forge
atk-1.0                   2.38.0               hd03087b_2    conda-forge
aws-c-auth                0.7.31               hc27b277_0    conda-forge
aws-c-cal                 0.7.4                h41dd001_1    conda-forge
aws-c-common              0.9.28               hd74edd7_0    conda-forge
aws-c-compression         0.2.19               h41dd001_1    conda-forge
aws-c-event-stream        0.4.3                h40a8fc1_2    conda-forge
aws-c-http                0.8.10               hf5a2c8c_0    conda-forge
aws-c-io                  0.14.18             hc3cb426_12    conda-forge
aws-c-mqtt                0.10.7               h3acc7b9_0    conda-forge
aws-c-s3                  0.6.6                hd16c091_0    conda-forge
aws-c-sdkutils            0.1.19               h41dd001_3    conda-forge
aws-checksums             0.1.20               h41dd001_0    conda-forge
aws-crt-cpp               0.28.3               h433f80b_6    conda-forge
aws-sdk-cpp               1.11.407             h0455a66_0    conda-forge
azure-core-cpp            1.13.0               hd01fc5c_0    conda-forge
azure-identity-cpp        1.8.0                h13ea094_2    conda-forge
azure-storage-blobs-cpp   12.12.0              hfde595f_0    conda-forge
azure-storage-common-cpp  12.7.0               hcf3b6fd_1    conda-forge
azure-storage-files-datalake-cpp 12.11.0              h082e32e_1    conda-forge
blackjax                  1.2.4              pyhd8ed1ab_0    conda-forge
blas                      2.124                  openblas    conda-forge
blas-devel                3.9.0           24_osxarm64_openblas    conda-forge
brotli                    1.1.0                hd74edd7_2    conda-forge
brotli-bin                1.1.0                hd74edd7_2    conda-forge
brotli-python             1.1.0           py312hde4cb15_2    conda-forge
bzip2                     1.0.8                h99b78c6_7    conda-forge
c-ares                    1.34.1               hd74edd7_0    conda-forge
c-compiler                1.8.0                h2664225_0    conda-forge
ca-certificates           2024.8.30            hf0a4a13_0    conda-forge
cached-property           1.5.2                hd8ed1ab_1    conda-forge
cached_property           1.5.2              pyha770c72_1    conda-forge
cachetools                5.5.0              pyhd8ed1ab_0    conda-forge
cairo                     1.18.0               hb4a6bf7_3    conda-forge
cctools                   1010.6               hf67d63f_1    conda-forge
cctools_osx-arm64         1010.6               h4208deb_1    conda-forge
certifi                   2024.8.30          pyhd8ed1ab_0    conda-forge
cffi                      1.17.1          py312h0fad829_0    conda-forge
charset-normalizer        3.4.0              pyhd8ed1ab_0    conda-forge
chex                      0.1.87             pyhd8ed1ab_0    conda-forge
clang                     17.0.6          default_h360f5da_7    conda-forge
clang-17                  17.0.6          default_h146c034_7    conda-forge
clang_impl_osx-arm64      17.0.6              he47c785_21    conda-forge
clang_osx-arm64           17.0.6              h54d7cd3_21    conda-forge
clangxx                   17.0.6          default_h360f5da_7    conda-forge
clangxx_impl_osx-arm64    17.0.6              h50f59cd_21    conda-forge
clangxx_osx-arm64         17.0.6              h54d7cd3_21    conda-forge
cloudpickle               3.0.0              pyhd8ed1ab_0    conda-forge
colorama                  0.4.6              pyhd8ed1ab_0    conda-forge
compiler-rt               17.0.6               h856b3c1_2    conda-forge
compiler-rt_osx-arm64     17.0.6               h832e737_2    conda-forge
cons                      0.4.6              pyhd8ed1ab_0    conda-forge
contourpy                 1.3.0           py312h6142ec9_2    conda-forge
cpython                   3.12.7          py312hd8ed1ab_0    conda-forge
cxx-compiler              1.8.0                he8d86c4_0    conda-forge
cycler                    0.12.1             pyhd8ed1ab_0    conda-forge
etils                     1.9.4              pyhd8ed1ab_0    conda-forge
etuples                   0.3.9              pyhd8ed1ab_0    conda-forge
expat                     2.6.3                hf9b8971_0    conda-forge
fastprogress              1.0.3              pyhd8ed1ab_0    conda-forge
filelock                  3.16.1             pyhd8ed1ab_0    conda-forge
font-ttf-dejavu-sans-mono 2.37                 hab24e00_0    conda-forge
font-ttf-inconsolata      3.000                h77eed37_0    conda-forge
font-ttf-source-code-pro  2.038                h77eed37_0    conda-forge
font-ttf-ubuntu           0.83                 h77eed37_3    conda-forge
fontconfig                2.14.2               h82840c6_0    conda-forge
fonts-conda-ecosystem     1                             0    conda-forge
fonts-conda-forge         1                             0    conda-forge
fonttools                 4.54.1          py312h024a12e_0    conda-forge
freetype                  2.12.1               hadb7bae_2    conda-forge
fribidi                   1.0.10               h27ca646_0    conda-forge
fsspec                    2024.9.0           pyhff2d567_0    conda-forge
gdk-pixbuf                2.42.12              h7ddc832_0    conda-forge
gflags                    2.2.2             hf9b8971_1005    conda-forge
glog                      0.7.1                heb240a5_0    conda-forge
gmp                       6.3.0                h7bae524_2    conda-forge
gmpy2                     2.1.5           py312h87fada9_2    conda-forge
graphite2                 1.3.13            hebf3989_1003    conda-forge
graphviz                  12.0.0               hbf8cc41_0    conda-forge
gtk2                      2.24.33              h91d5085_5    conda-forge
gts                       0.7.6                he42f4ea_4    conda-forge
h2                        4.1.0              pyhd8ed1ab_0    conda-forge
h5netcdf                  1.4.0              pyhd8ed1ab_0    conda-forge
h5py                      3.11.0          nompi_py312h903599c_102    conda-forge
harfbuzz                  9.0.0                h997cde5_1    conda-forge
hdf5                      1.14.3          nompi_hec07895_105    conda-forge
hpack                     4.0.0              pyh9f0ad1d_0    conda-forge
huggingface_hub           0.25.2             pyh0610db2_0    conda-forge
hyperframe                6.0.1              pyhd8ed1ab_0    conda-forge
icu                       75.1                 hfee45f7_0    conda-forge
idna                      3.10               pyhd8ed1ab_0    conda-forge
importlib-metadata        8.5.0              pyha770c72_0    conda-forge
jax                       0.4.31             pyhd8ed1ab_1    conda-forge
jaxlib                    0.4.31          cpu_py312h47007b3_1    conda-forge
jaxopt                    0.8.3              pyhd8ed1ab_0    conda-forge
jinja2                    3.1.4              pyhd8ed1ab_0    conda-forge
joblib                    1.4.2              pyhd8ed1ab_0    conda-forge
kiwisolver                1.4.7           py312h6142ec9_0    conda-forge
krb5                      1.21.3               h237132a_0    conda-forge
lcms2                     2.16                 ha0e7c42_0    conda-forge
ld64                      951.9                h39a299f_1    conda-forge
ld64_osx-arm64            951.9                hc81425b_1    conda-forge
lerc                      4.0.0                h9a09cb3_0    conda-forge
libabseil                 20240116.2      cxx17_h00cdb27_1    conda-forge
libaec                    1.1.3                hebf3989_0    conda-forge
libarrow                  17.0.0          hc6a7651_16_cpu    conda-forge
libblas                   3.9.0           24_osxarm64_openblas    conda-forge
libbrotlicommon           1.1.0                hd74edd7_2    conda-forge
libbrotlidec              1.1.0                hd74edd7_2    conda-forge
libbrotlienc              1.1.0                hd74edd7_2    conda-forge
libcblas                  3.9.0           24_osxarm64_openblas    conda-forge
libclang-cpp17            17.0.6          default_h146c034_7    conda-forge
libcrc32c                 1.1.2                hbdafb3b_0    conda-forge
libcurl                   8.10.1               h13a7ad3_0    conda-forge
libcxx                    19.1.1               ha82da77_0    conda-forge
libcxx-devel              17.0.6               h86353a2_6    conda-forge
libdeflate                1.22                 hd74edd7_0    conda-forge
libedit                   3.1.20191231         hc8eb9b7_2    conda-forge
libev                     4.33                 h93a5062_2    conda-forge
libexpat                  2.6.3                hf9b8971_0    conda-forge
libffi                    3.4.2                h3422bc3_5    conda-forge
libgd                     2.3.3               hac1b3a8_10    conda-forge
libgfortran               5.0.0           13_2_0_hd922786_3    conda-forge
libgfortran5              13.2.0               hf226fd6_3    conda-forge
libglib                   2.82.1               h4821c08_0    conda-forge
libgoogle-cloud           2.29.0               hfa33a2f_0    conda-forge
libgoogle-cloud-storage   2.29.0               h90fd6fa_0    conda-forge
libgrpc                   1.62.2               h9c18a4f_0    conda-forge
libiconv                  1.17                 h0d3ecfb_2    conda-forge
libintl                   0.22.5               h8414b35_3    conda-forge
libjpeg-turbo             3.0.0                hb547adb_1    conda-forge
liblapack                 3.9.0           24_osxarm64_openblas    conda-forge
liblapacke                3.9.0           24_osxarm64_openblas    conda-forge
libllvm14                 14.0.6               hd1a9a77_4    conda-forge
libllvm17                 17.0.6               h5090b49_2    conda-forge
libnghttp2                1.58.0               ha4dd798_1    conda-forge
libopenblas               0.3.27          openmp_h517c56d_1    conda-forge
libpng                    1.6.44               hc14010f_0    conda-forge
libprotobuf               4.25.3               hc39d83c_1    conda-forge
libre2-11                 2023.09.01           h7b2c953_2    conda-forge
librsvg                   2.58.4               h40956f1_0    conda-forge
libsqlite                 3.46.1               hc14010f_0    conda-forge
libssh2                   1.11.0               h7a5bd25_0    conda-forge
libtiff                   4.7.0                hfce79cd_1    conda-forge
libtorch                  2.4.1           cpu_generic_h123b01e_0    conda-forge
libutf8proc               2.8.0                h1a8c8d9_0    conda-forge
libuv                     1.49.0               hd74edd7_0    conda-forge
libwebp-base              1.4.0                h93a5062_0    conda-forge
libxcb                    1.17.0               hdb1d25a_0    conda-forge
libxml2                   2.12.7               h01dff8b_4    conda-forge
libzlib                   1.3.1                h8359307_2    conda-forge
llvm-openmp               19.1.1               h6cdba0f_0    conda-forge
llvm-tools                17.0.6               h5090b49_2    conda-forge
llvmlite                  0.43.0          py312ha9ca408_1    conda-forge
logical-unification       0.4.6              pyhd8ed1ab_0    conda-forge
lz4-c                     1.9.4                hb7217d7_0    conda-forge
macosx_deployment_target_osx-arm64 11.0                 h6553868_1    conda-forge
markdown-it-py            3.0.0              pyhd8ed1ab_0    conda-forge
markupsafe                3.0.1           py312h906988d_1    conda-forge
matplotlib                3.9.2           py312h1f38498_1    conda-forge
matplotlib-base           3.9.2           py312h9bd0bc6_1    conda-forge
mdurl                     0.1.2              pyhd8ed1ab_0    conda-forge
minikanren                1.0.3              pyhd8ed1ab_0    conda-forge
ml_dtypes                 0.5.0           py312hcd31e36_0    conda-forge
mpc                       1.3.1                h8f1351a_1    conda-forge
mpfr                      4.2.1                hb693164_3    conda-forge
mpmath                    1.3.0              pyhd8ed1ab_0    conda-forge
multipledispatch          0.6.0              pyhd8ed1ab_1    conda-forge
munkres                   1.1.4              pyh9f0ad1d_0    conda-forge
ncurses                   6.5                  h7bae524_1    conda-forge
networkx                  3.4                pyhd8ed1ab_0    conda-forge
nomkl                     1.0                  h5ca1d4c_0    conda-forge
numba                     0.60.0          py312h41cea2d_0    conda-forge
numpy                     1.26.4          py312h8442bc7_0    conda-forge
numpyro                   0.15.3             pyhd8ed1ab_0    conda-forge
nutpie                    0.13.2          py312headafe2_0    conda-forge
openblas                  0.3.27          openmp_h560b219_1    conda-forge
openjpeg                  2.5.2                h9f1df11_0    conda-forge
openssl                   3.3.2                h8359307_0    conda-forge
opt-einsum                3.4.0                hd8ed1ab_0    conda-forge
opt_einsum                3.4.0              pyhd8ed1ab_0    conda-forge
optax                     0.2.3              pyhd8ed1ab_0    conda-forge
orc                       2.0.2                h75dedd0_0    conda-forge
packaging                 24.1               pyhd8ed1ab_0    conda-forge
pandas                    2.2.3           py312hcd31e36_1    conda-forge
pango                     1.54.0               h9ee27a3_2    conda-forge
pcre2                     10.44                h297a79d_2    conda-forge
pillow                    10.4.0          py312h8609ca0_1    conda-forge
pip                       24.2               pyh8b19718_1    conda-forge
pixman                    0.43.4               hebf3989_0    conda-forge
psutil                    6.0.0           py312h024a12e_1    conda-forge
pthread-stubs             0.4               hd74edd7_1002    conda-forge
pyarrow-core              17.0.0          py312he20ac61_1_cpu    conda-forge
pycparser                 2.22               pyhd8ed1ab_0    conda-forge
pygments                  2.18.0             pyhd8ed1ab_0    conda-forge
pymc                      5.17.0               hd8ed1ab_0    conda-forge
pymc-base                 5.17.0             pyhd8ed1ab_0    conda-forge
pyparsing                 3.1.4              pyhd8ed1ab_0    conda-forge
pysocks                   1.7.1              pyha2e5f31_6    conda-forge
pytensor                  2.25.5          py312h3f593ad_0    conda-forge
pytensor-base             2.25.5          py312h02baea5_0    conda-forge
python                    3.12.7          h739c21a_0_cpython    conda-forge
python-dateutil           2.9.0              pyhd8ed1ab_0    conda-forge
python-graphviz           0.20.3             pyhe28f650_1    conda-forge
python-tzdata             2024.2             pyhd8ed1ab_0    conda-forge
python_abi                3.12                    5_cp312    conda-forge
pytorch                   2.4.1           cpu_generic_py312h40771f0_0    conda-forge
pytz                      2024.1             pyhd8ed1ab_0    conda-forge
pyyaml                    6.0.2           py312h024a12e_1    conda-forge
qhull                     2020.2               h420ef59_5    conda-forge
re2                       2023.09.01           h4cba328_2    conda-forge
readline                  8.2                  h92ec313_1    conda-forge
requests                  2.32.3             pyhd8ed1ab_0    conda-forge
rich                      13.9.2             pyhd8ed1ab_0    conda-forge
safetensors               0.4.5           py312he431725_0    conda-forge
scikit-learn              1.5.2           py312h387f99c_1    conda-forge
scipy                     1.14.1          py312heb3a901_0    conda-forge
setuptools                75.1.0             pyhd8ed1ab_0    conda-forge
sigtool                   0.1.3                h44b9a77_0    conda-forge
six                       1.16.0             pyh6c4a22f_0    conda-forge
sleef                     3.7                  h7783ee8_0    conda-forge
snappy                    1.2.1                hd02b534_0    conda-forge
sympy                     1.13.3           pyh2585a3b_104    conda-forge
tabulate                  0.9.0              pyhd8ed1ab_1    conda-forge
tapi                      1300.6.5             h03f4b80_0    conda-forge
threadpoolctl             3.5.0              pyhc1e730c_0    conda-forge
tk                        8.6.13               h5083fa2_1    conda-forge
toolz                     1.0.0              pyhd8ed1ab_0    conda-forge
tornado                   6.4.1           py312h024a12e_1    conda-forge
tqdm                      4.66.5             pyhd8ed1ab_0    conda-forge
typing-extensions         4.12.2               hd8ed1ab_0    conda-forge
typing_extensions         4.12.2             pyha770c72_0    conda-forge
tzdata                    2024b                hc8b5060_0    conda-forge
urllib3                   2.2.3              pyhd8ed1ab_0    conda-forge
wheel                     0.44.0             pyhd8ed1ab_0    conda-forge
xarray                    2024.9.0           pyhd8ed1ab_1    conda-forge
xarray-einstats           0.8.0              pyhd8ed1ab_0    conda-forge
xorg-libxau               1.0.11               hd74edd7_1    conda-forge
xorg-libxdmcp             1.1.5                hd74edd7_0    conda-forge
xz                        5.2.6                h57fd34a_0    conda-forge
yaml                      0.2.5                h3422bc3_2    conda-forge
zipp                      3.20.2             pyhd8ed1ab_0    conda-forge
zlib                      1.3.1                h8359307_2    conda-forge
zstandard                 0.23.0          py312h15fbf35_1    conda-forge
zstd                      1.5.6                hb46c0d2_0    conda-forge
@aseyboldt
Copy link
Member

That you for the bug report. I had seen something possibly related recently, but didn't manage to find an example in a smaller model. This example should make it much easier to find the problem.

Right now you can work around the issue by freezing the pymc model:

from pymc.model.transform.optimization import freeze_dims_and_data
trace = nutpie.sample(nutpie.compile_pymc_model(freeze_dims_and_data(model), **kwargs))

@aseyboldt
Copy link
Member

Well, that was easier than I though, and won't be hard to fix. The problem is that the argument name for the point in parameter space is x and if any shared variable (like the data) is also called x this will give an argument name collision.

@trendelkampschroer
Copy link
Author

trendelkampschroer commented Oct 21, 2024

@aseyboldt : Thanks a lot for the very helpful reply. I renamed 'x' -> 'X' and now things are working.

This is really good to know, but probably something that should either be fixed by ensuring that a unique name for the point in parameter space is used, or forbidding 'x' as name in the model (which would be a bit cumbersome, since predictors are often denoted by 'x').

On the upside using JAX gives a very nice speedup on my machine (Apple M1) :-).

@aseyboldt
Copy link
Member

Yes, definitely needs a fix, I'll push one soon.

Out of curiosity (I don't have a apple), could you do me a small favor and run this with jax and numba and tell me what the compile and the runtime is each time?

jax

frame = generate_data(num_samples=10_000)
model = make_model(frame)

kwargs = dict(backend="jax", gradient_backend="jax")

t0 = time.time()
compiled = nutpie.compile_pymc_model(model, **kwargs)
print(f"compile time: {time.time() - t0}")

t0 = time.time()
trace = nutpie.sample(compiled)
t = time.time() - t0
print(f"Time for nutpie (compiled, {kwargs=}) sampling is {t=:.3f}s.")
summary = az.summary(trace, var_names=["beta", "sigma"], round_to=4).loc[:, ["mean", "hdi_3%", "hdi_97%", "ess_bulk", "ess_tail", "r_hat"]]
print(summary)

numba

frame = generate_data(num_samples=10_000)
model = make_model(frame)

kwargs = dict(backend="numba", gradient_backend="jax")

t0 = time.time()
compiled = nutpie.compile_pymc_model(model, **kwargs)
print(f"compile time: {time.time() - t0}")

t0 = time.time()
trace = nutpie.sample(compiled)
t = time.time() - t0
print(f"Time for nutpie (compiled, {kwargs=}) sampling is {t=:.3f}s.")
summary = az.summary(trace, var_names=["beta", "sigma"], round_to=4).loc[:, ["mean", "hdi_3%", "hdi_97%", "ess_bulk", "ess_tail", "r_hat"]]
print(summary)

@trendelkampschroer
Copy link
Author

trendelkampschroer commented Oct 21, 2024

jax

compile time for kwargs={'backend': 'jax', 'gradient_backend': 'jax'}: 1.016s
Time for nutpie (compiled, kwargs={'backend': 'jax', 'gradient_backend': 'jax'}) sampling is t=2.094s.

numba

compile time for kwargs={'backend': 'numba', 'gradient_backend': 'pytensor'}: 3.835s
Time for nutpie (compiled, kwargs={'backend': 'numba', 'gradient_backend': 'pytensor'}) sampling is t=0.564s.

@trendelkampschroer
Copy link
Author

trendelkampschroer commented Oct 21, 2024

I hope that helps. I have another follow up question: While I observe a great speed-up when using the JAX backend on my M1 Apple machine, I observe significantly slower sampling with the JAX backend compared to Numba/Pytensor when running on a Google Cloud VM with a lot more cores (32) and memory. This is for a hierarchical linear regression with thousands of groups and a couple of predictors.

On the VM sampling with the "jax" backend is about 30% slower compared to the "numba" backend. Specifically I observe that for the "numba" backend I get a couple of (4-8) thread/CPU bars in htop with 100%, while for the JAX backend all 32 bars show "some occupancy at less than 50%".

If you have any ideas/insights what could cause this and also how to ensure best performance, then I'd be glad for any suggestions.

@aseyboldt
Copy link
Member

Thanks for the numbers :-)

First, I think it is important to distinguish compile time and sampling time. The numbers you just gave me show that the numba backend samples faster on the mac as well, only the compile time is much larger. If the model get's bigger the compile time will play less of a role, because it doesn't depend much on the data size.

I think what you observe with the jax backend is an issue with how the jax backend currently works:
The chains run in different threads that are controlled in the rust code. With the numba backend the python interpreter is only used to generate the logp function, all sampling happens without any involvement of python. But doing this with jax is currently much harder (I hope this will change in the not too distant future though).
jax compiles the logp function, but I can't easily access this compiled function from rust. So instead I have to call a python function that then calls the compiled jax function. While a bit silly, that wouldn't be too bad if python didn't have the GIL. But the GIL ensures that only one thread (ie chain) can use the python interpreter at the same time. So each logp function evaluation does something like the following:

  • Acquire the gil (ie wait until no other thread is holding the GIL)
  • start the jax logp function
  • release the gil
  • do the actual logp fucniton evaluation (while not holding the gil)
  • acquire the gil
  • get access to the results
  • release the gil

If the computation of the logp function takes a long time, and there aren't that many threads, then most of the time only one or even no thread will hold the gil, because each threads spends most of its time in the "do the actual logp function evaluation" phase, and all is good.
But if the logp function evaluation is relatively quick, then more than one thread will try to acquire the gil at the same time, and this means that the threads sit around waiting. Ie "low occupancy".

There are two things that might make this situation better in the future:

  • python 3.13 has a new build without the gil. That isn't supported by jax and many other libraries at the time, but I hope this will change quickly.
  • Jax might make it easier to call the compiled functions from other languages, so that I don't have to go through python to call the compiled function, which would avoid the whole problem to begin with.

In the meantime: If the cores of your machine aren't used well, you can at least try to limit the number of threads that run at the same time by setting the cores argument to sample to something smaller. This can reduce the lock contention and give you a modest speedup. It won't really fix the problem though...

If you are willing to go to some extra lengths: You can start multiple separate processes that sample your model (with different seeds!) and then combine the traces. This is much more annoying, but should completely avoid the lock contention. In that case you can run into other issues however, for instance if each process tries to use all available cores on the machine. Fixing that would then require using threadpoolctl and/or tasksel or some jax environment variable flags.

I hope that helps to clear it up a bit :-)

@trendelkampschroer
Copy link
Author

trendelkampschroer commented Oct 21, 2024

Thanks @aseyboldt this is really helpful. Do you know of a minimal example for the "start multilple separate processes" approach. I have seen https://discourse.pymc.io/t/harnessing-multiple-cores-to-speed-up-fits-with-small-number-of-chains/7669 where the idea is to concatenate multiple smaller chains to more efficiently harness the CPUs on a machine.
I'd e.g. try to use joblib for that but I am not sure how much that interferes with the PyMC and nutpie internals. If you have any pointers I'd be very glad to look into it.

Btw: with num_samples = 100_000 the numbers look like this on Apple M1

jax

compile time for kwargs={'backend': 'jax', 'gradient_backend': 'jax'}: 0.864s
Time for nutpie (compiled, kwargs={'backend': 'jax', 'gradient_backend': 'jax'}) sampling is t=5.842s.

numba

compile time for kwargs={'backend': 'numba', 'gradient_backend': 'pytensor'}: 3.874s
Time for nutpie (compiled, kwargs={'backend': 'numba', 'gradient_backend': 'pytensor'}) sampling is t=19.923s.

So JAX is a lot faster for sampling - which also matches my observation for a hierarchical linear model.

@aseyboldt
Copy link
Member

For sampling in separate processes:

# At the very start...
import os
os.environ["JOBLIB_START_METHOD"] = "forkserver"

import joblib
from joblib import parallel_config, Parallel, delayed
import arviz

def run_chain(data, idx, seed):
    model = make_model(data)
    seeds = np.random.SeedSequence(seed)
    seed = np.random.default_rng(seeds.spawn(idx + 1)[-1]).integers(2 ** 63)
    compiled = nutpie.compile_pymc_model(model, backend="jax", gradient_backend="jax")
    trace = nutpie.sample(compiled, seed=seed, chains=1, progress_bar=False)
    return trace.assign_coords(chain=[idx])

with parallel_config(n_jobs=10, prefer='processes'):
   traces = Parallel()(delayed(run_chain)(frame, i, 123) for i in range(10))

trace = arviz.concat(traces, dim="chain")

This comes with quite a bit of overhead (mostly constant though), so probably not worth it for smaller models.

Funnily enough, I see big differences between

# Option 1
mu = (beta * x).sum(axis=-1)

# Option 2
mu = x @ beta

And jax and numba react quite differently. Maybe an issue with the blas config? What blas implementation are you using?

(on conda-forge you can choose it as explained here: https://conda-forge.org/docs/maintainer/knowledge_base/#switching-blas-implementation)
I think on M1 accelerate is usually the fastest).

@trendelkampschroer
Copy link
Author

Thanks a lot for the again very helpful suggestions. I will benchmark the two versions of the "dot-product" to see whether I observe different performance.

Regarding BLAS

On Apple-M1 I have

blas                      2.124                  openblas    conda-forge
blas-devel                3.9.0           24_osxarm64_openblas    conda-forge
libblas                   3.9.0           24_osxarm64_openblas    conda-forge
libcblas                  3.9.0           24_osxarm64_openblas    conda-forge
liblapack                 3.9.0           24_osxarm64_openblas    conda-forge
liblapacke                3.9.0           24_osxarm64_openblas    conda-forge
libopenblas               0.3.27          openmp_h517c56d_1    conda-forge
openblas                  0.3.27          openmp_h560b219_1    conda-forge

and on the VM

blas                      2.120                       mkl    conda-forge
blas-devel                3.9.0            20_linux64_mkl    conda-forge
libblas                   3.9.0            20_linux64_mkl    conda-forge
libcblas                  3.9.0            20_linux64_mkl    conda-forge

I can try to use the accelerate BLAS. But I am more curious to speed up things on the VM now.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants