Skip to content

Commit

Permalink
expression: fix the arg verification for json functions. (#54145)
Browse files Browse the repository at this point in the history
close #54029, close #54044
  • Loading branch information
YangKeao authored Nov 20, 2024
1 parent 865b283 commit db4d19b
Show file tree
Hide file tree
Showing 7 changed files with 165 additions and 65 deletions.
123 changes: 61 additions & 62 deletions pkg/expression/builtin_json.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"bytes"
"context"
goJSON "encoding/json"
"strconv"
"strings"

"github.com/pingcap/errors"
Expand Down Expand Up @@ -108,7 +109,7 @@ func (b *builtinJSONTypeSig) Clone() builtinFunc {
}

func (c *jsonTypeFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) {
if err := c.verifyArgs(args); err != nil {
if err := c.verifyArgs(ctx.GetEvalCtx(), args); err != nil {
return nil, err
}
bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETString, types.ETJson)
Expand All @@ -125,6 +126,51 @@ func (c *jsonTypeFunctionClass) getFunction(ctx BuildContext, args []Expression)
return sig, nil
}

func (c *jsonTypeFunctionClass) verifyArgs(ctx EvalContext, args []Expression) error {
if err := c.baseFunctionClass.verifyArgs(args); err != nil {
return err
}
return verifyJSONArgsType(ctx, c.funcName, true, args, 0)
}

// verifyJSONArgsType verifies that all args specified in `jsonArgsIndex` are JSON or non-binary string or NULL.
// the `useJSONErr` specifies to use `ErrIncorrectType` or `ErrInvalidTypeForJSON`. If it's true, the error will be `ErrInvalidTypeForJSON`
func verifyJSONArgsType(ctx EvalContext, funcName string, useJSONErr bool, args []Expression, jsonArgsIndex ...int) error {
if jsonArgsIndex == nil {
// if no index is specified, verify all args
jsonArgsIndex = make([]int, len(args))
for i := 0; i < len(args); i++ {
jsonArgsIndex[i] = i
}
}
for _, argIndex := range jsonArgsIndex {
arg := args[argIndex]

typ := arg.GetType(ctx)
if typ.GetType() == mysql.TypeNull {
continue
}

evalType := typ.EvalType()
switch evalType {
case types.ETString:
cs := typ.GetCharset()
if cs == charset.CharsetBin {
return types.ErrInvalidJSONCharset.GenWithStackByArgs(cs)
}
continue
case types.ETJson:
continue
default:
if useJSONErr {
return ErrInvalidTypeForJSON.GenWithStackByArgs(argIndex+1, funcName)
}
return ErrIncorrectType.GenWithStackByArgs(strconv.Itoa(argIndex+1), funcName)
}
}
return nil
}

func (b *builtinJSONTypeSig) evalString(ctx EvalContext, row chunk.Row) (val string, isNull bool, err error) {
var j types.BinaryJSON
j, isNull, err = b.args[0].EvalJSON(ctx, row)
Expand Down Expand Up @@ -155,10 +201,7 @@ func (c *jsonExtractFunctionClass) verifyArgs(ctx EvalContext, args []Expression
if err := c.baseFunctionClass.verifyArgs(args); err != nil {
return err
}
if evalType := args[0].GetType(ctx).EvalType(); evalType != types.ETString && evalType != types.ETJson {
return ErrInvalidTypeForJSON.GenWithStackByArgs(1, "json_extract")
}
return nil
return verifyJSONArgsType(ctx, c.funcName, true, args, 0)
}

func (c *jsonExtractFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) {
Expand Down Expand Up @@ -225,10 +268,7 @@ func (c *jsonUnquoteFunctionClass) verifyArgs(ctx EvalContext, args []Expression
if err := c.baseFunctionClass.verifyArgs(args); err != nil {
return err
}
if evalType := args[0].GetType(ctx).EvalType(); evalType != types.ETString && evalType != types.ETJson {
return ErrIncorrectType.GenWithStackByArgs("1", "json_unquote")
}
return nil
return verifyJSONArgsType(ctx, c.funcName, false, args, 0)
}

func (c *jsonUnquoteFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) {
Expand Down Expand Up @@ -469,12 +509,7 @@ func (c *jsonMergeFunctionClass) verifyArgs(ctx EvalContext, args []Expression)
if err := c.baseFunctionClass.verifyArgs(args); err != nil {
return err
}
for i, arg := range args {
if evalType := arg.GetType(ctx).EvalType(); evalType != types.ETString && evalType != types.ETJson {
return ErrInvalidTypeForJSON.GenWithStackByArgs(i+1, "json_merge")
}
}
return nil
return verifyJSONArgsType(ctx, c.funcName, true, args)
}

type builtinJSONMergeSig struct {
Expand Down Expand Up @@ -682,10 +717,7 @@ func (c *jsonContainsPathFunctionClass) verifyArgs(ctx EvalContext, args []Expre
if err := c.baseFunctionClass.verifyArgs(args); err != nil {
return err
}
if evalType := args[0].GetType(ctx).EvalType(); evalType != types.ETString && evalType != types.ETJson {
return ErrInvalidTypeForJSON.GenWithStackByArgs(1, "json_contains_path")
}
return nil
return verifyJSONArgsType(ctx, c.funcName, true, args, 0)
}

func (c *jsonContainsPathFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) {
Expand Down Expand Up @@ -801,10 +833,7 @@ func (c *jsonMemberOfFunctionClass) verifyArgs(ctx EvalContext, args []Expressio
if err := c.baseFunctionClass.verifyArgs(args); err != nil {
return err
}
if evalType := args[1].GetType(ctx).EvalType(); evalType != types.ETJson && evalType != types.ETString {
return ErrInvalidTypeForJSON.GenWithStackByArgs(2, "member of")
}
return nil
return verifyJSONArgsType(ctx, "member of", true, args, 1)
}

func (c *jsonMemberOfFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) {
Expand Down Expand Up @@ -867,13 +896,7 @@ func (c *jsonContainsFunctionClass) verifyArgs(ctx EvalContext, args []Expressio
if err := c.baseFunctionClass.verifyArgs(args); err != nil {
return err
}
if evalType := args[0].GetType(ctx).EvalType(); evalType != types.ETJson && evalType != types.ETString {
return ErrInvalidTypeForJSON.GenWithStackByArgs(1, "json_contains")
}
if evalType := args[1].GetType(ctx).EvalType(); evalType != types.ETJson && evalType != types.ETString {
return ErrInvalidTypeForJSON.GenWithStackByArgs(2, "json_contains")
}
return nil
return verifyJSONArgsType(ctx, c.funcName, true, args, 0, 1)
}

func (c *jsonContainsFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) {
Expand Down Expand Up @@ -950,13 +973,7 @@ func (c *jsonOverlapsFunctionClass) verifyArgs(ctx EvalContext, args []Expressio
if err := c.baseFunctionClass.verifyArgs(args); err != nil {
return err
}
if evalType := args[0].GetType(ctx).EvalType(); evalType != types.ETJson && evalType != types.ETString {
return ErrInvalidTypeForJSON.GenWithStackByArgs(1, "json_overlaps")
}
if evalType := args[1].GetType(ctx).EvalType(); evalType != types.ETJson && evalType != types.ETString {
return ErrInvalidTypeForJSON.GenWithStackByArgs(2, "json_overlaps")
}
return nil
return verifyJSONArgsType(ctx, c.funcName, true, args, 0, 1)
}

func (c *jsonOverlapsFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) {
Expand Down Expand Up @@ -1284,12 +1301,7 @@ func (c *jsonMergePatchFunctionClass) verifyArgs(ctx EvalContext, args []Express
if err := c.baseFunctionClass.verifyArgs(args); err != nil {
return err
}
for i, arg := range args {
if evalType := arg.GetType(ctx).EvalType(); evalType != types.ETString && evalType != types.ETJson {
return ErrInvalidTypeForJSON.GenWithStackByArgs(i+1, "json_merge_patch")
}
}
return nil
return verifyJSONArgsType(ctx, c.funcName, true, args)
}

func (c *jsonMergePatchFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) {
Expand Down Expand Up @@ -1356,12 +1368,7 @@ func (c *jsonMergePreserveFunctionClass) verifyArgs(ctx EvalContext, args []Expr
if err := c.baseFunctionClass.verifyArgs(args); err != nil {
return err
}
for i, arg := range args {
if evalType := arg.GetType(ctx).EvalType(); evalType != types.ETString && evalType != types.ETJson {
return ErrInvalidTypeForJSON.GenWithStackByArgs(i+1, "json_merge_preserve")
}
}
return nil
return verifyJSONArgsType(ctx, c.funcName, true, args)
}

func (c *jsonMergePreserveFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) {
Expand Down Expand Up @@ -1510,10 +1517,7 @@ func (c *jsonSearchFunctionClass) verifyArgs(ctx EvalContext, args []Expression)
if err := c.baseFunctionClass.verifyArgs(args); err != nil {
return err
}
if evalType := args[0].GetType(ctx).EvalType(); evalType != types.ETString && evalType != types.ETJson {
return ErrInvalidTypeForJSON.GenWithStackByArgs(1, "json_search")
}
return nil
return verifyJSONArgsType(ctx, c.funcName, true, args, 0)
}

func (c *jsonSearchFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) {
Expand Down Expand Up @@ -1728,10 +1732,7 @@ func (c *jsonKeysFunctionClass) verifyArgs(ctx EvalContext, args []Expression) e
if err := c.baseFunctionClass.verifyArgs(args); err != nil {
return err
}
if evalType := args[0].GetType(ctx).EvalType(); evalType != types.ETString && evalType != types.ETJson {
return ErrInvalidTypeForJSON.GenWithStackByArgs(1, "json_keys")
}
return nil
return verifyJSONArgsType(ctx, c.funcName, true, args, 0)
}

func (c *jsonKeysFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) {
Expand Down Expand Up @@ -1903,11 +1904,9 @@ func (c *jsonSchemaValidFunctionClass) verifyArgs(ctx EvalContext, args []Expres
if err := c.baseFunctionClass.verifyArgs(args); err != nil {
return err
}
if evalType := args[0].GetType(ctx).EvalType(); evalType != types.ETString && evalType != types.ETJson {
return ErrInvalidTypeForJSON.GenWithStackByArgs(1, "json_schema_valid")
}
if evalType := args[1].GetType(ctx).EvalType(); evalType != types.ETString && evalType != types.ETJson {
return ErrInvalidTypeForJSON.GenWithStackByArgs(2, "json_schema_valid")

if err := verifyJSONArgsType(ctx, c.funcName, true, args, 0, 1); err != nil {
return err
}
if c, ok := args[0].(*Constant); ok {
// If args[0] is NULL, then don't check the length of *both* arguments.
Expand Down
1 change: 1 addition & 0 deletions pkg/parser/mysql/errcode.go
Original file line number Diff line number Diff line change
Expand Up @@ -898,6 +898,7 @@ const (
ErrInvalidJSONText = 3140
ErrInvalidJSONTextInParam = 3141
ErrInvalidJSONPath = 3143
ErrInvalidJSONCharset = 3144
ErrInvalidTypeForJSON = 3146
ErrInvalidJSONPathWildcard = 3149
ErrInvalidJSONContainsPathType = 3150
Expand Down
1 change: 1 addition & 0 deletions pkg/parser/mysql/errname.go
Original file line number Diff line number Diff line change
Expand Up @@ -907,6 +907,7 @@ var MySQLErrName = map[uint16]*ErrMessage{
ErrInvalidJSONText: Message("Invalid JSON text: %-.192s", nil),
ErrInvalidJSONTextInParam: Message("Invalid JSON text in argument %d to function %s: \"%s\" at position %d.", nil),
ErrInvalidJSONPath: Message("Invalid JSON path expression %s.", nil),
ErrInvalidJSONCharset: Message("Cannot create a JSON value from a string with CHARACTER SET '%s'.", nil),
ErrInvalidTypeForJSON: Message("Invalid data type for JSON data in argument %d to function %s; a JSON string or JSON type is required.", nil),
ErrInvalidJSONPathWildcard: Message("In this situation, path expressions may not contain the * and ** tokens or an array range.", nil),
ErrInvalidJSONContainsPathType: Message("The second argument can only be either 'one' or 'all'.", nil),
Expand Down
1 change: 1 addition & 0 deletions pkg/parser/mysql/state.go
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,7 @@ var MySQLState = map[uint16]string{
ErrInvalidJSONText: "22032",
ErrInvalidJSONTextInParam: "22032",
ErrInvalidJSONPath: "42000",
ErrInvalidJSONCharset: "22032",
ErrInvalidJSONData: "22032",
ErrInvalidJSONPathWildcard: "42000",
ErrJSONUsedAsKey: "42000",
Expand Down
4 changes: 4 additions & 0 deletions pkg/types/field_type.go
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,10 @@ func InferParamTypeFromUnderlyingValue(value any, tp *FieldType) {
tp.SetType(mysql.TypeVarString)
tp.SetFlen(UnspecifiedLength)
tp.SetDecimal(UnspecifiedLength)
// Also set the `charset` and `collation` for it, because some function (e.g. `json_object`) will return error
// if the argument collation is `binary`.
tp.SetCharset(mysql.DefaultCharset)
tp.SetCollate(mysql.DefaultCollationName)
default:
DefaultTypeForValue(value, tp, mysql.DefaultCharset, mysql.DefaultCollationName)
if hasVariantFieldLength(tp) {
Expand Down
50 changes: 48 additions & 2 deletions tests/integrationtest/r/expression/json.result
Original file line number Diff line number Diff line change
Expand Up @@ -518,9 +518,9 @@ select json_objectagg(a, b) from t;
json_objectagg(a, b)
{"a string": "base64:type252:YSBiaW5hcnkgc3RyaW5n"}
select json_object(b, a) from t;
Error 3144 (HY000): Cannot create a JSON value from a string with CHARACTER SET 'binary'.
Error 3144 (22032): Cannot create a JSON value from a string with CHARACTER SET 'binary'.
select json_objectagg(b, a) from t;
Error 3144 (HY000): Cannot create a JSON value from a string with CHARACTER SET 'binary'.
Error 3144 (22032): Cannot create a JSON value from a string with CHARACTER SET 'binary'.
select cast(cast(b'010101' as json) as signed);
cast(cast(b'010101' as json) as signed)
0
Expand Down Expand Up @@ -771,3 +771,49 @@ insert into t values (NULL, NULL, NULL);
select json_valid(j), json_valid(str), json_valid(other) from t;
json_valid(j) json_valid(str) json_valid(other)
NULL NULL NULL
DROP TABLE IF EXISTS t1;
CREATE TABLE t1(id INT PRIMARY KEY, d1 DATE, d2 DATETIME, t1 TIME, t2 TIMESTAMP, b1 BIT, b2 BINARY);
INSERT INTO t1 VALUES (1, '2024-06-14', '2024-06-14 09:37:00', '09:37:00', '2024-06-14 09:37:00', b'0', 0x41);
SELECT JSON_TYPE(d1) FROM t1;
Error 3146 (22032): Invalid data type for JSON data in argument 1 to function json_type; a JSON string or JSON type is required.
SELECT JSON_TYPE(d2) FROM t1;
Error 3146 (22032): Invalid data type for JSON data in argument 1 to function json_type; a JSON string or JSON type is required.
SELECT JSON_TYPE(t1) FROM t1;
Error 3146 (22032): Invalid data type for JSON data in argument 1 to function json_type; a JSON string or JSON type is required.
SELECT JSON_TYPE(t2) FROM t1;
Error 3146 (22032): Invalid data type for JSON data in argument 1 to function json_type; a JSON string or JSON type is required.
SELECT JSON_TYPE(b1) FROM t1;
Error 3146 (22032): Invalid data type for JSON data in argument 1 to function json_type; a JSON string or JSON type is required.
SELECT JSON_TYPE(b2) FROM t1;
Error 3144 (22032): Cannot create a JSON value from a string with CHARACTER SET 'binary'.
SELECT JSON_EXTRACT(b2, '$') FROM t1;
Error 3144 (22032): Cannot create a JSON value from a string with CHARACTER SET 'binary'.
SELECT JSON_MERGE(b2, '{a:"b"}') FROM t1;
Error 3144 (22032): Cannot create a JSON value from a string with CHARACTER SET 'binary'.
SELECT JSON_CONTAINS_PATH(b2, 'one', '$.a') FROM t1;
Error 3144 (22032): Cannot create a JSON value from a string with CHARACTER SET 'binary'.
SELECT '1' member of(b2) FROM t1;
Error 3144 (22032): Cannot create a JSON value from a string with CHARACTER SET 'binary'.
SELECT JSON_CONTAINS(b2, '{a:"b"}') FROM t1;
Error 3144 (22032): Cannot create a JSON value from a string with CHARACTER SET 'binary'.
SELECT JSON_OVERLAPS(b2, '{a:"b"}') FROM t1;
Error 3144 (22032): Cannot create a JSON value from a string with CHARACTER SET 'binary'.
SELECT JSON_MERGE_PATCH(b2, '{a:"b"}') FROM t1;
Error 3144 (22032): Cannot create a JSON value from a string with CHARACTER SET 'binary'.
SELECT JSON_MERGE_PATCH('{a:"b"}', b2) FROM t1;
Error 3144 (22032): Cannot create a JSON value from a string with CHARACTER SET 'binary'.
SELECT JSON_MERGE_PRESERVE(b2, '{a:"b"}') FROM t1;
Error 3144 (22032): Cannot create a JSON value from a string with CHARACTER SET 'binary'.
SELECT JSON_MERGE_PRESERVE('{a:"b"}', b2) FROM t1;
Error 3144 (22032): Cannot create a JSON value from a string with CHARACTER SET 'binary'.
SELECT JSON_SEARCH(b2, 'one', '1') FROM t1;
Error 3144 (22032): Cannot create a JSON value from a string with CHARACTER SET 'binary'.
SELECT JSON_KEYS(b2) FROM t1;
Error 3144 (22032): Cannot create a JSON value from a string with CHARACTER SET 'binary'.
SELECT JSON_SCHEMA_VALID(b2, '{}') FROM t1;
Error 3144 (22032): Cannot create a JSON value from a string with CHARACTER SET 'binary'.
prepare stmt from 'select json_object(?, ?)';
set @a=1;
execute stmt using @a, @a;
json_object(?, ?)
{"1": 1}
50 changes: 49 additions & 1 deletion tests/integrationtest/t/expression/json.test
Original file line number Diff line number Diff line change
Expand Up @@ -489,4 +489,52 @@ select json_type(cast(cast('2024' as year) as json));
drop table if exists t;
create table t(j json, str varchar(255), other int);
insert into t values (NULL, NULL, NULL);
select json_valid(j), json_valid(str), json_valid(other) from t;
select json_valid(j), json_valid(str), json_valid(other) from t;

# TestIssue54029
DROP TABLE IF EXISTS t1;
CREATE TABLE t1(id INT PRIMARY KEY, d1 DATE, d2 DATETIME, t1 TIME, t2 TIMESTAMP, b1 BIT, b2 BINARY);
INSERT INTO t1 VALUES (1, '2024-06-14', '2024-06-14 09:37:00', '09:37:00', '2024-06-14 09:37:00', b'0', 0x41);
-- error 3146
SELECT JSON_TYPE(d1) FROM t1;
-- error 3146
SELECT JSON_TYPE(d2) FROM t1;
-- error 3146
SELECT JSON_TYPE(t1) FROM t1;
-- error 3146
SELECT JSON_TYPE(t2) FROM t1;
-- error 3146
SELECT JSON_TYPE(b1) FROM t1;
-- error 3144
SELECT JSON_TYPE(b2) FROM t1;
-- error 3144
SELECT JSON_EXTRACT(b2, '$') FROM t1;
-- error 3144
SELECT JSON_MERGE(b2, '{a:"b"}') FROM t1;
-- error 3144
SELECT JSON_CONTAINS_PATH(b2, 'one', '$.a') FROM t1;
-- error 3144
SELECT '1' member of(b2) FROM t1;
-- error 3144
SELECT JSON_CONTAINS(b2, '{a:"b"}') FROM t1;
-- error 3144
SELECT JSON_OVERLAPS(b2, '{a:"b"}') FROM t1;
-- error 3144
SELECT JSON_MERGE_PATCH(b2, '{a:"b"}') FROM t1;
-- error 3144
SELECT JSON_MERGE_PATCH('{a:"b"}', b2) FROM t1;
-- error 3144
SELECT JSON_MERGE_PRESERVE(b2, '{a:"b"}') FROM t1;
-- error 3144
SELECT JSON_MERGE_PRESERVE('{a:"b"}', b2) FROM t1;
-- error 3144
SELECT JSON_SEARCH(b2, 'one', '1') FROM t1;
-- error 3144
SELECT JSON_KEYS(b2) FROM t1;
-- error 3144
SELECT JSON_SCHEMA_VALID(b2, '{}') FROM t1;

# TestIssue54044
prepare stmt from 'select json_object(?, ?)';
set @a=1;
execute stmt using @a, @a;

0 comments on commit db4d19b

Please sign in to comment.