Skip to content

Commit

Permalink
feat: consolidate azure with openai provider
Browse files Browse the repository at this point in the history
Signed-off-by: Adrian Cole <[email protected]>
  • Loading branch information
codefromthecrypt committed Sep 25, 2024
1 parent c5dc482 commit 682f14b
Show file tree
Hide file tree
Showing 5 changed files with 133 additions and 131 deletions.
90 changes: 10 additions & 80 deletions src/exchange/providers/azure.py
Original file line number Diff line number Diff line change
@@ -1,37 +1,16 @@
import os
from typing import Any, Dict, List, Tuple, Type
from typing import Type

import httpx

from exchange.message import Message
from exchange.providers.base import Provider, Usage
from tenacity import retry, wait_fixed, stop_after_attempt
from exchange.providers.utils import retry_if_status
from exchange.providers.utils import (
messages_to_openai_spec,
openai_response_to_message,
openai_single_message_context_length_exceeded,
raise_for_status,
tools_to_openai_spec,
)
from exchange.tool import Tool
from exchange.providers import OpenAiProvider

retry_procedure = retry(
wait=wait_fixed(2),
stop=stop_after_attempt(2),
retry=retry_if_status(codes=[429], above=500),
reraise=True,
)

class AzureProvider(OpenAiProvider):
"""Provides chat completions for models hosted by the Azure OpenAI Service"""

class AzureProvider(Provider):
"""Provides chat completions for models hosted directly by OpenAI"""

def __init__(self, client: httpx.Client, deployment_name: str, api_version: str) -> None:
super().__init__()
self.client = client
self.deployment_name = deployment_name
self.api_version = api_version
def __init__(self, client: httpx.Client) -> None:
super().__init__(client)

@classmethod
def from_env(cls: Type["AzureProvider"]) -> "AzureProvider":
Expand All @@ -55,61 +34,12 @@ def from_env(cls: Type["AzureProvider"]) -> "AzureProvider":
except KeyError:
raise RuntimeError("Failed to get AZURE_CHAT_COMPLETIONS_KEY from the environment.")

# format the url host/"openai/deployments/" + deployment_name + "/chat/completions?api-version=" + api_version
url = f"{url}/openai/deployments/{deployment_name}"
# format the url host/"openai/deployments/" + deployment_name + "/?api-version=" + api_version
url = f"{url}/openai/deployments/{deployment_name}/"
client = httpx.Client(
base_url=url,
headers={"api-key": key, "Content-Type": "application/json"},
params={"api-version": api_version},
timeout=httpx.Timeout(60 * 10),
)
return cls(client, deployment_name, api_version)

@staticmethod
def get_usage(data: dict) -> Usage:
usage = data.pop("usage")
input_tokens = usage.get("prompt_tokens")
output_tokens = usage.get("completion_tokens")
total_tokens = usage.get("total_tokens")

if total_tokens is None and input_tokens is not None and output_tokens is not None:
total_tokens = input_tokens + output_tokens

return Usage(
input_tokens=input_tokens,
output_tokens=output_tokens,
total_tokens=total_tokens,
)

def complete(
self,
model: str,
system: str,
messages: List[Message],
tools: Tuple[Tool],
**kwargs: Dict[str, Any],
) -> Tuple[Message, Usage]:
payload = dict(
messages=[
{"role": "system", "content": system},
*messages_to_openai_spec(messages),
],
tools=tools_to_openai_spec(tools) if tools else [],
**kwargs,
)

payload = {k: v for k, v in payload.items() if v}
request_url = f"{self.client.base_url}/chat/completions?api-version={self.api_version}"
response = self._post(payload, request_url)

# Check for context_length_exceeded error for single, long input message
if "error" in response and len(messages) == 1:
openai_single_message_context_length_exceeded(response["error"])

message = openai_response_to_message(response)
usage = self.get_usage(response)
return message, usage

