Skip to content

Commit

Permalink
add torch fp16
Browse files Browse the repository at this point in the history
  • Loading branch information
patrickvonplaten committed Jun 17, 2020
1 parent 6fb6fe6 commit c5fa0d5
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 0 deletions.
4 changes: 4 additions & 0 deletions examples/benchmarking/run_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,13 @@
""" Benchmarking the library on inference and training """

from transformers import HfArgumentParser, PyTorchBenchmark, PyTorchBenchmarkArguments
import logging

logger = logging.getLogger(__name__)


def main():
logger.setLevel(logging.WARN)
parser = HfArgumentParser(PyTorchBenchmarkArguments)
benchmark_args = parser.parse_args_into_dataclasses()[0]
benchmark = PyTorchBenchmark(args=benchmark_args)
Expand Down
25 changes: 25 additions & 0 deletions tests/test_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,21 @@ def test_inference_torchscript(self):
self.check_results_dict_not_empty(results.time_inference_result)
self.check_results_dict_not_empty(results.memory_inference_result)

def test_inference_fp16(self):
MODEL_ID = "sshleifer/tiny-gpt2"
benchmark_args = PyTorchBenchmarkArguments(
models=[MODEL_ID],
training=False,
no_inference=False,
fp16=True,
sequence_lengths=[8],
batch_sizes=[1],
)
benchmark = PyTorchBenchmark(benchmark_args)
results = benchmark.run()
self.check_results_dict_not_empty(results.time_inference_result)
self.check_results_dict_not_empty(results.memory_inference_result)

def test_train_no_configs(self):
MODEL_ID = "sshleifer/tiny-gpt2"
benchmark_args = PyTorchBenchmarkArguments(
Expand All @@ -58,6 +73,16 @@ def test_train_no_configs(self):
self.check_results_dict_not_empty(results.time_train_result)
self.check_results_dict_not_empty(results.memory_train_result)

def test_train_no_configs_fp16(self):
MODEL_ID = "sshleifer/tiny-gpt2"
benchmark_args = PyTorchBenchmarkArguments(
models=[MODEL_ID], training=True, no_inference=True, sequence_lengths=[8], batch_sizes=[1], fp16=True
)
benchmark = PyTorchBenchmark(benchmark_args)
results = benchmark.run()
self.check_results_dict_not_empty(results.time_train_result)
self.check_results_dict_not_empty(results.memory_train_result)

def test_inference_with_configs(self):
MODEL_ID = "sshleifer/tiny-gpt2"
config = AutoConfig.from_pretrained(MODEL_ID)
Expand Down

0 comments on commit c5fa0d5

Please sign in to comment.