Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Consolidate import and usage of pandas #33480

Merged
merged 1 commit into from
Aug 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions airflow/providers/amazon/aws/transfers/sql_to_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from airflow.providers.common.sql.hooks.sql import DbApiHook

if TYPE_CHECKING:
from pandas import DataFrame
import pandas as pd

from airflow.utils.context import Context

Expand Down Expand Up @@ -134,15 +134,15 @@ def __init__(
raise AirflowException(f"The argument file_format doesn't support {file_format} value.")

@staticmethod
def _fix_dtypes(df: DataFrame, file_format: FILE_FORMAT) -> None:
def _fix_dtypes(df: pd.DataFrame, file_format: FILE_FORMAT) -> None:
"""
Mutate DataFrame to set dtypes for float columns containing NaN values.

Set dtype of object to str to allow for downstream transformations.
"""
try:
import numpy as np
from pandas import Float64Dtype, Int64Dtype
import pandas as pd
except ImportError as e:
from airflow.exceptions import AirflowOptionalProviderFeatureException

Expand All @@ -163,13 +163,13 @@ def _fix_dtypes(df: DataFrame, file_format: FILE_FORMAT) -> None:
# The type ignore can be removed here if https://github.com/numpy/numpy/pull/23690
# is merged and released as currently NumPy does not consider None as valid for x/y.
df[col] = np.where(df[col].isnull(), None, df[col]) # type: ignore[call-overload]
df[col] = df[col].astype(Int64Dtype())
df[col] = df[col].astype(pd.Int64Dtype())
elif np.isclose(notna_series, notna_series.astype(int)).all():
# set to float dtype that retains floats and supports NaNs
# The type ignore can be removed here if https://github.com/numpy/numpy/pull/23690
# is merged and released
df[col] = np.where(df[col].isnull(), None, df[col]) # type: ignore[call-overload]
df[col] = df[col].astype(Float64Dtype())
df[col] = df[col].astype(pd.Float64Dtype())

def execute(self, context: Context) -> None:
sql_hook = self._get_hook()
Expand All @@ -192,7 +192,7 @@ def execute(self, context: Context) -> None:
filename=tmp_file.name, key=object_key, bucket_name=self.s3_bucket, replace=self.replace
)

def _partition_dataframe(self, df: DataFrame) -> Iterable[tuple[str, DataFrame]]:
def _partition_dataframe(self, df: pd.DataFrame) -> Iterable[tuple[str, pd.DataFrame]]:
"""Partition dataframe using pandas groupby() method."""
if not self.groupby_kwargs:
yield "", df
Expand Down
10 changes: 5 additions & 5 deletions airflow/providers/apache/hive/hooks/hive.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from airflow.exceptions import AirflowProviderDeprecationWarning

try:
import pandas
import pandas as pd
except ImportError as e:
from airflow.exceptions import AirflowOptionalProviderFeatureException

Expand Down Expand Up @@ -336,7 +336,7 @@ def test_hql(self, hql: str) -> None:

