diff --git a/apis/python/src/tiledbsoma/_query_condition.py b/apis/python/src/tiledbsoma/_query_condition.py index a75b0aa1b1..01ff847a4c 100644 --- a/apis/python/src/tiledbsoma/_query_condition.py +++ b/apis/python/src/tiledbsoma/_query_condition.py @@ -12,7 +12,6 @@ import attrs import numpy as np import pyarrow as pa -# import tiledb from . import pytiledbsoma as clib from ._exception import SOMAError @@ -132,12 +131,9 @@ def __attrs_post_init__(self): def init_query_condition( self, schema: pa.Schema, - enum_to_dtype: dict, query_attrs: Optional[List[str]], ): - print(schema) - - qctree = QueryConditionTree(schema, enum_to_dtype, query_attrs) + qctree = QueryConditionTree(schema, query_attrs) self.c_obj = qctree.visit(self.tree.body) if not isinstance(self.c_obj, clib.PyQueryCondition): @@ -152,7 +148,6 @@ def init_query_condition( @attrs.define class QueryConditionTree(ast.NodeVisitor): schema: pa.Schema - enum_to_dtype: dict query_attrs: List[str] def visit_BitOr(self, node): @@ -228,17 +223,15 @@ def visit_Compare(self, node: ast.Compare) -> clib.PyQueryCondition: variable = node.left.id values = [self.get_val_from_node(val) for val in self.visit(rhs)] - - # if self.schema.has_attr(variable): - # enum_label = self.schema.attr(variable).enum_label - # if enum_label is not None: - # dt = self.enum_to_dtype[enum_label] - # else: - # dt = self.schema.attr(variable).dtype - # else: - # dt = self.schema.attr_or_dim_dtype(variable) dt = self.schema.field(variable).type + if pa.types.is_dictionary(dt): + dt = dt.value_type + + if pa.types.is_string(dt) or pa.types.is_large_string(dt) or pa.types.is_binary(dt) or pa.types.is_large_binary(dt): + dtype = "string" + else: + dtype = np.dtype(dt.to_pandas_dtype()).name # sdf.read(column_names=["foo"], value_filter='bar == 999') should # result in bar being added to the column names. See also @@ -246,12 +239,8 @@ def visit_Compare(self, node: ast.Compare) -> clib.PyQueryCondition: att = self.get_att_from_node(node.left) if att not in self.query_attrs: self.query_attrs.append(att) - - if pa.types.is_string(dt) or pa.types.is_large_string(dt) or pa.types.is_binary(dt) or pa.types.is_large_binary(dt): - dtype = "string" - else: - dtype = dt - + + # dtype = "string" if dt.kind in "SUa" else dt.name op = clib.TILEDB_IN if isinstance(operator, ast.In) else clib.TILEDB_NOT_IN result = self.create_pyqc(dtype)(node.left.id, values, op) @@ -267,12 +256,15 @@ def aux_visit_Compare( att = self.get_att_from_node(att) val = self.get_val_from_node(val) - enum_label = self.schema.attr(att).enum_label - if enum_label is not None: - dt = self.enum_to_dtype[enum_label] + + dt = self.schema.field(att).type + if pa.types.is_dictionary(dt): + dt = dt.value_type + + if pa.types.is_string(dt) or pa.types.is_large_string(dt) or pa.types.is_binary(dt) or pa.types.is_large_binary(dt): + dtype = "string" else: - dt = self.schema.attr(att).dtype - dtype = "string" if dt.kind in "SUa" else dt.name + dtype = np.dtype(dt.to_pandas_dtype()).name val = self.cast_val_to_dtype(val, dtype) pyqc = clib.PyQueryCondition() diff --git a/apis/python/src/tiledbsoma/soma_array.cc b/apis/python/src/tiledbsoma/soma_array.cc index 2917e97936..4851951489 100644 --- a/apis/python/src/tiledbsoma/soma_array.cc +++ b/apis/python/src/tiledbsoma/soma_array.cc @@ -86,54 +86,13 @@ bool get_enum_is_ordered(SOMAArray& sr, std::string attr_name){ return attr_to_enmrs.at(attr_name).ordered(); } -/** - * @brief pybind11 bindings - * - */ void load_soma_array(py::module &m) { - m.doc() = "SOMA acceleration library"; - - m.def("version", []() { return tiledbsoma::version::as_string(); }); - - m.def( - "config_logging", - [](const std::string& level, const std::string& logfile) { - LOG_CONFIG(level, logfile); - }, - "level"_a, - "logfile"_a = ""); - - m.def("info", &LOG_INFO, "message"_a = ""); - m.def("debug", &LOG_DEBUG, "message"_a = ""); - - m.def( - "tiledbsoma_stats_enable", - []() { tiledbsoma::stats::enable(); }, - "Enable TileDB internal statistics. Lifecycle: experimental."); - m.def( - "tiledbsoma_stats_disable", - []() { tiledbsoma::stats::disable(); }, - "Disable TileDB internal statistics. Lifecycle: experimental."); - m.def( - "tiledbsoma_stats_reset", - []() { tiledbsoma::stats::reset(); }, - "Reset all TileDB internal statistics to 0. Lifecycle: experimental."); - m.def( - "tiledbsoma_stats_dump", - []() { - py::print(tiledbsoma::version::as_string()); - std::string stats = tiledbsoma::stats::dump(); - py::print(stats); - }, - "Print TileDB internal statistics. Lifecycle: experimental."); - py::class_(m, "SOMAArray") .def( py::init( [](std::string_view uri, std::string_view name, std::optional> column_names_in, - py::object py_query_condition, std::string_view batch_size, ResultOrder result_order, std::map platform_config, @@ -144,41 +103,7 @@ void load_soma_array(py::module &m) { column_names = *column_names_in; } - // Handle query condition based on - // TileDB-Py::PyQuery::set_attr_cond() - QueryCondition* qc = nullptr; - if (!py_query_condition.is(py::none())) { - py::object init_pyqc = py_query_condition.attr( - "init_query_condition"); - - try { - // Column names will be updated with columns present - // in the query condition - auto new_column_names = - init_pyqc(uri, column_names, platform_config, timestamp) - .cast>(); - - // Update the column_names list if it was not empty, - // otherwise continue selecting all columns with an - // empty column_names list - if (!column_names.empty()) { - column_names = new_column_names; - } - } catch (const std::exception& e) { - throw TileDBSOMAError(e.what()); - } - - qc = py_query_condition.attr("c_obj") - .cast() - .ptr() - .get(); - } - - // Release python GIL after we're done accessing python - // objects - py::gil_scoped_release release; - - auto reader = SOMAArray::open( + return SOMAArray::open( OpenMode::read, uri, name, @@ -187,29 +112,65 @@ void load_soma_array(py::module &m) { batch_size, result_order, timestamp); - - // Set query condition if present - if (qc) { - reader->set_condition(*qc); - } - - return reader; }), "uri"_a, py::kw_only(), "name"_a = "unnamed", "column_names"_a = py::none(), - "query_condition"_a = py::none(), "batch_size"_a = "auto", "result_order"_a = ResultOrder::automatic, "platform_config"_a = py::dict(), "timestamp"_a = py::none()) + .def( + "set_condition", + [](SOMAArray& reader, + py::object py_query_condition, + py::object py_schema){ + auto column_names = reader.column_names(); + // Handle query condition based on + // TileDB-Py::PyQuery::set_attr_cond() + QueryCondition* qc = nullptr; + if (!py_query_condition.is(py::none())) { + py::object init_pyqc = py_query_condition.attr( + "init_query_condition"); + try { + // Column names will be updated with columns present + // in the query condition + auto new_column_names = + init_pyqc(py_schema, column_names) + .cast>(); + // Update the column_names list if it was not empty, + // otherwise continue selecting all columns with an + // empty column_names list + if (!column_names.empty()) { + column_names = new_column_names; + } + } catch (const std::exception& e) { + throw TileDBSOMAError(e.what()); + } + qc = py_query_condition.attr("c_obj") + .cast() + .ptr() + .get(); + } + reader.reset(column_names); + + // Release python GIL after we're done accessing python + // objects + py::gil_scoped_release release; + // Set query condition if present + if (qc) { + reader.set_condition(*qc); + } + }, + "py_query_condition"_a, + "py_schema"_a) + .def( "reset", [](SOMAArray& reader, std::optional> column_names_in, - py::object py_query_condition, std::string_view batch_size, ResultOrder result_order) { // Handle optional args @@ -218,55 +179,11 @@ void load_soma_array(py::module &m) { column_names = *column_names_in; } - // Handle query condition based on - // TileDB-Py::PyQuery::set_attr_cond() - QueryCondition* qc = nullptr; - if (!py_query_condition.is(py::none())) { - py::object init_pyqc = py_query_condition.attr( - "init_query_condition"); - - try { - // Convert TileDB::Config to std::unordered map for pybind11 passing - std::unordered_map cfg; - for (const auto& it : reader.ctx()->config()) { - cfg[it.first] = it.second; - } - // Column names will be updated with columns present in - // the query condition - auto new_column_names = - init_pyqc(reader.uri(), column_names, cfg, reader.timestamp()) - .cast>(); - - // Update the column_names list if it was not empty, - // otherwise continue selecting all columns with an - // empty column_names list - if (!column_names.empty()) { - column_names = new_column_names; - } - } catch (const std::exception& e) { - throw TileDBSOMAError(e.what()); - } - - qc = py_query_condition.attr("c_obj") - .cast() - .ptr() - .get(); - } - - // Release python GIL after we're done accessing python objects - py::gil_scoped_release release; - // Reset state of the existing SOMAArray object reader.reset(column_names, batch_size, result_order); - - // Set query condition if present - if (qc) { - reader.set_condition(*qc); - } }, py::kw_only(), "column_names"_a = py::none(), - "query_condition"_a = py::none(), "batch_size"_a = "auto", "result_order"_a = ResultOrder::automatic) diff --git a/apis/python/src/tiledbsoma/soma_dataframe.cc b/apis/python/src/tiledbsoma/soma_dataframe.cc index 17a121e18b..7208d9b98b 100644 --- a/apis/python/src/tiledbsoma/soma_dataframe.cc +++ b/apis/python/src/tiledbsoma/soma_dataframe.cc @@ -62,13 +62,7 @@ void load_soma_dataframe(py::module &m) { .def("set_condition", [](SOMADataFrame& reader, py::object py_query_condition, - py::object pa_schema){ - auto attr_to_enum = reader.get_attr_to_enum_mapping(); - std::map enum_to_dtype; - for(auto const& [attr, enmr] : attr_to_enum){ - enum_to_dtype[attr] = tdb_to_np_dtype( - enmr.type(), enmr.cell_val_num()); - } + py::object pa_schema){ auto column_names = reader.column_names(); // Handle query condition based on // TileDB-Py::PyQuery::set_attr_cond() @@ -80,7 +74,7 @@ void load_soma_dataframe(py::module &m) { // Column names will be updated with columns present // in the query condition auto new_column_names = - init_pyqc(pa_schema, enum_to_dtype, column_names) + init_pyqc(pa_schema, column_names) .cast>(); // Update the column_names list if it was not empty, // otherwise continue selecting all columns with an diff --git a/libtiledbsoma/test/test_query_condition.py b/libtiledbsoma/test/test_query_condition.py index 2a384777b1..da1ff68ad4 100644 --- a/libtiledbsoma/test/test_query_condition.py +++ b/libtiledbsoma/test/test_query_condition.py @@ -3,10 +3,12 @@ import os import pytest +import tiledb import tiledbsoma.pytiledbsoma as clib from tiledbsoma._exception import SOMAError from tiledbsoma._query_condition import QueryCondition +from tiledbsoma._arrow_types import tiledb_schema_to_arrow VERBOSE = False @@ -27,7 +29,10 @@ def pandas_query(uri, condition): def soma_query(uri, condition): qc = QueryCondition(condition) - sr = clib.SOMAArray(uri, query_condition=qc) + sr = clib.SOMAArray(uri) + schema = tiledb_schema_to_arrow( + tiledb.open(uri).schema, uri, tiledb.default_ctx()) + sr.set_condition(qc, schema) arrow_table = sr.read_next() assert sr.results_complete() @@ -105,8 +110,11 @@ def test_query_condition_select_columns(): condition = "percent_mito > 0.02" qc = QueryCondition(condition) + schema = tiledb_schema_to_arrow( + tiledb.open(uri).schema, uri, tiledb.default_ctx()) - sr = clib.SOMAArray(uri, query_condition=qc, column_names=["n_genes"]) + sr = clib.SOMAArray(uri, column_names=["n_genes"]) + sr.set_condition(qc, schema) arrow_table = sr.read_next() assert sr.results_complete() @@ -119,8 +127,11 @@ def test_query_condition_all_columns(): condition = "percent_mito > 0.02" qc = QueryCondition(condition) + schema = tiledb_schema_to_arrow( + tiledb.open(uri).schema, uri, tiledb.default_ctx()) - sr = clib.SOMAArray(uri, query_condition=qc) + sr = clib.SOMAArray(uri) + sr.set_condition(qc, schema) arrow_table = sr.read_next() assert sr.results_complete() @@ -133,8 +144,11 @@ def test_query_condition_reset(): condition = "percent_mito > 0.02" qc = QueryCondition(condition) + schema = tiledb_schema_to_arrow( + tiledb.open(uri).schema, uri, tiledb.default_ctx()) - sr = clib.SOMAArray(uri, query_condition=qc) + sr = clib.SOMAArray(uri) + sr.set_condition(qc, schema) arrow_table = sr.read_next() assert sr.results_complete() @@ -145,7 +159,8 @@ def test_query_condition_reset(): # --------------------------------------------------------------- condition = "percent_mito < 0.02" qc = QueryCondition(condition) - sr.reset(column_names=["percent_mito"], query_condition=qc) + sr.reset(column_names=["percent_mito"]) + sr.set_condition(qc, schema) arrow_table = sr.read_next() @@ -213,7 +228,9 @@ def test_eval_error_conditions(malformed_condition): # with pytest.raises(RuntimeError): qc = QueryCondition(malformed_condition) - sr = clib.SOMAArray(uri, query_condition=qc) + schema = tiledb.open(uri).schema + sr = clib.SOMAArray(uri) + sr.set_condition(qc, schema) sr.read_next()