Skip to content

Commit

Permalink
Add add pandas_options in load_file (#1466)
Browse files Browse the repository at this point in the history
# Description
closes: #1519

## What is the current behavior?
currently, load_file do not have an option to pass the pandas-related
param while reading file

## What is the new behavior?
use `load_options` and pass the given values while reading files using
the pandas path

## Does this introduce a breaking change?
No

### Checklist
- [ ] Created tests which fail without the change (if possible)
- [ ] Extended the README / documentation, if necessary

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
2 people authored and utkarsharma2 committed Jan 17, 2023
1 parent ad0c4df commit 6850b44
Show file tree
Hide file tree
Showing 17 changed files with 211 additions and 15 deletions.
35 changes: 28 additions & 7 deletions python-sdk/src/astro/databases/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from pandas.io.sql import SQLDatabase
from sqlalchemy import column, insert, select

from astro.dataframes.load_options import PandasLoadOptions
from astro.dataframes.pandas import PandasDataframe

if TYPE_CHECKING: # pragma: no cover
Expand Down Expand Up @@ -246,6 +247,7 @@ def create_table_using_schema_autodetection(
file: File | None = None,
dataframe: pd.DataFrame | None = None,
columns_names_capitalization: ColumnCapitalization = "original", # skipcq
load_options: LoadOptions | None = None,
) -> None:
"""
Create a SQL table, automatically inferring the schema using the given file.
Expand All @@ -263,7 +265,9 @@ def create_table_using_schema_autodetection(
)
source_dataframe = dataframe
else:
source_dataframe = file.export_to_dataframe(nrows=LOAD_TABLE_AUTODETECT_ROWS_COUNT)
source_dataframe = file.export_to_dataframe(
nrows=LOAD_TABLE_AUTODETECT_ROWS_COUNT, load_options=load_options
)

db = SQLDatabase(engine=self.sqlalchemy_engine)
db.prep_table(
Expand Down Expand Up @@ -291,6 +295,7 @@ def create_table(
file: File | None = None,
dataframe: pd.DataFrame | None = None,
columns_names_capitalization: ColumnCapitalization = "original",
load_options: LoadOptions | None = None,
) -> None:
"""
Create a table either using its explicitly defined columns or inferring
Expand All @@ -307,7 +312,13 @@ def create_table(
elif file and self.is_native_autodetect_schema_available(file):
self.create_table_using_native_schema_autodetection(table, file)
else:
self.create_table_using_schema_autodetection(table, file, dataframe, columns_names_capitalization)
self.create_table_using_schema_autodetection(
table,
file=file,
dataframe=dataframe,
columns_names_capitalization=columns_names_capitalization,
load_options=load_options,
)

def create_table_from_select_statement(
self,
Expand Down Expand Up @@ -348,6 +359,7 @@ def create_schema_and_table_if_needed(
columns_names_capitalization: ColumnCapitalization = "original",
if_exists: LoadExistStrategy = "replace",
use_native_support: bool = True,
load_options: LoadOptions | None = None,
):
"""
Checks if the autodetect schema exists for native support else creates the schema and table
Expand Down Expand Up @@ -387,6 +399,7 @@ def create_schema_and_table_if_needed(
# We only use the first file for inferring the table schema
files[0],
columns_names_capitalization=columns_names_capitalization,
load_options=load_options,
)

def fetch_all_rows(self, table: BaseTable, row_limit: int = -1) -> list:
Expand Down Expand Up @@ -442,6 +455,7 @@ def load_file_to_table(
columns_names_capitalization=columns_names_capitalization,
if_exists=if_exists,
normalize_config=normalize_config,
load_options=load_options,
)

if use_native_support and self.is_native_load_file_available(
Expand All @@ -456,6 +470,7 @@ def load_file_to_table(
native_support_kwargs=native_support_kwargs,
enable_native_fallback=enable_native_fallback,
chunk_size=chunk_size,
load_options=load_options,
)
else:
self.load_file_to_table_using_pandas(
Expand All @@ -464,20 +479,22 @@ def load_file_to_table(
normalize_config=normalize_config,
if_exists="append",
chunk_size=chunk_size,
load_options=load_options,
)

@staticmethod
def get_dataframe_from_file(file: File):
def get_dataframe_from_file(file: File, load_options: LoadOptions | None = None):
"""
Get pandas dataframe file. We need export_to_dataframe() for Biqqery,Snowflake and Redshift except for Postgres.
For postgres we are overriding this method and using export_to_dataframe_via_byte_stream().
export_to_dataframe_via_byte_stream copies files in a buffer and then use that buffer to ingest data.
With this approach we have significant performance boost for postgres.
:param file: File path and conn_id for object stores
:param load_options: pandas options while reading file
"""

return file.export_to_dataframe()
return file.export_to_dataframe(load_options=load_options)

@staticmethod
def _assert_not_empty_df(df):
Expand All @@ -495,6 +512,7 @@ def load_file_to_table_using_pandas(
normalize_config: dict | None = None,
if_exists: LoadExistStrategy = "replace",
chunk_size: int = DEFAULT_CHUNK_SIZE,
load_options: PandasLoadOptions | LoadOptions | None = None,
):
logging.info("Loading file(s) with Pandas...")
input_files = resolve_file_path_pattern(
Expand All @@ -506,7 +524,7 @@ def load_file_to_table_using_pandas(

for file in input_files:
self.load_pandas_dataframe_to_table(
self.get_dataframe_from_file(file),
self.get_dataframe_from_file(file, load_options),
output_table,
chunk_size=chunk_size,
if_exists=if_exists,
Expand All @@ -521,6 +539,7 @@ def load_file_to_table_natively_with_fallback(
native_support_kwargs: dict | None = None,
enable_native_fallback: bool | None = LOAD_FILE_ENABLE_NATIVE_FALLBACK,
chunk_size: int = DEFAULT_CHUNK_SIZE,
load_options: PandasLoadOptions | LoadOptions | None = None,
**kwargs,
):
"""
Expand All @@ -533,6 +552,7 @@ def load_file_to_table_natively_with_fallback(
:param native_support_kwargs: kwargs to be used by method involved in native support flow
:param enable_native_fallback: Use enable_native_fallback=True to fall back to default transfer
:param normalize_config: pandas json_normalize params config
:param load_options: pandas options while reading file
"""

try:
Expand All @@ -557,6 +577,7 @@ def load_file_to_table_natively_with_fallback(
normalize_config=normalize_config,
if_exists=if_exists,
chunk_size=chunk_size,
load_options=load_options,
)
else:
raise load_exception
Expand Down Expand Up @@ -724,8 +745,8 @@ def schema_exists(self, schema: str) -> bool:
# ---------------------------------------------------------

def get_sqlalchemy_template_table_identifier_and_parameter(
self, table: BaseTable, jinja_table_identifier: str
) -> tuple[str, str]: # skipcq PYL-W0613
self, table: BaseTable, jinja_table_identifier: str # skipcq PYL-W0613
) -> tuple[str, str]:
"""
During the conversion from a Jinja-templated SQL query to a SQLAlchemy query, there is the need to
convert a Jinja table identifier to a safe SQLAlchemy-compatible table identifier.
Expand Down
4 changes: 3 additions & 1 deletion python-sdk/src/astro/databases/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from astro.constants import DEFAULT_CHUNK_SIZE, LoadExistStrategy, MergeConflictStrategy
from astro.databases.base import BaseDatabase
from astro.files import File
from astro.options import LoadOptions
from astro.settings import POSTGRES_SCHEMA
from astro.table import BaseTable, Metadata

Expand Down Expand Up @@ -200,11 +201,12 @@ def identifier_args(table: BaseTable):
self.run_sql(sql=sql)

@staticmethod
def get_dataframe_from_file(file: File):
def get_dataframe_from_file(file: File, load_options: LoadOptions | None = None): # skipcq: PYL-W0613
"""
Get pandas dataframe file
:param file: File path and conn_id for object stores
:param load_options: pandas options while reading file
"""
return file.export_to_dataframe_via_byte_stream()

Expand Down
6 changes: 5 additions & 1 deletion python-sdk/src/astro/databases/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from astro.databases.base import BaseDatabase
from astro.exceptions import DatabaseCustomError
from astro.files import File
from astro.options import LoadOptions
from astro.settings import LOAD_TABLE_AUTODETECT_ROWS_COUNT, SNOWFLAKE_SCHEMA
from astro.table import BaseTable, Metadata

Expand Down Expand Up @@ -518,6 +519,7 @@ def create_table_using_schema_autodetection(
file: File | None = None,
dataframe: pd.DataFrame | None = None,
columns_names_capitalization: ColumnCapitalization = "original",
load_options: LoadOptions | None = None,
) -> None: # skipcq PYL-W0613
"""
Create a SQL table, automatically inferring the schema using the given file.
Expand All @@ -534,7 +536,9 @@ def create_table_using_schema_autodetection(
)
source_dataframe = dataframe
else:
source_dataframe = file.export_to_dataframe(nrows=LOAD_TABLE_AUTODETECT_ROWS_COUNT)
source_dataframe = file.export_to_dataframe(
nrows=LOAD_TABLE_AUTODETECT_ROWS_COUNT, load_options=load_options
)

# We are changing the case of table name to ease out on the requirements to add quotes in raw queries.
# ToDO - Currently, we cannot to append using load_file to a table name which is having name in lower case.
Expand Down
32 changes: 32 additions & 0 deletions python-sdk/src/astro/dataframes/load_options.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from __future__ import annotations

from attr import define
from pandas._typing import DtypeArg

from astro.options import LoadOptions


@define
class PandasLoadOptions(LoadOptions):
pass


@define
class PandasCsvLoadOptions(PandasLoadOptions):
delimiter: str | None = None
dtype: DtypeArg | None = None


@define
class PandasJsonLoadOptions(PandasLoadOptions):
encoding: str | None = None


@define
class PandasNdjsonLoadOptions(PandasLoadOptions):
normalize_sep: str = "_"


@define
class PandasParquetLoadOptions(PandasLoadOptions):
columns: list[str] | None = None
8 changes: 6 additions & 2 deletions python-sdk/src/astro/files/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,11 @@

from astro import constants
from astro.airflow.datasets import Dataset
from astro.dataframes.load_options import PandasLoadOptions
from astro.files.locations import create_file_location
from astro.files.locations.base import BaseFileLocation
from astro.files.types import FileType, create_file_type
from astro.options import LoadOptions


@define
Expand Down Expand Up @@ -124,11 +126,13 @@ def is_directory(self) -> bool:

return pathlib.Path(self.path).is_dir()

def export_to_dataframe(self, **kwargs) -> pd.DataFrame:
def export_to_dataframe(
self, load_options: LoadOptions | PandasLoadOptions | None = None, **kwargs
) -> pd.DataFrame:
"""Read file from all supported location and convert them into dataframes."""
mode = "rb" if self.is_binary() else "r"
with smart_open.open(self.path, mode=mode, transport_params=self.location.transport_params) as stream:
return self.type.export_to_dataframe(stream, **kwargs)
return self.type.export_to_dataframe(stream, load_options, **kwargs)

def _convert_remote_file_to_byte_stream(self) -> io.IOBase:
"""
Expand Down
8 changes: 7 additions & 1 deletion python-sdk/src/astro/files/types/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@

import pandas as pd

from astro.dataframes.load_options import PandasLoadOptions
from astro.options import LoadOptions


class FileType(ABC):
"""Abstract File type class, meant to be the interface to all client code for all supported file types"""
Expand All @@ -14,10 +17,13 @@ def __init__(self, path: str, normalize_config: dict | None = None):
self.normalize_config = normalize_config

@abstractmethod
def export_to_dataframe(self, stream, **kwargs) -> pd.DataFrame:
def export_to_dataframe(
self, stream, load_options: LoadOptions | PandasLoadOptions | None = None, **kwargs
) -> pd.DataFrame:
"""read file from one of the supported locations and return dataframe
:param stream: file stream object
:param load_options: Pandas option to pass to the Pandas lib while reading file
"""
raise NotImplementedError

Expand Down
12 changes: 11 additions & 1 deletion python-sdk/src/astro/files/types/csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,14 @@

import io

import attr
import pandas as pd

from astro.constants import FileType as FileTypeConstants
from astro.dataframes.load_options import PandasLoadOptions
from astro.dataframes.pandas import PandasDataframe
from astro.files.types.base import FileType
from astro.options import LoadOptions
from astro.utils.dataframe import convert_columns_names_capitalization


Expand All @@ -15,14 +18,21 @@ class CSVFileType(FileType):

# We need skipcq because it's a method overloading so we don't want to make it a static method
def export_to_dataframe(
self, stream, columns_names_capitalization="original", **kwargs
self,
stream,
load_options: LoadOptions | PandasLoadOptions | None = None,
columns_names_capitalization="original",
**kwargs,
) -> pd.DataFrame: # skipcq PYL-R0201
"""read csv file from one of the supported locations and return dataframe
:param stream: file stream object
:param load_options: Pandas option to pass to the Pandas lib while reading csv
:param columns_names_capitalization: determines whether to convert all columns to lowercase/uppercase
in the resulting dataframe
"""
if isinstance(load_options, PandasLoadOptions):
kwargs.update(attr.asdict(load_options))
df = pd.read_csv(stream, **kwargs)
df = convert_columns_names_capitalization(
df=df, columns_names_capitalization=columns_names_capitalization
Expand Down
7 changes: 7 additions & 0 deletions python-sdk/src/astro/files/types/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,14 @@

import io

import attr
import pandas as pd

from astro.constants import FileType as FileTypeConstants
from astro.dataframes.load_options import PandasLoadOptions
from astro.dataframes.pandas import PandasDataframe
from astro.files.types.base import FileType
from astro.options import LoadOptions
from astro.utils.dataframe import convert_columns_names_capitalization


Expand All @@ -17,18 +20,22 @@ class JSONFileType(FileType):
def export_to_dataframe(
self,
stream: io.TextIOWrapper,
load_options: LoadOptions | PandasLoadOptions | None = None,
columns_names_capitalization="original",
**kwargs,
) -> pd.DataFrame: # skipcq PYL-R0201
"""read json file from one of the supported locations and return dataframe
:param stream: file stream object
:param load_options: Pandas option to pass to the Pandas lib while reading json
:param columns_names_capitalization: determines whether to convert all columns to lowercase/uppercase
in the resulting dataframe
"""
kwargs_copy = dict(kwargs)
# Pandas `read_json` does not support the `nrows` parameter unless we're using NDJSON
kwargs_copy.pop("nrows", None)
if isinstance(load_options, PandasLoadOptions):
kwargs_copy.update(attr.asdict(load_options))
df = pd.read_json(stream, **kwargs_copy)
df = convert_columns_names_capitalization(
df=df, columns_names_capitalization=columns_names_capitalization
Expand Down
Loading

0 comments on commit 6850b44

Please sign in to comment.