Skip to content

Commit

Permalink
Merge pull request #3963 from BerriAI/litellm_set_allowed_fail_policy
Browse files Browse the repository at this point in the history
[FEAT]- set custom AllowedFailsPolicy on litellm.Router
  • Loading branch information
ishaan-jaff authored Jun 2, 2024
2 parents fb49d03 + 9f0ae21 commit 054456c
Show file tree
Hide file tree
Showing 4 changed files with 109 additions and 11 deletions.
32 changes: 25 additions & 7 deletions docs/my-website/docs/routing.md
Original file line number Diff line number Diff line change
Expand Up @@ -713,26 +713,43 @@ response = router.completion(model="gpt-3.5-turbo", messages=messages)
print(f"response: {response}")
```

#### Retries based on Error Type
### [Advanced]: Custom Retries, Cooldowns based on Error Type

Use `RetryPolicy` if you want to set a `num_retries` based on the Exception receieved
- Use `RetryPolicy` if you want to set a `num_retries` based on the Exception receieved
- Use `AllowedFailsPolicy` to set a custom number of `allowed_fails`/minute before cooling down a deployment

Example:
- 4 retries for `ContentPolicyViolationError`
- 0 retries for `RateLimitErrors`

```python
retry_policy = RetryPolicy(
ContentPolicyViolationErrorRetries=3, # run 3 retries for ContentPolicyViolationErrors
AuthenticationErrorRetries=0, # run 0 retries for AuthenticationErrorRetries
)

allowed_fails_policy = AllowedFailsPolicy(
ContentPolicyViolationErrorAllowedFails=1000, # Allow 1000 ContentPolicyViolationError before cooling down a deployment
RateLimitErrorAllowedFails=100, # Allow 100 RateLimitErrors before cooling down a deployment
)
```

Example Usage

```python
from litellm.router import RetryPolicy
from litellm.router import RetryPolicy, AllowedFailsPolicy

retry_policy = RetryPolicy(
ContentPolicyViolationErrorRetries=3, # run 3 retries for ContentPolicyViolationErrors
AuthenticationErrorRetries=0, # run 0 retries for AuthenticationErrorRetries
ContentPolicyViolationErrorRetries=3, # run 3 retries for ContentPolicyViolationErrors
AuthenticationErrorRetries=0, # run 0 retries for AuthenticationErrorRetries
BadRequestErrorRetries=1,
TimeoutErrorRetries=2,
RateLimitErrorRetries=3,
)

allowed_fails_policy = AllowedFailsPolicy(
ContentPolicyViolationErrorAllowedFails=1000, # Allow 1000 ContentPolicyViolationError before cooling down a deployment
RateLimitErrorAllowedFails=100, # Allow 100 RateLimitErrors before cooling down a deployment
)

