Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Benchmarks #4912

Conversation

patrickvonplaten
Copy link
Contributor

@patrickvonplaten patrickvonplaten commented Jun 10, 2020

Benchmarks

This PR adds the functionality to measure the following functionalities for TF and PT:

Tensorflow:

  • Inference: CPU, GPU, GPU + XLA, GPU + eager mode, CPU + eager mode, TPU

PyTorch:

  • Inference: CPU, CPU + torchscript, GPU, GPU + torchscript, GPU + mixed precision, Torch/XLA TPU
  • Training: CPU, GPU, GPU + mixed precision, Torch/XLA TPU

How is memory measured?

CPU

We are always interested in the peak memory usage of the process. For CPU, the library psutil in combination with multiprocessing is leveraged

GPU

It is difficult to have exact memory measurement on GPU. Tensorflow allocates the full GPU memory by default. This is disabled with tf.config.experimental.set_memory_growth=True, but Tensorflow still allocates more memory than it needs for efficiency as far as I know.
=> Memory is therefore always measured to give the same maximal result as shown by nvidia-smi. This means that also memory for loading PyTorch / Tensorflow is taken into account which is for example not done when measuring via torch.cuda.max_allocated_memory.
Tensorflow also does not release GPU memory before the process is finished. Therefore, all measurement functions are wrapped into their own spawned process via Python's multiprocessing tools.

Also note that because TF does not release memory during the same process, memory and inference is measured using a multiprocess approach in TF. Also TF does not provide an official memory monitoring function, so that the same result that nvidia-smi would show for TF is used.

TPU

Memory measurement is currently not supported

How is speed measured?

For all functionality that requires compilation (TPU, XLA, Torchscript), 5 warmup calls of the function are done beforehand.

Afterwards, the minimum of self.args.repeat x the time-averaged over 10 function calls.

Example Colabs:

The colabs give quick examples for each functionality with little explanation for the moment:

Pytorch TPU: https://colab.research.google.com/drive/1GJFOdcBe1pW_FKWpA0jK_AOsIQ5epcvE?usp=sharing
Tensorflow TPU:
https://colab.research.google.com/drive/1t8DW1NxA4b1BsWSZ1ehFG9oT69l0h7os?usp=sharing

GPU: https://colab.research.google.com/drive/15XTPT_GPp42Zj7_f1W9X_T3NNXE9_1Te?usp=sharing
CPU: https://colab.research.google.com/drive/1OG2rZgo18KvliS-ratybld9pHD06-v5S?usp=sharing

Future PR:

  • Make nicer examples and explanations
  • Update docs and think about automatic measuring on website
  • Training in TF. Because the LM Head models currently do not accept labels parameter as an input, adding measurement for training is left for a future PR
  • GPU fp16 in TF. We currently have a bug in the lib that does not allow to run TF models in fp16 on GPU: TF BERT not FP16 compatible? #3320
  • PyTorch's amp package has memory leaks, so that we simply do model.half() to measure fp16 in Pytorch. See issue here: GPU memory issues (leak?) NVIDIA/apex#439 . Wait until amp is supported in upstream torch 1.6
  • Currently memory is not measured on TPU. Wait for more functionality for TPU
  • Allow multi-gpu measurments

@patrickvonplaten patrickvonplaten force-pushed the first_version_tf_benchmark branch 2 times, most recently from b95c536 to 3c3fc09 Compare June 16, 2020 14:35
@patrickvonplaten patrickvonplaten changed the title [WIP] Adding TF Benchmarks Add TF Benchmarks Jun 16, 2020
@@ -91,22 +91,6 @@ def test_train_with_configs(self):
self.check_results_dict_not_empty(results.time_train_result)
self.check_results_dict_not_empty(results.memory_train_result)

def test_train_with_configs_torchscript(self):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no possible torchscript training atm

@patrickvonplaten patrickvonplaten force-pushed the first_version_tf_benchmark branch 2 times, most recently from b7664d3 to c5fa0d5 Compare June 17, 2020 13:36
@codecov
Copy link

