diff --git a/src/substrait_extension.cpp b/src/substrait_extension.cpp index 0ff7de8..c6bbb7f 100644 --- a/src/substrait_extension.cpp +++ b/src/substrait_extension.cpp @@ -22,8 +22,8 @@ namespace duckdb { -void do_nothing(ClientContext *) { -} +//! This is a no-op deleter for creating a shared pointer to a reference. +void deleter_noop(ClientContext *) {} struct ToSubstraitFunctionData : public TableFunctionData { ToSubstraitFunctionData() = default; @@ -264,7 +264,7 @@ static unique_ptr SubstraitBind(ClientContext &context, TableFunctionB throw BinderException("from_substrait cannot be called with a NULL parameter"); } string serialized = input.inputs[0].GetValueUnsafe(); - shared_ptr c_ptr(&context, do_nothing); + shared_ptr c_ptr(&context, deleter_noop); auto plan = SubstraitPlanToDuckDBRel(c_ptr, serialized, is_json); return plan->GetTableRef(); } @@ -277,6 +277,46 @@ static unique_ptr FromSubstraitBindJSON(ClientContext &context, TableF return SubstraitBind(context, input, true); } +//! Container for TableFnExplainSubstrait to get data from BindFnExplainSubstrait +struct FromSubstraitFunctionData : public TableFunctionData { + FromSubstraitFunctionData() = default; + shared_ptr plan; + unique_ptr res; + unique_ptr conn; +}; + +static unique_ptr BindFnExplainSubstrait(ClientContext &context, TableFunctionBindInput &input, + vector &return_types, vector &names) { + if (input.inputs[0].IsNull()) { + throw BinderException("explain_substrait cannot be called with a NULL parameter"); + } + + // Prep args to `SubstraitPlanToDuckDBRel` + constexpr bool is_json = false; + string serialized = input.inputs[0].GetValueUnsafe(); + shared_ptr c_ptr(&context, deleter_noop); + + auto result = make_uniq(); + result->conn = make_uniq(*context.db); + result->plan = SubstraitPlanToDuckDBRel(c_ptr, serialized, is_json); + + // return schema is a single string attribute (column) + return_types.emplace_back(LogicalType::VARCHAR); + names.emplace_back("Explain Plan"); + + return std::move(result); +} + +static void TableFnExplainSubstrait(ClientContext &context, TableFunctionInput &data_p, DataChunk &output) { + auto &data = data_p.bind_data->CastNoConst(); + if (!data.res) { data.res = data.plan->Explain(); } + + auto result_chunk = data.res->Fetch(); + if (!result_chunk) { return; } + + output.Move(*result_chunk); +} + void InitializeGetSubstrait(const Connection &con) { auto &catalog = Catalog::GetSystemCatalog(*con.context); // create the get_substrait table function that allows us to get a substrait @@ -299,22 +339,39 @@ void InitializeGetSubstraitJSON(const Connection &con) { catalog.CreateTableFunction(*con.context, get_substrait_json_info); } +//! Define and register a TableFunction ("from_substrait") that returns a TableRef void InitializeFromSubstrait(const Connection &con) { - auto &catalog = Catalog::GetSystemCatalog(*con.context); - - // create the from_substrait table function that allows us to get a query - // result from a substrait plan - TableFunction from_sub_func("from_substrait", {LogicalType::BLOB}, nullptr, nullptr); + // `FromSubstraitBind` translates a substrait plan and returns a `TableRef` + // to return a `TableRef` we use `bind_replace` instead of `bind` + TableFunction from_sub_func("from_substrait", {LogicalType::BLOB}, nullptr); from_sub_func.bind_replace = FromSubstraitBind; + + // register the TableFunction in the system catalog + auto &catalog = Catalog::GetSystemCatalog(*con.context); CreateTableFunctionInfo from_sub_info(from_sub_func); catalog.CreateTableFunction(*con.context, from_sub_info); } +//! Define and register a TableFunction ("explain_substrait") that returns a QueryResult +void InitializeExplainSubstrait(const Connection &con) { + TableFunction explain_sub_func( + "explain_substrait" + ,{LogicalType::BLOB} + ,/*function=*/TableFnExplainSubstrait // Translates the plan then converts to a string + ,/*bind=*/BindFnExplainSubstrait // Sets return schema to a single string + ); + + // register the TableFunction in the system catalog + auto &catalog = Catalog::GetSystemCatalog(*con.context); + CreateTableFunctionInfo explain_sub_info(explain_sub_func); + catalog.CreateTableFunction(*con.context, explain_sub_info); +} + void InitializeFromSubstraitJSON(const Connection &con) { auto &catalog = Catalog::GetSystemCatalog(*con.context); // create the from_substrait table function that allows us to get a query // result from a substrait plan - TableFunction from_sub_func_json("from_substrait_json", {LogicalType::VARCHAR}, nullptr, nullptr); + TableFunction from_sub_func_json("from_substrait_json", {LogicalType::VARCHAR}, nullptr); from_sub_func_json.bind_replace = FromSubstraitBindJSON; CreateTableFunctionInfo from_sub_info_json(from_sub_func_json); catalog.CreateTableFunction(*con.context, from_sub_info_json); @@ -329,6 +386,7 @@ void SubstraitExtension::Load(DuckDB &db) { InitializeFromSubstrait(con); InitializeFromSubstraitJSON(con); + InitializeExplainSubstrait(con); con.Commit(); } diff --git a/test/python/test_substrait_explain.py b/test/python/test_substrait_explain.py new file mode 100644 index 0000000..b4ef39d --- /dev/null +++ b/test/python/test_substrait_explain.py @@ -0,0 +1,26 @@ +import pandas as pd +import duckdb + +EXPECTED_RESULT = ''' +┌───────────────┬──────────────────────────────────────────────────────────────────────────────────────────────────────┐ +│ explain_key │ explain_value │ +│ varchar │ varchar │ +├───────────────┼──────────────────────────────────────────────────────────────────────────────────────────────────────┤ +│ physical_plan │ ┌───────────────────────────┐\n│ STREAMING_LIMIT │\n└─────────────┬─────────────┘\n┌────… │ +└───────────────┴──────────────────────────────────────────────────────────────────────────────────────────────────────┘ + +''' + +def test_roundtrip_substrait(require): + connection = require('substrait') + connection.execute('CREATE TABLE integers (i integer)') + connection.execute('INSERT INTO integers VALUES (0),(1),(2),(3),(4),(5),(6),(7),(8),(9),(NULL)') + + translate_result = connection.get_substrait('SELECT * FROM integers LIMIT 5') + proto_bytes = translate_result.fetchone()[0] + + expected = pd.Series([EXPECTED_RESULT], name='Explain Plan', dtype='str') + actual = connection.table_function('explain_substrait', proto_bytes).execute() + + pd.testing.assert_series_equal(actual.df()['Explain Plan'], expected) +