Skip to content

Commit

Permalink
[python] Support new set-membership query condition
Browse files Browse the repository at this point in the history
  • Loading branch information
johnkerl committed Oct 5, 2023
1 parent e9c5929 commit 0551b17
Show file tree
Hide file tree
Showing 2 changed files with 117 additions and 13 deletions.
52 changes: 41 additions & 11 deletions apis/python/src/tiledbsoma/_query_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,8 +136,9 @@ def init_query_condition(
config: Optional[dict],
timestamps: Optional[Tuple[OpenTimestamp, OpenTimestamp]],
):
ctx = tiledb.Ctx(config)
qctree = QueryConditionTree(
tiledb.open(uri, ctx=tiledb.Ctx(config), timestamp=timestamps), query_attrs
ctx, tiledb.open(uri, ctx=ctx, timestamp=timestamps), query_attrs
)
self.c_obj = qctree.visit(self.tree.body)

Expand All @@ -152,6 +153,7 @@ def init_query_condition(

@attrs.define
class QueryConditionTree(ast.NodeVisitor):
ctx: tiledb.Ctx
array: tiledb.Array
query_attrs: List[str]

Expand Down Expand Up @@ -188,6 +190,9 @@ def visit_NotEq(self, node):
def visit_In(self, node):
return node

def visit_NotIn(self, node):
return node

def visit_List(self, node):
return list(node.elts)

Expand Down Expand Up @@ -216,23 +221,34 @@ def visit_Compare(self, node: ast.Compare) -> clib.PyQueryCondition:
self.visit(lhs), self.visit(op), self.visit(rhs)
)
result = result.combine(value, clib.TILEDB_AND)
elif isinstance(operator, ast.In):
elif isinstance(operator, (ast.In, ast.NotIn)):
rhs = node.comparators[0]
if not isinstance(rhs, ast.List):
raise tiledb.TileDBError(
"`in` operator syntax must be written as `attr in ['l', 'i', 's', 't']`"
)

consts = self.visit(rhs)
result = self.aux_visit_Compare(
self.visit(node.left), clib.TILEDB_EQ, consts[0]
)
variable = node.left.id
values = [self.get_val_from_node(val) for val in self.visit(rhs)]

for val in consts[1:]:
value = self.aux_visit_Compare(
self.visit(node.left), clib.TILEDB_EQ, val
)
result = result.combine(value, clib.TILEDB_OR)
if self.array.schema.has_attr(variable):
enum_label = self.array.attr(variable).enum_label
if enum_label is not None:
dt = self.array.enum(enum_label).dtype
else:
dt = self.array.attr(variable).dtype
else:
dt = self.array.schema.attr_or_dim_dtype(variable)

# sdf.read(column_names=["foo"], value_filter='bar == 999') should result in bar being
# added to the column names. See also https://github.com/single-cell-data/TileDB-SOMA/issues/755
att = self.get_att_from_node(node.left)
if att not in self.query_attrs:
self.query_attrs.append(att)

dtype = "string" if dt.kind in "SUa" else dt.name
op = clib.TILEDB_IN if isinstance(operator, ast.In) else clib.TILEDB_NOT_IN
result = self.create_pyqc(dtype)(self.ctx, node.left.id, values, op)

return result

Expand Down Expand Up @@ -405,6 +421,20 @@ def init_pyqc(self, pyqc: clib.PyQueryCondition, dtype: str) -> Callable:

return getattr(pyqc, init_fn_name)

def create_pyqc(self, dtype: str) -> Callable:
if dtype != "string":
if np.issubdtype(dtype, np.datetime64):
dtype = "int64"
elif np.issubdtype(dtype, bool):
dtype = "uint8"

create_fn_name = f"create_{dtype}"

if not hasattr(clib.PyQueryCondition, create_fn_name):
raise tiledb.TileDBError(f"PyQueryCondition.{create_fn_name}() not found.")

return getattr(clib.PyQueryCondition, create_fn_name)

