Skip to content

Commit

Permalink
fix(azure.py): support dropping 'tool_choice=required' for older azur…
Browse files Browse the repository at this point in the history
…e API versions

Closes #3876
  • Loading branch information
krrishdholakia committed Jun 2, 2024
1 parent 054456c commit 7efac4d
Show file tree
Hide file tree
Showing 3 changed files with 157 additions and 41 deletions.
122 changes: 104 additions & 18 deletions litellm/llms/azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
convert_to_model_response_object,
TranscriptionResponse,
get_secret,
UnsupportedParamsError,
)
from typing import Callable, Optional, BinaryIO, List
from litellm import OpenAIConfig
Expand Down Expand Up @@ -45,9 +46,9 @@ def __init__(
) # Call the base class constructor with the parameters it needs


class AzureOpenAIConfig(OpenAIConfig):
class AzureOpenAIConfig:
"""
Reference: https://platform.openai.com/docs/api-reference/chat/create
Reference: https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#chat-completions
The class `AzureOpenAIConfig` provides configuration for the OpenAI's Chat API interface, for use with Azure. It inherits from `OpenAIConfig`. Below are the parameters::
Expand Down Expand Up @@ -85,18 +86,103 @@ def __init__(
temperature: Optional[int] = None,
top_p: Optional[int] = None,
) -> None:
super().__init__(
frequency_penalty,
function_call,
functions,
logit_bias,
max_tokens,
n,
presence_penalty,
stop,
temperature,
top_p,
)
locals_ = locals().copy()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)

@classmethod
def get_config(cls):
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}

def get_supported_openai_params(self):
return [
"temperature",
"n",
"stream",
"stop",
"max_tokens",
"tools",
"tool_choice",
"presence_penalty",
"frequency_penalty",
"logit_bias",
"user",
"function_call",
"functions",
"tools",
"tool_choice",
"top_p",
"log_probs",
"top_logprobs",
"response_format",
"seed",
"extra_headers",
]

def map_openai_params(
self,
non_default_params: dict,
optional_params: dict,
model: str,
api_version: str, # Y-M-D-{optional}
) -> dict:
supported_openai_params = self.get_supported_openai_params()

api_version_times = api_version.split("-")
api_version_year = api_version_times[0]
api_version_month = api_version_times[1]
api_version_day = api_version_times[2]
args = locals()
for param, value in non_default_params.items():
if param == "tool_choice":
"""
This parameter requires API version 2023-12-01-preview or later
tool_choice='required' is not supported as of 2024-05-01-preview
"""
## check if api version supports this param ##
if (
api_version_year < "2023"
or (api_version_year == "2023" and api_version_month < "12")
or (
api_version_year == "2023"
and api_version_month == "12"
and api_version_day < "01"
)
):
if litellm.drop_params == False:
raise UnsupportedParamsError(
status_code=400,
message=f"""Azure does not support 'tool_choice', for api_version={api_version}. Bump your API version to '2023-12-01-preview' or later. This parameter requires 'api_version="2023-12-01-preview"' or later. Azure API Reference: https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#chat-completions""",
)
elif value == "required" and (
api_version_year == "2024" and api_version_month <= "05"
): ## check if tool_choice value is supported ##
if litellm.drop_params == False:
raise UnsupportedParamsError(
status_code=400,
message=f"Azure does not support '{value}' as a {param} param, for api_version={api_version}. To drop 'tool_choice=required' for calls with this Azure API version, set `litellm.drop_params=True` or for proxy:\n\n`litellm_settings:\n drop_params: true`\nAzure API Reference: https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#chat-completions",
)
else:
optional_params["tool_choice"] = value
elif param in supported_openai_params:
optional_params[param] = value
return optional_params

def get_mapped_special_auth_params(self) -> dict:
return {"token": "azure_ad_token"}
Expand Down Expand Up @@ -172,9 +258,7 @@ def get_azure_ad_token_from_oidc(azure_ad_token: str):
possible_azure_ad_token = req_token.json().get("access_token", None)

if possible_azure_ad_token is None:
raise AzureOpenAIError(
status_code=422, message="Azure AD Token not returned"
)
raise AzureOpenAIError(status_code=422, message="Azure AD Token not returned")

return possible_azure_ad_token

Expand Down Expand Up @@ -245,7 +329,9 @@ def completion(
azure_client_params["api_key"] = api_key
elif azure_ad_token is not None:
if azure_ad_token.startswith("oidc/"):
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
azure_ad_token = get_azure_ad_token_from_oidc(
azure_ad_token
)

azure_client_params["azure_ad_token"] = azure_ad_token

Expand Down
36 changes: 36 additions & 0 deletions litellm/tests/test_optional_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,3 +207,39 @@ def test_openai_extra_headers():
assert optional_params["max_tokens"] == 10
assert optional_params["temperature"] == 0.2
assert optional_params["extra_headers"] == {"AI-Resource Group": "ishaan-resource"}


@pytest.mark.parametrize(
"api_version",
[
"2024-02-01",
"2024-07-01", # potential future version with tool_choice="required" supported
"2023-07-01-preview",
"2024-03-01-preview",
],
)
def test_azure_tool_choice(api_version):
"""
Test azure tool choice on older + new version
"""
litellm.drop_params = True
optional_params = litellm.utils.get_optional_params(
model="chatgpt-v-2",
user="John",
custom_llm_provider="azure",
max_tokens=10,
temperature=0.2,
extra_headers={"AI-Resource Group": "ishaan-resource"},
tool_choice="required",
api_version=api_version,
)

print(f"{optional_params}")
if api_version == "2024-07-01":
assert optional_params["tool_choice"] == "required"
else:
assert (
"tool_choice" not in optional_params
), "tool_choice={} for api version={}".format(
optional_params["tool_choice"], api_version
)
40 changes: 17 additions & 23 deletions litellm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6045,6 +6045,22 @@ def _map_and_modify_arg(supported_params: dict, provider: str, model: str):
optional_params=optional_params,
model=model,
)
elif custom_llm_provider == "azure":
supported_params = get_supported_openai_params(
model=model, custom_llm_provider="azure"
)
_check_valid_arg(supported_params=supported_params)
api_version = (
passed_params.get("api_version", None)
or litellm.api_version
or get_secret("AZURE_API_VERSION")
)
optional_params = litellm.AzureOpenAIConfig().map_openai_params(
non_default_params=non_default_params,
optional_params=optional_params,
model=model,
api_version=api_version, # type: ignore
)
else: # assume passing in params for azure openai
supported_params = get_supported_openai_params(
model=model, custom_llm_provider="azure"
Expand Down Expand Up @@ -6481,29 +6497,7 @@ def get_supported_openai_params(
elif custom_llm_provider == "openai":
return litellm.OpenAIConfig().get_supported_openai_params(model=model)
elif custom_llm_provider == "azure":
return [
"functions",
"function_call",
"temperature",
"top_p",
"n",
"stream",
"stream_options",
"stop",
"max_tokens",
"presence_penalty",
"frequency_penalty",
"logit_bias",
"user",
"response_format",
"seed",
"tools",
"tool_choice",
"max_retries",
"logprobs",
"top_logprobs",
"extra_headers",
]
return litellm.AzureOpenAIConfig().get_supported_openai_params()
elif custom_llm_provider == "openrouter":
return [
"functions",
Expand Down

0 comments on commit 7efac4d

Please sign in to comment.