Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Backport release-1.5] [python] Support new set-membership query condition #1788

Merged
merged 1 commit into from
Oct 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 47 additions & 11 deletions apis/python/src/tiledbsoma/_query_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,8 +136,9 @@ def init_query_condition(
config: Optional[dict],
timestamps: Optional[Tuple[OpenTimestamp, OpenTimestamp]],
):
ctx = tiledb.Ctx(config)
qctree = QueryConditionTree(
tiledb.open(uri, ctx=tiledb.Ctx(config), timestamp=timestamps), query_attrs
ctx, tiledb.open(uri, ctx=ctx, timestamp=timestamps), query_attrs
)
self.c_obj = qctree.visit(self.tree.body)

Expand All @@ -152,6 +153,7 @@ def init_query_condition(

@attrs.define
class QueryConditionTree(ast.NodeVisitor):
ctx: tiledb.Ctx
array: tiledb.Array
query_attrs: List[str]

Expand Down Expand Up @@ -188,6 +190,9 @@ def visit_NotEq(self, node):
def visit_In(self, node):
return node

def visit_NotIn(self, node):
return node

def visit_List(self, node):
return list(node.elts)

Expand Down Expand Up @@ -216,23 +221,35 @@ def visit_Compare(self, node: ast.Compare) -> clib.PyQueryCondition:
self.visit(lhs), self.visit(op), self.visit(rhs)
)
result = result.combine(value, clib.TILEDB_AND)
elif isinstance(operator, ast.In):
elif isinstance(operator, (ast.In, ast.NotIn)):
rhs = node.comparators[0]
if not isinstance(rhs, ast.List):
raise tiledb.TileDBError(
"`in` operator syntax must be written as `attr in ['l', 'i', 's', 't']`"
)

consts = self.visit(rhs)
result = self.aux_visit_Compare(
self.visit(node.left), clib.TILEDB_EQ, consts[0]
)
variable = node.left.id
values = [self.get_val_from_node(val) for val in self.visit(rhs)]

for val in consts[1:]:
value = self.aux_visit_Compare(
self.visit(node.left), clib.TILEDB_EQ, val
)
result = result.combine(value, clib.TILEDB_OR)
if self.array.schema.has_attr(variable):
enum_label = self.array.attr(variable).enum_label
if enum_label is not None:
dt = self.array.enum(enum_label).dtype
else:
dt = self.array.attr(variable).dtype
else:
dt = self.array.schema.attr_or_dim_dtype(variable)

# sdf.read(column_names=["foo"], value_filter='bar == 999') should
# result in bar being added to the column names. See also
# https://github.com/single-cell-data/TileDB-SOMA/issues/755
att = self.get_att_from_node(node.left)
if att not in self.query_attrs:
self.query_attrs.append(att)

dtype = "string" if dt.kind in "SUa" else dt.name
op = clib.TILEDB_IN if isinstance(operator, ast.In) else clib.TILEDB_NOT_IN
result = self.create_pyqc(dtype)(self.ctx, node.left.id, values, op)

return result

Expand Down Expand Up @@ -338,6 +355,9 @@ def get_att_from_node(self, node: QueryConditionNodeElem) -> Any:
)
raise tiledb.TileDBError(f"Attribute `{att}` not found in schema.")

# sdf.read(column_names=["foo"], value_filter='bar == 999') should
# result in bar being added to the column names. See also
# https://github.com/single-cell-data/TileDB-SOMA/issues/755
if att not in self.query_attrs:
self.query_attrs.append(att)

Expand Down Expand Up @@ -405,6 +425,22 @@ def init_pyqc(self, pyqc: clib.PyQueryCondition, dtype: str) -> Callable:

return getattr(pyqc, init_fn_name)

def create_pyqc(self, dtype: str) -> Callable:
if dtype != "string":
if np.issubdtype(dtype, np.datetime64):
dtype = "int64"
elif np.issubdtype(dtype, bool):
dtype = "uint8"

create_fn_name = f"create_{dtype}"

try:
return getattr(clib.PyQueryCondition, create_fn_name)
except AttributeError as ae:
raise tiledb.TileDBError(
f"PyQueryCondition.{create_fn_name}() not found."
) from ae

