diff --git a/auth/auth.go b/auth/auth.go index 425e678..67f5f2c 100644 --- a/auth/auth.go +++ b/auth/auth.go @@ -1,12 +1,12 @@ package auth import ( - "crypto/rand" - "encoding/hex" "errors" + "fmt" "github.com/Fesaa/Media-Provider/config" "github.com/Fesaa/Media-Provider/payload" "github.com/gofiber/fiber/v2" + "github.com/golang-jwt/jwt/v5" "time" ) @@ -20,7 +20,7 @@ var ( authProvider Provider ) -func Init(cfg *config.Config) { +func Init() { authProvider = newAuth() } @@ -29,14 +29,12 @@ func I() Provider { } type authImpl struct { - tokens map[string]time.Time - pass func() string + pass func() string } func newAuth() Provider { return &authImpl{ - tokens: make(map[string]time.Time), - pass: func() string { return config.OrDefault(config.I().Password, "admin") }, + pass: func() string { return config.OrDefault(config.I().Password, "admin") }, } } @@ -57,12 +55,19 @@ func (v *authImpl) IsAuthenticated(ctx *fiber.Ctx) (bool, error) { if err != nil { return false, err } - t, ok := v.tokens[key] - if !ok { - return false, nil + + token, err := jwt.ParseWithClaims(key, &MpClaims{}, func(t *jwt.Token) (interface{}, error) { + if _, ok := t.Method.(*jwt.SigningMethodHMAC); !ok { + return nil, fmt.Errorf("unexpected signing method: %v", t.Header["alg"]) + } + + return []byte(config.I().Secret), nil + }) + if err != nil { + return false, err } - return time.Since(t) < time.Hour*24*7, nil + return token.Valid, nil } func (v *authImpl) Login(ctx *fiber.Ctx) (*payload.LoginResponse, error) { @@ -81,9 +86,25 @@ func (v *authImpl) Login(ctx *fiber.Ctx) (*payload.LoginResponse, error) { return nil, badRequest("Invalid password") } - token := generateSecureToken(32) - v.tokens[token] = time.Now().Add(time.Hour * 24 * 7) - return &payload.LoginResponse{Token: token}, nil + claims := MpClaims{ + RegisteredClaims: jwt.RegisteredClaims{ + IssuedAt: jwt.NewNumericDate(time.Now()), + ExpiresAt: jwt.NewNumericDate(func() time.Time { + if body.Remember { + return time.Now().Add(7 * 24 * time.Hour) + } + return time.Now().Add(24 * time.Hour) + }()), + }, + } + + token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) + t, err := token.SignedString([]byte(config.I().Secret)) + if err != nil { + return nil, err + } + + return &payload.LoginResponse{Token: t}, nil } func badRequest(msg string) error { @@ -92,11 +113,3 @@ func badRequest(msg string) error { Message: msg, } } - -func generateSecureToken(length int) string { - b := make([]byte, length) - if _, err := rand.Read(b); err != nil { - return "" - } - return hex.EncodeToString(b) -} diff --git a/auth/types.go b/auth/types.go index e5fe938..dc13d29 100644 --- a/auth/types.go +++ b/auth/types.go @@ -3,8 +3,13 @@ package auth import ( "github.com/Fesaa/Media-Provider/payload" "github.com/gofiber/fiber/v2" + "github.com/golang-jwt/jwt/v5" ) +type MpClaims struct { + jwt.RegisteredClaims +} + type Provider interface { // IsAuthenticated checks the current request for authentication. This should be handled by the middleware IsAuthenticated(ctx *fiber.Ctx) (bool, error) diff --git a/config/config.go b/config/config.go index 535c58a..c63b34b 100644 --- a/config/config.go +++ b/config/config.go @@ -9,6 +9,7 @@ type Config struct { Password string `json:"password" validate:"required"` RootDir string `json:"root_dir"` BaseUrl string `json:"base_url"` + Secret string `json:"secret"` Logging Logging `json:"logging"` Downloader Downloader `json:"downloader"` diff --git a/config/default.go b/config/default.go index bd4d4d7..17775de 100644 --- a/config/default.go +++ b/config/default.go @@ -1,17 +1,34 @@ package config import ( + "crypto/rand" + "encoding/base64" "log/slog" "os" "path" ) +func GenerateSecret(length int) (string, error) { + secret := make([]byte, length) + _, err := rand.Read(secret) + if err != nil { + return "", err + } + return base64.URLEncoding.EncodeToString(secret), nil +} + func defaultConfig() *Config { + secret, err := GenerateSecret(64) + if err != nil { + panic(err) + } + return &Config{ SyncId: 0, Password: "admin", RootDir: path.Join(OrDefault(os.Getenv("CONFIG_DIR"), "."), "temp"), BaseUrl: "", + Secret: secret, Logging: Logging{ Level: slog.LevelInfo, Source: true, diff --git a/go.mod b/go.mod index 84aed2b..6eae98b 100644 --- a/go.mod +++ b/go.mod @@ -56,6 +56,7 @@ require ( github.com/gocolly/colly v1.2.0 // indirect github.com/gofiber/template v1.8.3 // indirect github.com/gofiber/utils v1.1.0 // indirect + github.com/golang-jwt/jwt/v5 v5.2.1 // indirect github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect github.com/golang/protobuf v1.5.4 // indirect github.com/google/btree v1.1.2 // indirect diff --git a/go.sum b/go.sum index a247c69..05ac1a1 100644 --- a/go.sum +++ b/go.sum @@ -189,6 +189,8 @@ github.com/gofiber/utils v1.1.0 h1:vdEBpn7AzIUJRhe+CiTOJdUcTg4Q9RK+pEa0KPbLdrM= github.com/gofiber/utils v1.1.0/go.mod h1:poZpsnhBykfnY1Mc0KeEa6mSHrS3dV0+oBWyeQmb2e0= github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ= github.com/gogo/protobuf v1.2.0/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ= +github.com/golang-jwt/jwt/v5 v5.2.1 h1:OuVbFODueb089Lh128TAcimifWaLhJwVflnrgM17wHk= +github.com/golang-jwt/jwt/v5 v5.2.1/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= github.com/golang/groupcache v0.0.0-20190702054246-869f871628b6/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= diff --git a/main.go b/main.go index 9022f9e..944329f 100644 --- a/main.go +++ b/main.go @@ -25,7 +25,7 @@ func init() { validateConfig(cfg) UpdateBaseUrlInIndex(cfg.BaseUrl) - auth.Init(cfg) + auth.Init() yoitsu.Init(cfg) mangadex.Init(cfg) } diff --git a/pre_startup.go b/pre_startup.go index 494e979..7179a91 100644 --- a/pre_startup.go +++ b/pre_startup.go @@ -68,8 +68,17 @@ func validateRootConfig(c *config.Config) error { c.BaseUrl += "/" } + if c.Secret == "" { + secret, err := config.GenerateSecret(64) + if err != nil { + return err + } + c.Secret = secret + changed = true + } + if changed { - log.Warn("BaseUrl was forcefully changed, saving config", "baseUrl", c.BaseUrl) + log.Warn("Config was changed by validateRootConfig, saving...") return c.Save() }