Skip to content

Commit

Permalink
added max tokens keyword arg (#10)
Browse files Browse the repository at this point in the history
* add max tokens support

* add max tokens support
  • Loading branch information
Infrared1029 authored May 24, 2024
1 parent e8c1c6b commit 8346ef1
Showing 1 changed file with 20 additions and 4 deletions.
24 changes: 20 additions & 4 deletions unify/clients.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ def generate( # noqa: WPS234, WPS211
user_prompt: Optional[str] = None,
system_prompt: Optional[str] = None,
messages: Optional[List[Dict[str, str]]] = None,
max_tokens: Optional[int] = None,
stream: bool = False,
) -> Union[Generator[str, None, None], str]: # noqa: DAR101, DAR201, DAR401
"""Generate content using the Unify API.
Expand All @@ -140,6 +141,9 @@ def generate( # noqa: WPS234, WPS211
messages (List[Dict[str, str]]): A list of dictionaries containing the
conversation history. If provided, user_prompt must be None.
max_tokens (Optional[int]): The max number of output tokens, defaults
to the provider's default max_tokens when the value is None.
stream (bool): If True, generates content as a stream.
If False, generates content as a single response.
Defaults to False.
Expand All @@ -164,8 +168,8 @@ def generate( # noqa: WPS234, WPS211
raise UnifyError("You must provider either the user_prompt or messages!")

if stream:
return self._generate_stream(contents, self._endpoint)
return self._generate_non_stream(contents, self._endpoint)
return self._generate_stream(contents, self._endpoint, max_tokens=max_tokens)
return self._generate_non_stream(contents, self._endpoint, max_tokens=max_tokens)

def get_credit_balance(self) -> float:
# noqa: DAR201, DAR401
Expand Down Expand Up @@ -197,11 +201,13 @@ def _generate_stream(
self,
messages: List[Dict[str, str]],
endpoint: str,
max_tokens: Optional[int] = None
) -> Generator[str, None, None]:
try:
chat_completion = self.client.chat.completions.create(
model=endpoint,
messages=messages, # type: ignore[arg-type]
max_tokens=max_tokens,
stream=True,
)
for chunk in chat_completion:
Expand All @@ -216,11 +222,13 @@ def _generate_non_stream(
self,
messages: List[Dict[str, str]],
endpoint: str,
max_tokens: Optional[int] = None
) -> str:
try:
chat_completion = self.client.chat.completions.create(
model=endpoint,
messages=messages, # type: ignore[arg-type]
max_tokens=max_tokens,
stream=False,
)
self.set_provider(
Expand Down Expand Up @@ -379,6 +387,7 @@ async def generate( # noqa: WPS234, WPS211
user_prompt: Optional[str] = None,
system_prompt: Optional[str] = None,
messages: Optional[List[Dict[str, str]]] = None,
max_tokens: Optional[int] = None,
stream: bool = False,
) -> Union[AsyncGenerator[str, None], str]: # noqa: DAR101, DAR201, DAR401
"""Generate content asynchronously using the Unify API.
Expand All @@ -393,6 +402,9 @@ async def generate( # noqa: WPS234, WPS211
messages (List[Dict[str, str]]): A list of dictionaries containing the
conversation history. If provided, user_prompt must be None.
max_tokens (Optional[int]): The max number of output tokens, defaults
to the provider's default max_tokens when the value is None.
stream (bool): If True, generates content as a stream.
If False, generates content as a single response.
Defaults to False.
Expand All @@ -417,18 +429,20 @@ async def generate( # noqa: WPS234, WPS211
raise UnifyError("You must provide either the user_prompt or messages!")

if stream:
return self._generate_stream(contents, self._endpoint)
return await self._generate_non_stream(contents, self._endpoint)
return self._generate_stream(contents, self._endpoint, max_tokens=max_tokens)
return await self._generate_non_stream(contents, self._endpoint, max_tokens=max_tokens)

async def _generate_stream(
self,
messages: List[Dict[str, str]],
endpoint: str,
max_tokens: Optional[int] = None,
) -> AsyncGenerator[str, None]:
try:
async_stream = await self.client.chat.completions.create(
model=endpoint,
messages=messages, # type: ignore[arg-type]
max_tokens=max_tokens,
stream=True,
)
async for chunk in async_stream: # type: ignore[union-attr]
Expand All @@ -441,11 +455,13 @@ async def _generate_non_stream(
self,
messages: List[Dict[str, str]],
endpoint: str,
max_tokens: Optional[int] = None,
) -> str:
try:
async_response = await self.client.chat.completions.create(
model=endpoint,
messages=messages, # type: ignore[arg-type]
max_tokens=max_tokens,
stream=False,
)
self.set_provider(async_response.model.split("@")[-1]) # type: ignore
Expand Down

0 comments on commit 8346ef1

Please sign in to comment.