Skip to content

Commit

Permalink
planner: provide a unified interface to set and get user variables an…
Browse files Browse the repository at this point in the history
…d types (pingcap#37046)

ref pingcap#36598
  • Loading branch information
fzzf678 authored Aug 11, 2022
1 parent fb1b850 commit 7d8c45a
Show file tree
Hide file tree
Showing 10 changed files with 97 additions and 106 deletions.
4 changes: 1 addition & 3 deletions executor/load_data.go
Original file line number Diff line number Diff line change
Expand Up @@ -625,13 +625,11 @@ func (e *LoadDataInfo) colsToRow(ctx context.Context, cols []field) []types.Datu
row := make([]types.Datum, 0, len(e.insertColumns))
sessionVars := e.Ctx.GetSessionVars()
setVar := func(name string, col *field) {
sessionVars.UsersLock.Lock()
if col == nil || col.isNull() {
sessionVars.UnsetUserVar(name)
} else {
sessionVars.SetUserVar(name, string(col.str), mysql.DefaultCollationName)
sessionVars.SetStringUserVar(name, string(col.str), mysql.DefaultCollationName)
}
sessionVars.UsersLock.Unlock()
}

for i := 0; i < len(e.FieldMappings); i++ {
Expand Down
6 changes: 2 additions & 4 deletions executor/set.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,14 +87,12 @@ func (e *SetExecutor) Next(ctx context.Context, req *chunk.Chunk) error {
if err != nil {
return err
}
sessionVars.UsersLock.Lock()
if value.IsNull() {
sessionVars.UnsetUserVar(name)
} else {
sessionVars.Users[name] = value
sessionVars.UserVarTypes[name] = v.Expr.GetType()
sessionVars.SetUserVarVal(name, value)
sessionVars.SetUserVarType(name, v.Expr.GetType())
}
sessionVars.UsersLock.Unlock()
continue
}

Expand Down
40 changes: 10 additions & 30 deletions expression/builtin_other.go
Original file line number Diff line number Diff line change
Expand Up @@ -776,9 +776,7 @@ func (b *builtinSetStringVarSig) evalString(row chunk.Row) (res string, isNull b
if err != nil {
return "", isNull, err
}
sessionVars.UsersLock.Lock()
sessionVars.SetUserVar(varName, stringutil.Copy(res), datum.Collation())
sessionVars.UsersLock.Unlock()
sessionVars.SetStringUserVar(varName, stringutil.Copy(res), datum.Collation())
return res, false, nil
}

Expand Down Expand Up @@ -806,9 +804,7 @@ func (b *builtinSetRealVarSig) evalReal(row chunk.Row) (res float64, isNull bool
}
res = datum.GetFloat64()
varName = strings.ToLower(varName)
sessionVars.UsersLock.Lock()
sessionVars.Users[varName] = datum
sessionVars.UsersLock.Unlock()
sessionVars.SetUserVarVal(varName, datum)
return res, false, nil
}

Expand All @@ -835,9 +831,7 @@ func (b *builtinSetDecimalVarSig) evalDecimal(row chunk.Row) (*types.MyDecimal,
}
res := datum.GetMysqlDecimal()
varName = strings.ToLower(varName)
sessionVars.UsersLock.Lock()
sessionVars.Users[varName] = datum
sessionVars.UsersLock.Unlock()
sessionVars.SetUserVarVal(varName, datum)
return res, false, nil
}

Expand All @@ -864,9 +858,7 @@ func (b *builtinSetIntVarSig) evalInt(row chunk.Row) (int64, bool, error) {
}
res := datum.GetInt64()
varName = strings.ToLower(varName)
sessionVars.UsersLock.Lock()
sessionVars.Users[varName] = datum
sessionVars.UsersLock.Unlock()
sessionVars.SetUserVarVal(varName, datum)
return res, false, nil
}

Expand All @@ -892,9 +884,7 @@ func (b *builtinSetTimeVarSig) evalTime(row chunk.Row) (types.Time, bool, error)
}
res := datum.GetMysqlTime()
varName = strings.ToLower(varName)
sessionVars.UsersLock.Lock()
sessionVars.Users[varName] = datum
sessionVars.UsersLock.Unlock()
sessionVars.SetUserVarVal(varName, datum)
return res, false, nil
}

