Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cannot install a version of tensorflow that allows me to use GPU #611

Open
isaaceckert opened this issue Nov 25, 2024 · 4 comments
Open

Comments

@isaaceckert
Copy link

isaaceckert commented Nov 25, 2024

Working on a Apple M1 Pro mac in RStudio. I can install everything fine but they tensorflow can't find any local GPUs so the models take a long time to fit.

For example,

library(tensorflow)
tf$config$list_physical_devices("GPU")

Give me:

list()

Session info and stuff below:

> reticulate::py_config()
python:         /Users/isaaceckert/.virtualenvs/r-tensorflow/bin/python
libpython:      /Users/isaaceckert/.pyenv/versions/3.10.15/lib/libpython3.10.dylib
pythonhome:     /Users/isaaceckert/.virtualenvs/r-tensorflow:/Users/isaaceckert/.virtualenvs/r-tensorflow
version:        3.10.15 (main, Nov 20 2024, 16:40:35) [Clang 16.0.0 (clang-1600.0.26.4)]
numpy:          /Users/isaaceckert/.virtualenvs/r-tensorflow/lib/python3.10/site-packages/numpy
numpy_version:  1.26.4
tensorflow:     /Users/isaaceckert/.virtualenvs/r-tensorflow/lib/python3.10/site-packages/tensorflow

