Skip to content

Commit

Permalink
Add failing test case for tampered pkce verifier
Browse files Browse the repository at this point in the history
  • Loading branch information
Rorical committed Dec 22, 2024
1 parent 668ad16 commit 4310d11
Show file tree
Hide file tree
Showing 2 changed files with 262 additions and 4 deletions.
1 change: 1 addition & 0 deletions .github/workflows/test-integration.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ jobs:
- TestOIDCExpireNodesBasedOnTokenExpiry
- TestOIDC024UserCreation
- TestOIDCAuthenticationWithPKCE
- TestOIDCAuthenticationWithPKCEVerifierTampering
- TestAuthWebFlowAuthenticationPingAll
- TestAuthWebFlowLogoutAndRelogin
- TestUserCommand
Expand Down
265 changes: 261 additions & 4 deletions integration/auth_oidc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"net/netip"
"sort"
"strconv"
"strings"
"testing"
"time"

Expand All @@ -34,9 +35,13 @@ const (
dockerContextPath = "../."
hsicOIDCMockHashLength = 6
defaultAccessTTL = 10 * time.Minute
nodeStateRunning = "Running"
)

var errStatusCodeNotOK = errors.New("status code not OK")
var (
errStatusCodeNotOK = errors.New("status code not OK")
ErrOIDCClientCount = errors.New("client count must be 1 for OIDC scenario")
)

