From 6eed01221b939e8e29fc0ad94f4b0cc6ff2ba7a5 Mon Sep 17 00:00:00 2001 From: Ryan Date: Fri, 16 Jun 2023 11:48:52 +0200 Subject: [PATCH] feat(gateway): adding CORS to gateway --- api/gateway/middleware.go | 8 ++++++++ api/gateway/server_test.go | 31 +++++++++++++++++++++++++++++-- 2 files changed, 37 insertions(+), 2 deletions(-) diff --git a/api/gateway/middleware.go b/api/gateway/middleware.go index 498b9c5d64..2c88b34185 100644 --- a/api/gateway/middleware.go +++ b/api/gateway/middleware.go @@ -18,9 +18,17 @@ func (h *Handler) RegisterMiddleware(srv *Server) { setContentType, checkPostDisabled(h.state), wrapRequestContext, + enableCors, ) } +func enableCors(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Access-Control-Allow-Origin", "*") + next.ServeHTTP(w, r) + }) +} + func setContentType(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Add("Content-Type", "application/json") diff --git a/api/gateway/server_test.go b/api/gateway/server_test.go index e98d7a6091..cb8e3d17ae 100644 --- a/api/gateway/server_test.go +++ b/api/gateway/server_test.go @@ -12,8 +12,12 @@ import ( "github.com/stretchr/testify/require" ) +const ( + address = "localhost" + port = "0" +) + func TestServer(t *testing.T) { - address, port := "localhost", "0" server := NewServer(address, port) ctx, cancel := context.WithCancel(context.Background()) @@ -42,10 +46,33 @@ func TestServer(t *testing.T) { require.NoError(t, err) } +func TestCorsEnabled(t *testing.T) { + server := NewServer(address, port) + server.RegisterMiddleware(enableCors) + + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + + err := server.Start(ctx) + require.NoError(t, err) + + // register ping handler + ping := new(ping) + server.RegisterHandlerFunc("/ping", ping.ServeHTTP, http.MethodGet) + + url := fmt.Sprintf("http://%s/ping", server.ListenAddr()) + + resp, err := http.Get(url) + require.NoError(t, err) + defer resp.Body.Close() + + require.NoError(t, err) + require.Equal(t, resp.Header.Get("Access-Control-Allow-Origin"), "*") +} + // TestServer_contextLeakProtection tests to ensure a context // deadline was added by the context wrapper middleware server-side. func TestServer_contextLeakProtection(t *testing.T) { - address, port := "localhost", "0" server := NewServer(address, port) server.RegisterMiddleware(wrapRequestContext)