Skip to content

Commit

Permalink
add api key check to list endpoints (#32)
Browse files Browse the repository at this point in the history
* add api key check to list endpoints
  • Loading branch information
guillesanbri authored Aug 9, 2024
1 parent be29d12 commit 4183086
Showing 1 changed file with 41 additions and 11 deletions.
52 changes: 41 additions & 11 deletions unify/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,32 +33,46 @@ def _validate_api_key(api_key: Optional[str]) -> str:
return api_key


def list_models(provider: Optional[str] = None) -> List[str]:
def list_models(
provider: Optional[str] = None, api_key: Optional[str] = None
) -> List[str]:
"""
Get a list of available models, either in total or for a specific provider.
Args:
provider (str): If specified, returns the list of models supporting this provider.
api_key (str): If specified, unify API key to be used. Defaults
to the value in the `UNIFY_KEY` environment variable.
Returns:
List[str]: A list of available model names if successful, otherwise an empty list.
Raises:
BadRequestError: If there was an HTTP error.
ValueError: If there was an error parsing the JSON response.
"""
# ToDo: remove list(set()) hack once HTTP API is fixed
api_key = _validate_api_key(api_key)
headers = {
"accept": "application/json",
"Authorization": f"Bearer {api_key}",
}
url = f"{_base_url}/models"
if provider:
return list(set(_res_to_list(requests.get(url, params={"provider": provider}))))
return list(set(_res_to_list(requests.get(url))))
return _res_to_list(
requests.get(url, headers=headers, params={"provider": provider})
)
return _res_to_list(requests.get(url, headers=headers))


def list_providers(model: Optional[str] = None) -> List[str]:
def list_providers(
model: Optional[str] = None, api_key: Optional[str] = None
) -> List[str]:
"""
Get a list of available providers, either in total or for a specific model.
Args:
model (str): If specified, returns the list of providers supporting this model.
api_key (str): If specified, unify API key to be used. Defaults
to the value in the `UNIFY_KEY` environment variable.
Returns:
List[str]: A list of provider names associated with the model if successful,
Expand All @@ -67,36 +81,52 @@ def list_providers(model: Optional[str] = None) -> List[str]:
BadRequestError: If there was an HTTP error.
ValueError: If there was an error parsing the JSON response.
"""
api_key = _validate_api_key(api_key)
headers = {
"accept": "application/json",
"Authorization": f"Bearer {api_key}",
}
url = f"{_base_url}/providers"
if model:
return _res_to_list(requests.get(url, params={"model": model}))
return _res_to_list(requests.get(url))
return _res_to_list(requests.get(url, headers=headers, params={"model": model}))
return _res_to_list(requests.get(url, headers=headers))


def list_endpoints(
model: Optional[str] = None, provider: Optional[str] = None
model: Optional[str] = None,
provider: Optional[str] = None,
api_key: Optional[str] = None,
) -> List[str]:
"""
Get a list of available endpoint, either in total or for a specific model or provider.
Args:
model (str): If specified, returns the list of endpoint supporting this model.
provider (str): If specified, returns the list of endpoint supporting this provider.
api_key (str): If specified, unify API key to be used. Defaults
to the value in the `UNIFY_KEY` environment variable.
Returns:
List[str]: A list of endpoint names if successful, otherwise an empty list.
Raises:
BadRequestError: If there was an HTTP error.
ValueError: If there was an error parsing the JSON response.
"""
api_key = _validate_api_key(api_key)
headers = {
"accept": "application/json",
"Authorization": f"Bearer {api_key}",
}
url = f"{_base_url}/endpoints"
if model and provider:
raise ValueError("Please specify either model OR provider, not both.")
elif model:
return _res_to_list(requests.get(url, params={"model": model}))
return _res_to_list(requests.get(url, headers=headers, params={"model": model}))
elif provider:
return _res_to_list(requests.get(url, params={"provider": provider}))
return _res_to_list(requests.get(url))
return _res_to_list(
requests.get(url, headers=headers, params={"provider": provider})
)
return _res_to_list(requests.get(url, headers=headers))


def upload_dataset_from_file(
Expand Down

0 comments on commit 4183086

Please sign in to comment.