diff --git a/config/config.go b/config/config.go index 470700b..3bb769a 100644 --- a/config/config.go +++ b/config/config.go @@ -1,13 +1,24 @@ package config import ( - "git.devminer.xyz/devminer/unitel" - "github.com/joho/godotenv" - "github.com/rs/zerolog/log" "net/url" "os" "regexp" + "slices" "strconv" + "strings" + + "git.devminer.xyz/devminer/unitel" + "github.com/joho/godotenv" + "github.com/rs/zerolog/log" +) + +type Mode string + +const ( + ModeCombined Mode = "combined" + ModeWeb Mode = "web" + ModeConsumer Mode = "consumer" ) type Config struct { @@ -25,6 +36,9 @@ type Config struct { NATSURI string NATSStreamName string + Mode Mode + Consumers []string + DatabaseURI string Telemetry unitel.Opts @@ -62,6 +76,13 @@ func Load() { Msg("Both VERSIA_TLS_KEY and VERSIA_TLS_CERT have to be set if you want to use in-process TLS termination.") } + mode := getEnvStrOneOf("VERSIA_MODE", ModeCombined, ModeCombined, ModeWeb, ModeConsumer) + + var consumers []string + if raw := optionalEnvStr("VERSIA_TQ_CUSTOMERS"); raw != nil { + consumers = strings.Split(*raw, ",") + } + C = Config{ Port: getEnvInt("VERSIA_PORT", 80), TLSCert: tlsCert, @@ -76,13 +97,15 @@ func Load() { NATSURI: os.Getenv("NATS_URI"), NATSStreamName: getEnvStr("NATS_STREAM_NAME", "versia-go"), - DatabaseURI: os.Getenv("DATABASE_URI"), + + Mode: mode, + Consumers: consumers, + + DatabaseURI: os.Getenv("DATABASE_URI"), ForwardTracesTo: forwardTracesTo, Telemetry: unitel.ParseOpts("versia-go"), } - - return } func optionalEnvStr(key string) *string { @@ -93,6 +116,18 @@ func optionalEnvStr(key string) *string { return &value } +func getEnvBool(key string, default_ bool) bool { + if value, ok := os.LookupEnv(key); ok { + b, err := strconv.ParseBool(value) + if err != nil { + panic(err) + } + return b + } + + return default_ +} + func getEnvStr(key, default_ string) string { if value, ok := os.LookupEnv(key); ok { return value @@ -113,3 +148,22 @@ func getEnvInt(key string, default_ int) int { return default_ } + +func getEnvStrOneOf[T ~string](key string, default_ T, enum ...T) T { + if value, ok := os.LookupEnv(key); ok { + if !slices.Contains(enum, T(value)) { + sb := strings.Builder{} + sb.WriteString(key) + sb.WriteString(" can only be one of ") + for _, v := range enum { + sb.WriteString(string(v)) + } + + panic(sb.String()) + } + + return T(value) + } + + return default_ +} diff --git a/internal/repository/repo_impls/manager.go b/internal/repository/repo_impls/manager.go index 36ddec5..9d248f4 100644 --- a/internal/repository/repo_impls/manager.go +++ b/internal/repository/repo_impls/manager.go @@ -89,6 +89,10 @@ func (i *ManagerImpl) Atomic(ctx context.Context, fn func(ctx context.Context, t return tx.Finish() } +func (i *ManagerImpl) Ping() error { + return i.db.Ping() +} + func (i *ManagerImpl) Users() repository.UserRepository { return i.users } diff --git a/internal/repository/repository.go b/internal/repository/repository.go index b275021..a03d74c 100644 --- a/internal/repository/repository.go +++ b/internal/repository/repository.go @@ -51,6 +51,7 @@ type InstanceMetadataRepository interface { type Manager interface { Atomic(ctx context.Context, fn func(ctx context.Context, tx Manager) error) error + Ping() error Users() UserRepository Notes() NoteRepository diff --git a/internal/service/service.go b/internal/service/service.go index 0b409a5..1304ddc 100644 --- a/internal/service/service.go +++ b/internal/service/service.go @@ -2,6 +2,7 @@ package service import ( "context" + "github.com/gofiber/fiber/v2" "github.com/versia-pub/versia-go/internal/repository" "github.com/versia-pub/versia-go/pkg/versia" @@ -57,7 +58,7 @@ type InstanceMetadataService interface { } type TaskService interface { - ScheduleTask(ctx context.Context, type_ string, data any) error + ScheduleNoteTask(ctx context.Context, type_ string, data any) error } type RequestSigner interface { diff --git a/internal/service/svc_impls/note_service_impl.go b/internal/service/svc_impls/note_service_impl.go index 8e1c4c4..dbefa96 100644 --- a/internal/service/svc_impls/note_service_impl.go +++ b/internal/service/svc_impls/note_service_impl.go @@ -2,17 +2,18 @@ package svc_impls import ( "context" + "slices" + "github.com/versia-pub/versia-go/internal/repository" "github.com/versia-pub/versia-go/internal/service" + task_dtos "github.com/versia-pub/versia-go/internal/task/dtos" "github.com/versia-pub/versia-go/pkg/versia" - "slices" "git.devminer.xyz/devminer/unitel" "github.com/go-logr/logr" "github.com/google/uuid" "github.com/versia-pub/versia-go/internal/api_schema" "github.com/versia-pub/versia-go/internal/entity" - "github.com/versia-pub/versia-go/internal/tasks" ) var _ service.NoteService = (*NoteServiceImpl)(nil) @@ -69,7 +70,7 @@ func (i NoteServiceImpl) CreateNote(ctx context.Context, req api_schema.CreateNo return err } - if err := i.taskService.ScheduleTask(ctx, tasks.FederateNote, tasks.FederateNoteData{NoteID: n.ID}); err != nil { + if err := i.taskService.ScheduleNoteTask(ctx, task_dtos.FederateNote, task_dtos.FederateNoteData{NoteID: n.ID}); err != nil { return err } diff --git a/internal/service/svc_impls/task_service_impl.go b/internal/service/svc_impls/task_service_impl.go index c142741..69f4125 100644 --- a/internal/service/svc_impls/task_service_impl.go +++ b/internal/service/svc_impls/task_service_impl.go @@ -2,7 +2,9 @@ package svc_impls import ( "context" + "github.com/versia-pub/versia-go/internal/service" + "github.com/versia-pub/versia-go/internal/task" "git.devminer.xyz/devminer/unitel" "github.com/go-logr/logr" @@ -12,22 +14,22 @@ import ( var _ service.TaskService = (*TaskServiceImpl)(nil) type TaskServiceImpl struct { - client *taskqueue.Client + manager task.Manager telemetry *unitel.Telemetry log logr.Logger } -func NewTaskServiceImpl(client *taskqueue.Client, telemetry *unitel.Telemetry, log logr.Logger) *TaskServiceImpl { +func NewTaskServiceImpl(manager task.Manager, telemetry *unitel.Telemetry, log logr.Logger) *TaskServiceImpl { return &TaskServiceImpl{ - client: client, + manager: manager, telemetry: telemetry, log: log, } } -func (i TaskServiceImpl) ScheduleTask(ctx context.Context, type_ string, data any) error { +func (i TaskServiceImpl) ScheduleNoteTask(ctx context.Context, type_ string, data any) error { s := i.telemetry.StartSpan(ctx, "function", "svc_impls/TaskServiceImpl.ScheduleTask") defer s.End() ctx = s.Context() @@ -38,7 +40,7 @@ func (i TaskServiceImpl) ScheduleTask(ctx context.Context, type_ string, data an return err } - if err := i.client.Submit(ctx, t); err != nil { + if err := i.manager.Notes().Submit(ctx, t); err != nil { i.log.Error(err, "Failed to schedule task", "type", type_, "taskID", t.ID) return err } diff --git a/internal/task/dtos/note_dtos.go b/internal/task/dtos/note_dtos.go new file mode 100644 index 0000000..c3636df --- /dev/null +++ b/internal/task/dtos/note_dtos.go @@ -0,0 +1,11 @@ +package task_dtos + +import "github.com/google/uuid" + +const ( + FederateNote = "federate_note" +) + +type FederateNoteData struct { + NoteID uuid.UUID `json:"noteID"` +} diff --git a/internal/task/handler.go b/internal/task/handler.go new file mode 100644 index 0000000..802c6e3 --- /dev/null +++ b/internal/task/handler.go @@ -0,0 +1,20 @@ +package task + +import ( + "context" + + "github.com/versia-pub/versia-go/pkg/taskqueue" +) + +type Manager interface { + Notes() NoteHandler +} + +type Handler interface { + Register(*taskqueue.Set) + Submit(context.Context, taskqueue.Task) error +} + +type NoteHandler interface { + Submit(context.Context, taskqueue.Task) error +} diff --git a/internal/task/task_impls/base.go b/internal/task/task_impls/base.go new file mode 100644 index 0000000..9c279fb --- /dev/null +++ b/internal/task/task_impls/base.go @@ -0,0 +1,11 @@ +package task_impls + +import "git.devminer.xyz/devminer/unitel" + +type baseHandler struct { + telemetry *unitel.Telemetry +} + +func newBaseHandler() *baseHandler { + return &baseHandler{} +} diff --git a/internal/task/task_impls/manager.go b/internal/task/task_impls/manager.go new file mode 100644 index 0000000..ce984bf --- /dev/null +++ b/internal/task/task_impls/manager.go @@ -0,0 +1,29 @@ +package task_impls + +import ( + "git.devminer.xyz/devminer/unitel" + "github.com/go-logr/logr" + "github.com/versia-pub/versia-go/internal/task" +) + +var _ task.Manager = (*Manager)(nil) + +type Manager struct { + notes *NoteHandler + + telemetry *unitel.Telemetry + log logr.Logger +} + +func NewManager(notes *NoteHandler, telemetry *unitel.Telemetry, log logr.Logger) *Manager { + return &Manager{ + notes: notes, + + telemetry: telemetry, + log: log, + } +} + +func (m *Manager) Notes() task.NoteHandler { + return m.notes +} diff --git a/internal/task/task_impls/note_handler.go b/internal/task/task_impls/note_handler.go new file mode 100644 index 0000000..edf32c3 --- /dev/null +++ b/internal/task/task_impls/note_handler.go @@ -0,0 +1,97 @@ +package task_impls + +import ( + "context" + + "github.com/versia-pub/versia-go/internal/entity" + "github.com/versia-pub/versia-go/internal/repository" + "github.com/versia-pub/versia-go/internal/service" + "github.com/versia-pub/versia-go/internal/task" + task_dtos "github.com/versia-pub/versia-go/internal/task/dtos" + "github.com/versia-pub/versia-go/internal/utils" + + "git.devminer.xyz/devminer/unitel" + "github.com/go-logr/logr" + "github.com/versia-pub/versia-go/pkg/taskqueue" +) + +var _ task.Handler = (*NoteHandler)(nil) + +type NoteHandler struct { + federationService service.FederationService + + repositories repository.Manager + + telemetry *unitel.Telemetry + log logr.Logger + set *taskqueue.Set +} + +func NewNoteHandler(federationService service.FederationService, repositories repository.Manager, telemetry *unitel.Telemetry, log logr.Logger) *NoteHandler { + return &NoteHandler{ + federationService: federationService, + + repositories: repositories, + + telemetry: telemetry, + log: log, + } +} + +func (t *NoteHandler) Start(ctx context.Context) error { + consumer := t.set.Consumer("note-handler") + + return consumer.Start(ctx) +} + +func (t *NoteHandler) Register(s *taskqueue.Set) { + t.set = s + s.RegisterHandler(task_dtos.FederateNote, utils.ParseTask(t.FederateNote)) +} + +func (t *NoteHandler) Submit(ctx context.Context, task taskqueue.Task) error { + s := t.telemetry.StartSpan(ctx, "function", "task_impls/NoteHandler.Submit") + defer s.End() + ctx = s.Context() + + return t.set.Submit(ctx, task) +} + +func (t *NoteHandler) FederateNote(ctx context.Context, data task_dtos.FederateNoteData) error { + s := t.telemetry.StartSpan(ctx, "function", "task_impls/NoteHandler.FederateNote") + defer s.End() + ctx = s.Context() + + var n *entity.Note + if err := t.repositories.Atomic(ctx, func(ctx context.Context, tx repository.Manager) error { + var err error + n, err = tx.Notes().GetByID(ctx, data.NoteID) + if err != nil { + return err + } + if n == nil { + t.log.V(-1).Info("Could not find note", "id", data.NoteID) + return nil + } + + for _, uu := range n.Mentions { + if !uu.IsRemote { + t.log.V(2).Info("User is not remote", "user", uu.ID) + continue + } + + res, err := t.federationService.SendToInbox(ctx, n.Author, &uu, n.ToVersia()) + if err != nil { + t.log.Error(err, "Failed to send note to remote user", "res", res, "user", uu.ID) + } else { + t.log.V(2).Info("Sent note to remote user", "res", res, "user", uu.ID) + } + } + + return nil + }); err != nil { + return err + } + + return nil +} diff --git a/internal/tasks/federate_follow.go b/internal/tasks/federate_follow.go deleted file mode 100644 index fe52fbc..0000000 --- a/internal/tasks/federate_follow.go +++ /dev/null @@ -1,11 +0,0 @@ -package tasks - -import "context" - -type FederateFollowData struct { - FollowID string `json:"followID"` -} - -func (t *Handler) FederateFollow(ctx context.Context, data FederateNoteData) error { - return nil -} diff --git a/internal/tasks/federate_note.go b/internal/tasks/federate_note.go deleted file mode 100644 index 7e0a9d3..0000000 --- a/internal/tasks/federate_note.go +++ /dev/null @@ -1,52 +0,0 @@ -package tasks - -import ( - "context" - "github.com/versia-pub/versia-go/internal/repository" - - "github.com/google/uuid" - "github.com/versia-pub/versia-go/internal/entity" -) - -type FederateNoteData struct { - NoteID uuid.UUID `json:"noteID"` -} - -func (t *Handler) FederateNote(ctx context.Context, data FederateNoteData) error { - s := t.telemetry.StartSpan(ctx, "function", "tasks/Handler.FederateNote") - defer s.End() - ctx = s.Context() - - var n *entity.Note - if err := t.repositories.Atomic(ctx, func(ctx context.Context, tx repository.Manager) error { - var err error - n, err = tx.Notes().GetByID(ctx, data.NoteID) - if err != nil { - return err - } - if n == nil { - t.log.V(-1).Info("Could not find note", "id", data.NoteID) - return nil - } - - for _, uu := range n.Mentions { - if !uu.IsRemote { - t.log.V(2).Info("User is not remote", "user", uu.ID) - continue - } - - res, err := t.federationService.SendToInbox(ctx, n.Author, &uu, n.ToVersia()) - if err != nil { - t.log.Error(err, "Failed to send note to remote user", "res", res, "user", uu.ID) - } else { - t.log.V(2).Info("Sent note to remote user", "res", res, "user", uu.ID) - } - } - - return nil - }); err != nil { - return err - } - - return nil -} diff --git a/internal/tasks/handler.go b/internal/tasks/handler.go deleted file mode 100644 index 0143760..0000000 --- a/internal/tasks/handler.go +++ /dev/null @@ -1,53 +0,0 @@ -package tasks - -import ( - "context" - "encoding/json" - "github.com/versia-pub/versia-go/internal/repository" - "github.com/versia-pub/versia-go/internal/service" - - "git.devminer.xyz/devminer/unitel" - "github.com/go-logr/logr" - "github.com/versia-pub/versia-go/pkg/taskqueue" -) - -const ( - FederateNote = "federate_note" - FederateFollow = "federate_follow" -) - -type Handler struct { - federationService service.FederationService - - repositories repository.Manager - - telemetry *unitel.Telemetry - log logr.Logger -} - -func NewHandler(federationService service.FederationService, repositories repository.Manager, telemetry *unitel.Telemetry, log logr.Logger) *Handler { - return &Handler{ - federationService: federationService, - - repositories: repositories, - - telemetry: telemetry, - log: log, - } -} - -func (t *Handler) Register(tq *taskqueue.Client) { - tq.RegisterHandler(FederateNote, parse(t.FederateNote)) - tq.RegisterHandler(FederateFollow, parse(t.FederateFollow)) -} - -func parse[T any](handler func(context.Context, T) error) func(context.Context, taskqueue.Task) error { - return func(ctx context.Context, task taskqueue.Task) error { - var data T - if err := json.Unmarshal(task.Payload, &data); err != nil { - return err - } - - return handler(ctx, data) - } -} diff --git a/internal/utils/tasks.go b/internal/utils/tasks.go new file mode 100644 index 0000000..b7697e0 --- /dev/null +++ b/internal/utils/tasks.go @@ -0,0 +1,19 @@ +package utils + +import ( + "context" + "encoding/json" + + "github.com/versia-pub/versia-go/pkg/taskqueue" +) + +func ParseTask[T any](handler func(context.Context, T) error) func(context.Context, taskqueue.Task) error { + return func(ctx context.Context, task taskqueue.Task) error { + var data T + if err := json.Unmarshal(task.Payload, &data); err != nil { + return err + } + + return handler(ctx, data) + } +} diff --git a/main.go b/main.go index 16d6b5a..dd87361 100644 --- a/main.go +++ b/main.go @@ -7,42 +7,35 @@ import ( "crypto/tls" "database/sql" "database/sql/driver" - "fmt" - "git.devminer.xyz/devminer/unitel/unitelhttp" - "git.devminer.xyz/devminer/unitel/unitelsql" - "github.com/versia-pub/versia-go/ent/instancemetadata" - "github.com/versia-pub/versia-go/internal/api_schema" - "github.com/versia-pub/versia-go/internal/handlers/follow_handler" - "github.com/versia-pub/versia-go/internal/handlers/meta_handler" - "github.com/versia-pub/versia-go/internal/handlers/note_handler" - "github.com/versia-pub/versia-go/internal/repository" - "github.com/versia-pub/versia-go/internal/repository/repo_impls" - "github.com/versia-pub/versia-go/internal/service/svc_impls" - "github.com/versia-pub/versia-go/internal/validators/val_impls" "net/http" "os" "os/signal" + "slices" "strings" "sync" - "time" "entgo.io/ent/dialect" entsql "entgo.io/ent/dialect/sql" "git.devminer.xyz/devminer/unitel" + "git.devminer.xyz/devminer/unitel/unitelhttp" + "git.devminer.xyz/devminer/unitel/unitelsql" "github.com/go-logr/logr" "github.com/go-logr/zerologr" - "github.com/gofiber/fiber/v2" - "github.com/gofiber/fiber/v2/middleware/cors" pgx "github.com/jackc/pgx/v5/stdlib" "github.com/nats-io/nats.go" "github.com/rs/zerolog" "github.com/rs/zerolog/log" "github.com/versia-pub/versia-go/config" "github.com/versia-pub/versia-go/ent" + "github.com/versia-pub/versia-go/ent/instancemetadata" "github.com/versia-pub/versia-go/internal/database" - "github.com/versia-pub/versia-go/internal/handlers/user_handler" - "github.com/versia-pub/versia-go/internal/tasks" + "github.com/versia-pub/versia-go/internal/repository" + "github.com/versia-pub/versia-go/internal/repository/repo_impls" + "github.com/versia-pub/versia-go/internal/service/svc_impls" + "github.com/versia-pub/versia-go/internal/task" + "github.com/versia-pub/versia-go/internal/task/task_impls" "github.com/versia-pub/versia-go/internal/utils" + "github.com/versia-pub/versia-go/internal/validators/val_impls" "github.com/versia-pub/versia-go/pkg/taskqueue" "modernc.org/sqlite" ) @@ -52,11 +45,9 @@ func init() { log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stderr}) } -func shouldPropagate(r *http.Request) bool { - return config.C.ForwardTracesTo.Match([]byte(r.URL.String())) -} - func main() { + rootCtx, cancelRoot := context.WithCancel(context.Background()) + zerolog.SetGlobalLevel(zerolog.TraceLevel) zerologr.NameFieldName = "logger" zerologr.NameSeparator = "/" @@ -98,11 +89,10 @@ func main() { } log.Debug().Msg("Starting taskqueue client") - tq, err := taskqueue.NewClient(context.Background(), config.C.NATSStreamName, nc, tel, zerologr.New(&log.Logger).WithName("taskqueue-client")) + tq, err := taskqueue.NewClient(config.C.NATSStreamName, nc, tel, zerologr.New(&log.Logger).WithName("taskqueue-client")) if err != nil { log.Fatal().Err(err).Msg("Failed to create taskqueue client") } - defer tq.Close() log.Debug().Msg("Running schema migration") if err := migrateDB(db, zerologr.New(&log.Logger).WithName("migrate-db"), tel); err != nil { @@ -113,9 +103,8 @@ func main() { requestSigner := svc_impls.NewRequestSignerImpl(tel, zerologr.New(&log.Logger).WithName("request-signer")) federationService := svc_impls.NewFederationServiceImpl(httpClient, tel, zerologr.New(&log.Logger).WithName("federation-service")) - taskService := svc_impls.NewTaskServiceImpl(tq, tel, zerologr.New(&log.Logger).WithName("task-service")) - // Manager + // Repositories repos := repo_impls.NewManagerImpl( db, tel, zerologr.New(&log.Logger).WithName("repositories"), @@ -134,103 +123,50 @@ func main() { bodyValidator := val_impls.NewBodyValidator(zerologr.New(&log.Logger).WithName("validation-service")) requestValidator := val_impls.NewRequestValidator(repos, tel, zerologr.New(&log.Logger).WithName("request-validator")) + // Task handlers + + notes := task_impls.NewNoteHandler(federationService, repos, tel, zerologr.New(&log.Logger).WithName("task-note-handler")) + notesSet := registerTaskHandler(rootCtx, "notes", tq, notes) + + taskManager := task_impls.NewManager(notes, tel, zerologr.New(&log.Logger).WithName("task-manager")) + // Services + taskService := svc_impls.NewTaskServiceImpl(taskManager, tel, zerologr.New(&log.Logger).WithName("task-service")) userService := svc_impls.NewUserServiceImpl(repos, federationService, tel, zerologr.New(&log.Logger).WithName("user-service")) noteService := svc_impls.NewNoteServiceImpl(federationService, taskService, repos, tel, zerologr.New(&log.Logger).WithName("note-service")) followService := svc_impls.NewFollowServiceImpl(federationService, repos, tel, zerologr.New(&log.Logger).WithName("follow-service")) inboxService := svc_impls.NewInboxService(federationService, repos, tel, zerologr.New(&log.Logger).WithName("inbox-service")) instanceMetadataService := svc_impls.NewInstanceMetadataServiceImpl(federationService, repos, tel, zerologr.New(&log.Logger).WithName("instance-metadata-service")) - // Handlers - - userHandler := user_handler.New(federationService, requestSigner, userService, inboxService, bodyValidator, requestValidator, zerologr.New(&log.Logger).WithName("user-handler")) - noteHandler := note_handler.New(noteService, bodyValidator, requestSigner, zerologr.New(&log.Logger).WithName("notes-handler")) - followHandler := follow_handler.New(followService, federationService, zerologr.New(&log.Logger).WithName("follow-handler")) - metaHandler := meta_handler.New(instanceMetadataService, zerologr.New(&log.Logger).WithName("meta-handler")) - - // Initialization - - if err := initServerActor(db, tel); err != nil { - log.Fatal().Err(err).Msg("Failed to initialize server actor") - } - - web := fiber.New(fiber.Config{ - ProxyHeader: "X-Forwarded-For", - ErrorHandler: fiberErrorHandler, - DisableStartupMessage: true, - AppName: "versia-go", - EnablePrintRoutes: true, - }) - - web.Use(cors.New(cors.Config{ - AllowOriginsFunc: func(origin string) bool { - return true - }, - AllowMethods: "GET,POST,PUT,DELETE,PATCH", - AllowHeaders: "Origin, Content-Type, Accept, Authorization, b3, traceparent, sentry-trace, baggage", - AllowCredentials: true, - ExposeHeaders: "", - MaxAge: 0, - })) - - web.Use(unitelhttp.FiberMiddleware(tel, unitelhttp.FiberMiddlewareConfig{ - Repanic: false, - WaitForDelivery: false, - Timeout: 5 * time.Second, - // host for incoming requests - TraceRequestHeaders: []string{"origin", "x-nonce", "x-signature", "x-signed-by", "sentry-trace", "sentry-baggage"}, - // origin for outgoing requests - TraceResponseHeaders: []string{"host", "x-nonce", "x-signature", "x-signed-by", "sentry-trace", "sentry-baggage"}, - IgnoredRoutes: []string{"/api/health"}, - Logger: zerologr.New(&log.Logger).WithName("http-server"), - TracePropagator: shouldPropagate, - })) - web.Use(unitelhttp.RequestLogger(zerologr.New(&log.Logger).WithName("http-server"), true, true)) - - log.Debug().Msg("Registering handlers") - - web.Get("/api/health", healthCheck(db, nc)) - - userHandler.Register(web.Group("/")) - noteHandler.Register(web.Group("/")) - followHandler.Register(web.Group("/")) - metaHandler.Register(web.Group("/")) - wg := sync.WaitGroup{} - wg.Add(2) - - // TODO: Run these in separate processes, if wanted - go func() { - defer wg.Done() - - log.Debug().Msg("Starting taskqueue consumer") - - tasks.NewHandler(federationService, repos, tel, zerologr.New(&log.Logger).WithName("task-handler")). - Register(tq) - if err := tq.StartConsumer(context.Background(), "consumer"); err != nil { - log.Fatal().Err(err).Msg("failed to start taskqueue client") - } - }() - - go func() { - defer wg.Done() - - log.Debug().Msg("Starting server") - - addr := fmt.Sprintf(":%d", config.C.Port) + if config.C.Mode == config.ModeWeb || config.C.Mode == config.ModeCombined { + wg.Add(1) + go func() { + defer wg.Done() + + if err := server( + rootCtx, + tel, + db, + nc, + federationService, + requestSigner, + bodyValidator, + requestValidator, + userService, + noteService, + followService, + instanceMetadataService, + inboxService, + ); err != nil { + log.Fatal().Err(err).Msg("Failed to start server") + } + }() + } - var err error - if config.C.TLSKey != nil { - err = web.ListenTLS(addr, *config.C.TLSCert, *config.C.TLSKey) - } else { - err = web.Listen(addr) - } - if err != nil { - log.Fatal().Err(err).Msg("Failed to start server") - } - }() + maybeRunTaskHandler(rootCtx, "notes", notesSet, &wg) signalCh := make(chan os.Signal, 1) signal.Notify(signalCh, os.Interrupt) @@ -238,10 +174,7 @@ func main() { log.Info().Msg("Shutting down") - tq.Close() - if err := web.Shutdown(); err != nil { - log.Error().Err(err).Msg("Failed to shutdown server") - } + cancelRoot() wg.Wait() } @@ -354,27 +287,47 @@ func initServerActor(db *ent.Client, telemetry *unitel.Telemetry) error { return tx.Finish() } -func healthCheck(db *ent.Client, nc *nats.Conn) fiber.Handler { - return func(c *fiber.Ctx) error { - dbWorking := true - if err := db.Ping(); err != nil { - log.Error().Err(err).Msg("Database healthcheck failed") - dbWorking = false - } +func registerTaskHandler[T task.Handler](ctx context.Context, name string, tq *taskqueue.Client, handler T) *taskqueue.Set { + s, err := tq.Set(ctx, name) + if err != nil { + log.Fatal().Err(err).Str("handler", name).Msg("Could not create taskset for task handler") + } - natsWorking := true - if status := nc.Status(); status != nats.CONNECTED { - log.Error().Str("status", status.String()).Msg("NATS healthcheck failed") - natsWorking = false - } + handler.Register(s) - if dbWorking && natsWorking { - return c.SendString("lookin' good") - } + return s +} + +func maybeRunTaskHandler(ctx context.Context, name string, set *taskqueue.Set, wg *sync.WaitGroup) { + l := log.With().Str("handler", name).Logger() + + if config.C.Mode == config.ModeWeb { + l.Warn().Strs("requested", config.C.Consumers).Msg("Not starting task handler, as this process is running in web mode") + return + } + + if config.C.Mode == config.ModeConsumer && !slices.Contains(config.C.Consumers, name) { + l.Warn().Strs("requested", config.C.Consumers).Msg("Not starting task handler, as it wasn't requested") + return + } + + wg.Add(1) - return api_schema.ErrInternalServerError(map[string]any{ - "database": dbWorking, - "nats": natsWorking, - }) + c := set.Consumer(name) + if err := c.Start(ctx); err != nil { + l.Fatal().Err(err).Msg("Could not start task handler") } + + l.Info().Msg("Started task handler") + + go func() { + defer wg.Done() + + <-ctx.Done() + l.Debug().Msg("Got signal to stop task handler") + + c.Close() + + l.Info().Msg("Stopped task handler") + }() } diff --git a/pkg/taskqueue/client.go b/pkg/taskqueue/client.go index 010e3bc..e367a85 100644 --- a/pkg/taskqueue/client.go +++ b/pkg/taskqueue/client.go @@ -4,8 +4,7 @@ import ( "context" "encoding/json" "errors" - "strings" - "sync" + "fmt" "time" "git.devminer.xyz/devminer/unitel" @@ -53,12 +52,9 @@ func NewTask(type_ string, payload any) (Task, error) { }, nil } -type Handler func(ctx context.Context, task Task) error - type Client struct { - name string - subject string - handlers map[string][]Handler + name string + subject string nc *nats.Conn js jetstream.JetStream @@ -71,15 +67,27 @@ type Client struct { log logr.Logger } -func NewClient(ctx context.Context, streamName string, natsClient *nats.Conn, telemetry *unitel.Telemetry, log logr.Logger) (*Client, error) { +func NewClient(streamName string, natsClient *nats.Conn, telemetry *unitel.Telemetry, log logr.Logger) (*Client, error) { js, err := jetstream.New(natsClient) if err != nil { return nil, err } - s, err := js.CreateStream(ctx, jetstream.StreamConfig{ + return &Client{ + name: streamName, + + js: js, + + telemetry: telemetry, + log: log, + }, nil +} + +func (c *Client) Set(ctx context.Context, name string) (*Set, error) { + streamName := fmt.Sprintf("%s_%s", c.name, name) + + s, err := c.js.CreateStream(ctx, jetstream.StreamConfig{ Name: streamName, - Subjects: []string{streamName + ".*"}, MaxConsumers: -1, MaxMsgs: -1, Discard: jetstream.DiscardOld, @@ -89,7 +97,7 @@ func NewClient(ctx context.Context, streamName string, natsClient *nats.Conn, te AllowDirect: true, }) if errors.Is(err, nats.ErrStreamNameAlreadyInUse) { - s, err = js.Stream(ctx, streamName) + s, err = c.js.Stream(ctx, streamName) if err != nil { return nil, err } @@ -97,190 +105,13 @@ func NewClient(ctx context.Context, streamName string, natsClient *nats.Conn, te return nil, err } - stopCh := make(chan struct{}) - - c := &Client{ - name: streamName, - subject: streamName + ".tasks", - - handlers: map[string][]Handler{}, - - stopCh: stopCh, - closeOnce: sync.OnceFunc(func() { - close(stopCh) - }), - - nc: natsClient, - js: js, - s: s, - - telemetry: telemetry, - log: log, - } - - return c, nil -} - -func (c *Client) Close() { - c.closeOnce() - c.nc.Close() -} - -func (c *Client) Submit(ctx context.Context, task Task) error { - s := c.telemetry.StartSpan(ctx, "queue.publish", "taskqueue/Client.Submit"). - AddAttribute("messaging.destination.name", c.subject) - defer s.End() - ctx = s.Context() - - s.AddAttribute("jobID", task.ID) - - data, err := json.Marshal(c.newTaskWrapper(ctx, task)) - if err != nil { - return err - } - - s.AddAttribute("messaging.message.body.size", len(data)) - - msg, err := c.js.PublishMsg(ctx, &nats.Msg{Subject: c.subject, Data: data}) - if err != nil { - return err - } - c.log.V(2).Info("Submitted task", "id", task.ID, "type", task.Type, "sequence", msg.Sequence) - - s.AddAttribute("messaging.message.id", msg.Sequence) - - return nil -} - -func (c *Client) RegisterHandler(type_ string, handler Handler) { - c.log.V(2).Info("Registering handler", "type", type_) - - if _, ok := c.handlers[type_]; !ok { - c.handlers[type_] = []Handler{} - } - c.handlers[type_] = append(c.handlers[type_], handler) -} - -func (c *Client) StartConsumer(ctx context.Context, consumerGroup string) error { - c.log.Info("Starting consumer") - - sub, err := c.js.CreateConsumer(ctx, c.name, jetstream.ConsumerConfig{ - Durable: consumerGroup, - DeliverPolicy: jetstream.DeliverAllPolicy, - ReplayPolicy: jetstream.ReplayInstantPolicy, - AckPolicy: jetstream.AckExplicitPolicy, - FilterSubject: c.subject, - MaxWaiting: 1, - MaxAckPending: 1, - HeadersOnly: false, - MemoryStorage: false, - }) - if err != nil { - return err - } - - m, err := sub.Messages(jetstream.PullMaxMessages(1)) - if err != nil { - return err - } - - go func() { - for { - msg, err := m.Next() - if err != nil { - if errors.Is(err, jetstream.ErrMsgIteratorClosed) { - c.log.Info("Stopping") - return - } - - c.log.Error(err, "Failed to get next message") - break - } - - if err := c.handleTask(ctx, msg); err != nil { - c.log.Error(err, "Failed to handle task") - break - } - } - }() - go func() { - <-c.stopCh - m.Drain() - }() - - return nil -} - -func (c *Client) handleTask(ctx context.Context, msg jetstream.Msg) error { - msgMeta, err := msg.Metadata() - if err != nil { - return err - } - - data := msg.Data() - - var w taskWrapper - if err := json.Unmarshal(data, &w); err != nil { - if err := msg.Nak(); err != nil { - c.log.Error(err, "Failed to nak message") - } + return &Set{ + handlers: make(map[string][]TaskHandler), - return err - } - - s := c.telemetry.StartSpan( - context.Background(), - "queue.process", - "taskqueue/Client.handleTask", - c.telemetry.ContinueFromMap(w.TraceInfo), - ). - AddAttribute("messaging.destination.name", c.subject). - AddAttribute("messaging.message.id", msgMeta.Sequence.Stream). - AddAttribute("messaging.message.retry.count", msgMeta.NumDelivered). - AddAttribute("messaging.message.body.size", len(data)). - AddAttribute("messaging.message.receive.latency", time.Since(w.EnqueuedAt).Milliseconds()) - defer s.End() - ctx = s.Context() - - handlers, ok := c.handlers[w.Task.Type] - if !ok { - c.log.V(2).Info("No handler for task", "type", w.Task.Type) - return msg.Nak() - } - - var errs CombinedError - for _, handler := range handlers { - if err := handler(ctx, w.Task); err != nil { - c.log.Error(err, "Handler failed", "type", w.Task.Type) - errs.Errors = append(errs.Errors, err) - } - } - - if len(errs.Errors) > 0 { - if err := msg.Nak(); err != nil { - c.log.Error(err, "Failed to nak message") - errs.Errors = append(errs.Errors, err) - } - - return errs - } - - return msg.Ack() -} - -type CombinedError struct { - Errors []error -} - -func (e CombinedError) Error() string { - sb := strings.Builder{} - sb.WriteRune('[') - for i, err := range e.Errors { - if i > 0 { - sb.WriteRune(',') - } - sb.WriteString(err.Error()) - } - sb.WriteRune(']') - return sb.String() + streamName: streamName, + c: c, + s: s, + log: c.log.WithName(fmt.Sprintf("taskset(%s)", name)), + telemetry: c.telemetry, + }, nil } diff --git a/pkg/taskqueue/errors.go b/pkg/taskqueue/errors.go new file mode 100644 index 0000000..427bc7d --- /dev/null +++ b/pkg/taskqueue/errors.go @@ -0,0 +1,20 @@ +package taskqueue + +import "strings" + +type CombinedError struct { + Errors []error +} + +func (e CombinedError) Error() string { + sb := strings.Builder{} + sb.WriteRune('[') + for i, err := range e.Errors { + if i > 0 { + sb.WriteRune(',') + } + sb.WriteString(err.Error()) + } + sb.WriteRune(']') + return sb.String() +} diff --git a/pkg/taskqueue/taskset.go b/pkg/taskqueue/taskset.go new file mode 100644 index 0000000..fcd881b --- /dev/null +++ b/pkg/taskqueue/taskset.go @@ -0,0 +1,210 @@ +package taskqueue + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "sync" + "time" + + "git.devminer.xyz/devminer/unitel" + "github.com/go-logr/logr" + "github.com/nats-io/nats.go" + "github.com/nats-io/nats.go/jetstream" +) + +type TaskHandler = func(ctx context.Context, task Task) error + +type Set struct { + handlers map[string][]TaskHandler + + streamName string + c *Client + s jetstream.Stream + log logr.Logger + telemetry *unitel.Telemetry +} + +func (t *Set) RegisterHandler(type_ string, handler TaskHandler) { + t.log.V(2).Info("Registering handler", "type", type_) + + if _, ok := t.handlers[type_]; !ok { + t.handlers[type_] = []TaskHandler{} + } + t.handlers[type_] = append(t.handlers[type_], handler) +} + +func (t *Set) Submit(ctx context.Context, task Task) error { + s := t.telemetry.StartSpan(ctx, "queue.publish", "taskqueue/TaskSet.Submit"). + AddAttribute("messaging.destination.name", t.streamName) + defer s.End() + ctx = s.Context() + + s.AddAttribute("jobID", task.ID) + + data, err := json.Marshal(t.c.newTaskWrapper(ctx, task)) + if err != nil { + return err + } + + s.AddAttribute("messaging.message.body.size", len(data)) + + // TODO: Refactor + msg, err := t.c.js.PublishMsg(ctx, &nats.Msg{Subject: t.streamName, Data: data}) + if err != nil { + return err + } + t.log.V(2).Info("Submitted task", "id", task.ID, "type", task.Type, "sequence", msg.Sequence) + + s.AddAttribute("messaging.message.id", msg.Sequence) + + return nil +} + +func (t *Set) Consumer(name string) *Consumer { + stopCh := make(chan struct{}) + stopOnce := sync.OnceFunc(func() { + close(stopCh) + }) + + return &Consumer{ + stopCh: stopCh, + stopOnce: stopOnce, + + name: name, + streamName: t.streamName, + telemetry: t.telemetry, + log: t.log.WithName(fmt.Sprintf("consumer(%s)", name)), + t: t, + } +} + +type Consumer struct { + stopCh chan struct{} + stopOnce func() + + name string + streamName string + telemetry *unitel.Telemetry + log logr.Logger + t *Set +} + +func (c *Consumer) Close() { + c.stopOnce() +} + +func (c *Consumer) Start(ctx context.Context) error { + c.log.Info("Starting consumer") + + sub, err := c.t.c.js.CreateConsumer(ctx, c.streamName, jetstream.ConsumerConfig{ + Durable: c.name, + DeliverPolicy: jetstream.DeliverAllPolicy, + ReplayPolicy: jetstream.ReplayInstantPolicy, + AckPolicy: jetstream.AckExplicitPolicy, + MaxWaiting: 1, + MaxAckPending: 1, + HeadersOnly: false, + MemoryStorage: false, + }) + if err != nil { + return err + } + + m, err := sub.Messages(jetstream.PullMaxMessages(1)) + if err != nil { + return err + } + + go c.handleMessages(m) + + go func() { + <-ctx.Done() + c.Close() + }() + + go func() { + <-c.stopCh + m.Drain() + }() + + return nil +} + +func (c *Consumer) handleMessages(m jetstream.MessagesContext) { + for { + msg, err := m.Next() + if err != nil { + if errors.Is(err, jetstream.ErrMsgIteratorClosed) { + c.log.Info("Stopping") + return + } + + c.log.Error(err, "Failed to get next message") + break + } + + if err := c.handleTask(msg); err != nil { + c.log.Error(err, "Failed to handle task") + break + } + } +} + +func (c *Consumer) handleTask(msg jetstream.Msg) error { + msgMeta, err := msg.Metadata() + if err != nil { + return err + } + + data := msg.Data() + + var w taskWrapper + if err := json.Unmarshal(data, &w); err != nil { + if err := msg.Nak(); err != nil { + c.log.Error(err, "Failed to nak message") + } + + return err + } + + s := c.telemetry.StartSpan( + context.Background(), + "queue.process", + "taskqueue/Consumer.handleTask", + c.telemetry.ContinueFromMap(w.TraceInfo), + ). + AddAttribute("messaging.destination.name", msg.Subject()). + AddAttribute("messaging.message.id", msgMeta.Sequence.Stream). + AddAttribute("messaging.message.retry.count", msgMeta.NumDelivered). + AddAttribute("messaging.message.body.size", len(data)). + AddAttribute("messaging.message.receive.latency", time.Since(w.EnqueuedAt).Milliseconds()) + defer s.End() + ctx := s.Context() + + handlers, ok := c.t.handlers[w.Task.Type] + if !ok { + c.log.V(2).Info("No handler for task", "type", w.Task.Type) + return msg.Nak() + } + + var errs CombinedError + for _, handler := range handlers { + if err := handler(ctx, w.Task); err != nil { + c.log.Error(err, "Handler failed", "type", w.Task.Type) + errs.Errors = append(errs.Errors, err) + } + } + + if len(errs.Errors) > 0 { + if err := msg.Nak(); err != nil { + c.log.Error(err, "Failed to nak message") + errs.Errors = append(errs.Errors, err) + } + + return errs + } + + return msg.Ack() +} diff --git a/server.go b/server.go new file mode 100644 index 0000000..06a8cc2 --- /dev/null +++ b/server.go @@ -0,0 +1,146 @@ +package main + +import ( + "context" + "fmt" + "git.devminer.xyz/devminer/unitel" + "git.devminer.xyz/devminer/unitel/unitelhttp" + "github.com/versia-pub/versia-go/internal/api_schema" + "github.com/versia-pub/versia-go/internal/handlers/follow_handler" + "github.com/versia-pub/versia-go/internal/handlers/meta_handler" + "github.com/versia-pub/versia-go/internal/handlers/note_handler" + "github.com/versia-pub/versia-go/internal/service" + "github.com/versia-pub/versia-go/internal/validators" + "net/http" + "sync" + "time" + + "github.com/go-logr/zerologr" + "github.com/gofiber/fiber/v2" + "github.com/gofiber/fiber/v2/middleware/cors" + "github.com/nats-io/nats.go" + "github.com/rs/zerolog/log" + "github.com/versia-pub/versia-go/config" + "github.com/versia-pub/versia-go/ent" + "github.com/versia-pub/versia-go/internal/handlers/user_handler" +) + +func shouldPropagate(r *http.Request) bool { + return config.C.ForwardTracesTo.Match([]byte(r.URL.String())) +} + +func server( + ctx context.Context, + telemetry *unitel.Telemetry, + database *ent.Client, + natsConn *nats.Conn, + federationService service.FederationService, + requestSigner service.RequestSigner, + bodyValidator validators.BodyValidator, + requestValidator validators.RequestValidator, + userService service.UserService, + noteService service.NoteService, + followService service.FollowService, + instanceMetadataService service.InstanceMetadataService, + inboxService service.InboxService, +) error { + // Handlers + + userHandler := user_handler.New(federationService, requestSigner, userService, inboxService, bodyValidator, requestValidator, zerologr.New(&log.Logger).WithName("user-handler")) + noteHandler := note_handler.New(noteService, bodyValidator, requestSigner, zerologr.New(&log.Logger).WithName("notes-handler")) + followHandler := follow_handler.New(followService, federationService, zerologr.New(&log.Logger).WithName("follow-handler")) + metaHandler := meta_handler.New(instanceMetadataService, zerologr.New(&log.Logger).WithName("meta-handler")) + + // Initialization + + web := fiber.New(fiber.Config{ + ProxyHeader: "X-Forwarded-For", + ErrorHandler: fiberErrorHandler, + DisableStartupMessage: true, + AppName: "versia-go", + EnablePrintRoutes: true, + }) + + web.Use(cors.New(cors.Config{ + AllowOriginsFunc: func(origin string) bool { + return true + }, + AllowMethods: "GET,POST,PUT,DELETE,PATCH", + AllowHeaders: "Origin, Content-Type, Accept, Authorization, b3, traceparent, sentry-trace, baggage", + AllowCredentials: true, + ExposeHeaders: "", + MaxAge: 0, + })) + + web.Use(unitelhttp.FiberMiddleware(telemetry, unitelhttp.FiberMiddlewareConfig{ + Repanic: false, + WaitForDelivery: false, + Timeout: 5 * time.Second, + // host for incoming requests + TraceRequestHeaders: []string{"origin", "x-nonce", "x-signature", "x-signed-by", "sentry-trace", "sentry-baggage"}, + // origin for outgoing requests + TraceResponseHeaders: []string{"host", "x-nonce", "x-signature", "x-signed-by", "sentry-trace", "sentry-baggage"}, + IgnoredRoutes: []string{"/api/health"}, + Logger: zerologr.New(&log.Logger).WithName("http-server"), + TracePropagator: shouldPropagate, + })) + web.Use(unitelhttp.RequestLogger(zerologr.New(&log.Logger).WithName("http-server"), true, true)) + + log.Debug().Msg("Registering handlers") + + web.Get("/api/health", healthCheck(database, natsConn)) + + userHandler.Register(web.Group("/")) + noteHandler.Register(web.Group("/")) + followHandler.Register(web.Group("/")) + metaHandler.Register(web.Group("/")) + + wg := sync.WaitGroup{} + wg.Add(2) + + log.Debug().Msg("Starting server") + + addr := fmt.Sprintf(":%d", config.C.Port) + + go func() { + <-ctx.Done() + + if err := web.Shutdown(); err != nil { + log.Error().Err(err).Msg("Failed to shutdown server") + } + }() + + var err error + if config.C.TLSKey != nil { + err = web.ListenTLS(addr, *config.C.TLSCert, *config.C.TLSKey) + } else { + err = web.Listen(addr) + } + + return err +} + +func healthCheck(db *ent.Client, nc *nats.Conn) fiber.Handler { + return func(c *fiber.Ctx) error { + dbWorking := true + if err := db.Ping(); err != nil { + log.Error().Err(err).Msg("Database healthcheck failed") + dbWorking = false + } + + natsWorking := true + if status := nc.Status(); status != nats.CONNECTED { + log.Error().Str("status", status.String()).Msg("NATS healthcheck failed") + natsWorking = false + } + + if dbWorking && natsWorking { + return c.SendString("lookin' good") + } + + return api_schema.ErrInternalServerError(map[string]any{ + "database": dbWorking, + "nats": natsWorking, + }) + } +}