Skip to content

Commit

Permalink
support aliyun provider
Browse files Browse the repository at this point in the history
  • Loading branch information
Jing-ze committed Aug 22, 2024
1 parent 2af2ff8 commit fb4884d
Show file tree
Hide file tree
Showing 7 changed files with 161 additions and 36 deletions.
34 changes: 10 additions & 24 deletions oauthproxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,7 @@ func buildSessionChain(opts *options.Options, provider providers.Provider, sessi
})
chain = chain.Append(loadSession)
provider.Data().StoredSession = ss
provider.Data().StoredSession.NeedsVerifier = provider.Data().NeedsVerifier
return chain
}

Expand Down Expand Up @@ -385,8 +386,12 @@ func (p *OAuthProxy) OAuthCallback(rw http.ResponseWriter, req *http.Request) {
util.SendError("Invalid authentication via OAuth2: unauthorized", rw, http.StatusForbidden)
}
}
if _, err := (*p.provider.Data().Verifier.GetKeySet()).VerifySignature(req.Context(), session.IDToken); err != nil {
(*p.provider.Data().Verifier.GetKeySet()).UpdateKeys(p.client, p.provider.Data().VerifierTimeout, updateKeysCallback)
if p.provider.Data().NeedsVerifier {
if _, err := (*p.provider.Data().Verifier.GetKeySet()).VerifySignature(req.Context(), session.IDToken); err != nil {
(*p.provider.Data().Verifier.GetKeySet()).UpdateKeys(p.client, p.provider.Data().VerifierTimeout, updateKeysCallback)
} else {
updateKeysCallback()
}
} else {
updateKeysCallback()
}
Expand All @@ -407,13 +412,12 @@ func (p *OAuthProxy) Proxy(rw http.ResponseWriter, req *http.Request) {
case err == nil:
rw.WriteHeader(http.StatusOK)
if p.passAuthorization {
proxywasm.AddHttpRequestHeader("Authorization", fmt.Sprintf("%s %s", providers.TokenTypeBearer, session.IDToken))
proxywasm.AddHttpRequestHeader("Authorization", fmt.Sprintf("%s %s", providers.TokenTypeBearer, session.AccessToken))
}
if cookies, ok := rw.Header()[SetCookieHeader]; ok && len(cookies) > 0 {
newCookieValue := strings.Join(cookies, ",")
if p.ctx != nil {
p.ctx.SetContext(SetCookieHeader, newCookieValue)
modifyRequestCookie(req, p.CookieOptions.Name, newCookieValue)
util.Logger.Info("Authentication and session refresh successfully .")
} else {
util.Logger.Error("Set Cookie failed cause HttpContext is nil.")
Expand Down Expand Up @@ -493,7 +497,7 @@ func (p *OAuthProxy) IsAllowedRequest(req *http.Request) bool {
}

func (p *OAuthProxy) ValidateVerifier() error {
if p.provider.Data().Verifier == nil {
if p.provider.Data().Verifier == nil && p.provider.Data().NeedsVerifier {
return errors.New("Failed to obtain OpenID configuration, current OIDC plugin is not working properly.")
}
return nil
Expand All @@ -504,7 +508,7 @@ func (p *OAuthProxy) SetContext(ctx wrapper.HttpContext) {
}

func (p *OAuthProxy) SetVerifier(opts *options.Options) {
if p.provider.Data().Verifier == nil {
if p.provider.Data().Verifier == nil && oidcHandler.provider.Data().NeedsVerifier {
providers.NewVerifierFromConfig(opts.Providers[0], p.provider.Data(), p.client)
}
}
Expand Down Expand Up @@ -631,21 +635,3 @@ func redirectToLocation(rw http.ResponseWriter, location string) {
}
proxywasm.SendHttpResponse(http.StatusFound, headersMap, nil, -1)
}

func modifyRequestCookie(req *http.Request, cookieName, newValue string) {
var cookies []string
found := false
for _, cookie := range req.Cookies() {
// find specify cookie name
if cookie.Name == cookieName {
found = true
cookies = append(cookies, fmt.Sprintf("%s=%s", cookie.Name, newValue))
} else {
cookies = append(cookies, fmt.Sprintf("%s=%s", cookie.Name, cookie.Value))
}
}
if !found {
cookies = append(cookies, fmt.Sprintf("%s=%s", cookieName, newValue))
}
proxywasm.ReplaceHttpRequestHeader("Cookie", strings.Join(cookies, "; "))
}
6 changes: 1 addition & 5 deletions pkg/apis/options/legacy_options.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ func legacyProviderDefaults() LegacyProvider {
ValidateURL: "",
Scope: "",
Prompt: "",
ApprovalPrompt: "force",
ApprovalPrompt: "",
UserIDClaim: OIDCEmailClaim,
AllowedGroups: nil,
AcrValues: "",
Expand Down Expand Up @@ -133,10 +133,6 @@ func (l *LegacyProvider) convert() (Providers, error) {
urlParams = append(urlParams, LoginURLParameter{Name: "prompt", Default: []string{l.Prompt}})
case l.ApprovalPrompt != "":
urlParams = append(urlParams, LoginURLParameter{Name: "approval_prompt", Default: []string{l.ApprovalPrompt}})
default:
// match legacy behaviour by default - if neither prompt nor approval_prompt
// specified, use approval_prompt=force
urlParams = append(urlParams, LoginURLParameter{Name: "approval_prompt", Default: []string{"force"}})
}

provider.LoginURLParameters = urlParams
Expand Down
2 changes: 2 additions & 0 deletions pkg/apis/options/providers.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@ type ProviderType string
const (
// OIDCProvider is the provider type for OIDC
OIDCProvider ProviderType = "oidc"

AliyunProvider ProviderType = "aliyun"
)

type OIDCOptions struct {
Expand Down
3 changes: 2 additions & 1 deletion pkg/middleware/stored_session.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ type StoredSessionLoader struct {
refreshClient wrapper.HttpClient
refreshRequestTimeout uint32
RemoteKeySet *oidc.KeySet
NeedsVerifier bool
}

// loadSession attempts to load a session as identified by the request cookies.
Expand Down Expand Up @@ -100,7 +101,7 @@ func (s *StoredSessionLoader) loadSession(next http.Handler) http.Handler {
}
}
}
keysNeedsUpdate := (session != nil)
keysNeedsUpdate := (session != nil) && (s.NeedsVerifier)
if keysNeedsUpdate {
if _, err := (*s.RemoteKeySet).VerifySignature(req.Context(), session.IDToken); err == nil {
keysNeedsUpdate = false
Expand Down
134 changes: 134 additions & 0 deletions providers/aliyun.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
package providers

import (
"context"
"fmt"
"net/http"
"net/url"

"github.com/Jing-ze/oauth2-proxy/pkg/apis/sessions"
"github.com/Jing-ze/oauth2-proxy/pkg/util"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
)

type AliyunProvider struct {
*ProviderData
}

const (
aliyunProviderName = "Aliyun"
aliyunDefaultScope = "openid"
)

var (
aliyunDefaultLoginURL = &url.URL{
Scheme: "https",
Host: "signin.aliyun.com",
Path: "/oauth2/v1/auth",
RawQuery: "access_type=offline",
}

aliyunDefaultRedeemURL = &url.URL{
Scheme: "https",
Host: "oauth.aliyun.com",
Path: "/v1/token",
}
)

func NewAliyunProvider(p *ProviderData) *AliyunProvider {
p.setProviderDefaults(providerDefaults{
name: aliyunProviderName,
loginURL: aliyunDefaultLoginURL,
redeemURL: aliyunDefaultRedeemURL,
profileURL: nil,
validateURL: nil,
scope: aliyunDefaultScope,
})

provider := &AliyunProvider{ProviderData: p}

return provider
}

var _ Provider = (*AliyunProvider)(nil)

func (p *AliyunProvider) Redeem(ctx context.Context, redirectURL, code, codeVerifier string, client wrapper.HttpClient, callback func(args ...interface{}), timeout uint32) error {
clientSecret, err := p.GetClientSecret()
if err != nil {
return err
}
params := url.Values{}
params.Add("redirect_uri", redirectURL)
params.Add("client_id", p.ClientID)
params.Add("client_secret", clientSecret)
params.Add("code", code)
params.Add("grant_type", "authorization_code")

headers := [][2]string{{"Content-Type", "application/x-www-form-urlencoded"}}

client.Post(p.RedeemURL.String(), headers, []byte(params.Encode()), func(statusCode int, responseHeaders http.Header, responseBody []byte) {
token, err := util.UnmarshalToken(responseHeaders, responseBody)
if err != nil {
util.SendError(err.Error(), nil, http.StatusInternalServerError)
return
}
id_token, ok := token.Extra("id_token").(string)
if !ok {
util.SendError("id_token not found", nil, http.StatusInternalServerError)
return
}
session := &sessions.SessionState{
IDToken: id_token,
AccessToken: token.AccessToken,
RefreshToken: token.RefreshToken,
}
session.CreatedAtNow()
session.SetExpiresOn(token.Expiry)

callback(session)
}, timeout)

return nil
}

func (p *AliyunProvider) RefreshSession(ctx context.Context, s *sessions.SessionState, client wrapper.HttpClient, callback func(args ...interface{}), timeout uint32) (bool, error) {
if s == nil || s.RefreshToken == "" {
return false, fmt.Errorf("refresh token is empty")
}

err := p.redeemRefreshToken(ctx, s, client, callback, timeout)
if err != nil {
return false, fmt.Errorf("unable to redeem refresh token: %v", err)
}

return true, nil
}

func (p *AliyunProvider) redeemRefreshToken(ctx context.Context, s *sessions.SessionState, client wrapper.HttpClient, callback func(args ...interface{}), timeout uint32) error {
clientSecret, err := p.GetClientSecret()
if err != nil {
return err
}
params := url.Values{}
params.Add("client_id", p.ClientID)
params.Add("client_secret", clientSecret)
params.Add("refresh_token", s.RefreshToken)
params.Add("grant_type", "refresh_token")

headers := [][2]string{{"Content-Type", "application/x-www-form-urlencoded"}}

client.Post(p.RedeemURL.String(), headers, []byte(params.Encode()), func(statusCode int, responseHeaders http.Header, responseBody []byte) {
token, err := util.UnmarshalToken(responseHeaders, responseBody)
if err != nil {
util.SendError(err.Error(), nil, http.StatusInternalServerError)
return
}
s.AccessToken = token.AccessToken
s.CreatedAtNow()
s.SetExpiresOn(token.Expiry)

callback(s, true)
}, timeout)

return nil
}
1 change: 1 addition & 0 deletions providers/provider_data.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ type ProviderData struct {
EmailClaim string
GroupsClaim string
Verifier internaloidc.IDTokenVerifier
NeedsVerifier bool
SkipClaimsFromProfileURL bool

// Universal Group authorization data structure
Expand Down
17 changes: 11 additions & 6 deletions providers/providers.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,18 +43,15 @@ func NewProvider(providerConfig options.Provider) (Provider, error) {
switch providerConfig.Type {
case options.OIDCProvider:
return NewOIDCProvider(providerData, providerConfig.OIDCConfig), nil
case options.AliyunProvider:
return NewAliyunProvider(providerData), nil
default:
return nil, fmt.Errorf("unknown provider type %q", providerConfig.Type)
}
}

func NewVerifierFromConfig(providerConfig options.Provider, p *ProviderData, client wrapper.HttpClient) error {

needsVerifier, err := providerRequiresOIDCProviderVerifier(providerConfig.Type)
if err != nil {
return err
}
if needsVerifier {
if p.NeedsVerifier {
verifierOptions := internaloidc.ProviderVerifierOptions{
AudienceClaims: providerConfig.OIDCConfig.AudienceClaims,
ClientID: providerConfig.ClientID,
Expand Down Expand Up @@ -104,6 +101,12 @@ func newProviderDataFromConfig(providerConfig options.Provider) (*ProviderData,
VerifierTimeout: providerConfig.OIDCConfig.VerifierRequestTimeout,
}

needsVerifier, err := providerRequiresOIDCProviderVerifier(providerConfig.Type)
if err != nil {
return nil, err
}
p.NeedsVerifier = needsVerifier

errs := providerConfigInfoCheck(providerConfig, p)
// handle LoginURLParameters
errs = append(errs, p.compileLoginParams(providerConfig.LoginURLParameters)...)
Expand Down Expand Up @@ -155,6 +158,8 @@ func providerRequiresOIDCProviderVerifier(providerType options.ProviderType) (bo
switch providerType {
case options.OIDCProvider:
return true, nil
case options.AliyunProvider:
return false, nil
default:
return false, fmt.Errorf("unknown provider type: %s", providerType)
}
Expand Down

0 comments on commit fb4884d

Please sign in to comment.