Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor chat bot #5

Merged
merged 6 commits into from
Apr 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ pip install unifyai
## Basic Usage
```python
import os
from unifyai import Unify
from unify import Unify
unify = Unify(
# This is the default and optional to include.
api_key=os.environ.get("UNIFY_KEY"),
Expand Down Expand Up @@ -88,7 +88,7 @@ To use the AsyncUnify client, simply import `AsyncUnify` instead
of `Unify` and use `await` with the `.generate` function.

```python
from unifyai import AsyncUnify
from unify import AsyncUnify
import os
import asyncio
async_unify = AsyncUnify(
Expand All @@ -110,7 +110,7 @@ You can enable streaming responses by setting `stream=True` in the `.generate` f

```python
import os
from unifyai import Unify
from unify import Unify
unify = Unify(
# This is the default and optional to include.
api_key=os.environ.get("UNIFY_KEY"),
Expand All @@ -124,7 +124,7 @@ for chunk in stream:
It works in exactly the same way with Async clients.

```python
from unifyai import AsyncUnify
from unify import AsyncUnify
import os
import asyncio
async_unify = AsyncUnify(
Expand Down Expand Up @@ -152,7 +152,7 @@ As evidenced by our [benchmarks](https://unify.ai/hub), the optimal provider for

```python
import os
from unifyai import Unify
from unify import Unify
unify = Unify(
# This is the default and optional to include.
api_key=os.environ.get("UNIFY_KEY"),
Expand All @@ -172,7 +172,7 @@ Dynamic routing works with both Synchronous and Asynchronous clients. For more i
Our `ChatBot` allows you to start an interactive chat session with any of our supported llm endpoints with only a few lines of code:

```python
from unifyai import ChatBot
from unify import ChatBot
agent = ChatBot(
# This is the default and optional to include.
api_key=os.environ.get("UNIFY_KEY"),
Expand Down
5 changes: 3 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
[tool.poetry]
name = "unifyai"
version = "0.7.2"
packages = [{include = "unify"}]
version = "0.8.0"
readme = "README.md"
description = "A Python package for interacting with the Unify API"
authors = ["Unify <[email protected]>"]
repository = "https://github.com/unifyai/unify-llm-python"
repository = "https://github.com/unifyai/unify"

[tool.poetry.dependencies]
python = "^3.9"
Expand Down
4 changes: 4 additions & 0 deletions unify/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
"""Unify python module."""

from unify.clients import AsyncUnify, Unify # noqa: F403
from unify.chat import ChatBot # noqa: F403
232 changes: 232 additions & 0 deletions unify/chat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,232 @@
import sys

from typing import Optional
from unify.clients import Unify


class ChatBot: # noqa: WPS338
"""Agent class represents an LLM chat agent."""

def __init__(
self,
api_key: Optional[str] = None,
endpoint: Optional[str] = None,
model: Optional[str] = None,
provider: Optional[str] = None,
) -> None:
"""
Initializes the ChatBot object.

Args:
api_key (str, optional): API key for accessing the Unify API.
If None, it attempts to retrieve the API key from the
environment variable UNIFY_KEY.
Defaults to None.

endpoint (str, optional): Endpoint name in OpenAI API format:
<uploaded_by>/<model_name>@<provider_name>
Defaults to None.

model (str, optional): Name of the model. If None,
endpoint must be provided.

provider (str, optional): Name of the provider. If None,
endpoint must be provided.
Raises:
UnifyError: If the API key is missing.
"""
self._message_history = []
self._paused = False
self._client = Unify(
api_key=api_key,
endpoint=endpoint,
model=model,
provider=provider,
)

@property
def client(self) -> str:
"""
Get the client object. # noqa: DAR201.

Returns:
str: The model name.
"""
return self._client

def set_client(self, value: Unify) -> None:
"""
Set the model name. # noqa: DAR101.

Args:
value: The unify client.
"""
if isinstance(value, Unify):
self._client = value
else:
raise UnifyError("Invalid client!")


@property
def model(self) -> str:
"""
Get the model name. # noqa: DAR201.

Returns:
str: The model name.
"""
return self._client.model

def set_model(self, value: str) -> None:
"""
Set the model name. # noqa: DAR101.

Args:
value (str): The model name.
"""
self._client.set_model(value)
if self._client.provider:
self._client.set_endpoint("@".join([value, self._client.provider]))
else:
mode = self._client.endpoint.split("@")[1]
self._client.set_endpoint("@".join([value, mode]))

@property
def provider(self) -> Optional[str]:
"""
Get the provider name. # noqa :DAR201.

Returns:
str: The provider name.
"""
return self._client.provider

def set_provider(self, value: str) -> None:
"""
Set the provider name. # noqa: DAR101.

Args:
value (str): The provider name.
"""
self._client.set_provider(value)
self._client.set_endpoint("@".join([self._model, value]))

@property
def endpoint(self) -> str:
"""
Get the endpoint name. # noqa: DAR201.

Returns:
str: The endpoint name.
"""
return self._client.endpoint

def set_endpoint(self, value: str) -> None:
"""
Set the model name. # noqa: DAR101.

Args:
value (str): The endpoint name.
"""
self._client.set_endpoint(value)
self._client.set_model(value.split("@")[0])
self._client.set_provider(value.split("@")[1])

def _get_credits(self):
"""
Retrieves the current credit balance from associated with the UNIFY account.

Returns:
float: Current credit balance.
"""
return self._client.get_credit_balance()

def _process_input(self, inp: str, show_credits: bool, show_provider: bool):
"""
Processes the user input to generate AI response.

Args:
inp (str): User input message.
show_credits (bool): Whether to show credit consumption.
show_credits (bool): Whether to show provider used.

Yields:
str: Generated AI response chunks.
"""
self._update_message_history(role = "user", content = inp)
initial_credit_balance = self._get_credits()
stream = self._client.generate(
messages=self._message_history,
stream=True,
)
words = ""
for chunk in stream:
words += chunk
yield chunk

self._update_message_history(
role = "assistant",
content = words,
)
final_credit_balance = self._get_credits()
if show_credits:
sys.stdout.write(
"\n(spent {:.6f} credits)".format(
initial_credit_balance - final_credit_balance,
),
)
if show_provider:
sys.stdout.write("\n(provider: {})".format(self._client.provider))

def _update_message_history(self, role: str, content: str):
"""
Updates message history with user input.

Args:
role (str): Either "assistant" or "user".
content (str): User input message.
"""
self._message_history.append(
{
"role": role,
"content": content,
},
)

def clear_chat_history(self):
"""Clears the chat history."""
self._message_history.clear()

def run(self, show_credits: bool = False, show_provider: bool = False):
"""
Starts the chat interaction loop.

Args:
show_credits (bool, optional): Whether to show credit consumption.
Defaults to False.
show_provider (bool, optional): Whether to show the provider used.
Defaults to False.
"""
if not self._paused:
sys.stdout.write(
"Let's have a chat. (Enter `pause` to pause and `quit` to exit)\n",
)
self.clear_chat_history()
else:
sys.stdout.write(
"Welcome back! (Remember, enter `pause` to pause and `quit` to exit)\n",
)
self._paused = False
while True:
sys.stdout.write("> ")
inp = input()
if inp == "quit":
self.clear_chat_history()
break
elif inp == "pause":
self._paused = True
break
for word in self._process_input(inp, show_credits, show_provider):
sys.stdout.write(word)
sys.stdout.flush()
sys.stdout.write("\n")
14 changes: 7 additions & 7 deletions unifyai/clients.py → unify/clients.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

import openai
import requests
from unifyai.exceptions import BadRequestError, UnifyError, status_error_map
from unifyai.utils import ( # noqa:WPS450
from unify.exceptions import BadRequestError, UnifyError, status_error_map
from unify.utils import ( # noqa:WPS450
_available_dynamic_modes,
_validate_api_key,
_validate_endpoint,
Expand Down Expand Up @@ -205,7 +205,7 @@ def _generate_stream(
)
for chunk in chat_completion:
content = chunk.choices[0].delta.content # type: ignore[union-attr]
self._provider = chunk.model.split("@")[-1] # type: ignore[union-attr]
self.set_provider(chunk.model.split("@")[-1]) # type: ignore[union-attr]
if content is not None:
yield content
except openai.APIStatusError as e:
Expand All @@ -222,9 +222,9 @@ def _generate_non_stream(
messages=messages, # type: ignore[arg-type]
stream=False,
)
self._provider = chat_completion.model.split( # type: ignore[union-attr]
self.set_provider(chat_completion.model.split( # type: ignore[union-attr]
"@",
)[-1]
)[-1])

return chat_completion.choices[0].message.content.strip(" ") # type: ignore # noqa: E501, WPS219
except openai.APIStatusError as e:
Expand Down Expand Up @@ -429,7 +429,7 @@ async def _generate_stream(
stream=True,
)
async for chunk in async_stream: # type: ignore[union-attr]
self._provider = chunk.model.split("@")[-1]
self.set_provider(chunk.model.split("@")[-1])
yield chunk.choices[0].delta.content or ""
except openai.APIStatusError as e:
raise status_error_map[e.status_code](e.message) from None
Expand All @@ -445,7 +445,7 @@ async def _generate_non_stream(
messages=messages, # type: ignore[arg-type]
stream=False,
)
self._provider = async_response.model.split("@")[-1] # type: ignore
self.set_provider(async_response.model.split("@")[-1]) # type: ignore
return async_response.choices[0].message.content.strip(" ") # type: ignore # noqa: E501, WPS219
except openai.APIStatusError as e:
raise status_error_map[e.status_code](e.message) from None
File renamed without changes.
4 changes: 2 additions & 2 deletions unifyai/tests.py → unify/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
from types import AsyncGeneratorType, GeneratorType
from unittest.mock import MagicMock, patch

from unifyai.clients import AsyncUnify, Unify
from unifyai.exceptions import AuthenticationError, UnifyError
from unify.clients import AsyncUnify, Unify
from unify.exceptions import AuthenticationError, UnifyError


class TestUnify(unittest.TestCase):
Expand Down
File renamed without changes.
4 changes: 0 additions & 4 deletions unifyai/__init__.py

This file was deleted.

Loading
Loading