From 6cbb3586b92112fc1f27a179b201ac924e5600cc Mon Sep 17 00:00:00 2001 From: Roman Shaptala Date: Fri, 12 Nov 2021 04:24:46 +0200 Subject: [PATCH] [python] Faster categorical column names selection (#4787) * Faster categorical column names selection (#1) * Faster categorical column names selection Change slow and redundant dataframe query by select_dtypes into a dataframe.dtypes list comprehension * Update compat with CategoricalDtype * sort imports * import CategoricalDtype from pandas.api.types * add categorical import try/except --- python-package/lightgbm/basic.py | 5 +++-- python-package/lightgbm/compat.py | 10 ++++++++++ 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/python-package/lightgbm/basic.py b/python-package/lightgbm/basic.py index 569d9383f680..15e373baab04 100644 --- a/python-package/lightgbm/basic.py +++ b/python-package/lightgbm/basic.py @@ -17,7 +17,8 @@ import numpy as np import scipy.sparse -from .compat import PANDAS_INSTALLED, concat, dt_DataTable, is_dtype_sparse, pd_DataFrame, pd_Series +from .compat import (PANDAS_INSTALLED, concat, dt_DataTable, is_dtype_sparse, pd_CategoricalDtype, pd_DataFrame, + pd_Series) from .libpath import find_lib_path ZERO_THRESHOLD = 1e-35 @@ -567,7 +568,7 @@ def _data_from_pandas(data, feature_name, categorical_feature, pandas_categorica raise ValueError('Input data must be 2 dimensional and non empty.') if feature_name == 'auto' or feature_name is None: data = data.rename(columns=str) - cat_cols = list(data.select_dtypes(include=['category']).columns) + cat_cols = [col for col, dtype in zip(data.columns, data.dtypes) if isinstance(dtype, pd_CategoricalDtype)] cat_cols_not_ordered = [col for col in cat_cols if not data[col].cat.ordered] if pandas_categorical is None: # train dataset pandas_categorical = [list(data[col].cat.categories) for col in cat_cols] diff --git a/python-package/lightgbm/compat.py b/python-package/lightgbm/compat.py index 1b58d256a651..3573a99dd623 100644 --- a/python-package/lightgbm/compat.py +++ b/python-package/lightgbm/compat.py @@ -7,6 +7,10 @@ from pandas import Series as pd_Series from pandas import concat from pandas.api.types import is_sparse as is_dtype_sparse + try: + from pandas import CategoricalDtype as pd_CategoricalDtype + except ImportError: + from pandas.api.types import CategoricalDtype as pd_CategoricalDtype PANDAS_INSTALLED = True except ImportError: PANDAS_INSTALLED = False @@ -23,6 +27,12 @@ class pd_DataFrame: # type: ignore def __init__(self, *args, **kwargs): pass + class pd_CategoricalDtype: + """Dummy class for pandas.CategoricalDtype.""" + + def __init__(self, *args, **kwargs): + pass + concat = None is_dtype_sparse = None