Skip to content

Commit

Permalink
add functions to upload and evaluate datasets (#27)
Browse files Browse the repository at this point in the history
* add functions to upload and evaluate datasets

* fix test

* fix mypy

* black formatting
  • Loading branch information
guillesanbri authored Jul 17, 2024
1 parent 9ae54cc commit e71b2f3
Show file tree
Hide file tree
Showing 4 changed files with 347 additions and 64 deletions.
14 changes: 13 additions & 1 deletion unify/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,16 @@

from unify.chat import ChatBot # noqa: F403
from unify.clients import AsyncUnify, Unify # noqa: F403
from unify.utils import list_endpoints, list_models, list_providers
from unify.utils import (
list_endpoints,
list_models,
list_providers,
upload_dataset_from_file,
upload_dataset_from_dictionary,
delete_dataset,
download_dataset,
list_datasets,
evaluate,
delete_evaluation,
list_evaluations,
)
54 changes: 36 additions & 18 deletions unify/clients.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,8 +146,8 @@ def generate( # noqa: WPS234, WPS211
max_tokens (Optional[int]): The max number of output tokens.
Defaults to the provider's default max_tokens when the value is None.
temperature (Optional[float]): What sampling temperature to use, between 0 and 2.
Higher values like 0.8 will make the output more random,
temperature (Optional[float]): What sampling temperature to use, between 0 and 2.
Higher values like 0.8 will make the output more random,
while lower values like 0.2 will make it more focused and deterministic.
Defaults to the provider's default max_tokens when the value is None.
Expand Down Expand Up @@ -176,14 +176,20 @@ def generate( # noqa: WPS234, WPS211
raise UnifyError("You must provider either the user_prompt or messages!")

if stream:
return self._generate_stream(contents, self._endpoint,
max_tokens=max_tokens,
temperature=temperature,
stop=stop)
return self._generate_non_stream(contents, self._endpoint,
max_tokens=max_tokens,
temperature=temperature,
stop=stop)
return self._generate_stream(
contents,
self._endpoint,
max_tokens=max_tokens,
temperature=temperature,
stop=stop,
)
return self._generate_non_stream(
contents,
self._endpoint,
max_tokens=max_tokens,
temperature=temperature,
stop=stop,
)

def get_credit_balance(self) -> float:
# noqa: DAR201, DAR401
Expand Down Expand Up @@ -227,7 +233,7 @@ def _generate_stream(
temperature=temperature,
stop=stop,
stream=True,
extra_body={"signature": "package"}
extra_body={"signature": "package"},
)
for chunk in chat_completion:
content = chunk.choices[0].delta.content # type: ignore[union-attr]
Expand All @@ -253,7 +259,7 @@ def _generate_non_stream(
temperature=temperature,
stop=stop,
stream=False,
extra_body={"signature": "package"}
extra_body={"signature": "package"},
)
self.set_provider(
chat_completion.model.split( # type: ignore[union-attr]
Expand Down Expand Up @@ -431,8 +437,8 @@ async def generate( # noqa: WPS234, WPS211
max_tokens (Optional[int]): The max number of output tokens, defaults
to the provider's default max_tokens when the value is None.
temperature (Optional[float]): What sampling temperature to use, between 0 and 2.
Higher values like 0.8 will make the output more random,
temperature (Optional[float]): What sampling temperature to use, between 0 and 2.
Higher values like 0.8 will make the output more random,
while lower values like 0.2 will make it more focused and deterministic.
Defaults to the provider's default max_tokens when the value is None.
Expand Down Expand Up @@ -462,8 +468,20 @@ async def generate( # noqa: WPS234, WPS211
raise UnifyError("You must provide either the user_prompt or messages!")

if stream:
return self._generate_stream(contents, self._endpoint, max_tokens=max_tokens, stop=stop, temperature=temperature)
return await self._generate_non_stream(contents, self._endpoint, max_tokens=max_tokens, stop=stop, temperature=temperature)
return self._generate_stream(
contents,
self._endpoint,
max_tokens=max_tokens,
stop=stop,
temperature=temperature,
)
return await self._generate_non_stream(
contents,
self._endpoint,
max_tokens=max_tokens,
stop=stop,
temperature=temperature,
)

async def _generate_stream(
self,
Expand All @@ -481,7 +499,7 @@ async def _generate_stream(
temperature=temperature,
stop=stop,
stream=True,
extra_body={"signature": "package"}
extra_body={"signature": "package"},
)
async for chunk in async_stream: # type: ignore[union-attr]
self.set_provider(chunk.model.split("@")[-1])
Expand All @@ -505,7 +523,7 @@ async def _generate_non_stream(
temperature=temperature,
stop=stop,
stream=False,
extra_body={"signature": "package"}
extra_body={"signature": "package"},
)
self.set_provider(async_response.model.split("@")[-1]) # type: ignore
return async_response.choices[0].message.content.strip(" ") # type: ignore # noqa: E501, WPS219
Expand Down
18 changes: 11 additions & 7 deletions unify/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,15 @@ def test_invalid_api_key_raises_authentication_error(self) -> None:
with self.assertRaises(AuthenticationError):
unify = Unify(
api_key="invalid_api_key",
endpoint="llama-2-7b-chat@anyscale",
endpoint="llama-3-8b-chat@together-ai",
)
unify.generate(user_prompt="hello")

@patch("os.environ.get", return_value=None)
def test_missing_api_key_raises_key_error(self, mock_get: MagicMock) -> None:
# Initializing Unify without providing API key should raise KeyError
with self.assertRaises(KeyError):
Unify(endpoint="llama-2-7b-chat@anyscale")
Unify(endpoint="llama-3-8b-chat@together-ai")

def test_incorrect_model_name_raises_internal_server_error(self) -> None:
# Provide incorrect model name
Expand All @@ -34,15 +34,19 @@ def test_incorrect_model_name_raises_internal_server_error(self) -> None:

def test_generate_returns_string_when_stream_false(self) -> None:
# Instantiate Unify with a valid API key
unify = Unify(api_key=self.valid_api_key, endpoint="llama-2-7b-chat@anyscale")
unify = Unify(
api_key=self.valid_api_key, endpoint="llama-3-8b-chat@together-ai"
)
# Call generate with stream=False
result = unify.generate(user_prompt="hello", stream=False)
# Assert that the result is a string
self.assertIsInstance(result, str)

def test_generate_returns_generator_when_stream_true(self) -> None:
# Instantiate Unify with a valid API key
unify = Unify(api_key=self.valid_api_key, endpoint="llama-2-7b-chat@anyscale")
unify = Unify(
api_key=self.valid_api_key, endpoint="llama-3-8b-chat@together-ai"
)
# Call generate with stream=True
result = unify.generate(user_prompt="hello", stream=True)
# Assert that the result is a generator
Expand All @@ -59,7 +63,7 @@ async def test_invalid_api_key_raises_authentication_error(self) -> None:
with self.assertRaises(AuthenticationError):
async_unify = AsyncUnify(
api_key="invalid_api_key",
endpoint="llama-2-7b-chat@anyscale",
endpoint="llama-3-8b-chat@together-ai",
)
await async_unify.generate(user_prompt="hello")

Expand All @@ -80,7 +84,7 @@ async def test_generate_returns_string_when_stream_false(self) -> None:
# Instantiate AsyncUnify with a valid API key
async_unify = AsyncUnify(
api_key=self.valid_api_key,
endpoint="llama-2-7b-chat@anyscale",
endpoint="llama-3-8b-chat@together-ai",
)
# Call generate with stream=False
result = await async_unify.generate(user_prompt="hello", stream=False)
Expand All @@ -91,7 +95,7 @@ async def test_generate_returns_generator_when_stream_true(self) -> None:
# Instantiate AsyncUnify with a valid API key
async_unify = AsyncUnify(
api_key=self.valid_api_key,
endpoint="llama-2-7b-chat@anyscale",
endpoint="llama-3-8b-chat@together-ai",
)
# Call generate with stream=True
result = await async_unify.generate(user_prompt="hello", stream=True)
Expand Down
Loading

0 comments on commit e71b2f3

Please sign in to comment.