Skip to content

Commit

Permalink
Add export_options to export_to_file operator
Browse files Browse the repository at this point in the history
The current implementation of the export_to_file operator doesn't provide any mechanism for overriding or customizing the behavior of the file output. Instead the operator simply calls one of a number of Pandas.DataFrame.to_* functions with default functions that SDK users are locked in to.

This modification allows for the provision of a configuration dictionary that enables the customizing or overriding of the file write behavior by passing the parameters to the FileType.create_from_dataframe implementation.
  • Loading branch information
jordanyakerirt committed Nov 22, 2024
1 parent 33ca675 commit fece91a
Show file tree
Hide file tree
Showing 9 changed files with 50 additions and 16 deletions.
6 changes: 4 additions & 2 deletions python-sdk/src/astro/files/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,18 +114,20 @@ def is_pattern(self) -> bool:
"""
return not pathlib.PosixPath(self.path).suffix

def create_from_dataframe(self, df: pd.DataFrame, store_as_dataframe: bool = True) -> None:
def create_from_dataframe(self, df: pd.DataFrame, store_as_dataframe: bool = True, export_options: dict | None = None) -> None:
"""Create a file in the desired location using the values of a dataframe.
:param store_as_dataframe: Whether the data should later be deserialized as a dataframe or as a file containing
delimited data (e.g. csv, parquet, etc.).
:param df: pandas dataframe
:param export_options: additional arguments to pass to the underlying write functionality
"""

self.is_dataframe = store_as_dataframe
opts = export_options or {}

with self.location.get_stream() as stream:
self.type.create_from_dataframe(stream=stream, df=df)
self.type.create_from_dataframe(stream=stream, df=df, **opts)

@property
def openlineage_dataset_namespace(self) -> str:
Expand Down
3 changes: 2 additions & 1 deletion python-sdk/src/astro/files/types/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,12 @@ def export_to_dataframe(self, stream, **kwargs) -> pd.DataFrame:
raise NotImplementedError

@abstractmethod
def create_from_dataframe(self, df: pd.DataFrame, stream: io.TextIOWrapper) -> None:
def create_from_dataframe(self, df: pd.DataFrame, stream: io.TextIOWrapper, **kwargs) -> None:
"""Write file to one of the supported locations
:param df: pandas dataframe
:param stream: file stream object
:param kwargs: additional arguments to pass to the underlying write functionality
"""
raise NotImplementedError

Expand Down
6 changes: 4 additions & 2 deletions python-sdk/src/astro/files/types/csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,15 @@ def export_to_dataframe(
return PandasDataframe.from_pandas_df(df)

# We need skipcq because it's a method overloading so we don't want to make it a static method
def create_from_dataframe(self, df: pd.DataFrame, stream: io.TextIOWrapper) -> None: # skipcq PYL-R0201
def create_from_dataframe(self, df: pd.DataFrame, stream: io.TextIOWrapper, **kwargs) -> None: # skipcq PYL-R0201
"""Write csv file to one of the supported locations
:param df: pandas dataframe
:param stream: file stream object
:param kwargs: additional arguments to pass to the pandas `to_csv` function
"""
df.to_csv(stream, index=False)

df.to_csv(stream, **dict(index=False, **kwargs))

@property
def name(self):
Expand Down
5 changes: 3 additions & 2 deletions python-sdk/src/astro/files/types/excel.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,11 @@ def export_to_dataframe(
return PandasDataframe.from_pandas_df(df)

# We need skipcq because it's a method overloading so we don't want to make it a static method
def create_from_dataframe(self, df: pd.DataFrame, stream: io.TextIOWrapper) -> None: # skipcq PYL-R0201
def create_from_dataframe(self, df: pd.DataFrame, stream: io.TextIOWrapper, **kwargs) -> None: # skipcq PYL-R0201
"""Write Excel file to one of the supported locations
:param df: pandas dataframe
:param stream: file stream object
:param kwargs: additional arguments to pass to the pandas `to_excel` function
"""
df.to_excel(stream, index=False)
df.to_excel(stream, **dict(index=False, **kwargs))
5 changes: 3 additions & 2 deletions python-sdk/src/astro/files/types/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,14 @@ def export_to_dataframe(
return PandasDataframe.from_pandas_df(df)

# We need skipcq because it's a method overloading so we don't want to make it a static method
def create_from_dataframe(self, df: pd.DataFrame, stream: io.TextIOWrapper) -> None: # skipcq PYL-R0201
def create_from_dataframe(self, df: pd.DataFrame, stream: io.TextIOWrapper, **kwargs) -> None: # skipcq PYL-R0201
"""Write json file to one of the supported locations
:param df: pandas dataframe
:param stream: file stream object
:param kwargs: additional arguments to pass to the pandas `to_json` function
"""
df.to_json(stream, orient="records")
df.to_json(stream, **dict(orient="records", **kwargs))

@property
def name(self):
Expand Down
5 changes: 3 additions & 2 deletions python-sdk/src/astro/files/types/ndjson.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,14 @@ def export_to_dataframe(
return PandasDataframe.from_pandas_df(df)

# We need skipcq because it's a method overloading so we don't want to make it a static method
def create_from_dataframe(self, df: pd.DataFrame, stream: io.TextIOWrapper) -> None: # skipcq PYL-R0201
def create_from_dataframe(self, df: pd.DataFrame, stream: io.TextIOWrapper, **kwargs) -> None: # skipcq PYL-R0201
"""Write ndjson file to one of the supported locations
:param df: pandas dataframe
:param stream: file stream object
:param kwargs: additional arguments to pass to the pandas `to_json` function
"""
df.to_json(stream, orient="records", lines=True)
df.to_json(stream, **dict(orient="records", lines=True, **kwargs))

@property
def name(self):
Expand Down
5 changes: 3 additions & 2 deletions python-sdk/src/astro/files/types/parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,13 +57,14 @@ def _convert_remote_file_to_byte_stream(stream) -> io.IOBase:
return remote_obj_buffer

# We need skipcq because it's a method overloading so we don't want to make it a static method
def create_from_dataframe(self, df: pd.DataFrame, stream: io.TextIOWrapper) -> None: # skipcq PYL-R0201
def create_from_dataframe(self, df: pd.DataFrame, stream: io.TextIOWrapper, **kwargs) -> None: # skipcq PYL-R0201
"""Write parquet file to one of the supported locations
:param df: pandas dataframe
:param stream: file stream object
:param kwargs: additional arguments to pass to the pandas `to_parquet` method
"""
df.to_parquet(stream)
df.to_parquet(stream, **kwargs)

@property
def name(self):
Expand Down
12 changes: 9 additions & 3 deletions python-sdk/src/astro/sql/operators/export_to_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,20 +21,23 @@ class ExportToFileOperator(AstroSQLBaseOperator):
:param input_data: Table to convert to file
:param output_file: File object containing the path to the file and connection id.
:param if_exists: Overwrite file if exists. Default False.
:param export_options: Additional options to pass to the file export functions.
"""

