Skip to content

Commit

Permalink
[hexagon][testing] refactor benchmark-table code
Browse files Browse the repository at this point in the history
Generalize the benchmark-table code to support arbitrary
independent values. This supports future changes to the benchmark
code.
  • Loading branch information
Christian Convey committed May 25, 2022
1 parent 92cc5b0 commit f523fac
Show file tree
Hide file tree
Showing 2 changed files with 176 additions and 113 deletions.
154 changes: 41 additions & 113 deletions tests/python/contrib/test_hexagon/benchmark_hexagon.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,16 @@

import os
import os.path
import pathlib
import sys
import pytest
import numpy as np
import logging
import tempfile
import csv

import tvm.testing
from tvm import te
from tvm.contrib.hexagon.build import HexagonLauncherRPC
from .benchmark_util import BenchmarksTable

RPC_SERVER_PORT = 7070

Expand All @@ -40,7 +39,6 @@
# triggering TIME_WAIT state on the server socket. This prevents another
# server to bind to the same port until the wait time elapses.


@tvm.testing.requires_hexagon
def test_elemwise_add(hexagon_launcher: HexagonLauncherRPC):
"""
Expand All @@ -58,112 +56,22 @@ def test_elemwise_add(hexagon_launcher: HexagonLauncherRPC):
print("-" * 80)
print()

# TODO: We should move this into a separate test fixture, to make it easier to write
# additional benchmarking functions. We'd just need to generalize the assumptions regarding
# the particular fields being tracked as independent variables.
class benchmark_results_collection:
def __init__(self):
self.row_dicts_ = []

def num_failures(self):
num = 0
for d in self.row_dicts_:
if d["status"] == "FAIL":
num += 1
return num

def num_skips(self):
num = 0
for d in self.row_dicts_:
if d["status"] == "SKIP":
num += 1
return num

def record_success(
self, dtype, sched_type, mem_scope, num_vecs_per_tensor, benchmark_result
):
median_usec = benchmark_result.median * 1000000
min_usec = benchmark_result.min * 1000000
max_usec = benchmark_result.max * 1000000

self.row_dicts_.append(
{
"dtype": dtype,
"sched_type": sched_type,
"mem_scope": mem_scope,
"num_vecs_per_tensor": num_vecs_per_tensor,
"status": "OK",
"median(µsec)": f"{median_usec:.3}",
"min(µsec)": f"{min_usec:.3}",
"max(µsec)": f"{max_usec:.3}",
}
)

def record_failure(self, dtype, sched_type, mem_scope, num_vecs_per_tensor, error_text):
self.row_dicts_.append(
{
"dtype": dtype,
"sched_type": sched_type,
"mem_scope": mem_scope,
"num_vecs_per_tensor": num_vecs_per_tensor,
"status": "FAIL",
"comment": error_text,
}
)

def record_skip(self, dtype, sched_type, mem_scope, num_vecs_per_tensor, comment_text):
self.row_dicts_.append(
{
"dtype": dtype,
"sched_type": sched_type,
"mem_scope": mem_scope,
"num_vecs_per_tensor": num_vecs_per_tensor,
"status": "SKIP",
"comment": comment_text,
}
)

def dump(self, f):
csv.register_dialect(
"benchmarks",
delimiter="\t",
quotechar='"',
quoting=csv.QUOTE_MINIMAL,
)

fieldnames = [
"dtype",
"sched_type",
"mem_scope",
"num_vecs_per_tensor",
"status",
"median(µsec)",
"min(µsec)",
"max(µsec)",
"comment",
]

writer = csv.DictWriter(f, fieldnames, dialect="benchmarks", restval="")

writer.writeheader()
for d in self.row_dicts_:
writer.writerow(d)

br = benchmark_results_collection()
bt = BenchmarksTable()

# Create and benchmark a single primfunc.
# If an unexpected problem occurs, raise an exception. Otherwise add a row of output to 'br'.
# If an unexpected problem occurs, raise an exception. Otherwise add a row of output to 'bt'.
def test_one_config(dtype, sched_type, mem_scope, num_vectors_per_tensor):
version_name = f"dtype:{dtype}-schedtype:{sched_type}-memscope:{mem_scope}-numvecs:{num_vectors_per_tensor}"
print()
print(f"CONFIGURATION: {version_name}")

if num_vectors_per_tensor == 2048 and mem_scope == "global.vtcm":
br.record_skip(
dtype,
sched_type,
mem_scope,
num_vectors_per_tensor,
f"Expect to exceed VTCM budget.",
bt.record_skip(
dtype=dtype,
sched_type=sched_type,
mem_scope=mem_scope,
num_vectors_per_tensor=num_vectors_per_tensor,
comments="Expect to exceed VTCM budget.",
)
return

Expand Down Expand Up @@ -255,25 +163,45 @@ def test_one_config(dtype, sched_type, mem_scope, num_vectors_per_tensor):
timer = mod.time_evaluator("elemwise_add", sess.device, number=10, repeat=1)
timing_result = timer(A_data, B_data, C_data)

print("TIMING RESULT: {}".format(timing_result))

# Verify that the computation actually happened, and produced the correct result.
result = C_data.numpy()
tvm.testing.assert_allclose(host_numpy_C_data_expected, result)

br.record_success(
dtype, sched_type, mem_scope, num_vectors_per_tensor, timing_result
)
bt.record_success(
timing_result,
dtype=dtype,
sched_type=sched_type,
mem_scope=mem_scope,
num_vectors_per_tensor=num_vectors_per_tensor,
)

except Exception as err:
f.write("ERROR:\n")
f.write("{}\n".format(err))
br.record_failure(
dtype, sched_type, mem_scope, num_vectors_per_tensor, f"See {report_path}"
)
bt.record_fail(
dtype=dtype,
sched_type=sched_type,
mem_scope=mem_scope,
num_vectors_per_tensor=num_vectors_per_tensor,
comments=f"See {report_path}"
)

# -----------------------------------------------------------------------------------------------

csv_column_order = [
'dtype',
'sched_type',
'mem_scope',
'num_vectors_per_tensor',
'row_status',
"timings_min_usecs",
"timings_max_usecs",
"timings_median_usecs",
"timings_mean_usecs",
"timings_stddev_usecs",
'comments',
]

# Hexagon v69 allows more dtypes, but we're sticking with v68 for now.
for dtype in [
"int8",
Expand All @@ -300,7 +228,7 @@ def test_one_config(dtype, sched_type, mem_scope, num_vectors_per_tensor):
test_one_config(dtype, sched_type, mem_scope, num_vectors_per_tensor)

# Report our progress.
br.dump(sys.stdout)
bt.print_csv(sys.stdout, csv_column_order)

print("-" * 80)
print(f"OUTPUT DIRECTORY: {host_output_dir}")
Expand All @@ -309,8 +237,8 @@ def test_one_config(dtype, sched_type, mem_scope, num_vectors_per_tensor):

tabular_output_filename = os.path.join(host_output_dir, "benchmark-results.csv")
with open(tabular_output_filename, "w") as csv_file:
br.dump(csv_file)
bt.print_csv(csv_file, csv_column_order)
print(f"BENCHMARK RESULTS FILE: {tabular_output_filename}")

if br.num_failures() > 0:
if bt.has_fail() > 0:
pytest.fail("At least one benchmark configuration failed", pytrace=False)
135 changes: 135 additions & 0 deletions tests/python/contrib/test_hexagon/benchmark_util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

import csv

class BenchmarksTable:
"""
Stores/reports the result of benchmark runs.
Each line item has a status: success, fail, or skip.
Each 'success' line item must include benchmark data,
in the form provided by TVM's `time_evaluator` mechanism.
Each line item may also specify values for any subset of
the columns provided to the table's construstor.
"""
BUILTIN_COLUMN_NAMES = set([
"row_status",
"timings_min_usecs",
"timings_max_usecs",
"timings_median_usecs",
"timings_mean_usecs",
"timings_stddev_usecs",
])

def __init__(self):
self._line_items = []

def validate_user_supplied_kwargs(self, kwarg_dict):
name_conflicts = set(kwarg_dict).intersection(self.BUILTIN_COLUMN_NAMES)

if name_conflicts:
name_list = ', '.join(name_conflicts)
raise Exception(f'Attempting to supply values for built-in column names: {name_list}')

def record_success(self, timings, **kwargs):
"""
`timings` : Assumed to have the structure and meaning of
the timing results provided by TVM's `time_evaluator`
mechanism.
`kwargs` : Optional values for any of the other columns
defined for this benchmark table.
"""
self.validate_user_supplied_kwargs(kwargs)

line_item = dict(kwargs)

line_item['row_status'] = 'SUCCESS'

line_item['timings_min_usecs'] = timings.min * 1000000
line_item['timings_max_usecs'] = timings.max * 1000000
line_item['timings_median_usecs'] = timings.median * 1000000
line_item['timings_stddev_usecs'] = timings.std * 1000000
line_item['timings_mean_usecs'] = timings.mean * 1000000

self._line_items.append(line_item)

def record_skip(self, **kwargs):
self.validate_user_supplied_kwargs(kwargs)

line_item = dict(kwargs)
line_item['row_status'] = 'SKIP'
self._line_items.append(line_item)

def record_fail(self, **kwargs):
self.validate_user_supplied_kwargs(kwargs)

line_item = dict(kwargs)
line_item['row_status'] = 'FAIL'
self._line_items.append(line_item)

def has_fail(self):
"""
Returns True if the table contains at least one 'fail' line item,
otherwise returns False.
"""
return any(item['row_status']=='FAIL' for item in self._line_items)

def print_csv(self, f, column_name_order, timing_decimal_places=3):
"""
Print the benchmark results as a csv.
`f` : The output stream.
`column_name_order`: an iterable sequence of column names, indicating the
left-to-right ordering of columns in the CSV output.
The CSV output will contain only those columns that are mentioned in
this list.
`timing_decimal_places`: for the numeric timing values, this is the
number of decimal places to provide in the printed output.
For example, a value of 3 is equivalent to the Python formatting string
`'{:.3f}'`
"""
writer = csv.DictWriter(f, column_name_order, dialect="excel-tab", restval="", extrasaction='ignore')

writer.writeheader()

for line_item_dict in self._line_items:
# Use a copy of the line-item dictionary, because we might do some modifications
# for the sake of rendering...
csv_line_dict = dict(line_item_dict)

for col_name in [
"timings_min_usecs",
"timings_max_usecs",
"timings_median_usecs",
"timings_stddev_usecs",
"timings_mean_usecs",
]:
if col_name in csv_line_dict:
old_value = csv_line_dict[col_name]
assert isinstance(old_value, float), \
f"Formatting code assumes that column {col_name} is some col_nameind of float, but its actual type is {type(old_value)}"
str_value = f"{old_value:>0.{timing_decimal_places}f}"
csv_line_dict[col_name] = str_value

writer.writerow(csv_line_dict)

0 comments on commit f523fac

Please sign in to comment.