diff --git a/controllers/db.go b/controllers/db.go index d4bf560..4cfd806 100644 --- a/controllers/db.go +++ b/controllers/db.go @@ -13,10 +13,9 @@ type DatabaseController struct { } type Response struct { - Code int `json:"code"` - Data common.Any `json:"data"` - Message string `json:"message"` - Params common.ParameterMap `json:"params"` + Code int `json:"code"` + Data common.Any `json:"data"` + Message string `json:"message"` } type Request struct { @@ -86,7 +85,7 @@ func (this *DatabaseController) Execute() { res.Message = "connection refused for lack of session" } else { json.Unmarshal(this.Ctx.Input.RequestBody, ¶ms) - result, paramsMap, err := dao.Execute(nsid.(string), params.Gql, params.ParamList) + result, err := dao.Execute(nsid.(string), params.Gql, params.ParamList) if err == nil { res.Code = 0 res.Data = &result @@ -94,11 +93,6 @@ func (this *DatabaseController) Execute() { res.Code = -1 res.Message = err.Error() } - if len(paramsMap) == 0 { - res.Params = nil - } else { - res.Params = paramsMap - } } this.Data["json"] = &res this.ServeJSON() diff --git a/service/dao/dao.go b/service/dao/dao.go index e32e719..1127d39 100644 --- a/service/dao/dao.go +++ b/service/dao/dao.go @@ -12,9 +12,10 @@ import ( ) type ExecuteResult struct { - Headers []string `json:"headers"` - Tables []map[string]common.Any `json:"tables"` - TimeCost int64 `json:"timeCost"` + Headers []string `json:"headers"` + Tables []map[string]common.Any `json:"tables"` + TimeCost int64 `json:"timeCost"` + LocalParams common.ParameterMap `json:"localParams"` } type list []common.Any @@ -287,14 +288,15 @@ func Disconnect(nsid string) { pool.Disconnect(nsid) } -func Execute(nsid string, gql string, paramList common.ParameterList) (result ExecuteResult, params common.ParameterMap, err error) { +func Execute(nsid string, gql string, paramList common.ParameterList) (result ExecuteResult, err error) { result = ExecuteResult{ - Headers: make([]string, 0), - Tables: make([]map[string]common.Any, 0), + Headers: make([]string, 0), + Tables: make([]map[string]common.Any, 0), + LocalParams: nil, } connection, err := pool.GetConnection(nsid) if err != nil { - return result, nil, err + return result, err } responseChannel := make(chan pool.ChannelResponse) connection.RequestChannel <- pool.ChannelRequest{ @@ -304,12 +306,15 @@ func Execute(nsid string, gql string, paramList common.ParameterList) (result Ex } response := <-responseChannel paramsMap := response.Params + if len(paramsMap) > 0 { + result.LocalParams = paramsMap + } if response.Error != nil { - return result, paramsMap, response.Error + return result, response.Error } resp := response.Result if response.Result == nil { - return result, paramsMap, nil + return result, nil } if resp.IsSetPlanDesc() { format := string(resp.GetPlanDesc().GetFormat()) @@ -325,7 +330,7 @@ func Execute(nsid string, gql string, paramList common.ParameterList) (result Ex rowValue["operator info"] = rows[i][4] result.Tables = append(result.Tables, rowValue) } - return result, paramsMap, err + return result, err } else { var rowValue = make(map[string]common.Any) result.Headers = append(result.Headers, "format") @@ -335,12 +340,12 @@ func Execute(nsid string, gql string, paramList common.ParameterList) (result Ex rowValue["format"] = resp.MakeDotGraphByStruct() } result.Tables = append(result.Tables, rowValue) - return result, paramsMap, err + return result, err } } if !resp.IsSucceed() { logs.Info("ErrorCode: %v, ErrorMsg: %s", resp.GetErrorCode(), resp.GetErrorMsg()) - return result, paramsMap, errors.New(string(resp.GetErrorMsg())) + return result, errors.New(string(resp.GetErrorMsg())) } if !resp.IsEmpty() { rowSize := resp.GetRowSize() @@ -354,16 +359,16 @@ func Execute(nsid string, gql string, paramList common.ParameterList) (result Ex var _edgesParsedList = make(list, 0) var _pathsParsedList = make(list, 0) if err != nil { - return result, paramsMap, err + return result, err } for j := 0; j < colSize; j++ { rowData, err := record.GetValueByIndex(j) if err != nil { - return result, paramsMap, err + return result, err } value, err := getValue(rowData) if err != nil { - return result, paramsMap, err + return result, err } rowValue[result.Headers[j]] = value valueType := rowData.GetType() @@ -399,12 +404,12 @@ func Execute(nsid string, gql string, paramList common.ParameterList) (result Ex rowValue["_pathsParsedList"] = _pathsParsedList } if err != nil { - return result, paramsMap, err + return result, err } } result.Tables = append(result.Tables, rowValue) } } result.TimeCost = resp.GetLatency() - return result, paramsMap, nil + return result, nil } diff --git a/service/pool/pool.go b/service/pool/pool.go index 1d7418c..7c38263 100644 --- a/service/pool/pool.go +++ b/service/pool/pool.go @@ -194,6 +194,7 @@ func isCmd(query string) (isLocal bool, localCmd int, args []string) { } func executeCmd(parameterList common.ParameterList, parameterMap *common.ParameterMap) (showMap common.ParameterMap, err error) { + tempMap := make(common.ParameterMap) for _, v := range parameterList { // convert interface{} to nebula.Value if isLocal, cmd, args := isCmd(v); isLocal { @@ -207,7 +208,7 @@ func executeCmd(parameterList common.ParameterList, parameterMap *common.Paramet } case Params: if len(args) == 1 { - showMap, err = ListParams(args[0], parameterMap) + err = ListParams(args[0], &tempMap, parameterMap) } if err != nil { return nil, err @@ -215,12 +216,12 @@ func executeCmd(parameterList common.ParameterList, parameterMap *common.Paramet } } } - return showMap, nil + return tempMap, nil } func defineParams(args string, parameterMap *common.ParameterMap) (err error) { argsRewritten := strings.Replace(args, "'", "\"", -1) - reg := regexp.MustCompile(`^\s*:param\s+(\S+)\s*=>(.*)$`) + reg := regexp.MustCompile(`(?i)^\s*:param\s+(\S+)\s*=>(.*)$`) if reg == nil { err = errors.New("invalid regular expression") return @@ -252,9 +253,8 @@ func defineParams(args string, parameterMap *common.ParameterMap) (err error) { return nil } -func ListParams(args string, parameterMap *common.ParameterMap) (showMap common.ParameterMap, err error) { - reg := regexp.MustCompile(`^\s*:params\s*(\S*)\s*$`) - paramsWithGoType := make(common.ParameterMap) +func ListParams(args string, tmpParameter *common.ParameterMap, sessionMap *common.ParameterMap) (err error) { + reg := regexp.MustCompile(`(?i)^\s*:params\s*(\S*)\s*$`) if reg == nil { err = errors.New("invalid regular expression") return @@ -274,18 +274,18 @@ func ListParams(args string, parameterMap *common.ParameterMap) (showMap common. } else { paramKey := matchResult[0][1] if len(paramKey) == 0 { - for k, v := range *parameterMap { - paramsWithGoType[k] = v + for k, v := range *sessionMap { + (*tmpParameter)[k] = v } } else { - if paramValue, ok := (*parameterMap)[paramKey]; ok { - paramsWithGoType[paramKey] = paramValue + if paramValue, ok := (*sessionMap)[paramKey]; ok { + (*tmpParameter)[paramKey] = paramValue } else { err = errors.New("Unknown parameter: " + paramKey) } } } - return paramsWithGoType, nil + return nil } func NewConnection(address string, port int, username string, password string) (nsid string, err error) {