From a745aa954a0c9c98b93259aebda10df4be6471b0 Mon Sep 17 00:00:00 2001 From: Yunkon Kim Date: Thu, 23 Nov 2023 21:33:40 +0900 Subject: [PATCH] Add a custom middleware to check the list of trusted proxies --- websrc/serve/serve.go | 29 +++++++++++++++++++++++------ 1 file changed, 23 insertions(+), 6 deletions(-) diff --git a/websrc/serve/serve.go b/websrc/serve/serve.go index 13c1318..8338c4a 100644 --- a/websrc/serve/serve.go +++ b/websrc/serve/serve.go @@ -5,14 +5,13 @@ import ( "html/template" "io" "net/http" + "strings" "github.com/cloud-barista/cm-data-mold/websrc/routes" "github.com/labstack/echo/v4" "github.com/labstack/echo/v4/middleware" ) -// var router *gin.Engine - // TemplateRenderer is a custom html/template renderer for Echo framework type TemplateRenderer struct { templates *template.Template @@ -29,6 +28,25 @@ func (t *TemplateRenderer) Render(w io.Writer, name string, data interface{}, c return t.templates.ExecuteTemplate(w, name, data) } +// Custom middleware to check the list of trusted proxies +func TrustedProxiesMiddleware(trustedProxies []string) echo.MiddlewareFunc { + return func(next echo.HandlerFunc) echo.HandlerFunc { + return func(c echo.Context) error { + clientIP := c.RealIP() // Echo gets the real IP of the client + + for _, proxy := range trustedProxies { + if strings.HasPrefix(clientIP, proxy) { + // Request is from a trusted proxy + return next(c) + } + } + + // Handling requests from untrusted sources + return echo.NewHTTPError(http.StatusForbidden, "Access denied") + } + } +} + func InitServer() *echo.Echo { // router = gin.New() // router.Use(gin.Logger()) @@ -41,17 +59,15 @@ func InitServer() *echo.Echo { // router.ForwardedByClientIP = true // router.SetTrustedProxies([]string{"127.0.0.1"}) - - // help needed + e.Use(TrustedProxiesMiddleware([]string{"127.0.0.1"})) // router.Static("/res", "./web") // router.LoadHTMLGlob("./web/templates/*") // router.StaticFile("/favicon.ico", "./web/assets/favicon.ico") - e.Static("/res", "./web") e.File("/favicon.ico", "./web/assets/favicon.ico") renderer := &TemplateRenderer{ - templates: template.Must(template.ParseGlob("*.html")), + templates: template.Must(template.ParseGlob("./web/templates/*.html")), } e.Renderer = renderer @@ -78,6 +94,7 @@ func Run(rt *echo.Echo, port string) { // rt.Run(":" + port) port = fmt.Sprintf(":%s", port) if err := rt.Start(port); err != nil && err != http.ErrServerClosed { + rt.Logger.Error(err) rt.Logger.Panic("shuttig down the server") } }