Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Benchmarks #4912

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/benchmarking/run_benchmark.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# coding=utf-8
# Copyright 2018 The HuggingFace Inc. team.
# Copyright 2020 The HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand Down
29 changes: 29 additions & 0 deletions examples/benchmarking/run_benchmark_tf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# coding=utf-8
# Copyright 2018 The HuggingFace Inc. team.
patrickvonplaten marked this conversation as resolved.
Show resolved Hide resolved
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Benchmarking the library on inference and training in Tensorflow"""

from transformers import HfArgumentParser, TensorflowBenchmark, TensorflowBenchmarkArguments


def main():
parser = HfArgumentParser(TensorflowBenchmarkArguments)
benchmark_args = parser.parse_args_into_dataclasses()[0]
benchmark = TensorflowBenchmark(args=benchmark_args)
benchmark.run()


if __name__ == "__main__":
main()
6 changes: 3 additions & 3 deletions examples/longform-qa/eli5_app.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import faiss
import nlp
import numpy as np
import torch
from elasticsearch import Elasticsearch

import faiss
import nlp
import streamlit as st
import transformers
from elasticsearch import Elasticsearch
from eli5_utils import (
embed_questions_for_retrieval,
make_qa_s2s_model,
Expand Down
8 changes: 4 additions & 4 deletions examples/longform-qa/eli5_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,17 @@
from random import choice, randint
from time import time

import faiss # noqa: F401
import nlp # noqa: F401
import numpy as np
import pandas as pd
import torch
import torch.utils.checkpoint as checkpoint
from elasticsearch import Elasticsearch # noqa: F401
from elasticsearch.helpers import bulk, streaming_bulk # noqa: F401
from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler
from tqdm import tqdm

import faiss # noqa: F401
import nlp # noqa: F401
from elasticsearch import Elasticsearch # noqa: F401
from elasticsearch.helpers import bulk, streaming_bulk # noqa: F401
from transformers import AdamW, AutoModel, AutoModelForSeq2SeqLM, AutoTokenizer, get_linear_schedule_with_warmup


Expand Down
5 changes: 5 additions & 0 deletions examples/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,8 @@ tensorflow_datasets
pytorch-lightning==0.7.6
matplotlib
git-python==1.0.3
faiss
streamlit
elasticsearch
pandas
nlp
3 changes: 3 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,15 @@ include_trailing_comma = True
known_first_party = transformers
known_third_party =
absl
elasticsearch
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yjernite New packages for the example folder have to be added here to avoid problems with isort (learned from @sshleifer)

fairseq
faiss
fastprogress
git
h5py
matplotlib
MeCab
nlp
nltk
numpy
packaging
Expand Down
10 changes: 9 additions & 1 deletion src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,9 @@
add_end_docstrings,
add_start_docstrings,
cached_path,
is_apex_available,
is_psutil_available,
is_py3nvml_available,
is_tf_available,
is_torch_available,
is_torch_tpu_available,
Expand Down Expand Up @@ -380,7 +383,8 @@
from .data.datasets import GlueDataset, TextDataset, LineByLineTextDataset, GlueDataTrainingArguments

# Benchmarks
from .benchmark import PyTorchBenchmark, PyTorchBenchmarkArguments
from .benchmark.benchmark import PyTorchBenchmark
from .benchmark.benchmark_args import PyTorchBenchmarkArguments

# TensorFlow
if is_tf_available():
Expand Down Expand Up @@ -576,6 +580,10 @@
# Trainer
from .trainer_tf import TFTrainer

# Benchmarks
from .benchmark.benchmark_tf import TensorflowBenchmark
from .benchmark.benchmark_args_tf import TensorflowBenchmarkArguments


if not is_tf_available() and not is_torch_available():
logger.warning(
Expand Down
10 changes: 0 additions & 10 deletions src/transformers/benchmark/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +0,0 @@
# flake8: noqa
# There's no way to ignore "F401 '...' imported but unused" warnings in this
# module, but to preserve other warnings. So, don't check this module at all.

from ..file_utils import is_torch_available


if is_torch_available():
from .benchmark_args import PyTorchBenchmarkArguments
from .benchmark import PyTorchBenchmark
Loading