diff --git a/src/common/function/FunctionManager.cpp b/src/common/function/FunctionManager.cpp index 686891529b2..2a38b6736f3 100644 --- a/src/common/function/FunctionManager.cpp +++ b/src/common/function/FunctionManager.cpp @@ -55,7 +55,9 @@ std::unordered_map> FunctionManager::typ TypeSignature({Value::Type::FLOAT}, Value::Type::FLOAT)}}, {"round", {TypeSignature({Value::Type::INT}, Value::Type::FLOAT), - TypeSignature({Value::Type::FLOAT}, Value::Type::FLOAT)}}, + TypeSignature({Value::Type::INT, Value::Type::INT}, Value::Type::FLOAT), + TypeSignature({Value::Type::FLOAT}, Value::Type::FLOAT), + TypeSignature({Value::Type::FLOAT, Value::Type::INT}, Value::Type::FLOAT)}}, {"sqrt", {TypeSignature({Value::Type::INT}, Value::Type::FLOAT), TypeSignature({Value::Type::FLOAT}, Value::Type::FLOAT)}}, @@ -539,17 +541,23 @@ FunctionManager::FunctionManager() { // to nearest integral (as a floating-point value) auto &attr = functions_["round"]; attr.minArity_ = 1; - attr.maxArity_ = 1; + attr.maxArity_ = 2; attr.isPure_ = true; attr.body_ = [](const auto &args) -> Value { switch (args[0].get().type()) { case Value::Type::NULLVALUE: { return Value::kNullValue; } - case Value::Type::INT: { - return std::round(args[0].get().getInt()); - } + case Value::Type::INT: case Value::Type::FLOAT: { + if (args.size() == 2) { + if (args[1].get().type() == Value::Type::INT) { + auto decimal = args[1].get().getInt(); + return std::round(args[0].get().getFloat() * pow(10, decimal)) / pow(10, decimal); + } else { + return Value::kNullBadType; + } + } return std::round(args[0].get().getFloat()); } default: { diff --git a/src/common/function/test/FunctionManagerTest.cpp b/src/common/function/test/FunctionManagerTest.cpp index e0383f1106d..9305c788e59 100644 --- a/src/common/function/test/FunctionManagerTest.cpp +++ b/src/common/function/test/FunctionManagerTest.cpp @@ -103,6 +103,9 @@ std::unordered_map> FunctionManagerTest::args_ = {"one", {-1.2}}, {"two", {2, 4}}, {"pow", {2, 3}}, + {"round1", {11111.11111, 2}}, + {"round2", {11111.11111, -1}}, + {"round3", {11111.11111, -5}}, {"radians", {180}}, {"range1", {1, 5}}, {"range2", {1, 5, 2}}, @@ -268,6 +271,11 @@ TEST_F(FunctionManagerTest, functionCall) { TEST_FUNCTION(log, args_["int"], std::log(4)); TEST_FUNCTION(log2, args_["int"], 2.0); } + { + TEST_FUNCTION(round, args_["round1"], 11111.11); + TEST_FUNCTION(round, args_["round2"], 11110.0); + TEST_FUNCTION(round, args_["round3"], 0.0); + } { TEST_FUNCTION(range, args_["range1"], Value(List({1, 2, 3, 4, 5}))); TEST_FUNCTION(range, args_["range2"], Value(List({1, 3, 5}))); @@ -916,11 +924,21 @@ TEST_F(FunctionManagerTest, returnType) { ASSERT_TRUE(result.ok()); EXPECT_EQ(result.value(), Value::Type::FLOAT); } + { + auto result = FunctionManager::getReturnType("round", {Value::Type::INT, Value::Type::INT}); + ASSERT_TRUE(result.ok()); + EXPECT_EQ(result.value(), Value::Type::FLOAT); + } { auto result = FunctionManager::getReturnType("round", {Value::Type::FLOAT}); ASSERT_TRUE(result.ok()); EXPECT_EQ(result.value(), Value::Type::FLOAT); } + { + auto result = FunctionManager::getReturnType("round", {Value::Type::FLOAT, Value::Type::INT}); + ASSERT_TRUE(result.ok()); + EXPECT_EQ(result.value(), Value::Type::FLOAT); + } { auto result = FunctionManager::getReturnType("cbrt", {Value::Type::INT}); ASSERT_TRUE(result.ok()); diff --git a/tests/tck/features/expression/FunctionCall.feature b/tests/tck/features/expression/FunctionCall.feature index cd71235419e..5c5a6a00a46 100644 --- a/tests/tck/features/expression/FunctionCall.feature +++ b/tests/tck/features/expression/FunctionCall.feature @@ -121,6 +121,50 @@ Feature: Function Call Expression | result | | NULL | + Scenario: round + When executing query: + """ + YIELD round(3.1415926, 9) as result + """ + Then the result should be, in any order: + | result | + | 3.1415926 | + When executing query: + """ + YIELD round(3.1415926, 2) as result + """ + Then the result should be, in any order: + | result | + | 3.14 | + When executing query: + """ + YIELD round(3.1415926, 3) as result + """ + Then the result should be, in any order: + | result | + | 3.142 | + When executing query: + """ + YIELD round(3.14159265359, 0) as result + """ + Then the result should be, in any order: + | result | + | 3.0 | + When executing query: + """ + YIELD round(35543.14159265359, -3) as result + """ + Then the result should be, in any order: + | result | + | 36000.0 | + When executing query: + """ + YIELD round(35543.14159265359, -5) as result + """ + Then the result should be, in any order: + | result | + | 0.0 | + Scenario: error check When executing query: """