Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Getting Rid of TensorFlow #69

Closed
Weizhe-Chen opened this issue Jun 2, 2022 · 12 comments
Closed

feat: Getting Rid of TensorFlow #69

Weizhe-Chen opened this issue Jun 2, 2022 · 12 comments
Labels
enhancement New feature or request

Comments

@Weizhe-Chen
Copy link

Feature Request

Describe the Feature Request

GPJax is a Gaussian processes library based on Jax but currently it still depends on TensorFlow. TensorFlow is annoying because (1) its installation is slow, and (2) it produces many unnecessary warnings when running on CPU-only machines. It would be great if we can implement all functionalities without TensorFlow. Is it possible?

Describe Preferred Solution

Replace all the functionalities of TensorFlow with their Jax counterparts.

If the feature request is approved, would you be willing to submit a PR?
Yes

@Weizhe-Chen Weizhe-Chen added the enhancement New feature or request label Jun 2, 2022
@daniel-dodd
Copy link
Member

Hi, thanks for raising this issue. We are aware of this - it would be nice to drop this dependency! We currently rely on tensorflow.data to provide data-loading capabilities (prefetching, batching, shuffling, ...) on the gpjax.Dataset object for stochastic optimisation. We wouldn't want to fully implement Jax-based versions of these operations in GPJax from scratch (for readability and maintainability). Do you have any suggestions for Jax libraries that provide well-developed data-loading functionality?

@Weizhe-Chen
Copy link
Author

Hi Daniel,
Thanks for the prompt reply! I agree that we should avoid reinventing wheels while minimizing unnecessary dependencies. I found that the Jax-based Deep Learning library elegy provides a Jax implementation of Dataset and Dataloader which supports prefetching, batching, and shuffling.
Here is an example usage: https://github.com/poets-ai/elegy/blob/master/examples/elegy/mnist_dataloader.py
Implementation is here: https://github.com/poets-ai/elegy/blob/4709ce8dc9dde3925ce717e2358ce49112e36398/elegy/data/dataset.py#L34
I cannot say their implementation is well-developed since elegy is still in its infancy just like GPJax, but the Dataset and DataLoader classes can be easily understood and maintained. I can help in replacing tf.data with this implementation if you think it is a good idea.
Best,
Weizhe (Wesley)

@thomaspinder
Copy link
Collaborator

Hi Weizhe,

Thanks for your suggestion. As Dan has mentioned, removing the TensorFlow dependency is something we'd like to do. Your elegy suggestion seems sensible and I'd be very happy to review a PR on this. If you'd like to move forward with this, then there are two things I'd specifically like to confirm

  1. The installation of elegy is not larger and/or more convoluted than TensorFlow.
  2. Replacing the TensorFlow dataset object with an elegy implementation does not impair speed

Please let me know what, if any, support you will need to get started with this; I'm very happy to assist.

@Weizhe-Chen
Copy link
Author

Hi Thomas,
Good questions. I had the same concerns before suggesting this.

  1. Since we only need to Dataset and DataLoader classes, it would be better if we directly put these two classes into GPJax rather than installing elegy. We can remove TF dependency without introducing another heavy burden.
  2. It is straightforward to complete the above step. I can compare the speed of tf.data and elegy.data today and let you know.

@thomaspinder
Copy link
Collaborator

thomaspinder commented Jun 3, 2022

Responding to 1., this speaks to a broader question I've had as to whether a second library should be formed that provides functionality for just applying GP methods. In that case, GPJax could simply provide the tools that one needs to define a GP, whilst the second library (let's call it GPFit for now) has implementations of dataloaders, fitting routines, and post-hoc model diagnostics.

I'm keen to hear your thoughts on this @Weizhe-Chen and also @daniel-dodd .

@Weizhe-Chen
Copy link
Author

Great idea! Considering that there will be more and more Jax-based machine learning library emerging but Jax itself does not provide these indispensable pre-processing, inference, and evaluation tools, a standalone light-weight library dedicated to these tools will be a pressing need. Actually I think pre-processing, data-loading, evaluation tools can be one library which "non-Bayesian" libraries can also depend on, and Bayesian inference methods can form another library that can be also applied to other Bayesian models.

@daniel-dodd
Copy link
Member

@thomaspinder @Weizhe-Chen This is an excellent idea! I, too, would be interested in setting up a general toolkit for performing data-preprocessing and training models in Jax. The functions in abstractions.py could already, as such, apply to general supervised learning objectives.

@Weizhe-Chen
Copy link
Author

Maybe we are talking about a Jax counterpart of PyTorch Lightning?

@thomaspinder
Copy link
Collaborator

Perhaps. Elegy does a pretty good job imo of abstracting away a lot of the gritty details of model fitting. I think of Elegy as quite similar to Keras/Lightning. In the interest of not having n+1 standards, we should probably constrain ourselves to dealing with inference problems that are specific to Gaussian processes or Bayesian methods.

@thomaspinder
Copy link
Collaborator

thomaspinder commented Jun 3, 2022

I think this issue has become a little distracted. We should continue the discussion in #70.

Regarding the original issue of removing TensorFlow, would you be willing to open a PR @Weizhe-Chen for using Elegy's dataloader?

@Weizhe-Chen
Copy link
Author

Sounds like a plan!

@thomaspinder
Copy link
Collaborator

This has been resolved by #99

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

3 participants