Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add context to storage's Create endpoints #2935

Merged
merged 5 commits into from
Jan 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions server/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ func (d dexAPI) CreateClient(ctx context.Context, req *api.CreateClientReq) (*ap
Name: req.Client.Name,
LogoURL: req.Client.LogoUrl,
}
if err := d.s.CreateClient(c); err != nil {
if err := d.s.CreateClient(ctx, c); err != nil {
if err == storage.ErrAlreadyExists {
return &api.CreateClientResp{AlreadyExists: true}, nil
}
Expand Down Expand Up @@ -177,7 +177,7 @@ func (d dexAPI) CreatePassword(ctx context.Context, req *api.CreatePasswordReq)
Username: req.Password.Username,
UserID: req.Password.UserId,
}
if err := d.s.CreatePassword(p); err != nil {
if err := d.s.CreatePassword(ctx, p); err != nil {
if err == storage.ErrAlreadyExists {
return &api.CreatePasswordResp{AlreadyExists: true}, nil
}
Expand Down
4 changes: 2 additions & 2 deletions server/api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ func TestRefreshToken(t *testing.T) {
ConnectorData: []byte(`{"some":"data"}`),
}

if err := s.CreateRefresh(r); err != nil {
if err := s.CreateRefresh(ctx, r); err != nil {
t.Fatalf("create refresh token: %v", err)
}

Expand All @@ -280,7 +280,7 @@ func TestRefreshToken(t *testing.T) {
}
session.Refresh[tokenRef.ClientID] = &tokenRef

if err := s.CreateOfflineSessions(session); err != nil {
if err := s.CreateOfflineSessions(ctx, session); err != nil {
t.Fatalf("create offline session: %v", err)
}

Expand Down
8 changes: 5 additions & 3 deletions server/deviceflowhandlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ func (s *Server) handleDeviceExchange(w http.ResponseWriter, r *http.Request) {
}

func (s *Server) handleDeviceCode(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
pollIntervalSeconds := 5

switch r.Method {
Expand Down Expand Up @@ -106,7 +107,7 @@ func (s *Server) handleDeviceCode(w http.ResponseWriter, r *http.Request) {
Expiry: expireTime,
}

if err := s.storage.CreateDeviceRequest(deviceReq); err != nil {
if err := s.storage.CreateDeviceRequest(ctx, deviceReq); err != nil {
s.logger.Errorf("Failed to store device request; %v", err)
s.tokenErrHelper(w, errInvalidRequest, "", http.StatusInternalServerError)
return
Expand All @@ -125,7 +126,7 @@ func (s *Server) handleDeviceCode(w http.ResponseWriter, r *http.Request) {
},
}

if err := s.storage.CreateDeviceToken(deviceToken); err != nil {
if err := s.storage.CreateDeviceToken(ctx, deviceToken); err != nil {
s.logger.Errorf("Failed to store device token %v", err)
s.tokenErrHelper(w, errInvalidRequest, "", http.StatusInternalServerError)
return
Expand Down Expand Up @@ -280,6 +281,7 @@ func (s *Server) handleDeviceToken(w http.ResponseWriter, r *http.Request) {
}

func (s *Server) handleDeviceCallback(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
switch r.Method {
case http.MethodGet:
userCode := r.FormValue("state")
Expand Down Expand Up @@ -336,7 +338,7 @@ func (s *Server) handleDeviceCallback(w http.ResponseWriter, r *http.Request) {
return
}

resp, err := s.exchangeAuthCode(w, authCode, client)
resp, err := s.exchangeAuthCode(ctx, w, authCode, client)
if err != nil {
s.logger.Errorf("Could not exchange auth code for client %q: %v", deviceReq.ClientID, err)
s.renderError(r, w, http.StatusInternalServerError, "Failed to exchange auth code.")
Expand Down
14 changes: 7 additions & 7 deletions server/deviceflowhandlers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -366,15 +366,15 @@ func TestDeviceCallback(t *testing.T) {
})
defer httpServer.Close()

if err := s.storage.CreateAuthCode(tc.testAuthCode); err != nil {
if err := s.storage.CreateAuthCode(ctx, tc.testAuthCode); err != nil {
t.Fatalf("failed to create auth code: %v", err)
}

if err := s.storage.CreateDeviceRequest(tc.testDeviceRequest); err != nil {
if err := s.storage.CreateDeviceRequest(ctx, tc.testDeviceRequest); err != nil {
t.Fatalf("failed to create device request: %v", err)
}

if err := s.storage.CreateDeviceToken(tc.testDeviceToken); err != nil {
if err := s.storage.CreateDeviceToken(ctx, tc.testDeviceToken); err != nil {
t.Fatalf("failed to create device token: %v", err)
}

Expand All @@ -383,7 +383,7 @@ func TestDeviceCallback(t *testing.T) {
Secret: "",
RedirectURIs: []string{deviceCallbackURI},
}
if err := s.storage.CreateClient(client); err != nil {
if err := s.storage.CreateClient(ctx, client); err != nil {
t.Fatalf("failed to create client: %v", err)
}

Expand Down Expand Up @@ -660,11 +660,11 @@ func TestDeviceTokenResponse(t *testing.T) {
})
defer httpServer.Close()

if err := s.storage.CreateDeviceRequest(tc.testDeviceRequest); err != nil {
if err := s.storage.CreateDeviceRequest(ctx, tc.testDeviceRequest); err != nil {
t.Fatalf("Failed to store device token %v", err)
}

if err := s.storage.CreateDeviceToken(tc.testDeviceToken); err != nil {
if err := s.storage.CreateDeviceToken(ctx, tc.testDeviceToken); err != nil {
t.Fatalf("Failed to store device token %v", err)
}

Expand Down Expand Up @@ -794,7 +794,7 @@ func TestVerifyCodeResponse(t *testing.T) {
})
defer httpServer.Close()

if err := s.storage.CreateDeviceRequest(tc.testDeviceRequest); err != nil {
if err := s.storage.CreateDeviceRequest(ctx, tc.testDeviceRequest); err != nil {
t.Fatalf("Failed to store device token %v", err)
}

Expand Down
38 changes: 23 additions & 15 deletions server/handlers.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package server

import (
"context"
"crypto/hmac"
"crypto/sha256"
"crypto/subtle"
Expand Down Expand Up @@ -187,6 +188,7 @@ func (s *Server) handleAuthorization(w http.ResponseWriter, r *http.Request) {
}

func (s *Server) handleConnectorLogin(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
authReq, err := s.parseAuthorizationRequest(r)
if err != nil {
s.logger.Errorf("Failed to parse authorization request: %v", err)
Expand Down Expand Up @@ -229,7 +231,7 @@ func (s *Server) handleConnectorLogin(w http.ResponseWriter, r *http.Request) {

// Actually create the auth request
authReq.Expiry = s.now().Add(s.authRequestsValidFor)
if err := s.storage.CreateAuthRequest(*authReq); err != nil {
if err := s.storage.CreateAuthRequest(ctx, *authReq); err != nil {
s.logger.Errorf("Failed to create authorization request: %v", err)
s.renderError(r, w, http.StatusInternalServerError, "Failed to connect to the database.")
return
Expand Down Expand Up @@ -305,6 +307,7 @@ func (s *Server) handleConnectorLogin(w http.ResponseWriter, r *http.Request) {
}

func (s *Server) handlePasswordLogin(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
authID := r.URL.Query().Get("state")
if authID == "" {
s.renderError(r, w, http.StatusBadRequest, "User session error.")
Expand Down Expand Up @@ -360,7 +363,7 @@ func (s *Server) handlePasswordLogin(w http.ResponseWriter, r *http.Request) {
password := r.FormValue("password")
scopes := parseScopes(authReq.Scopes)

identity, ok, err := pwConn.Login(r.Context(), scopes, username, password)
identity, ok, err := pwConn.Login(ctx, scopes, username, password)
if err != nil {
s.logger.Errorf("Failed to login user: %v", err)
s.renderError(r, w, http.StatusInternalServerError, fmt.Sprintf("Login error: %v", err))
Expand All @@ -372,7 +375,7 @@ func (s *Server) handlePasswordLogin(w http.ResponseWriter, r *http.Request) {
}
return
}
redirectURL, canSkipApproval, err := s.finalizeLogin(identity, authReq, conn.Connector)
redirectURL, canSkipApproval, err := s.finalizeLogin(ctx, identity, authReq, conn.Connector)
if err != nil {
s.logger.Errorf("Failed to finalize login: %v", err)
s.renderError(r, w, http.StatusInternalServerError, "Login error.")
Expand All @@ -397,6 +400,7 @@ func (s *Server) handlePasswordLogin(w http.ResponseWriter, r *http.Request) {
}

func (s *Server) handleConnectorCallback(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
var authID string
switch r.Method {
case http.MethodGet: // OAuth2 callback
Expand Down Expand Up @@ -471,7 +475,7 @@ func (s *Server) handleConnectorCallback(w http.ResponseWriter, r *http.Request)
return
}

redirectURL, canSkipApproval, err := s.finalizeLogin(identity, authReq, conn.Connector)
redirectURL, canSkipApproval, err := s.finalizeLogin(ctx, identity, authReq, conn.Connector)
if err != nil {
s.logger.Errorf("Failed to finalize login: %v", err)
s.renderError(r, w, http.StatusInternalServerError, "Login error.")
Expand All @@ -494,7 +498,7 @@ func (s *Server) handleConnectorCallback(w http.ResponseWriter, r *http.Request)

// finalizeLogin associates the user's identity with the current AuthRequest, then returns
// the approval page's path.
func (s *Server) finalizeLogin(identity connector.Identity, authReq storage.AuthRequest, conn connector.Connector) (string, bool, error) {
func (s *Server) finalizeLogin(ctx context.Context, identity connector.Identity, authReq storage.AuthRequest, conn connector.Connector) (string, bool, error) {
claims := storage.Claims{
UserID: identity.UserID,
Username: identity.Username,
Expand Down Expand Up @@ -566,7 +570,7 @@ func (s *Server) finalizeLogin(identity connector.Identity, authReq storage.Auth

// Create a new OfflineSession object for the user and add a reference object for
// the newly received refreshtoken.
if err := s.storage.CreateOfflineSessions(offlineSessions); err != nil {
if err := s.storage.CreateOfflineSessions(ctx, offlineSessions); err != nil {
s.logger.Errorf("failed to create offline session: %v", err)
return "", false, err
}
Expand Down Expand Up @@ -649,6 +653,7 @@ func (s *Server) handleApproval(w http.ResponseWriter, r *http.Request) {
}

func (s *Server) sendCodeResponse(w http.ResponseWriter, r *http.Request, authReq storage.AuthRequest) {
ctx := r.Context()
if s.now().After(authReq.Expiry) {
s.renderError(r, w, http.StatusBadRequest, "User session has expired.")
return
Expand Down Expand Up @@ -701,7 +706,7 @@ func (s *Server) sendCodeResponse(w http.ResponseWriter, r *http.Request, authRe
ConnectorData: authReq.ConnectorData,
PKCE: authReq.PKCE,
}
if err := s.storage.CreateAuthCode(code); err != nil {
if err := s.storage.CreateAuthCode(ctx, code); err != nil {
s.logger.Errorf("Failed to create auth code: %v", err)
s.renderError(r, w, http.StatusInternalServerError, "Internal server error.")
return
Expand Down Expand Up @@ -876,6 +881,7 @@ func (s *Server) calculateCodeChallenge(codeVerifier, codeChallengeMethod string

// handle an access token request https://tools.ietf.org/html/rfc6749#section-4.1.3
func (s *Server) handleAuthCode(w http.ResponseWriter, r *http.Request, client storage.Client) {
ctx := r.Context()
code := r.PostFormValue("code")
redirectURI := r.PostFormValue("redirect_uri")

Expand Down Expand Up @@ -926,15 +932,15 @@ func (s *Server) handleAuthCode(w http.ResponseWriter, r *http.Request, client s
return
}

tokenResponse, err := s.exchangeAuthCode(w, authCode, client)
tokenResponse, err := s.exchangeAuthCode(ctx, w, authCode, client)
if err != nil {
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
return
}
s.writeAccessToken(w, tokenResponse)
}

func (s *Server) exchangeAuthCode(w http.ResponseWriter, authCode storage.AuthCode, client storage.Client) (*accessTokenResponse, error) {
func (s *Server) exchangeAuthCode(ctx context.Context, w http.ResponseWriter, authCode storage.AuthCode, client storage.Client) (*accessTokenResponse, error) {
accessToken, _, err := s.newAccessToken(client.ID, authCode.Claims, authCode.Scopes, authCode.Nonce, authCode.ConnectorID)
if err != nil {
s.logger.Errorf("failed to create new access token: %v", err)
Expand Down Expand Up @@ -1002,7 +1008,7 @@ func (s *Server) exchangeAuthCode(w http.ResponseWriter, authCode storage.AuthCo
return nil, err
}

if err := s.storage.CreateRefresh(refresh); err != nil {
if err := s.storage.CreateRefresh(ctx, refresh); err != nil {
s.logger.Errorf("failed to create refresh token: %v", err)
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
return nil, err
Expand Down Expand Up @@ -1047,7 +1053,7 @@ func (s *Server) exchangeAuthCode(w http.ResponseWriter, authCode storage.AuthCo

// Create a new OfflineSession object for the user and add a reference object for
// the newly received refreshtoken.
if err := s.storage.CreateOfflineSessions(offlineSessions); err != nil {
if err := s.storage.CreateOfflineSessions(ctx, offlineSessions); err != nil {
s.logger.Errorf("failed to create offline session: %v", err)
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
deleteToken = true
Expand Down Expand Up @@ -1080,6 +1086,7 @@ func (s *Server) exchangeAuthCode(w http.ResponseWriter, authCode storage.AuthCo
}

func (s *Server) handleUserInfo(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
const prefix = "Bearer "

auth := r.Header.Get("authorization")
Expand All @@ -1091,7 +1098,7 @@ func (s *Server) handleUserInfo(w http.ResponseWriter, r *http.Request) {
rawIDToken := auth[len(prefix):]

verifier := oidc.NewVerifier(s.issuerURL.String(), &storageKeySet{s.storage}, &oidc.Config{SkipClientIDCheck: true})
idToken, err := verifier.Verify(r.Context(), rawIDToken)
idToken, err := verifier.Verify(ctx, rawIDToken)
if err != nil {
s.tokenErrHelper(w, errAccessDenied, err.Error(), http.StatusForbidden)
return
Expand All @@ -1108,6 +1115,7 @@ func (s *Server) handleUserInfo(w http.ResponseWriter, r *http.Request) {
}

func (s *Server) handlePasswordGrant(w http.ResponseWriter, r *http.Request, client storage.Client) {
ctx := r.Context()
// Parse the fields
if err := r.ParseForm(); err != nil {
s.tokenErrHelper(w, errInvalidRequest, "Couldn't parse data", http.StatusBadRequest)
Expand Down Expand Up @@ -1177,7 +1185,7 @@ func (s *Server) handlePasswordGrant(w http.ResponseWriter, r *http.Request, cli
// Login
username := q.Get("username")
password := q.Get("password")
identity, ok, err := passwordConnector.Login(r.Context(), parseScopes(scopes), username, password)
identity, ok, err := passwordConnector.Login(ctx, parseScopes(scopes), username, password)
if err != nil {
s.logger.Errorf("Failed to login user: %v", err)
s.tokenErrHelper(w, errInvalidRequest, "Could not login user", http.StatusBadRequest)
Expand Down Expand Up @@ -1252,7 +1260,7 @@ func (s *Server) handlePasswordGrant(w http.ResponseWriter, r *http.Request, cli
return
}

if err := s.storage.CreateRefresh(refresh); err != nil {
if err := s.storage.CreateRefresh(ctx, refresh); err != nil {
s.logger.Errorf("failed to create refresh token: %v", err)
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
return
Expand Down Expand Up @@ -1298,7 +1306,7 @@ func (s *Server) handlePasswordGrant(w http.ResponseWriter, r *http.Request, cli

// Create a new OfflineSession object for the user and add a reference object for
// the newly received refreshtoken.
if err := s.storage.CreateOfflineSessions(offlineSessions); err != nil {
if err := s.storage.CreateOfflineSessions(ctx, offlineSessions); err != nil {
s.logger.Errorf("failed to create offline session: %v", err)
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
deleteToken = true
Expand Down
Loading