From 80c1ce76e19d363916f2253cdd536372f6a43aee Mon Sep 17 00:00:00 2001 From: Wojciech Januszek Date: Mon, 6 Jun 2022 15:02:35 +0200 Subject: [PATCH] Cloud Storage assets & StorageLink update (#23865) Co-authored-by: Wojciech Januszek --- .../cloud/operators/dataproc_metastore.py | 2 +- .../google/cloud/operators/datastore.py | 1 + .../providers/google/cloud/operators/gcs.py | 57 +++++++++++++++++++ .../providers/google/common/links/storage.py | 4 +- .../google/cloud/operators/test_gcs.py | 14 +++-- 5 files changed, 69 insertions(+), 9 deletions(-) diff --git a/airflow/providers/google/cloud/operators/dataproc_metastore.py b/airflow/providers/google/cloud/operators/dataproc_metastore.py index d0ca4a5f28672..4bdf519d2f5b5 100644 --- a/airflow/providers/google/cloud/operators/dataproc_metastore.py +++ b/airflow/providers/google/cloud/operators/dataproc_metastore.py @@ -711,7 +711,7 @@ def execute(self, context: "Context"): DataprocMetastoreLink.persist(context=context, task_instance=self, url=METASTORE_EXPORT_LINK) uri = self._get_uri_from_destination(MetadataExport.to_dict(metadata_export)["destination_gcs_uri"]) - StorageLink.persist(context=context, task_instance=self, uri=uri) + StorageLink.persist(context=context, task_instance=self, uri=uri, project_id=self.project_id) return MetadataExport.to_dict(metadata_export) def _get_uri_from_destination(self, destination_uri: str): diff --git a/airflow/providers/google/cloud/operators/datastore.py b/airflow/providers/google/cloud/operators/datastore.py index 8a92665e3694e..db08d53ba787b 100644 --- a/airflow/providers/google/cloud/operators/datastore.py +++ b/airflow/providers/google/cloud/operators/datastore.py @@ -140,6 +140,7 @@ def execute(self, context: 'Context') -> dict: context=context, task_instance=self, uri=f"{self.bucket}/{result['response']['outputUrl'].split('/')[3]}", + project_id=self.project_id or ds_hook.project_id, ) return result diff --git a/airflow/providers/google/cloud/operators/gcs.py b/airflow/providers/google/cloud/operators/gcs.py index 27cc6f79bd108..bfc2d9691975b 100644 --- a/airflow/providers/google/cloud/operators/gcs.py +++ b/airflow/providers/google/cloud/operators/gcs.py @@ -35,6 +35,7 @@ from airflow.exceptions import AirflowException from airflow.models import BaseOperator from airflow.providers.google.cloud.hooks.gcs import GCSHook +from airflow.providers.google.common.links.storage import FileDetailsLink, StorageLink from airflow.utils import timezone @@ -107,6 +108,7 @@ class GCSCreateBucketOperator(BaseOperator): 'impersonation_chain', ) ui_color = '#f0eee4' + operator_extra_links = (StorageLink(),) def __init__( self, @@ -139,6 +141,12 @@ def execute(self, context: "Context") -> None: delegate_to=self.delegate_to, impersonation_chain=self.impersonation_chain, ) + StorageLink.persist( + context=context, + task_instance=self, + uri=self.bucket_name, + project_id=self.project_id or hook.project_id, + ) try: hook.create_bucket( bucket_name=self.bucket_name, @@ -200,6 +208,8 @@ class GCSListObjectsOperator(BaseOperator): ui_color = '#f0eee4' + operator_extra_links = (StorageLink(),) + def __init__( self, *, @@ -234,6 +244,13 @@ def execute(self, context: "Context") -> list: self.prefix, ) + StorageLink.persist( + context=context, + task_instance=self, + uri=self.bucket, + project_id=hook.project_id, + ) + return hook.list(bucket_name=self.bucket, prefix=self.prefix, delimiter=self.delimiter) @@ -346,6 +363,7 @@ class GCSBucketCreateAclEntryOperator(BaseOperator): 'impersonation_chain', ) # [END gcs_bucket_create_acl_template_fields] + operator_extra_links = (StorageLink(),) def __init__( self, @@ -371,6 +389,12 @@ def execute(self, context: "Context") -> None: gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, ) + StorageLink.persist( + context=context, + task_instance=self, + uri=self.bucket, + project_id=hook.project_id, + ) hook.insert_bucket_acl( bucket_name=self.bucket, entity=self.entity, role=self.role, user_project=self.user_project ) @@ -418,6 +442,7 @@ class GCSObjectCreateAclEntryOperator(BaseOperator): 'impersonation_chain', ) # [END gcs_object_create_acl_template_fields] + operator_extra_links = (FileDetailsLink(),) def __init__( self, @@ -447,6 +472,12 @@ def execute(self, context: "Context") -> None: gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, ) + FileDetailsLink.persist( + context=context, + task_instance=self, + uri=f"{self.bucket}/{self.object_name}", + project_id=hook.project_id, + ) hook.insert_object_acl( bucket_name=self.bucket, object_name=self.object_name, @@ -498,6 +529,7 @@ class GCSFileTransformOperator(BaseOperator): 'transform_script', 'impersonation_chain', ) + operator_extra_links = (FileDetailsLink(),) def __init__( self, @@ -549,6 +581,12 @@ def execute(self, context: "Context") -> None: self.log.info("Transformation succeeded. Output temporarily located at %s", destination_file.name) self.log.info("Uploading file to %s as %s", self.destination_bucket, self.destination_object) + FileDetailsLink.persist( + context=context, + task_instance=self, + uri=f"{self.destination_bucket}/{self.destination_object}", + project_id=hook.project_id, + ) hook.upload( bucket_name=self.destination_bucket, object_name=self.destination_object, @@ -628,6 +666,7 @@ class GCSTimeSpanFileTransformOperator(BaseOperator): 'source_impersonation_chain', 'destination_impersonation_chain', ) + operator_extra_links = (StorageLink(),) @staticmethod def interpolate_prefix(prefix: str, dt: datetime.datetime) -> Optional[str]: @@ -718,6 +757,12 @@ def execute(self, context: "Context") -> List[str]: gcp_conn_id=self.destination_gcp_conn_id, impersonation_chain=self.destination_impersonation_chain, ) + StorageLink.persist( + context=context, + task_instance=self, + uri=self.destination_bucket, + project_id=destination_hook.project_id, + ) # Fetch list of files. blobs_to_transform = source_hook.list_by_timespan( @@ -904,6 +949,7 @@ class GCSSynchronizeBucketsOperator(BaseOperator): 'delegate_to', 'impersonation_chain', ) + operator_extra_links = (StorageLink(),) def __init__( self, @@ -938,6 +984,12 @@ def execute(self, context: "Context") -> None: delegate_to=self.delegate_to, impersonation_chain=self.impersonation_chain, ) + StorageLink.persist( + context=context, + task_instance=self, + uri=self._get_uri(self.destination_bucket, self.destination_object), + project_id=hook.project_id, + ) hook.sync( source_bucket=self.source_bucket, destination_bucket=self.destination_bucket, @@ -947,3 +999,8 @@ def execute(self, context: "Context") -> None: delete_extra_files=self.delete_extra_files, allow_overwrite=self.allow_overwrite, ) + + def _get_uri(self, gcs_bucket: str, gcs_object: Optional[str]) -> str: + if gcs_object and gcs_object[-1] == "/": + gcs_object = gcs_object[:-1] + return f"{gcs_bucket}/{gcs_object}" if gcs_object else gcs_bucket diff --git a/airflow/providers/google/common/links/storage.py b/airflow/providers/google/common/links/storage.py index 7934d95d33419..013dcc25f9b12 100644 --- a/airflow/providers/google/common/links/storage.py +++ b/airflow/providers/google/common/links/storage.py @@ -36,11 +36,11 @@ class StorageLink(BaseGoogleLink): format_str = GCS_STORAGE_LINK @staticmethod - def persist(context: "Context", task_instance, uri: str): + def persist(context: "Context", task_instance, uri: str, project_id: Optional[str]): task_instance.xcom_push( context=context, key=StorageLink.key, - value={"uri": uri, "project_id": task_instance.project_id}, + value={"uri": uri, "project_id": project_id}, ) diff --git a/tests/providers/google/cloud/operators/test_gcs.py b/tests/providers/google/cloud/operators/test_gcs.py index cac11ccf03f74..3d6cba0374e4a 100644 --- a/tests/providers/google/cloud/operators/test_gcs.py +++ b/tests/providers/google/cloud/operators/test_gcs.py @@ -57,7 +57,7 @@ def test_execute(self, mock_hook): project_id=TEST_PROJECT, ) - operator.execute(None) + operator.execute(context=mock.MagicMock()) mock_hook.return_value.create_bucket.assert_called_once_with( bucket_name=TEST_BUCKET, storage_class="MULTI_REGIONAL", @@ -78,7 +78,7 @@ def test_bucket_create_acl(self, mock_hook): user_project="test-user-project", task_id="id", ) - operator.execute(None) + operator.execute(context=mock.MagicMock()) mock_hook.return_value.insert_bucket_acl.assert_called_once_with( bucket_name="test-bucket", entity="test-entity", @@ -97,7 +97,7 @@ def test_object_create_acl(self, mock_hook): user_project="test-user-project", task_id="id", ) - operator.execute(None) + operator.execute(context=mock.MagicMock()) mock_hook.return_value.insert_object_acl.assert_called_once_with( bucket_name="test-bucket", object_name="test-object", @@ -148,7 +148,7 @@ def test_execute(self, mock_hook): task_id=TASK_ID, bucket=TEST_BUCKET, prefix=PREFIX, delimiter=DELIMITER ) - files = operator.execute(None) + files = operator.execute(context=mock.MagicMock()) mock_hook.return_value.list.assert_called_once_with( bucket_name=TEST_BUCKET, prefix=PREFIX, delimiter=DELIMITER ) @@ -197,7 +197,7 @@ def test_execute(self, mock_hook, mock_subprocess, mock_tempfile): destination_bucket=destination_bucket, transform_script=transform_script, ) - op.execute(None) + op.execute(context=mock.MagicMock()) mock_hook.return_value.download.assert_called_once_with( bucket_name=source_bucket, object_name=source_object, filename=source @@ -273,9 +273,11 @@ def test_execute(self, mock_hook, mock_subprocess, mock_tempdir): timespan_end = timespan_start + timedelta(hours=1) mock_dag = mock.Mock() mock_dag.following_schedule = lambda x: x + timedelta(hours=1) + mock_ti = mock.Mock() context = dict( execution_date=timespan_start, dag=mock_dag, + ti=mock_ti, ) mock_tempdir.return_value.__enter__.side_effect = [source, destination] @@ -397,7 +399,7 @@ def test_execute(self, mock_hook): delegate_to="DELEGATE_TO", impersonation_chain=IMPERSONATION_CHAIN, ) - task.execute({}) + task.execute(context=mock.MagicMock()) mock_hook.assert_called_once_with( gcp_conn_id='GCP_CONN_ID', delegate_to='DELEGATE_TO',