diff --git a/dbms/src/Flash/Coprocessor/DAGExpressionAnalyzerHelper.cpp b/dbms/src/Flash/Coprocessor/DAGExpressionAnalyzerHelper.cpp index 0b36f292668..10f2b24c9fa 100644 --- a/dbms/src/Flash/Coprocessor/DAGExpressionAnalyzerHelper.cpp +++ b/dbms/src/Flash/Coprocessor/DAGExpressionAnalyzerHelper.cpp @@ -281,27 +281,30 @@ String DAGExpressionAnalyzerHelper::buildCastFunction( return buildCastFunctionInternal(analyzer, {name, type_expr_name}, false, expr.field_type(), actions); } -String DAGExpressionAnalyzerHelper::buildCastAsJsonWithInputTiDBField( +String DAGExpressionAnalyzerHelper::buildSingleParamJsonRelatedFunctions( DAGExpressionAnalyzer * analyzer, const tipb::Expr & expr, const ExpressionActionsPtr & actions) { + auto func_name = getFunctionName(expr); if unlikely (expr.children_size() != 1) - throw TiFlashException("Cast function only support one argument", Errors::Coprocessor::BadRequest); + throw TiFlashException( + fmt::format("{} function only support one argument", func_name), + Errors::Coprocessor::BadRequest); if unlikely (!exprHasValidFieldType(expr)) - throw TiFlashException("CAST function without valid field type", Errors::Coprocessor::BadRequest); + throw TiFlashException( + fmt::format("{} function without valid field type", func_name), + Errors::Coprocessor::BadRequest); const auto & input_expr = expr.children(0); - auto func_name = getFunctionName(expr); - String arg = analyzer->getActions(input_expr, actions); const auto & collator = getCollatorFromExpr(expr); String result_name = genFuncString(func_name, {arg}, {collator}); if (actions->getSampleBlock().has(result_name)) return result_name; - const FunctionBuilderPtr & function_builder = FunctionFactory::instance().get(func_name, analyzer->getContext()); - auto * function_build_ptr = function_builder.get(); + const FunctionBuilderPtr & ifunction_builder = FunctionFactory::instance().get(func_name, analyzer->getContext()); + auto * function_build_ptr = ifunction_builder.get(); if (auto * function_builder = dynamic_cast(function_build_ptr); function_builder) { auto * function_impl = function_builder->getFunctionImpl().get(); @@ -321,17 +324,29 @@ String DAGExpressionAnalyzerHelper::buildCastAsJsonWithInputTiDBField( { function_cast_time_as_json->setInputTiDBFieldType(input_expr.field_type()); } + else if (auto * function_json_unquote = dynamic_cast(function_impl); + function_json_unquote) + { + bool valid_check + = !(isScalarFunctionExpr(input_expr) && input_expr.sig() == tipb::ScalarFuncSig::CastJsonAsString); + function_json_unquote->setNeedValidCheck(valid_check); + } + else if (auto * function_cast_json_as_string = dynamic_cast(function_impl); + function_cast_json_as_string) + { + function_cast_json_as_string->setOutputTiDBFieldType(expr.field_type()); + } else { - throw Exception(fmt::format("Unexpected func {} in buildCastAsJsonWithInputTiDBField", func_name)); + throw Exception(fmt::format("Unexpected func {} in buildSingleParamJsonRelatedFunctions", func_name)); } } else { - throw Exception(fmt::format("Unexpected func {} in buildCastAsJsonWithInputTiDBField", func_name)); + throw Exception(fmt::format("Unexpected func {} in buildSingleParamJsonRelatedFunctions", func_name)); } - const ExpressionAction & action = ExpressionAction::applyFunction(function_builder, {arg}, result_name, collator); + const ExpressionAction & action = ExpressionAction::applyFunction(ifunction_builder, {arg}, result_name, collator); actions->add(action); return result_name; } @@ -534,9 +549,11 @@ DAGExpressionAnalyzerHelper::FunctionBuilderMap DAGExpressionAnalyzerHelper::fun {"ifNull", DAGExpressionAnalyzerHelper::buildIfNullFunction}, {"multiIf", DAGExpressionAnalyzerHelper::buildMultiIfFunction}, {"tidb_cast", DAGExpressionAnalyzerHelper::buildCastFunction}, - {"cast_int_as_json", DAGExpressionAnalyzerHelper::buildCastAsJsonWithInputTiDBField}, - {"cast_string_as_json", DAGExpressionAnalyzerHelper::buildCastAsJsonWithInputTiDBField}, - {"cast_time_as_json", DAGExpressionAnalyzerHelper::buildCastAsJsonWithInputTiDBField}, + {"cast_int_as_json", DAGExpressionAnalyzerHelper::buildSingleParamJsonRelatedFunctions}, + {"cast_string_as_json", DAGExpressionAnalyzerHelper::buildSingleParamJsonRelatedFunctions}, + {"cast_time_as_json", DAGExpressionAnalyzerHelper::buildSingleParamJsonRelatedFunctions}, + {"cast_json_as_string", DAGExpressionAnalyzerHelper::buildSingleParamJsonRelatedFunctions}, + {"json_unquote", DAGExpressionAnalyzerHelper::buildSingleParamJsonRelatedFunctions}, {"and", DAGExpressionAnalyzerHelper::buildLogicalFunction}, {"or", DAGExpressionAnalyzerHelper::buildLogicalFunction}, {"xor", DAGExpressionAnalyzerHelper::buildLogicalFunction}, diff --git a/dbms/src/Flash/Coprocessor/DAGExpressionAnalyzerHelper.h b/dbms/src/Flash/Coprocessor/DAGExpressionAnalyzerHelper.h index 5313ad6b530..529dc204ace 100644 --- a/dbms/src/Flash/Coprocessor/DAGExpressionAnalyzerHelper.h +++ b/dbms/src/Flash/Coprocessor/DAGExpressionAnalyzerHelper.h @@ -71,7 +71,7 @@ class DAGExpressionAnalyzerHelper const tipb::Expr & expr, const ExpressionActionsPtr & actions); - static String buildCastAsJsonWithInputTiDBField( + static String buildSingleParamJsonRelatedFunctions( DAGExpressionAnalyzer * analyzer, const tipb::Expr & expr, const ExpressionActionsPtr & actions); diff --git a/dbms/src/Functions/FunctionsJson.h b/dbms/src/Functions/FunctionsJson.h index c8b7c80e586..153aa3c481a 100644 --- a/dbms/src/Functions/FunctionsJson.h +++ b/dbms/src/Functions/FunctionsJson.h @@ -27,11 +27,14 @@ #include #include #include +#include #include #include #include +#include #include #include +#include #include #include #include @@ -301,6 +304,7 @@ class FunctionJsonUnquote : public IFunction size_t getNumberOfArguments() const override { return 1; } + void setNeedValidCheck(bool need_valid_check_) { need_valid_check = need_valid_check_; } bool useDefaultImplementationForConstants() const override { return true; } DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override { @@ -327,16 +331,10 @@ class FunctionJsonUnquote : public IFunction offsets_to.resize(rows); ColumnUInt8::MutablePtr col_null_map = ColumnUInt8::create(rows, 0); JsonBinary::JsonBinaryWriteBuffer write_buffer(data_to); - size_t current_offset = 0; - for (size_t i = 0; i < block.rows(); ++i) - { - size_t next_offset = offsets_from[i]; - size_t data_length = next_offset - current_offset - 1; - JsonBinary::unquoteStringInBuffer(StringRef(&data_from[current_offset], data_length), write_buffer); - writeChar(0, write_buffer); - offsets_to[i] = write_buffer.count(); - current_offset = next_offset; - } + if (need_valid_check) + doUnquote(block, data_from, offsets_from, offsets_to, write_buffer); + else + doUnquote(block, data_from, offsets_from, offsets_to, write_buffer); data_to.resize(write_buffer.count()); block.getByPosition(result).column = ColumnNullable::create(std::move(col_to), std::move(col_null_map)); } @@ -345,6 +343,41 @@ class FunctionJsonUnquote : public IFunction fmt::format("Illegal column {} of argument of function {}", column->getName(), getName()), ErrorCodes::ILLEGAL_COLUMN); } + + template + void doUnquote( + const Block & block, + const ColumnString::Chars_t & data_from, + const IColumn::Offsets & offsets_from, + IColumn::Offsets & offsets_to, + JsonBinary::JsonBinaryWriteBuffer & write_buffer) const + { + size_t current_offset = 0; + for (size_t i = 0; i < block.rows(); ++i) + { + size_t next_offset = offsets_from[i]; + size_t data_length = next_offset - current_offset - 1; + if constexpr (validCheck) + { + // TODO(hyb): use SIMDJson to check when SIMDJson is proved in practice + if (data_length >= 2 && data_from[current_offset] == '"' && data_from[next_offset - 2] == '"' + && unlikely( + !checkJsonValid(reinterpret_cast(&data_from[current_offset]), data_length))) + { + throw Exception( + "Invalid JSON text: The document root must not be followed by other values.", + ErrorCodes::ILLEGAL_COLUMN); + } + } + JsonBinary::unquoteStringInBuffer(StringRef(&data_from[current_offset], data_length), write_buffer); + writeChar(0, write_buffer); + offsets_to[i] = write_buffer.count(); + current_offset = next_offset; + } + } + +private: + bool need_valid_check = false; }; @@ -352,7 +385,18 @@ class FunctionCastJsonAsString : public IFunction { public: static constexpr auto name = "cast_json_as_string"; - static FunctionPtr create(const Context &) { return std::make_shared(); } + static FunctionPtr create(const Context & context) + { + if (!context.getDAGContext()) + { + throw Exception("DAGContext should not be nullptr.", ErrorCodes::LOGICAL_ERROR); + } + return std::make_shared(context); + } + + explicit FunctionCastJsonAsString(const Context & context) + : context(context) + {} String getName() const override { return name; } @@ -360,6 +404,8 @@ class FunctionCastJsonAsString : public IFunction bool useDefaultImplementationForConstants() const override { return true; } + void setOutputTiDBFieldType(const tipb::FieldType & tidb_tp_) { tidb_tp = &tidb_tp_; } + DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override { if unlikely (!arguments[0]->isString()) @@ -386,25 +432,58 @@ class FunctionCastJsonAsString : public IFunction ColumnUInt8::MutablePtr col_null_map = ColumnUInt8::create(rows, 0); ColumnUInt8::Container & vec_null_map = col_null_map->getData(); JsonBinary::JsonBinaryWriteBuffer write_buffer(data_to); - size_t current_offset = 0; - for (size_t i = 0; i < block.rows(); ++i) + if likely (tidb_tp->flen() < 0) { - size_t next_offset = offsets_from[i]; - size_t json_length = next_offset - current_offset - 1; - if unlikely (isNullJsonBinary(json_length)) + size_t current_offset = 0; + for (size_t i = 0; i < block.rows(); ++i) { - vec_null_map[i] = 1; + size_t next_offset = offsets_from[i]; + size_t json_length = next_offset - current_offset - 1; + if unlikely (isNullJsonBinary(json_length)) + vec_null_map[i] = 1; + else + { + JsonBinary json_binary( + data_from[current_offset], + StringRef(&data_from[current_offset + 1], json_length - 1)); + json_binary.toStringInBuffer(write_buffer); + } + writeChar(0, write_buffer); + offsets_to[i] = write_buffer.count(); + current_offset = next_offset; } - else + } + else + { + ColumnString::Chars_t container_per_element; + size_t current_offset = 0; + for (size_t i = 0; i < block.rows(); ++i) { - JsonBinary json_binary( - data_from[current_offset], - StringRef(&data_from[current_offset + 1], json_length - 1)); - json_binary.toStringInBuffer(write_buffer); + size_t next_offset = offsets_from[i]; + size_t json_length = next_offset - current_offset - 1; + if unlikely (isNullJsonBinary(json_length)) + vec_null_map[i] = 1; + else + { + JsonBinary::JsonBinaryWriteBuffer element_write_buffer(container_per_element); + JsonBinary json_binary( + data_from[current_offset], + StringRef(&data_from[current_offset + 1], json_length - 1)); + json_binary.toStringInBuffer(element_write_buffer); + size_t orig_length = element_write_buffer.count(); + auto byte_length = charLengthToByteLengthFromUTF8( + reinterpret_cast(container_per_element.data()), + orig_length, + tidb_tp->flen()); + if (byte_length < element_write_buffer.count()) + context.getDAGContext()->handleTruncateError("Data Too Long"); + write_buffer.write(reinterpret_cast(container_per_element.data()), byte_length); + } + + writeChar(0, write_buffer); + offsets_to[i] = write_buffer.count(); + current_offset = next_offset; } - writeChar(0, write_buffer); - offsets_to[i] = write_buffer.count(); - current_offset = next_offset; } data_to.resize(write_buffer.count()); block.getByPosition(result).column = ColumnNullable::create(std::move(col_to), std::move(col_null_map)); @@ -414,8 +493,11 @@ class FunctionCastJsonAsString : public IFunction fmt::format("Illegal column {} of argument of function {}", column->getName(), getName()), ErrorCodes::ILLEGAL_COLUMN); } -}; +private: + const tipb::FieldType * tidb_tp; + const Context & context; +}; class FunctionJsonLength : public IFunction { diff --git a/dbms/src/Functions/FunctionsTiDBConversion.h b/dbms/src/Functions/FunctionsTiDBConversion.h index e251396f0d7..d5a212ef762 100644 --- a/dbms/src/Functions/FunctionsTiDBConversion.h +++ b/dbms/src/Functions/FunctionsTiDBConversion.h @@ -77,30 +77,40 @@ namespace constexpr static Int64 pow10[] = {1, 10, 100, 1000, 10000, 100000, 1000000, 10000000, 100000000, 1000000000}; } +ALWAYS_INLINE inline size_t charLengthToByteLengthFromUTF8(const char * data, size_t length, size_t char_length) +{ + size_t ret = 0; + for (size_t char_index = 0; char_index < char_length && ret < length; ++char_index) + { + uint8_t c = data[ret]; + if (c < 0x80) + ret += 1; + else if (c < 0xE0) + ret += 2; + else if (c < 0xF0) + ret += 3; + else + ret += 4; + } + if unlikely (ret > length) + { + throw Exception( + fmt::format( + "Illegal utf8 byte sequence bytes: {} result_length: {} char_length: {}", + length, + ret, + char_length), + ErrorCodes::ILLEGAL_COLUMN); + } + return ret; +} + /// cast int/real/decimal/time as string template struct TiDBConvertToString { using FromFieldType = typename FromDataType::FieldType; - static size_t charLengthToByteLengthFromUTF8(const char * data, size_t length, size_t char_length) - { - size_t ret = 0; - for (size_t char_index = 0; char_index < char_length && ret < length; char_index++) - { - uint8_t c = data[ret]; - if (c < 0x80) - ret += 1; - else if (c < 0xE0) - ret += 2; - else if (c < 0xF0) - ret += 3; - else - ret += 4; - } - return ret; - } - static void execute( Block & block, const ColumnNumbers & arguments, @@ -148,7 +158,7 @@ struct TiDBConvertToString size_t next_offset = (*offsets_from)[i]; size_t org_length = next_offset - current_offset - 1; size_t byte_length = org_length; - if (tp.flen() > 0) + if (tp.flen() >= 0) { byte_length = tp.flen(); if (tp.charset() == "utf8" || tp.charset() == "utf8mb4") @@ -189,7 +199,7 @@ struct TiDBConvertToString WriteBufferFromVector element_write_buffer(container_per_element); FormatImpl::execute(vec_from[i], element_write_buffer, &type, nullptr); size_t byte_length = element_write_buffer.count(); - if (tp.flen() > 0) + if (tp.flen() >= 0) byte_length = std::min(byte_length, tp.flen()); if (byte_length < element_write_buffer.count()) context.getDAGContext()->handleTruncateError("Data Too Long"); @@ -235,7 +245,7 @@ struct TiDBConvertToString WriteBufferFromVector element_write_buffer(container_per_element); FormatImpl::execute(vec_from[i], element_write_buffer, &type, nullptr); size_t byte_length = element_write_buffer.count(); - if (tp.flen() > 0) + if (tp.flen() >= 0) byte_length = std::min(byte_length, tp.flen()); if (byte_length < element_write_buffer.count()) context.getDAGContext()->handleTruncateError("Data Too Long"); diff --git a/dbms/src/Functions/tests/gtest_cast_as_json.cpp b/dbms/src/Functions/tests/gtest_cast_as_json.cpp index 3e8149aa6e6..cf97a45d2f1 100644 --- a/dbms/src/Functions/tests/gtest_cast_as_json.cpp +++ b/dbms/src/Functions/tests/gtest_cast_as_json.cpp @@ -51,7 +51,11 @@ class TestCastAsJson : public DB::tests::FunctionTest json_column = executeFunction(func_name, columns); } // The `json_binary` should be cast as a string to improve readability. - return executeFunction("cast_json_as_string", {json_column}); + tipb::FieldType field_type; + field_type.set_flen(-1); + field_type.set_collate(TiDB::ITiDBCollator::BINARY); + field_type.set_tp(TiDB::TypeString); + return executeCastJsonAsStringFunction(json_column, field_type); } template diff --git a/dbms/src/Functions/tests/gtest_cast_json_as_string.cpp b/dbms/src/Functions/tests/gtest_cast_json_as_string.cpp index 21e439d35e1..c9e124c008b 100644 --- a/dbms/src/Functions/tests/gtest_cast_json_as_string.cpp +++ b/dbms/src/Functions/tests/gtest_cast_json_as_string.cpp @@ -14,8 +14,10 @@ #include #include +#include #include #include +#include #include #include #include @@ -67,7 +69,11 @@ try auto output_col = createColumn>( {R"([{"a": 1, "b": true}, 3, 3.5, "hello, world", null, true])", {}, "[[0, 1], [2, 3], [4, [5, 6]]]"}); - auto res = executeFunction(func_name, input_col); + tipb::FieldType field_type; + field_type.set_flen(-1); + field_type.set_collate(TiDB::ITiDBCollator::BINARY); + field_type.set_tp(TiDB::TypeString); + auto res = executeCastJsonAsStringFunction(input_col, field_type); ASSERT_COLUMN_EQ(res, output_col); /// ColumnVector(null) @@ -80,13 +86,13 @@ try input_col = ColumnWithTypeAndName(std::move(json_col), nullable_string_type_ptr, "input0"); output_col = createColumn>({{}, {}, {}}); - res = executeFunction(func_name, input_col); + res = executeCastJsonAsStringFunction(input_col, field_type); ASSERT_COLUMN_EQ(res, output_col); /// ColumnConst(null) auto null_input_col = createConstColumn>(3, {}); output_col = createConstColumn>(3, {}); - res = executeFunction(func_name, null_input_col); + res = executeCastJsonAsStringFunction(null_input_col, field_type); ASSERT_COLUMN_EQ(res, output_col); /// ColumnVector(non-null) @@ -95,7 +101,7 @@ try non_null_str_col->insertData(reinterpret_cast(bj2), sizeof(bj2) / sizeof(UInt8)); non_null_str_col->insertData(reinterpret_cast(bj9), sizeof(bj9) / sizeof(UInt8)); auto non_null_input_col = ColumnWithTypeAndName(std::move(non_null_str_col), string_type_ptr, "input0"); - res = executeFunction(func_name, non_null_input_col); + res = executeCastJsonAsStringFunction(non_null_input_col, field_type); output_col = createColumn>( {R"([{"a": 1, "b": true}, 3, 3.5, "hello, world", null, true])", R"([{"a": 1, "b": true}, 3, 3.5, "hello, world", null, true])", @@ -106,7 +112,7 @@ try non_null_str_col = ColumnString::create(); non_null_str_col->insertData(reinterpret_cast(bj2), sizeof(bj2) / sizeof(UInt8)); auto const_non_null_input_col = ColumnConst::create(std::move(non_null_str_col), 3); - res = executeFunction(func_name, {std::move(const_non_null_input_col), string_type_ptr, ""}); + res = executeCastJsonAsStringFunction({std::move(const_non_null_input_col), string_type_ptr, ""}, field_type); output_col = createConstColumn>(3, {R"([{"a": 1, "b": true}, 3, 3.5, "hello, world", null, true])"}); ASSERT_COLUMN_EQ(res, output_col); @@ -119,9 +125,70 @@ try auto const_json_col = ColumnConst::create(std::move(json_col), 3); auto const_nullable_input_col = ColumnWithTypeAndName(std::move(const_json_col), nullable_string_type_ptr, "input0"); - res = executeFunction(func_name, const_nullable_input_col); + res = executeCastJsonAsStringFunction(const_nullable_input_col, field_type); output_col = createConstColumn>(3, {"[[0, 1], [2, 3], [4, [5, 6]]]"}); ASSERT_COLUMN_EQ(res, output_col); + + /// Limit string length + context->getDAGContext()->addFlag(TiDBSQLFlags::IGNORE_TRUNCATE); + str_col = ColumnString::create(); + str_col->insertData(reinterpret_cast(bj2), sizeof(bj2) / sizeof(UInt8)); + str_col->insertData("", 0); + str_col->insertData(reinterpret_cast(bj9), sizeof(bj9) / sizeof(UInt8)); + col_null_map = ColumnUInt8::create(3, 0); + json_col = ColumnNullable::create(std::move(str_col), std::move(col_null_map)); + input_col = ColumnWithTypeAndName(std::move(json_col), nullable_string_type_ptr, "input0"); + + output_col = createColumn>({R"([{"a")", {}, "[[0, "}); + field_type.set_flen(5); + res = executeCastJsonAsStringFunction(input_col, field_type); + ASSERT_COLUMN_EQ(res, output_col); + ASSERT_TRUE(context->getDAGContext()->getWarningCount() == 2); + + // multiple-bytes utf characters "你好" + // clang-format off + UInt8 bj3[] = {0xc, 0x6, 0xe4, 0xbd, 0xa0, 0xe5, 0xa5, 0xbd}; + // clang-format on + str_col = ColumnString::create(); + str_col->insertData(reinterpret_cast(bj3), sizeof(bj3) / sizeof(UInt8)); + col_null_map = ColumnUInt8::create(1, 0); + json_col = ColumnNullable::create(std::move(str_col), std::move(col_null_map)); + input_col = ColumnWithTypeAndName(std::move(json_col), nullable_string_type_ptr, "input0"); + + output_col = createColumn>({R"("你)"}); + field_type.set_flen(2); + context->getDAGContext()->clearWarnings(); + res = executeCastJsonAsStringFunction(input_col, field_type); + ASSERT_TRUE(context->getDAGContext()->getWarningCount() == 1); + ASSERT_COLUMN_EQ(res, output_col); + + output_col = createColumn>({R"()"}); + field_type.set_flen(0); + context->getDAGContext()->clearWarnings(); + res = executeCastJsonAsStringFunction(input_col, field_type); + ASSERT_TRUE(context->getDAGContext()->getWarningCount() == 1); + ASSERT_COLUMN_EQ(res, output_col); + + output_col = createColumn>({R"("你好")"}); + field_type.set_flen(-1); + context->getDAGContext()->clearWarnings(); + res = executeCastJsonAsStringFunction(input_col, field_type); + ASSERT_COLUMN_EQ(res, output_col); + ASSERT_TRUE(context->getDAGContext()->getWarningCount() == 0); + + output_col = createColumn>({R"("你好")"}); + field_type.set_flen(4); + context->getDAGContext()->clearWarnings(); + res = executeCastJsonAsStringFunction(input_col, field_type); + ASSERT_COLUMN_EQ(res, output_col); + ASSERT_TRUE(context->getDAGContext()->getWarningCount() == 0); + + output_col = createColumn>({R"("你好")"}); + field_type.set_flen(10); + context->getDAGContext()->clearWarnings(); + res = executeCastJsonAsStringFunction(input_col, field_type); + ASSERT_COLUMN_EQ(res, output_col); + ASSERT_TRUE(context->getDAGContext()->getWarningCount() == 0); } CATCH diff --git a/dbms/src/Functions/tests/gtest_json_array.cpp b/dbms/src/Functions/tests/gtest_json_array.cpp index 2ad33f0c309..a1c9a68a16f 100644 --- a/dbms/src/Functions/tests/gtest_json_array.cpp +++ b/dbms/src/Functions/tests/gtest_json_array.cpp @@ -33,7 +33,10 @@ class TestJsonArray : public DB::tests::FunctionTest static auto json_array_return_type = std::make_shared(); assert(json_array_return_type->equals(*json_column.type)); // The `json_binary` should be cast as a string to improve readability. - return executeFunction("cast_json_as_string", {json_column}); + tipb::FieldType field_type; + field_type.set_flen(-1); + field_type.set_tp(TiDB::TypeString); + return executeCastJsonAsStringFunction({json_column}, field_type); } }; diff --git a/dbms/src/Functions/tests/gtest_json_unquote.cpp b/dbms/src/Functions/tests/gtest_json_unquote.cpp index 2c91c4dee20..817f9807e90 100644 --- a/dbms/src/Functions/tests/gtest_json_unquote.cpp +++ b/dbms/src/Functions/tests/gtest_json_unquote.cpp @@ -12,10 +12,11 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include +#include +#include #include #include -#include +#include #include #include @@ -24,6 +25,49 @@ namespace DB::tests { class TestJsonUnquote : public DB::tests::FunctionTest { +public: + ColumnWithTypeAndName executeJsonUnquoteFunction( + Context & context, + const ColumnWithTypeAndName & input_column, + bool valid_check) + { + auto & factory = FunctionFactory::instance(); + ColumnsWithTypeAndName columns({input_column}); + ColumnNumbers argument_column_numbers; + for (size_t i = 0; i < columns.size(); ++i) + argument_column_numbers.push_back(i); + + ColumnsWithTypeAndName arguments; + for (const auto argument_column_number : argument_column_numbers) + arguments.push_back(columns.at(argument_column_number)); + + const String func_name = "json_unquote"; + auto builder = factory.tryGet(func_name, context); + if (!builder) + throw TiFlashTestException(fmt::format("Function {} not found!", func_name)); + auto func = builder->build(arguments, nullptr); + auto * function_build_ptr = builder.get(); + if (auto * default_function_builder = dynamic_cast(function_build_ptr); + default_function_builder) + { + auto * function_impl = default_function_builder->getFunctionImpl().get(); + if (auto * function_json_unquote = dynamic_cast(function_impl); + function_json_unquote) + { + function_json_unquote->setNeedValidCheck(valid_check); + } + else + { + throw TiFlashTestException(fmt::format("Function {} not found!", func_name)); + } + } + + Block block(columns); + block.insert({nullptr, func->getReturnType(), "res"}); + func->execute(block, argument_column_numbers, columns.size()); + + return block.getByPosition(columns.size()); + } }; TEST_F(TestJsonUnquote, TestAll) @@ -76,4 +120,158 @@ try } CATCH +TEST_F(TestJsonUnquote, TestCheckValid) +try +{ + /// Normal case: ColumnVector(nullable) + const String func_name = "json_unquote"; + static auto const nullable_string_type_ptr = makeNullable(std::make_shared()); + static auto const string_type_ptr = std::make_shared(); + String bj2("\"hello, \\\"你好, \\u554A world, null, true]\""); + String bj4("[[0, 1], [2, 3], [4, [5, 6]]]"); + auto input_col = createColumn>({bj2, {}, bj4}); + auto output_col + = createColumn>({"hello, \"你好, 啊 world, null, true]", {}, "[[0, 1], [2, 3], [4, [5, 6]]]"}); + auto res = executeJsonUnquoteFunction(*context, input_col, false); + ASSERT_COLUMN_EQ(res, output_col); + + res = executeJsonUnquoteFunction(*context, input_col, true); + ASSERT_COLUMN_EQ(res, output_col); + + try + { + String bj3(R"("hello world \ ")"); + input_col = createColumn>({bj3}); + res = executeJsonUnquoteFunction(*context, input_col, true); + GTEST_FAIL(); + } + catch (Exception & e) + {} +} +CATCH + +TEST_F(TestJsonUnquote, TestCheckValidRaw) +try +{ + // Valid Literals + std::vector valid_literals = {R"(true)", R"(false)", R"(null)"}; + for (const auto & str : valid_literals) + { + ASSERT_TRUE(checkJsonValid(str.c_str(), str.size())); + } + + // Invalid Literals + std::vector invalid_literals + = {R"(tRue)", R"(tru)", R"(trued)", R"(False)", R"(fale)", R"(falses)", R"(nulL)", R"(nul)", R"(nulll)"}; + for (const auto & str : invalid_literals) + { + ASSERT_TRUE(!checkJsonValid(str.c_str(), str.size())); + } + + // Valid Numbers + std::vector valid_numbers = {R"(3)", R"(-100)", R"(231.0123)", R"(3.14e0)", R"(-3.14e-1)", R"(3.14e100)"}; + for (const auto & str : valid_numbers) + { + ASSERT_TRUE(checkJsonValid(str.c_str(), str.size())); + } + + // Invalid Numbers + std::vector invalid_numbers + = {R"(3.3.3)", R"(3.3t)", R"(--100)", R"(e231.0123)", R"(+3.14e)", R"(-+3.23)", R"(-+341a)"}; + for (const auto & str : invalid_numbers) + { + ASSERT_TRUE(!checkJsonValid(str.c_str(), str.size())); + } + + // Valid Strings + std::vector valid_strings + = {R"("foo")", + R"("hello world!\n")", + R"("hello \"name\"")", + R"("你好 朋友")", + R"("\u554A world")", + R"("{\"foo\":\"bar\",\"bar\":{\"baz\":[\"qux\"]}}")", + R"("\"hello world\"")"}; + for (const auto & str : valid_strings) + { + ASSERT_TRUE(checkJsonValid(str.c_str(), str.size())); + } + + // Invalid Strings + std::vector invalid_strings + = {R"(""hello world"")", R"("hello world"ef)", R"("hello world)", R"("hello world \ ")"}; + for (const auto & str : invalid_strings) + { + ASSERT_TRUE(!checkJsonValid(str.c_str(), str.size())); + } + + // Valid Objects + std::vector valid_objects + = {R"({})", + R"({"a":3.0, "b":-4, "c":"hello world", "d":true})", + R"({"a":3.0, "b":{"c":{"d":"hello world"}}})", + R"({"a":3.0, "b":{"name":"Tom", "experience":{"current":10, "last":30}}, "c":"hello world", "d":true})"}; + for (const auto & str : valid_objects) + { + ASSERT_TRUE(checkJsonValid(str.c_str(), str.size())); + } + + // Invalid Objects + std::vector invalid_objects + = {R"({"a")", + R"({"a"})", + R"({"a":})", + R"({32:"a"})", + R"({"a":32:})", + R"({"a":32}})", + R"({"a":32,:})", + R"({"a":32,"dd":{"d","e"}})"}; + for (const auto & str : invalid_objects) + { + ASSERT_TRUE(!checkJsonValid(str.c_str(), str.size())); + } + + // Valid Arrays + std::vector valid_arrays + = {R"([])", + R"([true, null, false, "hello world", 3.0, -1e4])", + R"([1,2,[ 3, "hello", ["world", 32]]])", + R"([[],[]])"}; + for (const auto & str : valid_arrays) + { + ASSERT_TRUE(checkJsonValid(str.c_str(), str.size())); + } + + // Invalid Arrays + std::vector invalid_arrays + = {R"([32)", R"([32]])", R"([32,])", R"([[],23)", R"([32], 23)", R"(["hello", ["world"])"}; + for (const auto & str : invalid_arrays) + { + ASSERT_TRUE(!checkJsonValid(str.c_str(), str.size())); + } + + // Valid Mixtures + std::vector valid_mixtures + = {R"([{}])", + R"([3, {"name":3}, {"experince":[3,6,9]}])", + R"({"age":3, "value":[-32, true, {"exr":[23,-12,true]}]})"}; + for (const auto & str : valid_mixtures) + { + ASSERT_TRUE(checkJsonValid(str.c_str(), str.size())); + } + + // Invalid Mixtures + std::vector invalid_mixtures + = {R"({[]})", + R"({"name":"tome", [3,2]})", + R"([{"name":"tome"})", + R"([{"name":"tome"])", + R"([{"name":"tome", "age":[}])"}; + for (const auto & str : invalid_mixtures) + { + ASSERT_TRUE(!checkJsonValid(str.c_str(), str.size())); + } +} +CATCH + } // namespace DB::tests \ No newline at end of file diff --git a/dbms/src/TestUtils/ColumnsToTiPBExpr.cpp b/dbms/src/TestUtils/ColumnsToTiPBExpr.cpp index 294ae5924ce..2e2adbb23c0 100644 --- a/dbms/src/TestUtils/ColumnsToTiPBExpr.cpp +++ b/dbms/src/TestUtils/ColumnsToTiPBExpr.cpp @@ -114,6 +114,10 @@ void columnsToTiPBExprForTiDBCast( auto * argument_expr = expr->add_children(); columnToTiPBExpr(argument_expr, columns[argument_column_number[0]], 0); ColumnInfo ci = reverseGetColumnInfo({type_string, target_type}, 0, Field(), true); + if (ci.tp == TiDB::TypeString) + { + ci.flen = -1; + } *(expr->mutable_field_type()) = columnInfoToFieldType(ci); if (collator != nullptr) expr->mutable_field_type()->set_collate(-collator->getCollatorId()); diff --git a/dbms/src/TestUtils/FunctionTestUtils.cpp b/dbms/src/TestUtils/FunctionTestUtils.cpp index 5f345c0d306..260d7938a0b 100644 --- a/dbms/src/TestUtils/FunctionTestUtils.cpp +++ b/dbms/src/TestUtils/FunctionTestUtils.cpp @@ -22,6 +22,7 @@ #include #include #include +#include #include #include #include @@ -599,5 +600,47 @@ ColumnWithTypeAndName FunctionTest::executeFunctionWithMetaData( return DB::tests::executeFunction(*context, func_name, argument_column_numbers, columns, collator, meta.val, false); } +ColumnWithTypeAndName FunctionTest::executeCastJsonAsStringFunction( + const ColumnWithTypeAndName & input_column, + const tipb::FieldType & field_type) +{ + auto & factory = FunctionFactory::instance(); + ColumnsWithTypeAndName columns({input_column}); + ColumnNumbers argument_column_numbers; + for (size_t i = 0; i < columns.size(); ++i) + argument_column_numbers.push_back(i); + + ColumnsWithTypeAndName arguments; + for (const auto argument_column_number : argument_column_numbers) + arguments.push_back(columns.at(argument_column_number)); + + const String func_name = "cast_json_as_string"; + auto builder = factory.tryGet(func_name, *context); + if (!builder) + throw TiFlashTestException(fmt::format("Function {} not found!", func_name)); + auto func = builder->build(arguments, nullptr); + auto * function_build_ptr = builder.get(); + if (auto * default_function_builder = dynamic_cast(function_build_ptr); + default_function_builder) + { + auto * function_impl = default_function_builder->getFunctionImpl().get(); + if (auto * function_cast_json_as_string = dynamic_cast(function_impl); + function_cast_json_as_string) + { + function_cast_json_as_string->setOutputTiDBFieldType(field_type); + } + else + { + throw TiFlashTestException(fmt::format("Function {} not found!", func_name)); + } + } + + Block block(columns); + block.insert({nullptr, func->getReturnType(), "res"}); + func->execute(block, argument_column_numbers, columns.size()); + + return block.getByPosition(columns.size()); +} + } // namespace tests } // namespace DB diff --git a/dbms/src/TestUtils/FunctionTestUtils.h b/dbms/src/TestUtils/FunctionTestUtils.h index 87e065bbe5a..33a49b94854 100644 --- a/dbms/src/TestUtils/FunctionTestUtils.h +++ b/dbms/src/TestUtils/FunctionTestUtils.h @@ -839,6 +839,10 @@ class FunctionTest : public ::testing::Test const FuncMetaData & meta, const TiDB::TiDBCollatorPtr & collator = nullptr); + ColumnWithTypeAndName executeCastJsonAsStringFunction( + const ColumnWithTypeAndName & input_column, + const tipb::FieldType & field_type); + DAGContext & getDAGContext() { RUNTIME_ASSERT(dag_context_ptr != nullptr); diff --git a/dbms/src/TiDB/Decode/JsonScanner.cpp b/dbms/src/TiDB/Decode/JsonScanner.cpp new file mode 100644 index 00000000000..fd55317ae53 --- /dev/null +++ b/dbms/src/TiDB/Decode/JsonScanner.cpp @@ -0,0 +1,514 @@ +// Copyright 2023 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +namespace DB +{ +bool JsonScanIsSpace(char c) +{ + return (static_cast(c) <= static_cast(' ')) + && (c == ' ' || c == '\t' || c == '\r' || c == '\n'); +} + +bool checkJsonValid(const char * data, size_t length) +{ + JsonScanner scanner; + for (size_t i = 0; i < length; ++i) + { + ++scanner.bytes; + if (scanner.stepFunc(scanner, data[i]) == JsonScanError) + { + return false; + } + } + return (scanner.eof() != JsonScanError); +} + +// stateBeginValueOrEmpty is the state after reading `[`. +JsonScanAction stateBeginValueOrEmpty(JsonScanner & scanner, char c) +{ + if (JsonScanIsSpace(c)) + { + return JsonScanSkipSpace; + } + if (c == ']') + { + return stateEndValue(scanner, c); + } + return stateBeginValue(scanner, c); +} + +// stateBeginValue is the state at the beginning of the input. +JsonScanAction stateBeginValue(JsonScanner & scanner, char c) +{ + if (JsonScanIsSpace(c)) + { + return JsonScanSkipSpace; + } + switch (c) + { + case '{': + scanner.stepFunc = stateBeginStringOrEmpty; + return scanner.pushParseState(c, JsonScanParseObjectKey, JsonScanBeginObject); + case '[': + scanner.stepFunc = stateBeginValueOrEmpty; + return scanner.pushParseState(c, JsonScanParseArrayValue, JsonScanBeginArray); + case '"': + scanner.stepFunc = stateInString; + return JsonScanBeginLiteral; + case '-': + scanner.stepFunc = stateNeg; + return JsonScanBeginLiteral; + case '0': // beginning of 0.123 + scanner.stepFunc = state0; + return JsonScanBeginLiteral; + case 't': // beginning of true + scanner.stepFunc = stateT; + return JsonScanBeginLiteral; + case 'f': // beginning of false + scanner.stepFunc = stateF; + return JsonScanBeginLiteral; + case 'n': // beginning of null + scanner.stepFunc = stateN; + return JsonScanBeginLiteral; + default: + break; + } + if ('1' <= c && c <= '9') + { // beginning of 1234.5 + scanner.stepFunc = state1; + return JsonScanBeginLiteral; + } + return scanner.genError("looking for beginning of value", c); +} + +// stateBeginStringOrEmpty is the state after reading `{`. +JsonScanAction stateBeginStringOrEmpty(JsonScanner & scanner, char c) +{ + if (JsonScanIsSpace(c)) + { + return JsonScanSkipSpace; + } + if (c == '}') + { + scanner.parse_state_stack.top() = JsonScanParseObjectValue; + return stateEndValue(scanner, c); + } + return stateBeginString(scanner, c); +} + +// stateBeginString is the state after reading `{"key": value,`. +JsonScanAction stateBeginString(JsonScanner & scanner, char c) +{ + if (JsonScanIsSpace(c)) + { + return JsonScanSkipSpace; + } + if (c == '"') + { + scanner.stepFunc = stateInString; + return JsonScanBeginLiteral; + } + return scanner.genError("looking for beginning of object key string", c); +} + +// stateEndValue is the state after completing a value, +// such as after reading `{}` or `true` or `["x"`. +JsonScanAction stateEndValue(JsonScanner & scanner, char c) +{ + size_t n = scanner.parse_state_stack.size(); + if (n == 0) + { + // Completed top-level before the current byte. + scanner.stepFunc = stateEndTop; + scanner.end_top = true; + return stateEndTop(scanner, c); + } + if (JsonScanIsSpace(c)) + { + scanner.stepFunc = stateEndValue; + return JsonScanSkipSpace; + } + switch (scanner.parse_state_stack.top()) + { + case JsonScanParseObjectKey: + if (c == ':') + { + scanner.parse_state_stack.top() = JsonScanParseObjectValue; + scanner.stepFunc = stateBeginValue; + return JsonScanObjectKey; + } + return scanner.genError("after object key", c); + case JsonScanParseObjectValue: + if (c == ',') + { + scanner.parse_state_stack.top() = JsonScanParseObjectKey; + scanner.stepFunc = stateBeginString; + return JsonScanObjectValue; + } + if (c == '}') + { + scanner.popParseState(); + return JsonScanEndObject; + } + return scanner.genError("after object key:value pair", c); + case JsonScanParseArrayValue: + if (c == ',') + { + scanner.stepFunc = stateBeginValue; + return JsonScanArrayValue; + } + if (c == ']') + { + scanner.popParseState(); + return JsonScanEndArray; + } + return scanner.genError("after array element", c); + default: + break; + } + return scanner.genError("", c); +} + +// stateEndTop is the state after finishing the top-level value, +// such as after reading `{}` or `[1,2,3]`. +// Only space characters should be seen now. +JsonScanAction stateEndTop(JsonScanner & scanner, char c) +{ + if (!JsonScanIsSpace(c)) + { + // Complain about non-space byte on next call. + scanner.genError("after top-level value", c); + } + return JsonScanEnd; +} + +// stateInString is the state after reading `"`. +JsonScanAction stateInString(JsonScanner & scanner, char c) +{ + if (c == '"') + { + scanner.stepFunc = stateEndValue; + return JsonScanContinue; + } + if (c == '\\') + { + scanner.stepFunc = stateInStringEsc; + return JsonScanContinue; + } + if (static_cast(c) < 0x20) + { + return scanner.genError("in string literal", c); + } + return JsonScanContinue; +} + +// stateInStringEsc is the state after reading `"\` during a quoted string. +JsonScanAction stateInStringEsc(JsonScanner & scanner, char c) +{ + switch (c) + { + case 'b': + case 'f': + case 'n': + case 'r': + case 't': + case '\\': + case '/': + case '"': + scanner.stepFunc = stateInString; + return JsonScanContinue; + case 'u': + scanner.stepFunc = stateInStringEscU; + return JsonScanContinue; + default: + break; + } + return scanner.genError("in string escape code", c); +} + +// stateInStringEscU is the state after reading `"\u` during a quoted string. +JsonScanAction stateInStringEscU(JsonScanner & scanner, char c) +{ + if (('0' <= c && c <= '9') || ('a' <= c && c <= 'f') || ('A' <= c && c <= 'F')) + { + scanner.stepFunc = stateInStringEscU1; + return JsonScanContinue; + } + // numbers + return scanner.genError("in \\u hexadecimal character escape", c); +} + +// stateInStringEscU1 is the state after reading `"\u1` during a quoted string. +JsonScanAction stateInStringEscU1(JsonScanner & scanner, char c) +{ + if (('0' <= c && c <= '9') || ('a' <= c && c <= 'f') || ('A' <= c && c <= 'F')) + { + scanner.stepFunc = stateInStringEscU12; + return JsonScanContinue; + } + // numbers + return scanner.genError("in \\u hexadecimal character escape", c); +} + +// stateInStringEscU12 is the state after reading `"\u12` during a quoted string. +JsonScanAction stateInStringEscU12(JsonScanner & scanner, char c) +{ + if (('0' <= c && c <= '9') || ('a' <= c && c <= 'f') || ('A' <= c && c <= 'F')) + { + scanner.stepFunc = stateInStringEscU123; + return JsonScanContinue; + } + // numbers + return scanner.genError("in \\u hexadecimal character escape", c); +} + +// stateInStringEscU123 is the state after reading `"\u123` during a quoted string. +JsonScanAction stateInStringEscU123(JsonScanner & scanner, char c) +{ + if (('0' <= c && c <= '9') || ('a' <= c && c <= 'f') || ('A' <= c && c <= 'F')) + { + scanner.stepFunc = stateInString; + return JsonScanContinue; + } + // numbers + return scanner.genError("in \\u hexadecimal character escape", c); +} + +// stateNeg is the state after reading `-` during a number. +JsonScanAction stateNeg(JsonScanner & scanner, char c) +{ + if (c == '0') + { + scanner.stepFunc = state0; + return JsonScanContinue; + } + if ('1' <= c && c <= '9') + { + scanner.stepFunc = state1; + return JsonScanContinue; + } + return scanner.genError("in numeric literal", c); +} + +// state1 is the state after reading a non-zero integer during a number, +// such as after reading `1` or `100` but not `0`. +JsonScanAction state1(JsonScanner & scanner, char c) +{ + if ('0' <= c && c <= '9') + { + scanner.stepFunc = state1; + return JsonScanContinue; + } + return state0(scanner, c); +} + +// state0 is the state after reading `0` during a number. +JsonScanAction state0(JsonScanner & scanner, char c) +{ + if (c == '.') + { + scanner.stepFunc = stateDot; + return JsonScanContinue; + } + if (c == 'e' || c == 'E') + { + scanner.stepFunc = stateE; + return JsonScanContinue; + } + return stateEndValue(scanner, c); +} + +// stateDot is the state after reading the integer and decimal point in a number, +// such as after reading `1.`. +JsonScanAction stateDot(JsonScanner & scanner, char c) +{ + if ('0' <= c && c <= '9') + { + scanner.stepFunc = stateDot0; + return JsonScanContinue; + } + return scanner.genError("after decimal point in numeric literal", c); +} + +// stateDot0 is the state after reading the integer, decimal point, and subsequent +// digits of a number, such as after reading `3.14`. +JsonScanAction stateDot0(JsonScanner & scanner, char c) +{ + if ('0' <= c && c <= '9') + { + return JsonScanContinue; + } + if (c == 'e' || c == 'E') + { + scanner.stepFunc = stateE; + return JsonScanContinue; + } + return stateEndValue(scanner, c); +} + +// stateE is the state after reading the mantissa and e in a number, +// such as after reading `314e` or `0.314e`. +JsonScanAction stateE(JsonScanner & scanner, char c) +{ + if (c == '+' || c == '-') + { + scanner.stepFunc = stateESign; + return JsonScanContinue; + } + return stateESign(scanner, c); +} + +// stateESign is the state after reading the mantissa, e, and sign in a number, +// such as after reading `314e-` or `0.314e+`. +JsonScanAction stateESign(JsonScanner & scanner, char c) +{ + if ('0' <= c && c <= '9') + { + scanner.stepFunc = stateE0; + return JsonScanContinue; + } + return scanner.genError("in exponent of numeric literal", c); +} + +// stateE0 is the state after reading the mantissa, e, optional sign, +// and at least one digit of the exponent in a number, +// such as after reading `314e-2` or `0.314e+1` or `3.14e0`. +JsonScanAction stateE0(JsonScanner & scanner, char c) +{ + if ('0' <= c && c <= '9') + { + return JsonScanContinue; + } + return stateEndValue(scanner, c); +} + +// stateT is the state after reading `t`. +JsonScanAction stateT(JsonScanner & scanner, char c) +{ + if (c == 'r') + { + scanner.stepFunc = stateTr; + return JsonScanContinue; + } + return scanner.genError("in literal true (expecting 'r')", c); +} + +// stateTr is the state after reading `tr`. +JsonScanAction stateTr(JsonScanner & scanner, char c) +{ + if (c == 'u') + { + scanner.stepFunc = stateTru; + return JsonScanContinue; + } + return scanner.genError("in literal true (expecting 'u')", c); +} + +// stateTru is the state after reading `tru`. +JsonScanAction stateTru(JsonScanner & scanner, char c) +{ + if (c == 'e') + { + scanner.stepFunc = stateEndValue; + return JsonScanContinue; + } + return scanner.genError("in literal true (expecting 'e')", c); +} + +// stateF is the state after reading `f`. +JsonScanAction stateF(JsonScanner & scanner, char c) +{ + if (c == 'a') + { + scanner.stepFunc = stateFa; + return JsonScanContinue; + } + return scanner.genError("in literal false (expecting 'a')", c); +} + +// stateFa is the state after reading `fa`. +JsonScanAction stateFa(JsonScanner & scanner, char c) +{ + if (c == 'l') + { + scanner.stepFunc = stateFal; + return JsonScanContinue; + } + return scanner.genError("in literal false (expecting 'l')", c); +} + +// stateFal is the state after reading `fal`. +JsonScanAction stateFal(JsonScanner & scanner, char c) +{ + if (c == 's') + { + scanner.stepFunc = stateFals; + return JsonScanContinue; + } + return scanner.genError("in literal false (expecting 's')", c); +} + +// stateFals is the state after reading `fals`. +JsonScanAction stateFals(JsonScanner & scanner, char c) +{ + if (c == 'e') + { + scanner.stepFunc = stateEndValue; + return JsonScanContinue; + } + return scanner.genError("in literal false (expecting 'e')", c); +} + +// stateN is the state after reading `n`. +JsonScanAction stateN(JsonScanner & scanner, char c) +{ + if (c == 'u') + { + scanner.stepFunc = stateNu; + return JsonScanContinue; + } + return scanner.genError("in literal null (expecting 'u')", c); +} + +// stateNu is the state after reading `nu`. +JsonScanAction stateNu(JsonScanner & scanner, char c) +{ + if (c == 'l') + { + scanner.stepFunc = stateNul; + return JsonScanContinue; + } + return scanner.genError("in literal null (expecting 'l')", c); +} + +// stateNul is the state after reading `nul`. +JsonScanAction stateNul(JsonScanner & scanner, char c) +{ + if (c == 'l') + { + scanner.stepFunc = stateEndValue; + return JsonScanContinue; + } + return scanner.genError("in literal null (expecting 'l')", c); +} + +// stateError is the state after reaching a syntax error, +// such as after reading `[1}` or `5.1.2`. +JsonScanAction stateError(JsonScanner &, char) +{ + return JsonScanEnd; +} +} // namespace DB \ No newline at end of file diff --git a/dbms/src/TiDB/Decode/JsonScanner.h b/dbms/src/TiDB/Decode/JsonScanner.h new file mode 100644 index 00000000000..a18b5f3c937 --- /dev/null +++ b/dbms/src/TiDB/Decode/JsonScanner.h @@ -0,0 +1,203 @@ +// Copyright 2023 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include + +#include + +namespace DB +{ +// C++ implementation of https://github.com/golang/go/blob/master/src/encoding/json/scanner.go +// Most of the comments are copied from go-lib directly. + +// These values are returned by the state transition functions +// assigned to scanner.state and the method scanner.eof. +// They give details about the current state of the scan that +// callers might be interested to know about. +// It is okay to ignore the return value of any particular +// call to scanner.state: if one call returns scanError, +// every subsequent call will return scanError too. +enum JsonScanAction +{ + JsonScanContinue, // uninteresting byte + JsonScanBeginLiteral, // end implied by next result != scanContinue + JsonScanBeginObject, // begin object + JsonScanObjectKey, // just finished object key (string) + JsonScanObjectValue, // just finished non-last object value + JsonScanEndObject, // end object (implies scanObjectValue if possible) + JsonScanBeginArray, // begin array + JsonScanArrayValue, // just finished array value + JsonScanEndArray, // end array (implies scanArrayValue if possible) + JsonScanSkipSpace, // space byte; can skip; known to be last "continue" result + // Stop. + JsonScanEnd, // top-level value ended *before* this byte; known to be first "stop" result + JsonScanError, // hit an error, scanner.err. +}; + +// These values are stored in the parseState stack. +// They give the current state of a composite value +// being scanned. If the parser is inside a nested value +// the parseState describes the nested state, outermost at entry 0. +enum JsonScanState +{ + JsonScanParseObjectKey, // parsing object key (before colon) + JsonScanParseObjectValue, // parsing object value (after colon) + JsonScanParseArrayValue, // parsing array value +}; + +struct JsonSyntaxError +{ + String msg; + Int64 offset; +}; + +struct JsonScanner; + +// This limits the max nesting depth to prevent stack overflow. +// This is permitted by https://tools.ietf.org/html/rfc7159#section-9 +const static Int64 kJsonMaxNestingDepth = 10000; + +JsonScanAction stateBeginValueOrEmpty(JsonScanner & scanner, char c); +JsonScanAction stateBeginValue(JsonScanner & scanner, char c); +JsonScanAction stateBeginStringOrEmpty(JsonScanner & scanner, char c); +JsonScanAction stateBeginString(JsonScanner & scanner, char c); +JsonScanAction stateEndValue(JsonScanner & scanner, char c); +JsonScanAction stateEndTop(JsonScanner & scanner, char c); +JsonScanAction stateInString(JsonScanner & scanner, char c); +JsonScanAction stateInStringEsc(JsonScanner & scanner, char c); +JsonScanAction stateInStringEscU(JsonScanner & scanner, char c); +JsonScanAction stateInStringEscU1(JsonScanner & scanner, char c); +JsonScanAction stateInStringEscU12(JsonScanner & scanner, char c); +JsonScanAction stateInStringEscU123(JsonScanner & scanner, char c); +JsonScanAction stateNeg(JsonScanner & scanner, char c); +JsonScanAction state1(JsonScanner & scanner, char c); +JsonScanAction state0(JsonScanner & scanner, char c); +JsonScanAction stateDot(JsonScanner & scanner, char c); +JsonScanAction stateDot0(JsonScanner & scanner, char c); +JsonScanAction stateE(JsonScanner & scanner, char c); +JsonScanAction stateESign(JsonScanner & scanner, char c); +JsonScanAction stateE0(JsonScanner & scanner, char c); +JsonScanAction stateT(JsonScanner & scanner, char c); +JsonScanAction stateTr(JsonScanner & scanner, char c); +JsonScanAction stateTru(JsonScanner & scanner, char c); +JsonScanAction stateF(JsonScanner & scanner, char c); +JsonScanAction stateFa(JsonScanner & scanner, char c); +JsonScanAction stateFal(JsonScanner & scanner, char c); +JsonScanAction stateFals(JsonScanner & scanner, char c); +JsonScanAction stateN(JsonScanner & scanner, char c); +JsonScanAction stateNu(JsonScanner & scanner, char c); +JsonScanAction stateNul(JsonScanner & scanner, char c); +JsonScanAction stateError(JsonScanner & scanner, char c); + +// A scanner is a JSON scanning state machine. +// Callers call scan.reset and then pass bytes in one at a time +// by calling scan.step(&scan, c) for each byte. +// The return value, referred to as an opcode, tells the +// caller about significant parsing events like beginning +// and ending literals, objects, and arrays, so that the +// caller can follow along if it wishes. +// The return value scanEnd indicates that a single top-level +// JSON value has been completed, *before* the byte that +// just got passed in. (The indication must be delayed in order +// to recognize the end of numbers: is 123 a whole value or +// the beginning of 12345e+6?). +struct JsonScanner +{ + // error records an error and switches to the error state. + JsonScanAction genError(String msg, char c) + { + stepFunc = &stateError; + error.msg = fmt::format("invalid character {} {}", c, msg); + error.offset = bytes; + return JsonScanError; + } + + // eof tells the scanner that the end of input has been reached. + // It returns a scan status just as s.step does. + JsonScanAction eof() + { + if (!error.msg.empty()) + { + return JsonScanError; + } + if (end_top) + { + return JsonScanEnd; + } + stepFunc(*this, ' '); + if (end_top) + { + return JsonScanEnd; + } + if (!error.msg.empty()) + { + error.msg = "unexpected end of JSON input"; + error.offset = bytes; + } + return JsonScanError; + } + + // pushParseState pushes a new parse state p onto the parse stack. + // an error state is returned if maxNestingDepth was exceeded, otherwise successState is returned. + JsonScanAction pushParseState(char c, JsonScanState new_parse_state, JsonScanAction success_action) + { + parse_state_stack.push(new_parse_state); + if (parse_state_stack.size() <= kJsonMaxNestingDepth) + { + return success_action; + } + return genError("exceeded max depth", c); + } + + // popParseState pops a parse state (already obtained) off the stack + // and updates s.step accordingly. + void popParseState() + { + parse_state_stack.pop(); + if (parse_state_stack.empty()) + { + stepFunc = &stateEndTop; + end_top = true; + } + else + { + stepFunc = stateEndValue; + } + } + + // The step is a func to be called to execute the next transition. + // Also tried using an integer constant and a single func + // with a switch, but using the func directly was 10% faster + // on a 64-bit Mac Mini, and it's nicer to read. + JsonScanAction (*stepFunc)(JsonScanner &, char) = stateBeginValue; + + // Reached end of top-level value. + bool end_top = false; + // Stack of what we're in the middle of - array values, object keys, object values. + std::stack parse_state_stack; + // Error that happened, if any. + JsonSyntaxError error; + // total bytes consumed, updated by decoder.Decode (and deliberately + // not set to zero by scan.reset) + Int64 bytes = 0; +}; + +// checkJsonValid verifies that data is valid JSON-encoded data. +bool checkJsonValid(const char * data, size_t length); +} // namespace DB \ No newline at end of file diff --git a/tests/fullstack-test/expr/cast_json_as_string.test b/tests/fullstack-test/expr/cast_json_as_string.test index d124e7debc9..728739f4eb0 100644 --- a/tests/fullstack-test/expr/cast_json_as_string.test +++ b/tests/fullstack-test/expr/cast_json_as_string.test @@ -33,3 +33,38 @@ mysql> set tidb_enforce_mpp=1; select cast(a as char) from test.t; | 3.01 | | "2020-01-02" | +-----------------+ + +mysql> set tidb_enforce_mpp=1; select cast(a as char(3)) from test.t; ++--------------------+ +| cast(a as char(3)) | ++--------------------+ +| NULL | +| [1, | +| {"a | +| 3.0 | +| "20 | ++--------------------+ + +mysql> set tidb_enforce_mpp=1; select cast(a as char(30)) as col from test.t; ++--------------+ +| col | ++--------------+ +| NULL | +| [1, 2] | +| {"a": "b"} | +| 3.01 | +| "2020-01-02" | ++--------------+ + +mysql> set tidb_enforce_mpp=1; select char_length(cast(a as char(0))) as col from test.t; ++------+ +| col | ++------+ +| NULL | +| 0 | +| 0 | +| 0 | +| 0 | ++------+ + +mysql> drop table if exists test.t \ No newline at end of file diff --git a/tests/fullstack-test/expr/json_extract.test b/tests/fullstack-test/expr/json_extract.test index 3693f2e1bd4..1514718273e 100644 --- a/tests/fullstack-test/expr/json_extract.test +++ b/tests/fullstack-test/expr/json_extract.test @@ -46,3 +46,5 @@ mysql> set tidb_enforce_mpp=1; select json_extract(d, '\$[0]', '\$[1]', '\$[2].a +-------------+ | [1, 2, "b"] | +-------------+ + +mysql> drop table if exists test.t diff --git a/tests/fullstack-test/expr/json_unquote.test b/tests/fullstack-test/expr/json_unquote.test new file mode 100644 index 00000000000..22d440e14cf --- /dev/null +++ b/tests/fullstack-test/expr/json_unquote.test @@ -0,0 +1,36 @@ +# Copyright 2023 PingCAP, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +mysql> drop table if exists test.t +mysql> create table test.t(a varchar(20)) +mysql> alter table test.t set tiflash replica 1 +mysql> insert into test.t values(null) +mysql> insert into test.t values('\"hello world\"') #NO_UNESCAPE + +func> wait_table test t + +mysql> set tidb_enforce_mpp=1; select json_unquote(a) as col from test.t ++-------------+ +| col | ++-------------+ +| NULL | +| hello world | ++-------------+ + +mysql> insert into test.t values('\"hello\\\\ \"') #NO_UNESCAPE + +mysql> set tidb_enforce_mpp=1; select json_unquote(a) as col from test.t +{#REGEXP}.*Invalid JSON text: The document root must not be followed by other values.* + +mysql> drop table if exists test.t \ No newline at end of file diff --git a/tests/fullstack-test/expr/json_unquote_extract.test b/tests/fullstack-test/expr/json_unquote_extract.test index 330699c6ba9..ec2e5de2ff1 100644 --- a/tests/fullstack-test/expr/json_unquote_extract.test +++ b/tests/fullstack-test/expr/json_unquote_extract.test @@ -32,3 +32,5 @@ mysql> set tidb_enforce_mpp=1; select b->>'\$.*' as col_a, c->>'\$[*]' as col_b, +-------------+----------------+-------------+ | ["b", "bb"] | [1, 2, [3, 4]] | ["b"] | +-------------+----------------+-------------+ + +mysql> drop table if exists test.t