Skip to content

Commit

Permalink
Add n_jobs param to BenchmarkConfig
Browse files Browse the repository at this point in the history
GitOrigin-RevId: feaf8a226e10527cd1b778bdfa8b7fc5933224c6
  • Loading branch information
kboyd committed Apr 10, 2024
1 parent f175bb9 commit c4181f9
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 1 deletion.
2 changes: 2 additions & 0 deletions src/gretel_trainer/benchmark/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,14 @@ def __init__(
trainer: bool = False,
working_dir: Optional[Union[str, Path]] = None,
additional_report_scores: Optional[list[str]] = None,
n_jobs: int = 5,
):
self.project_display_name = project_display_name or _default_name()
self.working_dir = Path(working_dir or self.project_display_name)
self.refresh_interval = refresh_interval
self.trainer = trainer
self.additional_report_scores = additional_report_scores or []
self.n_jobs = n_jobs


class Timer:
Expand Down
2 changes: 1 addition & 1 deletion src/gretel_trainer/benchmark/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def __init__(
self._gretel_executors: dict[RunKey, Executor] = {}
self._custom_executors: dict[RunKey, Executor] = {}
self._trainer_project_names: dict[str, str] = {}
self._thread_pool = ThreadPoolExecutor(5)
self._thread_pool = ThreadPoolExecutor(self._config.n_jobs)
self._futures: dict[FutureKeyT, Future] = {}

self._report_scores = {
Expand Down

0 comments on commit c4181f9

Please sign in to comment.