Skip to content

Commit

Permalink
[python] Support new set-membership query condition (#1756)
Browse files Browse the repository at this point in the history
* [python] Support new set-membership query condition

* unit-test coverage for not-in

* code-review feedback

* lint

* attempt to work around anndata 0.10.0 [WIP]

* undo anndata 0.10.0 attempted workaround; code-review feedback
  • Loading branch information
johnkerl authored and github-actions[bot] committed Oct 12, 2023
1 parent 09a61e9 commit fbb3607
Show file tree
Hide file tree
Showing 3 changed files with 154 additions and 13 deletions.
58 changes: 47 additions & 11 deletions apis/python/src/tiledbsoma/_query_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,8 +136,9 @@ def init_query_condition(
config: Optional[dict],
timestamps: Optional[Tuple[OpenTimestamp, OpenTimestamp]],
):
ctx = tiledb.Ctx(config)
qctree = QueryConditionTree(
tiledb.open(uri, ctx=tiledb.Ctx(config), timestamp=timestamps), query_attrs
ctx, tiledb.open(uri, ctx=ctx, timestamp=timestamps), query_attrs
)
self.c_obj = qctree.visit(self.tree.body)

Expand All @@ -152,6 +153,7 @@ def init_query_condition(

@attrs.define
class QueryConditionTree(ast.NodeVisitor):
ctx: tiledb.Ctx
array: tiledb.Array
query_attrs: List[str]

Expand Down Expand Up @@ -188,6 +190,9 @@ def visit_NotEq(self, node):
def visit_In(self, node):
return node

def visit_NotIn(self, node):
return node

def visit_List(self, node):
return list(node.elts)

Expand Down Expand Up @@ -216,23 +221,35 @@ def visit_Compare(self, node: ast.Compare) -> clib.PyQueryCondition:
self.visit(lhs), self.visit(op), self.visit(rhs)
)
result = result.combine(value, clib.TILEDB_AND)
elif isinstance(operator, ast.In):
elif isinstance(operator, (ast.In, ast.NotIn)):
rhs = node.comparators[0]
if not isinstance(rhs, ast.List):
raise tiledb.TileDBError(
"`in` operator syntax must be written as `attr in ['l', 'i', 's', 't']`"
)

consts = self.visit(rhs)
result = self.aux_visit_Compare(
self.visit(node.left), clib.TILEDB_EQ, consts[0]
)
variable = node.left.id
values = [self.get_val_from_node(val) for val in self.visit(rhs)]

for val in consts[1:]:
value = self.aux_visit_Compare(
self.visit(node.left), clib.TILEDB_EQ, val
)
result = result.combine(value, clib.TILEDB_OR)
if self.array.schema.has_attr(variable):
enum_label = self.array.attr(variable).enum_label
if enum_label is not None:
dt = self.array.enum(enum_label).dtype
else:
dt = self.array.attr(variable).dtype
else:
dt = self.array.schema.attr_or_dim_dtype(variable)

# sdf.read(column_names=["foo"], value_filter='bar == 999') should
# result in bar being added to the column names. See also
# https://github.com/single-cell-data/TileDB-SOMA/issues/755
att = self.get_att_from_node(node.left)
if att not in self.query_attrs:
self.query_attrs.append(att)

dtype = "string" if dt.kind in "SUa" else dt.name
op = clib.TILEDB_IN if isinstance(operator, ast.In) else clib.TILEDB_NOT_IN
result = self.create_pyqc(dtype)(self.ctx, node.left.id, values, op)

return result

Expand Down Expand Up @@ -338,6 +355,9 @@ def get_att_from_node(self, node: QueryConditionNodeElem) -> Any:
)
raise tiledb.TileDBError(f"Attribute `{att}` not found in schema.")

# sdf.read(column_names=["foo"], value_filter='bar == 999') should
# result in bar being added to the column names. See also
# https://github.com/single-cell-data/TileDB-SOMA/issues/755
if att not in self.query_attrs:
self.query_attrs.append(att)

Expand Down Expand Up @@ -405,6 +425,22 @@ def init_pyqc(self, pyqc: clib.PyQueryCondition, dtype: str) -> Callable:

return getattr(pyqc, init_fn_name)

def create_pyqc(self, dtype: str) -> Callable:
if dtype != "string":
if np.issubdtype(dtype, np.datetime64):
dtype = "int64"
elif np.issubdtype(dtype, bool):
dtype = "uint8"

create_fn_name = f"create_{dtype}"

try:
return getattr(clib.PyQueryCondition, create_fn_name)
except AttributeError as ae:
raise tiledb.TileDBError(
f"PyQueryCondition.{create_fn_name}() not found."
) from ae

