Skip to content

Commit

Permalink
Help misuse of options argument. (#1402)
Browse files Browse the repository at this point in the history
I have sometimes seen misuse of a keyword argument `options`  in read_xxx/to_xxx functions. E.g.,

```py
kdf = ks.read_csv(..., options={ ... })
```

In this case, the argument `options` is actually `{'options': { ... }}`, which is not what the user wants to do.

We can help those cases by getting the `'options'` value.
  • Loading branch information
ueshin authored Apr 8, 2020
1 parent abee019 commit fab9f6f
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 0 deletions.
12 changes: 12 additions & 0 deletions databricks/koalas/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -4072,6 +4072,9 @@ def to_table(
>>> df.to_table('%s.my_table' % db, partition_cols='date')
"""
if "options" in options and isinstance(options.get("options"), dict) and len(options) == 1:
options = options.get("options") # type: ignore

self.to_spark(index_col=index_col).write.saveAsTable(
name=name, format=format, mode=mode, partitionBy=partition_cols, **options
)
Expand Down Expand Up @@ -4141,6 +4144,9 @@ def to_delta(
>>> df.to_delta('%s/to_delta/bar' % path,
... mode='overwrite', replaceWhere='date >= "2012-01-01"')
"""
if "options" in options and isinstance(options.get("options"), dict) and len(options) == 1:
options = options.get("options") # type: ignore

self.to_spark_io(
path=path,
mode=mode,
Expand Down Expand Up @@ -4212,6 +4218,9 @@ def to_parquet(
... mode = 'overwrite',
... partition_cols=['date', 'country'])
"""
if "options" in options and isinstance(options.get("options"), dict) and len(options) == 1:
options = options.get("options") # type: ignore

builder = self.to_spark(index_col=index_col).write.mode(mode)
OptionUtils._set_opts(
builder, mode=mode, partitionBy=partition_cols, compression=compression
Expand Down Expand Up @@ -4277,6 +4286,9 @@ def to_spark_io(
>>> df.to_spark_io(path='%s/to_spark_io/foo.json' % path, format='json')
"""
if "options" in options and isinstance(options.get("options"), dict) and len(options) == 1:
options = options.get("options") # type: ignore

self.to_spark(index_col=index_col).write.save(
path=path, format=format, mode=mode, partitionBy=partition_cols, **options
)
Expand Down
6 changes: 6 additions & 0 deletions databricks/koalas/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -669,6 +669,9 @@ def to_csv(
... ... 2012-02-29 12:00:00
... ... 2012-03-31 12:00:00
"""
if "options" in options and isinstance(options.get("options"), dict) and len(options) == 1:
options = options.get("options") # type: ignore

if path is None:
# If path is none, just collect and use pandas's to_csv.
kdf_or_ser = self
Expand Down Expand Up @@ -826,6 +829,9 @@ def to_json(
0 a
1 c
"""
if "options" in options and isinstance(options.get("options"), dict) and len(options) == 1:
options = options.get("options") # type: ignore

if path is None:
# If path is none, just collect and use pandas's to_json.
kdf_or_ser = self
Expand Down
24 changes: 24 additions & 0 deletions databricks/koalas/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,9 @@ def read_csv(
--------
>>> ks.read_csv('data.csv') # doctest: +SKIP
"""
if "options" in options and isinstance(options.get("options"), dict) and len(options) == 1:
options = options.get("options") # type: ignore

if mangle_dupe_cols is not True:
raise ValueError("mangle_dupe_cols can only be `True`: %s" % mangle_dupe_cols)
if parse_dates is not False:
Expand Down Expand Up @@ -396,6 +399,9 @@ def read_json(path: str, index_col: Optional[Union[str, List[str]]] = None, **op
0 a b
1 c d
"""
if "options" in options and isinstance(options.get("options"), dict) and len(options) == 1:
options = options.get("options") # type: ignore

return read_spark_io(path, format="json", index_col=index_col, **options)


Expand Down Expand Up @@ -472,6 +478,9 @@ def read_delta(
3 13
4 14
"""
if "options" in options and isinstance(options.get("options"), dict) and len(options) == 1:
options = options.get("options") # type: ignore

if version is not None:
options["versionAsOf"] = version
if timestamp is not None:
Expand Down Expand Up @@ -592,6 +601,9 @@ def read_spark_io(
3 13
4 14
"""
if "options" in options and isinstance(options.get("options"), dict) and len(options) == 1:
options = options.get("options") # type: ignore

sdf = default_session().read.load(path=path, format=format, schema=schema, **options)
index_map = _get_index_map(sdf, index_col)

Expand Down Expand Up @@ -639,6 +651,9 @@ def read_parquet(path, columns=None, index_col=None, **options) -> DataFrame:
index
0 0
"""
if "options" in options and isinstance(options.get("options"), dict) and len(options) == 1:
options = options.get("options") # type: ignore

if columns is not None:
columns = list(columns)

Expand Down Expand Up @@ -1111,6 +1126,9 @@ def read_sql_table(table_name, con, schema=None, index_col=None, columns=None, *
--------
>>> ks.read_sql_table('table_name', 'jdbc:postgresql:db_name') # doctest: +SKIP
"""
if "options" in options and isinstance(options.get("options"), dict) and len(options) == 1:
options = options.get("options") # type: ignore

reader = default_session().read
reader.option("dbtable", table_name)
reader.option("url", con)
Expand Down Expand Up @@ -1164,6 +1182,9 @@ def read_sql_query(sql, con, index_col=None, **options):
--------
>>> ks.read_sql_query('SELECT * FROM table_name', 'jdbc:postgresql:db_name') # doctest: +SKIP
"""
if "options" in options and isinstance(options.get("options"), dict) and len(options) == 1:
options = options.get("options") # type: ignore

reader = default_session().read
reader.option("query", sql)
reader.option("url", con)
Expand Down Expand Up @@ -1218,6 +1239,9 @@ def read_sql(sql, con, index_col=None, columns=None, **options):
>>> ks.read_sql('table_name', 'jdbc:postgresql:db_name') # doctest: +SKIP
>>> ks.read_sql('SELECT * FROM table_name', 'jdbc:postgresql:db_name') # doctest: +SKIP
"""
if "options" in options and isinstance(options.get("options"), dict) and len(options) == 1:
options = options.get("options") # type: ignore

striped = sql.strip()
if " " not in striped: # TODO: identify the table name or not more precisely.
return read_sql_table(sql, con, index_col=index_col, columns=columns, **options)
Expand Down

0 comments on commit fab9f6f

Please sign in to comment.