Skip to content

Commit

Permalink
feat: context auth (#220)
Browse files Browse the repository at this point in the history
  • Loading branch information
huaxiabuluo authored and hetao92 committed May 26, 2022
1 parent 23567a9 commit f553a98
Show file tree
Hide file tree
Showing 11 changed files with 81 additions and 100 deletions.
1 change: 1 addition & 0 deletions app/utils/http.ts
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ service.interceptors.request.use(config => {

service.interceptors.response.use(
(response: any) => {

// const isExecReq = /api-nebula\/db\/(exec|batchExec)$/.test(response.config?.url);
if (response.data?.data?.data) {
response.data.data = response.data.data.data;
Expand Down
5 changes: 2 additions & 3 deletions server-v2/api/studio/internal/common/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,12 @@ import (

var (
ReserveRequestRoutes = []string{
"/api-nebula/db/disconnect",
"/api-nebula/db/",
"/api/files",
"/api/import-tasks",
}
ReserveResponseRoutes = []string{
"/api-nebula/db/connect",
"/api-nebula/db/disconnect",
"/api-nebula/db/",
"/api/import-tasks",
}
IgnoreHandlerBodyPatterns = []*regexp.Regexp{
Expand Down

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions server-v2/api/studio/internal/logic/gateway/disonnectlogic.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,6 @@ func NewDisonnectLogic(ctx context.Context, svcCtx *svc.ServiceContext) Disonnec
}
}

func (l *DisonnectLogic) Disonnect(req types.DisconnectDBParams) (*types.AnyResponse, error) {
return service.NewGatewayService(l.ctx, l.svcCtx).DisconnectDB(&req)
func (l *DisonnectLogic) Disonnect() (*types.AnyResponse, error) {
return service.NewGatewayService(l.ctx, l.svcCtx).DisconnectDB()
}
37 changes: 31 additions & 6 deletions server-v2/api/studio/internal/service/gateway.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ type (
ExecNGQL(request *types.ExecNGQLParams) (*types.AnyResponse, error)
BatchExecNGQL(request *types.BatchExecNGQLParams) (*types.AnyResponse, error)
ConnectDB(request *types.ConnectDBParams) (*types.ConnectDBResult, error)
DisconnectDB(request *types.DisconnectDBParams) (*types.AnyResponse, error)
DisconnectDB() (*types.AnyResponse, error)
}

gatewayService struct {
Expand Down Expand Up @@ -74,7 +74,7 @@ func (s *gatewayService) ConnectDB(request *types.ConnectDBParams) (*types.Conne
}, nil
}

func (s *gatewayService) DisconnectDB(request *types.DisconnectDBParams) (*types.AnyResponse, error) {
func (s *gatewayService) DisconnectDB() (*types.AnyResponse, error) {
httpReq, _ := middleware.GetRequest(s.ctx)
httpRes, _ := middleware.GetResponseWriter(s.ctx)

Expand All @@ -90,20 +90,45 @@ func (s *gatewayService) DisconnectDB(request *types.DisconnectDBParams) (*types
}

func (s *gatewayService) ExecNGQL(request *types.ExecNGQLParams) (*types.AnyResponse, error) {
execute, _, err := dao.Execute(request.NSID, request.Gql, request.ParamList)
httpReq, _ := middleware.GetRequest(s.ctx)
NSIDCookie, NSIDErr := httpReq.Cookie(auth.NSIDName)
if NSIDErr != nil {
return nil, ecode.WithSessionMessage(NSIDErr)
}

execute, _, err := dao.Execute(NSIDCookie.Value, request.Gql, request.ParamList)
if err != nil {
return nil, ecode.WithCode(ecode.ErrInternalServer, err, "exec failed")
// TODO: common middleware should handle this
subErrMsgStr := []string{
"session expired",
"connection refused",
"broken pipe",
"an existing connection was forcibly closed",
"Token is expired",
}
for _, subErrMsg := range subErrMsgStr {
if strings.Contains(err.Error(), subErrMsg) {
return nil, ecode.WithSessionMessage(err)
}
}
return nil, ecode.WithErrorMessage(ecode.ErrInternalServer, err, "execute failed")
}

return &types.AnyResponse{Data: execute}, nil
}

func (s *gatewayService) BatchExecNGQL(request *types.BatchExecNGQLParams) (*types.AnyResponse, error) {
data := make([]map[string]interface{}, 0)
httpReq, _ := middleware.GetRequest(s.ctx)
NSIDCookie, NSIDErr := httpReq.Cookie(auth.NSIDName)
if NSIDErr != nil {
return nil, ecode.WithSessionMessage(NSIDErr)
}

NSID := request.NSID
NSID := NSIDCookie.Value
gqls := request.Gqls
paramList := request.ParamList

data := make([]map[string]interface{}, 0)
for _, gql := range gqls {
execute, _, err := dao.Execute(NSID, gql, make([]string, 0))
gqlRes := map[string]interface{}{"gql": gql, "data": execute}
Expand Down
17 changes: 13 additions & 4 deletions server-v2/api/studio/internal/service/import.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"github.com/vesoft-inc/nebula-studio/server-v2/api/studio/internal/service/importer"
"github.com/vesoft-inc/nebula-studio/server-v2/api/studio/internal/svc"
"github.com/vesoft-inc/nebula-studio/server-v2/api/studio/internal/types"
"github.com/vesoft-inc/nebula-studio/server-v2/api/studio/pkg/auth"
"github.com/vesoft-inc/nebula-studio/server-v2/api/studio/pkg/ecode"
"github.com/vesoft-inc/nebula-studio/server-v2/api/studio/pkg/utils"
"github.com/zeromicro/go-zero/core/logx"
Expand Down Expand Up @@ -151,7 +152,9 @@ func (i *importService) CreateImportTask(req *types.CreateImportTaskRequest) (*t
}

func (i *importService) StopImportTask(req *types.StopImportTaskRequest) error {
return importer.StopImportTask(req.Id, req.Address+":"+req.Port, req.Username)
auth := i.ctx.Value(auth.CtxKeyUserInfo{}).(*auth.AuthData)
host := fmt.Sprintf("%s:%d", auth.Address, auth.Port)
return importer.StopImportTask(req.Id, host, auth.Username)
}

func (i *importService) DownloadConfig(req *types.DownloadConfigsRequest) error {
Expand Down Expand Up @@ -201,15 +204,21 @@ func (i *importService) DownloadLogs(req *types.DownloadLogsRequest) error {
}

func (i *importService) DeleteImportTask(req *types.DeleteImportTaskRequest) error {
return importer.DeleteImportTask(i.svcCtx.Config.File.TasksDir, req.Id, req.Address+":"+req.Port, req.Username)
auth := i.ctx.Value(auth.CtxKeyUserInfo{}).(*auth.AuthData)
host := fmt.Sprintf("%s:%d", auth.Address, auth.Port)
return importer.DeleteImportTask(i.svcCtx.Config.File.TasksDir, req.Id, host, auth.Username)
}

func (i *importService) GetImportTask(req *types.GetImportTaskRequest) (*types.GetImportTaskData, error) {
return importer.GetImportTask(i.svcCtx.Config.File.TasksDir, req.Id, req.Address+":"+req.Port, req.Username)
auth := i.ctx.Value(auth.CtxKeyUserInfo{}).(*auth.AuthData)
host := fmt.Sprintf("%s:%d", auth.Address, auth.Port)
return importer.GetImportTask(i.svcCtx.Config.File.TasksDir, req.Id, host, auth.Username)
}

func (i *importService) GetManyImportTask(req *types.GetManyImportTaskRequest) (*types.GetManyImportTaskData, error) {
return importer.GetManyImportTask(i.svcCtx.Config.File.TasksDir, req.Address+":"+req.Port, req.Username, req.Page, req.PageSize)
auth := i.ctx.Value(auth.CtxKeyUserInfo{}).(*auth.AuthData)
host := fmt.Sprintf("%s:%d", auth.Address, auth.Port)
return importer.GetManyImportTask(i.svcCtx.Config.File.TasksDir, host, auth.Username, req.Page, req.PageSize)
}

// GetImportTaskLogNames :Get all log file's name of a task
Expand Down
37 changes: 9 additions & 28 deletions server-v2/api/studio/internal/types/types.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

27 changes: 11 additions & 16 deletions server-v2/api/studio/pkg/auth/authorize.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package auth

import (
"context"
"encoding/base64"
"errors"
"fmt"
Expand All @@ -26,6 +27,7 @@ type (
Address string `json:"address"`
Port int `json:"port"`
Username string `json:"username"`
Password string `json:"password"`
}

authClaims struct {
Expand Down Expand Up @@ -110,6 +112,7 @@ func ParseConnectDBParams(params *types.ConnectDBParams, config *config.Config)
Address: params.Address,
Port: params.Port,
Username: username,
Password: password,
},
config,
)
Expand All @@ -124,17 +127,8 @@ func AuthMiddlewareWithCtx(svcCtx *svc.ServiceContext) rest.Middleware {
return
}

NSIDCookie, NSIDErr := r.Cookie(NSIDName)
if NSIDErr == nil {
// Add NSID to request query
utils.AddQueryParams(r, map[string]string{"NSID": NSIDCookie.Value})
}

tokenCookie, tokenErr := r.Cookie(TokenName)
if NSIDErr != nil {
svcCtx.ResponseHandler.Handle(w, r, nil, ecode.WithSessionMessage(NSIDErr))
return
} else if tokenErr != nil {
if tokenErr != nil {
svcCtx.ResponseHandler.Handle(w, r, nil, ecode.WithSessionMessage(tokenErr))
return
}
Expand All @@ -145,12 +139,13 @@ func AuthMiddlewareWithCtx(svcCtx *svc.ServiceContext) rest.Middleware {
return
}

// Add address|port|username to request query
utils.AddQueryParams(r, map[string]string{
"address": auth.Address,
"port": fmt.Sprintf("%d", auth.Port),
"username": auth.Username,
})
/**
* Add auth to request context
*
* Get auth from context:
* auth := s.ctx.Value(auth.CtxKeyUserInfo{}).(*auth.AuthData)
*/
r = r.WithContext(context.WithValue(r.Context(), CtxKeyUserInfo{}, auth))

next(w, r)
}
Expand Down
5 changes: 5 additions & 0 deletions server-v2/api/studio/pkg/ecode/codes.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,11 @@ func WithSessionMessage(err error, formatWithArgs ...interface{}) error {
return WithCode(ErrSessionWithMessage, err, formatWithArgs...)
}

func WithErrorMessage(c *ErrCode, err error, formatWithArgs ...interface{}) error {
ErrWithMessage := newErrCode(c.GetCode(), PlatformCode, 1, fmt.Sprintf("%s::%s", c.GetMessage(), err.Error()))
return WithCode(ErrWithMessage, err, formatWithArgs...)
}

func WithForbidden(err error, formatWithArgs ...interface{}) error {
return WithCode(ErrForbidden, err, formatWithArgs...)
}
Expand Down
8 changes: 1 addition & 7 deletions server-v2/api/studio/restapi/gateway.api
Original file line number Diff line number Diff line change
Expand Up @@ -4,25 +4,19 @@ type (
ExecNGQLParams {
Gql string `json:"gql"`
ParamList []string `json:"paramList,optional"`
NSID string `form:"NSID"`
}
BatchExecNGQLParams {
Gqls []string `json:"gqls"`
ParamList []string `json:"paramList,optional"`
NSID string `form:"NSID"`
}
ConnectDBParams {
Address string `json:"address"`
Port int `json:"port"`
NebulaVersion string `form:"nebulaVersion,optional"`
Authorization string `header:"Authorization"`
}
ConnectDBResult {
Version string `json:"version"`
}
DisconnectDBParams {
NSID string `form:"NSID,optional"`
}
AnyResponse {
Data interface{} `json:"data"`
}
Expand Down Expand Up @@ -55,5 +49,5 @@ service studio-api {

@doc "Disonnect DB"
@handler Disonnect
post /disconnect(DisconnectDBParams) returns (AnyResponse)
post /disconnect returns (AnyResponse)
}
Loading

0 comments on commit f553a98

Please sign in to comment.