Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add export_options to export_to_file operator #2196

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions python-sdk/src/astro/files/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,18 +114,22 @@ def is_pattern(self) -> bool:
"""
return not pathlib.PosixPath(self.path).suffix

def create_from_dataframe(self, df: pd.DataFrame, store_as_dataframe: bool = True) -> None:
def create_from_dataframe(
self, df: pd.DataFrame, store_as_dataframe: bool = True, export_options: dict | None = None
) -> None:
"""Create a file in the desired location using the values of a dataframe.

:param store_as_dataframe: Whether the data should later be deserialized as a dataframe or as a file containing
delimited data (e.g. csv, parquet, etc.).
:param df: pandas dataframe
:param export_options: additional arguments to pass to the underlying write functionality
"""

self.is_dataframe = store_as_dataframe
opts = export_options or {}

with self.location.get_stream() as stream:
self.type.create_from_dataframe(stream=stream, df=df)
self.type.create_from_dataframe(stream=stream, df=df, **opts)

@property
def openlineage_dataset_namespace(self) -> str:
Expand Down
3 changes: 2 additions & 1 deletion python-sdk/src/astro/files/types/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,12 @@ def export_to_dataframe(self, stream, **kwargs) -> pd.DataFrame:
raise NotImplementedError

@abstractmethod
def create_from_dataframe(self, df: pd.DataFrame, stream: io.TextIOWrapper) -> None:
def create_from_dataframe(self, df: pd.DataFrame, stream: io.TextIOWrapper, **kwargs) -> None:
"""Write file to one of the supported locations

:param df: pandas dataframe
:param stream: file stream object
:param kwargs: additional arguments to pass to the underlying write functionality
"""
raise NotImplementedError

Expand Down
8 changes: 6 additions & 2 deletions python-sdk/src/astro/files/types/csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,17 @@ def export_to_dataframe(
return PandasDataframe.from_pandas_df(df)

# We need skipcq because it's a method overloading so we don't want to make it a static method
def create_from_dataframe(self, df: pd.DataFrame, stream: io.TextIOWrapper) -> None: # skipcq PYL-R0201
def create_from_dataframe(
self, df: pd.DataFrame, stream: io.TextIOWrapper, **kwargs
) -> None: # skipcq PYL-R0201
"""Write csv file to one of the supported locations

:param df: pandas dataframe
:param stream: file stream object
:param kwargs: additional arguments to pass to the pandas `to_csv` function
"""
df.to_csv(stream, index=False)

df.to_csv(stream, **dict(index=False, **kwargs))

@property
def name(self):
Expand Down
7 changes: 5 additions & 2 deletions python-sdk/src/astro/files/types/excel.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,13 @@ def export_to_dataframe(
return PandasDataframe.from_pandas_df(df)

# We need skipcq because it's a method overloading so we don't want to make it a static method
def create_from_dataframe(self, df: pd.DataFrame, stream: io.TextIOWrapper) -> None: # skipcq PYL-R0201
def create_from_dataframe(
self, df: pd.DataFrame, stream: io.TextIOWrapper, **kwargs
) -> None: # skipcq PYL-R0201
"""Write Excel file to one of the supported locations

:param df: pandas dataframe
:param stream: file stream object
:param kwargs: additional arguments to pass to the pandas `to_excel` function
"""
df.to_excel(stream, index=False)
df.to_excel(stream, **dict(index=False, **kwargs))
7 changes: 5 additions & 2 deletions python-sdk/src/astro/files/types/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,16 @@ def export_to_dataframe(
return PandasDataframe.from_pandas_df(df)

# We need skipcq because it's a method overloading so we don't want to make it a static method
def create_from_dataframe(self, df: pd.DataFrame, stream: io.TextIOWrapper) -> None: # skipcq PYL-R0201
def create_from_dataframe(
self, df: pd.DataFrame, stream: io.TextIOWrapper, **kwargs
) -> None: # skipcq PYL-R0201
"""Write json file to one of the supported locations

:param df: pandas dataframe
:param stream: file stream object
:param kwargs: additional arguments to pass to the pandas `to_json` function
"""
df.to_json(stream, orient="records")
df.to_json(stream, **dict(orient="records", **kwargs))

@property
def name(self):
Expand Down
7 changes: 5 additions & 2 deletions python-sdk/src/astro/files/types/ndjson.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,16 @@ def export_to_dataframe(
return PandasDataframe.from_pandas_df(df)

# We need skipcq because it's a method overloading so we don't want to make it a static method
def create_from_dataframe(self, df: pd.DataFrame, stream: io.TextIOWrapper) -> None: # skipcq PYL-R0201
def create_from_dataframe(
self, df: pd.DataFrame, stream: io.TextIOWrapper, **kwargs
) -> None: # skipcq PYL-R0201
"""Write ndjson file to one of the supported locations

