Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use load_options param to pass pandasOptions in load_file #1466

Merged
merged 23 commits into from
Jan 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 28 additions & 7 deletions python-sdk/src/astro/databases/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from pandas.io.sql import SQLDatabase
from sqlalchemy import column, insert, select

from astro.dataframes.load_options import PandasLoadOptions
from astro.dataframes.pandas import PandasDataframe

if TYPE_CHECKING: # pragma: no cover
Expand Down Expand Up @@ -246,6 +247,7 @@ def create_table_using_schema_autodetection(
file: File | None = None,
dataframe: pd.DataFrame | None = None,
columns_names_capitalization: ColumnCapitalization = "original", # skipcq
load_options: LoadOptions | None = None,
) -> None:
"""
Create a SQL table, automatically inferring the schema using the given file.
Expand All @@ -263,7 +265,9 @@ def create_table_using_schema_autodetection(
)
source_dataframe = dataframe
else:
source_dataframe = file.export_to_dataframe(nrows=LOAD_TABLE_AUTODETECT_ROWS_COUNT)
source_dataframe = file.export_to_dataframe(
nrows=LOAD_TABLE_AUTODETECT_ROWS_COUNT, load_options=load_options
)

db = SQLDatabase(engine=self.sqlalchemy_engine)
db.prep_table(
Expand Down Expand Up @@ -291,6 +295,7 @@ def create_table(
file: File | None = None,
dataframe: pd.DataFrame | None = None,
columns_names_capitalization: ColumnCapitalization = "original",
load_options: LoadOptions | None = None,
) -> None:
"""
Create a table either using its explicitly defined columns or inferring
Expand All @@ -307,7 +312,13 @@ def create_table(
elif file and self.is_native_autodetect_schema_available(file):
self.create_table_using_native_schema_autodetection(table, file)
else:
self.create_table_using_schema_autodetection(table, file, dataframe, columns_names_capitalization)
self.create_table_using_schema_autodetection(
table,
file=file,
dataframe=dataframe,
columns_names_capitalization=columns_names_capitalization,
load_options=load_options,
)

def create_table_from_select_statement(
self,
Expand Down Expand Up @@ -348,6 +359,7 @@ def create_schema_and_table_if_needed(
columns_names_capitalization: ColumnCapitalization = "original",
if_exists: LoadExistStrategy = "replace",
use_native_support: bool = True,
load_options: LoadOptions | None = None,
):
"""
Checks if the autodetect schema exists for native support else creates the schema and table
Expand Down Expand Up @@ -387,6 +399,7 @@ def create_schema_and_table_if_needed(
# We only use the first file for inferring the table schema
files[0],
columns_names_capitalization=columns_names_capitalization,
load_options=load_options,
)

def fetch_all_rows(self, table: BaseTable, row_limit: int = -1) -> list:
Expand Down Expand Up @@ -442,6 +455,7 @@ def load_file_to_table(
columns_names_capitalization=columns_names_capitalization,
if_exists=if_exists,
normalize_config=normalize_config,
load_options=load_options,
)

if use_native_support and self.is_native_load_file_available(
Expand All @@ -456,6 +470,7 @@ def load_file_to_table(
native_support_kwargs=native_support_kwargs,
enable_native_fallback=enable_native_fallback,
chunk_size=chunk_size,
load_options=load_options,
)
else:
self.load_file_to_table_using_pandas(
Expand All @@ -464,20 +479,22 @@ def load_file_to_table(
normalize_config=normalize_config,
if_exists="append",
chunk_size=chunk_size,
load_options=load_options,
)

@staticmethod
def get_dataframe_from_file(file: File):
def get_dataframe_from_file(file: File, load_options: LoadOptions | None = None):
"""
Get pandas dataframe file. We need export_to_dataframe() for Biqqery,Snowflake and Redshift except for Postgres.
For postgres we are overriding this method and using export_to_dataframe_via_byte_stream().
export_to_dataframe_via_byte_stream copies files in a buffer and then use that buffer to ingest data.
With this approach we have significant performance boost for postgres.

:param file: File path and conn_id for object stores
:param load_options: pandas options while reading file
"""

return file.export_to_dataframe()
return file.export_to_dataframe(load_options=load_options)

@staticmethod
def _assert_not_empty_df(df):
Expand All @@ -495,6 +512,7 @@ def load_file_to_table_using_pandas(
normalize_config: dict | None = None,
if_exists: LoadExistStrategy = "replace",
chunk_size: int = DEFAULT_CHUNK_SIZE,
load_options: PandasLoadOptions | LoadOptions | None = None,
):
logging.info("Loading file(s) with Pandas...")
input_files = resolve_file_path_pattern(
Expand All @@ -506,7 +524,7 @@ def load_file_to_table_using_pandas(

for file in input_files:
self.load_pandas_dataframe_to_table(
self.get_dataframe_from_file(file),
self.get_dataframe_from_file(file, load_options),
output_table,
chunk_size=chunk_size,
if_exists=if_exists,
Expand All @@ -521,6 +539,7 @@ def load_file_to_table_natively_with_fallback(
native_support_kwargs: dict | None = None,
enable_native_fallback: bool | None = LOAD_FILE_ENABLE_NATIVE_FALLBACK,
chunk_size: int = DEFAULT_CHUNK_SIZE,
load_options: PandasLoadOptions | LoadOptions | None = None,
**kwargs,
):
"""
Expand All @@ -533,6 +552,7 @@ def load_file_to_table_natively_with_fallback(
:param native_support_kwargs: kwargs to be used by method involved in native support flow
:param enable_native_fallback: Use enable_native_fallback=True to fall back to default transfer
:param normalize_config: pandas json_normalize params config
:param load_options: pandas options while reading file
"""

try:
Expand All @@ -557,6 +577,7 @@ def load_file_to_table_natively_with_fallback(
normalize_config=normalize_config,
if_exists=if_exists,
chunk_size=chunk_size,
load_options=load_options,
)
else:
raise load_exception
Expand Down Expand Up @@ -724,8 +745,8 @@ def schema_exists(self, schema: str) -> bool:
# ---------------------------------------------------------

def get_sqlalchemy_template_table_identifier_and_parameter(
self, table: BaseTable, jinja_table_identifier: str
) -> tuple[str, str]: # skipcq PYL-W0613
self, table: BaseTable, jinja_table_identifier: str # skipcq PYL-W0613
pankajastro marked this conversation as resolved.
Show resolved Hide resolved
) -> tuple[str, str]:
"""
During the conversion from a Jinja-templated SQL query to a SQLAlchemy query, there is the need to
convert a Jinja table identifier to a safe SQLAlchemy-compatible table identifier.
Expand Down
4 changes: 3 additions & 1 deletion python-sdk/src/astro/databases/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from astro.constants import DEFAULT_CHUNK_SIZE, LoadExistStrategy, MergeConflictStrategy
from astro.databases.base import BaseDatabase
from astro.files import File
from astro.options import LoadOptions
from astro.settings import POSTGRES_SCHEMA
from astro.table import BaseTable, Metadata

Expand Down Expand Up @@ -200,11 +201,12 @@ def identifier_args(table: BaseTable):
self.run_sql(sql=sql)

@staticmethod
def get_dataframe_from_file(file: File):
def get_dataframe_from_file(file: File, load_options: LoadOptions | None = None): # skipcq: PYL-W0613
"""
Get pandas dataframe file

:param file: File path and conn_id for object stores
:param load_options: pandas options while reading file
"""
return file.export_to_dataframe_via_byte_stream()

Expand Down
6 changes: 5 additions & 1 deletion python-sdk/src/astro/databases/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from astro.databases.base import BaseDatabase
from astro.exceptions import DatabaseCustomError
from astro.files import File
from astro.options import LoadOptions
from astro.settings import LOAD_TABLE_AUTODETECT_ROWS_COUNT, SNOWFLAKE_SCHEMA
from astro.table import BaseTable, Metadata

Expand Down Expand Up @@ -518,6 +519,7 @@ def create_table_using_schema_autodetection(
file: File | None = None,
dataframe: pd.DataFrame | None = None,
columns_names_capitalization: ColumnCapitalization = "original",
load_options: LoadOptions | None = None,
) -> None: # skipcq PYL-W0613
"""
Create a SQL table, automatically inferring the schema using the given file.
Expand All @@ -534,7 +536,9 @@ def create_table_using_schema_autodetection(
)
source_dataframe = dataframe
else:
source_dataframe = file.export_to_dataframe(nrows=LOAD_TABLE_AUTODETECT_ROWS_COUNT)
source_dataframe = file.export_to_dataframe(
nrows=LOAD_TABLE_AUTODETECT_ROWS_COUNT, load_options=load_options
)

# We are changing the case of table name to ease out on the requirements to add quotes in raw queries.
# ToDO - Currently, we cannot to append using load_file to a table name which is having name in lower case.
Expand Down
32 changes: 32 additions & 0 deletions python-sdk/src/astro/dataframes/load_options.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from __future__ import annotations

from attr import define
from pandas._typing import DtypeArg

from astro.options import LoadOptions


@define
class PandasLoadOptions(LoadOptions):
pass


@define
class PandasCsvLoadOptions(PandasLoadOptions):
delimiter: str | None = None
dtype: DtypeArg | None = None


@define
class PandasJsonLoadOptions(PandasLoadOptions):
encoding: str | None = None


@define
class PandasNdjsonLoadOptions(PandasLoadOptions):
normalize_sep: str = "_"


@define
class PandasParquetLoadOptions(PandasLoadOptions):
columns: list[str] | None = None
8 changes: 6 additions & 2 deletions python-sdk/src/astro/files/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,11 @@

from astro import constants
from astro.airflow.datasets import Dataset
from astro.dataframes.load_options import PandasLoadOptions
from astro.files.locations import create_file_location
from astro.files.locations.base import BaseFileLocation
from astro.files.types import FileType, create_file_type
from astro.options import LoadOptions


@define
Expand Down Expand Up @@ -124,11 +126,13 @@ def is_directory(self) -> bool:

return pathlib.Path(self.path).is_dir()

def export_to_dataframe(self, **kwargs) -> pd.DataFrame:
def export_to_dataframe(
self, load_options: LoadOptions | PandasLoadOptions | None = None, **kwargs
) -> pd.DataFrame:
"""Read file from all supported location and convert them into dataframes."""
mode = "rb" if self.is_binary() else "r"
with smart_open.open(self.path, mode=mode, transport_params=self.location.transport_params) as stream:
return self.type.export_to_dataframe(stream, **kwargs)
return self.type.export_to_dataframe(stream, load_options, **kwargs)

def _convert_remote_file_to_byte_stream(self) -> io.IOBase:
"""
Expand Down
8 changes: 7 additions & 1 deletion python-sdk/src/astro/files/types/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@

import pandas as pd

from astro.dataframes.load_options import PandasLoadOptions
from astro.options import LoadOptions


class FileType(ABC):
"""Abstract File type class, meant to be the interface to all client code for all supported file types"""
Expand All @@ -14,10 +17,13 @@ def __init__(self, path: str, normalize_config: dict | None = None):
self.normalize_config = normalize_config

@abstractmethod
def export_to_dataframe(self, stream, **kwargs) -> pd.DataFrame:
def export_to_dataframe(
self, stream, load_options: LoadOptions | PandasLoadOptions | None = None, **kwargs
) -> pd.DataFrame:
"""read file from one of the supported locations and return dataframe

:param stream: file stream object
:param load_options: Pandas option to pass to the Pandas lib while reading file
"""
raise NotImplementedError

Expand Down
12 changes: 11 additions & 1 deletion python-sdk/src/astro/files/types/csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,14 @@

import io

import attr
import pandas as pd

from astro.constants import FileType as FileTypeConstants
from astro.dataframes.load_options import PandasLoadOptions
from astro.dataframes.pandas import PandasDataframe
from astro.files.types.base import FileType
from astro.options import LoadOptions
from astro.utils.dataframe import convert_columns_names_capitalization


Expand All @@ -15,14 +18,21 @@ class CSVFileType(FileType):

# We need skipcq because it's a method overloading so we don't want to make it a static method
def export_to_dataframe(
self, stream, columns_names_capitalization="original", **kwargs
self,
stream,
load_options: LoadOptions | PandasLoadOptions | None = None,
columns_names_capitalization="original",
**kwargs,
) -> pd.DataFrame: # skipcq PYL-R0201
"""read csv file from one of the supported locations and return dataframe

:param stream: file stream object
:param load_options: Pandas option to pass to the Pandas lib while reading csv
:param columns_names_capitalization: determines whether to convert all columns to lowercase/uppercase
in the resulting dataframe
"""
if isinstance(load_options, PandasLoadOptions):
kwargs.update(attr.asdict(load_options))
df = pd.read_csv(stream, **kwargs)
df = convert_columns_names_capitalization(
df=df, columns_names_capitalization=columns_names_capitalization
Expand Down
7 changes: 7 additions & 0 deletions python-sdk/src/astro/files/types/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,14 @@

import io

import attr
import pandas as pd

from astro.constants import FileType as FileTypeConstants
from astro.dataframes.load_options import PandasLoadOptions
from astro.dataframes.pandas import PandasDataframe
from astro.files.types.base import FileType
from astro.options import LoadOptions
from astro.utils.dataframe import convert_columns_names_capitalization


Expand All @@ -17,18 +20,22 @@ class JSONFileType(FileType):
def export_to_dataframe(
self,
stream: io.TextIOWrapper,
load_options: LoadOptions | PandasLoadOptions | None = None,
columns_names_capitalization="original",
**kwargs,
) -> pd.DataFrame: # skipcq PYL-R0201
"""read json file from one of the supported locations and return dataframe

:param stream: file stream object
:param load_options: Pandas option to pass to the Pandas lib while reading json
:param columns_names_capitalization: determines whether to convert all columns to lowercase/uppercase
in the resulting dataframe
"""
kwargs_copy = dict(kwargs)
# Pandas `read_json` does not support the `nrows` parameter unless we're using NDJSON
kwargs_copy.pop("nrows", None)
if isinstance(load_options, PandasLoadOptions):
kwargs_copy.update(attr.asdict(load_options))
df = pd.read_json(stream, **kwargs_copy)
df = convert_columns_names_capitalization(
df=df, columns_names_capitalization=columns_names_capitalization
Expand Down
Loading