From fece91ab1f3927bc2c73be4deadb54692644dc7e Mon Sep 17 00:00:00 2001 From: Jordan Yaker Date: Fri, 22 Nov 2024 11:53:36 -0500 Subject: [PATCH 1/2] Add export_options to export_to_file operator The current implementation of the export_to_file operator doesn't provide any mechanism for overriding or customizing the behavior of the file output. Instead the operator simply calls one of a number of Pandas.DataFrame.to_* functions with default functions that SDK users are locked in to. This modification allows for the provision of a configuration dictionary that enables the customizing or overriding of the file write behavior by passing the parameters to the FileType.create_from_dataframe implementation. --- python-sdk/src/astro/files/base.py | 6 ++++-- python-sdk/src/astro/files/types/base.py | 3 ++- python-sdk/src/astro/files/types/csv.py | 6 ++++-- python-sdk/src/astro/files/types/excel.py | 5 +++-- python-sdk/src/astro/files/types/json.py | 5 +++-- python-sdk/src/astro/files/types/ndjson.py | 5 +++-- python-sdk/src/astro/files/types/parquet.py | 5 +++-- .../src/astro/sql/operators/export_to_file.py | 12 +++++++++--- .../tests/sql/operators/test_export_file.py | 19 +++++++++++++++++++ 9 files changed, 50 insertions(+), 16 deletions(-) diff --git a/python-sdk/src/astro/files/base.py b/python-sdk/src/astro/files/base.py index 1b9ba8f6a..2910d3098 100644 --- a/python-sdk/src/astro/files/base.py +++ b/python-sdk/src/astro/files/base.py @@ -114,18 +114,20 @@ def is_pattern(self) -> bool: """ return not pathlib.PosixPath(self.path).suffix - def create_from_dataframe(self, df: pd.DataFrame, store_as_dataframe: bool = True) -> None: + def create_from_dataframe(self, df: pd.DataFrame, store_as_dataframe: bool = True, export_options: dict | None = None) -> None: """Create a file in the desired location using the values of a dataframe. :param store_as_dataframe: Whether the data should later be deserialized as a dataframe or as a file containing delimited data (e.g. csv, parquet, etc.). :param df: pandas dataframe + :param export_options: additional arguments to pass to the underlying write functionality """ self.is_dataframe = store_as_dataframe + opts = export_options or {} with self.location.get_stream() as stream: - self.type.create_from_dataframe(stream=stream, df=df) + self.type.create_from_dataframe(stream=stream, df=df, **opts) @property def openlineage_dataset_namespace(self) -> str: diff --git a/python-sdk/src/astro/files/types/base.py b/python-sdk/src/astro/files/types/base.py index 48dcdda5e..d71a042f3 100644 --- a/python-sdk/src/astro/files/types/base.py +++ b/python-sdk/src/astro/files/types/base.py @@ -27,11 +27,12 @@ def export_to_dataframe(self, stream, **kwargs) -> pd.DataFrame: raise NotImplementedError @abstractmethod - def create_from_dataframe(self, df: pd.DataFrame, stream: io.TextIOWrapper) -> None: + def create_from_dataframe(self, df: pd.DataFrame, stream: io.TextIOWrapper, **kwargs) -> None: """Write file to one of the supported locations :param df: pandas dataframe :param stream: file stream object + :param kwargs: additional arguments to pass to the underlying write functionality """ raise NotImplementedError diff --git a/python-sdk/src/astro/files/types/csv.py b/python-sdk/src/astro/files/types/csv.py index f5d6e6229..2aae0bf7e 100644 --- a/python-sdk/src/astro/files/types/csv.py +++ b/python-sdk/src/astro/files/types/csv.py @@ -38,13 +38,15 @@ def export_to_dataframe( return PandasDataframe.from_pandas_df(df) # We need skipcq because it's a method overloading so we don't want to make it a static method - def create_from_dataframe(self, df: pd.DataFrame, stream: io.TextIOWrapper) -> None: # skipcq PYL-R0201 + def create_from_dataframe(self, df: pd.DataFrame, stream: io.TextIOWrapper, **kwargs) -> None: # skipcq PYL-R0201 """Write csv file to one of the supported locations :param df: pandas dataframe :param stream: file stream object + :param kwargs: additional arguments to pass to the pandas `to_csv` function """ - df.to_csv(stream, index=False) + + df.to_csv(stream, **dict(index=False, **kwargs)) @property def name(self): diff --git a/python-sdk/src/astro/files/types/excel.py b/python-sdk/src/astro/files/types/excel.py index 1073deaaf..e4e6b434a 100644 --- a/python-sdk/src/astro/files/types/excel.py +++ b/python-sdk/src/astro/files/types/excel.py @@ -37,10 +37,11 @@ def export_to_dataframe( return PandasDataframe.from_pandas_df(df) # We need skipcq because it's a method overloading so we don't want to make it a static method - def create_from_dataframe(self, df: pd.DataFrame, stream: io.TextIOWrapper) -> None: # skipcq PYL-R0201 + def create_from_dataframe(self, df: pd.DataFrame, stream: io.TextIOWrapper, **kwargs) -> None: # skipcq PYL-R0201 """Write Excel file to one of the supported locations :param df: pandas dataframe :param stream: file stream object + :param kwargs: additional arguments to pass to the pandas `to_excel` function """ - df.to_excel(stream, index=False) + df.to_excel(stream, **dict(index=False, **kwargs)) diff --git a/python-sdk/src/astro/files/types/json.py b/python-sdk/src/astro/files/types/json.py index 91cf878f7..18a153e70 100644 --- a/python-sdk/src/astro/files/types/json.py +++ b/python-sdk/src/astro/files/types/json.py @@ -42,13 +42,14 @@ def export_to_dataframe( return PandasDataframe.from_pandas_df(df) # We need skipcq because it's a method overloading so we don't want to make it a static method - def create_from_dataframe(self, df: pd.DataFrame, stream: io.TextIOWrapper) -> None: # skipcq PYL-R0201 + def create_from_dataframe(self, df: pd.DataFrame, stream: io.TextIOWrapper, **kwargs) -> None: # skipcq PYL-R0201 """Write json file to one of the supported locations :param df: pandas dataframe :param stream: file stream object + :param kwargs: additional arguments to pass to the pandas `to_json` function """ - df.to_json(stream, orient="records") + df.to_json(stream, **dict(orient="records", **kwargs)) @property def name(self): diff --git a/python-sdk/src/astro/files/types/ndjson.py b/python-sdk/src/astro/files/types/ndjson.py index 5bd92b33f..94935167d 100644 --- a/python-sdk/src/astro/files/types/ndjson.py +++ b/python-sdk/src/astro/files/types/ndjson.py @@ -39,13 +39,14 @@ def export_to_dataframe( return PandasDataframe.from_pandas_df(df) # We need skipcq because it's a method overloading so we don't want to make it a static method - def create_from_dataframe(self, df: pd.DataFrame, stream: io.TextIOWrapper) -> None: # skipcq PYL-R0201 + def create_from_dataframe(self, df: pd.DataFrame, stream: io.TextIOWrapper, **kwargs) -> None: # skipcq PYL-R0201 """Write ndjson file to one of the supported locations :param df: pandas dataframe :param stream: file stream object + :param kwargs: additional arguments to pass to the pandas `to_json` function """ - df.to_json(stream, orient="records", lines=True) + df.to_json(stream, **dict(orient="records", lines=True, **kwargs)) @property def name(self): diff --git a/python-sdk/src/astro/files/types/parquet.py b/python-sdk/src/astro/files/types/parquet.py index a6213dda2..1a61e7446 100644 --- a/python-sdk/src/astro/files/types/parquet.py +++ b/python-sdk/src/astro/files/types/parquet.py @@ -57,13 +57,14 @@ def _convert_remote_file_to_byte_stream(stream) -> io.IOBase: return remote_obj_buffer # We need skipcq because it's a method overloading so we don't want to make it a static method - def create_from_dataframe(self, df: pd.DataFrame, stream: io.TextIOWrapper) -> None: # skipcq PYL-R0201 + def create_from_dataframe(self, df: pd.DataFrame, stream: io.TextIOWrapper, **kwargs) -> None: # skipcq PYL-R0201 """Write parquet file to one of the supported locations :param df: pandas dataframe :param stream: file stream object + :param kwargs: additional arguments to pass to the pandas `to_parquet` method """ - df.to_parquet(stream) + df.to_parquet(stream, **kwargs) @property def name(self): diff --git a/python-sdk/src/astro/sql/operators/export_to_file.py b/python-sdk/src/astro/sql/operators/export_to_file.py index e4be82201..27a00a409 100644 --- a/python-sdk/src/astro/sql/operators/export_to_file.py +++ b/python-sdk/src/astro/sql/operators/export_to_file.py @@ -21,20 +21,23 @@ class ExportToFileOperator(AstroSQLBaseOperator): :param input_data: Table to convert to file :param output_file: File object containing the path to the file and connection id. :param if_exists: Overwrite file if exists. Default False. + :param export_options: Additional options to pass to the file export functions. """ - template_fields = ("input_data", "output_file") + template_fields = ("input_data", "output_file", "export_options") def __init__( self, input_data: BaseTable | pd.DataFrame, output_file: File, if_exists: ExportExistsStrategy = "exception", + export_options: dict | None = None, **kwargs, ) -> None: self.output_file = output_file self.input_data = input_data self.if_exists = if_exists + self.export_options = export_options or {} self.kwargs = kwargs datasets = {"output_datasets": self.output_file} if isinstance(input_data, Table): @@ -57,7 +60,7 @@ def execute(self, context: Context) -> File: # skipcq PYL-W0613 raise ValueError(f"Expected input_table to be Table or dataframe. Got {type(self.input_data)}") # Write file if overwrite == True or if file doesn't exist. if self.if_exists == "replace" or not self.output_file.exists(): - self.output_file.create_from_dataframe(df, store_as_dataframe=False) + self.output_file.create_from_dataframe(df, store_as_dataframe=False, export_options=self.export_options) return self.output_file else: raise FileExistsError(f"{self.output_file.path} file already exists.") @@ -144,7 +147,8 @@ def export_to_file( output_file: File, if_exists: ExportExistsStrategy = "exception", task_id: str | None = None, - **kwargs: Any, + export_options: dict | None = None, + **kwargs, ) -> XComArg: """Convert ExportToFileOperator into a function. Returns XComArg. @@ -170,6 +174,7 @@ def export_to_file( :param input_data: Input table / dataframe :param if_exists: Overwrite file if exists. Default "exception" :param task_id: task id, optional + :param export_options: Additional options to pass to the file export functions. """ task_id = task_id or get_unique_task_id("export_to_file") @@ -179,5 +184,6 @@ def export_to_file( output_file=output_file, input_data=input_data, if_exists=if_exists, + export_options=export_options, **kwargs, ).output diff --git a/python-sdk/tests/sql/operators/test_export_file.py b/python-sdk/tests/sql/operators/test_export_file.py index b474197e5..cb97b231f 100644 --- a/python-sdk/tests/sql/operators/test_export_file.py +++ b/python-sdk/tests/sql/operators/test_export_file.py @@ -37,6 +37,25 @@ def make_df(): assert df.equals(pd.DataFrame(data={"col1": [1, 2], "col2": [3, 4]})) +def test_save_dataframe_to_local_with_options(sample_dag): + @aql.dataframe + def make_df(): + return pd.DataFrame(data={"col1": [1, 2], "col2": [3, 4]}) + + with sample_dag: + df = make_df() + aql.export_to_file( + input_data=df, + output_file=File(path="/tmp/saved_df.csv"), + if_exists="replace", + export_options={"header": None}, + ) + test_utils.run_dag(sample_dag) + + df = pd.read_csv("/tmp/saved_df.csv") + assert df.equals(pd.DataFrame(data={"0": [1, 2], "1": [3, 4]})) + + @pytest.mark.parametrize("database_table_fixture", [{"database": Database.SQLITE}], indirect=True) def test_save_temp_table_to_local(sample_dag, database_table_fixture): _, test_table = database_table_fixture From 7eda15f8790a3f4d94f089058290d976c9f1cba3 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 22 Nov 2024 21:32:06 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- python-sdk/src/astro/files/base.py | 4 +++- python-sdk/src/astro/files/types/csv.py | 4 +++- python-sdk/src/astro/files/types/excel.py | 4 +++- python-sdk/src/astro/files/types/json.py | 4 +++- python-sdk/src/astro/files/types/ndjson.py | 4 +++- python-sdk/src/astro/files/types/parquet.py | 4 +++- python-sdk/src/astro/sql/operators/export_to_file.py | 6 +++--- 7 files changed, 21 insertions(+), 9 deletions(-) diff --git a/python-sdk/src/astro/files/base.py b/python-sdk/src/astro/files/base.py index 2910d3098..6ace3ef05 100644 --- a/python-sdk/src/astro/files/base.py +++ b/python-sdk/src/astro/files/base.py @@ -114,7 +114,9 @@ def is_pattern(self) -> bool: """ return not pathlib.PosixPath(self.path).suffix - def create_from_dataframe(self, df: pd.DataFrame, store_as_dataframe: bool = True, export_options: dict | None = None) -> None: + def create_from_dataframe( + self, df: pd.DataFrame, store_as_dataframe: bool = True, export_options: dict | None = None + ) -> None: """Create a file in the desired location using the values of a dataframe. :param store_as_dataframe: Whether the data should later be deserialized as a dataframe or as a file containing diff --git a/python-sdk/src/astro/files/types/csv.py b/python-sdk/src/astro/files/types/csv.py index 2aae0bf7e..e830206ee 100644 --- a/python-sdk/src/astro/files/types/csv.py +++ b/python-sdk/src/astro/files/types/csv.py @@ -38,7 +38,9 @@ def export_to_dataframe( return PandasDataframe.from_pandas_df(df) # We need skipcq because it's a method overloading so we don't want to make it a static method - def create_from_dataframe(self, df: pd.DataFrame, stream: io.TextIOWrapper, **kwargs) -> None: # skipcq PYL-R0201 + def create_from_dataframe( + self, df: pd.DataFrame, stream: io.TextIOWrapper, **kwargs + ) -> None: # skipcq PYL-R0201 """Write csv file to one of the supported locations :param df: pandas dataframe diff --git a/python-sdk/src/astro/files/types/excel.py b/python-sdk/src/astro/files/types/excel.py index e4e6b434a..ae49b497f 100644 --- a/python-sdk/src/astro/files/types/excel.py +++ b/python-sdk/src/astro/files/types/excel.py @@ -37,7 +37,9 @@ def export_to_dataframe( return PandasDataframe.from_pandas_df(df) # We need skipcq because it's a method overloading so we don't want to make it a static method - def create_from_dataframe(self, df: pd.DataFrame, stream: io.TextIOWrapper, **kwargs) -> None: # skipcq PYL-R0201 + def create_from_dataframe( + self, df: pd.DataFrame, stream: io.TextIOWrapper, **kwargs + ) -> None: # skipcq PYL-R0201 """Write Excel file to one of the supported locations :param df: pandas dataframe diff --git a/python-sdk/src/astro/files/types/json.py b/python-sdk/src/astro/files/types/json.py index 18a153e70..8e5802bbb 100644 --- a/python-sdk/src/astro/files/types/json.py +++ b/python-sdk/src/astro/files/types/json.py @@ -42,7 +42,9 @@ def export_to_dataframe( return PandasDataframe.from_pandas_df(df) # We need skipcq because it's a method overloading so we don't want to make it a static method - def create_from_dataframe(self, df: pd.DataFrame, stream: io.TextIOWrapper, **kwargs) -> None: # skipcq PYL-R0201 + def create_from_dataframe( + self, df: pd.DataFrame, stream: io.TextIOWrapper, **kwargs + ) -> None: # skipcq PYL-R0201 """Write json file to one of the supported locations :param df: pandas dataframe diff --git a/python-sdk/src/astro/files/types/ndjson.py b/python-sdk/src/astro/files/types/ndjson.py index 94935167d..2a0bd63ed 100644 --- a/python-sdk/src/astro/files/types/ndjson.py +++ b/python-sdk/src/astro/files/types/ndjson.py @@ -39,7 +39,9 @@ def export_to_dataframe( return PandasDataframe.from_pandas_df(df) # We need skipcq because it's a method overloading so we don't want to make it a static method - def create_from_dataframe(self, df: pd.DataFrame, stream: io.TextIOWrapper, **kwargs) -> None: # skipcq PYL-R0201 + def create_from_dataframe( + self, df: pd.DataFrame, stream: io.TextIOWrapper, **kwargs + ) -> None: # skipcq PYL-R0201 """Write ndjson file to one of the supported locations :param df: pandas dataframe diff --git a/python-sdk/src/astro/files/types/parquet.py b/python-sdk/src/astro/files/types/parquet.py index 1a61e7446..c6bf0b010 100644 --- a/python-sdk/src/astro/files/types/parquet.py +++ b/python-sdk/src/astro/files/types/parquet.py @@ -57,7 +57,9 @@ def _convert_remote_file_to_byte_stream(stream) -> io.IOBase: return remote_obj_buffer # We need skipcq because it's a method overloading so we don't want to make it a static method - def create_from_dataframe(self, df: pd.DataFrame, stream: io.TextIOWrapper, **kwargs) -> None: # skipcq PYL-R0201 + def create_from_dataframe( + self, df: pd.DataFrame, stream: io.TextIOWrapper, **kwargs + ) -> None: # skipcq PYL-R0201 """Write parquet file to one of the supported locations :param df: pandas dataframe diff --git a/python-sdk/src/astro/sql/operators/export_to_file.py b/python-sdk/src/astro/sql/operators/export_to_file.py index 27a00a409..fff2d3a5b 100644 --- a/python-sdk/src/astro/sql/operators/export_to_file.py +++ b/python-sdk/src/astro/sql/operators/export_to_file.py @@ -1,7 +1,5 @@ from __future__ import annotations -from typing import Any - import pandas as pd from airflow.decorators.base import get_unique_task_id from airflow.models.xcom_arg import XComArg @@ -60,7 +58,9 @@ def execute(self, context: Context) -> File: # skipcq PYL-W0613 raise ValueError(f"Expected input_table to be Table or dataframe. Got {type(self.input_data)}") # Write file if overwrite == True or if file doesn't exist. if self.if_exists == "replace" or not self.output_file.exists(): - self.output_file.create_from_dataframe(df, store_as_dataframe=False, export_options=self.export_options) + self.output_file.create_from_dataframe( + df, store_as_dataframe=False, export_options=self.export_options + ) return self.output_file else: raise FileExistsError(f"{self.output_file.path} file already exists.")