Skip to content

Commit

Permalink
feat: Add optional HTTP Middleware function to StartSettings for serv…
Browse files Browse the repository at this point in the history
…erimpl

Problem
---------------------

Traces for HTTP requests to the opamp-go server break, see #253

Solution
---------------------

- Add an HTTP Handler middleware function to `StartSettings`
- If this function is configured, apply it in serverimpl's `Start` where the HTTP Handler is set
- (add unit tests)

Code review notes
---------------------

- This is a step in addressing #253 but mostly just for HTTP clients and requests. There is likely more to do for maintaining trace linkage through requests that come over websocket connections
- I figured if users are using `Attach` instead of `Start`, they might have their own middleware configured for their HTTP server, so it makes more sense to hook this into `StartSettings` and the `Start` function
  • Loading branch information
gdfast committed Mar 15, 2024
1 parent ba70a24 commit bdd080a
Show file tree
Hide file tree
Showing 3 changed files with 133 additions and 1 deletion.
4 changes: 4 additions & 0 deletions server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,10 @@ type StartSettings struct {

// Server's TLS configuration.
TLSConfig *tls.Config

// HTTPMiddlewareFunc specifies middleware for HTTP messages received by the server.
// This function is optional to set.
HTTPMiddlewareFunc func(handlerFunc http.HandlerFunc) http.HandlerFunc
}

type HTTPHandlerFunc func(http.ResponseWriter, *http.Request)
Expand Down
6 changes: 5 additions & 1 deletion server/serverimpl.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,11 @@ func (s *server) Start(settings StartSettings) error {
path = defaultOpAMPPath
}

mux.HandleFunc(path, s.httpHandler)
if settings.HTTPMiddlewareFunc != nil {
mux.HandleFunc(path, settings.HTTPMiddlewareFunc(s.httpHandler))
} else {
mux.HandleFunc(path, s.httpHandler)
}

hs := &http.Server{
Handler: mux,
Expand Down
124 changes: 124 additions & 0 deletions server/serverimpl_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,31 @@ func TestServerStartStop(t *testing.T) {
assert.NoError(t, err)
}

func TestServerStartStopWithMiddleware(t *testing.T) {
var addedMiddleware atomic.Bool
assert.False(t, addedMiddleware.Load())

testHTTPMiddlewareFunc := func(handlerFunc http.HandlerFunc) http.HandlerFunc {
addedMiddleware.Store(true)
return func(writer http.ResponseWriter, request *http.Request) {
handlerFunc(writer, request)
}
}

startSettings := &StartSettings{
HTTPMiddlewareFunc: testHTTPMiddlewareFunc,
}

srv := startServer(t, startSettings)
assert.True(t, addedMiddleware.Load())

err := srv.Start(*startSettings)
assert.ErrorIs(t, err, errAlreadyStarted)

err = srv.Stop(context.Background())
assert.NoError(t, err)
}

func TestServerAddrWithNonZeroPort(t *testing.T) {
srv := New(&sharedinternal.NopLogger{})
require.NotNil(t, srv)
Expand Down Expand Up @@ -830,6 +855,105 @@ func TestConnectionAllowsConcurrentWrites(t *testing.T) {
}
}

func TestServerCallsHTTPMiddlewareOverWebsocket(t *testing.T) {
middlewareCalled := int32(0)

testHTTPMiddlewareFunc := func(handlerFunc http.HandlerFunc) http.HandlerFunc {
return func(writer http.ResponseWriter, request *http.Request) {
atomic.AddInt32(&middlewareCalled, 1)
handlerFunc(writer, request)
}
}

callbacks := CallbacksStruct{
OnConnectingFunc: func(request *http.Request) types.ConnectionResponse {
return types.ConnectionResponse{
Accept: true,
ConnectionCallbacks: ConnectionCallbacksStruct{},
}
},
}

// Start a Server
settings := &StartSettings{
HTTPMiddlewareFunc: testHTTPMiddlewareFunc,
Settings: Settings{Callbacks: callbacks},
}
srv := startServer(t, settings)
defer func() {
err := srv.Stop(context.Background())
assert.NoError(t, err)
}()

// Connect to the server, ensuring successful connection
conn, resp, err := dialClient(settings)
assert.NoError(t, err)
assert.NotNil(t, conn)
require.NotNil(t, resp)
assert.EqualValues(t, 101, resp.StatusCode)

// Verify middleware was called once for the websocket connection
eventually(t, func() bool { return atomic.LoadInt32(&middlewareCalled) == int32(1) })
assert.Equal(t, int32(1), atomic.LoadInt32(&middlewareCalled))
}

func TestServerCallsHTTPMiddlewareOverHTTP(t *testing.T) {
middlewareCalled := int32(0)

testHTTPMiddlewareFunc := func(handlerFunc http.HandlerFunc) http.HandlerFunc {
return func(writer http.ResponseWriter, request *http.Request) {
atomic.AddInt32(&middlewareCalled, 1)
handlerFunc(writer, request)
}
}

callbacks := CallbacksStruct{
OnConnectingFunc: func(request *http.Request) types.ConnectionResponse {
return types.ConnectionResponse{
Accept: true,
ConnectionCallbacks: ConnectionCallbacksStruct{},
}
},
}

// Start a Server
settings := &StartSettings{
HTTPMiddlewareFunc: testHTTPMiddlewareFunc,
Settings: Settings{Callbacks: callbacks},
}
srv := startServer(t, settings)
defer func() {
err := srv.Stop(context.Background())
assert.NoError(t, err)
}()

// Send an AgentToServer message to the Server
sendMsg1 := protobufs.AgentToServer{InstanceUid: "01BX5ZZKBKACTAV9WEVGEMMVS1"}
serializedProtoBytes1, err := proto.Marshal(&sendMsg1)
require.NoError(t, err)
_, err = http.Post(
"http://"+settings.ListenEndpoint+settings.ListenPath,
contentTypeProtobuf,
bytes.NewReader(serializedProtoBytes1),
)
require.NoError(t, err)

// Send another AgentToServer message to the Server
sendMsg2 := protobufs.AgentToServer{InstanceUid: "01BX5ZZKBKACTAV9WEVGEMMVRZ"}
serializedProtoBytes2, err := proto.Marshal(&sendMsg2)
require.NoError(t, err)
_, err = http.Post(
"http://"+settings.ListenEndpoint+settings.ListenPath,
contentTypeProtobuf,
bytes.NewReader(serializedProtoBytes2),
)
require.NoError(t, err)

// Verify middleware was triggered for each HTTP call
eventually(t, func() bool { return atomic.LoadInt32(&middlewareCalled) == int32(2) })
assert.Equal(t, int32(2), atomic.LoadInt32(&middlewareCalled))
}

func BenchmarkSendToClient(b *testing.B) {
clientConnections := []*websocket.Conn{}
serverConnections := []types.Connection{}
Expand Down

0 comments on commit bdd080a

Please sign in to comment.