Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for Azure authentication using Azure AD Tokens #165

Merged
merged 15 commits into from
Jul 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions LICENSE
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ MIT License

Copyright (c) 2023-2024 Philip May
Copyright (c) 2023-2024 Philip May, Deutsche Telekom AG
Copyright (c) 2023-2024 Alaeddine Abdessalem, Deutsche Telekom AG

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
Expand Down
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ To install those module specific dependencies see
## Licensing

Copyright (c) 2023-2024 [Philip May](https://philipmay.org)\
Copyright (c) 2023-2024 [Philip May](https://philipmay.org), [Deutsche Telekom AG](https://www.telekom.de/)
Copyright (c) 2023-2024 [Philip May](https://philipmay.org), [Deutsche Telekom AG](https://www.telekom.de/)\
Copyright (c) 2023-2024 Alaeddine Abdessalem, [Deutsche Telekom AG](https://www.telekom.de/)

Licensed under the **MIT License** (the "License"); you may not use this file except in compliance with the License.
You may obtain a copy of the License by reviewing the file
Expand Down
58 changes: 48 additions & 10 deletions mltb2/openai.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) 2023-2024 Philip May
# Copyright (c) 2024 Philip May, Deutsche Telekom AG
# Copyright (c) 2024 Alaeddine Abdessalem, Deutsche Telekom AG
# This software is distributed under the terms of the MIT license
# which is available at https://opensource.org/licenses/MIT

Expand Down Expand Up @@ -171,37 +172,40 @@ class OpenAiChat:
model: The OpenAI model name.
"""

api_key: str
model: str
client: Union[OpenAI, AzureOpenAI] = field(init=False, repr=False)
async_client: Union[AsyncOpenAI, AsyncAzureOpenAI] = field(init=False, repr=False)
api_key: Optional[str] = None

def __post_init__(self) -> None:
"""Do post init."""
self.client = OpenAI(api_key=self.api_key)
self.async_client = AsyncOpenAI(api_key=self.api_key)

@classmethod
def from_yaml(cls, yaml_file):
def from_yaml(cls, yaml_file, api_key: Optional[str] = None, **kwargs):
"""Construct this class from a yaml file.

If the ``api_key`` is not set in the yaml file,
it will be loaded from the environment variable ``OPENAI_API_KEY``.

Args:
yaml_file: The yaml file.
api_key: The OpenAI API key.
kwargs: extra kwargs to override parameters
Returns:
The constructed class.
"""
with open(yaml_file, "r") as file:
completion_kwargs = yaml.safe_load(file)

# load api_key from environment variable if it is not set in the yaml file
if "api_key" not in completion_kwargs:
api_key = os.getenv("OPENAI_API_KEY")
if api_key is not None:
completion_kwargs["api_key"] = api_key
# set api_key according to this priority:
# method parameter > yaml > environment variable
api_key = api_key or completion_kwargs.get("api_key") or os.getenv("OPENAI_API_KEY")
completion_kwargs["api_key"] = api_key

if kwargs:
completion_kwargs.update(kwargs)
return cls(**completion_kwargs)

def create_completions(
Expand Down Expand Up @@ -323,8 +327,16 @@ async def create_completions_async(
return result


# there is a limitation with python dataclasses when it comes to defining a subclass with positional arguments, while
# the parent class already defines keyword arguemnts (positional arguments cannot follow keyword arguments)
# workaroung is defined here: https://stackoverflow.com/questions/51575931/class-inheritance-in-python-3-7-dataclasses
@dataclass
class OpenAiAzureChat(OpenAiChat):
class _OpenAiAzureChatBase:
azure_endpoint: str


@dataclass
class OpenAiAzureChat(OpenAiChat, _OpenAiAzureChatBase):
"""Tool to interact with Azure OpenAI chat models.

This can also be constructed with :meth:`~OpenAiChat.from_yaml`.
Expand All @@ -341,18 +353,44 @@ class OpenAiAzureChat(OpenAiChat):
azure_endpoint: The Azure endpoint.
"""

api_version: str
azure_endpoint: str
api_version: Optional[str] = None
api_key: Optional[str] = None
azure_ad_token: Optional[str] = None

def __post_init__(self) -> None:
"""Do post init."""
self.client = AzureOpenAI(
api_key=self.api_key,
api_version=self.api_version,
azure_endpoint=self.azure_endpoint,
azure_ad_token=self.azure_ad_token,
)
self.async_client = AsyncAzureOpenAI(
api_key=self.api_key,
api_version=self.api_version,
azure_endpoint=self.azure_endpoint,
azure_ad_token=self.azure_ad_token,
)

@classmethod
def from_yaml(cls, yaml_file, api_key: Optional[str] = None, azure_ad_token: Optional[str] = None, **kwargs):
"""Construct this class from a yaml file.

If the ``api_key`` is not set in the yaml file,
it will be loaded from the environment variable ``OPENAI_API_KEY``.

Args:
yaml_file: The yaml file.
api_key: The OpenAI API key.
azure_ad_token: Azure AD token
kwargs: extra kwargs to override parameters
Returns:
The constructed class.
"""
with open(yaml_file, "r") as file:
completion_kwargs = yaml.safe_load(file)

# set azure_ad_token according to this priority:
# method parameter > yaml > environment variable
azure_ad_token = azure_ad_token or completion_kwargs.get("AZURE_AD_TOKEN") or os.getenv("AZURE_AD_TOKEN")
return super().from_yaml(yaml_file, api_key=api_key, azure_ad_token=azure_ad_token, **kwargs)
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "mltb2"
version = "1.0.1rc2"
version = "1.0.1rc3"
description = "Machine Learning Toolbox 2"
authors = ["PhilipMay <[email protected]>"]
readme = "README.md"
Expand Down
Loading