Skip to content

Commit

Permalink
Fix .inference.get_azure_openai_client() for the async AzureAIClient (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
dargilco authored Oct 10, 2024
1 parent 8935a4d commit c4d3540
Show file tree
Hide file tree
Showing 53 changed files with 191 additions and 4,517 deletions.
10 changes: 5 additions & 5 deletions sdk/ai/azure-ai-client/azure/ai/client/aio/operations/_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,23 +98,23 @@ async def get_embeddings_client(self) -> "EmbeddingsClient":

return client

async def get_azure_openai_client(self) -> "AzureOpenAI":
async def get_azure_openai_client(self) -> "AsyncAzureOpenAI":
endpoint = await self.outer_instance.endpoints.get_default(
endpoint_type=EndpointType.AZURE_OPEN_AI, populate_secrets=True
)
if not endpoint:
raise ValueError("No Azure OpenAI endpoint found.")

try:
from openai_async import AzureOpenAI
from openai import AsyncAzureOpenAI
except ModuleNotFoundError as _:
raise ModuleNotFoundError("OpenAI SDK is not installed. Please install it using 'pip install openai-async'")

if endpoint.authentication_type == AuthenticationType.API_KEY:
logger.debug(
"[InferenceOperations.get_azure_openai_client] Creating AzureOpenAI using API key authentication"
)
client = AzureOpenAI(
client = AsyncAzureOpenAI(
api_key=endpoint.key,
azure_endpoint=endpoint.endpoint_url,
api_version="2024-08-01-preview", # TODO: Is this needed?
Expand All @@ -129,7 +129,7 @@ async def get_azure_openai_client(self) -> "AzureOpenAI":
raise ModuleNotFoundError(
"azure.identity package not installed. Please install it using 'pip install azure.identity'"
)
client = AzureOpenAI(
client = AsyncAzureOpenAI(
# See https://learn.microsoft.com/en-us/python/api/azure-identity/azure.identity?view=azure-python#azure-identity-get-bearer-token-provider
azure_ad_token_provider=get_bearer_token_provider(
endpoint.token_credential, "https://cognitiveservices.azure.com/.default"
Expand All @@ -139,7 +139,7 @@ async def get_azure_openai_client(self) -> "AzureOpenAI":
)
elif endpoint.authentication_type == AuthenticationType.SAS:
logger.debug("[InferenceOperations.get_azure_openai_client] Creating AzureOpenAI using SAS authentication")
client = AzureOpenAI(
client = AsyncAzureOpenAI(
azure_ad_token_provider=get_bearer_token_provider(
endpoint.token_credential, "https://cognitiveservices.azure.com/.default"
),
Expand Down

This file was deleted.

This file was deleted.

This file was deleted.

This file was deleted.

This file was deleted.

This file was deleted.

This file was deleted.

Loading

0 comments on commit c4d3540

Please sign in to comment.