diff --git a/scripts/setup-ubuntu.sh b/scripts/setup-ubuntu.sh index c09ad710b948..725b96c322ee 100755 --- a/scripts/setup-ubuntu.sh +++ b/scripts/setup-ubuntu.sh @@ -78,8 +78,8 @@ function install_protobuf { cd protobuf-21.4 ./configure --prefix=/usr make "-j$(nproc)" - make install - ldconfig + sudo make install + sudo ldconfig } function install_velox_deps { diff --git a/velox/substrait/CMakeLists.txt b/velox/substrait/CMakeLists.txt index 93adbca51f0f..63e114b82c17 100644 --- a/velox/substrait/CMakeLists.txt +++ b/velox/substrait/CMakeLists.txt @@ -33,7 +33,7 @@ get_filename_component(PROTO_DIR ${substrait_proto_directory}/, DIRECTORY) # Generate Substrait hearders add_custom_command( OUTPUT ${PROTO_OUTPUT_FILES} - COMMAND protoc --proto_path ${CMAKE_SOURCE_DIR}/ --cpp_out ${CMAKE_SOURCE_DIR} + COMMAND protoc --proto_path ${proto_directory}/ --cpp_out ${PROTO_OUTPUT_DIR} ${PROTO_FILES} DEPENDS ${PROTO_DIR} COMMENT "Running PROTO compiler" @@ -47,7 +47,7 @@ set(SRCS SubstraitParser.cpp SubstraitToVeloxExpr.cpp SubstraitToVeloxPlan.cpp - TypeUtils.cpp_out + TypeUtils.cpp SubstraitExtensionCollector.cpp VeloxToSubstraitExpr.cpp VeloxToSubstraitPlan.cpp diff --git a/velox/substrait/SubstraitParser.cpp b/velox/substrait/SubstraitParser.cpp index c7102999e044..382a8cc387b9 100644 --- a/velox/substrait/SubstraitParser.cpp +++ b/velox/substrait/SubstraitParser.cpp @@ -254,9 +254,9 @@ void SubstraitParser::getSubFunctionTypes( std::string SubstraitParser::findVeloxFunction( const std::unordered_map& functionMap, uint64_t id) const { - std::string funcSpec = findFunctionSpec(functionMap, id); - std::string_view funcName = getNameBeforeDelimiter(funcSpec, ":"); - return mapToVeloxFunction({funcName.begin(), funcName.end()}); + std::string funcSpec = findSubstraitFuncSpec(functionMap, id); + std::string funcName = getSubFunctionName(funcSpec); + return mapToVeloxFunction(funcName); } std::string SubstraitParser::mapToVeloxFunction( diff --git a/velox/substrait/SubstraitToVeloxExpr.cpp b/velox/substrait/SubstraitToVeloxExpr.cpp index a805e134ec5f..d3f446c2ffc5 100644 --- a/velox/substrait/SubstraitToVeloxExpr.cpp +++ b/velox/substrait/SubstraitToVeloxExpr.cpp @@ -122,8 +122,8 @@ SubstraitVeloxExprConverter::toVeloxExpr( const ::substrait::Expression::ScalarFunction& sFunc, const RowTypePtr& inputType) { std::vector params; - params.reserve(substraitFunc.arguments().size()); - for (const auto& sArg : substraitFunc.arguments()) { + params.reserve(sFunc.arguments().size()); + for (const auto& sArg : sFunc.arguments()) { params.emplace_back(toVeloxExpr(sArg.value(), inputType)); } const auto& veloxFunction = @@ -142,18 +142,50 @@ SubstraitVeloxExprConverter::toVeloxExpr( toVeloxType(typeName), std::move(params), veloxFunction); } +std::shared_ptr +SubstraitVeloxExprConverter::literalsToConstantExpr( + const std::vector<::substrait::Expression::Literal>& literals) { + std::vector variants; + variants.reserve(literals.size()); + VELOX_CHECK(literals.size() > 0, "List should have at least one item."); + std::optional literalType = std::nullopt; + for (const auto& literal : literals) { + auto veloxVariant = toVeloxExpr(literal)->value(); + if (!literalType.has_value()) { + literalType = veloxVariant.inferType(); + } + variants.emplace_back(veloxVariant); + } + VELOX_CHECK(literalType.has_value(), "Type expected."); + // Create flat vector from the variants. + VectorPtr vector = + setVectorFromVariants(literalType.value(), variants, pool_); + // Create array vector from the flat vector. + ArrayVectorPtr arrayVector = + toArrayVector(literalType.value(), vector, pool_); + // Wrap the array vector into constant vector. + auto constantVector = BaseVector::wrapInConstant(1, 0, arrayVector); + return std::make_shared(constantVector); +} + core::TypedExprPtr SubstraitVeloxExprConverter::toVeloxExpr( const ::substrait::Expression::SingularOrList& singularOrList, const RowTypePtr& inputType) { + VELOX_CHECK( + singularOrList.options_size() > 0, "At least one option is expected."); + auto options = singularOrList.options(); + std::vector<::substrait::Expression::Literal> literals; + literals.reserve(options.size()); + for (const auto& option : options) { + VELOX_CHECK(option.has_literal(), "Literal is expected as option."); + literals.emplace_back(option.literal()); + } + std::vector> params; - // TODO: other options? - auto inLists = singularOrList.options(); - VELOX_CHECK(inLists.size() > 0, "At least one option is needed."); params.reserve(2); - // first is the value, second is the list + // First param is the value, second param is the list. params.emplace_back(toVeloxExpr(singularOrList.value(), inputType)); - // TODO: is this the correct way to use SingularOrList? - params.emplace_back(toVeloxExpr(inLists[0], inputType)); + params.emplace_back(literalsToConstantExpr(literals)); return std::make_shared( BOOLEAN(), std::move(params), "in"); } @@ -196,33 +228,13 @@ SubstraitVeloxExprConverter::toVeloxExpr( veloxType, variant::null(veloxType->kind())); } case ::substrait::Expression_Literal::LiteralTypeCase::kList: { - // List is used in 'in' expression. Will wrap a constant - // vector with an array vector inside to create the constant expression. - std::vector variants; - variants.reserve(substraitLit.list().values().size()); - VELOX_CHECK( - substraitLit.list().values().size() > 0, - "List should have at least one item."); - std::optional literalType = std::nullopt; + // Literals in List are put in a constant vector. + std::vector<::substrait::Expression::Literal> literals; + literals.reserve(substraitLit.list().values().size()); for (const auto& literal : substraitLit.list().values()) { - auto typedVariant = toTypedVariant(literal); - if (!literalType.has_value()) { - literalType = typedVariant->variantType; - } - variants.emplace_back(typedVariant->veloxVariant); + literals.emplace_back(literal); } - VELOX_CHECK(literalType.has_value(), "Type expected."); - // Create flat vector from the variants. - VectorPtr vector = - setVectorFromVariants(literalType.value(), variants, pool_); - // Create array vector from the flat vector. - ArrayVectorPtr arrayVector = - toArrayVector(literalType.value(), vector, pool_); - // Wrap the array vector into constant vector. - auto constantVector = BaseVector::wrapInConstant(1, 0, arrayVector); - auto constantExpr = - std::make_shared(constantVector); - return constantExpr; + return literalsToConstantExpr(literals); } case ::substrait::Expression_Literal::LiteralTypeCase::kVarChar: return std::make_shared( @@ -248,31 +260,6 @@ SubstraitVeloxExprConverter::toVeloxExpr( return std::make_shared(type, inputs, nullOnFailure); } -std::shared_ptr -SubstraitVeloxExprConverter::toVeloxExpr( - const ::substrait::Expression& sExpr, - const RowTypePtr& inputType) { - std::shared_ptr veloxExpr; - auto typeCase = sExpr.rex_type_case(); - switch (typeCase) { - case ::substrait::Expression::RexTypeCase::kLiteral: - return toVeloxExpr(sExpr.literal()); - case ::substrait::Expression::RexTypeCase::kScalarFunction: - return toVeloxExpr(sExpr.scalar_function(), inputType); - case ::substrait::Expression::RexTypeCase::kSelection: - return toVeloxExpr(sExpr.selection(), inputType); - case ::substrait::Expression::RexTypeCase::kCast: - return toVeloxExpr(sExpr.cast(), inputType); - case ::substrait::Expression::RexTypeCase::kIfThen: - return toVeloxExpr(sExpr.if_then(), inputType); - case ::substrait::Expression::RexTypeCase::kSingularOrList: - return toVeloxExpr(sExpr.singular_or_list(), inputType); - default: - VELOX_NYI( - "Substrait conversion not supported for Expression '{}'", typeCase); - } -} - std::shared_ptr SubstraitVeloxExprConverter::toVeloxExpr( const ::substrait::Expression::IfThen& ifThenExpr, @@ -312,25 +299,26 @@ SubstraitVeloxExprConverter::toVeloxExpr( std::shared_ptr SubstraitVeloxExprConverter::toVeloxExpr( - const ::substrait::Expression_IfThen& substraitIfThen, + const ::substrait::Expression& sExpr, const RowTypePtr& inputType) { - std::vector inputs; - if (substraitIfThen.has_else_()) { - inputs.reserve(substraitIfThen.ifs_size() * 2 + 1); - } else { - inputs.reserve(substraitIfThen.ifs_size() * 2); - } - - TypePtr resultType; - for (auto& ifExpr : substraitIfThen.ifs()) { - auto ifClauseExpr = toVeloxExpr(ifExpr.if_(), inputType); - inputs.emplace_back(ifClauseExpr); - auto thenClauseExpr = toVeloxExpr(ifExpr.then(), inputType); - inputs.emplace_back(thenClauseExpr); - - if (!thenClauseExpr->type()->containsUnknown()) { - resultType = thenClauseExpr->type(); - } + std::shared_ptr veloxExpr; + auto typeCase = sExpr.rex_type_case(); + switch (typeCase) { + case ::substrait::Expression::RexTypeCase::kLiteral: + return toVeloxExpr(sExpr.literal()); + case ::substrait::Expression::RexTypeCase::kScalarFunction: + return toVeloxExpr(sExpr.scalar_function(), inputType); + case ::substrait::Expression::RexTypeCase::kSelection: + return toVeloxExpr(sExpr.selection(), inputType); + case ::substrait::Expression::RexTypeCase::kCast: + return toVeloxExpr(sExpr.cast(), inputType); + case ::substrait::Expression::RexTypeCase::kIfThen: + return toVeloxExpr(sExpr.if_then(), inputType); + case ::substrait::Expression::RexTypeCase::kSingularOrList: + return toVeloxExpr(sExpr.singular_or_list(), inputType); + default: + VELOX_NYI( + "Substrait conversion not supported for Expression '{}'", typeCase); } } diff --git a/velox/substrait/SubstraitToVeloxExpr.h b/velox/substrait/SubstraitToVeloxExpr.h index aee610908725..bf5d85aa5598 100644 --- a/velox/substrait/SubstraitToVeloxExpr.h +++ b/velox/substrait/SubstraitToVeloxExpr.h @@ -88,6 +88,11 @@ class SubstraitVeloxExprConverter { const ::substrait::Expression::IfThen& ifThenExpr, const RowTypePtr& inputType); + /// Wrap a constant vector from literals with an array vector inside to create + /// the constant expression. + std::shared_ptr literalsToConstantExpr( + const std::vector<::substrait::Expression::Literal>& literals); + private: /// Memory pool. memory::MemoryPool* pool_; diff --git a/velox/substrait/SubstraitToVeloxPlan.cpp b/velox/substrait/SubstraitToVeloxPlan.cpp index 73ed0dd0e310..e276ca870472 100644 --- a/velox/substrait/SubstraitToVeloxPlan.cpp +++ b/velox/substrait/SubstraitToVeloxPlan.cpp @@ -206,12 +206,13 @@ std::shared_ptr SubstraitVeloxPlanConverter::toVeloxAgg( // Each measure represents one aggregate expression. std::vector aggExprs; aggExprs.reserve(sAgg.measures().size()); - + std::vector aggregateMasks; + aggregateMasks.reserve(sAgg.measures().size()); for (const auto& smea : sAgg.measures()) { core::FieldAccessTypedExprPtr aggregateMask; - ::substrait::Expression substraitAggMask = measure.filter(); + ::substrait::Expression substraitAggMask = smea.filter(); // Get Aggregation Masks. - if (measure.has_filter()) { + if (smea.has_filter()) { if (substraitAggMask.ByteSizeLong() == 0) { aggregateMask = {}; } else { @@ -227,8 +228,8 @@ std::shared_ptr SubstraitVeloxPlanConverter::toVeloxAgg( std::vector> aggParams; aggParams.reserve(aggFunction.arguments().size()); for (const auto& arg : aggFunction.arguments()) { - aggParams.emplace_back(exprConverter_->toVeloxExpr( - getExprFromFunctionArgument(arg), inputType)); + aggParams.emplace_back( + exprConverter_->toVeloxExpr(arg.value(), inputType)); } auto aggVeloxType = toVeloxType(subParser_->parseType(aggFunction.output_type())->type); @@ -238,8 +239,6 @@ std::shared_ptr SubstraitVeloxPlanConverter::toVeloxAgg( } bool ignoreNullKeys = false; - std::vector> aggregateMasks( - sAgg.measures().size()); std::vector> preGroupingExprs = {}; @@ -267,10 +266,10 @@ std::shared_ptr SubstraitVeloxPlanConverter::toVeloxAgg( } core::PlanNodePtr SubstraitVeloxPlanConverter::toVeloxPlan( - const ::substrait::ProjectRel& projectRel) { + const ::substrait::ProjectRel& sProject) { core::PlanNodePtr childNode; - if (projectRel.has_input()) { - childNode = toVeloxPlan(projectRel.input()); + if (sProject.has_input()) { + childNode = toVeloxPlan(sProject.input()); } else { VELOX_FAIL("Child Rel is expected in ProjectRel."); } @@ -286,8 +285,7 @@ core::PlanNodePtr SubstraitVeloxPlanConverter::toVeloxPlan( int colIdx = 0; for (const auto& expr : projectExprs) { expressions.emplace_back(exprConverter_->toVeloxExpr(expr, inputType)); - projectNames.emplace_back( - substraitParser_->makeNodeName(planNodeId_, colIdx)); + projectNames.emplace_back(subParser_->makeNodeName(planNodeId_, colIdx)); colIdx += 1; } @@ -309,13 +307,13 @@ core::PlanNodePtr SubstraitVeloxPlanConverter::toVeloxPlan( const auto& inputType = childNode->outputType(); const auto& sExpr = filterRel.condition(); - std::shared_ptr - return std::make_shared( - nextPlanNodeId(), - exprConverter_->toVeloxExpr(sExpr, inputType), - childNode); + return std::make_shared( + nextPlanNodeId(), + exprConverter_->toVeloxExpr(sExpr, inputType), + childNode); } + bool isPushDownSupportedByFormat( const dwio::common::FileFormat& format, connector::hive::SubfieldFilters& subfieldFilters) { @@ -565,12 +563,12 @@ core::PlanNodePtr SubstraitVeloxPlanConverter::toVeloxPlan( } core::PlanNodePtr SubstraitVeloxPlanConverter::toVeloxPlan( - const ::substrait::RelRoot& root) { + const ::substrait::RelRoot& sRoot) { // TODO: Use the names as the output names for the whole computing. - const auto& names = root.names(); - if (root.has_input()) { - const auto& rel = root.input(); - return toVeloxPlan(rel); + const auto& sNames = sRoot.names(); + if (sRoot.has_input()) { + const auto& sRel = sRoot.input(); + return toVeloxPlan(sRel); } VELOX_FAIL("Input is expected in RelRoot."); } @@ -776,9 +774,7 @@ void SubstraitVeloxPlanConverter::flattenConditions( if (subParser_->getSubFunctionName(filterNameSpec) == "and") { for (const auto& sCondition : sFunc.arguments()) { flattenConditions( - getExprFromFunctionArgument(sCondition), - scalarFunctions, - singularOrLists); + sCondition.value(), scalarFunctions, singularOrLists); } } else { scalarFunctions.emplace_back(sFunc); @@ -798,33 +794,35 @@ std::string SubstraitVeloxPlanConverter::findFuncSpec(uint64_t id) { return subParser_->findSubstraitFuncSpec(functionMap_, id); } -core::PlanNodePtr SubstraitVeloxPlanConverter::toVeloxPlan( - const ::substrait::Plan& substraitPlan, - memory::MemoryPool* pool) { - VELOX_CHECK( - checkTypeExtension(substraitPlan), - "The type extension only have unknown type.") - // Construct the function map based on the Substrait representation. - constructFunctionMap(substraitPlan); +int32_t SubstraitVeloxPlanConverter::streamIsInput( + const ::substrait::ReadRel& sRead) { + if (sRead.has_local_files()) { + const auto& fileList = sRead.local_files().items(); + if (fileList.size() == 0) { + VELOX_FAIL("At least one file path is expected."); + } - // In fact, only one RelRoot or Rel is expected here. - VELOX_CHECK_EQ(substraitPlan.relations_size(), 1); - const auto& rel = substraitPlan.relations(0); - if (rel.has_root()) { - return toVeloxPlan(rel.root(), pool); + // The stream input will be specified with the format of + // "iterator:${index}". + std::string filePath = fileList[0].uri_file(); + std::string prefix = "iterator:"; + std::size_t pos = filePath.find(prefix); + if (pos == std::string::npos) { + return -1; + } + + // Get the index. + std::string idxStr = filePath.substr(pos + prefix.size(), filePath.size()); + try { + return stoi(idxStr); + } catch (const std::exception& err) { + VELOX_FAIL(err.what()); + } } - if (rel.has_rel()) { - return toVeloxPlan(rel.rel(), pool); + if (validationMode_) { + return -1; } - VELOX_FAIL("Input is expected in RelRoot."); -} - -core::PlanNodePtr SubstraitVeloxPlanConverter::toVeloxPlan( - const ::substrait::Plan& substraitPlan, - memory::MemoryPool* pool) { - // Construct the function map based on the Substrait representation. - constructFuncMap(sPlan); - VELOX_FAIL("RelRoot or Rel is expected in Plan."); + VELOX_FAIL("Local file is expected."); } void SubstraitVeloxPlanConverter::extractJoinKeys( @@ -843,17 +841,17 @@ void SubstraitVeloxPlanConverter::extractJoinKeys( functionMap_, visited->scalar_function().function_reference())); const auto& args = visited->scalar_function().arguments(); if (funcName == "and") { - expressions.push_back(&getExprFromFunctionArgument(args[0])); - expressions.push_back(&getExprFromFunctionArgument(args[1])); + expressions.push_back(&args[0].value()); + expressions.push_back(&args[1].value()); } else if (funcName == "eq") { VELOX_CHECK(std::all_of( args.cbegin(), args.cend(), [](const ::substrait::FunctionArgument& arg) { - return getExprFromFunctionArgument(arg).has_selection(); + return arg.value().has_selection(); })); - leftExprs.push_back(&getExprFromFunctionArgument(args[0]).selection()); - rightExprs.push_back(&getExprFromFunctionArgument(args[1]).selection()); + leftExprs.push_back(&args[0].value().selection()); + rightExprs.push_back(&args[1].value().selection()); } else { VELOX_NYI("Join condition {} not supported.", funcName); } @@ -884,12 +882,11 @@ connector::hive::SubfieldFilters SubstraitVeloxPlanConverter::toSubfieldFilters( auto filterName = subParser_->getSubFunctionName(filterNameSpec); if (filterName == sNot) { VELOX_CHECK(scalarFunction.arguments().size() == 1); - auto expr = getExprFromFunctionArgument(scalarFunction.arguments()[0]); + auto expr = scalarFunction.arguments()[0].value(); if (expr.has_scalar_function()) { // Set its chid to filter info with reverse enabled. setFilterMap( - getExprFromFunctionArgument(scalarFunction.arguments()[0]) - .scalar_function(), + scalarFunction.arguments()[0].value().scalar_function(), inputTypeList, colInfoMap, true); @@ -906,18 +903,16 @@ connector::hive::SubfieldFilters SubstraitVeloxPlanConverter::toSubfieldFilters( scalarFunction.arguments().cbegin(), scalarFunction.arguments().cend(), [](const ::substrait::FunctionArgument& arg) { - return getExprFromFunctionArgument(arg).has_scalar_function() || - getExprFromFunctionArgument(arg).has_singular_or_list(); + return arg.value().has_scalar_function() || + arg.value().has_singular_or_list(); })); // Set the chidren functions to filter info. They should be // effective to the same field. for (const auto& arg : scalarFunction.arguments()) { - auto expr = getExprFromFunctionArgument(arg); + auto expr = arg.value(); if (expr.has_scalar_function()) { setFilterMap( - getExprFromFunctionArgument(arg).scalar_function(), - inputTypeList, - colInfoMap); + arg.value().scalar_function(), inputTypeList, colInfoMap); } else if (expr.has_singular_or_list()) { setSingularListValues(expr.singular_or_list(), colInfoMap); } else { @@ -941,12 +936,10 @@ bool SubstraitVeloxPlanConverter::fieldOrWithLiteral( arguments, uint32_t& fieldIndex) { if (arguments.size() == 1) { - if (getExprFromFunctionArgument(arguments[0]).has_selection()) { + if (arguments[0].value().has_selection()) { // Only field exists. fieldIndex = subParser_->parseReferenceSegment( - getExprFromFunctionArgument(arguments[0]) - .selection() - .direct_reference()); + arguments[0].value().selection().direct_reference()); return true; } else { return false; @@ -960,11 +953,11 @@ bool SubstraitVeloxPlanConverter::fieldOrWithLiteral( bool fieldExists = false; bool literalExists = false; for (const auto& param : arguments) { - auto typeCase = getExprFromFunctionArgument(param).rex_type_case(); + auto typeCase = param.value().rex_type_case(); switch (typeCase) { case ::substrait::Expression::RexTypeCase::kSelection: fieldIndex = subParser_->parseReferenceSegment( - getExprFromFunctionArgument(param).selection().direct_reference()); + param.value().selection().direct_reference()); fieldExists = true; break; case ::substrait::Expression::RexTypeCase::kLiteral: @@ -983,19 +976,19 @@ bool SubstraitVeloxPlanConverter::chidrenFunctionsOnSameField( // Get the column indices of the chidren functions. std::vector colIndices; for (const auto& arg : function.arguments()) { - if (getExprFromFunctionArgument(arg).has_scalar_function()) { - auto scalarFunction = getExprFromFunctionArgument(arg).scalar_function(); + if (arg.value().has_scalar_function()) { + auto scalarFunction = arg.value().scalar_function(); for (const auto& param : scalarFunction.arguments()) { - if (getExprFromFunctionArgument(param).has_selection()) { - auto field = getExprFromFunctionArgument(param).selection(); + if (param.value().has_selection()) { + auto field = param.value().selection(); VELOX_CHECK(field.has_direct_reference()); int32_t colIdx = subParser_->parseReferenceSegment(field.direct_reference()); colIndices.emplace_back(colIdx); } } - } else if (getExprFromFunctionArgument(arg).has_singular_or_list()) { - auto singularOrList = getExprFromFunctionArgument(arg).singular_or_list(); + } else if (arg.value().has_singular_or_list()) { + auto singularOrList = arg.value().singular_or_list(); int32_t colIdx = getColumnIndexFromSingularOrList(singularOrList); colIndices.emplace_back(colIdx); } else { @@ -1038,17 +1031,14 @@ bool SubstraitVeloxPlanConverter::canPushdownNot( scalarFunction.arguments().size() == 1, "Only one arg is expected for Not."); auto notArg = scalarFunction.arguments()[0]; - if (!getExprFromFunctionArgument(notArg).has_scalar_function()) { + if (!notArg.value().has_scalar_function()) { // Not for a Boolean Literal or Or List is not supported curretly. // It can be pushed down with an AlwaysTrue or AlwaysFalse Range. return false; } auto argFunction = subParser_->findSubstraitFuncSpec( - functionMap_, - getExprFromFunctionArgument(notArg) - .scalar_function() - .function_reference()); + functionMap_, notArg.value().scalar_function().function_reference()); auto functionName = subParser_->getSubFunctionName(argFunction); std::unordered_set supportedNotFunctions = { @@ -1056,8 +1046,7 @@ bool SubstraitVeloxPlanConverter::canPushdownNot( uint32_t fieldIdx; bool isFieldOrWithLiteral = fieldOrWithLiteral( - getExprFromFunctionArgument(notArg).scalar_function().arguments(), - fieldIdx); + notArg.value().scalar_function().arguments(), fieldIdx); if (supportedNotFunctions.find(functionName) != supportedNotFunctions.end() && isFieldOrWithLiteral && @@ -1082,18 +1071,14 @@ bool SubstraitVeloxPlanConverter::canPushdownOr( sIsNotNull, sGte, sGt, sLte, sLt, sEqual}; for (const auto& arg : scalarFunction.arguments()) { - if (getExprFromFunctionArgument(arg).has_scalar_function()) { + if (arg.value().has_scalar_function()) { auto nameSpec = subParser_->findSubstraitFuncSpec( - functionMap_, - getExprFromFunctionArgument(arg) - .scalar_function() - .function_reference()); + functionMap_, arg.value().scalar_function().function_reference()); auto functionName = subParser_->getSubFunctionName(nameSpec); uint32_t fieldIdx; bool isFieldOrWithLiteral = fieldOrWithLiteral( - getExprFromFunctionArgument(arg).scalar_function().arguments(), - fieldIdx); + arg.value().scalar_function().arguments(), fieldIdx); if (supportedOrFunctions.find(functionName) == supportedOrFunctions.end() || !isFieldOrWithLiteral || @@ -1102,8 +1087,8 @@ bool SubstraitVeloxPlanConverter::canPushdownOr( // The arg should be field or field with literal. return false; } - } else if (getExprFromFunctionArgument(arg).has_singular_or_list()) { - auto singularOrList = getExprFromFunctionArgument(arg).singular_or_list(); + } else if (arg.value().has_singular_or_list()) { + auto singularOrList = arg.value().singular_or_list(); if (!canPushdownSingularOrList(singularOrList, true)) { return false; } @@ -1298,14 +1283,14 @@ void SubstraitVeloxPlanConverter::setFilterMap( std::optional colIdx; std::optional<::substrait::Expression_Literal> substraitLit; for (const auto& param : scalarFunction.arguments()) { - auto typeCase = getExprFromFunctionArgument(param).rex_type_case(); + auto typeCase = param.value().rex_type_case(); switch (typeCase) { case ::substrait::Expression::RexTypeCase::kSelection: colIdx = subParser_->parseReferenceSegment( - getExprFromFunctionArgument(param).selection().direct_reference()); + param.value().selection().direct_reference()); break; case ::substrait::Expression::RexTypeCase::kLiteral: - substraitLit = getExprFromFunctionArgument(param).literal(); + substraitLit = param.value().literal(); break; default: VELOX_NYI( @@ -1573,21 +1558,6 @@ void SubstraitVeloxPlanConverter::constructSubfieldFilters( std::move(colFilters), inputName, filterInfo->nullAllowed_, filters); } -void SubstraitVeloxPlanConverter::constructFunctionMap( - const ::substrait::Plan& substraitPlan) { - // Construct the function map based on the Substrait representation. - for (const auto& sExtension : substraitPlan.extensions()) { - if (!sExtension.has_extension_function()) { - continue; - } - const auto& sFmap = sExtension.extension_function(); - auto id = sFmap.function_anchor(); - auto name = sFmap.name(); - functionMap_[id] = name; - } - exprConverter_ = - std::make_shared(pool_, functionMap_); -} bool SubstraitVeloxPlanConverter::checkTypeExtension( const ::substrait::Plan& substraitPlan) { for (const auto& sExtension : substraitPlan.extensions()) { @@ -1603,49 +1573,6 @@ bool SubstraitVeloxPlanConverter::checkTypeExtension( return true; } -const std::string& SubstraitVeloxPlanConverter::findFunction( - uint64_t id) const { - return substraitParser_->findFunctionSpec(functionMap_, id); -} - -void SubstraitVeloxPlanConverter::extractJoinKeys( - const ::substrait::Expression& joinExpression, - std::vector& leftExprs, - std::vector& rightExprs) { - std::vector expressions; - expressions.push_back(&joinExpression); - while (!expressions.empty()) { - auto visited = expressions.back(); - expressions.pop_back(); - if (visited->rex_type_case() == - ::substrait::Expression::RexTypeCase::kScalarFunction) { - const auto& funcName = - subParser_->getSubFunctionName(subParser_->findVeloxFunction( - functionMap_, visited->scalar_function().function_reference())); - const auto& args = visited->scalar_function().arguments(); - if (funcName == "and") { - expressions.push_back(&getExprFromFunctionArgument(args[0])); - expressions.push_back(&getExprFromFunctionArgument(args[1])); - } else if (funcName == "eq") { - VELOX_CHECK(std::all_of( - args.cbegin(), - args.cend(), - [](const ::substrait::FunctionArgument& arg) { - return getExprFromFunctionArgument(arg).has_selection(); - })); - leftExprs.push_back(&getExprFromFunctionArgument(args[0]).selection()); - rightExprs.push_back(&getExprFromFunctionArgument(args[1]).selection()); - } else { - VELOX_NYI("Join condition {} not supported.", funcName); - } - } else { - VELOX_FAIL( - "Unable to parse from join expression: {}", - joinExpression.DebugString()); - } - } -} - connector::hive::SubfieldFilters SubstraitVeloxPlanConverter::mapToFilters( const std::vector& inputNameList, const std::vector& inputTypeList, @@ -1742,15 +1669,13 @@ bool SubstraitVeloxPlanConverter::canPushdownSingularOrList( const ::substrait::Expression_SingularOrList& singularOrList, bool disableIntLike) { VELOX_CHECK( - singularOrList.options_size() == 1, - "Only one options list is expected in SingularOrList expression."); + singularOrList.options_size() > 0, "At least one option is expected."); // Check whether the value is field. bool hasField = singularOrList.value().has_selection(); - // TODO: improve the logic here. - auto literals = singularOrList.options()[0].literal().list().values(); - std::vector types; - for (auto& literal : literals) { - auto type = literal.literal_type_case(); + auto options = singularOrList.options(); + for (const auto& option : options) { + VELOX_CHECK(option.has_literal(), "Literal is expected as option."); + auto type = option.literal().literal_type_case(); // Only BigintValues and BytesValues are supported. if (type != ::substrait::Expression_Literal::LiteralTypeCase::kI32 && type != ::substrait::Expression_Literal::LiteralTypeCase::kI64 && @@ -1773,8 +1698,10 @@ uint32_t SubstraitVeloxPlanConverter::getColumnIndexFromSingularOrList( // Get the column index. ::substrait::Expression_FieldReference selection; if (singularOrList.value().has_scalar_function()) { - selection = getExprFromFunctionArgument( - singularOrList.value().scalar_function().arguments()[0]) + selection = singularOrList.value() + .scalar_function() + .arguments()[0] + .value() .selection(); } else if (singularOrList.value().has_selection()) { selection = singularOrList.value().selection(); @@ -1787,23 +1714,19 @@ uint32_t SubstraitVeloxPlanConverter::getColumnIndexFromSingularOrList( void SubstraitVeloxPlanConverter::setSingularListValues( const ::substrait::Expression_SingularOrList& singularOrList, std::unordered_map>& colInfoMap) { + VELOX_CHECK( + singularOrList.options_size() > 0, "At least one option is expected."); // Get the column index. uint32_t colIdx = getColumnIndexFromSingularOrList(singularOrList); // Get the value list. + auto options = singularOrList.options(); std::vector variants; - VELOX_CHECK( - singularOrList.options_size() == 1, - "Options list size 1 expected in SingularOrList expression."); - auto option = singularOrList.options()[0]; - VELOX_CHECK( - option.has_literal(), - "Options list has literal expected in SingularOrList expression."); - auto valueList = option.literal().list(); - variants.reserve(valueList.values().size()); - for (const auto& literal : valueList.values()) { + variants.reserve(options.size()); + for (const auto& option : options) { + VELOX_CHECK(option.has_literal(), "Literal is expected as option."); variants.emplace_back( - exprConverter_->toTypedVariant(literal)->veloxVariant); + exprConverter_->toVeloxExpr(option.literal())->value()); } // Set the value list to filter info. colInfoMap[colIdx]->setValues(variants); diff --git a/velox/substrait/SubstraitToVeloxPlan.h b/velox/substrait/SubstraitToVeloxPlan.h index e1b50adf512b..9000e0472db5 100644 --- a/velox/substrait/SubstraitToVeloxPlan.h +++ b/velox/substrait/SubstraitToVeloxPlan.h @@ -87,20 +87,16 @@ class SubstraitVeloxPlanConverter { /// Index: the index of the partition this item belongs to. /// Starts: the start positions in byte to read from the items. /// Lengths: the lengths in byte to read from the items. - std::shared_ptr toVeloxPlan( - const ::substrait::ReadRel& sRead); + core::PlanNodePtr toVeloxPlan(const ::substrait::ReadRel& sRead); /// Used to convert Substrait Rel into Velox PlanNode. - std::shared_ptr toVeloxPlan( - const ::substrait::Rel& sRel); + core::PlanNodePtr toVeloxPlan(const ::substrait::Rel& sRel); /// Used to convert Substrait RelRoot into Velox PlanNode. - std::shared_ptr toVeloxPlan( - const ::substrait::RelRoot& sRoot); + core::PlanNodePtr toVeloxPlan(const ::substrait::RelRoot& sRoot); /// Used to convert Substrait Plan into Velox PlanNode. - std::shared_ptr toVeloxPlan( - const ::substrait::Plan& substraitPlan); + core::PlanNodePtr toVeloxPlan(const ::substrait::Plan& substraitPlan); /// Used to construct the function map between the index /// and the Substrait function name. Initialize the expression diff --git a/velox/substrait/SubstraitToVeloxPlanValidator.cpp b/velox/substrait/SubstraitToVeloxPlanValidator.cpp index 6f4ae0bb92d3..3bd768f0a7c6 100644 --- a/velox/substrait/SubstraitToVeloxPlanValidator.cpp +++ b/velox/substrait/SubstraitToVeloxPlanValidator.cpp @@ -265,7 +265,7 @@ bool SubstraitToVeloxPlanValidator::validate( planConverter_->findFuncSpec(aggFunction.function_reference())); toVeloxType(subParser_->parseType(aggFunction.output_type())->type); for (const auto& arg : aggFunction.arguments()) { - auto typeCase = getExprFromFunctionArgument(arg).rex_type_case(); + auto typeCase = arg.value().rex_type_case(); switch (typeCase) { case ::substrait::Expression::RexTypeCase::kSelection: case ::substrait::Expression::RexTypeCase::kLiteral: diff --git a/velox/substrait/TypeUtils.cpp b/velox/substrait/TypeUtils.cpp index 0b86d6a50e9c..a7517140037e 100644 --- a/velox/substrait/TypeUtils.cpp +++ b/velox/substrait/TypeUtils.cpp @@ -123,12 +123,5 @@ TypePtr toVeloxType(const std::string& typeName) { VELOX_NYI("Velox type conversion not supported for type {}.", typeName); } } -const ::substrait::Expression& getExprFromFunctionArgument( - const ::substrait::FunctionArgument& arg) { - if (arg.has_value()) { - return arg.value(); - } else { - VELOX_NYI("FunctionArgument arg must has value."); - } -} + } // namespace facebook::velox::substrait diff --git a/velox/substrait/tests/FunctionTest.cpp b/velox/substrait/tests/FunctionTest.cpp index cd9856ee3f77..a8f673c354fe 100644 --- a/velox/substrait/tests/FunctionTest.cpp +++ b/velox/substrait/tests/FunctionTest.cpp @@ -33,10 +33,6 @@ namespace vestrait = facebook::velox::substrait; class FunctionTest : public ::testing::Test { protected: std::shared_ptr queryCtx_ = core::QueryCtx::createForTest(); - - std::unique_ptr pool_ = - memory::getDefaultScopedMemoryPool(); - std::shared_ptr substraitParser_ = std::make_shared(); @@ -94,7 +90,7 @@ TEST_F(FunctionTest, constructFunctionMap) { auto functionMap = planConverter_->getFunctionMap(); ASSERT_EQ(functionMap.size(), 9); - std::string function = planConverter_->findFunction(1); + std::string function = planConverter_->findFuncSpec(1); ASSERT_EQ(function, "lte:fp64_fp64"); function = planConverter_->findFuncSpec(2); @@ -112,13 +108,13 @@ TEST_F(FunctionTest, constructFunctionMap) { function = planConverter_->findFuncSpec(6); ASSERT_EQ(function, "sum:opt_fp64"); - function = planConverter_->findFunction(7); + function = planConverter_->findFuncSpec(7); ASSERT_EQ(function, "count:opt_fp64"); function = planConverter_->findFuncSpec(8); ASSERT_EQ(function, "count:opt_i32"); - function = planConverter_->findFunction(9); + function = planConverter_->findFuncSpec(9); ASSERT_EQ(function, "is_not_null:fp64"); } diff --git a/velox/substrait/tests/Substrait2VeloxPlanConversionTest.cpp b/velox/substrait/tests/Substrait2VeloxPlanConversionTest.cpp index e043ccd25ee4..69fab2909a36 100644 --- a/velox/substrait/tests/Substrait2VeloxPlanConversionTest.cpp +++ b/velox/substrait/tests/Substrait2VeloxPlanConversionTest.cpp @@ -29,6 +29,7 @@ using namespace facebook::velox; using namespace facebook::velox::test; using namespace facebook::velox::connector::hive; using namespace facebook::velox::exec; +namespace vestrait = facebook::velox::substrait; class Substrait2VeloxPlanConversionTest : public exec::test::HiveConnectorTestBase { @@ -68,6 +69,13 @@ class Substrait2VeloxPlanConversionTest std::shared_ptr tmpDir_{ exec::test::TempDirectoryPath::create()}; + std::shared_ptr planConverter_ = + std::make_shared( + memoryPool_.get()); + + private: + std::unique_ptr memoryPool_{ + memory::getDefaultScopedMemoryPool()}; }; // This test will firstly generate mock TPC-H lineitem ORC file. Then, Velox's @@ -272,7 +280,7 @@ TEST_F(Substrait2VeloxPlanConversionTest, q6) { // Read q6_first_stage.json and resume the Substrait plan. ::substrait::Plan substraitPlan; - JsonToProtoConverter::readFromFile(planPath, substraitPlan); + JsonToProtoConverter::readFromFile(subPlanPath, substraitPlan); // Convert to Velox PlanNode. facebook::velox::substrait::SubstraitVeloxPlanConverter planConverter( @@ -284,6 +292,6 @@ TEST_F(Substrait2VeloxPlanConversionTest, q6) { }); exec::test::AssertQueryBuilder(planNode) - .splits(makeSplits(planConverter, planNode)) + .splits(makeSplits(*planConverter_, planNode)) .assertResults(expectedResult); } \ No newline at end of file diff --git a/velox/substrait/tests/VeloxSubstraitRoundTripTest.cpp b/velox/substrait/tests/VeloxSubstraitRoundTripTest.cpp index 902ea57c1465..e1ccdbe4221b 100644 --- a/velox/substrait/tests/VeloxSubstraitRoundTripTest.cpp +++ b/velox/substrait/tests/VeloxSubstraitRoundTripTest.cpp @@ -69,7 +69,6 @@ class VeloxSubstraitRoundTripTest : public OperatorTestBase { // Assert velox again. assertQuery(samePlan, duckDbSql); } - std::shared_ptr veloxConvertor_ = std::make_shared(); std::shared_ptr substraitConverter_ =