From f214a45a35fbe77d1f89c162f17ee8a488770325 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Thu, 18 Jun 2020 16:16:36 +0200 Subject: [PATCH 1/8] finish benchmark --- examples/benchmarking/run_benchmark_tf.py | 29 ++ examples/longform-qa/eli5_app.py | 2 +- examples/longform-qa/eli5_utils.py | 2 +- src/transformers/__init__.py | 10 +- src/transformers/benchmark/__init__.py | 10 - src/transformers/benchmark/benchmark.py | 347 ++++++++---------- src/transformers/benchmark/benchmark_args.py | 23 +- .../benchmark/benchmark_args_tf.py | 106 ++++++ .../benchmark/benchmark_args_utils.py | 23 ++ src/transformers/benchmark/benchmark_tf.py | 216 +++++++++++ src/transformers/benchmark/benchmark_utils.py | 279 ++++++++------ src/transformers/file_utils.py | 37 ++ src/transformers/trainer.py | 13 +- tests/test_benchmark.py | 87 ++++- tests/test_benchmark_tf.py | 165 +++++++++ 15 files changed, 988 insertions(+), 361 deletions(-) create mode 100644 examples/benchmarking/run_benchmark_tf.py create mode 100644 src/transformers/benchmark/benchmark_args_tf.py create mode 100644 src/transformers/benchmark/benchmark_tf.py create mode 100644 tests/test_benchmark_tf.py diff --git a/examples/benchmarking/run_benchmark_tf.py b/examples/benchmarking/run_benchmark_tf.py new file mode 100644 index 00000000000000..5d578e0f8c7ffc --- /dev/null +++ b/examples/benchmarking/run_benchmark_tf.py @@ -0,0 +1,29 @@ +# coding=utf-8 +# Copyright 2018 The HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Benchmarking the library on inference and training in Tensorflow""" + +from transformers import HfArgumentParser, TensorflowBenchmark, TensorflowBenchmarkArguments + + +def main(): + parser = HfArgumentParser(TensorflowBenchmarkArguments) + benchmark_args = parser.parse_args_into_dataclasses()[0] + benchmark = TensorflowBenchmark(args=benchmark_args) + benchmark.run() + + +if __name__ == "__main__": + main() diff --git a/examples/longform-qa/eli5_app.py b/examples/longform-qa/eli5_app.py index e79f1d6ed14fae..8826cd2fbae48e 100644 --- a/examples/longform-qa/eli5_app.py +++ b/examples/longform-qa/eli5_app.py @@ -1,8 +1,8 @@ +import nlp import numpy as np import torch import faiss -import nlp import streamlit as st import transformers from elasticsearch import Elasticsearch diff --git a/examples/longform-qa/eli5_utils.py b/examples/longform-qa/eli5_utils.py index 0298625cdc792b..c3a1da26d96432 100644 --- a/examples/longform-qa/eli5_utils.py +++ b/examples/longform-qa/eli5_utils.py @@ -4,6 +4,7 @@ from random import choice, randint from time import time +import nlp # noqa: F401 import numpy as np import pandas as pd import torch @@ -12,7 +13,6 @@ from tqdm import tqdm import faiss # noqa: F401 -import nlp # noqa: F401 from elasticsearch import Elasticsearch # noqa: F401 from elasticsearch.helpers import bulk, streaming_bulk # noqa: F401 from transformers import AdamW, AutoModel, AutoModelForSeq2SeqLM, AutoTokenizer, get_linear_schedule_with_warmup diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index f0181a0860f887..730f099e81551f 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -77,6 +77,9 @@ add_end_docstrings, add_start_docstrings, cached_path, + is_apex_available, + is_psutil_available, + is_py3nvml_available, is_tf_available, is_torch_available, is_torch_tpu_available, @@ -380,7 +383,8 @@ from .data.datasets import GlueDataset, TextDataset, LineByLineTextDataset, GlueDataTrainingArguments # Benchmarks - from .benchmark import PyTorchBenchmark, PyTorchBenchmarkArguments + from .benchmark.benchmark import PyTorchBenchmark + from .benchmark.benchmark_args import PyTorchBenchmarkArguments # TensorFlow if is_tf_available(): @@ -576,6 +580,10 @@ # Trainer from .trainer_tf import TFTrainer + # Benchmarks + from .benchmark.benchmark_tf import TensorflowBenchmark + from .benchmark.benchmark_args_tf import TensorflowBenchmarkArguments + if not is_tf_available() and not is_torch_available(): logger.warning( diff --git a/src/transformers/benchmark/__init__.py b/src/transformers/benchmark/__init__.py index 5eae4b2cb36783..e69de29bb2d1d6 100644 --- a/src/transformers/benchmark/__init__.py +++ b/src/transformers/benchmark/__init__.py @@ -1,10 +0,0 @@ -# flake8: noqa -# There's no way to ignore "F401 '...' imported but unused" warnings in this -# module, but to preserve other warnings. So, don't check this module at all. - -from ..file_utils import is_torch_available - - -if is_torch_available(): - from .benchmark_args import PyTorchBenchmarkArguments - from .benchmark import PyTorchBenchmark diff --git a/src/transformers/benchmark/benchmark.py b/src/transformers/benchmark/benchmark.py index 63db8272364815..2136cdf7ab1a38 100644 --- a/src/transformers/benchmark/benchmark.py +++ b/src/transformers/benchmark/benchmark.py @@ -25,8 +25,8 @@ MODEL_MAPPING, MODEL_WITH_LM_HEAD_MAPPING, PretrainedConfig, + is_py3nvml_available, is_torch_available, - is_torch_tpu_available, ) from .benchmark_utils import Benchmark, Memory, measure_peak_memory_cpu, start_memory_tracing, stop_memory_tracing @@ -37,6 +37,10 @@ from .benchmark_args import PyTorchBenchmarkArguments +if is_py3nvml_available(): + import py3nvml.py3nvml as nvml + + logger = logging.getLogger(__name__) @@ -50,220 +54,167 @@ class PyTorchBenchmark(Benchmark): def framework_version(self): return torch.__version__ - def train(self, model_name, batch_size, sequence_length, trace_memory=False): - try: - config = self.config_dict[model_name] + def _inference_speed(self, model_name, batch_size, sequence_length): + _inference = self._prepare_inference_func(model_name, batch_size, sequence_length) + return self._measure_speed(_inference) - if self.args.torchscript: - config.torchscript = True + def _inference_memory(self, model_name, batch_size, sequence_length): + _inference = self._prepare_inference_func(model_name, batch_size, sequence_length) + return self._measure_memory(_inference) - model = MODEL_WITH_LM_HEAD_MAPPING[config.__class__](config) - model.to(self.args.device) - model.train() + def _train_speed(self, model_name, batch_size, sequence_length): + _train = self._prepare_train_func(model_name, batch_size, sequence_length) + return self._measure_speed(_train) - # encoder-decoder has vocab size saved differently - vocab_size = config.vocab_size if hasattr(config, "vocab_size") else config.encoder.vocab_size - input_ids = torch.randint( - vocab_size, (batch_size, sequence_length), dtype=torch.long, device=self.args.device - ) + def _train_memory(self, model_name, batch_size, sequence_length): + _train = self._prepare_train_func(model_name, batch_size, sequence_length) + return self._measure_memory(_train) - if self.args.torchscript: - raise NotImplementedError("Training for torchscript is currently not implemented") - else: - train_model = model - - def compute_loss_and_backprob_encoder(): - loss = train_model(input_ids, labels=input_ids)[0] - loss.backward() - train_model.zero_grad() - - def compute_loss_and_backprob_encoder_decoder(): - loss = train_model(input_ids, decoder_input_ids=input_ids, labels=input_ids)[0] - loss.backward() - train_model.zero_grad() - - _train = ( - compute_loss_and_backprob_encoder_decoder - if config.is_encoder_decoder - else compute_loss_and_backprob_encoder - ) - - if trace_memory is True: - if self.args.trace_memory_line_by_line: - trace = start_memory_tracing("transformers") - - if self.args.n_gpu > 0: - # gpu - # clear gpu cache - torch.cuda.empty_cache() - if hasattr(torch.cuda, "max_memory_reserved"): - torch.cuda.reset_peak_memory_stats() - else: - logger.info( - "Please consider updating PyTorch to version 1.4 to get more accuracy on GPU memory usage" - ) - torch.cuda.reset_max_memory_cached() - - # calculate loss and do backpropagation - _train() - elif not self.args.no_tpu and is_torch_tpu_available(): - # tpu - raise NotImplementedError( - "Memory Benchmarking is currently not implemented for TPU. Please disable memory benchmarking with `args.no_memory=True`" - ) - else: - # cpu - memory_bytes = measure_peak_memory_cpu(_train) - memory = Memory(memory_bytes) if isinstance(memory_bytes, int) else memory_bytes + def _prepare_inference_func(self, model_name, batch_size, sequence_length): + config = self.config_dict[model_name] - if self.args.trace_memory_line_by_line: - summary = stop_memory_tracing(trace) - else: - summary = None - - if self.args.n_gpu > 0: - # gpu - if hasattr(torch.cuda, "max_memory_reserved"): - memory = Memory(torch.cuda.max_memory_reserved()) - else: - logger.info( - "Please consider updating PyTorch to version 1.4 to get more accuracy on GPU memory usage" - ) - memory = Memory(torch.cuda.max_memory_reserved()) - - return memory, summary - else: - if (not self.args.no_tpu and is_torch_tpu_available()) or self.args.torchscript: - # run additional 10 times to stabilize compilation for tpu and torchscript - logger.info("Do inference on TPU or torchscript. Running model 5 times to stabilize compilation") - timeit.repeat( - _train, repeat=1, number=5, - ) + if self.args.torchscript: + config.torchscript = True + if self.args.with_lm_head: + model = MODEL_WITH_LM_HEAD_MAPPING[config.__class__](config) + else: + model = MODEL_MAPPING[config.__class__](config) + + model.eval() + model.to(self.args.device) + + # encoder-decoder has vocab size saved differently + vocab_size = config.vocab_size if hasattr(config, "vocab_size") else config.encoder.vocab_size + input_ids = torch.randint(vocab_size, (batch_size, sequence_length), dtype=torch.long, device=self.args.device) + + if self.args.fp16: + logger.info("Running training in Mixed Precision...") + assert self.args.is_gpu, "Mixed precision is possible only for GPU." + # amp seems to have memory leaks so that memory usage + # is measured using .half() for now https://github.com/NVIDIA/apex/issues/439 + model.half() + + if self.args.torchscript: + with torch.no_grad(): + inference_model = torch.jit.trace(model, input_ids) + else: + inference_model = model + + def encoder_decoder_forward(): + with torch.no_grad(): + inference_model(input_ids, decoder_input_ids=input_ids) + + def encoder_forward(): + with torch.no_grad(): + inference_model(input_ids) + + _forward = encoder_decoder_forward if config.is_encoder_decoder else encoder_forward + return _forward + + def _prepare_train_func(self, model_name, batch_size, sequence_length): + config = self.config_dict[model_name] + model = MODEL_WITH_LM_HEAD_MAPPING[config.__class__](config) + + if self.args.torchscript: + raise NotImplementedError("Training for torchscript is currently not implemented") + else: + train_model = model + + model.eval() + model.to(self.args.device) + + # encoder-decoder has vocab size saved differently + vocab_size = config.vocab_size if hasattr(config, "vocab_size") else config.encoder.vocab_size + input_ids = torch.randint(vocab_size, (batch_size, sequence_length), dtype=torch.long, device=self.args.device) + + if self.args.fp16: + logger.info("Running training in Mixed Precision...") + assert self.args.is_gpu, "Mixed precision is possible only for GPU." + + # amp seems to have memory leaks so that memory usage + # is measured using .half() for now https://github.com/NVIDIA/apex/issues/439 + model.half() + + def compute_loss_and_backprob_encoder(): + loss = train_model(input_ids, labels=input_ids)[0] + loss.backward() + train_model.zero_grad() + + def compute_loss_and_backprob_encoder_decoder(): + loss = train_model(input_ids, decoder_input_ids=input_ids, labels=input_ids)[0] + loss.backward() + train_model.zero_grad() + + _train = ( + compute_loss_and_backprob_encoder_decoder + if config.is_encoder_decoder + else compute_loss_and_backprob_encoder + ) + return _train + + def _measure_speed(self, func): + try: + if self.args.is_tpu or self.args.torchscript: + # run additional 10 times to stabilize compilation for tpu and torchscript + logger.info("Do inference on TPU or torchscript. Running model 5 times to stabilize compilation") + timeit.repeat( + func, repeat=1, number=5, + ) - # as written in https://docs.python.org/2/library/timeit.html#timeit.Timer.repeat, min should be taken rather than the average - runtimes = timeit.repeat(_train, repeat=self.args.repeat, number=10,) + # as written in https://docs.python.org/2/library/timeit.html#timeit.Timer.repeat, min should be taken rather than the average + runtimes = timeit.repeat(func, repeat=self.args.repeat, number=10,) - if not self.args.no_tpu and is_torch_tpu_available() and self.args.tpu_print_metrics: - import torch_xla.debug.metrics as met + if self.args.is_tpu and self.args.torch_xla_tpu_print_metrics: + import torch_xla.debug.metrics as met - self.print_fn(met.metrics_report()) + self.print_fn(met.metrics_report()) - return min(runtimes) / 10.0 + return min(runtimes) / 10.0 except RuntimeError as e: self.print_fn("Doesn't fit on GPU. {}".format(e)) - if trace_memory: - return "N/A", None - else: - return "N/A" + return "N/A" - def inference(self, model_name, batch_size, sequence_length, trace_memory=False): + def _measure_memory(self, func): try: - config = self.config_dict[model_name] - model = None - - if self.args.torchscript: - config.torchscript = True - - if self.args.with_lm_head: - model = MODEL_WITH_LM_HEAD_MAPPING[config.__class__](config) - else: - model = MODEL_MAPPING[config.__class__](config) - - model.eval() - model.to(self.args.device) - - # encoder-decoder has vocab size saved differently - vocab_size = config.vocab_size if hasattr(config, "vocab_size") else config.encoder.vocab_size - - input_ids = torch.randint( - vocab_size, (batch_size, sequence_length), dtype=torch.long, device=self.args.device - ) - - if self.args.torchscript: - with torch.no_grad(): - if config.is_encoder_decoder: - raise NotImplementedError("Torchscript is currently not supported for EncoderDecoder models") - else: - inference_model = torch.jit.trace(model, input_ids) - else: - inference_model = model - - def encoder_decoder_forward(): - with torch.no_grad(): - inference_model(input_ids, decoder_input_ids=input_ids) - - def encoder_forward(): - with torch.no_grad(): - inference_model(input_ids) - - _forward = encoder_decoder_forward if config.is_encoder_decoder else encoder_forward - - if trace_memory is True: - if self.args.trace_memory_line_by_line: - trace = start_memory_tracing("transformers") - - if self.args.n_gpu > 0: - # gpu - # clear gpu cache - torch.cuda.empty_cache() - if hasattr(torch.cuda, "max_memory_reserved"): - torch.cuda.reset_peak_memory_stats() - else: - logger.info( - "Please consider updating PyTorch to version 1.4 to get more accuracy on GPU memory usage" - ) - torch.cuda.reset_max_memory_cached() - - # run forward - _forward() - elif not self.args.no_tpu and is_torch_tpu_available(): - # tpu - raise NotImplementedError( - "Memory Benchmarking is currently not implemented for TPU. Please disable memory benchmarking with `args.no_memory=True`" + if self.args.trace_memory_line_by_line: + trace = start_memory_tracing("transformers") + + if self.args.is_tpu: + # tpu + raise NotImplementedError( + "Memory Benchmarking is currently not implemented for TPU. Please disable memory benchmarking with `args.no_memory=True`" + ) + elif self.args.is_gpu: + if not is_py3nvml_available(): + logger.warning( + "py3nvml not installed, we won't log GPU memory usage. " + "Install py3nvml (pip install py3nvml) to log information about GPU." ) + memory = "N/A" else: - # cpu - memory_bytes = measure_peak_memory_cpu(_forward) - memory = Memory(memory_bytes) if isinstance(memory_bytes, int) else memory_bytes - - if self.args.trace_memory_line_by_line: - summary = stop_memory_tracing(trace) - else: - summary = None - - if self.args.n_gpu > 0: - # gpu - if hasattr(torch.cuda, "max_memory_reserved"): - memory = Memory(torch.cuda.max_memory_reserved()) - else: - logger.info( - "Please consider updating PyTorch to version 1.4 to get more accuracy on GPU memory usage" - ) - memory = Memory(torch.cuda.max_memory_cached()) - - return memory, summary - else: - - if (not self.args.no_tpu and is_torch_tpu_available()) or self.args.torchscript: - # run additional 10 times to stabilize compilation for tpu and torchscript - logger.info("Do inference on TPU or torchscript. Running model 5 times to stabilize compilation") - timeit.repeat( - _forward, repeat=1, number=5, + logger.info( + "Measuring total GPU usage on GPU device. Make sure to not have additional processes running on the same GPU." ) + # init nvml + nvml.nvmlInit() + func() + handle = nvml.nvmlDeviceGetHandleByIndex(self.args.device_idx) + meminfo = nvml.nvmlDeviceGetMemoryInfo(handle) + max_bytes_in_use = meminfo.used + memory = Memory(max_bytes_in_use) + # shutdown nvml + nvml.nvmlShutdown() + else: + # cpu + memory_bytes = measure_peak_memory_cpu(func) + memory = Memory(memory_bytes) if isinstance(memory_bytes, int) else memory_bytes - # as written in https://docs.python.org/2/library/timeit.html#timeit.Timer.repeat, min should be taken rather than the average - runtimes = timeit.repeat(_forward, repeat=self.args.repeat, number=10,) - - if not self.args.no_tpu and is_torch_tpu_available() and self.args.tpu_print_metrics: - import torch_xla.debug.metrics as met - - self.print_fn(met.metrics_report()) - - return min(runtimes) / 10.0 + if self.args.trace_memory_line_by_line: + summary = stop_memory_tracing(trace) + else: + summary = None + return memory, summary except RuntimeError as e: self.print_fn("Doesn't fit on GPU. {}".format(e)) - if trace_memory: - return "N/A", None - else: - return "N/A" + return "N/A", None diff --git a/src/transformers/benchmark/benchmark_args.py b/src/transformers/benchmark/benchmark_args.py index 0cc043537b5ceb..0ecac83adf946d 100644 --- a/src/transformers/benchmark/benchmark_args.py +++ b/src/transformers/benchmark/benchmark_args.py @@ -34,11 +34,17 @@ @dataclass class PyTorchBenchmarkArguments(BenchmarkArguments): - no_cuda: bool = field(default=False, metadata={"help": "Whether to run on available cuda devices"}) torchscript: bool = field(default=False, metadata={"help": "Trace the models using torchscript"}) - no_tpu: bool = field(default=False, metadata={"help": "Whether to run on available tpu devices"}) - fp16: bool = field(default=False, metadata={"help": "Use FP16 to accelerate inference."}) - tpu_print_metrics: bool = field(default=False, metadata={"help": "Use FP16 to accelerate inference."}) + torch_xla_tpu_print_metrics: bool = field(default=False, metadata={"help": "Print Xla/PyTorch tpu metrics"}) + fp16_opt_level: str = field( + default="O1", + metadata={ + "help": ( + "For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']." + "See details at https://nvidia.github.io/apex/amp.html" + ) + }, + ) @cached_property @torch_required @@ -55,9 +61,14 @@ def _setup_devices(self) -> Tuple["torch.device", int]: n_gpu = torch.cuda.device_count() return device, n_gpu + @property + def is_tpu(self): + return is_torch_tpu_available() and not self.no_tpu + @property @torch_required def device_idx(self) -> int: + # TODO(PVP): currently only single GPU is supported return torch.cuda.current_device() @property @@ -69,3 +80,7 @@ def device(self) -> "torch.device": @torch_required def n_gpu(self): return self._setup_devices[1] + + @property + def is_gpu(self): + return self.n_gpu > 0 diff --git a/src/transformers/benchmark/benchmark_args_tf.py b/src/transformers/benchmark/benchmark_args_tf.py new file mode 100644 index 00000000000000..881b581ee85d8d --- /dev/null +++ b/src/transformers/benchmark/benchmark_args_tf.py @@ -0,0 +1,106 @@ +# coding=utf-8 +# Copyright 2018 The HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from dataclasses import dataclass, field +from typing import Tuple + +from ..file_utils import cached_property, is_tf_available, tf_required +from .benchmark_args_utils import BenchmarkArguments + + +if is_tf_available(): + import tensorflow as tf + + +logger = logging.getLogger(__name__) + + +@dataclass +class TensorflowBenchmarkArguments(BenchmarkArguments): + tpu_name: str = field( + default=None, metadata={"help": "Name of TPU"}, + ) + device_idx: int = field( + default=0, metadata={"help": "CPU / GPU device index. Defaults to 0."}, + ) + eager_mode: bool = field(default=False, metadata={"help": "Benchmark models in eager model."}) + use_xla: bool = field( + default=False, + metadata={ + "help": "Benchmark models using XLA JIT compilation. Note that `eager_model` has to be set to `False`." + }, + ) + + @cached_property + @tf_required + def _setup_strategy(self) -> Tuple["tf.distribute.Strategy", "tf.distribute.cluster_resolver.TPUClusterResolver"]: + logger.info("Tensorflow: setting up strategy") + + if not self.no_tpu: + try: + if self.tpu_name: + tpu = tf.distribute.cluster_resolver.TPUClusterResolver(self.tpu_name) + else: + tpu = tf.distribute.cluster_resolver.TPUClusterResolver() + except ValueError: + tpu = None + else: + tpu = None + + if tpu is not None: + tf.config.experimental_connect_to_cluster(tpu) + tf.tpu.experimental.initialize_tpu_system(tpu) + + strategy = tf.distribute.experimental.TPUStrategy(tpu) + else: + # currently no multi gpu is allowed + if self.is_gpu: + # TODO: Currently only single GPU is supported + tf.config.experimental.set_visible_devices(self.gpu_list[self.device_idx], "GPU") + strategy = tf.distribute.OneDeviceStrategy(device=f"/gpu:{self.device_idx}") + else: + tf.config.experimental.set_visible_devices([], "GPU") # disable GPU + strategy = tf.distribute.OneDeviceStrategy(device=f"/cpu:{self.device_idx}") + + return strategy, tpu + + @property + @tf_required + def is_tpu(self) -> bool: + tpu = self._setup_strategy[1] + return tpu is not None and not self.no_tpu + + @property + @tf_required + def strategy(self) -> "tf.distribute.Strategy": + return self._setup_strategy[0] + + @property + @tf_required + def gpu_list(self): + return tf.config.list_physical_devices("GPU") + + @property + @tf_required + def n_gpu(self) -> int: + if not self.no_cuda: + return len(self.gpu_list) + return 0 + + @property + def is_gpu(self) -> bool: + return self.n_gpu > 0 diff --git a/src/transformers/benchmark/benchmark_args_utils.py b/src/transformers/benchmark/benchmark_args_utils.py index ac76c37eb1f8c8..7962c68288cc5f 100644 --- a/src/transformers/benchmark/benchmark_args_utils.py +++ b/src/transformers/benchmark/benchmark_args_utils.py @@ -16,11 +16,15 @@ import dataclasses import json +import logging from dataclasses import dataclass, field from time import time from typing import List +logger = logging.getLogger(__name__) + + def list_field(default=None, metadata=None): return field(default_factory=lambda: default, metadata=metadata) @@ -53,6 +57,9 @@ class BenchmarkArguments: ) no_inference: bool = field(default=False, metadata={"help": "Don't benchmark inference of model"}) + no_cuda: bool = field(default=False, metadata={"help": "Whether to run on available cuda devices"}) + no_tpu: bool = field(default=False, metadata={"help": "Whether to run on available tpu devices"}) + fp16: bool = field(default=False, metadata={"help": "Use FP16 to accelerate inference."}) training: bool = field(default=False, metadata={"help": "Benchmark training of model"}) verbose: bool = field(default=False, metadata={"help": "Verbose memory tracing"}) no_speed: bool = field(default=False, metadata={"help": "Don't perform speed measurments"}) @@ -61,6 +68,12 @@ class BenchmarkArguments: save_to_csv: bool = field(default=False, metadata={"help": "Save result to a CSV file"}) log_print: bool = field(default=False, metadata={"help": "Save all print statements in a log file"}) no_env_print: bool = field(default=False, metadata={"help": "Don't print environment information"}) + no_multi_process: bool = field( + default=False, + metadata={ + "help": "Don't use multiprocessing for memory and speed measurement. It is highly recommended to use multiprocessing for accurate CPU and GPU memory measurements. This option should only be used for debugging / testing and on TPU." + }, + ) with_lm_head: bool = field( default=False, metadata={ @@ -102,3 +115,13 @@ def to_json_string(self): @property def model_names(self): return self.models + + @property + def do_multi_processing(self): + if self.no_multi_process: + return False + elif self.is_tpu: + logger.info("Multiprocessing is currently not possible on TPU.") + return False + else: + return True diff --git a/src/transformers/benchmark/benchmark_tf.py b/src/transformers/benchmark/benchmark_tf.py new file mode 100644 index 00000000000000..5bb34c25f0fe7a --- /dev/null +++ b/src/transformers/benchmark/benchmark_tf.py @@ -0,0 +1,216 @@ +# coding=utf-8 +# Copyright 2018 The HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + Benchmarking the library on inference and training in PyTorch. +""" + + +import logging +import random +import timeit +from functools import wraps + +from transformers import ( + TF_MODEL_MAPPING, + TF_MODEL_WITH_LM_HEAD_MAPPING, + PretrainedConfig, + is_py3nvml_available, + is_tf_available, +) + +from .benchmark_utils import Benchmark, Memory, measure_peak_memory_cpu, start_memory_tracing, stop_memory_tracing + + +if is_tf_available(): + import tensorflow as tf + from .benchmark_args_tf import TensorflowBenchmarkArguments + from tensorflow.python.framework.errors_impl import ResourceExhaustedError + +if is_py3nvml_available(): + import py3nvml.py3nvml as nvml + +logger = logging.getLogger(__name__) + + +def run_with_tf_optimizations(do_eager_mode, do_xla): + def run_func(func): + @wraps(func) + def run_in_eager_mode(*args, **kwargs): + return func(*args, **kwargs) + + @wraps(func) + @tf.function(experimental_compile=do_xla) + def run_in_graph_mode(*args, **kwargs): + return func(*args, **kwargs) + + if do_eager_mode is True: + assert ( + do_xla is False + ), "Cannot run model in XLA, if `args.eager_mode` is set to `True`. Please set `args.eager_mode=False`." + return run_in_eager_mode + else: + return run_in_graph_mode + + return run_func + + +def random_input_ids(batch_size, sequence_length, vocab_size): + rng = random.Random() + values = [rng.randint(0, vocab_size - 1) for i in range(batch_size * sequence_length)] + return tf.constant(values, shape=(batch_size, sequence_length), dtype=tf.int32) + + +class TensorflowBenchmark(Benchmark): + + args: TensorflowBenchmarkArguments + configs: PretrainedConfig + framework: str = "Tensorflow" + + @property + def framework_version(self): + return tf.__version__ + + def _inference_speed(self, model_name, batch_size, sequence_length): + # initialize GPU on separate process + strategy = self.args.strategy + assert strategy is not None, "A device strategy has to be initialized before using Tensorflow." + _inference = self._prepare_inference_func(model_name, batch_size, sequence_length) + return self._measure_speed(_inference) + + def _train_speed(self, model_name, batch_size, sequence_length): + raise NotImplementedError( + "Training is currently not really implemented." "Wait for TFTrainer to support CLM and MLM." + ) + + def _inference_memory(self, model_name, batch_size, sequence_length): + # initialize GPU on separate process + if self.args.is_gpu: + tf.config.experimental.set_memory_growth(self.args.gpu_list[self.args.device_idx], True) + strategy = self.args.strategy + assert strategy is not None, "A device strategy has to be initialized before using Tensorflow." + _inference = self._prepare_inference_func(model_name, batch_size, sequence_length) + return self._measure_memory(_inference) + + def _train_memory(self, model_name, batch_size, sequence_length): + raise NotImplementedError( + "Training is currently not really implemented. Wait for TFTrainer to support CLM and MLM." + ) + + def _prepare_inference_func(self, model_name, batch_size, sequence_length): + config = self.config_dict[model_name] + + if self.args.fp16: + raise NotImplementedError("Mixed precision is currently not supported.") + + if self.args.with_lm_head: + model = TF_MODEL_WITH_LM_HEAD_MAPPING[config.__class__](config) + else: + model = TF_MODEL_MAPPING[config.__class__](config) + + # encoder-decoder has vocab size saved differently + vocab_size = config.vocab_size if hasattr(config, "vocab_size") else config.encoder.vocab_size + input_ids = random_input_ids(batch_size, sequence_length, vocab_size) + + @run_with_tf_optimizations(self.args.eager_mode, self.args.use_xla) + def encoder_decoder_forward(): + model(input_ids, decoder_input_ids=input_ids, training=False) + + @run_with_tf_optimizations(self.args.eager_mode, self.args.use_xla) + def encoder_forward(): + model(input_ids, training=False) + + _inference = encoder_decoder_forward if config.is_encoder_decoder else encoder_forward + + return _inference + + def _measure_speed(self, func): + with self.args.strategy.scope(): + try: + if self.args.is_tpu or self.args.use_xla: + # run additional 10 times to stabilize compilation for tpu + logger.info("Do inference on TPU. Running model 5 times to stabilize compilation") + timeit.repeat(func, repeat=1, number=5) + + # as written in https://docs.python.org/2/library/timeit.html#timeit.Timer.repeat, min should be taken rather than the average + runtimes = timeit.repeat(func, repeat=self.args.repeat, number=10,) + + return min(runtimes) / 10.0 + except ResourceExhaustedError as e: + self.print_fn("Doesn't fit on GPU. {}".format(e)) + + def _measure_memory(self, func): + logger.info( + "Note that Tensorflow allocates more memory than" + "it might need to speed up computation." + "The memory reported here corresponds to the memory" + "reported by `nvidia-smi`, which can vary depending" + "on total available memory on the GPU that is used." + ) + with self.args.strategy.scope(): + try: + if self.args.trace_memory_line_by_line: + assert ( + self.args.eager_mode + ), "`args.eager_mode` is set to `False`. Make sure to run model in eager mode to measure memory consumption line by line." + trace = start_memory_tracing("transformers") + + if self.args.is_tpu: + # tpu + raise NotImplementedError( + "Memory Benchmarking is currently not implemented for TPU. Please disable memory benchmarking with `args.no_memory=True`" + ) + elif self.args.is_gpu: + # gpu + if not is_py3nvml_available(): + logger.warning( + "py3nvml not installed, we won't log GPU memory usage. " + "Install py3nvml (pip install py3nvml) to log information about GPU." + ) + memory = "N/A" + else: + logger.info( + "Measuring total GPU usage on GPU device. Make sure to not have additional processes running on the same GPU." + ) + # init nvml + nvml.nvmlInit() + func() + handle = nvml.nvmlDeviceGetHandleByIndex(self.args.device_idx) + meminfo = nvml.nvmlDeviceGetMemoryInfo(handle) + max_bytes_in_use = meminfo.used + memory = Memory(max_bytes_in_use) + # shutdown nvml + nvml.nvmlShutdown() + else: + # cpu + if self.args.trace_memory_line_by_line: + logger.info( + "When enabling line by line tracing, the max peak memory for CPU is inaccurate in Tensorflow." + ) + memory = None + else: + memory_bytes = measure_peak_memory_cpu(func) + memory = Memory(memory_bytes) if isinstance(memory_bytes, int) else memory_bytes + if self.args.trace_memory_line_by_line: + summary = stop_memory_tracing(trace) + if memory is None: + memory = summary.total + else: + summary = None + + return memory, summary + except ResourceExhaustedError as e: + self.print_fn("Doesn't fit on GPU. {}".format(e)) + return "N/A", None diff --git a/src/transformers/benchmark/benchmark_utils.py b/src/transformers/benchmark/benchmark_utils.py index 5b7cb438532a07..166adeef39208a 100644 --- a/src/transformers/benchmark/benchmark_utils.py +++ b/src/transformers/benchmark/benchmark_utils.py @@ -14,14 +14,14 @@ from abc import ABC, abstractmethod from collections import defaultdict, namedtuple from datetime import datetime -from multiprocessing import Pipe, Process +from multiprocessing import Pipe, Process, Queue from multiprocessing.connection import Connection from typing import Callable, Iterable, List, NamedTuple, Optional, Union from transformers import AutoConfig, PretrainedConfig from transformers import __version__ as version -from ..file_utils import is_tf_available, is_torch_available, is_torch_tpu_available +from ..file_utils import is_psutil_available, is_py3nvml_available, is_tf_available, is_torch_available from .benchmark_args_utils import BenchmarkArguments @@ -31,6 +31,11 @@ if is_tf_available(): from tensorflow.python.eager import context as tf_context +if is_psutil_available(): + import psutil + +if is_py3nvml_available(): + import py3nvml.py3nvml as nvml if platform.system() == "Windows": from signal import CTRL_C_EVENT as SIGKILL @@ -56,6 +61,33 @@ ) +def separate_process_wrapper_fn(func, do_multi_processing): + def multi_process_func(*args, **kwargs): + # run function in an individual + # process to get correct memory + def wrapper_func(queue, *args): + try: + result = func(*args) + except Exception as e: + logger.error(e) + print(e) + result = "N/A" + queue.put(result) + + queue = Queue() + p = Process(target=wrapper_func, args=[queue] + list(args)) + p.start() + result = queue.get() + p.join() + return result + + if do_multi_processing: + logging.info("Use multiprocessing...") + return multi_process_func + else: + return func + + def is_memory_tracing_enabled(): global _is_memory_tracing_enabled return _is_memory_tracing_enabled @@ -136,7 +168,7 @@ class MemorySummary(NamedTuple): MemoryTrace = List[UsedMemoryState] -def measure_peak_memory_cpu(function: Callable[[], None], interval=0.5) -> int: +def measure_peak_memory_cpu(function: Callable[[], None], interval=0.5, device_idx=None) -> int: """ measures peak cpu memory consumption of a given `function` running the function for at least interval seconds @@ -148,16 +180,38 @@ def measure_peak_memory_cpu(function: Callable[[], None], interval=0.5) -> int: - `function`: (`callable`): function() -> ... function without any arguments to measure for which to measure the peak memory - - `interval`: (`float`) + - `interval`: (`float`, `optional`, defaults to `0.5`) interval in second for which to measure the memory usage + - `device_idx`: (`int`, `optional`, defaults to `None`) + device id for which to measure gpu usage + Returns: - `max_memory`: (`int`) cosumed memory peak in Bytes """ - try: - import psutil - except (ImportError): + + def get_cpu_memory(process_id: int) -> int: + """ + measures current cpu memory usage of a given `process_id` + + Args: + - `process_id`: (`int`) + process_id for which to measure memory + + Returns + - `memory`: (`int`) + cosumed memory in Bytes + """ + process = psutil.Process(process_id) + try: + meminfo_attr = "memory_info" if hasattr(process, "memory_info") else "get_memory_info" + memory = getattr(process, meminfo_attr)()[0] + except psutil.AccessDenied: + raise ValueError("Error with Psutil.") + return memory + + if not is_psutil_available(): logger.warning( "Psutil not installed, we won't log CPU memory usage. " "Install Psutil (pip install psutil) to use CPU memory tracing." @@ -165,26 +219,6 @@ def measure_peak_memory_cpu(function: Callable[[], None], interval=0.5) -> int: max_memory = "N/A" else: - def _get_memory(process_id: int) -> int: - """ - measures current cpu memory usage of a given `process_id` - - Args: - - `process_id`: (`int`) - process_id for which to measure memory - - Returns - - `memory`: (`int`) - cosumed memory in Bytes - """ - process = psutil.Process(process_id) - try: - meminfo_attr = "memory_info" if hasattr(process, "memory_info") else "get_memory_info" - memory = getattr(process, meminfo_attr)()[0] - except psutil.AccessDenied: - raise ValueError("Error with Psutil.") - return memory - class MemoryMeasureProcess(Process): """ @@ -198,13 +232,13 @@ def __init__(self, process_id: int, child_connection: Connection, interval: floa self.interval = interval self.connection = child_connection self.num_measurements = 1 - self.mem_usage = _get_memory(process_id) + self.mem_usage = get_cpu_memory(self.process_id) def run(self): self.connection.send(0) stop = False while True: - self.mem_usage = max(self.mem_usage, _get_memory(self.process_id)) + self.mem_usage = max(self.mem_usage, get_cpu_memory(self.process_id)) self.num_measurements += 1 if stop: @@ -296,34 +330,31 @@ def start_memory_tracing( - 'line_text' (string): Text of the line in the python script """ - try: - import psutil - except (ImportError): + if is_psutil_available(): + process = psutil.Process(os.getpid()) + else: logger.warning( "Psutil not installed, we won't log CPU memory usage. " "Install psutil (pip install psutil) to use CPU memory tracing." ) process = None - else: - process = psutil.Process(os.getpid()) - try: - from py3nvml import py3nvml - - py3nvml.nvmlInit() - devices = list(range(py3nvml.nvmlDeviceGetCount())) if gpus_to_trace is None else gpus_to_trace - py3nvml.nvmlShutdown() - except ImportError: + if is_py3nvml_available(): + try: + nvml.nvmlInit() + devices = list(range(nvml.nvmlDeviceGetCount())) if gpus_to_trace is None else gpus_to_trace + nvml.nvmlShutdown() + except (OSError, nvml.NVMLError): + logger.warning("Error while initializing comunication with GPU. " "We won't perform GPU memory tracing.") + log_gpu = False + else: + log_gpu = is_torch_available() or is_tf_available() + else: logger.warning( "py3nvml not installed, we won't log GPU memory usage. " "Install py3nvml (pip install py3nvml) to use GPU memory tracing." ) log_gpu = False - except (OSError, py3nvml.NVMLError): - logger.warning("Error while initializing comunication with GPU. " "We won't perform GPU memory tracing.") - log_gpu = False - else: - log_gpu = is_torch_available() or is_tf_available() memory_trace = [] @@ -385,14 +416,14 @@ def traceit(frame, event, args): tf_context.context()._clear_caches() # See https://github.com/tensorflow/tensorflow/issues/20218#issuecomment-416771802 # Sum used memory for all GPUs - py3nvml.nvmlInit() + nvml.nvmlInit() for i in devices: - handle = py3nvml.nvmlDeviceGetHandleByIndex(i) - meminfo = py3nvml.nvmlDeviceGetMemoryInfo(handle) + handle = nvml.nvmlDeviceGetHandleByIndex(i) + meminfo = nvml.nvmlDeviceGetMemoryInfo(handle) gpu_mem += meminfo.used - py3nvml.nvmlShutdown() + nvml.nvmlShutdown() mem_state = UsedMemoryState(traced_state, cpu_mem, gpu_mem) memory_trace.append(mem_state) @@ -522,7 +553,6 @@ class Benchmark(ABC): def __init__(self, args: BenchmarkArguments = None, configs: PretrainedConfig = None): self.args = args - if configs is None: self.config_dict = { model_name: AutoConfig.from_pretrained(model_name) for model_name in self.args.model_names @@ -530,6 +560,11 @@ def __init__(self, args: BenchmarkArguments = None, configs: PretrainedConfig = else: self.config_dict = {model_name: config for model_name, config in zip(self.args.model_names, configs)} + if not self.args.no_memory and os.getenv("TRANSFORMERS_USE_MULTIPROCESSING") == 0: + logger.warning( + "Memory consumption will not be measured accurately if `args.no_multi_process` is set to `True.` The flag 'TRANSFORMERS_USE_MULTIPROCESSING' should only be disabled for debugging / testing." + ) + self._print_fn = None self._framework_version = None self._environment_info = None @@ -541,7 +576,7 @@ def print_fn(self): def print_and_log(*args): with open(self.args.log_filename, "a") as log_file: - log_file.write(str(*args) + "\n") + log_file.write("".join(args) + "\n") print(*args) self._print_fn = print_and_log @@ -550,26 +585,38 @@ def print_and_log(*args): return self._print_fn @property - def is_gpu(self): - return self.args.n_gpu > 0 + @abstractmethod + def framework_version(self): + pass - @property - def is_tpu(self): - return is_torch_tpu_available() and not self.args.no_tpu + @abstractmethod + def _inference_speed(self, model_name, batch_size, sequence_length): + pass - @property @abstractmethod - def framework_version(self): + def _train_speed(self, model_name, batch_size, sequence_length): pass @abstractmethod - def train(self, model_name, batch_size, sequence_length): + def _inference_memory(self, model_name, batch_size, sequence_length): pass @abstractmethod - def inference(self, model_name, batch_size, sequence_length): + def _train_memory(self, model_name, batch_size, sequence_length): pass + def inference_speed(self, *args, **kwargs): + return separate_process_wrapper_fn(self._inference_speed, self.args.do_multi_processing)(*args, **kwargs) + + def train_speed(self, *args, **kwargs): + return separate_process_wrapper_fn(self._train_speed, self.args.do_multi_processing)(*args, **kwargs) + + def inference_memory(self, *args, **kwargs): + return separate_process_wrapper_fn(self._inference_memory, self.args.do_multi_processing)(*args, **kwargs) + + def train_memory(self, *args, **kwargs): + return separate_process_wrapper_fn(self._train_memory, self.args.do_multi_processing)(*args, **kwargs) + def run(self): result_dict = {model_name: {} for model_name in self.args.model_names} inference_result_time = copy.deepcopy(result_dict) @@ -596,64 +643,60 @@ def run(self): for sequence_length in self.args.sequence_lengths: if not self.args.no_inference: if not self.args.no_memory: - memory, inference_summary = self.inference( - model_name, batch_size, sequence_length, trace_memory=True - ) + memory, inference_summary = self.inference_memory(model_name, batch_size, sequence_length) inference_result_memory[model_name]["result"][batch_size][sequence_length] = memory if not self.args.no_speed: - time = self.inference(model_name, batch_size, sequence_length, trace_memory=False) + time = self.inference_speed(model_name, batch_size, sequence_length) inference_result_time[model_name]["result"][batch_size][sequence_length] = time if self.args.training: if not self.args.no_memory: - memory, train_summary = self.train( - model_name, batch_size, sequence_length, trace_memory=True - ) + memory, train_summary = self.train_memory(model_name, batch_size, sequence_length) train_result_memory[model_name]["result"][batch_size][sequence_length] = memory if not self.args.no_speed: - time = self.inference(model_name, batch_size, sequence_length, trace_memory=False) + time = self.train_speed(model_name, batch_size, sequence_length) train_result_time[model_name]["result"][batch_size][sequence_length] = time if not self.args.no_inference: if not self.args.no_speed: - self.print_fn("======= INFERENCE - SPEED - RESULT =======") - self.print_results(inference_result_time) + self.print_fn("\n" + 20 * "=" + ("INFERENCE - SPEED - RESULT").center(40) + 20 * "=") + self.print_results(inference_result_time, type_label="Time in s") self.save_to_csv(inference_result_time, self.args.inference_time_csv_file) - if self.is_tpu: + if self.args.is_tpu: self.print_fn( "TPU was used for inference. Note that the time after compilation stabilized (after ~10 inferences model.forward(..) calls) was measured." ) if not self.args.no_memory: - self.print_fn("======= INFERENCE - MEMORY - RESULT =======") - self.print_results(inference_result_memory) + self.print_fn("\n" + 20 * "=" + ("INFERENCE - MEMORY - RESULT").center(40) + 20 * "=") + self.print_results(inference_result_memory, type_label="Memory in MB") self.save_to_csv(inference_result_memory, self.args.inference_memory_csv_file) if self.args.trace_memory_line_by_line: - self.print_fn("======= INFERENCE - MEMORY LINE BY LINE TRACE - SUMMARY =======") + self.print_fn("\n" + 20 * "=" + ("INFERENCE - MEMOMRY - LINE BY LINE - SUMMARY").center(40) + 20 * "=") self.print_memory_trace_statistics(inference_summary) if self.args.training: if not self.args.no_speed: - self.print_fn("======= TRAIN - SPEED - RESULT =======") - self.print_results(train_result_time) + self.print_fn("\n" + 20 * "=" + ("TRAIN - SPEED - RESULTS").center(40) + 20 * "=") + self.print_results(train_result_time, "Time in s") self.save_to_csv(train_result_time, self.args.train_time_csv_file) - if self.is_tpu: + if self.args.is_tpu: self.print_fn( "TPU was used for training. Note that the time after compilation stabilized (after ~10 train loss=model.forward(...) + loss.backward() calls) was measured." ) if not self.args.no_memory: - self.print_fn("======= TRAIN - MEMORY - RESULT =======") - self.print_results(train_result_memory) + self.print_fn("\n" + 20 * "=" + ("TRAIN - MEMORY - RESULTS").center(40) + 20 * "=") + self.print_results(train_result_memory, type_label="Memory in MB") self.save_to_csv(train_result_memory, self.args.train_memory_csv_file) if self.args.trace_memory_line_by_line: - self.print_fn("======= TRAIN - MEMORY LINE BY LINE TRACE - SUMMARY =======") + self.print_fn("\n" + 20 * "=" + ("TRAIN - MEMOMRY - LINE BY LINE - SUMMARY").center(40) + 20 * "=") self.print_memory_trace_statistics(train_summary) if not self.args.no_env_print: - self.print_fn("\n======== ENVIRONMENT - INFORMATION ========") + self.print_fn("\n" + 20 * "=" + ("ENVIRONMENT INFORMATION").center(40) + 20 * "=") self.print_fn( "\n".join(["- {}: {}".format(prop, val) for prop, val in self.environment_info.items()]) + "\n" ) @@ -681,6 +724,9 @@ def environment_info(self): info["framework"] = self.framework if self.framework == "PyTorch": info["use_torchscript"] = self.args.torchscript + if self.framework == "Tensorflow": + info["eager_mode"] = self.args.eager_mode + info["use_xla"] = self.args.use_xla info["framework_version"] = self.framework_version info["python_version"] = platform.python_version() info["system"] = platform.system() @@ -688,27 +734,30 @@ def environment_info(self): info["architecture"] = platform.architecture()[0] info["date"] = datetime.date(datetime.now()) info["time"] = datetime.time(datetime.now()) + info["fp16"] = self.args.fp16 + info["use_multiprocessing"] = self.args.do_multi_processing - try: - import psutil - except (ImportError): + if is_psutil_available(): + info["cpu_ram_mb"] = bytes_to_mega_bytes(psutil.virtual_memory().total) + else: logger.warning( "Psutil not installed, we won't log available CPU memory." "Install psutil (pip install psutil) to log available CPU memory." ) info["cpu_ram_mb"] = "N/A" - else: - info["cpu_ram_mb"] = bytes_to_mega_bytes(psutil.virtual_memory().total) - - info["use_gpu"] = self.is_gpu - if self.is_gpu: - info["num_gpus"] = self.args.n_gpu - try: - from py3nvml import py3nvml - py3nvml.nvmlInit() - handle = py3nvml.nvmlDeviceGetHandleByIndex(self.args.device_idx) - except ImportError: + info["use_gpu"] = self.args.is_gpu + if self.args.is_gpu: + info["num_gpus"] = 1 # TODO(PVP) Currently only single GPU is supported + if is_py3nvml_available(): + nvml.nvmlInit() + handle = nvml.nvmlDeviceGetHandleByIndex(self.args.device_idx) + info["gpu"] = nvml.nvmlDeviceGetName(handle) + info["gpu_ram_mb"] = bytes_to_mega_bytes(nvml.nvmlDeviceGetMemoryInfo(handle).total) + info["gpu_power_watts"] = nvml.nvmlDeviceGetPowerManagementLimit(handle) / 1000 + info["gpu_performance_state"] = nvml.nvmlDeviceGetPerformanceState(handle) + nvml.nvmlShutdown() + else: logger.warning( "py3nvml not installed, we won't log GPU memory usage. " "Install py3nvml (pip install py3nvml) to log information about GPU." @@ -717,41 +766,35 @@ def environment_info(self): info["gpu_ram_mb"] = "N/A" info["gpu_power_watts"] = "N/A" info["gpu_performance_state"] = "N/A" - except (OSError, py3nvml.NVMLError): - logger.warning( - "Error while initializing comunication with GPU. " "We won't log information about GPU." - ) - info["gpu"] = "N/A" - info["gpu_ram_mb"] = "N/A" - info["gpu_power_watts"] = "N/A" - info["gpu_performance_state"] = "N/A" - py3nvml.nvmlShutdown() - else: - info["gpu"] = py3nvml.nvmlDeviceGetName(handle) - info["gpu_ram_mb"] = bytes_to_mega_bytes(py3nvml.nvmlDeviceGetMemoryInfo(handle).total) - info["gpu_power_watts"] = py3nvml.nvmlDeviceGetPowerManagementLimit(handle) / 1000 - info["gpu_performance_state"] = py3nvml.nvmlDeviceGetPerformanceState(handle) - py3nvml.nvmlShutdown() - info["use_tpu"] = self.is_tpu + info["use_tpu"] = self.args.is_tpu # TODO(PVP): See if we can add more information about TPU # see: https://github.com/pytorch/xla/issues/2180 self._environment_info = info return self._environment_info - def print_results(self, result_dict): + def print_results(self, result_dict, type_label): + self.print_fn(80 * "-") + self.print_fn( + "Model Name".center(30) + "Batch Size".center(15) + "Seq Length".center(15) + type_label.center(15) + ) + self.print_fn(80 * "-") for model_name in self.args.model_names: - self.print_fn("\t" + f"======= MODEL CHECKPOINT: {model_name} =======") for batch_size in result_dict[model_name]["bs"]: for sequence_length in result_dict[model_name]["ss"]: result = result_dict[model_name]["result"][batch_size][sequence_length] if isinstance(result, float): - self.print_fn( - f"\t\t{model_name}/{batch_size}/{sequence_length}: " f"{(round(1000 * result) / 1000)}s" - ) + result = round(1000 * result) / 1000 + result = "< 0.001" if result == 0.0 else str(result) else: - self.print_fn(f"\t\t{model_name}/{batch_size}/{sequence_length}: " f"{result} MB") + result = str(result) + self.print_fn( + model_name.center(30) + str(batch_size).center(15), + str(sequence_length).center(15), + result.center(15), + ) + self.print_fn(80 * "-") def print_memory_trace_statistics(self, summary: MemorySummary): self.print_fn( diff --git a/src/transformers/file_utils.py b/src/transformers/file_utils.py index 433c77ae5addce..a2af66955632b4 100644 --- a/src/transformers/file_utils.py +++ b/src/transformers/file_utils.py @@ -81,6 +81,31 @@ _torch_tpu_available = False +try: + import psutil # noqa: F401 + + _psutil_available = True + +except ImportError: + _psutil_available = False + + +try: + import py3nvml # noqa: F401 + + _py3nvml_available = True + +except ImportError: + _py3nvml_available = False + + +try: + from apex import amp # noqa: F401 + + _has_apex = True +except ImportError: + _has_apex = False + default_cache_path = os.path.join(torch_cache_home, "transformers") @@ -115,6 +140,18 @@ def is_torch_tpu_available(): return _torch_tpu_available +def is_psutil_available(): + return _psutil_available + + +def is_py3nvml_available(): + return _py3nvml_available + + +def is_apex_available(): + return _has_apex + + def add_start_docstrings(*docstr): def docstring_decorator(fn): fn.__doc__ = "".join(docstr) + (fn.__doc__ if fn.__doc__ is not None else "") diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 7777f9fa7b233f..13381dbfe71957 100644 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -20,23 +20,16 @@ from tqdm.auto import tqdm, trange from .data.data_collator import DataCollator, default_data_collator +from .file_utils import is_apex_available, is_torch_tpu_available from .modeling_utils import PreTrainedModel from .optimization import AdamW, get_linear_schedule_with_warmup from .trainer_utils import PREFIX_CHECKPOINT_DIR, EvalPrediction, PredictionOutput, TrainOutput, is_wandb_available -from .training_args import TrainingArguments, is_torch_tpu_available +from .training_args import TrainingArguments -try: +if is_apex_available(): from apex import amp - _has_apex = True -except ImportError: - _has_apex = False - - -def is_apex_available(): - return _has_apex - if is_torch_tpu_available(): import torch_xla.core.xla_model as xm diff --git a/tests/test_benchmark.py b/tests/test_benchmark.py index b891582c500ab4..bb20af47493add 100644 --- a/tests/test_benchmark.py +++ b/tests/test_benchmark.py @@ -5,7 +5,7 @@ from transformers import AutoConfig, is_torch_available -from .utils import require_torch +from .utils import require_torch, torch_device if is_torch_available(): @@ -26,7 +26,12 @@ def check_results_dict_not_empty(self, results): def test_inference_no_configs(self): MODEL_ID = "sshleifer/tiny-gpt2" benchmark_args = PyTorchBenchmarkArguments( - models=[MODEL_ID], training=False, no_inference=False, sequence_lengths=[8], batch_sizes=[1] + models=[MODEL_ID], + training=False, + no_inference=False, + sequence_lengths=[8], + batch_sizes=[1], + no_multi_process=True, ) benchmark = PyTorchBenchmark(benchmark_args) results = benchmark.run() @@ -42,6 +47,24 @@ def test_inference_torchscript(self): torchscript=True, sequence_lengths=[8], batch_sizes=[1], + no_multi_process=True, + ) + benchmark = PyTorchBenchmark(benchmark_args) + results = benchmark.run() + self.check_results_dict_not_empty(results.time_inference_result) + self.check_results_dict_not_empty(results.memory_inference_result) + + @unittest.skipIf(torch_device == "cpu", "Cant do half precision") + def test_inference_fp16(self): + MODEL_ID = "sshleifer/tiny-gpt2" + benchmark_args = PyTorchBenchmarkArguments( + models=[MODEL_ID], + training=False, + no_inference=False, + fp16=True, + sequence_lengths=[8], + batch_sizes=[1], + no_multi_process=True, ) benchmark = PyTorchBenchmark(benchmark_args) results = benchmark.run() @@ -51,7 +74,29 @@ def test_inference_torchscript(self): def test_train_no_configs(self): MODEL_ID = "sshleifer/tiny-gpt2" benchmark_args = PyTorchBenchmarkArguments( - models=[MODEL_ID], training=True, no_inference=True, sequence_lengths=[8], batch_sizes=[1] + models=[MODEL_ID], + training=True, + no_inference=True, + sequence_lengths=[8], + batch_sizes=[1], + no_multi_process=True, + ) + benchmark = PyTorchBenchmark(benchmark_args) + results = benchmark.run() + self.check_results_dict_not_empty(results.time_train_result) + self.check_results_dict_not_empty(results.memory_train_result) + + @unittest.skipIf(torch_device == "cpu", "Cant do half precision") + def test_train_no_configs_fp16(self): + MODEL_ID = "sshleifer/tiny-gpt2" + benchmark_args = PyTorchBenchmarkArguments( + models=[MODEL_ID], + training=True, + no_inference=True, + sequence_lengths=[8], + batch_sizes=[1], + fp16=True, + no_multi_process=True, ) benchmark = PyTorchBenchmark(benchmark_args) results = benchmark.run() @@ -62,7 +107,12 @@ def test_inference_with_configs(self): MODEL_ID = "sshleifer/tiny-gpt2" config = AutoConfig.from_pretrained(MODEL_ID) benchmark_args = PyTorchBenchmarkArguments( - models=[MODEL_ID], training=False, no_inference=False, sequence_lengths=[8], batch_sizes=[1] + models=[MODEL_ID], + training=False, + no_inference=False, + sequence_lengths=[8], + batch_sizes=[1], + no_multi_process=True, ) benchmark = PyTorchBenchmark(benchmark_args, configs=[config]) results = benchmark.run() @@ -73,7 +123,12 @@ def test_inference_encoder_decoder_with_configs(self): MODEL_ID = "sshleifer/tinier_bart" config = AutoConfig.from_pretrained(MODEL_ID) benchmark_args = PyTorchBenchmarkArguments( - models=[MODEL_ID], training=False, no_inference=False, sequence_lengths=[8], batch_sizes=[1] + models=[MODEL_ID], + training=False, + no_inference=False, + sequence_lengths=[8], + batch_sizes=[1], + no_multi_process=True, ) benchmark = PyTorchBenchmark(benchmark_args, configs=[config]) results = benchmark.run() @@ -81,26 +136,15 @@ def test_inference_encoder_decoder_with_configs(self): self.check_results_dict_not_empty(results.memory_inference_result) def test_train_with_configs(self): - MODEL_ID = "sshleifer/tiny-gpt2" - config = AutoConfig.from_pretrained(MODEL_ID) - benchmark_args = PyTorchBenchmarkArguments( - models=[MODEL_ID], training=True, no_inference=True, sequence_lengths=[8], batch_sizes=[1] - ) - benchmark = PyTorchBenchmark(benchmark_args, configs=[config]) - results = benchmark.run() - self.check_results_dict_not_empty(results.time_train_result) - self.check_results_dict_not_empty(results.memory_train_result) - - def test_train_with_configs_torchscript(self): MODEL_ID = "sshleifer/tiny-gpt2" config = AutoConfig.from_pretrained(MODEL_ID) benchmark_args = PyTorchBenchmarkArguments( models=[MODEL_ID], training=True, no_inference=True, - torchscript=True, sequence_lengths=[8], batch_sizes=[1], + no_multi_process=True, ) benchmark = PyTorchBenchmark(benchmark_args, configs=[config]) results = benchmark.run() @@ -111,7 +155,12 @@ def test_train_encoder_decoder_with_configs(self): MODEL_ID = "sshleifer/tinier_bart" config = AutoConfig.from_pretrained(MODEL_ID) benchmark_args = PyTorchBenchmarkArguments( - models=[MODEL_ID], training=True, no_inference=True, sequence_lengths=[8], batch_sizes=[1] + models=[MODEL_ID], + training=True, + no_inference=True, + sequence_lengths=[8], + batch_sizes=[1], + no_multi_process=True, ) benchmark = PyTorchBenchmark(benchmark_args, configs=[config]) results = benchmark.run() @@ -133,6 +182,7 @@ def test_save_csv_files(self): inference_memory_csv_file=os.path.join(tmp_dir, "inf_mem.csv"), train_time_csv_file=os.path.join(tmp_dir, "train_time.csv"), env_info_csv_file=os.path.join(tmp_dir, "env.csv"), + no_multi_process=True, ) benchmark = PyTorchBenchmark(benchmark_args) benchmark.run() @@ -161,6 +211,7 @@ def _check_summary_is_not_empty(summary): log_filename=os.path.join(tmp_dir, "log.txt"), log_print=True, trace_memory_line_by_line=True, + no_multi_process=True, ) benchmark = PyTorchBenchmark(benchmark_args) result = benchmark.run() diff --git a/tests/test_benchmark_tf.py b/tests/test_benchmark_tf.py new file mode 100644 index 00000000000000..b23ff51e509849 --- /dev/null +++ b/tests/test_benchmark_tf.py @@ -0,0 +1,165 @@ +import os +import tempfile +import unittest +from pathlib import Path + +from transformers import AutoConfig, is_tf_available + +from .utils import require_tf + + +if is_tf_available(): + import tensorflow as tf + from transformers import TensorflowBenchmark, TensorflowBenchmarkArguments + + +@require_tf +class TFBenchmarkTest(unittest.TestCase): + def check_results_dict_not_empty(self, results): + for model_result in results.values(): + for batch_size, sequence_length in zip(model_result["bs"], model_result["ss"]): + result = model_result["result"][batch_size][sequence_length] + self.assertIsNotNone(result) + + def test_inference_no_configs_eager(self): + MODEL_ID = "sshleifer/tiny-gpt2" + benchmark_args = TensorflowBenchmarkArguments( + models=[MODEL_ID], + training=False, + no_inference=False, + sequence_lengths=[8], + batch_sizes=[1], + eager_mode=True, + no_multi_process=True, + ) + benchmark = TensorflowBenchmark(benchmark_args) + results = benchmark.run() + self.check_results_dict_not_empty(results.time_inference_result) + self.check_results_dict_not_empty(results.memory_inference_result) + + def test_inference_no_configs_graph(self): + MODEL_ID = "sshleifer/tiny-gpt2" + benchmark_args = TensorflowBenchmarkArguments( + models=[MODEL_ID], + training=False, + no_inference=False, + sequence_lengths=[8], + batch_sizes=[1], + no_multi_process=True, + ) + benchmark = TensorflowBenchmark(benchmark_args) + results = benchmark.run() + self.check_results_dict_not_empty(results.time_inference_result) + self.check_results_dict_not_empty(results.memory_inference_result) + + def test_inference_with_configs_eager(self): + MODEL_ID = "sshleifer/tiny-gpt2" + config = AutoConfig.from_pretrained(MODEL_ID) + benchmark_args = TensorflowBenchmarkArguments( + models=[MODEL_ID], + training=False, + no_inference=False, + sequence_lengths=[8], + batch_sizes=[1], + eager_mode=True, + no_multi_process=True, + ) + benchmark = TensorflowBenchmark(benchmark_args, [config]) + results = benchmark.run() + self.check_results_dict_not_empty(results.time_inference_result) + self.check_results_dict_not_empty(results.memory_inference_result) + + def test_inference_with_configs_graph(self): + MODEL_ID = "sshleifer/tiny-gpt2" + config = AutoConfig.from_pretrained(MODEL_ID) + benchmark_args = TensorflowBenchmarkArguments( + models=[MODEL_ID], + training=False, + no_inference=False, + sequence_lengths=[8], + batch_sizes=[1], + no_multi_process=True, + ) + benchmark = TensorflowBenchmark(benchmark_args, [config]) + results = benchmark.run() + self.check_results_dict_not_empty(results.time_inference_result) + self.check_results_dict_not_empty(results.memory_inference_result) + + def test_inference_encoder_decoder_with_configs(self): + MODEL_ID = "patrickvonplaten/t5-tiny-random" + config = AutoConfig.from_pretrained(MODEL_ID) + benchmark_args = TensorflowBenchmarkArguments( + models=[MODEL_ID], + training=False, + no_inference=False, + sequence_lengths=[8], + batch_sizes=[1], + no_multi_process=True, + ) + benchmark = TensorflowBenchmark(benchmark_args, configs=[config]) + results = benchmark.run() + self.check_results_dict_not_empty(results.time_inference_result) + self.check_results_dict_not_empty(results.memory_inference_result) + + @unittest.skipIf(is_tf_available() and len(tf.config.list_physical_devices("GPU")) == 0, "Cannot do xla on CPU.") + def test_inference_no_configs_xla(self): + MODEL_ID = "sshleifer/tiny-gpt2" + benchmark_args = TensorflowBenchmarkArguments( + models=[MODEL_ID], + training=False, + no_inference=False, + sequence_lengths=[8], + batch_sizes=[1], + use_xla=True, + no_multi_process=True, + ) + benchmark = TensorflowBenchmark(benchmark_args) + results = benchmark.run() + self.check_results_dict_not_empty(results.time_inference_result) + self.check_results_dict_not_empty(results.memory_inference_result) + + def test_save_csv_files(self): + MODEL_ID = "sshleifer/tiny-gpt2" + with tempfile.TemporaryDirectory() as tmp_dir: + benchmark_args = TensorflowBenchmarkArguments( + models=[MODEL_ID], + no_inference=False, + save_to_csv=True, + sequence_lengths=[8], + batch_sizes=[1], + inference_time_csv_file=os.path.join(tmp_dir, "inf_time.csv"), + inference_memory_csv_file=os.path.join(tmp_dir, "inf_mem.csv"), + env_info_csv_file=os.path.join(tmp_dir, "env.csv"), + no_multi_process=True, + ) + benchmark = TensorflowBenchmark(benchmark_args) + benchmark.run() + self.assertTrue(Path(os.path.join(tmp_dir, "inf_time.csv")).exists()) + self.assertTrue(Path(os.path.join(tmp_dir, "inf_mem.csv")).exists()) + self.assertTrue(Path(os.path.join(tmp_dir, "env.csv")).exists()) + + def test_trace_memory(self): + MODEL_ID = "sshleifer/tiny-gpt2" + + def _check_summary_is_not_empty(summary): + self.assertTrue(hasattr(summary, "sequential")) + self.assertTrue(hasattr(summary, "cumulative")) + self.assertTrue(hasattr(summary, "current")) + self.assertTrue(hasattr(summary, "total")) + + with tempfile.TemporaryDirectory() as tmp_dir: + benchmark_args = TensorflowBenchmarkArguments( + models=[MODEL_ID], + no_inference=False, + sequence_lengths=[8], + batch_sizes=[1], + log_filename=os.path.join(tmp_dir, "log.txt"), + log_print=True, + trace_memory_line_by_line=True, + eager_mode=True, + no_multi_process=True, + ) + benchmark = TensorflowBenchmark(benchmark_args) + result = benchmark.run() + _check_summary_is_not_empty(result.inference_summary) + self.assertTrue(Path(os.path.join(tmp_dir, "log.txt")).exists()) From ece6b456c21df48726d8a6939131831682d627e7 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Thu, 18 Jun 2020 16:21:14 +0200 Subject: [PATCH 2/8] fix isort --- examples/longform-qa/eli5_app.py | 4 ++-- examples/longform-qa/eli5_utils.py | 6 +++--- examples/requirements.txt | 5 +++++ setup.cfg | 3 +++ 4 files changed, 13 insertions(+), 5 deletions(-) diff --git a/examples/longform-qa/eli5_app.py b/examples/longform-qa/eli5_app.py index 8826cd2fbae48e..a7d75565ae1631 100644 --- a/examples/longform-qa/eli5_app.py +++ b/examples/longform-qa/eli5_app.py @@ -1,11 +1,11 @@ +import faiss import nlp import numpy as np import torch +from elasticsearch import Elasticsearch -import faiss import streamlit as st import transformers -from elasticsearch import Elasticsearch from eli5_utils import ( embed_questions_for_retrieval, make_qa_s2s_model, diff --git a/examples/longform-qa/eli5_utils.py b/examples/longform-qa/eli5_utils.py index c3a1da26d96432..4f7d7a9d46d037 100644 --- a/examples/longform-qa/eli5_utils.py +++ b/examples/longform-qa/eli5_utils.py @@ -4,17 +4,17 @@ from random import choice, randint from time import time +import faiss # noqa: F401 import nlp # noqa: F401 import numpy as np import pandas as pd import torch import torch.utils.checkpoint as checkpoint +from elasticsearch import Elasticsearch # noqa: F401 +from elasticsearch.helpers import bulk, streaming_bulk # noqa: F401 from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler from tqdm import tqdm -import faiss # noqa: F401 -from elasticsearch import Elasticsearch # noqa: F401 -from elasticsearch.helpers import bulk, streaming_bulk # noqa: F401 from transformers import AdamW, AutoModel, AutoModelForSeq2SeqLM, AutoTokenizer, get_linear_schedule_with_warmup diff --git a/examples/requirements.txt b/examples/requirements.txt index 05d716bdc0790a..daf2081fe94295 100644 --- a/examples/requirements.txt +++ b/examples/requirements.txt @@ -8,3 +8,8 @@ tensorflow_datasets pytorch-lightning==0.7.6 matplotlib git-python==1.0.3 +faiss +streamlit +elasticsearch +pandas +nlp diff --git a/setup.cfg b/setup.cfg index 5badc1ae760a13..cf83f37c6f96da 100644 --- a/setup.cfg +++ b/setup.cfg @@ -5,12 +5,15 @@ include_trailing_comma = True known_first_party = transformers known_third_party = absl + elasticsearch fairseq + faiss fastprogress git h5py matplotlib MeCab + nlp nltk numpy packaging From 148a9b85d257876b0dabffc24eaaca954ad4d05b Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Thu, 18 Jun 2020 16:22:19 +0200 Subject: [PATCH 3/8] fix setup cfg --- setup.cfg | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.cfg b/setup.cfg index cf83f37c6f96da..93897a5ec1aaac 100644 --- a/setup.cfg +++ b/setup.cfg @@ -7,7 +7,7 @@ known_third_party = absl elasticsearch fairseq - faiss + faiss fastprogress git h5py From 85a80750e223a5628c16df0bd4712f9112ff1929 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Thu, 18 Jun 2020 16:22:55 +0200 Subject: [PATCH 4/8] retab --- setup.cfg | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.cfg b/setup.cfg index 93897a5ec1aaac..0b4c1af0714ce0 100644 --- a/setup.cfg +++ b/setup.cfg @@ -7,7 +7,7 @@ known_third_party = absl elasticsearch fairseq - faiss + faiss fastprogress git h5py From 62bf420c1b92f21c212c5ed5dd145f74ffe87041 Mon Sep 17 00:00:00 2001 From: patrickvonplaten Date: Thu, 18 Jun 2020 22:08:09 +0200 Subject: [PATCH 5/8] fix time measuring of tf graph mode --- src/transformers/benchmark/benchmark_tf.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/benchmark/benchmark_tf.py b/src/transformers/benchmark/benchmark_tf.py index 5bb34c25f0fe7a..594ac53fc02969 100644 --- a/src/transformers/benchmark/benchmark_tf.py +++ b/src/transformers/benchmark/benchmark_tf.py @@ -126,11 +126,11 @@ def _prepare_inference_func(self, model_name, batch_size, sequence_length): @run_with_tf_optimizations(self.args.eager_mode, self.args.use_xla) def encoder_decoder_forward(): - model(input_ids, decoder_input_ids=input_ids, training=False) + return model(input_ids, decoder_input_ids=input_ids, training=False) @run_with_tf_optimizations(self.args.eager_mode, self.args.use_xla) def encoder_forward(): - model(input_ids, training=False) + return model(input_ids, training=False) _inference = encoder_decoder_forward if config.is_encoder_decoder else encoder_forward From 3df1cfcaf214f5fe7a2ba2b1d62905983ddd69f9 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Fri, 19 Jun 2020 11:29:21 +0000 Subject: [PATCH 6/8] fix tf cuda --- src/transformers/benchmark/benchmark.py | 6 +++-- .../benchmark/benchmark_args_tf.py | 25 +++++++++---------- 2 files changed, 16 insertions(+), 15 deletions(-) diff --git a/src/transformers/benchmark/benchmark.py b/src/transformers/benchmark/benchmark.py index 2136cdf7ab1a38..f54ad229e50bcd 100644 --- a/src/transformers/benchmark/benchmark.py +++ b/src/transformers/benchmark/benchmark.py @@ -102,11 +102,13 @@ def _prepare_inference_func(self, model_name, batch_size, sequence_length): def encoder_decoder_forward(): with torch.no_grad(): - inference_model(input_ids, decoder_input_ids=input_ids) + outputs = inference_model(input_ids, decoder_input_ids=input_ids) + return outputs def encoder_forward(): with torch.no_grad(): - inference_model(input_ids) + outputs = inference_model(input_ids) + return outputs _forward = encoder_decoder_forward if config.is_encoder_decoder else encoder_forward return _forward diff --git a/src/transformers/benchmark/benchmark_args_tf.py b/src/transformers/benchmark/benchmark_args_tf.py index 881b581ee85d8d..0f2b243c3838db 100644 --- a/src/transformers/benchmark/benchmark_args_tf.py +++ b/src/transformers/benchmark/benchmark_args_tf.py @@ -47,9 +47,7 @@ class TensorflowBenchmarkArguments(BenchmarkArguments): @cached_property @tf_required - def _setup_strategy(self) -> Tuple["tf.distribute.Strategy", "tf.distribute.cluster_resolver.TPUClusterResolver"]: - logger.info("Tensorflow: setting up strategy") - + def _setup_tpu(self) -> Tuple["tf.distribute.cluster_resolver.TPUClusterResolver"]: if not self.no_tpu: try: if self.tpu_name: @@ -58,14 +56,16 @@ def _setup_strategy(self) -> Tuple["tf.distribute.Strategy", "tf.distribute.clus tpu = tf.distribute.cluster_resolver.TPUClusterResolver() except ValueError: tpu = None - else: - tpu = None + return tpu - if tpu is not None: - tf.config.experimental_connect_to_cluster(tpu) - tf.tpu.experimental.initialize_tpu_system(tpu) + @cached_property + @tf_required + def _setup_strategy(self) -> Tuple["tf.distribute.Strategy", "tf.distribute.cluster_resolver.TPUClusterResolver"]: + if self.is_tpu: + tf.config.experimental_connect_to_cluster(self._setup_tpu) + tf.tpu.experimental.initialize_tpu_system(self._setup_tpu) - strategy = tf.distribute.experimental.TPUStrategy(tpu) + strategy = tf.distribute.experimental.TPUStrategy(self._setup_tpu) else: # currently no multi gpu is allowed if self.is_gpu: @@ -76,18 +76,17 @@ def _setup_strategy(self) -> Tuple["tf.distribute.Strategy", "tf.distribute.clus tf.config.experimental.set_visible_devices([], "GPU") # disable GPU strategy = tf.distribute.OneDeviceStrategy(device=f"/cpu:{self.device_idx}") - return strategy, tpu + return strategy @property @tf_required def is_tpu(self) -> bool: - tpu = self._setup_strategy[1] - return tpu is not None and not self.no_tpu + return self._setup_tpu is not None @property @tf_required def strategy(self) -> "tf.distribute.Strategy": - return self._setup_strategy[0] + return self._setup_strategy @property @tf_required From 9ec8adc124ac21e8353378116a748878db83a558 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Fri, 19 Jun 2020 12:52:25 +0000 Subject: [PATCH 7/8] clean code --- examples/benchmarking/run_benchmark.py | 2 +- examples/benchmarking/run_benchmark_tf.py | 2 +- src/transformers/benchmark/benchmark.py | 30 ++++++++++----- src/transformers/benchmark/benchmark_tf.py | 30 ++++++++++----- src/transformers/benchmark/benchmark_utils.py | 38 +++++++++++++------ 5 files changed, 70 insertions(+), 32 deletions(-) diff --git a/examples/benchmarking/run_benchmark.py b/examples/benchmarking/run_benchmark.py index 163bcfb6fc2501..f995b8212ab4b0 100644 --- a/examples/benchmarking/run_benchmark.py +++ b/examples/benchmarking/run_benchmark.py @@ -1,5 +1,5 @@ # coding=utf-8 -# Copyright 2018 The HuggingFace Inc. team. +# Copyright 2020 The HuggingFace Inc. team. # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/examples/benchmarking/run_benchmark_tf.py b/examples/benchmarking/run_benchmark_tf.py index 5d578e0f8c7ffc..c37c833c9678b5 100644 --- a/examples/benchmarking/run_benchmark_tf.py +++ b/examples/benchmarking/run_benchmark_tf.py @@ -1,6 +1,6 @@ # coding=utf-8 # Copyright 2018 The HuggingFace Inc. team. -# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/src/transformers/benchmark/benchmark.py b/src/transformers/benchmark/benchmark.py index f54ad229e50bcd..16c08a244d2b8d 100644 --- a/src/transformers/benchmark/benchmark.py +++ b/src/transformers/benchmark/benchmark.py @@ -20,6 +20,7 @@ import logging import timeit +from typing import Callable, Optional from transformers import ( MODEL_MAPPING, @@ -29,7 +30,14 @@ is_torch_available, ) -from .benchmark_utils import Benchmark, Memory, measure_peak_memory_cpu, start_memory_tracing, stop_memory_tracing +from .benchmark_utils import ( + Benchmark, + Memory, + MemorySummary, + measure_peak_memory_cpu, + start_memory_tracing, + stop_memory_tracing, +) if is_torch_available(): @@ -54,23 +62,27 @@ class PyTorchBenchmark(Benchmark): def framework_version(self): return torch.__version__ - def _inference_speed(self, model_name, batch_size, sequence_length): + def _inference_speed(self, model_name: str, batch_size: int, sequence_length: int) -> float: _inference = self._prepare_inference_func(model_name, batch_size, sequence_length) return self._measure_speed(_inference) - def _inference_memory(self, model_name, batch_size, sequence_length): + def _inference_memory( + self, model_name: str, batch_size: int, sequence_length: int + ) -> [Memory, Optional[MemorySummary]]: _inference = self._prepare_inference_func(model_name, batch_size, sequence_length) return self._measure_memory(_inference) - def _train_speed(self, model_name, batch_size, sequence_length): + def _train_speed(self, model_name: str, batch_size: int, sequence_length: int) -> float: _train = self._prepare_train_func(model_name, batch_size, sequence_length) return self._measure_speed(_train) - def _train_memory(self, model_name, batch_size, sequence_length): + def _train_memory( + self, model_name: str, batch_size: int, sequence_length: int + ) -> [Memory, Optional[MemorySummary]]: _train = self._prepare_train_func(model_name, batch_size, sequence_length) return self._measure_memory(_train) - def _prepare_inference_func(self, model_name, batch_size, sequence_length): + def _prepare_inference_func(self, model_name: str, batch_size: int, sequence_length: int) -> Callable[[], None]: config = self.config_dict[model_name] if self.args.torchscript: @@ -113,7 +125,7 @@ def encoder_forward(): _forward = encoder_decoder_forward if config.is_encoder_decoder else encoder_forward return _forward - def _prepare_train_func(self, model_name, batch_size, sequence_length): + def _prepare_train_func(self, model_name: str, batch_size: int, sequence_length: int) -> Callable[[], None]: config = self.config_dict[model_name] model = MODEL_WITH_LM_HEAD_MAPPING[config.__class__](config) @@ -154,7 +166,7 @@ def compute_loss_and_backprob_encoder_decoder(): ) return _train - def _measure_speed(self, func): + def _measure_speed(self, func) -> float: try: if self.args.is_tpu or self.args.torchscript: # run additional 10 times to stabilize compilation for tpu and torchscript @@ -176,7 +188,7 @@ def _measure_speed(self, func): self.print_fn("Doesn't fit on GPU. {}".format(e)) return "N/A" - def _measure_memory(self, func): + def _measure_memory(self, func: Callable[[], None]) -> [Memory, MemorySummary]: try: if self.args.trace_memory_line_by_line: trace = start_memory_tracing("transformers") diff --git a/src/transformers/benchmark/benchmark_tf.py b/src/transformers/benchmark/benchmark_tf.py index 594ac53fc02969..4a92e863a136ef 100644 --- a/src/transformers/benchmark/benchmark_tf.py +++ b/src/transformers/benchmark/benchmark_tf.py @@ -22,6 +22,7 @@ import random import timeit from functools import wraps +from typing import Callable, Optional from transformers import ( TF_MODEL_MAPPING, @@ -31,7 +32,14 @@ is_tf_available, ) -from .benchmark_utils import Benchmark, Memory, measure_peak_memory_cpu, start_memory_tracing, stop_memory_tracing +from .benchmark_utils import ( + Benchmark, + Memory, + MemorySummary, + measure_peak_memory_cpu, + start_memory_tracing, + stop_memory_tracing, +) if is_tf_available(): @@ -45,20 +53,20 @@ logger = logging.getLogger(__name__) -def run_with_tf_optimizations(do_eager_mode, do_xla): +def run_with_tf_optimizations(do_eager_mode: bool, use_xla: bool): def run_func(func): @wraps(func) def run_in_eager_mode(*args, **kwargs): return func(*args, **kwargs) @wraps(func) - @tf.function(experimental_compile=do_xla) + @tf.function(experimental_compile=use_xla) def run_in_graph_mode(*args, **kwargs): return func(*args, **kwargs) if do_eager_mode is True: assert ( - do_xla is False + use_xla is False ), "Cannot run model in XLA, if `args.eager_mode` is set to `True`. Please set `args.eager_mode=False`." return run_in_eager_mode else: @@ -67,7 +75,7 @@ def run_in_graph_mode(*args, **kwargs): return run_func -def random_input_ids(batch_size, sequence_length, vocab_size): +def random_input_ids(batch_size: int, sequence_length: int, vocab_size: int) -> ["tf.Tensor"]: rng = random.Random() values = [rng.randint(0, vocab_size - 1) for i in range(batch_size * sequence_length)] return tf.constant(values, shape=(batch_size, sequence_length), dtype=tf.int32) @@ -83,7 +91,7 @@ class TensorflowBenchmark(Benchmark): def framework_version(self): return tf.__version__ - def _inference_speed(self, model_name, batch_size, sequence_length): + def _inference_speed(self, model_name: str, batch_size: int, sequence_length: int) -> float: # initialize GPU on separate process strategy = self.args.strategy assert strategy is not None, "A device strategy has to be initialized before using Tensorflow." @@ -95,7 +103,9 @@ def _train_speed(self, model_name, batch_size, sequence_length): "Training is currently not really implemented." "Wait for TFTrainer to support CLM and MLM." ) - def _inference_memory(self, model_name, batch_size, sequence_length): + def _inference_memory( + self, model_name: str, batch_size: int, sequence_length: int + ) -> [Memory, Optional[MemorySummary]]: # initialize GPU on separate process if self.args.is_gpu: tf.config.experimental.set_memory_growth(self.args.gpu_list[self.args.device_idx], True) @@ -109,7 +119,7 @@ def _train_memory(self, model_name, batch_size, sequence_length): "Training is currently not really implemented. Wait for TFTrainer to support CLM and MLM." ) - def _prepare_inference_func(self, model_name, batch_size, sequence_length): + def _prepare_inference_func(self, model_name: str, batch_size: int, sequence_length: int) -> Callable[[], None]: config = self.config_dict[model_name] if self.args.fp16: @@ -136,7 +146,7 @@ def encoder_forward(): return _inference - def _measure_speed(self, func): + def _measure_speed(self, func) -> float: with self.args.strategy.scope(): try: if self.args.is_tpu or self.args.use_xla: @@ -151,7 +161,7 @@ def _measure_speed(self, func): except ResourceExhaustedError as e: self.print_fn("Doesn't fit on GPU. {}".format(e)) - def _measure_memory(self, func): + def _measure_memory(self, func: Callable[[], None]) -> [Memory, MemorySummary]: logger.info( "Note that Tensorflow allocates more memory than" "it might need to speed up computation." diff --git a/src/transformers/benchmark/benchmark_utils.py b/src/transformers/benchmark/benchmark_utils.py index 166adeef39208a..682887d0d5dc8c 100644 --- a/src/transformers/benchmark/benchmark_utils.py +++ b/src/transformers/benchmark/benchmark_utils.py @@ -61,11 +61,23 @@ ) -def separate_process_wrapper_fn(func, do_multi_processing): +def separate_process_wrapper_fn(func: Callable[[], None], do_multi_processing: bool) -> Callable[[], None]: + """ + This function wraps another function into its own separated process. + In order to ensure accurate memory measurements it is important that the function + is executed in a separate process + + Args: + - `func`: (`callable`): function() -> ... + generic function which will be executed in its own separate process + - `do_multi_processing`: (`bool`) + Whether to run function on separate process or not + """ + def multi_process_func(*args, **kwargs): # run function in an individual # process to get correct memory - def wrapper_func(queue, *args): + def wrapper_func(queue: Queue, *args): try: result = func(*args) except Exception as e: @@ -82,7 +94,7 @@ def wrapper_func(queue, *args): return result if do_multi_processing: - logging.info("Use multiprocessing...") + logging.info("fFunction {func} is executed in its own process...") return multi_process_func else: return func @@ -590,31 +602,35 @@ def framework_version(self): pass @abstractmethod - def _inference_speed(self, model_name, batch_size, sequence_length): + def _inference_speed(self, model_name: str, batch_size: int, sequence_length: int) -> float: pass @abstractmethod - def _train_speed(self, model_name, batch_size, sequence_length): + def _train_speed(self, model_name: str, batch_size: int, sequence_length: int) -> float: pass @abstractmethod - def _inference_memory(self, model_name, batch_size, sequence_length): + def _inference_memory( + self, model_name: str, batch_size: int, sequence_length: int + ) -> [Memory, Optional[MemorySummary]]: pass @abstractmethod - def _train_memory(self, model_name, batch_size, sequence_length): + def _train_memory( + self, model_name: str, batch_size: int, sequence_length: int + ) -> [Memory, Optional[MemorySummary]]: pass - def inference_speed(self, *args, **kwargs): + def inference_speed(self, *args, **kwargs) -> float: return separate_process_wrapper_fn(self._inference_speed, self.args.do_multi_processing)(*args, **kwargs) - def train_speed(self, *args, **kwargs): + def train_speed(self, *args, **kwargs) -> float: return separate_process_wrapper_fn(self._train_speed, self.args.do_multi_processing)(*args, **kwargs) - def inference_memory(self, *args, **kwargs): + def inference_memory(self, *args, **kwargs) -> [Memory, Optional[MemorySummary]]: return separate_process_wrapper_fn(self._inference_memory, self.args.do_multi_processing)(*args, **kwargs) - def train_memory(self, *args, **kwargs): + def train_memory(self, *args, **kwargs) -> [Memory, Optional[MemorySummary]]: return separate_process_wrapper_fn(self._train_memory, self.args.do_multi_processing)(*args, **kwargs) def run(self): From 8b710419a456c4b6b56f06463277ad673c41aced Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Fri, 19 Jun 2020 13:36:00 +0000 Subject: [PATCH 8/8] better error message --- src/transformers/benchmark/benchmark.py | 2 +- src/transformers/benchmark/benchmark_args_utils.py | 3 +++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/src/transformers/benchmark/benchmark.py b/src/transformers/benchmark/benchmark.py index 16c08a244d2b8d..a24c5028e999e9 100644 --- a/src/transformers/benchmark/benchmark.py +++ b/src/transformers/benchmark/benchmark.py @@ -196,7 +196,7 @@ def _measure_memory(self, func: Callable[[], None]) -> [Memory, MemorySummary]: if self.args.is_tpu: # tpu raise NotImplementedError( - "Memory Benchmarking is currently not implemented for TPU. Please disable memory benchmarking with `args.no_memory=True`" + "Memory Benchmarking is currently not implemented for TPU. Please disable memory benchmarking with `--no_memory` or `args.no_memory=True`" ) elif self.args.is_gpu: if not is_py3nvml_available(): diff --git a/src/transformers/benchmark/benchmark_args_utils.py b/src/transformers/benchmark/benchmark_args_utils.py index 7962c68288cc5f..5f7dbff672e620 100644 --- a/src/transformers/benchmark/benchmark_args_utils.py +++ b/src/transformers/benchmark/benchmark_args_utils.py @@ -114,6 +114,9 @@ def to_json_string(self): @property def model_names(self): + assert ( + len(self.models) > 0 + ), "Please make sure you provide at least one model name / model identifier, *e.g.* `--models bert-base-cased` or `args.models = ['bert-base-cased']." return self.models @property