-
Notifications
You must be signed in to change notification settings - Fork 27.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
CLI: add stricter automatic checks to pt-to-tf
#17588
Conversation
The documentation is not available anymore as the PR was closed or merged. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for working on this!
Co-authored-by: Sylvain Gugger <[email protected]>
|
||
return max_difference, max_difference_source | ||
|
||
return compate_pt_tf_values(pt_outputs, tf_outputs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
compare_pt_tf_values
😄
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will rename to _compare_pt_tf_models
(to avoid a name clash, as Matt mentioned)
raise ValueError("The model outputs have different attributes, aborting.") | ||
|
||
# 2. For each key, ALL values must be the same | ||
def compate_pt_tf_values(pt_out, tf_out, attr_name=""): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
compare_pt_tf_values
..?
for i, pt_item in enumerate(pt_out): | ||
# If it is a named attribute, we keep the name. Otherwise, just its index. | ||
if isinstance(pt_item, str): | ||
branch_name = root_name + pt_item |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I feel that we will need to have something like f"{root_name}.{pt_item}"
, i.e. to include some kind of separator, so the result names will be more readable.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is no need, the names are not nested (at the moment). As it is structured, it will print the variable as we would write on a python terminal to get it, so we can copy-paste it for further inspection -- e.g. past_key_values[0][2]
else: | ||
branch_name = root_name + f"[{i}]" | ||
tf_item = tf_out[i] | ||
difference, difference_source = compate_pt_tf_values(pt_item, tf_item, branch_name) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
compare_pt_tf_values
..?
LGTM, just a few nits if they make sense. |
@staticmethod | ||
def compare_pt_tf_models(pt_model, pt_input, tf_model, tf_input): | ||
""" | ||
Compares the TensorFload and PyTorch models, given their inputs, returning a tuple with the maximum observed |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Compares the TensorFload and PyTorch models, given their inputs, returning a tuple with the maximum observed | |
Compares the TensorFlow and PyTorch models, given their inputs, returning a tuple with the maximum observed |
converted_diff = np.max(np.abs(pt_last_hidden_state - tf_last_hidden_state)) | ||
del tf_from_pt_model # will no longer be used, and may have a large memory footprint | ||
tf_model = tf_class.from_pretrained(self._local_dir) | ||
converted_diff, diff_source = self.compare_pt_tf_models(pt_model, pt_input, tf_model, tf_input) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The function is called as compare_pt_tf_models
here, but as @ydshieh mentioned it's defined as compate_pt_tf_models
, so this bit will probably crash.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It was my bad, my comment should be compate_pt_tf_values
--> compare_pt_tf_values
.
Nothing wrong about compare_pt_tf_models
.
* Stricter pt-to-tf checks; Update docker image for related tests * check all attributes in the output Co-authored-by: Sylvain Gugger <[email protected]>
* Stricter pt-to-tf checks; Update docker image for related tests * check all attributes in the output Co-authored-by: Sylvain Gugger <[email protected]>
What does this PR do?
Last week I introduced the
pt-to-tf
CLI (#17497), enabling automatic weight conversion followed by PR opening.This PR makes four changes related to that CLI:
output_hidden_states=True
) are verified;git lfs
-- I did it for the circleci workflows in the original PR, but forgot to do it for the scheduled tests.🚨 This also means I will double-check previously open Hub PRs (about 10), to confirm that the model head is present in the TF weights (I suspect it isn't in some cases 😢 ) and that the outputs pass the stricter tests.
For context, if the conversion fails because of a difference in the model outputs, we get a message like this one: