Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for Resource Owner Password Credentials Grant #1163

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions cmd/dex/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,8 @@ type OAuth2 struct {
// If specified, do not prompt the user to approve client authorization. The
// act of logging in implies authorization.
SkipApprovalScreen bool `json:"skipApprovalScreen"`
// This is the connector that can be used for password grant
PasswordConnector string `json:"passwordConnector"`
}

// Web is the config format for the HTTP server.
Expand Down
4 changes: 4 additions & 0 deletions cmd/dex/serve.go
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,9 @@ func serve(cmd *cobra.Command, args []string) error {
if c.OAuth2.SkipApprovalScreen {
logger.Infof("config skipping approval screen")
}
if c.OAuth2.PasswordConnector != "" {
logger.Infof("config using password grant connector: %s", c.OAuth2.PasswordConnector)
}
if len(c.Web.AllowedOrigins) > 0 {
logger.Infof("config allowed origins: %s", c.Web.AllowedOrigins)
}
Expand All @@ -218,6 +221,7 @@ func serve(cmd *cobra.Command, args []string) error {
serverConfig := server.Config{
SupportedResponseTypes: c.OAuth2.ResponseTypes,
SkipApprovalScreen: c.OAuth2.SkipApprovalScreen,
PasswordConnector: c.OAuth2.PasswordConnector,
AllowedOrigins: c.Web.AllowedOrigins,
Issuer: c.Issuer,
Storage: s,
Expand Down
4 changes: 3 additions & 1 deletion examples/config-dev.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,9 @@ telemetry:
# Uncomment this block to control which response types dex supports. For example
# the following response types enable the implicit flow for web-only clients.
# Defaults to ["code"], the code flow.
# oauth2:
# Uncommend the passwordConnector to use a specific connector for password grants
#oauth2:
# passwordConnector: local
# responseTypes: ["code", "token", "id_token"]

# Instead of reading from an external storage, use this list of clients.
Expand Down
246 changes: 246 additions & 0 deletions server/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -631,6 +631,8 @@ func (s *Server) handleToken(w http.ResponseWriter, r *http.Request) {
s.handleAuthCode(w, r, client)
case grantTypeRefreshToken:
s.handleRefreshToken(w, r, client)
case grantTypePassword:
s.handlePasswordGrant(w, r, client)
default:
s.tokenErrHelper(w, errInvalidGrant, "", http.StatusBadRequest)
}
Expand Down Expand Up @@ -967,6 +969,250 @@ func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request, clie
s.writeAccessToken(w, idToken, accessToken, rawNewToken, expiry)
}

func (s *Server) handlePasswordGrant(w http.ResponseWriter, r *http.Request, client storage.Client) {

// Parse the fields
if err := r.ParseForm(); err != nil {
s.tokenErrHelper(w, errInvalidRequest, "Couldn't parse data", http.StatusBadRequest)
return
}
q := r.Form

// Get the clientID and secret from basic auth or form variables
clientID, clientSecret, ok := r.BasicAuth()
if ok {
var err error
if clientID, err = url.QueryUnescape(clientID); err != nil {
s.tokenErrHelper(w, errInvalidRequest, "client_id improperly encoded", http.StatusBadRequest)
return
}
if clientSecret, err = url.QueryUnescape(clientSecret); err != nil {
s.tokenErrHelper(w, errInvalidRequest, "client_secret improperly encoded", http.StatusBadRequest)
return
}
} else {
clientID = q.Get("client_id")
clientSecret = q.Get("client_secret")
}

nonce := q.Get("nonce")
// Some clients, like the old go-oidc, provide extra whitespace. Tolerate this.
scopes := strings.Fields(q.Get("scope"))

// Get the client from the database
client, err := s.storage.GetClient(clientID)
if err != nil {
if err == storage.ErrNotFound {
s.tokenErrHelper(w, errInvalidClient, fmt.Sprintf("Invalid client_id (%q).", clientID), http.StatusBadRequest)
return
}
s.tokenErrHelper(w, errInvalidClient, fmt.Sprintf("Failed to get client %v.", err), http.StatusBadRequest)
return
}

// Parse the scopes if they are passed
var (
unrecognized []string
invalidScopes []string
)
hasOpenIDScope := false
for _, scope := range scopes {
switch scope {
case scopeOpenID:
hasOpenIDScope = true
case scopeOfflineAccess, scopeEmail, scopeProfile, scopeGroups:
default:
peerID, ok := parseCrossClientScope(scope)
if !ok {
unrecognized = append(unrecognized, scope)
continue
}

isTrusted, err := s.validateCrossClientTrust(clientID, peerID)
if err != nil {
s.tokenErrHelper(w, errInvalidClient, fmt.Sprintf("Error validating cross client trust %v.", err), http.StatusBadRequest)
return
}
if !isTrusted {
invalidScopes = append(invalidScopes, scope)
}
}
}
if !hasOpenIDScope {
s.tokenErrHelper(w, errInvalidRequest, `Missing required scope(s) ["openid"].`, http.StatusBadRequest)
return
}
if len(unrecognized) > 0 {
s.tokenErrHelper(w, errInvalidRequest, fmt.Sprintf("Unrecognized scope(s) %q", unrecognized), http.StatusBadRequest)
return
}
if len(invalidScopes) > 0 {
s.tokenErrHelper(w, errInvalidRequest, fmt.Sprintf("Client can't request scope(s) %q", invalidScopes), http.StatusBadRequest)
return
}

// Which connector
connID := s.passwordConnector
conn, err := s.getConnector(connID)
if err != nil {
s.tokenErrHelper(w, errInvalidRequest, "Requested connector does not exist.", http.StatusBadRequest)
return
}

passwordConnector, ok := conn.Connector.(connector.PasswordConnector)
if !ok {
s.tokenErrHelper(w, errInvalidRequest, "Requested password connector does not correct type.", http.StatusBadRequest)
return
}

// Login
username := q.Get("username")
password := q.Get("password")
identity, ok, err := passwordConnector.Login(r.Context(), parseScopes(scopes), username, password)
if err != nil {
s.tokenErrHelper(w, errInvalidRequest, "Could not login user", http.StatusBadRequest)
return
}
if !ok {
s.tokenErrHelper(w, errAccessDenied, "Invalid username or password", http.StatusUnauthorized)
return
}

// Build the claims to send the id token
claims := storage.Claims{
UserID: identity.UserID,
Username: identity.Username,
Email: identity.Email,
EmailVerified: identity.EmailVerified,
Groups: identity.Groups,
}

accessToken := storage.NewID()
idToken, expiry, err := s.newIDToken(client.ID, claims, scopes, nonce, accessToken, connID)
if err != nil {
s.tokenErrHelper(w, errServerError, fmt.Sprintf("failed to create ID token: %v", err), http.StatusInternalServerError)
return
}

reqRefresh := func() bool {
// Ensure the connector supports refresh tokens.
//
// Connectors like `saml` do not implement RefreshConnector.
_, ok := conn.Connector.(connector.RefreshConnector)
if !ok {
return false
}

for _, scope := range scopes {
if scope == scopeOfflineAccess {
return true
}
}
return false
}()
var refreshToken string
if reqRefresh {
refresh := storage.RefreshToken{
ID: storage.NewID(),
Token: storage.NewID(),
ClientID: clientID,
ConnectorID: connID,
Scopes: scopes,
Claims: claims,
Nonce: nonce,
// ConnectorData: authCode.ConnectorData,
CreatedAt: s.now(),
LastUsed: s.now(),
}
token := &internal.RefreshToken{
RefreshId: refresh.ID,
Token: refresh.Token,
}
if refreshToken, err = internal.Marshal(token); err != nil {
s.logger.Errorf("failed to marshal refresh token: %v", err)
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
return
}

if err := s.storage.CreateRefresh(refresh); err != nil {
s.logger.Errorf("failed to create refresh token: %v", err)
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
return
}

// deleteToken determines if we need to delete the newly created refresh token
// due to a failure in updating/creating the OfflineSession object for the
// corresponding user.
var deleteToken bool
defer func() {
if deleteToken {
// Delete newly created refresh token from storage.
if err := s.storage.DeleteRefresh(refresh.ID); err != nil {
s.logger.Errorf("failed to delete refresh token: %v", err)
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
return
}
}
}()

tokenRef := storage.RefreshTokenRef{
ID: refresh.ID,
ClientID: refresh.ClientID,
CreatedAt: refresh.CreatedAt,
LastUsed: refresh.LastUsed,
}

// Try to retrieve an existing OfflineSession object for the corresponding user.
if session, err := s.storage.GetOfflineSessions(refresh.Claims.UserID, refresh.ConnectorID); err != nil {
if err != storage.ErrNotFound {
s.logger.Errorf("failed to get offline session: %v", err)
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
deleteToken = true
return
}
offlineSessions := storage.OfflineSessions{
UserID: refresh.Claims.UserID,
ConnID: refresh.ConnectorID,
Refresh: make(map[string]*storage.RefreshTokenRef),
}
offlineSessions.Refresh[tokenRef.ClientID] = &tokenRef

// Create a new OfflineSession object for the user and add a reference object for
// the newly received refreshtoken.
if err := s.storage.CreateOfflineSessions(offlineSessions); err != nil {
s.logger.Errorf("failed to create offline session: %v", err)
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
deleteToken = true
return
}
} else {
if oldTokenRef, ok := session.Refresh[tokenRef.ClientID]; ok {
// Delete old refresh token from storage.
if err := s.storage.DeleteRefresh(oldTokenRef.ID); err != nil {
s.logger.Errorf("failed to delete refresh token: %v", err)
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
deleteToken = true
return
}
}

// Update existing OfflineSession obj with new RefreshTokenRef.
if err := s.storage.UpdateOfflineSessions(session.UserID, session.ConnID, func(old storage.OfflineSessions) (storage.OfflineSessions, error) {
old.Refresh[tokenRef.ClientID] = &tokenRef
return old, nil
}); err != nil {
s.logger.Errorf("failed to update offline session: %v", err)
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
deleteToken = true
return
}

}
}

s.writeAccessToken(w, idToken, accessToken, refreshToken, expiry)
}

func (s *Server) writeAccessToken(w http.ResponseWriter, idToken, accessToken, refreshToken string, expiry time.Time) {
// TODO(ericchiang): figure out an access token story and support the user info
// endpoint. For now use a random value so no one depends on the access_token
Expand Down
1 change: 1 addition & 0 deletions server/oauth2.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ const (
const (
grantTypeAuthorizationCode = "authorization_code"
grantTypeRefreshToken = "refresh_token"
grantTypePassword = "password"
)

const (
Expand Down
7 changes: 7 additions & 0 deletions server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,9 @@ type Config struct {
// Logging in implies approval.
SkipApprovalScreen bool

// If set, the server will use this connector to handle password grants
PasswordConnector string
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here, the connector should be initialized in cmd so this would be of type connector.PasswordConnector


RotateKeysAfter time.Duration // Defaults to 6 hours.
IDTokensValidFor time.Duration // Defaults to 24 hours

Expand Down Expand Up @@ -132,6 +135,9 @@ type Server struct {
// If enabled, don't prompt user for approval after logging in through connector.
skipApproval bool

// Used for password grant
passwordConnector string
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit, this should be a connector.PasswordConnector


supportedResponseTypes map[string]bool

now func() time.Time
Expand Down Expand Up @@ -197,6 +203,7 @@ func newServer(ctx context.Context, c Config, rotationStrategy rotationStrategy)
supportedResponseTypes: supported,
idTokensValidFor: value(c.IDTokensValidFor, 24*time.Hour),
skipApproval: c.SkipApprovalScreen,
passwordConnector: c.PasswordConnector,
now: now,
templates: tmpls,
logger: c.Logger,
Expand Down