diff --git a/databricks/koalas/internal.py b/databricks/koalas/internal.py index 16687e9e9e..16af2a94ec 100644 --- a/databricks/koalas/internal.py +++ b/databricks/koalas/internal.py @@ -534,13 +534,15 @@ def __init__( index_dtypes = [ spark_type_to_pandas_dtype(spark_frame.select(scol).schema[0].dataType) - if dtype is None + if dtype is None or dtype == np.dtype("object") else dtype for dtype, scol in zip(index_dtypes, index_spark_columns) ] assert all( - as_spark_type(dtype, raise_error=False) is not None for dtype in index_dtypes + isinstance(dtype, Dtype.__args__) # type: ignore + and (dtype == np.dtype("object") or as_spark_type(dtype, raise_error=False) is not None) + for dtype in index_dtypes ), index_dtypes self._index_dtypes = index_dtypes @@ -593,13 +595,15 @@ def __init__( data_dtypes = [ spark_type_to_pandas_dtype(spark_frame.select(scol).schema[0].dataType) - if dtype is None + if dtype is None or dtype == np.dtype("object") else dtype for dtype, scol in zip(data_dtypes, data_spark_columns) ] assert all( - as_spark_type(dtype, raise_error=False) is not None for dtype in data_dtypes + isinstance(dtype, Dtype.__args__) # type: ignore + and (dtype == np.dtype("object") or as_spark_type(dtype, raise_error=False) is not None) + for dtype in data_dtypes ), data_dtypes self._data_dtypes = data_dtypes