diff --git a/google/cloud/aiplatform/models.py b/google/cloud/aiplatform/models.py index 5029f4db6d..ed986a64dc 100644 --- a/google/cloud/aiplatform/models.py +++ b/google/cloud/aiplatform/models.py @@ -1167,7 +1167,12 @@ def _instantiate_prediction_client( prediction_client=True, ) - def predict(self, instances: List, parameters: Optional[Dict] = None) -> Prediction: + def predict( + self, + instances: List, + parameters: Optional[Dict] = None, + timeout: Optional[float] = None, + ) -> Prediction: """Make a prediction against this Endpoint. Args: @@ -1190,13 +1195,17 @@ def predict(self, instances: List, parameters: Optional[Dict] = None) -> Predict ][google.cloud.aiplatform.v1beta1.DeployedModel.model] [PredictSchemata's][google.cloud.aiplatform.v1beta1.Model.predict_schemata] ``parameters_schema_uri``. + timeout (float): Optional. The timeout for this request in seconds. Returns: prediction: Prediction with returned predictions and Model Id. """ self.wait() prediction_response = self._prediction_client.predict( - endpoint=self._gca_resource.name, instances=instances, parameters=parameters + endpoint=self._gca_resource.name, + instances=instances, + parameters=parameters, + timeout=timeout, ) return Prediction( @@ -1212,6 +1221,7 @@ def explain( instances: List[Dict], parameters: Optional[Dict] = None, deployed_model_id: Optional[str] = None, + timeout: Optional[float] = None, ) -> Prediction: """Make a prediction with explanations against this Endpoint. @@ -1242,6 +1252,7 @@ def explain( deployed_model_id (str): Optional. If specified, this ExplainRequest will be served by the chosen DeployedModel, overriding this Endpoint's traffic split. + timeout (float): Optional. The timeout for this request in seconds. Returns: prediction: Prediction with returned predictions, explanations and Model Id. """ @@ -1252,6 +1263,7 @@ def explain( instances=instances, parameters=parameters, deployed_model_id=deployed_model_id, + timeout=timeout, ) return Prediction( diff --git a/tests/system/aiplatform/test_e2e_tabular.py b/tests/system/aiplatform/test_e2e_tabular.py index ee9e6bc7b2..ee0692014a 100644 --- a/tests/system/aiplatform/test_e2e_tabular.py +++ b/tests/system/aiplatform/test_e2e_tabular.py @@ -164,13 +164,14 @@ def test_end_to_end_tabular(self, shared_state): is True ) - custom_prediction = custom_endpoint.predict([_INSTANCE]) + custom_prediction = custom_endpoint.predict([_INSTANCE], timeout=180.0) custom_batch_prediction_job.wait() automl_endpoint.wait() automl_prediction = automl_endpoint.predict( - [{k: str(v) for k, v in _INSTANCE.items()}] # Cast int values to strings + [{k: str(v) for k, v in _INSTANCE.items()}], # Cast int values to strings + timeout=180.0, ) # Test lazy loading of Endpoint, check getter was never called after predict() diff --git a/tests/unit/aiplatform/test_end_to_end.py b/tests/unit/aiplatform/test_end_to_end.py index 715cda06c7..09a3db12f4 100644 --- a/tests/unit/aiplatform/test_end_to_end.py +++ b/tests/unit/aiplatform/test_end_to_end.py @@ -174,6 +174,7 @@ def test_dataset_create_to_model_predict( endpoint=test_endpoints._TEST_ENDPOINT_NAME, instances=[[1.0, 2.0, 3.0], [1.0, 3.0, 4.0]], parameters={"param": 3.0}, + timeout=None, ) expected_dataset = gca_dataset.Dataset( diff --git a/tests/unit/aiplatform/test_endpoints.py b/tests/unit/aiplatform/test_endpoints.py index f8b9d2ac67..2adfc23e01 100644 --- a/tests/unit/aiplatform/test_endpoints.py +++ b/tests/unit/aiplatform/test_endpoints.py @@ -1162,6 +1162,7 @@ def test_predict(self, get_endpoint_mock, predict_client_predict_mock): endpoint=_TEST_ENDPOINT_NAME, instances=_TEST_INSTANCES, parameters={"param": 3.0}, + timeout=None, ) def test_explain(self, get_endpoint_mock, predict_client_explain_mock): @@ -1187,6 +1188,43 @@ def test_explain(self, get_endpoint_mock, predict_client_explain_mock): instances=_TEST_INSTANCES, parameters={"param": 3.0}, deployed_model_id=_TEST_MODEL_ID, + timeout=None, + ) + + @pytest.mark.usefixtures("get_endpoint_mock") + def test_predict_with_timeout(self, predict_client_predict_mock): + + test_endpoint = models.Endpoint(_TEST_ID) + + test_endpoint.predict( + instances=_TEST_INSTANCES, parameters={"param": 3.0}, timeout=10.0 + ) + + predict_client_predict_mock.assert_called_once_with( + endpoint=_TEST_ENDPOINT_NAME, + instances=_TEST_INSTANCES, + parameters={"param": 3.0}, + timeout=10.0, + ) + + @pytest.mark.usefixtures("get_endpoint_mock") + def test_explain_with_timeout(self, predict_client_explain_mock): + + test_endpoint = models.Endpoint(_TEST_ID) + + test_endpoint.explain( + instances=_TEST_INSTANCES, + parameters={"param": 3.0}, + deployed_model_id=_TEST_MODEL_ID, + timeout=10.0, + ) + + predict_client_explain_mock.assert_called_once_with( + endpoint=_TEST_ENDPOINT_NAME, + instances=_TEST_INSTANCES, + parameters={"param": 3.0}, + deployed_model_id=_TEST_MODEL_ID, + timeout=10.0, ) def test_list_models(self, get_endpoint_with_models_mock):