diff --git a/profile/run_profile.py b/profile/run_profile.py index 3e85378..4b36f36 100644 --- a/profile/run_profile.py +++ b/profile/run_profile.py @@ -1,5 +1,4 @@ import argparse -import colorama from gptools.stan import compile_model from gptools.util import Timer from gptools.util.fft import transform_irfft @@ -101,17 +100,16 @@ def __main__(args: Optional[List[str]] = None) -> None: "data": data, "seed": args.seed, } - remaining = args.timeout - total_timer.duration if args.method == "sample": iter_warmup = args.iter_warmup or args.iter_sampling fit = call_with_timeout( - remaining, model.sample, iter_sampling=args.iter_sampling, chains=1, + args.timeout, model.sample, iter_sampling=args.iter_sampling, chains=1, threads_per_chain=1, show_progress=args.show_progress, iter_warmup=iter_warmup, **kwargs, ) elif args.method == "variational": fit = call_with_timeout( - remaining, model.variational, output_samples=args.iter_sampling, + args.timeout, model.variational, output_samples=args.iter_sampling, require_converged=not args.ignore_converged, **kwargs, ) else: # pragma: no cover @@ -162,11 +160,6 @@ def __main__(args: Optional[List[str]] = None) -> None: if args.method == "sample" and args.show_diagnostics: print(fit.diagnose()) - # Complain if the timeout was exceeded by more than a second. - if args.timeout is not None and total_timer.duration > args.timeout + 1: - print(f"{colorama.Back.RED}TOTAL DURATION EXCEEDED TIMEOUT: {total_timer.duration:.1f} > " - f"{args.timeout:.1f}{colorama.Back.RESET}") - if __name__ == "__main__": __main__()