Skip to content

Commit

Permalink
Merge branch 'main' into revert
Browse files Browse the repository at this point in the history
  • Loading branch information
badboynt1 authored Jul 30, 2024
2 parents 6a1c28f + 5a98c53 commit 2ae7c14
Show file tree
Hide file tree
Showing 6 changed files with 341 additions and 1 deletion.
2 changes: 1 addition & 1 deletion pkg/pb/statsinfo/statsinfo.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ func (sc *StatsInfo) NeedUpdate(currentApproxObjNum int64) bool {
}

func (sc *StatsInfo) Merge(newInfo *StatsInfo) {
if sc == nil {
if sc == nil || newInfo == nil {
return
}
// TODO: do not handle ShuffleRange for now.
Expand Down
85 changes: 85 additions & 0 deletions pkg/sql/plan/function/func_unary.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,11 @@ package function

import (
"context"
"crypto/aes"
"crypto/cipher"
"crypto/md5"
"crypto/sha1"
"crypto/sha256"
"encoding/base64"
"encoding/hex"
"fmt"
Expand Down Expand Up @@ -1118,6 +1121,88 @@ func octFloat[T constraints.Float](xs T) (types.Decimal128, error) {
return res, nil
}

func generateSHAKey(key []byte) []byte {
// return 32 bytes SHA256 checksum of the key
hash := sha256.Sum256(key)
return hash[:]
}

func generateInitializationVector(key []byte, length int) []byte {
data := append(key, byte(length))
hash := sha256.Sum256(data)
return hash[:aes.BlockSize]
}

// encode function encrypts a string, returns a binary string of the same length of the original string.
// https://dev.mysql.com/doc/refman/5.7/en/encryption-functions.html#function_encode
func encodeByAES(plaintext []byte, key []byte, null bool, rs *vector.FunctionResult[types.Varlena]) error {
if null {
return rs.AppendMustNullForBytesResult()
}
fixedKey := generateSHAKey(key)
block, err := aes.NewCipher(fixedKey)
if err != nil {
return err
}
initializationVector := generateInitializationVector(key, len(plaintext))
ciphertext := make([]byte, len(plaintext))
stream := cipher.NewCTR(block, initializationVector)
stream.XORKeyStream(ciphertext, plaintext)
return rs.AppendMustBytesValue(ciphertext)
}

func Encode(parameters []*vector.Vector, result vector.FunctionResultWrapper, proc *process.Process, length int, selectList *FunctionSelectList) error {
source := vector.GenerateFunctionStrParameter(parameters[0])
key := vector.GenerateFunctionStrParameter(parameters[1])
rs := vector.MustFunctionResult[types.Varlena](result)

rowCount := uint64(length)
for i := uint64(0); i < rowCount; i++ {
data, nullData := source.GetStrValue(i)
keyData, nullKey := key.GetStrValue(i)
if err := encodeByAES(data, keyData, nullData || nullKey, rs); err != nil {
return err
}
}

return nil
}

// decode function decodes an encoded string and returns the original string
// https://dev.mysql.com/doc/refman/5.7/en/encryption-functions.html#function_decode
func decodeByAES(ciphertext []byte, key []byte, null bool, rs *vector.FunctionResult[types.Varlena]) error {
if null {
return rs.AppendMustNullForBytesResult()
}
fixedKey := generateSHAKey(key)
block, err := aes.NewCipher(fixedKey)
if err != nil {
return err
}
iv := generateInitializationVector(key, len(ciphertext))
plaintext := make([]byte, len(ciphertext))
stream := cipher.NewCTR(block, iv)
stream.XORKeyStream(plaintext, ciphertext)
return rs.AppendMustBytesValue(plaintext)
}

func Decode(parameters []*vector.Vector, result vector.FunctionResultWrapper, proc *process.Process, length int, selectList *FunctionSelectList) error {
source := vector.GenerateFunctionStrParameter(parameters[0])
key := vector.GenerateFunctionStrParameter(parameters[1])
rs := vector.MustFunctionResult[types.Varlena](result)

rowCount := uint64(length)
for i := uint64(0); i < rowCount; i++ {
data, nullData := source.GetStrValue(i)
keyData, nullKey := key.GetStrValue(i)
if err := decodeByAES(data, keyData, nullData || nullKey, rs); err != nil {
return err
}
}

return nil
}

func DateToMonth(ivecs []*vector.Vector, result vector.FunctionResultWrapper, proc *process.Process, length int, selectList *FunctionSelectList) error {
return opUnaryFixedToFixed[types.Date, uint8](ivecs, result, proc, length, func(v types.Date) uint8 {
return v.Month()
Expand Down
161 changes: 161 additions & 0 deletions pkg/sql/plan/function/func_unary_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ package function

import (
"context"
"encoding/hex"
"fmt"
"math"
"testing"
Expand Down Expand Up @@ -2876,6 +2877,166 @@ func TestOctInt64(t *testing.T) {
//TODO: Previous OctFloat didn't have testcase. Should we add new testcases?
}

func TestDecode(t *testing.T) {
testCases := initDecodeTestCase()

proc := testutil.NewProcess()
for _, tc := range testCases {
fcTC := NewFunctionTestCase(proc, tc.inputs, tc.expect, Decode)
s, info := fcTC.Run()
require.True(t, s, fmt.Sprintf("case is '%s', err info is '%s'", tc.info, info))
}
}

func initDecodeTestCase() []tcTemp {
regularCases := []struct {
info string
data []string
keys []string
wants []string
}{
{
info: "test decode - simple text",
data: []string{
"",
"MatrixOne",
"MatrixOne",
"MatrixOne123",
"MatrixOne#%$%^",
"MatrixOne",
"分布式データベース",
"MatrixOne",
"MatrixOne数据库",
},
keys: []string{
"",
"1234567890123456",
"asdfjasfwefjfjkj",
"123456789012345678901234",
"*^%YTu1234567",
"",
"pass1234@#$%%^^&",
"密匙",
"数据库passwd12345667",
},
wants: []string{
"",
"973F9E44B6330489C7",
"BDE957D76C42800E16",
"928248DD2211D7DB886AD0FE",
"A5A0BE100EB06512E4422A51DC9C",
"549D65E48BD9A29CE9",
"D1D6913ED82E228022A08CD2DCB8869118819FECFE2008176625BB",
"6B406CBF644FCB9BCA",
"34B8B67B8C4EDF31009142BC6346E3C32B0C",
},
},
}

var testInputs = make([]tcTemp, 0, len(regularCases))
for _, c := range regularCases {
realWants := make([]string, len(c.wants))
for i, want := range c.wants {
bytes, err := hex.DecodeString(want)
if err != nil {
fmt.Printf("decode string error: %v", err)
}

realWants[i] = string(bytes)
}
testInputs = append(testInputs, tcTemp{
info: c.info,
inputs: []FunctionTestInput{
NewFunctionTestInput(types.T_blob.ToType(), c.data, []bool{}),
NewFunctionTestInput(types.T_varchar.ToType(), c.keys, []bool{}),
},
expect: NewFunctionTestResult(types.T_blob.ToType(), false, realWants, []bool{}),
})
}

return testInputs
}

func TestEncode(t *testing.T) {
testCases := initEncodeTestCase()

proc := testutil.NewProcess()
for _, tc := range testCases {
fcTC := NewFunctionTestCase(proc, tc.inputs, tc.expect, Encode)
s, info := fcTC.Run()
require.True(t, s, fmt.Sprintf("case is '%s', err info is '%s'", tc.info, info))
}
}

func initEncodeTestCase() []tcTemp {
regularCases := []struct {
info string
data []string
keys []string
wants []string
}{
{
info: "test encode - simple text",
data: []string{
"",
"MatrixOne",
"MatrixOne",
"MatrixOne123",
"MatrixOne#%$%^",
"MatrixOne",
"分布式データベース",
"MatrixOne",
"MatrixOne数据库",
},
keys: []string{
"",
"1234567890123456",
"asdfjasfwefjfjkj",
"123456789012345678901234",
"*^%YTu1234567",
"",
"pass1234@#$%%^^&",
"密匙",
"数据库passwd12345667",
},
wants: []string{
"",
"973F9E44B6330489C7",
"BDE957D76C42800E16",
"928248DD2211D7DB886AD0FE",
"A5A0BE100EB06512E4422A51DC9C",
"549D65E48BD9A29CE9",
"D1D6913ED82E228022A08CD2DCB8869118819FECFE2008176625BB",
"6B406CBF644FCB9BCA",
"34B8B67B8C4EDF31009142BC6346E3C32B0C",
},
},
}

var testInputs = make([]tcTemp, 0, len(regularCases))
for _, c := range regularCases {
realWants := make([]string, len(c.wants))
for i, want := range c.wants {
bytes, err := hex.DecodeString(want)
if err != nil {
fmt.Printf("decode string error: %v", err)
}

realWants[i] = string(bytes)
}
testInputs = append(testInputs, tcTemp{
info: c.info,
inputs: []FunctionTestInput{
NewFunctionTestInput(types.T_varchar.ToType(), c.data, []bool{}),
NewFunctionTestInput(types.T_varchar.ToType(), c.keys, []bool{}),
},
expect: NewFunctionTestResult(types.T_blob.ToType(), false, realWants, []bool{}),
})
}

return testInputs
}

// Month

func initDateToMonthTestCase() []tcTemp {
Expand Down
42 changes: 42 additions & 0 deletions pkg/sql/plan/function/list_builtIn.go
Original file line number Diff line number Diff line change
Expand Up @@ -1723,6 +1723,48 @@ var supportedStringBuiltIns = []FuncNew{
},
},

// function `encode`
{
functionId: ENCODE,
class: plan.Function_STRICT,
layout: STANDARD_FUNCTION,
checkFn: fixedTypeMatch,

Overloads: []overload{
{
overloadId: 0,
args: []types.T{types.T_varchar, types.T_varchar},
retType: func(parameters []types.Type) types.Type {
return types.T_blob.ToType()
},
newOp: func() executeLogicOfOverload {
return Encode
},
},
},
},

// function `decode`
{
functionId: DECODE,
class: plan.Function_STRICT,
layout: STANDARD_FUNCTION,
checkFn: fixedTypeMatch,

Overloads: []overload{
{
overloadId: 0,
args: []types.T{types.T_blob, types.T_varchar},
retType: func(parameters []types.Type) types.Type {
return types.T_varchar.ToType()
},
newOp: func() executeLogicOfOverload {
return Decode
},
},
},
},

// function `trim`
{
functionId: TRIM,
Expand Down
39 changes: 39 additions & 0 deletions test/distributed/cases/function/func_decode_encode.result
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
SELECT DECODE(ENCODE('Hello, World!', 'mysecretkey'), 'mysecretkey');
DECODE(ENCODE(Hello, World!, mysecretkey), mysecretkey)
Hello, World!
SELECT DECODE(ENCODE('', ''), '');
DECODE(ENCODE(, ), )

SELECT DECODE(ENCODE('MatrixOne', '1234567890123456'), '1234567890123456');
DECODE(ENCODE(MatrixOne, 1234567890123456), 1234567890123456)
MatrixOne
SELECT DECODE(ENCODE('MatrixOne', 'asdfjasfwefjfjkj'), 'asdfjasfwefjfjkj');
DECODE(ENCODE(MatrixOne, asdfjasfwefjfjkj), asdfjasfwefjfjkj)
MatrixOne
SELECT DECODE(ENCODE('MatrixOne123', '123456789012345678901234'), '123456789012345678901234');
DECODE(ENCODE(MatrixOne123, 123456789012345678901234), 123456789012345678901234)
MatrixOne123
SELECT DECODE(ENCODE('MatrixOne#%$%^', '*^%YTu1234567'), '*^%YTu1234567');
DECODE(ENCODE(MatrixOne#%$%^, *^%YTu1234567), *^%YTu1234567)
MatrixOne#%$%^
SELECT DECODE(ENCODE('MatrixOne', ''), '');
DECODE(ENCODE(MatrixOne, ), )
MatrixOne
SELECT DECODE(ENCODE('分布式データベース', 'pass1234@#$%%^^&'), 'pass1234@#$%%^^&');
DECODE(ENCODE(分布式データベース, pass1234@#$%%^^&), pass1234@#$%%^^&)
分布式データベース
SELECT DECODE(ENCODE('分布式データベース', '分布式7782734adgwy1242'), '分布式7782734adgwy1242');
DECODE(ENCODE(分布式データベース, 分布式7782734adgwy1242), 分布式7782734adgwy1242)
分布式データベース
SELECT DECODE(ENCODE('MatrixOne', '密匙'), '密匙');
DECODE(ENCODE(MatrixOne, 密匙), 密匙)
MatrixOne
SELECT DECODE(ENCODE('MatrixOne数据库', '数据库passwd12345667'), '数据库passwd12345667');
DECODE(ENCODE(MatrixOne数据库, 数据库passwd12345667), 数据库passwd12345667)
MatrixOne数据库
SELECT HEX(ENCODE('MatrixOne数据库', '数据库passwd12345667'));
HEX(ENCODE(MatrixOne数据库, 数据库passwd12345667))
34B8B67B8C4EDF31009142BC6346E3C32B0C
SELECT HEX(ENCODE('mytext','mykeystring'));
HEX(ENCODE(mytext, mykeystring))
562C778A6367
13 changes: 13 additions & 0 deletions test/distributed/cases/function/func_decode_encode.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
SELECT DECODE(ENCODE('Hello, World!', 'mysecretkey'), 'mysecretkey');
SELECT DECODE(ENCODE('', ''), '');
SELECT DECODE(ENCODE('MatrixOne', '1234567890123456'), '1234567890123456');
SELECT DECODE(ENCODE('MatrixOne', 'asdfjasfwefjfjkj'), 'asdfjasfwefjfjkj');
SELECT DECODE(ENCODE('MatrixOne123', '123456789012345678901234'), '123456789012345678901234');
SELECT DECODE(ENCODE('MatrixOne#%$%^', '*^%YTu1234567'), '*^%YTu1234567');
SELECT DECODE(ENCODE('MatrixOne', ''), '');
SELECT DECODE(ENCODE('分布式データベース', 'pass1234@#$%%^^&'), 'pass1234@#$%%^^&');
SELECT DECODE(ENCODE('分布式データベース', '分布式7782734adgwy1242'), '分布式7782734adgwy1242');
SELECT DECODE(ENCODE('MatrixOne', '密匙'), '密匙');
SELECT DECODE(ENCODE('MatrixOne数据库', '数据库passwd12345667'), '数据库passwd12345667');
SELECT HEX(ENCODE('MatrixOne数据库', '数据库passwd12345667'));
SELECT HEX(ENCODE('mytext','mykeystring'));

0 comments on commit 2ae7c14

Please sign in to comment.