From fab9f6f7e52c89225c04ffb40668129b8a79f746 Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Tue, 7 Apr 2020 18:24:54 -0700 Subject: [PATCH] Help misuse of options argument. (#1402) I have sometimes seen misuse of a keyword argument `options` in read_xxx/to_xxx functions. E.g., ```py kdf = ks.read_csv(..., options={ ... }) ``` In this case, the argument `options` is actually `{'options': { ... }}`, which is not what the user wants to do. We can help those cases by getting the `'options'` value. --- databricks/koalas/frame.py | 12 ++++++++++++ databricks/koalas/generic.py | 6 ++++++ databricks/koalas/namespace.py | 24 ++++++++++++++++++++++++ 3 files changed, 42 insertions(+) diff --git a/databricks/koalas/frame.py b/databricks/koalas/frame.py index d583255d6b..fcbb0b026a 100644 --- a/databricks/koalas/frame.py +++ b/databricks/koalas/frame.py @@ -4072,6 +4072,9 @@ def to_table( >>> df.to_table('%s.my_table' % db, partition_cols='date') """ + if "options" in options and isinstance(options.get("options"), dict) and len(options) == 1: + options = options.get("options") # type: ignore + self.to_spark(index_col=index_col).write.saveAsTable( name=name, format=format, mode=mode, partitionBy=partition_cols, **options ) @@ -4141,6 +4144,9 @@ def to_delta( >>> df.to_delta('%s/to_delta/bar' % path, ... mode='overwrite', replaceWhere='date >= "2012-01-01"') """ + if "options" in options and isinstance(options.get("options"), dict) and len(options) == 1: + options = options.get("options") # type: ignore + self.to_spark_io( path=path, mode=mode, @@ -4212,6 +4218,9 @@ def to_parquet( ... mode = 'overwrite', ... partition_cols=['date', 'country']) """ + if "options" in options and isinstance(options.get("options"), dict) and len(options) == 1: + options = options.get("options") # type: ignore + builder = self.to_spark(index_col=index_col).write.mode(mode) OptionUtils._set_opts( builder, mode=mode, partitionBy=partition_cols, compression=compression @@ -4277,6 +4286,9 @@ def to_spark_io( >>> df.to_spark_io(path='%s/to_spark_io/foo.json' % path, format='json') """ + if "options" in options and isinstance(options.get("options"), dict) and len(options) == 1: + options = options.get("options") # type: ignore + self.to_spark(index_col=index_col).write.save( path=path, format=format, mode=mode, partitionBy=partition_cols, **options ) diff --git a/databricks/koalas/generic.py b/databricks/koalas/generic.py index 814f990975..476c4ea4bc 100644 --- a/databricks/koalas/generic.py +++ b/databricks/koalas/generic.py @@ -669,6 +669,9 @@ def to_csv( ... ... 2012-02-29 12:00:00 ... ... 2012-03-31 12:00:00 """ + if "options" in options and isinstance(options.get("options"), dict) and len(options) == 1: + options = options.get("options") # type: ignore + if path is None: # If path is none, just collect and use pandas's to_csv. kdf_or_ser = self @@ -826,6 +829,9 @@ def to_json( 0 a 1 c """ + if "options" in options and isinstance(options.get("options"), dict) and len(options) == 1: + options = options.get("options") # type: ignore + if path is None: # If path is none, just collect and use pandas's to_json. kdf_or_ser = self diff --git a/databricks/koalas/namespace.py b/databricks/koalas/namespace.py index 2fdac86b9a..9f3244e53f 100644 --- a/databricks/koalas/namespace.py +++ b/databricks/koalas/namespace.py @@ -249,6 +249,9 @@ def read_csv( -------- >>> ks.read_csv('data.csv') # doctest: +SKIP """ + if "options" in options and isinstance(options.get("options"), dict) and len(options) == 1: + options = options.get("options") # type: ignore + if mangle_dupe_cols is not True: raise ValueError("mangle_dupe_cols can only be `True`: %s" % mangle_dupe_cols) if parse_dates is not False: @@ -396,6 +399,9 @@ def read_json(path: str, index_col: Optional[Union[str, List[str]]] = None, **op 0 a b 1 c d """ + if "options" in options and isinstance(options.get("options"), dict) and len(options) == 1: + options = options.get("options") # type: ignore + return read_spark_io(path, format="json", index_col=index_col, **options) @@ -472,6 +478,9 @@ def read_delta( 3 13 4 14 """ + if "options" in options and isinstance(options.get("options"), dict) and len(options) == 1: + options = options.get("options") # type: ignore + if version is not None: options["versionAsOf"] = version if timestamp is not None: @@ -592,6 +601,9 @@ def read_spark_io( 3 13 4 14 """ + if "options" in options and isinstance(options.get("options"), dict) and len(options) == 1: + options = options.get("options") # type: ignore + sdf = default_session().read.load(path=path, format=format, schema=schema, **options) index_map = _get_index_map(sdf, index_col) @@ -639,6 +651,9 @@ def read_parquet(path, columns=None, index_col=None, **options) -> DataFrame: index 0 0 """ + if "options" in options and isinstance(options.get("options"), dict) and len(options) == 1: + options = options.get("options") # type: ignore + if columns is not None: columns = list(columns) @@ -1111,6 +1126,9 @@ def read_sql_table(table_name, con, schema=None, index_col=None, columns=None, * -------- >>> ks.read_sql_table('table_name', 'jdbc:postgresql:db_name') # doctest: +SKIP """ + if "options" in options and isinstance(options.get("options"), dict) and len(options) == 1: + options = options.get("options") # type: ignore + reader = default_session().read reader.option("dbtable", table_name) reader.option("url", con) @@ -1164,6 +1182,9 @@ def read_sql_query(sql, con, index_col=None, **options): -------- >>> ks.read_sql_query('SELECT * FROM table_name', 'jdbc:postgresql:db_name') # doctest: +SKIP """ + if "options" in options and isinstance(options.get("options"), dict) and len(options) == 1: + options = options.get("options") # type: ignore + reader = default_session().read reader.option("query", sql) reader.option("url", con) @@ -1218,6 +1239,9 @@ def read_sql(sql, con, index_col=None, columns=None, **options): >>> ks.read_sql('table_name', 'jdbc:postgresql:db_name') # doctest: +SKIP >>> ks.read_sql('SELECT * FROM table_name', 'jdbc:postgresql:db_name') # doctest: +SKIP """ + if "options" in options and isinstance(options.get("options"), dict) and len(options) == 1: + options = options.get("options") # type: ignore + striped = sql.strip() if " " not in striped: # TODO: identify the table name or not more precisely. return read_sql_table(sql, con, index_col=index_col, columns=columns, **options)