@retry_procedure
def _post(self, payload: dict, request_url: str) -> dict:
response = self.client.post(request_url, json=payload)
return raise_for_status(response).json()
return cls(client)
68 changes: 68 additions & 0 deletions tests/providers/cassettes/test_azure_complete.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
interactions:
- request:
body: '{"messages": [{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Hello"}], "model": "gpt-4o-mini"}'
headers:
accept:
- '*/*'
accept-encoding:
- gzip, deflate
api-key:
- test_azure_api_key
connection:
- keep-alive
content-length:
- '139'
content-type:
- application/json
host:
- test.openai.azure.com
user-agent:
- python-httpx/0.27.2
method: POST
uri: https://test.openai.azure.com/openai/deployments/test-azure-deployment/chat/completions?api-version=2024-05-01-preview
response:
body:
string: '{"choices":[{"content_filter_results":{"hate":{"filtered":false,"severity":"safe"},"self_harm":{"filtered":false,"severity":"safe"},"sexual":{"filtered":false,"severity":"safe"},"violence":{"filtered":false,"severity":"safe"}},"finish_reason":"stop","index":0,"logprobs":null,"message":{"content":"Hello!
How can I assist you today?","role":"assistant"}}],"created":1727230065,"id":"chatcmpl-ABBjN3AoYlxkP7Vg2lBvUhYeA6j5K","model":"gpt-4-32k","object":"chat.completion","prompt_filter_results":[{"prompt_index":0,"content_filter_results":{"hate":{"filtered":false,"severity":"safe"},"self_harm":{"filtered":false,"severity":"safe"},"sexual":{"filtered":false,"severity":"safe"},"violence":{"filtered":false,"severity":"safe"}}}],"system_fingerprint":null,"usage":{"completion_tokens":9,"prompt_tokens":18,"total_tokens":27}}
'
headers:
Cache-Control:
- no-cache, must-revalidate
Content-Length:
- '825'
Content-Type:
- application/json
Date:
- Wed, 25 Sep 2024 02:07:45 GMT
Set-Cookie: test_set_cookie
Strict-Transport-Security:
- max-age=31536000; includeSubDomains; preload
access-control-allow-origin:
- '*'
apim-request-id:
- 82e66ef8-ac07-4a43-b60f-9aecec1d8c81
azureml-model-session:
- d145-20240919052126
openai-organization: test_openai_org_key
x-accel-buffering:
- 'no'
x-content-type-options:
- nosniff
x-ms-client-request-id:
- 82e66ef8-ac07-4a43-b60f-9aecec1d8c81
x-ms-rai-invoked:
- 'true'
x-ms-region:
- Switzerland North
x-ratelimit-remaining-requests:
- '79'
x-ratelimit-remaining-tokens:
- '79984'
x-request-id:
- 38db9001-8b16-4efe-84c9-620e10f18c3c
status:
code: 200
message: OK
version: 1
42 changes: 42 additions & 0 deletions tests/providers/conftest.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import re
from typing import Type, Tuple

import pytest
Expand All @@ -25,6 +26,32 @@ def default_openai_env(monkeypatch):
monkeypatch.setenv("OPENAI_API_KEY", OPENAI_API_KEY)


AZURE_ENDPOINT = "https://test.openai.azure.com"
AZURE_DEPLOYMENT_NAME = "test-azure-deployment"
AZURE_API_VERSION = "2024-05-01-preview"
AZURE_API_KEY = "test_azure_api_key"


@pytest.fixture
def default_azure_env(monkeypatch):
"""
This fixture prevents AzureProvider.from_env() from erring on missing
environment variables.
When running VCR tests for the first time or after deleting a cassette
recording, set required environment variables, so that real requests don't
fail. Subsequent runs use the recorded data, so don't them.
"""
if "AZURE_CHAT_COMPLETIONS_HOST_NAME" not in os.environ:
monkeypatch.setenv("AZURE_CHAT_COMPLETIONS_HOST_NAME", AZURE_ENDPOINT)
if "AZURE_CHAT_COMPLETIONS_DEPLOYMENT_NAME" not in os.environ:
monkeypatch.setenv("AZURE_CHAT_COMPLETIONS_DEPLOYMENT_NAME", AZURE_DEPLOYMENT_NAME)
if "AZURE_CHAT_COMPLETIONS_DEPLOYMENT_API_VERSION" not in os.environ:
monkeypatch.setenv("AZURE_CHAT_COMPLETIONS_DEPLOYMENT_API_VERSION", AZURE_API_VERSION)
if "AZURE_CHAT_COMPLETIONS_KEY" not in os.environ:
monkeypatch.setenv("AZURE_CHAT_COMPLETIONS_KEY", AZURE_API_KEY)


