Skip to content

Commit

Permalink
also track time in benchmark
Browse files Browse the repository at this point in the history
  • Loading branch information
tybug committed Aug 14, 2024
1 parent 7557e14 commit cb8aa7e
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 15 deletions.
8 changes: 7 additions & 1 deletion hypothesis-python/benchmark/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,18 @@

import inspect
import json
import time
from collections import defaultdict

import pytest
from _pytest.monkeypatch import MonkeyPatch

mode = "calls"
# we'd like to support xdist here for parallelism, but a session-scope fixture won't
# be enough: https://github.com/pytest-dev/pytest-xdist/issues/271. need a lockfile
# or equivalent.
shrink_calls = defaultdict(list)
timer = time.process_time


def pytest_collection_modifyitems(config, items):
Expand Down Expand Up @@ -51,8 +54,11 @@ def record_shrink_calls(calls):
old_shrink = Shrinker.shrink

def shrink(self, *args, **kwargs):
t = timer()
v = old_shrink(self, *args, **kwargs)
record_shrink_calls(self.engine.call_count - self.initial_calls)
time = timer() - t
calls = self.engine.call_count - self.initial_calls
record_shrink_calls({"calls": calls, "time": time})
return v

monkeypatch.setattr(Shrinker, "shrink", shrink)
Expand Down
35 changes: 21 additions & 14 deletions hypothesis-python/benchmark/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,16 +51,17 @@
new_names.append(name)
names = new_names


# either "time" or "calls"
statistic = "calls"
# name : average calls
old_values = {}
new_values = {}
for name in names:

# mean across the different minimal() calls in a single test function, then
# median across the n iterations we ran that for to reduce error
old_vals = [statistics.mean(run[name]) for run in old_runs]
new_vals = [statistics.mean(run[name]) for run in new_runs]
old_vals = [statistics.mean(r[statistic] for r in run[name]) for run in old_runs]
new_vals = [statistics.mean(r[statistic] for r in run[name]) for run in new_runs]
old_values[name] = statistics.median(old_vals)
new_values[name] = statistics.median(new_vals)

Expand All @@ -70,20 +71,21 @@
old = old_values[name]
new = new_values[name]
diff = old - new
diff_times = (old - new) / old
if old == 0:
diff_times = 0
else:
diff_times = (old - new) / old
if 0 < diff_times < 1:
diff_times = (1 / (1 - diff_times)) - 1
diffs[name] = (diff, diff_times)

print(f"{name} {int(diff)} ({int(old)} -> {int(new)}, {round(diff_times, 1)}✕)")
print(f"{name} {diff} ({old} -> {new}, {round(diff_times, 1)}✕)")

diffs = dict(sorted(diffs.items(), key=lambda kv: kv[1][0]))
diffs_value = [v[0] for v in diffs.values()]
diffs_percentage = [v[1] for v in diffs.values()]

print(
f"mean: {int(statistics.mean(diffs_value))}, median: {int(statistics.median(diffs_value))}"
)
print(f"mean: {statistics.mean(diffs_value)}, median: {statistics.median(diffs_value)}")


# https://stackoverflow.com/a/65824524
Expand All @@ -100,15 +102,20 @@ def align_axes(ax1, ax2):
ax1.set_ylim(bottom=ax1_ylims[1] * ax2_yratio)


ax1 = sns.barplot(diffs_value, color="b", alpha=0.7, label="shrink call change")
ax1 = sns.barplot(diffs_value, color="b", alpha=0.7, label="absolute change")
ax2 = plt.twinx()
sns.barplot(diffs_percentage, color="r", alpha=0.7, label=r"n✕ change", ax=ax2)
sns.barplot(diffs_percentage, color="r", alpha=0.7, ax=ax2, label="n✕ change")

ax1.set_title("old shrinks - new shrinks (aka shrinks saved, higher is better)")
ax1.set_title(
"old shrinks - new shrinks (aka shrinks saved, higher is better)"
if statistic == "calls"
else "old time - new time in seconds (aka time saved, higher is better)"
)
ax1.set_xticks([])
align_axes(ax1, ax2)
legend = ax1.legend(labels=["shrink call change", "n✕ change"])
legend.legend_handles[0].set_color("b")
legend.legend_handles[1].set_color("r")
legend1 = ax1.legend(loc="upper left")
legend1.legend_handles[0].set_color("b")
legend2 = ax2.legend(loc="lower right")
legend2.legend_handles[0].set_color("r")

plt.show()

0 comments on commit cb8aa7e

Please sign in to comment.