Skip to content

Commit

Permalink
simplify table class
Browse files Browse the repository at this point in the history
  • Loading branch information
Christian Convey committed May 25, 2022
1 parent 8603cde commit a821f36
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 128 deletions.
24 changes: 8 additions & 16 deletions tests/python/contrib/test_hexagon/benchmark_hexagon.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@
# triggering TIME_WAIT state on the server socket. This prevents another
# server to bind to the same port until the wait time elapses.


@tvm.testing.requires_hexagon
def test_elemwise_add(hexagon_launcher: HexagonLauncherRPC):
"""
Expand All @@ -58,23 +57,16 @@ def test_elemwise_add(hexagon_launcher: HexagonLauncherRPC):
print("-" * 80)
print()

br = benchmarks_table({
'dtype':'dtype',
'sched_type':'sched_type',
'mem_scope':'mem_scope',
'num_vectors_per_tensor':'# 2KB vectors per tensor',
'comments':'comments',
})

bt = benchmarks_table()

# Create and benchmark a single primfunc.
# If an unexpected problem occurs, raise an exception. Otherwise add a row of output to 'br'.
# If an unexpected problem occurs, raise an exception. Otherwise add a row of output to 'bt'.
def test_one_config(dtype, sched_type, mem_scope, num_vectors_per_tensor):
version_name = f"dtype:{dtype}-schedtype:{sched_type}-memscope:{mem_scope}-numvecs:{num_vectors_per_tensor}"
print(f"CONFIGURATION: {version_name}")

if num_vectors_per_tensor == 2048 and mem_scope == "global.vtcm":
br.record_skip(
bt.record_skip(
dtype=dtype,
sched_type=sched_type,
mem_scope=mem_scope,
Expand Down Expand Up @@ -177,7 +169,7 @@ def test_one_config(dtype, sched_type, mem_scope, num_vectors_per_tensor):
result = C_data.numpy()
tvm.testing.assert_allclose(host_numpy_C_data_expected, result)

br.record_success(
bt.record_success(
timing_result,
dtype=dtype,
sched_type=sched_type,
Expand All @@ -188,7 +180,7 @@ def test_one_config(dtype, sched_type, mem_scope, num_vectors_per_tensor):
except Exception as err:
f.write("ERROR:\n")
f.write("{}\n".format(err))
br.record_failure(
bt.record_fail(
dtype=dtype,
sched_type=sched_type,
mem_scope=mem_scope,
Expand Down Expand Up @@ -236,7 +228,7 @@ def test_one_config(dtype, sched_type, mem_scope, num_vectors_per_tensor):
test_one_config(dtype, sched_type, mem_scope, num_vectors_per_tensor)

# Report our progress.
br.print_csv(sys.stdout, csv_column_order)
bt.print_csv(sys.stdout, csv_column_order)

print("-" * 80)
print(f"OUTPUT DIRECTORY: {host_output_dir}")
Expand All @@ -245,8 +237,8 @@ def test_one_config(dtype, sched_type, mem_scope, num_vectors_per_tensor):

tabular_output_filename = os.path.join(host_output_dir, "benchmark-results.csv")
with open(tabular_output_filename, "w") as csv_file:
br.print_csv(csv_file, csv_column_order)
bt.print_csv(csv_file, csv_column_order)
print(f"BENCHMARK RESULTS FILE: {tabular_output_filename}")

if br.has_fail() > 0:
if bt.has_fail() > 0:
pytest.fail("At least one benchmark configuration failed", pytrace=False)
130 changes: 18 additions & 112 deletions tests/python/contrib/test_hexagon/benchmarks_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,108 +30,22 @@ class benchmarks_table:
Each line item may also specify values for any subset of
the columns provided to the table's construstor.
"""
BUILTIN_COLUMN_NAMES_TO_DESCS = {
"row_status":"status",
"timings_median_usecs":"median(µsec)",
"timings_min_usecs":"min(µsec)",
"timings_max_usecs":"max(µsec)",
}

class column_metadata_:
def __init__(self, name, is_reserved, header_text):
self.name = name
self.is_reserved = is_reserved
self.header_text = header_text

class column_collection_metadata_:
def __init__(self):
self.by_name = {}
self.by_header_text = {}

def add(self, cm):
if cm.name in self.by_name:
raise Exception(f"A column already exists with name '{cm.name}'")

if cm.header_text in self.by_header_text:
raise Exception(f"A column already exists with header_text '{cm.header_text}'")

self.by_name[ cm.name ] = cm
self.by_header_text[ cm.header_text ] = cm

def get_column_names(self):
return set(self.by_name.keys())

def get_unreserved_column_names(self):
return set([ k for k,v in self.by_name.items() if not v.is_reserved])

def get_reserved_column_names(self):
return set([ k for k,v in self.by_name.items() if v.is_reserved])

def get_ordered_by_name_sequence(self, name_sequence):
"""
Returns a list of column_metadata objects, ordered according to
`name_sequence`.
"""
return_list = []
for column_name in name_sequence:
assert column_name in self.by_name
return_list.append(self.by_name[column_name])
return return_list

