From 8ddddbcf9650aa7bc3dfd94c10da62546aa85150 Mon Sep 17 00:00:00 2001 From: yuzhenmao <57878927+yuzhenmao@users.noreply.github.com> Date: Wed, 30 Sep 2020 23:56:36 +0800 Subject: [PATCH] fix(eda): change dtype 'string' to 'object' --- dataprep/eda/create_report/formatter.py | 10 +++++++++- dataprep/eda/distribution/compute/__init__.py | 4 ++-- dataprep/eda/dtypes.py | 14 ++++++++++++++ dataprep/eda/missing/compute/__init__.py | 3 ++- 4 files changed, 27 insertions(+), 4 deletions(-) diff --git a/dataprep/eda/create_report/formatter.py b/dataprep/eda/create_report/formatter.py index f8afb9aa3..3e15d1294 100644 --- a/dataprep/eda/create_report/formatter.py +++ b/dataprep/eda/create_report/formatter.py @@ -18,7 +18,14 @@ from ..distribution.compute.overview import calc_stats from ..distribution.compute.univariate import cont_comps, nom_comps from ..distribution.render import format_cat_stats, format_num_stats, format_ov_stats -from ..dtypes import CATEGORICAL_DTYPES, Continuous, Nominal, detect_dtype, is_dtype +from ..dtypes import ( + CATEGORICAL_DTYPES, + Continuous, + Nominal, + detect_dtype, + is_dtype, + string_dtype_to_object, +) from ..intermediate import Intermediate from ..missing import render_missing from ..missing.compute.nullivariate import compute_missing_nullivariate @@ -51,6 +58,7 @@ def format_report( # pylint: disable=too-many-locals,too-many-statements with ProgressBar(minimum=1, disable=not progress): df = to_dask(df) + df = string_dtype_to_object(df) if mode == "basic": comps = format_basic(df) # elif mode == "full": diff --git a/dataprep/eda/distribution/compute/__init__.py b/dataprep/eda/distribution/compute/__init__.py index b5d9bd95a..e95cd8bea 100644 --- a/dataprep/eda/distribution/compute/__init__.py +++ b/dataprep/eda/distribution/compute/__init__.py @@ -5,7 +5,7 @@ import dask.dataframe as dd import pandas as pd -from ...dtypes import DTypeDef +from ...dtypes import DTypeDef, string_dtype_to_object from ...intermediate import Intermediate from ...utils import to_dask from .bivariate import compute_bivariate @@ -93,9 +93,9 @@ def compute( dtype = {"a": Continuous(), "b": "nominal"} or dtype = Continuous() or dtype = "Continuous" or dtype = Continuous() """ # pylint: disable=too-many-locals - df = to_dask(df) df.columns = df.columns.astype(str) + df = string_dtype_to_object(df) if not any((x, y, z)): return compute_overview(df, bins, ngroups, largest, timeunit, dtype) diff --git a/dataprep/eda/dtypes.py b/dataprep/eda/dtypes.py index 3ae1d29cb..248825635 100644 --- a/dataprep/eda/dtypes.py +++ b/dataprep/eda/dtypes.py @@ -15,6 +15,9 @@ CATEGORICAL_PANDAS_DTYPES = [pd.CategoricalDtype, pd.PeriodDtype] CATEGORICAL_DTYPES = CATEGORICAL_NUMPY_DTYPES + CATEGORICAL_PANDAS_DTYPES +STRING_PANDAS_DTYPES = [pd.StringDtype] +STRING_DTYPES = STRING_PANDAS_DTYPES + NUMERICAL_NUMPY_DTYPES = [np.number] NUMERICAL_DTYPES = NUMERICAL_NUMPY_DTYPES @@ -256,6 +259,17 @@ def is_pandas_categorical(dtype: Any) -> bool: return any(isinstance(dtype, c) for c in CATEGORICAL_PANDAS_DTYPES) +def string_dtype_to_object(df: dd.DataFrame) -> dd.DataFrame: + """ + Convert string dtype to object dtype + """ + for col in df.columns: + if any(isinstance(df[col].dtype, c) for c in STRING_DTYPES): + df[col] = df[col].astype(object) + + return df + + def drop_null( var: Union[dd.Series, pd.DataFrame, dd.DataFrame] ) -> Union[pd.Series, dd.Series, pd.DataFrame, dd.DataFrame]: diff --git a/dataprep/eda/missing/compute/__init__.py b/dataprep/eda/missing/compute/__init__.py index dbf9ad693..757ab9227 100644 --- a/dataprep/eda/missing/compute/__init__.py +++ b/dataprep/eda/missing/compute/__init__.py @@ -5,7 +5,7 @@ from warnings import catch_warnings, filterwarnings from ...data_array import DataArray, DataFrame -from ...dtypes import DTypeDef +from ...dtypes import DTypeDef, string_dtype_to_object from ...intermediate import Intermediate from .bivariate import compute_missing_bivariate from .nullivariate import compute_missing_nullivariate @@ -53,6 +53,7 @@ def compute_missing( # pylint: disable=too-many-arguments >>> plot_missing(df, "HDI_for_year") >>> plot_missing(df, "HDI_for_year", "population") """ + df = string_dtype_to_object(df) df = DataArray(df) # pylint: disable=no-else-raise