router = litellm.Router(
model_list=[
{
Expand All @@ -755,6 +772,7 @@ router = litellm.Router(
},
],
retry_policy=retry_policy,
allowed_fails_policy=allowed_fails_policy,
)

response = await router.acompletion(
Expand Down
59 changes: 57 additions & 2 deletions litellm/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
updateDeployment,
updateLiteLLMParams,
RetryPolicy,
AllowedFailsPolicy,
AlertingConfig,
DeploymentTypedDict,
ModelGroupInfo,
Expand Down Expand Up @@ -116,6 +117,9 @@ def __init__(
allowed_fails: Optional[
int
] = None, # Number of times a deployment can failbefore being added to cooldown
allowed_fails_policy: Optional[
AllowedFailsPolicy
] = None, # set custom allowed fails policy
cooldown_time: Optional[
float
] = None, # (seconds) time to cooldown a deployment after failure
Expand Down Expand Up @@ -361,6 +365,7 @@ def __init__(
self.model_group_retry_policy: Optional[Dict[str, RetryPolicy]] = (
model_group_retry_policy
)
self.allowed_fails_policy: Optional[AllowedFailsPolicy] = allowed_fails_policy
self.alerting_config: Optional[AlertingConfig] = alerting_config
if self.alerting_config is not None:
self._initialize_alerting()
Expand Down Expand Up @@ -2445,6 +2450,7 @@ def deployment_callback_on_failure(
deployment_id = _model_info.get("id", None)
self._set_cooldown_deployments(
exception_status=exception_status,
original_exception=exception,
deployment=deployment_id,
time_to_cooldown=_time_to_cooldown,
) # setting deployment_id in cooldown deployments
Expand Down Expand Up @@ -2550,6 +2556,7 @@ def _is_cooldown_required(self, exception_status: Union[str, int]):

def _set_cooldown_deployments(
self,
original_exception: Any,
exception_status: Union[str, int],
deployment: Optional[str] = None,
time_to_cooldown: Optional[float] = None,
Expand All @@ -2568,6 +2575,12 @@ def _set_cooldown_deployments(
if self._is_cooldown_required(exception_status=exception_status) == False:
return

_allowed_fails = self.get_allowed_fails_from_policy(
exception=original_exception,
)

allowed_fails = _allowed_fails or self.allowed_fails

dt = get_utc_datetime()
current_minute = dt.strftime("%H-%M")
# get current fails for deployment
Expand All @@ -2577,7 +2590,7 @@ def _set_cooldown_deployments(
current_fails = self.failed_calls.get_cache(key=deployment) or 0
updated_fails = current_fails + 1
verbose_router_logger.debug(
f"Attempting to add {deployment} to cooldown list. updated_fails: {updated_fails}; self.allowed_fails: {self.allowed_fails}"
f"Attempting to add {deployment} to cooldown list. updated_fails: {updated_fails}; self.allowed_fails: {allowed_fails}"
)
cooldown_time = self.cooldown_time or 1
if time_to_cooldown is not None:
Expand All @@ -2594,7 +2607,8 @@ def _set_cooldown_deployments(
)
exception_status = 500
_should_retry = litellm._should_retry(status_code=exception_status)
if updated_fails > self.allowed_fails or _should_retry == False:

if updated_fails > allowed_fails or _should_retry == False:
# get the current cooldown list for that minute
cooldown_key = f"{current_minute}:cooldown_models" # group cooldown models by minute to reduce number of redis calls
cached_value = self.cache.get_cache(key=cooldown_key)
Expand Down Expand Up @@ -2737,6 +2751,7 @@ async def async_routing_strategy_pre_call_checks(self, deployment: dict):
except litellm.RateLimitError as e:
self._set_cooldown_deployments(
exception_status=e.status_code,
original_exception=e,
deployment=deployment["model_info"]["id"],
time_to_cooldown=self.cooldown_time,
)
Expand Down Expand Up @@ -4429,6 +4444,46 @@ def get_num_retries_from_retry_policy(
):
return retry_policy.ContentPolicyViolationErrorRetries

def get_allowed_fails_from_policy(self, exception: Exception):
"""
BadRequestErrorRetries: Optional[int] = None
AuthenticationErrorRetries: Optional[int] = None
TimeoutErrorRetries: Optional[int] = None
RateLimitErrorRetries: Optional[int] = None
ContentPolicyViolationErrorRetries: Optional[int] = None
"""
# if we can find the exception then in the retry policy -> return the number of retries
allowed_fails_policy: Optional[AllowedFailsPolicy] = self.allowed_fails_policy

if allowed_fails_policy is None:
return None

if (
isinstance(exception, litellm.BadRequestError)
and allowed_fails_policy.BadRequestErrorAllowedFails is not None
):
return allowed_fails_policy.BadRequestErrorAllowedFails
if (
isinstance(exception, litellm.AuthenticationError)
and allowed_fails_policy.AuthenticationErrorAllowedFails is not None
):
return allowed_fails_policy.AuthenticationErrorAllowedFails
if (
isinstance(exception, litellm.Timeout)
and allowed_fails_policy.TimeoutErrorAllowedFails is not None
):
return allowed_fails_policy.TimeoutErrorAllowedFails
if (
isinstance(exception, litellm.RateLimitError)
and allowed_fails_policy.RateLimitErrorAllowedFails is not None
):
return allowed_fails_policy.RateLimitErrorAllowedFails
if (
isinstance(exception, litellm.ContentPolicyViolationError)
and allowed_fails_policy.ContentPolicyViolationErrorAllowedFails is not None
):
return allowed_fails_policy.ContentPolicyViolationErrorAllowedFails

def _initialize_alerting(self):
from litellm.integrations.slack_alerting import SlackAlerting

Expand Down
8 changes: 7 additions & 1 deletion litellm/tests/test_router_retries.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,12 +128,17 @@ async def test_router_retries_errors(sync_mode, error_type):
["AuthenticationErrorRetries", "ContentPolicyViolationErrorRetries"], #
)
async def test_router_retry_policy(error_type):
from litellm.router import RetryPolicy
from litellm.router import RetryPolicy, AllowedFailsPolicy

retry_policy = RetryPolicy(
ContentPolicyViolationErrorRetries=3, AuthenticationErrorRetries=0
)

allowed_fails_policy = AllowedFailsPolicy(
ContentPolicyViolationErrorAllowedFails=1000,
RateLimitErrorAllowedFails=100,
)

router = Router(
model_list=[
{
Expand All @@ -156,6 +161,7 @@ async def test_router_retry_policy(error_type):
},
],
retry_policy=retry_policy,
allowed_fails_policy=allowed_fails_policy,
)

customHandler = MyCustomHandler()
Expand Down
21 changes: 20 additions & 1 deletion litellm/types/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,9 @@ class ModelInfo(BaseModel):
id: Optional[
str
] # Allow id to be optional on input, but it will always be present as a str in the model instance
db_model: bool = False # used for proxy - to separate models which are stored in the db vs. config.
db_model: bool = (
False # used for proxy - to separate models which are stored in the db vs. config.
)
updated_at: Optional[datetime.datetime] = None
updated_by: Optional[str] = None

Expand Down Expand Up @@ -381,6 +383,23 @@ class RouterErrors(enum.Enum):
no_deployments_available = "No deployments available for selected model"


class AllowedFailsPolicy(BaseModel):
"""
Use this to set a custom number of allowed fails/minute before cooling down a deployment
If `AuthenticationErrorAllowedFails = 1000`, then 1000 AuthenticationError will be allowed before cooling down a deployment
Mapping of Exception type to allowed_fails for each exception
https://docs.litellm.ai/docs/exception_mapping
"""

BadRequestErrorAllowedFails: Optional[int] = None
AuthenticationErrorAllowedFails: Optional[int] = None
TimeoutErrorAllowedFails: Optional[int] = None
RateLimitErrorAllowedFails: Optional[int] = None
ContentPolicyViolationErrorAllowedFails: Optional[int] = None
InternalServerErrorAllowedFails: Optional[int] = None


class RetryPolicy(BaseModel):
"""
Use this to set a custom number of retries per exception type
Expand Down

0 comments on commit 054456c

Please sign in to comment.