diff --git a/auth/authzserver/authorize_test.go b/auth/authzserver/authorize_test.go index d481283df..6091f4fcb 100644 --- a/auth/authzserver/authorize_test.go +++ b/auth/authzserver/authorize_test.go @@ -34,7 +34,7 @@ func TestAuthEndpoint(t *testing.T) { authCtx.OnOAuth2Provider().Return(oauth2Provider) cookieManager := &mocks.CookieHandler{} - cookieManager.OnSetAuthCodeCookie(req.Context(), w, originalURL).Return(nil) + cookieManager.OnSetAuthCodeCookie(req.Context(), req, w, originalURL).Return(nil) authCtx.OnCookieManager().Return(cookieManager) authEndpoint(authCtx, w, req) @@ -57,7 +57,7 @@ func TestAuthEndpoint(t *testing.T) { authCtx.OnOAuth2Provider().Return(oauth2Provider) cookieManager := &mocks.CookieHandler{} - cookieManager.OnSetAuthCodeCookie(req.Context(), w, originalURL).Return(fmt.Errorf("failure injection")) + cookieManager.OnSetAuthCodeCookie(req.Context(), req, w, originalURL).Return(fmt.Errorf("failure injection")) authCtx.OnCookieManager().Return(cookieManager) authEndpoint(authCtx, w, req) diff --git a/auth/config/config.go b/auth/config/config.go index f131df036..03e81d27a 100644 --- a/auth/config/config.go +++ b/auth/config/config.go @@ -1,7 +1,6 @@ package config import ( - "net/http" "net/url" "time" @@ -74,9 +73,9 @@ var ( "profile", }, }, - CookieSetting: &CookieSettings{ - CoverSubdomains: false, - SameSite: http.SameSiteDefaultMode, + CookieSetting: CookieSettings{ + DomainMatchPolicy: DomainMatchExact, + SameSitePolicy: SameSiteDefaultMode, }, }, AppAuth: OAuth2Options{ @@ -217,14 +216,32 @@ type UserAuthConfig struct { // Possibly add basicAuth & SAML/p support. // Secret names, defaults are set in DefaultConfig variable above but are possible to override through configs. - CookieHashKeySecretName string `json:"cookieHashKeySecretName" pflag:",OPTIONAL: Secret name to use for cookie hash key."` - CookieBlockKeySecretName string `json:"cookieBlockKeySecretName" pflag:",OPTIONAL: Secret name to use for cookie block key."` - CookieSetting *CookieSettings `json:"cookieSetting" pflag:", settings used by cookies created for user auth"` + CookieHashKeySecretName string `json:"cookieHashKeySecretName" pflag:",OPTIONAL: Secret name to use for cookie hash key."` + CookieBlockKeySecretName string `json:"cookieBlockKeySecretName" pflag:",OPTIONAL: Secret name to use for cookie block key."` + CookieSetting CookieSettings `json:"cookieSetting" pflag:", settings used by cookies created for user auth"` } +//go:generate enumer --type=DomainMatch --trimprefix=DomainMatch -json +type DomainMatch int + +const ( + DomainMatchExact DomainMatch = iota + DomainMatchSubdomains +) + +//go:generate enumer --type=SameSite --trimprefix=SameSite -json +type SameSite int + +const ( + SameSiteDefaultMode SameSite = iota + SameSiteLaxMode + SameSiteStrictMode + SameSiteNoneMode +) + type CookieSettings struct { - SameSite http.SameSite `json:"sameSite" pflag:",OPTIONAL: Allows you to declare if your cookie should be restricted to a first-party or same-site context."` - CoverSubdomains bool `json:"coverSubDomains" pflag:",OPTIONAL: Allow subdomain access to the created cookies by setting the domain attribute."` + SameSitePolicy SameSite `json:"sameSitePolicy" pflag:",OPTIONAL: Allows you to declare if your cookie should be restricted to a first-party or same-site context.Wrapper around http.SameSite."` + DomainMatchPolicy DomainMatch `json:"domainMatchPolicy" pflag:",OPTIONAL: Allow subdomain access to the created cookies by setting the domain attribute or do an exact match on domain."` } type OpenIDOptions struct { diff --git a/auth/config/config_flags.go b/auth/config/config_flags.go index cd6223778..4be9d6a0a 100755 --- a/auth/config/config_flags.go +++ b/auth/config/config_flags.go @@ -62,6 +62,8 @@ func (cfg Config) GetPFlagSet(prefix string) *pflag.FlagSet { cmdFlags.StringSlice(fmt.Sprintf("%v%v", prefix, "userAuth.openId.scopes"), DefaultConfig.UserAuth.OpenID.Scopes, "") cmdFlags.String(fmt.Sprintf("%v%v", prefix, "userAuth.cookieHashKeySecretName"), DefaultConfig.UserAuth.CookieHashKeySecretName, "OPTIONAL: Secret name to use for cookie hash key.") cmdFlags.String(fmt.Sprintf("%v%v", prefix, "userAuth.cookieBlockKeySecretName"), DefaultConfig.UserAuth.CookieBlockKeySecretName, "OPTIONAL: Secret name to use for cookie block key.") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "userAuth.cookieSetting.sameSitePolicy"), DefaultConfig.UserAuth.CookieSetting.SameSitePolicy.String(), "OPTIONAL: Allows you to declare if your cookie should be restricted to a first-party or same-site context.Wrapper around http.SameSite.") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "userAuth.cookieSetting.domainMatchPolicy"), DefaultConfig.UserAuth.CookieSetting.DomainMatchPolicy.String(), "OPTIONAL: Allow subdomain access to the created cookies by setting the domain attribute or do an exact match on domain.") cmdFlags.String(fmt.Sprintf("%v%v", prefix, "appAuth.selfAuthServer.issuer"), DefaultConfig.AppAuth.SelfAuthServer.Issuer, "Defines the issuer to use when issuing and validating tokens. The default value is https:///") cmdFlags.String(fmt.Sprintf("%v%v", prefix, "appAuth.selfAuthServer.accessTokenLifespan"), DefaultConfig.AppAuth.SelfAuthServer.AccessTokenLifespan.String(), "Defines the lifespan of issued access tokens.") cmdFlags.String(fmt.Sprintf("%v%v", prefix, "appAuth.selfAuthServer.refreshTokenLifespan"), DefaultConfig.AppAuth.SelfAuthServer.RefreshTokenLifespan.String(), "Defines the lifespan of issued access tokens.") diff --git a/auth/config/config_flags_test.go b/auth/config/config_flags_test.go index 8cd44ab68..b2e601e7a 100755 --- a/auth/config/config_flags_test.go +++ b/auth/config/config_flags_test.go @@ -267,6 +267,34 @@ func TestConfig_SetFlags(t *testing.T) { } }) }) + t.Run("Test_userAuth.cookieSetting.sameSitePolicy", func(t *testing.T) { + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("userAuth.cookieSetting.sameSitePolicy", testValue) + if vString, err := cmdFlags.GetString("userAuth.cookieSetting.sameSitePolicy"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.UserAuth.CookieSetting.SameSitePolicy) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_userAuth.cookieSetting.domainMatchPolicy", func(t *testing.T) { + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("userAuth.cookieSetting.domainMatchPolicy", testValue) + if vString, err := cmdFlags.GetString("userAuth.cookieSetting.domainMatchPolicy"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.UserAuth.CookieSetting.DomainMatchPolicy) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) t.Run("Test_appAuth.selfAuthServer.issuer", func(t *testing.T) { t.Run("Override", func(t *testing.T) { diff --git a/auth/config/domainmatch_enumer.go b/auth/config/domainmatch_enumer.go new file mode 100644 index 000000000..39087dc09 --- /dev/null +++ b/auth/config/domainmatch_enumer.go @@ -0,0 +1,68 @@ +// Code generated by "enumer --type=DomainMatch --trimprefix=DomainMatch -json"; DO NOT EDIT. + +// +package config + +import ( + "encoding/json" + "fmt" +) + +const _DomainMatchName = "ExactSubdomains" + +var _DomainMatchIndex = [...]uint8{0, 5, 15} + +func (i DomainMatch) String() string { + if i < 0 || i >= DomainMatch(len(_DomainMatchIndex)-1) { + return fmt.Sprintf("DomainMatch(%d)", i) + } + return _DomainMatchName[_DomainMatchIndex[i]:_DomainMatchIndex[i+1]] +} + +var _DomainMatchValues = []DomainMatch{0, 1} + +var _DomainMatchNameToValueMap = map[string]DomainMatch{ + _DomainMatchName[0:5]: 0, + _DomainMatchName[5:15]: 1, +} + +// DomainMatchString retrieves an enum value from the enum constants string name. +// Throws an error if the param is not part of the enum. +func DomainMatchString(s string) (DomainMatch, error) { + if val, ok := _DomainMatchNameToValueMap[s]; ok { + return val, nil + } + return 0, fmt.Errorf("%s does not belong to DomainMatch values", s) +} + +// DomainMatchValues returns all values of the enum +func DomainMatchValues() []DomainMatch { + return _DomainMatchValues +} + +// IsADomainMatch returns "true" if the value is listed in the enum definition. "false" otherwise +func (i DomainMatch) IsADomainMatch() bool { + for _, v := range _DomainMatchValues { + if i == v { + return true + } + } + return false +} + +// MarshalJSON implements the json.Marshaler interface for DomainMatch +func (i DomainMatch) MarshalJSON() ([]byte, error) { + return json.Marshal(i.String()) +} + +// UnmarshalJSON implements the json.Unmarshaler interface for DomainMatch +func (i *DomainMatch) UnmarshalJSON(data []byte) error { + var s string + if err := json.Unmarshal(data, &s); err != nil { + return fmt.Errorf("DomainMatch should be a string, got %s", data) + } + + var err error + *i, err = DomainMatchString(s) + return err +} diff --git a/auth/config/samesite_enumer.go b/auth/config/samesite_enumer.go new file mode 100644 index 000000000..af9bfdf6c --- /dev/null +++ b/auth/config/samesite_enumer.go @@ -0,0 +1,70 @@ +// Code generated by "enumer --type=SameSite --trimprefix=SameSite -json"; DO NOT EDIT. + +// +package config + +import ( + "encoding/json" + "fmt" +) + +const _SameSiteName = "DefaultModeLaxModeStrictModeNoneMode" + +var _SameSiteIndex = [...]uint8{0, 11, 18, 28, 36} + +func (i SameSite) String() string { + if i < 0 || i >= SameSite(len(_SameSiteIndex)-1) { + return fmt.Sprintf("SameSite(%d)", i) + } + return _SameSiteName[_SameSiteIndex[i]:_SameSiteIndex[i+1]] +} + +var _SameSiteValues = []SameSite{0, 1, 2, 3} + +var _SameSiteNameToValueMap = map[string]SameSite{ + _SameSiteName[0:11]: 0, + _SameSiteName[11:18]: 1, + _SameSiteName[18:28]: 2, + _SameSiteName[28:36]: 3, +} + +// SameSiteString retrieves an enum value from the enum constants string name. +// Throws an error if the param is not part of the enum. +func SameSiteString(s string) (SameSite, error) { + if val, ok := _SameSiteNameToValueMap[s]; ok { + return val, nil + } + return 0, fmt.Errorf("%s does not belong to SameSite values", s) +} + +// SameSiteValues returns all values of the enum +func SameSiteValues() []SameSite { + return _SameSiteValues +} + +// IsASameSite returns "true" if the value is listed in the enum definition. "false" otherwise +func (i SameSite) IsASameSite() bool { + for _, v := range _SameSiteValues { + if i == v { + return true + } + } + return false +} + +// MarshalJSON implements the json.Marshaler interface for SameSite +func (i SameSite) MarshalJSON() ([]byte, error) { + return json.Marshal(i.String()) +} + +// UnmarshalJSON implements the json.Unmarshaler interface for SameSite +func (i *SameSite) UnmarshalJSON(data []byte) error { + var s string + if err := json.Unmarshal(data, &s); err != nil { + return fmt.Errorf("SameSite should be a string, got %s", data) + } + + var err error + *i, err = SameSiteString(s) + return err +} diff --git a/auth/cookie_manager.go b/auth/cookie_manager.go index 37681f7c7..32435e5a9 100644 --- a/auth/cookie_manager.go +++ b/auth/cookie_manager.go @@ -5,22 +5,22 @@ import ( "encoding/base64" "encoding/json" "fmt" - "github.com/flyteorg/flyteadmin/auth/config" "net/http" "time" + "github.com/flyteorg/flyteadmin/auth/config" "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/service" - "github.com/flyteorg/flytestdlib/errors" "github.com/flyteorg/flytestdlib/logger" + "golang.org/x/oauth2" ) type CookieManager struct { - hashKey []byte - blockKey []byte - coverSubDomains bool - sameSite http.SameSite + hashKey []byte + blockKey []byte + domainMatchPolicy config.DomainMatch + sameSitePolicy config.SameSite } const ( @@ -31,7 +31,7 @@ const ( ErrNoIDToken errors.ErrorCode = "NO_ID_TOKEN_IN_RESPONSE" ) -func NewCookieManager(ctx context.Context, hashKeyEncoded, blockKeyEncoded string, cookieSettings *config.CookieSettings) (CookieManager, error) { +func NewCookieManager(ctx context.Context, hashKeyEncoded, blockKeyEncoded string, cookieSettings config.CookieSettings) (CookieManager, error) { logger.Infof(ctx, "Instantiating cookie manager") hashKey, err := base64.RawStdEncoding.DecodeString(hashKeyEncoded) @@ -45,10 +45,10 @@ func NewCookieManager(ctx context.Context, hashKeyEncoded, blockKeyEncoded strin } return CookieManager{ - hashKey: hashKey, - blockKey: blockKey, - coverSubDomains: cookieSettings.CoverSubdomains, - sameSite: cookieSettings.SameSite, + hashKey: hashKey, + blockKey: blockKey, + domainMatchPolicy: cookieSettings.DomainMatchPolicy, + sameSitePolicy: cookieSettings.SameSitePolicy, }, nil } @@ -86,7 +86,7 @@ func (c CookieManager) SetUserInfoCookie(ctx context.Context, request *http.Requ return fmt.Errorf("failed to marshal user info to store in a cookie. Error: %w", err) } - userInfoCookie, err := NewSecureCookie(userInfoCookieName, string(raw), c.hashKey, c.blockKey, c.getCookieDomain(request), c.sameSite) + userInfoCookie, err := NewSecureCookie(userInfoCookieName, string(raw), c.hashKey, c.blockKey, c.getCookieDomain(request), c.getHTTPSameSitePolicy()) if err != nil { logger.Errorf(ctx, "Error generating encrypted user info cookie %s", err) return err @@ -124,7 +124,7 @@ func (c CookieManager) RetrieveAuthCodeRequest(ctx context.Context, request *htt } func (c CookieManager) SetAuthCodeCookie(ctx context.Context, request *http.Request, writer http.ResponseWriter, authRequestURL string) error { - authCodeCookie, err := NewSecureCookie(authCodeCookieName, authRequestURL, c.hashKey, c.blockKey, c.getCookieDomain(request), c.sameSite) + authCodeCookie, err := NewSecureCookie(authCodeCookieName, authRequestURL, c.hashKey, c.blockKey, c.getCookieDomain(request), c.getHTTPSameSitePolicy()) if err != nil { logger.Errorf(ctx, "Error generating encrypted accesstoken cookie %s", err) return err @@ -141,7 +141,7 @@ func (c CookieManager) SetTokenCookies(ctx context.Context, request *http.Reques return errors.Errorf(ErrTokenNil, "Attempting to set cookies with nil token") } - atCookie, err := NewSecureCookie(accessTokenCookieName, token.AccessToken, c.hashKey, c.blockKey, c.getCookieDomain(request), c.sameSite) + atCookie, err := NewSecureCookie(accessTokenCookieName, token.AccessToken, c.hashKey, c.blockKey, c.getCookieDomain(request), c.getHTTPSameSitePolicy()) if err != nil { logger.Errorf(ctx, "Error generating encrypted accesstoken cookie %s", err) return err @@ -150,7 +150,7 @@ func (c CookieManager) SetTokenCookies(ctx context.Context, request *http.Reques http.SetCookie(writer, &atCookie) if idTokenRaw, converted := token.Extra(idTokenExtra).(string); converted { - idCookie, err := NewSecureCookie(idTokenCookieName, idTokenRaw, c.hashKey, c.blockKey, c.getCookieDomain(request), c.sameSite) + idCookie, err := NewSecureCookie(idTokenCookieName, idTokenRaw, c.hashKey, c.blockKey, c.getCookieDomain(request), c.getHTTPSameSitePolicy()) if err != nil { logger.Errorf(ctx, "Error generating encrypted id token cookie %s", err) return err @@ -164,7 +164,7 @@ func (c CookieManager) SetTokenCookies(ctx context.Context, request *http.Reques // Set the refresh cookie if there is a refresh token if token.RefreshToken != "" { - refreshCookie, err := NewSecureCookie(refreshTokenCookieName, token.RefreshToken, c.hashKey, c.blockKey, cookieDomain, c.sameSite) + refreshCookie, err := NewSecureCookie(refreshTokenCookieName, token.RefreshToken, c.hashKey, c.blockKey, c.getCookieDomain(request), c.getHTTPSameSitePolicy()) if err != nil { logger.Errorf(ctx, "Error generating encrypted refresh token cookie %s", err) return err @@ -200,8 +200,23 @@ func (c CookieManager) DeleteCookies(ctx context.Context, writer http.ResponseWr http.SetCookie(writer, getLogoutRefreshCookie()) } +func (c CookieManager) getHTTPSameSitePolicy() http.SameSite { + httpSameSite := http.SameSiteDefaultMode + switch c.sameSitePolicy { + case config.SameSiteDefaultMode: + httpSameSite = http.SameSiteDefaultMode + case config.SameSiteLaxMode: + httpSameSite = http.SameSiteLaxMode + case config.SameSiteStrictMode: + httpSameSite = http.SameSiteStrictMode + case config.SameSiteNoneMode: + httpSameSite = http.SameSiteNoneMode + } + return httpSameSite +} + func (c CookieManager) getCookieDomain(request *http.Request) string { - if !c.coverSubDomains { + if c.domainMatchPolicy == config.DomainMatchExact { return "" } return fmt.Sprintf(".%s", request.URL.Hostname()) diff --git a/auth/cookie_manager_test.go b/auth/cookie_manager_test.go index 28b377e07..b7f69ca64 100644 --- a/auth/cookie_manager_test.go +++ b/auth/cookie_manager_test.go @@ -8,6 +8,8 @@ import ( "testing" "time" + "github.com/flyteorg/flyteadmin/auth/config" + "github.com/stretchr/testify/assert" "golang.org/x/oauth2" ) @@ -17,8 +19,11 @@ func TestCookieManager_SetTokenCookies(t *testing.T) { // These were generated for unit testing only. hashKeyEncoded := "wG4pE1ccdw/pHZ2ml8wrD5VJkOtLPmBpWbKHmezWXktGaFbRoAhXidWs8OpbA3y7N8vyZhz1B1E37+tShWC7gA" //nolint:goconst blockKeyEncoded := "afyABVgGOvWJFxVyOvCWCupoTn6BkNl4SOHmahho16Q" //nolint:goconst - - manager, err := NewCookieManager(ctx, hashKeyEncoded, blockKeyEncoded) + cookieSetting := config.CookieSettings{ + SameSitePolicy: config.SameSiteDefaultMode, + DomainMatchPolicy: config.DomainMatchSubdomains, + } + manager, err := NewCookieManager(ctx, hashKeyEncoded, blockKeyEncoded, cookieSetting) assert.NoError(t, err) token := &oauth2.Token{ @@ -31,7 +36,9 @@ func TestCookieManager_SetTokenCookies(t *testing.T) { }) w := httptest.NewRecorder() - err = manager.SetTokenCookies(ctx, w, token) + req, err := http.NewRequest("GET", "/api/v1/projects", nil) + assert.NoError(t, err) + err = manager.SetTokenCookies(ctx, req, w, token) assert.NoError(t, err) fmt.Println(w.Header().Get("Set-Cookie")) c := w.Result().Cookies() @@ -46,7 +53,12 @@ func TestCookieManager_RetrieveTokenValues(t *testing.T) { hashKeyEncoded := "wG4pE1ccdw/pHZ2ml8wrD5VJkOtLPmBpWbKHmezWXktGaFbRoAhXidWs8OpbA3y7N8vyZhz1B1E37+tShWC7gA" //nolint:goconst blockKeyEncoded := "afyABVgGOvWJFxVyOvCWCupoTn6BkNl4SOHmahho16Q" //nolint:goconst - manager, err := NewCookieManager(ctx, hashKeyEncoded, blockKeyEncoded) + cookieSetting := config.CookieSettings{ + SameSitePolicy: config.SameSiteDefaultMode, + DomainMatchPolicy: config.DomainMatchSubdomains, + } + + manager, err := NewCookieManager(ctx, hashKeyEncoded, blockKeyEncoded, cookieSetting) assert.NoError(t, err) token := &oauth2.Token{ @@ -59,11 +71,13 @@ func TestCookieManager_RetrieveTokenValues(t *testing.T) { }) w := httptest.NewRecorder() - err = manager.SetTokenCookies(ctx, w, token) + req, err := http.NewRequest("GET", "/api/v1/projects", nil) + assert.NoError(t, err) + err = manager.SetTokenCookies(ctx, req, w, token) assert.NoError(t, err) cookies := w.Result().Cookies() - req, err := http.NewRequest("GET", "/api/v1/projects", nil) + req, err = http.NewRequest("GET", "/api/v1/projects", nil) assert.NoError(t, err) for _, c := range cookies { req.AddCookie(c) @@ -92,8 +106,12 @@ func TestCookieManager_DeleteCookies(t *testing.T) { // These were generated for unit testing only. hashKeyEncoded := "wG4pE1ccdw/pHZ2ml8wrD5VJkOtLPmBpWbKHmezWXktGaFbRoAhXidWs8OpbA3y7N8vyZhz1B1E37+tShWC7gA" //nolint:goconst blockKeyEncoded := "afyABVgGOvWJFxVyOvCWCupoTn6BkNl4SOHmahho16Q" //nolint:goconst + cookieSetting := config.CookieSettings{ + SameSitePolicy: config.SameSiteDefaultMode, + DomainMatchPolicy: config.DomainMatchSubdomains, + } - manager, err := NewCookieManager(ctx, hashKeyEncoded, blockKeyEncoded) + manager, err := NewCookieManager(ctx, hashKeyEncoded, blockKeyEncoded, cookieSetting) assert.NoError(t, err) w := httptest.NewRecorder() diff --git a/auth/cookie_test.go b/auth/cookie_test.go index 64573c6b6..b525ba8b4 100644 --- a/auth/cookie_test.go +++ b/auth/cookie_test.go @@ -35,7 +35,7 @@ func TestSecureCookieLifecycle(t *testing.T) { fmt.Printf("Hash key: |%s| Block key: |%s|\n", base64.RawStdEncoding.EncodeToString(hashKey), base64.RawStdEncoding.EncodeToString(blockKey)) - cookie, err := NewSecureCookie("choc", "chip", hashKey, blockKey) + cookie, err := NewSecureCookie("choc", "chip", hashKey, blockKey, "localhost", http.SameSiteLaxMode) assert.NoError(t, err) value, err := ReadSecureCookie(context.Background(), cookie, hashKey, blockKey) diff --git a/auth/handlers_test.go b/auth/handlers_test.go index d86dee3a4..5a210b0a5 100644 --- a/auth/handlers_test.go +++ b/auth/handlers_test.go @@ -43,8 +43,8 @@ func setupMockedAuthContextAtEndpoint(endpoint string) *mocks.AuthenticationCont Scopes: []string{"openid", "other"}, } mockAuthCtx.OnCookieManagerMatch().Return(mockCookieHandler) - mockCookieHandler.OnSetTokenCookiesMatch(mock.Anything, mock.Anything, mock.Anything).Return(nil) - mockCookieHandler.OnSetUserInfoCookieMatch(mock.Anything, mock.Anything, mock.Anything).Return(nil) + mockCookieHandler.OnSetTokenCookiesMatch(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil) + mockCookieHandler.OnSetUserInfoCookieMatch(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil) mockAuthCtx.OnOAuth2ClientConfigMatch(mock.Anything).Return(&dummyOAuth2Config) return mockAuthCtx } @@ -212,7 +212,11 @@ func TestGetHTTPRequestCookieToMetadataHandler(t *testing.T) { // These were generated for unit testing only. hashKeyEncoded := "wG4pE1ccdw/pHZ2ml8wrD5VJkOtLPmBpWbKHmezWXktGaFbRoAhXidWs8OpbA3y7N8vyZhz1B1E37+tShWC7gA" //nolint:goconst blockKeyEncoded := "afyABVgGOvWJFxVyOvCWCupoTn6BkNl4SOHmahho16Q" //nolint:goconst - cookieManager, err := NewCookieManager(ctx, hashKeyEncoded, blockKeyEncoded) + cookieSetting := config.CookieSettings{ + SameSitePolicy: config.SameSiteDefaultMode, + DomainMatchPolicy: config.DomainMatchSubdomains, + } + cookieManager, err := NewCookieManager(ctx, hashKeyEncoded, blockKeyEncoded, cookieSetting) assert.NoError(t, err) mockAuthCtx := mocks.AuthenticationContext{} mockAuthCtx.OnCookieManager().Return(&cookieManager) @@ -221,11 +225,11 @@ func TestGetHTTPRequestCookieToMetadataHandler(t *testing.T) { req, err := http.NewRequest("GET", "/api/v1/projects", nil) assert.NoError(t, err) - accessTokenCookie, err := NewSecureCookie(accessTokenCookieName, "a.b.c", cookieManager.hashKey, cookieManager.blockKey) + accessTokenCookie, err := NewSecureCookie(accessTokenCookieName, "a.b.c", cookieManager.hashKey, cookieManager.blockKey, "localhost", http.SameSiteDefaultMode) assert.NoError(t, err) req.AddCookie(&accessTokenCookie) - idCookie, err := NewSecureCookie(idTokenCookieName, "a.b.c", cookieManager.hashKey, cookieManager.blockKey) + idCookie, err := NewSecureCookie(idTokenCookieName, "a.b.c", cookieManager.hashKey, cookieManager.blockKey, "localhost", http.SameSiteDefaultMode) assert.NoError(t, err) req.AddCookie(&idCookie) @@ -246,7 +250,11 @@ func TestGetHTTPRequestCookieToMetadataHandler_CustomHeader(t *testing.T) { // These were generated for unit testing only. hashKeyEncoded := "wG4pE1ccdw/pHZ2ml8wrD5VJkOtLPmBpWbKHmezWXktGaFbRoAhXidWs8OpbA3y7N8vyZhz1B1E37+tShWC7gA" //nolint:goconst blockKeyEncoded := "afyABVgGOvWJFxVyOvCWCupoTn6BkNl4SOHmahho16Q" //nolint:goconst - cookieManager, err := NewCookieManager(ctx, hashKeyEncoded, blockKeyEncoded) + cookieSetting := config.CookieSettings{ + SameSitePolicy: config.SameSiteDefaultMode, + DomainMatchPolicy: config.DomainMatchSubdomains, + } + cookieManager, err := NewCookieManager(ctx, hashKeyEncoded, blockKeyEncoded, cookieSetting) assert.NoError(t, err) mockAuthCtx := mocks.AuthenticationContext{} mockAuthCtx.On("CookieManager").Return(&cookieManager) diff --git a/auth/interfaces/cookie.go b/auth/interfaces/cookie.go index f682749e3..2ff86ab9a 100644 --- a/auth/interfaces/cookie.go +++ b/auth/interfaces/cookie.go @@ -9,6 +9,8 @@ import ( "golang.org/x/oauth2" ) +//go:generate mockery -name=CookieHandler -output=mocks/ -case=underscore + type CookieHandler interface { SetTokenCookies(ctx context.Context, request *http.Request, writer http.ResponseWriter, token *oauth2.Token) error RetrieveTokenValues(ctx context.Context, request *http.Request) (idToken, accessToken, refreshToken string, err error) diff --git a/auth/interfaces/mocks/cookie_handler.go b/auth/interfaces/mocks/cookie_handler.go index 74371b0ce..289c19123 100644 --- a/auth/interfaces/mocks/cookie_handler.go +++ b/auth/interfaces/mocks/cookie_handler.go @@ -164,8 +164,8 @@ func (_m CookieHandler_SetAuthCodeCookie) Return(_a0 error) *CookieHandler_SetAu return &CookieHandler_SetAuthCodeCookie{Call: _m.Call.Return(_a0)} } -func (_m *CookieHandler) OnSetAuthCodeCookie(ctx context.Context, writer http.ResponseWriter, authRequestURL string) *CookieHandler_SetAuthCodeCookie { - c_call := _m.On("SetAuthCodeCookie", ctx, writer, authRequestURL) +func (_m *CookieHandler) OnSetAuthCodeCookie(ctx context.Context, request *http.Request, writer http.ResponseWriter, authRequestURL string) *CookieHandler_SetAuthCodeCookie { + c_call := _m.On("SetAuthCodeCookie", ctx, request, writer, authRequestURL) return &CookieHandler_SetAuthCodeCookie{Call: c_call} } @@ -174,13 +174,13 @@ func (_m *CookieHandler) OnSetAuthCodeCookieMatch(matchers ...interface{}) *Cook return &CookieHandler_SetAuthCodeCookie{Call: c_call} } -// SetAuthCodeCookie provides a mock function with given fields: ctx, writer, authRequestURL -func (_m *CookieHandler) SetAuthCodeCookie(ctx context.Context, writer http.ResponseWriter, authRequestURL string) error { - ret := _m.Called(ctx, writer, authRequestURL) +// SetAuthCodeCookie provides a mock function with given fields: ctx, request, writer, authRequestURL +func (_m *CookieHandler) SetAuthCodeCookie(ctx context.Context, request *http.Request, writer http.ResponseWriter, authRequestURL string) error { + ret := _m.Called(ctx, request, writer, authRequestURL) var r0 error - if rf, ok := ret.Get(0).(func(context.Context, http.ResponseWriter, string) error); ok { - r0 = rf(ctx, writer, authRequestURL) + if rf, ok := ret.Get(0).(func(context.Context, *http.Request, http.ResponseWriter, string) error); ok { + r0 = rf(ctx, request, writer, authRequestURL) } else { r0 = ret.Error(0) } @@ -196,8 +196,8 @@ func (_m CookieHandler_SetTokenCookies) Return(_a0 error) *CookieHandler_SetToke return &CookieHandler_SetTokenCookies{Call: _m.Call.Return(_a0)} } -func (_m *CookieHandler) OnSetTokenCookies(ctx context.Context, writer http.ResponseWriter, token *oauth2.Token) *CookieHandler_SetTokenCookies { - c_call := _m.On("SetTokenCookies", ctx, writer, token) +func (_m *CookieHandler) OnSetTokenCookies(ctx context.Context, request *http.Request, writer http.ResponseWriter, token *oauth2.Token) *CookieHandler_SetTokenCookies { + c_call := _m.On("SetTokenCookies", ctx, request, writer, token) return &CookieHandler_SetTokenCookies{Call: c_call} } @@ -206,13 +206,13 @@ func (_m *CookieHandler) OnSetTokenCookiesMatch(matchers ...interface{}) *Cookie return &CookieHandler_SetTokenCookies{Call: c_call} } -// SetTokenCookies provides a mock function with given fields: ctx, writer, token -func (_m *CookieHandler) SetTokenCookies(ctx context.Context, writer http.ResponseWriter, token *oauth2.Token) error { - ret := _m.Called(ctx, writer, token) +// SetTokenCookies provides a mock function with given fields: ctx, request, writer, token +func (_m *CookieHandler) SetTokenCookies(ctx context.Context, request *http.Request, writer http.ResponseWriter, token *oauth2.Token) error { + ret := _m.Called(ctx, request, writer, token) var r0 error - if rf, ok := ret.Get(0).(func(context.Context, http.ResponseWriter, *oauth2.Token) error); ok { - r0 = rf(ctx, writer, token) + if rf, ok := ret.Get(0).(func(context.Context, *http.Request, http.ResponseWriter, *oauth2.Token) error); ok { + r0 = rf(ctx, request, writer, token) } else { r0 = ret.Error(0) } @@ -228,8 +228,8 @@ func (_m CookieHandler_SetUserInfoCookie) Return(_a0 error) *CookieHandler_SetUs return &CookieHandler_SetUserInfoCookie{Call: _m.Call.Return(_a0)} } -func (_m *CookieHandler) OnSetUserInfoCookie(ctx context.Context, writer http.ResponseWriter, userInfo *service.UserInfoResponse) *CookieHandler_SetUserInfoCookie { - c_call := _m.On("SetUserInfoCookie", ctx, writer, userInfo) +func (_m *CookieHandler) OnSetUserInfoCookie(ctx context.Context, request *http.Request, writer http.ResponseWriter, userInfo *service.UserInfoResponse) *CookieHandler_SetUserInfoCookie { + c_call := _m.On("SetUserInfoCookie", ctx, request, writer, userInfo) return &CookieHandler_SetUserInfoCookie{Call: c_call} } @@ -238,13 +238,13 @@ func (_m *CookieHandler) OnSetUserInfoCookieMatch(matchers ...interface{}) *Cook return &CookieHandler_SetUserInfoCookie{Call: c_call} } -// SetUserInfoCookie provides a mock function with given fields: ctx, writer, userInfo -func (_m *CookieHandler) SetUserInfoCookie(ctx context.Context, writer http.ResponseWriter, userInfo *service.UserInfoResponse) error { - ret := _m.Called(ctx, writer, userInfo) +// SetUserInfoCookie provides a mock function with given fields: ctx, request, writer, userInfo +func (_m *CookieHandler) SetUserInfoCookie(ctx context.Context, request *http.Request, writer http.ResponseWriter, userInfo *service.UserInfoResponse) error { + ret := _m.Called(ctx, request, writer, userInfo) var r0 error - if rf, ok := ret.Get(0).(func(context.Context, http.ResponseWriter, *service.UserInfoResponse) error); ok { - r0 = rf(ctx, writer, userInfo) + if rf, ok := ret.Get(0).(func(context.Context, *http.Request, http.ResponseWriter, *service.UserInfoResponse) error); ok { + r0 = rf(ctx, request, writer, userInfo) } else { r0 = ret.Error(0) }