Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Export for JAX Backend #819

Merged
merged 27 commits into from
Sep 8, 2023
Merged

Add Export for JAX Backend #819

merged 27 commits into from
Sep 8, 2023

Conversation

nkovela1
Copy link
Collaborator

This PR adds initial support for the export of JAX-backend models to TF SavedModel.
Additionally, a test has been added for subclassed model export support along with logic in both TF and JAX backends.

There are a few caveats and TODOs to note here:

  • jax2tf function conversion is lacking argument name preservation, causing the names of arguments passed to become args_tf_0, args_tf_1, etc. I have opened a bug internally in JAX2TF Users regarding this, and a workaround should be included in a subsequent PR.

  • ReloadedLayer support will be included in a subsequent PR, further discussion will be needed on this.

  • All other tests pass!

Copy link
Member

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR!

@@ -432,6 +432,9 @@ def standardize_shape(shape):
for e in shape:
if e is None:
continue
if config.backend() == "jax" and str(e) == "b":
# JAX2TF tracing represents `None` dimensions as `b`
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are they literally the string "b"? Is there a more reliable way to make this check than str(e) == "b"?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, in JAX the dimensions unknown at tracing time are represented by the literal string "b", so this is the easiest check:

https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#shape-polymorphic-conversion

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about just e is "b"?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That results in a SyntaxError with Python:

image

I think it used to be a SyntaxWarning but more recently upgraded to Error.

keras_core/layers/layer.py Outdated Show resolved Hide resolved
keras_core/export/export_lib.py Outdated Show resolved Hide resolved
keras_core/export/export_lib.py Outdated Show resolved Hide resolved
@@ -626,6 +627,7 @@ def compute_mask(self, inputs, previous_mask):
@traceback_utils.filter_traceback
def __call__(self, *args, **kwargs):
self._check_super_called()
self._called = True
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not use self.built? Should be equivalent

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If a user calls build by themselves, for example, we should not be allowing export. There is an existing test for this exact flow in TF Keras:

https://github.com/keras-team/keras/blob/master/keras/export/export_lib_test.py#L364

It's a niche case, but let me know what you think.

@codecov
Copy link

codecov bot commented Sep 7, 2023

Codecov Report

❗ No coverage uploaded for pull request base (main@1a7720e). Click here to learn what that means.
Patch has no changes to coverable lines.

Additional details and impacted files
@@           Coverage Diff           @@
##             main     #819   +/-   ##
=======================================
  Coverage        ?   76.01%           
=======================================
  Files           ?      328           
  Lines           ?    31134           
  Branches        ?     6060           
=======================================
  Hits            ?    23668           
  Misses          ?     5868           
  Partials        ?     1598           
Flag Coverage Δ
keras_core 75.92% <0.00%> (?)

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.

📢 Have feedback on the report? Share it here.

Copy link
Member

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the update!

keras_core/export/export_lib.py Outdated Show resolved Hide resolved
keras_core/layers/layer.py Outdated Show resolved Hide resolved
Copy link
Member

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks!

@fchollet fchollet merged commit ac9be33 into keras-team:main Sep 8, 2023
8 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants