From db4d19bb84505c998a6ab52e9f3c3ae0b83009a2 Mon Sep 17 00:00:00 2001 From: YangKeao Date: Thu, 21 Nov 2024 00:00:41 +0800 Subject: [PATCH] expression: fix the arg verification for json functions. (#54145) close pingcap/tidb#54029, close pingcap/tidb#54044 --- pkg/expression/builtin_json.go | 123 +++++++++--------- pkg/parser/mysql/errcode.go | 1 + pkg/parser/mysql/errname.go | 1 + pkg/parser/mysql/state.go | 1 + pkg/types/field_type.go | 4 + .../integrationtest/r/expression/json.result | 50 ++++++- tests/integrationtest/t/expression/json.test | 50 ++++++- 7 files changed, 165 insertions(+), 65 deletions(-) diff --git a/pkg/expression/builtin_json.go b/pkg/expression/builtin_json.go index 0dbd252da8420..179c948baa98c 100644 --- a/pkg/expression/builtin_json.go +++ b/pkg/expression/builtin_json.go @@ -18,6 +18,7 @@ import ( "bytes" "context" goJSON "encoding/json" + "strconv" "strings" "github.com/pingcap/errors" @@ -108,7 +109,7 @@ func (b *builtinJSONTypeSig) Clone() builtinFunc { } func (c *jsonTypeFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) { - if err := c.verifyArgs(args); err != nil { + if err := c.verifyArgs(ctx.GetEvalCtx(), args); err != nil { return nil, err } bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETString, types.ETJson) @@ -125,6 +126,51 @@ func (c *jsonTypeFunctionClass) getFunction(ctx BuildContext, args []Expression) return sig, nil } +func (c *jsonTypeFunctionClass) verifyArgs(ctx EvalContext, args []Expression) error { + if err := c.baseFunctionClass.verifyArgs(args); err != nil { + return err + } + return verifyJSONArgsType(ctx, c.funcName, true, args, 0) +} + +// verifyJSONArgsType verifies that all args specified in `jsonArgsIndex` are JSON or non-binary string or NULL. +// the `useJSONErr` specifies to use `ErrIncorrectType` or `ErrInvalidTypeForJSON`. If it's true, the error will be `ErrInvalidTypeForJSON` +func verifyJSONArgsType(ctx EvalContext, funcName string, useJSONErr bool, args []Expression, jsonArgsIndex ...int) error { + if jsonArgsIndex == nil { + // if no index is specified, verify all args + jsonArgsIndex = make([]int, len(args)) + for i := 0; i < len(args); i++ { + jsonArgsIndex[i] = i + } + } + for _, argIndex := range jsonArgsIndex { + arg := args[argIndex] + + typ := arg.GetType(ctx) + if typ.GetType() == mysql.TypeNull { + continue + } + + evalType := typ.EvalType() + switch evalType { + case types.ETString: + cs := typ.GetCharset() + if cs == charset.CharsetBin { + return types.ErrInvalidJSONCharset.GenWithStackByArgs(cs) + } + continue + case types.ETJson: + continue + default: + if useJSONErr { + return ErrInvalidTypeForJSON.GenWithStackByArgs(argIndex+1, funcName) + } + return ErrIncorrectType.GenWithStackByArgs(strconv.Itoa(argIndex+1), funcName) + } + } + return nil +} + func (b *builtinJSONTypeSig) evalString(ctx EvalContext, row chunk.Row) (val string, isNull bool, err error) { var j types.BinaryJSON j, isNull, err = b.args[0].EvalJSON(ctx, row) @@ -155,10 +201,7 @@ func (c *jsonExtractFunctionClass) verifyArgs(ctx EvalContext, args []Expression if err := c.baseFunctionClass.verifyArgs(args); err != nil { return err } - if evalType := args[0].GetType(ctx).EvalType(); evalType != types.ETString && evalType != types.ETJson { - return ErrInvalidTypeForJSON.GenWithStackByArgs(1, "json_extract") - } - return nil + return verifyJSONArgsType(ctx, c.funcName, true, args, 0) } func (c *jsonExtractFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) { @@ -225,10 +268,7 @@ func (c *jsonUnquoteFunctionClass) verifyArgs(ctx EvalContext, args []Expression if err := c.baseFunctionClass.verifyArgs(args); err != nil { return err } - if evalType := args[0].GetType(ctx).EvalType(); evalType != types.ETString && evalType != types.ETJson { - return ErrIncorrectType.GenWithStackByArgs("1", "json_unquote") - } - return nil + return verifyJSONArgsType(ctx, c.funcName, false, args, 0) } func (c *jsonUnquoteFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) { @@ -469,12 +509,7 @@ func (c *jsonMergeFunctionClass) verifyArgs(ctx EvalContext, args []Expression) if err := c.baseFunctionClass.verifyArgs(args); err != nil { return err } - for i, arg := range args { - if evalType := arg.GetType(ctx).EvalType(); evalType != types.ETString && evalType != types.ETJson { - return ErrInvalidTypeForJSON.GenWithStackByArgs(i+1, "json_merge") - } - } - return nil + return verifyJSONArgsType(ctx, c.funcName, true, args) } type builtinJSONMergeSig struct { @@ -682,10 +717,7 @@ func (c *jsonContainsPathFunctionClass) verifyArgs(ctx EvalContext, args []Expre if err := c.baseFunctionClass.verifyArgs(args); err != nil { return err } - if evalType := args[0].GetType(ctx).EvalType(); evalType != types.ETString && evalType != types.ETJson { - return ErrInvalidTypeForJSON.GenWithStackByArgs(1, "json_contains_path") - } - return nil + return verifyJSONArgsType(ctx, c.funcName, true, args, 0) } func (c *jsonContainsPathFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) { @@ -801,10 +833,7 @@ func (c *jsonMemberOfFunctionClass) verifyArgs(ctx EvalContext, args []Expressio if err := c.baseFunctionClass.verifyArgs(args); err != nil { return err } - if evalType := args[1].GetType(ctx).EvalType(); evalType != types.ETJson && evalType != types.ETString { - return ErrInvalidTypeForJSON.GenWithStackByArgs(2, "member of") - } - return nil + return verifyJSONArgsType(ctx, "member of", true, args, 1) } func (c *jsonMemberOfFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) { @@ -867,13 +896,7 @@ func (c *jsonContainsFunctionClass) verifyArgs(ctx EvalContext, args []Expressio if err := c.baseFunctionClass.verifyArgs(args); err != nil { return err } - if evalType := args[0].GetType(ctx).EvalType(); evalType != types.ETJson && evalType != types.ETString { - return ErrInvalidTypeForJSON.GenWithStackByArgs(1, "json_contains") - } - if evalType := args[1].GetType(ctx).EvalType(); evalType != types.ETJson && evalType != types.ETString { - return ErrInvalidTypeForJSON.GenWithStackByArgs(2, "json_contains") - } - return nil + return verifyJSONArgsType(ctx, c.funcName, true, args, 0, 1) } func (c *jsonContainsFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) { @@ -950,13 +973,7 @@ func (c *jsonOverlapsFunctionClass) verifyArgs(ctx EvalContext, args []Expressio if err := c.baseFunctionClass.verifyArgs(args); err != nil { return err } - if evalType := args[0].GetType(ctx).EvalType(); evalType != types.ETJson && evalType != types.ETString { - return ErrInvalidTypeForJSON.GenWithStackByArgs(1, "json_overlaps") - } - if evalType := args[1].GetType(ctx).EvalType(); evalType != types.ETJson && evalType != types.ETString { - return ErrInvalidTypeForJSON.GenWithStackByArgs(2, "json_overlaps") - } - return nil + return verifyJSONArgsType(ctx, c.funcName, true, args, 0, 1) } func (c *jsonOverlapsFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) { @@ -1284,12 +1301,7 @@ func (c *jsonMergePatchFunctionClass) verifyArgs(ctx EvalContext, args []Express if err := c.baseFunctionClass.verifyArgs(args); err != nil { return err } - for i, arg := range args { - if evalType := arg.GetType(ctx).EvalType(); evalType != types.ETString && evalType != types.ETJson { - return ErrInvalidTypeForJSON.GenWithStackByArgs(i+1, "json_merge_patch") - } - } - return nil + return verifyJSONArgsType(ctx, c.funcName, true, args) } func (c *jsonMergePatchFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) { @@ -1356,12 +1368,7 @@ func (c *jsonMergePreserveFunctionClass) verifyArgs(ctx EvalContext, args []Expr if err := c.baseFunctionClass.verifyArgs(args); err != nil { return err } - for i, arg := range args { - if evalType := arg.GetType(ctx).EvalType(); evalType != types.ETString && evalType != types.ETJson { - return ErrInvalidTypeForJSON.GenWithStackByArgs(i+1, "json_merge_preserve") - } - } - return nil + return verifyJSONArgsType(ctx, c.funcName, true, args) } func (c *jsonMergePreserveFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) { @@ -1510,10 +1517,7 @@ func (c *jsonSearchFunctionClass) verifyArgs(ctx EvalContext, args []Expression) if err := c.baseFunctionClass.verifyArgs(args); err != nil { return err } - if evalType := args[0].GetType(ctx).EvalType(); evalType != types.ETString && evalType != types.ETJson { - return ErrInvalidTypeForJSON.GenWithStackByArgs(1, "json_search") - } - return nil + return verifyJSONArgsType(ctx, c.funcName, true, args, 0) } func (c *jsonSearchFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) { @@ -1728,10 +1732,7 @@ func (c *jsonKeysFunctionClass) verifyArgs(ctx EvalContext, args []Expression) e if err := c.baseFunctionClass.verifyArgs(args); err != nil { return err } - if evalType := args[0].GetType(ctx).EvalType(); evalType != types.ETString && evalType != types.ETJson { - return ErrInvalidTypeForJSON.GenWithStackByArgs(1, "json_keys") - } - return nil + return verifyJSONArgsType(ctx, c.funcName, true, args, 0) } func (c *jsonKeysFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) { @@ -1903,11 +1904,9 @@ func (c *jsonSchemaValidFunctionClass) verifyArgs(ctx EvalContext, args []Expres if err := c.baseFunctionClass.verifyArgs(args); err != nil { return err } - if evalType := args[0].GetType(ctx).EvalType(); evalType != types.ETString && evalType != types.ETJson { - return ErrInvalidTypeForJSON.GenWithStackByArgs(1, "json_schema_valid") - } - if evalType := args[1].GetType(ctx).EvalType(); evalType != types.ETString && evalType != types.ETJson { - return ErrInvalidTypeForJSON.GenWithStackByArgs(2, "json_schema_valid") + + if err := verifyJSONArgsType(ctx, c.funcName, true, args, 0, 1); err != nil { + return err } if c, ok := args[0].(*Constant); ok { // If args[0] is NULL, then don't check the length of *both* arguments. diff --git a/pkg/parser/mysql/errcode.go b/pkg/parser/mysql/errcode.go index 05d5dc3e69d9a..bfbd196138de9 100644 --- a/pkg/parser/mysql/errcode.go +++ b/pkg/parser/mysql/errcode.go @@ -898,6 +898,7 @@ const ( ErrInvalidJSONText = 3140 ErrInvalidJSONTextInParam = 3141 ErrInvalidJSONPath = 3143 + ErrInvalidJSONCharset = 3144 ErrInvalidTypeForJSON = 3146 ErrInvalidJSONPathWildcard = 3149 ErrInvalidJSONContainsPathType = 3150 diff --git a/pkg/parser/mysql/errname.go b/pkg/parser/mysql/errname.go index f757f825da43e..845c50a21bcec 100644 --- a/pkg/parser/mysql/errname.go +++ b/pkg/parser/mysql/errname.go @@ -907,6 +907,7 @@ var MySQLErrName = map[uint16]*ErrMessage{ ErrInvalidJSONText: Message("Invalid JSON text: %-.192s", nil), ErrInvalidJSONTextInParam: Message("Invalid JSON text in argument %d to function %s: \"%s\" at position %d.", nil), ErrInvalidJSONPath: Message("Invalid JSON path expression %s.", nil), + ErrInvalidJSONCharset: Message("Cannot create a JSON value from a string with CHARACTER SET '%s'.", nil), ErrInvalidTypeForJSON: Message("Invalid data type for JSON data in argument %d to function %s; a JSON string or JSON type is required.", nil), ErrInvalidJSONPathWildcard: Message("In this situation, path expressions may not contain the * and ** tokens or an array range.", nil), ErrInvalidJSONContainsPathType: Message("The second argument can only be either 'one' or 'all'.", nil), diff --git a/pkg/parser/mysql/state.go b/pkg/parser/mysql/state.go index 2cbc6f1d2b168..307965d08838a 100644 --- a/pkg/parser/mysql/state.go +++ b/pkg/parser/mysql/state.go @@ -254,6 +254,7 @@ var MySQLState = map[uint16]string{ ErrInvalidJSONText: "22032", ErrInvalidJSONTextInParam: "22032", ErrInvalidJSONPath: "42000", + ErrInvalidJSONCharset: "22032", ErrInvalidJSONData: "22032", ErrInvalidJSONPathWildcard: "42000", ErrJSONUsedAsKey: "42000", diff --git a/pkg/types/field_type.go b/pkg/types/field_type.go index bfa84118d8c93..6a2e562c0b699 100644 --- a/pkg/types/field_type.go +++ b/pkg/types/field_type.go @@ -197,6 +197,10 @@ func InferParamTypeFromUnderlyingValue(value any, tp *FieldType) { tp.SetType(mysql.TypeVarString) tp.SetFlen(UnspecifiedLength) tp.SetDecimal(UnspecifiedLength) + // Also set the `charset` and `collation` for it, because some function (e.g. `json_object`) will return error + // if the argument collation is `binary`. + tp.SetCharset(mysql.DefaultCharset) + tp.SetCollate(mysql.DefaultCollationName) default: DefaultTypeForValue(value, tp, mysql.DefaultCharset, mysql.DefaultCollationName) if hasVariantFieldLength(tp) { diff --git a/tests/integrationtest/r/expression/json.result b/tests/integrationtest/r/expression/json.result index 2bb62441b94ce..f240ddda5ed90 100644 --- a/tests/integrationtest/r/expression/json.result +++ b/tests/integrationtest/r/expression/json.result @@ -518,9 +518,9 @@ select json_objectagg(a, b) from t; json_objectagg(a, b) {"a string": "base64:type252:YSBiaW5hcnkgc3RyaW5n"} select json_object(b, a) from t; -Error 3144 (HY000): Cannot create a JSON value from a string with CHARACTER SET 'binary'. +Error 3144 (22032): Cannot create a JSON value from a string with CHARACTER SET 'binary'. select json_objectagg(b, a) from t; -Error 3144 (HY000): Cannot create a JSON value from a string with CHARACTER SET 'binary'. +Error 3144 (22032): Cannot create a JSON value from a string with CHARACTER SET 'binary'. select cast(cast(b'010101' as json) as signed); cast(cast(b'010101' as json) as signed) 0 @@ -771,3 +771,49 @@ insert into t values (NULL, NULL, NULL); select json_valid(j), json_valid(str), json_valid(other) from t; json_valid(j) json_valid(str) json_valid(other) NULL NULL NULL +DROP TABLE IF EXISTS t1; +CREATE TABLE t1(id INT PRIMARY KEY, d1 DATE, d2 DATETIME, t1 TIME, t2 TIMESTAMP, b1 BIT, b2 BINARY); +INSERT INTO t1 VALUES (1, '2024-06-14', '2024-06-14 09:37:00', '09:37:00', '2024-06-14 09:37:00', b'0', 0x41); +SELECT JSON_TYPE(d1) FROM t1; +Error 3146 (22032): Invalid data type for JSON data in argument 1 to function json_type; a JSON string or JSON type is required. +SELECT JSON_TYPE(d2) FROM t1; +Error 3146 (22032): Invalid data type for JSON data in argument 1 to function json_type; a JSON string or JSON type is required. +SELECT JSON_TYPE(t1) FROM t1; +Error 3146 (22032): Invalid data type for JSON data in argument 1 to function json_type; a JSON string or JSON type is required. +SELECT JSON_TYPE(t2) FROM t1; +Error 3146 (22032): Invalid data type for JSON data in argument 1 to function json_type; a JSON string or JSON type is required. +SELECT JSON_TYPE(b1) FROM t1; +Error 3146 (22032): Invalid data type for JSON data in argument 1 to function json_type; a JSON string or JSON type is required. +SELECT JSON_TYPE(b2) FROM t1; +Error 3144 (22032): Cannot create a JSON value from a string with CHARACTER SET 'binary'. +SELECT JSON_EXTRACT(b2, '$') FROM t1; +Error 3144 (22032): Cannot create a JSON value from a string with CHARACTER SET 'binary'. +SELECT JSON_MERGE(b2, '{a:"b"}') FROM t1; +Error 3144 (22032): Cannot create a JSON value from a string with CHARACTER SET 'binary'. +SELECT JSON_CONTAINS_PATH(b2, 'one', '$.a') FROM t1; +Error 3144 (22032): Cannot create a JSON value from a string with CHARACTER SET 'binary'. +SELECT '1' member of(b2) FROM t1; +Error 3144 (22032): Cannot create a JSON value from a string with CHARACTER SET 'binary'. +SELECT JSON_CONTAINS(b2, '{a:"b"}') FROM t1; +Error 3144 (22032): Cannot create a JSON value from a string with CHARACTER SET 'binary'. +SELECT JSON_OVERLAPS(b2, '{a:"b"}') FROM t1; +Error 3144 (22032): Cannot create a JSON value from a string with CHARACTER SET 'binary'. +SELECT JSON_MERGE_PATCH(b2, '{a:"b"}') FROM t1; +Error 3144 (22032): Cannot create a JSON value from a string with CHARACTER SET 'binary'. +SELECT JSON_MERGE_PATCH('{a:"b"}', b2) FROM t1; +Error 3144 (22032): Cannot create a JSON value from a string with CHARACTER SET 'binary'. +SELECT JSON_MERGE_PRESERVE(b2, '{a:"b"}') FROM t1; +Error 3144 (22032): Cannot create a JSON value from a string with CHARACTER SET 'binary'. +SELECT JSON_MERGE_PRESERVE('{a:"b"}', b2) FROM t1; +Error 3144 (22032): Cannot create a JSON value from a string with CHARACTER SET 'binary'. +SELECT JSON_SEARCH(b2, 'one', '1') FROM t1; +Error 3144 (22032): Cannot create a JSON value from a string with CHARACTER SET 'binary'. +SELECT JSON_KEYS(b2) FROM t1; +Error 3144 (22032): Cannot create a JSON value from a string with CHARACTER SET 'binary'. +SELECT JSON_SCHEMA_VALID(b2, '{}') FROM t1; +Error 3144 (22032): Cannot create a JSON value from a string with CHARACTER SET 'binary'. +prepare stmt from 'select json_object(?, ?)'; +set @a=1; +execute stmt using @a, @a; +json_object(?, ?) +{"1": 1} diff --git a/tests/integrationtest/t/expression/json.test b/tests/integrationtest/t/expression/json.test index 15e255ce696a1..f2925c440fb85 100644 --- a/tests/integrationtest/t/expression/json.test +++ b/tests/integrationtest/t/expression/json.test @@ -489,4 +489,52 @@ select json_type(cast(cast('2024' as year) as json)); drop table if exists t; create table t(j json, str varchar(255), other int); insert into t values (NULL, NULL, NULL); -select json_valid(j), json_valid(str), json_valid(other) from t; \ No newline at end of file +select json_valid(j), json_valid(str), json_valid(other) from t; + +# TestIssue54029 +DROP TABLE IF EXISTS t1; +CREATE TABLE t1(id INT PRIMARY KEY, d1 DATE, d2 DATETIME, t1 TIME, t2 TIMESTAMP, b1 BIT, b2 BINARY); +INSERT INTO t1 VALUES (1, '2024-06-14', '2024-06-14 09:37:00', '09:37:00', '2024-06-14 09:37:00', b'0', 0x41); +-- error 3146 +SELECT JSON_TYPE(d1) FROM t1; +-- error 3146 +SELECT JSON_TYPE(d2) FROM t1; +-- error 3146 +SELECT JSON_TYPE(t1) FROM t1; +-- error 3146 +SELECT JSON_TYPE(t2) FROM t1; +-- error 3146 +SELECT JSON_TYPE(b1) FROM t1; +-- error 3144 +SELECT JSON_TYPE(b2) FROM t1; +-- error 3144 +SELECT JSON_EXTRACT(b2, '$') FROM t1; +-- error 3144 +SELECT JSON_MERGE(b2, '{a:"b"}') FROM t1; +-- error 3144 +SELECT JSON_CONTAINS_PATH(b2, 'one', '$.a') FROM t1; +-- error 3144 +SELECT '1' member of(b2) FROM t1; +-- error 3144 +SELECT JSON_CONTAINS(b2, '{a:"b"}') FROM t1; +-- error 3144 +SELECT JSON_OVERLAPS(b2, '{a:"b"}') FROM t1; +-- error 3144 +SELECT JSON_MERGE_PATCH(b2, '{a:"b"}') FROM t1; +-- error 3144 +SELECT JSON_MERGE_PATCH('{a:"b"}', b2) FROM t1; +-- error 3144 +SELECT JSON_MERGE_PRESERVE(b2, '{a:"b"}') FROM t1; +-- error 3144 +SELECT JSON_MERGE_PRESERVE('{a:"b"}', b2) FROM t1; +-- error 3144 +SELECT JSON_SEARCH(b2, 'one', '1') FROM t1; +-- error 3144 +SELECT JSON_KEYS(b2) FROM t1; +-- error 3144 +SELECT JSON_SCHEMA_VALID(b2, '{}') FROM t1; + +# TestIssue54044 +prepare stmt from 'select json_object(?, ?)'; +set @a=1; +execute stmt using @a, @a; \ No newline at end of file