def visit_BinOp(self, node: ast.BinOp) -> clib.PyQueryCondition:
try:
op = self.visit(node.op)
Expand Down
78 changes: 76 additions & 2 deletions apis/python/src/tiledbsoma/query_condition.cc
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,24 @@ class PyQueryCondition {

py::capsule __capsule__() { return py::capsule(&qc_, "qc"); }

template <typename T>
static PyQueryCondition
create(py::object pyctx, const std::string &field_name,
const std::vector<T> &values, tiledb_query_condition_op_t op) {
auto pyqc = PyQueryCondition(pyctx);

const Context ctx = std::as_const(pyqc.ctx_);

auto set_membership_qc =
QueryConditionExperimental::create(ctx, field_name, values, op);

pyqc.qc_ = std::make_shared<QueryCondition>(std::move(set_membership_qc));

return pyqc;
}

PyQueryCondition
combine(PyQueryCondition rhs,
combine(PyQueryCondition qc,
tiledb_query_condition_combination_op_t combination_op) const {

auto pyqc = PyQueryCondition(nullptr, ctx_.ptr().get());
Expand All @@ -113,7 +129,7 @@ class PyQueryCondition {
tiledb_query_condition_alloc(ctx_.ptr().get(), &combined_qc));

ctx_.handle_error(tiledb_query_condition_combine(
ctx_.ptr().get(), qc_->ptr().get(), rhs.qc_->ptr().get(),
ctx_.ptr().get(), qc_->ptr().get(), qc.qc_->ptr().get(),
combination_op, &combined_qc));

pyqc.qc_ = std::shared_ptr<QueryCondition>(
Expand Down Expand Up @@ -199,6 +215,62 @@ void init_query_condition(py::module &m) {

.def("combine", &PyQueryCondition::combine)

.def_static(
"create_string",
static_cast<PyQueryCondition (*)(
py::object, const std::string &, const std::vector<std::string> &,
tiledb_query_condition_op_t)>(&PyQueryCondition::create))
.def_static(
"create_uint64",
static_cast<PyQueryCondition (*)(
py::object, const std::string &, const std::vector<uint64_t> &,
tiledb_query_condition_op_t)>(&PyQueryCondition::create))
.def_static(
"create_int64",
static_cast<PyQueryCondition (*)(
py::object, const std::string &, const std::vector<int64_t> &,
tiledb_query_condition_op_t)>(&PyQueryCondition::create))
.def_static(
"create_uint32",
static_cast<PyQueryCondition (*)(
py::object, const std::string &, const std::vector<uint32_t> &,
tiledb_query_condition_op_t)>(&PyQueryCondition::create))
.def_static(
"create_int32",
static_cast<PyQueryCondition (*)(
py::object, const std::string &, const std::vector<int32_t> &,
tiledb_query_condition_op_t)>(&PyQueryCondition::create))
.def_static(
"create_uint16",
static_cast<PyQueryCondition (*)(
py::object, const std::string &, const std::vector<uint16_t> &,
tiledb_query_condition_op_t)>(&PyQueryCondition::create))
.def_static(
"create_int8",
static_cast<PyQueryCondition (*)(
py::object, const std::string &, const std::vector<int8_t> &,
tiledb_query_condition_op_t)>(&PyQueryCondition::create))
.def_static(
"create_uint16",
static_cast<PyQueryCondition (*)(
py::object, const std::string &, const std::vector<uint16_t> &,
tiledb_query_condition_op_t)>(&PyQueryCondition::create))
.def_static(
"create_int8",
static_cast<PyQueryCondition (*)(
py::object, const std::string &, const std::vector<int8_t> &,
tiledb_query_condition_op_t)>(&PyQueryCondition::create))
.def_static(
"create_float32",
static_cast<PyQueryCondition (*)(
py::object, const std::string &, const std::vector<float> &,
tiledb_query_condition_op_t)>(&PyQueryCondition::create))
.def_static(
"create_float64",
static_cast<PyQueryCondition (*)(
py::object, const std::string &, const std::vector<double> &,
tiledb_query_condition_op_t)>(&PyQueryCondition::create))

.def("__capsule__", &PyQueryCondition::__capsule__);

py::enum_<tiledb_query_condition_op_t>(m, "tiledb_query_condition_op_t",
Expand All @@ -209,6 +281,8 @@ void init_query_condition(py::module &m) {
.value("TILEDB_GE", TILEDB_GE)
.value("TILEDB_EQ", TILEDB_EQ)
.value("TILEDB_NE", TILEDB_NE)
.value("TILEDB_IN", TILEDB_IN)
.value("TILEDB_NOT_IN", TILEDB_NOT_IN)
.export_values();

py::enum_<tiledb_query_condition_combination_op_t>(
Expand Down
31 changes: 31 additions & 0 deletions apis/python/tests/test_experiment_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,37 @@ def test_experiment_query_value_filter(soma_experiment):
assert query.var().concat()["label"].to_pylist() == var_label_values


@pytest.mark.parametrize("n_obs,n_vars", [(1001, 99)])
def test_experiment_query_value_filter2(soma_experiment):
"""Test query by value filter"""
obs_label_values = ["3", "7", "38", "99"]
var_label_values = ["18", "34", "67"]
with soma_experiment.axis_query(
"RNA",
obs_query=soma.AxisQuery(value_filter=f"label not in {obs_label_values}"),
var_query=soma.AxisQuery(value_filter=f"label not in {var_label_values}"),
) as query:
assert query.n_obs == soma_experiment.obs.count - len(obs_label_values)
assert query.n_vars == soma_experiment.ms["RNA"].var.count - len(
var_label_values
)
all_obs_values = set(
soma_experiment.obs.read(column_names=["label"])
.concat()
.to_pandas()["label"]
)
all_var_values = set(
soma_experiment.ms["RNA"]
.var.read(column_names=["label"])
.concat()
.to_pandas()["label"]
)
qry_obs_values = set(query.obs().concat()["label"].to_pylist())
qry_var_values = set(query.var().concat()["label"].to_pylist())
assert qry_obs_values == all_obs_values.difference(set(obs_label_values))
assert qry_var_values == all_var_values.difference(set(var_label_values))


@pytest.mark.parametrize("n_obs,n_vars", [(1001, 99)])
def test_experiment_query_combo(soma_experiment):
"""Test query by combinations of coords and value_filter"""
Expand Down

0 comments on commit fbb3607

Please sign in to comment.