Skip to content

Commit

Permalink
CLI: add stricter automatic checks to pt-to-tf (#17588)
Browse files Browse the repository at this point in the history
* Stricter pt-to-tf checks; Update docker image for related tests

* check all attributes in the output

Co-authored-by: Sylvain Gugger <[email protected]>
  • Loading branch information
gante and sgugger authored Jun 8, 2022
1 parent c6cea5a commit 78c695e
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 21 deletions.
3 changes: 2 additions & 1 deletion docker/transformers-all-latest-gpu/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@ LABEL maintainer="Hugging Face"
ARG DEBIAN_FRONTEND=noninteractive

RUN apt update
RUN apt install -y git libsndfile1-dev tesseract-ocr espeak-ng python3 python3-pip ffmpeg
RUN apt install -y git libsndfile1-dev tesseract-ocr espeak-ng python3 python3-pip ffmpeg git-lfs
RUN git lfs install
RUN python3 -m pip install --no-cache-dir --upgrade pip

ARG REF=main
Expand Down
111 changes: 91 additions & 20 deletions src/transformers/commands/pt_to_tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,14 @@

import os
from argparse import ArgumentParser, Namespace
from importlib import import_module

import numpy as np
from datasets import load_dataset

from huggingface_hub import Repository, upload_file

from .. import AutoFeatureExtractor, AutoModel, AutoTokenizer, TFAutoModel, is_tf_available, is_torch_available
from .. import AutoConfig, AutoFeatureExtractor, AutoTokenizer, is_tf_available, is_torch_available
from ..utils import logging
from . import BaseTransformersCLICommand

Expand All @@ -44,7 +45,7 @@ def convert_command_factory(args: Namespace):
Returns: ServeCommand
"""
return PTtoTFCommand(args.model_name, args.local_dir, args.no_pr)
return PTtoTFCommand(args.model_name, args.local_dir, args.no_pr, args.new_weights)


class PTtoTFCommand(BaseTransformersCLICommand):
Expand Down Expand Up @@ -78,13 +79,69 @@ def register_subcommand(parser: ArgumentParser):
train_parser.add_argument(
"--no-pr", action="store_true", help="Optional flag to NOT open a PR with converted weights."
)
train_parser.add_argument(
"--new-weights",
action="store_true",
help="Optional flag to create new TensorFlow weights, even if they already exist.",
)
train_parser.set_defaults(func=convert_command_factory)

def __init__(self, model_name: str, local_dir: str, no_pr: bool, *args):
@staticmethod
def compare_pt_tf_models(pt_model, pt_input, tf_model, tf_input):
"""
Compares the TensorFlow and PyTorch models, given their inputs, returning a tuple with the maximum observed
difference and its source.
"""
pt_outputs = pt_model(**pt_input, output_hidden_states=True)
tf_outputs = tf_model(**tf_input, output_hidden_states=True)

# 1. All output attributes must be the same
pt_out_attrs = set(pt_outputs.keys())
tf_out_attrs = set(tf_outputs.keys())
if pt_out_attrs != tf_out_attrs:
raise ValueError(
f"The model outputs have different attributes, aborting. (Pytorch: {pt_out_attrs}, TensorFlow:"
f" {tf_out_attrs})"
)

# 2. For each output attribute, ALL values must be the same
def _compate_pt_tf_models(pt_out, tf_out, attr_name=""):
max_difference = 0
max_difference_source = ""

# If the current attribute is a tensor, it is a leaf and we make the comparison. Otherwise, we will dig in
# recursivelly, keeping the name of the attribute.
if isinstance(pt_out, (torch.Tensor)):
difference = np.max(np.abs(pt_out.detach().numpy() - tf_out.numpy()))
if difference > max_difference:
max_difference = difference
max_difference_source = attr_name
else:
root_name = attr_name
for i, pt_item in enumerate(pt_out):
# If it is a named attribute, we keep the name. Otherwise, just its index.
if isinstance(pt_item, str):
branch_name = root_name + pt_item
tf_item = tf_out[pt_item]
pt_item = pt_out[pt_item]
else:
branch_name = root_name + f"[{i}]"
tf_item = tf_out[i]
difference, difference_source = _compate_pt_tf_models(pt_item, tf_item, branch_name)
if difference > max_difference:
max_difference = difference
max_difference_source = difference_source

return max_difference, max_difference_source

return _compate_pt_tf_models(pt_outputs, tf_outputs)

def __init__(self, model_name: str, local_dir: str, no_pr: bool, new_weights: bool, *args):
self._logger = logging.get_logger("transformers-cli/pt_to_tf")
self._model_name = model_name
self._local_dir = local_dir if local_dir else os.path.join("/tmp", model_name)
self._no_pr = no_pr
self._new_weights = new_weights

def get_text_inputs(self):
tokenizer = AutoTokenizer.from_pretrained(self._local_dir)
Expand Down Expand Up @@ -119,8 +176,25 @@ def run(self):
repo = Repository(local_dir=self._local_dir, clone_from=self._model_name)
repo.git_pull() # in case the repo already exists locally, but with an older commit

# Load config and get the appropriate architecture -- the latter is needed to convert the head's weights
config = AutoConfig.from_pretrained(self._local_dir)
architectures = config.architectures
if architectures is None: # No architecture defined -- use auto classes
pt_class = getattr(import_module("transformers"), "AutoModel")
tf_class = getattr(import_module("transformers"), "TFAutoModel")
self._logger.warn("No detected architecture, using AutoModel/TFAutoModel")
else: # Architecture defined -- use it
if len(architectures) > 1:
raise ValueError(f"More than one architecture was found, aborting. (architectures = {architectures})")
self._logger.warn(f"Detected architecture: {architectures[0]}")
pt_class = getattr(import_module("transformers"), architectures[0])
try:
tf_class = getattr(import_module("transformers"), "TF" + architectures[0])
except AttributeError:
raise AttributeError(f"The TensorFlow equivalent of {architectures[0]} doesn't exist in transformers.")

# Load models and acquire a basic input for its modality.
pt_model = AutoModel.from_pretrained(self._local_dir)
pt_model = pt_class.from_pretrained(self._local_dir)
main_input_name = pt_model.main_input_name
if main_input_name == "input_ids":
pt_input, tf_input = self.get_text_inputs()
Expand All @@ -130,7 +204,7 @@ def run(self):
pt_input, tf_input = self.get_audio_inputs()
else:
raise ValueError(f"Can't detect the model modality (`main_input_name` = {main_input_name})")
tf_from_pt_model = TFAutoModel.from_pretrained(self._local_dir, from_pt=True)
tf_from_pt_model = tf_class.from_pretrained(self._local_dir, from_pt=True)

# Extra input requirements, in addition to the input modality
if hasattr(pt_model, "encoder") and hasattr(pt_model, "decoder"):
Expand All @@ -139,27 +213,24 @@ def run(self):
tf_input.update({"decoder_input_ids": tf.convert_to_tensor(decoder_input_ids)})

# Confirms that cross loading PT weights into TF worked.
pt_last_hidden_state = pt_model(**pt_input).last_hidden_state.detach().numpy()
tf_from_pt_last_hidden_state = tf_from_pt_model(**tf_input).last_hidden_state.numpy()
crossload_diff = np.max(np.abs(pt_last_hidden_state - tf_from_pt_last_hidden_state))
crossload_diff, diff_source = self.compare_pt_tf_models(pt_model, pt_input, tf_from_pt_model, tf_input)
if crossload_diff >= MAX_ERROR:
raise ValueError(
"The cross-loaded TF model has different last hidden states, something went wrong! (max difference ="
f" {crossload_diff})"
"The cross-loaded TF model has different outputs, something went wrong! (max difference ="
f" {crossload_diff:.3e}, observed in {diff_source})"
)

# Save the weights in a TF format (if they don't exist) and confirms that the results are still good
# Save the weights in a TF format (if needed) and confirms that the results are still good
tf_weights_path = os.path.join(self._local_dir, TF_WEIGHTS_NAME)
if not os.path.exists(tf_weights_path):
if not os.path.exists(tf_weights_path) or self._new_weights:
tf_from_pt_model.save_weights(tf_weights_path)
del tf_from_pt_model, pt_model # will no longer be used, and may have a large memory footprint
tf_model = TFAutoModel.from_pretrained(self._local_dir)
tf_last_hidden_state = tf_model(**tf_input).last_hidden_state.numpy()
converted_diff = np.max(np.abs(pt_last_hidden_state - tf_last_hidden_state))
del tf_from_pt_model # will no longer be used, and may have a large memory footprint
tf_model = tf_class.from_pretrained(self._local_dir)
converted_diff, diff_source = self.compare_pt_tf_models(pt_model, pt_input, tf_model, tf_input)
if converted_diff >= MAX_ERROR:
raise ValueError(
"The converted TF model has different last hidden states, something went wrong! (max difference ="
f" {converted_diff})"
"The converted TF model has different outputs, something went wrong! (max difference ="
f" {converted_diff:.3e}, observed in {diff_source})"
)

if not self._no_pr:
Expand All @@ -174,8 +245,8 @@ def run(self):
create_pr=True,
pr_commit_summary="Add TF weights",
pr_commit_description=(
f"Validated by the `pt_to_tf` CLI. Max crossload hidden state difference={crossload_diff:.3e};"
f" Max converted hidden state difference={converted_diff:.3e}."
f"Validated by the `pt_to_tf` CLI. Max crossload output difference={crossload_diff:.3e};"
f" Max converted output difference={converted_diff:.3e}."
),
)
self._logger.warn(f"PR open in {hub_pr_url}")
Expand Down

0 comments on commit 78c695e

Please sign in to comment.