def visit_BinOp(self, node: ast.BinOp) -> clib.PyQueryCondition:
try:
op = self.visit(node.op)
Expand Down
78 changes: 76 additions & 2 deletions apis/python/src/tiledbsoma/query_condition.cc
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,24 @@ class PyQueryCondition {

py::capsule __capsule__() { return py::capsule(&qc_, "qc"); }

template <typename T>
static PyQueryCondition
create(py::object pyctx, const std::string &field_name,
const std::vector<T> &values, tiledb_query_condition_op_t op) {
auto pyqc = PyQueryCondition(pyctx);

const Context ctx = std::as_const(pyqc.ctx_);

auto set_membership_qc =
QueryConditionExperimental::create(ctx, field_name, values, op);

pyqc.qc_ = std::make_shared<QueryCondition>(std::move(set_membership_qc));

return pyqc;
}

PyQueryCondition
combine(PyQueryCondition rhs,
combine(PyQueryCondition qc,
tiledb_query_condition_combination_op_t combination_op) const {

auto pyqc = PyQueryCondition(nullptr, ctx_.ptr().get());
Expand All @@ -113,7 +129,7 @@ class PyQueryCondition {
tiledb_query_condition_alloc(ctx_.ptr().get(), &combined_qc));

ctx_.handle_error(tiledb_query_condition_combine(
ctx_.ptr().get(), qc_->ptr().get(), rhs.qc_->ptr().get(),
ctx_.ptr().get(), qc_->ptr().get(), qc.qc_->ptr().get(),
combination_op, &combined_qc));

pyqc.qc_ = std::shared_ptr<QueryCondition>(
Expand Down Expand Up @@ -199,6 +215,62 @@ void init_query_condition(py::module &m) {

.def("combine", &PyQueryCondition::combine)

.def_static(
"create_string",
static_cast<PyQueryCondition (*)(
py::object, const std::string &, const std::vector<std::string> &,
tiledb_query_condition_op_t)>(&PyQueryCondition::create))
.def_static(
"create_uint64",
static_cast<PyQueryCondition (*)(
py::object, const std::string &, const std::vector<uint64_t> &,
tiledb_query_condition_op_t)>(&PyQueryCondition::create))
.def_static(
"create_int64",
static_cast<PyQueryCondition (*)(
py::object, const std::string &, const std::vector<int64_t> &,
tiledb_query_condition_op_t)>(&PyQueryCondition::create))
.def_static(
"create_uint32",
static_cast<PyQueryCondition (*)(
py::object, const std::string &, const std::vector<uint32_t> &,
tiledb_query_condition_op_t)>(&PyQueryCondition::create))
.def_static(
"create_int32",
static_cast<PyQueryCondition (*)(
py::object, const std::string &, const std::vector<int32_t> &,
tiledb_query_condition_op_t)>(&PyQueryCondition::create))
.def_static(
"create_uint16",
static_cast<PyQueryCondition (*)(
py::object, const std::string &, const std::vector<uint16_t> &,
tiledb_query_condition_op_t)>(&PyQueryCondition::create))
.def_static(
"create_int8",
static_cast<PyQueryCondition (*)(
py::object, const std::string &, const std::vector<int8_t> &,
tiledb_query_condition_op_t)>(&PyQueryCondition::create))
.def_static(
"create_uint16",
static_cast<PyQueryCondition (*)(
py::object, const std::string &, const std::vector<uint16_t> &,
tiledb_query_condition_op_t)>(&PyQueryCondition::create))
.def_static(
"create_int8",
static_cast<PyQueryCondition (*)(
py::object, const std::string &, const std::vector<int8_t> &,
tiledb_query_condition_op_t)>(&PyQueryCondition::create))
.def_static(
"create_float32",
static_cast<PyQueryCondition (*)(
py::object, const std::string &, const std::vector<float> &,
tiledb_query_condition_op_t)>(&PyQueryCondition::create))
.def_static(
"create_float64",
static_cast<PyQueryCondition (*)(
py::object, const std::string &, const std::vector<double> &,
tiledb_query_condition_op_t)>(&PyQueryCondition::create))

.def("__capsule__", &PyQueryCondition::__capsule__);

py::enum_<tiledb_query_condition_op_t>(m, "tiledb_query_condition_op_t",
Expand All @@ -209,6 +281,8 @@ void init_query_condition(py::module &m) {
.value("TILEDB_GE", TILEDB_GE)
.value("TILEDB_EQ", TILEDB_EQ)
.value("TILEDB_NE", TILEDB_NE)
.value("TILEDB_IN", TILEDB_IN)
.value("TILEDB_NOT_IN", TILEDB_NOT_IN)
.export_values();

py::enum_<tiledb_query_condition_combination_op_t>(
Expand Down
31 changes: 31 additions & 0 deletions apis/python/tests/test_experiment_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,37 @@ def test_experiment_query_value_filter(soma_experiment):
assert query.var().concat()["label"].to_pylist() == var_label_values


@pytest.mark.parametrize("n_obs,n_vars", [(1001, 99)])
def test_experiment_query_value_filter2(soma_experiment):
"""Test query by value filter"""
obs_label_values = ["3", "7", "38", "99"]
var_label_values = ["18", "34", "67"]
with soma_experiment.axis_query(
"RNA",
obs_query=soma.AxisQuery(value_filter=f"label not in {obs_label_values}"),
var_query=soma.AxisQuery(value_filter=f"label not in {var_label_values}"),
) as query:
assert query.n_obs == soma_experiment.obs.count - len(obs_label_values)
assert query.n_vars == soma_experiment.ms["RNA"].var.count - len(
var_label_values
)
all_obs_values = set(
soma_experiment.obs.read(column_names=["label"])
.concat()
.to_pandas()["label"]
)
all_var_values = set(
soma_experiment.ms["RNA"]
.var.read(column_names=["label"])
.concat()
.to_pandas()["label"]
)
qry_obs_values = set(query.obs().concat()["label"].to_pylist())
qry_var_values = set(query.var().concat()["label"].to_pylist())
assert qry_obs_values == all_obs_values.difference(set(obs_label_values))
assert qry_var_values == all_var_values.difference(set(var_label_values))


@pytest.mark.parametrize("n_obs,n_vars", [(1001, 99)])
def test_experiment_query_combo(soma_experiment):
"""Test query by combinations of coords and value_filter"""
Expand Down
Loading