From 3efe1571652ead17a84876dff7500db5ff6cba11 Mon Sep 17 00:00:00 2001 From: Rivery Date: Wed, 15 Sep 2021 16:18:30 +0800 Subject: [PATCH 1/3] feat:support cypher parameter refactor: remove vscode file update: fix typo value mod: update local cmd mod: update params mod: support params mod: add type --- .gitignore | 1 - common/common.go | 3 + controllers/db.go | 17 ++- service/dao/dao.go | 29 ++--- service/pool/pool.go | 258 ++++++++++++++++++++++++++++++++++++++++++- 5 files changed, 283 insertions(+), 25 deletions(-) diff --git a/.gitignore b/.gitignore index 9401776..7b5cbe4 100644 --- a/.gitignore +++ b/.gitignore @@ -14,7 +14,6 @@ nebula-httpd .idea .vscode/ vendor/ - # Dependency directories (remove the comment below to include it) tmp/ diff --git a/common/common.go b/common/common.go index 4638a78..af7b746 100644 --- a/common/common.go +++ b/common/common.go @@ -1,3 +1,6 @@ package common type Any interface{} + +type ParameterList []string +type ParameterMap map[string]interface{} diff --git a/controllers/db.go b/controllers/db.go index 3140419..d4bf560 100644 --- a/controllers/db.go +++ b/controllers/db.go @@ -13,9 +13,10 @@ type DatabaseController struct { } type Response struct { - Code int `json:"code"` - Data common.Any `json:"data"` - Message string `json:"message"` + Code int `json:"code"` + Data common.Any `json:"data"` + Message string `json:"message"` + Params common.ParameterMap `json:"params"` } type Request struct { @@ -26,7 +27,8 @@ type Request struct { } type ExecuteRequest struct { - Gql string `json:"gql"` + Gql string `json:"gql"` + ParamList common.ParameterList `json:"paramList"` } type Data map[string]interface{} @@ -84,7 +86,7 @@ func (this *DatabaseController) Execute() { res.Message = "connection refused for lack of session" } else { json.Unmarshal(this.Ctx.Input.RequestBody, ¶ms) - result, err := dao.Execute(nsid.(string), params.Gql) + result, paramsMap, err := dao.Execute(nsid.(string), params.Gql, params.ParamList) if err == nil { res.Code = 0 res.Data = &result @@ -92,6 +94,11 @@ func (this *DatabaseController) Execute() { res.Code = -1 res.Message = err.Error() } + if len(paramsMap) == 0 { + res.Params = nil + } else { + res.Params = paramsMap + } } this.Data["json"] = &res this.ServeJSON() diff --git a/service/dao/dao.go b/service/dao/dao.go index d311edf..e32e719 100644 --- a/service/dao/dao.go +++ b/service/dao/dao.go @@ -287,26 +287,30 @@ func Disconnect(nsid string) { pool.Disconnect(nsid) } -func Execute(nsid string, gql string) (result ExecuteResult, err error) { +func Execute(nsid string, gql string, paramList common.ParameterList) (result ExecuteResult, params common.ParameterMap, err error) { result = ExecuteResult{ Headers: make([]string, 0), Tables: make([]map[string]common.Any, 0), } connection, err := pool.GetConnection(nsid) if err != nil { - return result, err + return result, nil, err } - responseChannel := make(chan pool.ChannelResponse) connection.RequestChannel <- pool.ChannelRequest{ Gql: gql, ResponseChannel: responseChannel, + ParamList: paramList, } response := <-responseChannel + paramsMap := response.Params if response.Error != nil { - return result, response.Error + return result, paramsMap, response.Error } resp := response.Result + if response.Result == nil { + return result, paramsMap, nil + } if resp.IsSetPlanDesc() { format := string(resp.GetPlanDesc().GetFormat()) if format == "row" { @@ -321,7 +325,7 @@ func Execute(nsid string, gql string) (result ExecuteResult, err error) { rowValue["operator info"] = rows[i][4] result.Tables = append(result.Tables, rowValue) } - return result, err + return result, paramsMap, err } else { var rowValue = make(map[string]common.Any) result.Headers = append(result.Headers, "format") @@ -331,13 +335,12 @@ func Execute(nsid string, gql string) (result ExecuteResult, err error) { rowValue["format"] = resp.MakeDotGraphByStruct() } result.Tables = append(result.Tables, rowValue) - return result, err + return result, paramsMap, err } } - if !resp.IsSucceed() { logs.Info("ErrorCode: %v, ErrorMsg: %s", resp.GetErrorCode(), resp.GetErrorMsg()) - return result, errors.New(string(resp.GetErrorMsg())) + return result, paramsMap, errors.New(string(resp.GetErrorMsg())) } if !resp.IsEmpty() { rowSize := resp.GetRowSize() @@ -351,16 +354,16 @@ func Execute(nsid string, gql string) (result ExecuteResult, err error) { var _edgesParsedList = make(list, 0) var _pathsParsedList = make(list, 0) if err != nil { - return result, err + return result, paramsMap, err } for j := 0; j < colSize; j++ { rowData, err := record.GetValueByIndex(j) if err != nil { - return result, err + return result, paramsMap, err } value, err := getValue(rowData) if err != nil { - return result, err + return result, paramsMap, err } rowValue[result.Headers[j]] = value valueType := rowData.GetType() @@ -396,12 +399,12 @@ func Execute(nsid string, gql string) (result ExecuteResult, err error) { rowValue["_pathsParsedList"] = _pathsParsedList } if err != nil { - return result, err + return result, paramsMap, err } } result.Tables = append(result.Tables, rowValue) } } result.TimeCost = resp.GetLatency() - return result, nil + return result, paramsMap, nil } diff --git a/service/pool/pool.go b/service/pool/pool.go index bb8bc5c..1d7418c 100644 --- a/service/pool/pool.go +++ b/service/pool/pool.go @@ -1,7 +1,10 @@ package pool import ( + "encoding/json" "errors" + "fmt" + "regexp" "strings" "sync" "time" @@ -12,11 +15,20 @@ import ( uuid "github.com/satori/go.uuid" nebula "github.com/vesoft-inc/nebula-go/v2" + nebulaType "github.com/vesoft-inc/nebula-go/v2/nebula" ) var ( ConnectionClosedError = errors.New("an existing connection was forcibly closed, please check your network") SessionLostError = errors.New("the connection session was lost, please connect again") + InterruptError = errors.New("Other statements was not executed due to this error.") +) + +// Console side commands +const ( + Unknown = -1 + Param = 1 + Params = 2 ) type Account struct { @@ -26,18 +38,21 @@ type Account struct { type ChannelResponse struct { Result *nebula.ResultSet + Params common.ParameterMap Error error } type ChannelRequest struct { Gql string ResponseChannel chan ChannelResponse + ParamList common.ParameterList } type Connection struct { RequestChannel chan ChannelRequest CloseChannel chan bool updateTime int64 + parameterMap common.ParameterMap account *Account session *nebula.Session } @@ -76,6 +91,203 @@ func isThriftTransportError(err error) bool { return false } +// construct Slice to nebula.NList +func Slice2Nlist(list []interface{}) (*nebulaType.NList, error) { + sv := []*nebulaType.Value{} + var ret nebulaType.NList + for _, item := range list { + nv, er := Base2Value(item) + if er != nil { + return nil, er + } + sv = append(sv, nv) + } + ret.Values = sv + return &ret, nil +} + +// construct map to nebula.NMap +func Map2Nmap(m map[string]interface{}) (*nebulaType.NMap, error) { + var ret nebulaType.NMap + kvs := map[string]*nebulaType.Value{} + for k, v := range m { + nv, err := Base2Value(v) + if err != nil { + return nil, err + } + kvs[k] = nv + } + ret.Kvs = kvs + return &ret, nil +} + +// construct go-type to nebula.Value +func Base2Value(any interface{}) (value *nebulaType.Value, err error) { + value = nebulaType.NewValue() + if v, ok := any.(bool); ok { + value.BVal = &v + } else if v, ok := any.(int); ok { + ival := int64(v) + value.IVal = &ival + } else if v, ok := any.(float64); ok { + if v == float64(int64(v)) { + iv := int64(v) + value.IVal = &iv + } else { + value.FVal = &v + } + } else if v, ok := any.(float32); ok { + if v == float32(int64(v)) { + iv := int64(v) + value.IVal = &iv + } else { + fval := float64(v) + value.FVal = &fval + } + } else if v, ok := any.(string); ok { + value.SVal = []byte(v) + } else if any == nil { + nval := nebulaType.NullType___NULL__ + value.NVal = &nval + } else if v, ok := any.([]interface{}); ok { + nv, er := Slice2Nlist([]interface{}(v)) + if er != nil { + err = er + } + value.LVal = nv + } else if v, ok := any.(map[string]interface{}); ok { + nv, er := Map2Nmap(map[string]interface{}(v)) + if er != nil { + err = er + } + value.MVal = nv + } else { + // unsupport other Value type, use this function carefully + err = fmt.Errorf("Only support convert boolean/float/int/string/map/list to nebula.Value but %T", any) + } + return +} + +func isCmd(query string) (isLocal bool, localCmd int, args []string) { + isLocal = false + localCmd = Unknown + plain := strings.TrimSpace(query) + if len(plain) < 1 || plain[0] != ':' { + return + } + isLocal = true + words := strings.Fields(plain[1:]) + localCmdName := words[0] + switch strings.ToLower(localCmdName) { + case "param": + { + localCmd = Param + args = []string{plain} + } + case "params": + { + localCmd = Params + args = []string{plain} + } + } + return +} + +func executeCmd(parameterList common.ParameterList, parameterMap *common.ParameterMap) (showMap common.ParameterMap, err error) { + for _, v := range parameterList { + // convert interface{} to nebula.Value + if isLocal, cmd, args := isCmd(v); isLocal { + switch cmd { + case Param: + if len(args) == 1 { + err = defineParams(args[0], parameterMap) + } + if err != nil { + return nil, err + } + case Params: + if len(args) == 1 { + showMap, err = ListParams(args[0], parameterMap) + } + if err != nil { + return nil, err + } + } + } + } + return showMap, nil +} + +func defineParams(args string, parameterMap *common.ParameterMap) (err error) { + argsRewritten := strings.Replace(args, "'", "\"", -1) + reg := regexp.MustCompile(`^\s*:param\s+(\S+)\s*=>(.*)$`) + if reg == nil { + err = errors.New("invalid regular expression") + return + } + matchResult := reg.FindAllStringSubmatch(argsRewritten, -1) + if len(matchResult) != 1 || len(matchResult[0]) != 3 { + err = errors.New("Set params failed. Wrong local command format (" + reg.String() + ") ") + return + } + /* + * :param p1=> -> [":param p1=>",":p1",""] + * :param p2=>3 -> [":param p2=>3",":p2","3"] + */ + paramKey := matchResult[0][1] + paramValue := matchResult[0][2] + if len(paramValue) == 0 { + delete((*parameterMap), paramKey) + } else { + paramsWithGoType := make(common.ParameterMap) + param := "{\"" + paramKey + "\"" + ":" + paramValue + "}" + err = json.Unmarshal([]byte(param), ¶msWithGoType) + if err != nil { + return + } + for k, v := range paramsWithGoType { + (*parameterMap)[k] = v + } + } + return nil +} + +func ListParams(args string, parameterMap *common.ParameterMap) (showMap common.ParameterMap, err error) { + reg := regexp.MustCompile(`^\s*:params\s*(\S*)\s*$`) + paramsWithGoType := make(common.ParameterMap) + if reg == nil { + err = errors.New("invalid regular expression") + return + } + matchResult := reg.FindAllStringSubmatch(args, -1) + if len(matchResult) != 1 { + err = errors.New("Set params failed. Wrong local command format " + reg.String() + ") ") + return + } + res := matchResult[0] + /* + * :params -> [":params",""] + * :params p1 -> ["params","p1"] + */ + if len(res) != 2 { + return + } else { + paramKey := matchResult[0][1] + if len(paramKey) == 0 { + for k, v := range *parameterMap { + paramsWithGoType[k] = v + } + } else { + if paramValue, ok := (*parameterMap)[paramKey]; ok { + paramsWithGoType[paramKey] = paramValue + } else { + err = errors.New("Unknown parameter: " + paramKey) + } + } + } + return paramsWithGoType, nil +} + func NewConnection(address string, port int, username string, password string) (nsid string, err error) { connectLock.Lock() defer connectLock.Unlock() @@ -103,6 +315,7 @@ func NewConnection(address string, port int, username string, password string) ( CloseChannel: make(chan bool), updateTime: time.Now().Unix(), session: session, + parameterMap: make(common.ParameterMap), account: &Account{ username: username, password: password, @@ -126,13 +339,46 @@ func NewConnection(address string, port int, username string, password string) ( } } }() - response, err := connection.session.Execute(request.Gql) - if err != nil && (isThriftProtoError(err) || isThriftTransportError(err)) { - err = ConnectionClosedError + showMap := make(common.ParameterMap) + if len(request.ParamList) > 0 { + showMap, err = executeCmd(request.ParamList, &connection.parameterMap) + if err != nil { + if len(request.Gql) > 0 { + err = errors.New(err.Error() + InterruptError.Error()) + } + request.ResponseChannel <- ChannelResponse{ + Result: nil, + Params: showMap, + Error: err, + } + return + } } - request.ResponseChannel <- ChannelResponse{ - Result: response, - Error: err, + + if len(request.Gql) > 0 { + params := make(map[string]*nebulaType.Value) + for k, v := range connection.parameterMap { + value, paramError := Base2Value(v) + if paramError != nil { + err = paramError + } + params[k] = value + } + response, err := connection.session.ExecuteWithParameter(request.Gql, params) + if err != nil && (isThriftProtoError(err) || isThriftTransportError(err)) { + err = ConnectionClosedError + } + request.ResponseChannel <- ChannelResponse{ + Result: response, + Params: showMap, + Error: err, + } + } else { + request.ResponseChannel <- ChannelResponse{ + Result: nil, + Params: showMap, + Error: nil, + } } }() case <-connection.CloseChannel: From b4ae35876454bbc412b2e2053fa5ee7cb25b4d19 Mon Sep 17 00:00:00 2001 From: Nut He <18328704+hetao92@users.noreply.github.com> Date: Wed, 29 Dec 2021 20:25:25 +0800 Subject: [PATCH 2/3] mod: return param in data struc --- controllers/db.go | 14 ++++---------- service/dao/dao.go | 39 ++++++++++++++++++++++----------------- service/pool/pool.go | 22 +++++++++++----------- 3 files changed, 37 insertions(+), 38 deletions(-) diff --git a/controllers/db.go b/controllers/db.go index d4bf560..4cfd806 100644 --- a/controllers/db.go +++ b/controllers/db.go @@ -13,10 +13,9 @@ type DatabaseController struct { } type Response struct { - Code int `json:"code"` - Data common.Any `json:"data"` - Message string `json:"message"` - Params common.ParameterMap `json:"params"` + Code int `json:"code"` + Data common.Any `json:"data"` + Message string `json:"message"` } type Request struct { @@ -86,7 +85,7 @@ func (this *DatabaseController) Execute() { res.Message = "connection refused for lack of session" } else { json.Unmarshal(this.Ctx.Input.RequestBody, ¶ms) - result, paramsMap, err := dao.Execute(nsid.(string), params.Gql, params.ParamList) + result, err := dao.Execute(nsid.(string), params.Gql, params.ParamList) if err == nil { res.Code = 0 res.Data = &result @@ -94,11 +93,6 @@ func (this *DatabaseController) Execute() { res.Code = -1 res.Message = err.Error() } - if len(paramsMap) == 0 { - res.Params = nil - } else { - res.Params = paramsMap - } } this.Data["json"] = &res this.ServeJSON() diff --git a/service/dao/dao.go b/service/dao/dao.go index e32e719..1127d39 100644 --- a/service/dao/dao.go +++ b/service/dao/dao.go @@ -12,9 +12,10 @@ import ( ) type ExecuteResult struct { - Headers []string `json:"headers"` - Tables []map[string]common.Any `json:"tables"` - TimeCost int64 `json:"timeCost"` + Headers []string `json:"headers"` + Tables []map[string]common.Any `json:"tables"` + TimeCost int64 `json:"timeCost"` + LocalParams common.ParameterMap `json:"localParams"` } type list []common.Any @@ -287,14 +288,15 @@ func Disconnect(nsid string) { pool.Disconnect(nsid) } -func Execute(nsid string, gql string, paramList common.ParameterList) (result ExecuteResult, params common.ParameterMap, err error) { +func Execute(nsid string, gql string, paramList common.ParameterList) (result ExecuteResult, err error) { result = ExecuteResult{ - Headers: make([]string, 0), - Tables: make([]map[string]common.Any, 0), + Headers: make([]string, 0), + Tables: make([]map[string]common.Any, 0), + LocalParams: nil, } connection, err := pool.GetConnection(nsid) if err != nil { - return result, nil, err + return result, err } responseChannel := make(chan pool.ChannelResponse) connection.RequestChannel <- pool.ChannelRequest{ @@ -304,12 +306,15 @@ func Execute(nsid string, gql string, paramList common.ParameterList) (result Ex } response := <-responseChannel paramsMap := response.Params + if len(paramsMap) > 0 { + result.LocalParams = paramsMap + } if response.Error != nil { - return result, paramsMap, response.Error + return result, response.Error } resp := response.Result if response.Result == nil { - return result, paramsMap, nil + return result, nil } if resp.IsSetPlanDesc() { format := string(resp.GetPlanDesc().GetFormat()) @@ -325,7 +330,7 @@ func Execute(nsid string, gql string, paramList common.ParameterList) (result Ex rowValue["operator info"] = rows[i][4] result.Tables = append(result.Tables, rowValue) } - return result, paramsMap, err + return result, err } else { var rowValue = make(map[string]common.Any) result.Headers = append(result.Headers, "format") @@ -335,12 +340,12 @@ func Execute(nsid string, gql string, paramList common.ParameterList) (result Ex rowValue["format"] = resp.MakeDotGraphByStruct() } result.Tables = append(result.Tables, rowValue) - return result, paramsMap, err + return result, err } } if !resp.IsSucceed() { logs.Info("ErrorCode: %v, ErrorMsg: %s", resp.GetErrorCode(), resp.GetErrorMsg()) - return result, paramsMap, errors.New(string(resp.GetErrorMsg())) + return result, errors.New(string(resp.GetErrorMsg())) } if !resp.IsEmpty() { rowSize := resp.GetRowSize() @@ -354,16 +359,16 @@ func Execute(nsid string, gql string, paramList common.ParameterList) (result Ex var _edgesParsedList = make(list, 0) var _pathsParsedList = make(list, 0) if err != nil { - return result, paramsMap, err + return result, err } for j := 0; j < colSize; j++ { rowData, err := record.GetValueByIndex(j) if err != nil { - return result, paramsMap, err + return result, err } value, err := getValue(rowData) if err != nil { - return result, paramsMap, err + return result, err } rowValue[result.Headers[j]] = value valueType := rowData.GetType() @@ -399,12 +404,12 @@ func Execute(nsid string, gql string, paramList common.ParameterList) (result Ex rowValue["_pathsParsedList"] = _pathsParsedList } if err != nil { - return result, paramsMap, err + return result, err } } result.Tables = append(result.Tables, rowValue) } } result.TimeCost = resp.GetLatency() - return result, paramsMap, nil + return result, nil } diff --git a/service/pool/pool.go b/service/pool/pool.go index 1d7418c..7c38263 100644 --- a/service/pool/pool.go +++ b/service/pool/pool.go @@ -194,6 +194,7 @@ func isCmd(query string) (isLocal bool, localCmd int, args []string) { } func executeCmd(parameterList common.ParameterList, parameterMap *common.ParameterMap) (showMap common.ParameterMap, err error) { + tempMap := make(common.ParameterMap) for _, v := range parameterList { // convert interface{} to nebula.Value if isLocal, cmd, args := isCmd(v); isLocal { @@ -207,7 +208,7 @@ func executeCmd(parameterList common.ParameterList, parameterMap *common.Paramet } case Params: if len(args) == 1 { - showMap, err = ListParams(args[0], parameterMap) + err = ListParams(args[0], &tempMap, parameterMap) } if err != nil { return nil, err @@ -215,12 +216,12 @@ func executeCmd(parameterList common.ParameterList, parameterMap *common.Paramet } } } - return showMap, nil + return tempMap, nil } func defineParams(args string, parameterMap *common.ParameterMap) (err error) { argsRewritten := strings.Replace(args, "'", "\"", -1) - reg := regexp.MustCompile(`^\s*:param\s+(\S+)\s*=>(.*)$`) + reg := regexp.MustCompile(`(?i)^\s*:param\s+(\S+)\s*=>(.*)$`) if reg == nil { err = errors.New("invalid regular expression") return @@ -252,9 +253,8 @@ func defineParams(args string, parameterMap *common.ParameterMap) (err error) { return nil } -func ListParams(args string, parameterMap *common.ParameterMap) (showMap common.ParameterMap, err error) { - reg := regexp.MustCompile(`^\s*:params\s*(\S*)\s*$`) - paramsWithGoType := make(common.ParameterMap) +func ListParams(args string, tmpParameter *common.ParameterMap, sessionMap *common.ParameterMap) (err error) { + reg := regexp.MustCompile(`(?i)^\s*:params\s*(\S*)\s*$`) if reg == nil { err = errors.New("invalid regular expression") return @@ -274,18 +274,18 @@ func ListParams(args string, parameterMap *common.ParameterMap) (showMap common. } else { paramKey := matchResult[0][1] if len(paramKey) == 0 { - for k, v := range *parameterMap { - paramsWithGoType[k] = v + for k, v := range *sessionMap { + (*tmpParameter)[k] = v } } else { - if paramValue, ok := (*parameterMap)[paramKey]; ok { - paramsWithGoType[paramKey] = paramValue + if paramValue, ok := (*sessionMap)[paramKey]; ok { + (*tmpParameter)[paramKey] = paramValue } else { err = errors.New("Unknown parameter: " + paramKey) } } } - return paramsWithGoType, nil + return nil } func NewConnection(address string, port int, username string, password string) (nsid string, err error) { From 6f743c9b36a39b819aaa6804552810c45a338d72 Mon Sep 17 00:00:00 2001 From: Nut He <18328704+hetao92@users.noreply.github.com> Date: Thu, 30 Dec 2021 10:11:58 +0800 Subject: [PATCH 3/3] mod: code review --- service/dao/dao.go | 18 ++++++------- service/pool/pool.go | 61 +++++++++++++++++--------------------------- 2 files changed, 33 insertions(+), 46 deletions(-) diff --git a/service/dao/dao.go b/service/dao/dao.go index 1127d39..11f122b 100644 --- a/service/dao/dao.go +++ b/service/dao/dao.go @@ -8,7 +8,7 @@ import ( "github.com/vesoft-inc/nebula-http-gateway/service/pool" nebula "github.com/vesoft-inc/nebula-go/v2" - nebulaType "github.com/vesoft-inc/nebula-go/v2/nebula" + nebulatype "github.com/vesoft-inc/nebula-go/v2/nebula" ) type ExecuteResult struct { @@ -45,21 +45,21 @@ func getBasicValue(valWarp *nebula.ValueWrapper) (common.Any, error) { if valType == "null" { value, err := valWarp.AsNull() switch value { - case nebulaType.NullType___NULL__: + case nebulatype.NullType___NULL__: return "NULL", err - case nebulaType.NullType_NaN: + case nebulatype.NullType_NaN: return "NaN", err - case nebulaType.NullType_BAD_DATA: + case nebulatype.NullType_BAD_DATA: return "BAD_DATA", err - case nebulaType.NullType_BAD_TYPE: + case nebulatype.NullType_BAD_TYPE: return "BAD_TYPE", err - case nebulaType.NullType_OUT_OF_RANGE: + case nebulatype.NullType_OUT_OF_RANGE: return "OUT_OF_RANGE", err - case nebulaType.NullType_DIV_BY_ZERO: + case nebulatype.NullType_DIV_BY_ZERO: return "DIV_BY_ZERO", err - case nebulaType.NullType_UNKNOWN_PROP: + case nebulatype.NullType_UNKNOWN_PROP: return "UNKNOWN_PROP", err - case nebulaType.NullType_ERR_OVERFLOW: + case nebulatype.NullType_ERR_OVERFLOW: return "ERR_OVERFLOW", err } return "NULL", err diff --git a/service/pool/pool.go b/service/pool/pool.go index 7c38263..92ac19c 100644 --- a/service/pool/pool.go +++ b/service/pool/pool.go @@ -15,7 +15,7 @@ import ( uuid "github.com/satori/go.uuid" nebula "github.com/vesoft-inc/nebula-go/v2" - nebulaType "github.com/vesoft-inc/nebula-go/v2/nebula" + nebulatype "github.com/vesoft-inc/nebula-go/v2/nebula" ) var ( @@ -92,9 +92,9 @@ func isThriftTransportError(err error) bool { } // construct Slice to nebula.NList -func Slice2Nlist(list []interface{}) (*nebulaType.NList, error) { - sv := []*nebulaType.Value{} - var ret nebulaType.NList +func Slice2Nlist(list []interface{}) (*nebulatype.NList, error) { + sv := []*nebulatype.Value{} + var ret nebulatype.NList for _, item := range list { nv, er := Base2Value(item) if er != nil { @@ -107,9 +107,9 @@ func Slice2Nlist(list []interface{}) (*nebulaType.NList, error) { } // construct map to nebula.NMap -func Map2Nmap(m map[string]interface{}) (*nebulaType.NMap, error) { - var ret nebulaType.NMap - kvs := map[string]*nebulaType.Value{} +func Map2Nmap(m map[string]interface{}) (*nebulatype.NMap, error) { + var ret nebulatype.NMap + kvs := map[string]*nebulatype.Value{} for k, v := range m { nv, err := Base2Value(v) if err != nil { @@ -122,8 +122,8 @@ func Map2Nmap(m map[string]interface{}) (*nebulaType.NMap, error) { } // construct go-type to nebula.Value -func Base2Value(any interface{}) (value *nebulaType.Value, err error) { - value = nebulaType.NewValue() +func Base2Value(any interface{}) (value *nebulatype.Value, err error) { + value = nebulatype.NewValue() if v, ok := any.(bool); ok { value.BVal = &v } else if v, ok := any.(int); ok { @@ -147,7 +147,7 @@ func Base2Value(any interface{}) (value *nebulaType.Value, err error) { } else if v, ok := any.(string); ok { value.SVal = []byte(v) } else if any == nil { - nval := nebulaType.NullType___NULL__ + nval := nebulatype.NullType___NULL__ value.NVal = &nval } else if v, ok := any.([]interface{}); ok { nv, er := Slice2Nlist([]interface{}(v)) @@ -180,15 +180,11 @@ func isCmd(query string) (isLocal bool, localCmd int, args []string) { localCmdName := words[0] switch strings.ToLower(localCmdName) { case "param": - { - localCmd = Param - args = []string{plain} - } + localCmd = Param + args = []string{plain} case "params": - { - localCmd = Params - args = []string{plain} - } + localCmd = Params + args = []string{plain} } return } @@ -222,10 +218,6 @@ func executeCmd(parameterList common.ParameterList, parameterMap *common.Paramet func defineParams(args string, parameterMap *common.ParameterMap) (err error) { argsRewritten := strings.Replace(args, "'", "\"", -1) reg := regexp.MustCompile(`(?i)^\s*:param\s+(\S+)\s*=>(.*)$`) - if reg == nil { - err = errors.New("invalid regular expression") - return - } matchResult := reg.FindAllStringSubmatch(argsRewritten, -1) if len(matchResult) != 1 || len(matchResult[0]) != 3 { err = errors.New("Set params failed. Wrong local command format (" + reg.String() + ") ") @@ -255,10 +247,6 @@ func defineParams(args string, parameterMap *common.ParameterMap) (err error) { func ListParams(args string, tmpParameter *common.ParameterMap, sessionMap *common.ParameterMap) (err error) { reg := regexp.MustCompile(`(?i)^\s*:params\s*(\S*)\s*$`) - if reg == nil { - err = errors.New("invalid regular expression") - return - } matchResult := reg.FindAllStringSubmatch(args, -1) if len(matchResult) != 1 { err = errors.New("Set params failed. Wrong local command format " + reg.String() + ") ") @@ -271,18 +259,17 @@ func ListParams(args string, tmpParameter *common.ParameterMap, sessionMap *comm */ if len(res) != 2 { return + } + paramKey := matchResult[0][1] + if len(paramKey) == 0 { + for k, v := range *sessionMap { + (*tmpParameter)[k] = v + } } else { - paramKey := matchResult[0][1] - if len(paramKey) == 0 { - for k, v := range *sessionMap { - (*tmpParameter)[k] = v - } + if paramValue, ok := (*sessionMap)[paramKey]; ok { + (*tmpParameter)[paramKey] = paramValue } else { - if paramValue, ok := (*sessionMap)[paramKey]; ok { - (*tmpParameter)[paramKey] = paramValue - } else { - err = errors.New("Unknown parameter: " + paramKey) - } + err = errors.New("Unknown parameter: " + paramKey) } } return nil @@ -356,7 +343,7 @@ func NewConnection(address string, port int, username string, password string) ( } if len(request.Gql) > 0 { - params := make(map[string]*nebulaType.Value) + params := make(map[string]*nebulatype.Value) for k, v := range connection.parameterMap { value, paramError := Base2Value(v) if paramError != nil {