From 77df14174ab3480b5e51846909264e203f843de2 Mon Sep 17 00:00:00 2001 From: LiuBo Date: Mon, 29 Jul 2024 23:29:59 +0800 Subject: [PATCH 1/2] [bug] stats: fix nil pointer (#17767) fix nil pointer in Merge function of stats info Approved by: @zhangxu19830126 --- pkg/pb/statsinfo/statsinfo.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/pb/statsinfo/statsinfo.go b/pkg/pb/statsinfo/statsinfo.go index 5013df42dd188..1967584ed4517 100644 --- a/pkg/pb/statsinfo/statsinfo.go +++ b/pkg/pb/statsinfo/statsinfo.go @@ -34,7 +34,7 @@ func (sc *StatsInfo) NeedUpdate(currentApproxObjNum int64) bool { } func (sc *StatsInfo) Merge(newInfo *StatsInfo) { - if sc == nil { + if sc == nil || newInfo == nil { return } // TODO: do not handle ShuffleRange for now. From 5a98c53750d720c61b4c902003c515af25f21b92 Mon Sep 17 00:00:00 2001 From: Charles Chi Le Date: Tue, 30 Jul 2024 01:02:56 +0800 Subject: [PATCH 2/2] Implement MySQL Encode()/Decode() function by AES (#17568) Implement Encode() and Decode() functions by AES Approved by: @fengttt, @heni02, @m-schen --- pkg/sql/plan/function/func_unary.go | 85 +++++++++ pkg/sql/plan/function/func_unary_test.go | 161 ++++++++++++++++++ pkg/sql/plan/function/list_builtIn.go | 42 +++++ .../cases/function/func_decode_encode.result | 39 +++++ .../cases/function/func_decode_encode.sql | 13 ++ 5 files changed, 340 insertions(+) create mode 100644 test/distributed/cases/function/func_decode_encode.result create mode 100644 test/distributed/cases/function/func_decode_encode.sql diff --git a/pkg/sql/plan/function/func_unary.go b/pkg/sql/plan/function/func_unary.go index 13188c5a7aa73..47114d29e36bc 100644 --- a/pkg/sql/plan/function/func_unary.go +++ b/pkg/sql/plan/function/func_unary.go @@ -16,8 +16,11 @@ package function import ( "context" + "crypto/aes" + "crypto/cipher" "crypto/md5" "crypto/sha1" + "crypto/sha256" "encoding/base64" "encoding/hex" "fmt" @@ -1118,6 +1121,88 @@ func octFloat[T constraints.Float](xs T) (types.Decimal128, error) { return res, nil } +func generateSHAKey(key []byte) []byte { + // return 32 bytes SHA256 checksum of the key + hash := sha256.Sum256(key) + return hash[:] +} + +func generateInitializationVector(key []byte, length int) []byte { + data := append(key, byte(length)) + hash := sha256.Sum256(data) + return hash[:aes.BlockSize] +} + +// encode function encrypts a string, returns a binary string of the same length of the original string. +// https://dev.mysql.com/doc/refman/5.7/en/encryption-functions.html#function_encode +func encodeByAES(plaintext []byte, key []byte, null bool, rs *vector.FunctionResult[types.Varlena]) error { + if null { + return rs.AppendMustNullForBytesResult() + } + fixedKey := generateSHAKey(key) + block, err := aes.NewCipher(fixedKey) + if err != nil { + return err + } + initializationVector := generateInitializationVector(key, len(plaintext)) + ciphertext := make([]byte, len(plaintext)) + stream := cipher.NewCTR(block, initializationVector) + stream.XORKeyStream(ciphertext, plaintext) + return rs.AppendMustBytesValue(ciphertext) +} + +func Encode(parameters []*vector.Vector, result vector.FunctionResultWrapper, proc *process.Process, length int, selectList *FunctionSelectList) error { + source := vector.GenerateFunctionStrParameter(parameters[0]) + key := vector.GenerateFunctionStrParameter(parameters[1]) + rs := vector.MustFunctionResult[types.Varlena](result) + + rowCount := uint64(length) + for i := uint64(0); i < rowCount; i++ { + data, nullData := source.GetStrValue(i) + keyData, nullKey := key.GetStrValue(i) + if err := encodeByAES(data, keyData, nullData || nullKey, rs); err != nil { + return err + } + } + + return nil +} + +// decode function decodes an encoded string and returns the original string +// https://dev.mysql.com/doc/refman/5.7/en/encryption-functions.html#function_decode +func decodeByAES(ciphertext []byte, key []byte, null bool, rs *vector.FunctionResult[types.Varlena]) error { + if null { + return rs.AppendMustNullForBytesResult() + } + fixedKey := generateSHAKey(key) + block, err := aes.NewCipher(fixedKey) + if err != nil { + return err + } + iv := generateInitializationVector(key, len(ciphertext)) + plaintext := make([]byte, len(ciphertext)) + stream := cipher.NewCTR(block, iv) + stream.XORKeyStream(plaintext, ciphertext) + return rs.AppendMustBytesValue(plaintext) +} + +func Decode(parameters []*vector.Vector, result vector.FunctionResultWrapper, proc *process.Process, length int, selectList *FunctionSelectList) error { + source := vector.GenerateFunctionStrParameter(parameters[0]) + key := vector.GenerateFunctionStrParameter(parameters[1]) + rs := vector.MustFunctionResult[types.Varlena](result) + + rowCount := uint64(length) + for i := uint64(0); i < rowCount; i++ { + data, nullData := source.GetStrValue(i) + keyData, nullKey := key.GetStrValue(i) + if err := decodeByAES(data, keyData, nullData || nullKey, rs); err != nil { + return err + } + } + + return nil +} + func DateToMonth(ivecs []*vector.Vector, result vector.FunctionResultWrapper, proc *process.Process, length int, selectList *FunctionSelectList) error { return opUnaryFixedToFixed[types.Date, uint8](ivecs, result, proc, length, func(v types.Date) uint8 { return v.Month() diff --git a/pkg/sql/plan/function/func_unary_test.go b/pkg/sql/plan/function/func_unary_test.go index a25f5cc32380d..7156bdf2df982 100644 --- a/pkg/sql/plan/function/func_unary_test.go +++ b/pkg/sql/plan/function/func_unary_test.go @@ -16,6 +16,7 @@ package function import ( "context" + "encoding/hex" "fmt" "math" "testing" @@ -2876,6 +2877,166 @@ func TestOctInt64(t *testing.T) { //TODO: Previous OctFloat didn't have testcase. Should we add new testcases? } +func TestDecode(t *testing.T) { + testCases := initDecodeTestCase() + + proc := testutil.NewProcess() + for _, tc := range testCases { + fcTC := NewFunctionTestCase(proc, tc.inputs, tc.expect, Decode) + s, info := fcTC.Run() + require.True(t, s, fmt.Sprintf("case is '%s', err info is '%s'", tc.info, info)) + } +} + +func initDecodeTestCase() []tcTemp { + regularCases := []struct { + info string + data []string + keys []string + wants []string + }{ + { + info: "test decode - simple text", + data: []string{ + "", + "MatrixOne", + "MatrixOne", + "MatrixOne123", + "MatrixOne#%$%^", + "MatrixOne", + "分布式データベース", + "MatrixOne", + "MatrixOne数据库", + }, + keys: []string{ + "", + "1234567890123456", + "asdfjasfwefjfjkj", + "123456789012345678901234", + "*^%YTu1234567", + "", + "pass1234@#$%%^^&", + "密匙", + "数据库passwd12345667", + }, + wants: []string{ + "", + "973F9E44B6330489C7", + "BDE957D76C42800E16", + "928248DD2211D7DB886AD0FE", + "A5A0BE100EB06512E4422A51DC9C", + "549D65E48BD9A29CE9", + "D1D6913ED82E228022A08CD2DCB8869118819FECFE2008176625BB", + "6B406CBF644FCB9BCA", + "34B8B67B8C4EDF31009142BC6346E3C32B0C", + }, + }, + } + + var testInputs = make([]tcTemp, 0, len(regularCases)) + for _, c := range regularCases { + realWants := make([]string, len(c.wants)) + for i, want := range c.wants { + bytes, err := hex.DecodeString(want) + if err != nil { + fmt.Printf("decode string error: %v", err) + } + + realWants[i] = string(bytes) + } + testInputs = append(testInputs, tcTemp{ + info: c.info, + inputs: []FunctionTestInput{ + NewFunctionTestInput(types.T_blob.ToType(), c.data, []bool{}), + NewFunctionTestInput(types.T_varchar.ToType(), c.keys, []bool{}), + }, + expect: NewFunctionTestResult(types.T_blob.ToType(), false, realWants, []bool{}), + }) + } + + return testInputs +} + +func TestEncode(t *testing.T) { + testCases := initEncodeTestCase() + + proc := testutil.NewProcess() + for _, tc := range testCases { + fcTC := NewFunctionTestCase(proc, tc.inputs, tc.expect, Encode) + s, info := fcTC.Run() + require.True(t, s, fmt.Sprintf("case is '%s', err info is '%s'", tc.info, info)) + } +} + +func initEncodeTestCase() []tcTemp { + regularCases := []struct { + info string + data []string + keys []string + wants []string + }{ + { + info: "test encode - simple text", + data: []string{ + "", + "MatrixOne", + "MatrixOne", + "MatrixOne123", + "MatrixOne#%$%^", + "MatrixOne", + "分布式データベース", + "MatrixOne", + "MatrixOne数据库", + }, + keys: []string{ + "", + "1234567890123456", + "asdfjasfwefjfjkj", + "123456789012345678901234", + "*^%YTu1234567", + "", + "pass1234@#$%%^^&", + "密匙", + "数据库passwd12345667", + }, + wants: []string{ + "", + "973F9E44B6330489C7", + "BDE957D76C42800E16", + "928248DD2211D7DB886AD0FE", + "A5A0BE100EB06512E4422A51DC9C", + "549D65E48BD9A29CE9", + "D1D6913ED82E228022A08CD2DCB8869118819FECFE2008176625BB", + "6B406CBF644FCB9BCA", + "34B8B67B8C4EDF31009142BC6346E3C32B0C", + }, + }, + } + + var testInputs = make([]tcTemp, 0, len(regularCases)) + for _, c := range regularCases { + realWants := make([]string, len(c.wants)) + for i, want := range c.wants { + bytes, err := hex.DecodeString(want) + if err != nil { + fmt.Printf("decode string error: %v", err) + } + + realWants[i] = string(bytes) + } + testInputs = append(testInputs, tcTemp{ + info: c.info, + inputs: []FunctionTestInput{ + NewFunctionTestInput(types.T_varchar.ToType(), c.data, []bool{}), + NewFunctionTestInput(types.T_varchar.ToType(), c.keys, []bool{}), + }, + expect: NewFunctionTestResult(types.T_blob.ToType(), false, realWants, []bool{}), + }) + } + + return testInputs +} + // Month func initDateToMonthTestCase() []tcTemp { diff --git a/pkg/sql/plan/function/list_builtIn.go b/pkg/sql/plan/function/list_builtIn.go index b8aad62b66451..86ddf61b01015 100644 --- a/pkg/sql/plan/function/list_builtIn.go +++ b/pkg/sql/plan/function/list_builtIn.go @@ -1723,6 +1723,48 @@ var supportedStringBuiltIns = []FuncNew{ }, }, + // function `encode` + { + functionId: ENCODE, + class: plan.Function_STRICT, + layout: STANDARD_FUNCTION, + checkFn: fixedTypeMatch, + + Overloads: []overload{ + { + overloadId: 0, + args: []types.T{types.T_varchar, types.T_varchar}, + retType: func(parameters []types.Type) types.Type { + return types.T_blob.ToType() + }, + newOp: func() executeLogicOfOverload { + return Encode + }, + }, + }, + }, + + // function `decode` + { + functionId: DECODE, + class: plan.Function_STRICT, + layout: STANDARD_FUNCTION, + checkFn: fixedTypeMatch, + + Overloads: []overload{ + { + overloadId: 0, + args: []types.T{types.T_blob, types.T_varchar}, + retType: func(parameters []types.Type) types.Type { + return types.T_varchar.ToType() + }, + newOp: func() executeLogicOfOverload { + return Decode + }, + }, + }, + }, + // function `trim` { functionId: TRIM, diff --git a/test/distributed/cases/function/func_decode_encode.result b/test/distributed/cases/function/func_decode_encode.result new file mode 100644 index 0000000000000..fa71f8b89b0a8 --- /dev/null +++ b/test/distributed/cases/function/func_decode_encode.result @@ -0,0 +1,39 @@ +SELECT DECODE(ENCODE('Hello, World!', 'mysecretkey'), 'mysecretkey'); +DECODE(ENCODE(Hello, World!, mysecretkey), mysecretkey) +Hello, World! +SELECT DECODE(ENCODE('', ''), ''); +DECODE(ENCODE(, ), ) + +SELECT DECODE(ENCODE('MatrixOne', '1234567890123456'), '1234567890123456'); +DECODE(ENCODE(MatrixOne, 1234567890123456), 1234567890123456) +MatrixOne +SELECT DECODE(ENCODE('MatrixOne', 'asdfjasfwefjfjkj'), 'asdfjasfwefjfjkj'); +DECODE(ENCODE(MatrixOne, asdfjasfwefjfjkj), asdfjasfwefjfjkj) +MatrixOne +SELECT DECODE(ENCODE('MatrixOne123', '123456789012345678901234'), '123456789012345678901234'); +DECODE(ENCODE(MatrixOne123, 123456789012345678901234), 123456789012345678901234) +MatrixOne123 +SELECT DECODE(ENCODE('MatrixOne#%$%^', '*^%YTu1234567'), '*^%YTu1234567'); +DECODE(ENCODE(MatrixOne#%$%^, *^%YTu1234567), *^%YTu1234567) +MatrixOne#%$%^ +SELECT DECODE(ENCODE('MatrixOne', ''), ''); +DECODE(ENCODE(MatrixOne, ), ) +MatrixOne +SELECT DECODE(ENCODE('分布式データベース', 'pass1234@#$%%^^&'), 'pass1234@#$%%^^&'); +DECODE(ENCODE(分布式データベース, pass1234@#$%%^^&), pass1234@#$%%^^&) +分布式データベース +SELECT DECODE(ENCODE('分布式データベース', '分布式7782734adgwy1242'), '分布式7782734adgwy1242'); +DECODE(ENCODE(分布式データベース, 分布式7782734adgwy1242), 分布式7782734adgwy1242) +分布式データベース +SELECT DECODE(ENCODE('MatrixOne', '密匙'), '密匙'); +DECODE(ENCODE(MatrixOne, 密匙), 密匙) +MatrixOne +SELECT DECODE(ENCODE('MatrixOne数据库', '数据库passwd12345667'), '数据库passwd12345667'); +DECODE(ENCODE(MatrixOne数据库, 数据库passwd12345667), 数据库passwd12345667) +MatrixOne数据库 +SELECT HEX(ENCODE('MatrixOne数据库', '数据库passwd12345667')); +HEX(ENCODE(MatrixOne数据库, 数据库passwd12345667)) +34B8B67B8C4EDF31009142BC6346E3C32B0C +SELECT HEX(ENCODE('mytext','mykeystring')); +HEX(ENCODE(mytext, mykeystring)) +562C778A6367 \ No newline at end of file diff --git a/test/distributed/cases/function/func_decode_encode.sql b/test/distributed/cases/function/func_decode_encode.sql new file mode 100644 index 0000000000000..0966cf3dc4b42 --- /dev/null +++ b/test/distributed/cases/function/func_decode_encode.sql @@ -0,0 +1,13 @@ +SELECT DECODE(ENCODE('Hello, World!', 'mysecretkey'), 'mysecretkey'); +SELECT DECODE(ENCODE('', ''), ''); +SELECT DECODE(ENCODE('MatrixOne', '1234567890123456'), '1234567890123456'); +SELECT DECODE(ENCODE('MatrixOne', 'asdfjasfwefjfjkj'), 'asdfjasfwefjfjkj'); +SELECT DECODE(ENCODE('MatrixOne123', '123456789012345678901234'), '123456789012345678901234'); +SELECT DECODE(ENCODE('MatrixOne#%$%^', '*^%YTu1234567'), '*^%YTu1234567'); +SELECT DECODE(ENCODE('MatrixOne', ''), ''); +SELECT DECODE(ENCODE('分布式データベース', 'pass1234@#$%%^^&'), 'pass1234@#$%%^^&'); +SELECT DECODE(ENCODE('分布式データベース', '分布式7782734adgwy1242'), '分布式7782734adgwy1242'); +SELECT DECODE(ENCODE('MatrixOne', '密匙'), '密匙'); +SELECT DECODE(ENCODE('MatrixOne数据库', '数据库passwd12345667'), '数据库passwd12345667'); +SELECT HEX(ENCODE('MatrixOne数据库', '数据库passwd12345667')); +SELECT HEX(ENCODE('mytext','mykeystring')); \ No newline at end of file