codecov bot commented Jun 18, 2020

Codecov Report

Merging #4912 into master will decrease coverage by 0.94%.
The diff coverage is 78.29%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master    #4912      +/-   ##
==========================================
- Coverage   77.28%   76.34%   -0.95%     
==========================================
  Files         133      134       +1     
  Lines       22134    22369     +235     
==========================================
- Hits        17107    17078      -29     
- Misses       5027     5291     +264     
Impacted Files Coverage Δ
src/transformers/benchmark/benchmark_utils.py 69.77% <68.51%> (-3.20%) ⬇️
src/transformers/benchmark/benchmark_args_utils.py 89.13% <71.42%> (-7.75%) ⬇️
src/transformers/benchmark/benchmark.py 79.13% <76.00%> (+9.70%) ⬆️
src/transformers/benchmark/benchmark_tf.py 82.69% <82.69%> (ø)
src/transformers/file_utils.py 76.80% <86.66%> (+1.40%) ⬆️
src/transformers/benchmark/benchmark_args_tf.py 87.50% <87.50%> (ø)
src/transformers/__init__.py 99.18% <100.00%> (+0.02%) ⬆️
src/transformers/benchmark/benchmark_args.py 86.04% <100.00%> (+0.68%) ⬆️
src/transformers/trainer.py 39.57% <100.00%> (+0.18%) ⬆️
src/transformers/modeling_tf_t5.py 50.10% <0.00%> (-43.61%) ⬇️
... and 7 more

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 355954f...8b71041. Read the comment docs.

@julien-c julien-c added the model card Related to pretrained model cards label Jun 18, 2020
@patrickvonplaten patrickvonplaten removed the model card Related to pretrained model cards label Jun 18, 2020
@patrickvonplaten patrickvonplaten changed the title Add TF Benchmarks Benchmarks Jun 18, 2020
@@ -5,12 +5,15 @@ include_trailing_comma = True
known_first_party = transformers
known_third_party =
absl
elasticsearch
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yjernite New packages for the example folder have to be added here to avoid problems with isort (learned from @sshleifer)

Copy link
Contributor

@sshleifer sshleifer left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Love this, LGTM!
Great test coverage as well.
Crazy that there isn't a better way to do this in tf.
Left some nits, but feel free to ignore.

examples/benchmarking/run_benchmark_tf.py Show resolved Hide resolved
src/transformers/benchmark/benchmark.py Show resolved Hide resolved
src/transformers/benchmark/benchmark.py Show resolved Hide resolved
# as written in https://docs.python.org/2/library/timeit.html#timeit.Timer.repeat, min should be taken rather than the average
runtimes = timeit.repeat(_train, repeat=self.args.repeat, number=10,)
# as written in https://docs.python.org/2/library/timeit.html#timeit.Timer.repeat, min should be taken rather than the average
runtimes = timeit.repeat(func, repeat=self.args.repeat, number=10,)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is number. Is this why the benchmarking is slow?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

number defines how many times the function should be executed and is then summed up repeat say how often the sum of functions should be run. So it returns a list of len(...) = self.args.repeat and each element in the list is the sum of number running the function

src/transformers/benchmark/benchmark.py Outdated Show resolved Hide resolved
src/transformers/benchmark/benchmark_utils.py Outdated Show resolved Hide resolved
@@ -81,6 +81,31 @@
_torch_tpu_available = False


try:
import psutil # noqa: F401
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a fairly innocuous dependency afaict, and you could add it to requirements.txt/setup.py

src/transformers/file_utils.py Show resolved Hide resolved
tests/test_benchmark.py Show resolved Hide resolved
no_inference=True,
sequence_lengths=[8],
batch_sizes=[1],
no_multi_process=True,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

multiprocess and pytest are not friends it seems?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nope :D PyTest also starts a new process in itself for each test so multiprocessing in multiprocessing breaks CUDA inits.