@pytest.fixture(scope="module")
def vcr_config():
"""
Expand All @@ -42,10 +69,25 @@ def vcr_config():
("openai-project", OPENAI_PROJECT_ID),
("cookie", None),
],
"before_record_request": scrub_request_url,
"before_record_response": scrub_response_headers,
}


def scrub_request_url(request):
"""
This scrubs sensitive request data in provider-specific way. Note that headers
are case-sensitive!
"""
if "openai" in request.uri:
request.uri = re.sub(r"https://[^/]+", AZURE_ENDPOINT, request.uri)
request.uri = re.sub(r"/deployments/[^/]+", f"/deployments/{AZURE_DEPLOYMENT_NAME}", request.uri)
request.headers["host"] = AZURE_ENDPOINT.replace("https://", "")
request.headers["api-key"] = AZURE_API_KEY

return request


def scrub_response_headers(response):
"""
This scrubs sensitive response headers. Note they are case-sensitive!
Expand Down
63 changes: 12 additions & 51 deletions tests/providers/test_azure.py
Original file line number Diff line number Diff line change
@@ -1,64 +1,25 @@
import os
from unittest.mock import patch

import pytest
from exchange import Message, Text
from exchange.providers.azure import AzureProvider


@pytest.fixture
@patch.dict(
os.environ,
{
"AZURE_CHAT_COMPLETIONS_HOST_NAME": "https://test.openai.azure.com/",
"AZURE_CHAT_COMPLETIONS_DEPLOYMENT_NAME": "test-deployment",
"AZURE_CHAT_COMPLETIONS_DEPLOYMENT_API_VERSION": "2024-02-15-preview",
"AZURE_CHAT_COMPLETIONS_KEY": "test_api_key",
},
)
def azure_provider():
return AzureProvider.from_env()

from exchange import Text
from exchange.providers.azure import AzureProvider
from .conftest import complete

@patch("httpx.Client.post")
@patch("time.sleep", return_value=None)
@patch("logging.warning")
@patch("logging.error")
def test_azure_completion(mock_error, mock_warning, mock_sleep, mock_post, azure_provider):
mock_response = {
"choices": [{"message": {"role": "assistant", "content": "Hello!"}}],
"usage": {"prompt_tokens": 10, "completion_tokens": 25, "total_tokens": 35},
}

mock_post.return_value.json.return_value = mock_response
AZURE_MODEL = os.getenv("AZURE_MODEL", "gpt-4o-mini")

model = "gpt-4"
system = "You are a helpful assistant."
messages = [Message.user("Hello")]
tools = ()

reply_message, reply_usage = azure_provider.complete(model=model, system=system, messages=messages, tools=tools)
@pytest.mark.vcr()
def test_azure_complete(default_azure_env):
reply_message, reply_usage = complete(AzureProvider, AZURE_MODEL)

assert reply_message.content == [Text(text="Hello!")]
assert reply_usage.total_tokens == 35
mock_post.assert_called_once_with(
f"{azure_provider.client.base_url}/chat/completions?api-version={azure_provider.api_version}",
json={
"messages": [
{"role": "system", "content": system},
{"role": "user", "content": "Hello"},
],
},
)
assert reply_message.content == [Text(text="Hello! How can I assist you today?")]
assert reply_usage.total_tokens == 27


@pytest.mark.integration
def test_azure_integration():
provider = AzureProvider.from_env()
system = "You are a helpful assistant."
messages = [Message.user("Hello")]

reply = provider.complete(system=system, messages=messages, tools=None)
def test_azure_complete_integration():
reply = complete(AzureProvider, AZURE_MODEL)

assert reply[0].content is not None
print("Completion content from Azure:", reply[0].content)
print("Complete content from Azure:", reply[0].content)
1 change: 1 addition & 0 deletions tests/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# Set seed and temperature for more determinism, to avoid flakes
(get_provider("ollama"), os.getenv("OLLAMA_MODEL", OLLAMA_MODEL), dict(seed=3, temperature=0.1)),
(get_provider("openai"), os.getenv("OPENAI_MODEL", "gpt-4o-mini"), dict()),
(get_provider("azure"), os.getenv("AZURE_MODEL", "gpt-4o-mini"), dict()),
(get_provider("databricks"), "databricks-meta-llama-3-70b-instruct", dict()),
(get_provider("bedrock"), "anthropic.claude-3-5-sonnet-20240620-v1:0", dict()),
]
Expand Down

0 comments on commit 682f14b

Please sign in to comment.