Skip to content

Commit

Permalink
black formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
guillesanbri committed Jul 17, 2024
1 parent 73f4af5 commit 70ac6a5
Show file tree
Hide file tree
Showing 4 changed files with 74 additions and 33 deletions.
6 changes: 3 additions & 3 deletions unify/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
from unify.chat import ChatBot # noqa: F403
from unify.clients import AsyncUnify, Unify # noqa: F403
from unify.utils import (
list_endpoints,
list_models,
list_providers,
list_endpoints,
list_models,
list_providers,
upload_dataset_from_file,
upload_dataset_from_dictionary,
delete_dataset,
Expand Down
54 changes: 36 additions & 18 deletions unify/clients.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,8 +146,8 @@ def generate( # noqa: WPS234, WPS211
max_tokens (Optional[int]): The max number of output tokens.
Defaults to the provider's default max_tokens when the value is None.
temperature (Optional[float]): What sampling temperature to use, between 0 and 2.
Higher values like 0.8 will make the output more random,
temperature (Optional[float]): What sampling temperature to use, between 0 and 2.
Higher values like 0.8 will make the output more random,
while lower values like 0.2 will make it more focused and deterministic.
Defaults to the provider's default max_tokens when the value is None.
Expand Down Expand Up @@ -176,14 +176,20 @@ def generate( # noqa: WPS234, WPS211
raise UnifyError("You must provider either the user_prompt or messages!")

if stream:
return self._generate_stream(contents, self._endpoint,
max_tokens=max_tokens,
temperature=temperature,
stop=stop)
return self._generate_non_stream(contents, self._endpoint,
max_tokens=max_tokens,
temperature=temperature,
stop=stop)
return self._generate_stream(
contents,
self._endpoint,
max_tokens=max_tokens,
temperature=temperature,
stop=stop,
)
return self._generate_non_stream(
contents,
self._endpoint,
max_tokens=max_tokens,
temperature=temperature,
stop=stop,
)

def get_credit_balance(self) -> float:
# noqa: DAR201, DAR401
Expand Down Expand Up @@ -227,7 +233,7 @@ def _generate_stream(
temperature=temperature,
stop=stop,
stream=True,
extra_body={"signature": "package"}
extra_body={"signature": "package"},
)
for chunk in chat_completion:
content = chunk.choices[0].delta.content # type: ignore[union-attr]
Expand All @@ -253,7 +259,7 @@ def _generate_non_stream(
temperature=temperature,
stop=stop,
stream=False,
extra_body={"signature": "package"}
extra_body={"signature": "package"},
)
self.set_provider(
chat_completion.model.split( # type: ignore[union-attr]
Expand Down Expand Up @@ -431,8 +437,8 @@ async def generate( # noqa: WPS234, WPS211
max_tokens (Optional[int]): The max number of output tokens, defaults
to the provider's default max_tokens when the value is None.
temperature (Optional[float]): What sampling temperature to use, between 0 and 2.
Higher values like 0.8 will make the output more random,
temperature (Optional[float]): What sampling temperature to use, between 0 and 2.
Higher values like 0.8 will make the output more random,
while lower values like 0.2 will make it more focused and deterministic.
Defaults to the provider's default max_tokens when the value is None.
Expand Down Expand Up @@ -462,8 +468,20 @@ async def generate( # noqa: WPS234, WPS211
raise UnifyError("You must provide either the user_prompt or messages!")

if stream:
return self._generate_stream(contents, self._endpoint, max_tokens=max_tokens, stop=stop, temperature=temperature)
return await self._generate_non_stream(contents, self._endpoint, max_tokens=max_tokens, stop=stop, temperature=temperature)
return self._generate_stream(
contents,
self._endpoint,
max_tokens=max_tokens,
stop=stop,
temperature=temperature,
)
return await self._generate_non_stream(
contents,
self._endpoint,
max_tokens=max_tokens,
stop=stop,
temperature=temperature,
)

async def _generate_stream(
self,
Expand All @@ -481,7 +499,7 @@ async def _generate_stream(
temperature=temperature,
stop=stop,
stream=True,
extra_body={"signature": "package"}
extra_body={"signature": "package"},
)
async for chunk in async_stream: # type: ignore[union-attr]
self.set_provider(chunk.model.split("@")[-1])
Expand All @@ -505,7 +523,7 @@ async def _generate_non_stream(
temperature=temperature,
stop=stop,
stream=False,
extra_body={"signature": "package"}
extra_body={"signature": "package"},
)
self.set_provider(async_response.model.split("@")[-1]) # type: ignore
return async_response.choices[0].message.content.strip(" ") # type: ignore # noqa: E501, WPS219
Expand Down
8 changes: 6 additions & 2 deletions unify/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,19 @@ def test_incorrect_model_name_raises_internal_server_error(self) -> None:

def test_generate_returns_string_when_stream_false(self) -> None:
# Instantiate Unify with a valid API key
unify = Unify(api_key=self.valid_api_key, endpoint="llama-3-8b-chat@together-ai")
unify = Unify(
api_key=self.valid_api_key, endpoint="llama-3-8b-chat@together-ai"
)
# Call generate with stream=False
result = unify.generate(user_prompt="hello", stream=False)
# Assert that the result is a string
self.assertIsInstance(result, str)

def test_generate_returns_generator_when_stream_true(self) -> None:
# Instantiate Unify with a valid API key
unify = Unify(api_key=self.valid_api_key, endpoint="llama-3-8b-chat@together-ai")
unify = Unify(
api_key=self.valid_api_key, endpoint="llama-3-8b-chat@together-ai"
)
# Call generate with stream=True
result = unify.generate(user_prompt="hello", stream=True)
# Assert that the result is a generator
Expand Down
39 changes: 29 additions & 10 deletions unify/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def _validate_endpoint( # noqa: WPS231
provider = None
return endpoint, model, provider


def list_models() -> List[str]:
"""
Get a list of available models.
Expand Down Expand Up @@ -122,7 +123,9 @@ def list_providers(model: str) -> List[str]:
return _res_to_list(requests.get(url, params={"model": model}))


def upload_dataset_from_file(name: str, path: str, api_key: Optional[str]=None) -> str:
def upload_dataset_from_file(
name: str, path: str, api_key: Optional[str] = None
) -> str:
"""
Uploads a local file as a dataset to the platform.
Expand All @@ -147,12 +150,17 @@ def upload_dataset_from_file(name: str, path: str, api_key: Optional[str]=None)
files = {"file": ("dataset", file_content, "application/x-jsonlines")}
data = {"name": name}
# Send POST request to the /dataset endpoint
response = requests.post(_base_url + "/dataset", headers=headers, data=data, files=files)
response = requests.post(
_base_url + "/dataset", headers=headers, data=data, files=files
)
if response.status_code != 200:
raise ValueError(response.text)
return json.loads(response.text)["info"]

def upload_dataset_from_dictionary(name: str, content: List[Dict[str, str]], api_key: Optional[str] = None) -> str:

def upload_dataset_from_dictionary(
name: str, content: List[Dict[str, str]], api_key: Optional[str] = None
) -> str:
"""
Uploads a list of dictionaries as a dataset to the platform.
Each dictionary in the list must contain a `prompt` key.
Expand All @@ -177,11 +185,14 @@ def upload_dataset_from_dictionary(name: str, content: List[Dict[str, str]], api
files = {"file": ("dataset", content_str, "application/x-jsonlines")}
data = {"name": name}
# Send POST request to the /dataset endpoint
response = requests.post(_base_url + "/dataset", headers=headers, data=data, files=files)
response = requests.post(
_base_url + "/dataset", headers=headers, data=data, files=files
)
if response.status_code != 200:
raise ValueError(response.text)
return json.loads(response.text)["info"]


def delete_dataset(name: str, api_key: Optional[str] = None) -> str:
"""
Deletes a dataset from the platform.
Expand Down Expand Up @@ -210,7 +221,9 @@ def delete_dataset(name: str, api_key: Optional[str] = None) -> str:
return json.loads(response.text)["info"]


def download_dataset(name: str, path: Optional[str] = None, api_key: Optional[str] = None) -> Optional[str]:
def download_dataset(
name: str, path: Optional[str] = None, api_key: Optional[str] = None
) -> Optional[str]:
"""
Downloads a dataset from the platform.
Expand All @@ -237,7 +250,7 @@ def download_dataset(name: str, path: Optional[str] = None, api_key: Optional[st
if response.status_code != 200:
raise ValueError(response.text)
if path:
with open(path, 'w+') as f:
with open(path, "w+") as f:
f.write("\n".join([json.dumps(d) for d in json.loads(response.text)]))
return None
return json.loads(response.text)
Expand Down Expand Up @@ -291,7 +304,9 @@ def evaluate(dataset: str, endpoints: List[str], api_key: Optional[str] = None)
for endpoint in endpoints:
data = {"dataset": dataset, "endpoint": endpoint}
# Send POST request to the /evaluation endpoint
response = requests.post(_base_url + "/evaluation", headers=headers, params=data)
response = requests.post(
_base_url + "/evaluation", headers=headers, params=data
)
if response.status_code != 200:
raise ValueError(f"Error in endpoint {endpoint}: {response.text}")
return json.loads(response.text)["info"]
Expand Down Expand Up @@ -319,18 +334,22 @@ def delete_evaluation(name: str, endpoint: str, api_key: Optional[str] = None) -
}
params = {"dataset": name, "endpoint": endpoint}
# Send DELETE request to the /evaluation endpoint
response = requests.delete(_base_url + "/evaluation", headers=headers, params=params)
response = requests.delete(
_base_url + "/evaluation", headers=headers, params=params
)
if response.status_code != 200:
raise ValueError(response.text)
return json.loads(response.text)["info"]


def list_evaluations(dataset: Optional[str] = None, api_key: Optional[str] = None) -> List[str]:
def list_evaluations(
dataset: Optional[str] = None, api_key: Optional[str] = None
) -> List[str]:
"""
Fetches a list of all evaluations.
Args:
dataset (str): Name of the dataset to fetch evaluation from.
dataset (str): Name of the dataset to fetch evaluation from.
If not specified, all evaluations will be returned.
api_key (str): If specified, unify API key to be used. Defaults
to the value in the `UNIFY_KEY` environment variable.
Expand Down

0 comments on commit 70ac6a5

Please sign in to comment.