NOTE: Python version was forced by VIRTUAL_ENV
> tensorflow::tf_config()
TensorFlow v2.15.1 (~/.virtualenvs/r-tensorflow/lib/python3.10/site-packages/tensorflow)
Python vError in cat("Python v", x$python_version, " (", aliased(x$python), ")\n",  : 
  argument 2 (type 'list') cannot be handled by 'cat'
> reticulate::import("tensorflow")
Module(tensorflow)
> reticulate::py_last_error()

── Python Exception Message ─────────────────────────────────────────────────────────────────────────────────────
AttributeError: partially initialized module 'tensorflow' has no attribute 'constant' (most likely due to a circular import)

── R Traceback ──────────────────────────────────────────────────────────────────────────────────────────────────
     ▆
  1. ├─tf$constant
  2. ├─reticulate:::`$.python.builtin.module`(tf, constant)
  3. │ └─reticulate::py_get_attr(x, name, TRUE)
  4. ├─reticulate (local) `<fn>`()
  5. │ └─reticulate::configure_environment()
  6. │   └─reticulate:::python_package_requirements(package)
  7. │     └─base::lapply(...)
  8. │       └─reticulate (local) FUN(X[[i]], ...)
  9. │         ├─base::tryCatch(...)
 10. │         │ └─base (local) tryCatchList(expr, classes, parentenv, handlers)
 11. │         │   └─base (local) tryCatchOne(expr, names, parentenv, handlers[[1L]])
 12. │         │     └─base (local) doTryCatch(return(expr), name, parentenv, handler)
 13. │         └─reticulate:::python_package_requirements_find(package)
 14. │           └─base::system.file("DESCRIPTION", package = package)
 15. │             └─base::find.package(package, lib.loc, quiet = TRUE)
 16. │               └─base::apply(!is.na(db), 1L, all)
 17. ├─base::suppressWarnings(base::try(tf$constant, silent = TRUE))
 18. │ └─base::withCallingHandlers(...)
 19. ├─base::try(tf$constant, silent = TRUE)
 20. │ └─base::tryCatch(...)
 21. │   └─base (local) tryCatchList(expr, classes, parentenv, handlers)
 22. │     └─base (local) tryCatchOne(expr, names, parentenv, handlers[[1L]])
 23. │       └─base (local) doTryCatch(return(expr), name, parentenv, handler)
 24. ├─tf$constant
 25. ├─reticulate:::`$.python.builtin.module`(tf, constant)
 26. │ └─reticulate::py_get_attr(x, name, TRUE)
 27. ├─reticulate (local) `<fn>`(`<python.builtin.module>`)
 28. │ ├─base::tryCatch(import(module), error = clear_error_handler())
 29. │ │ └─base (local) tryCatchList(expr, classes, parentenv, handlers)
 30. │ │   └─base (local) tryCatchOne(expr, names, parentenv, handlers[[1L]])
 31. │ │     └─base (local) doTryCatch(return(expr), name, parentenv, handler)
 32. │ └─reticulate::import(module)
 33. │   └─reticulate:::py_module_import(module, convert = convert)
 34. ├─base::suppressWarnings(base::try(tf$constant, silent = TRUE))
 35. │ └─base::withCallingHandlers(...)
 36. ├─base::try(tf$constant, silent = TRUE)
 37. │ └─base::tryCatch(...)
 38. │   └─base (local) tryCatchList(expr, classes, parentenv, handlers)
 39. │     └─base (local) tryCatchOne(expr, names, parentenv, handlers[[1L]])
 40. │       └─base (local) doTryCatch(return(expr), name, parentenv, handler)
 41. ├─tf$constant
 42. └─reticulate:::`$.python.builtin.module`(tf, constant)
 43.   └─reticulate::py_get_attr(x, name, FALSE)
See `reticulate::py_last_error()$r_trace$full_call` for more details.
> sessionInfo()
R version 4.3.0 (2023-04-21)
Platform: aarch64-apple-darwin20 (64-bit)
Running under: macOS 14.6.1

Matrix products: default
BLAS:   /System/Library/Frameworks/Accelerate.framework/Versions/A/Frameworks/vecLib.framework/Versions/A/libBLAS.dylib 
LAPACK: /Library/Frameworks/R.framework/Versions/4.3-arm64/Resources/lib/libRlapack.dylib;  LAPACK version 3.11.0

locale:
[1] en_US.UTF-8/en_US.UTF-8/en_US.UTF-8/C/en_US.UTF-8/en_US.UTF-8

time zone: America/Toronto
tzcode source: internal

attached base packages:
[1] stats     graphics  grDevices utils     datasets  methods   base     

other attached packages:
[1] tensorflow_2.16.0.9000 keras_2.15.0          

loaded via a namespace (and not attached):
 [1] utf8_1.2.4        R6_2.5.1          base64enc_0.1-3   Matrix_1.6-1      lattice_0.21-8    reticulate_1.40.0
 [7] magrittr_2.0.3    rappdirs_0.3.3    glue_1.8.0        png_0.1-8         generics_0.1.3    lifecycle_1.0.4  
[13] cli_3.6.3         fansi_1.0.6       vctrs_0.6.5       grid_4.3.0        withr_3.0.2       zeallot_0.1.0    
[19] tfruns_1.5.3      compiler_4.3.0    rstudioapi_0.17.1 tools_4.3.0       whisker_0.4.1     pillar_1.9.0     
[25] Rcpp_1.0.13-1     rlang_1.1.4       jsonlite_1.8.9   

Any help is very much appreciated! Thanks!

@t-kalinowski
Copy link
Member

If you’re writing new code today, I recommend using keras3 with the JAX backend. The keras3::op_* family covers most of the TensorFlow API but provides better R support (e.g., integrated R help, consistent 1-based indexing, and better argument handling). For any gaps, you can easily use raw JAX via reticulate::import("jax").

That said, the packages enabling GPU usage with TensorFlow on M-series Macs still work, though they currently lag a few versions behind the latest release. You can install them as follows:

system("brew install openssl readline sqlite3 xz zlib tcl-tk")
python <- reticulate::install_python("3.11") # will take a few minutes; have patience
reticulate::virtualenv_create(
  envname = "r-tensorflow", 
  python = python, 
  packages = c("tensorflow-macos", "tensorflow-metal"), 
  force = TRUE
)
# rstudioapi::restartSession()

After this, restart the R session (rstudioapi::restartSession() or Ctrl+Shift+F10), and tf$config$list_physical_devices() should show both a CPU and a GPU device.

@isaaceckert
Copy link
Author

Thanks so much! Just so i understand correctly, if i am using keras3 to fit my model it will use GPU acceleration automatically? Is there anyway i can verify that this is what is indeed happening?

@t-kalinowski
Copy link
Member

Currently, GPU usage on macOS is not enabled by the default installer install_keras() for the jax backend, because the jax homepage displays this prominent warning:

Apple provides an experimental Metal plugin. For details, refer to Apple’s JAX on Metal documentation.

Note: There are several caveats with the Metal plugin:

  • The Metal plugin is new and experimental and has a number of known issues. Please report any issues on the JAX issue tracker.
  • The Metal plugin currently requires very specific versions of jax and jaxlib. This restriction will be relaxed over time as the plugin API matures.

However, after a closer look, that might be a stale warning. The jax-metal package was released recently and in a quick test appears to be compatible with the latest jax release. Testing locally on an M1 Mac, this worked fine:

remotes::install_github("rstudio/keras3")
reticulate::virtualenv_create("r-keras", "3.11", force = TRUE, packages = c(
  "jax", "jax-metal", "tensorflow", "pydot", "keras>=3", "numpy"
))
# restart R session
library(keras3)
use_backend("jax")

op_arange(10) # force jax to initialize

jax <- reticulate::import("jax")
jax$devices() # METAL(id=0)

@isaaceckert
Copy link
Author

Thanks so much!! I'll give it a shot!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants