Skip to content

Commit

Permalink
Move imports in AWS SqlToS3Operator transfer to callable function (#2…
Browse files Browse the repository at this point in the history
…9045)

* Move imports to callable function

* Move DataFrame import to if TYPE_CHECKING block

---------

Co-authored-by: Niko Oliveira <[email protected]>
  • Loading branch information
joarobles and o-nikolas authored Jan 30, 2023
1 parent af0bbe6 commit 6282567
Showing 1 changed file with 13 additions and 10 deletions.
23 changes: 13 additions & 10 deletions airflow/providers/amazon/aws/transfers/sql_to_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,6 @@
from tempfile import NamedTemporaryFile
from typing import TYPE_CHECKING, Iterable, Mapping, Sequence

try:
import numpy as np
import pandas as pd
except ImportError as e:
from airflow.exceptions import AirflowOptionalProviderFeatureException

raise AirflowOptionalProviderFeatureException(e)
from typing_extensions import Literal

from airflow.exceptions import AirflowException
Expand All @@ -38,6 +31,8 @@
from airflow.providers.common.sql.hooks.sql import DbApiHook

if TYPE_CHECKING:
from pandas import DataFrame

from airflow.utils.context import Context


Expand Down Expand Up @@ -134,11 +129,19 @@ def __init__(
raise AirflowException(f"The argument file_format doesn't support {file_format} value.")

@staticmethod
def _fix_dtypes(df: pd.DataFrame, file_format: FILE_FORMAT) -> None:
def _fix_dtypes(df: DataFrame, file_format: FILE_FORMAT) -> None:
"""
Mutate DataFrame to set dtypes for float columns containing NaN values.
Set dtype of object to str to allow for downstream transformations.
"""
try:
import numpy as np
from pandas import Float64Dtype, Int64Dtype
except ImportError as e:
from airflow.exceptions import AirflowOptionalProviderFeatureException

raise AirflowOptionalProviderFeatureException(e)

for col in df:

if df[col].dtype.name == "object" and file_format == "parquet":
Expand All @@ -152,11 +155,11 @@ def _fix_dtypes(df: pd.DataFrame, file_format: FILE_FORMAT) -> None:
if np.equal(notna_series, notna_series.astype(int)).all():
# set to dtype that retains integers and supports NaNs
df[col] = np.where(df[col].isnull(), None, df[col])
df[col] = df[col].astype(pd.Int64Dtype())
df[col] = df[col].astype(Int64Dtype())
elif np.isclose(notna_series, notna_series.astype(int)).all():
# set to float dtype that retains floats and supports NaNs
df[col] = np.where(df[col].isnull(), None, df[col])
df[col] = df[col].astype(pd.Float64Dtype())
df[col] = df[col].astype(Float64Dtype())

def execute(self, context: Context) -> None:
sql_hook = self._get_hook()
Expand Down

0 comments on commit 6282567

Please sign in to comment.