Skip to content

Commit

Permalink
Merge pull request #4532 from oobabooga/dev
Browse files Browse the repository at this point in the history
Merge dev branch
  • Loading branch information
oobabooga authored Nov 9, 2023
2 parents 4da00b6 + effb3ae commit f7534b2
Showing 1 changed file with 12 additions and 6 deletions.
18 changes: 12 additions & 6 deletions extensions/openai/script.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import json
import os
import traceback
Expand Down Expand Up @@ -46,6 +47,9 @@
}


streaming_semaphore = asyncio.Semaphore(1)


def verify_api_key(authorization: str = Header(None)) -> None:
expected_api_key = shared.args.api_key
if expected_api_key and (authorization is None or authorization != f"Bearer {expected_api_key}"):
Expand Down Expand Up @@ -84,9 +88,10 @@ async def openai_completions(request: Request, request_data: CompletionRequest):

if request_data.stream:
async def generator():
response = OAIcompletions.stream_completions(to_dict(request_data), is_legacy=is_legacy)
for resp in response:
yield {"data": json.dumps(resp)}
async with streaming_semaphore:
response = OAIcompletions.stream_completions(to_dict(request_data), is_legacy=is_legacy)
for resp in response:
yield {"data": json.dumps(resp)}

return EventSourceResponse(generator()) # SSE streaming

Expand All @@ -102,9 +107,10 @@ async def openai_chat_completions(request: Request, request_data: ChatCompletion

if request_data.stream:
async def generator():
response = OAIcompletions.stream_chat_completions(to_dict(request_data), is_legacy=is_legacy)
for resp in response:
yield {"data": json.dumps(resp)}
async with streaming_semaphore:
response = OAIcompletions.stream_chat_completions(to_dict(request_data), is_legacy=is_legacy)
for resp in response:
yield {"data": json.dumps(resp)}

return EventSourceResponse(generator()) # SSE streaming

Expand Down

0 comments on commit f7534b2

Please sign in to comment.