From 0551b17c53fd01743dce8d7a7b48f358933b7187 Mon Sep 17 00:00:00 2001 From: John Kerl Date: Wed, 4 Oct 2023 17:11:54 -0400 Subject: [PATCH] [python] Support new set-membership query condition --- .../python/src/tiledbsoma/_query_condition.py | 52 ++++++++++--- apis/python/src/tiledbsoma/query_condition.cc | 78 ++++++++++++++++++- 2 files changed, 117 insertions(+), 13 deletions(-) diff --git a/apis/python/src/tiledbsoma/_query_condition.py b/apis/python/src/tiledbsoma/_query_condition.py index 9ce7e58fe5..012d8ca436 100644 --- a/apis/python/src/tiledbsoma/_query_condition.py +++ b/apis/python/src/tiledbsoma/_query_condition.py @@ -136,8 +136,9 @@ def init_query_condition( config: Optional[dict], timestamps: Optional[Tuple[OpenTimestamp, OpenTimestamp]], ): + ctx = tiledb.Ctx(config) qctree = QueryConditionTree( - tiledb.open(uri, ctx=tiledb.Ctx(config), timestamp=timestamps), query_attrs + ctx, tiledb.open(uri, ctx=ctx, timestamp=timestamps), query_attrs ) self.c_obj = qctree.visit(self.tree.body) @@ -152,6 +153,7 @@ def init_query_condition( @attrs.define class QueryConditionTree(ast.NodeVisitor): + ctx: tiledb.Ctx array: tiledb.Array query_attrs: List[str] @@ -188,6 +190,9 @@ def visit_NotEq(self, node): def visit_In(self, node): return node + def visit_NotIn(self, node): + return node + def visit_List(self, node): return list(node.elts) @@ -216,23 +221,34 @@ def visit_Compare(self, node: ast.Compare) -> clib.PyQueryCondition: self.visit(lhs), self.visit(op), self.visit(rhs) ) result = result.combine(value, clib.TILEDB_AND) - elif isinstance(operator, ast.In): + elif isinstance(operator, (ast.In, ast.NotIn)): rhs = node.comparators[0] if not isinstance(rhs, ast.List): raise tiledb.TileDBError( "`in` operator syntax must be written as `attr in ['l', 'i', 's', 't']`" ) - consts = self.visit(rhs) - result = self.aux_visit_Compare( - self.visit(node.left), clib.TILEDB_EQ, consts[0] - ) + variable = node.left.id + values = [self.get_val_from_node(val) for val in self.visit(rhs)] - for val in consts[1:]: - value = self.aux_visit_Compare( - self.visit(node.left), clib.TILEDB_EQ, val - ) - result = result.combine(value, clib.TILEDB_OR) + if self.array.schema.has_attr(variable): + enum_label = self.array.attr(variable).enum_label + if enum_label is not None: + dt = self.array.enum(enum_label).dtype + else: + dt = self.array.attr(variable).dtype + else: + dt = self.array.schema.attr_or_dim_dtype(variable) + + # sdf.read(column_names=["foo"], value_filter='bar == 999') should result in bar being + # added to the column names. See also https://github.com/single-cell-data/TileDB-SOMA/issues/755 + att = self.get_att_from_node(node.left) + if att not in self.query_attrs: + self.query_attrs.append(att) + + dtype = "string" if dt.kind in "SUa" else dt.name + op = clib.TILEDB_IN if isinstance(operator, ast.In) else clib.TILEDB_NOT_IN + result = self.create_pyqc(dtype)(self.ctx, node.left.id, values, op) return result @@ -405,6 +421,20 @@ def init_pyqc(self, pyqc: clib.PyQueryCondition, dtype: str) -> Callable: return getattr(pyqc, init_fn_name) + def create_pyqc(self, dtype: str) -> Callable: + if dtype != "string": + if np.issubdtype(dtype, np.datetime64): + dtype = "int64" + elif np.issubdtype(dtype, bool): + dtype = "uint8" + + create_fn_name = f"create_{dtype}" + + if not hasattr(clib.PyQueryCondition, create_fn_name): + raise tiledb.TileDBError(f"PyQueryCondition.{create_fn_name}() not found.") + + return getattr(clib.PyQueryCondition, create_fn_name) + def visit_BinOp(self, node: ast.BinOp) -> clib.PyQueryCondition: try: op = self.visit(node.op) diff --git a/apis/python/src/tiledbsoma/query_condition.cc b/apis/python/src/tiledbsoma/query_condition.cc index b08dcaca7a..da597a6f75 100644 --- a/apis/python/src/tiledbsoma/query_condition.cc +++ b/apis/python/src/tiledbsoma/query_condition.cc @@ -102,8 +102,24 @@ class PyQueryCondition { py::capsule __capsule__() { return py::capsule(&qc_, "qc"); } + template + static PyQueryCondition + create(py::object pyctx, const std::string &field_name, + const std::vector &values, tiledb_query_condition_op_t op) { + auto pyqc = PyQueryCondition(pyctx); + + const Context ctx = std::as_const(pyqc.ctx_); + + auto set_membership_qc = + QueryConditionExperimental::create(ctx, field_name, values, op); + + pyqc.qc_ = std::make_shared(std::move(set_membership_qc)); + + return pyqc; + } + PyQueryCondition - combine(PyQueryCondition rhs, + combine(PyQueryCondition qc, tiledb_query_condition_combination_op_t combination_op) const { auto pyqc = PyQueryCondition(nullptr, ctx_.ptr().get()); @@ -113,7 +129,7 @@ class PyQueryCondition { tiledb_query_condition_alloc(ctx_.ptr().get(), &combined_qc)); ctx_.handle_error(tiledb_query_condition_combine( - ctx_.ptr().get(), qc_->ptr().get(), rhs.qc_->ptr().get(), + ctx_.ptr().get(), qc_->ptr().get(), qc.qc_->ptr().get(), combination_op, &combined_qc)); pyqc.qc_ = std::shared_ptr( @@ -199,6 +215,62 @@ void init_query_condition(py::module &m) { .def("combine", &PyQueryCondition::combine) + .def_static( + "create_string", + static_cast &, + tiledb_query_condition_op_t)>(&PyQueryCondition::create)) + .def_static( + "create_uint64", + static_cast &, + tiledb_query_condition_op_t)>(&PyQueryCondition::create)) + .def_static( + "create_int64", + static_cast &, + tiledb_query_condition_op_t)>(&PyQueryCondition::create)) + .def_static( + "create_uint32", + static_cast &, + tiledb_query_condition_op_t)>(&PyQueryCondition::create)) + .def_static( + "create_int32", + static_cast &, + tiledb_query_condition_op_t)>(&PyQueryCondition::create)) + .def_static( + "create_uint16", + static_cast &, + tiledb_query_condition_op_t)>(&PyQueryCondition::create)) + .def_static( + "create_int8", + static_cast &, + tiledb_query_condition_op_t)>(&PyQueryCondition::create)) + .def_static( + "create_uint16", + static_cast &, + tiledb_query_condition_op_t)>(&PyQueryCondition::create)) + .def_static( + "create_int8", + static_cast &, + tiledb_query_condition_op_t)>(&PyQueryCondition::create)) + .def_static( + "create_float32", + static_cast &, + tiledb_query_condition_op_t)>(&PyQueryCondition::create)) + .def_static( + "create_float64", + static_cast &, + tiledb_query_condition_op_t)>(&PyQueryCondition::create)) + .def("__capsule__", &PyQueryCondition::__capsule__); py::enum_(m, "tiledb_query_condition_op_t", @@ -209,6 +281,8 @@ void init_query_condition(py::module &m) { .value("TILEDB_GE", TILEDB_GE) .value("TILEDB_EQ", TILEDB_EQ) .value("TILEDB_NE", TILEDB_NE) + .value("TILEDB_IN", TILEDB_IN) + .value("TILEDB_NOT_IN", TILEDB_NOT_IN) .export_values(); py::enum_(