diff --git a/databricks/koalas/frame.py b/databricks/koalas/frame.py index 9dd3c41630..24ac1b7fce 100644 --- a/databricks/koalas/frame.py +++ b/databricks/koalas/frame.py @@ -3979,7 +3979,6 @@ def select_dtypes(self, include=None, exclude=None): inc_ex=set(include).intersection(set(exclude)))) # Handle Spark types - columns = [] include_spark_type = [] for inc in include: try: @@ -4009,23 +4008,26 @@ def select_dtypes(self, include=None, exclude=None): except: pass - for col in self._internal.data_columns: + columns = [] + column_index = [] + for idx, col in zip(self._internal.column_index, self._internal.data_columns): if len(include) > 0: should_include = ( - infer_dtype_from_object(self[col].dtype.name) in include_numpy_type or - self._sdf.schema[col].dataType in include_spark_type) + infer_dtype_from_object(self[idx].dtype.name) in include_numpy_type or + self._internal.spark_type_for(col) in include_spark_type) else: should_include = not ( - infer_dtype_from_object(self[col].dtype.name) in exclude_numpy_type or - self._sdf.schema[col].dataType in exclude_spark_type) + infer_dtype_from_object(self[idx].dtype.name) in exclude_numpy_type or + self._internal.spark_type_for(col) in exclude_spark_type) if should_include: - columns += col + columns.append(col) + column_index.append(idx) return DataFrame(self._internal.copy( sdf=self._sdf.select(self._internal.index_scols + - [scol_for(self._sdf, col) for col in columns]), - data_columns=columns)) + [self._internal.scol_for(col) for col in columns]), + data_columns=columns, column_index=column_index)) def count(self, axis=None): """