Skip to content

Commit

Permalink
Benchmarks (#4912)
Browse files Browse the repository at this point in the history
* finish benchmark

* fix isort

* fix setup cfg

* retab

* fix time measuring of tf graph mode

* fix tf cuda

* clean code

* better error message
  • Loading branch information
patrickvonplaten authored Jun 22, 2020
1 parent 18a0150 commit fa0be6d
Show file tree
Hide file tree
Showing 18 changed files with 1,045 additions and 368 deletions.
2 changes: 1 addition & 1 deletion examples/benchmarking/run_benchmark.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# coding=utf-8
# Copyright 2018 The HuggingFace Inc. team.
# Copyright 2020 The HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand Down
29 changes: 29 additions & 0 deletions examples/benchmarking/run_benchmark_tf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# coding=utf-8
# Copyright 2018 The HuggingFace Inc. team.
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Benchmarking the library on inference and training in Tensorflow"""

from transformers import HfArgumentParser, TensorflowBenchmark, TensorflowBenchmarkArguments


def main():
parser = HfArgumentParser(TensorflowBenchmarkArguments)
benchmark_args = parser.parse_args_into_dataclasses()[0]
benchmark = TensorflowBenchmark(args=benchmark_args)
benchmark.run()


if __name__ == "__main__":
main()
6 changes: 3 additions & 3 deletions examples/longform-qa/eli5_app.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import faiss
import nlp
import numpy as np
import torch
from elasticsearch import Elasticsearch

import faiss
import nlp
import streamlit as st
import transformers
from elasticsearch import Elasticsearch
from eli5_utils import (
embed_questions_for_retrieval,
make_qa_s2s_model,
Expand Down
8 changes: 4 additions & 4 deletions examples/longform-qa/eli5_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,17 @@
from random import choice, randint
from time import time

import faiss # noqa: F401
import nlp # noqa: F401
import numpy as np
import pandas as pd
import torch
import torch.utils.checkpoint as checkpoint
from elasticsearch import Elasticsearch # noqa: F401
from elasticsearch.helpers import bulk, streaming_bulk # noqa: F401
from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler
from tqdm import tqdm

import faiss # noqa: F401
import nlp # noqa: F401
from elasticsearch import Elasticsearch # noqa: F401
from elasticsearch.helpers import bulk, streaming_bulk # noqa: F401
from transformers import AdamW, AutoModel, AutoModelForSeq2SeqLM, AutoTokenizer, get_linear_schedule_with_warmup


Expand Down
5 changes: 5 additions & 0 deletions examples/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,8 @@ tensorflow_datasets
pytorch-lightning==0.7.6
matplotlib
git-python==1.0.3
faiss
streamlit
elasticsearch
pandas
nlp
3 changes: 3 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,15 @@ include_trailing_comma = True
known_first_party = transformers
known_third_party =
absl
elasticsearch
fairseq
faiss
fastprogress
git
h5py
matplotlib
MeCab
nlp
nltk
numpy
packaging
Expand Down
10 changes: 9 additions & 1 deletion src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,9 @@
add_end_docstrings,
add_start_docstrings,
cached_path,
is_apex_available,
is_psutil_available,
is_py3nvml_available,
is_tf_available,
is_torch_available,
is_torch_tpu_available,
Expand Down Expand Up @@ -398,7 +401,8 @@
from .data.datasets import GlueDataset, TextDataset, LineByLineTextDataset, GlueDataTrainingArguments

# Benchmarks
from .benchmark import PyTorchBenchmark, PyTorchBenchmarkArguments
from .benchmark.benchmark import PyTorchBenchmark
from .benchmark.benchmark_args import PyTorchBenchmarkArguments

# TensorFlow
if is_tf_available():
Expand Down Expand Up @@ -608,6 +612,10 @@
# Trainer
from .trainer_tf import TFTrainer

# Benchmarks
from .benchmark.benchmark_tf import TensorflowBenchmark
from .benchmark.benchmark_args_tf import TensorflowBenchmarkArguments


if not is_tf_available() and not is_torch_available():
logger.warning(
Expand Down
10 changes: 0 additions & 10 deletions src/transformers/benchmark/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +0,0 @@
# flake8: noqa
# There's no way to ignore "F401 '...' imported but unused" warnings in this
# module, but to preserve other warnings. So, don't check this module at all.

from ..file_utils import is_torch_available


if is_torch_available():
from .benchmark_args import PyTorchBenchmarkArguments
from .benchmark import PyTorchBenchmark
Loading

0 comments on commit fa0be6d

Please sign in to comment.