diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 68eea393f1221..28e9858e1133b 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -334,7 +334,6 @@ repos: # https://github.com/apache/airflow/issues/36484 exclude: | (?x)^( - ^airflow\/providers\/google\/cloud\/operators\/mlengine.py$| ^airflow\/providers\/google\/cloud\/operators\/cloud_storage_transfer_service.py$| ^airflow\/providers\/apache\/spark\/operators\/spark_submit.py\.py$| ^airflow\/providers\/google\/cloud\/operators\/vertex_ai\/auto_ml\.py$| diff --git a/airflow/providers/google/cloud/operators/mlengine.py b/airflow/providers/google/cloud/operators/mlengine.py index 79bde913a36c1..d46a54c8b643a 100644 --- a/airflow/providers/google/cloud/operators/mlengine.py +++ b/airflow/providers/google/cloud/operators/mlengine.py @@ -111,7 +111,7 @@ class MLEngineStartBatchPredictionJobOperator(GoogleCloudBaseOperator): In options 2 and 3, both model and version name should contain the minimal identifier. For instance, call:: - MLEngineBatchPredictionOperator( + MLEngineStartBatchPredictionJobOperator( ..., model_name='my_model', version_name='my_version', @@ -173,15 +173,15 @@ class MLEngineStartBatchPredictionJobOperator(GoogleCloudBaseOperator): """ template_fields: Sequence[str] = ( - "_project_id", - "_job_id", - "_region", - "_input_paths", - "_output_path", - "_model_name", - "_version_name", - "_uri", - "_impersonation_chain", + "project_id", + "job_id", + "region", + "input_paths", + "output_path", + "model_name", + "version_name", + "uri", + "impersonation_chain", ) def __init__( @@ -206,67 +206,66 @@ def __init__( ) -> None: super().__init__(**kwargs) - self._project_id = project_id - self._job_id = job_id - self._region = region + self.project_id = project_id + self.job_id = job_id + self.region = region self._data_format = data_format - self._input_paths = input_paths - self._output_path = output_path - self._model_name = model_name - self._version_name = version_name - self._uri = uri + self.input_paths = input_paths + self.output_path = output_path + self.model_name = model_name + self.version_name = version_name + self.uri = uri self._max_worker_count = max_worker_count self._runtime_version = runtime_version self._signature_name = signature_name self._gcp_conn_id = gcp_conn_id self._labels = labels - self._impersonation_chain = impersonation_chain + self.impersonation_chain = impersonation_chain - if not self._project_id: + def execute(self, context: Context): + if not self.project_id: raise AirflowException("Google Cloud project id is required.") - if not self._job_id: + if not self.job_id: raise AirflowException("An unique job id is required for Google MLEngine prediction job.") - if self._uri: - if self._model_name or self._version_name: + if self.uri: + if self.model_name or self.version_name: raise AirflowException( "Ambiguous model origin: Both uri and model/version name are provided." ) - if self._version_name and not self._model_name: + if self.version_name and not self.model_name: raise AirflowException( "Missing model: Batch prediction expects a model name when a version name is provided." ) - if not (self._uri or self._model_name): + if not (self.uri or self.model_name): raise AirflowException( "Missing model origin: Batch prediction expects a model, " "a model & version combination, or a URI to a savedModel." ) - - def execute(self, context: Context): - job_id = _normalize_mlengine_job_id(self._job_id) + job_id = _normalize_mlengine_job_id(self.job_id) prediction_request: dict[str, Any] = { "jobId": job_id, "predictionInput": { "dataFormat": self._data_format, - "inputPaths": self._input_paths, - "outputPath": self._output_path, - "region": self._region, + "inputPaths": self.input_paths, + "outputPath": self.output_path, + "region": self.region, }, } if self._labels: prediction_request["labels"] = self._labels - if self._uri: - prediction_request["predictionInput"]["uri"] = self._uri - elif self._model_name: - origin_name = f"projects/{self._project_id}/models/{self._model_name}" - if not self._version_name: + if self.uri: + prediction_request["predictionInput"]["uri"] = self.uri + elif self.model_name: + origin_name = f"projects/{self.project_id}/models/{self.model_name}" + if not self.version_name: prediction_request["predictionInput"]["modelName"] = origin_name else: prediction_request["predictionInput"]["versionName"] = ( - origin_name + f"/versions/{self._version_name}" + origin_name + f"/versions/{self.version_name}" ) if self._max_worker_count: @@ -278,7 +277,7 @@ def execute(self, context: Context): if self._signature_name: prediction_request["predictionInput"]["signatureName"] = self._signature_name - hook = MLEngineHook(gcp_conn_id=self._gcp_conn_id, impersonation_chain=self._impersonation_chain) + hook = MLEngineHook(gcp_conn_id=self._gcp_conn_id, impersonation_chain=self.impersonation_chain) # Helper method to check if the existing job's prediction input is the # same as the request we get here. @@ -286,7 +285,7 @@ def check_existing_job(existing_job): return existing_job.get("predictionInput") == prediction_request["predictionInput"] finished_prediction_job = hook.create_job( - project_id=self._project_id, job=prediction_request, use_existing_job_fn=check_existing_job + project_id=self.project_id, job=prediction_request, use_existing_job_fn=check_existing_job ) if finished_prediction_job["state"] != "SUCCEEDED": @@ -336,9 +335,9 @@ class MLEngineManageModelOperator(GoogleCloudBaseOperator): """ template_fields: Sequence[str] = ( - "_project_id", - "_model", - "_impersonation_chain", + "project_id", + "model", + "impersonation_chain", ) def __init__( @@ -352,21 +351,21 @@ def __init__( **kwargs, ) -> None: super().__init__(**kwargs) - self._project_id = project_id - self._model = model + self.project_id = project_id + self.model = model self._operation = operation self._gcp_conn_id = gcp_conn_id - self._impersonation_chain = impersonation_chain + self.impersonation_chain = impersonation_chain def execute(self, context: Context): hook = MLEngineHook( gcp_conn_id=self._gcp_conn_id, - impersonation_chain=self._impersonation_chain, + impersonation_chain=self.impersonation_chain, ) if self._operation == "create": - return hook.create_model(project_id=self._project_id, model=self._model) + return hook.create_model(project_id=self.project_id, model=self.model) elif self._operation == "get": - return hook.get_model(project_id=self._project_id, model_name=self._model["name"]) + return hook.get_model(project_id=self.project_id, model_name=self.model["name"]) else: raise ValueError(f"Unknown operation: {self._operation}") @@ -408,9 +407,9 @@ class MLEngineCreateModelOperator(GoogleCloudBaseOperator): """ template_fields: Sequence[str] = ( - "_project_id", - "_model", - "_impersonation_chain", + "project_id", + "model", + "impersonation_chain", ) operator_extra_links = (MLEngineModelLink(),) @@ -424,27 +423,27 @@ def __init__( **kwargs, ) -> None: super().__init__(**kwargs) - self._project_id = project_id - self._model = model + self.project_id = project_id + self.model = model self._gcp_conn_id = gcp_conn_id - self._impersonation_chain = impersonation_chain + self.impersonation_chain = impersonation_chain def execute(self, context: Context): hook = MLEngineHook( gcp_conn_id=self._gcp_conn_id, - impersonation_chain=self._impersonation_chain, + impersonation_chain=self.impersonation_chain, ) - project_id = self._project_id or hook.project_id + project_id = self.project_id or hook.project_id if project_id: MLEngineModelLink.persist( context=context, task_instance=self, project_id=project_id, - model_id=self._model["name"], + model_id=self.model["name"], ) - return hook.create_model(project_id=self._project_id, model=self._model) + return hook.create_model(project_id=self.project_id, model=self.model) @deprecated( @@ -484,9 +483,9 @@ class MLEngineGetModelOperator(GoogleCloudBaseOperator): """ template_fields: Sequence[str] = ( - "_project_id", - "_model_name", - "_impersonation_chain", + "project_id", + "model_name", + "impersonation_chain", ) operator_extra_links = (MLEngineModelLink(),) @@ -500,26 +499,26 @@ def __init__( **kwargs, ) -> None: super().__init__(**kwargs) - self._project_id = project_id - self._model_name = model_name + self.project_id = project_id + self.model_name = model_name self._gcp_conn_id = gcp_conn_id - self._impersonation_chain = impersonation_chain + self.impersonation_chain = impersonation_chain def execute(self, context: Context): hook = MLEngineHook( gcp_conn_id=self._gcp_conn_id, - impersonation_chain=self._impersonation_chain, + impersonation_chain=self.impersonation_chain, ) - project_id = self._project_id or hook.project_id + project_id = self.project_id or hook.project_id if project_id: MLEngineModelLink.persist( context=context, task_instance=self, project_id=project_id, - model_id=self._model_name, + model_id=self.model_name, ) - return hook.get_model(project_id=self._project_id, model_name=self._model_name) + return hook.get_model(project_id=self.project_id, model_name=self.model_name) @deprecated( @@ -563,9 +562,9 @@ class MLEngineDeleteModelOperator(GoogleCloudBaseOperator): """ template_fields: Sequence[str] = ( - "_project_id", - "_model_name", - "_impersonation_chain", + "project_id", + "model_name", + "impersonation_chain", ) operator_extra_links = (MLEngineModelsListLink(),) @@ -580,19 +579,19 @@ def __init__( **kwargs, ) -> None: super().__init__(**kwargs) - self._project_id = project_id - self._model_name = model_name + self.project_id = project_id + self.model_name = model_name self._delete_contents = delete_contents self._gcp_conn_id = gcp_conn_id - self._impersonation_chain = impersonation_chain + self.impersonation_chain = impersonation_chain def execute(self, context: Context): hook = MLEngineHook( gcp_conn_id=self._gcp_conn_id, - impersonation_chain=self._impersonation_chain, + impersonation_chain=self.impersonation_chain, ) - project_id = self._project_id or hook.project_id + project_id = self.project_id or hook.project_id if project_id: MLEngineModelsListLink.persist( context=context, @@ -601,7 +600,7 @@ def execute(self, context: Context): ) return hook.delete_model( - project_id=self._project_id, model_name=self._model_name, delete_contents=self._delete_contents + project_id=self.project_id, model_name=self.model_name, delete_contents=self._delete_contents ) @@ -667,11 +666,11 @@ class MLEngineManageVersionOperator(GoogleCloudBaseOperator): """ template_fields: Sequence[str] = ( - "_project_id", - "_model_name", - "_version_name", - "_version", - "_impersonation_chain", + "project_id", + "model_name", + "version_name", + "version", + "impersonation_chain", ) def __init__( @@ -687,38 +686,38 @@ def __init__( **kwargs, ) -> None: super().__init__(**kwargs) - self._project_id = project_id - self._model_name = model_name - self._version_name = version_name - self._version = version or {} + self.project_id = project_id + self.model_name = model_name + self.version_name = version_name + self.version = version or {} self._operation = operation self._gcp_conn_id = gcp_conn_id - self._impersonation_chain = impersonation_chain + self.impersonation_chain = impersonation_chain def execute(self, context: Context): - if "name" not in self._version: - self._version["name"] = self._version_name + if "name" not in self.version: + self.version["name"] = self.version_name hook = MLEngineHook( gcp_conn_id=self._gcp_conn_id, - impersonation_chain=self._impersonation_chain, + impersonation_chain=self.impersonation_chain, ) if self._operation == "create": - if not self._version: + if not self.version: raise ValueError(f"version attribute of {self.__class__.__name__} could not be empty") return hook.create_version( - project_id=self._project_id, model_name=self._model_name, version_spec=self._version + project_id=self.project_id, model_name=self.model_name, version_spec=self.version ) elif self._operation == "set_default": return hook.set_default_version( - project_id=self._project_id, model_name=self._model_name, version_name=self._version["name"] + project_id=self.project_id, model_name=self.model_name, version_name=self.version["name"] ) elif self._operation == "list": - return hook.list_versions(project_id=self._project_id, model_name=self._model_name) + return hook.list_versions(project_id=self.project_id, model_name=self.model_name) elif self._operation == "delete": return hook.delete_version( - project_id=self._project_id, model_name=self._model_name, version_name=self._version["name"] + project_id=self.project_id, model_name=self.model_name, version_name=self.version["name"] ) else: raise ValueError(f"Unknown operation: {self._operation}") @@ -762,10 +761,10 @@ class MLEngineCreateVersionOperator(GoogleCloudBaseOperator): """ template_fields: Sequence[str] = ( - "_project_id", - "_model_name", - "_version", - "_impersonation_chain", + "project_id", + "model_name", + "version", + "impersonation_chain", ) operator_extra_links = (MLEngineModelVersionDetailsLink(),) @@ -780,38 +779,38 @@ def __init__( **kwargs, ) -> None: super().__init__(**kwargs) - self._project_id = project_id - self._model_name = model_name - self._version = version + self.project_id = project_id + self.model_name = model_name + self.version = version self._gcp_conn_id = gcp_conn_id - self._impersonation_chain = impersonation_chain - self._validate_inputs() + self.impersonation_chain = impersonation_chain def _validate_inputs(self): - if not self._model_name: + if not self.model_name: raise AirflowException("The model_name parameter could not be empty.") - if not self._version: + if not self.version: raise AirflowException("The version parameter could not be empty.") def execute(self, context: Context): + self._validate_inputs() hook = MLEngineHook( gcp_conn_id=self._gcp_conn_id, - impersonation_chain=self._impersonation_chain, + impersonation_chain=self.impersonation_chain, ) - project_id = self._project_id or hook.project_id + project_id = self.project_id or hook.project_id if project_id: MLEngineModelVersionDetailsLink.persist( context=context, task_instance=self, project_id=project_id, - model_id=self._model_name, - version_id=self._version["name"], + model_id=self.model_name, + version_id=self.version["name"], ) return hook.create_version( - project_id=self._project_id, model_name=self._model_name, version_spec=self._version + project_id=self.project_id, model_name=self.model_name, version_spec=self.version ) @@ -855,10 +854,10 @@ class MLEngineSetDefaultVersionOperator(GoogleCloudBaseOperator): """ template_fields: Sequence[str] = ( - "_project_id", - "_model_name", - "_version_name", - "_impersonation_chain", + "project_id", + "model_name", + "version_name", + "impersonation_chain", ) operator_extra_links = (MLEngineModelVersionDetailsLink(),) @@ -873,38 +872,38 @@ def __init__( **kwargs, ) -> None: super().__init__(**kwargs) - self._project_id = project_id - self._model_name = model_name - self._version_name = version_name + self.project_id = project_id + self.model_name = model_name + self.version_name = version_name self._gcp_conn_id = gcp_conn_id - self._impersonation_chain = impersonation_chain - self._validate_inputs() + self.impersonation_chain = impersonation_chain def _validate_inputs(self): - if not self._model_name: + if not self.model_name: raise AirflowException("The model_name parameter could not be empty.") - if not self._version_name: + if not self.version_name: raise AirflowException("The version_name parameter could not be empty.") def execute(self, context: Context): + self._validate_inputs() hook = MLEngineHook( gcp_conn_id=self._gcp_conn_id, - impersonation_chain=self._impersonation_chain, + impersonation_chain=self.impersonation_chain, ) - project_id = self._project_id or hook.project_id + project_id = self.project_id or hook.project_id if project_id: MLEngineModelVersionDetailsLink.persist( context=context, task_instance=self, project_id=project_id, - model_id=self._model_name, - version_id=self._version_name, + model_id=self.model_name, + version_id=self.version_name, ) return hook.set_default_version( - project_id=self._project_id, model_name=self._model_name, version_name=self._version_name + project_id=self.project_id, model_name=self.model_name, version_name=self.version_name ) @@ -947,9 +946,9 @@ class MLEngineListVersionsOperator(GoogleCloudBaseOperator): """ template_fields: Sequence[str] = ( - "_project_id", - "_model_name", - "_impersonation_chain", + "project_id", + "model_name", + "impersonation_chain", ) operator_extra_links = (MLEngineModelLink(),) @@ -963,34 +962,34 @@ def __init__( **kwargs, ) -> None: super().__init__(**kwargs) - self._project_id = project_id - self._model_name = model_name + self.project_id = project_id + self.model_name = model_name self._gcp_conn_id = gcp_conn_id - self._impersonation_chain = impersonation_chain - self._validate_inputs() + self.impersonation_chain = impersonation_chain def _validate_inputs(self): - if not self._model_name: + if not self.model_name: raise AirflowException("The model_name parameter could not be empty.") def execute(self, context: Context): + self._validate_inputs() hook = MLEngineHook( gcp_conn_id=self._gcp_conn_id, - impersonation_chain=self._impersonation_chain, + impersonation_chain=self.impersonation_chain, ) - project_id = self._project_id or hook.project_id + project_id = self.project_id or hook.project_id if project_id: MLEngineModelLink.persist( context=context, task_instance=self, project_id=project_id, - model_id=self._model_name, + model_id=self.model_name, ) return hook.list_versions( - project_id=self._project_id, - model_name=self._model_name, + project_id=self.project_id, + model_name=self.model_name, ) @@ -1034,10 +1033,10 @@ class MLEngineDeleteVersionOperator(GoogleCloudBaseOperator): """ template_fields: Sequence[str] = ( - "_project_id", - "_model_name", - "_version_name", - "_impersonation_chain", + "project_id", + "model_name", + "version_name", + "impersonation_chain", ) operator_extra_links = (MLEngineModelLink(),) @@ -1052,37 +1051,37 @@ def __init__( **kwargs, ) -> None: super().__init__(**kwargs) - self._project_id = project_id - self._model_name = model_name - self._version_name = version_name + self.project_id = project_id + self.model_name = model_name + self.version_name = version_name self._gcp_conn_id = gcp_conn_id - self._impersonation_chain = impersonation_chain - self._validate_inputs() + self.impersonation_chain = impersonation_chain def _validate_inputs(self): - if not self._model_name: + if not self.model_name: raise AirflowException("The model_name parameter could not be empty.") - if not self._version_name: + if not self.version_name: raise AirflowException("The version_name parameter could not be empty.") def execute(self, context: Context): + self._validate_inputs() hook = MLEngineHook( gcp_conn_id=self._gcp_conn_id, - impersonation_chain=self._impersonation_chain, + impersonation_chain=self.impersonation_chain, ) - project_id = self._project_id or hook.project_id + project_id = self.project_id or hook.project_id if project_id: MLEngineModelLink.persist( context=context, task_instance=self, project_id=project_id, - model_id=self._model_name, + model_id=self.model_name, ) return hook.delete_version( - project_id=self._project_id, model_name=self._model_name, version_name=self._version_name + project_id=self.project_id, model_name=self.model_name, version_name=self.version_name ) @@ -1163,21 +1162,21 @@ class MLEngineStartTrainingJobOperator(GoogleCloudBaseOperator): """ template_fields: Sequence[str] = ( - "_project_id", - "_job_id", - "_region", - "_package_uris", - "_training_python_module", - "_training_args", - "_scale_tier", - "_master_type", - "_master_config", - "_runtime_version", - "_python_version", - "_job_dir", - "_service_account", - "_hyperparameters", - "_impersonation_chain", + "project_id", + "job_id", + "region", + "package_uris", + "training_python_module", + "training_args", + "scale_tier", + "master_type", + "master_config", + "runtime_version", + "python_version", + "job_dir", + "service_account", + "hyperparameters", + "impersonation_chain", ) operator_extra_links = (MLEngineJobDetailsLink(),) @@ -1207,98 +1206,95 @@ def __init__( **kwargs, ) -> None: super().__init__(**kwargs) - self._project_id = project_id - self._job_id = job_id - self._region = region - self._package_uris = package_uris - self._training_python_module = training_python_module - self._training_args = training_args - self._scale_tier = scale_tier - self._master_type = master_type - self._master_config = master_config - self._runtime_version = runtime_version - self._python_version = python_version - self._job_dir = job_dir - self._service_account = service_account + self.project_id = project_id + self.job_id = job_id + self.region = region + self.package_uris = package_uris + self.training_python_module = training_python_module + self.training_args = training_args + self.scale_tier = scale_tier + self.master_type = master_type + self.master_config = master_config + self.runtime_version = runtime_version + self.python_version = python_version + self.job_dir = job_dir + self.service_account = service_account self._gcp_conn_id = gcp_conn_id self._mode = mode self._labels = labels - self._hyperparameters = hyperparameters - self._impersonation_chain = impersonation_chain + self.hyperparameters = hyperparameters + self.impersonation_chain = impersonation_chain self.deferrable = deferrable self.cancel_on_kill = cancel_on_kill - custom = self._scale_tier is not None and self._scale_tier.upper() == "CUSTOM" + def _handle_job_error(self, finished_training_job) -> None: + if finished_training_job["state"] != "SUCCEEDED": + self.log.error("MLEngine training job failed: %s", finished_training_job) + raise RuntimeError(finished_training_job["errorMessage"]) + + def execute(self, context: Context): + custom = self.scale_tier is not None and self.scale_tier.upper() == "CUSTOM" custom_image = ( - custom - and self._master_config is not None - and self._master_config.get("imageUri", None) is not None + custom and self.master_config is not None and self.master_config.get("imageUri", None) is not None ) - if not self._project_id: + if not self.project_id: raise AirflowException("Google Cloud project id is required.") - if not self._job_id: + if not self.job_id: raise AirflowException("An unique job id is required for Google MLEngine training job.") - if not self._region: + if not self.region: raise AirflowException("Google Compute Engine region is required.") - if custom and not self._master_type: + if custom and not self.master_type: raise AirflowException("master_type must be set when scale_tier is CUSTOM") - if self._master_config and not self._master_type: + if self.master_config and not self.master_type: raise AirflowException("master_type must be set when master_config is provided") - if not (package_uris and training_python_module) and not custom_image: + if not (self.package_uris and self.training_python_module) and not custom_image: raise AirflowException( "Either a Python package with a Python module or a custom Docker image should be provided." ) - if (package_uris or training_python_module) and custom_image: + if (self.package_uris or self.training_python_module) and custom_image: raise AirflowException( "Either a Python package with a Python module or " "a custom Docker image should be provided but not both." ) - - def _handle_job_error(self, finished_training_job) -> None: - if finished_training_job["state"] != "SUCCEEDED": - self.log.error("MLEngine training job failed: %s", finished_training_job) - raise RuntimeError(finished_training_job["errorMessage"]) - - def execute(self, context: Context): - job_id = _normalize_mlengine_job_id(self._job_id) + job_id = _normalize_mlengine_job_id(self.job_id) self.job_id = job_id training_request: dict[str, Any] = { "jobId": self.job_id, "trainingInput": { - "scaleTier": self._scale_tier, - "region": self._region, + "scaleTier": self.scale_tier, + "region": self.region, }, } - if self._package_uris: - training_request["trainingInput"]["packageUris"] = self._package_uris + if self.package_uris: + training_request["trainingInput"]["packageUris"] = self.package_uris - if self._training_python_module: - training_request["trainingInput"]["pythonModule"] = self._training_python_module + if self.training_python_module: + training_request["trainingInput"]["pythonModule"] = self.training_python_module - if self._training_args: - training_request["trainingInput"]["args"] = self._training_args + if self.training_args: + training_request["trainingInput"]["args"] = self.training_args - if self._master_type: - training_request["trainingInput"]["masterType"] = self._master_type + if self.master_type: + training_request["trainingInput"]["masterType"] = self.master_type - if self._master_config: - training_request["trainingInput"]["masterConfig"] = self._master_config + if self.master_config: + training_request["trainingInput"]["masterConfig"] = self.master_config - if self._runtime_version: - training_request["trainingInput"]["runtimeVersion"] = self._runtime_version + if self.runtime_version: + training_request["trainingInput"]["runtimeVersion"] = self.runtime_version - if self._python_version: - training_request["trainingInput"]["pythonVersion"] = self._python_version + if self.python_version: + training_request["trainingInput"]["pythonVersion"] = self.python_version - if self._job_dir: - training_request["trainingInput"]["jobDir"] = self._job_dir + if self.job_dir: + training_request["trainingInput"]["jobDir"] = self.job_dir - if self._service_account: - training_request["trainingInput"]["serviceAccount"] = self._service_account + if self.service_account: + training_request["trainingInput"]["serviceAccount"] = self.service_account - if self._hyperparameters: - training_request["trainingInput"]["hyperparameters"] = self._hyperparameters + if self.hyperparameters: + training_request["trainingInput"]["hyperparameters"] = self.hyperparameters if self._labels: training_request["labels"] = self._labels @@ -1310,25 +1306,25 @@ def execute(self, context: Context): hook = MLEngineHook( gcp_conn_id=self._gcp_conn_id, - impersonation_chain=self._impersonation_chain, + impersonation_chain=self.impersonation_chain, ) self.hook = hook try: self.log.info("Executing: %s'", training_request) self.job_id = self.hook.create_job_without_waiting_result( - project_id=self._project_id, + project_id=self.project_id, body=training_request, ) except HttpError as e: if e.resp.status == 409: # If the job already exists retrieve it - self.hook.get_job(project_id=self._project_id, job_id=self.job_id) - if self._project_id: + self.hook.get_job(project_id=self.project_id, job_id=self.job_id) + if self.project_id: MLEngineJobDetailsLink.persist( context=context, task_instance=self, - project_id=self._project_id, + project_id=self.project_id, job_id=self.job_id, ) self.log.error( @@ -1345,30 +1341,30 @@ def execute(self, context: Context): trigger=MLEngineStartTrainingJobTrigger( conn_id=self._gcp_conn_id, job_id=self.job_id, - project_id=self._project_id, - region=self._region, - runtime_version=self._runtime_version, - python_version=self._python_version, - job_dir=self._job_dir, - package_uris=self._package_uris, - training_python_module=self._training_python_module, - training_args=self._training_args, + project_id=self.project_id, + region=self.region, + runtime_version=self.runtime_version, + python_version=self.python_version, + job_dir=self.job_dir, + package_uris=self.package_uris, + training_python_module=self.training_python_module, + training_args=self.training_args, labels=self._labels, gcp_conn_id=self._gcp_conn_id, - impersonation_chain=self._impersonation_chain, + impersonation_chain=self.impersonation_chain, ), method_name="execute_complete", ) else: - finished_training_job = self._wait_for_job_done(self._project_id, self.job_id) + finished_training_job = self._wait_for_job_done(self.project_id, self.job_id) self._handle_job_error(finished_training_job) gcp_metadata = { "job_id": self.job_id, - "project_id": self._project_id, + "project_id": self.project_id, } context["task_instance"].xcom_push("gcp_metadata", gcp_metadata) - project_id = self._project_id or hook.project_id + project_id = self.project_id or hook.project_id if project_id: MLEngineJobDetailsLink.persist( context=context, @@ -1413,19 +1409,19 @@ def execute_complete(self, context: Context, event: dict[str, Any]): self.task_id, event["message"], ) - if self._project_id: + if self.project_id: MLEngineJobDetailsLink.persist( context=context, task_instance=self, - project_id=self._project_id, - job_id=self._job_id, + project_id=self.project_id, + job_id=self.job_id, ) def on_kill(self) -> None: if self.job_id and self.cancel_on_kill: - self.hook.cancel_job(job_id=self.job_id, project_id=self._project_id) # type: ignore[union-attr] + self.hook.cancel_job(job_id=self.job_id, project_id=self.project_id) # type: ignore[union-attr] else: - self.log.info("Skipping to cancel job: %s:%s.%s", self._project_id, self.job_id) + self.log.info("Skipping to cancel job: %s:%s.%s", self.project_id, self.job_id) @deprecated( @@ -1461,9 +1457,9 @@ class MLEngineTrainingCancelJobOperator(GoogleCloudBaseOperator): """ template_fields: Sequence[str] = ( - "_project_id", - "_job_id", - "_impersonation_chain", + "project_id", + "job_id", + "impersonation_chain", ) operator_extra_links = (MLEngineJobSListLink(),) @@ -1477,21 +1473,49 @@ def __init__( **kwargs, ) -> None: super().__init__(**kwargs) - self._project_id = project_id - self._job_id = job_id + self.project_id = project_id + self.job_id = job_id self._gcp_conn_id = gcp_conn_id - self._impersonation_chain = impersonation_chain + self.impersonation_chain = impersonation_chain - if not self._project_id: - raise AirflowException("Google Cloud project id is required.") + @property + @deprecated( + reason="`_project_id` is deprecated and will be removed in the future. Please use `project_id`" + " instead.", + category=AirflowProviderDeprecationWarning, + ) + def _project_id(self): + """Alias for ``project_id``, used for compatibility (deprecated).""" + return self.project_id + + @property + @deprecated( + reason="`_job_id` is deprecated and will be removed in the future. Please use `job_id` instead.", + category=AirflowProviderDeprecationWarning, + ) + def _job_id(self): + """Alias for ``job_id``, used for compatibility (deprecated).""" + return self.job_id + + @property + @deprecated( + reason="`_impersonation_chain` is deprecated and will be removed in the future." + " Please use `impersonation_chain` instead.", + category=AirflowProviderDeprecationWarning, + ) + def _impersonation_chain(self): + """Alias for ``impersonation_chain``, used for compatibility (deprecated).""" + return self.impersonation_chain def execute(self, context: Context): + if not self.project_id: + raise AirflowException("Google Cloud project id is required.") hook = MLEngineHook( gcp_conn_id=self._gcp_conn_id, - impersonation_chain=self._impersonation_chain, + impersonation_chain=self.impersonation_chain, ) - project_id = self._project_id or hook.project_id + project_id = self.project_id or hook.project_id if project_id: MLEngineJobSListLink.persist( context=context, @@ -1499,4 +1523,4 @@ def execute(self, context: Context): project_id=project_id, ) - hook.cancel_job(project_id=self._project_id, job_id=_normalize_mlengine_job_id(self._job_id)) + hook.cancel_job(project_id=self.project_id, job_id=_normalize_mlengine_job_id(self.job_id)) diff --git a/tests/providers/google/cloud/operators/test_mlengine.py b/tests/providers/google/cloud/operators/test_mlengine.py index bc99de2bf19b5..45c09daa9ae6d 100644 --- a/tests/providers/google/cloud/operators/test_mlengine.py +++ b/tests/providers/google/cloud/operators/test_mlengine.py @@ -66,7 +66,7 @@ MLENGINE_AI_PATH = "airflow.providers.google.cloud.operators.mlengine.{}" -class TestMLEngineBatchPredictionOperator: +class TestMLEngineStartBatchPredictionJobOperator: INPUT_MISSING_ORIGIN = { "dataFormat": "TEXT", "inputPaths": ["gs://legal-bucket/fake-input-path/*"], @@ -307,6 +307,38 @@ def test_failed_job_error(self, mock_hook): assert "A failure message" == str(ctx.value) + @pytest.mark.db_test + def test_templating(self, create_task_instance_of_operator): + ti = create_task_instance_of_operator( + MLEngineStartBatchPredictionJobOperator, + # Templated fields + project_id="{{ 'project_id' }}", + job_id="{{ 'job_id' }}", + region="{{ 'region' }}", + input_paths="{{ 'input_paths' }}", + output_path="{{ 'output_path' }}", + model_name="{{ 'model_name' }}", + version_name="{{ 'version_name' }}", + uri="{{ 'uri' }}", + impersonation_chain="{{ 'impersonation_chain' }}", + # Other parameters + data_format="data_format", + dag_id="test_template_body_templating_dag", + task_id="test_template_body_templating_task", + execution_date=timezone.datetime(2024, 2, 1, tzinfo=timezone.utc), + ) + ti.render_templates() + task: MLEngineStartBatchPredictionJobOperator = ti.task + assert task.project_id == "project_id" + assert task.job_id == "job_id" + assert task.region == "region" + assert task.input_paths == "input_paths" + assert task.output_path == "output_path" + assert task.model_name == "model_name" + assert task.version_name == "version_name" + assert task.uri == "uri" + assert task.impersonation_chain == "impersonation_chain" + class TestMLEngineTrainingCancelJobOperator: TRAINING_DEFAULT_ARGS = { @@ -357,6 +389,25 @@ def test_http_error(self, mock_hook): ) assert http_error_code == ctx.value.resp.status + @pytest.mark.db_test + def test_templating(self, create_task_instance_of_operator): + ti = create_task_instance_of_operator( + MLEngineTrainingCancelJobOperator, + # Templated fields + project_id="{{ 'project_id' }}", + job_id="{{ 'job_id' }}", + impersonation_chain="{{ 'impersonation_chain' }}", + # Other parameters + dag_id="test_template_body_templating_dag", + task_id="test_template_body_templating_task", + execution_date=timezone.datetime(2024, 2, 1, tzinfo=timezone.utc), + ) + ti.render_templates() + task: MLEngineTrainingCancelJobOperator = ti.task + assert task.project_id == "project_id" + assert task.job_id == "job_id" + assert task.impersonation_chain == "impersonation_chain" + class TestMLEngineModelOperator: @patch(MLENGINE_AI_PATH.format("MLEngineHook")) @@ -414,6 +465,25 @@ def test_fail(self, mock_hook): with pytest.raises(ValueError): task.execute(None) + @pytest.mark.db_test + def test_templating(self, create_task_instance_of_operator): + ti = create_task_instance_of_operator( + MLEngineManageModelOperator, + # Templated fields + project_id="{{ 'project_id' }}", + model="{{ 'model' }}", + impersonation_chain="{{ 'impersonation_chain' }}", + # Other parameters + dag_id="test_template_body_templating_dag", + task_id="test_template_body_templating_task", + execution_date=timezone.datetime(2024, 2, 1, tzinfo=timezone.utc), + ) + ti.render_templates() + task: MLEngineManageModelOperator = ti.task + assert task.project_id == "project_id" + assert task.model == "model" + assert task.impersonation_chain == "impersonation_chain" + class TestMLEngineCreateModelOperator: @patch(MLENGINE_AI_PATH.format("MLEngineHook")) @@ -436,6 +506,25 @@ def test_success_create_model(self, mock_hook): project_id=TEST_PROJECT_ID, model=TEST_MODEL ) + @pytest.mark.db_test + def test_templating(self, create_task_instance_of_operator): + ti = create_task_instance_of_operator( + MLEngineCreateModelOperator, + # Templated fields + project_id="{{ 'project_id' }}", + model="{{ 'model' }}", + impersonation_chain="{{ 'impersonation_chain' }}", + # Other parameters + dag_id="test_template_body_templating_dag", + task_id="test_template_body_templating_task", + execution_date=timezone.datetime(2024, 2, 1, tzinfo=timezone.utc), + ) + ti.render_templates() + task: MLEngineCreateModelOperator = ti.task + assert task.project_id == "project_id" + assert task.model == "model" + assert task.impersonation_chain == "impersonation_chain" + class TestMLEngineGetModelOperator: @patch(MLENGINE_AI_PATH.format("MLEngineHook")) @@ -459,6 +548,25 @@ def test_success_get_model(self, mock_hook): ) assert mock_hook.return_value.get_model.return_value == result + @pytest.mark.db_test + def test_templating(self, create_task_instance_of_operator): + ti = create_task_instance_of_operator( + MLEngineGetModelOperator, + # Templated fields + project_id="{{ 'project_id' }}", + model_name="{{ 'model_name' }}", + impersonation_chain="{{ 'impersonation_chain' }}", + # Other parameters + dag_id="test_template_body_templating_dag", + task_id="test_template_body_templating_task", + execution_date=timezone.datetime(2024, 2, 1, tzinfo=timezone.utc), + ) + ti.render_templates() + task: MLEngineGetModelOperator = ti.task + assert task.project_id == "project_id" + assert task.model_name == "model_name" + assert task.impersonation_chain == "impersonation_chain" + class TestMLEngineDeleteModelOperator: @patch(MLENGINE_AI_PATH.format("MLEngineHook")) @@ -482,6 +590,25 @@ def test_success_delete_model(self, mock_hook): project_id=TEST_PROJECT_ID, model_name=TEST_MODEL_NAME, delete_contents=True ) + @pytest.mark.db_test + def test_templating(self, create_task_instance_of_operator): + ti = create_task_instance_of_operator( + MLEngineDeleteModelOperator, + # Templated fields + project_id="{{ 'project_id' }}", + model_name="{{ 'model_name' }}", + impersonation_chain="{{ 'impersonation_chain' }}", + # Other parameters + dag_id="test_template_body_templating_dag", + task_id="test_template_body_templating_task", + execution_date=timezone.datetime(2024, 2, 1, tzinfo=timezone.utc), + ) + ti.render_templates() + task: MLEngineDeleteModelOperator = ti.task + assert task.project_id == "project_id" + assert task.model_name == "model_name" + assert task.impersonation_chain == "impersonation_chain" + class TestMLEngineVersionOperator: VERSION_DEFAULT_ARGS = { @@ -509,6 +636,29 @@ def test_success_create_version(self, mock_hook): project_id="test-project", model_name="test-model", version_spec=TEST_VERSION ) + @pytest.mark.db_test + def test_templating(self, create_task_instance_of_operator): + ti = create_task_instance_of_operator( + MLEngineManageVersionOperator, + # Templated fields + project_id="{{ 'project_id' }}", + model_name="{{ 'model_name' }}", + version="{{ 'version' }}", + version_name="{{ 'version_name' }}", + impersonation_chain="{{ 'impersonation_chain' }}", + # Other parameters + dag_id="test_template_body_templating_dag", + task_id="test_template_body_templating_task", + execution_date=timezone.datetime(2024, 2, 1, tzinfo=timezone.utc), + ) + ti.render_templates() + task: MLEngineManageVersionOperator = ti.task + assert task.project_id == "project_id" + assert task.model_name == "model_name" + assert task.version == "version" + assert task.version_name == "version_name" + assert task.impersonation_chain == "impersonation_chain" + class TestMLEngineCreateVersion: @patch(MLENGINE_AI_PATH.format("MLEngineHook")) @@ -534,23 +684,46 @@ def test_success(self, mock_hook): def test_missing_model_name(self): with pytest.raises(AirflowException): - MLEngineCreateVersionOperator( + task = MLEngineCreateVersionOperator( task_id="task-id", project_id=TEST_PROJECT_ID, model_name=None, version=TEST_VERSION, gcp_conn_id=TEST_GCP_CONN_ID, ) + task.execute(context=MagicMock()) def test_missing_version(self): with pytest.raises(AirflowException): - MLEngineCreateVersionOperator( + task = MLEngineCreateVersionOperator( task_id="task-id", project_id=TEST_PROJECT_ID, model_name=TEST_MODEL_NAME, version=None, gcp_conn_id=TEST_GCP_CONN_ID, ) + task.execute(context=MagicMock()) + + @pytest.mark.db_test + def test_templating(self, create_task_instance_of_operator): + ti = create_task_instance_of_operator( + MLEngineCreateVersionOperator, + # Templated fields + project_id="{{ 'project_id' }}", + model_name="{{ 'model_name' }}", + version="{{ 'version' }}", + impersonation_chain="{{ 'impersonation_chain' }}", + # Other parameters + dag_id="test_template_body_templating_dag", + task_id="test_template_body_templating_task", + execution_date=timezone.datetime(2024, 2, 1, tzinfo=timezone.utc), + ) + ti.render_templates() + task: MLEngineCreateVersionOperator = ti.task + assert task.project_id == "project_id" + assert task.model_name == "model_name" + assert task.version == "version" + assert task.impersonation_chain == "impersonation_chain" class TestMLEngineSetDefaultVersion: @@ -577,23 +750,46 @@ def test_success(self, mock_hook): def test_missing_model_name(self): with pytest.raises(AirflowException): - MLEngineSetDefaultVersionOperator( + task = MLEngineSetDefaultVersionOperator( task_id="task-id", project_id=TEST_PROJECT_ID, model_name=None, version_name=TEST_VERSION_NAME, gcp_conn_id=TEST_GCP_CONN_ID, ) + task.execute(context=MagicMock()) def test_missing_version_name(self): with pytest.raises(AirflowException): - MLEngineSetDefaultVersionOperator( + task = MLEngineSetDefaultVersionOperator( task_id="task-id", project_id=TEST_PROJECT_ID, model_name=TEST_MODEL_NAME, version_name=None, gcp_conn_id=TEST_GCP_CONN_ID, ) + task.execute(context=MagicMock()) + + @pytest.mark.db_test + def test_templating(self, create_task_instance_of_operator): + ti = create_task_instance_of_operator( + MLEngineSetDefaultVersionOperator, + # Templated fields + project_id="{{ 'project_id' }}", + model_name="{{ 'model_name' }}", + version_name="{{ 'version_name' }}", + impersonation_chain="{{ 'impersonation_chain' }}", + # Other parameters + dag_id="test_template_body_templating_dag", + task_id="test_template_body_templating_task", + execution_date=timezone.datetime(2024, 2, 1, tzinfo=timezone.utc), + ) + ti.render_templates() + task: MLEngineSetDefaultVersionOperator = ti.task + assert task.project_id == "project_id" + assert task.model_name == "model_name" + assert task.version_name == "version_name" + assert task.impersonation_chain == "impersonation_chain" class TestMLEngineListVersions: @@ -620,12 +816,32 @@ def test_success(self, mock_hook): def test_missing_model_name(self): with pytest.raises(AirflowException): - MLEngineListVersionsOperator( + task = MLEngineListVersionsOperator( task_id="task-id", project_id=TEST_PROJECT_ID, model_name=None, gcp_conn_id=TEST_GCP_CONN_ID, ) + task.execute(context=MagicMock()) + + @pytest.mark.db_test + def test_templating(self, create_task_instance_of_operator): + ti = create_task_instance_of_operator( + MLEngineListVersionsOperator, + # Templated fields + project_id="{{ 'project_id' }}", + model_name="{{ 'model_name' }}", + impersonation_chain="{{ 'impersonation_chain' }}", + # Other parameters + dag_id="test_template_body_templating_dag", + task_id="test_template_body_templating_task", + execution_date=timezone.datetime(2024, 2, 1, tzinfo=timezone.utc), + ) + ti.render_templates() + task: MLEngineListVersionsOperator = ti.task + assert task.project_id == "project_id" + assert task.model_name == "model_name" + assert task.impersonation_chain == "impersonation_chain" class TestMLEngineDeleteVersion: @@ -652,23 +868,46 @@ def test_success(self, mock_hook): def test_missing_version_name(self): with pytest.raises(AirflowException): - MLEngineDeleteVersionOperator( + task = MLEngineDeleteVersionOperator( task_id="task-id", project_id=TEST_PROJECT_ID, model_name=TEST_MODEL_NAME, version_name=None, gcp_conn_id=TEST_GCP_CONN_ID, ) + task.execute(context=MagicMock()) def test_missing_model_name(self): with pytest.raises(AirflowException): - MLEngineDeleteVersionOperator( + task = MLEngineDeleteVersionOperator( task_id="task-id", project_id=TEST_PROJECT_ID, model_name=None, version_name=TEST_VERSION_NAME, gcp_conn_id=TEST_GCP_CONN_ID, ) + task.execute(context=MagicMock()) + + @pytest.mark.db_test + def test_templating(self, create_task_instance_of_operator): + ti = create_task_instance_of_operator( + MLEngineDeleteVersionOperator, + # Templated fields + project_id="{{ 'project_id' }}", + model_name="{{ 'model_name' }}", + version_name="{{ 'version_name' }}", + impersonation_chain="{{ 'impersonation_chain' }}", + # Other parameters + dag_id="test_template_body_templating_dag", + task_id="test_template_body_templating_task", + execution_date=timezone.datetime(2024, 2, 1, tzinfo=timezone.utc), + ) + ti.render_templates() + task: MLEngineDeleteVersionOperator = ti.task + assert task.project_id == "project_id" + assert task.model_name == "model_name" + assert task.version_name == "version_name" + assert task.impersonation_chain == "impersonation_chain" class TestMLEngineStartTrainingJobOperator: @@ -929,6 +1168,49 @@ def test_create_training_job_should_throw_exception_when_job_failed(self, mock_h ) assert "A failure message" == str(ctx.value) + @pytest.mark.db_test + def test_templating(self, create_task_instance_of_operator): + ti = create_task_instance_of_operator( + MLEngineStartTrainingJobOperator, + # Templated fields + project_id="{{ 'project_id' }}", + job_id="{{ 'job_id' }}", + region="{{ 'region' }}", + package_uris="{{ 'package_uris' }}", + training_python_module="{{ 'training_python_module' }}", + training_args="{{ 'training_args' }}", + scale_tier="{{ 'scale_tier' }}", + master_type="{{ 'master_type' }}", + master_config="{{ 'master_config' }}", + runtime_version="{{ 'runtime_version' }}", + python_version="{{ 'python_version' }}", + job_dir="{{ 'job_dir' }}", + service_account="{{ 'service_account' }}", + hyperparameters="{{ 'hyperparameters' }}", + impersonation_chain="{{ 'impersonation_chain' }}", + # Other parameters + dag_id="test_template_body_templating_dag", + task_id="test_template_body_templating_task", + execution_date=timezone.datetime(2024, 2, 1, tzinfo=timezone.utc), + ) + ti.render_templates() + task: MLEngineStartTrainingJobOperator = ti.task + assert task.project_id == "project_id" + assert task.job_id == "job_id" + assert task.region == "region" + assert task.package_uris == "package_uris" + assert task.training_python_module == "training_python_module" + assert task.training_args == "training_args" + assert task.scale_tier == "scale_tier" + assert task.master_type == "master_type" + assert task.master_config == "master_config" + assert task.runtime_version == "runtime_version" + assert task.python_version == "python_version" + assert task.job_dir == "job_dir" + assert task.service_account == "service_account" + assert task.hyperparameters == "hyperparameters" + assert task.impersonation_chain == "impersonation_chain" + TEST_TASK_ID = "training" TEST_JOB_ID = "1234" diff --git a/tests/providers/google/cloud/utils/test_mlengine_operator_utils.py b/tests/providers/google/cloud/utils/test_mlengine_operator_utils.py index 8d468143a6b07..82260ddc247d3 100644 --- a/tests/providers/google/cloud/utils/test_mlengine_operator_utils.py +++ b/tests/providers/google/cloud/utils/test_mlengine_operator_utils.py @@ -119,13 +119,13 @@ def test_create_evaluate_ops(self, mock_beam_pipeline, mock_python): METRIC_FN_ENCODED = base64.b64encode(dill.dumps(METRIC_FN, recurse=True)).decode() assert TASK_PREFIX_PREDICTION == evaluate_prediction.task_id - assert PROJECT_ID == evaluate_prediction._project_id - assert BATCH_PREDICTION_JOB_ID == evaluate_prediction._job_id - assert REGION == evaluate_prediction._region + assert PROJECT_ID == evaluate_prediction.project_id + assert BATCH_PREDICTION_JOB_ID == evaluate_prediction.job_id + assert REGION == evaluate_prediction.region assert DATA_FORMAT == evaluate_prediction._data_format - assert INPUT_PATHS == evaluate_prediction._input_paths - assert PREDICTION_PATH == evaluate_prediction._output_path - assert MODEL_URI == evaluate_prediction._uri + assert INPUT_PATHS == evaluate_prediction.input_paths + assert PREDICTION_PATH == evaluate_prediction.output_path + assert MODEL_URI == evaluate_prediction.uri assert TASK_PREFIX_SUMMARY == evaluate_summary.task_id assert DATAFLOW_OPTIONS == evaluate_summary.default_pipeline_options @@ -165,14 +165,14 @@ def test_create_evaluate_ops_model_and_version_name(self, mock_beam_pipeline, mo METRIC_FN_ENCODED = base64.b64encode(dill.dumps(METRIC_FN, recurse=True)).decode() assert TASK_PREFIX_PREDICTION == evaluate_prediction.task_id - assert PROJECT_ID == evaluate_prediction._project_id - assert BATCH_PREDICTION_JOB_ID == evaluate_prediction._job_id - assert REGION == evaluate_prediction._region + assert PROJECT_ID == evaluate_prediction.project_id + assert BATCH_PREDICTION_JOB_ID == evaluate_prediction.job_id + assert REGION == evaluate_prediction.region assert DATA_FORMAT == evaluate_prediction._data_format - assert INPUT_PATHS == evaluate_prediction._input_paths - assert PREDICTION_PATH == evaluate_prediction._output_path - assert MODEL_NAME == evaluate_prediction._model_name - assert VERSION_NAME == evaluate_prediction._version_name + assert INPUT_PATHS == evaluate_prediction.input_paths + assert PREDICTION_PATH == evaluate_prediction.output_path + assert MODEL_NAME == evaluate_prediction.model_name + assert VERSION_NAME == evaluate_prediction.version_name assert TASK_PREFIX_SUMMARY == evaluate_summary.task_id assert DATAFLOW_OPTIONS == evaluate_summary.default_pipeline_options @@ -208,14 +208,14 @@ def test_create_evaluate_ops_dag(self, mock_dataflow, mock_python): METRIC_FN_ENCODED = base64.b64encode(dill.dumps(METRIC_FN, recurse=True)).decode() assert TASK_PREFIX_PREDICTION == evaluate_prediction.task_id - assert PROJECT_ID == evaluate_prediction._project_id - assert BATCH_PREDICTION_JOB_ID == evaluate_prediction._job_id - assert REGION == evaluate_prediction._region + assert PROJECT_ID == evaluate_prediction.project_id + assert BATCH_PREDICTION_JOB_ID == evaluate_prediction.job_id + assert REGION == evaluate_prediction.region assert DATA_FORMAT == evaluate_prediction._data_format - assert INPUT_PATHS == evaluate_prediction._input_paths - assert PREDICTION_PATH == evaluate_prediction._output_path - assert MODEL_NAME == evaluate_prediction._model_name - assert VERSION_NAME == evaluate_prediction._version_name + assert INPUT_PATHS == evaluate_prediction.input_paths + assert PREDICTION_PATH == evaluate_prediction.output_path + assert MODEL_NAME == evaluate_prediction.model_name + assert VERSION_NAME == evaluate_prediction.version_name assert TASK_PREFIX_SUMMARY == evaluate_summary.task_id assert DATAFLOW_OPTIONS == evaluate_summary.default_pipeline_options