Copy link
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is great. I really like that you can import the benchmark if you want to use them during runtime, rather than the only option being to run a script.

Some remarks after playing with it:

  • Maybe you should raise an error when no model_names are specified. Right now it crashes with UnboundLocalError: local variable 'inference_summary' referenced before assignment (pytorch version at least)
  • There seems to be an error in the way the runtimes are computed. PyTorch using GPU, is slower than TensorFlow on CPU (10x times slower), while PyTorch on CPU is 150x slower than TensorFlow on CPU.

Here are the results from my runs so far. The following is on CPU with TensorFlow (2ms per inference with bert-base-cased, seq len 8 and batch size 512 on a CPU??) I didn't test the memory usage so they're not in the results:

====================       INFERENCE - SPEED - RESULT       ====================
--------------------------------------------------------------------------------
          Model Name             Batch Size     Seq Length     Time in s   
--------------------------------------------------------------------------------
       bert-base-cased               8               8             0.001     
       bert-base-cased               8               32            0.001     
       bert-base-cased               8              128            0.001     
       bert-base-cased               8              512            0.002     
--------------------------------------------------------------------------------

====================        ENVIRONMENT INFORMATION         ====================
- transformers_version: 2.11.0
- framework: Tensorflow
- eager_mode: False
- use_xla: False
- framework_version: 2.2.0
- python_version: 3.6.10
- system: Linux
- cpu: 
- architecture: 64bit
- date: 2020-06-18
- time: 11:57:18.595804
- fp16: False
- use_multiprocessing: True
- cpu_ram_mb: 64333
- use_gpu: False
- use_tpu: False

Here's the test with PyTorch on GPU:

====================       INFERENCE - SPEED - RESULT       ====================
--------------------------------------------------------------------------------
          Model Name             Batch Size     Seq Length     Time in s   
--------------------------------------------------------------------------------
       bert-base-cased               8               8             0.007     
       bert-base-cased               8               32            0.007     
       bert-base-cased               8              128            0.019     
       bert-base-cased               8              512            0.074     
--------------------------------------------------------------------------------

====================        ENVIRONMENT INFORMATION         ====================
- transformers_version: 2.11.0
- framework: PyTorch
- use_torchscript: False
- framework_version: 1.5.0
- python_version: 3.6.10
- system: Linux
- cpu: 
- architecture: 64bit
- date: 2020-06-18
- time: 11:56:31.041360
- fp16: False
- use_multiprocessing: True
- cpu_ram_mb: 64333
- use_gpu: True
- num_gpus: 1
- gpu: N/A
- gpu_ram_mb: N/A
- gpu_power_watts: N/A
- gpu_performance_state: N/A
- use_tpu: False

I'm not sure that PyTorch on GPU is ~37x slower than TensorFlow on CPU 😄 I tried to debug but it's not easy to debug tf functions unfortunately

Comment on lines 1 to 16
# coding=utf-8
# Copyright 2018 The HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Benchmarking the library on inference and training in Tensorflow"""
Copy link
Member

@LysandreJik LysandreJik Jun 18, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool, nice addition to the existing run_benchmark.py. You preferred to split the files into two because the arguments are too different?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, yeah I thought it's a nicer standard to just always have files one for tf one for pt. Also I didn't want to add a dataclass with a use_tf argument.

@patrickvonplaten
Copy link
Contributor Author

This is great. I really like that you can import the benchmark if you want to use them during runtime, rather than the only option being to run a script.

Some remarks after playing with it:

  • Maybe you should raise an error when no model_names are specified. Right now it crashes with UnboundLocalError: local variable 'inference_summary' referenced before assignment (pytorch version at least)
  • There seems to be an error in the way the runtimes are computed. PyTorch using GPU, is slower than TensorFlow on CPU (10x times slower), while PyTorch on CPU is 150x slower than TensorFlow on CPU.

