-
Notifications
You must be signed in to change notification settings - Fork 117
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Export for JAX Backend #819
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR!
@@ -432,6 +432,9 @@ def standardize_shape(shape): | |||
for e in shape: | |||
if e is None: | |||
continue | |||
if config.backend() == "jax" and str(e) == "b": | |||
# JAX2TF tracing represents `None` dimensions as `b` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are they literally the string "b"
? Is there a more reliable way to make this check than str(e) == "b"
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, in JAX the dimensions unknown at tracing time are represented by the literal string "b", so this is the easiest check:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What about just e is "b"
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@@ -626,6 +627,7 @@ def compute_mask(self, inputs, previous_mask): | |||
@traceback_utils.filter_traceback | |||
def __call__(self, *args, **kwargs): | |||
self._check_super_called() | |||
self._called = True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not use self.built
? Should be equivalent
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If a user calls build
by themselves, for example, we should not be allowing export. There is an existing test for this exact flow in TF Keras:
https://github.com/keras-team/keras/blob/master/keras/export/export_lib_test.py#L364
It's a niche case, but let me know what you think.
Codecov Report
Additional details and impacted files@@ Coverage Diff @@
## main #819 +/- ##
=======================================
Coverage ? 76.01%
=======================================
Files ? 328
Lines ? 31134
Branches ? 6060
=======================================
Hits ? 23668
Misses ? 5868
Partials ? 1598
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 📢 Have feedback on the report? Share it here. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the update!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks!
This PR adds initial support for the export of JAX-backend models to TF SavedModel.
Additionally, a test has been added for subclassed model export support along with logic in both TF and JAX backends.
There are a few caveats and TODOs to note here:
jax2tf function conversion is lacking argument name preservation, causing the names of arguments passed to become
args_tf_0
,args_tf_1
, etc. I have opened a bug internally in JAX2TF Users regarding this, and a workaround should be included in a subsequent PR.ReloadedLayer support will be included in a subsequent PR, further discussion will be needed on this.
All other tests pass!