template_fields = ("input_data", "output_file")
template_fields = ("input_data", "output_file", "export_options")

def __init__(
self,
input_data: BaseTable | pd.DataFrame,
output_file: File,
if_exists: ExportExistsStrategy = "exception",
export_options: dict | None = None,
**kwargs,
) -> None:
self.output_file = output_file
self.input_data = input_data
self.if_exists = if_exists
self.export_options = export_options or {}
self.kwargs = kwargs
datasets = {"output_datasets": self.output_file}
if isinstance(input_data, Table):
Expand All @@ -57,7 +60,7 @@ def execute(self, context: Context) -> File: # skipcq PYL-W0613
raise ValueError(f"Expected input_table to be Table or dataframe. Got {type(self.input_data)}")
# Write file if overwrite == True or if file doesn't exist.
if self.if_exists == "replace" or not self.output_file.exists():
self.output_file.create_from_dataframe(df, store_as_dataframe=False)
self.output_file.create_from_dataframe(df, store_as_dataframe=False, export_options=self.export_options)
return self.output_file
else:
raise FileExistsError(f"{self.output_file.path} file already exists.")
Expand Down Expand Up @@ -144,7 +147,8 @@ def export_to_file(
output_file: File,
if_exists: ExportExistsStrategy = "exception",
task_id: str | None = None,
**kwargs: Any,
export_options: dict | None = None,
**kwargs,
) -> XComArg:
"""Convert ExportToFileOperator into a function. Returns XComArg.
Expand All @@ -170,6 +174,7 @@ def export_to_file(
:param input_data: Input table / dataframe
:param if_exists: Overwrite file if exists. Default "exception"
:param task_id: task id, optional
:param export_options: Additional options to pass to the file export functions.
"""

task_id = task_id or get_unique_task_id("export_to_file")
Expand All @@ -179,5 +184,6 @@ def export_to_file(
output_file=output_file,
input_data=input_data,
if_exists=if_exists,
export_options=export_options,
**kwargs,
).output
19 changes: 19 additions & 0 deletions python-sdk/tests/sql/operators/test_export_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,25 @@ def make_df():
assert df.equals(pd.DataFrame(data={"col1": [1, 2], "col2": [3, 4]}))


def test_save_dataframe_to_local_with_options(sample_dag):
@aql.dataframe
def make_df():
return pd.DataFrame(data={"col1": [1, 2], "col2": [3, 4]})

with sample_dag:
df = make_df()
aql.export_to_file(
input_data=df,
output_file=File(path="/tmp/saved_df.csv"),
if_exists="replace",
export_options={"header": None},
)
test_utils.run_dag(sample_dag)

df = pd.read_csv("/tmp/saved_df.csv")
assert df.equals(pd.DataFrame(data={"0": [1, 2], "1": [3, 4]}))


@pytest.mark.parametrize("database_table_fixture", [{"database": Database.SQLITE}], indirect=True)
def test_save_temp_table_to_local(sample_dag, database_table_fixture):
_, test_table = database_table_fixture
Expand Down

0 comments on commit fece91a

Please sign in to comment.