Skip to content

Commit

Permalink
feat: Add timeout arguments to Endpoint.predict and Endpoint.explain (#…
Browse files Browse the repository at this point in the history
…1094)

Fixes # [b/224990641](b/224990641) 🦕
  • Loading branch information
sasha-gitg authored Mar 21, 2022
1 parent 25b546a commit cc59e60
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 4 deletions.
16 changes: 14 additions & 2 deletions google/cloud/aiplatform/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1167,7 +1167,12 @@ def _instantiate_prediction_client(
prediction_client=True,
)

def predict(self, instances: List, parameters: Optional[Dict] = None) -> Prediction:
def predict(
self,
instances: List,
parameters: Optional[Dict] = None,
timeout: Optional[float] = None,
) -> Prediction:
"""Make a prediction against this Endpoint.
Args:
Expand All @@ -1190,13 +1195,17 @@ def predict(self, instances: List, parameters: Optional[Dict] = None) -> Predict
][google.cloud.aiplatform.v1beta1.DeployedModel.model]
[PredictSchemata's][google.cloud.aiplatform.v1beta1.Model.predict_schemata]
``parameters_schema_uri``.
timeout (float): Optional. The timeout for this request in seconds.
Returns:
prediction: Prediction with returned predictions and Model Id.
"""
self.wait()

prediction_response = self._prediction_client.predict(
endpoint=self._gca_resource.name, instances=instances, parameters=parameters
endpoint=self._gca_resource.name,
instances=instances,
parameters=parameters,
timeout=timeout,
)

return Prediction(
Expand All @@ -1212,6 +1221,7 @@ def explain(
instances: List[Dict],
parameters: Optional[Dict] = None,
deployed_model_id: Optional[str] = None,
timeout: Optional[float] = None,
) -> Prediction:
"""Make a prediction with explanations against this Endpoint.
Expand Down Expand Up @@ -1242,6 +1252,7 @@ def explain(
deployed_model_id (str):
Optional. If specified, this ExplainRequest will be served by the
chosen DeployedModel, overriding this Endpoint's traffic split.
timeout (float): Optional. The timeout for this request in seconds.
Returns:
prediction: Prediction with returned predictions, explanations and Model Id.
"""
Expand All @@ -1252,6 +1263,7 @@ def explain(
instances=instances,
parameters=parameters,
deployed_model_id=deployed_model_id,
timeout=timeout,
)

return Prediction(
Expand Down
5 changes: 3 additions & 2 deletions tests/system/aiplatform/test_e2e_tabular.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,13 +164,14 @@ def test_end_to_end_tabular(self, shared_state):
is True
)

custom_prediction = custom_endpoint.predict([_INSTANCE])
custom_prediction = custom_endpoint.predict([_INSTANCE], timeout=180.0)

custom_batch_prediction_job.wait()

automl_endpoint.wait()
automl_prediction = automl_endpoint.predict(
[{k: str(v) for k, v in _INSTANCE.items()}] # Cast int values to strings
[{k: str(v) for k, v in _INSTANCE.items()}], # Cast int values to strings
timeout=180.0,
)

# Test lazy loading of Endpoint, check getter was never called after predict()
Expand Down
1 change: 1 addition & 0 deletions tests/unit/aiplatform/test_end_to_end.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ def test_dataset_create_to_model_predict(
endpoint=test_endpoints._TEST_ENDPOINT_NAME,
instances=[[1.0, 2.0, 3.0], [1.0, 3.0, 4.0]],
parameters={"param": 3.0},
timeout=None,
)

expected_dataset = gca_dataset.Dataset(
Expand Down
38 changes: 38 additions & 0 deletions tests/unit/aiplatform/test_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -1162,6 +1162,7 @@ def test_predict(self, get_endpoint_mock, predict_client_predict_mock):
endpoint=_TEST_ENDPOINT_NAME,
instances=_TEST_INSTANCES,
parameters={"param": 3.0},
timeout=None,
)

def test_explain(self, get_endpoint_mock, predict_client_explain_mock):
Expand All @@ -1187,6 +1188,43 @@ def test_explain(self, get_endpoint_mock, predict_client_explain_mock):
instances=_TEST_INSTANCES,
parameters={"param": 3.0},
deployed_model_id=_TEST_MODEL_ID,
timeout=None,
)

@pytest.mark.usefixtures("get_endpoint_mock")
def test_predict_with_timeout(self, predict_client_predict_mock):

test_endpoint = models.Endpoint(_TEST_ID)

test_endpoint.predict(
instances=_TEST_INSTANCES, parameters={"param": 3.0}, timeout=10.0
)

predict_client_predict_mock.assert_called_once_with(
endpoint=_TEST_ENDPOINT_NAME,
instances=_TEST_INSTANCES,
parameters={"param": 3.0},
timeout=10.0,
)

@pytest.mark.usefixtures("get_endpoint_mock")
def test_explain_with_timeout(self, predict_client_explain_mock):

test_endpoint = models.Endpoint(_TEST_ID)

test_endpoint.explain(
instances=_TEST_INSTANCES,
parameters={"param": 3.0},
deployed_model_id=_TEST_MODEL_ID,
timeout=10.0,
)

predict_client_explain_mock.assert_called_once_with(
endpoint=_TEST_ENDPOINT_NAME,
instances=_TEST_INSTANCES,
parameters={"param": 3.0},
deployed_model_id=_TEST_MODEL_ID,
timeout=10.0,
)

def test_list_models(self, get_endpoint_with_models_mock):
Expand Down

0 comments on commit cc59e60

Please sign in to comment.