Here are the results from my runs so far. The following is on CPU with TensorFlow (2ms per inference with bert-base-cased, seq len 8 and batch size 512 on a CPU??) I didn't test the memory usage so they're not in the results:

====================       INFERENCE - SPEED - RESULT       ====================
--------------------------------------------------------------------------------
          Model Name             Batch Size     Seq Length     Time in s   
--------------------------------------------------------------------------------
       bert-base-cased               8               8             0.001     
       bert-base-cased               8               32            0.001     
       bert-base-cased               8              128            0.001     
       bert-base-cased               8              512            0.002     
--------------------------------------------------------------------------------

====================        ENVIRONMENT INFORMATION         ====================
- transformers_version: 2.11.0
- framework: Tensorflow
- eager_mode: False
- use_xla: False
- framework_version: 2.2.0
- python_version: 3.6.10
- system: Linux
- cpu: 
- architecture: 64bit
- date: 2020-06-18
- time: 11:57:18.595804
- fp16: False
- use_multiprocessing: True
- cpu_ram_mb: 64333
- use_gpu: False
- use_tpu: False

Here's the test with PyTorch on GPU:

====================       INFERENCE - SPEED - RESULT       ====================
--------------------------------------------------------------------------------
          Model Name             Batch Size     Seq Length     Time in s   
--------------------------------------------------------------------------------
       bert-base-cased               8               8             0.007     
       bert-base-cased               8               32            0.007     
       bert-base-cased               8              128            0.019     
       bert-base-cased               8              512            0.074     
--------------------------------------------------------------------------------

====================        ENVIRONMENT INFORMATION         ====================
- transformers_version: 2.11.0
- framework: PyTorch
- use_torchscript: False
- framework_version: 1.5.0
- python_version: 3.6.10
- system: Linux
- cpu: 
- architecture: 64bit
- date: 2020-06-18
- time: 11:56:31.041360
- fp16: False
- use_multiprocessing: True
- cpu_ram_mb: 64333
- use_gpu: True
- num_gpus: 1
- gpu: N/A
- gpu_ram_mb: N/A
- gpu_power_watts: N/A
- gpu_performance_state: N/A
- use_tpu: False

I'm not sure that PyTorch on GPU is ~37x slower than TensorFlow on CPU I tried to debug but it's not easy to debug tf functions unfortunately

Thanks a lot for checking everything! Found the error :-) One just has to return a tensor out of the tf.function context so that it is actually computed. I guess before compilation TF compilation optimizes the function so that variables that are not used outside of the @tf.function scope are not computed.

Will update the notebooks and should then getter more reasonable results :-)

@patrickvonplaten
Copy link
Contributor Author

And will definitely add a better error message

@patrickvonplaten
Copy link
Contributor Author

The speed tests seem much more reasonable now, if you check the notebooks :-) @LysandreJik
There seems to be a problem with GPU memory in TF now :-/ Will check tomorrow again

@patrickvonplaten
Copy link
Contributor Author

patrickvonplaten commented Jun 19, 2020

GPU locally gives reasonable results of TF vs. PT.

All tests were run in this environment:

- transformers_version: 2.11.0
- python_version: 3.6.10
- system: Linux
- cpu: x86_64
- architecture: 64bit
- date: 2020-06-19
- time: 13:49:57.455208
- use_multiprocessing: True
- cpu_ram_mb: 32088
- use_gpu: True
- num_gpus: 1
- gpu: TITAN RTX
- gpu_ram_mb: 24217
- gpu_power_watts: 280.0
- gpu_performance_state: 2

for TF 2.2 and Pytorch 1.4.0

PyTorch

python run_benchmark.py --models gpt2 bert-base-cased --no_env_print --no_memory gives:


      Model Name             Batch Size     Seq Length     Time in s   

         gpt2                    8               8             0.006     
         gpt2                    8               32            0.007     
         gpt2                    8              128            0.026     
         gpt2                    8              512            0.104     
   bert-base-cased               8               8             0.006     
   bert-base-cased               8               32            0.006     
   bert-base-cased               8              128            0.021     
   bert-base-cased               8              512            0.094     

