Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Handle Date32 columns in Arrow tables and Polars DataFrames #3377

Merged
merged 8 commits into from
Mar 24, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions altair/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
infer_vegalite_type,
infer_encoding_types,
sanitize_dataframe,
sanitize_arrow_table,
parse_shorthand,
use_signature,
update_nested,
Expand All @@ -18,6 +19,7 @@
"infer_vegalite_type",
"infer_encoding_types",
"sanitize_dataframe",
"sanitize_arrow_table",
"spec_to_html",
"parse_shorthand",
"use_signature",
Expand Down
11 changes: 6 additions & 5 deletions altair/utils/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
else:
from typing_extensions import ParamSpec

from typing import Literal, Protocol, TYPE_CHECKING
from typing import Literal, Protocol, TYPE_CHECKING, runtime_checkable

if TYPE_CHECKING:
from pandas.core.interchange.dataframe_protocol import Column as PandasColumn
Expand All @@ -46,6 +46,7 @@
P = ParamSpec("P")


@runtime_checkable
class DataFrameLike(Protocol):
def __dataframe__(
self, nan_as_null: bool = False, allow_copy: bool = True
Expand Down Expand Up @@ -429,15 +430,15 @@ def sanitize_arrow_table(pa_table):
schema = pa_table.schema
for name in schema.names:
array = pa_table[name]
dtype = schema.field(name).type
if str(dtype).startswith("timestamp"):
dtype_name = str(schema.field(name).type)
if dtype_name.startswith("timestamp") or dtype_name.startswith("date"):
arrays.append(pc.strftime(array))
elif str(dtype).startswith("duration"):
elif dtype_name.startswith("duration"):
raise ValueError(
'Field "{col_name}" has type "{dtype}" which is '
"not supported by Altair. Please convert to "
"either a timestamp or a numerical value."
"".format(col_name=name, dtype=dtype)
"".format(col_name=name, dtype=dtype_name)
)
else:
arrays.append(array)
Expand Down
50 changes: 31 additions & 19 deletions altair/utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,13 +105,12 @@ def raise_max_rows_error():
# mypy gets confused as it doesn't see Dict[Any, Any]
# as equivalent to TDataType
return data # type: ignore[return-value]
elif hasattr(data, "__dataframe__"):
pi = import_pyarrow_interchange()
pa_table = pi.from_dataframe(data)
elif isinstance(data, DataFrameLike):
pa_table = arrow_table_from_dfi_dataframe(data)
if max_rows is not None and pa_table.num_rows > max_rows:
raise_max_rows_error()
# Return pyarrow Table instead of input since the
# `from_dataframe` call may be expensive
# `arrow_table_from_dfi_dataframe` call above may be expensive
return pa_table

if max_rows is not None and len(values) > max_rows:
Expand Down Expand Up @@ -142,10 +141,8 @@ def sample(
else:
# Maybe this should raise an error or return something useful?
return None
elif hasattr(data, "__dataframe__"):
# experimental interchange dataframe support
pi = import_pyarrow_interchange()
pa_table = pi.from_dataframe(data)
elif isinstance(data, DataFrameLike):
pa_table = arrow_table_from_dfi_dataframe(data)
if not n:
if frac is None:
raise ValueError(
Expand Down Expand Up @@ -232,10 +229,8 @@ def to_values(data: DataType) -> ToValuesReturnType:
if "values" not in data:
raise KeyError("values expected in data dict, but not present.")
return data
elif hasattr(data, "__dataframe__"):
# experimental interchange dataframe support
pi = import_pyarrow_interchange()
pa_table = sanitize_arrow_table(pi.from_dataframe(data))
elif isinstance(data, DataFrameLike):
pa_table = sanitize_arrow_table(arrow_table_from_dfi_dataframe(data))
return {"values": pa_table.to_pylist()}
else:
# Should never reach this state as tested by check_data_type
Expand Down Expand Up @@ -277,10 +272,8 @@ def _data_to_json_string(data: DataType) -> str:
if "values" not in data:
raise KeyError("values expected in data dict, but not present.")
return json.dumps(data["values"], sort_keys=True)
elif hasattr(data, "__dataframe__"):
# experimental interchange dataframe support
pi = import_pyarrow_interchange()
pa_table = pi.from_dataframe(data)
elif isinstance(data, DataFrameLike):
pa_table = arrow_table_from_dfi_dataframe(data)
return json.dumps(pa_table.to_pylist())
else:
raise NotImplementedError(
Expand All @@ -303,13 +296,12 @@ def _data_to_csv_string(data: Union[dict, pd.DataFrame, DataFrameLike]) -> str:
if "values" not in data:
raise KeyError("values expected in data dict, but not present")
return pd.DataFrame.from_dict(data["values"]).to_csv(index=False)
elif hasattr(data, "__dataframe__"):
elif isinstance(data, DataFrameLike):
# experimental interchange dataframe support
pi = import_pyarrow_interchange()
import pyarrow as pa
import pyarrow.csv as pa_csv

pa_table = pi.from_dataframe(data)
pa_table = arrow_table_from_dfi_dataframe(data)
csv_buffer = pa.BufferOutputStream()
pa_csv.write_csv(pa_table, csv_buffer)
return csv_buffer.getvalue().to_pybytes().decode()
Expand Down Expand Up @@ -346,3 +338,23 @@ def curry(*args, **kwargs):
stacklevel=1,
)
return curried.curry(*args, **kwargs)


def arrow_table_from_dfi_dataframe(dfi_df: DataFrameLike) -> "pyarrow.lib.Table":
"""Convert a DataFrame Interchange Protocol compatible object to an Arrow Table"""
import pyarrow as pa

# First check if the dataframe object has a method to convert to arrow.
# Give this preference over the pyarrow from_dataframe function since the object
# has more control over the conversion, and may have broader compatibility.
# This is the case for Polars, which supports Date32 columns in direct conversion
# while pyarrow does not yet support this type in from_dataframe
for convert_method_name in ("arrow", "to_arrow", "to_arrow_table"):
convert_method = getattr(dfi_df, convert_method_name, None)
if callable(convert_method):
result = convert_method()
if isinstance(result, pa.Table):
return result
Comment on lines +352 to +357
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is a joy to read this type of code diff. Really nice approach!


pi = import_pyarrow_interchange()
return pi.from_dataframe(dfi_df)
2 changes: 1 addition & 1 deletion altair/vegalite/v5/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def _dataset_name(values: Union[dict, list, core.InlineDataset]) -> str:
values = values.to_dict()
if values == [{}]:
return "empty"
values_json = json.dumps(values, sort_keys=True)
values_json = json.dumps(values, sort_keys=True, default=str)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For computing the values hash, I think it's fine to fallback to string representation for types not supported by json.dumps (datetime.date in this case).

hsh = hashlib.sha256(values_json.encode()).hexdigest()[:32]
return "data-" + hsh

Expand Down
1 change: 1 addition & 0 deletions doc/releases/changes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ Bug Fixes
~~~~~~~~~
- Fix type hints for libraries such as Polars where Altair uses the dataframe interchange protocol (#3297)
- Fix anywidget deprecation warning (#3364)
- Fix handling of Date32 columns in arrow tables and Polars DataFrames (#3377)

Backward-Incompatible Changes
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ all = [
"vega_datasets>=0.9.0",
"vl-convert-python>=1.3.0",
"pyarrow>=11",
"vegafusion[embed]>=1.5.0",
"vegafusion[embed]>=1.6.6",
"anywidget>=0.9.0",
"altair_tiles>=0.3.0"
]
Expand Down
2 changes: 1 addition & 1 deletion tests/utils/test_mimebundle.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ def check_pre_transformed_vega_spec(vega_spec):

# Check that the bin transform has been applied
row0 = data_0["values"][0]
assert row0 == {"a": "A", "b": 28, "b_end": 28.0, "b_start": 0.0}
assert row0 == {"a": "A", "b_end": 28.0, "b_start": 0.0}
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

VegaFusion 1.6.6 strips out the "description" encoding field (which isn't used in the canvas renderer), so it's able to do a better job dropping unused columns.


# And no transforms remain
assert len(data_0.get("transform", [])) == 0
Expand Down
49 changes: 48 additions & 1 deletion tests/utils/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import pandas as pd
import pytest

from altair.utils import infer_vegalite_type, sanitize_dataframe
from altair.utils import infer_vegalite_type, sanitize_dataframe, sanitize_arrow_table

try:
import pyarrow as pa
Expand Down Expand Up @@ -120,6 +120,53 @@ def test_sanitize_dataframe_arrow_columns():
json.dumps(records)


@pytest.mark.skipif(pa is None, reason="pyarrow not installed")
def test_sanitize_pyarrow_table_columns():
# create a dataframe with various types
df = pd.DataFrame(
{
"s": list("abcde"),
"f": np.arange(5, dtype=float),
"i": np.arange(5, dtype=int),
"b": np.array([True, False, True, True, False]),
"d": pd.date_range("2012-01-01", periods=5, freq="H"),
"c": pd.Series(list("ababc"), dtype="category"),
"p": pd.date_range("2012-01-01", periods=5, freq="H").tz_localize("UTC"),
}
)

# Create pyarrow table with explicit schema so that date32 type is preserved
pa_table = pa.Table.from_pandas(
df,
pa.schema(
[
pa.field("s", pa.string()),
pa.field("f", pa.float64()),
pa.field("i", pa.int64()),
pa.field("b", pa.bool_()),
pa.field("d", pa.date32()),
pa.field("c", pa.dictionary(pa.int8(), pa.string())),
pa.field("p", pa.timestamp("ns", tz="UTC")),
]
),
)
sanitized = sanitize_arrow_table(pa_table)
values = sanitized.to_pylist()

assert values[0] == {
"s": "a",
"f": 0.0,
"i": 0,
"b": True,
"d": "2012-01-01T00:00:00",
"c": "a",
"p": "2012-01-01T00:00:00.000000000",
}

# Make sure we can serialize to JSON without error
json.dumps(values)


def test_sanitize_dataframe_colnames():
df = pd.DataFrame(np.arange(12).reshape(4, 3))

Expand Down
Loading