def load_df(
self,
df: pandas.DataFrame,
df: pd.DataFrame,
table: str,
field_dict: dict[Any, Any] | None = None,
delimiter: str = ",",
Expand All @@ -361,7 +361,7 @@ def load_df(
:param kwargs: passed to self.load_file
"""

def _infer_field_types_from_df(df: pandas.DataFrame) -> dict[Any, Any]:
def _infer_field_types_from_df(df: pd.DataFrame) -> dict[Any, Any]:
dtype_kind_hive_type = {
"b": "BOOLEAN", # boolean
"i": "BIGINT", # signed integer
Expand Down Expand Up @@ -1037,7 +1037,7 @@ def get_pandas_df( # type: ignore
schema: str = "default",
hive_conf: dict[Any, Any] | None = None,
**kwargs,
) -> pandas.DataFrame:
) -> pd.DataFrame:
"""
Get a pandas dataframe from a Hive query.

Expand All @@ -1056,5 +1056,5 @@ def get_pandas_df( # type: ignore
:return: pandas.DateFrame
"""
res = self.get_results(sql, schema=schema, hive_conf=hive_conf)
df = pandas.DataFrame(res["data"], columns=[c[0] for c in res["header"]], **kwargs)
df = pd.DataFrame(res["data"], columns=[c[0] for c in res["header"]], **kwargs)
return df
4 changes: 2 additions & 2 deletions airflow/providers/google/cloud/hooks/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from datetime import datetime, timedelta
from typing import Any, Iterable, Mapping, NoReturn, Sequence, Union, cast

import pandas as pd
from aiohttp import ClientSession as ClientSession
from gcloud.aio.bigquery import Job, Table as Table_async
from google.api_core.page_iterator import HTTPIterator
Expand All @@ -49,7 +50,6 @@
from google.cloud.bigquery.table import EncryptionConfiguration, Row, RowIterator, Table, TableReference
from google.cloud.exceptions import NotFound
from googleapiclient.discovery import Resource, build
from pandas import DataFrame
from pandas_gbq import read_gbq
from pandas_gbq.gbq import GbqConnector # noqa
from requests import Session
Expand Down Expand Up @@ -244,7 +244,7 @@ def get_pandas_df(
parameters: Iterable | Mapping[str, Any] | None = None,
dialect: str | None = None,
**kwargs,
) -> DataFrame:
) -> pd.DataFrame:
"""Get a Pandas DataFrame for the BigQuery results.

The DbApiHook method must be overridden because Pandas doesn't support
Expand Down
6 changes: 3 additions & 3 deletions airflow/providers/presto/hooks/presto.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def get_first(
raise PrestoException(e)

def get_pandas_df(self, sql: str = "", parameters=None, **kwargs):
import pandas
import pandas as pd

cursor = self.get_cursor()
try:
Expand All @@ -168,10 +168,10 @@ def get_pandas_df(self, sql: str = "", parameters=None, **kwargs):
raise PrestoException(e)
column_descriptions = cursor.description
if data:
df = pandas.DataFrame(data, **kwargs)
df = pd.DataFrame(data, **kwargs)
df.columns = [c[0] for c in column_descriptions]
else:
df = pandas.DataFrame(**kwargs)
df = pd.DataFrame(**kwargs)
return df

def insert_rows(
Expand Down
4 changes: 2 additions & 2 deletions airflow/providers/slack/transfers/sql_to_slack.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from tempfile import NamedTemporaryFile
from typing import TYPE_CHECKING, Any, Iterable, Mapping, Sequence

from pandas import DataFrame
import pandas as pd
from tabulate import tabulate

from airflow.exceptions import AirflowException
Expand Down Expand Up @@ -70,7 +70,7 @@ def _get_hook(self) -> DbApiHook:
)
return hook

def _get_query_results(self) -> DataFrame:
def _get_query_results(self) -> pd.DataFrame:
sql_hook = self._get_hook()

self.log.info("Running SQL query: %s", self.sql)
Expand Down
6 changes: 3 additions & 3 deletions airflow/providers/trino/hooks/trino.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ def get_first(
def get_pandas_df(
self, sql: str = "", parameters: Iterable | Mapping[str, Any] | None = None, **kwargs
): # type: ignore[override]
import pandas
import pandas as pd

cursor = self.get_cursor()
try:
Expand All @@ -188,10 +188,10 @@ def get_pandas_df(
raise TrinoException(e)
column_descriptions = cursor.description
if data:
df = pandas.DataFrame(data, **kwargs)
df = pd.DataFrame(data, **kwargs)
df.columns = [c[0] for c in column_descriptions]
else:
df = pandas.DataFrame(**kwargs)
df = pd.DataFrame(**kwargs)
return df

def insert_rows(
Expand Down
8 changes: 4 additions & 4 deletions airflow/serialization/serializers/pandas.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,19 +28,19 @@
deserializers = serializers

if TYPE_CHECKING:
from pandas import DataFrame
import pandas as pd

from airflow.serialization.serde import U

__version__ = 1


def serialize(o: object) -> tuple[U, str, int, bool]:
import pandas as pd
import pyarrow as pa
from pandas import DataFrame
from pyarrow import parquet as pq

if not isinstance(o, DataFrame):
if not isinstance(o, pd.DataFrame):
return "", "", 0, False

# for now, we *always* serialize into in memory
Expand All @@ -53,7 +53,7 @@ def serialize(o: object) -> tuple[U, str, int, bool]:
return buf.getvalue().hex().decode("utf-8"), qualname(o), __version__, True


def deserialize(classname: str, version: int, data: object) -> DataFrame:
def deserialize(classname: str, version: int, data: object) -> pd.DataFrame:
if version > __version__:
raise TypeError(f"serialized {version} of {classname} > {__version__}")

Expand Down
4 changes: 2 additions & 2 deletions tests/serialization/serializers/test_serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import decimal

import numpy
import pandas
import pandas as pd
import pendulum.tz
import pytest
from pendulum import DateTime
Expand Down Expand Up @@ -94,7 +94,7 @@ def test_params(self):
assert i["x"] == d["x"]

def test_pandas(self):
i = pandas.DataFrame(data={"col1": [1, 2], "col2": [3, 4]})
i = pd.DataFrame(data={"col1": [1, 2], "col2": [3, 4]})
e = serialize(i)
d = deserialize(e)
assert i.equals(d)