Skip to content

Commit

Permalink
SqlToS3Operator - Add feature to partition SQL table (#30460)
Browse files Browse the repository at this point in the history
  • Loading branch information
utkarsharma2 authored Apr 18, 2023
1 parent 372a088 commit d7cef58
Show file tree
Hide file tree
Showing 4 changed files with 143 additions and 9 deletions.
32 changes: 23 additions & 9 deletions airflow/providers/amazon/aws/transfers/sql_to_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ class SqlToS3Operator(BaseOperator):
CA cert bundle than the one used by botocore.
:param file_format: the destination file format, only string 'csv', 'json' or 'parquet' is accepted.
:param pd_kwargs: arguments to include in DataFrame ``.to_parquet()``, ``.to_json()`` or ``.to_csv()``.
:param groupby_kwargs: argument to include in DataFrame ``groupby()``.
"""

template_fields: Sequence[str] = (
Expand Down Expand Up @@ -107,6 +108,7 @@ def __init__(
verify: bool | str | None = None,
file_format: Literal["csv", "json", "parquet"] = "csv",
pd_kwargs: dict | None = None,
groupby_kwargs: dict | None = None,
**kwargs,
) -> None:
super().__init__(**kwargs)
Expand All @@ -119,6 +121,7 @@ def __init__(
self.replace = replace
self.pd_kwargs = pd_kwargs or {}
self.parameters = parameters
self.groupby_kwargs = groupby_kwargs or {}

if "path_or_buf" in self.pd_kwargs:
raise AirflowException("The argument path_or_buf is not allowed, please remove it")
Expand Down Expand Up @@ -170,15 +173,26 @@ def execute(self, context: Context) -> None:
self._fix_dtypes(data_df, self.file_format)
file_options = FILE_OPTIONS_MAP[self.file_format]

with NamedTemporaryFile(mode=file_options.mode, suffix=file_options.suffix) as tmp_file:

self.log.info("Writing data to temp file")
getattr(data_df, file_options.function)(tmp_file.name, **self.pd_kwargs)

self.log.info("Uploading data to S3")
s3_conn.load_file(
filename=tmp_file.name, key=self.s3_key, bucket_name=self.s3_bucket, replace=self.replace
)
for group_name, df in self._partition_dataframe(df=data_df):
with NamedTemporaryFile(mode=file_options.mode, suffix=file_options.suffix) as tmp_file:

self.log.info("Writing data to temp file")
getattr(df, file_options.function)(tmp_file.name, **self.pd_kwargs)

self.log.info("Uploading data to S3")
object_key = f"{self.s3_key}_{group_name}" if group_name else self.s3_key
s3_conn.load_file(
filename=tmp_file.name, key=object_key, bucket_name=self.s3_bucket, replace=self.replace
)

def _partition_dataframe(self, df: DataFrame) -> Iterable[tuple[str, DataFrame]]:
"""Partition dataframe using pandas groupby() method"""
if not self.groupby_kwargs:
yield "", df
else:
grouped_df = df.groupby(**self.groupby_kwargs)
for group_label in grouped_df.groups.keys():
yield group_label, grouped_df.get_group(group_label).reset_index(drop=True)

def _get_hook(self) -> DbApiHook:
self.log.debug("Get connection for %s", self.sql_conn_id)
Expand Down
13 changes: 13 additions & 0 deletions docs/apache-airflow-providers-amazon/transfer/sql_to_s3.rst
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,19 @@ Example usage:
:start-after: [START howto_transfer_sql_to_s3]
:end-before: [END howto_transfer_sql_to_s3]

Grouping
--------

We can group the data in the table by passing the ``groupby_kwargs`` param. This param accepts a ``dict`` which will be passed to pandas `groupby() <https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DataFrame.groupby.html#pandas.DataFrame.groupby>`_ as kwargs.

Example usage:

.. exampleinclude:: /../../tests/system/providers/amazon/aws/example_sql_to_s3.py
:language: python
:dedent: 4
:start-after: [START howto_transfer_sql_to_s3_with_groupby_param]
:end-before: [END howto_transfer_sql_to_s3_with_groupby_param]

Reference
---------

Expand Down
94 changes: 94 additions & 0 deletions tests/providers/amazon/aws/transfers/test_sql_to_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,3 +175,97 @@ def test_invalid_file_format(self):
file_format="invalid_format",
dag=None,
)

def test_with_groupby_kwarg(self):
"""
Test operator when the groupby_kwargs is specified
"""
query = "query"
s3_bucket = "bucket"
s3_key = "key"

op = SqlToS3Operator(
query=query,
s3_bucket=s3_bucket,
s3_key=s3_key,
sql_conn_id="mysql_conn_id",
aws_conn_id="aws_conn_id",
task_id="task_id",
replace=True,
pd_kwargs={"index": False, "header": False},
groupby_kwargs={"by": "Team"},
dag=None,
)
example = {
"Team": ["Australia", "Australia", "India", "India"],
"Player": ["Ricky", "David Warner", "Virat Kohli", "Rohit Sharma"],
"Runs": [345, 490, 672, 560],
}

df = pd.DataFrame(example)
data = []
for group_name, df in op._partition_dataframe(df):
data.append((group_name, df))
data.sort(key=lambda d: d[0])
team, df = data[0]
assert df.equals(
pd.DataFrame(
{
"Team": ["Australia", "Australia"],
"Player": ["Ricky", "David Warner"],
"Runs": [345, 490],
}
)
)
team, df = data[1]
assert df.equals(
pd.DataFrame(
{
"Team": ["India", "India"],
"Player": ["Virat Kohli", "Rohit Sharma"],
"Runs": [672, 560],
}
)
)

def test_without_groupby_kwarg(self):
"""
Test operator when the groupby_kwargs is not specified
"""
query = "query"
s3_bucket = "bucket"
s3_key = "key"

op = SqlToS3Operator(
query=query,
s3_bucket=s3_bucket,
s3_key=s3_key,
sql_conn_id="mysql_conn_id",
aws_conn_id="aws_conn_id",
task_id="task_id",
replace=True,
pd_kwargs={"index": False, "header": False},
dag=None,
)
example = {
"Team": ["Australia", "Australia", "India", "India"],
"Player": ["Ricky", "David Warner", "Virat Kohli", "Rohit Sharma"],
"Runs": [345, 490, 672, 560],
}

df = pd.DataFrame(example)
data = []
for group_name, df in op._partition_dataframe(df):
data.append((group_name, df))

assert len(data) == 1
team, df = data[0]
assert df.equals(
pd.DataFrame(
{
"Team": ["Australia", "Australia", "India", "India"],
"Player": ["Ricky", "David Warner", "Virat Kohli", "Rohit Sharma"],
"Runs": [345, 490, 672, 560],
}
)
)
13 changes: 13 additions & 0 deletions tests/system/providers/amazon/aws/example_sql_to_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,18 @@ def delete_security_group(sec_group_id: str, sec_group_name: str):
)
# [END howto_transfer_sql_to_s3]

# [START howto_transfer_sql_to_s3_with_groupby_param]
sql_to_s3_task_with_groupby = SqlToS3Operator(
task_id="sql_to_s3_with_groupby_task",
sql_conn_id=conn_id_name,
query=SQL_QUERY,
s3_bucket=bucket_name,
s3_key=key,
replace=True,
groupby_kwargs={"by": "color"},
)
# [END howto_transfer_sql_to_s3_with_groupby_param]

delete_bucket = S3DeleteBucketOperator(
task_id="delete_bucket",
bucket_name=bucket_name,
Expand Down Expand Up @@ -202,6 +214,7 @@ def delete_security_group(sec_group_id: str, sec_group_name: str):
insert_data,
# TEST BODY
sql_to_s3_task,
sql_to_s3_task_with_groupby,
# TEST TEARDOWN
delete_bucket,
delete_cluster,
Expand Down

0 comments on commit d7cef58

Please sign in to comment.