diff --git a/app/cmd/root.go b/app/cmd/root.go index 8e5c874..afdd7a6 100644 --- a/app/cmd/root.go +++ b/app/cmd/root.go @@ -68,13 +68,30 @@ var rootCmd = &cobra.Command{ logrus.Fatal(err) } + // read basic auth flags + basicAuthUsers := make(map[string]string) + authString, err := cmd.Flags().GetString("http-basic-auth-users") + if err != nil { + logrus.Fatal(err) + } + if authString != "" { + for _, user := range strings.Split(authString, ",") { + userParts := strings.Split(user, ":") + if len(userParts) != 2 { + logrus.Fatalf("invalid basic auth user: %s", user) + } + basicAuthUsers[userParts[0]] = userParts[1] + } + } + // build config struct cfg := &mopsos.Config{ DBProvider: provider, DBDSN: dsn, DBMigrate: migrate, - HttpListener: listener, + HttpListener: listener, + BasicAuthUsers: basicAuthUsers, EnableTracing: enableTracing, TracingTarget: tracingTarget, @@ -113,6 +130,7 @@ func Execute() { // webserver flags rootCmd.Flags().String("http-listener", ":8080", "HTTP listener") + rootCmd.Flags().String("http-basic-auth-users", "", "Comma-separated list of clusters and tokens, e.g. 'cluster1:token1,cluster2:token2'") // otel flags rootCmd.Flags().Bool("otel", false, "Enable OpenTelemetry tracing") diff --git a/app/config.go b/app/config.go index 8a6454f..b941632 100644 --- a/app/config.go +++ b/app/config.go @@ -6,7 +6,8 @@ type Config struct { DBDSN string DBMigrate bool - HttpListener string + HttpListener string + BasicAuthUsers map[string]string EnableTracing bool TracingTarget string diff --git a/app/server.go b/app/server.go index 08a047c..6db155e 100644 --- a/app/server.go +++ b/app/server.go @@ -5,6 +5,7 @@ import ( "encoding/json" "net/http" + "github.com/adfinis-sygroup/mopsos/app/models" otelObs "github.com/cloudevents/sdk-go/observability/opentelemetry/v2/client" cloudevents "github.com/cloudevents/sdk-go/v2" "github.com/cloudevents/sdk-go/v2/binding" @@ -41,6 +42,18 @@ func (s *Server) Start() { }) mux.Handle("/webhook", otelhttp.NewHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ctx := r.Context() + + // get basic auth credentials + username, password, ok := r.BasicAuth() + if !ok { + http.Error(w, "missing Authorization header", http.StatusUnauthorized) + return + } + if !s.checkAuth(username, password) { + http.Error(w, "invalid credentials", http.StatusUnauthorized) + return + } + // get event message := httproto.NewMessageFromHttpRequest(r) event, err := binding.ToEvent(context.TODO(), message) @@ -48,6 +61,21 @@ func (s *Server) Start() { logrus.WithError(err).Error("failed to decode event") return } + + // TODO consider how to harmonise this with what the handler does later on + record := &models.Record{} + if err := event.DataAs(record); err != nil { + logrus.WithError(err).Errorf("failed to unmarshal event data") + http.Error(w, "failed to unmarshal event data", http.StatusInternalServerError) + return + } + + // reject record that have not been sent from the right auth + if record.ClusterName != username { + http.Error(w, "event data does not match username", http.StatusUnauthorized) + return + } + err = s.HandleReceivedEvent(ctx, *event) if err != nil { logrus.WithError(err).Error("failed to handle event") @@ -85,3 +113,11 @@ func (s *Server) HandleReceivedEvent(ctx context.Context, event cloudevents.Even return nil } + +// checkAuth checks if the username and password are correct +func (s *Server) checkAuth(username, password string) bool { + logrus.WithFields(logrus.Fields{ + "username": username, + }).Debug("checking credentials") + return s.config.BasicAuthUsers[username] == password +}