diff --git a/src/gpt_review/_openai.py b/src/gpt_review/_openai.py index d7f43c7e..9bd9debe 100644 --- a/src/gpt_review/_openai.py +++ b/src/gpt_review/_openai.py @@ -1,11 +1,11 @@ """Open AI API Call Wrapper.""" import logging -import time import openai from openai.error import RateLimitError import gpt_review.constants as C +from gpt_review.utils import _retry_with_exponential_backoff from gpt_review.context import _load_azure_openai_context @@ -96,12 +96,7 @@ def _call_gpt( return completion.choices[0].message.content # type: ignore except RateLimitError as error: if retry < C.MAX_RETRIES: - logging.warning("Call to GPT failed due to rate limit, retry attempt %s of %s", retry, C.MAX_RETRIES) - - wait_time = int(error.headers["Retry-After"]) if error.headers["Retry-After"] else retry * 10 - logging.warning("Waiting for %s seconds before retrying.", wait_time) - - time.sleep(wait_time) + _retry_with_exponential_backoff(retry, error.headers["Retry-After"]) return _call_gpt(prompt, temperature, max_tokens, top_p, frequency_penalty, presence_penalty, retry + 1) raise RateLimitError("Retry limit exceeded") from error diff --git a/src/gpt_review/utils.py b/src/gpt_review/utils.py new file mode 100644 index 00000000..7d6d79df --- /dev/null +++ b/src/gpt_review/utils.py @@ -0,0 +1,24 @@ +"""Utility functions""" +import logging +import time +from typing import Optional + +import gpt_review.constants as C + + +def _retry_with_exponential_backoff(current_retry: int, retry_after: Optional[str]) -> None: + """ + Use exponential backoff to retry a request after specific time while staying under the retry count + + Args: + current_retry (int): The current retry count. + retry_after (Optional[str]): The time to wait before retrying. + """ + logging.warning("Call to GPT failed due to rate limit, retry attempt %s of %s", current_retry, C.MAX_RETRIES) + + multiplication_factor = 2 * (1 + current_retry / C.MAX_RETRIES) + wait_time = int(retry_after) * multiplication_factor if retry_after else current_retry * multiplication_factor + + logging.warning("Waiting for %s seconds before retrying.", wait_time) + + time.sleep(wait_time)