Skip to content

Commit

Permalink
Add support for to_substrait and from_substrait for virtual table exp…
Browse files Browse the repository at this point in the history
…ression (#130)

* Also Updated duckdb submodule to include constructor for latest value relation
  • Loading branch information
anshuldata authored Nov 21, 2024
1 parent 728fbcd commit abc4b70
Show file tree
Hide file tree
Showing 7 changed files with 90 additions and 15 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/main_distribution.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ jobs:
name: Build extension binaries
uses: duckdb/extension-ci-tools/.github/workflows/_extension_distribution.yml@main
with:
duckdb_version: ca5af32c331f9d5ea49f7158d5c83a47f25b8b79
duckdb_version: c29c67bb971362cd1e9143305acffebb1bc9bd63
ci_tools_version: 5bdbe4d606d78dbd749f9578ba8ca639feece023
exclude_archs: "wasm_mvp;wasm_eh;wasm_threads;windows_amd64;windows_amd64_mingw;windows_amd64_rtools"
extension_name: substrait
Expand Down
2 changes: 1 addition & 1 deletion duckdb
Submodule duckdb updated 359 files
51 changes: 38 additions & 13 deletions src/from_substrait.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -615,23 +615,28 @@ shared_ptr<Relation> SubstraitToDuckDB::TransformReadOp(const substrait::Rel &so
scan = rel->Alias(name);
} else if (sget.has_virtual_table()) {
// We need to handle a virtual table as a LogicalExpressionGet
auto literal_values = sget.virtual_table().values();
vector<vector<Value>> expression_rows;
for (auto &row : literal_values) {
auto values = row.fields();
vector<Value> expression_row;
for (const auto &value : values) {
expression_row.emplace_back(TransformLiteralToValue(value));
if (!sget.virtual_table().values().empty()) {

Check warning on line 618 in src/from_substrait.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / MacOS (osx_amd64, x86_64, x64-osx)

'values' is deprecated [-Wdeprecated-declarations]

Check warning on line 618 in src/from_substrait.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / MacOS (osx_arm64, arm64, arm64-osx)

'values' is deprecated [-Wdeprecated-declarations]
auto literal_values = sget.virtual_table().values();

Check warning on line 619 in src/from_substrait.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / MacOS (osx_amd64, x86_64, x64-osx)

'values' is deprecated [-Wdeprecated-declarations]

Check warning on line 619 in src/from_substrait.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / MacOS (osx_arm64, arm64, arm64-osx)

'values' is deprecated [-Wdeprecated-declarations]
vector<vector<Value>> expression_rows;
for (auto &row : literal_values) {
auto values = row.fields();
vector<Value> expression_row;
for (const auto &value : values) {
expression_row.emplace_back(TransformLiteralToValue(value));
}
expression_rows.emplace_back(expression_row);
}
expression_rows.emplace_back(expression_row);
}
vector<string> column_names;
if (acquire_lock) {
scan = make_shared_ptr<ValueRelation>(context, expression_rows, column_names);
vector<string> column_names;
if (acquire_lock) {
scan = make_shared_ptr<ValueRelation>(context, expression_rows, column_names);

} else {
scan = make_shared_ptr<ValueRelation>(context_wrapper, expression_rows, column_names);
}
} else {
scan = make_shared_ptr<ValueRelation>(context_wrapper, expression_rows, column_names);
scan = GetValuesExpression(sget.virtual_table().expressions());
}

} else {
throw NotImplementedException("Unsupported type of read operator for substrait");
}
Expand All @@ -656,6 +661,26 @@ shared_ptr<Relation> SubstraitToDuckDB::TransformReadOp(const substrait::Rel &so
return scan;
}

shared_ptr<Relation> SubstraitToDuckDB::GetValuesExpression(const google::protobuf::RepeatedPtrField<substrait::Expression_Nested_Struct> &expression_rows) {
vector<vector<unique_ptr<ParsedExpression>>> expressions;
for (auto &row : expression_rows) {
vector<unique_ptr<ParsedExpression>> expression_row;
for (const auto &expr : row.fields()) {
expression_row.emplace_back(TransformExpr(expr));
}
expressions.emplace_back(std::move(expression_row));
}
vector<string> column_names;
shared_ptr<Relation> scan;
if (acquire_lock) {
scan = make_shared_ptr<ValueRelation>(context, std::move(expressions), column_names);
} else {
auto context_wrapper = make_shared_ptr<RelationContextWrapper>(context);
scan = make_shared_ptr<ValueRelation>(context_wrapper, std::move(expressions), column_names);
}
return scan;
}

shared_ptr<Relation> SubstraitToDuckDB::TransformSortOp(const substrait::Rel &sop,
const google::protobuf::RepeatedPtrField<std::string> *names) {
vector<OrderByNode> order_nodes;
Expand Down
1 change: 1 addition & 0 deletions src/include/from_substrait.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ class SubstraitToDuckDB {
const google::protobuf::RepeatedPtrField<std::string> *names = nullptr);
shared_ptr<Relation> TransformAggregateOp(const substrait::Rel &sop);
shared_ptr<Relation> TransformReadOp(const substrait::Rel &sop);
shared_ptr<Relation> GetValuesExpression(const google::protobuf::RepeatedPtrField<substrait::Expression_Nested_Struct> &expression_rows);
shared_ptr<Relation> TransformSortOp(const substrait::Rel &sop,
const google::protobuf::RepeatedPtrField<std::string> *names = nullptr);
shared_ptr<Relation> TransformSetOp(const substrait::Rel &sop,
Expand Down
1 change: 1 addition & 0 deletions src/include/to_substrait.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ class DuckDBToSubstrait {
substrait::Rel *TransformOrderBy(LogicalOperator &dop);
substrait::Rel *TransformComparisonJoin(LogicalOperator &dop);
substrait::Rel *TransformAggregateGroup(LogicalOperator &dop);
substrait::Rel *TransformExpressionGet(LogicalOperator &dop);
substrait::Rel *TransformGet(LogicalOperator &dop);
substrait::Rel *TransformCrossProduct(LogicalOperator &dop);
substrait::Rel *TransformUnion(LogicalOperator &dop);
Expand Down
21 changes: 21 additions & 0 deletions src/to_substrait.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1339,6 +1339,25 @@ substrait::Rel *DuckDBToSubstrait::TransformGet(LogicalOperator &dop) {
return get_rel;
}

substrait::Rel *DuckDBToSubstrait::TransformExpressionGet(LogicalOperator &dop) {
auto get_rel = new substrait::Rel();
auto &dget = dop.Cast<LogicalExpressionGet>();

auto sget = get_rel->mutable_read();
auto virtual_table = sget->mutable_virtual_table();

for (auto &row : dget.expressions) {
auto row_item = virtual_table->add_expressions();
for (auto &expr : row) {
auto s_expr = new substrait::Expression();
TransformExpr(*expr, *s_expr);
*row_item->add_fields() = *s_expr;
delete s_expr;
}
}
return get_rel;
}

substrait::Rel *DuckDBToSubstrait::TransformCrossProduct(LogicalOperator &dop) {
auto rel = new substrait::Rel();
auto sub_cross_prod = rel->mutable_cross();
Expand Down Expand Up @@ -1537,6 +1556,8 @@ substrait::Rel *DuckDBToSubstrait::TransformOp(LogicalOperator &dop) {
return TransformAggregateGroup(dop);
case LogicalOperatorType::LOGICAL_GET:
return TransformGet(dop);
case LogicalOperatorType::LOGICAL_EXPRESSION_GET:
return TransformExpressionGet(dop);
case LogicalOperatorType::LOGICAL_CROSS_PRODUCT:
return TransformCrossProduct(dop);
case LogicalOperatorType::LOGICAL_UNION:
Expand Down
27 changes: 27 additions & 0 deletions test/c/test_substrait_c_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

#include <chrono>
#include <thread>
#include <iostream>

using namespace duckdb;
using namespace std;
Expand Down Expand Up @@ -293,3 +294,29 @@ TEST_CASE("Test C DeleteRows with Substrait API", "[substrait-api]") {
REQUIRE(CHECK_COLUMN(result, 2, {1, 2, 3}));
REQUIRE(CHECK_COLUMN(result, 3, {120000, 80000, 95000}));
}

TEST_CASE("Test C VirtualTable input Literal", "[substrait-api]") {
DuckDB db(nullptr);
Connection con(db);

auto json = con.GetSubstraitJSON("select * from (values (1, 2),(3, 4))");
REQUIRE(!json.empty());
std::cout << json << std::endl;

auto result = con.FromSubstraitJSON(json);
REQUIRE(CHECK_COLUMN(result, 0, {1, 3}));
REQUIRE(CHECK_COLUMN(result, 1, {2, 4}));
}

TEST_CASE("Test C VirtualTable input Expression", "[substrait-api]") {
DuckDB db(nullptr);
Connection con(db);

auto json = con.GetSubstraitJSON("select * from (values (1+1,2+2),(3+3,4+4)) as temp(a,b)");
REQUIRE(!json.empty());
std::cout << json << std::endl;

auto result = con.FromSubstraitJSON(json);
REQUIRE(CHECK_COLUMN(result, 0, {2, 6}));
REQUIRE(CHECK_COLUMN(result, 1, {4, 8}));
}

0 comments on commit abc4b70

Please sign in to comment.