Skip to content

Commit

Permalink
Use full timeout duration for each run rather than just remainder.
Browse files Browse the repository at this point in the history
  • Loading branch information
tillahoffmann committed Sep 14, 2023
1 parent 300c384 commit 71c0836
Showing 1 changed file with 2 additions and 9 deletions.
11 changes: 2 additions & 9 deletions profile/run_profile.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import argparse
import colorama
from gptools.stan import compile_model
from gptools.util import Timer
from gptools.util.fft import transform_irfft
Expand Down Expand Up @@ -101,17 +100,16 @@ def __main__(args: Optional[List[str]] = None) -> None:
"data": data,
"seed": args.seed,
}
remaining = args.timeout - total_timer.duration
if args.method == "sample":
iter_warmup = args.iter_warmup or args.iter_sampling
fit = call_with_timeout(
remaining, model.sample, iter_sampling=args.iter_sampling, chains=1,
args.timeout, model.sample, iter_sampling=args.iter_sampling, chains=1,
threads_per_chain=1, show_progress=args.show_progress,
iter_warmup=iter_warmup, **kwargs,
)
elif args.method == "variational":
fit = call_with_timeout(
remaining, model.variational, output_samples=args.iter_sampling,
args.timeout, model.variational, output_samples=args.iter_sampling,
require_converged=not args.ignore_converged, **kwargs,
)
else: # pragma: no cover
Expand Down Expand Up @@ -162,11 +160,6 @@ def __main__(args: Optional[List[str]] = None) -> None:
if args.method == "sample" and args.show_diagnostics:
print(fit.diagnose())

# Complain if the timeout was exceeded by more than a second.
if args.timeout is not None and total_timer.duration > args.timeout + 1:
print(f"{colorama.Back.RED}TOTAL DURATION EXCEEDED TIMEOUT: {total_timer.duration:.1f} > "
f"{args.timeout:.1f}{colorama.Back.RESET}")


if __name__ == "__main__":
__main__()

0 comments on commit 71c0836

Please sign in to comment.