Skip to content

Commit

Permalink
WIP Add query condition and schema evolution enum functions
Browse files Browse the repository at this point in the history
  • Loading branch information
nguyenv committed Aug 23, 2023
1 parent 57591d3 commit 6ee967e
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 4 deletions.
5 changes: 5 additions & 0 deletions tiledb/cc/enumeration.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,11 @@ void init_enumeration(py::module &m) {

.def(py::init<const Context &, py::capsule>(), py::keep_alive<1, 2>())

.def("__capsule__",
[](Enumeration &enmr) {
return py::capsule(enmr.ptr().get(), "enmr", nullptr);
})

.def_property_readonly("name", &Enumeration::name)

.def_property_readonly("type", &Enumeration::type)
Expand Down
11 changes: 8 additions & 3 deletions tiledb/query_condition.cc
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,11 @@ class PyQueryCondition {

py::capsule __capsule__() { return py::capsule(&qc_, "qc", nullptr); }


void set_use_enumeration(bool use_enumeration) {
QueryConditionExperimental::set_use_enumeration(ctx_, *qc_, use_enumeration);
}

PyQueryCondition
combine(PyQueryCondition rhs,
tiledb_query_condition_combination_op_t combination_op) const {
Expand Down Expand Up @@ -150,9 +155,9 @@ void init_query_condition(py::module &m) {
tiledb_query_condition_op_t)>(
&PyQueryCondition::init))

.def("combine", &PyQueryCondition::combine)

.def("__capsule__", &PyQueryCondition::__capsule__);
.def("__capsule__", &PyQueryCondition::__capsule__)
.def("combine", &PyQueryCondition::combine);

py::enum_<tiledb_query_condition_op_t>(m, "tiledb_query_condition_op_t",
py::arithmetic())
Expand Down
22 changes: 21 additions & 1 deletion tiledb/schema_evolution.cc
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,27 @@ void init_schema_evolution(py::module &m) {
if (rc != TILEDB_OK) {
TPY_ERROR_LOC(get_last_ctx_err_str(inst.ctx_, rc));
}
});
})
.def("add_enumeration",
[](ArraySchemaEvolution &inst, py::object enum_py) {
tiledb_enumeration_t *enum_c =
(py::capsule)enum_py.attr("__capsule__")();
if (enum_c == nullptr)
TPY_ERROR_LOC("Invalid Enumeration!");
int rc = tiledb_array_schema_evolution_add_enumeration(
inst.ctx_, inst.evol_, enum_c);
if (rc != TILEDB_OK) {
TPY_ERROR_LOC(get_last_ctx_err_str(inst.ctx_, rc));
}
})
.def("drop_enumeration",
[](ArraySchemaEvolution &inst, const std::string& enumeration_name) {
int rc = tiledb_array_schema_evolution_drop_enumeration(
inst.ctx_, inst.evol_, enumeration_name.c_str());
if (rc != TILEDB_OK) {
TPY_ERROR_LOC(get_last_ctx_err_str(inst.ctx_, rc));
}
});
}

}; // namespace tiledbpy
45 changes: 45 additions & 0 deletions tiledb/tests/test_enumeration.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,3 +81,48 @@ def test_array_schema_enumeration(self):

assert_array_equal(A.df[:]["attr1"], A[:]["attr1"])
assert_array_equal(A.df[:]["attr2"], A[:]["attr2"])

def test_array_schema_enumeration_nullable(self):
uri = self.path("test_array_schema_enumeration")
dom = tiledb.Domain(tiledb.Dim(domain=(1, 8), tile=1))
enum1 = tiledb.Enumeration("enmr1", False, np.arange(3) * 10)
enum2 = tiledb.Enumeration("enmr2", False, ["a", "bb", "ccc"])
attr1 = tiledb.Attr("attr1", dtype=np.int32, enum_label="enmr1")
attr2 = tiledb.Attr("attr2", dtype=np.int32, enum_label="enmr2")
attr3 = tiledb.Attr("attr3", dtype=np.int32)
schema = tiledb.ArraySchema(
domain=dom, attrs=(attr1, attr2, attr3), enums=(enum1, enum2)
)
tiledb.Array.create(uri, schema)

data1 = np.random.randint(0, 3, 8)
data2 = np.random.randint(0, 3, 8)
data3 = np.random.randint(0, 3, 8)

with tiledb.open(uri, "w") as A:
A[:] = {"attr1": data1, "attr2": data2, "attr3": data3}

with tiledb.open(uri, "r") as A:
assert A.enum("enmr1") == enum1
assert attr1.enum_label == "enmr1"
assert A.attr("attr1").enum_label == "enmr1"

assert A.enum("enmr2") == enum2
assert attr2.enum_label == "enmr2"
assert A.attr("attr2").enum_label == "enmr2"

with self.assertRaises(tiledb.TileDBError) as excinfo:
assert A.enum("enmr3") == []
assert " No enumeration named 'enmr3'" in str(excinfo.value)
assert attr3.enum_label is None
assert A.attr("attr3").enum_label is None

if has_pandas():
assert_array_equal(A.df[:]["attr1"].cat.codes, data1)
assert_array_equal(A.df[:]["attr2"].cat.codes, data2)

assert_array_equal(A.df[:]["attr1"], A.multi_index[:]["attr1"])
assert_array_equal(A.df[:]["attr2"], A.multi_index[:]["attr2"])

assert_array_equal(A.df[:]["attr1"], A[:]["attr1"])
assert_array_equal(A.df[:]["attr2"], A[:]["attr2"])

0 comments on commit 6ee967e

Please sign in to comment.