From 7968757242445ea27fe8f0d3447a5f72970b10ba Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Wed, 10 Mar 2021 18:52:51 -0800 Subject: [PATCH 1/2] Check index_dtype and data_dtypes more strict. --- databricks/koalas/internal.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/databricks/koalas/internal.py b/databricks/koalas/internal.py index 16687e9e9e..3c84fe1851 100644 --- a/databricks/koalas/internal.py +++ b/databricks/koalas/internal.py @@ -534,13 +534,15 @@ def __init__( index_dtypes = [ spark_type_to_pandas_dtype(spark_frame.select(scol).schema[0].dataType) - if dtype is None + if dtype is None or dtype == np.dtype("object") else dtype for dtype, scol in zip(index_dtypes, index_spark_columns) ] assert all( - as_spark_type(dtype, raise_error=False) is not None for dtype in index_dtypes + isinstance(dtype, Dtype.__args__) + and (dtype == np.dtype("object") or as_spark_type(dtype, raise_error=False) is not None) + for dtype in index_dtypes ), index_dtypes self._index_dtypes = index_dtypes @@ -593,13 +595,15 @@ def __init__( data_dtypes = [ spark_type_to_pandas_dtype(spark_frame.select(scol).schema[0].dataType) - if dtype is None + if dtype is None or dtype == np.dtype("object") else dtype for dtype, scol in zip(data_dtypes, data_spark_columns) ] assert all( - as_spark_type(dtype, raise_error=False) is not None for dtype in data_dtypes + isinstance(dtype, Dtype.__args__) + and (dtype == np.dtype("object") or as_spark_type(dtype, raise_error=False) is not None) + for dtype in data_dtypes ), data_dtypes self._data_dtypes = data_dtypes From 6826d12dbf8b5f1680db79b6ec48c85c36ea695f Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Wed, 10 Mar 2021 19:03:12 -0800 Subject: [PATCH 2/2] Fix. --- databricks/koalas/internal.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/databricks/koalas/internal.py b/databricks/koalas/internal.py index 3c84fe1851..16af2a94ec 100644 --- a/databricks/koalas/internal.py +++ b/databricks/koalas/internal.py @@ -540,7 +540,7 @@ def __init__( ] assert all( - isinstance(dtype, Dtype.__args__) + isinstance(dtype, Dtype.__args__) # type: ignore and (dtype == np.dtype("object") or as_spark_type(dtype, raise_error=False) is not None) for dtype in index_dtypes ), index_dtypes @@ -601,7 +601,7 @@ def __init__( ] assert all( - isinstance(dtype, Dtype.__args__) + isinstance(dtype, Dtype.__args__) # type: ignore and (dtype == np.dtype("object") or as_spark_type(dtype, raise_error=False) is not None) for dtype in data_dtypes ), data_dtypes