diff --git a/src/parser/CMakeLists.txt b/src/parser/CMakeLists.txt index 478f2f82144..3f727575c95 100644 --- a/src/parser/CMakeLists.txt +++ b/src/parser/CMakeLists.txt @@ -14,6 +14,7 @@ add_library( ${FLEX_Scanner_OUTPUTS} ${BISON_Parser_OUTPUTS} Expressions.cpp + FunctionManager.cpp Clauses.cpp SequentialSentences.cpp MaintainSentences.cpp diff --git a/src/parser/Expressions.cpp b/src/parser/Expressions.cpp index e229afd9343..5ebdb25c75b 100644 --- a/src/parser/Expressions.cpp +++ b/src/parser/Expressions.cpp @@ -7,6 +7,7 @@ #include "base/Base.h" #include "base/Cord.h" #include "parser/Expressions.h" +#include "parser/FunctionManager.h" #define THROW_IF_NO_SPACE(POS, END, REQUIRE) \ @@ -71,6 +72,8 @@ std::unique_ptr Expression::makeExpr(uint8_t kind) { switch (intToKind(kind)) { case kPrimary: return std::make_unique(); + case kFunctionCall: + return std::make_unique(); case kUnary: return std::make_unique(); case kTypeCasting: @@ -643,6 +646,88 @@ const char* PrimaryExpression::decode(const char *pos, const char *end) { } +std::string FunctionCallExpression::toString() const { + std::string buf; + buf.reserve(256); + buf += *name_; + buf += "("; + for (auto &arg : args_) { + buf += arg->toString(); + buf += ","; + } + if (!args_.empty()) { + buf.resize(buf.size() - 1); + } + return buf; +} + + +VariantType FunctionCallExpression::eval() const { + std::vector args; + args.resize(args_.size()); + auto eval = [] (auto &expr) { + return expr->eval(); + }; + std::transform(args_.begin(), args_.end(), args.begin(), eval); + return function_(args); +} + + +Status FunctionCallExpression::prepare() { + auto result = FunctionManager::get(*name_, args_.size()); + if (!result.ok()) { + return std::move(result).status(); + } + + function_ = std::move(result).value(); + + auto status = Status::OK(); + for (auto &arg : args_) { + status = arg->prepare(); + if (!status.ok()) { + break; + } + } + return status; +} + + +void FunctionCallExpression::encode(Cord &cord) const { + cord << kindToInt(kind()); + + cord << static_cast(name_->size()); + cord << *name_; + + cord << static_cast(args_.size()); + for (auto &arg : args_) { + arg->encode(cord); + } +} + + +const char* FunctionCallExpression::decode(const char *pos, const char *end) { + THROW_IF_NO_SPACE(pos, end, 2UL); + auto size = *reinterpret_cast(pos); + pos += 2; + + THROW_IF_NO_SPACE(pos, end, size); + name_ = std::make_unique(pos, size); + pos += size; + + auto count = *reinterpret_cast(pos); + pos += 2; + + args_.reserve(count); + for (auto i = 0u; i < count; i++) { + THROW_IF_NO_SPACE(pos, end, 1UL); + auto arg = makeExpr(*reinterpret_cast(pos++)); + pos = arg->decode(pos, end); + args_.emplace_back(std::move(arg)); + } + return pos; +} + + std::string UnaryExpression::toString() const { std::string buf; buf.reserve(256); diff --git a/src/parser/Expressions.h b/src/parser/Expressions.h index eafcc117f7e..05efa0df79c 100644 --- a/src/parser/Expressions.h +++ b/src/parser/Expressions.h @@ -212,6 +212,7 @@ class Expression { kUnknown = 0, kPrimary, + kFunctionCall, kUnary, kTypeCasting, kArithmetic, @@ -251,6 +252,7 @@ class Expression { // to allow them to call private encode/decode on each other. friend class PrimaryExpression; friend class UnaryExpression; + friend class FunctionCallExpression; friend class TypeCastingExpression; friend class ArithmeticExpression; friend class RelationalExpression; @@ -603,6 +605,61 @@ class PrimaryExpression final : public Expression { }; +class ArgumentList final { +public: + void addArgument(Expression *arg) { + args_.emplace_back(arg); + } + + auto args() { + return std::move(args_); + } + +private: + std::vector> args_; +}; + + +class FunctionCallExpression final : public Expression { +public: + FunctionCallExpression() { + kind_ = kFunctionCall; + } + + FunctionCallExpression(std::string *name, ArgumentList *args) { + kind_ = kFunctionCall; + name_.reset(name); + if (args != nullptr) { + args_ = args->args(); + delete args; + } + } + + std::string toString() const override; + + VariantType eval() const override; + + Status MUST_USE_RESULT prepare() override; + + void setContext(ExpressionContext *ctx) { + context_ = ctx; + for (auto &arg : args_) { + arg->setContext(ctx); + } + } + +private: + void encode(Cord &cord) const override; + + const char* decode(const char *pos, const char *end) override; + +private: + std::unique_ptr name_; + std::vector> args_; + std::function&)> function_; +}; + + // +expr, -expr, !expr class UnaryExpression final : public Expression { public: diff --git a/src/parser/FunctionManager.cpp b/src/parser/FunctionManager.cpp new file mode 100644 index 00000000000..9e6a209777d --- /dev/null +++ b/src/parser/FunctionManager.cpp @@ -0,0 +1,305 @@ +/* Copyright (c) 2018 - present, VE Software Inc. All rights reserved + * + * This source code is licensed under Apache 2.0 License + * (found in the LICENSE.Apache file in the root directory) + */ + +#include "base/Base.h" +#include "parser/FunctionManager.h" +#include "time/TimeUtils.h" + +namespace nebula { + +// static +FunctionManager& FunctionManager::instance() { + static FunctionManager instance; + return instance; +} + + +FunctionManager::FunctionManager() { + { + // absolute value + auto &attr = functions_["abs"]; + attr.minArity_ = 1; + attr.maxArity_ = 1; + attr.body_ = [] (const auto &args) { + return std::abs(Expression::asDouble(args[0])); + }; + } + { + // to nearest integer value not greater than x + auto &attr = functions_["floor"]; + attr.minArity_ = 1; + attr.maxArity_ = 1; + attr.body_ = [] (const auto &args) { + return std::floor(Expression::asDouble(args[0])); + }; + } + { + // to nearest integer value not less than x + auto &attr = functions_["ceil"]; + attr.minArity_ = 1; + attr.maxArity_ = 1; + attr.body_ = [] (const auto &args) { + return std::ceil(Expression::asDouble(args[0])); + }; + } + { + // to nearest integer value + auto &attr = functions_["round"]; + attr.minArity_ = 1; + attr.maxArity_ = 1; + attr.body_ = [] (const auto &args) { + return std::round(Expression::asDouble(args[0])); + }; + } + { + // square root + auto &attr = functions_["sqrt"]; + attr.minArity_ = 1; + attr.maxArity_ = 1; + attr.body_ = [] (const auto &args) { + return std::sqrt(Expression::asDouble(args[0])); + }; + } + { + // cubic root + auto &attr = functions_["cbrt"]; + attr.minArity_ = 1; + attr.maxArity_ = 1; + attr.body_ = [] (const auto &args) { + return std::cbrt(Expression::asDouble(args[0])); + }; + } + { + // sqrt(x^2 + y^2) + auto &attr = functions_["hypot"]; + attr.minArity_ = 2; + attr.maxArity_ = 2; + attr.body_ = [] (const auto &args) { + auto x = Expression::asDouble(args[0]); + auto y = Expression::asDouble(args[1]); + return std::hypot(x, y); + }; + } + { + // base^exp + auto &attr = functions_["pow"]; + attr.minArity_ = 2; + attr.maxArity_ = 2; + attr.body_ = [] (const auto &args) { + auto base = Expression::asDouble(args[0]); + auto exp = Expression::asDouble(args[1]); + return std::pow(base, exp); + }; + } + { + // e^x + auto &attr = functions_["exp"]; + attr.minArity_ = 1; + attr.maxArity_ = 1; + attr.body_ = [] (const auto &args) { + return std::exp(Expression::asDouble(args[0])); + }; + } + { + // 2^x + auto &attr = functions_["exp2"]; + attr.minArity_ = 1; + attr.maxArity_ = 1; + attr.body_ = [] (const auto &args) { + return std::exp2(Expression::asDouble(args[0])); + }; + } + { + // e-based logarithm + auto &attr = functions_["log"]; + attr.minArity_ = 1; + attr.maxArity_ = 1; + attr.body_ = [] (const auto &args) { + return std::log(Expression::asDouble(args[0])); + }; + } + { + // 2-based logarithm + auto &attr = functions_["log2"]; + attr.minArity_ = 1; + attr.maxArity_ = 1; + attr.body_ = [] (const auto &args) { + return std::log2(Expression::asDouble(args[0])); + }; + } + { + // 10-based logarithm + auto &attr = functions_["log10"]; + attr.minArity_ = 1; + attr.maxArity_ = 1; + attr.body_ = [] (const auto &args) { + return std::log10(Expression::asDouble(args[0])); + }; + } + { + auto &attr = functions_["sin"]; + attr.minArity_ = 1; + attr.maxArity_ = 1; + attr.body_ = [] (const auto &args) { + return std::sin(Expression::asDouble(args[0])); + }; + } + { + auto &attr = functions_["asin"]; + attr.minArity_ = 1; + attr.maxArity_ = 1; + attr.body_ = [] (const auto &args) { + return std::asin(Expression::asDouble(args[0])); + }; + } + { + auto &attr = functions_["cos"]; + attr.minArity_ = 1; + attr.maxArity_ = 1; + attr.body_ = [] (const auto &args) { + return std::cos(Expression::asDouble(args[0])); + }; + } + { + auto &attr = functions_["acos"]; + attr.minArity_ = 1; + attr.maxArity_ = 1; + attr.body_ = [] (const auto &args) { + return std::acos(Expression::asDouble(args[0])); + }; + } + { + auto &attr = functions_["tan"]; + attr.minArity_ = 1; + attr.maxArity_ = 1; + attr.body_ = [] (const auto &args) { + return std::tan(Expression::asDouble(args[0])); + }; + } + { + auto &attr = functions_["atan"]; + attr.minArity_ = 1; + attr.maxArity_ = 1; + attr.body_ = [] (const auto &args) { + return std::atan(Expression::asDouble(args[0])); + }; + } + { + // rand32(), rand32(max), rand32(min, max) + auto &attr = functions_["rand32"]; + attr.minArity_ = 0; + attr.maxArity_ = 2; + attr.body_ = [] (const auto &args) { + if (args.empty()) { + return static_cast(folly::Random::rand32()); + } else if (args.size() == 1UL) { + auto max = Expression::asInt(args[0]); + return static_cast(folly::Random::rand32(max)); + } + DCHECK_EQ(2UL, args.size()); + auto min = Expression::asInt(args[0]); + auto max = Expression::asInt(args[1]); + return static_cast(folly::Random::rand32(min, max)); + }; + } + { + // rand64(), rand64(max), rand64(min, max) + auto &attr = functions_["rand64"]; + attr.minArity_ = 0; + attr.maxArity_ = 2; + attr.body_ = [] (const auto &args) { + if (args.empty()) { + return static_cast(folly::Random::rand32()); + } else if (args.size() == 1UL) { + auto max = Expression::asInt(args[0]); + return static_cast(folly::Random::rand32(max)); + } + DCHECK_EQ(2UL, args.size()); + auto min = Expression::asInt(args[0]); + auto max = Expression::asInt(args[1]); + return static_cast(folly::Random::rand32(min, max)); + }; + } + { + // unix timestamp + auto &attr = functions_["now"]; + attr.minArity_ = 0; + attr.maxArity_ = 0; + attr.body_ = [] (const auto &args) { + UNUSED(args); + return time::TimeUtils::nowInSeconds(); + }; + } + { + auto &attr = functions_["strcasecmp"]; + attr.minArity_ = 2; + attr.maxArity_ = 2; + attr.body_ = [] (const auto &args) { + auto &left = Expression::asString(args[0]); + auto &right = Expression::asString(args[1]); + return static_cast(::strcasecmp(left.c_str(), right.c_str())); + }; + } +} + + +// static +StatusOr +FunctionManager::get(const std::string &func, size_t arity) { + return instance().getInternal(func, arity); +} + + +StatusOr +FunctionManager::getInternal(const std::string &func, size_t arity) const { + auto status = Status::OK(); + folly::RWSpinLock::ReadHolder holder(lock_); + // check existence + auto iter = functions_.find(func); + if (iter == functions_.end()) { + return Status::Error("Function `%s' not defined", func.c_str()); + } + // check arity + auto minArity = iter->second.minArity_; + auto maxArity = iter->second.maxArity_; + if (arity < minArity || arity > maxArity) { + return Status::Error("Arity not match for function `%s': %lu vs. [%lu-%lu]", + func.c_str(), arity, minArity, maxArity); + } + return iter->second.body_; +} + + +// static +Status FunctionManager::load(const std::string &name, + const std::vector &funcs) { + return instance().loadInternal(name, funcs); +} + + +Status FunctionManager::loadInternal(const std::string &name, + const std::vector &funcs) { + UNUSED(name); + UNUSED(funcs); + return Status::Error("Dynamic function loading not supported yet"); +} + + +// static +Status FunctionManager::unload(const std::string &name, + const std::vector &funcs) { + return instance().loadInternal(name, funcs); +} + + +Status FunctionManager::unloadInternal(const std::string &name, + const std::vector &funcs) { + UNUSED(name); + UNUSED(funcs); + return Status::Error("Dynamic function unloading not supported yet"); +} + +} // namespace nebula diff --git a/src/parser/FunctionManager.h b/src/parser/FunctionManager.h new file mode 100644 index 00000000000..fced623b1dc --- /dev/null +++ b/src/parser/FunctionManager.h @@ -0,0 +1,69 @@ +/* Copyright (c) 2018 - present, VE Software Inc. All rights reserved + * + * This source code is licensed under Apache 2.0 License + * (found in the LICENSE.Apache file in the root directory) + */ + +#ifndef PARSER_FUNCTIONMANAGER_H_ +#define PARSER_FUNCTIONMANAGER_H_ + +#include "base/Base.h" +#include "base/StatusOr.h" +#include "base/Status.h" +#include "parser/Expressions.h" + +/** + * FunctionManager is for managing builtin and dynamic-loaded functions, + * which users could use as function call expressions. + * + * TODO(dutor) To implement dynamic loading. + */ + +namespace nebula { + +class FunctionManager final { +public: + using Function = std::function&)>; + + /** + * To obtain a function named `func', with the actual arity. + */ + static StatusOr get(const std::string &func, size_t arity); + + /** + * To load a set of functions from a shared object dynamically. + */ + static Status load(const std::string &soname, const std::vector &funcs); + + /** + * To unload a shared object. + */ + static Status unload(const std::string &soname, const std::vector &funcs); + +private: + /** + * FunctionManager functions as a singleton, since the dynamic loading is process-wide. + */ + FunctionManager(); + + static FunctionManager& instance(); + + StatusOr getInternal(const std::string &func, size_t arity) const; + + Status loadInternal(const std::string &soname, const std::vector &funcs); + + Status unloadInternal(const std::string &soname, const std::vector &funcs); + + struct FunctionAttributes final { + size_t minArity_{0}; + size_t maxArity_{0}; + Function body_; + }; + + mutable folly::RWSpinLock lock_; + std::unordered_map functions_; +}; + +} // namespace nebula + +#endif // PARSER_FUNCTIONMANAGER_H_ diff --git a/src/parser/parser.yy b/src/parser/parser.yy index 86522762ecb..c43269f64b7 100644 --- a/src/parser/parser.yy +++ b/src/parser/parser.yy @@ -57,6 +57,7 @@ class GraphScanner; nebula::UpdateList *update_list; nebula::UpdateItem *update_item; nebula::EdgeList *edge_list; + nebula::ArgumentList *argument_list; } /* destructors */ @@ -91,6 +92,8 @@ class GraphScanner; %type input_ref_expression %type var_ref_expression %type alias_ref_expression +%type function_call_expression +%type argument_list %type type_spec %type step_clause %type from_clause @@ -166,6 +169,9 @@ primary_expression | L_PAREN expression R_PAREN { $$ = $2; } + | function_call_expression { + $$ = $1; + } ; input_ref_expression @@ -210,6 +216,26 @@ alias_ref_expression } ; +function_call_expression + : LABEL L_PAREN argument_list R_PAREN { + $$ = new FunctionCallExpression($1, $3); + } + ; + +argument_list + : %empty { + $$ = nullptr; + } + | expression { + $$ = new ArgumentList(); + $$->addArgument($1); + } + | argument_list COMMA expression { + $$ = $1; + $$->addArgument($3); + } + ; + unary_expression : primary_expression { $$ = $1; } | ADD primary_expression { diff --git a/src/parser/scanner.lex b/src/parser/scanner.lex index b6da72a7856..0076eaaa7ad 100644 --- a/src/parser/scanner.lex +++ b/src/parser/scanner.lex @@ -184,8 +184,8 @@ OCT ([0-7]) return TokenType::INTEGER; } [+-]?{DEC}+ { yylval->intval = ::atoll(yytext); return TokenType::INTEGER; } -{DEC}+\.{DEC}* { yylval->doubleval = ::atof(yytext); return TokenType::DOUBLE; } -{DEC}*\.{DEC}+ { yylval->doubleval = ::atof(yytext); return TokenType::DOUBLE; } +[+-]?{DEC}+\.{DEC}* { yylval->doubleval = ::atof(yytext); return TokenType::DOUBLE; } +[+-]?{DEC}*\.{DEC}+ { yylval->doubleval = ::atof(yytext); return TokenType::DOUBLE; } \${LABEL} { yylval->strval = new std::string(yytext + 1, yyleng - 1); return TokenType::VARIABLE; } diff --git a/src/parser/test/ExpressionTest.cpp b/src/parser/test/ExpressionTest.cpp index 44e3978fc0b..a2dca8ac2f5 100644 --- a/src/parser/test/ExpressionTest.cpp +++ b/src/parser/test/ExpressionTest.cpp @@ -8,6 +8,7 @@ #include #include "parser/GQLParser.h" #include "parser/SequentialSentences.h" +#include "parser/FunctionManager.h" namespace nebula { @@ -493,4 +494,82 @@ TEST_F(ExpressionTest, EdgeReference) { } } + +TEST_F(ExpressionTest, FunctionCall) { + GQLParser parser; +#define TEST_EXPR(expected, op, expr_arg, type) \ + do { \ + std::string query = "GO FROM 1 OVER follow WHERE " #expr_arg; \ + auto parsed = parser.parse(query); \ + ASSERT_TRUE(parsed.ok()) << parsed.status(); \ + auto *expr = getFilterExpr(parsed.value().get()); \ + ASSERT_NE(nullptr, expr); \ + auto decoded = Expression::decode(Expression::encode(expr)); \ + ASSERT_TRUE(decoded.ok()) << decoded.status(); \ + auto ctx = std::make_unique(); \ + decoded.value()->setContext(ctx.get()); \ + auto status = decoded.value()->prepare(); \ + ASSERT_TRUE(status.ok()) << status; \ + auto value = decoded.value()->eval(); \ + ASSERT_TRUE(Expression::is##type(value)); \ + if (#type == std::string("Double")) { \ + if (#op != std::string("EQ")) { \ + ASSERT_##op(expected, Expression::as##type(value)); \ + } else { \ + ASSERT_DOUBLE_EQ(expected, Expression::as##type(value));\ + } \ + } else { \ + ASSERT_##op(expected, Expression::as##type(value)); \ + } \ + } while (false) + + TEST_EXPR(5.0, EQ, abs(5), Double); + TEST_EXPR(5.0, EQ, abs(-5), Double); + + TEST_EXPR(3.0, EQ, floor(3.14), Double); + TEST_EXPR(-4.0, EQ, floor(-3.14), Double); + + TEST_EXPR(4.0, EQ, ceil(3.14), Double); + TEST_EXPR(-3.0, EQ, ceil(-3.14), Double); + + TEST_EXPR(3.0, EQ, round(3.14), Double); + TEST_EXPR(-3.0, EQ, round(-3.14), Double); + TEST_EXPR(4.0, EQ, round(3.5), Double); + TEST_EXPR(-4.0, EQ, round(-3.5), Double); + + TEST_EXPR(3.0, EQ, cbrt(27), Double); + + constexpr auto euler = 2.7182818284590451; + TEST_EXPR(euler, EQ, exp(1), Double); + TEST_EXPR(1024, EQ, exp2(10), Double); + + TEST_EXPR(2, EQ, log(pow(2.7182818284590451, 2)), Double); + TEST_EXPR(10, EQ, log2(1024), Double); + TEST_EXPR(3, EQ, log10(1000), Double); + + TEST_EXPR(5.0, EQ, hypot(3, 4), Double); + TEST_EXPR(5.0, EQ, sqrt(pow(3, 2) + pow(4, 2)), Double); + + TEST_EXPR(1.0, EQ, hypot(sin(0.5), cos(0.5)), Double); + TEST_EXPR(1.0, EQ, (sin(0.5) / cos(0.5)) / tan(0.5), Double); + + TEST_EXPR(0.3, EQ, sin(asin(0.3)), Double); + TEST_EXPR(0.3, EQ, cos(acos(0.3)), Double); + TEST_EXPR(0.3, EQ, tan(atan(0.3)), Double); + + TEST_EXPR(1024, GT, rand32(1024), Int); + TEST_EXPR(0, LE, rand32(1024), Int); + TEST_EXPR(4096, GT, rand64(1024, 4096), Int); + TEST_EXPR(1024, LE, rand64(1024, 4096), Int); + + TEST_EXPR(1554716753, LT, now(), Int); + TEST_EXPR(4773548753, GE, now(), Int); // failed 102 years later + + TEST_EXPR(0, EQ, strcasecmp("HelLo", "hello"), Int); + TEST_EXPR(0, LT, strcasecmp("HelLo", "hell"), Int); + TEST_EXPR(0, GT, strcasecmp("HelLo", "World"), Int); + +#undef TEST_EXPR +} + } // namespace nebula diff --git a/src/parser/test/ScannerTest.cpp b/src/parser/test/ScannerTest.cpp index a8c31785750..6b50fb94415 100644 --- a/src/parser/test/ScannerTest.cpp +++ b/src/parser/test/ScannerTest.cpp @@ -245,6 +245,8 @@ TEST(Scanner, Basic) { CHECK_SEMANTIC_VALUE("123.", TokenType::DOUBLE, 123.), CHECK_SEMANTIC_VALUE(".123", TokenType::DOUBLE, 0.123), CHECK_SEMANTIC_VALUE("123.456", TokenType::DOUBLE, 123.456), + CHECK_SEMANTIC_VALUE("+123.456", TokenType::DOUBLE, 123.456), + CHECK_SEMANTIC_VALUE("-123.456", TokenType::DOUBLE, -123.456), CHECK_SEMANTIC_VALUE("\"Hello\"", TokenType::STRING, "Hello"), CHECK_SEMANTIC_VALUE("\"Hello\\\\\"", TokenType::STRING, "Hello\\"),