diff --git a/velox/docs/functions/spark/json.rst b/velox/docs/functions/spark/json.rst index 07f4f3a75ace2..c2708b938dca1 100644 --- a/velox/docs/functions/spark/json.rst +++ b/velox/docs/functions/spark/json.rst @@ -22,6 +22,8 @@ JSON Functions .. spark:function:: get_json_object(json, path) -> varchar - Extracts a json object from path:: + Extracts a json object from ``path``. Returns NULL if it finds json string + is malformed. :: - SELECT get_json_object('{"a":"b"}', '$.a'); -- b \ No newline at end of file + SELECT get_json_object('{"a":"b"}', '$.a'); -- 'b' + SELECT get_json_object('{"a":{"b":"c"}}', '$.a'); -- '{"b":"c"}' \ No newline at end of file diff --git a/velox/functions/sparksql/CMakeLists.txt b/velox/functions/sparksql/CMakeLists.txt index 0522cbfefab5f..352f7498fcaca 100644 --- a/velox/functions/sparksql/CMakeLists.txt +++ b/velox/functions/sparksql/CMakeLists.txt @@ -48,7 +48,8 @@ target_link_libraries( velox_functions_spark_specialforms velox_is_null_functions velox_functions_util - Folly::folly) + Folly::folly + simdjson) set_property(TARGET velox_functions_spark PROPERTY JOB_POOL_COMPILE high_memory_pool) diff --git a/velox/functions/sparksql/Register.cpp b/velox/functions/sparksql/Register.cpp index 919d5bc0a51fe..62f8fc5210c8d 100644 --- a/velox/functions/sparksql/Register.cpp +++ b/velox/functions/sparksql/Register.cpp @@ -22,7 +22,6 @@ #include "velox/functions/lib/Re2Functions.h" #include "velox/functions/lib/RegistrationHelpers.h" #include "velox/functions/prestosql/DateTimeFunctions.h" -#include "velox/functions/prestosql/JsonFunctions.h" #include "velox/functions/prestosql/StringFunctions.h" #include "velox/functions/sparksql/ArrayMinMaxFunction.h" #include "velox/functions/sparksql/ArraySort.h" @@ -35,6 +34,7 @@ #include "velox/functions/sparksql/RegexFunctions.h" #include "velox/functions/sparksql/RegisterArithmetic.h" #include "velox/functions/sparksql/RegisterCompare.h" +#include "velox/functions/sparksql/SIMDJsonFunctions.h" #include "velox/functions/sparksql/Size.h" #include "velox/functions/sparksql/String.h" #include "velox/functions/sparksql/StringToMap.h" @@ -124,7 +124,7 @@ void registerFunctions(const std::string& prefix) { // Register size functions registerSize(prefix + "size"); - registerFunction( + registerFunction( {prefix + "get_json_object"}); // Register string functions. diff --git a/velox/functions/sparksql/SIMDJsonFunctions.h b/velox/functions/sparksql/SIMDJsonFunctions.h new file mode 100644 index 0000000000000..fc751672bd60c --- /dev/null +++ b/velox/functions/sparksql/SIMDJsonFunctions.h @@ -0,0 +1,182 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/functions/prestosql/SIMDJsonFunctions.h" + +using namespace simdjson; + +namespace facebook::velox::functions::sparksql { + +template +struct SIMDGetJsonObjectFunction { + VELOX_DEFINE_FUNCTION_TYPES(T); + std::optional formattedJsonPath_; + + // ASCII input always produces ASCII result. + static constexpr bool is_default_ascii_behavior = true; + + // Makes a conversion from spark's json path, e.g., converts + // "$.a.b" to "/a/b". + FOLLY_ALWAYS_INLINE std::string getFormattedJsonPath( + const arg_type& jsonPath) { + char formattedJsonPath[jsonPath.size() + 1]; + int j = 0; + for (int i = 0; i < jsonPath.size(); i++) { + if (jsonPath.data()[i] == '$' || jsonPath.data()[i] == ']' || + jsonPath.data()[i] == '\'') { + continue; + } else if (jsonPath.data()[i] == '[' || jsonPath.data()[i] == '.') { + formattedJsonPath[j] = '/'; + j++; + } else { + formattedJsonPath[j] = jsonPath.data()[i]; + j++; + } + } + formattedJsonPath[j] = '\0'; + return std::string(formattedJsonPath, j + 1); + } + + FOLLY_ALWAYS_INLINE void initialize( + const core::QueryConfig& config, + const arg_type* /*json*/, + const arg_type* jsonPath) { + if (jsonPath != nullptr) { + formattedJsonPath_ = getFormattedJsonPath(*jsonPath); + } + } + + FOLLY_ALWAYS_INLINE simdjson::error_code extractStringResult( + simdjson_result rawResult, + out_type& result) { + simdjson::error_code error; + std::stringstream ss; + switch (rawResult.type()) { + // For number and bool types, we need to explicitly get the value + // for specific types instead of using `ss << rawResult`. Thus, we + // can make simdjson's internal parsing position moved and then we + // can check the validity of ending character. + case ondemand::json_type::number: { + switch (rawResult.get_number_type()) { + case ondemand::number_type::unsigned_integer: { + uint64_t numberResult; + error = rawResult.get_uint64().get(numberResult); + if (!error) { + ss << numberResult; + result.append(ss.str()); + } + return error; + } + case ondemand::number_type::signed_integer: { + int64_t numberResult; + error = rawResult.get_int64().get(numberResult); + if (!error) { + ss << numberResult; + result.append(ss.str()); + } + return error; + } + case ondemand::number_type::floating_point_number: { + double numberResult; + error = rawResult.get_double().get(numberResult); + if (!error) { + ss << numberResult; + result.append(ss.str()); + } + return error; + } + default: + VELOX_UNREACHABLE(); + } + } + case ondemand::json_type::boolean: { + bool boolResult; + error = rawResult.get_bool().get(boolResult); + if (!error) { + result.append(boolResult ? "true" : "false"); + } + return error; + } + case ondemand::json_type::string: { + std::string_view stringResult; + error = rawResult.get_string().get(stringResult); + result.append(stringResult); + return error; + } + case ondemand::json_type::object: { + // For nested case, e.g., for "{"my": {"hello": 10}}", "$.my" will + // return an object type. + ss << rawResult; + result.append(ss.str()); + return SUCCESS; + } + case ondemand::json_type::array: { + ss << rawResult; + result.append(ss.str()); + return SUCCESS; + } + default: { + return UNSUPPORTED_ARCHITECTURE; + } + } + } + + // This is a simple validation by checking whether the obtained result is + // followed by valid char. Because ondemand parsing we are using ignores json + // format validation for characters following the current parsing position. + bool isValidEndingCharacter(const char* currentPos) { + char endingChar = *currentPos; + if (endingChar == ',' || endingChar == '}' || endingChar == ']') { + return true; + } + if (endingChar == ' ' || endingChar == '\r' || endingChar == '\n' || + endingChar == '\t') { + // These chars can be prior to a valid ending char. + return isValidEndingCharacter(currentPos++); + } + return false; + } + + FOLLY_ALWAYS_INLINE bool call( + out_type& result, + const arg_type& json, + const arg_type& jsonPath) { + ParserContext ctx(json.data(), json.size()); + try { + ctx.parseDocument(); + simdjson_result rawResult = + formattedJsonPath_.has_value() + ? ctx.jsonDoc.at_pointer(formattedJsonPath_.value().data()) + : ctx.jsonDoc.at_pointer(getFormattedJsonPath(jsonPath).data()); + // Field not found. + if (rawResult.error() == NO_SUCH_FIELD) { + return false; + } + auto error = extractStringResult(rawResult, result); + if (error) { + return false; + } + } catch (simdjson_error& e) { + return false; + } + + const char* currentPos; + ctx.jsonDoc.current_location().get(currentPos); + return isValidEndingCharacter(currentPos); + } +}; + +} // namespace facebook::velox::functions::sparksql diff --git a/velox/functions/sparksql/tests/CMakeLists.txt b/velox/functions/sparksql/tests/CMakeLists.txt index 86d4d1fa5d52e..8e566fb108c15 100644 --- a/velox/functions/sparksql/tests/CMakeLists.txt +++ b/velox/functions/sparksql/tests/CMakeLists.txt @@ -29,6 +29,7 @@ add_executable( ElementAtTest.cpp HashTest.cpp InTest.cpp + JsonFunctionsTest.cpp LeastGreatestTest.cpp MapTest.cpp MightContainTest.cpp diff --git a/velox/functions/sparksql/tests/JsonFunctionsTest.cpp b/velox/functions/sparksql/tests/JsonFunctionsTest.cpp new file mode 100644 index 0000000000000..ecdb8a47d00cd --- /dev/null +++ b/velox/functions/sparksql/tests/JsonFunctionsTest.cpp @@ -0,0 +1,105 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "velox/functions/prestosql/types/JsonType.h" +#include "velox/functions/sparksql/tests/SparkFunctionBaseTest.h" +#include "velox/type/Type.h" + +#include + +namespace facebook::velox::functions::sparksql::test { +namespace { + +class JsonFunctionTest : public SparkFunctionBaseTest { + protected: + std::optional getJsonObject( + const std::optional& json, + const std::optional& jsonPath) { + return evaluateOnce("get_json_object(c0, c1)", json, jsonPath); + } +}; + +TEST_F(JsonFunctionTest, getJsonObject) { + EXPECT_EQ(getJsonObject(R"({"hello": "3.5"})", "$.hello"), "3.5"); + EXPECT_EQ(getJsonObject(R"({"hello": 3.5})", "$.hello"), "3.5"); + EXPECT_EQ(getJsonObject(R"({"hello": 292222730})", "$.hello"), "292222730"); + EXPECT_EQ(getJsonObject(R"({"hello": -292222730})", "$.hello"), "-292222730"); + EXPECT_EQ(getJsonObject(R"({"my": {"hello": 3.5}})", "$.my.hello"), "3.5"); + EXPECT_EQ(getJsonObject(R"({"my": {"hello": true}})", "$.my.hello"), "true"); + EXPECT_EQ(getJsonObject(R"({"hello": ""})", "$.hello"), ""); + EXPECT_EQ( + getJsonObject(R"({"name": "Alice", "age": 5, "id": "001"})", "$.age"), + "5"); + EXPECT_EQ( + getJsonObject(R"({"name": "Alice", "age": 5, "id": "001"})", "$.id"), + "001"); + EXPECT_EQ( + getJsonObject( + R"([{"my": {"param": {"name": "Alice", "age": "5", "id": "001"}}}, {"other": "v1"}])", + "$[0]['my']['param']['age']"), + "5"); + EXPECT_EQ( + getJsonObject( + R"([{"my": {"param": {"name": "Alice", "age": "5", "id": "001"}}}, {"other": "v1"}])", + "$[0].my.param.age"), + "5"); + + // Json object as result. + EXPECT_EQ( + getJsonObject( + R"({"my": {"param": {"name": "Alice", "age": "5", "id": "001"}}})", + "$.my.param"), + R"({"name": "Alice", "age": "5", "id": "001"})"); + EXPECT_EQ( + getJsonObject( + R"({"my": {"param": {"name": "Alice", "age": "5", "id": "001"}}})", + "$['my']['param']"), + R"({"name": "Alice", "age": "5", "id": "001"})"); + + // Array as result. + EXPECT_EQ( + getJsonObject( + R"([{"my": {"param": {"name": "Alice"}}}, {"other": ["v1", "v2"]}])", + "$[1].other"), + R"(["v1", "v2"])"); + // Array element as result. + EXPECT_EQ( + getJsonObject( + R"([{"my": {"param": {"name": "Alice"}}}, {"other": ["v1", "v2"]}])", + "$[1].other[0]"), + "v1"); + EXPECT_EQ( + getJsonObject( + R"([{"my": {"param": {"name": "Alice"}}}, {"other": ["v1", "v2"]}])", + "$[1].other[1]"), + "v2"); + + // Field not found. + EXPECT_EQ(getJsonObject(R"({"hello": "3.5"})", "$.hi"), std::nullopt); + // Illegal json. + EXPECT_EQ(getJsonObject(R"({"hello"-3.5})", "$.hello"), std::nullopt); + // Illegal json path. + EXPECT_EQ(getJsonObject(R"({"hello": "3.5"})", "$hello"), std::nullopt); + EXPECT_EQ(getJsonObject(R"({"hello": "3.5"})", "$."), std::nullopt); + // Invalid ending character. + EXPECT_EQ( + getJsonObject( + R"([{"my": {"param": {"name": "Alice"quoted""}}}, {"other": ["v1", "v2"]}])", + "$[0].my.param.name"), + std::nullopt); +} + +} // namespace +} // namespace facebook::velox::functions::sparksql::test