Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[python] Protect against huge enum-of-strings input #3354

Merged
merged 6 commits into from
Dec 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 36 additions & 33 deletions apis/python/src/tiledbsoma/io/conversions.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import numpy as np
import pandas as pd
import pandas._typing as pdt
import pandas.api.types
import pyarrow as pa
import scipy.sparse as sp

Expand All @@ -25,19 +26,31 @@
_MT = TypeVar("_MT", NPNDArray, sp.spmatrix, PDSeries)
_str_to_type = {"boolean": bool, "string": str, "bytes": bytes}

STRING_DECAT_THRESHOLD = 4096
"""
For enum-of-string columns with a cardinality higher than this, we convert from
enum-of-string in the AnnData ``obs``/``var``, to plain string in TileDB-SOMA
``obs``/``var``. However, if we're appending to existing storage, we follow the
schema there.
"""

def decategoricalize_obs_or_var(obs_or_var: pd.DataFrame) -> pd.DataFrame:
"""Performs a typecast into types that TileDB can persist."""
if len(obs_or_var.columns) > 0:
return pd.DataFrame.from_dict(
{
str(k): to_tiledb_supported_array_type(str(k), v)
for k, v in obs_or_var.items()
},
)
else:

def obs_or_var_to_tiledb_supported_array_type(obs_or_var: pd.DataFrame) -> pd.DataFrame:
"""
Performs a typecast into types that TileDB can persist. This includes, as a
performance improvement, converting high-cardinality categorical-of-string
columns (cardinality > 4096) to plain string.
"""
if len(obs_or_var.columns) == 0:
return obs_or_var.copy()

return pd.DataFrame.from_dict(
{
str(k): to_tiledb_supported_array_type(str(k), v)
for k, v in obs_or_var.items()
},
)


@typeguard_ignore
def _to_tiledb_supported_dtype(dtype: _DT) -> _DT:
Expand All @@ -47,36 +60,26 @@ def _to_tiledb_supported_dtype(dtype: _DT) -> _DT:


def to_tiledb_supported_array_type(name: str, x: _MT) -> _MT:
"""Converts datatypes unrepresentable by TileDB into datatypes it can represent.
E.g., float16 -> float32
"""Converts datatypes unrepresentable by TileDB into datatypes it can represent,
e.g., float16 -> float32.
"""
if isinstance(x, (np.ndarray, sp.spmatrix)) or not isinstance(
x.dtype, pd.CategoricalDtype
):
target_dtype = _to_tiledb_supported_dtype(x.dtype)
return x if target_dtype == x.dtype else x.astype(target_dtype)

# categories = x.cat.categories
# cat_dtype = categories.dtype
# if cat_dtype.kind in ("f", "u", "i"):
# if x.hasnans and cat_dtype.kind == "i":
# raise ValueError(
# f"Categorical column {name!r} contains NaN -- unable to convert to TileDB array."
# )
# # More mysterious spurious mypy errors.
# target_dtype = _to_tiledb_supported_dtype(cat_dtype) # type: ignore[arg-type]
# else:
# # Into the weirdness. See if Pandas can help with edge cases.
# inferred = infer_dtype(categories)
# if x.hasnans and inferred in ("boolean", "bytes"):
# raise ValueError(
# "Categorical array contains NaN -- unable to convert to TileDB array."
# )
# target_dtype = np.dtype( # type: ignore[assignment]
# _str_to_type.get(inferred, object)
# )

# return x.astype(target_dtype)
# If the column is categorical-of-string of high cardinality, we declare
# this is likely a mistake, and it will definitely lead to performance
# issues in subsequent processing.
if isinstance(x, pd.Series) and isinstance(x.dtype, pd.CategoricalDtype):
# Heuristic number
if (
pandas.api.types.is_string_dtype(x)
and len(x.cat.categories) > STRING_DECAT_THRESHOLD
):
return x.astype(str)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we add a warning or debugging message that even though the column is a category type, we are coverting that to a string?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you think @bkmartinjr ?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would not - I would instead just add it to the help/docstrings that this is the default behavior.

I don't find these warnings particularly useful, as most of this code runs in production pipelines. Better to document the behavior in the API docs, IMHO.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is a bug - assumes all categoricals are of type str. They can be any primitive type, e.g., int, float, `bool, etc.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @bkmartinjr -- I'll make a follow-on PR -- this one is "Protect against huge enum-of-strings input" -- I'll generalize the follow-on to "Protect against huge enum-of-anything input"


return x


Expand Down
149 changes: 134 additions & 15 deletions apis/python/src/tiledbsoma/io/ingest.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,7 +516,7 @@
df_uri = _util.uri_joinpath(experiment_uri, "obs")
with _write_dataframe(
df_uri,
conversions.decategoricalize_obs_or_var(anndata.obs),
conversions.obs_or_var_to_tiledb_supported_array_type(anndata.obs),
id_column_name=obs_id_name,
axis_mapping=jidmaps.obs_axis,
**ingest_platform_ctx,
Expand Down Expand Up @@ -566,7 +566,7 @@
# MS/meas/VAR
with _write_dataframe(
_util.uri_joinpath(measurement_uri, "var"),
conversions.decategoricalize_obs_or_var(anndata.var),
conversions.obs_or_var_to_tiledb_supported_array_type(anndata.var),
id_column_name=var_id_name,
# Layer existence is pre-checked in the registration phase
axis_mapping=jidmaps.var_axes[measurement_name],
Expand Down Expand Up @@ -701,7 +701,9 @@

with _write_dataframe(
_util.uri_joinpath(raw_uri, "var"),
conversions.decategoricalize_obs_or_var(anndata.raw.var),
conversions.obs_or_var_to_tiledb_supported_array_type(
anndata.raw.var
),
id_column_name=var_id_name,
axis_mapping=jidmaps.var_axes["raw"],
**ingest_platform_ctx,
Expand Down Expand Up @@ -794,7 +796,7 @@

with _write_dataframe(
exp.obs.uri,
conversions.decategoricalize_obs_or_var(new_obs),
conversions.obs_or_var_to_tiledb_supported_array_type(new_obs),
id_column_name=obs_id_name,
platform_config=platform_config,
context=context,
Expand Down Expand Up @@ -859,7 +861,7 @@

with _write_dataframe(
sdf.uri,
conversions.decategoricalize_obs_or_var(new_var),
conversions.obs_or_var_to_tiledb_supported_array_type(new_var),
id_column_name=var_id_name,
platform_config=platform_config,
context=context,
Expand Down Expand Up @@ -1068,6 +1070,130 @@
)


def _extract_new_values_for_append_aux(
previous_soma_dataframe: DataFrame,
arrow_table: pa.Table,
) -> pa.Table:
"""
Helper function for _extract_new_values_for_append.

This does two things:

* Retains only the 'new' rows compared to existing storage.
Example is append-mode updates to the var dataframe for
which it's likely that most/all gene IDs have already been seen.

* String vs categorical of string:

o If we're appending a plain-string column to existing
categorical-of-string storage, convert the about-to-be-written data
to categorical of string, to match.

o If we're appending a categorical-of-string column to existing
plain-string storage, convert the about-to-be-written data
to plain string, to match.

Context: https://github.com/single-cell-data/TileDB-SOMA/issues/3353.
Namely, we find that AnnData's to_h5ad/from_h5ad can categoricalize (without
the user's knowledge or intention) string columns. For example, even
cell_id/barcode, for which there may be millions of distinct values, with no
gain to be had from dictionary encoding, will be converted to categorical.
We find that converting these high-cardinality enums to plain string is a
significant performance win for subsequent accesses. When we do an initial
ingest from AnnData to TileDB-SOMA, we convert from categorical-of-string to
plain string if the cardinality exceeds some threshold.

All well and good -- except for one more complication which is append mode.
Namely, if the new column has high enough cardinality that we would
downgrade to plain string, but the existing storage has
categorical-of-string, we must write the new data as categorical-of-string.
Likewise, if the new column has low enough cardinality that we would keep it
as categorical-of-string, but the existing storage has plain string, we must
write the new data as plain strings.
"""

# Retain only the new rows.
previous_sjids_table = previous_soma_dataframe.read(
column_names=["soma_joinid"]
).concat()
previous_join_ids = set(
int(e)
for e in get_dataframe_values(previous_sjids_table.to_pandas(), SOMA_JOINID)
)
mask = [e.as_py() not in previous_join_ids for e in arrow_table[SOMA_JOINID]]
arrow_table = arrow_table.filter(mask)

# This is a redundant, failsafe check. The append-mode registrar already
# ensure schema homogeneity before we get here.
old_schema = previous_soma_dataframe.schema
new_schema = arrow_table.schema

# Note: we may be doing an add-column that doesn't exist in tiledbsoma
# storage but is present in the new AnnData. We don't need to change
# anything in that case. Regardless, we can't assume that the old
# and new schema have the same column names.

# Helper functions for
def is_str_type(typ: pa.DataType) -> bool:
return cast(bool, typ == pa.string() or typ == pa.large_string())

def is_str_col(field: pa.Field) -> bool:
return is_str_type(field.type)

def is_str_cat_col(field: pa.Field) -> bool:
if not pa.types.is_dictionary(field.type):
return False
return is_str_type(field.type.value_type)

# Make a quick check of the old and new schemas to see if any columns need
# changing between plain string and categorical-of-string. We're about to
# duplicate the new data -- and we must, since pyarrow.Table is immutable --
# so let's only do that if we need to.
any_to_change = False
for name in new_schema.names:
if name not in old_schema.names:
continue
if is_str_col(old_schema.field(name)) and is_str_cat_col(
new_schema.field(name)
):
any_to_change = True
break
if is_str_cat_col(old_schema.field(name)) and is_str_col(
new_schema.field(name)
):
any_to_change = True
break

if any_to_change:
fields_dict = {}
for name in new_schema.names:
if name not in old_schema.names:
continue

Check warning on line 1171 in apis/python/src/tiledbsoma/io/ingest.py

View check run for this annotation

Codecov / codecov/patch

apis/python/src/tiledbsoma/io/ingest.py#L1171

Added line #L1171 was not covered by tests
column = arrow_table.column(name)
old_info = old_schema.field(name)
new_info = new_schema.field(name)
if is_str_col(old_info) and is_str_cat_col(new_info):
# Convert from categorical-of-string to plain string.
column = column.to_pylist()
elif is_str_cat_col(old_info) and is_str_col(new_info):
# Convert from plain string to categorical-of-string. Note:
# libtiledbsoma already merges the enum mappings, e.g if the
# storage has red, yellow, & green, but our new data has some
# yellow, green, and orange.
column = pa.array(
column.to_pylist(),
pa.dictionary(
index_type=old_info.type.index_type,
value_type=old_info.type.value_type,
ordered=old_info.type.ordered,
),
)
fields_dict[name] = column
arrow_table = pa.Table.from_pydict(fields_dict)

return arrow_table


def _extract_new_values_for_append(
df_uri: str,
arrow_table: pa.Table,
Expand Down Expand Up @@ -1098,17 +1224,10 @@
with _factory.open(
df_uri, "r", soma_type=DataFrame, context=context
) as previous_soma_dataframe:
previous_sjids_table = previous_soma_dataframe.read().concat()
previous_join_ids = set(
int(e)
for e in get_dataframe_values(
previous_sjids_table.to_pandas(), SOMA_JOINID
)
return _extract_new_values_for_append_aux(
previous_soma_dataframe, arrow_table
)
mask = [
e.as_py() not in previous_join_ids for e in arrow_table[SOMA_JOINID]
]
return arrow_table.filter(mask)

except DoesNotExistError:
return arrow_table

Expand Down
Loading