def visit_BinOp(self, node: ast.BinOp) -> clib.PyQueryCondition:
try:
op = self.visit(node.op)
Expand Down
78 changes: 76 additions & 2 deletions apis/python/src/tiledbsoma/query_condition.cc
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,24 @@ class PyQueryCondition {

py::capsule __capsule__() { return py::capsule(&qc_, "qc"); }

template <typename T>
static PyQueryCondition
create(py::object pyctx, const std::string &field_name,
const std::vector<T> &values, tiledb_query_condition_op_t op) {
auto pyqc = PyQueryCondition(pyctx);

const Context ctx = std::as_const(pyqc.ctx_);

auto set_membership_qc =
QueryConditionExperimental::create(ctx, field_name, values, op);

pyqc.qc_ = std::make_shared<QueryCondition>(std::move(set_membership_qc));

return pyqc;
}

PyQueryCondition
combine(PyQueryCondition rhs,
combine(PyQueryCondition qc,
tiledb_query_condition_combination_op_t combination_op) const {

auto pyqc = PyQueryCondition(nullptr, ctx_.ptr().get());
Expand All @@ -113,7 +129,7 @@ class PyQueryCondition {
tiledb_query_condition_alloc(ctx_.ptr().get(), &combined_qc));

ctx_.handle_error(tiledb_query_condition_combine(
ctx_.ptr().get(), qc_->ptr().get(), rhs.qc_->ptr().get(),
ctx_.ptr().get(), qc_->ptr().get(), qc.qc_->ptr().get(),
combination_op, &combined_qc));

pyqc.qc_ = std::shared_ptr<QueryCondition>(
Expand Down Expand Up @@ -199,6 +215,62 @@ void init_query_condition(py::module &m) {

.def("combine", &PyQueryCondition::combine)

.def_static(
"create_string",
static_cast<PyQueryCondition (*)(
py::object, const std::string &, const std::vector<std::string> &,
tiledb_query_condition_op_t)>(&PyQueryCondition::create))
.def_static(
"create_uint64",
static_cast<PyQueryCondition (*)(
py::object, const std::string &, const std::vector<uint64_t> &,
tiledb_query_condition_op_t)>(&PyQueryCondition::create))
.def_static(
"create_int64",
static_cast<PyQueryCondition (*)(
py::object, const std::string &, const std::vector<int64_t> &,
tiledb_query_condition_op_t)>(&PyQueryCondition::create))
.def_static(
"create_uint32",
static_cast<PyQueryCondition (*)(
py::object, const std::string &, const std::vector<uint32_t> &,
tiledb_query_condition_op_t)>(&PyQueryCondition::create))
.def_static(
"create_int32",
static_cast<PyQueryCondition (*)(
py::object, const std::string &, const std::vector<int32_t> &,
tiledb_query_condition_op_t)>(&PyQueryCondition::create))
.def_static(
"create_uint16",
static_cast<PyQueryCondition (*)(
py::object, const std::string &, const std::vector<uint16_t> &,
tiledb_query_condition_op_t)>(&PyQueryCondition::create))
.def_static(
"create_int8",
static_cast<PyQueryCondition (*)(
py::object, const std::string &, const std::vector<int8_t> &,
tiledb_query_condition_op_t)>(&PyQueryCondition::create))
.def_static(
"create_uint16",
static_cast<PyQueryCondition (*)(
py::object, const std::string &, const std::vector<uint16_t> &,
tiledb_query_condition_op_t)>(&PyQueryCondition::create))
.def_static(
"create_int8",
static_cast<PyQueryCondition (*)(
py::object, const std::string &, const std::vector<int8_t> &,
tiledb_query_condition_op_t)>(&PyQueryCondition::create))
.def_static(
"create_float32",
static_cast<PyQueryCondition (*)(
py::object, const std::string &, const std::vector<float> &,
tiledb_query_condition_op_t)>(&PyQueryCondition::create))
.def_static(
"create_float64",
static_cast<PyQueryCondition (*)(
py::object, const std::string &, const std::vector<double> &,
tiledb_query_condition_op_t)>(&PyQueryCondition::create))

.def("__capsule__", &PyQueryCondition::__capsule__);

py::enum_<tiledb_query_condition_op_t>(m, "tiledb_query_condition_op_t",
Expand All @@ -209,6 +281,8 @@ void init_query_condition(py::module &m) {
.value("TILEDB_GE", TILEDB_GE)
.value("TILEDB_EQ", TILEDB_EQ)
.value("TILEDB_NE", TILEDB_NE)
.value("TILEDB_IN", TILEDB_IN)
.value("TILEDB_NOT_IN", TILEDB_NOT_IN)
.export_values();

py::enum_<tiledb_query_condition_combination_op_t>(
Expand Down

0 comments on commit 0551b17

Please sign in to comment.