diff --git a/unify/clients.py b/unify/clients.py index c85f986..a204248 100644 --- a/unify/clients.py +++ b/unify/clients.py @@ -126,6 +126,7 @@ def generate( # noqa: WPS234, WPS211 user_prompt: Optional[str] = None, system_prompt: Optional[str] = None, messages: Optional[List[Dict[str, str]]] = None, + max_tokens: Optional[int] = None, stream: bool = False, ) -> Union[Generator[str, None, None], str]: # noqa: DAR101, DAR201, DAR401 """Generate content using the Unify API. @@ -140,6 +141,9 @@ def generate( # noqa: WPS234, WPS211 messages (List[Dict[str, str]]): A list of dictionaries containing the conversation history. If provided, user_prompt must be None. + max_tokens (Optional[int]): The max number of output tokens, defaults + to the provider's default max_tokens when the value is None. + stream (bool): If True, generates content as a stream. If False, generates content as a single response. Defaults to False. @@ -164,8 +168,8 @@ def generate( # noqa: WPS234, WPS211 raise UnifyError("You must provider either the user_prompt or messages!") if stream: - return self._generate_stream(contents, self._endpoint) - return self._generate_non_stream(contents, self._endpoint) + return self._generate_stream(contents, self._endpoint, max_tokens=max_tokens) + return self._generate_non_stream(contents, self._endpoint, max_tokens=max_tokens) def get_credit_balance(self) -> float: # noqa: DAR201, DAR401 @@ -197,11 +201,13 @@ def _generate_stream( self, messages: List[Dict[str, str]], endpoint: str, + max_tokens: Optional[int] = None ) -> Generator[str, None, None]: try: chat_completion = self.client.chat.completions.create( model=endpoint, messages=messages, # type: ignore[arg-type] + max_tokens=max_tokens, stream=True, ) for chunk in chat_completion: @@ -216,11 +222,13 @@ def _generate_non_stream( self, messages: List[Dict[str, str]], endpoint: str, + max_tokens: Optional[int] = None ) -> str: try: chat_completion = self.client.chat.completions.create( model=endpoint, messages=messages, # type: ignore[arg-type] + max_tokens=max_tokens, stream=False, ) self.set_provider( @@ -379,6 +387,7 @@ async def generate( # noqa: WPS234, WPS211 user_prompt: Optional[str] = None, system_prompt: Optional[str] = None, messages: Optional[List[Dict[str, str]]] = None, + max_tokens: Optional[int] = None, stream: bool = False, ) -> Union[AsyncGenerator[str, None], str]: # noqa: DAR101, DAR201, DAR401 """Generate content asynchronously using the Unify API. @@ -393,6 +402,9 @@ async def generate( # noqa: WPS234, WPS211 messages (List[Dict[str, str]]): A list of dictionaries containing the conversation history. If provided, user_prompt must be None. + max_tokens (Optional[int]): The max number of output tokens, defaults + to the provider's default max_tokens when the value is None. + stream (bool): If True, generates content as a stream. If False, generates content as a single response. Defaults to False. @@ -417,18 +429,20 @@ async def generate( # noqa: WPS234, WPS211 raise UnifyError("You must provide either the user_prompt or messages!") if stream: - return self._generate_stream(contents, self._endpoint) - return await self._generate_non_stream(contents, self._endpoint) + return self._generate_stream(contents, self._endpoint, max_tokens=max_tokens) + return await self._generate_non_stream(contents, self._endpoint, max_tokens=max_tokens) async def _generate_stream( self, messages: List[Dict[str, str]], endpoint: str, + max_tokens: Optional[int] = None, ) -> AsyncGenerator[str, None]: try: async_stream = await self.client.chat.completions.create( model=endpoint, messages=messages, # type: ignore[arg-type] + max_tokens=max_tokens, stream=True, ) async for chunk in async_stream: # type: ignore[union-attr] @@ -441,11 +455,13 @@ async def _generate_non_stream( self, messages: List[Dict[str, str]], endpoint: str, + max_tokens: Optional[int] = None, ) -> str: try: async_response = await self.client.chat.completions.create( model=endpoint, messages=messages, # type: ignore[arg-type] + max_tokens=max_tokens, stream=False, ) self.set_provider(async_response.model.split("@")[-1]) # type: ignore