diff --git a/providers/oidc.go b/providers/oidc.go index 58adca077..ccd1bbdfd 100644 --- a/providers/oidc.go +++ b/providers/oidc.go @@ -35,7 +35,59 @@ func (p *OIDCProvider) Redeem(redirectURL, code string) (s *SessionState, err er if err != nil { return nil, fmt.Errorf("token exchange: %v", err) } + s, err = p.createSessionState(token, ctx) + if err != nil { + return nil, fmt.Errorf("unable to update session: %v", err) + } + return +} + +func (p *OIDCProvider) RefreshSessionIfNeeded(s *SessionState) (bool, error) { + if s == nil || s.ExpiresOn.After(time.Now()) || s.RefreshToken == "" { + return false, nil + } + + origExpiration := s.ExpiresOn + err := p.redeemRefreshToken(s) + if err != nil { + return false, fmt.Errorf("unable to redeem refresh token: %v", err) + } + + fmt.Printf("refreshed id token %s (expired on %s)\n", s, origExpiration) + return true, nil +} + +func (p *OIDCProvider) redeemRefreshToken(s *SessionState) (err error) { + c := oauth2.Config{ + ClientID: p.ClientID, + ClientSecret: p.ClientSecret, + Endpoint: oauth2.Endpoint{ + TokenURL: p.RedeemURL.String(), + }, + } + ctx := context.Background() + t := &oauth2.Token{ + RefreshToken: s.RefreshToken, + Expiry: time.Now().Add(-time.Hour), + } + token, err := c.TokenSource(ctx, t).Token() + if err != nil { + return fmt.Errorf("failed to get token: %v", err) + } + newSession, err := p.createSessionState(token, ctx) + if err != nil { + return fmt.Errorf("unable to update session: %v", err) + } + s.AccessToken = newSession.AccessToken + s.IdToken = newSession.IdToken + s.RefreshToken = newSession.RefreshToken + s.ExpiresOn = newSession.ExpiresOn + s.Email = newSession.Email + return +} + +func (p *OIDCProvider) createSessionState(token *oauth2.Token, ctx context.Context) (*SessionState, error) { rawIDToken, ok := token.Extra("id_token").(string) if !ok { return nil, fmt.Errorf("token response did not contain an id_token") @@ -63,24 +115,11 @@ func (p *OIDCProvider) Redeem(redirectURL, code string) (s *SessionState, err er return nil, fmt.Errorf("email in id_token (%s) isn't verified", claims.Email) } - s = &SessionState{ + return &SessionState{ AccessToken: token.AccessToken, IdToken: rawIDToken, RefreshToken: token.RefreshToken, ExpiresOn: token.Expiry, Email: claims.Email, - } - - return -} - -func (p *OIDCProvider) RefreshSessionIfNeeded(s *SessionState) (bool, error) { - if s == nil || s.ExpiresOn.After(time.Now()) || s.RefreshToken == "" { - return false, nil - } - - origExpiration := s.ExpiresOn - s.ExpiresOn = time.Now().Add(time.Second).Truncate(time.Second) - fmt.Printf("refreshed access token %s (expired on %s)\n", s, origExpiration) - return false, nil + }, nil }