Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix einsum #174

Merged
merged 2 commits into from
Sep 22, 2024
Merged

fix einsum #174

merged 2 commits into from
Sep 22, 2024

Conversation

tip3x
Copy link
Collaborator

@tip3x tip3x commented Sep 19, 2024

No description provided.

@tip3x tip3x force-pushed the fix-einsum branch 3 times, most recently from 355f25b to dbe6329 Compare September 19, 2024 12:29
is_input_0_constant = isinstance(input_0, tf.Tensor)
is_input_1_constant = isinstance(input_1, tf.Tensor)
if is_input_0_constant and is_input_1_constant:
layers[node_name] = tf.einsum(equation, *[input_0, input_1], name=keras_name)
Copy link
Collaborator

@tomkoren21 tomkoren21 Sep 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. The name param in TFOP does not propagate to the actual model due to a bug in keras.
  2. We use a different name than keras_name for clarity and debugging now, I suggest using "{params['cleaned_name']}_einsum" to be consistent

We've created a wrapper that makes sure the name propagates (under tfops_funcs.py)
I suggest wrapping tf.einsum with a named_tfop (tf_einsum), importing it, and using it as:
tf_einsum(whatever-you-need, tf_name="{params['cleaned_name']}_einsum")

Copy link
Collaborator Author

@tip3x tip3x Sep 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is not a problem because in this case the output is constant and its not part of the model

equation = params['equation'].decode('utf-8')

is_input_0_constant = isinstance(input_0, tf.Tensor)
is_input_1_constant = isinstance(input_1, tf.Tensor)
Copy link
Collaborator

@tomkoren21 tomkoren21 Sep 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tip3x
For both tests above, this could be a numpy as well, depending on the specifics of the model. Does it make sense to test for that as well?

@tomkoren21
Copy link
Collaborator

tomkoren21 commented Sep 21, 2024

If you can, please also add a test to the CI so we won't have regressions on this model in future fixes.
It should have a similar format to this

@tip3x tip3x merged commit 01ed1a5 into master Sep 22, 2024
3 checks passed
@tip3x tip3x deleted the fix-einsum branch September 22, 2024 10:59
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants