Skip to content

Commit

Permalink
Merge pull request #414 from Patrick0308/gorestful-filter
Browse files Browse the repository at this point in the history
Add: Support let user add global filter for go-restful server
  • Loading branch information
AlexStocks authored Mar 16, 2020
2 parents e8f7526 + 0852fd2 commit bde7db3
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 0 deletions.
15 changes: 15 additions & 0 deletions protocol/rest/rest_invoker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
)

import (
"github.com/emicklei/go-restful/v3"
"github.com/stretchr/testify/assert"
)

Expand All @@ -35,12 +36,25 @@ import (
"github.com/apache/dubbo-go/protocol/rest/client"
"github.com/apache/dubbo-go/protocol/rest/client/client_impl"
rest_config "github.com/apache/dubbo-go/protocol/rest/config"
"github.com/apache/dubbo-go/protocol/rest/server/server_impl"
)

func TestRestInvoker_Invoke(t *testing.T) {
// Refer
proto := GetRestProtocol()
defer proto.Destroy()
var filterNum int
server_impl.AddGoRestfulServerFilter(func(request *restful.Request, response *restful.Response, chain *restful.FilterChain) {
println(request.SelectedRoutePath())
filterNum = filterNum + 1
chain.ProcessFilter(request, response)
})
server_impl.AddGoRestfulServerFilter(func(request *restful.Request, response *restful.Response, chain *restful.FilterChain) {
println("filter2")
filterNum = filterNum + 1
chain.ProcessFilter(request, response)
})

url, err := common.NewURL("rest://127.0.0.1:8877/com.ikurento.user.UserProvider?anyhost=true&" +
"application=BDTService&category=providers&default.timeout=10000&dubbo=dubbo-provider-golang-1.0.0&" +
"environment=dev&interface=com.ikurento.user.UserProvider&ip=192.168.56.1&methods=GetUser%2C&" +
Expand Down Expand Up @@ -191,6 +205,7 @@ func TestRestInvoker_Invoke(t *testing.T) {
res = invoker.Invoke(context.Background(), inv)
assert.Error(t, res.Error(), "test error")

assert.Equal(t, filterNum, 12)
err = common.ServiceMap.UnRegister(url.Protocol, "com.ikurento.user.UserProvider")
assert.NoError(t, err)
}
11 changes: 11 additions & 0 deletions protocol/rest/server/server_impl/go_restful_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ func init() {
extension.SetRestServer(constant.DEFAULT_REST_SERVER, GetNewGoRestfulServer)
}

var filterSlice []restful.FilterFunction

type GoRestfulServer struct {
srv *http.Server
container *restful.Container
Expand All @@ -59,6 +61,9 @@ func NewGoRestfulServer() *GoRestfulServer {

func (grs *GoRestfulServer) Start(url common.URL) {
grs.container = restful.NewContainer()
for _, filter := range filterSlice {
grs.container.Filter(filter)
}
grs.srv = &http.Server{
Handler: grs.container,
}
Expand Down Expand Up @@ -309,3 +314,9 @@ func getArgsFromRequest(req *restful.Request, argsTypes []reflect.Type, config *
func GetNewGoRestfulServer() server.RestServer {
return NewGoRestfulServer()
}

// Let user addFilter
// addFilter should before config.Load()
func AddGoRestfulServerFilter(filterFuc restful.FilterFunction) {
filterSlice = append(filterSlice, filterFuc)
}

0 comments on commit bde7db3

Please sign in to comment.