From e63989ac7c62c8210ae153c8193917dd28b07c65 Mon Sep 17 00:00:00 2001 From: John Kerl Date: Thu, 12 Oct 2023 17:33:54 -0400 Subject: [PATCH] [python] Support new set-membership query condition (#1756) * [python] Support new set-membership query condition * unit-test coverage for not-in * code-review feedback * lint * attempt to work around anndata 0.10.0 [WIP] * undo anndata 0.10.0 attempted workaround; code-review feedback --- .../python/src/tiledbsoma/_query_condition.py | 58 +++++++++++--- apis/python/src/tiledbsoma/query_condition.cc | 78 ++++++++++++++++++- apis/python/tests/test_experiment_query.py | 31 ++++++++ 3 files changed, 154 insertions(+), 13 deletions(-) diff --git a/apis/python/src/tiledbsoma/_query_condition.py b/apis/python/src/tiledbsoma/_query_condition.py index 9ce7e58fe5..8ee0476580 100644 --- a/apis/python/src/tiledbsoma/_query_condition.py +++ b/apis/python/src/tiledbsoma/_query_condition.py @@ -136,8 +136,9 @@ def init_query_condition( config: Optional[dict], timestamps: Optional[Tuple[OpenTimestamp, OpenTimestamp]], ): + ctx = tiledb.Ctx(config) qctree = QueryConditionTree( - tiledb.open(uri, ctx=tiledb.Ctx(config), timestamp=timestamps), query_attrs + ctx, tiledb.open(uri, ctx=ctx, timestamp=timestamps), query_attrs ) self.c_obj = qctree.visit(self.tree.body) @@ -152,6 +153,7 @@ def init_query_condition( @attrs.define class QueryConditionTree(ast.NodeVisitor): + ctx: tiledb.Ctx array: tiledb.Array query_attrs: List[str] @@ -188,6 +190,9 @@ def visit_NotEq(self, node): def visit_In(self, node): return node + def visit_NotIn(self, node): + return node + def visit_List(self, node): return list(node.elts) @@ -216,23 +221,35 @@ def visit_Compare(self, node: ast.Compare) -> clib.PyQueryCondition: self.visit(lhs), self.visit(op), self.visit(rhs) ) result = result.combine(value, clib.TILEDB_AND) - elif isinstance(operator, ast.In): + elif isinstance(operator, (ast.In, ast.NotIn)): rhs = node.comparators[0] if not isinstance(rhs, ast.List): raise tiledb.TileDBError( "`in` operator syntax must be written as `attr in ['l', 'i', 's', 't']`" ) - consts = self.visit(rhs) - result = self.aux_visit_Compare( - self.visit(node.left), clib.TILEDB_EQ, consts[0] - ) + variable = node.left.id + values = [self.get_val_from_node(val) for val in self.visit(rhs)] - for val in consts[1:]: - value = self.aux_visit_Compare( - self.visit(node.left), clib.TILEDB_EQ, val - ) - result = result.combine(value, clib.TILEDB_OR) + if self.array.schema.has_attr(variable): + enum_label = self.array.attr(variable).enum_label + if enum_label is not None: + dt = self.array.enum(enum_label).dtype + else: + dt = self.array.attr(variable).dtype + else: + dt = self.array.schema.attr_or_dim_dtype(variable) + + # sdf.read(column_names=["foo"], value_filter='bar == 999') should + # result in bar being added to the column names. See also + # https://github.com/single-cell-data/TileDB-SOMA/issues/755 + att = self.get_att_from_node(node.left) + if att not in self.query_attrs: + self.query_attrs.append(att) + + dtype = "string" if dt.kind in "SUa" else dt.name + op = clib.TILEDB_IN if isinstance(operator, ast.In) else clib.TILEDB_NOT_IN + result = self.create_pyqc(dtype)(self.ctx, node.left.id, values, op) return result @@ -338,6 +355,9 @@ def get_att_from_node(self, node: QueryConditionNodeElem) -> Any: ) raise tiledb.TileDBError(f"Attribute `{att}` not found in schema.") + # sdf.read(column_names=["foo"], value_filter='bar == 999') should + # result in bar being added to the column names. See also + # https://github.com/single-cell-data/TileDB-SOMA/issues/755 if att not in self.query_attrs: self.query_attrs.append(att) @@ -405,6 +425,22 @@ def init_pyqc(self, pyqc: clib.PyQueryCondition, dtype: str) -> Callable: return getattr(pyqc, init_fn_name) + def create_pyqc(self, dtype: str) -> Callable: + if dtype != "string": + if np.issubdtype(dtype, np.datetime64): + dtype = "int64" + elif np.issubdtype(dtype, bool): + dtype = "uint8" + + create_fn_name = f"create_{dtype}" + + try: + return getattr(clib.PyQueryCondition, create_fn_name) + except AttributeError as ae: + raise tiledb.TileDBError( + f"PyQueryCondition.{create_fn_name}() not found." + ) from ae + def visit_BinOp(self, node: ast.BinOp) -> clib.PyQueryCondition: try: op = self.visit(node.op) diff --git a/apis/python/src/tiledbsoma/query_condition.cc b/apis/python/src/tiledbsoma/query_condition.cc index b08dcaca7a..da597a6f75 100644 --- a/apis/python/src/tiledbsoma/query_condition.cc +++ b/apis/python/src/tiledbsoma/query_condition.cc @@ -102,8 +102,24 @@ class PyQueryCondition { py::capsule __capsule__() { return py::capsule(&qc_, "qc"); } + template + static PyQueryCondition + create(py::object pyctx, const std::string &field_name, + const std::vector &values, tiledb_query_condition_op_t op) { + auto pyqc = PyQueryCondition(pyctx); + + const Context ctx = std::as_const(pyqc.ctx_); + + auto set_membership_qc = + QueryConditionExperimental::create(ctx, field_name, values, op); + + pyqc.qc_ = std::make_shared(std::move(set_membership_qc)); + + return pyqc; + } + PyQueryCondition - combine(PyQueryCondition rhs, + combine(PyQueryCondition qc, tiledb_query_condition_combination_op_t combination_op) const { auto pyqc = PyQueryCondition(nullptr, ctx_.ptr().get()); @@ -113,7 +129,7 @@ class PyQueryCondition { tiledb_query_condition_alloc(ctx_.ptr().get(), &combined_qc)); ctx_.handle_error(tiledb_query_condition_combine( - ctx_.ptr().get(), qc_->ptr().get(), rhs.qc_->ptr().get(), + ctx_.ptr().get(), qc_->ptr().get(), qc.qc_->ptr().get(), combination_op, &combined_qc)); pyqc.qc_ = std::shared_ptr( @@ -199,6 +215,62 @@ void init_query_condition(py::module &m) { .def("combine", &PyQueryCondition::combine) + .def_static( + "create_string", + static_cast &, + tiledb_query_condition_op_t)>(&PyQueryCondition::create)) + .def_static( + "create_uint64", + static_cast &, + tiledb_query_condition_op_t)>(&PyQueryCondition::create)) + .def_static( + "create_int64", + static_cast &, + tiledb_query_condition_op_t)>(&PyQueryCondition::create)) + .def_static( + "create_uint32", + static_cast &, + tiledb_query_condition_op_t)>(&PyQueryCondition::create)) + .def_static( + "create_int32", + static_cast &, + tiledb_query_condition_op_t)>(&PyQueryCondition::create)) + .def_static( + "create_uint16", + static_cast &, + tiledb_query_condition_op_t)>(&PyQueryCondition::create)) + .def_static( + "create_int8", + static_cast &, + tiledb_query_condition_op_t)>(&PyQueryCondition::create)) + .def_static( + "create_uint16", + static_cast &, + tiledb_query_condition_op_t)>(&PyQueryCondition::create)) + .def_static( + "create_int8", + static_cast &, + tiledb_query_condition_op_t)>(&PyQueryCondition::create)) + .def_static( + "create_float32", + static_cast &, + tiledb_query_condition_op_t)>(&PyQueryCondition::create)) + .def_static( + "create_float64", + static_cast &, + tiledb_query_condition_op_t)>(&PyQueryCondition::create)) + .def("__capsule__", &PyQueryCondition::__capsule__); py::enum_(m, "tiledb_query_condition_op_t", @@ -209,6 +281,8 @@ void init_query_condition(py::module &m) { .value("TILEDB_GE", TILEDB_GE) .value("TILEDB_EQ", TILEDB_EQ) .value("TILEDB_NE", TILEDB_NE) + .value("TILEDB_IN", TILEDB_IN) + .value("TILEDB_NOT_IN", TILEDB_NOT_IN) .export_values(); py::enum_( diff --git a/apis/python/tests/test_experiment_query.py b/apis/python/tests/test_experiment_query.py index fb20f20dee..09415ca11a 100644 --- a/apis/python/tests/test_experiment_query.py +++ b/apis/python/tests/test_experiment_query.py @@ -182,6 +182,37 @@ def test_experiment_query_value_filter(soma_experiment): assert query.var().concat()["label"].to_pylist() == var_label_values +@pytest.mark.parametrize("n_obs,n_vars", [(1001, 99)]) +def test_experiment_query_value_filter2(soma_experiment): + """Test query by value filter""" + obs_label_values = ["3", "7", "38", "99"] + var_label_values = ["18", "34", "67"] + with soma_experiment.axis_query( + "RNA", + obs_query=soma.AxisQuery(value_filter=f"label not in {obs_label_values}"), + var_query=soma.AxisQuery(value_filter=f"label not in {var_label_values}"), + ) as query: + assert query.n_obs == soma_experiment.obs.count - len(obs_label_values) + assert query.n_vars == soma_experiment.ms["RNA"].var.count - len( + var_label_values + ) + all_obs_values = set( + soma_experiment.obs.read(column_names=["label"]) + .concat() + .to_pandas()["label"] + ) + all_var_values = set( + soma_experiment.ms["RNA"] + .var.read(column_names=["label"]) + .concat() + .to_pandas()["label"] + ) + qry_obs_values = set(query.obs().concat()["label"].to_pylist()) + qry_var_values = set(query.var().concat()["label"].to_pylist()) + assert qry_obs_values == all_obs_values.difference(set(obs_label_values)) + assert qry_var_values == all_var_values.difference(set(var_label_values)) + + @pytest.mark.parametrize("n_obs,n_vars", [(1001, 99)]) def test_experiment_query_combo(soma_experiment): """Test query by combinations of coords and value_filter"""