Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: project rel to and from substrait to include pass through columns #135

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 44 additions & 4 deletions src/from_substrait.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@
interval_t interval {};
interval.months = 0;
interval.days = literal.interval_day_to_second().days();
interval.micros = literal.interval_day_to_second().microseconds();

Check warning on line 165 in src/from_substrait.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / MacOS (osx_amd64, x86_64, x64-osx)

'microseconds' is deprecated [-Wdeprecated-declarations]

Check warning on line 165 in src/from_substrait.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / MacOS (osx_amd64, x86_64, x64-osx)

'microseconds' is deprecated [-Wdeprecated-declarations]

Check warning on line 165 in src/from_substrait.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / MacOS (osx_arm64, arm64, arm64-osx)

'microseconds' is deprecated [-Wdeprecated-declarations]

Check warning on line 165 in src/from_substrait.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / MacOS (osx_arm64, arm64, arm64-osx)

'microseconds' is deprecated [-Wdeprecated-declarations]
return Value::INTERVAL(interval);
}
default:
Expand Down Expand Up @@ -492,22 +492,59 @@
return make_shared_ptr<FilterRelation>(TransformOp(sfilter.input()), TransformExpr(sfilter.condition()));
}

const google::protobuf::RepeatedField<int32_t>& GetOutputMapping(const substrait::Rel &sop) {
const substrait::RelCommon* common = nullptr;
switch (sop.rel_type_case()) {
case substrait::Rel::RelTypeCase::kJoin:
common = &sop.join().common();
break;
case substrait::Rel::RelTypeCase::kProject:
common = &sop.project().common();
break;
default:
throw InternalException("Unsupported relation type " + to_string(sop.rel_type_case()));
}
if (!common->has_emit()) {
static google::protobuf::RepeatedField<int32_t> empty_mapping;
return empty_mapping;
}
return common->emit().output_mapping();
}

