diff --git a/airflow/providers/amazon/aws/operators/dms_create_task.py b/airflow/providers/amazon/aws/operators/dms_create_task.py index 675070aa07911..6eac3d58d58c4 100644 --- a/airflow/providers/amazon/aws/operators/dms_create_task.py +++ b/airflow/providers/amazon/aws/operators/dms_create_task.py @@ -20,7 +20,6 @@ from airflow.models import BaseOperator from airflow.providers.amazon.aws.hooks.dms import DmsHook -from airflow.utils.decorators import apply_defaults class DmsCreateTaskOperator(BaseOperator): @@ -68,7 +67,6 @@ class DmsCreateTaskOperator(BaseOperator): "create_task_kwargs": "json", } - @apply_defaults def __init__( self, *, diff --git a/airflow/providers/amazon/aws/operators/dms_delete_task.py b/airflow/providers/amazon/aws/operators/dms_delete_task.py index fb3cda7570c15..6f12d2b795f30 100644 --- a/airflow/providers/amazon/aws/operators/dms_delete_task.py +++ b/airflow/providers/amazon/aws/operators/dms_delete_task.py @@ -20,7 +20,6 @@ from airflow.models import BaseOperator from airflow.providers.amazon.aws.hooks.dms import DmsHook -from airflow.utils.decorators import apply_defaults class DmsDeleteTaskOperator(BaseOperator): @@ -45,7 +44,6 @@ class DmsDeleteTaskOperator(BaseOperator): template_ext = () template_fields_renderers = {} - @apply_defaults def __init__( self, *, diff --git a/airflow/providers/amazon/aws/operators/dms_describe_tasks.py b/airflow/providers/amazon/aws/operators/dms_describe_tasks.py index 9f4194799512b..cc97044512479 100644 --- a/airflow/providers/amazon/aws/operators/dms_describe_tasks.py +++ b/airflow/providers/amazon/aws/operators/dms_describe_tasks.py @@ -20,7 +20,6 @@ from airflow.models import BaseOperator from airflow.providers.amazon.aws.hooks.dms import DmsHook -from airflow.utils.decorators import apply_defaults class DmsDescribeTasksOperator(BaseOperator): @@ -41,7 +40,6 @@ class DmsDescribeTasksOperator(BaseOperator): template_ext = () template_fields_renderers = {'describe_tasks_kwargs': 'json'} - @apply_defaults def __init__( self, *, diff --git a/airflow/providers/amazon/aws/operators/dms_start_task.py b/airflow/providers/amazon/aws/operators/dms_start_task.py index 50ae6ad4065e1..a2ce635dc9c24 100644 --- a/airflow/providers/amazon/aws/operators/dms_start_task.py +++ b/airflow/providers/amazon/aws/operators/dms_start_task.py @@ -20,7 +20,6 @@ from airflow.models import BaseOperator from airflow.providers.amazon.aws.hooks.dms import DmsHook -from airflow.utils.decorators import apply_defaults class DmsStartTaskOperator(BaseOperator): @@ -54,7 +53,6 @@ class DmsStartTaskOperator(BaseOperator): template_ext = () template_fields_renderers = {'start_task_kwargs': 'json'} - @apply_defaults def __init__( self, *, diff --git a/airflow/providers/amazon/aws/operators/dms_stop_task.py b/airflow/providers/amazon/aws/operators/dms_stop_task.py index c7c1aa2c2d46a..ea45b58b86ae8 100644 --- a/airflow/providers/amazon/aws/operators/dms_stop_task.py +++ b/airflow/providers/amazon/aws/operators/dms_stop_task.py @@ -20,7 +20,6 @@ from airflow.models import BaseOperator from airflow.providers.amazon.aws.hooks.dms import DmsHook -from airflow.utils.decorators import apply_defaults class DmsStopTaskOperator(BaseOperator): @@ -41,7 +40,6 @@ class DmsStopTaskOperator(BaseOperator): template_ext = () template_fields_renderers = {} - @apply_defaults def __init__( self, *, diff --git a/airflow/providers/amazon/aws/operators/emr_containers.py b/airflow/providers/amazon/aws/operators/emr_containers.py index ca3c9363f3666..9466c322cf1c2 100644 --- a/airflow/providers/amazon/aws/operators/emr_containers.py +++ b/airflow/providers/amazon/aws/operators/emr_containers.py @@ -27,7 +27,6 @@ from airflow.models import BaseOperator from airflow.providers.amazon.aws.hooks.emr_containers import EMRContainerHook -from airflow.utils.decorators import apply_defaults class EMRContainerOperator(BaseOperator): @@ -63,7 +62,6 @@ class EMRContainerOperator(BaseOperator): template_fields = ["name", "virtual_cluster_id", "execution_role_arn", "release_label", "job_driver"] ui_color = "#f9c915" - @apply_defaults def __init__( # pylint: disable=too-many-arguments self, *, diff --git a/airflow/providers/amazon/aws/sensors/dms_task.py b/airflow/providers/amazon/aws/sensors/dms_task.py index 39af3627c3507..32ed4f9b11c07 100644 --- a/airflow/providers/amazon/aws/sensors/dms_task.py +++ b/airflow/providers/amazon/aws/sensors/dms_task.py @@ -21,7 +21,6 @@ from airflow.exceptions import AirflowException from airflow.providers.amazon.aws.hooks.dms import DmsHook from airflow.sensors.base import BaseSensorOperator -from airflow.utils.decorators import apply_defaults class DmsTaskBaseSensor(BaseSensorOperator): @@ -45,7 +44,6 @@ class DmsTaskBaseSensor(BaseSensorOperator): template_fields = ['replication_task_arn'] template_ext = () - @apply_defaults def __init__( self, replication_task_arn: str, @@ -104,7 +102,6 @@ class DmsTaskCompletedSensor(DmsTaskBaseSensor): template_fields = ['replication_task_arn'] template_ext = () - @apply_defaults def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.target_statuses = ['stopped'] diff --git a/airflow/providers/apache/drill/operators/drill.py b/airflow/providers/apache/drill/operators/drill.py index 459c623434379..2c330117a0d32 100644 --- a/airflow/providers/apache/drill/operators/drill.py +++ b/airflow/providers/apache/drill/operators/drill.py @@ -21,7 +21,6 @@ from airflow.models import BaseOperator from airflow.providers.apache.drill.hooks.drill import DrillHook -from airflow.utils.decorators import apply_defaults class DrillOperator(BaseOperator): @@ -48,7 +47,6 @@ class DrillOperator(BaseOperator): template_ext = ('.sql',) ui_color = '#ededed' - @apply_defaults def __init__( self, *, diff --git a/airflow/providers/cncf/kubernetes/backcompat/backwards_compat_converters.py b/airflow/providers/cncf/kubernetes/backcompat/backwards_compat_converters.py index 7c28105fe6564..4c6404f054832 100644 --- a/airflow/providers/cncf/kubernetes/backcompat/backwards_compat_converters.py +++ b/airflow/providers/cncf/kubernetes/backcompat/backwards_compat_converters.py @@ -23,8 +23,6 @@ from airflow.exceptions import AirflowException from airflow.providers.cncf.kubernetes.backcompat.pod import Port, Resources from airflow.providers.cncf.kubernetes.backcompat.pod_runtime_info_env import PodRuntimeInfoEnv -from airflow.providers.cncf.kubernetes.backcompat.volume import Volume -from airflow.providers.cncf.kubernetes.backcompat.volume_mount import VolumeMount def _convert_kube_model_object(obj, old_class, new_class): @@ -54,6 +52,8 @@ def convert_volume(volume) -> k8s.V1Volume: :param volume: :return: k8s.V1Volume """ + from airflow.providers.cncf.kubernetes.backcompat.volume import Volume + return _convert_kube_model_object(volume, Volume, k8s.V1Volume) @@ -64,6 +64,8 @@ def convert_volume_mount(volume_mount) -> k8s.V1VolumeMount: :param volume_mount: :return: k8s.V1VolumeMount """ + from airflow.providers.cncf.kubernetes.backcompat.volume_mount import VolumeMount + return _convert_kube_model_object(volume_mount, VolumeMount, k8s.V1VolumeMount) diff --git a/airflow/providers/google/cloud/example_dags/example_bigquery_operations.py b/airflow/providers/google/cloud/example_dags/example_bigquery_operations.py index a72ff36f58c5f..e3e96fe3f823b 100644 --- a/airflow/providers/google/cloud/example_dags/example_bigquery_operations.py +++ b/airflow/providers/google/cloud/example_dags/example_bigquery_operations.py @@ -33,7 +33,6 @@ BigQueryDeleteTableOperator, BigQueryGetDatasetOperator, BigQueryGetDatasetTablesOperator, - BigQueryPatchDatasetOperator, BigQueryUpdateDatasetOperator, BigQueryUpdateTableOperator, BigQueryUpdateTableSchemaOperator, @@ -134,14 +133,26 @@ # [START howto_operator_bigquery_create_external_table] create_external_table = BigQueryCreateExternalTableOperator( task_id="create_external_table", + table_resource={ + "tableReference": { + "projectId": PROJECT_ID, + "datasetId": DATASET_NAME, + "tableId": "external_table", + }, + "schema": { + "fields": [ + {"name": "name", "type": "STRING"}, + {"name": "post_abbr", "type": "STRING"}, + ] + }, + "externalDataConfiguration": { + "sourceFormat": "CSV", + "compression": "NONE", + "csvOptions": {"skipLeadingRows": 1}, + }, + }, bucket=DATA_SAMPLE_GCS_BUCKET_NAME, source_objects=[DATA_SAMPLE_GCS_OBJECT_NAME], - destination_project_dataset_table=f"{DATASET_NAME}.external_table", - skip_leading_rows=1, - schema_fields=[ - {"name": "name", "type": "STRING"}, - {"name": "post_abbr", "type": "STRING"}, - ], ) # [END howto_operator_bigquery_create_external_table] @@ -191,17 +202,6 @@ ) # [END howto_operator_bigquery_update_table] - # [START howto_operator_bigquery_patch_dataset] - patch_dataset = BigQueryPatchDatasetOperator( - task_id="patch_dataset", - dataset_id=DATASET_NAME, - dataset_resource={ - "friendlyName": "Patched Dataset", - "description": "Patched dataset", - }, - ) - # [END howto_operator_bigquery_patch_dataset] - # [START howto_operator_bigquery_update_dataset] update_dataset = BigQueryUpdateDatasetOperator( task_id="update_dataset", @@ -216,7 +216,7 @@ ) # [END howto_operator_bigquery_delete_dataset] - create_dataset >> patch_dataset >> update_dataset >> get_dataset >> get_dataset_result >> delete_dataset + create_dataset >> update_dataset >> get_dataset >> get_dataset_result >> delete_dataset ( update_dataset diff --git a/airflow/providers/google/cloud/example_dags/example_bigquery_queries.py b/airflow/providers/google/cloud/example_dags/example_bigquery_queries.py index 24788c9e05174..06946fa974a3f 100644 --- a/airflow/providers/google/cloud/example_dags/example_bigquery_queries.py +++ b/airflow/providers/google/cloud/example_dags/example_bigquery_queries.py @@ -29,7 +29,6 @@ BigQueryCreateEmptyDatasetOperator, BigQueryCreateEmptyTableOperator, BigQueryDeleteDatasetOperator, - BigQueryExecuteQueryOperator, BigQueryGetDataOperator, BigQueryInsertJobOperator, BigQueryIntervalCheckOperator, @@ -125,25 +124,40 @@ ) # [END howto_operator_bigquery_select_job] - execute_insert_query = BigQueryExecuteQueryOperator( - task_id="execute_insert_query", sql=INSERT_ROWS_QUERY, use_legacy_sql=False, location=location + execute_insert_query = BigQueryInsertJobOperator( + task_id="execute_insert_query", + configuration={ + "query": { + "query": INSERT_ROWS_QUERY, + "useLegacySql": False, + } + }, + location=location, ) - bigquery_execute_multi_query = BigQueryExecuteQueryOperator( + bigquery_execute_multi_query = BigQueryInsertJobOperator( task_id="execute_multi_query", - sql=[ - f"SELECT * FROM {DATASET_NAME}.{TABLE_2}", - f"SELECT COUNT(*) FROM {DATASET_NAME}.{TABLE_2}", - ], - use_legacy_sql=False, + configuration={ + "query": { + "query": [ + f"SELECT * FROM {DATASET_NAME}.{TABLE_2}", + f"SELECT COUNT(*) FROM {DATASET_NAME}.{TABLE_2}", + ], + "useLegacySql": False, + } + }, location=location, ) - execute_query_save = BigQueryExecuteQueryOperator( + execute_query_save = BigQueryInsertJobOperator( task_id="execute_query_save", - sql=f"SELECT * FROM {DATASET_NAME}.{TABLE_1}", - use_legacy_sql=False, - destination_dataset_table=f"{DATASET_NAME}.{TABLE_2}", + configuration={ + "query": { + "query": f"SELECT * FROM {DATASET_NAME}.{TABLE_1}", + "useLegacySql": False, + "destinationTable": f"{DATASET_NAME}.{TABLE_2}", + } + }, location=location, ) diff --git a/airflow/providers/google/cloud/example_dags/example_bigquery_sensors.py b/airflow/providers/google/cloud/example_dags/example_bigquery_sensors.py index 4423b1c53331f..cde6cd9dbaadd 100644 --- a/airflow/providers/google/cloud/example_dags/example_bigquery_sensors.py +++ b/airflow/providers/google/cloud/example_dags/example_bigquery_sensors.py @@ -27,7 +27,7 @@ BigQueryCreateEmptyDatasetOperator, BigQueryCreateEmptyTableOperator, BigQueryDeleteDatasetOperator, - BigQueryExecuteQueryOperator, + BigQueryInsertJobOperator, ) from airflow.providers.google.cloud.sensors.bigquery import ( BigQueryTableExistenceSensor, @@ -80,8 +80,14 @@ ) # [END howto_sensor_bigquery_table] - execute_insert_query = BigQueryExecuteQueryOperator( - task_id="execute_insert_query", sql=INSERT_ROWS_QUERY, use_legacy_sql=False + execute_insert_query = BigQueryInsertJobOperator( + task_id="execute_insert_query", + configuration={ + "query": { + "query": INSERT_ROWS_QUERY, + "useLegacySql": False, + } + }, ) # [START howto_sensor_bigquery_table_partition] diff --git a/airflow/providers/google/cloud/example_dags/example_datacatalog.py b/airflow/providers/google/cloud/example_dags/example_datacatalog.py index a764c11a76e72..3457805f2f094 100644 --- a/airflow/providers/google/cloud/example_dags/example_datacatalog.py +++ b/airflow/providers/google/cloud/example_dags/example_datacatalog.py @@ -22,6 +22,7 @@ from google.cloud.datacatalog_v1beta1 import FieldType, TagField, TagTemplateField from airflow import models +from airflow.models.baseoperator import chain from airflow.operators.bash import BashOperator from airflow.providers.google.cloud.operators.datacatalog import ( CloudDataCatalogCreateEntryGroupOperator, @@ -47,7 +48,6 @@ CloudDataCatalogUpdateTagTemplateOperator, ) from airflow.utils.dates import days_ago -from airflow.utils.helpers import chain PROJECT_ID = "polidea-airflow" LOCATION = "us-central1" diff --git a/airflow/providers/google/cloud/example_dags/example_dataflow.py b/airflow/providers/google/cloud/example_dags/example_dataflow.py index 6e58ff4d67ffe..1761cbaee6532 100644 --- a/airflow/providers/google/cloud/example_dags/example_dataflow.py +++ b/airflow/providers/google/cloud/example_dags/example_dataflow.py @@ -25,11 +25,13 @@ from airflow import models from airflow.exceptions import AirflowException +from airflow.providers.apache.beam.operators.beam import ( + BeamRunJavaPipelineOperator, + BeamRunPythonPipelineOperator, +) from airflow.providers.google.cloud.hooks.dataflow import DataflowJobStatus from airflow.providers.google.cloud.operators.dataflow import ( CheckJobRunning, - DataflowCreateJavaJobOperator, - DataflowCreatePythonJobOperator, DataflowTemplatedJobStartOperator, ) from airflow.providers.google.cloud.sensors.dataflow import ( @@ -66,17 +68,18 @@ ) as dag_native_java: # [START howto_operator_start_java_job_jar_on_gcs] - start_java_job = DataflowCreateJavaJobOperator( + start_java_job = BeamRunJavaPipelineOperator( task_id="start-java-job", jar=GCS_JAR, - job_name='{{task.task_id}}', - options={ + pipeline_options={ 'output': GCS_OUTPUT, }, - poll_sleep=10, job_class='org.apache.beam.examples.WordCount', - check_if_running=CheckJobRunning.IgnoreJob, - location='europe-west3', + dataflow_config={ + "check_if_running": CheckJobRunning.IgnoreJob, + "location": 'europe-west3', + "poll_sleep": 10, + }, ) # [END howto_operator_start_java_job_jar_on_gcs] @@ -88,16 +91,18 @@ filename="/tmp/dataflow-{{ ds_nodash }}.jar", ) - start_java_job_local = DataflowCreateJavaJobOperator( + start_java_job_local = BeamRunJavaPipelineOperator( task_id="start-java-job-local", jar="/tmp/dataflow-{{ ds_nodash }}.jar", - job_name='{{task.task_id}}', - options={ + pipeline_options={ 'output': GCS_OUTPUT, }, - poll_sleep=10, job_class='org.apache.beam.examples.WordCount', - check_if_running=CheckJobRunning.WaitForRun, + dataflow_config={ + "check_if_running": CheckJobRunning.WaitForRun, + "location": 'europe-west3', + "poll_sleep": 10, + }, ) jar_to_local >> start_java_job_local # [END howto_operator_start_java_job_local_jar] @@ -111,27 +116,25 @@ ) as dag_native_python: # [START howto_operator_start_python_job] - start_python_job = DataflowCreatePythonJobOperator( + start_python_job = BeamRunPythonPipelineOperator( task_id="start-python-job", py_file=GCS_PYTHON, py_options=[], - job_name='{{task.task_id}}', - options={ + pipeline_options={ 'output': GCS_OUTPUT, }, py_requirements=['apache-beam[gcp]==2.21.0'], py_interpreter='python3', py_system_site_packages=False, - location='europe-west3', + dataflow_config={'location': 'europe-west3'}, ) # [END howto_operator_start_python_job] - start_python_job_local = DataflowCreatePythonJobOperator( + start_python_job_local = BeamRunPythonPipelineOperator( task_id="start-python-job-local", py_file='apache_beam.examples.wordcount', py_options=['-m'], - job_name='{{task.task_id}}', - options={ + pipeline_options={ 'output': GCS_OUTPUT, }, py_requirements=['apache-beam[gcp]==2.14.0'], @@ -147,19 +150,17 @@ tags=['example'], ) as dag_native_python_async: # [START howto_operator_start_python_job_async] - start_python_job_async = DataflowCreatePythonJobOperator( + start_python_job_async = BeamRunPythonPipelineOperator( task_id="start-python-job-async", py_file=GCS_PYTHON, py_options=[], - job_name='{{task.task_id}}', - options={ + pipeline_options={ 'output': GCS_OUTPUT, }, py_requirements=['apache-beam[gcp]==2.25.0'], py_interpreter='python3', py_system_site_packages=False, - location='europe-west3', - wait_until_finished=False, + dataflow_config={"location": 'europe-west3', "wait_until_finished": False}, ) # [END howto_operator_start_python_job_async] diff --git a/airflow/providers/google/cloud/example_dags/example_dataproc.py b/airflow/providers/google/cloud/example_dags/example_dataproc.py index 914df0ef149fa..9694eb8c78d74 100644 --- a/airflow/providers/google/cloud/example_dags/example_dataproc.py +++ b/airflow/providers/google/cloud/example_dags/example_dataproc.py @@ -170,7 +170,7 @@ update_mask=UPDATE_MASK, graceful_decommission_timeout=TIMEOUT, project_id=PROJECT_ID, - location=REGION, + region=REGION, ) # [END how_to_cloud_dataproc_update_cluster_operator] @@ -179,7 +179,7 @@ task_id="create_workflow_template", template=WORKFLOW_TEMPLATE, project_id=PROJECT_ID, - location=REGION, + region=REGION, ) # [END how_to_cloud_dataproc_create_workflow_template] @@ -190,24 +190,24 @@ # [END how_to_cloud_dataproc_trigger_workflow_template] pig_task = DataprocSubmitJobOperator( - task_id="pig_task", job=PIG_JOB, location=REGION, project_id=PROJECT_ID + task_id="pig_task", job=PIG_JOB, region=REGION, project_id=PROJECT_ID ) spark_sql_task = DataprocSubmitJobOperator( - task_id="spark_sql_task", job=SPARK_SQL_JOB, location=REGION, project_id=PROJECT_ID + task_id="spark_sql_task", job=SPARK_SQL_JOB, region=REGION, project_id=PROJECT_ID ) spark_task = DataprocSubmitJobOperator( - task_id="spark_task", job=SPARK_JOB, location=REGION, project_id=PROJECT_ID + task_id="spark_task", job=SPARK_JOB, region=REGION, project_id=PROJECT_ID ) # [START cloud_dataproc_async_submit_sensor] spark_task_async = DataprocSubmitJobOperator( - task_id="spark_task_async", job=SPARK_JOB, location=REGION, project_id=PROJECT_ID, asynchronous=True + task_id="spark_task_async", job=SPARK_JOB, region=REGION, project_id=PROJECT_ID, asynchronous=True ) spark_task_async_sensor = DataprocJobSensor( task_id='spark_task_async_sensor_task', - location=REGION, + region=REGION, project_id=PROJECT_ID, dataproc_job_id=spark_task_async.output, poke_interval=10, @@ -216,20 +216,20 @@ # [START how_to_cloud_dataproc_submit_job_to_cluster_operator] pyspark_task = DataprocSubmitJobOperator( - task_id="pyspark_task", job=PYSPARK_JOB, location=REGION, project_id=PROJECT_ID + task_id="pyspark_task", job=PYSPARK_JOB, region=REGION, project_id=PROJECT_ID ) # [END how_to_cloud_dataproc_submit_job_to_cluster_operator] sparkr_task = DataprocSubmitJobOperator( - task_id="sparkr_task", job=SPARKR_JOB, location=REGION, project_id=PROJECT_ID + task_id="sparkr_task", job=SPARKR_JOB, region=REGION, project_id=PROJECT_ID ) hive_task = DataprocSubmitJobOperator( - task_id="hive_task", job=HIVE_JOB, location=REGION, project_id=PROJECT_ID + task_id="hive_task", job=HIVE_JOB, region=REGION, project_id=PROJECT_ID ) hadoop_task = DataprocSubmitJobOperator( - task_id="hadoop_task", job=HADOOP_JOB, location=REGION, project_id=PROJECT_ID + task_id="hadoop_task", job=HADOOP_JOB, region=REGION, project_id=PROJECT_ID ) # [START how_to_cloud_dataproc_delete_cluster_operator] diff --git a/airflow/providers/google/cloud/example_dags/example_facebook_ads_to_gcs.py b/airflow/providers/google/cloud/example_dags/example_facebook_ads_to_gcs.py index e5d67f04cf19e..920935da1f23d 100644 --- a/airflow/providers/google/cloud/example_dags/example_facebook_ads_to_gcs.py +++ b/airflow/providers/google/cloud/example_dags/example_facebook_ads_to_gcs.py @@ -27,7 +27,7 @@ BigQueryCreateEmptyDatasetOperator, BigQueryCreateEmptyTableOperator, BigQueryDeleteDatasetOperator, - BigQueryExecuteQueryOperator, + BigQueryInsertJobOperator, ) from airflow.providers.google.cloud.operators.gcs import GCSCreateBucketOperator, GCSDeleteBucketOperator from airflow.providers.google.cloud.transfers.facebook_ads_to_gcs import FacebookAdsReportToGcsOperator @@ -105,10 +105,14 @@ write_disposition='WRITE_TRUNCATE', ) - read_data_from_gcs_many_chunks = BigQueryExecuteQueryOperator( + read_data_from_gcs_many_chunks = BigQueryInsertJobOperator( task_id="read_data_from_gcs_many_chunks", - sql=f"SELECT COUNT(*) FROM `{GCP_PROJECT_ID}.{DATASET_NAME}.{TABLE_NAME}`", - use_legacy_sql=False, + configuration={ + "query": { + "query": f"SELECT COUNT(*) FROM `{GCP_PROJECT_ID}.{DATASET_NAME}.{TABLE_NAME}`", + "useLegacySql": False, + } + }, ) delete_bucket = GCSDeleteBucketOperator( diff --git a/airflow/providers/google/cloud/example_dags/example_gdrive_to_gcs.py b/airflow/providers/google/cloud/example_dags/example_gdrive_to_gcs.py index 3a604681901f0..974fa66d0c77f 100644 --- a/airflow/providers/google/cloud/example_dags/example_gdrive_to_gcs.py +++ b/airflow/providers/google/cloud/example_dags/example_gdrive_to_gcs.py @@ -44,8 +44,8 @@ task_id="upload_gdrive_object_to_gcs", folder_id=FOLDER_ID, file_name=FILE_NAME, - destination_bucket=BUCKET, - destination_object=OBJECT, + bucket_name=BUCKET, + object_name=OBJECT, ) # [END upload_gdrive_to_gcs] detect_file >> upload_gdrive_to_gcs diff --git a/airflow/providers/google/cloud/example_dags/example_presto_to_gcs.py b/airflow/providers/google/cloud/example_dags/example_presto_to_gcs.py index 759c429e65e7a..cf82a4800cb39 100644 --- a/airflow/providers/google/cloud/example_dags/example_presto_to_gcs.py +++ b/airflow/providers/google/cloud/example_dags/example_presto_to_gcs.py @@ -26,7 +26,7 @@ BigQueryCreateEmptyDatasetOperator, BigQueryCreateExternalTableOperator, BigQueryDeleteDatasetOperator, - BigQueryExecuteQueryOperator, + BigQueryInsertJobOperator, ) from airflow.providers.google.cloud.transfers.presto_to_gcs import PrestoToGCSOperator from airflow.utils.dates import days_ago @@ -84,16 +84,37 @@ def safe_name(s: str) -> str: task_id="create_external_table_multiple_types", bucket=GCS_BUCKET, source_objects=[f"{safe_name(SOURCE_MULTIPLE_TYPES)}.*.json"], - source_format="NEWLINE_DELIMITED_JSON", - destination_project_dataset_table=f"{DATASET_NAME}.{safe_name(SOURCE_MULTIPLE_TYPES)}", + table_resource={ + "tableReference": { + "projectId": GCP_PROJECT_ID, + "datasetId": DATASET_NAME, + "tableId": f"{safe_name(SOURCE_MULTIPLE_TYPES)}", + }, + "schema": { + "fields": [ + {"name": "name", "type": "STRING"}, + {"name": "post_abbr", "type": "STRING"}, + ] + }, + "externalDataConfiguration": { + "sourceFormat": "NEWLINE_DELIMITED_JSON", + "compression": "NONE", + "csvOptions": {"skipLeadingRows": 1}, + }, + }, schema_object=f"{safe_name(SOURCE_MULTIPLE_TYPES)}-schema.json", ) # [END howto_operator_create_external_table_multiple_types] - read_data_from_gcs_multiple_types = BigQueryExecuteQueryOperator( + read_data_from_gcs_multiple_types = BigQueryInsertJobOperator( task_id="read_data_from_gcs_multiple_types", - sql=f"SELECT COUNT(*) FROM `{GCP_PROJECT_ID}.{DATASET_NAME}.{safe_name(SOURCE_MULTIPLE_TYPES)}`", - use_legacy_sql=False, + configuration={ + "query": { + "query": f"SELECT COUNT(*) FROM `{GCP_PROJECT_ID}.{DATASET_NAME}." + f"{safe_name(SOURCE_MULTIPLE_TYPES)}`", + "useLegacySql": False, + } + }, ) # [START howto_operator_presto_to_gcs_many_chunks] @@ -111,17 +132,38 @@ def safe_name(s: str) -> str: create_external_table_many_chunks = BigQueryCreateExternalTableOperator( task_id="create_external_table_many_chunks", bucket=GCS_BUCKET, + table_resource={ + "tableReference": { + "projectId": GCP_PROJECT_ID, + "datasetId": DATASET_NAME, + "tableId": f"{safe_name(SOURCE_CUSTOMER_TABLE)}", + }, + "schema": { + "fields": [ + {"name": "name", "type": "STRING"}, + {"name": "post_abbr", "type": "STRING"}, + ] + }, + "externalDataConfiguration": { + "sourceFormat": "NEWLINE_DELIMITED_JSON", + "compression": "NONE", + "csvOptions": {"skipLeadingRows": 1}, + }, + }, source_objects=[f"{safe_name(SOURCE_CUSTOMER_TABLE)}.*.json"], - source_format="NEWLINE_DELIMITED_JSON", - destination_project_dataset_table=f"{DATASET_NAME}.{safe_name(SOURCE_CUSTOMER_TABLE)}", schema_object=f"{safe_name(SOURCE_CUSTOMER_TABLE)}-schema.json", ) # [START howto_operator_read_data_from_gcs_many_chunks] - read_data_from_gcs_many_chunks = BigQueryExecuteQueryOperator( + read_data_from_gcs_many_chunks = BigQueryInsertJobOperator( task_id="read_data_from_gcs_many_chunks", - sql=f"SELECT COUNT(*) FROM `{GCP_PROJECT_ID}.{DATASET_NAME}.{safe_name(SOURCE_CUSTOMER_TABLE)}`", - use_legacy_sql=False, + configuration={ + "query": { + "query": f"SELECT COUNT(*) FROM `{GCP_PROJECT_ID}.{DATASET_NAME}." + f"{safe_name(SOURCE_CUSTOMER_TABLE)}`", + "useLegacySql": False, + } + }, ) # [END howto_operator_read_data_from_gcs_many_chunks] diff --git a/airflow/providers/google/cloud/example_dags/example_salesforce_to_gcs.py b/airflow/providers/google/cloud/example_dags/example_salesforce_to_gcs.py index a49b267b34a4d..be28864084e25 100644 --- a/airflow/providers/google/cloud/example_dags/example_salesforce_to_gcs.py +++ b/airflow/providers/google/cloud/example_dags/example_salesforce_to_gcs.py @@ -25,7 +25,7 @@ BigQueryCreateEmptyDatasetOperator, BigQueryCreateEmptyTableOperator, BigQueryDeleteDatasetOperator, - BigQueryExecuteQueryOperator, + BigQueryInsertJobOperator, ) from airflow.providers.google.cloud.operators.gcs import GCSCreateBucketOperator, GCSDeleteBucketOperator from airflow.providers.google.cloud.transfers.gcs_to_bigquery import GCSToBigQueryOperator @@ -98,10 +98,14 @@ write_disposition='WRITE_TRUNCATE', ) - read_data_from_gcs = BigQueryExecuteQueryOperator( + read_data_from_gcs = BigQueryInsertJobOperator( task_id="read_data_from_gcs", - sql=f"SELECT COUNT(*) FROM `{GCP_PROJECT_ID}.{DATASET_NAME}.{TABLE_NAME}`", - use_legacy_sql=False, + configuration={ + "query": { + "query": f"SELECT COUNT(*) FROM `{GCP_PROJECT_ID}.{DATASET_NAME}.{TABLE_NAME}`", + "useLegacySql": False, + } + }, ) delete_bucket = GCSDeleteBucketOperator( diff --git a/airflow/providers/google/cloud/example_dags/example_trino_to_gcs.py b/airflow/providers/google/cloud/example_dags/example_trino_to_gcs.py index 209c51e2c2d9c..0d8ef0296af9b 100644 --- a/airflow/providers/google/cloud/example_dags/example_trino_to_gcs.py +++ b/airflow/providers/google/cloud/example_dags/example_trino_to_gcs.py @@ -26,7 +26,7 @@ BigQueryCreateEmptyDatasetOperator, BigQueryCreateExternalTableOperator, BigQueryDeleteDatasetOperator, - BigQueryExecuteQueryOperator, + BigQueryInsertJobOperator, ) from airflow.providers.google.cloud.transfers.trino_to_gcs import TrinoToGCSOperator from airflow.utils.dates import days_ago @@ -83,17 +83,38 @@ def safe_name(s: str) -> str: create_external_table_multiple_types = BigQueryCreateExternalTableOperator( task_id="create_external_table_multiple_types", bucket=GCS_BUCKET, + table_resource={ + "tableReference": { + "projectId": GCP_PROJECT_ID, + "datasetId": DATASET_NAME, + "tableId": f"{safe_name(SOURCE_MULTIPLE_TYPES)}", + }, + "schema": { + "fields": [ + {"name": "name", "type": "STRING"}, + {"name": "post_abbr", "type": "STRING"}, + ] + }, + "externalDataConfiguration": { + "sourceFormat": "NEWLINE_DELIMITED_JSON", + "compression": "NONE", + "csvOptions": {"skipLeadingRows": 1}, + }, + }, source_objects=[f"{safe_name(SOURCE_MULTIPLE_TYPES)}.*.json"], - source_format="NEWLINE_DELIMITED_JSON", - destination_project_dataset_table=f"{DATASET_NAME}.{safe_name(SOURCE_MULTIPLE_TYPES)}", schema_object=f"{safe_name(SOURCE_MULTIPLE_TYPES)}-schema.json", ) # [END howto_operator_create_external_table_multiple_types] - read_data_from_gcs_multiple_types = BigQueryExecuteQueryOperator( + read_data_from_gcs_multiple_types = BigQueryInsertJobOperator( task_id="read_data_from_gcs_multiple_types", - sql=f"SELECT COUNT(*) FROM `{GCP_PROJECT_ID}.{DATASET_NAME}.{safe_name(SOURCE_MULTIPLE_TYPES)}`", - use_legacy_sql=False, + configuration={ + "query": { + "query": f"SELECT COUNT(*) FROM `{GCP_PROJECT_ID}.{DATASET_NAME}." + f"{safe_name(SOURCE_MULTIPLE_TYPES)}`", + "useLegacySql": False, + } + }, ) # [START howto_operator_trino_to_gcs_many_chunks] @@ -111,17 +132,38 @@ def safe_name(s: str) -> str: create_external_table_many_chunks = BigQueryCreateExternalTableOperator( task_id="create_external_table_many_chunks", bucket=GCS_BUCKET, + table_resource={ + "tableReference": { + "projectId": GCP_PROJECT_ID, + "datasetId": DATASET_NAME, + "tableId": f"{safe_name(SOURCE_CUSTOMER_TABLE)}", + }, + "schema": { + "fields": [ + {"name": "name", "type": "STRING"}, + {"name": "post_abbr", "type": "STRING"}, + ] + }, + "externalDataConfiguration": { + "sourceFormat": "NEWLINE_DELIMITED_JSON", + "compression": "NONE", + "csvOptions": {"skipLeadingRows": 1}, + }, + }, source_objects=[f"{safe_name(SOURCE_CUSTOMER_TABLE)}.*.json"], - source_format="NEWLINE_DELIMITED_JSON", - destination_project_dataset_table=f"{DATASET_NAME}.{safe_name(SOURCE_CUSTOMER_TABLE)}", schema_object=f"{safe_name(SOURCE_CUSTOMER_TABLE)}-schema.json", ) # [START howto_operator_read_data_from_gcs_many_chunks] - read_data_from_gcs_many_chunks = BigQueryExecuteQueryOperator( + read_data_from_gcs_many_chunks = BigQueryInsertJobOperator( task_id="read_data_from_gcs_many_chunks", - sql=f"SELECT COUNT(*) FROM `{GCP_PROJECT_ID}.{DATASET_NAME}.{safe_name(SOURCE_CUSTOMER_TABLE)}`", - use_legacy_sql=False, + configuration={ + "query": { + "query": f"SELECT COUNT(*) FROM `{GCP_PROJECT_ID}.{DATASET_NAME}." + f"{safe_name(SOURCE_CUSTOMER_TABLE)}`", + "useLegacySql": False, + } + }, ) # [END howto_operator_read_data_from_gcs_many_chunks] diff --git a/airflow/providers/google/cloud/hooks/dataproc.py b/airflow/providers/google/cloud/hooks/dataproc.py index 0f138d723c48c..e353ef7f83ae5 100644 --- a/airflow/providers/google/cloud/hooks/dataproc.py +++ b/airflow/providers/google/cloud/hooks/dataproc.py @@ -216,7 +216,7 @@ def get_cluster_client( "Parameter `location` will be deprecated. " "Please provide value through `region` parameter instead.", DeprecationWarning, - stacklevel=1, + stacklevel=2, ) region = location client_options = None @@ -236,7 +236,7 @@ def get_template_client( "Parameter `location` will be deprecated. " "Please provide value through `region` parameter instead.", DeprecationWarning, - stacklevel=1, + stacklevel=2, ) region = location client_options = None @@ -256,7 +256,7 @@ def get_job_client( "Parameter `location` will be deprecated. " "Please provide value through `region` parameter instead.", DeprecationWarning, - stacklevel=1, + stacklevel=2, ) region = location client_options = None @@ -587,7 +587,7 @@ def update_cluster( "Parameter `location` will be deprecated. " "Please provide value through `region` parameter instead.", DeprecationWarning, - stacklevel=1, + stacklevel=2, ) region = location else: @@ -647,7 +647,7 @@ def create_workflow_template( "Parameter `location` will be deprecated. " "Please provide value through `region` parameter instead.", DeprecationWarning, - stacklevel=1, + stacklevel=2, ) region = location else: @@ -712,7 +712,7 @@ def instantiate_workflow_template( "Parameter `location` will be deprecated. " "Please provide value through `region` parameter instead.", DeprecationWarning, - stacklevel=1, + stacklevel=2, ) region = location else: @@ -771,7 +771,7 @@ def instantiate_inline_workflow_template( "Parameter `location` will be deprecated. " "Please provide value through `region` parameter instead.", DeprecationWarning, - stacklevel=1, + stacklevel=2, ) region = location else: @@ -819,7 +819,7 @@ def wait_for_job( "Parameter `location` will be deprecated. " "Please provide value through `region` parameter instead.", DeprecationWarning, - stacklevel=1, + stacklevel=2, ) region = location else: @@ -878,7 +878,7 @@ def get_job( "Parameter `location` will be deprecated. " "Please provide value through `region` parameter instead.", DeprecationWarning, - stacklevel=1, + stacklevel=2, ) region = location else: @@ -935,7 +935,7 @@ def submit_job( "Parameter `location` will be deprecated. " "Please provide value through `region` parameter instead.", DeprecationWarning, - stacklevel=1, + stacklevel=2, ) region = location else: @@ -1010,7 +1010,7 @@ def cancel_job( "Parameter `location` will be deprecated. " "Please provide value through `region` parameter instead.", DeprecationWarning, - stacklevel=1, + stacklevel=2, ) region = location diff --git a/airflow/providers/google/cloud/operators/bigquery.py b/airflow/providers/google/cloud/operators/bigquery.py index 4a0f6c007adc1..fe853c22aca35 100644 --- a/airflow/providers/google/cloud/operators/bigquery.py +++ b/airflow/providers/google/cloud/operators/bigquery.py @@ -489,6 +489,9 @@ class BigQueryExecuteQueryOperator(BaseOperator): Executes BigQuery SQL queries in a specific BigQuery database. This operator does not assert idempotency. + This operator is deprecated. + Please use :class:`airflow.providers.google.cloud.operators.bigquery.BigQueryInsertJobOperator` + :param sql: the sql code to be executed (templated) :type sql: Can receive a str representing a sql statement, a list of str (sql statements), or reference to a template file. @@ -1084,14 +1087,14 @@ def __init__( *, bucket: str, source_objects: List, - destination_project_dataset_table: str, + destination_project_dataset_table: str = None, table_resource: Optional[Dict[str, Any]] = None, schema_fields: Optional[List] = None, schema_object: Optional[str] = None, - source_format: str = 'CSV', - compression: str = 'NONE', - skip_leading_rows: int = 0, - field_delimiter: str = ',', + source_format: Optional[str] = None, + compression: Optional[str] = None, + skip_leading_rows: Optional[int] = None, + field_delimiter: Optional[str] = None, max_bad_records: int = 0, quote_character: Optional[str] = None, allow_quoted_newlines: bool = False, @@ -1140,6 +1143,14 @@ def __init__( DeprecationWarning, stacklevel=2, ) + if not source_format: + source_format = 'CSV' + if not compression: + compression = 'NONE' + if not skip_leading_rows: + skip_leading_rows = 0 + if not field_delimiter: + field_delimiter = "," if table_resource and kwargs_passed: raise ValueError("You provided both `table_resource` and exclusive keywords arguments.") @@ -1579,9 +1590,8 @@ class BigQueryPatchDatasetOperator(BaseOperator): This operator is used to patch dataset for your Project in BigQuery. It only replaces fields that are provided in the submitted dataset resource. - .. seealso:: - For more information on how to use this operator, take a look at the guide: - :ref:`howto/operator:BigQueryPatchDatasetOperator` + This operator is deprecated. + Please use :class:`airflow.providers.google.cloud.operators.bigquery.BigQueryUpdateTableOperator` :param dataset_id: The id of dataset. Don't need to provide, if datasetId in dataset_reference. diff --git a/airflow/providers/google/cloud/operators/dataproc.py b/airflow/providers/google/cloud/operators/dataproc.py index 7988aa52ce9a4..96d3cedfdc1f5 100644 --- a/airflow/providers/google/cloud/operators/dataproc.py +++ b/airflow/providers/google/cloud/operators/dataproc.py @@ -1673,7 +1673,7 @@ def __init__( "Parameter `location` will be deprecated. " "Please provide value through `region` parameter instead.", DeprecationWarning, - stacklevel=1, + stacklevel=2, ) region = location else: @@ -1974,7 +1974,7 @@ def __init__( "Parameter `location` will be deprecated. " "Please provide value through `region` parameter instead.", DeprecationWarning, - stacklevel=1, + stacklevel=2, ) region = location else: @@ -2116,7 +2116,7 @@ def __init__( "Parameter `location` will be deprecated. " "Please provide value through `region` parameter instead.", DeprecationWarning, - stacklevel=1, + stacklevel=2, ) region = location else: diff --git a/airflow/providers/google/cloud/sensors/dataproc.py b/airflow/providers/google/cloud/sensors/dataproc.py index 68b4c701ca6ff..2bcfbe138a7be 100644 --- a/airflow/providers/google/cloud/sensors/dataproc.py +++ b/airflow/providers/google/cloud/sensors/dataproc.py @@ -63,7 +63,7 @@ def __init__( "Parameter `location` will be deprecated. " "Please provide value through `region` parameter instead.", DeprecationWarning, - stacklevel=1, + stacklevel=2, ) region = location else: diff --git a/airflow/providers/google/cloud/transfers/bigquery_to_mssql.py b/airflow/providers/google/cloud/transfers/bigquery_to_mssql.py index 553a8bfe8812c..a384a0f84064b 100644 --- a/airflow/providers/google/cloud/transfers/bigquery_to_mssql.py +++ b/airflow/providers/google/cloud/transfers/bigquery_to_mssql.py @@ -23,7 +23,6 @@ from airflow.models import BaseOperator from airflow.providers.google.cloud.hooks.bigquery import BigQueryHook from airflow.providers.microsoft.mssql.hooks.mssql import MsSqlHook -from airflow.utils.decorators import apply_defaults class BigQueryToMsSqlOperator(BaseOperator): @@ -89,7 +88,6 @@ class BigQueryToMsSqlOperator(BaseOperator): 'impersonation_chain', ) - @apply_defaults def __init__( self, *, diff --git a/airflow/providers/google/cloud/utils/mlengine_operator_utils.py b/airflow/providers/google/cloud/utils/mlengine_operator_utils.py index 5551d8fd4bb2b..90c05bb14aeb7 100644 --- a/airflow/providers/google/cloud/utils/mlengine_operator_utils.py +++ b/airflow/providers/google/cloud/utils/mlengine_operator_utils.py @@ -30,8 +30,8 @@ from airflow import DAG from airflow.exceptions import AirflowException from airflow.operators.python import PythonOperator +from airflow.providers.apache.beam.operators.beam import BeamRunPythonPipelineOperator from airflow.providers.google.cloud.hooks.gcs import GCSHook -from airflow.providers.google.cloud.operators.dataflow import DataflowCreatePythonJobOperator from airflow.providers.google.cloud.operators.mlengine import MLEngineStartBatchPredictionJobOperator T = TypeVar("T", bound=Callable) @@ -242,11 +242,11 @@ def validate_err_and_count(summary): ) metric_fn_encoded = base64.b64encode(dill.dumps(metric_fn, recurse=True)).decode() - evaluate_summary = DataflowCreatePythonJobOperator( + evaluate_summary = BeamRunPythonPipelineOperator( task_id=(task_prefix + "-summary"), py_file=os.path.join(os.path.dirname(__file__), 'mlengine_prediction_summary.py'), - dataflow_default_options=dataflow_options, - options={ + default_pipeline_options=dataflow_options, + pipeline_options={ "prediction_path": prediction_path, "metric_fn_encoded": metric_fn_encoded, "metric_keys": ','.join(metric_keys), diff --git a/airflow/providers/google/firebase/example_dags/example_firestore.py b/airflow/providers/google/firebase/example_dags/example_firestore.py index df0bb3a0b50ea..041b266222b83 100644 --- a/airflow/providers/google/firebase/example_dags/example_firestore.py +++ b/airflow/providers/google/firebase/example_dags/example_firestore.py @@ -51,7 +51,7 @@ BigQueryCreateEmptyDatasetOperator, BigQueryCreateExternalTableOperator, BigQueryDeleteDatasetOperator, - BigQueryExecuteQueryOperator, + BigQueryInsertJobOperator, ) from airflow.providers.google.firebase.operators.firestore import CloudFirestoreExportDatabaseOperator from airflow.utils import dates @@ -99,19 +99,39 @@ create_external_table_multiple_types = BigQueryCreateExternalTableOperator( task_id="create_external_table", bucket=BUCKET_NAME, + table_resource={ + "tableReference": { + "projectId": GCP_PROJECT_ID, + "datasetId": DATASET_NAME, + "tableId": "firestore_data", + }, + "schema": { + "fields": [ + {"name": "name", "type": "STRING"}, + {"name": "post_abbr", "type": "STRING"}, + ] + }, + "externalDataConfiguration": { + "sourceFormat": "DATASTORE_BACKUP", + "compression": "NONE", + "csvOptions": {"skipLeadingRows": 1}, + }, + }, source_objects=[ f"{EXPORT_PREFIX}/all_namespaces/kind_{EXPORT_COLLECTION_ID}" f"/all_namespaces_kind_{EXPORT_COLLECTION_ID}.export_metadata" ], - source_format="DATASTORE_BACKUP", - destination_project_dataset_table=f"{GCP_PROJECT_ID}.{DATASET_NAME}.firestore_data", ) # [END howto_operator_create_external_table_multiple_types] - read_data_from_gcs_multiple_types = BigQueryExecuteQueryOperator( + read_data_from_gcs_multiple_types = BigQueryInsertJobOperator( task_id="execute_query", - sql=f"SELECT COUNT(*) FROM `{GCP_PROJECT_ID}.{DATASET_NAME}.firestore_data`", - use_legacy_sql=False, + configuration={ + "query": { + "query": f"SELECT COUNT(*) FROM `{GCP_PROJECT_ID}.{DATASET_NAME}.firestore_data`", + "useLegacySql": False, + } + }, ) # Firestore diff --git a/airflow/providers/singularity/example_dags/example_singularity.py b/airflow/providers/singularity/example_dags/example_singularity.py index b37cab72261d4..83c9a7b787a5c 100644 --- a/airflow/providers/singularity/example_dags/example_singularity.py +++ b/airflow/providers/singularity/example_dags/example_singularity.py @@ -19,7 +19,7 @@ from datetime import timedelta from airflow import DAG -from airflow.operators.bash_operator import BashOperator +from airflow.operators.bash import BashOperator from airflow.providers.singularity.operators.singularity import SingularityOperator from airflow.utils.dates import days_ago diff --git a/dev/import_all_classes.py b/dev/import_all_classes.py index a1611f3e0cc7f..67a76c840a1bf 100755 --- a/dev/import_all_classes.py +++ b/dev/import_all_classes.py @@ -20,8 +20,10 @@ import pkgutil import sys import traceback +import warnings from inspect import isclass -from typing import List, Set +from typing import List, Set, Tuple +from warnings import WarningMessage from rich import print @@ -32,7 +34,7 @@ def import_all_classes( provider_ids: List[str] = None, print_imports: bool = False, print_skips: bool = False, -) -> List[str]: +) -> Tuple[List[str], List[WarningMessage]]: """ Imports all classes in providers packages. This method loads and imports all the classes found in providers, so that we can find all the subclasses @@ -43,7 +45,7 @@ def import_all_classes( :param provider_ids - provider ids that should be loaded. :param print_imports - if imported class should also be printed in output :param print_skips - if skipped classes should also be printed in output - :return: list of all imported classes + :return: tupple of list of all imported classes and all warnings generated """ imported_classes = [] tracebacks = [] @@ -63,6 +65,7 @@ def onerror(_): if any(provider_prefix in exception_string for provider_prefix in provider_prefixes): tracebacks.append(exception_string) + all_warnings: List[WarningMessage] = [] for modinfo in pkgutil.walk_packages(path=paths, prefix=prefix, onerror=onerror): if not any(modinfo.name.startswith(provider_prefix) for provider_prefix in provider_prefixes): if print_skips: @@ -74,12 +77,16 @@ def onerror(_): printed_packages.add(package_to_print) print(f"Importing package: {package_to_print}") try: - _module = importlib.import_module(modinfo.name) - for attribute_name in dir(_module): - class_name = modinfo.name + "." + attribute_name - attribute = getattr(_module, attribute_name) - if isclass(attribute): - imported_classes.append(class_name) + with warnings.catch_warnings(record=True) as w: + warnings.filterwarnings("always", category=DeprecationWarning) + _module = importlib.import_module(modinfo.name) + for attribute_name in dir(_module): + class_name = modinfo.name + "." + attribute_name + attribute = getattr(_module, attribute_name) + if isclass(attribute): + imported_classes.append(class_name) + if w: + all_warnings.extend(w) except Exception: exception_str = traceback.format_exc() tracebacks.append(exception_str) @@ -96,7 +103,7 @@ def onerror(_): print("[red]----------------------------------------[/]", file=sys.stderr) sys.exit(1) else: - return imported_classes + return imported_classes, all_warnings if __name__ == '__main__': @@ -109,10 +116,18 @@ def onerror(_): print() print(f"Walking all packages in {args.path} with prefix {args.prefix}") print() - classes = import_all_classes(print_imports=True, print_skips=True, paths=args.path, prefix=args.prefix) + classes, warns = import_all_classes( + print_imports=True, print_skips=True, paths=args.path, prefix=args.prefix + ) if len(classes) == 0: print("[red]Something is seriously wrong - no classes imported[/]") sys.exit(1) + if warns: + print("[yellow]There were warnings generated during the import[/]") + for w in warns: + one_line_message = str(w.message).replace('\n', ' ') + print(f"[yellow]{w.filename}:{w.lineno}: {one_line_message}[/]") + print() print(f"[green]SUCCESS: All provider packages are importable! Imported {len(classes)} classes.[/]") print() diff --git a/dev/provider_packages/prepare_provider_packages.py b/dev/provider_packages/prepare_provider_packages.py index 23c51c74b2f01..b45ee3c5f42cb 100755 --- a/dev/provider_packages/prepare_provider_packages.py +++ b/dev/provider_packages/prepare_provider_packages.py @@ -31,6 +31,7 @@ import sys import tempfile import textwrap +import warnings from contextlib import contextmanager from copy import deepcopy from datetime import datetime, timedelta @@ -43,20 +44,15 @@ import click import jsonschema -import yaml from github import Github, PullRequest, UnknownObjectException from packaging.version import Version -from rich import print from rich.console import Console from rich.progress import Progress from rich.syntax import Syntax -ALL_PYTHON_VERSIONS = ["3.6", "3.7", "3.8", "3.9"] +from airflow.utils.yaml import safe_load -try: - from yaml import CSafeLoader as SafeLoader -except ImportError: - from yaml import SafeLoader +ALL_PYTHON_VERSIONS = ["3.6", "3.7", "3.8", "3.9"] INITIAL_CHANGELOG_CONTENT = """ @@ -118,6 +114,8 @@ PY3 = sys.version_info[0] == 3 +console = Console(width=400, color_system="standard") + @click.group(context_settings={'help_option_names': ['-h', '--help'], 'max_content_width': 500}) def cli(): @@ -172,12 +170,12 @@ def with_group(title): https://docs.github.com/en/free-pro-team@latest/actions/reference/workflow-commands-for-github-actions#grouping-log-lines """ if os.environ.get('GITHUB_ACTIONS', 'false') != "true": - print("[blue]" + "#" * 10 + ' ' + title + ' ' + "#" * 10 + "[/]") + console.print("[blue]" + "#" * 10 + ' ' + title + ' ' + "#" * 10 + "[/]") yield return - print(f"::group::{title}") + console.print(f"::group::{title}") yield - print("::endgroup::") + console.print("::endgroup::") class EntityType(Enum): @@ -189,11 +187,8 @@ class EntityType(Enum): class EntityTypeSummary(NamedTuple): - entities: Set[str] - new_entities: List[str] - moved_entities: Dict[str, str] + entities: List[str] new_entities_table: str - moved_entities_table: str wrong_entities: List[Tuple[type, str]] @@ -221,12 +216,12 @@ class ProviderPackageDetails(NamedTuple): EntityType.Secrets: "Secrets", } -TOTALS: Dict[EntityType, List[int]] = { - EntityType.Operators: [0, 0], - EntityType.Hooks: [0, 0], - EntityType.Sensors: [0, 0], - EntityType.Transfers: [0, 0], - EntityType.Secrets: [0, 0], +TOTALS: Dict[EntityType, int] = { + EntityType.Operators: 0, + EntityType.Hooks: 0, + EntityType.Sensors: 0, + EntityType.Transfers: 0, + EntityType.Secrets: 0, } OPERATORS_PATTERN = r".*Operator$" @@ -301,14 +296,6 @@ def get_target_providers_package_folder(provider_package_id: str) -> str: DEPENDENCIES_JSON_FILE = os.path.join(PROVIDERS_PATH, "dependencies.json") -MOVED_ENTITIES: Dict[EntityType, Dict[str, str]] = { - EntityType.Operators: {value[0]: value[1] for value in tests.deprecated_classes.OPERATORS}, - EntityType.Sensors: {value[0]: value[1] for value in tests.deprecated_classes.SENSORS}, - EntityType.Hooks: {value[0]: value[1] for value in tests.deprecated_classes.HOOKS}, - EntityType.Secrets: {value[0]: value[1] for value in tests.deprecated_classes.SECRETS}, - EntityType.Transfers: {value[0]: value[1] for value in tests.deprecated_classes.TRANSFERS}, -} - def get_pip_package_name(provider_package_id: str) -> str: """ @@ -565,49 +552,19 @@ def find_all_entities( return VerifiedEntities(all_entities=found_entities, wrong_entities=wrong_entities) -def convert_new_classes_to_table( - entity_type: EntityType, new_entities: List[str], full_package_name: str -) -> str: +def convert_classes_to_table(entity_type: EntityType, entities: List[str], full_package_name: str) -> str: """ Converts new entities tp a markdown table. - :param entity_type: list of entities to convert to markup - :param new_entities: list of new entities + :param entity_type: entity type to convert to markup + :param entities: list of entities :param full_package_name: name of the provider package :return: table of new classes """ from tabulate import tabulate headers = [f"New Airflow 2.0 {entity_type.value.lower()}: `{full_package_name}` package"] - table = [(get_class_code_link(full_package_name, class_name, "main"),) for class_name in new_entities] - return tabulate(table, headers=headers, tablefmt="pipe") - - -def convert_moved_classes_to_table( - entity_type: EntityType, - moved_entities: Dict[str, str], - full_package_name: str, -) -> str: - """ - Converts moved entities to a markdown table - :param entity_type: type of entities -> operators, sensors etc. - :param moved_entities: dictionary of moved entities `to -> from` - :param full_package_name: name of the provider package - :return: table of moved classes - """ - from tabulate import tabulate - - headers = [ - f"Airflow 2.0 {entity_type.value.lower()}: `{full_package_name}` package", - "Airflow 1.10.* previous location (usually `airflow.contrib`)", - ] - table = [ - ( - get_class_code_link(full_package_name, to_class, "main"), - get_class_code_link("airflow", moved_entities[to_class], "v1-10-stable"), - ) - for to_class in sorted(moved_entities.keys()) - ] + table = [(get_class_code_link(full_package_name, class_name, "main"),) for class_name in entities] return tabulate(table, headers=headers, tablefmt="pipe") @@ -618,8 +575,7 @@ def get_details_about_classes( full_package_name: str, ) -> EntityTypeSummary: """ - Splits the set of entities into new and moved, depending on their presence in the dict of objects - retrieved from the test_contrib_to_core. Updates all_entities with the split class. + Get details about entities.. :param entity_type: type of entity (Operators, Hooks etc.) :param entities: set of entities found @@ -627,30 +583,14 @@ def get_details_about_classes( :param full_package_name: full package name :return: """ - dict_of_moved_classes = MOVED_ENTITIES[entity_type] - new_entities = [] - moved_entities = {} - for obj in entities: - if obj in dict_of_moved_classes: - moved_entities[obj] = dict_of_moved_classes[obj] - del dict_of_moved_classes[obj] - else: - new_entities.append(obj) - new_entities.sort() - TOTALS[entity_type][0] += len(new_entities) - TOTALS[entity_type][1] += len(moved_entities) + all_entities = list(entities) + all_entities.sort() + TOTALS[entity_type] += len(all_entities) return EntityTypeSummary( - entities=entities, - new_entities=new_entities, - moved_entities=moved_entities, - new_entities_table=convert_new_classes_to_table( - entity_type=entity_type, - new_entities=new_entities, - full_package_name=full_package_name, - ), - moved_entities_table=convert_moved_classes_to_table( + entities=all_entities, + new_entities_table=convert_classes_to_table( entity_type=entity_type, - moved_entities=moved_entities, + entities=all_entities, full_package_name=full_package_name, ), wrong_entities=wrong_entities, @@ -701,9 +641,9 @@ def print_wrong_naming(entity_type: EntityType, wrong_classes: List[Tuple[type, :param wrong_classes: list of wrong entities """ if wrong_classes: - print(f"\n[red]There are wrongly named entities of type {entity_type}:[/]\n", file=sys.stderr) + console.print(f"\n[red]There are wrongly named entities of type {entity_type}:[/]\n") for wrong_entity_type, message in wrong_classes: - print(f"{wrong_entity_type}: {message}", file=sys.stderr) + console.print(f"{wrong_entity_type}: {message}") def get_package_class_summary( @@ -1050,10 +990,9 @@ def check_if_release_version_ok( current_release_version = (datetime.today() + timedelta(days=5)).strftime('%Y.%m.%d') if previous_release_version: if Version(current_release_version) < Version(previous_release_version): - print( + console.print( f"[red]The release {current_release_version} must be not less than " - f"{previous_release_version} - last release for the package[/]", - file=sys.stderr, + f"{previous_release_version} - last release for the package[/]" ) raise Exception("Bad release version") return current_release_version, previous_release_version @@ -1088,7 +1027,7 @@ def make_sure_remote_apache_exists_and_fetch(git_update: bool, verbose: bool): try: check_remote_command = ["git", "remote", "get-url", HTTPS_REMOTE] if verbose: - print(f"Running command: '{' '.join(check_remote_command)}'") + console.print(f"Running command: '{' '.join(check_remote_command)}'") subprocess.check_call( check_remote_command, stdout=subprocess.DEVNULL, @@ -1108,19 +1047,19 @@ def make_sure_remote_apache_exists_and_fetch(git_update: bool, verbose: bool): "https://github.com/apache/airflow.git", ] if verbose: - print(f"Running command: '{' '.join(remote_add_command)}'") + console.print(f"Running command: '{' '.join(remote_add_command)}'") try: subprocess.check_output( remote_add_command, stderr=subprocess.STDOUT, ) except subprocess.CalledProcessError as ex: - print("[red]Error: when adding remote:[/]", ex) + console.print("[red]Error: when adding remote:[/]", ex) else: raise if verbose: - print("Fetching full history and tags from remote. ") - print("This might override your local tags!") + console.print("Fetching full history and tags from remote. ") + console.print("This might override your local tags!") is_shallow_repo = ( subprocess.check_output(["git", "rev-parse", "--is-shallow-repository"], stderr=subprocess.DEVNULL) == 'true' @@ -1128,13 +1067,13 @@ def make_sure_remote_apache_exists_and_fetch(git_update: bool, verbose: bool): fetch_command = ["git", "fetch", "--tags", "--force", HTTPS_REMOTE] if is_shallow_repo: if verbose: - print( + console.print( "This will also unshallow the repository, " "making all history available and increasing storage!" ) fetch_command.append("--unshallow") if verbose: - print(f"Running command: '{' '.join(fetch_command)}'") + console.print(f"Running command: '{' '.join(fetch_command)}'") subprocess.check_call( fetch_command, stderr=subprocess.DEVNULL, @@ -1163,7 +1102,7 @@ def get_git_log_command( git_cmd.append(from_commit) git_cmd.extend(['--', '.']) if verbose: - print(f"Command to run: '{' '.join(git_cmd)}'") + console.print(f"Command to run: '{' '.join(git_cmd)}'") return git_cmd @@ -1261,13 +1200,13 @@ def check_if_classes_are_properly_named( _, class_name = class_full_name.rsplit(".", maxsplit=1) error_encountered = False if not is_camel_case_with_acronyms(class_name): - print( + console.print( f"[red]The class {class_full_name} is wrongly named. The " f"class name should be CamelCaseWithACRONYMS ![/]" ) error_encountered = True if not class_name.endswith(class_suffix): - print( + console.print( f"[red]The class {class_full_name} is wrongly named. It is one of the {entity_type.value}" f" so it should end with {class_suffix}[/]" ) @@ -1296,7 +1235,7 @@ def validate_provider_info_with_runtime_schema(provider_info: Dict[str, Any]) -> try: jsonschema.validate(provider_info, schema=schema) except jsonschema.ValidationError as ex: - print("[red]Provider info not validated against runtime schema[/]") + console.print("[red]Provider info not validated against runtime schema[/]") raise Exception( "Error when validating schema. The schema must be compatible with " + "airflow/provider_info.schema.json.", @@ -1316,7 +1255,7 @@ def get_provider_yaml(provider_package_id: str) -> Dict[str, Any]: if not os.path.exists(provider_yaml_file_name): raise Exception(f"The provider.yaml file is missing: {provider_yaml_file_name}") with open(provider_yaml_file_name) as provider_file: - provider_yaml_dict = yaml.load(provider_file, SafeLoader) + provider_yaml_dict = safe_load(provider_file) return provider_yaml_dict @@ -1339,7 +1278,6 @@ def get_version_tag(version: str, provider_package_id: str, version_suffix: str def print_changes_table(changes_table): syntax = Syntax(changes_table, "rst", theme="ansi_dark") - console = Console(width=200) console.print(syntax) @@ -1360,14 +1298,14 @@ def get_all_changes_for_package( current_version = versions[0] current_tag_no_suffix = get_version_tag(current_version, provider_package_id) if verbose: - print(f"Checking if tag '{current_tag_no_suffix}' exist.") + console.print(f"Checking if tag '{current_tag_no_suffix}' exist.") if not subprocess.call( get_git_tag_check_command(current_tag_no_suffix), cwd=source_provider_package_path, stderr=subprocess.DEVNULL, ): if verbose: - print(f"The tag {current_tag_no_suffix} exists.") + console.print(f"The tag {current_tag_no_suffix} exists.") # The tag already exists changes = subprocess.check_output( get_git_log_command(verbose, HEAD_OF_HTTPS_REMOTE, current_tag_no_suffix), @@ -1389,21 +1327,23 @@ def get_all_changes_for_package( universal_newlines=True, ) if not changes_since_last_doc_only_check: - print() - print("[yellow]The provider has doc-only changes since the last release. Skipping[/]") + console.print() + console.print( + "[yellow]The provider has doc-only changes since the last release. Skipping[/]" + ) # Returns 66 in case of doc-only changes sys.exit(66) except subprocess.CalledProcessError: # ignore when the commit mentioned as last doc-only change is obsolete pass - print(f"[yellow]The provider {provider_package_id} has changes since last release[/]") - print() - print( + console.print(f"[yellow]The provider {provider_package_id} has changes since last release[/]") + console.print() + console.print( "[yellow]Please update version in " f"'airflow/providers/{provider_package_id.replace('-','/')}/'" "provider.yaml'[/]\n" ) - print("[yellow]Or mark the changes as doc-only[/]") + console.print("[yellow]Or mark the changes as doc-only[/]") changes_table, array_of_changes = convert_git_changes_to_table( "UNKNOWN", changes, @@ -1413,14 +1353,16 @@ def get_all_changes_for_package( print_changes_table(changes_table) return False, array_of_changes[0], changes_table else: - print(f"No changes for {provider_package_id}") + console.print(f"No changes for {provider_package_id}") return False, None, "" if verbose: - print("The tag does not exist. ") + console.print("The tag does not exist. ") if len(versions) == 1: - print(f"The provider '{provider_package_id}' has never been released but it is ready to release!\n") + console.print( + f"The provider '{provider_package_id}' has never been released but it is ready to release!\n" + ) else: - print(f"New version of the '{provider_package_id}' package is ready to be released!\n") + console.print(f"New version of the '{provider_package_id}' package is ready to be released!\n") next_version_tag = HEAD_OF_HTTPS_REMOTE changes_table = '' current_version = versions[0] @@ -1565,7 +1507,7 @@ def confirm(message: str): """ answer = "" while answer not in ["y", "n", "q"]: - print(f"[yellow]{message}[Y/N/Q]?[/] ", end='') + console.print(f"[yellow]{message}[Y/N/Q]?[/] ", end='') answer = input("").lower() if answer == "q": # Returns 65 in case user decided to quit @@ -1576,7 +1518,7 @@ def confirm(message: str): def mark_latest_changes_as_documentation_only( provider_details: ProviderPackageDetails, latest_change: Change ): - print( + console.print( f"Marking last change: {latest_change.short_hash} and all above changes since the last release " "as doc-only changes!" ) @@ -1626,11 +1568,11 @@ def update_release_notes( if interactive and not confirm("Provider marked for release. Proceed?"): return False elif not latest_change: - print() - print( + console.print() + console.print( f"[yellow]Provider: {provider_package_id} - skipping documentation generation. No changes![/]" ) - print() + console.print() return False else: if interactive and confirm("Are those changes documentation-only?"): @@ -1666,9 +1608,9 @@ def update_setup_files( current_release_version=current_release_version, version_suffix=version_suffix, ) - print() - print(f"Generating setup files for {provider_package_id}") - print() + console.print() + console.print(f"Generating setup files for {provider_package_id}") + console.print() prepare_setup_py_file(jinja_context) prepare_setup_cfg_file(jinja_context) prepare_get_provider_info_py_file(jinja_context, provider_package_id) @@ -1685,9 +1627,9 @@ def replace_content(file_path, old_text, new_text, provider_package_id): copyfile(file_path, temp_file_path) with open(file_path, "wt") as readme_file: readme_file.write(new_text) - print() - print(f"Generated {file_path} file for the {provider_package_id} provider") - print() + console.print() + console.print(f"Generated {file_path} file for the {provider_package_id} provider") + console.print() if old_text != "": subprocess.call(["diff", "--color=always", temp_file_path, file_path]) finally: @@ -1838,9 +1780,9 @@ def verify_provider_package(provider_package_id: str) -> str: :return: None """ if provider_package_id not in get_provider_packages(): - print(f"[red]Wrong package name: {provider_package_id}[/]") - print("Use one of:") - print(get_provider_packages()) + console.print(f"[red]Wrong package name: {provider_package_id}[/]") + console.print("Use one of:") + console.print(get_provider_packages()) raise Exception(f"The package {provider_package_id} is not a provider package.") @@ -1848,17 +1790,16 @@ def verify_changelog_exists(package: str) -> str: provider_details = get_provider_details(package) changelog_path = os.path.join(provider_details.source_provider_package_path, "CHANGELOG.rst") if not os.path.isfile(changelog_path): - print(f"[red]ERROR: Missing ${changelog_path}[/]") - print("Please add the file with initial content:") - print() + console.print(f"[red]ERROR: Missing ${changelog_path}[/]") + console.print("Please add the file with initial content:") + console.print() syntax = Syntax( INITIAL_CHANGELOG_CONTENT, "rst", theme="ansi_dark", ) - console = Console(width=200) console.print(syntax) - print() + console.print() raise Exception(f"Missing {changelog_path}") return changelog_path @@ -1868,7 +1809,7 @@ def list_providers_packages(): """List all provider packages.""" providers = get_all_providers() for provider in providers: - print(provider) + console.print(provider) @cli.command() @@ -1894,7 +1835,7 @@ def update_package_documentation( provider_package_id = package_id verify_provider_package(provider_package_id) with with_group(f"Update release notes for package '{provider_package_id}' "): - print("Updating documentation for the latest release version.") + console.print("Updating documentation for the latest release version.") make_sure_remote_apache_exists_and_fetch(git_update, verbose) if not update_release_notes( provider_package_id, version_suffix, force=force, verbose=verbose, interactive=interactive @@ -1906,7 +1847,7 @@ def update_package_documentation( def tag_exists_for_version(provider_package_id: str, current_tag: str, verbose: bool): provider_details = get_provider_details(provider_package_id) if verbose: - print(f"Checking if tag `{current_tag}` exists.") + console.print(f"Checking if tag `{current_tag}` exists.") if not subprocess.call( get_git_tag_check_command(current_tag), cwd=provider_details.source_provider_package_path, @@ -1914,10 +1855,10 @@ def tag_exists_for_version(provider_package_id: str, current_tag: str, verbose: stdout=subprocess.DEVNULL, ): if verbose: - print(f"Tag `{current_tag}` exists.") + console.print(f"Tag `{current_tag}` exists.") return True if verbose: - print(f"Tag `{current_tag}` does not exist.") + console.print(f"Tag `{current_tag}` does not exist.") return False @@ -1936,7 +1877,7 @@ def generate_setup_files(version_suffix: str, git_update: bool, package_id: str, with with_group(f"Generate setup files for '{provider_package_id}'"): current_tag = get_current_tag(provider_package_id, version_suffix, git_update, verbose) if tag_exists_for_version(provider_package_id, current_tag, verbose): - print(f"[yellow]The tag {current_tag} exists. Not preparing the package.[/]") + console.print(f"[yellow]The tag {current_tag} exists. Not preparing the package.[/]") # Returns 1 in case of skipped package sys.exit(1) else: @@ -1944,7 +1885,7 @@ def generate_setup_files(version_suffix: str, git_update: bool, package_id: str, provider_package_id, version_suffix, ): - print(f"[green]Generated regular package setup files for {provider_package_id}[/]") + console.print(f"[green]Generated regular package setup files for {provider_package_id}[/]") else: # Returns 64 in case of skipped package sys.exit(64) @@ -1962,7 +1903,7 @@ def get_current_tag(provider_package_id: str, suffix: str, git_update: bool, ver def cleanup_remnants(verbose: bool): if verbose: - print("Cleaning remnants") + console.print("Cleaning remnants") files = glob.glob("*.egg-info") for file in files: shutil.rmtree(file, ignore_errors=True) @@ -1976,11 +1917,11 @@ def verify_setup_py_prepared(provider_package): setup_content = f.read() search_for = f"providers-{provider_package.replace('.','-')} for Apache Airflow" if search_for not in setup_content: - print( + console.print( f"[red]The setup.py is probably prepared for another package. " f"It does not contain [bold]{search_for}[/bold]![/]" ) - print( + console.print( f"\nRun:\n\n[bold]./dev/provider_packages/prepare_provider_packages.py " f"generate-setup-files {provider_package}[/bold]\n" ) @@ -2022,15 +1963,15 @@ def build_provider_packages( with with_group(f"Prepare provider package for '{provider_package_id}'"): current_tag = get_current_tag(provider_package_id, version_suffix, git_update, verbose) if tag_exists_for_version(provider_package_id, current_tag, verbose): - print(f"[yellow]The tag {current_tag} exists. Skipping the package.[/]") + console.print(f"[yellow]The tag {current_tag} exists. Skipping the package.[/]") return False - print(f"Changing directory to ${TARGET_PROVIDER_PACKAGES_PATH}") + console.print(f"Changing directory to ${TARGET_PROVIDER_PACKAGES_PATH}") os.chdir(TARGET_PROVIDER_PACKAGES_PATH) cleanup_remnants(verbose) provider_package = package_id verify_setup_py_prepared(provider_package) - print(f"Building provider package: {provider_package} in format {package_format}") + console.print(f"Building provider package: {provider_package} in format {package_format}") command = ["python3", "setup.py", "build", "--build-temp", tmp_build_dir] if version_suffix is not None: command.extend(['egg_info', '--tag-build', version_suffix]) @@ -2038,13 +1979,15 @@ def build_provider_packages( command.append("sdist") if package_format in ['wheel', 'both']: command.extend(["bdist_wheel", "--bdist-dir", tmp_dist_dir]) - print(f"Executing command: '{' '.join(command)}'") + console.print(f"Executing command: '{' '.join(command)}'") try: subprocess.check_call(command, stdout=subprocess.DEVNULL) except subprocess.CalledProcessError as ex: - print(ex.output.decode()) + console.print(ex.output.decode()) raise Exception("The command returned an error %s", command) - print(f"[green]Prepared provider package {provider_package} in format {package_format}[/]") + console.print( + f"[green]Prepared provider package {provider_package} in format {package_format}[/]" + ) finally: shutil.rmtree(tmp_build_dir, ignore_errors=True) shutil.rmtree(tmp_dist_dir, ignore_errors=True) @@ -2057,35 +2000,148 @@ def verify_provider_classes_for_single_provider(imported_classes: List[str], pro total, bad = check_if_classes_are_properly_named(entity_summaries) bad += sum(len(entity_summary.wrong_entities) for entity_summary in entity_summaries.values()) if bad != 0: - print() - print(f"[red]There are {bad} errors of {total} entities for {provider_package_id}[/]") - print() + console.print() + console.print(f"[red]There are {bad} errors of {total} entities for {provider_package_id}[/]") + console.print() return total, bad -def summarise_total_vs_bad(total: int, bad: int): - """Summarises Bad/Good class names for providers""" +def summarise_total_vs_bad_and_warnings(total: int, bad: int, warns: List[warnings.WarningMessage]) -> bool: + """Summarises Bad/Good class names for providers and warnings""" + raise_error = False if bad == 0: - print() - print(f"[green]All good! All {total} entities are properly named[/]") - print() - print("Totals:") - print() - print("New:") - print() - for entity in EntityType: - print(f"{entity.value}: {TOTALS[entity][0]}") - print() - print("Moved:") - print() + console.print() + console.print(f"[green]OK: All {total} entities are properly named[/]") + console.print() + console.print("Totals:") + console.print() for entity in EntityType: - print(f"{entity.value}: {TOTALS[entity][1]}") - print() + console.print(f"{entity.value}: {TOTALS[entity]}") + console.print() else: - print() - print(f"[red]There are in total: {bad} entities badly named out of {total} entities[/]") - print() - raise Exception("Badly names entities") + console.print() + console.print( + f"[red]ERROR! There are in total: {bad} entities badly named out of {total} entities[/]" + ) + console.print() + raise_error = True + if warns: + console.print() + console.print("[red]Unknown warnings generated:[/]") + console.print() + for w in warns: + one_line_message = str(w.message).replace('\n', ' ') + console.print(f"{w.filename}:{w.lineno}:[yellow]{one_line_message}[/]") + console.print() + console.print(f"[red]ERROR! There were {len(warns)} warnings generated during the import[/]") + console.print() + console.print("[yellow]Ideally, fix it, so that no warnings are generated during import.[/]") + console.print("[yellow]There are two cases that are legitimate deprecation warnings though:[/]") + console.print("[yellow] 1) when you deprecate whole module or class and replace it in provider[/]") + console.print("[yellow] 2) when 3rd-party module generates Deprecation and you cannot upgrade it[/]") + console.print() + console.print( + "[yellow]In case 1), add the deprecation message to " + "the KNOWN_DEPRECATED_DIRECT_IMPORTS in prepare_provider_packages.py[/]" + ) + console.print( + "[yellow]In case 2), add the deprecation message together with module it generates to " + "the KNOWN_DEPRECATED_MESSAGES in prepare_provider_packages.py[/]" + ) + console.print() + raise_error = True + else: + console.print() + console.print("[green]OK: No warnings generated[/]") + console.print() + + if raise_error: + console.print("[red]Please fix the problems listed above [/]") + return False + return True + + +# The set of known deprecation messages that we know about. +# It contains tuples of "message" and the module that generates the warning - so when the +# Same warning is generated by different module, it is not treated as "known" warning. +KNOWN_DEPRECATED_MESSAGES: Set[Tuple[str, str]] = { + ( + 'This version of Apache Beam has not been sufficiently tested on Python 3.9. ' + 'You may encounter bugs or missing features.', + "apache_beam", + ), + ( + "Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated since" + " Python 3.3, and in 3.10 it will stop working", + "apache_beam", + ), + ( + "Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated since" + " Python 3.3, and in 3.10 it will stop working", + "dns", + ), + ( + 'pyarrow.HadoopFileSystem is deprecated as of 2.0.0, please use pyarrow.fs.HadoopFileSystem instead.', + "papermill", + ), + ( + "You have an incompatible version of 'pyarrow' installed (4.0.1), please install a version that " + "adheres to: 'pyarrow<3.1.0,>=3.0.0; extra == \"pandas\"'", + "apache_beam", + ), + ( + "You have an incompatible version of 'pyarrow' installed (4.0.1), please install a version that " + "adheres to: 'pyarrow<5.1.0,>=5.0.0; extra == \"pandas\"'", + "snowflake", + ), + ("dns.hash module will be removed in future versions. Please use hashlib instead.", "dns"), + ("PKCS#7 support in pyOpenSSL is deprecated. You should use the APIs in cryptography.", "eventlet"), + ("PKCS#12 support in pyOpenSSL is deprecated. You should use the APIs in cryptography.", "eventlet"), + ( + "the imp module is deprecated in favour of importlib; see the module's documentation" + " for alternative uses", + "hdfs", + ), + ("This operator is deprecated. Please use `airflow.providers.tableau.operators.tableau`.", "salesforce"), + ( + "You have an incompatible version of 'pyarrow' installed (4.0.1), please install a version that" + " adheres to: 'pyarrow<3.1.0,>=3.0.0; extra == \"pandas\"'", + "snowflake", + ), + ("SelectableGroups dict interface is deprecated. Use select.", "kombu"), + ("The module cloudant is now deprecated. The replacement is ibmcloudant.", "cloudant"), +} + +# The set of warning messages generated by direct importing of some deprecated modules. We should only +# ignore those messages when the warnings are generated directly by importlib - which means that +# we imported it directly during module walk by the importlib library +KNOWN_DEPRECATED_DIRECT_IMPORTS: Set[str] = { + "This module is deprecated. Please use `airflow.providers.amazon.aws.hooks.dynamodb`.", + "This module is deprecated. Please use `airflow.providers.tableau.operators.tableau_refresh_workbook`.", + "This module is deprecated. Please use `airflow.providers.tableau.sensors.tableau_job_status`.", + "This module is deprecated. Please use `airflow.providers.tableau.hooks.tableau`.", + "This module is deprecated. Please use `kubernetes.client.models.V1Volume`.", + "This module is deprecated. Please use `kubernetes.client.models.V1VolumeMount`.", + 'numpy.ufunc size changed, may indicate binary incompatibility. Expected 192 from C header,' + ' got 216 from PyObject', +} + + +def filter_known_warnings(warn: warnings.WarningMessage) -> bool: + msg_string = str(warn.message).replace("\n", " ") + for m in KNOWN_DEPRECATED_MESSAGES: + expected_package_string = "/" + m[1] + "/" + if msg_string == m[0] and warn.filename.find(expected_package_string) != -1: + return False + return True + + +def filter_direct_importlib_warning(warn: warnings.WarningMessage) -> bool: + msg_string = str(warn.message).replace("\n", " ") + for m in KNOWN_DEPRECATED_DIRECT_IMPORTS: + if msg_string == m and warn.filename.find("/importlib/") != -1: + return False + return True @cli.command() @@ -2093,9 +2149,9 @@ def verify_provider_classes(): """Verifies names for all provider classes.""" with with_group("Verifies names for all provider classes"): provider_ids = get_all_providers() - imported_classes = import_all_classes( + imported_classes, warns = import_all_classes( provider_ids=provider_ids, - print_imports=False, + print_imports=True, paths=[PROVIDERS_PATH], prefix="airflow.providers.", ) @@ -2107,7 +2163,10 @@ def verify_provider_classes(): ) total += inc_total bad += inc_bad - summarise_total_vs_bad(total, bad) + warns = list(filter(filter_known_warnings, warns)) + warns = list(filter(filter_direct_importlib_warning, warns)) + if not summarise_total_vs_bad_and_warnings(total, bad, warns): + sys.exit(1) def find_insertion_index_for_version(content: List[str], version: str) -> Tuple[int, bool]: @@ -2199,12 +2258,14 @@ def _update_changelog(package_id: str, verbose: bool) -> bool: verbose, ) if not proceed: - print(f"[yellow]The provider {package_id} is not being released. Skipping the package.[/]") + console.print( + f"[yellow]The provider {package_id} is not being released. Skipping the package.[/]" + ) return True generate_new_changelog(package_id, provider_details, changelog_path, changes) - print() - print(f"Update index.rst for {package_id}") - print() + console.print() + console.print(f"Update index.rst for {package_id}") + console.print() update_index_rst(jinja_context, package_id, provider_details.documentation_provider_package_path) return False @@ -2217,7 +2278,7 @@ def generate_new_changelog(package_id, provider_details, changelog_path, changes insertion_index, append = find_insertion_index_for_version(current_changelog_lines, latest_version) if append: if not changes: - print( + console.print( f"[green]The provider {package_id} changelog for `{latest_version}` " "has first release. Not updating the changelog.[/]" ) @@ -2226,7 +2287,7 @@ def generate_new_changelog(package_id, provider_details, changelog_path, changes change for change in changes[0] if change.pr and "(#" + change.pr + ")" not in current_changelog ] if not new_changes: - print( + console.print( f"[green]The provider {package_id} changelog for `{latest_version}` " "has no new changes. Not updating the changelog.[/]" ) @@ -2250,15 +2311,16 @@ def generate_new_changelog(package_id, provider_details, changelog_path, changes new_changelog_lines.extend(current_changelog_lines[insertion_index:]) diff = "\n".join(difflib.context_diff(current_changelog_lines, new_changelog_lines, n=5)) syntax = Syntax(diff, "diff") - console = Console(width=200) console.print(syntax) if not append: - print( + console.print( f"[green]The provider {package_id} changelog for `{latest_version}` " "version is missing. Generating fresh changelog.[/]" ) else: - print(f"[green]Appending the provider {package_id} changelog for" f"`{latest_version}` version.[/]") + console.print( + f"[green]Appending the provider {package_id} changelog for" f"`{latest_version}` version.[/]" + ) with open(changelog_path, "wt") as changelog: changelog.write("\n".join(new_changelog_lines)) changelog.write("\n") @@ -2340,7 +2402,6 @@ def generate_issue_content(package_ids: List[str], github_token: str, suffix: st excluded_prs = [int(pr) for pr in excluded_pr_list.split(",")] else: excluded_prs = [] - console = Console(width=200, color_system="standard") all_prs: Set[int] = set() provider_prs: Dict[str, List[int]] = {} for package_id in package_ids: diff --git a/docs/apache-airflow-providers-google/operators/cloud/bigquery.rst b/docs/apache-airflow-providers-google/operators/cloud/bigquery.rst index b99971a193065..5cd2548df8e4b 100644 --- a/docs/apache-airflow-providers-google/operators/cloud/bigquery.rst +++ b/docs/apache-airflow-providers-google/operators/cloud/bigquery.rst @@ -99,23 +99,6 @@ method only replaces fields that are provided in the submitted Table resource. :start-after: [START howto_operator_bigquery_update_table] :end-before: [END howto_operator_bigquery_update_table] -.. _howto/operator:BigQueryPatchDatasetOperator: - -Patch dataset -""""""""""""" - -To patch a dataset in BigQuery you can use -:class:`~airflow.providers.google.cloud.operators.bigquery.BigQueryPatchDatasetOperator`. - -Note, this operator only replaces fields that are provided in the submitted dataset -resource. - -.. exampleinclude:: /../../airflow/providers/google/cloud/example_dags/example_bigquery_operations.py - :language: python - :dedent: 4 - :start-after: [START howto_operator_bigquery_patch_dataset] - :end-before: [END howto_operator_bigquery_patch_dataset] - .. _howto/operator:BigQueryUpdateDatasetOperator: Update dataset diff --git a/tests/always/test_project_structure.py b/tests/always/test_project_structure.py index d4d86455dc3ba..5b40d74554402 100644 --- a/tests/always/test_project_structure.py +++ b/tests/always/test_project_structure.py @@ -162,39 +162,39 @@ class TestGoogleProviderProjectStructure(unittest.TestCase): ('ads', 'ads_to_gcs'), } - MISSING_EXAMPLES_FOR_OPERATORS = { - # Deprecated operator. Ignore it. + # Those operators are deprecated and we do not need examples for them + DEPRECATED_OPERATORS = { 'airflow.providers.google.cloud.operators.cloud_storage_transfer_service' '.CloudDataTransferServiceS3ToGCSOperator', - # Deprecated operator. Ignore it. 'airflow.providers.google.cloud.operators.cloud_storage_transfer_service' '.CloudDataTransferServiceGCSToGCSOperator', - # Deprecated operator. Ignore it. 'airflow.providers.google.cloud.sensors.gcs.GCSObjectsWtihPrefixExistenceSensor', - # Base operator. Ignore it. - 'airflow.providers.google.cloud.operators.cloud_sql.CloudSQLBaseOperator', - # Deprecated operator. Ignore it 'airflow.providers.google.cloud.operators.dataproc.DataprocSubmitHadoopJobOperator', - 'airflow.providers.google.cloud.operators.dataproc.DataprocInstantiateInlineWorkflowTemplateOperator', - # Deprecated operator. Ignore it 'airflow.providers.google.cloud.operators.dataproc.DataprocScaleClusterOperator', - # Base operator. Ignore it - 'airflow.providers.google.cloud.operators.dataproc.DataprocJobBaseOperator', - # Deprecated operator. Ignore it 'airflow.providers.google.cloud.operators.dataproc.DataprocSubmitSparkJobOperator', - # Deprecated operator. Ignore it 'airflow.providers.google.cloud.operators.dataproc.DataprocSubmitSparkSqlJobOperator', - # Deprecated operator. Ignore it 'airflow.providers.google.cloud.operators.dataproc.DataprocSubmitHiveJobOperator', - # Deprecated operator. Ignore it 'airflow.providers.google.cloud.operators.dataproc.DataprocSubmitPigJobOperator', - # Deprecated operator. Ignore it 'airflow.providers.google.cloud.operators.dataproc.DataprocSubmitPySparkJobOperator', - 'airflow.providers.google.cloud.operators.mlengine.MLEngineTrainingCancelJobOperator', - # Deprecated operator. Ignore it 'airflow.providers.google.cloud.operators.mlengine.MLEngineManageModelOperator', - # Deprecated operator. Ignore it 'airflow.providers.google.cloud.operators.mlengine.MLEngineManageVersionOperator', + 'airflow.providers.google.cloud.operators.dataflow.DataflowCreateJavaJobOperator', + 'airflow.providers.google.cloud.operators.bigquery.BigQueryPatchDatasetOperator', + 'airflow.providers.google.cloud.operators.dataflow.DataflowCreatePythonJobOperator', + 'airflow.providers.google.cloud.operators.bigquery.BigQueryExecuteQueryOperator', + } + + # Those operators should not have examples as they are never used standalone (they are abstract) + BASE_OPERATORS = { + 'airflow.providers.google.cloud.operators.compute.ComputeEngineBaseOperator', + 'airflow.providers.google.cloud.operators.cloud_sql.CloudSQLBaseOperator', + 'airflow.providers.google.cloud.operators.dataproc.DataprocJobBaseOperator', + } + + # Please at the examples to those operators at the earliest convenience :) + MISSING_EXAMPLES_FOR_OPERATORS = { + 'airflow.providers.google.cloud.operators.dataproc.DataprocInstantiateInlineWorkflowTemplateOperator', + 'airflow.providers.google.cloud.operators.mlengine.MLEngineTrainingCancelJobOperator', 'airflow.providers.google.cloud.operators.dlp.CloudDLPGetStoredInfoTypeOperator', 'airflow.providers.google.cloud.operators.dlp.CloudDLPReidentifyContentOperator', 'airflow.providers.google.cloud.operators.dlp.CloudDLPCreateDeidentifyTemplateOperator', @@ -217,8 +217,6 @@ class TestGoogleProviderProjectStructure(unittest.TestCase): 'airflow.providers.google.cloud.operators.dlp.CloudDLPRedactImageOperator', 'airflow.providers.google.cloud.operators.datastore.CloudDatastoreDeleteOperationOperator', 'airflow.providers.google.cloud.operators.datastore.CloudDatastoreGetOperationOperator', - # Base operator. Ignore it - 'airflow.providers.google.cloud.operators.compute.ComputeEngineBaseOperator', 'airflow.providers.google.cloud.sensors.gcs.GCSObjectUpdateSensor', 'airflow.providers.google.cloud.sensors.gcs.GCSUploadSessionCompleteSensor', } @@ -299,7 +297,11 @@ def test_missing_example_for_operator(self): print("example_paths=", example_paths) operators_paths = set(get_classes_from_file(f"{ROOT_FOLDER}/{filepath}")) missing_operators.extend(operators_paths - example_paths) - assert set(missing_operators) == self.MISSING_EXAMPLES_FOR_OPERATORS + full_set = set() + full_set.update(self.MISSING_EXAMPLES_FOR_OPERATORS) + full_set.update(self.DEPRECATED_OPERATORS) + full_set.update(self.BASE_OPERATORS) + assert set(missing_operators) == full_set @parameterized.expand( itertools.product(["_system.py", "_system_helper.py"], ["operators", "sensors", "transfers"]) diff --git a/tests/providers/google/cloud/operators/test_mlengine_utils.py b/tests/providers/google/cloud/operators/test_mlengine_utils.py index 37a753a03f13f..27f08862529bf 100644 --- a/tests/providers/google/cloud/operators/test_mlengine_utils.py +++ b/tests/providers/google/cloud/operators/test_mlengine_utils.py @@ -106,34 +106,16 @@ def test_successful_run(self): ) assert success_message['predictionOutput'] == result - with patch( - 'airflow.providers.google.cloud.operators.dataflow.DataflowHook' - ) as mock_dataflow_hook, patch( - 'airflow.providers.google.cloud.operators.dataflow.BeamHook' - ) as mock_beam_hook: - dataflow_hook_instance = mock_dataflow_hook.return_value - dataflow_hook_instance.start_python_dataflow.return_value = None + with patch('airflow.providers.apache.beam.operators.beam.BeamHook') as mock_beam_hook: beam_hook_instance = mock_beam_hook.return_value summary.execute(None) - mock_dataflow_hook.assert_called_once_with( - gcp_conn_id='google_cloud_default', - delegate_to=None, - poll_sleep=10, - drain_pipeline=False, - cancel_timeout=600, - wait_until_finished=None, - impersonation_chain=None, - ) - mock_beam_hook.assert_called_once_with(runner="DataflowRunner") + mock_beam_hook.assert_called_once_with(runner="DirectRunner") beam_hook_instance.start_python_pipeline.assert_called_once_with( variables={ 'prediction_path': 'gs://legal-bucket/fake-output-path', 'labels': {'airflow-version': TEST_VERSION}, 'metric_keys': 'err', 'metric_fn_encoded': self.metric_fn_encoded, - 'project': 'test-project', - 'region': 'us-central1', - 'job_name': mock.ANY, }, py_file=mock.ANY, py_options=[], @@ -142,9 +124,6 @@ def test_successful_run(self): py_system_site_packages=False, process_line_callback=mock.ANY, ) - dataflow_hook_instance.wait_for_done.assert_called_once_with( - job_name=mock.ANY, location='us-central1', job_id=mock.ANY, multiple_jobs=False - ) with patch('airflow.providers.google.cloud.utils.mlengine_operator_utils.GCSHook') as mock_gcs_hook: hook_instance = mock_gcs_hook.return_value diff --git a/tests/providers/google/cloud/utils/test_mlengine_operator_utils.py b/tests/providers/google/cloud/utils/test_mlengine_operator_utils.py index 5e19c17a35ce9..b50815cb2ca1a 100644 --- a/tests/providers/google/cloud/utils/test_mlengine_operator_utils.py +++ b/tests/providers/google/cloud/utils/test_mlengine_operator_utils.py @@ -27,8 +27,8 @@ from airflow.exceptions import AirflowException from airflow.models import DAG from airflow.operators.python import PythonOperator +from airflow.providers.apache.beam.operators.beam import BeamRunPythonPipelineOperator from airflow.providers.google.cloud.hooks.gcs import GCSHook -from airflow.providers.google.cloud.operators.dataflow import DataflowCreatePythonJobOperator from airflow.providers.google.cloud.utils.mlengine_operator_utils import create_evaluate_ops TASK_PREFIX = "test-task-prefix" @@ -93,8 +93,8 @@ def validate_err_and_count(summary): class TestMlengineOperatorUtils(unittest.TestCase): @mock.patch.object(PythonOperator, "set_upstream") - @mock.patch.object(DataflowCreatePythonJobOperator, "set_upstream") - def test_create_evaluate_ops(self, mock_dataflow, mock_python): + @mock.patch.object(BeamRunPythonPipelineOperator, "set_upstream") + def test_create_evaluate_ops(self, mock_beam_pipeline, mock_python): result = create_evaluate_ops( task_prefix=TASK_PREFIX, data_format=DATA_FORMAT, @@ -111,7 +111,7 @@ def test_create_evaluate_ops(self, mock_dataflow, mock_python): evaluate_prediction, evaluate_summary, evaluate_validation = result - mock_dataflow.assert_called_once_with(evaluate_prediction) + mock_beam_pipeline.assert_called_once_with(evaluate_prediction) mock_python.assert_called_once_with(evaluate_summary) assert TASK_PREFIX_PREDICTION == evaluate_prediction.task_id @@ -124,17 +124,17 @@ def test_create_evaluate_ops(self, mock_dataflow, mock_python): assert MODEL_URI == evaluate_prediction._uri assert TASK_PREFIX_SUMMARY == evaluate_summary.task_id - assert DATAFLOW_OPTIONS == evaluate_summary.dataflow_default_options - assert PREDICTION_PATH == evaluate_summary.options["prediction_path"] - assert METRIC_FN_ENCODED == evaluate_summary.options["metric_fn_encoded"] - assert METRIC_KEYS_EXPECTED == evaluate_summary.options["metric_keys"] + assert DATAFLOW_OPTIONS == evaluate_summary.default_pipeline_options + assert PREDICTION_PATH == evaluate_summary.pipeline_options["prediction_path"] + assert METRIC_FN_ENCODED == evaluate_summary.pipeline_options["metric_fn_encoded"] + assert METRIC_KEYS_EXPECTED == evaluate_summary.pipeline_options["metric_keys"] assert TASK_PREFIX_VALIDATION == evaluate_validation.task_id assert PREDICTION_PATH == evaluate_validation.templates_dict["prediction_path"] @mock.patch.object(PythonOperator, "set_upstream") - @mock.patch.object(DataflowCreatePythonJobOperator, "set_upstream") - def test_create_evaluate_ops_model_and_version_name(self, mock_dataflow, mock_python): + @mock.patch.object(BeamRunPythonPipelineOperator, "set_upstream") + def test_create_evaluate_ops_model_and_version_name(self, mock_beam_pipeline, mock_python): result = create_evaluate_ops( task_prefix=TASK_PREFIX, data_format=DATA_FORMAT, @@ -152,7 +152,7 @@ def test_create_evaluate_ops_model_and_version_name(self, mock_dataflow, mock_py evaluate_prediction, evaluate_summary, evaluate_validation = result - mock_dataflow.assert_called_once_with(evaluate_prediction) + mock_beam_pipeline.assert_called_once_with(evaluate_prediction) mock_python.assert_called_once_with(evaluate_summary) assert TASK_PREFIX_PREDICTION == evaluate_prediction.task_id @@ -166,16 +166,16 @@ def test_create_evaluate_ops_model_and_version_name(self, mock_dataflow, mock_py assert VERSION_NAME == evaluate_prediction._version_name assert TASK_PREFIX_SUMMARY == evaluate_summary.task_id - assert DATAFLOW_OPTIONS == evaluate_summary.dataflow_default_options - assert PREDICTION_PATH == evaluate_summary.options["prediction_path"] - assert METRIC_FN_ENCODED == evaluate_summary.options["metric_fn_encoded"] - assert METRIC_KEYS_EXPECTED == evaluate_summary.options["metric_keys"] + assert DATAFLOW_OPTIONS == evaluate_summary.default_pipeline_options + assert PREDICTION_PATH == evaluate_summary.pipeline_options["prediction_path"] + assert METRIC_FN_ENCODED == evaluate_summary.pipeline_options["metric_fn_encoded"] + assert METRIC_KEYS_EXPECTED == evaluate_summary.pipeline_options["metric_keys"] assert TASK_PREFIX_VALIDATION == evaluate_validation.task_id assert PREDICTION_PATH == evaluate_validation.templates_dict["prediction_path"] @mock.patch.object(PythonOperator, "set_upstream") - @mock.patch.object(DataflowCreatePythonJobOperator, "set_upstream") + @mock.patch.object(BeamRunPythonPipelineOperator, "set_upstream") def test_create_evaluate_ops_dag(self, mock_dataflow, mock_python): result = create_evaluate_ops( task_prefix=TASK_PREFIX, @@ -204,18 +204,18 @@ def test_create_evaluate_ops_dag(self, mock_dataflow, mock_python): assert VERSION_NAME == evaluate_prediction._version_name assert TASK_PREFIX_SUMMARY == evaluate_summary.task_id - assert DATAFLOW_OPTIONS == evaluate_summary.dataflow_default_options - assert PREDICTION_PATH == evaluate_summary.options["prediction_path"] - assert METRIC_FN_ENCODED == evaluate_summary.options["metric_fn_encoded"] - assert METRIC_KEYS_EXPECTED == evaluate_summary.options["metric_keys"] + assert DATAFLOW_OPTIONS == evaluate_summary.default_pipeline_options + assert PREDICTION_PATH == evaluate_summary.pipeline_options["prediction_path"] + assert METRIC_FN_ENCODED == evaluate_summary.pipeline_options["metric_fn_encoded"] + assert METRIC_KEYS_EXPECTED == evaluate_summary.pipeline_options["metric_keys"] assert TASK_PREFIX_VALIDATION == evaluate_validation.task_id assert PREDICTION_PATH == evaluate_validation.templates_dict["prediction_path"] @mock.patch.object(GCSHook, "download") @mock.patch.object(PythonOperator, "set_upstream") - @mock.patch.object(DataflowCreatePythonJobOperator, "set_upstream") - def test_apply_validate_fn(self, mock_dataflow, mock_python, mock_download): + @mock.patch.object(BeamRunPythonPipelineOperator, "set_upstream") + def test_apply_validate_fn(self, mock_beam_pipeline, mock_python, mock_download): result = create_evaluate_ops( task_prefix=TASK_PREFIX, data_format=DATA_FORMAT,