Skip to content

Commit

Permalink
Add test for handler
Browse files Browse the repository at this point in the history
  • Loading branch information
ypatia committed Nov 14, 2023
1 parent 1e206cf commit 92a430e
Show file tree
Hide file tree
Showing 5 changed files with 156 additions and 8 deletions.
128 changes: 124 additions & 4 deletions test/src/unit-request-handlers.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,17 @@

#include "test/support/tdb_catch.h"
#include "tiledb/api/c_api/buffer/buffer_api_internal.h"
#include "tiledb/api/c_api/string/string_api_internal.h"
#include "tiledb/sm/array_schema/enumeration.h"
#include "tiledb/sm/c_api/tiledb_serialization.h"
#include "tiledb/sm/c_api/tiledb_struct_def.h"
#include "tiledb/sm/cpp_api/tiledb"
#include "tiledb/sm/crypto/encryption_key.h"
#include "tiledb/sm/enums/array_type.h"
#include "tiledb/sm/enums/encryption_type.h"
#include "tiledb/sm/enums/serialization_type.h"
#include "tiledb/sm/serialization/array_schema.h"
#include "tiledb/sm/serialization/query_plan.h"
#include "tiledb/sm/storage_manager/context.h"

using namespace tiledb::sm;
Expand All @@ -57,9 +60,6 @@ struct RequestHandlerFx {

shared_ptr<Array> get_array(QueryType type);

shared_ptr<const Enumeration> create_string_enumeration(
std::string name, std::vector<std::string>& values);

URI uri_;
Config cfg_;
Context ctx_;
Expand All @@ -74,6 +74,18 @@ struct HandleLoadArraySchemaRequestFx : RequestHandlerFx {
virtual shared_ptr<ArraySchema> create_schema() override;
ArraySchema call_handler(
serialization::LoadArraySchemaRequest req, SerializationType stype);

shared_ptr<const Enumeration> create_string_enumeration(
std::string name, std::vector<std::string>& values);
};

struct HandleQueryPlanRequestFx : RequestHandlerFx {
HandleQueryPlanRequestFx()
: RequestHandlerFx("query_plan_handler") {
}

virtual shared_ptr<ArraySchema> create_schema() override;
std::string call_handler(SerializationType stype, Query& query);
};

/* ********************************* */
Expand Down Expand Up @@ -153,6 +165,60 @@ TEST_CASE_METHOD(
REQUIRE(rval != TILEDB_OK);
}

/* ******************************************** */
/* Testing Query Plan serialization */
/* ******************************************** */

TEST_CASE_METHOD(
HandleQueryPlanRequestFx,
"tiledb_handle_query_plan_request - default request",
"[request_handler][query_plan][default]") {
auto stype = GENERATE(SerializationType::JSON, SerializationType::CAPNP);

// Create and open array
create_array();
tiledb_ctx_t* ctx;
REQUIRE(tiledb_ctx_alloc(NULL, &ctx) == TILEDB_OK);
tiledb_array_t* array;
REQUIRE(tiledb_array_alloc(ctx, uri_.c_str(), &array) == TILEDB_OK);
REQUIRE(tiledb_array_open(ctx, array, TILEDB_READ) == TILEDB_OK);

// Create query
tiledb_query_t* query;
REQUIRE(tiledb_query_alloc(ctx, array, TILEDB_READ, &query) == TILEDB_OK);
REQUIRE(tiledb_query_set_layout(ctx, query, TILEDB_ROW_MAJOR) == TILEDB_OK);
int32_t dom[] = {1, 2, 1, 2};
REQUIRE(tiledb_query_set_subarray(ctx, query, &dom) == TILEDB_OK);

std::vector<int32_t> a1(2);
uint64_t size = 1;
REQUIRE(
tiledb_query_set_data_buffer(ctx, query, "attr1", a1.data(), &size) ==
TILEDB_OK);
std::vector<int64_t> a2(2);
REQUIRE(
tiledb_query_set_data_buffer(ctx, query, "attr2", a2.data(), &size) ==
TILEDB_OK);

// Use C API to get the query plan
tiledb_string_handle_t* query_plan;
REQUIRE(tiledb_query_get_plan(ctx, query, &query_plan) == TILEDB_OK);

// Call handler to get query plan via serialized req/deserialized response
auto query_plan_ser_deser = call_handler(stype, *query->query_);

// Compare the two query plans
// std::cout << query_plan->view();
// std::cout << query_plan_ser_deser;
REQUIRE(query_plan->view() == query_plan_ser_deser);

// Clean up
REQUIRE(tiledb_array_close(ctx, array) == TILEDB_OK);
tiledb_query_free(&query);
tiledb_array_free(&array);
tiledb_ctx_free(&ctx);
}

/* ********************************* */
/* Testing Support Code */
/* ********************************* */
Expand Down Expand Up @@ -187,7 +253,8 @@ shared_ptr<Array> RequestHandlerFx::get_array(QueryType type) {
return array;
}

shared_ptr<const Enumeration> RequestHandlerFx::create_string_enumeration(
shared_ptr<const Enumeration>
HandleLoadArraySchemaRequestFx::create_string_enumeration(
std::string name, std::vector<std::string>& values) {
uint64_t total_size = 0;
for (auto v : values) {
Expand Down Expand Up @@ -263,4 +330,57 @@ ArraySchema HandleLoadArraySchemaRequestFx::call_handler(
stype, resp_buf->buffer());
}

shared_ptr<ArraySchema> HandleQueryPlanRequestFx::create_schema() {
// Create a schema to serialize
auto schema = make_shared<ArraySchema>(HERE(), ArrayType::DENSE);
schema->set_capacity(10000);
throw_if_not_ok(schema->set_cell_order(Layout::ROW_MAJOR));
throw_if_not_ok(schema->set_tile_order(Layout::ROW_MAJOR));
uint32_t dim_domain[] = {1, 10, 1, 10};
// uint64_t extents[] = {5, 5};

auto dim1 = make_shared<Dimension>(HERE(), "dim1", Datatype::INT32);
throw_if_not_ok(dim1->set_domain(&dim_domain[0]));
auto dim2 = make_shared<Dimension>(HERE(), "dim2", Datatype::INT32);
throw_if_not_ok(dim2->set_domain(&dim_domain[2]));

auto dom = make_shared<Domain>(HERE());
throw_if_not_ok(dom->add_dimension(dim1));
throw_if_not_ok(dom->add_dimension(dim2));
throw_if_not_ok(schema->set_domain(dom));

auto attr1 = make_shared<Attribute>(HERE(), "attr1", Datatype::INT32);
throw_if_not_ok(schema->add_attribute(attr1));
auto attr2 = make_shared<Attribute>(HERE(), "attr2", Datatype::INT64);
throw_if_not_ok(schema->add_attribute(attr2));

return schema;
}

std::string HandleQueryPlanRequestFx::call_handler(
SerializationType stype, Query& query) {
auto ctx = tiledb::Context();
auto array = tiledb::Array(ctx, uri_.to_string(), TILEDB_READ);
// auto query = tiledb::Query(ctx, array);
auto req_buf = tiledb_buffer_handle_t::make_handle();
auto resp_buf = tiledb_buffer_handle_t::make_handle();

serialization::serialize_query_plan_request(
// cfg_, *query.ptr().get()->query_, stype, req_buf->buffer());
cfg_,
query,
stype,
req_buf->buffer());
auto rval = tiledb_handle_query_plan_request(
ctx.ptr().get(),
array.ptr().get(),
static_cast<tiledb_serialization_type_t>(stype),
req_buf,
resp_buf);
REQUIRE(rval == TILEDB_OK);

return serialization::deserialize_query_plan_response(
stype, resp_buf->buffer());
}

#endif // TILEDB_SERIALIZATION
3 changes: 1 addition & 2 deletions tiledb/sm/c_api/tiledb.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4470,7 +4470,6 @@ capi_return_t tiledb_handle_query_plan_request(
plan.dump_json(),
static_cast<tiledb::sm::SerializationType>(serialization_type),
response->buffer());

return TILEDB_OK;
}

Expand Down Expand Up @@ -7385,7 +7384,7 @@ CAPI_INTERFACE(
}

CAPI_INTERFACE(
handle_get_query_plan_request,
handle_query_plan_request,
tiledb_ctx_t* ctx,
tiledb_array_t* array,
tiledb_serialization_type_t serialization_type,
Expand Down
17 changes: 17 additions & 0 deletions tiledb/sm/query_plan/query_plan.cc
Original file line number Diff line number Diff line change
Expand Up @@ -94,5 +94,22 @@ std::string QueryPlan::dump_json(uint32_t indent) {
return rv.dump(indent);
}

void QueryPlan::from_json(const std::string& json) {
auto j = nlohmann::json::parse(json);

j = j["TileDB Query Plan"];
array_uri_ = j["Array.URI"];
throw_if_not_ok(array_type_enum(j["Array.Type"], &array_type_));
vfs_backend_ = j["VFS.Backend"];
throw_if_not_ok(layout_enum(j["Query.Layout"], &query_layout_));
strategy_name_ = j["Query.Strategy.Name"];
for (auto& a : j["Query.Attributes"]) {
attributes_.push_back(a);
}
for (auto& d : j["Query.Dimensions"]) {
dimensions_.push_back(d);
}
}

} // namespace sm
} // namespace tiledb
9 changes: 9 additions & 0 deletions tiledb/sm/query_plan/query_plan.h
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,15 @@ class QueryPlan {

/** A list of queried dimensions */
std::vector<std::string> dimensions_;

/**
* Populate query plan from a valid json representation.
* Only meant to be used during construction when a remote query
* plan comes as a json string from rest serialization.
*
* @param json a json representation of the query plan
*/
void from_json(const std::string& json);
};

} // namespace sm
Expand Down
7 changes: 5 additions & 2 deletions tiledb/sm/serialization/query_plan.cc
Original file line number Diff line number Diff line change
Expand Up @@ -152,13 +152,15 @@ void deserialize_query_plan_request(
switch (serialization_type) {
case SerializationType::JSON: {
::capnp::JsonCodec json;
json.handleByAnnotation<capnp::QueryPlanRequest>();
::capnp::MallocMessageBuilder message_builder;
capnp::QueryPlanRequest::Builder builder =
message_builder.initRoot<capnp::QueryPlanRequest>();
json.decode(
kj::StringPtr(static_cast<const char*>(request.data())), builder);
capnp::QueryPlanRequest::Reader reader = builder.asReader();
query_plan_request_from_capnp(reader, compute_tp, query);
break;
}
case SerializationType::CAPNP: {
const auto mBytes = reinterpret_cast<const kj::byte*>(request.data());
Expand All @@ -168,6 +170,7 @@ void deserialize_query_plan_request(
capnp::QueryPlanRequest::Reader reader =
array_reader.getRoot<capnp::QueryPlanRequest>();
query_plan_request_from_capnp(reader, compute_tp, query);
break;
}
default: {
throw Status_SerializationError(
Expand Down Expand Up @@ -280,13 +283,13 @@ std::string deserialize_query_plan_response(
#else

void serialize_query_plan_request(
const Config&, const Query&, const SerializationType, Buffer&) {
const Config&, Query&, const SerializationType, Buffer&) {
throw Status_SerializationError(
"Cannot serialize; serialization not enabled.");
}

void deserialize_query_plan_request(
const SerializationType, const Buffer&, const ThreadPool&, Query&) {
const SerializationType, const Buffer&, ThreadPool&, Query&) {
throw Status_SerializationError(
"Cannot serialize; serialization not enabled.");
}
Expand Down

0 comments on commit 92a430e

Please sign in to comment.