Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GRAPH EXECUTOR,VM] Add benchmarking function to graph executor and vm #8807

Merged
merged 6 commits into from
Aug 26, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion python/tvm/contrib/graph_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,7 +369,7 @@ def benchmark(self, device, func_name="run", repeat=5, number=5, min_repeat_ms=N

Returns
-------
timing_results : ProfileResult
timing_results : BenchmarkResult
Runtimes of the function. Use `.mean` to access the mean runtime, use `.results` to
jwfromm marked this conversation as resolved.
Show resolved Hide resolved
access the individual runtimes.
"""
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/driver/tvmc/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,7 +399,7 @@ def format_times(self):
str
A formatted string containing the statistics.
"""
return str(times)
return str(self.times)

def get_output(self, name: str):
"""A helper function to grab one of the outputs by name.
Expand Down
26 changes: 15 additions & 11 deletions python/tvm/runtime/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,20 +47,24 @@ def __init__(self, results: Sequence[float]):
min : float
Minimum runtime in seconds of all results.
mean : float
Mean runtime in seconds of all results. Note that this mean is not
necessarily statistically correct as it is the mean of mean
runtimes.
Mean runtime in seconds of all results. If py:meth:`Module.time_evaluator` or
`benchmark` is called with `number` > 0, then each result is already the mean of a
`number` of runtimes, so this becomes the mean of means.
median : float
Median runtime in seconds of all results. Note that this is not necessarily
statistically correct as it is the median of mean runtimes.
Median runtime in seconds of all results. If py:meth:`Module.time_evaluator` is called
with `number` > 0, then each result is already the mean of a `number` of runtimes, so
this becomes the median of means.
max : float
Maximum runtime in seconds of all results.
Maximum runtime in seconds of all results. If py:meth:`Module.time_evaluator` is called
with `number` > 0, then each result is already the mean of a `number` of runtimes, so
this becomes the maximum of those means.
std : float
Standard deviation in seconds of runtimes. Note that this is not necessarily
correct as it is the std of mean runtimes.
Standard deviation in seconds of runtimes. If py:meth:`Module.time_evaluator` is called
with `number` > 0, then each result is already the mean of a `number` of runtimes, so
this becomes the standard deviation of means.
results : Sequence[float]
The collected runtimes (in seconds). This may be a series of mean runtimes if
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this needs more explanation. Currently a BenchmarkResult object contains no information on the benchmark parameters that were used, and it would be best (IMHO) to avoid there being surprises in terms of the interpretation of the results based on how the object was created. My recommendation would be to either fully document the behavior of what it means to benchmark with 'number > 1' or ensure that the BenchmarkResult object itself contains the benchmark parameters used.

the benchmark was run with `number` > 1.
py:meth:`Module.time_evaluator` or `benchmark` was run with `number` > 1.
"""
self.results = results
self.mean = np.mean(self.results)
Expand All @@ -77,7 +81,7 @@ def __repr__(self):
def __str__(self):
return """Execution time summary:
{:^12} {:^12} {:^12} {:^12} {:^12}
{:^12.2f} {:^12.2f} {:^12.2f} {:^12.2f} {:^12.2f}
{:^12.4f} {:^12.4f} {:^12.4f} {:^12.4f} {:^12.4f}
""".format(
"mean (ms)",
"median (ms)",
Expand Down Expand Up @@ -292,7 +296,7 @@ def evaluator(*args):

return evaluator
except NameError:
raise NameError("time_evaluate is only supported when RPC is enabled")
raise NameError("time_evaluator is only supported when RPC is enabled")

def _collect_from_import_tree(self, filter_func):
"""Helper function to collect modules from the tree matching a filter_func, then return it.
Expand Down
2 changes: 1 addition & 1 deletion tests/python/driver/tvmc/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def test_tvmc_workflow(keras_simple):
assert type(result) is TVMCResult
assert path.exists(tuning_records)
assert type(result.outputs) is dict
assert type(result.times) is tuple
assert type(result.times) is tvm.runtime.module.BenchmarkResult
assert "output_0" in result.outputs.keys()


Expand Down
2 changes: 1 addition & 1 deletion tests/python/driver/tvmc/test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,5 +101,5 @@ def test_run_tflite_module__with_profile__valid_input(
tiger_cat_mobilenet_id in top_5_ids
), "tiger cat is expected in the top-5 for mobilenet v1"
assert type(result.outputs) is dict
assert type(result.times) is tuple
assert type(result.times) is BenchmarkResult
assert "output_0" in result.outputs.keys()
13 changes: 13 additions & 0 deletions tests/python/relay/test_backend_graph_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# under the License.
import numpy as np
import pytest
from unittest.mock import patch

import tvm
import json
Expand Down Expand Up @@ -334,6 +335,18 @@ def test_benchmark():
assert result.mean > 0
assert len(result.results) == 2

with patch.object(
tvm.runtime.module.Module,
"time_evaluator",
return_value=lambda: tvm.runtime.module.BenchmarkResult([1, 2, 2, 5]),
) as method:
result = exe.benchmark(tvm.cpu(), data=data, func_name="run", repeat=2, number=1)
assert result.mean == 2.5
assert result.median == 2.0
assert result.max == 5
assert result.min == 1
assert result.std == 1.5


if __name__ == "__main__":
pytest.main([__file__])
13 changes: 13 additions & 0 deletions tests/python/relay/test_vm.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import numpy as np
import pytest
import time
from unittest.mock import patch

import tvm
from tvm import runtime
Expand Down Expand Up @@ -967,6 +968,18 @@ def test_benchmark():
assert result.mean > 0
assert len(result.results) == 2

with patch.object(
tvm.runtime.module.Module,
"time_evaluator",
return_value=lambda x: tvm.runtime.module.BenchmarkResult([1, 2, 2, 5]),
) as method:
result = exe.benchmark(tvm.cpu(), data, func_name="main", repeat=2, number=1)
assert result.mean == 2.5
assert result.median == 2.0
assert result.max == 5
assert result.min == 1
assert result.std == 1.5


if __name__ == "__main__":
pytest.main([__file__])
11 changes: 11 additions & 0 deletions tests/python/unittest/test_runtime_measure.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import tvm
from tvm import te
from tvm.contrib.utils import tempdir
from tvm.runtime.module import BenchmarkResult


def test_min_repeat_ms():
Expand Down Expand Up @@ -56,5 +57,15 @@ def my_debug(filename):
assert ct > 10 + 2


def test_benchmark_result():
r = BenchmarkResult([1, 2, 2, 5])
assert r.mean == 2.5
assert r.median == 2.0
assert r.min == 1
assert r.max == 5
assert r.std == 1.5


if __name__ == "__main__":
test_min_repeat_ms()
test_benchmark_result()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, this is odd. Are we manually calling each test case in each test file? pytest should do this for you.