Skip to content

Commit

Permalink
Merge pull request #4693 from BerriAI/litellm_bad_req_error_mapping
Browse files Browse the repository at this point in the history
fix -  Raise `BadRequestError` when passing the wrong role
  • Loading branch information
ishaan-jaff authored Jul 13, 2024
2 parents c1a9881 + bba748e commit 1206b0b
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 8 deletions.
1 change: 1 addition & 0 deletions litellm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -876,6 +876,7 @@ def identify(event_details):
InternalServerError,
JSONSchemaValidationError,
LITELLM_EXCEPTION_TYPES,
MockException,
)
from .budget_manager import BudgetManager
from .proxy.proxy_cli import run_server
Expand Down
25 changes: 25 additions & 0 deletions litellm/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -723,3 +723,28 @@ def __init__(self, message, model, llm_provider):
super().__init__(
self.message, f"{self.model}"
) # Call the base class constructor with the parameters it needs


class MockException(openai.APIError):
# used for testing
def __init__(
self,
status_code,
message,
llm_provider,
model,
request: Optional[httpx.Request] = None,
litellm_debug_info: Optional[str] = None,
max_retries: Optional[int] = None,
num_retries: Optional[int] = None,
):
self.status_code = status_code
self.message = "litellm.MockException: {}".format(message)
self.llm_provider = llm_provider
self.model = model
self.litellm_debug_info = litellm_debug_info
self.max_retries = max_retries
self.num_retries = num_retries
if request is None:
request = httpx.Request(method="POST", url="https://api.openai.com/v1")
super().__init__(self.message, request=request, body=None) # type: ignore
2 changes: 1 addition & 1 deletion litellm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,7 +479,7 @@ def mock_completion(
if isinstance(mock_response, Exception):
if isinstance(mock_response, openai.APIError):
raise mock_response
raise litellm.APIError(
raise litellm.MockException(
status_code=getattr(mock_response, "status_code", 500), # type: ignore
message=getattr(mock_response, "text", str(mock_response)),
llm_provider=getattr(mock_response, "llm_provider", custom_llm_provider or "openai"), # type: ignore
Expand Down
15 changes: 8 additions & 7 deletions litellm/tests/test_rules.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
#### What this tests ####
# This tests setting rules before / after making llm api calls
import sys, os, time
import traceback, asyncio
import asyncio
import os
import sys
import time
import traceback

import pytest

sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import litellm
from litellm import completion, acompletion
from litellm import acompletion, completion


def my_pre_call_rule(input: str):
Expand Down Expand Up @@ -126,10 +130,7 @@ def test_post_call_rule_streaming():
print("Got exception", e)
print(type(e))
print(vars(e))
assert (
"OpenAIException - This violates LiteLLM Proxy Rules. Response too short"
in e.message
)
assert "This violates LiteLLM Proxy Rules. Response too short" in e.message


@pytest.mark.asyncio
Expand Down

0 comments on commit 1206b0b

Please sign in to comment.