-
Notifications
You must be signed in to change notification settings - Fork 215
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Workspaces for different backends of Keras #1242
base: develop
Are you sure you want to change the base?
Conversation
Signed-off-by: yes <[email protected]>
Signed-off-by: yes <[email protected]>
Signed-off-by: yes <[email protected]>
Signed-off-by: yes <[email protected]>
Signed-off-by: yes <[email protected]>
Signed-off-by: yes <[email protected]>
Signed-off-by: yes <[email protected]>
Signed-off-by: yes <[email protected]>
Signed-off-by: yes <[email protected]>
Signed-off-by: yes <[email protected]>
Signed-off-by: yes <[email protected]>
Signed-off-by: yes <[email protected]>
Signed-off-by: yes <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for getting this started @tanwarsh !
My biggest comment here (and I'd like to hear your thoughts on this) is that the JAX/torch workspaces don't seem to showcase JAX or torch. OpenFL offers the ability to define custom training and evaluation loop via train_
and validate_
methods of the runner. Keras is able to support custom training loops written in JAX and torch that can potentially preserve a lot of each framework's native API in a similar manner. See JAX / Torch.
My personal opinion, I think it would be more compelling to showcase OpenFL's framework flexibility in this manner. WDYT?
@@ -0,0 +1,2 @@ | |||
keras==3.6.0 | |||
jax==0.4.38 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It doesn't seem like JAX is ever actually explicitly used throughout the experiment
@@ -0,0 +1,2 @@ | |||
keras==3.6.0 | |||
torch==2.5.1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similar comment here - it doesn't seem like torch is ever explicitly used in the experiment
@@ -17,6 +19,14 @@ | |||
from openfl.utilities import Metric, TensorKey, change_tags | |||
from openfl.utilities.split import split_tensor_dict_for_holdouts | |||
|
|||
# Set the KERAS_BACKEND environment variable based on the available deep learning framework |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder if it would be better to add a "backend" argument to the keras runner (that could maybe be configured in the plan.yaml
). Rationale being that a user may have an env that contains multiple frameworks. Idk if that's a super common scenario, but the runner would end up setting the backend to the first one it encounters rather than the intended backend resulting in an explicit error or some unintended consequences
I agree with you @kta-intel. The changes in this PR only enable different backends for Keras, but the workspaces do not showcase JAX, Torch, or TensorFlow. It would definitely be compelling to showcase JAX and Torch with Keras. I will work on this. |
But will it make sense, to merge this which creates the building block and create a separate PR to bring actual example? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Echoing @kta-intel's comment, three separate tutorials with different backends do not justify the extra lines of code we have to maintain in this repo.
My suggestions, in no particular order:
- Group tutorials by use-case under
keras
:keras/{mnist,nlp}
. Backends should be switchable by installing necessary packages (jax, tf or torch). This would go as a multi-backend CI test for OpenFL. - Group tutorials by framework;
{jax,torch}
. These will target custom training loops in pure DL framework code, utilizing their respective taskrunner classes. This can be a separate PR.
The PR includes the following changes: