From 83dea5f92061977ae33e0f8fb7d68807469a5901 Mon Sep 17 00:00:00 2001 From: rot1024 Date: Wed, 31 Aug 2022 21:56:26 +0900 Subject: [PATCH] refactor: use reearthx.authserver (#335) --- server/go.mod | 10 +- server/go.sum | 10 +- server/internal/app/app.go | 5 +- server/internal/app/auth_client.go | 2 +- server/internal/app/auth_server.go | 303 +++---------- server/internal/app/auth_server_test.go | 247 +++++++++++ server/internal/app/config.go | 27 +- server/internal/app/config_test.go | 3 +- .../infrastructure/memory/auth_request.go | 74 ---- .../internal/infrastructure/memory/config.go | 10 + .../infrastructure/memory/container.go | 3 +- server/internal/infrastructure/memory/user.go | 2 +- .../infrastructure/mongo/auth_request.go | 58 --- .../infrastructure/mongo/auth_request_test.go | 73 --- .../internal/infrastructure/mongo/config.go | 39 +- .../infrastructure/mongo/container.go | 3 +- .../mongo/mongodoc/auth_request.go | 101 ----- .../infrastructure/mongo/mongodoc/config.go | 18 +- .../infrastructure/mongo/mongodoc/user.go | 6 +- server/internal/usecase/interactor/auth.go | 414 ------------------ server/internal/usecase/interactor/user.go | 17 +- .../usecase/interactor/user_signup.go | 4 +- .../usecase/interactor/user_signup_test.go | 36 +- server/internal/usecase/repo/auth_request.go | 16 - server/internal/usecase/repo/config.go | 1 + server/internal/usecase/repo/container.go | 3 +- server/pkg/auth/builder.go | 102 ----- server/pkg/auth/client.go | 117 ----- server/pkg/auth/request.go | 143 ------ server/pkg/user/auth.go | 80 +++- server/pkg/user/auth_test.go | 68 +-- server/pkg/user/builder.go | 3 + server/pkg/user/builder_test.go | 50 ++- server/pkg/user/user.go | 83 +--- server/pkg/user/user_test.go | 329 +++----------- server/pkg/user/userops/initializer.go | 7 +- 36 files changed, 651 insertions(+), 1816 deletions(-) create mode 100644 server/internal/app/auth_server_test.go delete mode 100644 server/internal/infrastructure/memory/auth_request.go delete mode 100644 server/internal/infrastructure/mongo/auth_request.go delete mode 100644 server/internal/infrastructure/mongo/auth_request_test.go delete mode 100644 server/internal/infrastructure/mongo/mongodoc/auth_request.go delete mode 100644 server/internal/usecase/interactor/auth.go delete mode 100644 server/internal/usecase/repo/auth_request.go delete mode 100644 server/pkg/auth/builder.go delete mode 100644 server/pkg/auth/client.go delete mode 100644 server/pkg/auth/request.go diff --git a/server/go.mod b/server/go.mod index 10a3cb03cf..989d1a4bbc 100644 --- a/server/go.mod +++ b/server/go.mod @@ -8,11 +8,8 @@ require ( github.com/auth0/go-jwt-middleware/v2 v2.0.1 github.com/avast/retry-go/v4 v4.0.4 github.com/blang/semver v3.5.1+incompatible - github.com/caos/oidc v1.2.0 github.com/goccy/go-yaml v1.9.5 - github.com/golang/gddo v0.0.0-20210115222349-20d68f94ee1f github.com/google/uuid v1.3.0 - github.com/gorilla/mux v1.8.0 github.com/iancoleman/strcase v0.2.0 github.com/idubinskiy/schematyper v0.0.0-20190118213059-f71b40dac30d github.com/jarcoal/httpmock v1.2.0 @@ -26,7 +23,7 @@ require ( github.com/paulmach/go.geojson v1.4.0 github.com/pkg/errors v0.9.1 github.com/ravilushqa/otelgqlgen v0.8.0 - github.com/reearth/reearthx v0.0.0-20220830052647-7562f433f71a + github.com/reearth/reearthx v0.0.0-20220831124713-1b1373700421 github.com/samber/lo v1.27.0 github.com/sendgrid/sendgrid-go v3.11.1+incompatible github.com/sirupsen/logrus v1.8.1 @@ -38,6 +35,7 @@ require ( github.com/uber/jaeger-lib v2.4.1+incompatible github.com/vektah/dataloaden v0.3.0 github.com/vektah/gqlparser/v2 v2.4.6 + github.com/zitadel/oidc v1.5.1 go.mongodb.org/mongo-driver v1.10.1 go.opentelemetry.io/contrib/instrumentation/github.com/labstack/echo/otelecho v0.32.0 go.opentelemetry.io/contrib/instrumentation/go.mongodb.org/mongo-driver/mongo/otelmongo v0.32.0 @@ -63,7 +61,6 @@ require ( github.com/agnivade/levenshtein v1.1.1 // indirect github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751 // indirect github.com/alecthomas/units v0.0.0-20210927113745-59d0afb8317a // indirect - github.com/caos/logging v0.0.2 // indirect github.com/cpuguy83/go-md2man/v2 v2.0.1 // indirect github.com/davecgh/go-spew v1.1.1 // indirect github.com/dgryski/trifles v0.0.0-20200705224438-cafc02a1ee2b // indirect @@ -73,6 +70,7 @@ require ( github.com/go-logr/logr v1.2.3 // indirect github.com/go-logr/stdr v1.2.2 // indirect github.com/golang-jwt/jwt v3.2.2+incompatible // indirect + github.com/golang/gddo v0.0.0-20210115222349-20d68f94ee1f // indirect github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect github.com/golang/protobuf v1.5.2 // indirect github.com/golang/snappy v0.0.3 // indirect @@ -82,6 +80,7 @@ require ( github.com/googleapis/go-type-adapters v1.0.0 // indirect github.com/gopherjs/gopherjs v0.0.0-20200217142428-fce0ec30dd00 // indirect github.com/gorilla/handlers v1.5.1 // indirect + github.com/gorilla/mux v1.8.0 // indirect github.com/gorilla/schema v1.2.0 // indirect github.com/gorilla/securecookie v1.1.1 // indirect github.com/gorilla/websocket v1.5.0 // indirect @@ -109,7 +108,6 @@ require ( github.com/xrash/smetrics v0.0.0-20201216005158-039620a65673 // indirect github.com/youmark/pkcs8 v0.0.0-20181117223130-1be2e3e5546d // indirect github.com/zitadel/logging v0.3.3 // indirect - github.com/zitadel/oidc v1.5.1 // indirect go.opencensus.io v0.23.0 // indirect go.opentelemetry.io/contrib v1.7.0 // indirect go.opentelemetry.io/otel/trace v1.7.0 // indirect diff --git a/server/go.sum b/server/go.sum index ea8868a5a3..fc7b95124c 100644 --- a/server/go.sum +++ b/server/go.sum @@ -101,10 +101,6 @@ github.com/benbjohnson/clock v1.1.0/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZx github.com/blang/semver v3.5.1+incompatible h1:cQNTCjp13qL8KC3Nbxr/y2Bqb63oX6wdnnjpJbkM4JQ= github.com/blang/semver v3.5.1+incompatible/go.mod h1:kRBLl5iJ+tD4TcOOxsy/0fnwebNt5EWlYSAyrTnjyyk= github.com/bradfitz/gomemcache v0.0.0-20170208213004-1952afaa557d/go.mod h1:PmM6Mmwb0LSuEubjR8N7PtNe1KxZLtOUHtbeikc5h60= -github.com/caos/logging v0.0.2 h1:ebg5C/HN0ludYR+WkvnFjwSExF4wvyiWPyWGcKMYsoo= -github.com/caos/logging v0.0.2/go.mod h1:9LKiDE2ChuGv6CHYif/kiugrfEXu9AwDiFWSreX7Wp0= -github.com/caos/oidc v1.2.0 h1:dTy5bcT2WQbwPgytEZiG8SV1bCgHUXyDdaPDCNtRdEU= -github.com/caos/oidc v1.2.0/go.mod h1:4l0PPwdc6BbrdCFhNrRTUddsG292uHGa7gE2DSEIqoU= github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= github.com/cespare/xxhash v1.1.0/go.mod h1:XrSqR1VqqWfGrhpAt58auRo0WTKS1nRRg3ghfAqPWnc= github.com/cespare/xxhash/v2 v2.1.1/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= @@ -420,8 +416,8 @@ github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZN github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= github.com/ravilushqa/otelgqlgen v0.8.0 h1:x48k+D1GMgm87xhMO2Lekrr9YGzFbpG3yijn9GpxuAY= github.com/ravilushqa/otelgqlgen v0.8.0/go.mod h1:6JO5YO2iY4POC7R6yB/L/RKXCcyISL8qQt5NnHOhh0o= -github.com/reearth/reearthx v0.0.0-20220830052647-7562f433f71a h1:SKuK6FmWfiYCUFkS7h5weNAzrSwiGs/nxhz7bmtb57E= -github.com/reearth/reearthx v0.0.0-20220830052647-7562f433f71a/go.mod h1:YZMXO1RhQ5fFL0GIOFvJq2GNskW7w+xoW4Zfu2QUXhw= +github.com/reearth/reearthx v0.0.0-20220831124713-1b1373700421 h1:fQ/f3Vmcv3BMIArGKLDV+AcXaJZOLJu8DZDsVzczrmg= +github.com/reearth/reearthx v0.0.0-20220831124713-1b1373700421/go.mod h1:YZMXO1RhQ5fFL0GIOFvJq2GNskW7w+xoW4Zfu2QUXhw= github.com/robertkrimen/godocdown v0.0.0-20130622164427-0bfa04905481/go.mod h1:C9WhFzY47SzYBIvzFqSvHIR6ROgDo4TtdTuRaOMjF/s= github.com/rogpeppe/fastuuid v1.2.0/go.mod h1:jVj6XXZzXRy/MSR5jhDC/2q6DgLz+nrA6LYCDYWNEvQ= github.com/rogpeppe/go-internal v1.1.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= @@ -716,7 +712,6 @@ golang.org/x/sys v0.0.0-20190726091711-fc99dfbffb4e/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20191001151750-bb3f8db39f24/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191204072324-ce4227a45e2e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20191206220618-eeba5f6aabab/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191228213918-04cbcbbfeed8/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200113162924-86b910548bc1/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= @@ -1053,7 +1048,6 @@ gopkg.in/square/go-jose.v2 v2.6.0/go.mod h1:M9dMgbHiYLoDGQrXy7OpJDJWiKiU//h+vD76 gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.3/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.4/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= -gopkg.in/yaml.v2 v2.2.7/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= diff --git a/server/internal/app/app.go b/server/internal/app/app.go index 1f1ee690e4..a5926065d4 100644 --- a/server/internal/app/app.go +++ b/server/internal/app/app.go @@ -86,10 +86,7 @@ func initEcho(ctx context.Context, cfg *ServerConfig) *echo.Echo { })) // auth srv - if !cfg.Config.AuthSrv.Disabled { - auth := e.Group("") - authEndPoints(ctx, e, auth, cfg) - } + authServer(ctx, e, &cfg.Config.AuthSrv, cfg.Repos) // apis api := e.Group("/api") diff --git a/server/internal/app/auth_client.go b/server/internal/app/auth_client.go index a784b2273b..ab3898f5a8 100644 --- a/server/internal/app/auth_client.go +++ b/server/internal/app/auth_client.go @@ -62,7 +62,7 @@ func authMiddleware(cfg *ServerConfig) echo.MiddlewareFunc { // save a new sub if u != nil && au != nil { - if err := addAuth0SubToUser(ctx, u, user.AuthFromAuth0Sub(au.Sub), cfg); err != nil { + if err := addAuth0SubToUser(ctx, u, user.AuthFrom(au.Sub), cfg); err != nil { return err } } diff --git a/server/internal/app/auth_server.go b/server/internal/app/auth_server.go index 8ca41cc76f..8bcfc3eb0b 100644 --- a/server/internal/app/auth_server.go +++ b/server/internal/app/auth_server.go @@ -2,274 +2,111 @@ package app import ( "context" - "crypto/sha256" - "encoding/json" "errors" - "net/http" - "net/url" - "os" - "strconv" - "github.com/caos/oidc/pkg/op" - "github.com/golang/gddo/httputil/header" - "github.com/gorilla/mux" "github.com/labstack/echo/v4" - "github.com/reearth/reearth/server/internal/usecase/interactor" - "github.com/reearth/reearth/server/internal/usecase/interfaces" + "github.com/reearth/reearth/server/internal/usecase/repo" + "github.com/reearth/reearth/server/pkg/config" "github.com/reearth/reearth/server/pkg/user" - "github.com/reearth/reearthx/log" + "github.com/reearth/reearthx/authserver" + "github.com/reearth/reearthx/rerror" + "github.com/zitadel/oidc/pkg/oidc" ) -const ( - loginEndpoint = "api/login" - logoutEndpoint = "api/logout" - jwksEndpoint = ".well-known/jwks.json" - authProvider = "reearth" -) +const authServerDefaultClientID = "reearth-authsrv-client-default" -func authEndPoints(ctx context.Context, e *echo.Echo, r *echo.Group, cfg *ServerConfig) { - userUsecase := interactor.NewUser(cfg.Repos, cfg.Gateways, cfg.Config.SignupSecret, cfg.Config.Host_Web) +var ErrInvalidEmailORPassword = errors.New("wrong email or password") - domain := cfg.Config.AuthServeDomainURL() - if domain == nil || domain.String() == "" { - log.Panicf("auth: not valid auth domain: %s", domain) +func authServer(ctx context.Context, e *echo.Echo, cfg *AuthSrvConfig, repos *repo.Container) { + if cfg.Disabled { + return } - domain.Path = "/" - - uidomain := cfg.Config.AuthServeUIDomainURL() - config := &op.Config{ - Issuer: domain.String(), - CryptoKey: sha256.Sum256([]byte(cfg.Config.AuthSrv.Key)), - GrantTypeRefreshToken: true, - } + authserver.Endpoint(ctx, authserver.EndpointConfig{ + Issuer: cfg.Issuer, + URL: cfg.DomainURL(), + WebURL: cfg.UIDomainURL(), + DefaultClientID: authServerDefaultClientID, + Dev: cfg.Dev, + Key: cfg.Key, + DN: cfg.DN.AuthServerDNConfig(), + UserRepo: &authServerUser{User: repos.User}, + ConfigRepo: &authServerConfig{Config: repos.Config}, + RequestRepo: repos.AuthRequest, + }, e.Group("")) +} - dn := (*interactor.AuthDNConfig)(cfg.Config.AuthSrv.DN.AuthServerDNConfig()) +type authServerUser struct { + User repo.User +} - storage, err := interactor.NewAuthStorage( - ctx, - &interactor.StorageConfig{ - Domain: domain.String(), - ClientDomain: cfg.Config.Host_Web, - Debug: cfg.Debug, - DN: dn, - }, - cfg.Repos.AuthRequest, - cfg.Repos.Config, - userUsecase.GetUserBySubject, - ) +func (r *authServerUser) Sub(ctx context.Context, email, password, authRequestID string) (string, error) { + u, err := r.User.FindByNameOrEmail(ctx, email) if err != nil { - log.Fatalf("auth: init failed: %s\n", err) + if errors.Is(rerror.ErrNotFound, err) { + return "", ErrInvalidEmailORPassword + } + return "", err } - handler, err := op.NewOpenIDProvider( - ctx, - config, - storage, - op.WithHttpInterceptors(jsonToFormHandler()), - op.WithHttpInterceptors(setURLVarsHandler()), - op.WithCustomEndSessionEndpoint(op.NewEndpoint(logoutEndpoint)), - op.WithCustomKeysEndpoint(op.NewEndpoint(jwksEndpoint)), - ) + ok, err := u.MatchPassword(password) if err != nil { - log.Fatalf("auth: init failed: %s\n", err) + return "", err } - router := handler.HttpHandler().(*mux.Router) - - if err := router.Walk(muxToEchoMapper(r)); err != nil { - log.Fatalf("auth: walk failed: %s\n", err) + if !ok { + return "", ErrInvalidEmailORPassword } - // Actual login endpoint - r.POST(loginEndpoint, login(ctx, domain, uidomain, storage, userUsecase)) - - r.GET(logoutEndpoint, logout()) - - // used for auth0/auth0-react; the logout endpoint URL is hard-coded - // can be removed when the mentioned issue is solved - // https://github.com/auth0/auth0-spa-js/issues/845 - r.GET("v2/logout", logout()) - - debugMsg := "" - if dev, ok := os.LookupEnv(op.OidcDevMode); ok { - if isDev, _ := strconv.ParseBool(dev); isDev { - debugMsg = " with debug mode" - } + a := u.Auths().GetByProvider(user.ProviderReearth) + if a == nil || a.Sub == "" { + return "", ErrInvalidEmailORPassword } - log.Infof("auth: oidc server started%s at %s", debugMsg, domain.String()) -} -func setURLVarsHandler() func(handler http.Handler) http.Handler { - return func(handler http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.URL.Path != "/authorize/callback" { - handler.ServeHTTP(w, r) - return - } - - r2 := mux.SetURLVars(r, map[string]string{"id": r.URL.Query().Get("id")}) - handler.ServeHTTP(w, r2) - }) - } + return a.Sub, nil } -func jsonToFormHandler() func(handler http.Handler) http.Handler { - return func(handler http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.URL.Path != "/oauth/token" { - handler.ServeHTTP(w, r) - return - } - - if r.Header.Get("Content-Type") != "" { - value, _ := header.ParseValueAndParams(r.Header, "Content-Type") - if value != "application/json" { - // Content-Type header is not application/json - handler.ServeHTTP(w, r) - return - } - } - - if err := r.ParseForm(); err != nil { - return - } - - var result map[string]string - - if err := json.NewDecoder(r.Body).Decode(&result); err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) - return - } - - for key, value := range result { - r.Form.Set(key, value) - } - - handler.ServeHTTP(w, r) - }) +func (r *authServerUser) Info(ctx context.Context, sub string, scopes []string, ui oidc.UserInfoSetter) error { + u, err := r.User.FindByAuth0Sub(ctx, sub) + if err != nil { + return err } -} - -func muxToEchoMapper(r *echo.Group) func(route *mux.Route, router *mux.Router, ancestors []*mux.Route) error { - return func(route *mux.Route, router *mux.Router, ancestors []*mux.Route) error { - path, err := route.GetPathTemplate() - if err != nil { - return err - } - methods, err := route.GetMethods() - if err != nil { - r.Any(path, echo.WrapHandler(route.GetHandler())) - return nil - } - - for _, method := range methods { - r.Add(method, path, echo.WrapHandler(route.GetHandler())) - } - - return nil - } + ui.SetEmail(u.Email(), u.Verification().IsVerified()) + ui.SetLocale(u.Lang()) + ui.SetName(u.Name()) + return nil } -type loginForm struct { - Email string `json:"username" form:"username"` - Password string `json:"password" form:"password"` - AuthRequestID string `json:"id" form:"id"` +type authServerConfig struct { + Config repo.Config } -func login(ctx context.Context, url, uiurl *url.URL, storage op.Storage, userUsecase interfaces.User) func(ctx echo.Context) error { - return func(ec echo.Context) error { - request := new(loginForm) - err := ec.Bind(request) - if err != nil { - log.Errorln("auth: filed to parse login request") - return ec.Redirect( - http.StatusFound, - redirectURL(uiurl, "/login", "", "Bad request!"), - ) - } - - if _, err := storage.AuthRequestByID(ctx, request.AuthRequestID); err != nil { - log.Errorf("auth: filed to parse login request: %s\n", err) - return ec.Redirect( - http.StatusFound, - redirectURL(uiurl, "/login", "", "Bad request!"), - ) - } - - if len(request.Email) == 0 || len(request.Password) == 0 { - log.Errorln("auth: one of credentials are not provided") - return ec.Redirect( - http.StatusFound, - redirectURL(uiurl, "/login", request.AuthRequestID, "Bad request!"), - ) - } - - // check user credentials from db - u, err := userUsecase.GetUserByCredentials(ctx, interfaces.GetUserByCredentials{ - Email: request.Email, - Password: request.Password, - }) - var auth *user.Auth - if err == nil { - auth = u.GetAuthByProvider(authProvider) - if auth == nil { - err = errors.New("The account is not signed up with Re:Earth") - } - } - if err != nil { - log.Errorf("auth: wrong credentials: %s\n", err) - return ec.Redirect( - http.StatusFound, - redirectURL(uiurl, "/login", request.AuthRequestID, "Login failed; Invalid user ID or password."), - ) - } - - // Complete the auth request && set the subject - err = storage.(*interactor.AuthStorage).CompleteAuthRequest(ctx, request.AuthRequestID, auth.Sub) - if err != nil { - log.Errorf("auth: failed to complete the auth request: %s\n", err) - return ec.Redirect( - http.StatusFound, - redirectURL(uiurl, "/login", request.AuthRequestID, "Bad request!"), - ) - } - - return ec.Redirect( - http.StatusFound, - redirectURL(url, "/authorize/callback", request.AuthRequestID, ""), - ) +func (c *authServerConfig) Load(ctx context.Context) (*authserver.Config, error) { + cfg, err := c.Config.LockAndLoad(ctx) + if err != nil { + return nil, err } -} - -func logout() func(ec echo.Context) error { - return func(ec echo.Context) error { - u := ec.QueryParam("returnTo") - return ec.Redirect(http.StatusTemporaryRedirect, u) + if cfg.Auth == nil { + return nil, nil } + + return &authserver.Config{ + Cert: cfg.Auth.Cert, + Key: cfg.Auth.Key, + }, nil } -func redirectURL(u *url.URL, p string, requestID, err string) string { - v := cloneURL(u) - if p != "" { - v.Path = p - } - queryValues := u.Query() - queryValues.Set("id", requestID) - if err != "" { - queryValues.Set("error", err) +func (c *authServerConfig) Save(ctx context.Context, cfg *authserver.Config) error { + if cfg == nil { + return nil } - v.RawQuery = queryValues.Encode() - return v.String() + return c.Config.SaveAuth(ctx, &config.Auth{ + Cert: cfg.Cert, + Key: cfg.Key, + }) } -func cloneURL(u *url.URL) *url.URL { - return &url.URL{ - Scheme: u.Scheme, - Opaque: u.Opaque, - User: u.User, - Host: u.Host, - Path: u.Path, - } +func (c *authServerConfig) Unlock(ctx context.Context) error { + return c.Config.Unlock(ctx) } diff --git a/server/internal/app/auth_server_test.go b/server/internal/app/auth_server_test.go new file mode 100644 index 0000000000..dc51218175 --- /dev/null +++ b/server/internal/app/auth_server_test.go @@ -0,0 +1,247 @@ +package app + +import ( + "bytes" + "context" + "crypto/rand" + "crypto/sha256" + "encoding/base64" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "testing" + "time" + + "github.com/labstack/echo/v4" + "github.com/reearth/reearth/server/internal/infrastructure/mongo" + "github.com/reearth/reearth/server/internal/usecase/repo" + "github.com/reearth/reearth/server/pkg/user" + "github.com/reearth/reearthx/authserver" + "github.com/reearth/reearthx/mongox" + "github.com/reearth/reearthx/mongox/mongotest" + "github.com/reearth/reearthx/util" + "github.com/samber/lo" + "github.com/stretchr/testify/assert" + "gopkg.in/square/go-jose.v2" + "gopkg.in/square/go-jose.v2/jwt" +) + +func init() { + mongotest.Env = "REEARTH_DB" +} + +func TestEndpoint(t *testing.T) { + ctx := context.Background() + db := mongotest.Connect(t)(t) + e := echo.New() + lr := lo.Must(mongo.NewLock(db.Collection("locks"))) + cr := mongo.NewConfig(db.Collection("config"), lr) + ur := mongo.NewUser(mongox.NewClientWithDatabase(db)) + rr := authserver.NewMongo(mongox.NewClientCollection(db.Collection("authRequest"))) + + uid := user.NewID() + usr := user.New().ID(uid). + Name("aaa"). + Workspace(user.NewWorkspaceID()). + Email("aaa@example.com"). + Auths(user.Auths{user.NewReearthAuth("subsub")}). + PasswordPlainText("Xyzxyz123"). + Verification(user.VerificationFrom("", time.Time{}, true)). + MustBuild() + lo.Must0(ur.Save(ctx, usr)) + + ts := httptest.NewServer(e) + defer ts.Close() + + authServer(ctx, e, &AuthSrvConfig{ + Domain: "https://example.com", + UIDomain: "https://web.example.com", + }, &repo.Container{ + AuthRequest: rr, + Config: cr, + User: ur, + }) + + // step 1 + verifier, challenge := randomCodeChallenge() + res := send(http.MethodGet, ts.URL+"/authorize", false, map[string]string{ + "response_type": "code", + "client_id": authServerDefaultClientID, + "redirect_uri": "https://web.example.com", + "scope": "openid email profile", + "state": "hogestate", + "code_challenge": challenge, + "code_challenge_method": "S256", + }, nil) + assert.Equal(t, http.StatusFound, res.StatusCode) + loc := res.Header.Get("Location") + assert.Contains(t, loc, "https://web.example.com/login?id=") + reqID := lo.Must(url.Parse(loc)).Query().Get("id") + + // step 2 + res = send(http.MethodPost, ts.URL+"/api/login", true, map[string]string{ + "username": "aaa@example.com", + "password": "xyzxyz123", + "id": reqID, + }, nil) + assert.Equal(t, http.StatusFound, res.StatusCode) + assert.Equal(t, "https://web.example.com/login?error=Login+failed%3B+Invalid+s+ID+or+password.&id="+reqID, res.Header.Get("Location")) + + res = send(http.MethodPost, ts.URL+"/api/login", true, map[string]string{ + "username": "aaa@example.com", + "password": "Xyzxyz123", + "id": reqID, + }, nil) + assert.Equal(t, http.StatusFound, res.StatusCode) + assert.Equal(t, "https://example.com/authorize/callback?id="+reqID, res.Header.Get("Location")) + + // step 3 + res = send(http.MethodGet, ts.URL+"/authorize/callback?id="+reqID, false, nil, nil) + assert.Equal(t, http.StatusFound, res.StatusCode) + loc = res.Header.Get("Location") + assert.Contains(t, loc, "https://web.example.com?code=") + locu := lo.Must(url.Parse(loc)) + assert.Equal(t, "hogestate", locu.Query().Get("state")) + code := locu.Query().Get("code") + + // step 4 + res2 := send(http.MethodPost, ts.URL+"/oauth/token", true, map[string]string{ + "grant_type": "authorization_code", + "redirect_uri": "https://web.example.com", + "client_id": authServerDefaultClientID, + "code": code, + "code_verifier": verifier, + }, nil) + var r map[string]any + util.Must(json.Unmarshal(lo.Must(io.ReadAll(res2.Body)), &r)) + assert.Equal(t, map[string]any{ + "id_token": r["id_token"], + "access_token": r["access_token"], + "expires_in": r["expires_in"], + "token_type": "Bearer", + "state": "hogestate", + }, r) + accessToken := r["access_token"].(string) + idToken := r["id_token"].(string) + + // userinfo + res3 := send(http.MethodGet, ts.URL+"/userinfo", false, nil, map[string]string{ + "Authorization": "Bearer " + accessToken, + }) + var r2 map[string]any + util.Must(json.Unmarshal(lo.Must(io.ReadAll(res3.Body)), &r2)) + assert.Equal(t, map[string]any{ + "sub": "reearth|subsub", + "email": "aaa@example.com", + "name": "aaa", + "email_verified": true, + }, r2) + + // openid-configuration + res4 := send(http.MethodGet, ts.URL+"/.well-known/openid-configuration", false, nil, nil) + var r3 map[string]any + util.Must(json.Unmarshal(lo.Must(io.ReadAll(res4.Body)), &r3)) + assert.Equal(t, "https://example.com/jwks.json", r3["jwks_uri"]) + + // jwks + res5 := send(http.MethodGet, ts.URL+"/jwks.json", false, nil, nil) + var jwks jose.JSONWebKeySet + util.Must(json.Unmarshal(lo.Must(io.ReadAll(res5.Body)), &jwks)) + + // validate access_token + token := lo.Must(jwt.ParseSigned(accessToken)) + header, _ := lo.Find(token.Headers, func(h jose.Header) bool { + return h.Algorithm == string(jose.RS256) + }) + key := jwks.Key(header.KeyID)[0] + claims := map[string]any{} + util.Must(token.Claims(key.Key, &claims)) + assert.Equal(t, map[string]any{ + "iss": "https://example.com/", + "sub": "reearth|subsub", + "aud": []any{"https://example.com"}, + "jti": claims["jti"], + "exp": claims["exp"], + "nbf": claims["nbf"], + "iat": claims["iat"], + }, claims) + + // validate id_token + token2 := lo.Must(jwt.ParseSigned(idToken)) + header2, _ := lo.Find(token2.Headers, func(h jose.Header) bool { + return h.Algorithm == string(jose.RS256) + }) + key2 := jwks.Key(header2.KeyID)[0] + claims2 := map[string]any{} + util.Must(token.Claims(key2.Key, &claims2)) + assert.Equal(t, map[string]any{ + "iss": "https://example.com/", + "sub": "reearth|subsub", + "aud": []any{"https://example.com"}, + "jti": claims["jti"], + "exp": claims["exp"], + "nbf": claims["nbf"], + "iat": claims["iat"], + }, claims2) +} + +var httpClient = &http.Client{ + CheckRedirect: func(req *http.Request, via []*http.Request) error { + return http.ErrUseLastResponse + }, +} + +func send(method, u string, form bool, body any, headers map[string]string) *http.Response { + var b io.Reader + if body != nil { + if method == http.MethodPost || method == http.MethodPatch || method == http.MethodPut { + if form { + values := url.Values{} + for k, v := range body.(map[string]string) { + values.Set(k, v) + } + b = strings.NewReader(values.Encode()) + } else { + j := lo.Must(json.Marshal(body)) + b = bytes.NewReader(j) + } + } else if b, ok := body.(map[string]string); ok { + u2 := lo.Must(url.Parse(u)) + q := u2.Query() + for k, v := range b { + q.Set(k, v) + } + u2.RawQuery = q.Encode() + u = u2.String() + } + } + + req := lo.Must(http.NewRequest(method, u, b)) + if b != nil { + if !form { + req.Header.Set("Content-Type", "application/json") + } else { + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + } + } + for k, v := range headers { + req.Header.Set(k, v) + } + return lo.Must(httpClient.Do(req)) +} + +func codeChallenge(seed []byte) (string, string) { + verifier := base64.RawURLEncoding.EncodeToString(seed) + challengeSum := sha256.Sum256([]byte(verifier)) + challenge := strings.ReplaceAll(strings.ReplaceAll(strings.ReplaceAll(base64.StdEncoding.EncodeToString(challengeSum[:]), "+", "-"), "/", "_"), "=", "") + return verifier, challenge +} + +func randomCodeChallenge() (string, string) { + seed := make([]byte, 32) + _, _ = rand.Read(seed) + return codeChallenge(seed) +} diff --git a/server/internal/app/config.go b/server/internal/app/config.go index 2b559cdf34..0e702f2513 100644 --- a/server/internal/app/config.go +++ b/server/internal/app/config.go @@ -7,13 +7,12 @@ import ( "os" "strings" - "github.com/caos/oidc/pkg/op" "github.com/joho/godotenv" "github.com/kelseyhightower/envconfig" - "github.com/reearth/reearth/server/pkg/auth" "github.com/reearth/reearth/server/pkg/workspace" "github.com/reearth/reearthx/authserver" "github.com/reearth/reearthx/log" + "github.com/samber/lo" "golang.org/x/oauth2" "golang.org/x/oauth2/clientcredentials" ) @@ -66,6 +65,7 @@ type Auth0Config struct { type AuthSrvConfig struct { Dev bool Disabled bool + Issuer string Domain string UIDomain string Key string @@ -89,12 +89,10 @@ func (c AuthSrvConfig) AuthConfig(debug bool, host string) *AuthConfig { aud = []string{domain} } - clientID := auth.ClientID - return &AuthConfig{ ISS: domain, AUD: aud, - ClientID: &clientID, + ClientID: lo.ToPtr(authServerDefaultClientID), } } @@ -163,31 +161,28 @@ func ReadConfig(debug bool) (*Config, error) { var c Config err := envconfig.Process(configPrefix, &c) - // overwrite env vars - if !c.AuthSrv.Disabled && (c.Dev || c.AuthSrv.Dev || c.AuthSrv.Domain == "") { - if _, ok := os.LookupEnv(op.OidcDevMode); !ok { - _ = os.Setenv(op.OidcDevMode, "1") - } - } - // default values if debug { c.Dev = true } + c.Host = addHTTPScheme(c.Host) if c.Host_Web == "" { c.Host_Web = c.Host } else { c.Host_Web = addHTTPScheme(c.Host_Web) } + if c.AuthSrv.Domain == "" { c.AuthSrv.Domain = c.Host } else { c.AuthSrv.Domain = addHTTPScheme(c.AuthSrv.Domain) } + if c.Host_Web == "" { c.Host_Web = c.Host } + if c.AuthSrv.UIDomain == "" { c.AuthSrv.UIDomain = c.Host_Web } else { @@ -298,16 +293,16 @@ func (c Config) HostWebURL() *url.URL { return u } -func (c Config) AuthServeDomainURL() *url.URL { - u, err := url.Parse(c.AuthSrv.Domain) +func (c AuthSrvConfig) DomainURL() *url.URL { + u, err := url.Parse(c.Domain) if err != nil { u = nil } return u } -func (c Config) AuthServeUIDomainURL() *url.URL { - u, err := url.Parse(c.AuthSrv.UIDomain) +func (c AuthSrvConfig) UIDomainURL() *url.URL { + u, err := url.Parse(c.UIDomain) if err != nil { u = nil } diff --git a/server/internal/app/config_test.go b/server/internal/app/config_test.go index f4c567b5c4..583a636948 100644 --- a/server/internal/app/config_test.go +++ b/server/internal/app/config_test.go @@ -3,7 +3,6 @@ package app import ( "testing" - "github.com/reearth/reearth/server/pkg/auth" "github.com/stretchr/testify/assert" ) @@ -24,7 +23,7 @@ func TestAuth0Config_AuthConfig(t *testing.T) { } func TestReadConfig(t *testing.T) { - clientID := auth.ClientID + clientID := authServerDefaultClientID localAuth := AuthConfig{ ISS: "http://localhost:8080", AUD: []string{"http://localhost:8080"}, diff --git a/server/internal/infrastructure/memory/auth_request.go b/server/internal/infrastructure/memory/auth_request.go deleted file mode 100644 index 473b9dd8b0..0000000000 --- a/server/internal/infrastructure/memory/auth_request.go +++ /dev/null @@ -1,74 +0,0 @@ -package memory - -import ( - "context" - "sync" - - "github.com/reearth/reearth/server/pkg/auth" - "github.com/reearth/reearth/server/pkg/id" - "github.com/reearth/reearthx/rerror" -) - -type AuthRequest struct { - lock sync.Mutex - data map[id.AuthRequestID]*auth.Request -} - -func NewAuthRequest() *AuthRequest { - return &AuthRequest{ - data: map[id.AuthRequestID]*auth.Request{}, - } -} - -func (r *AuthRequest) FindByID(_ context.Context, id id.AuthRequestID) (*auth.Request, error) { - r.lock.Lock() - defer r.lock.Unlock() - - d, ok := r.data[id] - if ok { - return d, nil - } - return nil, rerror.ErrNotFound -} - -func (r *AuthRequest) FindByCode(_ context.Context, s string) (*auth.Request, error) { - r.lock.Lock() - defer r.lock.Unlock() - - for _, ar := range r.data { - if ar.GetCode() == s { - return ar, nil - } - } - - return nil, rerror.ErrNotFound -} - -func (r *AuthRequest) FindBySubject(_ context.Context, s string) (*auth.Request, error) { - r.lock.Lock() - defer r.lock.Unlock() - - for _, ar := range r.data { - if ar.GetSubject() == s { - return ar, nil - } - } - - return nil, rerror.ErrNotFound -} - -func (r *AuthRequest) Save(_ context.Context, request *auth.Request) error { - r.lock.Lock() - defer r.lock.Unlock() - - r.data[request.ID()] = request - return nil -} - -func (r *AuthRequest) Remove(_ context.Context, requestID id.AuthRequestID) error { - r.lock.Lock() - defer r.lock.Unlock() - - delete(r.data, requestID) - return nil -} diff --git a/server/internal/infrastructure/memory/config.go b/server/internal/infrastructure/memory/config.go index 23d58ccf34..d4edb8ae35 100644 --- a/server/internal/infrastructure/memory/config.go +++ b/server/internal/infrastructure/memory/config.go @@ -31,6 +31,16 @@ func (r *Config) Save(ctx context.Context, c *config.Config) error { return nil } +func (r *Config) SaveAuth(ctx context.Context, c *config.Auth) error { + if c != nil { + if r.data == nil { + r.data = &config.Config{} + } + r.data.Auth = c + } + return nil +} + func (r *Config) SaveAndUnlock(ctx context.Context, c *config.Config) error { _ = r.Save(ctx, c) return r.Unlock(ctx) diff --git a/server/internal/infrastructure/memory/container.go b/server/internal/infrastructure/memory/container.go index 3354076f5b..b645a1eb77 100644 --- a/server/internal/infrastructure/memory/container.go +++ b/server/internal/infrastructure/memory/container.go @@ -2,6 +2,7 @@ package memory import ( "github.com/reearth/reearth/server/internal/usecase/repo" + "github.com/reearth/reearthx/authserver" "github.com/reearth/reearthx/usecasex" ) @@ -21,7 +22,7 @@ func New() *repo.Container { Workspace: NewWorkspace(), User: NewUser(), SceneLock: NewSceneLock(), - AuthRequest: NewAuthRequest(), + AuthRequest: authserver.NewMemory(), Policy: NewPolicy(), Lock: NewLock(), Transaction: &usecasex.NopTransaction{}, diff --git a/server/internal/infrastructure/memory/user.go b/server/internal/infrastructure/memory/user.go index 791ccca7d4..15cfee1bc4 100644 --- a/server/internal/infrastructure/memory/user.go +++ b/server/internal/infrastructure/memory/user.go @@ -71,7 +71,7 @@ func (r *User) FindByAuth0Sub(ctx context.Context, auth0sub string) (*user.User, } for _, u := range r.data { - if u.ContainAuth(user.AuthFromAuth0Sub(auth0sub)) { + if u.Auths().Has(auth0sub) { return u, nil } } diff --git a/server/internal/infrastructure/mongo/auth_request.go b/server/internal/infrastructure/mongo/auth_request.go deleted file mode 100644 index 417ef4791a..0000000000 --- a/server/internal/infrastructure/mongo/auth_request.go +++ /dev/null @@ -1,58 +0,0 @@ -package mongo - -import ( - "context" - - "github.com/reearth/reearth/server/internal/infrastructure/mongo/mongodoc" - "github.com/reearth/reearth/server/pkg/auth" - "github.com/reearth/reearth/server/pkg/id" - "github.com/reearth/reearthx/log" - "github.com/reearth/reearthx/mongox" - "go.mongodb.org/mongo-driver/bson" -) - -type AuthRequest struct { - client *mongox.ClientCollection -} - -func NewAuthRequest(client *mongox.Client) *AuthRequest { - r := &AuthRequest{client: client.WithCollection("authRequest")} - r.init() - return r -} - -func (r *AuthRequest) init() { - i := r.client.CreateIndex(context.Background(), nil, []string{"id", "code", "subject"}) - if len(i) > 0 { - log.Infof("mongo: %s: index created: %s", "authRequest", i) - } -} - -func (r *AuthRequest) FindByID(ctx context.Context, id2 id.AuthRequestID) (*auth.Request, error) { - return r.findOne(ctx, bson.M{"id": id2.String()}) -} - -func (r *AuthRequest) FindByCode(ctx context.Context, s string) (*auth.Request, error) { - return r.findOne(ctx, bson.M{"code": s}) -} - -func (r *AuthRequest) FindBySubject(ctx context.Context, s string) (*auth.Request, error) { - return r.findOne(ctx, bson.M{"subject": s}) -} - -func (r *AuthRequest) Save(ctx context.Context, request *auth.Request) error { - doc, id1 := mongodoc.NewAuthRequest(request) - return r.client.SaveOne(ctx, id1, doc) -} - -func (r *AuthRequest) Remove(ctx context.Context, requestID id.AuthRequestID) error { - return r.client.RemoveOne(ctx, bson.M{"id": requestID.String()}) -} - -func (r *AuthRequest) findOne(ctx context.Context, filter any) (*auth.Request, error) { - c := mongodoc.NewAuthRequestConsumer() - if err := r.client.FindOne(ctx, filter, c); err != nil { - return nil, err - } - return c.Result[0], nil -} diff --git a/server/internal/infrastructure/mongo/auth_request_test.go b/server/internal/infrastructure/mongo/auth_request_test.go deleted file mode 100644 index be49d6aa26..0000000000 --- a/server/internal/infrastructure/mongo/auth_request_test.go +++ /dev/null @@ -1,73 +0,0 @@ -package mongo - -import ( - "context" - "testing" - - "github.com/caos/oidc/pkg/oidc" - "github.com/reearth/reearth/server/pkg/auth" - "github.com/reearth/reearthx/mongox" - "github.com/reearth/reearthx/mongox/mongotest" - "github.com/stretchr/testify/assert" -) - -func TestAuthRequestRepo(t *testing.T) { - tests := []struct { - Name string - Expected struct { - Name string - AuthRequest *auth.Request - } - }{ - { - Expected: struct { - Name string - AuthRequest *auth.Request - }{ - AuthRequest: auth.NewRequest(). - NewID(). - ClientID("client id"). - State("state"). - ResponseType("response type"). - Scopes([]string{"scope"}). - Audiences([]string{"audience"}). - RedirectURI("redirect uri"). - Nonce("nonce"). - CodeChallenge(&oidc.CodeChallenge{ - Challenge: "challenge", - Method: "S256", - }). - AuthorizedAt(nil). - MustBuild(), - }, - }, - } - - init := mongotest.Connect(t) - - for _, tt := range tests { - t.Run(tt.Name, func(t *testing.T) { - t.Parallel() - - client := init(t) - repo := NewAuthRequest(mongox.NewClientWithDatabase(client)) - - ctx := context.Background() - err := repo.Save(ctx, tt.Expected.AuthRequest) - assert.NoError(t, err) - - got, err := repo.FindByID(ctx, tt.Expected.AuthRequest.ID()) - assert.NoError(t, err) - assert.Equal(t, tt.Expected.AuthRequest.ID(), got.ID()) - assert.Equal(t, tt.Expected.AuthRequest.GetClientID(), got.GetClientID()) - assert.Equal(t, tt.Expected.AuthRequest.GetState(), got.GetState()) - assert.Equal(t, tt.Expected.AuthRequest.GetResponseType(), got.GetResponseType()) - assert.Equal(t, tt.Expected.AuthRequest.GetScopes(), got.GetScopes()) - assert.Equal(t, tt.Expected.AuthRequest.GetAudience(), got.GetAudience()) - assert.Equal(t, tt.Expected.AuthRequest.GetRedirectURI(), got.GetRedirectURI()) - assert.Equal(t, tt.Expected.AuthRequest.GetNonce(), got.GetNonce()) - assert.Equal(t, tt.Expected.AuthRequest.GetCodeChallenge(), got.GetCodeChallenge()) - assert.Equal(t, tt.Expected.AuthRequest.AuthorizedAt(), got.AuthorizedAt()) - }) - } -} diff --git a/server/internal/infrastructure/mongo/config.go b/server/internal/infrastructure/mongo/config.go index b24bd316b9..bfad7ff218 100644 --- a/server/internal/infrastructure/mongo/config.go +++ b/server/internal/infrastructure/mongo/config.go @@ -39,15 +39,36 @@ func (r *Config) LockAndLoad(ctx context.Context) (cfg *config.Config, err error } func (r *Config) Save(ctx context.Context, cfg *config.Config) error { - if cfg != nil { - if _, err := r.client.UpdateOne( - ctx, - bson.M{}, - bson.M{"$set": mongodoc.NewConfig(*cfg)}, - (&options.UpdateOptions{}).SetUpsert(true), - ); err != nil { - return rerror.ErrInternalBy(err) - } + if cfg == nil { + return nil + } + + if _, err := r.client.UpdateOne( + ctx, + bson.M{}, + bson.M{"$set": mongodoc.NewConfig(*cfg)}, + (&options.UpdateOptions{}).SetUpsert(true), + ); err != nil { + return rerror.ErrInternalBy(err) + } + + return nil +} + +func (r *Config) SaveAuth(ctx context.Context, cfg *config.Auth) error { + if cfg == nil { + return nil + } + + if _, err := r.client.UpdateOne( + ctx, + bson.M{}, + bson.M{"$set": bson.M{ + "auth": mongodoc.NewConfigAuth(cfg), + }}, + (&options.UpdateOptions{}).SetUpsert(true), + ); err != nil { + return rerror.ErrInternalBy(err) } return nil diff --git a/server/internal/infrastructure/mongo/container.go b/server/internal/infrastructure/mongo/container.go index f9201cb981..a6ffd987f2 100644 --- a/server/internal/infrastructure/mongo/container.go +++ b/server/internal/infrastructure/mongo/container.go @@ -7,6 +7,7 @@ import ( "github.com/reearth/reearth/server/internal/usecase/repo" "github.com/reearth/reearth/server/pkg/scene" "github.com/reearth/reearth/server/pkg/user" + "github.com/reearth/reearthx/authserver" "github.com/reearth/reearthx/mongox" "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/mongo" @@ -26,7 +27,7 @@ func New(ctx context.Context, mc *mongo.Client, databaseName string) (*repo.Cont client := mongox.NewClientWithDatabase(db) c := &repo.Container{ Asset: NewAsset(client), - AuthRequest: NewAuthRequest(client), + AuthRequest: authserver.NewMongo(client.WithCollection("authRequest")), Config: NewConfig(db.Collection("config"), lock), DatasetSchema: NewDatasetSchema(client), Dataset: NewDataset(client), diff --git a/server/internal/infrastructure/mongo/mongodoc/auth_request.go b/server/internal/infrastructure/mongo/mongodoc/auth_request.go deleted file mode 100644 index f2a5547114..0000000000 --- a/server/internal/infrastructure/mongo/mongodoc/auth_request.go +++ /dev/null @@ -1,101 +0,0 @@ -package mongodoc - -import ( - "time" - - "github.com/caos/oidc/pkg/oidc" - "github.com/reearth/reearth/server/pkg/auth" - "github.com/reearth/reearth/server/pkg/id" - "github.com/reearth/reearthx/mongox" -) - -type AuthRequestDocument struct { - ID string - ClientID string - Subject string - Code string - State string - ResponseType string - Scopes []string - Audiences []string - RedirectURI string - Nonce string - CodeChallenge *CodeChallengeDocument - CreatedAt time.Time - AuthorizedAt *time.Time -} - -type CodeChallengeDocument struct { - Challenge string - Method string -} - -type AuthRequestConsumer = mongox.SliceFuncConsumer[*AuthRequestDocument, *auth.Request] - -func NewAuthRequestConsumer() *AuthRequestConsumer { - return NewComsumer[*AuthRequestDocument, *auth.Request]() -} - -func NewAuthRequest(req *auth.Request) (*AuthRequestDocument, string) { - if req == nil { - return nil, "" - } - reqID := req.GetID() - var cc *CodeChallengeDocument - if req.GetCodeChallenge() != nil { - cc = &CodeChallengeDocument{ - Challenge: req.GetCodeChallenge().Challenge, - Method: string(req.GetCodeChallenge().Method), - } - } - return &AuthRequestDocument{ - ID: reqID, - ClientID: req.GetClientID(), - Subject: req.GetSubject(), - Code: req.GetCode(), - State: req.GetState(), - ResponseType: string(req.GetResponseType()), - Scopes: req.GetScopes(), - Audiences: req.GetAudience(), - RedirectURI: req.GetRedirectURI(), - Nonce: req.GetNonce(), - CodeChallenge: cc, - CreatedAt: req.CreatedAt(), - AuthorizedAt: req.AuthorizedAt(), - }, reqID -} - -func (d *AuthRequestDocument) Model() (*auth.Request, error) { - if d == nil { - return nil, nil - } - - ulid, err := id.AuthRequestIDFrom(d.ID) - if err != nil { - return nil, err - } - - var cc *oidc.CodeChallenge - if d.CodeChallenge != nil { - cc = &oidc.CodeChallenge{ - Challenge: d.CodeChallenge.Challenge, - Method: oidc.CodeChallengeMethod(d.CodeChallenge.Method), - } - } - var req = auth.NewRequest(). - ID(ulid). - ClientID(d.ClientID). - Subject(d.Subject). - Code(d.Code). - State(d.State). - ResponseType(oidc.ResponseType(d.ResponseType)). - Scopes(d.Scopes). - Audiences(d.Audiences). - RedirectURI(d.RedirectURI). - Nonce(d.Nonce). - CodeChallenge(cc). - CreatedAt(d.CreatedAt). - AuthorizedAt(d.AuthorizedAt). - MustBuild() - return req, nil -} diff --git a/server/internal/infrastructure/mongo/mongodoc/config.go b/server/internal/infrastructure/mongo/mongodoc/config.go index dea95a394f..a14452fe7a 100644 --- a/server/internal/infrastructure/mongo/mongodoc/config.go +++ b/server/internal/infrastructure/mongo/mongodoc/config.go @@ -17,18 +17,20 @@ type Auth struct { } func NewConfig(c config.Config) ConfigDocument { - d := ConfigDocument{ + return ConfigDocument{ Migration: c.Migration, + Auth: NewConfigAuth(c.Auth), } +} - if c.Auth != nil { - d.Auth = &Auth{ - Cert: c.Auth.Cert, - Key: c.Auth.Key, - } +func NewConfigAuth(c *config.Auth) *Auth { + if c == nil { + return nil + } + return &Auth{ + Cert: c.Cert, + Key: c.Key, } - - return d } func (c *ConfigDocument) Model() *config.Config { diff --git a/server/internal/infrastructure/mongo/mongodoc/user.go b/server/internal/infrastructure/mongo/mongodoc/user.go index 858c6a9b02..7f5b480c7d 100644 --- a/server/internal/infrastructure/mongo/mongodoc/user.go +++ b/server/internal/infrastructure/mongo/mongodoc/user.go @@ -19,7 +19,7 @@ type UserDocument struct { ID string Name string Email string - Auth0Sub string + Auth0Sub string `bson:"auth0sub,omitempty"` Auth0SubList []string Workspace string `bson:"team"` // DON'T CHANGE NAME Lang string @@ -90,9 +90,9 @@ func (d *UserDocument) Model() (*user1.User, error) { return nil, err } - auths := util.Map(d.Auth0SubList, func(s string) user.Auth { return user.AuthFromAuth0Sub(s) }) + auths := util.Map(d.Auth0SubList, func(s string) user.Auth { return user.AuthFrom(s) }) if d.Auth0Sub != "" { - auths = append(auths, user.AuthFromAuth0Sub(d.Auth0Sub)) + auths = append(auths, user.AuthFrom(d.Auth0Sub)) } var v *user.Verification if d.Verification != nil { diff --git a/server/internal/usecase/interactor/auth.go b/server/internal/usecase/interactor/auth.go deleted file mode 100644 index d6fee7fee9..0000000000 --- a/server/internal/usecase/interactor/auth.go +++ /dev/null @@ -1,414 +0,0 @@ -package interactor - -import ( - "context" - "crypto/rand" - "crypto/rsa" - "crypto/x509" - "crypto/x509/pkix" - "encoding/pem" - "errors" - "fmt" - "math/big" - "time" - - "github.com/caos/oidc/pkg/oidc" - "github.com/caos/oidc/pkg/op" - "github.com/reearth/reearth/server/internal/usecase/repo" - "github.com/reearth/reearth/server/pkg/auth" - config2 "github.com/reearth/reearth/server/pkg/config" - "github.com/reearth/reearth/server/pkg/id" - "github.com/reearth/reearth/server/pkg/user" - "github.com/reearth/reearthx/log" - "gopkg.in/square/go-jose.v2" -) - -type AuthStorage struct { - appConfig *StorageConfig - getUserBySubject func(context.Context, string) (*user.User, error) - clients map[string]op.Client - requests repo.AuthRequest - keySet jose.JSONWebKeySet - key *rsa.PrivateKey - sigKey jose.SigningKey -} - -type StorageConfig struct { - Domain string `default:"http://localhost:8080"` - ClientDomain string `default:"http://localhost:8080"` - Debug bool - DN *AuthDNConfig -} - -type AuthDNConfig struct { - CommonName string - Organization []string - OrganizationalUnit []string - Country []string - Province []string - Locality []string - StreetAddress []string - PostalCode []string -} - -var dummyName = pkix.Name{ - CommonName: "Dummy company, INC.", - Organization: []string{"Dummy company, INC."}, - OrganizationalUnit: []string{"Dummy OU"}, - Country: []string{"US"}, - Province: []string{"Dummy"}, - Locality: []string{"Dummy locality"}, - StreetAddress: []string{"Dummy street"}, - PostalCode: []string{"1"}, -} - -func NewAuthStorage(ctx context.Context, cfg *StorageConfig, request repo.AuthRequest, config repo.Config, getUserBySubject func(context.Context, string) (*user.User, error)) (op.Storage, error) { - client := auth.NewLocalClient(cfg.Debug, cfg.ClientDomain) - - name := dummyName - if cfg.DN != nil { - name = pkix.Name{ - CommonName: cfg.DN.CommonName, - Organization: cfg.DN.Organization, - OrganizationalUnit: cfg.DN.OrganizationalUnit, - Country: cfg.DN.Country, - Province: cfg.DN.Province, - Locality: cfg.DN.Locality, - StreetAddress: cfg.DN.StreetAddress, - PostalCode: cfg.DN.PostalCode, - } - } - c, err := config.LockAndLoad(ctx) - if err != nil { - return nil, fmt.Errorf("could not load auth config: %w\n", err) - } - defer func() { - if err := config.Unlock(ctx); err != nil { - log.Errorf("auth: could not release config lock: %s\n", err) - } - }() - - var keyBytes, certBytes []byte - if c.Auth != nil { - keyBytes = []byte(c.Auth.Key) - certBytes = []byte(c.Auth.Cert) - } else { - keyBytes, certBytes, err = generateCert(name) - if err != nil { - return nil, fmt.Errorf("could not generate raw cert: %w\n", err) - } - c.Auth = &config2.Auth{ - Key: string(keyBytes), - Cert: string(certBytes), - } - - if err := config.Save(ctx, c); err != nil { - return nil, fmt.Errorf("could not save raw cert: %w\n", err) - } - log.Info("auth: init a new private key and certificate") - } - - key, sigKey, keySet, err := initKeys(keyBytes, certBytes) - if err != nil { - return nil, fmt.Errorf("could not init keys: %w\n", err) - } - - return &AuthStorage{ - appConfig: cfg, - getUserBySubject: getUserBySubject, - requests: request, - key: key, - sigKey: *sigKey, - keySet: *keySet, - clients: map[string]op.Client{ - client.GetID(): client, - }, - }, nil -} - -func initKeys(keyBytes, certBytes []byte) (*rsa.PrivateKey, *jose.SigningKey, *jose.JSONWebKeySet, error) { - keyBlock, _ := pem.Decode(keyBytes) - if keyBlock == nil { - return nil, nil, nil, fmt.Errorf("failed to decode the key bytes") - } - key, err := x509.ParsePKCS1PrivateKey(keyBlock.Bytes) - if err != nil { - return nil, nil, nil, fmt.Errorf("failed to parse the private key bytes: %w\n", err) - } - - var certActualBytes []byte - certBlock, _ := pem.Decode(certBytes) - if certBlock == nil { - certActualBytes = certBytes // backwards compatibility - } else { - certActualBytes = certBlock.Bytes - } - - var cert *x509.Certificate - cert, err = x509.ParseCertificate(certActualBytes) - if err != nil { - return nil, nil, nil, fmt.Errorf("failed to parse the cert bytes: %w\n", err) - } - - keyID := "RE01" - sk := jose.SigningKey{ - Algorithm: jose.RS256, - Key: jose.JSONWebKey{Key: key, Use: "sig", Algorithm: string(jose.RS256), KeyID: keyID, Certificates: []*x509.Certificate{cert}}, - } - - return key, &sk, &jose.JSONWebKeySet{ - Keys: []jose.JSONWebKey{ - {Key: key.Public(), Use: "sig", Algorithm: string(jose.RS256), KeyID: keyID, Certificates: []*x509.Certificate{cert}}, - }, - }, nil -} - -func generateCert(name pkix.Name) (keyPem, certPem []byte, err error) { - key, err := rsa.GenerateKey(rand.Reader, 2048) - if err != nil { - err = fmt.Errorf("failed to generate key: %w\n", err) - return - } - - keyPem = pem.EncodeToMemory(&pem.Block{ - Type: "RSA PRIVATE KEY", - Bytes: x509.MarshalPKCS1PrivateKey(key), - }) - - cert := &x509.Certificate{ - SerialNumber: big.NewInt(1), - Subject: name, - NotBefore: time.Now(), - NotAfter: time.Now().AddDate(100, 0, 0), - IsCA: true, - KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign | x509.KeyUsageCRLSign, - } - - certBytes, err := x509.CreateCertificate(rand.Reader, cert, cert, key.Public(), key) - if err != nil { - err = fmt.Errorf("failed to create the cert: %w\n", err) - } - - certPem = pem.EncodeToMemory(&pem.Block{ - Type: "CERTIFICATE", - Bytes: certBytes, - }) - return -} - -func (s *AuthStorage) Health(_ context.Context) error { - return nil -} - -func (s *AuthStorage) CreateAuthRequest(ctx context.Context, authReq *oidc.AuthRequest, _ string) (op.AuthRequest, error) { - audiences := []string{ - s.appConfig.Domain, - } - if s.appConfig.Debug { - audiences = append(audiences, "http://localhost:8080") - } - - var cc *oidc.CodeChallenge - if authReq.CodeChallenge != "" { - cc = &oidc.CodeChallenge{ - Challenge: authReq.CodeChallenge, - Method: authReq.CodeChallengeMethod, - } - } - var request = auth.NewRequest(). - NewID(). - ClientID(authReq.ClientID). - State(authReq.State). - ResponseType(authReq.ResponseType). - Scopes(authReq.Scopes). - Audiences(audiences). - RedirectURI(authReq.RedirectURI). - Nonce(authReq.Nonce). - CodeChallenge(cc). - CreatedAt(time.Now().UTC()). - AuthorizedAt(nil). - MustBuild() - - if err := s.requests.Save(ctx, request); err != nil { - return nil, err - } - return request, nil -} - -func (s *AuthStorage) AuthRequestByID(ctx context.Context, requestID string) (op.AuthRequest, error) { - if requestID == "" { - return nil, errors.New("invalid id") - } - reqId, err := id.AuthRequestIDFrom(requestID) - if err != nil { - return nil, err - } - request, err := s.requests.FindByID(ctx, reqId) - if err != nil { - return nil, err - } - return request, nil -} - -func (s *AuthStorage) AuthRequestByCode(ctx context.Context, code string) (op.AuthRequest, error) { - if code == "" { - return nil, errors.New("invalid code") - } - return s.requests.FindByCode(ctx, code) -} - -func (s *AuthStorage) AuthRequestBySubject(ctx context.Context, subject string) (op.AuthRequest, error) { - if subject == "" { - return nil, errors.New("invalid subject") - } - - return s.requests.FindBySubject(ctx, subject) -} - -func (s *AuthStorage) SaveAuthCode(ctx context.Context, requestID, code string) error { - request, err := s.AuthRequestByID(ctx, requestID) - if err != nil { - return err - } - request2 := request.(*auth.Request) - request2.SetCode(code) - err = s.updateRequest(ctx, requestID, *request2) - return err -} - -func (s *AuthStorage) DeleteAuthRequest(_ context.Context, requestID string) error { - delete(s.clients, requestID) - return nil -} - -func (s *AuthStorage) CreateAccessToken(_ context.Context, _ op.TokenRequest) (string, time.Time, error) { - return "id", time.Now().UTC().Add(5 * time.Hour), nil -} - -func (s *AuthStorage) CreateAccessAndRefreshTokens(_ context.Context, request op.TokenRequest, _ string) (accessTokenID string, newRefreshToken string, expiration time.Time, err error) { - authReq := request.(*auth.Request) - return "id", authReq.GetID(), time.Now().UTC().Add(5 * time.Minute), nil -} - -func (s *AuthStorage) TokenRequestByRefreshToken(ctx context.Context, refreshToken string) (op.RefreshTokenRequest, error) { - r, err := s.AuthRequestByID(ctx, refreshToken) - if err != nil { - return nil, err - } - return r.(op.RefreshTokenRequest), err -} - -func (s *AuthStorage) TerminateSession(_ context.Context, _, _ string) error { - return errors.New("not implemented") -} - -func (s *AuthStorage) GetSigningKey(_ context.Context, keyCh chan<- jose.SigningKey) { - keyCh <- s.sigKey -} - -func (s *AuthStorage) GetKeySet(_ context.Context) (*jose.JSONWebKeySet, error) { - return &s.keySet, nil -} - -func (s *AuthStorage) GetKeyByIDAndUserID(_ context.Context, kid, _ string) (*jose.JSONWebKey, error) { - return &s.keySet.Key(kid)[0], nil -} - -func (s *AuthStorage) GetClientByClientID(_ context.Context, clientID string) (op.Client, error) { - - if clientID == "" { - return nil, errors.New("invalid client id") - } - - client, exists := s.clients[clientID] - if !exists { - return nil, errors.New("not found") - } - - return client, nil -} - -func (s *AuthStorage) AuthorizeClientIDSecret(_ context.Context, _ string, _ string) error { - return nil -} - -func (s *AuthStorage) SetUserinfoFromToken(ctx context.Context, userinfo oidc.UserInfoSetter, _, _, _ string) error { - return s.SetUserinfoFromScopes(ctx, userinfo, "", "", []string{}) -} - -func (s *AuthStorage) SetUserinfoFromScopes(ctx context.Context, userinfo oidc.UserInfoSetter, subject, _ string, _ []string) error { - - request, err := s.AuthRequestBySubject(ctx, subject) - if err != nil { - return err - } - - u, err := s.getUserBySubject(ctx, subject) - if err != nil { - return err - } - - userinfo.SetSubject(request.GetSubject()) - userinfo.SetEmail(u.Email(), true) - userinfo.SetName(u.Name()) - userinfo.AppendClaims("lang", u.Lang()) - userinfo.AppendClaims("theme", u.Theme()) - - return nil -} - -func (s *AuthStorage) GetPrivateClaimsFromScopes(_ context.Context, _, _ string, _ []string) (map[string]interface{}, error) { - return map[string]interface{}{"private_claim": "test"}, nil -} - -func (s *AuthStorage) SetIntrospectionFromToken(ctx context.Context, introspect oidc.IntrospectionResponse, _, subject, clientID string) error { - if err := s.SetUserinfoFromScopes(ctx, introspect, subject, clientID, []string{}); err != nil { - return err - } - request, err := s.AuthRequestBySubject(ctx, subject) - if err != nil { - return err - } - introspect.SetClientID(request.GetClientID()) - return nil -} - -func (s *AuthStorage) ValidateJWTProfileScopes(_ context.Context, _ string, scope []string) ([]string, error) { - return scope, nil -} - -func (s *AuthStorage) RevokeToken(_ context.Context, _ string, _ string, _ string) *oidc.Error { - // TODO implement me - panic("implement me") -} - -func (s *AuthStorage) CompleteAuthRequest(ctx context.Context, requestId, sub string) error { - request, err := s.AuthRequestByID(ctx, requestId) - if err != nil { - return err - } - req := request.(*auth.Request) - req.Complete(sub) - err = s.updateRequest(ctx, requestId, *req) - return err -} - -func (s *AuthStorage) updateRequest(ctx context.Context, requestID string, req auth.Request) error { - if requestID == "" { - return errors.New("invalid id") - } - reqId, err := id.AuthRequestIDFrom(requestID) - if err != nil { - return err - } - - if _, err := s.requests.FindByID(ctx, reqId); err != nil { - return err - } - - if err := s.requests.Save(ctx, &req); err != nil { - return err - } - - return nil -} diff --git a/server/internal/usecase/interactor/user.go b/server/internal/usecase/interactor/user.go index 26b8b8252b..99e250403c 100644 --- a/server/internal/usecase/interactor/user.go +++ b/server/internal/usecase/interactor/user.go @@ -58,7 +58,11 @@ var ( authTextTMPL *textTmpl.Template authHTMLTMPL *htmlTmpl.Template - passwordResetMailContent mailContent + passwordResetMailContent = mailContent{ + Message: "Thank you for using Re:Earth. We've received a request to reset your password. If this was you, please click the link below to confirm and change your password.", + Suffix: "If you did not mean to reset your password, then you can ignore this email.", + ActionLabel: "Confirm to reset your password", + } ) func init() { @@ -71,12 +75,6 @@ func init() { if err != nil { log.Panicf("password reset email template parse error: %s\n", err) } - - passwordResetMailContent = mailContent{ - Message: "Thank you for using Re:Earth. We’ve received a request to reset your password. If this was you, please click the link below to confirm and change your password.", - Suffix: "If you did not mean to reset your password, then you can ignore this email.", - ActionLabel: "Confirm to reset your password", - } } func NewUser(r *repo.Container, g *gateway.Container, signupSecret, authSrcUIDomain string) interfaces.User { @@ -304,7 +302,7 @@ func (i *User) UpdateMe(ctx context.Context, p interfaces.UpdateMeParam, operato u.UpdateTheme(*p.Theme) } - if p.Password != nil && u.HasAuthProvider("reearth") { + if p.Password != nil && u.Auths().HasProvider(user.ProviderReearth) { if err := u.SetPassword(*p.Password); err != nil { return nil, err } @@ -365,8 +363,7 @@ func (i *User) RemoveMyAuth(ctx context.Context, authProvider string, operator * u.RemoveAuthByProvider(authProvider) - err = i.userRepo.Save(ctx, u) - if err != nil { + if err = i.userRepo.Save(ctx, u); err != nil { return nil, err } diff --git a/server/internal/usecase/interactor/user_signup.go b/server/internal/usecase/interactor/user_signup.go index 5409a6a5af..d578f3f530 100644 --- a/server/internal/usecase/interactor/user_signup.go +++ b/server/internal/usecase/interactor/user_signup.go @@ -59,7 +59,7 @@ func (i *User) Signup(ctx context.Context, inp interfaces.SignupParam) (*user.Us // Initialize user and workspace var auth *user.Auth if inp.Sub != nil { - auth = user.AuthFromAuth0Sub(*inp.Sub).Ref() + auth = user.AuthFrom(*inp.Sub).Ref() } u, ws, err := userops.Init(userops.InitParams{ Email: inp.Email, @@ -137,7 +137,7 @@ func (i *User) SignupOIDC(ctx context.Context, inp interfaces.SignupOIDCParam) ( u, ws, err := userops.Init(userops.InitParams{ Email: email, Name: name, - Sub: user.AuthFromAuth0Sub(sub).Ref(), + Sub: user.AuthFrom(sub).Ref(), Lang: inp.User.Lang, Theme: inp.User.Theme, UserID: inp.User.UserID, diff --git a/server/internal/usecase/interactor/user_signup_test.go b/server/internal/usecase/interactor/user_signup_test.go index 8702b2a1d2..d7a22bf778 100644 --- a/server/internal/usecase/interactor/user_signup_test.go +++ b/server/internal/usecase/interactor/user_signup_test.go @@ -24,7 +24,7 @@ import ( func TestUser_Signup(t *testing.T) { user.DefaultPasswordEncoder = &user.NoopPasswordEncoder{} uid := id.NewUserID() - tid := id.NewWorkspaceID() + wid := id.NewWorkspaceID() mocktime := time.Time{} mockcode := "CODECODE" @@ -55,20 +55,20 @@ func TestUser_Signup(t *testing.T) { Password: lo.ToPtr("PAss00!!"), User: interfaces.SignupUserParam{ UserID: &uid, - WorkspaceID: &tid, + WorkspaceID: &wid, }, }, wantUser: user.New(). ID(uid). - Workspace(tid). + Workspace(wid). Name("NAME"). - Auths([]user.Auth{{Provider: "", Sub: "SUB"}}). + Auths([]user.Auth{{Sub: "SUB"}}). Email("aaa@bbb.com"). PasswordPlainText("PAss00!!"). Verification(user.VerificationFrom(mockcode, mocktime.Add(24*time.Hour), false)). MustBuild(), wantWorkspace: workspace.New(). - ID(tid). + ID(wid). Name("NAME"). Members(map[id.UserID]workspace.Role{uid: workspace.RoleOwner}). Personal(true). @@ -84,7 +84,7 @@ func TestUser_Signup(t *testing.T) { authSrvUIDomain: "", createUserBefore: user.New(). ID(uid). - Workspace(tid). + Workspace(wid). Email("aaa@bbb.com"). MustBuild(), args: interfaces.SignupParam{ @@ -93,12 +93,12 @@ func TestUser_Signup(t *testing.T) { Password: lo.ToPtr("PAss00!!"), User: interfaces.SignupUserParam{ UserID: &uid, - WorkspaceID: &tid, + WorkspaceID: &wid, }, }, wantUser: user.New(). ID(uid). - Workspace(tid). + Workspace(wid). Email("aaa@bbb.com"). Verification(user.VerificationFrom(mockcode, mocktime.Add(24*time.Hour), false)). MustBuild(), @@ -114,7 +114,7 @@ func TestUser_Signup(t *testing.T) { authSrvUIDomain: "", createUserBefore: user.New(). ID(uid). - Workspace(tid). + Workspace(wid). Email("aaa@bbb.com"). Verification(user.VerificationFrom(mockcode, mocktime, true)). MustBuild(), @@ -125,7 +125,7 @@ func TestUser_Signup(t *testing.T) { Password: lo.ToPtr("PAss00!!"), User: interfaces.SignupUserParam{ UserID: &uid, - WorkspaceID: &tid, + WorkspaceID: &wid, }, }, wantUser: nil, @@ -144,12 +144,12 @@ func TestUser_Signup(t *testing.T) { Secret: lo.ToPtr("hogehoge"), User: interfaces.SignupUserParam{ UserID: &uid, - WorkspaceID: &tid, + WorkspaceID: &wid, }, }, wantUser: user.New(). ID(uid). - Workspace(tid). + Workspace(wid). Name("NAME"). Auths([]user.Auth{{Provider: "", Sub: "SUB"}}). Email("aaa@bbb.com"). @@ -157,7 +157,7 @@ func TestUser_Signup(t *testing.T) { Verification(user.VerificationFrom(mockcode, mocktime.Add(24*time.Hour), false)). MustBuild(), wantWorkspace: workspace.New(). - ID(tid). + ID(wid). Name("NAME"). Members(map[id.UserID]workspace.Role{uid: workspace.RoleOwner}). Personal(true). @@ -179,14 +179,14 @@ func TestUser_Signup(t *testing.T) { Secret: lo.ToPtr("SECRET"), User: interfaces.SignupUserParam{ UserID: &uid, - WorkspaceID: &tid, + WorkspaceID: &wid, Lang: &language.Japanese, Theme: user.ThemeDark.Ref(), }, }, wantUser: user.New(). ID(uid). - Workspace(tid). + Workspace(wid). Name("NAME"). Auths([]user.Auth{{Provider: "", Sub: "SUB"}}). Email("aaa@bbb.com"). @@ -196,7 +196,7 @@ func TestUser_Signup(t *testing.T) { Verification(user.VerificationFrom(mockcode, mocktime.Add(24*time.Hour), false)). MustBuild(), wantWorkspace: workspace.New(). - ID(tid). + ID(wid). Name("NAME"). Members(map[id.UserID]workspace.Role{uid: workspace.RoleOwner}). Personal(true). @@ -274,10 +274,12 @@ func TestUser_Signup(t *testing.T) { m := mailer.NewMock() g := &gateway.Container{Mailer: m} uc := NewUser(r, g, tt.signupSecret, tt.authSrvUIDomain) + user, ws, err := uc.Signup(context.Background(), tt.args) assert.Equal(t, tt.wantUser, user) assert.Equal(t, tt.wantWorkspace, ws) assert.Equal(t, tt.wantError, err) + mails := m.Mails() if tt.wantMailSubject == "" { assert.Empty(t, mails) @@ -430,6 +432,7 @@ func TestUser_SignupOIDC(t *testing.T) { createUserBefore: user.New(). ID(uid). Email("aaa@bbb.com"). + Workspace(user.NewWorkspaceID()). MustBuild(), args: interfaces.SignupOIDCParam{ AccessToken: "accesstoken", @@ -448,6 +451,7 @@ func TestUser_SignupOIDC(t *testing.T) { createUserBefore: user.New(). ID(uid). Email("aaa@bbb.com"). + Workspace(user.NewWorkspaceID()). Verification(user.VerificationFrom(mockcode, mocktime, true)). MustBuild(), args: interfaces.SignupOIDCParam{ diff --git a/server/internal/usecase/repo/auth_request.go b/server/internal/usecase/repo/auth_request.go deleted file mode 100644 index cc402eab91..0000000000 --- a/server/internal/usecase/repo/auth_request.go +++ /dev/null @@ -1,16 +0,0 @@ -package repo - -import ( - "context" - - "github.com/reearth/reearth/server/pkg/auth" - "github.com/reearth/reearth/server/pkg/id" -) - -type AuthRequest interface { - FindByID(context.Context, id.AuthRequestID) (*auth.Request, error) - FindByCode(context.Context, string) (*auth.Request, error) - FindBySubject(context.Context, string) (*auth.Request, error) - Save(context.Context, *auth.Request) error - Remove(context.Context, id.AuthRequestID) error -} diff --git a/server/internal/usecase/repo/config.go b/server/internal/usecase/repo/config.go index 19f6a8193d..8951f5fd7c 100644 --- a/server/internal/usecase/repo/config.go +++ b/server/internal/usecase/repo/config.go @@ -9,6 +9,7 @@ import ( type Config interface { LockAndLoad(context.Context) (*config.Config, error) Save(context.Context, *config.Config) error + SaveAuth(context.Context, *config.Auth) error SaveAndUnlock(context.Context, *config.Config) error Unlock(context.Context) error } diff --git a/server/internal/usecase/repo/container.go b/server/internal/usecase/repo/container.go index b0a5c756f9..428d057c7d 100644 --- a/server/internal/usecase/repo/container.go +++ b/server/internal/usecase/repo/container.go @@ -6,6 +6,7 @@ import ( "github.com/reearth/reearth/server/internal/usecase" "github.com/reearth/reearth/server/pkg/scene" "github.com/reearth/reearth/server/pkg/user" + "github.com/reearth/reearthx/authserver" "github.com/reearth/reearthx/usecasex" ) @@ -15,7 +16,7 @@ var ( type Container struct { Asset Asset - AuthRequest AuthRequest + AuthRequest authserver.RequestRepo Config Config DatasetSchema DatasetSchema Dataset Dataset diff --git a/server/pkg/auth/builder.go b/server/pkg/auth/builder.go deleted file mode 100644 index d79391cc5c..0000000000 --- a/server/pkg/auth/builder.go +++ /dev/null @@ -1,102 +0,0 @@ -package auth - -import ( - "time" - - "github.com/caos/oidc/pkg/oidc" - "github.com/reearth/reearth/server/pkg/id" -) - -type RequestBuilder struct { - r *Request -} - -func NewRequest() *RequestBuilder { - return &RequestBuilder{r: &Request{}} -} - -func (b *RequestBuilder) Build() (*Request, error) { - if b.r.id.IsNil() { - return nil, id.ErrInvalidID - } - b.r.createdAt = time.Now() - return b.r, nil -} - -func (b *RequestBuilder) MustBuild() *Request { - r, err := b.Build() - if err != nil { - panic(err) - } - return r -} - -func (b *RequestBuilder) ID(id id.AuthRequestID) *RequestBuilder { - b.r.id = id - return b -} - -func (b *RequestBuilder) NewID() *RequestBuilder { - b.r.id = id.NewAuthRequestID() - return b -} - -func (b *RequestBuilder) ClientID(id string) *RequestBuilder { - b.r.clientID = id - return b -} - -func (b *RequestBuilder) Subject(subject string) *RequestBuilder { - b.r.subject = subject - return b -} - -func (b *RequestBuilder) Code(code string) *RequestBuilder { - b.r.code = code - return b -} - -func (b *RequestBuilder) State(state string) *RequestBuilder { - b.r.state = state - return b -} - -func (b *RequestBuilder) ResponseType(rt oidc.ResponseType) *RequestBuilder { - b.r.responseType = rt - return b -} - -func (b *RequestBuilder) Scopes(scopes []string) *RequestBuilder { - b.r.scopes = scopes - return b -} - -func (b *RequestBuilder) Audiences(audiences []string) *RequestBuilder { - b.r.audiences = audiences - return b -} - -func (b *RequestBuilder) RedirectURI(redirectURI string) *RequestBuilder { - b.r.redirectURI = redirectURI - return b -} - -func (b *RequestBuilder) Nonce(nonce string) *RequestBuilder { - b.r.nonce = nonce - return b -} - -func (b *RequestBuilder) CodeChallenge(CodeChallenge *oidc.CodeChallenge) *RequestBuilder { - b.r.codeChallenge = CodeChallenge - return b -} - -func (b *RequestBuilder) CreatedAt(createdAt time.Time) *RequestBuilder { - b.r.createdAt = createdAt - return b -} - -func (b *RequestBuilder) AuthorizedAt(authorizedAt *time.Time) *RequestBuilder { - b.r.authorizedAt = authorizedAt - return b -} diff --git a/server/pkg/auth/client.go b/server/pkg/auth/client.go deleted file mode 100644 index 0c7f6b5b68..0000000000 --- a/server/pkg/auth/client.go +++ /dev/null @@ -1,117 +0,0 @@ -package auth - -import ( - "fmt" - "time" - - "github.com/caos/oidc/pkg/oidc" - "github.com/caos/oidc/pkg/op" -) - -const ClientID = "reearth-authsrv-client-default" - -type Client struct { - id string - applicationType op.ApplicationType - authMethod oidc.AuthMethod - accessTokenType op.AccessTokenType - responseTypes []oidc.ResponseType - grantTypes []oidc.GrantType - allowedScopes []string - redirectURIs []string - logoutRedirectURIs []string - loginURI string - idTokenLifetime time.Duration - clockSkew time.Duration - devMode bool -} - -func NewLocalClient(devMode bool, clientDomain string) op.Client { - return &Client{ - id: ClientID, - applicationType: op.ApplicationTypeWeb, - authMethod: oidc.AuthMethodNone, - accessTokenType: op.AccessTokenTypeJWT, - responseTypes: []oidc.ResponseType{oidc.ResponseTypeCode}, - grantTypes: []oidc.GrantType{oidc.GrantTypeCode, oidc.GrantTypeRefreshToken}, - redirectURIs: []string{clientDomain}, - allowedScopes: []string{"openid", "profile", "email"}, - loginURI: clientDomain + "/login?id=%s", - idTokenLifetime: 5 * time.Minute, - clockSkew: 0, - devMode: devMode, - } -} - -func (c *Client) GetID() string { - return c.id -} - -func (c *Client) RedirectURIs() []string { - return c.redirectURIs -} - -func (c *Client) PostLogoutRedirectURIs() []string { - return c.logoutRedirectURIs -} - -func (c *Client) LoginURL(id string) string { - return fmt.Sprintf(c.loginURI, id) -} - -func (c *Client) ApplicationType() op.ApplicationType { - return c.applicationType -} - -func (c *Client) AuthMethod() oidc.AuthMethod { - return c.authMethod -} - -func (c *Client) IDTokenLifetime() time.Duration { - return c.idTokenLifetime -} - -func (c *Client) AccessTokenType() op.AccessTokenType { - return c.accessTokenType -} - -func (c *Client) ResponseTypes() []oidc.ResponseType { - return c.responseTypes -} - -func (c *Client) GrantTypes() []oidc.GrantType { - return c.grantTypes -} - -func (c *Client) DevMode() bool { - return c.devMode -} - -func (c *Client) RestrictAdditionalIdTokenScopes() func(scopes []string) []string { - return func(scopes []string) []string { - return scopes - } -} - -func (c *Client) RestrictAdditionalAccessTokenScopes() func(scopes []string) []string { - return func(scopes []string) []string { - return scopes - } -} - -func (c *Client) IsScopeAllowed(scope string) bool { - for _, clientScope := range c.allowedScopes { - if clientScope == scope { - return true - } - } - return false -} - -func (c *Client) IDTokenUserinfoClaimsAssertion() bool { - return false -} - -func (c *Client) ClockSkew() time.Duration { - return c.clockSkew -} diff --git a/server/pkg/auth/request.go b/server/pkg/auth/request.go deleted file mode 100644 index 9f51900bc2..0000000000 --- a/server/pkg/auth/request.go +++ /dev/null @@ -1,143 +0,0 @@ -package auth - -import ( - "time" - - "github.com/caos/oidc/pkg/oidc" - "github.com/reearth/reearth/server/pkg/id" -) - -var essentialScopes = []string{"openid", "profile", "email"} - -type Request struct { - id id.AuthRequestID - clientID string - subject string - code string - state string - responseType oidc.ResponseType - scopes []string - audiences []string - redirectURI string - nonce string - codeChallenge *oidc.CodeChallenge - createdAt time.Time - authorizedAt *time.Time -} - -func (a *Request) ID() id.AuthRequestID { - return a.id -} - -func (a *Request) GetID() string { - return a.id.String() -} - -func (a *Request) GetACR() string { - return "" -} - -func (a *Request) GetAMR() []string { - return []string{ - "password", - } -} - -func (a *Request) GetAudience() []string { - if a.audiences == nil { - return make([]string, 0) - } - - return a.audiences -} - -func (a *Request) GetAuthTime() time.Time { - return a.createdAt -} - -func (a *Request) GetClientID() string { - return a.clientID -} - -func (a *Request) GetResponseMode() oidc.ResponseMode { - // TODO make sure about this - return oidc.ResponseModeQuery -} - -func (a *Request) GetCode() string { - return a.code -} - -func (a *Request) GetState() string { - return a.state -} - -func (a *Request) GetCodeChallenge() *oidc.CodeChallenge { - return a.codeChallenge -} - -func (a *Request) GetNonce() string { - return a.nonce -} - -func (a *Request) GetRedirectURI() string { - return a.redirectURI -} - -func (a *Request) GetResponseType() oidc.ResponseType { - return a.responseType -} - -func (a *Request) GetScopes() []string { - return unique(append(a.scopes, essentialScopes...)) -} - -func (a *Request) SetCurrentScopes(scopes []string) { - a.scopes = unique(append(scopes, essentialScopes...)) -} - -func (a *Request) GetSubject() string { - return a.subject -} - -func (a *Request) CreatedAt() time.Time { - return a.createdAt -} - -func (a *Request) SetCreatedAt(createdAt time.Time) { - a.createdAt = createdAt -} - -func (a *Request) AuthorizedAt() *time.Time { - return a.authorizedAt -} - -func (a *Request) SetAuthorizedAt(authorizedAt *time.Time) { - a.authorizedAt = authorizedAt -} - -func (a *Request) Done() bool { - return a.authorizedAt != nil -} - -func (a *Request) Complete(sub string) { - a.subject = sub - now := time.Now() - a.authorizedAt = &now -} - -func (a *Request) SetCode(code string) { - a.code = code -} - -func unique(list []string) []string { - allKeys := make(map[string]struct{}) - var uniqueList []string - for _, item := range list { - if _, ok := allKeys[item]; !ok { - allKeys[item] = struct{}{} - uniqueList = append(uniqueList, item) - } - } - return uniqueList -} diff --git a/server/pkg/user/auth.go b/server/pkg/user/auth.go index a20ed8990a..5351c04104 100644 --- a/server/pkg/user/auth.go +++ b/server/pkg/user/auth.go @@ -2,6 +2,14 @@ package user import ( "strings" + + "github.com/samber/lo" + "golang.org/x/exp/slices" +) + +const ( + ProviderReearth = "reearth" + ProviderAuth0 = "auth0" ) type Auth struct { @@ -9,7 +17,7 @@ type Auth struct { Sub string } -func AuthFromAuth0Sub(sub string) Auth { +func AuthFrom(sub string) Auth { s := strings.SplitN(sub, "|", 2) if len(s) != 2 { return Auth{Provider: "", Sub: sub} @@ -17,18 +25,74 @@ func AuthFromAuth0Sub(sub string) Auth { return Auth{Provider: s[0], Sub: sub} } +func NewReearthAuth(sub string) Auth { + return Auth{ + Provider: ProviderReearth, + Sub: "reearth|" + sub, + } +} + +func (a Auth) IsReearth() bool { + return a.Provider == ProviderReearth +} + func (a Auth) IsAuth0() bool { - return a.Provider == "auth0" + return a.Provider == ProviderAuth0 } func (a Auth) Ref() *Auth { - a2 := a - return &a2 + return lo.ToPtr(a) +} + +func (a Auth) String() string { + return a.Sub +} + +type Auths []Auth + +func (a Auths) Has(sub string) bool { + return lo.ContainsBy(a, func(a Auth) bool { return a.Sub == sub }) +} + +func (a Auths) HasProvider(p string) bool { + return lo.ContainsBy(a, func(a Auth) bool { return a.Provider == p }) +} + +func (a Auths) GetByProvider(p string) *Auth { + _, i, ok := lo.FindIndexOf(a, func(a Auth) bool { return a.Provider == p }) + if !ok { + return nil + } + return &a[i] +} + +func (a Auths) Get(sub string) *Auth { + _, i, ok := lo.FindIndexOf(a, func(a Auth) bool { return a.Sub == sub }) + if !ok { + return nil + } + return &a[i] +} + +func (a Auths) Add(u Auth) Auths { + if a.Has(u.Sub) { + return a + } + return append(a, u) +} + +func (a Auths) Remove(sub string) Auths { + _, i, ok := lo.FindIndexOf(a, func(a Auth) bool { return a.Sub == sub }) + if !ok { + return a + } + return slices.Delete(a, i, 1) } -func GenReearthSub(userID string) *Auth { - return &Auth{ - Provider: "reearth", - Sub: "reearth|" + userID, +func (a Auths) RemoveByProvider(p string) Auths { + _, i, ok := lo.FindIndexOf(a, func(a Auth) bool { return a.Provider == p }) + if !ok { + return a } + return slices.Delete(a, i, 1) } diff --git a/server/pkg/user/auth_test.go b/server/pkg/user/auth_test.go index c85d252167..49a8e60271 100644 --- a/server/pkg/user/auth_test.go +++ b/server/pkg/user/auth_test.go @@ -4,6 +4,7 @@ import ( "testing" "github.com/reearth/reearth/server/pkg/id" + "golang.org/x/exp/slices" "github.com/stretchr/testify/assert" ) @@ -40,56 +41,33 @@ func TestAuthFromAuth0Sub(t *testing.T) { tc := tc t.Run(tc.Name, func(t *testing.T) { t.Parallel() - assert.Equal(t, tc.Expected, AuthFromAuth0Sub(tc.Sub)) + assert.Equal(t, tc.Expected, AuthFrom(tc.Sub)) }) } } func TestAuth_IsAuth0(t *testing.T) { - tests := []struct { - Name string - Auth Auth - Expected bool - }{ - { - Name: "is Auth", - Auth: Auth{ - Provider: "auth0", - Sub: "xxx", - }, - Expected: true, - }, - { - Name: "is not Auth", - Auth: Auth{ - Provider: "foo", - Sub: "hoge", - }, - Expected: false, - }, - } + assert.True(t, Auth{Provider: "auth0", Sub: "xxx"}.IsAuth0()) + assert.False(t, Auth{Provider: "reearth", Sub: "xxx"}.IsAuth0()) +} - for _, tc := range tests { - tc := tc - t.Run(tc.Name, func(t *testing.T) { - t.Parallel() - assert.Equal(t, tc.Expected, tc.Auth.IsAuth0()) - }) - } +func TestAuth_IsReearth(t *testing.T) { + assert.True(t, Auth{Provider: "reearth", Sub: "xxx"}.IsReearth()) + assert.False(t, Auth{Provider: "auth0", Sub: "xxx"}.IsReearth()) } -func TestGenReearthSub(t *testing.T) { +func TestNewReearthAuth(t *testing.T) { uid := id.NewUserID() tests := []struct { name string input string - want *Auth + want Auth }{ { name: "should return reearth sub", input: uid.String(), - want: &Auth{ + want: Auth{ Provider: "reearth", Sub: "reearth|" + uid.String(), }, @@ -98,8 +76,30 @@ func TestGenReearthSub(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got := GenReearthSub(tt.input) + got := NewReearthAuth(tt.input) assert.Equal(t, tt.want, got) }) } } + +func TestAuths(t *testing.T) { + auths := Auths{ + {Provider: "x", Sub: "a"}, + {Provider: "y", Sub: "b"}, + } + + assert.True(t, auths.Has("a")) + assert.False(t, auths.Has("y")) + assert.True(t, auths.HasProvider("x")) + assert.False(t, auths.HasProvider("b")) + assert.Equal(t, &Auth{Provider: "y", Sub: "b"}, auths.Get("b")) + assert.Nil(t, auths.Get("x")) + assert.Equal(t, &Auth{Provider: "x", Sub: "a"}, auths.GetByProvider("x")) + assert.Nil(t, auths.GetByProvider("b")) + assert.Equal(t, append(auths, Auth{Provider: "z", Sub: "c"}), auths.Add(Auth{Provider: "z", Sub: "c"})) + assert.Equal(t, auths, auths.Add(Auth{Provider: "z", Sub: "a"})) + assert.Equal(t, Auths{{Provider: "y", Sub: "b"}}, slices.Clone(auths).Remove("a")) + assert.Equal(t, auths, auths.Remove("c")) + assert.Equal(t, Auths{{Provider: "y", Sub: "b"}}, auths.RemoveByProvider("x")) + assert.Equal(t, auths, auths.RemoveByProvider("z")) +} diff --git a/server/pkg/user/builder.go b/server/pkg/user/builder.go index 60ed3cc449..80704aa983 100644 --- a/server/pkg/user/builder.go +++ b/server/pkg/user/builder.go @@ -18,6 +18,9 @@ func (b *Builder) Build() (*User, error) { if b.u.id.IsNil() { return nil, ErrInvalidID } + if b.u.workspace.IsNil() { + return nil, ErrInvalidID + } if b.u.theme == "" { b.u.theme = ThemeDefault } diff --git a/server/pkg/user/builder_test.go b/server/pkg/user/builder_test.go index dab75679c5..452f90ce9c 100644 --- a/server/pkg/user/builder_test.go +++ b/server/pkg/user/builder_test.go @@ -10,35 +10,35 @@ import ( func TestBuilder_ID(t *testing.T) { uid := NewID() - b := New().ID(uid).Email("aaa@bbb.com").MustBuild() + b := New().ID(uid).Email("aaa@bbb.com").Workspace(NewWorkspaceID()).MustBuild() assert.Equal(t, uid, b.ID()) assert.Nil(t, b.passwordReset) } func TestBuilder_Name(t *testing.T) { - b := New().NewID().Name("xxx").Email("aaa@bbb.com").MustBuild() + b := New().NewID().Name("xxx").Email("aaa@bbb.com").Workspace(NewWorkspaceID()).MustBuild() assert.Equal(t, "xxx", b.Name()) } func TestBuilder_NewID(t *testing.T) { - b := New().NewID().Email("aaa@bbb.com").MustBuild() + b := New().NewID().Email("aaa@bbb.com").Workspace(NewWorkspaceID()).MustBuild() assert.NotNil(t, b.ID()) } func TestBuilder_Workspace(t *testing.T) { - tid := NewWorkspaceID() - b := New().NewID().Email("aaa@bbb.com").Workspace(tid).MustBuild() - assert.Equal(t, tid, b.Workspace()) + wid := NewWorkspaceID() + b := New().NewID().Email("aaa@bbb.com").Workspace(wid).MustBuild() + assert.Equal(t, wid, b.Workspace()) } func TestBuilder_Auths(t *testing.T) { - b := New().NewID().Email("aaa@bbb.com").Auths([]Auth{ + b := New().NewID().Email("aaa@bbb.com").Workspace(NewWorkspaceID()).Auths([]Auth{ { Provider: "xxx", Sub: "aaa", }, }).MustBuild() - assert.Equal(t, []Auth{ + assert.Equal(t, Auths{ { Provider: "xxx", Sub: "aaa", @@ -47,13 +47,13 @@ func TestBuilder_Auths(t *testing.T) { } func TestBuilder_Email(t *testing.T) { - b := New().NewID().Email("xx@yy.zz").MustBuild() + b := New().NewID().Email("xx@yy.zz").Workspace(NewWorkspaceID()).MustBuild() assert.Equal(t, "xx@yy.zz", b.Email()) } func TestBuilder_Lang(t *testing.T) { l := language.Make("en") - b := New().NewID().Email("aaa@bbb.com").Lang(l).MustBuild() + b := New().NewID().Email("aaa@bbb.com").Workspace(NewWorkspaceID()).Lang(l).MustBuild() assert.Equal(t, l, b.Lang()) } @@ -83,7 +83,7 @@ func TestBuilder_LangFrom(t *testing.T) { tc := tc t.Run(tc.Name, func(t *testing.T) { t.Parallel() - b := New().NewID().Email("aaa@bbb.com").LangFrom(tc.Lang).MustBuild() + b := New().NewID().Email("aaa@bbb.com").Workspace(NewWorkspaceID()).LangFrom(tc.Lang).MustBuild() assert.Equal(t, tc.Expected, b.Lang()) }) } @@ -100,7 +100,7 @@ func TestBuilder_Build(t *testing.T) { DefaultPasswordEncoder = &NoopPasswordEncoder{} uid := NewID() - tid := NewWorkspaceID() + wid := NewWorkspaceID() pass := MustEncodedPassword("abcDEF0!") type args struct { @@ -124,7 +124,7 @@ func TestBuilder_Build(t *testing.T) { Email: "xx@yy.zz", Lang: "en", ID: uid, - Workspace: tid, + Workspace: wid, PasswordBin: pass, Auths: []Auth{ { @@ -135,7 +135,7 @@ func TestBuilder_Build(t *testing.T) { }, Expected: &User{ id: uid, - workspace: tid, + workspace: wid, email: "xx@yy.zz", name: "xxx", password: pass, @@ -143,10 +143,24 @@ func TestBuilder_Build(t *testing.T) { lang: language.English, theme: ThemeDefault, }, - }, { - Name: "failed invalid id", - Expected: nil, - Err: ErrInvalidID, + }, + { + Name: "failed invalid id", + Err: ErrInvalidID, + }, + { + Name: "ID missing", + Args: args{ + Workspace: wid, + }, + Err: ErrInvalidID, + }, + { + Name: "workspace ID missing", + Args: args{ + ID: uid, + }, + Err: ErrInvalidID, }, } diff --git a/server/pkg/user/user.go b/server/pkg/user/user.go index bd4729b977..66618530c2 100644 --- a/server/pkg/user/user.go +++ b/server/pkg/user/user.go @@ -4,6 +4,7 @@ import ( "errors" "net/mail" + "golang.org/x/exp/slices" "golang.org/x/text/language" ) @@ -17,7 +18,7 @@ type User struct { email string password EncodedPassword workspace WorkspaceID - auths []Auth + auths Auths lang language.Tag theme Theme verification *Verification @@ -80,90 +81,44 @@ func (u *User) Verification() *Verification { return u.verification } -func (u *User) Auths() []Auth { +func (u *User) Auths() Auths { if u == nil { return nil } - return append([]Auth{}, u.auths...) + return slices.Clone(u.auths) } -func (u *User) ContainAuth(a Auth) bool { - if u == nil { - return false - } - for _, b := range u.auths { - if a == b || a.Provider == b.Provider { - return true - } - } - return false -} - -func (u *User) HasAuthProvider(p string) bool { - if u == nil { - return false - } - for _, b := range u.auths { - if b.Provider == p { - return true - } - } - return false +func (u *User) SetAuths(a Auths) { + u.auths = slices.Clone(a) } func (u *User) AddAuth(a Auth) bool { - if u == nil { - return false - } - if !u.ContainAuth(a) { - u.auths = append(u.auths, a) + auths := u.auths.Add(a) + if len(auths) != len(u.auths) { + u.auths = auths return true } return false } -func (u *User) RemoveAuth(a Auth) bool { - if u == nil || a.IsAuth0() { - return false - } - for i, b := range u.auths { - if a == b { - u.auths = append(u.auths[:i], u.auths[i+1:]...) - return true - } +func (u *User) RemoveAuth(sub string) bool { + auths := u.auths.Remove(sub) + if len(auths) != len(u.auths) { + u.auths = auths + return true } return false } -func (u *User) GetAuthByProvider(provider string) *Auth { - if u == nil || u.auths == nil { - return nil - } - for _, b := range u.auths { - if provider == b.Provider { - return &b - } - } - return nil -} - -func (u *User) RemoveAuthByProvider(provider string) bool { - if u == nil || provider == "auth0" { - return false - } - for i, b := range u.auths { - if provider == b.Provider { - u.auths = append(u.auths[:i], u.auths[i+1:]...) - return true - } +func (u *User) RemoveAuthByProvider(p string) bool { + auths := u.auths.RemoveByProvider(p) + if len(auths) != len(u.auths) { + u.auths = auths + return true } return false } -func (u *User) ClearAuths() { - u.auths = []Auth{} -} - func (u *User) SetPassword(pass string) error { p, err := NewEncodedPassword(pass) if err != nil { diff --git a/server/pkg/user/user_test.go b/server/pkg/user/user_test.go index e0ae90e52a..050075baae 100644 --- a/server/pkg/user/user_test.go +++ b/server/pkg/user/user_test.go @@ -20,7 +20,7 @@ func TestUser(t *testing.T) { Name string Email string Workspace WorkspaceID - Auths []Auth + Auths Auths Lang language.Tag } }{ @@ -31,7 +31,7 @@ func TestUser(t *testing.T) { Name("xxx"). LangFrom("en"). Email("ff@xx.zz"). - Auths([]Auth{{ + Auths(Auths{{ Provider: "aaa", Sub: "sss", }}).MustBuild(), @@ -40,14 +40,14 @@ func TestUser(t *testing.T) { Name string Email string Workspace WorkspaceID - Auths []Auth + Auths Auths Lang language.Tag }{ Id: uid, Name: "xxx", Email: "ff@xx.zz", Workspace: tid, - Auths: []Auth{{ + Auths: Auths{{ Provider: "aaa", Sub: "sss", }}, @@ -70,298 +70,89 @@ func TestUser(t *testing.T) { } } -func TestUser_AddAuth(t *testing.T) { - tests := []struct { - Name string - User *User - A Auth - Expected bool - }{ - { - Name: "nil user", - User: nil, - Expected: false, - }, - { - Name: "add new auth", - User: New().NewID().Email("aaa@bbb.com").MustBuild(), - A: Auth{ - Provider: "xxx", - Sub: "zzz", - }, - Expected: true, - }, - { - Name: "existing auth", - User: New().NewID().Email("aaa@bbb.com").Auths([]Auth{{ - Provider: "xxx", - Sub: "zzz", - }}).MustBuild(), - A: Auth{ - Provider: "xxx", - Sub: "zzz", - }, - Expected: false, - }, - } - - for _, tc := range tests { - tc := tc - t.Run(tc.Name, func(t *testing.T) { - t.Parallel() - res := tc.User.AddAuth(tc.A) - assert.Equal(t, tc.Expected, res) - }) - } -} - -func TestUser_RemoveAuth(t *testing.T) { - tests := []struct { - Name string - User *User - A Auth - Expected bool - }{ - { - Name: "nil user", - User: nil, - Expected: false, - }, - { - Name: "remove auth0", - User: New().NewID().Email("aaa@bbb.com").MustBuild(), - A: Auth{ - Provider: "auth0", - Sub: "zzz", - }, - Expected: false, - }, - { - Name: "existing auth", - User: New().NewID().Email("aaa@bbb.com").Auths([]Auth{{ - Provider: "xxx", - Sub: "zzz", - }}).MustBuild(), - A: Auth{ - Provider: "xxx", - Sub: "zzz", - }, - Expected: true, - }, - } - - for _, tc := range tests { - tc := tc - t.Run(tc.Name, func(t *testing.T) { - t.Parallel() - res := tc.User.RemoveAuth(tc.A) - assert.Equal(t, tc.Expected, res) - }) - } -} +func TestUser_Auths(t *testing.T) { + var u *User + assert.Equal(t, Auths(nil), u.Auths()) -func TestUser_ContainAuth(t *testing.T) { - tests := []struct { - Name string - User *User - A Auth - Expected bool - }{ - { - Name: "nil user", - User: nil, - Expected: false, - }, - { - Name: "not existing auth", - User: New().NewID().Email("aaa@bbb.com").MustBuild(), - A: Auth{ - Provider: "auth0", - Sub: "zzz", - }, - Expected: false, - }, - { - Name: "existing auth", - User: New().NewID().Email("aaa@bbb.com").Auths([]Auth{{ - Provider: "xxx", - Sub: "zzz", - }}).MustBuild(), - A: Auth{ - Provider: "xxx", - Sub: "zzz", - }, - Expected: true, - }, - } + u = New().NewID().Email("aaa@bbb.com").Workspace(NewWorkspaceID()).Auths(Auths{{ + Provider: "xxx", + Sub: "zzz", + }}).MustBuild() + assert.Equal(t, Auths{{ + Provider: "xxx", + Sub: "zzz", + }}, u.Auths()) - for _, tc := range tests { - tc := tc - t.Run(tc.Name, func(t *testing.T) { - t.Parallel() - res := tc.User.ContainAuth(tc.A) - assert.Equal(t, tc.Expected, res) - }) - } + u.Auths().Get("zzz").Provider = "yyy" + // should not change + assert.Equal(t, Auths{{ + Provider: "xxx", + Sub: "zzz", + }}, u.Auths()) } -func TestUser_HasAuthProvider(t *testing.T) { - tests := []struct { - Name string - User *User - P string - Expected bool - }{ - { - Name: "nil user", - User: nil, - Expected: false, - }, - { - Name: "not existing auth", - User: New().NewID().Email("aaa@bbb.com").MustBuild(), - P: "auth0", - Expected: false, - }, - { - Name: "existing auth", - User: New().NewID().Email("aaa@bbb.com").Auths([]Auth{{ - Provider: "xxx", - Sub: "zzz", - }}).MustBuild(), - P: "xxx", - Expected: true, - }, - } - - for _, tc := range tests { - tc := tc - t.Run(tc.Name, func(t *testing.T) { - t.Parallel() - res := tc.User.HasAuthProvider(tc.P) - assert.Equal(t, tc.Expected, res) - }) - } +func TestUser_SetAuths(t *testing.T) { + u := New().NewID().Email("aaa@bbb.com").Workspace(NewWorkspaceID()).Auths(Auths{{ + Provider: "xxx", + Sub: "zzz", + }}).MustBuild() + u.SetAuths(nil) + assert.Equal(t, 0, len(u.Auths())) } -func TestUser_RemoveAuthByProvider(t *testing.T) { - tests := []struct { - Name string - User *User - Provider string - Expected bool - }{ - { - Name: "nil user", - User: nil, - Expected: false, - }, - { - Name: "remove auth0", - User: New().NewID().Email("aaa@bbb.com").MustBuild(), - Provider: "auth0", - Expected: false, - }, - { - Name: "existing auth", - User: New().NewID().Email("aaa@bbb.com").Auths([]Auth{{ - Provider: "xxx", - Sub: "zzz", - }}).MustBuild(), - Provider: "xxx", - Expected: true, - }, - } - - for _, tc := range tests { - tc := tc - t.Run(tc.Name, func(t *testing.T) { - t.Parallel() - res := tc.User.RemoveAuthByProvider(tc.Provider) - assert.Equal(t, tc.Expected, res) - }) - } +func TestUser_AddAuth(t *testing.T) { + u := New().NewID().Email("aaa@bbb.com").Workspace(NewWorkspaceID()).MustBuild() + assert.True(t, u.AddAuth(Auth{Provider: "a", Sub: "zzz"})) + assert.False(t, u.AddAuth(Auth{Provider: "b", Sub: "zzz"})) + assert.Equal(t, Auths{{Provider: "a", Sub: "zzz"}}, u.auths) } -func TestUser_ClearAuths(t *testing.T) { - u := New().NewID().Email("aaa@bbb.com").Auths([]Auth{{ +func TestUser_RemoveAuth(t *testing.T) { + u := New().NewID().Email("aaa@bbb.com").Workspace(NewWorkspaceID()).Auths(Auths{{ Provider: "xxx", Sub: "zzz", }}).MustBuild() - u.ClearAuths() - assert.Equal(t, 0, len(u.Auths())) + assert.True(t, u.RemoveAuth("zzz")) + assert.False(t, u.RemoveAuth("aaa")) + assert.Equal(t, Auths{}, u.auths) } -func TestUser_Auths(t *testing.T) { - var u *User - assert.Equal(t, []Auth(nil), u.Auths()) +func TestUser_RemoveAuthByProvider(t *testing.T) { + u := New().NewID().Email("aaa@bbb.com").Workspace(NewWorkspaceID()).Auths(Auths{{ + Provider: "xxx", + Sub: "zzz", + }}).MustBuild() + assert.True(t, u.RemoveAuthByProvider("xxx")) + assert.False(t, u.RemoveAuthByProvider("xxx")) + assert.Equal(t, Auths{}, u.auths) } func TestUser_UpdateEmail(t *testing.T) { - u := New().NewID().Email("abc@abc.com").MustBuild() + u := New().NewID().Email("abc@abc.com").Workspace(NewWorkspaceID()).MustBuild() assert.NoError(t, u.UpdateEmail("abc@xyz.com")) assert.Equal(t, "abc@xyz.com", u.Email()) assert.Error(t, u.UpdateEmail("abcxyz")) } func TestUser_UpdateLang(t *testing.T) { - u := New().NewID().Email("aaa@bbb.com").MustBuild() + u := New().NewID().Email("aaa@bbb.com").Workspace(NewWorkspaceID()).MustBuild() u.UpdateLang(language.Make("en")) assert.Equal(t, language.Make("en"), u.Lang()) } func TestUser_UpdateWorkspace(t *testing.T) { tid := NewWorkspaceID() - u := New().NewID().Email("aaa@bbb.com").MustBuild() + u := New().NewID().Email("aaa@bbb.com").Workspace(NewWorkspaceID()).MustBuild() u.UpdateWorkspace(tid) assert.Equal(t, tid, u.Workspace()) } func TestUser_UpdateName(t *testing.T) { - u := New().NewID().Email("aaa@bbb.com").MustBuild() + u := New().NewID().Email("aaa@bbb.com").Workspace(NewWorkspaceID()).MustBuild() u.UpdateName("xxx") assert.Equal(t, "xxx", u.Name()) } -func TestUser_GetAuthByProvider(t *testing.T) { - testCases := []struct { - Name string - User *User - Provider string - Expected *Auth - }{ - { - Name: "existing auth", - User: New().NewID().Email("aaa@bbb.com").Auths([]Auth{{ - Provider: "xxx", - Sub: "zzz", - }}).MustBuild(), - Provider: "xxx", - Expected: &Auth{ - Provider: "xxx", - Sub: "zzz", - }, - }, - { - Name: "not existing auth", - User: New().NewID().Email("aaa@bbb.com").Auths([]Auth{{ - Provider: "xxx", - Sub: "zzz", - }}).MustBuild(), - Provider: "yyy", - Expected: nil, - }, - } - for _, tc := range testCases { - tc := tc - t.Run(tc.Name, func(tt *testing.T) { - tt.Parallel() - res := tc.User.GetAuthByProvider(tc.Provider) - assert.Equal(tt, tc.Expected, res) - }) - } -} - func TestUser_MatchPassword(t *testing.T) { // bcrypt is not suitable for unit tests as it requires heavy computation DefaultPasswordEncoder = &NoopPasswordEncoder{} @@ -457,19 +248,19 @@ func TestUser_SetPassword(t *testing.T) { } func TestUser_PasswordReset(t *testing.T) { - testCases := []struct { + tests := []struct { Name string User *User Expected *PasswordReset }{ { Name: "not password request", - User: New().NewID().Email("aaa@bbb.com").MustBuild(), + User: New().NewID().Email("aaa@bbb.com").Workspace(NewWorkspaceID()).MustBuild(), Expected: nil, }, { Name: "create new password request over existing one", - User: New().NewID().Email("aaa@bbb.com").PasswordReset(&PasswordReset{"xzy", time.Unix(0, 0)}).MustBuild(), + User: New().NewID().Email("aaa@bbb.com").Workspace(NewWorkspaceID()).PasswordReset(&PasswordReset{"xzy", time.Unix(0, 0)}).MustBuild(), Expected: &PasswordReset{ Token: "xzy", CreatedAt: time.Unix(0, 0), @@ -477,7 +268,7 @@ func TestUser_PasswordReset(t *testing.T) { }, } - for _, tc := range testCases { + for _, tc := range tests { tc := tc t.Run(tc.Name, func(tt *testing.T) { tt.Parallel() @@ -495,13 +286,13 @@ func TestUser_SetPasswordReset(t *testing.T) { }{ { Name: "nil", - User: New().NewID().Email("aaa@bbb.com").MustBuild(), + User: New().NewID().Email("aaa@bbb.com").Workspace(NewWorkspaceID()).MustBuild(), Pr: nil, Expected: nil, }, { Name: "nil", - User: New().NewID().Email("aaa@bbb.com").MustBuild(), + User: New().NewID().Email("aaa@bbb.com").Workspace(NewWorkspaceID()).MustBuild(), Pr: &PasswordReset{ Token: "xyz", CreatedAt: time.Unix(1, 1), @@ -513,7 +304,7 @@ func TestUser_SetPasswordReset(t *testing.T) { }, { Name: "create new password request", - User: New().NewID().Email("aaa@bbb.com").MustBuild(), + User: New().NewID().Email("aaa@bbb.com").Workspace(NewWorkspaceID()).MustBuild(), Pr: &PasswordReset{ Token: "xyz", CreatedAt: time.Unix(1, 1), @@ -525,7 +316,7 @@ func TestUser_SetPasswordReset(t *testing.T) { }, { Name: "create new password request over existing one", - User: New().NewID().Email("aaa@bbb.com").PasswordReset(&PasswordReset{"xzy", time.Now()}).MustBuild(), + User: New().NewID().Email("aaa@bbb.com").Workspace(NewWorkspaceID()).PasswordReset(&PasswordReset{"xzy", time.Now()}).MustBuild(), Pr: &PasswordReset{ Token: "xyz", CreatedAt: time.Unix(1, 1), @@ -537,13 +328,13 @@ func TestUser_SetPasswordReset(t *testing.T) { }, { Name: "remove none existing password request", - User: New().NewID().Email("aaa@bbb.com").MustBuild(), + User: New().NewID().Email("aaa@bbb.com").Workspace(NewWorkspaceID()).MustBuild(), Pr: nil, Expected: nil, }, { Name: "remove existing password request", - User: New().NewID().Email("aaa@bbb.com").PasswordReset(&PasswordReset{"xzy", time.Now()}).MustBuild(), + User: New().NewID().Email("aaa@bbb.com").Workspace(NewWorkspaceID()).PasswordReset(&PasswordReset{"xzy", time.Now()}).MustBuild(), Pr: nil, Expected: nil, }, diff --git a/server/pkg/user/userops/initializer.go b/server/pkg/user/userops/initializer.go index 6dbca32918..6c7418ef1e 100644 --- a/server/pkg/user/userops/initializer.go +++ b/server/pkg/user/userops/initializer.go @@ -3,6 +3,7 @@ package userops import ( "github.com/reearth/reearth/server/pkg/user" "github.com/reearth/reearth/server/pkg/workspace" + "github.com/samber/lo" "golang.org/x/text/language" ) @@ -32,7 +33,7 @@ func Init(p InitParams) (*user.User, *workspace.Workspace, error) { p.Theme = &t } if p.Sub == nil { - p.Sub = user.GenReearthSub(p.UserID.String()) + p.Sub = lo.ToPtr(user.NewReearthAuth(p.UserID.String())) } b := user.New(). @@ -41,7 +42,8 @@ func Init(p InitParams) (*user.User, *workspace.Workspace, error) { Email(p.Email). Auths([]user.Auth{*p.Sub}). Lang(*p.Lang). - Theme(*p.Theme) + Theme(*p.Theme). + Workspace(*p.WorkspaceID) if p.Password != nil { b = b.PasswordPlainText(*p.Password) } @@ -60,7 +62,6 @@ func Init(p InitParams) (*user.User, *workspace.Workspace, error) { if err != nil { return nil, nil, err } - u.UpdateWorkspace(t.ID()) return u, t, err }