diff --git a/.github/workflows/test-integration.yaml b/.github/workflows/test-integration.yaml index f74dcac145..83db1c33b0 100644 --- a/.github/workflows/test-integration.yaml +++ b/.github/workflows/test-integration.yaml @@ -25,6 +25,7 @@ jobs: - TestOIDCAuthenticationPingAll - TestOIDCExpireNodesBasedOnTokenExpiry - TestOIDC024UserCreation + - TestOIDCAuthenticationWithPKCE - TestAuthWebFlowAuthenticationPingAll - TestAuthWebFlowLogoutAndRelogin - TestUserCommand diff --git a/CHANGELOG.md b/CHANGELOG.md index ffa0b10497..ce3e10e772 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -172,6 +172,7 @@ This will also affect the way you [#2261](https://github.com/juanfont/headscale/pull/2261) - Add `dns.extra_records_path` configuration option [#2262](https://github.com/juanfont/headscale/issues/2262) - Support client verify for DERP [#2046](https://github.com/juanfont/headscale/pull/2046) +- Add PKCE Verifier for OIDC [#2314](https://github.com/juanfont/headscale/pull/2314) ## 0.23.0 (2024-09-18) diff --git a/config-example.yaml b/config-example.yaml index cb7bf4da0a..581d997dd4 100644 --- a/config-example.yaml +++ b/config-example.yaml @@ -364,6 +364,18 @@ unix_socket_permission: "0770" # allowed_users: # - alice@example.com # +# # Optional: PKCE (Proof Key for Code Exchange) configuration +# # PKCE adds an additional layer of security to the OAuth 2.0 authorization code flow +# # by preventing authorization code interception attacks +# # See https://datatracker.ietf.org/doc/html/rfc7636 +# pkce: +# # Enable or disable PKCE support (default: false) +# enabled: false +# # PKCE method to use: +# # - plain: Use plain code verifier +# # - S256: Use SHA256 hashed code verifier (default, recommended) +# method: S256 +# # # Map legacy users from pre-0.24.0 versions of headscale to the new OIDC users # # by taking the username from the legacy user and matching it with the username # # provided by the OIDC. This is useful when migrating from legacy users to OIDC diff --git a/docs/ref/oidc.md b/docs/ref/oidc.md index 6bc4557206..9f8c3e596d 100644 --- a/docs/ref/oidc.md +++ b/docs/ref/oidc.md @@ -45,6 +45,18 @@ oidc: allowed_users: - alice@example.com + # Optional: PKCE (Proof Key for Code Exchange) configuration + # PKCE adds an additional layer of security to the OAuth 2.0 authorization code flow + # by preventing authorization code interception attacks + # See https://datatracker.ietf.org/doc/html/rfc7636 + pkce: + # Enable or disable PKCE support (default: false) + enabled: false + # PKCE method to use: + # - plain: Use plain code verifier + # - S256: Use SHA256 hashed code verifier (default, recommended) + method: S256 + # If `strip_email_domain` is set to `true`, the domain part of the username email address will be removed. # This will transform `first-name.last-name@example.com` to the user `first-name.last-name` # If `strip_email_domain` is set to `false` the domain part will NOT be removed resulting to the following diff --git a/hscontrol/oidc.go b/hscontrol/oidc.go index 14191d23d0..35e3c77880 100644 --- a/hscontrol/oidc.go +++ b/hscontrol/oidc.go @@ -28,12 +28,14 @@ import ( ) const ( - randomByteSize = 16 + randomByteSize = 16 + defaultOAuthOptionsCount = 3 ) var ( errEmptyOIDCCallbackParams = errors.New("empty OIDC callback params") errNoOIDCIDToken = errors.New("could not extract ID Token for OIDC callback") + errNoOIDCRegistrationInfo = errors.New("could not get registration info from cache") errOIDCAllowedDomains = errors.New( "authenticated principal does not match any allowed domain", ) @@ -47,11 +49,17 @@ var ( errOIDCNodeKeyMissing = errors.New("could not get node key from cache") ) +// RegistrationInfo contains both machine key and verifier information for OIDC validation. +type RegistrationInfo struct { + MachineKey key.MachinePublic + Verifier *string +} + type AuthProviderOIDC struct { serverURL string cfg *types.OIDCConfig db *db.HSDatabase - registrationCache *zcache.Cache[string, key.MachinePublic] + registrationCache *zcache.Cache[string, RegistrationInfo] notifier *notifier.Notifier ipAlloc *db.IPAllocator polMan policy.PolicyManager @@ -87,7 +95,7 @@ func NewAuthProviderOIDC( Scopes: cfg.Scope, } - registrationCache := zcache.New[string, key.MachinePublic]( + registrationCache := zcache.New[string, RegistrationInfo]( registerCacheExpiration, registerCacheCleanup, ) @@ -157,19 +165,36 @@ func (a *AuthProviderOIDC) RegisterHandler( stateStr := hex.EncodeToString(randomBlob)[:32] - // place the node key into the state cache, so it can be retrieved later - a.registrationCache.Set( - stateStr, - machineKey, - ) + // Initialize registration info with machine key + registrationInfo := RegistrationInfo{ + MachineKey: machineKey, + } - // Add any extra parameter provided in the configuration to the Authorize Endpoint request - extras := make([]oauth2.AuthCodeOption, 0, len(a.cfg.ExtraParams)) + extras := make([]oauth2.AuthCodeOption, 0, len(a.cfg.ExtraParams)+defaultOAuthOptionsCount) + // Add PKCE verification if enabled + if a.cfg.PKCE.Enabled { + verifier := oauth2.GenerateVerifier() + registrationInfo.Verifier = &verifier + + extras = append(extras, oauth2.AccessTypeOffline) + + switch a.cfg.PKCE.Method { + case types.PKCEMethodS256: + extras = append(extras, oauth2.S256ChallengeOption(verifier)) + case types.PKCEMethodPlain: + // oauth2 does not have a plain challenge option, so we add it manually + extras = append(extras, oauth2.SetAuthURLParam("code_challenge_method", "plain"), oauth2.SetAuthURLParam("code_challenge", verifier)) + } + } + // Add any extra parameters from configuration for k, v := range a.cfg.ExtraParams { extras = append(extras, oauth2.SetAuthURLParam(k, v)) } + // Cache the registration info + a.registrationCache.Set(stateStr, registrationInfo) + authURL := a.oauth2Config.AuthCodeURL(stateStr, extras...) log.Debug().Msgf("Redirecting to %s for authentication", authURL) @@ -203,7 +228,7 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler( return } - idToken, err := a.extractIDToken(req.Context(), code) + idToken, err := a.extractIDToken(req.Context(), code, state) if err != nil { http.Error(writer, err.Error(), http.StatusBadRequest) return @@ -318,8 +343,21 @@ func extractCodeAndStateParamFromRequest( func (a *AuthProviderOIDC) extractIDToken( ctx context.Context, code string, + state string, ) (*oidc.IDToken, error) { - oauth2Token, err := a.oauth2Config.Exchange(ctx, code) + var exchangeOpts []oauth2.AuthCodeOption + + if a.cfg.PKCE.Enabled { + regInfo, ok := a.registrationCache.Get(state) + if !ok { + return nil, errNoOIDCRegistrationInfo + } + if regInfo.Verifier != nil { + exchangeOpts = []oauth2.AuthCodeOption{oauth2.VerifierOption(*regInfo.Verifier)} + } + } + + oauth2Token, err := a.oauth2Config.Exchange(ctx, code, exchangeOpts...) if err != nil { return nil, fmt.Errorf("could not exchange code for token: %w", err) } @@ -394,7 +432,7 @@ func validateOIDCAllowedUsers( // cache. If the machine key is found, it will try retrieve the // node information from the database. func (a *AuthProviderOIDC) getMachineKeyFromState(state string) (*types.Node, *key.MachinePublic) { - machineKey, ok := a.registrationCache.Get(state) + regInfo, ok := a.registrationCache.Get(state) if !ok { return nil, nil } @@ -403,9 +441,9 @@ func (a *AuthProviderOIDC) getMachineKeyFromState(state string) (*types.Node, *k // The error is not important, because if it does not // exist, then this is a new node and we will move // on to registration. - node, _ := a.db.GetNodeByMachineKey(machineKey) + node, _ := a.db.GetNodeByMachineKey(regInfo.MachineKey) - return node, &machineKey + return node, ®Info.MachineKey } // reauthenticateNode updates the node expiry in the database diff --git a/hscontrol/types/config.go b/hscontrol/types/config.go index f6c5c48a29..b462b8e928 100644 --- a/hscontrol/types/config.go +++ b/hscontrol/types/config.go @@ -26,11 +26,14 @@ import ( const ( defaultOIDCExpiryTime = 180 * 24 * time.Hour // 180 Days maxDuration time.Duration = 1<<63 - 1 + PKCEMethodPlain string = "plain" + PKCEMethodS256 string = "S256" ) var ( errOidcMutuallyExclusive = errors.New("oidc_client_secret and oidc_client_secret_path are mutually exclusive") errServerURLSuffix = errors.New("server_url cannot be part of base_domain in a way that could make the DERP and headscale server unreachable") + errInvalidPKCEMethod = errors.New("pkce.method must be either 'plain' or 'S256'") ) type IPAllocationStrategy string @@ -162,6 +165,11 @@ type LetsEncryptConfig struct { ChallengeType string } +type PKCEConfig struct { + Enabled bool + Method string +} + type OIDCConfig struct { OnlyStartIfOIDCIsAvailable bool Issuer string @@ -176,6 +184,7 @@ type OIDCConfig struct { Expiry time.Duration UseExpiryFromToken bool MapLegacyUsers bool + PKCE PKCEConfig } type DERPConfig struct { @@ -226,6 +235,13 @@ type Tuning struct { NodeMapSessionBufferedChanSize int } +func validatePKCEMethod(method string) error { + if method != PKCEMethodPlain && method != PKCEMethodS256 { + return errInvalidPKCEMethod + } + return nil +} + // LoadConfig prepares and loads the Headscale configuration into Viper. // This means it sets the default values, reads the configuration file and // environment variables, and handles deprecated configuration options. @@ -293,6 +309,8 @@ func LoadConfig(path string, isFile bool) error { viper.SetDefault("oidc.expiry", "180d") viper.SetDefault("oidc.use_expiry_from_token", false) viper.SetDefault("oidc.map_legacy_users", true) + viper.SetDefault("oidc.pkce.enabled", false) + viper.SetDefault("oidc.pkce.method", "S256") viper.SetDefault("logtail.enabled", false) viper.SetDefault("randomize_client_port", false) @@ -340,6 +358,12 @@ func validateServerConfig() error { // after #2170 is cleaned up // depr.fatal("oidc.strip_email_domain") + if viper.GetBool("oidc.enabled") { + if err := validatePKCEMethod(viper.GetString("oidc.pkce.method")); err != nil { + return err + } + } + depr.Log() for _, removed := range []string{ @@ -928,6 +952,10 @@ func LoadServerConfig() (*Config, error) { // after #2170 is cleaned up StripEmaildomain: viper.GetBool("oidc.strip_email_domain"), MapLegacyUsers: viper.GetBool("oidc.map_legacy_users"), + PKCE: PKCEConfig{ + Enabled: viper.GetBool("oidc.pkce.enabled"), + Method: viper.GetString("oidc.pkce.method"), + }, }, LogTail: logTailConfig, diff --git a/integration/auth_oidc_test.go b/integration/auth_oidc_test.go index 2245987682..e8b4999189 100644 --- a/integration/auth_oidc_test.go +++ b/integration/auth_oidc_test.go @@ -534,6 +534,86 @@ func TestOIDC024UserCreation(t *testing.T) { } } +func TestOIDCAuthenticationWithPKCE(t *testing.T) { + IntegrationSkip(t) + t.Parallel() + + baseScenario, err := NewScenario(dockertestMaxWait()) + assertNoErr(t, err) + + scenario := AuthOIDCScenario{ + Scenario: baseScenario, + } + defer scenario.ShutdownAssertNoPanics(t) + + // Single user with one node for testing PKCE flow + spec := map[string]int{ + "user1": 1, + } + + mockusers := []mockoidc.MockUser{ + oidcMockUser("user1", true), + } + + oidcConfig, err := scenario.runMockOIDC(defaultAccessTTL, mockusers) + assertNoErrf(t, "failed to run mock OIDC server: %s", err) + defer scenario.mockOIDC.Close() + + oidcMap := map[string]string{ + "HEADSCALE_OIDC_ISSUER": oidcConfig.Issuer, + "HEADSCALE_OIDC_CLIENT_ID": oidcConfig.ClientID, + "HEADSCALE_OIDC_CLIENT_SECRET_PATH": "${CREDENTIALS_DIRECTORY_TEST}/hs_client_oidc_secret", + "CREDENTIALS_DIRECTORY_TEST": "/tmp", + "HEADSCALE_OIDC_PKCE_ENABLED": "1", // Enable PKCE + "HEADSCALE_OIDC_MAP_LEGACY_USERS": "0", + "HEADSCALE_OIDC_STRIP_EMAIL_DOMAIN": "0", + } + + err = scenario.CreateHeadscaleEnv( + spec, + hsic.WithTestName("oidcauthpkce"), + hsic.WithConfigEnv(oidcMap), + hsic.WithTLS(), + hsic.WithHostnameAsServerURL(), + hsic.WithFileInContainer("/tmp/hs_client_oidc_secret", []byte(oidcConfig.ClientSecret)), + ) + assertNoErrHeadscaleEnv(t, err) + + // Get all clients and verify they can connect + allClients, err := scenario.ListTailscaleClients() + assertNoErrListClients(t, err) + + allIps, err := scenario.ListTailscaleClientsIPs() + assertNoErrListClientIPs(t, err) + + err = scenario.WaitForTailscaleSync() + assertNoErrSync(t, err) + + // Verify PKCE was used in authentication + headscale, err := scenario.Headscale() + assertNoErr(t, err) + + var listUsers []v1.User + err = executeAndUnmarshal(headscale, + []string{ + "headscale", + "users", + "list", + "--output", + "json", + }, + &listUsers, + ) + assertNoErr(t, err) + + allAddrs := lo.Map(allIps, func(x netip.Addr, index int) string { + return x.String() + }) + + success := pingAllHelper(t, allClients, allAddrs) + t.Logf("%d successful pings out of %d", success, len(allClients)*len(allIps)) +} + func (s *AuthOIDCScenario) CreateHeadscaleEnv( users map[string]int, opts ...hsic.Option,