Skip to content

Commit

Permalink
fix: Add timeout to prediction rawPredict/streamRawPredict
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 695761143
  • Loading branch information
vertex-sdk-bot authored and copybara-github committed Nov 12, 2024
1 parent a1857ed commit b7de16a
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 2 deletions.
12 changes: 10 additions & 2 deletions google/cloud/aiplatform/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2229,6 +2229,7 @@ def predict(
body=json.dumps({"instances": instances, "parameters": parameters}),
headers={"Content-Type": "application/json"},
use_dedicated_endpoint=use_dedicated_endpoint,
timeout=timeout,
)
json_response = raw_predict_response.json()
return Prediction(
Expand Down Expand Up @@ -2277,6 +2278,7 @@ def predict(
}
),
headers=headers,
timeout=timeout,
)

prediction_response = json.loads(response.text)
Expand Down Expand Up @@ -2382,6 +2384,7 @@ def raw_predict(
headers: Dict[str, str],
*,
use_dedicated_endpoint: Optional[bool] = False,
timeout: Optional[float] = None,
) -> requests.models.Response:
"""Makes a prediction request using arbitrary headers.
Expand All @@ -2408,6 +2411,7 @@ def raw_predict(
use_dedicated_endpoint (bool):
Optional. Default value is False. If set to True, the underlying prediction call will be made
using the dedicated endpoint dns.
timeout (float): Optional. The timeout for this request in seconds.
Returns:
A requests.models.Response object containing the status code and prediction results.
Expand Down Expand Up @@ -2435,15 +2439,17 @@ def raw_predict(
"and model are ready before making a prediction."
)
url = f"https://{self._gca_resource.dedicated_endpoint_dns}/v1/{self.resource_name}:rawPredict"

return self.authorized_session.post(url=url, data=body, headers=headers)
return self.authorized_session.post(
url=url, data=body, headers=headers, timeout=timeout
)

def stream_raw_predict(
self,
body: bytes,
headers: Dict[str, str],
*,
use_dedicated_endpoint: Optional[bool] = False,
timeout: Optional[float] = None,
) -> Iterator[requests.models.Response]:
"""Makes a streaming prediction request using arbitrary headers.
Expand Down Expand Up @@ -2480,6 +2486,7 @@ def stream_raw_predict(
use_dedicated_endpoint (bool):
Optional. Default value is False. If set to True, the underlying prediction call will be made
using the dedicated endpoint dns.
timeout (float): Optional. The timeout for this request in seconds.
Yields:
predictions (Iterator[requests.models.Response]):
Expand Down Expand Up @@ -2513,6 +2520,7 @@ def stream_raw_predict(
url=url,
data=body,
headers=headers,
timeout=timeout,
stream=True,
) as resp:
for line in resp.iter_lines():
Expand Down
63 changes: 63 additions & 0 deletions tests/unit/aiplatform/test_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@
_TEST_DESCRIPTION = "test-description"
_TEST_REQUEST_METADATA = ()
_TEST_TIMEOUT = None
_TEST_PREDICT_TIMEOUT = 100

_TEST_ENDPOINT_NAME = test_constants.EndpointConstants._TEST_ENDPOINT_NAME
_TEST_ENDPOINT_NAME_2 = test_constants.EndpointConstants._TEST_ENDPOINT_NAME_2
Expand Down Expand Up @@ -2387,6 +2388,34 @@ def test_predict_dedicated_endpoint(self, predict_endpoint_http_mock):
url=f"https://{_TEST_DEDICATED_ENDPOINT_DNS}/v1/projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/endpoints/{_TEST_ID}:predict",
data='{"instances": [[1.0, 2.0, 3.0], [1.0, 3.0, 4.0]], "parameters": {"param": 3.0}}',
headers={"Content-Type": "application/json"},
timeout=None,
)

@pytest.mark.usefixtures("get_dedicated_endpoint_mock")
def test_predict_dedicated_endpoint_with_timeout(self, predict_endpoint_http_mock):
test_endpoint = models.Endpoint(_TEST_ENDPOINT_NAME)

test_prediction = test_endpoint.predict(
instances=_TEST_INSTANCES,
parameters={"param": 3.0},
use_dedicated_endpoint=True,
timeout=_TEST_PREDICT_TIMEOUT,
)

true_prediction = models.Prediction(
predictions=_TEST_PREDICTION,
deployed_model_id=_TEST_ID,
metadata=_TEST_METADATA,
model_version_id=_TEST_VERSION_ID,
model_resource_name=_TEST_MODEL_NAME,
)

assert true_prediction == test_prediction
predict_endpoint_http_mock.assert_called_once_with(
url=f"https://{_TEST_DEDICATED_ENDPOINT_DNS}/v1/projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/endpoints/{_TEST_ID}:predict",
data='{"instances": [[1.0, 2.0, 3.0], [1.0, 3.0, 4.0]], "parameters": {"param": 3.0}}',
headers={"Content-Type": "application/json"},
timeout=_TEST_PREDICT_TIMEOUT,
)

@pytest.mark.usefixtures("get_endpoint_mock")
Expand Down Expand Up @@ -2432,6 +2461,40 @@ def test_raw_predict_dedicated_endpoint(self, predict_endpoint_http_mock):
url=f"https://{_TEST_DEDICATED_ENDPOINT_DNS}/v1/projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/endpoints/{_TEST_ID}:rawPredict",
data=_TEST_RAW_INPUTS,
headers={"Content-Type": "application/json"},
timeout=None,
)

@pytest.mark.usefixtures("get_dedicated_endpoint_mock")
def test_raw_predict_dedicated_endpoint_with_timeout(
self, predict_endpoint_http_mock
):
test_endpoint = models.Endpoint(_TEST_ENDPOINT_NAME)

test_prediction = test_endpoint.raw_predict(
body=_TEST_RAW_INPUTS,
headers={"Content-Type": "application/json"},
use_dedicated_endpoint=True,
timeout=_TEST_PREDICT_TIMEOUT,
)

true_prediction = requests.Response()
true_prediction.status_code = 200
true_prediction._content = json.dumps(
{
"predictions": _TEST_PREDICTION,
"metadata": _TEST_METADATA,
"deployedModelId": _TEST_DEPLOYED_MODELS[0].id,
"model": _TEST_MODEL_NAME,
"modelVersionId": "1",
}
).encode("utf-8")
assert true_prediction.status_code == test_prediction.status_code
assert true_prediction.text == test_prediction.text
predict_endpoint_http_mock.assert_called_once_with(
url=f"https://{_TEST_DEDICATED_ENDPOINT_DNS}/v1/projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/endpoints/{_TEST_ID}:rawPredict",
data=_TEST_RAW_INPUTS,
headers={"Content-Type": "application/json"},
timeout=_TEST_PREDICT_TIMEOUT,
)

@pytest.mark.usefixtures("get_endpoint_mock")
Expand Down
1 change: 1 addition & 0 deletions tests/unit/aiplatform/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3900,6 +3900,7 @@ def test_raw_predict(self, raw_predict_mock):
url=_TEST_RAW_PREDICT_URL,
data=_TEST_RAW_PREDICT_DATA,
headers=_TEST_RAW_PREDICT_HEADER,
timeout=None,
)

@pytest.mark.parametrize(
Expand Down

0 comments on commit b7de16a

Please sign in to comment.