shared_ptr<Relation>
SubstraitToDuckDB::TransformProjectOp(const substrait::Rel &sop,
const google::protobuf::RepeatedPtrField<std::string> *names) {
vector<unique_ptr<ParsedExpression>> expressions;
RootNameIterator iterator(names);

for (auto &sexpr : sop.project().expressions()) {
expressions.push_back(TransformExpr(sexpr, &iterator));
auto input_rel = TransformOp(sop.project().input());

auto mapping = GetOutputMapping(sop);
auto num_input_columns = input_rel->Columns().size();
if (mapping.empty()) {
for (int i = 1; i <= num_input_columns; i++) {
expressions.push_back(make_uniq<PositionalReferenceExpression>(i));
}

for (auto &sexpr : sop.project().expressions()) {
expressions.push_back(TransformExpr(sexpr, &iterator));
}
} else {
expressions.resize(mapping.size());
for (size_t i = 0; i < mapping.size(); i++) {
if (mapping[i] < num_input_columns) {
expressions[i] = make_uniq<PositionalReferenceExpression>(mapping[i] + 1);
} else {
expressions[i] = TransformExpr(sop.project().expressions(mapping[i] - num_input_columns), &iterator);
}
}
}

vector<string> mock_aliases;
for (size_t i = 0; i < expressions.size(); i++) {
mock_aliases.push_back("expr_" + to_string(i));
}
return make_shared_ptr<ProjectionRelation>(TransformOp(sop.project().input()), std::move(expressions),
std::move(mock_aliases));
return make_shared_ptr<ProjectionRelation>(input_rel, std::move(expressions), std::move(mock_aliases));
}

shared_ptr<Relation> SubstraitToDuckDB::TransformAggregateOp(const substrait::Rel &sop) {
Expand All @@ -515,7 +552,7 @@

if (sop.aggregate().groupings_size() > 0) {
for (auto &sgrp : sop.aggregate().groupings()) {
for (auto &sgrpexpr : sgrp.grouping_expressions()) {

Check warning on line 555 in src/from_substrait.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / MacOS (osx_amd64, x86_64, x64-osx)

'grouping_expressions' is deprecated [-Wdeprecated-declarations]

Check warning on line 555 in src/from_substrait.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / MacOS (osx_arm64, arm64, arm64-osx)

'grouping_expressions' is deprecated [-Wdeprecated-declarations]
groups.push_back(TransformExpr(sgrpexpr));
expressions.push_back(TransformExpr(sgrpexpr));
}
Expand Down Expand Up @@ -615,8 +652,8 @@
scan = rel->Alias(name);
} else if (sget.has_virtual_table()) {
// We need to handle a virtual table as a LogicalExpressionGet
if (!sget.virtual_table().values().empty()) {

Check warning on line 655 in src/from_substrait.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / MacOS (osx_amd64, x86_64, x64-osx)

'values' is deprecated [-Wdeprecated-declarations]

Check warning on line 655 in src/from_substrait.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / MacOS (osx_arm64, arm64, arm64-osx)

'values' is deprecated [-Wdeprecated-declarations]
auto literal_values = sget.virtual_table().values();

Check warning on line 656 in src/from_substrait.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / MacOS (osx_amd64, x86_64, x64-osx)

'values' is deprecated [-Wdeprecated-declarations]

Check warning on line 656 in src/from_substrait.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / MacOS (osx_arm64, arm64, arm64-osx)

'values' is deprecated [-Wdeprecated-declarations]
vector<vector<Value>> expression_rows;
for (auto &row : literal_values) {
auto values = row.fields();
Expand Down Expand Up @@ -822,6 +859,9 @@
if (first_projection_or_table) {
vector<ColumnDefinition> *column_definitions = &first_projection_or_table->Cast<ProjectionRelation>().columns;
int32_t i = 0;
if (column_definitions->size() > column_names.size()) {
throw InvalidInputException("Number of column names less than number of column definitions");
}
for (auto &column : *column_definitions) {
aliases.push_back(column_names[i++]);
auto column_type = column.GetType();
Expand Down
4 changes: 3 additions & 1 deletion src/include/to_substrait.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,9 @@ class DuckDBToSubstrait {
substrait::Rel *TransformInsertTable(LogicalOperator &dop);
substrait::Rel *TransformDeleteTable(LogicalOperator &dop);
static substrait::Rel *TransformDummyScan();
//! Methods to transform different LogicalGet Types (e.g., Table, Parquet)
static substrait::RelCommon *CreateOutputMapping(vector<int32_t> vector);
//! Methods to transform different LogicalGe:75
//t Types (e.g., Table, Parquet)
//! To Substrait;
void TransformTableScanToSubstrait(LogicalGet &dget, substrait::ReadRel *sget) const;
void TransformParquetScanToSubstrait(LogicalGet &dget, substrait::ReadRel *sget, BindInfo &bind_info,
Expand Down
70 changes: 69 additions & 1 deletion src/to_substrait.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@
} else {
auto interval_day = make_uniq<substrait::Expression_Literal_IntervalDayToSecond>();
interval_day->set_days(dval.GetValue<interval_t>().days);
interval_day->set_microseconds(static_cast<int32_t>(dval.GetValue<interval_t>().micros));

Check warning on line 204 in src/to_substrait.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / MacOS (osx_amd64, x86_64, x64-osx)

'set_microseconds' is deprecated [-Wdeprecated-declarations]

Check warning on line 204 in src/to_substrait.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / MacOS (osx_arm64, arm64, arm64-osx)

'set_microseconds' is deprecated [-Wdeprecated-declarations]
sval.set_allocated_interval_day_to_second(interval_day.release());
}
}
Expand Down Expand Up @@ -856,14 +856,75 @@
return res;
}

substrait::RelCommon *DuckDBToSubstrait::CreateOutputMapping(vector<int32_t> vector) {
auto rel_common = new substrait::RelCommon();
auto output_mapping = rel_common->mutable_emit()->mutable_output_mapping();
for (auto &col_idx : vector) {
output_mapping->Add(col_idx);
}
return rel_common;
}

substrait::Rel *DuckDBToSubstrait::TransformProjection(LogicalOperator &dop) {
auto res = new substrait::Rel();
auto &dproj = dop.Cast<LogicalProjection>();

auto child_column_count = dop.children[0]->types.size();
auto num_passthrough_columns = 0;
auto need_output_mapping = true;
if (child_column_count <= dproj.expressions.size()) {
// check if the projection is just pass through of input columns with no reordering
auto exp_col_idx = 0;
auto is_passthrough = true;
for (auto &dexpr : dproj.expressions) {
if (dexpr->type != ExpressionType::BOUND_REF) {
is_passthrough = false;
break;
}
num_passthrough_columns++;
auto &dref = dexpr.get()->Cast<BoundReferenceExpression>();
if (dref.index != exp_col_idx) {
is_passthrough = false;
break;
}
exp_col_idx++;
}
if (is_passthrough && child_column_count == exp_col_idx) {
// skip the projection
return TransformOp(*dop.children[0]);
}
if (child_column_count == exp_col_idx) {
// all input columns are projected, no need for output mapping
num_passthrough_columns = child_column_count;
need_output_mapping = false;
}
}

auto sproj = res->mutable_project();
sproj->set_allocated_input(TransformOp(*dop.children[0]));

auto t_index = 0;
vector<int32_t> output_mapping;
for (auto &dexpr : dproj.expressions) {
TransformExpr(*dexpr, *sproj->add_expressions());
switch (dexpr->type) {
case ExpressionType::BOUND_REF: {
auto &dref = dexpr.get()->Cast<BoundReferenceExpression>();
output_mapping.push_back(dref.index);
break;
}
default:
TransformExpr(*dexpr.get(), *sproj->add_expressions());
output_mapping.push_back(child_column_count + t_index);
t_index++;
}
}
if (need_output_mapping) {
if (sproj->expressions_size() == 0) {
// atleast one expression should be there, add zeroth column as dummy expression
CreateFieldRef(sproj->add_expressions(), 0);
}
auto rel_common = CreateOutputMapping(output_mapping);
sproj->set_allocated_common(rel_common);
}
return res;
}
Expand Down Expand Up @@ -998,6 +1059,13 @@
}
}

auto child_column_count = dop.children[0]->types.size() + dop.children[1]->types.size();
vector<int32_t> output_mapping;
for (idx_t i = 0; i < projection->expressions_size(); i++) {
output_mapping.push_back(child_column_count + i);
}
auto rel_common = CreateOutputMapping(output_mapping);
projection->set_allocated_common(rel_common);
projection->set_allocated_input(res);
return proj_rel;
}
Expand All @@ -1014,7 +1082,7 @@
// TODO push projection or push substrait to allow expressions here
throw NotImplementedException("No expressions in groupings yet");
}
TransformExpr(*dgrp, *sgrp->add_grouping_expressions());

Check warning on line 1085 in src/to_substrait.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / MacOS (osx_amd64, x86_64, x64-osx)

'add_grouping_expressions' is deprecated [-Wdeprecated-declarations]

Check warning on line 1085 in src/to_substrait.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / MacOS (osx_arm64, arm64, arm64-osx)

'add_grouping_expressions' is deprecated [-Wdeprecated-declarations]
}
for (auto &dmeas : daggr.expressions) {
auto smeas = saggr->add_measures()->mutable_measure();
Expand Down Expand Up @@ -1282,7 +1350,7 @@
auto virtual_table = sget->mutable_virtual_table();

// Add a dummy value to emit one row
auto dummy_value = virtual_table->add_values();

Check warning on line 1353 in src/to_substrait.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / MacOS (osx_amd64, x86_64, x64-osx)

'add_values' is deprecated [-Wdeprecated-declarations]

Check warning on line 1353 in src/to_substrait.cpp

View workflow job for this annotation

GitHub Actions / Build extension binaries / MacOS (osx_arm64, arm64, arm64-osx)

'add_values' is deprecated [-Wdeprecated-declarations]
dummy_value->add_fields()->set_i32(42);
return get_rel;
}
Expand Down
2 changes: 1 addition & 1 deletion test/c/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ include_directories(../../duckdb/src/include)
include_directories(../../duckdb/test/include)
include_directories(../../duckdb/third_party/catch)

set(ALL_SOURCES test_substrait_c_api.cpp)
set(ALL_SOURCES test_substrait_c_api.cpp test_substrait_c_utils.cpp test_projection.cpp)


add_library_unity(test_substrait OBJECT ${ALL_SOURCES})
Expand Down
159 changes: 159 additions & 0 deletions test/c/test_projection.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
#include "catch.hpp"
#include "test_helpers.hpp"
#include "duckdb/main/connection_manager.hpp"
#include "test_substrait_c_utils.hpp"

#include <chrono>
#include <thread>
#include <iostream>

using namespace duckdb;
using namespace std;

TEST_CASE("Test C Project input columns with Substrait API", "[substrait-api]") {
DuckDB db(nullptr);
Connection con(db);

REQUIRE_NO_FAIL(con.Query("CREATE TABLE integers(i INTEGER)"));
REQUIRE_NO_FAIL(con.Query("INSERT INTO integers VALUES (10), (20), (30)"));
CreateEmployeeTable(con);

auto expected_json_str = R"({"relations":[{"root":{"input":{"read":{"baseSchema":{"names":["i"],"struct":{"types":[{"i32":{"nullability":"NULLABILITY_NULLABLE"}}],"nullability":"NULLABILITY_REQUIRED"}},"projection":{"select":{"structItems":[{}]},"maintainSingularStruct":true},"namedTable":{"names":["integers"]}}},"names":["i"]}}],"version":{"minorNumber":53,"producer":"DuckDB"}})";
auto json_str = con.GetSubstraitJSON("SELECT i FROM integers");
REQUIRE(json_str == expected_json_str);
auto result = con.FromSubstraitJSON(json_str);
REQUIRE(CHECK_COLUMN(result, 0, {10, 20, 30}));
}

TEST_CASE("Test C Project 1 input column 1 transformation with Substrait API", "[substrait-api]") {
DuckDB db(nullptr);
Connection con(db);

REQUIRE_NO_FAIL(con.Query("CREATE TABLE integers(i INTEGER)"));
REQUIRE_NO_FAIL(con.Query("INSERT INTO integers VALUES (10), (20), (30)"));
CreateEmployeeTable(con);

auto expected_json_str = R"({"extensionUris":[{"extensionUriAnchor":1,"uri":"https://github.com/substrait-io/substrait/blob/main/extensions/functions_arithmetic.yaml"}],"extensions":[{"extensionFunction":{"extensionUriReference":1,"functionAnchor":1,"name":"multiply:i32_i32"}}],"relations":[{"root":{"input":{"project":{"input":{"read":{"baseSchema":{"names":["i"],"struct":{"types":[{"i32":{"nullability":"NULLABILITY_NULLABLE"}}],"nullability":"NULLABILITY_REQUIRED"}},"projection":{"select":{"structItems":[{}]},"maintainSingularStruct":true},"namedTable":{"names":["integers"]}}},"expressions":[{"scalarFunction":{"functionReference":1,"outputType":{"i32":{"nullability":"NULLABILITY_NULLABLE"}},"arguments":[{"value":{"selection":{"directReference":{"structField":{}},"rootReference":{}}}},{"value":{"selection":{"directReference":{"structField":{}},"rootReference":{}}}}]}}]}},"names":["i","isquare"]}}],"version":{"minorNumber":53,"producer":"DuckDB"}})";
auto json_str = con.GetSubstraitJSON("SELECT i, i *i as isquare FROM integers");
REQUIRE(json_str == expected_json_str);
auto result = con.FromSubstraitJSON(json_str);
REQUIRE(CHECK_COLUMN(result, 0, {10, 20, 30}));
REQUIRE(CHECK_COLUMN(result, 1, {100, 400, 900}));
}

TEST_CASE("Test C Project all columns with Substrait API", "[substrait-api]") {
DuckDB db(nullptr);
Connection con(db);

CreateEmployeeTable(con);

// This should not have a ProjectRel node
auto json_str = con.GetSubstraitJSON("SELECT * FROM employees");
auto expected_json_str = R"({"relations":[{"root":{"input":{"read":{"baseSchema":{"names":["employee_id","name","department_id","salary"],"struct":{"types":[{"i32":{"nullability":"NULLABILITY_REQUIRED"}},{"string":{"nullability":"NULLABILITY_NULLABLE"}},{"i32":{"nullability":"NULLABILITY_NULLABLE"}},{"decimal":{"scale":2,"precision":10,"nullability":"NULLABILITY_NULLABLE"}}],"nullability":"NULLABILITY_REQUIRED"}},"projection":{"select":{"structItems":[{},{"field":1},{"field":2},{"field":3}]},"maintainSingularStruct":true},"namedTable":{"names":["employees"]}}},"names":["employee_id","name","department_id","salary"]}}],"version":{"minorNumber":53,"producer":"DuckDB"}})";
REQUIRE(json_str == expected_json_str);
auto result = con.FromSubstraitJSON(json_str);
REQUIRE(CHECK_COLUMN(result, 0, {1, 2, 3, 4, 5}));
REQUIRE(CHECK_COLUMN(result, 1, {"John Doe", "Jane Smith", "Alice Johnson", "Bob Brown", "Charlie Black"}));
REQUIRE(CHECK_COLUMN(result, 2, {1, 2, 1, 3, 2}));
REQUIRE(CHECK_COLUMN(result, 3, {120000, 80000, 50000, 95000, 60000}));
}

TEST_CASE("Test C Project two passthrough columns with Substrait API", "[substrait-api]") {
DuckDB db(nullptr);
Connection con(db);

CreateEmployeeTable(con);

// This should not have a ProjectRel node
auto json_str = con.GetSubstraitJSON("SELECT name, salary FROM employees");
auto expected_json_str = R"({"relations":[{"root":{"input":{"read":{"baseSchema":{"names":["employee_id","name","department_id","salary"],"struct":{"types":[{"i32":{"nullability":"NULLABILITY_REQUIRED"}},{"string":{"nullability":"NULLABILITY_NULLABLE"}},{"i32":{"nullability":"NULLABILITY_NULLABLE"}},{"decimal":{"scale":2,"precision":10,"nullability":"NULLABILITY_NULLABLE"}}],"nullability":"NULLABILITY_REQUIRED"}},"projection":{"select":{"structItems":[{"field":1},{"field":3}]},"maintainSingularStruct":true},"namedTable":{"names":["employees"]}}},"names":["name","salary"]}}],"version":{"minorNumber":53,"producer":"DuckDB"}})";
REQUIRE(json_str == expected_json_str);
auto result = con.FromSubstraitJSON(json_str);
REQUIRE(CHECK_COLUMN(result, 0, {"John Doe", "Jane Smith", "Alice Johnson", "Bob Brown", "Charlie Black"}));
REQUIRE(CHECK_COLUMN(result, 1, {120000, 80000, 50000, 95000, 60000}));
}

TEST_CASE("Test C Project two passthrough columns with filter", "[substrait-api]") {
DuckDB db(nullptr);
Connection con(db);

CreateEmployeeTable(con);

// This should not have a ProjectRel node
auto json_str = con.GetSubstraitJSON("SELECT name, salary FROM employees where department_id = 1");
auto expected_json_str = R"({"extensionUris":[{"extensionUriAnchor":1,"uri":"https://github.com/substrait-io/substrait/blob/main/extensions/"}],"extensions":[{"extensionFunction":{"extensionUriReference":1,"functionAnchor":1,"name":"equal:i32_i32"}}],"relations":[{"root":{"input":{"read":{"baseSchema":{"names":["employee_id","name","department_id","salary"],"struct":{"types":[{"i32":{"nullability":"NULLABILITY_REQUIRED"}},{"string":{"nullability":"NULLABILITY_NULLABLE"}},{"i32":{"nullability":"NULLABILITY_NULLABLE"}},{"decimal":{"scale":2,"precision":10,"nullability":"NULLABILITY_NULLABLE"}}],"nullability":"NULLABILITY_REQUIRED"}},"filter":{"scalarFunction":{"functionReference":1,"outputType":{"bool":{"nullability":"NULLABILITY_NULLABLE"}},"arguments":[{"value":{"selection":{"directReference":{"structField":{"field":2}},"rootReference":{}}}},{"value":{"literal":{"i32":1}}}]}},"projection":{"select":{"structItems":[{"field":1},{"field":3}]},"maintainSingularStruct":true},"namedTable":{"names":["employees"]}}},"names":["name","salary"]}}],"version":{"minorNumber":53,"producer":"DuckDB"}})";
REQUIRE(json_str == expected_json_str);
auto result = con.FromSubstraitJSON(json_str);
REQUIRE(CHECK_COLUMN(result, 0, {"John Doe", "Alice Johnson" }));
REQUIRE(CHECK_COLUMN(result, 1, {120000, 50000 }));
}

TEST_CASE("Test C Project 1 passthrough column, 1 transformation with column elimination", "[substrait-api]") {
DuckDB db(nullptr);
Connection con(db);

CreateEmployeeTable(con);

auto json_str = con.GetSubstraitJSON("SELECT name, salary * 1.2 as new_salary FROM employees");
auto expected_json_str = R"({"extensionUris":[{"extensionUriAnchor":1,"uri":"https://github.com/substrait-io/substrait/blob/main/extensions/functions_arithmetic_decimal.yaml"}],"extensions":[{"extensionFunction":{"extensionUriReference":1,"functionAnchor":1,"name":"multiply:decimal_decimal"}}],"relations":[{"root":{"input":{"project":{"common":{"emit":{"outputMapping":[0,2]}},"input":{"read":{"baseSchema":{"names":["employee_id","name","department_id","salary"],"struct":{"types":[{"i32":{"nullability":"NULLABILITY_REQUIRED"}},{"string":{"nullability":"NULLABILITY_NULLABLE"}},{"i32":{"nullability":"NULLABILITY_NULLABLE"}},{"decimal":{"scale":2,"precision":10,"nullability":"NULLABILITY_NULLABLE"}}],"nullability":"NULLABILITY_REQUIRED"}},"projection":{"select":{"structItems":[{"field":1},{"field":3}]},"maintainSingularStruct":true},"namedTable":{"names":["employees"]}}},"expressions":[{"scalarFunction":{"functionReference":1,"outputType":{"decimal":{"scale":3,"precision":12,"nullability":"NULLABILITY_NULLABLE"}},"arguments":[{"value":{"selection":{"directReference":{"structField":{"field":1}},"rootReference":{}}}},{"value":{"literal":{"decimal":{"value":"DAAAAAAAAAAAAAAAAAAAAA==","precision":12,"scale":1}}}}]}}]}},"names":["name","new_salary"]}}],"version":{"minorNumber":53,"producer":"DuckDB"}})";
REQUIRE(json_str == expected_json_str);
auto result = con.FromSubstraitJSON(json_str);
REQUIRE(CHECK_COLUMN(result, 0, {"John Doe", "Jane Smith", "Alice Johnson", "Bob Brown", "Charlie Black"}));
REQUIRE(CHECK_COLUMN(result, 1, {144000, 96000, 60000, 114000, 72000}));
}

TEST_CASE("Test C Project 1 passthrough column and 1 aggregate transformation", "[substrait-api]") {
DuckDB db(nullptr);
Connection con(db);

CreateEmployeeTable(con);

auto json_str = con.GetSubstraitJSON("SELECT department_id, AVG(salary) AS avg_salary FROM employees GROUP BY department_id");
auto expected_json_str = R"({"extensionUris":[{"extensionUriAnchor":1,"uri":"https://github.com/substrait-io/substrait/blob/main/extensions/"}],"extensions":[{"extensionFunction":{"extensionUriReference":1,"functionAnchor":1,"name":"avg:decimal"}}],"relations":[{"root":{"input":{"aggregate":{"input":{"read":{"baseSchema":{"names":["employee_id","name","department_id","salary"],"struct":{"types":[{"i32":{"nullability":"NULLABILITY_REQUIRED"}},{"string":{"nullability":"NULLABILITY_NULLABLE"}},{"i32":{"nullability":"NULLABILITY_NULLABLE"}},{"decimal":{"scale":2,"precision":10,"nullability":"NULLABILITY_NULLABLE"}}],"nullability":"NULLABILITY_REQUIRED"}},"projection":{"select":{"structItems":[{"field":2},{"field":3}]},"maintainSingularStruct":true},"namedTable":{"names":["employees"]}}},"groupings":[{"groupingExpressions":[{"selection":{"directReference":{"structField":{}},"rootReference":{}}}]}],"measures":[{"measure":{"functionReference":1,"outputType":{"fp64":{"nullability":"NULLABILITY_NULLABLE"}},"arguments":[{"value":{"selection":{"directReference":{"structField":{"field":1}},"rootReference":{}}}}]}}]}},"names":["department_id","avg_salary"]}}],"version":{"minorNumber":53,"producer":"DuckDB"}})";
REQUIRE(json_str == expected_json_str);
auto result = con.FromSubstraitJSON(json_str);
REQUIRE(CHECK_COLUMN(result, 0, {1, 2, 3}));
REQUIRE(CHECK_COLUMN(result, 1, {85000, 70000, 95000}));
}

TEST_CASE("Test C Project on Join with Substrait API", "[substrait-api]") {
DuckDB db(nullptr);
Connection con(db);

CreateEmployeeTable(con);
CreateDepartmentsTable(con);

auto result = ExecuteViaSubstraitJSON(con,
"SELECT e.employee_id, e.name, d.department_name "
"FROM employees e "
"JOIN departments d "
"ON e.department_id = d.department_id"
);

REQUIRE(CHECK_COLUMN(result, 0, {1, 2, 3, 4, 5}));
REQUIRE(CHECK_COLUMN(result, 1, {"John Doe", "Jane Smith", "Alice Johnson", "Bob Brown", "Charlie Black"}));
REQUIRE(CHECK_COLUMN(result, 2, {"HR", "Engineering", "HR", "Finance", "Engineering"}));
}

TEST_CASE("Test Project with bad plan", "[substrait-api]") {
DuckDB db(nullptr);
Connection con(db);
con.EnableQueryVerification();
REQUIRE_NO_FAIL(con.Query("CREATE TABLE integers(i INTEGER)"));
REQUIRE_NO_FAIL(con.Query("INSERT INTO integers VALUES (1), (2), (3), (NULL)"));

auto query_json = R"({"relations":[{"root":{"input":{"project":{"input":{"fetch":{"input":{"read":{"baseSchema":{"names":["i"],"struct":{"types":[{"i32":{"nullability":"NULLABILITY_NULLABLE"}}],"nullability":"NULLABILITY_REQUIRED"}},"projection":{"select":{"structItems":[{}]},"maintainSingularStruct":true},"namedTable":{"names":["integers"]}}},"count":"5"}},"expressions":[{"selection":{"directReference":{"structField":{}},"rootReference":{}}}]}},"names":["i"]}}],"version":{"minorNumber":53,"producer":"DuckDB"}})";
REQUIRE_THROWS(con.FromSubstraitJSON(query_json));
}

TEST_CASE("Test Project with duplicate columns", "[substrait-api]") {
DuckDB db(nullptr);
Connection con(db);
con.EnableQueryVerification();
REQUIRE_NO_FAIL(con.Query("CREATE TABLE integers(i INTEGER)"));
REQUIRE_NO_FAIL(con.Query("INSERT INTO integers VALUES (1), (2), (3), (NULL)"));

auto query_json = R"({"relations":[{"root":{"input":{"project":{"input":{"fetch":{"input":{"read":{"baseSchema":{"names":["i"],"struct":{"types":[{"i32":{"nullability":"NULLABILITY_NULLABLE"}}],"nullability":"NULLABILITY_REQUIRED"}},"projection":{"select":{"structItems":[{}]},"maintainSingularStruct":true},"namedTable":{"names":["integers"]}}},"count":"5"}},"expressions":[{"selection":{"directReference":{"structField":{}},"rootReference":{}}}]}},"names":["i", "integers"]}}],"version":{"minorNumber":53,"producer":"DuckDB"}})";
auto res1 = con.FromSubstraitJSON(query_json);
REQUIRE(CHECK_COLUMN(res1, 0, {1, 2, 3, Value()}));
}
Loading
Loading