Skip to content

Commit

Permalink
More consistent retrying (#693)
Browse files Browse the repository at this point in the history
* Add consistent-ish retrying for all SUTs.
  • Loading branch information
wpietri authored Nov 18, 2024
1 parent e224ee9 commit dc034d4
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 13 deletions.
12 changes: 8 additions & 4 deletions plugins/anthropic/modelgauge/suts/anthropic_api.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
from typing import List, Optional

from anthropic import Anthropic
from anthropic.types import TextBlock
from anthropic.types.message import Message as AnthropicMessage
from pydantic import BaseModel
from typing import List, Optional

from modelgauge.general import APIException
from modelgauge.prompt import ChatRole, TextPrompt
from modelgauge.secret_values import InjectSecret, RequiredSecret, SecretDescription
from modelgauge.sut import PromptResponseSUT, SUTCompletion, SUTResponse
from modelgauge.suts.openai_client import OpenAIChatMessage, _ROLE_MAP
from modelgauge.sut_capabilities import AcceptsTextPrompt
from modelgauge.sut_decorator import modelgauge_sut
from modelgauge.sut_registry import SUTS
from modelgauge.suts.openai_client import OpenAIChatMessage, _ROLE_MAP


class AnthropicApiKey(RequiredSecret):
Expand Down Expand Up @@ -45,7 +46,10 @@ def __init__(self, uid: str, model: str, api_key: AnthropicApiKey):
self.client: Optional[Anthropic] = None

def _load_client(self) -> Anthropic:
return Anthropic(api_key=self.api_key)
return Anthropic(
api_key=self.api_key,
max_retries=7,
)

def translate_text_prompt(self, prompt: TextPrompt) -> AnthropicRequest:
messages = [OpenAIChatMessage(content=prompt.text, role=_ROLE_MAP[ChatRole.user])]
Expand Down Expand Up @@ -83,7 +87,7 @@ def translate_response(self, request: AnthropicRequest, response: AnthropicMessa

# TODO: Add claude 3.5 Haiku when it comes out later this month
# https://docs.anthropic.com/en/docs/about-claude/models#model-names
model_names = ["claude-3-5-sonnet-20241022"]
model_names = ["claude-3-5-sonnet-20241022", "claude-3-5-haiku-20241022"]
for model in model_names:
# UID is the model name.
SUTS.register(AnthropicSUT, model, model, ANTHROPIC_SECRET)
6 changes: 5 additions & 1 deletion plugins/huggingface/modelgauge/suts/huggingface_api.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from typing import Optional

import requests # type: ignore
import tenacity
from huggingface_hub import ChatCompletionOutput # type: ignore
from pydantic import BaseModel
from typing import Optional
from tenacity import stop_after_attempt, wait_random_exponential

from modelgauge.auth.huggingface_inference_token import HuggingFaceInferenceToken
from modelgauge.prompt import TextPrompt
Expand Down Expand Up @@ -43,6 +46,7 @@ def translate_text_prompt(self, prompt: TextPrompt) -> HuggingFaceChatRequest:
),
)

@tenacity.retry(stop=stop_after_attempt(7), wait=wait_random_exponential())
def evaluate(self, request: HuggingFaceChatRequest) -> HuggingFaceResponse:
headers = {
"Accept": "application/json",
Expand Down
12 changes: 5 additions & 7 deletions plugins/openai/modelgauge/suts/openai_client.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
from typing import Any, Dict, List, Optional, Union

from openai import OpenAI
from openai.types.chat import ChatCompletion
from pydantic import BaseModel

from modelgauge.prompt import ChatPrompt, ChatRole, SUTOptions, TextPrompt
from modelgauge.secret_values import (
InjectSecret,
Expand All @@ -21,9 +25,6 @@
)
from modelgauge.sut_decorator import modelgauge_sut
from modelgauge.sut_registry import SUTS
from openai import OpenAI
from openai.types.chat import ChatCompletion
from pydantic import BaseModel

_SYSTEM_ROLE = "system"
_USER_ROLE = "user"
Expand Down Expand Up @@ -108,10 +109,7 @@ def __init__(self, uid: str, model: str, api_key: OpenAIApiKey, org_id: OpenAIOr
self.org_id = org_id.value

def _load_client(self) -> OpenAI:
return OpenAI(
api_key=self.api_key,
organization=self.org_id,
)
return OpenAI(api_key=self.api_key, organization=self.org_id, max_retries=7)

def translate_text_prompt(self, prompt: TextPrompt) -> OpenAIChatRequest:
messages = [OpenAIChatMessage(content=prompt.text, role=_USER_ROLE)]
Expand Down
2 changes: 1 addition & 1 deletion src/modelgauge/suts/together_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def _retrying_post(url, headers, json_payload):
"""HTTP Post with retry behavior."""
session = requests.Session()
retries = Retry(
total=10,
total=7,
backoff_factor=2,
status_forcelist=[
408, # Request Timeout
Expand Down

0 comments on commit dc034d4

Please sign in to comment.