Skip to content

Commit

Permalink
support member of function
Browse files Browse the repository at this point in the history
  • Loading branch information
xiongjiwei committed Dec 14, 2022
1 parent b41be06 commit a91d22f
Show file tree
Hide file tree
Showing 6 changed files with 148 additions and 23 deletions.
1 change: 1 addition & 0 deletions expression/builtin.go
Original file line number Diff line number Diff line change
Expand Up @@ -873,6 +873,7 @@ var funcs = map[string]functionClass{
ast.JSONMerge: &jsonMergeFunctionClass{baseFunctionClass{ast.JSONMerge, 2, -1}},
ast.JSONObject: &jsonObjectFunctionClass{baseFunctionClass{ast.JSONObject, 0, -1}},
ast.JSONArray: &jsonArrayFunctionClass{baseFunctionClass{ast.JSONArray, 0, -1}},
ast.JSONMemberOf: &jsonMemberOfFunctionClass{baseFunctionClass{ast.JSONMemberOf, 2, 2}},
ast.JSONContains: &jsonContainsFunctionClass{baseFunctionClass{ast.JSONContains, 2, 3}},
ast.JSONOverlaps: &jsonOverlapsFunctionClass{baseFunctionClass{ast.JSONOverlaps, 2, 2}},
ast.JSONContainsPath: &jsonContainsPathFunctionClass{baseFunctionClass{ast.JSONContainsPath, 3, -1}},
Expand Down
74 changes: 74 additions & 0 deletions expression/builtin_json.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ var (
_ functionClass = &jsonMergeFunctionClass{}
_ functionClass = &jsonObjectFunctionClass{}
_ functionClass = &jsonArrayFunctionClass{}
_ functionClass = &jsonMemberOfFunctionClass{}
_ functionClass = &jsonContainsFunctionClass{}
_ functionClass = &jsonOverlapsFunctionClass{}
_ functionClass = &jsonContainsPathFunctionClass{}
Expand Down Expand Up @@ -72,6 +73,7 @@ var (
_ builtinFunc = &builtinJSONReplaceSig{}
_ builtinFunc = &builtinJSONRemoveSig{}
_ builtinFunc = &builtinJSONMergeSig{}
_ builtinFunc = &builtinJSONMemberOfSig{}
_ builtinFunc = &builtinJSONContainsSig{}
_ builtinFunc = &builtinJSONOverlapsSig{}
_ builtinFunc = &builtinJSONStorageSizeSig{}
Expand Down Expand Up @@ -742,6 +744,78 @@ func jsonModify(ctx sessionctx.Context, args []Expression, row chunk.Row, mt typ
return res, false, nil
}

type jsonMemberOfFunctionClass struct {
baseFunctionClass
}

type builtinJSONMemberOfSig struct {
baseBuiltinFunc
}

func (b *builtinJSONMemberOfSig) Clone() builtinFunc {
newSig := &builtinJSONMemberOfSig{}
newSig.cloneFrom(&b.baseBuiltinFunc)
return newSig
}

func (c *jsonMemberOfFunctionClass) verifyArgs(args []Expression) error {
if err := c.baseFunctionClass.verifyArgs(args); err != nil {
return err
}
if evalType := args[1].GetType().EvalType(); evalType != types.ETJson && evalType != types.ETString {
return types.ErrInvalidJSONData.GenWithStackByArgs(2, "member of")
}
return nil
}

func (c *jsonMemberOfFunctionClass) getFunction(ctx sessionctx.Context, args []Expression) (builtinFunc, error) {
if err := c.verifyArgs(args); err != nil {
return nil, err
}

argTps := []types.EvalType{args[0].GetType().EvalType(), types.ETJson}
bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETInt, argTps...)
if err != nil {
return nil, err
}
sig := &builtinJSONMemberOfSig{bf}
return sig, nil
}

func (b *builtinJSONMemberOfSig) evalInt(row chunk.Row) (res int64, isNull bool, err error) {
var target types.BinaryJSON
if b.args[0].GetType().EvalType() != types.ETJson {
eval, err := b.args[0].Eval(row)
if err != nil || eval.IsNull() {
return 0, eval.IsNull(), err
}
target = types.CreateBinaryJSON(eval.GetValue())
} else {
target, isNull, err = b.args[0].EvalJSON(b.ctx, row)
}

if isNull || err != nil {
return res, isNull, err
}
obj, isNull, err := b.args[1].EvalJSON(b.ctx, row)
if isNull || err != nil {
return res, isNull, err
}

if obj.TypeCode != types.JSONTypeCodeArray {
return boolToInt64(types.CompareBinaryJSON(obj, target) == 0), false, nil
}

elemCount := obj.GetElemCount()
for i := 0; i < elemCount; i++ {
if types.CompareBinaryJSON(obj.ArrayGetElem(i), target) == 0 {
return 1, false, nil
}
}

return 0, false, nil
}

type jsonContainsFunctionClass struct {
baseFunctionClass
}
Expand Down
42 changes: 42 additions & 0 deletions expression/builtin_json_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,48 @@ func TestJSONRemove(t *testing.T) {
}
}

func TestJSONMemberOf(t *testing.T) {
ctx := createContext(t)
fc := funcs[ast.JSONMemberOf]
tbl := []struct {
input []interface{}
expected interface{}
err error
}{
{[]interface{}{`1`, `a:1`}, 1, types.ErrInvalidJSONText},

{[]interface{}{1, `[1, 2]`}, 1, nil},
{[]interface{}{1, `[1]`}, 1, nil},
{[]interface{}{1, `[0]`}, 0, nil},
{[]interface{}{1, `[1]`}, 1, nil},
{[]interface{}{1, `[[1]]`}, 0, nil},
{[]interface{}{"1", `[1]`}, 0, nil},
{[]interface{}{"1", `["1"]`}, 1, nil},
{[]interface{}{`{"a":1}`, `{"a":1}`}, 0, nil},
{[]interface{}{`{"a":1}`, `[{"a":1}]`}, 0, nil},
{[]interface{}{`{"a":1}`, `[{"a":1}, 1]`}, 0, nil},
{[]interface{}{`{"a":1}`, `["{\"a\":1}"]`}, 1, nil},
{[]interface{}{`{"a":1}`, `["{\"a\":1}", 1]`}, 1, nil},
}
for _, tt := range tbl {
args := types.MakeDatums(tt.input...)
f, err := fc.getFunction(ctx, datumsToConstants(args))
require.NoError(t, err, tt.input)
d, err := evalBuiltinFunc(f, chunk.Row{})
if tt.err == nil {
require.NoError(t, err, tt.input)
if tt.expected == nil {
require.True(t, d.IsNull(), tt.input)
} else {
require.Equal(t, int64(tt.expected.(int)), d.GetInt64(), tt.input)
}
} else {
require.True(t, tt.err.(*terror.Error).Equal(err), tt.input)
}
}
}


func TestJSONContains(t *testing.T) {
ctx := createContext(t)
fc := funcs[ast.JSONContains]
Expand Down
7 changes: 7 additions & 0 deletions expression/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2734,6 +2734,13 @@ func TestFuncJSON(t *testing.T) {
tk.MustExec("insert into tx1 values (1, 0.1, 0.2, 0.3, 0.0)")
tk.MustQuery("select a+b, c from tx1").Check(testkit.Rows("0.30000000000000004 0.3"))
tk.MustQuery("select json_array(a+b) = json_array(c) from tx1").Check(testkit.Rows("0"))

tk.MustQuery("SELECT '{\"a\":1}' MEMBER OF('{\"a\":1}');").Check(testkit.Rows("0"))
tk.MustQuery("SELECT '{\"a\":1}' MEMBER OF('[{\"a\":1}]');").Check(testkit.Rows("0"))
tk.MustQuery("SELECT 1 MEMBER OF('1');").Check(testkit.Rows("1"))
tk.MustQuery("SELECT '{\"a\":1}' MEMBER OF('{\"a\":1}');").Check(testkit.Rows("0"))
tk.MustQuery("SELECT '[4,5]' MEMBER OF('[[3,4],[4,5]]');").Check(testkit.Rows("0"))
tk.MustQuery("SELECT '[4,5]' MEMBER OF('[[3,4],\"[4,5]\"]');").Check(testkit.Rows("1"))
}

func TestColumnInfoModified(t *testing.T) {
Expand Down
7 changes: 4 additions & 3 deletions types/json_binary.go
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,8 @@ func (bj BinaryJSON) GetElemCount() int {
return int(jsonEndian.Uint32(bj.Value))
}

func (bj BinaryJSON) arrayGetElem(idx int) BinaryJSON {
// ArrayGetElem gets the element of the index `idx`.
func (bj BinaryJSON) ArrayGetElem(idx int) BinaryJSON {
return bj.valEntryGet(headerSize + idx*valEntrySize)
}

Expand Down Expand Up @@ -355,7 +356,7 @@ func (bj BinaryJSON) marshalArrayTo(buf []byte) ([]byte, error) {
buf = append(buf, ", "...)
}
var err error
buf, err = bj.arrayGetElem(i).marshalTo(buf)
buf, err = bj.ArrayGetElem(i).marshalTo(buf)
if err != nil {
return nil, errors.Trace(err)
}
Expand Down Expand Up @@ -557,7 +558,7 @@ func (bj BinaryJSON) HashValue(buf []byte) []byte {
elemCount := int(jsonEndian.Uint32(bj.Value))
buf = append(buf, bj.Value[0:dataSizeOff]...)
for i := 0; i < elemCount; i++ {
buf = bj.arrayGetElem(i).HashValue(buf)
buf = bj.ArrayGetElem(i).HashValue(buf)
}
case JSONTypeCodeObject:
// this hash value is bidirectional, because you can get the key using the json
Expand Down
40 changes: 20 additions & 20 deletions types/json_binary_functions.go
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,7 @@ func (bj BinaryJSON) extractTo(buf []BinaryJSON, pathExpr JSONPathExpression, du
start, end := currentLeg.arraySelection.getIndexRange(bj)
if start >= 0 && start <= end {
for i := start; i <= end; i++ {
buf = bj.arrayGetElem(i).extractTo(buf, subPathExpr, dup, one)
buf = bj.ArrayGetElem(i).extractTo(buf, subPathExpr, dup, one)
}
}
} else if currentLeg.typ == jsonPathLegKey && bj.TypeCode == JSONTypeCodeObject {
Expand All @@ -314,7 +314,7 @@ func (bj BinaryJSON) extractTo(buf []BinaryJSON, pathExpr JSONPathExpression, du
if bj.TypeCode == JSONTypeCodeArray {
elemCount := bj.GetElemCount()
for i := 0; i < elemCount && !jsonFinished(buf, one); i++ {
buf = bj.arrayGetElem(i).extractTo(buf, pathExpr, dup, one)
buf = bj.ArrayGetElem(i).extractTo(buf, pathExpr, dup, one)
}
} else if bj.TypeCode == JSONTypeCodeObject {
elemCount := bj.GetElemCount()
Expand Down Expand Up @@ -459,12 +459,12 @@ func (bj BinaryJSON) ArrayInsert(pathExpr JSONPathExpression, value BinaryJSON)
// Insert into the array
newArray := make([]BinaryJSON, 0, count+1)
for i := 0; i < idx; i++ {
elem := obj.arrayGetElem(i)
elem := obj.ArrayGetElem(i)
newArray = append(newArray, elem)
}
newArray = append(newArray, value)
for i := idx; i < count; i++ {
elem := obj.arrayGetElem(i)
elem := obj.ArrayGetElem(i)
newArray = append(newArray, elem)
}
obj = buildBinaryJSONArray(newArray)
Expand Down Expand Up @@ -556,7 +556,7 @@ func (bm *binaryModifier) doInsert(path JSONPathExpression, newBj BinaryJSON) {
elemCount := parentBj.GetElemCount()
elems := make([]BinaryJSON, 0, elemCount+1)
for i := 0; i < elemCount; i++ {
elems = append(elems, parentBj.arrayGetElem(i))
elems = append(elems, parentBj.ArrayGetElem(i))
}
elems = append(elems, newBj)
bm.modifyValue = buildBinaryJSONArray(elems)
Expand Down Expand Up @@ -622,7 +622,7 @@ func (bm *binaryModifier) doRemove(path JSONPathExpression) {
elems := make([]BinaryJSON, 0, elemCount-1)
for i := 0; i < elemCount; i++ {
if i != idx {
elems = append(elems, parentBj.arrayGetElem(i))
elems = append(elems, parentBj.ArrayGetElem(i))
}
}
bm.modifyValue = buildBinaryJSONArray(elems)
Expand Down Expand Up @@ -809,8 +809,8 @@ func CompareBinaryJSON(left, right BinaryJSON) int {
leftCount := left.GetElemCount()
rightCount := right.GetElemCount()
for i := 0; i < leftCount && i < rightCount; i++ {
elem1 := left.arrayGetElem(i)
elem2 := right.arrayGetElem(i)
elem1 := left.ArrayGetElem(i)
elem2 := right.ArrayGetElem(i)
cmp = CompareBinaryJSON(elem1, elem2)
if cmp != 0 {
return cmp
Expand Down Expand Up @@ -993,7 +993,7 @@ func mergeBinaryArray(elems []BinaryJSON) BinaryJSON {
} else {
childCount := elem.GetElemCount()
for j := 0; j < childCount; j++ {
buf = append(buf, elem.arrayGetElem(j))
buf = append(buf, elem.ArrayGetElem(j))
}
}
}
Expand Down Expand Up @@ -1088,15 +1088,15 @@ func ContainsBinaryJSON(obj, target BinaryJSON) bool {
if target.TypeCode == JSONTypeCodeArray {
elemCount := target.GetElemCount()
for i := 0; i < elemCount; i++ {
if !ContainsBinaryJSON(obj, target.arrayGetElem(i)) {
if !ContainsBinaryJSON(obj, target.ArrayGetElem(i)) {
return false
}
}
return true
}
elemCount := obj.GetElemCount()
for i := 0; i < elemCount; i++ {
if ContainsBinaryJSON(obj.arrayGetElem(i), target) {
if ContainsBinaryJSON(obj.ArrayGetElem(i), target) {
return true
}
}
Expand Down Expand Up @@ -1175,7 +1175,7 @@ func (bj BinaryJSON) GetElemDepth() int {
elemCount := bj.GetElemCount()
maxDepth := 0
for i := 0; i < elemCount; i++ {
obj := bj.arrayGetElem(i)
obj := bj.ArrayGetElem(i)
depth := obj.GetElemDepth()
if depth > maxDepth {
maxDepth = depth
Expand Down Expand Up @@ -1246,19 +1246,19 @@ func (bj BinaryJSON) extractToCallback(pathExpr JSONPathExpression, callbackFn e
switch selection := currentLeg.arraySelection.(type) {
case jsonPathArraySelectionAsterisk:
for i := 0; i < elemCount; i++ {
// buf = bj.arrayGetElem(i).extractTo(buf, subPathExpr)
// buf = bj.ArrayGetElem(i).extractTo(buf, subPathExpr)
path := fullpath.pushBackOneArraySelectionLeg(jsonPathArraySelectionIndex{jsonPathArrayIndexFromStart(i)})
stop, err = bj.arrayGetElem(i).extractToCallback(subPathExpr, callbackFn, path)
stop, err = bj.ArrayGetElem(i).extractToCallback(subPathExpr, callbackFn, path)
if stop || err != nil {
return
}
}
case jsonPathArraySelectionIndex:
idx := selection.index.getIndexFromStart(bj)
if idx < elemCount && idx >= 0 {
// buf = bj.arrayGetElem(currentLeg.arraySelection).extractTo(buf, subPathExpr)
// buf = bj.ArrayGetElem(currentLeg.arraySelection).extractTo(buf, subPathExpr)
path := fullpath.pushBackOneArraySelectionLeg(currentLeg.arraySelection)
stop, err = bj.arrayGetElem(idx).extractToCallback(subPathExpr, callbackFn, path)
stop, err = bj.ArrayGetElem(idx).extractToCallback(subPathExpr, callbackFn, path)
if stop || err != nil {
return
}
Expand All @@ -1272,7 +1272,7 @@ func (bj BinaryJSON) extractToCallback(pathExpr JSONPathExpression, callbackFn e
if start <= end && start >= 0 {
for i := start; i <= end; i++ {
path := fullpath.pushBackOneArraySelectionLeg(jsonPathArraySelectionIndex{jsonPathArrayIndexFromStart(i)})
stop, err = bj.arrayGetElem(i).extractToCallback(subPathExpr, callbackFn, path)
stop, err = bj.ArrayGetElem(i).extractToCallback(subPathExpr, callbackFn, path)
if stop || err != nil {
return
}
Expand Down Expand Up @@ -1311,9 +1311,9 @@ func (bj BinaryJSON) extractToCallback(pathExpr JSONPathExpression, callbackFn e
if bj.TypeCode == JSONTypeCodeArray {
elemCount := bj.GetElemCount()
for i := 0; i < elemCount; i++ {
// buf = bj.arrayGetElem(i).extractTo(buf, pathExpr)
// buf = bj.ArrayGetElem(i).extractTo(buf, pathExpr)
path := fullpath.pushBackOneArraySelectionLeg(jsonPathArraySelectionIndex{jsonPathArrayIndexFromStart(i)})
stop, err = bj.arrayGetElem(i).extractToCallback(pathExpr, callbackFn, path)
stop, err = bj.ArrayGetElem(i).extractToCallback(pathExpr, callbackFn, path)
if stop || err != nil {
return
}
Expand Down Expand Up @@ -1357,7 +1357,7 @@ func (bj BinaryJSON) Walk(walkFn BinaryJSONWalkFunc, pathExprList ...JSONPathExp
elemCount := bj.GetElemCount()
for i := 0; i < elemCount; i++ {
path := fullpath.pushBackOneArraySelectionLeg(jsonPathArraySelectionIndex{jsonPathArrayIndexFromStart(i)})
stop, err = doWalk(path, bj.arrayGetElem(i))
stop, err = doWalk(path, bj.ArrayGetElem(i))
if stop || err != nil {
return
}
Expand Down

0 comments on commit a91d22f

Please sign in to comment.