Skip to content

Commit

Permalink
Sql to gcs with exclude columns (#23695)
Browse files Browse the repository at this point in the history
  • Loading branch information
jaegwonseo authored May 22, 2022
1 parent 69f444f commit 65f3b18
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 2 deletions.
16 changes: 14 additions & 2 deletions airflow/providers/google/cloud/transfers/sql_to_gcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ class BaseSQLToGCSOperator(BaseOperator):
If set as a sequence, the identities from the list must grant
Service Account Token Creator IAM role to the directly preceding identity, with first
account from the list granting this role to the originating account (templated).
:param exclude_columns: set of columns to exclude from transmission
"""

template_fields: Sequence[str] = (
Expand Down Expand Up @@ -103,9 +104,13 @@ def __init__(
gcp_conn_id: str = 'google_cloud_default',
delegate_to: Optional[str] = None,
impersonation_chain: Optional[Union[str, Sequence[str]]] = None,
exclude_columns=None,
**kwargs,
) -> None:
super().__init__(**kwargs)
if exclude_columns is None:
exclude_columns = set()

self.sql = sql
self.bucket = bucket
self.filename = filename
Expand All @@ -120,6 +125,7 @@ def __init__(
self.gcp_conn_id = gcp_conn_id
self.delegate_to = delegate_to
self.impersonation_chain = impersonation_chain
self.exclude_columns = exclude_columns

def execute(self, context: 'Context'):
self.log.info("Executing query")
Expand Down Expand Up @@ -165,7 +171,9 @@ def _write_local_data_files(self, cursor):
names in GCS, and values are file handles to local files that
contain the data for the GCS objects.
"""
schema = list(map(lambda schema_tuple: schema_tuple[0], cursor.description))
org_schema = list(map(lambda schema_tuple: schema_tuple[0], cursor.description))
schema = [column for column in org_schema if column not in self.exclude_columns]

col_type_dict = self._get_col_type_dict()
file_no = 0

Expand Down Expand Up @@ -314,7 +322,11 @@ def _write_local_schema_file(self, cursor):
schema = self.schema
else:
self.log.info("Starts generating schema")
schema = [self.field_to_bigquery(field) for field in cursor.description]
schema = [
self.field_to_bigquery(field)
for field in cursor.description
if field[0] not in self.exclude_columns
]

if isinstance(schema, list):
schema = json.dumps(schema, sort_keys=True)
Expand Down
Binary file removed tests/providers/google/cloud/transfers/temp-file
Binary file not shown.
29 changes: 29 additions & 0 deletions tests/providers/google/cloud/transfers/test_sql_to_gcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,12 @@

OUTPUT_DF = pd.DataFrame([['convert_type_return_value'] * 3] * 3, columns=COLUMNS)

EXCLUDE_COLUMNS = set('column_c')
NEW_COLUMNS = [c for c in COLUMNS if c not in EXCLUDE_COLUMNS]
OUTPUT_DF_WITH_EXCLUDE_COLUMNS = pd.DataFrame(
[['convert_type_return_value'] * len(NEW_COLUMNS)] * 3, columns=NEW_COLUMNS
)


class DummySQLToGCSOperator(BaseSQLToGCSOperator):
def field_to_bigquery(self, field) -> Dict[str, str]:
Expand Down Expand Up @@ -287,3 +293,26 @@ def test__write_local_data_files_parquet(self):
file.flush()
df = pd.read_parquet(file.name)
assert df.equals(OUTPUT_DF)

def test__write_local_data_files_json_with_exclude_columns(self):
op = DummySQLToGCSOperator(
sql=SQL,
bucket=BUCKET,
filename=FILENAME,
task_id=TASK_ID,
schema_filename=SCHEMA_FILE,
export_format="json",
gzip=False,
schema=SCHEMA,
gcp_conn_id='google_cloud_default',
exclude_columns=EXCLUDE_COLUMNS,
)
cursor = MagicMock()
cursor.__iter__.return_value = INPUT_DATA
cursor.description = CURSOR_DESCRIPTION

files = op._write_local_data_files(cursor)
file = next(files)['file_handle']
file.flush()
df = pd.read_json(file.name, orient='records', lines=True)
assert df.equals(OUTPUT_DF_WITH_EXCLUDE_COLUMNS)

0 comments on commit 65f3b18

Please sign in to comment.