Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

improve benchmark_runner usage efficiency #16653

Merged
merged 1 commit into from
Oct 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 18 additions & 3 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

import dataclasses
import datetime
import functools
import math
import multiprocessing
import os
import random
Expand Down Expand Up @@ -62,7 +64,7 @@
from tests.core.data_layer.util import ChiaRoot
from tests.core.node_height import node_height_at_least
from tests.simulation.test_simulation import test_constants_modified
from tests.util.misc import BenchmarkRunner
from tests.util.misc import BenchmarkRunner, GcMode, _AssertRuntime, measure_overhead

multiprocessing.set_start_method("spawn")

Expand All @@ -81,10 +83,23 @@ def seeded_random_fixture() -> random.Random:
return seeded_random


@pytest.fixture(name="benchmark_runner_overhead", scope="session")
def benchmark_runner_overhead_fixture() -> float:
return measure_overhead(
manager_maker=functools.partial(
_AssertRuntime,
gc_mode=GcMode.nothing,
seconds=math.inf,
print=False,
),
cycles=100,
)


@pytest.fixture(name="benchmark_runner")
def benchmark_runner_fixture(request: SubRequest) -> BenchmarkRunner:
def benchmark_runner_fixture(request: SubRequest, benchmark_runner_overhead: float) -> BenchmarkRunner:
label = request.node.name
return BenchmarkRunner(label=label)
return BenchmarkRunner(label=label, overhead=benchmark_runner_overhead)


@pytest.fixture(name="node_name_for_file")
Expand Down
41 changes: 15 additions & 26 deletions tests/util/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import functools
import gc
import logging
import math
import os
import subprocess
import sys
Expand Down Expand Up @@ -39,7 +38,9 @@ class GcMode(enum.Enum):

@contextlib.contextmanager
def manage_gc(mode: GcMode) -> Iterator[None]:
if mode == GcMode.precollect:
if mode == GcMode.nothing:
yield
elif mode == GcMode.precollect:
gc.collect()
yield
elif mode == GcMode.disable:
Expand Down Expand Up @@ -71,7 +72,7 @@ class RuntimeResults:
end: float
duration: float
entry_line: str
overhead: float
overhead: Optional[float]

def block(self, label: str = "") -> str:
# The entry line is reported starting at the beginning of the line to trigger
Expand All @@ -82,7 +83,7 @@ def block(self, label: str = "") -> str:
Measuring runtime: {label}
{self.entry_line}
run time: {self.duration}
overhead: {self.overhead}
overhead: {self.overhead if self.overhead is not None else "not measured"}
"""
)

Expand All @@ -94,13 +95,13 @@ class AssertRuntimeResults:
end: float
duration: float
entry_line: str
overhead: float
overhead: Optional[float]
limit: float
ratio: float

@classmethod
def from_runtime_results(
cls, results: RuntimeResults, limit: float, entry_line: str, overhead: float
cls, results: RuntimeResults, limit: float, entry_line: str, overhead: Optional[float]
) -> AssertRuntimeResults:
return cls(
start=results.start,
Expand All @@ -121,7 +122,7 @@ def block(self, label: str = "") -> str:
Asserting maximum duration: {label}
{self.entry_line}
run time: {self.duration}
overhead: {self.overhead}
overhead: {self.overhead if self.overhead is not None else "not measured"}
allowed: {self.limit}
percent: {self.percent_str()}
"""
Expand Down Expand Up @@ -164,19 +165,11 @@ def measure_runtime(
label: str = "",
clock: Callable[[], float] = thread_time,
gc_mode: GcMode = GcMode.disable,
calibrate: bool = True,
overhead: Optional[float] = None,
print_results: bool = True,
) -> Iterator[Future[RuntimeResults]]:
entry_line = caller_file_and_line()

def manager_maker() -> contextlib.AbstractContextManager[Future[RuntimeResults]]:
return measure_runtime(clock=clock, gc_mode=gc_mode, calibrate=False, print_results=False)

if calibrate:
overhead = measure_overhead(manager_maker=manager_maker)
else:
overhead = 0

results_future: Future[RuntimeResults] = Future()

with manage_gc(mode=gc_mode):
Expand All @@ -188,7 +181,8 @@ def manager_maker() -> contextlib.AbstractContextManager[Future[RuntimeResults]]
end = clock()

duration = end - start
duration -= overhead
if overhead is not None:
duration -= overhead

results = RuntimeResults(
start=start,
Expand Down Expand Up @@ -232,9 +226,8 @@ class _AssertRuntime:
label: str = ""
clock: Callable[[], float] = thread_time
gc_mode: GcMode = GcMode.disable
calibrate: bool = True
print: bool = True
overhead: float = 0
overhead: Optional[float] = None
entry_line: Optional[str] = None
_results: Optional[AssertRuntimeResults] = None
runtime_manager: Optional[contextlib.AbstractContextManager[Future[RuntimeResults]]] = None
Expand All @@ -243,15 +236,9 @@ class _AssertRuntime:

def __enter__(self) -> Future[AssertRuntimeResults]:
self.entry_line = caller_file_and_line()
if self.calibrate:

def manager_maker() -> contextlib.AbstractContextManager[Future[AssertRuntimeResults]]:
return dataclasses.replace(self, seconds=math.inf, calibrate=False, print=False)

self.overhead = measure_overhead(manager_maker=manager_maker)

self.runtime_manager = measure_runtime(
clock=self.clock, gc_mode=self.gc_mode, calibrate=False, print_results=False
clock=self.clock, gc_mode=self.gc_mode, overhead=self.overhead, print_results=False
)
self.runtime_results_callable = self.runtime_manager.__enter__()
self.results_callable: Future[AssertRuntimeResults] = Future()
Expand Down Expand Up @@ -292,10 +279,12 @@ def __exit__(
class BenchmarkRunner:
enable_assertion: bool = True
label: Optional[str] = None
overhead: Optional[float] = None

@functools.wraps(_AssertRuntime)
def assert_runtime(self, *args: Any, **kwargs: Any) -> _AssertRuntime:
kwargs.setdefault("enable_assertion", self.enable_assertion)
kwargs.setdefault("overhead", self.overhead)
if self.label is not None:
kwargs.setdefault("label", self.label)
return _AssertRuntime(*args, **kwargs)
Expand Down
Loading