@patrickvonplaten
Copy link
Contributor Author

patrickvonplaten commented Jun 19, 2020

PyTorch FP16

python run_benchmark.py --models gpt2 bert-base-cased --no_env_print --no_memory --fp16


      Model Name             Batch Size     Seq Length     Time in s   

         gpt2                    8               8             0.006     
         gpt2                    8               32            0.007     
         gpt2                    8              128            0.009     
         gpt2                    8              512            0.043     
   bert-base-cased               8               8             0.006     
   bert-base-cased               8               32            0.006     
   bert-base-cased               8              128            0.006     
   bert-base-cased               8              512             0.03   

@patrickvonplaten
Copy link
Contributor Author

patrickvonplaten commented Jun 19, 2020

TF no eager modus

python run_benchmark_tf.py --models gpt2 bert-base-cased --no_env_print --no_memory


      Model Name             Batch Size     Seq Length     Time in s   

         gpt2                    8               8             0.005     
         gpt2                    8               32            0.007     
         gpt2                    8              128            0.029     
         gpt2                    8              512            0.125     
   bert-base-cased               8               8             0.005     
   bert-base-cased               8               32            0.006     
   bert-base-cased               8              128            0.024     
   bert-base-cased               8              512            0.114     

@patrickvonplaten
Copy link
Contributor Author

TF XLA

python run_benchmark_tf.py --models gpt2 bert-base-cased --no_env_print --no_memory --use_xla


      Model Name             Batch Size     Seq Length     Time in s   

         gpt2                    8               8             0.002     
         gpt2                    8               32            0.006     
         gpt2                    8              128            0.021     
         gpt2                    8              512            0.095     
   bert-base-cased               8               8             0.003     
   bert-base-cased               8               32            0.005     
   bert-base-cased               8              128            0.019     
   bert-base-cased               8              512            0.087     

@patrickvonplaten
Copy link
Contributor Author

patrickvonplaten commented Jun 19, 2020

Memory measurements

They also seem reasonable for forward pass:.

TF no eager mode (keeping in mind that nvidia-smi is not accurate here and TF always allocates more than it needs):

python run_benchmark_tf.py --models gpt2 bert-base-cased --no_env_print --no_speed


      Model Name             Batch Size     Seq Length    Memory in MB 

         gpt2                    64              8              1704     
         gpt2                    64              32             1704     
         gpt2                    64             128             2728     
         gpt2                    64             512             8872     
   bert-base-cased               64              8              1192     
   bert-base-cased               64              32             1192     
   bert-base-cased               64             128             1704     
   bert-base-cased               64             512             4776     

PyTorch

python run_benchmark.py --models gpt2 bert-base-cased --no_env_print --no_speed


      Model Name             Batch Size     Seq Length    Memory in MB 

         gpt2                    64              8              1150     
         gpt2                    64              32             1384     
         gpt2                    64             128             2290     
         gpt2                    64             512             5890     
   bert-base-cased               64              8              1016     
   bert-base-cased               64              32             1104     
   bert-base-cased               64             128             1448     
   bert-base-cased               64             512             3224     

PyTorch FP16

python run_benchmark.py --models gpt2 bert-base-cased --no_env_print --no_speed --fp16


      Model Name             Batch Size     Seq Length    Memory in MB 

         gpt2                    64              8              1170     
         gpt2                    64              32             1164     
         gpt2                    64             128             1596     
         gpt2                    64             512             3420     
   bert-base-cased               64              8              1066     
   bert-base-cased               64              32             1060     
   bert-base-cased               64             128             1108     
   bert-base-cased               64             512             2118     

Copy link
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great!! Can't wait for the future PRs, this is an exciting subject!

@patrickvonplaten patrickvonplaten merged commit fa0be6d into huggingface:master Jun 22, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants