Skip to content

Commit

Permalink
Support Enumerations On Nullable Attributes and Query Conditions
Browse files Browse the repository at this point in the history
  • Loading branch information
nguyenv committed Aug 24, 2023
1 parent 8cffeb5 commit c2c81de
Show file tree
Hide file tree
Showing 6 changed files with 26 additions and 42 deletions.
23 changes: 12 additions & 11 deletions apis/python/src/tiledbsoma/_query_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,8 +128,8 @@ def __attrs_post_init__(self):
"(Is this an empty expression?)"
)

def init_query_condition(self, schema: tiledb.ArraySchema, query_attrs: List[str]):
qctree = QueryConditionTree(schema, query_attrs)
def init_query_condition(self, uri: str, query_attrs: List[str]):
qctree = QueryConditionTree(tiledb.open(uri), query_attrs)
self.c_obj = qctree.visit(self.tree.body)

if not isinstance(self.c_obj, clib.PyQueryCondition):
Expand All @@ -143,7 +143,7 @@ def init_query_condition(self, schema: tiledb.ArraySchema, query_attrs: List[str

@attrs.define
class QueryConditionTree(ast.NodeVisitor):
schema: tiledb.ArraySchema
array: tiledb.Array
query_attrs: List[str]

def visit_BitOr(self, node):
Expand Down Expand Up @@ -237,14 +237,15 @@ def aux_visit_Compare(

att = self.get_att_from_node(att)
val = self.get_val_from_node(val)

dt = self.schema.attr(att).dtype

if self.schema.attr(att).enum_label is not None:
dtype = "string"
enum_label = self.array.attr(att).enum_label
if enum_label is not None:
dt = self.array.enum(enum_label).dtype
else:
dtype = "string" if dt.kind in "SUa" else dt.name
val = self.cast_val_to_dtype(val, dtype)
dt = self.array.attr(att).dtype

dtype = "string" if dt.kind in "SUa" else dt.name
val = self.cast_val_to_dtype(val, dtype)

pyqc = clib.PyQueryCondition()
self.init_pyqc(pyqc, dtype)(att, val, op)
Expand Down Expand Up @@ -322,8 +323,8 @@ def get_att_from_node(self, node: QueryConditionNodeElem) -> Any:
f"Incorrect type for attribute name: {ast.dump(node)}"
)

if not self.schema.has_attr(att):
if self.schema.domain.has_dim(att):
if not self.array.schema.has_attr(att):
if self.array.schema.domain.has_dim(att):
raise tiledb.TileDBError(
f"`{att}` is a dimension. QueryConditions currently only "
"work on attributes."
Expand Down
4 changes: 2 additions & 2 deletions apis/python/src/tiledbsoma/_tiledb_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,8 @@ def _soma_reader(
# Leave empty arguments out of kwargs to allow C++ constructor defaults to apply, as
# they're not all wrapped in std::optional<>.
kwargs: Dict[str, object] = {}
if schema:
kwargs["schema"] = schema
# if schema:
# kwargs["schema"] = schema
if column_names:
kwargs["column_names"] = column_names
if query_condition:
Expand Down
6 changes: 1 addition & 5 deletions apis/python/src/tiledbsoma/io/ingest.py
Original file line number Diff line number Diff line change
Expand Up @@ -750,7 +750,7 @@ def _write_dataframe_impl(
else:
enums[att.name] = cat
col_to_enums[att.name] = att.name

soma_df = DataFrame.create(
df_uri,
schema=arrow_table.schema,
Expand Down Expand Up @@ -1026,10 +1026,6 @@ def _update_dataframe(
old_type = old_sig[key]
new_type = new_sig[key]

# if it is a pa.dictionary type, we need to check against the index type
if new_type.startswith("dictionary"):
new_type = new_type.split(",")[1].split("=")[-1]

if old_type != new_type:
msgs.append(f"{key} type {old_type} != {new_type}")
if msgs:
Expand Down
4 changes: 3 additions & 1 deletion apis/python/src/tiledbsoma/io/registration/signatures.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ def _string_dict_from_arrow_schema(schema: pa.Schema) -> Dict[str, str]:
retval = {}
for name in schema.names:
arrow_type = schema.field(name).type
if pa.types.is_dictionary(arrow_type):
arrow_type = arrow_type.index_type
retval[name] = _stringify_type(arrow_type)

# The soma_joinid field is specific to SOMA data but does not exist in AnnData/H5AD. When we
Expand Down Expand Up @@ -243,7 +245,7 @@ def from_soma_experiment(
varm_dtypes[varm_layer_name] = str(
varm.schema.field("soma_data").type
)

return cls(
obs_schema=obs_schema,
var_schema=var_schema,
Expand Down
8 changes: 2 additions & 6 deletions apis/python/src/tiledbsoma/pytiledbsoma.cc
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,6 @@ PYBIND11_MODULE(pytiledbsoma, m) {
std::string_view name,
std::optional<std::vector<std::string>> column_names_in,
py::object py_query_condition,
py::object py_schema,
std::string_view batch_size,
ResultOrder result_order,
std::map<std::string, std::string> platform_config,
Expand All @@ -222,7 +221,7 @@ PYBIND11_MODULE(pytiledbsoma, m) {
// Column names will be updated with columns present
// in the query condition
auto new_column_names =
init_pyqc(py_schema, column_names)
init_pyqc(uri, column_names)
.cast<std::vector<std::string>>();

// Update the column_names list if it was not empty,
Expand Down Expand Up @@ -267,7 +266,6 @@ PYBIND11_MODULE(pytiledbsoma, m) {
"name"_a = "unnamed",
"column_names"_a = py::none(),
"query_condition"_a = py::none(),
"schema"_a = py::none(),
"batch_size"_a = "auto",
"result_order"_a = ResultOrder::automatic,
"platform_config"_a = py::dict(),
Expand All @@ -278,7 +276,6 @@ PYBIND11_MODULE(pytiledbsoma, m) {
[](SOMAArray& reader,
std::optional<std::vector<std::string>> column_names_in,
py::object py_query_condition,
py::object py_schema,
std::string_view batch_size,
ResultOrder result_order) {
// Handle optional args
Expand All @@ -298,7 +295,7 @@ PYBIND11_MODULE(pytiledbsoma, m) {
// Column names will be updated with columns present in
// the query condition
auto new_column_names =
init_pyqc(py_schema, column_names)
init_pyqc(reader.uri(), column_names)
.cast<std::vector<std::string>>();

// Update the column_names list if it was not empty,
Expand Down Expand Up @@ -331,7 +328,6 @@ PYBIND11_MODULE(pytiledbsoma, m) {
py::kw_only(),
"column_names"_a = py::none(),
"query_condition"_a = py::none(),
"schema"_a = py::none(),
"batch_size"_a = "auto",
"result_order"_a = ResultOrder::automatic)

Expand Down
23 changes: 6 additions & 17 deletions libtiledbsoma/test/test_query_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,7 @@ def pandas_query(uri, condition):

def soma_query(uri, condition):
qc = QueryCondition(condition)
schema = tiledb.open(uri).schema

with tiledb.open(uri) as A:
schema = A.schema
print(schema)
print(A.enum("louvain").dtype)

sr = clib.SOMAArray(uri, query_condition=qc, schema=schema)
sr = clib.SOMAArray(uri, query_condition=qc)
sr.submit()
arrow_table = sr.read_next()
assert sr.results_complete()
Expand Down Expand Up @@ -115,10 +108,9 @@ def test_query_condition_select_columns():
condition = "percent_mito > 0.02"

qc = QueryCondition(condition)
schema = tiledb.open(uri).schema

sr = clib.SOMAArray(
uri, query_condition=qc, schema=schema, column_names=["n_genes"]
uri, query_condition=qc, column_names=["n_genes"]
)
sr.submit()
arrow_table = sr.read_next()
Expand All @@ -133,9 +125,8 @@ def test_query_condition_all_columns():
condition = "percent_mito > 0.02"

qc = QueryCondition(condition)
schema = tiledb.open(uri).schema

sr = clib.SOMAArray(uri, query_condition=qc, schema=schema)
sr = clib.SOMAArray(uri, query_condition=qc)
sr.submit()
arrow_table = sr.read_next()

Expand All @@ -149,9 +140,8 @@ def test_query_condition_reset():
condition = "percent_mito > 0.02"

qc = QueryCondition(condition)
schema = tiledb.open(uri).schema

sr = clib.SOMAArray(uri, query_condition=qc, schema=schema)
sr = clib.SOMAArray(uri, query_condition=qc)
sr.submit()
arrow_table = sr.read_next()

Expand All @@ -163,7 +153,7 @@ def test_query_condition_reset():
# ---------------------------------------------------------------
condition = "percent_mito < 0.02"
qc = QueryCondition(condition)
sr.reset(column_names=["percent_mito"], query_condition=qc, schema=schema)
sr.reset(column_names=["percent_mito"], query_condition=qc)

sr.submit()
arrow_table = sr.read_next()
Expand Down Expand Up @@ -226,14 +216,13 @@ def test_parsing_error_conditions(malformed_condition):
def test_eval_error_conditions(malformed_condition):
"""Conditions which should not evaluate (but WILL parse)"""
uri = os.path.join(SOMA_URI, "obs")
schema = tiledb.open(uri).schema

# TODO: these raise the wrong error - it should be SOMAError. Change the test
# when https://github.com/single-cell-data/TileDB-SOMA/issues/783 is fixed
#
with pytest.raises(RuntimeError):
qc = QueryCondition(malformed_condition)
sr = clib.SOMAArray(uri, query_condition=qc, schema=schema)
sr = clib.SOMAArray(uri, query_condition=qc)
sr.submit()
sr.read_next()

Expand Down

0 comments on commit c2c81de

Please sign in to comment.