def convert_dict_key_from_column_name_to_header_text(self, d_in):
"""
`d_in` : A dictionary whose keys are a subset of those in `self.by_name`
Returns a new dictionary whose keys have been replaced with the
corresponding `header_text`.
Useful for things like csv.DictWriter.
"""
d_out = {}

for k_in,v in d_in.items():
k_out = self.by_name[k_in].header_text
d_out[ k_out ] = v

return d_out

def __init__(self, user_column_defns):
"""
`user_column_defns` : A dictionary of the form
(column_name : column_description).
The combination of this dictionary and the
BUILTIN_COLUMN_NAMES_TO_DESCS dictionary defines the set
of columns in that the benchmark table supports.
In the combined dictionary, no two columns can have
the same name or the same description.
"""
self.all_cols_metadata_ = self.column_collection_metadata_()

for col_name, col_header_text in self.BUILTIN_COLUMN_NAMES_TO_DESCS.items():
self.all_cols_metadata_.add(self.column_metadata_(col_name, True, col_header_text))

for col_name, col_header_text in user_column_defns.items():
self.all_cols_metadata_.add(self.column_metadata_(col_name, False, col_header_text))

BUILTIN_COLUMN_NAMES = set([
"row_status",
"timings_median_usecs",
"timings_min_usecs",
"timings_max_usecs",
])

def __init__(self):
self.line_items_ = []

def validate_user_supplied_kwargs(self, kwarg_dict):
provided_column_names = set(kwarg_dict.keys())
defined_column_names = self.all_cols_metadata_.get_column_names()
reserved_column_names = self.all_cols_metadata_.get_reserved_column_names()

reserved_names_used = provided_column_names.intersection(reserved_column_names)
undefined_names_used = provided_column_names - defined_column_names
name_conflicts = set(kwarg_dict.keys()).intersection(self.BUILTIN_COLUMN_NAMES)

if len(reserved_names_used) > 0:
name_list = ', '.join(reserved_names_used)
raise Exception(f'Cannot supply a value for reserved column names: {reserved_names_used}')

if len(undefined_names_used) > 0:
name_list = ', '.join(undefined_names_used)
raise Exception(f'Cannot supply a value for undefined column names: {undefined_names_used}')
if len(name_conflicts) > 0:
name_list = ', '.join(name_conflicts)
raise Exception(f'Attempting to supply values for buil-in column names: {name_list}')

def record_success(self, timings, **kwargs):
"""
Expand All @@ -147,6 +61,7 @@ def record_success(self, timings, **kwargs):
line_item = dict(kwargs)

line_item['row_status'] = 'SUCCESS'

line_item['timings_min_usecs'] = timings.min * 1000000
line_item['timings_max_usecs'] = timings.max * 1000000
line_item['timings_median_usecs'] = timings.median * 1000000
Expand All @@ -169,7 +84,7 @@ def record_fail(self, **kwargs):

def has_fail(self):
"""
Returns True if the table contains at least one 'file' line item,
Returns True if the table contains at least one 'fail' line item,
otherwise returns False.
"""
for li in self.line_items_:
Expand All @@ -186,8 +101,6 @@ def print_csv(self, f, column_name_order, timing_decimal_places=3):
`column_name_order`: an iterable sequence of column names, indicating the
order of column in the CSV output.
Each string must be one of the column names provided by
BUILTIN_COLUMN_NAMES_TO_DESCS or provided to the class constructor.
The CSV output will contain only those columns that are mentioned in
this list.
Expand All @@ -197,20 +110,18 @@ def print_csv(self, f, column_name_order, timing_decimal_places=3):
For example, a value of 3 is equivalent to the Python formatting string
`'{:.3f}'`
"""

csv.register_dialect(
"benchmarks",
delimiter="\t",
quotechar='"',
quoting=csv.QUOTE_MINIMAL,
)

output_order_cm_list = self.all_cols_metadata_.get_ordered_by_name_sequence(column_name_order)

output_order_header_texts = [ cm.header_text for cm in output_order_cm_list ]

writer = csv.DictWriter(f, output_order_header_texts, dialect="benchmarks", restval="")
writer = csv.DictWriter(f, column_name_order, dialect="benchmarks", restval="", extrasaction='ignore')

writer.writeheader()

for line_item_dict in self.line_items_:
for k in [
"timings_median_usecs",
Expand All @@ -223,9 +134,4 @@ def print_csv(self, f, column_name_order, timing_decimal_places=3):
str_value = f"{old_value:>0.{timing_decimal_places}f}"
line_item_dict[k] = str_value

# self.line_items_ is a list of dictionaries, where each dictionary is indexed
# by column *name*. DictWriter requires dictionaries that are indexed by *header text*.
csv_line_item_dict = \
self.all_cols_metadata_.convert_dict_key_from_column_name_to_header_text(line_item_dict)

writer.writerow(csv_line_item_dict)
writer.writerow(line_item_dict)

0 comments on commit a821f36

Please sign in to comment.