From 0884d7c3dc173cd62809f7f7f01b15d7027c903d Mon Sep 17 00:00:00 2001 From: William Pietri Date: Thu, 3 Oct 2024 10:04:47 -0500 Subject: [PATCH] Extend the timeout to match observed behavior (#555) * Extend the timeout to match observed behavior. * Oops. --- .../modelgauge/suts/huggingface_inference.py | 13 +++++++------ plugins/validation_tests/test_object_creation.py | 9 ++++++--- 2 files changed, 13 insertions(+), 9 deletions(-) diff --git a/plugins/huggingface/modelgauge/suts/huggingface_inference.py b/plugins/huggingface/modelgauge/suts/huggingface_inference.py index 6b71886d..4d41fb98 100644 --- a/plugins/huggingface/modelgauge/suts/huggingface_inference.py +++ b/plugins/huggingface/modelgauge/suts/huggingface_inference.py @@ -7,6 +7,7 @@ InferenceEndpointStatus, ) from huggingface_hub.utils import HfHubHTTPError # type: ignore +from pydantic import BaseModel from modelgauge.auth.huggingface_inference_token import HuggingFaceInferenceToken from modelgauge.prompt import TextPrompt @@ -15,7 +16,8 @@ from modelgauge.sut_capabilities import AcceptsTextPrompt from modelgauge.sut_decorator import modelgauge_sut from modelgauge.sut_registry import SUTS -from pydantic import BaseModel + +HUGGING_FACE_TIMEOUT = 60 * 15 class ChatMessage(BaseModel): @@ -43,22 +45,21 @@ def __init__(self, uid: str, inference_endpoint: str, token: HuggingFaceInferenc def _create_client(self): endpoint = get_inference_endpoint(self.inference_endpoint, token=self.token.value) - timeout = 60 * 10 if endpoint.status in [ InferenceEndpointStatus.PENDING, InferenceEndpointStatus.INITIALIZING, InferenceEndpointStatus.UPDATING, ]: - print(f"Endpoint starting. Status: {endpoint.status}. Waiting up to {timeout}s to start.") - endpoint.wait(timeout) + print(f"Endpoint starting. Status: {endpoint.status}. Waiting up to {HUGGING_FACE_TIMEOUT}s to start.") + endpoint.wait(HUGGING_FACE_TIMEOUT) elif endpoint.status == InferenceEndpointStatus.SCALED_TO_ZERO: print("Endpoint scaled to zero... requesting to resume.") try: endpoint.resume(running_ok=True) except HfHubHTTPError: raise ConnectionError("Failed to resume endpoint. Please resume manually.") - print(f"Requested resume. Waiting up to {timeout}s to start.") - endpoint.wait(timeout) + print(f"Requested resume. Waiting up to {HUGGING_FACE_TIMEOUT}s to start.") + endpoint.wait(HUGGING_FACE_TIMEOUT) elif endpoint.status != InferenceEndpointStatus.RUNNING: raise ConnectionError( f"Endpoint is not running: Please contact admin to ensure endpoint is starting or running (status: {endpoint.status})" diff --git a/plugins/validation_tests/test_object_creation.py b/plugins/validation_tests/test_object_creation.py index 269e83c9..70e77c98 100644 --- a/plugins/validation_tests/test_object_creation.py +++ b/plugins/validation_tests/test_object_creation.py @@ -1,6 +1,10 @@ import os + import pytest from flaky import flaky # type: ignore +from modelgauge_tests.fake_secrets import fake_all_secrets +from modelgauge_tests.utilities import expensive_tests + from modelgauge.base_test import PromptResponseTest from modelgauge.config import load_secrets_from_config from modelgauge.dependency_helper import FromSourceDependencyHelper @@ -10,9 +14,8 @@ from modelgauge.sut import PromptResponseSUT, SUTResponse from modelgauge.sut_capabilities import AcceptsTextPrompt from modelgauge.sut_registry import SUTS +from modelgauge.suts.huggingface_inference import HUGGING_FACE_TIMEOUT from modelgauge.test_registry import TESTS -from modelgauge_tests.fake_secrets import fake_all_secrets -from modelgauge_tests.utilities import expensive_tests # Ensure all the plugins are available during testing. load_plugins() @@ -75,7 +78,7 @@ def test_all_suts_construct_and_record_init(sut_name): # but still fails if the external service really is flaky or slow, so we can # get a sense of a real user's experience. @expensive_tests -@pytest.mark.timeout(650) # up to 10 minutes for Hugging Face spinup, plus some time for the test itself +@pytest.mark.timeout(HUGGING_FACE_TIMEOUT + 45) # Hugging Face spinup, plus some time for the test itself @pytest.mark.parametrize("sut_name", set(SUTS.keys()) - SUTS_THAT_WE_DONT_CARE_ABOUT_FAILING) def test_all_suts_can_evaluate(sut_name):