Expand Down Expand Up @@ -971,9 +961,7 @@ func (b *builtinGetStringVarSig) evalString(row chunk.Row) (string, bool, error)
return "", isNull, err
}
varName = strings.ToLower(varName)
sessionVars.UsersLock.RLock()
defer sessionVars.UsersLock.RUnlock()
if v, ok := sessionVars.Users[varName]; ok {
if v, ok := sessionVars.GetUserVarVal(varName); ok {
// We cannot use v.GetString() here, because the datum may be in KindMysqlTime, which
// stores the data in datum.x.
// This seems controversial with https://dev.mysql.com/doc/refman/8.0/en/user-variables.html:
Expand Down Expand Up @@ -1025,9 +1013,7 @@ func (b *builtinGetIntVarSig) evalInt(row chunk.Row) (int64, bool, error) {
return 0, isNull, err
}
varName = strings.ToLower(varName)
sessionVars.UsersLock.RLock()
defer sessionVars.UsersLock.RUnlock()
if v, ok := sessionVars.Users[varName]; ok {
if v, ok := sessionVars.GetUserVarVal(varName); ok {
return v.GetInt64(), false, nil
}
return 0, true, nil
Expand Down Expand Up @@ -1067,9 +1053,7 @@ func (b *builtinGetRealVarSig) evalReal(row chunk.Row) (float64, bool, error) {
return 0, isNull, err
}
varName = strings.ToLower(varName)
sessionVars.UsersLock.RLock()
defer sessionVars.UsersLock.RUnlock()
if v, ok := sessionVars.Users[varName]; ok {
if v, ok := sessionVars.GetUserVarVal(varName); ok {
return v.GetFloat64(), false, nil
}
return 0, true, nil
Expand Down Expand Up @@ -1109,9 +1093,7 @@ func (b *builtinGetDecimalVarSig) evalDecimal(row chunk.Row) (*types.MyDecimal,
return nil, isNull, err
}
varName = strings.ToLower(varName)
sessionVars.UsersLock.RLock()
defer sessionVars.UsersLock.RUnlock()
if v, ok := sessionVars.Users[varName]; ok {
if v, ok := sessionVars.GetUserVarVal(varName); ok {
return v.GetMysqlDecimal(), false, nil
}
return nil, true, nil
Expand Down Expand Up @@ -1159,9 +1141,7 @@ func (b *builtinGetTimeVarSig) evalTime(row chunk.Row) (types.Time, bool, error)
return types.ZeroTime, isNull, err
}
varName = strings.ToLower(varName)
sessionVars.UsersLock.RLock()
defer sessionVars.UsersLock.RUnlock()
if v, ok := sessionVars.Users[varName]; ok {
if v, ok := sessionVars.GetUserVarVal(varName); ok {
return v.GetMysqlTime(), false, nil
}
return types.ZeroTime, true, nil
Expand Down
12 changes: 5 additions & 7 deletions expression/builtin_other_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ func TestSetVar(t *testing.T) {
if tc.args[1] != nil {
key, ok := tc.args[0].(string)
require.Equal(t, true, ok)
sessionVar, ok := ctx.GetSessionVars().Users[key]
sessionVar, ok := ctx.GetSessionVars().GetUserVarVal(key)
require.Equal(t, true, ok)
require.Equal(t, tc.res, sessionVar.GetValue())
}
Expand All @@ -135,15 +135,15 @@ func TestGetVar(t *testing.T) {
{"h", timeDec},
}
for _, kv := range sessionVars {
ctx.GetSessionVars().Users[kv.key] = types.NewDatum(kv.val)
ctx.GetSessionVars().SetUserVarVal(kv.key, types.NewDatum(kv.val))
var tp *types.FieldType
if _, ok := kv.val.(types.Time); ok {
tp = types.NewFieldType(mysql.TypeDatetime)
} else {
tp = types.NewFieldType(mysql.TypeVarString)
}
types.DefaultParamTypeForValue(kv.val, tp)
ctx.GetSessionVars().UserVarTypes[kv.key] = tp
ctx.GetSessionVars().SetUserVarType(kv.key, tp)
}

testCases := []struct {
Expand All @@ -160,7 +160,7 @@ func TestGetVar(t *testing.T) {
{[]interface{}{"h"}, timeDec.String()},
}
for _, tc := range testCases {
tp, ok := ctx.GetSessionVars().UserVarTypes[tc.args[0].(string)]
tp, ok := ctx.GetSessionVars().GetUserVarType(tc.args[0].(string))
if !ok {
tp = types.NewFieldType(mysql.TypeVarString)
}
Expand Down Expand Up @@ -244,9 +244,7 @@ func TestSetVarFromColumn(t *testing.T) {

// Check whether the user variable changed.
sessionVars := ctx.GetSessionVars()
sessionVars.UsersLock.RLock()
defer sessionVars.UsersLock.RUnlock()
sessionVar, ok := sessionVars.Users["a"]
sessionVar, ok := sessionVars.GetUserVarVal("a")
require.Equal(t, true, ok)
require.Equal(t, "a", sessionVar.GetString())
}
Expand Down
32 changes: 8 additions & 24 deletions expression/builtin_other_vec.go
Original file line number Diff line number Diff line change
Expand Up @@ -179,8 +179,6 @@ func (b *builtinSetStringVarSig) vecEvalString(input *chunk.Chunk, result *chunk
}
result.ReserveString(n)
sessionVars := b.ctx.GetSessionVars()
sessionVars.UsersLock.Lock()
defer sessionVars.UsersLock.Unlock()
_, collation := sessionVars.GetCharsetInfo()
for i := 0; i < n; i++ {
if buf0.IsNull(i) || buf1.IsNull(i) {
Expand All @@ -189,7 +187,7 @@ func (b *builtinSetStringVarSig) vecEvalString(input *chunk.Chunk, result *chunk
}
varName := strings.ToLower(buf0.GetString(i))
res := buf1.GetString(i)
sessionVars.Users[varName] = types.NewCollationStringDatum(stringutil.Copy(res), collation)
sessionVars.SetUserVarVal(varName, types.NewCollationStringDatum(stringutil.Copy(res), collation))
result.AppendString(res)
}
return nil
Expand Down Expand Up @@ -220,16 +218,14 @@ func (b *builtinSetIntVarSig) vecEvalInt(input *chunk.Chunk, result *chunk.Colum
result.ResizeInt64(n, false)
i64s := result.Int64s()
sessionVars := b.ctx.GetSessionVars()
sessionVars.UsersLock.Lock()
defer sessionVars.UsersLock.Unlock()
for i := 0; i < n; i++ {
if buf0.IsNull(i) || buf1.IsNull(i) {
result.SetNull(i, true)
continue
}
varName := strings.ToLower(buf0.GetString(i))
res := buf1.GetInt64(i)
sessionVars.Users[varName] = types.NewIntDatum(res)
sessionVars.SetUserVarVal(varName, types.NewIntDatum(res))
i64s[i] = res
}
return nil
Expand Down Expand Up @@ -260,16 +256,14 @@ func (b *builtinSetRealVarSig) vecEvalReal(input *chunk.Chunk, result *chunk.Col
result.ResizeFloat64(n, false)
f64s := result.Float64s()
sessionVars := b.ctx.GetSessionVars()
sessionVars.UsersLock.Lock()
defer sessionVars.UsersLock.Unlock()
for i := 0; i < n; i++ {
if buf0.IsNull(i) || buf1.IsNull(i) {
result.SetNull(i, true)
continue
}
varName := strings.ToLower(buf0.GetString(i))
res := buf1.GetFloat64(i)
sessionVars.Users[varName] = types.NewFloat64Datum(res)
sessionVars.SetUserVarVal(varName, types.NewFloat64Datum(res))
f64s[i] = res
}
return nil
Expand Down Expand Up @@ -300,16 +294,14 @@ func (b *builtinSetDecimalVarSig) vecEvalDecimal(input *chunk.Chunk, result *chu
result.ResizeDecimal(n, false)
decs := result.Decimals()
sessionVars := b.ctx.GetSessionVars()
sessionVars.UsersLock.Lock()
defer sessionVars.UsersLock.Unlock()
for i := 0; i < n; i++ {
if buf0.IsNull(i) || buf1.IsNull(i) {
result.SetNull(i, true)
continue
}
varName := strings.ToLower(buf0.GetString(i))
res := buf1.GetDecimal(i)
sessionVars.Users[varName] = types.NewDecimalDatum(res)
sessionVars.SetUserVarVal(varName, types.NewDecimalDatum(res))
decs[i] = *res
}
return nil
Expand Down Expand Up @@ -339,15 +331,13 @@ func (b *builtinGetStringVarSig) vecEvalString(input *chunk.Chunk, result *chunk
}
result.ReserveString(n)
sessionVars := b.ctx.GetSessionVars()
sessionVars.UsersLock.Lock()
defer sessionVars.UsersLock.Unlock()
for i := 0; i < n; i++ {
if buf0.IsNull(i) {
result.AppendNull()
continue
}
varName := strings.ToLower(buf0.GetString(i))
if v, ok := sessionVars.Users[varName]; ok {
if v, ok := sessionVars.GetUserVarVal(varName); ok {
res, err := v.ToString()
if err != nil {
return err
Expand Down Expand Up @@ -378,14 +368,12 @@ func (b *builtinGetIntVarSig) vecEvalInt(input *chunk.Chunk, result *chunk.Colum
result.MergeNulls(buf0)
i64s := result.Int64s()
sessionVars := b.ctx.GetSessionVars()
sessionVars.UsersLock.Lock()
defer sessionVars.UsersLock.Unlock()
for i := 0; i < n; i++ {
if result.IsNull(i) {
continue
}
varName := strings.ToLower(buf0.GetString(i))
if v, ok := sessionVars.Users[varName]; ok {
if v, ok := sessionVars.GetUserVarVal(varName); ok {
i64s[i] = v.GetInt64()
continue
}
Expand All @@ -412,14 +400,12 @@ func (b *builtinGetRealVarSig) vecEvalReal(input *chunk.Chunk, result *chunk.Col
result.MergeNulls(buf0)
f64s := result.Float64s()
sessionVars := b.ctx.GetSessionVars()
sessionVars.UsersLock.Lock()
defer sessionVars.UsersLock.Unlock()
for i := 0; i < n; i++ {
if result.IsNull(i) {
continue
}
varName := strings.ToLower(buf0.GetString(i))
if v, ok := sessionVars.Users[varName]; ok {
if v, ok := sessionVars.GetUserVarVal(varName); ok {
f64s[i] = v.GetFloat64()
continue
}
Expand All @@ -446,14 +432,12 @@ func (b *builtinGetDecimalVarSig) vecEvalDecimal(input *chunk.Chunk, result *chu
result.MergeNulls(buf0)
decs := result.Decimals()
sessionVars := b.ctx.GetSessionVars()
sessionVars.UsersLock.Lock()
defer sessionVars.UsersLock.Unlock()
for i := 0; i < n; i++ {
if result.IsNull(i) {
continue
}
varName := strings.ToLower(buf0.GetString(i))
if v, ok := sessionVars.Users[varName]; ok {
if v, ok := sessionVars.GetUserVarVal(varName); ok {
decs[i] = *v.GetMysqlDecimal()
continue
}
Expand Down
2 changes: 1 addition & 1 deletion planner/core/common_plans.go
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ func isGetVarBinaryLiteral(sctx sessionctx.Context, expr expression.Expression)
name, isNull, err := scalarFunc.GetArgs()[0].EvalString(sctx, chunk.Row{})
if err != nil || isNull {
res = false
} else if dt, ok2 := sctx.GetSessionVars().Users[name]; ok2 {
} else if dt, ok2 := sctx.GetSessionVars().GetUserVarVal(name); ok2 {
res = dt.Kind() == types.KindBinaryLiteral
}
}
Expand Down
8 changes: 2 additions & 6 deletions planner/core/expression_rewriter.go
Original file line number Diff line number Diff line change
Expand Up @@ -1280,14 +1280,10 @@ func (er *expressionRewriter) rewriteVariable(v *ast.VariableExpr) {
// Store the field type of the variable into SessionVars.UserVarTypes.
// Normally we can infer the type from SessionVars.User, but we need SessionVars.UserVarTypes when
// GetVar has not been executed to fill the SessionVars.Users.
sessionVars.UsersLock.Lock()
sessionVars.UserVarTypes[name] = tp
sessionVars.UsersLock.Unlock()
sessionVars.SetUserVarType(name, tp)
return
}
sessionVars.UsersLock.RLock()
tp, ok := sessionVars.UserVarTypes[name]
sessionVars.UsersLock.RUnlock()
tp, ok := sessionVars.GetUserVarType(name)
if !ok {
tp = types.NewFieldType(mysql.TypeVarString)
tp.SetFlen(mysql.MaxFieldVarCharLength)
Expand Down
4 changes: 2 additions & 2 deletions planner/core/plan_cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,8 @@ func parseParamTypes(sctx sessionctx.Context, params []expression.Expression) (p

// from text protocol, there must be a GetVar function
name := param.(*expression.ScalarFunction).GetArgs()[0].String()
tp := sctx.GetSessionVars().UserVarTypes[name]
if tp == nil {
tp, ok := sctx.GetSessionVars().GetUserVarType(name)
if !ok {
tp = types.NewFieldType(mysql.TypeNull)
}
paramTypes = append(paramTypes, tp)
Expand Down
2 changes: 1 addition & 1 deletion planner/core/planbuilder.go
Original file line number Diff line number Diff line change
Expand Up @@ -1362,7 +1362,7 @@ func (b *PlanBuilder) buildPrepare(x *ast.PrepareStmt) Plan {
Name: x.Name,
}
if x.SQLVar != nil {
if v, ok := b.ctx.GetSessionVars().Users[strings.ToLower(x.SQLVar.Name)]; ok {
if v, ok := b.ctx.GetSessionVars().GetUserVarVal(strings.ToLower(x.SQLVar.Name)); ok {
var err error
p.SQLText, err = v.ToString()
if err != nil {
Expand Down
Loading

0 comments on commit 7d8c45a

Please sign in to comment.