Skip to content

Commit

Permalink
fix(DRAFT): Add multiple fallbacks for pyarrow JSON
Browse files Browse the repository at this point in the history
  • Loading branch information
dangotbanned committed Nov 12, 2024
1 parent 403b787 commit 3fbc759
Show file tree
Hide file tree
Showing 2 changed files with 94 additions and 8 deletions.
62 changes: 55 additions & 7 deletions altair/datasets/_readers.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

import os
import urllib.request
from collections.abc import Mapping, Sequence
from functools import partial
from http.client import HTTPResponse
from importlib import import_module
Expand All @@ -34,6 +35,7 @@
from narwhals.typing import IntoDataFrameT, IntoExpr, IntoFrameT

if TYPE_CHECKING:
import json # noqa: F401
import sys
from urllib.request import OpenerDirector

Expand Down Expand Up @@ -346,25 +348,71 @@ class _PyArrowReader(_Reader["pa.Table", "pa.Table"]):
def __init__(self, name: _PyArrow, /) -> None:
self._name = _requirements(name)
if not TYPE_CHECKING:
pa = self._import(self._name) # noqa: F841
pa = self._import(self._name)
pa_csv = self._import(f"{self._name}.csv")
pa_feather = self._import(f"{self._name}.feather")
pa_json = self._import(f"{self._name}.json")
pa_parquet = self._import(f"{self._name}.parquet")

pa_read_csv = pa_csv.read_csv
pa_read_feather = pa_feather.read_table
pa_read_json = pa_json.read_json
pa_read_parquet = pa_parquet.read_table

# opt1 = ParseOptions(delimiter="\t") # type: ignore
# HACK: Multiple alternatives to `pyarrow.json.read_json`
# -------------------------------------------------------
# NOTE: Prefer `polars` since it is zero-copy and fast (1)
if find_spec("polars") is not None:
import polars as pl

def pa_read_json(source: StrPath, /, **kwds) -> pa.Table:
return pl.read_json(source).to_arrow()

else:
import json

def stdlib_read_json(source: Any, /, **kwds) -> pa.Table:
if not isinstance(source, (Path)):
obj = json.load(source)
else:
with Path(source).open(encoding="utf-8") as f:
obj = json.load(f)
# Very naive check, but still less likely to fail
if isinstance(obj, Sequence) and isinstance(obj[0], Mapping):
return pa.Table.from_pylist(obj)
else:
# NOTE: Almost certainly will fail on read as of `v2.9.0`
pa_json = self._import(f"{self._name}.json")
return pa_json.read_json(source)

# NOTE: Use `pandas` as a slower fallback (2)
if find_spec("pandas") is not None:
import pandas as pd

def pa_read_json(source: StrPath, /, **kwds) -> pa.Table:
try:
table = (
nw.from_native(
pd.read_json(
source, dtype_backend="pyarrow"
).convert_dtypes(dtype_backend="pyarrow")
)
.with_columns(
nw.selectors.by_dtype(nw.Object).cast(nw.String)
)
.to_arrow()
)
except ValueError:
table = stdlib_read_json(source)
return table
else:
# NOTE: Convert inline from stdlib json (3)
pa_read_json = stdlib_read_json

# Stubs suggest using a dataclass, but no way to construct it
opt2: Any = {"delimiter": "\t"}
tab_sep: Any = {"delimiter": "\t"}

self._read_fn = {
".csv": pa_read_csv,
".json": pa_read_json,
".tsv": partial(pa_read_csv, parse_options=opt2),
".tsv": partial(pa_read_csv, parse_options=tab_sep),
".arrow": pa_read_feather,
}
self._scan_fn = {".parquet": pa_read_parquet}
Expand Down
40 changes: 39 additions & 1 deletion tests/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
from pathlib import Path
from typing import Literal

from altair.datasets._readers import _Backend
from altair.datasets._readers import _Backend, _Pandas, _Polars
from altair.datasets._typing import DatasetName

CACHE_ENV_VAR: Literal["ALTAIR_DATASETS_DIR"] = "ALTAIR_DATASETS_DIR"

Expand Down Expand Up @@ -274,3 +275,40 @@ def test_reader_cache(

assert len(tuple(cache_dir.iterdir())) == 4
assert cached_paths == tuple(cache_dir.iterdir())


@pytest.mark.parametrize(
"dataset",
[
"cars",
"movies",
"wheat",
"barley",
"gapminder",
"income",
"burtin",
pytest.param(
"earthquakes",
marks=pytest.mark.xfail(
reason="GeoJSON seems to not work with pandas -> pyarrow"
),
),
],
)
@pytest.mark.parametrize("fallback", ["polars", "pandas", None])
@skip_requires_pyarrow
def test_pyarrow_read_json(
fallback: _Polars | _Pandas | None,
dataset: DatasetName,
monkeypatch: pytest.MonkeyPatch,
) -> None:
monkeypatch.setenv(CACHE_ENV_VAR, "")

if fallback == "polars" or fallback is None:
monkeypatch.delitem(sys.modules, "pandas", raising=False)
elif fallback == "pandas" or fallback is None:
monkeypatch.setitem(sys.modules, "polars", None)

data = Loader.with_backend("pyarrow")

data(dataset, ".json")

0 comments on commit 3fbc759

Please sign in to comment.