type AuthOIDCScenario struct {
*Scenario
Expand Down Expand Up @@ -617,12 +622,128 @@ func TestOIDCAuthenticationWithPKCE(t *testing.T) {
for _, client := range allClients {
status, err := client.Status()
assertNoErr(t, err)
if status.BackendState != "Running" {
if status.BackendState != nodeStateRunning {
t.Errorf("client %s is not running: %s", client.Hostname(), status.BackendState)
}
}
}

type tamperVerifierTransport struct {
base http.RoundTripper
}

func (t *tamperVerifierTransport) RoundTrip(req *http.Request) (*http.Response, error) {
log.Printf("RoundTrip: %s %s", req.Method, req.URL.String())

// For POST requests, tamper with form data
if req.Method == http.MethodPost {
log.Printf("Processing POST request")
err := req.ParseForm()
if err != nil {
log.Printf("Error parsing form: %v", err)
return nil, err
}
if verifier := req.Form.Get("code_challenge"); verifier != "" {
log.Printf("Found POST verifier: %s", verifier)
// Tamper with the verifier
req.Form.Set("code_challenge", verifier+"_tampered")
log.Printf("Modified POST verifier to: %s", req.Form.Get("code_challenge"))
// Update request body with modified form
req.Body = io.NopCloser(strings.NewReader(req.Form.Encode()))
req.ContentLength = int64(len(req.Form.Encode()))
} else {
log.Printf("No code_challenge found in POST form data")
}
}

// For GET requests, tamper with URL query parameters
if req.Method == http.MethodGet {
log.Printf("Processing GET request")
q := req.URL.Query()
if verifier := q.Get("code_challenge"); verifier != "" {
log.Printf("Found GET verifier: %s", verifier)
q.Set("code_challenge", verifier+"_tampered")
req.URL.RawQuery = q.Encode()
log.Printf("Modified URL to: %s", req.URL.String())
} else {
log.Printf("No code_challenge found in GET query params")
}
}

// Forward the request with the tampered verifier
resp, err := t.base.RoundTrip(req)
if err != nil {
log.Printf("RoundTrip error: %v", err)

return nil, err
}
log.Printf("Response status: %s", resp.Status)

return resp, err
}

func TestOIDCAuthenticationWithPKCEVerifierTampering(t *testing.T) {
IntegrationSkip(t)
t.Parallel()

baseScenario, err := NewScenario(dockertestMaxWait())
assertNoErr(t, err)

scenario := AuthOIDCScenario{
Scenario: baseScenario,
}
defer scenario.ShutdownAssertNoPanics(t)

// Single user with one node for testing PKCE flow
spec := map[string]int{
"user1": 1,
}

mockusers := []mockoidc.MockUser{
oidcMockUser("user1", true),
}

oidcConfig, err := scenario.runMockOIDC(defaultAccessTTL, mockusers)
assertNoErrf(t, "failed to run mock OIDC server: %s", err)
defer scenario.mockOIDC.Close()

oidcMap := map[string]string{
"HEADSCALE_OIDC_ISSUER": oidcConfig.Issuer,
"HEADSCALE_OIDC_CLIENT_ID": oidcConfig.ClientID,
"HEADSCALE_OIDC_CLIENT_SECRET_PATH": "${CREDENTIALS_DIRECTORY_TEST}/hs_client_oidc_secret",
"CREDENTIALS_DIRECTORY_TEST": "/tmp",
"HEADSCALE_OIDC_PKCE_ENABLED": "1", // Enable PKCE
"HEADSCALE_OIDC_MAP_LEGACY_USERS": "0",
"HEADSCALE_OIDC_STRIP_EMAIL_DOMAIN": "0",
}

// Create a transport that modifies the PKCE verifier in transit
baseTransport := &http.Transport{
// #nosec G402 -- This is a test-only code using mock OIDC server with self-signed certificates
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
}
tamperTransport := &tamperVerifierTransport{
base: baseTransport,
}

err = scenario.CreateHeadscaleEnvWithHTTPModifier(
spec,
func(cli *http.Client) {
cli.Transport = tamperTransport
},
hsic.WithTestName("oidcauthpkce"),
hsic.WithConfigEnv(oidcMap),
hsic.WithTLS(),
hsic.WithHostnameAsServerURL(),
hsic.WithFileInContainer("/tmp/hs_client_oidc_secret", []byte(oidcConfig.ClientSecret)),
)
if err == nil {
t.Error("expected authentication to fail due to PKCE verifier tampering, but it succeeded")
} else {
log.Printf("auth got error: %s", err)
}
}

func (s *AuthOIDCScenario) CreateHeadscaleEnv(
users map[string]int,
opts ...hsic.Option,
Expand All @@ -643,7 +764,7 @@ func (s *AuthOIDCScenario) CreateHeadscaleEnv(
// This is because the MockOIDC server can only serve login
// requests based on a queue it has been given on startup.
// We currently only populates it with one login request per user.
return fmt.Errorf("client count must be 1 for OIDC scenario.")
return ErrOIDCClientCount
}
log.Printf("creating user %s with %d clients", userName, clientCount)
err = s.CreateUser(userName)
Expand All @@ -665,6 +786,49 @@ func (s *AuthOIDCScenario) CreateHeadscaleEnv(
return nil
}

func (s *AuthOIDCScenario) CreateHeadscaleEnvWithHTTPModifier(
users map[string]int,
httpModifier func(*http.Client),
opts ...hsic.Option,
) error {
headscale, err := s.Headscale(opts...)
if err != nil {
return err
}

err = headscale.WaitForRunning()
if err != nil {
return err
}

for userName, clientCount := range users {
if clientCount != 1 {
// OIDC scenario only supports one client per user.
// This is because the MockOIDC server can only serve login
// requests based on a queue it has been given on startup.
// We currently only populates it with one login request per user.
return ErrOIDCClientCount
}
log.Printf("creating user %s with %d clients", userName, clientCount)
err = s.CreateUser(userName)
if err != nil {
return err
}

err = s.CreateTailscaleNodesInUser(userName, "all", clientCount)
if err != nil {
return err
}

err = s.runTailscaleUpWithModifier(userName, headscale.GetEndpoint(), httpModifier)
if err != nil {
return err
}
}

return nil
}

func (s *AuthOIDCScenario) runMockOIDC(accessTTL time.Duration, users []mockoidc.MockUser) (*types.OIDCConfig, error) {
port, err := dockertestutil.RandomFreeHostPort()
if err != nil {
Expand Down Expand Up @@ -774,14 +938,15 @@ func (s *AuthOIDCScenario) runTailscaleUp(
log.Printf("%s failed to run tailscale up: %s", c.Hostname(), err)
}

loginURL.Host = fmt.Sprintf("%s:8080", headscale.GetIP())
loginURL.Host = headscale.GetIP() + ":8080"
loginURL.Scheme = "http"

if len(headscale.GetCert()) > 0 {
loginURL.Scheme = "https"
}

insecureTransport := &http.Transport{
// #nosec G402 -- This is a test-only code using mock OIDC server with self-signed certificates
TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, // nolint
}

Expand Down Expand Up @@ -848,6 +1013,98 @@ func (s *AuthOIDCScenario) runTailscaleUp(
return fmt.Errorf("failed to up tailscale node: %w", errNoUserAvailable)
}

func (s *AuthOIDCScenario) runTailscaleUpWithModifier(
userStr string,
loginServer string,
httpClientModifier func(*http.Client),
) error {
headscale, err := s.Headscale()
if err != nil {
return err
}

log.Printf("running tailscale up for user %s", userStr)
if user, ok := s.users[userStr]; ok {
for _, client := range user.Clients {
c := client
err := func() error {
status, err := c.Status()
if err != nil {
log.Printf("%s failed to get status: %s", c.Hostname(), err)
return err
}

if status.BackendState == nodeStateRunning {
log.Printf("%s is already running", c.Hostname())
return nil
}

log.Printf("%s running tailscale up", c.Hostname())

loginURL, err := c.LoginWithURL(loginServer)
if err != nil {
log.Printf("%s failed to run tailscale up: %s", c.Hostname(), err)
return err
}

loginURL.Host = headscale.GetIP() + ":8080"
loginURL.Scheme = "http"

if len(headscale.GetCert()) > 0 {
loginURL.Scheme = "https"
}

insecureTransport := &http.Transport{
// #nosec G402 -- This is a test-only code using mock OIDC server with self-signed certificates
TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, // nolint
}

log.Printf("%s login url: %s\n", c.Hostname(), loginURL.String())

log.Printf("%s logging in with url", c.Hostname())
httpClient := &http.Client{Transport: insecureTransport}

// Allow the test to modify the HTTP client
if httpClientModifier != nil {
httpClientModifier(httpClient)
}

ctx := context.Background()
req, _ := http.NewRequestWithContext(ctx, http.MethodGet, loginURL.String(), nil)
resp, err := httpClient.Do(req)
if err != nil {
log.Printf(
"%s failed to login using url %s: %s",
c.Hostname(),
loginURL,
err,
)

return err
}
defer resp.Body.Close()

if resp.StatusCode != http.StatusOK {
log.Printf("%s response code of oidc login request was %s", c.Hostname(), resp.Status)
body, _ := io.ReadAll(resp.Body)
log.Printf("body: %s", body)

return errStatusCodeNotOK
}

return nil
}()
if err != nil {
return err
}
}

return nil
}

return fmt.Errorf("failed to up tailscale node: %w", errNoUserAvailable)
}

func (s *AuthOIDCScenario) Shutdown() {
err := s.pool.Purge(s.mockOIDC)
if err != nil {
Expand Down

0 comments on commit 4310d11

Please sign in to comment.