Skip to content

Commit

Permalink
Merge pull request #4701 from BerriAI/litellm_rpm_support_passthrough
Browse files Browse the repository at this point in the history
Support key-rpm limits on pass-through endpoints
  • Loading branch information
krrishdholakia authored Jul 13, 2024
2 parents 1206b0b + a6deb9c commit bc58e44
Show file tree
Hide file tree
Showing 11 changed files with 177 additions and 27 deletions.
5 changes: 4 additions & 1 deletion docs/my-website/docs/proxy/pass_through.md
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,8 @@ POST /api/public/ingestion HTTP/1.1" 207 Multi-Status
Use this if you want the pass through endpoint to honour LiteLLM keys/authentication
This also enforces the key's rpm limits on pass-through endpoints.
Usage - set `auth: true` on the config
```yaml
general_settings:
Expand Down Expand Up @@ -361,4 +363,5 @@ curl --location 'http://0.0.0.0:4000/v1/messages' \
{"role": "user", "content": "Hello, world"}
]
}'
```
```
1 change: 1 addition & 0 deletions litellm/integrations/custom_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ async def async_pre_call_hook(
"image_generation",
"moderation",
"audio_transcription",
"pass_through_endpoint",
],
) -> Optional[
Union[Exception, str, dict]
Expand Down
13 changes: 7 additions & 6 deletions litellm/proxy/_new_secret_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,12 @@ model_list:


general_settings:
alerting: ["slack"]
alerting_threshold: 10
master_key: sk-1234
pass_through_endpoints:
- path: "/v1/test-messages" # route you want to add to LiteLLM Proxy Server
target: litellm.adapters.anthropic_adapter.anthropic_adapter # URL this route should forward requests to
headers: # headers to forward to this URL
litellm_user_api_key: "x-my-test-key"
- path: "/v1/rerank"
target: "https://api.cohere.com/v1/rerank"
auth: true # 👈 Key change to use LiteLLM Auth / Keys
headers:
Authorization: "bearer os.environ/COHERE_API_KEY"
content-type: application/json
accept: application/json
1 change: 0 additions & 1 deletion litellm/proxy/auth/user_api_key_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,6 @@ async def user_api_key_auth(
anthropic_api_key_header
),
) -> UserAPIKeyAuth:

from litellm.proxy.proxy_server import (
allowed_routes_check,
common_checks,
Expand Down
8 changes: 5 additions & 3 deletions litellm/proxy/custom_callbacks1.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from litellm.integrations.custom_logger import CustomLogger
from typing import Literal, Optional

import litellm
from litellm.proxy.proxy_server import UserAPIKeyAuth, DualCache
from typing import Optional, Literal
from litellm.integrations.custom_logger import CustomLogger
from litellm.proxy.proxy_server import DualCache, UserAPIKeyAuth


# This file includes the custom callbacks for LiteLLM Proxy
Expand All @@ -27,6 +28,7 @@ async def async_pre_call_hook(
"image_generation",
"moderation",
"audio_transcription",
"pass_through_endpoint",
],
):
return data
Expand Down
1 change: 1 addition & 0 deletions litellm/proxy/hooks/dynamic_rate_limiter.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,7 @@ async def async_pre_call_hook(
"image_generation",
"moderation",
"audio_transcription",
"pass_through_endpoint",
],
) -> Optional[
Union[Exception, str, dict]
Expand Down
16 changes: 10 additions & 6 deletions litellm/proxy/hooks/parallel_request_limiter.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
import sys
import traceback
from datetime import datetime
from typing import Optional
import litellm, traceback, sys
from litellm.caching import DualCache
from litellm.proxy._types import UserAPIKeyAuth
from litellm.integrations.custom_logger import CustomLogger

from fastapi import HTTPException
from litellm._logging import verbose_proxy_logger

import litellm
from litellm import ModelResponse
from datetime import datetime
from litellm._logging import verbose_proxy_logger
from litellm.caching import DualCache
from litellm.integrations.custom_logger import CustomLogger
from litellm.proxy._types import UserAPIKeyAuth


class _PROXY_MaxParallelRequestsHandler(CustomLogger):
Expand Down
96 changes: 86 additions & 10 deletions litellm/proxy/pass_through_endpoints/pass_through_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import json
import traceback
from base64 import b64encode
from typing import Optional

import httpx
from fastapi import (
Expand All @@ -22,8 +23,6 @@
from litellm.proxy._types import ProxyException, UserAPIKeyAuth
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth

async_client = httpx.AsyncClient()


async def set_env_variables_in_header(custom_headers: dict):
"""
Expand Down Expand Up @@ -240,21 +239,44 @@ async def chat_completion_pass_through_endpoint(
)


async def pass_through_request(request: Request, target: str, custom_headers: dict):
async def pass_through_request(
request: Request,
target: str,
custom_headers: dict,
user_api_key_dict: UserAPIKeyAuth,
):
try:
import time
import uuid

from litellm.litellm_core_utils.litellm_logging import Logging
from litellm.proxy.proxy_server import proxy_logging_obj

url = httpx.URL(target)
headers = custom_headers

request_body = await request.body()
_parsed_body = ast.literal_eval(request_body.decode("utf-8"))
body_str = request_body.decode()
try:
_parsed_body = ast.literal_eval(body_str)
except:
_parsed_body = json.loads(body_str)

verbose_proxy_logger.debug(
"Pass through endpoint sending request to \nURL {}\nheaders: {}\nbody: {}\n".format(
url, headers, _parsed_body
)
)

### CALL HOOKS ### - modify incoming data / reject request before calling the model
_parsed_body = await proxy_logging_obj.pre_call_hook(
user_api_key_dict=user_api_key_dict,
data=_parsed_body,
call_type="pass_through_endpoint",
)

async_client = httpx.AsyncClient()

response = await async_client.request(
method=request.method,
url=url,
Expand All @@ -267,15 +289,56 @@ async def pass_through_request(request: Request, target: str, custom_headers: di
raise HTTPException(status_code=response.status_code, detail=response.text)

content = await response.aread()

## LOG SUCCESS
start_time = time.time()
end_time = time.time()
# create logging object
logging_obj = Logging(
model="unknown",
messages=[{"role": "user", "content": "no-message-pass-through-endpoint"}],
stream=False,
call_type="pass_through_endpoint",
start_time=start_time,
litellm_call_id=str(uuid.uuid4()),
function_id="1245",
)
# done for supporting 'parallel_request_limiter.py' with pass-through endpoints
kwargs = {
"litellm_params": {
"metadata": {
"user_api_key": user_api_key_dict.api_key,
"user_api_key_user_id": user_api_key_dict.user_id,
"user_api_key_team_id": user_api_key_dict.team_id,
"user_api_key_end_user_id": user_api_key_dict.user_id,
}
},
"call_type": "pass_through_endpoint",
}
logging_obj.update_environment_variables(
model="unknown",
user="unknown",
optional_params={},
litellm_params=kwargs["litellm_params"],
call_type="pass_through_endpoint",
)

await logging_obj.async_success_handler(
result="",
start_time=start_time,
end_time=end_time,
cache_hit=False,
)

return Response(
content=content,
status_code=response.status_code,
headers=dict(response.headers),
)
except Exception as e:
verbose_proxy_logger.error(
"litellm.proxy.proxy_server.pass_through_endpoint(): Exception occured - {}".format(
str(e)
"litellm.proxy.proxy_server.pass_through_endpoint(): Exception occured - {}\n{}".format(
str(e), traceback.format_exc()
)
)
verbose_proxy_logger.debug(traceback.format_exc())
Expand All @@ -296,7 +359,9 @@ async def pass_through_request(request: Request, target: str, custom_headers: di
)


def create_pass_through_route(endpoint, target: str, custom_headers=None):
def create_pass_through_route(
endpoint, target: str, custom_headers: Optional[dict] = None
):
# check if target is an adapter.py or a url
import uuid

Expand Down Expand Up @@ -325,8 +390,17 @@ async def endpoint_func(
except Exception:
verbose_proxy_logger.warning("Defaulting to target being a url.")

async def endpoint_func(request: Request): # type: ignore
return await pass_through_request(request, target, custom_headers)
async def endpoint_func(
request: Request,
fastapi_response: Response,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
return await pass_through_request(
request=request,
target=target,
custom_headers=custom_headers or {},
user_api_key_dict=user_api_key_dict,
)

return endpoint_func

Expand All @@ -349,7 +423,9 @@ async def initialize_pass_through_endpoints(pass_through_endpoints: list):
if _auth is not None and str(_auth).lower() == "true":
if premium_user is not True:
raise ValueError(
f"Error Setting Authentication on Pass Through Endpoint: {CommonProxyErrors.not_premium_user}"
"Error Setting Authentication on Pass Through Endpoint: {}".format(
CommonProxyErrors.not_premium_user.value
)
)
_dependencies = [Depends(user_api_key_auth)]
LiteLLMRoutes.openai_routes.value.append(_path)
Expand Down
1 change: 1 addition & 0 deletions litellm/proxy/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,7 @@ async def pre_call_hook(
"image_generation",
"moderation",
"audio_transcription",
"pass_through_endpoint",
],
) -> dict:
"""
Expand Down
61 changes: 61 additions & 0 deletions litellm/tests/test_pass_through_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,67 @@ async def test_pass_through_endpoint_rerank(client):
assert response.status_code == 200


@pytest.mark.parametrize(
"auth, rpm_limit, expected_error_code",
[(True, 0, 429), (True, 1, 200), (False, 0, 401)],
)
@pytest.mark.asyncio
async def test_pass_through_endpoint_rpm_limit(auth, expected_error_code, rpm_limit):
client = TestClient(app)
import litellm
from litellm.proxy._types import UserAPIKeyAuth
from litellm.proxy.proxy_server import ProxyLogging, hash_token, user_api_key_cache

mock_api_key = "sk-my-test-key"
cache_value = UserAPIKeyAuth(token=hash_token(mock_api_key), rpm_limit=rpm_limit)

_cohere_api_key = os.environ.get("COHERE_API_KEY")

user_api_key_cache.set_cache(key=hash_token(mock_api_key), value=cache_value)

proxy_logging_obj = ProxyLogging(user_api_key_cache=user_api_key_cache)
proxy_logging_obj._init_litellm_callbacks()

setattr(litellm.proxy.proxy_server, "user_api_key_cache", user_api_key_cache)
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
setattr(litellm.proxy.proxy_server, "prisma_client", "FAKE-VAR")
setattr(litellm.proxy.proxy_server, "proxy_logging_obj", proxy_logging_obj)

# Define a pass-through endpoint
pass_through_endpoints = [
{
"path": "/v1/rerank",
"target": "https://api.cohere.com/v1/rerank",
"auth": auth,
"headers": {"Authorization": f"bearer {_cohere_api_key}"},
}
]

# Initialize the pass-through endpoint
await initialize_pass_through_endpoints(pass_through_endpoints)

_json_data = {
"model": "rerank-english-v3.0",
"query": "What is the capital of the United States?",
"top_n": 3,
"documents": [
"Carson City is the capital city of the American state of Nevada."
],
}

# Make a request to the pass-through endpoint
response = client.post(
"/v1/rerank",
json=_json_data,
headers={"Authorization": "Bearer {}".format(mock_api_key)},
)

print("JSON response: ", _json_data)

# Assert the response
assert response.status_code == expected_error_code


@pytest.mark.asyncio
async def test_pass_through_endpoint_anthropic(client):
import litellm
Expand Down
1 change: 1 addition & 0 deletions litellm/tests/test_proxy_reject_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ async def async_pre_call_hook(
"image_generation",
"moderation",
"audio_transcription",
"pass_through_endpoint",
],
):
raise HTTPException(
Expand Down

0 comments on commit bc58e44

Please sign in to comment.