:param df: pandas dataframe
:param stream: file stream object
:param kwargs: additional arguments to pass to the pandas `to_json` function
"""
df.to_json(stream, orient="records", lines=True)
df.to_json(stream, **dict(orient="records", lines=True, **kwargs))

@property
def name(self):
Expand Down
7 changes: 5 additions & 2 deletions python-sdk/src/astro/files/types/parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,13 +57,16 @@ def _convert_remote_file_to_byte_stream(stream) -> io.IOBase:
return remote_obj_buffer

# We need skipcq because it's a method overloading so we don't want to make it a static method
def create_from_dataframe(self, df: pd.DataFrame, stream: io.TextIOWrapper) -> None: # skipcq PYL-R0201
def create_from_dataframe(
self, df: pd.DataFrame, stream: io.TextIOWrapper, **kwargs
) -> None: # skipcq PYL-R0201
"""Write parquet file to one of the supported locations

:param df: pandas dataframe
:param stream: file stream object
:param kwargs: additional arguments to pass to the pandas `to_parquet` method
"""
df.to_parquet(stream)
df.to_parquet(stream, **kwargs)

@property
def name(self):
Expand Down
16 changes: 11 additions & 5 deletions python-sdk/src/astro/sql/operators/export_to_file.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
from __future__ import annotations

from typing import Any

import pandas as pd
from airflow.decorators.base import get_unique_task_id
from airflow.models.xcom_arg import XComArg
Expand All @@ -21,20 +19,23 @@ class ExportToFileOperator(AstroSQLBaseOperator):
:param input_data: Table to convert to file
:param output_file: File object containing the path to the file and connection id.
:param if_exists: Overwrite file if exists. Default False.
:param export_options: Additional options to pass to the file export functions.
"""

template_fields = ("input_data", "output_file")
template_fields = ("input_data", "output_file", "export_options")

def __init__(
self,
input_data: BaseTable | pd.DataFrame,
output_file: File,
if_exists: ExportExistsStrategy = "exception",
export_options: dict | None = None,
**kwargs,
) -> None:
self.output_file = output_file
self.input_data = input_data
self.if_exists = if_exists
self.export_options = export_options or {}
self.kwargs = kwargs
datasets = {"output_datasets": self.output_file}
if isinstance(input_data, Table):
Expand All @@ -57,7 +58,9 @@ def execute(self, context: Context) -> File: # skipcq PYL-W0613
raise ValueError(f"Expected input_table to be Table or dataframe. Got {type(self.input_data)}")
# Write file if overwrite == True or if file doesn't exist.
if self.if_exists == "replace" or not self.output_file.exists():
self.output_file.create_from_dataframe(df, store_as_dataframe=False)
self.output_file.create_from_dataframe(
df, store_as_dataframe=False, export_options=self.export_options
)
return self.output_file
else:
raise FileExistsError(f"{self.output_file.path} file already exists.")
Expand Down Expand Up @@ -144,7 +147,8 @@ def export_to_file(
output_file: File,
if_exists: ExportExistsStrategy = "exception",
task_id: str | None = None,
**kwargs: Any,
export_options: dict | None = None,
**kwargs,
) -> XComArg:
"""Convert ExportToFileOperator into a function. Returns XComArg.

Expand All @@ -170,6 +174,7 @@ def export_to_file(
:param input_data: Input table / dataframe
:param if_exists: Overwrite file if exists. Default "exception"
:param task_id: task id, optional
:param export_options: Additional options to pass to the file export functions.
"""

task_id = task_id or get_unique_task_id("export_to_file")
Expand All @@ -179,5 +184,6 @@ def export_to_file(
output_file=output_file,
input_data=input_data,
if_exists=if_exists,
export_options=export_options,
**kwargs,
).output
19 changes: 19 additions & 0 deletions python-sdk/tests/sql/operators/test_export_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,25 @@ def make_df():
assert df.equals(pd.DataFrame(data={"col1": [1, 2], "col2": [3, 4]}))


def test_save_dataframe_to_local_with_options(sample_dag):
@aql.dataframe
def make_df():
return pd.DataFrame(data={"col1": [1, 2], "col2": [3, 4]})

with sample_dag:
df = make_df()
aql.export_to_file(
input_data=df,
output_file=File(path="/tmp/saved_df.csv"),
if_exists="replace",
export_options={"header": None},
)
test_utils.run_dag(sample_dag)

df = pd.read_csv("/tmp/saved_df.csv")
assert df.equals(pd.DataFrame(data={"0": [1, 2], "1": [3, 4]}))


@pytest.mark.parametrize("database_table_fixture", [{"database": Database.SQLITE}], indirect=True)
def test_save_temp_table_to_local(sample_dag, database_table_fixture):
_, test_table = database_table_fixture
Expand Down