diff --git a/airflow/providers/google/cloud/transfers/sql_to_gcs.py b/airflow/providers/google/cloud/transfers/sql_to_gcs.py index 0df46bc7773cb..46e1ad505d784 100644 --- a/airflow/providers/google/cloud/transfers/sql_to_gcs.py +++ b/airflow/providers/google/cloud/transfers/sql_to_gcs.py @@ -71,6 +71,7 @@ class BaseSQLToGCSOperator(BaseOperator): If set as a sequence, the identities from the list must grant Service Account Token Creator IAM role to the directly preceding identity, with first account from the list granting this role to the originating account (templated). + :param exclude_columns: set of columns to exclude from transmission """ template_fields: Sequence[str] = ( @@ -103,9 +104,13 @@ def __init__( gcp_conn_id: str = 'google_cloud_default', delegate_to: Optional[str] = None, impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + exclude_columns=None, **kwargs, ) -> None: super().__init__(**kwargs) + if exclude_columns is None: + exclude_columns = set() + self.sql = sql self.bucket = bucket self.filename = filename @@ -120,6 +125,7 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.delegate_to = delegate_to self.impersonation_chain = impersonation_chain + self.exclude_columns = exclude_columns def execute(self, context: 'Context'): self.log.info("Executing query") @@ -165,7 +171,9 @@ def _write_local_data_files(self, cursor): names in GCS, and values are file handles to local files that contain the data for the GCS objects. """ - schema = list(map(lambda schema_tuple: schema_tuple[0], cursor.description)) + org_schema = list(map(lambda schema_tuple: schema_tuple[0], cursor.description)) + schema = [column for column in org_schema if column not in self.exclude_columns] + col_type_dict = self._get_col_type_dict() file_no = 0 @@ -314,7 +322,11 @@ def _write_local_schema_file(self, cursor): schema = self.schema else: self.log.info("Starts generating schema") - schema = [self.field_to_bigquery(field) for field in cursor.description] + schema = [ + self.field_to_bigquery(field) + for field in cursor.description + if field[0] not in self.exclude_columns + ] if isinstance(schema, list): schema = json.dumps(schema, sort_keys=True) diff --git a/tests/providers/google/cloud/transfers/temp-file b/tests/providers/google/cloud/transfers/temp-file deleted file mode 100644 index d2282fc46c665..0000000000000 Binary files a/tests/providers/google/cloud/transfers/temp-file and /dev/null differ diff --git a/tests/providers/google/cloud/transfers/test_sql_to_gcs.py b/tests/providers/google/cloud/transfers/test_sql_to_gcs.py index 525e04bd0e0a7..824ab8ff317f3 100644 --- a/tests/providers/google/cloud/transfers/test_sql_to_gcs.py +++ b/tests/providers/google/cloud/transfers/test_sql_to_gcs.py @@ -61,6 +61,12 @@ OUTPUT_DF = pd.DataFrame([['convert_type_return_value'] * 3] * 3, columns=COLUMNS) +EXCLUDE_COLUMNS = set('column_c') +NEW_COLUMNS = [c for c in COLUMNS if c not in EXCLUDE_COLUMNS] +OUTPUT_DF_WITH_EXCLUDE_COLUMNS = pd.DataFrame( + [['convert_type_return_value'] * len(NEW_COLUMNS)] * 3, columns=NEW_COLUMNS +) + class DummySQLToGCSOperator(BaseSQLToGCSOperator): def field_to_bigquery(self, field) -> Dict[str, str]: @@ -287,3 +293,26 @@ def test__write_local_data_files_parquet(self): file.flush() df = pd.read_parquet(file.name) assert df.equals(OUTPUT_DF) + + def test__write_local_data_files_json_with_exclude_columns(self): + op = DummySQLToGCSOperator( + sql=SQL, + bucket=BUCKET, + filename=FILENAME, + task_id=TASK_ID, + schema_filename=SCHEMA_FILE, + export_format="json", + gzip=False, + schema=SCHEMA, + gcp_conn_id='google_cloud_default', + exclude_columns=EXCLUDE_COLUMNS, + ) + cursor = MagicMock() + cursor.__iter__.return_value = INPUT_DATA + cursor.description = CURSOR_DESCRIPTION + + files = op._write_local_data_files(cursor) + file = next(files)['file_handle'] + file.flush() + df = pd.read_json(file.name, orient='records', lines=True) + assert df.equals(OUTPUT_DF_WITH_EXCLUDE_COLUMNS)