From b3f8a0e45a3d9258b9e487f470f7fb85cc70ddb7 Mon Sep 17 00:00:00 2001 From: John Kerl Date: Mon, 18 Sep 2023 12:26:11 -0400 Subject: [PATCH] [python] Flatten categorical `soma_joinid` if presented at `write` (#1698) --- apis/python/src/tiledbsoma/_dataframe.py | 21 +++++++++----- apis/python/tests/test_dataframe.py | 35 ++++++++++++++++++++++++ 2 files changed, 49 insertions(+), 7 deletions(-) diff --git a/apis/python/src/tiledbsoma/_dataframe.py b/apis/python/src/tiledbsoma/_dataframe.py index cb118f9fd4..0133941ff5 100644 --- a/apis/python/src/tiledbsoma/_dataframe.py +++ b/apis/python/src/tiledbsoma/_dataframe.py @@ -409,18 +409,25 @@ def write( for name in values.schema.names: col = values.column(name) n = len(col) + cols_map = dim_cols_map if name in dim_names_set else attr_cols_map if pa.types.is_dictionary(col.type) and col.num_chunks != 0: - attr = self._handle.schema.attr(name) - if attr.enum_label is not None: - # Normal case: writing categorical data to categorical schema. - cols_map[name] = col.chunk(0).indices.to_pandas() - else: - # Schema is non-categorical but the user is writing categorical. - # Simply decategoricalize for them. + if name in dim_names_set: + # Dims are never categorical. Decategoricalize for them. cols_map[name] = pa.chunked_array( [chunk.dictionary_decode() for chunk in col.chunks] ) + else: + attr = self._handle.schema.attr(name) + if attr.enum_label is not None: + # Normal case: writing categorical data to categorical schema. + cols_map[name] = col.chunk(0).indices.to_pandas() + else: + # Schema is non-categorical but the user is writing categorical. + # Simply decategoricalize for them. + cols_map[name] = pa.chunked_array( + [chunk.dictionary_decode() for chunk in col.chunks] + ) else: cols_map[name] = col.to_pandas() diff --git a/apis/python/tests/test_dataframe.py b/apis/python/tests/test_dataframe.py index 8a817e8bf5..77376064f5 100644 --- a/apis/python/tests/test_dataframe.py +++ b/apis/python/tests/test_dataframe.py @@ -914,6 +914,41 @@ def test_write_categorical_types(tmp_path): assert (df == sdf.read().concat().to_pandas()).all().all() +def test_write_categorical_dims(tmp_path): + """ + Categories are not supported as dims. Here we test our handling of what we + do when we are given them as input. + """ + schema = pa.schema( + [ + ("soma_joinid", pa.int64()), + ("string", pa.dictionary(pa.int8(), pa.large_string())), + ] + ) + with soma.DataFrame.create( + tmp_path.as_posix(), + schema=schema, + index_column_names=["soma_joinid"], + enumerations={ + "enum-string": ["b", "a"], + }, + ordered_enumerations=[], + column_to_enumerations={ + "string": "enum-string", + }, + ) as sdf: + df = pd.DataFrame( + data={ + "soma_joinid": pd.Categorical([0, 1, 2, 3], categories=[0, 1, 2, 3]), + "string": pd.Categorical(["a", "b", "a", "b"], categories=["b", "a"]), + } + ) + sdf.write(pa.Table.from_pandas(df)) + + with soma.DataFrame.open(tmp_path.as_posix()) as sdf: + assert (df == sdf.read().concat().to_pandas()).all().all() + + def test_result_order(tmp_path): # cf. https://docs.tiledb.com/main/background/key-concepts-and-data-format#data-layout schema = pa.schema(