Skip to content

Commit

Permalink
add validate client in session loader
Browse files Browse the repository at this point in the history
  • Loading branch information
Jing-ze committed Dec 26, 2024
1 parent 64ee6ed commit 10767df
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 17 deletions.
18 changes: 10 additions & 8 deletions oauthproxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ func NewOAuthProxy(opts *options.Options) (*OAuthProxy, error) {
if err != nil {
return nil, fmt.Errorf("could not build pre-auth chain: %v", err)
}
sessionChain := buildSessionChain(opts, provider, sessionStore, serviceClient)
sessionChain := buildSessionChain(opts, provider, sessionStore, serviceClient, validateServiceClient)

redirectValidator := redirect.NewValidator(opts.WhitelistDomains)
appDirector := redirect.NewAppDirector(redirect.AppDirectorOpts{
Expand Down Expand Up @@ -195,16 +195,18 @@ func buildPreAuthChain(opts *options.Options) (alice.Chain, error) {
return chain, nil
}

func buildSessionChain(opts *options.Options, provider providers.Provider, sessionStore sessionsapi.SessionStore, serviceClient wrapper.HttpClient) alice.Chain {
func buildSessionChain(opts *options.Options, provider providers.Provider, sessionStore sessionsapi.SessionStore, serviceClient wrapper.HttpClient, validateClient wrapper.HttpClient) alice.Chain {
chain := alice.New()

ss, loadSession := middleware.NewStoredSessionLoader(&middleware.StoredSessionLoaderOptions{
SessionStore: sessionStore,
RefreshPeriod: opts.Cookie.Refresh,
RefreshSession: provider.RefreshSession,
ValidateSession: provider.ValidateSession,
RefreshClient: serviceClient,
RefreshRequestTimeout: provider.Data().RedeemTimeout,
SessionStore: sessionStore,
RefreshPeriod: opts.Cookie.Refresh,
RefreshSession: provider.RefreshSession,
ValidateSession: provider.ValidateSession,
RefreshClient: serviceClient,
ValidateClient: validateClient,
RefreshRequestTimeout: provider.Data().RedeemTimeout,
ValidateRequestTimeout: provider.Data().RedeemTimeout,
})
chain = chain.Append(loadSession)
provider.Data().StoredSession = ss
Expand Down
29 changes: 20 additions & 9 deletions pkg/middleware/stored_session.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ type StoredSessionLoaderOptions struct {
// Refresh request parameters
RefreshClient wrapper.HttpClient
RefreshRequestTimeout uint32

// Validate request parameters
ValidateClient wrapper.HttpClient
ValidateRequestTimeout uint32
}

// NewStoredSessionLoader creates a new StoredSessionLoader which loads
Expand All @@ -49,12 +53,14 @@ type StoredSessionLoaderOptions struct {
// If a session was loader by a previous handler, it will not be replaced.
func NewStoredSessionLoader(opts *StoredSessionLoaderOptions) (*StoredSessionLoader, alice.Constructor) {
ss := &StoredSessionLoader{
store: opts.SessionStore,
refreshPeriod: opts.RefreshPeriod,
sessionRefresher: opts.RefreshSession,
sessionValidator: opts.ValidateSession,
refreshClient: opts.RefreshClient,
refreshRequestTimeout: opts.RefreshRequestTimeout,
store: opts.SessionStore,
refreshPeriod: opts.RefreshPeriod,
sessionRefresher: opts.RefreshSession,
sessionValidator: opts.ValidateSession,
refreshClient: opts.RefreshClient,
refreshRequestTimeout: opts.RefreshRequestTimeout,
validateClient: opts.ValidateClient,
validateRequestTimeout: opts.ValidateRequestTimeout,
}
return ss, ss.loadSession
}
Expand All @@ -70,8 +76,13 @@ type StoredSessionLoader struct {
// Refresh request parameters
refreshClient wrapper.HttpClient
refreshRequestTimeout uint32
RemoteKeySet *oidc.KeySet
NeedsVerifier bool

// Validate request parameters
validateClient wrapper.HttpClient
validateRequestTimeout uint32

RemoteKeySet *oidc.KeySet
NeedsVerifier bool
}

// loadSession attempts to load a session as identified by the request cookies.
Expand Down Expand Up @@ -222,7 +233,7 @@ func (s *StoredSessionLoader) validateSession(ctx context.Context, session *sess
if session.IsExpired() {
return errors.New("session is expired"), false
}
valid, isAsync := s.sessionValidator(ctx, session, s.refreshClient, callback, s.refreshRequestTimeout)
valid, isAsync := s.sessionValidator(ctx, session, s.validateClient, callback, s.validateRequestTimeout)
if !valid {
return errors.New("session is invalid"), isAsync
}
Expand Down

0 comments on commit 10767df

Please sign in to comment.