diff --git a/src/exchange/providers/anthropic.py b/src/exchange/providers/anthropic.py index 02e6361..e12d210 100644 --- a/src/exchange/providers/anthropic.py +++ b/src/exchange/providers/anthropic.py @@ -1,5 +1,4 @@ import os -import time from typing import Any, Dict, List, Tuple, Type import httpx @@ -7,6 +6,7 @@ from exchange import Message, Tool from exchange.content import Text, ToolResult, ToolUse from exchange.providers.base import Provider, Usage +from exchange.providers.retry_with_back_off_decorator import retry_httpx_request from exchange.providers.utils import raise_for_status ANTHROPIC_HOST = "https://api.anthropic.com/v1/messages" @@ -138,26 +138,14 @@ def complete( ) payload = {k: v for k, v in payload.items() if v} - max_retries = 5 - initial_wait = 10 # Start with 10 seconds - backoff_factor = 1 - for retry in range(max_retries): - response = self.client.post(ANTHROPIC_HOST, json=payload) - if response.status_code not in (429, 529, 500): - break - else: - sleep_time = initial_wait + (backoff_factor * (2**retry)) - time.sleep(sleep_time) - - if response.status_code in (429, 529, 500): - raise httpx.HTTPStatusError( - f"Failed after {max_retries} retries due to rate limiting", - request=response.request, - response=response, - ) + response = self._send_request(payload) response_data = raise_for_status(response).json() message = self.anthropic_response_to_message(response_data) usage = self.get_usage(response_data) return message, usage + + @retry_httpx_request() + def _send_request(self, payload: Dict[str, Any]) -> httpx.Response: + return self.client.post(ANTHROPIC_HOST, json=payload) diff --git a/src/exchange/providers/azure.py b/src/exchange/providers/azure.py index aadec03..c3b9ca9 100644 --- a/src/exchange/providers/azure.py +++ b/src/exchange/providers/azure.py @@ -5,6 +5,7 @@ from exchange.message import Message from exchange.providers.base import Provider, Usage +from exchange.providers.retry_with_back_off_decorator import retry_httpx_request from exchange.providers.utils import ( messages_to_openai_spec, openai_response_to_message, @@ -98,7 +99,7 @@ def complete( payload = {k: v for k, v in payload.items() if v} request_url = f"{self.client.base_url}/chat/completions?api-version={self.api_version}" - response = self.client.post(request_url, json=payload) + response = self._send_request(payload, request_url) # Check for context_length_exceeded error for single, long input message if "error" in response.json() and len(messages) == 1: @@ -109,3 +110,7 @@ def complete( message = openai_response_to_message(data) usage = self.get_usage(data) return message, usage + + @retry_httpx_request() + def _send_request(self, payload: Any, request_url: str) -> httpx.Response: # noqa: ANN401 + return self.client.post(request_url, json=payload) diff --git a/src/exchange/providers/bedrock.py b/src/exchange/providers/bedrock.py index 0dd59c0..739c2e1 100644 --- a/src/exchange/providers/bedrock.py +++ b/src/exchange/providers/bedrock.py @@ -12,6 +12,7 @@ from exchange.content import Text, ToolResult, ToolUse from exchange.message import Message from exchange.providers import Provider, Usage +from exchange.providers.retry_with_back_off_decorator import retry_httpx_request from exchange.providers.utils import raise_for_status from exchange.tool import Tool @@ -204,7 +205,7 @@ def complete( path = f"model/{model}/converse" - response = self.client.post(path, json=payload) + response = self._send_request(payload, path) raise_for_status(response) response_message = response.json()["output"]["message"] @@ -217,6 +218,10 @@ def complete( return self.response_to_message(response_message), usage + @retry_httpx_request() + def _send_request(self, payload: Any, path:str) -> httpx.Response: # noqa: ANN401 + return self.client.post(path, json=payload) + @staticmethod def message_to_bedrock_spec(message: Message) -> dict: bedrock_content = [] diff --git a/src/exchange/providers/databricks.py b/src/exchange/providers/databricks.py index 09453f0..f946c52 100644 --- a/src/exchange/providers/databricks.py +++ b/src/exchange/providers/databricks.py @@ -5,6 +5,7 @@ from exchange.message import Message from exchange.providers.base import Provider, Usage +from exchange.providers.retry_with_back_off_decorator import retry_httpx_request from exchange.providers.utils import ( messages_to_openai_spec, openai_response_to_message, @@ -79,11 +80,15 @@ def complete( **kwargs, ) payload = {k: v for k, v in payload.items() if v} - response = self.client.post( - f"serving-endpoints/{model}/invocations", - json=payload, - ) + response = self._send_request(model, payload) data = raise_for_status(response).json() message = openai_response_to_message(data) usage = self.get_usage(data) return message, usage + + @retry_httpx_request() + def _send_request(self, model: str, payload: Any) -> httpx.Response: # noqa: ANN401 + return self.client.post( + f"serving-endpoints/{model}/invocations", + json=payload, + ) diff --git a/src/exchange/providers/ollama.py b/src/exchange/providers/ollama.py index a7af7c6..cb954a5 100644 --- a/src/exchange/providers/ollama.py +++ b/src/exchange/providers/ollama.py @@ -5,6 +5,7 @@ from exchange.message import Message from exchange.providers.base import Provider, Usage +from exchange.providers.retry_with_back_off_decorator import retry_httpx_request from exchange.providers.utils import ( messages_to_openai_spec, openai_response_to_message, @@ -80,7 +81,7 @@ def complete( **kwargs, ) payload = {k: v for k, v in payload.items() if v} - response = self.client.post("v1/chat/completions", json=payload) + response = self._send_request(payload) # Check for context_length_exceeded error for single, long input message if "error" in response.json() and len(messages) == 1: @@ -91,3 +92,7 @@ def complete( message = openai_response_to_message(data) usage = self.get_usage(data) return message, usage + + @retry_httpx_request() + def _send_request(self, payload: Any) -> httpx.Response: # noqa: ANN401 + return self.client.post("v1/chat/completions", json=payload) diff --git a/src/exchange/providers/openai.py b/src/exchange/providers/openai.py index 1f1ac23..1f3133b 100644 --- a/src/exchange/providers/openai.py +++ b/src/exchange/providers/openai.py @@ -5,6 +5,7 @@ from exchange.message import Message from exchange.providers.base import Provider, Usage +from exchange.providers.retry_with_back_off_decorator import retry_httpx_request from exchange.providers.utils import ( messages_to_openai_spec, openai_response_to_message, @@ -74,7 +75,7 @@ def complete( **kwargs, ) payload = {k: v for k, v in payload.items() if v} - response = self.client.post("v1/chat/completions", json=payload) + response = self._send_request(payload) # Check for context_length_exceeded error for single, long input message if "error" in response.json() and len(messages) == 1: @@ -85,3 +86,7 @@ def complete( message = openai_response_to_message(data) usage = self.get_usage(data) return message, usage + + @retry_httpx_request() + def _send_request(self, payload: Any) -> httpx.Response: # noqa: ANN401 + return self.client.post("v1/chat/completions", json=payload) diff --git a/src/exchange/providers/retry_with_back_off_decorator.py b/src/exchange/providers/retry_with_back_off_decorator.py new file mode 100644 index 0000000..af6de50 --- /dev/null +++ b/src/exchange/providers/retry_with_back_off_decorator.py @@ -0,0 +1,56 @@ +import time +from functools import wraps +from typing import Any, Callable, Dict, Iterable, List, Optional + +from httpx import HTTPStatusError, Response + + +def retry_with_backoff( + should_retry: Callable, + max_retries: Optional[int] = 5, + initial_wait: Optional[float] = 10, + backoff_factor: Optional[float] = 1, + handle_retry_exhausted: Optional[Callable] = None) -> Callable: + def decorator(func: Callable) -> Callable: + @wraps(func) + def wrapper(*args: List, **kwargs: Dict) -> Any: # noqa: ANN401 + result = None + for retry in range(max_retries): + result = func(*args, **kwargs) + if not should_retry(result): + return result + if (retry + 1) == max_retries: + break + sleep_time = initial_wait + (backoff_factor * (2 ** retry)) + time.sleep(sleep_time) + if handle_retry_exhausted: + handle_retry_exhausted(result, max_retries) + return result + return wrapper + return decorator + +def retry_httpx_request( + retry_on_status_code: Optional[Iterable[int]] = None, + max_retries: Optional[int] = 5, + initial_wait: Optional[float] = 10, + backoff_factor: Optional[float] = 1, +) -> Callable: + if retry_on_status_code is None: + retry_on_status_code = set(range(401, 999)) + def should_retry(response: Response) -> bool: + return response.status_code in retry_on_status_code + + def handle_retry_exhausted(response: Response, max_retries: int) -> None: + raise HTTPStatusError( + f"Failed after {max_retries}.", + request=response.request, + response=response, + ) + + return retry_with_backoff( + max_retries=max_retries, + initial_wait=initial_wait, + backoff_factor=backoff_factor, + should_retry=should_retry, + handle_retry_exhausted=handle_retry_exhausted + ) \ No newline at end of file diff --git a/tests/providers/test_retry_with_back_off_decorator.py b/tests/providers/test_retry_with_back_off_decorator.py new file mode 100644 index 0000000..1b56bbc --- /dev/null +++ b/tests/providers/test_retry_with_back_off_decorator.py @@ -0,0 +1,175 @@ +from unittest.mock import MagicMock + +import pytest +from exchange.providers.retry_with_back_off_decorator import retry_httpx_request, retry_with_backoff +from httpx import HTTPStatusError, Response + + +def create_mock_function(): + mock_function = MagicMock() + mock_function.side_effect = [3, 5, 7] + return mock_function + +def test_retry_with_backoff_retry_exhausted(): + mock_function = create_mock_function() + handle_retry_exhausted_function = MagicMock() + + def should_try(result): + return result < 7 + + @retry_with_backoff( + should_retry = should_try, + max_retries=2, + initial_wait=0, + backoff_factor=0.001, + handle_retry_exhausted=handle_retry_exhausted_function) + def test_func(): + return mock_function() + + assert test_func() == 5 + + assert mock_function.call_count == 2 + handle_retry_exhausted_function.assert_called_once() + handle_retry_exhausted_function.assert_called_with(5, 2) + +def test_retry_with_backoff_retry_successful(): + mock_function = create_mock_function() + handle_retry_exhausted_function = MagicMock() + + def should_try(result): + return result < 4 + + @retry_with_backoff( + should_retry = should_try, + max_retries=2, + initial_wait=0, + backoff_factor=0.001, + handle_retry_exhausted=handle_retry_exhausted_function) + def test_func(): + return mock_function() + + assert test_func() == 5 + + assert mock_function.call_count == 2 + handle_retry_exhausted_function.assert_not_called() + +def test_retry_with_backoff_without_retry(): + mock_function = create_mock_function() + handle_retry_exhausted_function = MagicMock() + + def should_try(result): + return result < 2 + + @retry_with_backoff( + should_retry = should_try, + max_retries=2, + initial_wait=0, + backoff_factor=0.001, + handle_retry_exhausted=handle_retry_exhausted_function) + def test_func(): + return mock_function() + + assert test_func() == 3 + + assert mock_function.call_count == 1 + handle_retry_exhausted_function.assert_not_called() + +def create_mock_httpx_request_call_function(responses=[500, 429, 200]): + mock_function = MagicMock() + mock_responses = [] + for response_code in responses: + response = MagicMock() + response.status_code = response_code + mock_responses.append(response) + + mock_function.side_effect = mock_responses + return mock_function + +def test_retry_httpx_request_backoff_retry_exhausted(): + mock_httpx_request_call_function = create_mock_httpx_request_call_function() + + @retry_httpx_request( + retry_on_status_code=[500, 429], + max_retries=2, + initial_wait=0, + backoff_factor=0.001) + def test_func() -> Response: + return mock_httpx_request_call_function() + + with pytest.raises(HTTPStatusError): + test_func() + + assert mock_httpx_request_call_function.call_count == 2 + +def test_retry_httpx_request_backoff_retry_successful(): + mock_httpx_request_call_function = create_mock_httpx_request_call_function() + + @retry_httpx_request( + retry_on_status_code=[500], + max_retries=2, + initial_wait=0, + backoff_factor=0.001) + def test_func() -> Response: + return mock_httpx_request_call_function() + + assert test_func().status_code == 429 + + assert mock_httpx_request_call_function.call_count == 2 + +def test_retry_httpx_request_backoff_without_retry(): + mock_httpx_request_call_function = create_mock_httpx_request_call_function() + + @retry_httpx_request( + retry_on_status_code=[503], + max_retries=2, + initial_wait=0, + backoff_factor=0.001) + def test_func() -> Response: + return mock_httpx_request_call_function() + + assert test_func().status_code == 500 + + assert mock_httpx_request_call_function.call_count == 1 + +def test_retry_httpx_request_backoff_range(): + mock_httpx_request_call_function = create_mock_httpx_request_call_function(responses=[200]) + + @retry_httpx_request(max_retries=2, initial_wait=0, backoff_factor=0.001) + def test_func() -> Response: + return mock_httpx_request_call_function() + + assert test_func().status_code == 200 + + assert mock_httpx_request_call_function.call_count == 1 + + +def test_retry_httpx_request_backoff_range_retry_never_succeed(): + mock_httpx_request_call_function = create_mock_httpx_request_call_function(responses=[401, 500, 500]) + + @retry_httpx_request(max_retries=3, initial_wait=0, backoff_factor=0.001) + def test_func() -> Response: + return mock_httpx_request_call_function() + + # Never gets a successful response + with pytest.raises(HTTPStatusError): + f = test_func() + # last error is 500 + assert f.status_code == 500 + + # Has been retried 3 times + assert mock_httpx_request_call_function.call_count == 3 + + +def test_retry_httpx_request_backoff_range_retry_succeed(): + mock_httpx_request_call_function = create_mock_httpx_request_call_function(responses=[401, 500, 200]) + + @retry_httpx_request(max_retries=3, initial_wait=0, backoff_factor=0.001) + def test_func() -> Response: + return mock_httpx_request_call_function() + + # Retries and raises no error + f = test_func() + assert f.status_code == 200 + + # Has been retried 3 times + assert mock_httpx_request_call_function.call_count == 3