From f9272373207edb68b27f0642e4b9d7a086108963 Mon Sep 17 00:00:00 2001 From: Zach Brown Date: Tue, 2 Jan 2018 22:15:01 -0500 Subject: [PATCH] Add support for password grant #926 --- cmd/dex/config.go | 2 + cmd/dex/serve.go | 4 + examples/config-dev.yaml | 4 +- server/handlers.go | 247 +++++++++++++++++++++++++++++++++++++++ server/oauth2.go | 1 + server/server.go | 7 ++ 6 files changed, 264 insertions(+), 1 deletion(-) diff --git a/cmd/dex/config.go b/cmd/dex/config.go index dde369783e..7f467a353a 100644 --- a/cmd/dex/config.go +++ b/cmd/dex/config.go @@ -94,6 +94,8 @@ type OAuth2 struct { // If specified, do not prompt the user to approve client authorization. The // act of logging in implies authorization. SkipApprovalScreen bool `json:"skipApprovalScreen"` + // This is the connector that can be used for password grant + PasswordConnector string `json:"passwordConnector"` } // Web is the config format for the HTTP server. diff --git a/cmd/dex/serve.go b/cmd/dex/serve.go index dcc0c35239..ee34c25d8a 100644 --- a/cmd/dex/serve.go +++ b/cmd/dex/serve.go @@ -208,6 +208,9 @@ func serve(cmd *cobra.Command, args []string) error { if c.OAuth2.SkipApprovalScreen { logger.Infof("config skipping approval screen") } + if c.OAuth2.PasswordConnector != "" { + logger.Infof("config using password grant connector: %s", c.OAuth2.PasswordConnector) + } if len(c.Web.AllowedOrigins) > 0 { logger.Infof("config allowed origins: %s", c.Web.AllowedOrigins) } @@ -218,6 +221,7 @@ func serve(cmd *cobra.Command, args []string) error { serverConfig := server.Config{ SupportedResponseTypes: c.OAuth2.ResponseTypes, SkipApprovalScreen: c.OAuth2.SkipApprovalScreen, + PasswordConnector: c.OAuth2.PasswordConnector, AllowedOrigins: c.Web.AllowedOrigins, Issuer: c.Issuer, Storage: s, diff --git a/examples/config-dev.yaml b/examples/config-dev.yaml index 8778391709..cd5332c93f 100644 --- a/examples/config-dev.yaml +++ b/examples/config-dev.yaml @@ -45,7 +45,9 @@ telemetry: # Uncomment this block to control which response types dex supports. For example # the following response types enable the implicit flow for web-only clients. # Defaults to ["code"], the code flow. -# oauth2: +# Uncommend the passwordConnector to use a specific connector for password grants +#oauth2: +# passwordConnector: local # responseTypes: ["code", "token", "id_token"] # Instead of reading from an external storage, use this list of clients. diff --git a/server/handlers.go b/server/handlers.go index acbd19bf35..ccd6f8c583 100644 --- a/server/handlers.go +++ b/server/handlers.go @@ -634,6 +634,8 @@ func (s *Server) handleToken(w http.ResponseWriter, r *http.Request) { s.handleAuthCode(w, r, client) case grantTypeRefreshToken: s.handleRefreshToken(w, r, client) + case grantTypePassword: + s.handlePasswordGrant(w, r, client) default: s.tokenErrHelper(w, errInvalidGrant, "", http.StatusBadRequest) } @@ -970,6 +972,251 @@ func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request, clie s.writeAccessToken(w, idToken, accessToken, rawNewToken, expiry) } +func (s *Server) handlePasswordGrant(w http.ResponseWriter, r *http.Request, client storage.Client) { + + // Parse the fields + if err := r.ParseForm(); err != nil { + s.tokenErrHelper(w, errInvalidRequest, "Couldn't parse data", http.StatusBadRequest) + return + } + q := r.Form + + // Get the clientID and secret from basic auth or form variables + clientID, clientSecret, ok := r.BasicAuth() + if ok { + var err error + if clientID, err = url.QueryUnescape(clientID); err != nil { + s.tokenErrHelper(w, errInvalidRequest, "client_id improperly encoded", http.StatusBadRequest) + return + } + if clientSecret, err = url.QueryUnescape(clientSecret); err != nil { + s.tokenErrHelper(w, errInvalidRequest, "client_secret improperly encoded", http.StatusBadRequest) + return + } + } else { + clientID = q.Get("client_id") + clientSecret = q.Get("client_secret") + } + + nonce := q.Get("nonce") + // Some clients, like the old go-oidc, provide extra whitespace. Tolerate this. + scopes := strings.Fields(q.Get("scope")) + + // Get the client from the database + client, err := s.storage.GetClient(clientID) + if err != nil { + if err == storage.ErrNotFound { + s.tokenErrHelper(w, errInvalidClient, fmt.Sprintf("Invalid client_id (%q).", clientID), http.StatusBadRequest) + return + } + s.tokenErrHelper(w, errInvalidClient, fmt.Sprintf("Failed to get client %v.", err), http.StatusBadRequest) + return + } + + // Parse the scopes if they are passed + var ( + unrecognized []string + invalidScopes []string + ) + hasOpenIDScope := false + for _, scope := range scopes { + switch scope { + case scopeOpenID: + hasOpenIDScope = true + case scopeOfflineAccess, scopeEmail, scopeProfile, scopeGroups, scopeFederatedID: + default: + peerID, ok := parseCrossClientScope(scope) + if !ok { + unrecognized = append(unrecognized, scope) + continue + } + + isTrusted, err := s.validateCrossClientTrust(clientID, peerID) + if err != nil { + s.tokenErrHelper(w, errInvalidClient, fmt.Sprintf("Error validating cross client trust %v.", err), http.StatusBadRequest) + return + } + if !isTrusted { + invalidScopes = append(invalidScopes, scope) + } + } + } + if !hasOpenIDScope { + s.tokenErrHelper(w, errInvalidRequest, `Missing required scope(s) ["openid"].`, http.StatusBadRequest) + return + } + if len(unrecognized) > 0 { + s.tokenErrHelper(w, errInvalidRequest, fmt.Sprintf("Unrecognized scope(s) %q", unrecognized), http.StatusBadRequest) + return + } + if len(invalidScopes) > 0 { + s.tokenErrHelper(w, errInvalidRequest, fmt.Sprintf("Client can't request scope(s) %q", invalidScopes), http.StatusBadRequest) + return + } + + // Which connector + connID := s.passwordConnector + conn, err := s.getConnector(connID) + if err != nil { + s.tokenErrHelper(w, errInvalidRequest, "Requested connector does not exist.", http.StatusBadRequest) + return + } + + passwordConnector, ok := conn.Connector.(connector.PasswordConnector) + if !ok { + s.tokenErrHelper(w, errInvalidRequest, "Requested password connector does not correct type.", http.StatusBadRequest) + return + } + + // Login + username := q.Get("username") + password := q.Get("password") + identity, ok, err := passwordConnector.Login(r.Context(), parseScopes(scopes), username, password) + if err != nil { + s.tokenErrHelper(w, errInvalidRequest, "Could not login user", http.StatusBadRequest) + return + } + if !ok { + s.tokenErrHelper(w, errAccessDenied, "Invalid username or password", http.StatusUnauthorized) + return + } + + // Build the claims to send the id token + claims := storage.Claims{ + UserID: identity.UserID, + Username: identity.Username, + Name: identity.Name, + Email: identity.Email, + EmailVerified: identity.EmailVerified, + Groups: identity.Groups, + } + + accessToken := storage.NewID() + idToken, expiry, err := s.newIDToken(client.ID, claims, scopes, nonce, accessToken, connID) + if err != nil { + s.tokenErrHelper(w, errServerError, fmt.Sprintf("failed to create ID token: %v", err), http.StatusInternalServerError) + return + } + + reqRefresh := func() bool { + // Ensure the connector supports refresh tokens. + // + // Connectors like `saml` do not implement RefreshConnector. + _, ok := conn.Connector.(connector.RefreshConnector) + if !ok { + return false + } + + for _, scope := range scopes { + if scope == scopeOfflineAccess { + return true + } + } + return false + }() + var refreshToken string + if reqRefresh { + refresh := storage.RefreshToken{ + ID: storage.NewID(), + Token: storage.NewID(), + ClientID: clientID, + ConnectorID: connID, + Scopes: scopes, + Claims: claims, + Nonce: nonce, + // ConnectorData: authCode.ConnectorData, + CreatedAt: s.now(), + LastUsed: s.now(), + } + token := &internal.RefreshToken{ + RefreshId: refresh.ID, + Token: refresh.Token, + } + if refreshToken, err = internal.Marshal(token); err != nil { + s.logger.Errorf("failed to marshal refresh token: %v", err) + s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) + return + } + + if err := s.storage.CreateRefresh(refresh); err != nil { + s.logger.Errorf("failed to create refresh token: %v", err) + s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) + return + } + + // deleteToken determines if we need to delete the newly created refresh token + // due to a failure in updating/creating the OfflineSession object for the + // corresponding user. + var deleteToken bool + defer func() { + if deleteToken { + // Delete newly created refresh token from storage. + if err := s.storage.DeleteRefresh(refresh.ID); err != nil { + s.logger.Errorf("failed to delete refresh token: %v", err) + s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) + return + } + } + }() + + tokenRef := storage.RefreshTokenRef{ + ID: refresh.ID, + ClientID: refresh.ClientID, + CreatedAt: refresh.CreatedAt, + LastUsed: refresh.LastUsed, + } + + // Try to retrieve an existing OfflineSession object for the corresponding user. + if session, err := s.storage.GetOfflineSessions(refresh.Claims.UserID, refresh.ConnectorID); err != nil { + if err != storage.ErrNotFound { + s.logger.Errorf("failed to get offline session: %v", err) + s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) + deleteToken = true + return + } + offlineSessions := storage.OfflineSessions{ + UserID: refresh.Claims.UserID, + ConnID: refresh.ConnectorID, + Refresh: make(map[string]*storage.RefreshTokenRef), + } + offlineSessions.Refresh[tokenRef.ClientID] = &tokenRef + + // Create a new OfflineSession object for the user and add a reference object for + // the newly received refreshtoken. + if err := s.storage.CreateOfflineSessions(offlineSessions); err != nil { + s.logger.Errorf("failed to create offline session: %v", err) + s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) + deleteToken = true + return + } + } else { + if oldTokenRef, ok := session.Refresh[tokenRef.ClientID]; ok { + // Delete old refresh token from storage. + if err := s.storage.DeleteRefresh(oldTokenRef.ID); err != nil { + s.logger.Errorf("failed to delete refresh token: %v", err) + s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) + deleteToken = true + return + } + } + + // Update existing OfflineSession obj with new RefreshTokenRef. + if err := s.storage.UpdateOfflineSessions(session.UserID, session.ConnID, func(old storage.OfflineSessions) (storage.OfflineSessions, error) { + old.Refresh[tokenRef.ClientID] = &tokenRef + return old, nil + }); err != nil { + s.logger.Errorf("failed to update offline session: %v", err) + s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) + deleteToken = true + return + } + + } + } + + s.writeAccessToken(w, idToken, accessToken, refreshToken, expiry) +} + func (s *Server) writeAccessToken(w http.ResponseWriter, idToken, accessToken, refreshToken string, expiry time.Time) { // TODO(ericchiang): figure out an access token story and support the user info // endpoint. For now use a random value so no one depends on the access_token diff --git a/server/oauth2.go b/server/oauth2.go index 8c9494f514..ee0046420c 100644 --- a/server/oauth2.go +++ b/server/oauth2.go @@ -118,6 +118,7 @@ const ( const ( grantTypeAuthorizationCode = "authorization_code" grantTypeRefreshToken = "refresh_token" + grantTypePassword = "password" ) const ( diff --git a/server/server.go b/server/server.go index a1bfcfe192..b94aad9455 100644 --- a/server/server.go +++ b/server/server.go @@ -68,6 +68,9 @@ type Config struct { // Logging in implies approval. SkipApprovalScreen bool + // If set, the server will use this connector to handle password grants + PasswordConnector string + RotateKeysAfter time.Duration // Defaults to 6 hours. IDTokensValidFor time.Duration // Defaults to 24 hours @@ -133,6 +136,9 @@ type Server struct { // If enabled, don't prompt user for approval after logging in through connector. skipApproval bool + // Used for password grant + passwordConnector string + supportedResponseTypes map[string]bool now func() time.Time @@ -191,6 +197,7 @@ func newServer(ctx context.Context, c Config, rotationStrategy rotationStrategy) skipApproval: c.SkipApprovalScreen, now: c.Now, templates: templates, + passwordConnector: c.PasswordConnector, logger: c.Logger, }