-
Notifications
You must be signed in to change notification settings - Fork 27.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
CLI: add stricter automatic checks to pt-to-tf
#17588
Changes from 8 commits
4aaa263
7a58886
ada8865
692c19e
52c98ae
c9b78ac
8a9d873
d6612ce
c35417f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,13 +14,14 @@ | |
|
||
import os | ||
from argparse import ArgumentParser, Namespace | ||
from importlib import import_module | ||
|
||
import numpy as np | ||
from datasets import load_dataset | ||
|
||
from huggingface_hub import Repository, upload_file | ||
|
||
from .. import AutoFeatureExtractor, AutoModel, AutoTokenizer, TFAutoModel, is_tf_available, is_torch_available | ||
from .. import AutoConfig, AutoFeatureExtractor, AutoTokenizer, is_tf_available, is_torch_available | ||
from ..utils import logging | ||
from . import BaseTransformersCLICommand | ||
|
||
|
@@ -44,7 +45,7 @@ def convert_command_factory(args: Namespace): | |
|
||
Returns: ServeCommand | ||
""" | ||
return PTtoTFCommand(args.model_name, args.local_dir, args.no_pr) | ||
return PTtoTFCommand(args.model_name, args.local_dir, args.no_pr, args.new_weights) | ||
|
||
|
||
class PTtoTFCommand(BaseTransformersCLICommand): | ||
|
@@ -78,13 +79,69 @@ def register_subcommand(parser: ArgumentParser): | |
train_parser.add_argument( | ||
"--no-pr", action="store_true", help="Optional flag to NOT open a PR with converted weights." | ||
) | ||
train_parser.add_argument( | ||
"--new-weights", | ||
action="store_true", | ||
help="Optional flag to create new TensorFlow weights, even if they already exist.", | ||
) | ||
train_parser.set_defaults(func=convert_command_factory) | ||
|
||
def __init__(self, model_name: str, local_dir: str, no_pr: bool, *args): | ||
@staticmethod | ||
def compare_pt_tf_models(pt_model, pt_input, tf_model, tf_input): | ||
""" | ||
Compares the TensorFload and PyTorch models, given their inputs, returning a tuple with the maximum observed | ||
difference and its source. | ||
""" | ||
pt_outputs = pt_model(**pt_input, output_hidden_states=True) | ||
tf_outputs = tf_model(**tf_input, output_hidden_states=True) | ||
|
||
# 1. All output attributes must be the same | ||
pt_out_attrs = set(pt_outputs.keys()) | ||
tf_out_attrs = set(tf_outputs.keys()) | ||
if pt_out_attrs != tf_out_attrs: | ||
raise ValueError( | ||
f"The model outputs have different attributes, aborting. (Pytorch: {pt_out_attrs}, TensorFlow:" | ||
f" {tf_out_attrs})" | ||
) | ||
|
||
# 2. For each output attribute, ALL values must be the same | ||
def compate_pt_tf_values(pt_out, tf_out, attr_name=""): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
max_difference = 0 | ||
max_difference_source = "" | ||
|
||
# If the current attribute is a tensor, it is a leaf and we make the comparison. Otherwise, we will dig in | ||
# recursivelly, keeping the name of the attribute. | ||
if isinstance(pt_out, (torch.Tensor)): | ||
difference = np.max(np.abs(pt_out.detach().numpy() - tf_out.numpy())) | ||
if difference > max_difference: | ||
max_difference = difference | ||
max_difference_source = attr_name | ||
else: | ||
root_name = attr_name | ||
for i, pt_item in enumerate(pt_out): | ||
# If it is a named attribute, we keep the name. Otherwise, just its index. | ||
if isinstance(pt_item, str): | ||
branch_name = root_name + pt_item | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I feel that we will need to have something like There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There is no need, the names are not nested (at the moment). As it is structured, it will print the variable as we would write on a python terminal to get it, so we can copy-paste it for further inspection -- e.g. |
||
tf_item = tf_out[pt_item] | ||
pt_item = pt_out[pt_item] | ||
else: | ||
branch_name = root_name + f"[{i}]" | ||
tf_item = tf_out[i] | ||
difference, difference_source = compate_pt_tf_values(pt_item, tf_item, branch_name) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
if difference > max_difference: | ||
max_difference = difference | ||
max_difference_source = difference_source | ||
|
||
return max_difference, max_difference_source | ||
|
||
return compate_pt_tf_values(pt_outputs, tf_outputs) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Will rename to |
||
|
||
def __init__(self, model_name: str, local_dir: str, no_pr: bool, new_weights: bool, *args): | ||
self._logger = logging.get_logger("transformers-cli/pt_to_tf") | ||
self._model_name = model_name | ||
self._local_dir = local_dir if local_dir else os.path.join("/tmp", model_name) | ||
self._no_pr = no_pr | ||
self._new_weights = new_weights | ||
|
||
def get_text_inputs(self): | ||
tokenizer = AutoTokenizer.from_pretrained(self._local_dir) | ||
|
@@ -119,8 +176,25 @@ def run(self): | |
repo = Repository(local_dir=self._local_dir, clone_from=self._model_name) | ||
repo.git_pull() # in case the repo already exists locally, but with an older commit | ||
|
||
# Load config and get the appropriate architecture -- the latter is needed to convert the head's weights | ||
config = AutoConfig.from_pretrained(self._local_dir) | ||
architectures = config.architectures | ||
if architectures is None: # No architecture defined -- use auto classes | ||
pt_class = getattr(import_module("transformers"), "AutoModel") | ||
tf_class = getattr(import_module("transformers"), "TFAutoModel") | ||
self._logger.warn("No detected architecture, using AutoModel/TFAutoModel") | ||
else: # Architecture defined -- use it | ||
if len(architectures) > 1: | ||
raise ValueError(f"More than one architecture was found, aborting. (architectures = {architectures})") | ||
self._logger.warn(f"Detected architecture: {architectures[0]}") | ||
pt_class = getattr(import_module("transformers"), architectures[0]) | ||
try: | ||
tf_class = getattr(import_module("transformers"), "TF" + architectures[0]) | ||
except AttributeError: | ||
raise AttributeError(f"The TensorFlow equivalent of {architectures[0]} doesn't exist in transformers.") | ||
|
||
# Load models and acquire a basic input for its modality. | ||
pt_model = AutoModel.from_pretrained(self._local_dir) | ||
pt_model = pt_class.from_pretrained(self._local_dir) | ||
main_input_name = pt_model.main_input_name | ||
if main_input_name == "input_ids": | ||
pt_input, tf_input = self.get_text_inputs() | ||
|
@@ -130,7 +204,7 @@ def run(self): | |
pt_input, tf_input = self.get_audio_inputs() | ||
else: | ||
raise ValueError(f"Can't detect the model modality (`main_input_name` = {main_input_name})") | ||
tf_from_pt_model = TFAutoModel.from_pretrained(self._local_dir, from_pt=True) | ||
tf_from_pt_model = tf_class.from_pretrained(self._local_dir, from_pt=True) | ||
|
||
# Extra input requirements, in addition to the input modality | ||
if hasattr(pt_model, "encoder") and hasattr(pt_model, "decoder"): | ||
|
@@ -139,27 +213,24 @@ def run(self): | |
tf_input.update({"decoder_input_ids": tf.convert_to_tensor(decoder_input_ids)}) | ||
|
||
# Confirms that cross loading PT weights into TF worked. | ||
pt_last_hidden_state = pt_model(**pt_input).last_hidden_state.detach().numpy() | ||
tf_from_pt_last_hidden_state = tf_from_pt_model(**tf_input).last_hidden_state.numpy() | ||
crossload_diff = np.max(np.abs(pt_last_hidden_state - tf_from_pt_last_hidden_state)) | ||
crossload_diff, diff_source = self.compare_pt_tf_models(pt_model, pt_input, tf_from_pt_model, tf_input) | ||
if crossload_diff >= MAX_ERROR: | ||
raise ValueError( | ||
"The cross-loaded TF model has different last hidden states, something went wrong! (max difference =" | ||
f" {crossload_diff})" | ||
"The cross-loaded TF model has different outputs, something went wrong! (max difference =" | ||
f" {crossload_diff:.3e}, observed in {diff_source})" | ||
) | ||
|
||
# Save the weights in a TF format (if they don't exist) and confirms that the results are still good | ||
# Save the weights in a TF format (if needed) and confirms that the results are still good | ||
tf_weights_path = os.path.join(self._local_dir, TF_WEIGHTS_NAME) | ||
if not os.path.exists(tf_weights_path): | ||
if not os.path.exists(tf_weights_path) or self._new_weights: | ||
tf_from_pt_model.save_weights(tf_weights_path) | ||
del tf_from_pt_model, pt_model # will no longer be used, and may have a large memory footprint | ||
tf_model = TFAutoModel.from_pretrained(self._local_dir) | ||
tf_last_hidden_state = tf_model(**tf_input).last_hidden_state.numpy() | ||
converted_diff = np.max(np.abs(pt_last_hidden_state - tf_last_hidden_state)) | ||
del tf_from_pt_model # will no longer be used, and may have a large memory footprint | ||
tf_model = tf_class.from_pretrained(self._local_dir) | ||
converted_diff, diff_source = self.compare_pt_tf_models(pt_model, pt_input, tf_model, tf_input) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The function is called as There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It was my bad, my comment should be Nothing wrong about |
||
if converted_diff >= MAX_ERROR: | ||
raise ValueError( | ||
"The converted TF model has different last hidden states, something went wrong! (max difference =" | ||
f" {converted_diff})" | ||
"The converted TF model has different outputs, something went wrong! (max difference =" | ||
f" {converted_diff:.3e}, observed in {diff_source})" | ||
) | ||
|
||
if not self._no_pr: | ||
|
@@ -174,8 +245,8 @@ def run(self): | |
create_pr=True, | ||
pr_commit_summary="Add TF weights", | ||
pr_commit_description=( | ||
f"Validated by the `pt_to_tf` CLI. Max crossload hidden state difference={crossload_diff:.3e};" | ||
f" Max converted hidden state difference={converted_diff:.3e}." | ||
f"Validated by the `pt_to_tf` CLI. Max crossload output difference={crossload_diff:.3e};" | ||
f" Max converted output difference={converted_diff:.3e}." | ||
), | ||
) | ||
self._logger.warn(f"PR open in {hub_pr_url}") | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.