Skip to content

Commit

Permalink
feat(internal): handle streaming error (#119)
Browse files Browse the repository at this point in the history
Co-authored-by: Jordan Wu <[email protected]>
  • Loading branch information
jordan-wu-97 and jordan-wu-97 authored Sep 3, 2024
1 parent 89b9c14 commit 3722579
Showing 1 changed file with 76 additions and 3 deletions.
79 changes: 76 additions & 3 deletions src/groq/_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@

import httpx

from ._utils import extract_type_var_from_base
from ._utils import is_mapping, extract_type_var_from_base
from ._exceptions import APIError

if TYPE_CHECKING:
from ._client import Groq, AsyncGroq
Expand Down Expand Up @@ -57,7 +58,43 @@ def __stream__(self) -> Iterator[_T]:
for sse in iterator:
if sse.data.startswith("[DONE]"):
break
yield process_data(data=sse.json(), cast_to=cast_to, response=response)

if sse.event is None:
data = sse.json()
if is_mapping(data) and data.get("error"):
message = None
error = data.get("error")
if is_mapping(error):
message = error.get("message")
if not message or not isinstance(message, str):
message = "An error occurred during streaming"

raise APIError(
message=message,
request=self.response.request,
body=data["error"],
)

yield process_data(data=data, cast_to=cast_to, response=response)

else:
data = sse.json()

if sse.event == "error" and is_mapping(data) and data.get("error"):
message = None
error = data.get("error")
if is_mapping(error):
message = error.get("message")
if not message or not isinstance(message, str):
message = "An error occurred during streaming"

raise APIError(
message=message,
request=self.response.request,
body=data["error"],
)

yield process_data(data={"data": data, "event": sse.event}, cast_to=cast_to, response=response)

# Ensure the entire stream is consumed
for _sse in iterator:
Expand Down Expand Up @@ -123,7 +160,43 @@ async def __stream__(self) -> AsyncIterator[_T]:
async for sse in iterator:
if sse.data.startswith("[DONE]"):
break
yield process_data(data=sse.json(), cast_to=cast_to, response=response)

if sse.event is None:
data = sse.json()
if is_mapping(data) and data.get("error"):
message = None
error = data.get("error")
if is_mapping(error):
message = error.get("message")
if not message or not isinstance(message, str):
message = "An error occurred during streaming"

raise APIError(
message=message,
request=self.response.request,
body=data["error"],
)

yield process_data(data=data, cast_to=cast_to, response=response)

else:
data = sse.json()

if sse.event == "error" and is_mapping(data) and data.get("error"):
message = None
error = data.get("error")
if is_mapping(error):
message = error.get("message")
if not message or not isinstance(message, str):
message = "An error occurred during streaming"

raise APIError(
message=message,
request=self.response.request,
body=data["error"],
)

yield process_data(data={"data": data, "event": sse.event}, cast_to=cast_to, response=response)

# Ensure the entire stream is consumed
async for _sse in iterator:
Expand Down

0 comments on commit 3722579

Please sign in to comment.