diff --git a/pkg/server/api_v2_auth.go b/pkg/server/api_v2_auth.go index 27d1d7940031..8badd88ff275 100644 --- a/pkg/server/api_v2_auth.go +++ b/pkg/server/api_v2_auth.go @@ -277,32 +277,67 @@ func newAuthenticationV2Mux(s *authenticationV2Server, inner http.Handler) *auth } } -// getSession decodes the cookie from the request, looks up the corresponding session, and -// returns the logged in user name. If there's an error, it returns an error value and -// also sends the error over http using w. +// apiV2UseCookieBasedAuth is a magic value of the auth header that +// tells us to look for the session in the cookie. This can be used by +// frontend code to maintain cookie-based auth while interacting with +// the API. +const apiV2UseCookieBasedAuth = "cookie" + +// getSession decodes the cookie from the request, looks up the corresponding +// session, and returns the logged-in username. The session can be looked up +// either from a session cookie as used in the non-v2 API server, or via the +// session header. In order for us to use the cookie as the session source, the +// header `"X-Cockroach-API-Session"` must be set to `"cookie"` (This is to +// guard against CSRF attacks in the browser since it forces the caller to use +// javascript to set the header). If there's an error, it returns an error value +// and also sends the error over http using w. func (a *authenticationV2Mux) getSession( w http.ResponseWriter, req *http.Request, ) (string, *serverpb.SessionCookie, error) { - // Validate the returned cookie. + ctx := req.Context() + // Validate the returned session header or cookie. rawSession := req.Header.Get(apiV2AuthHeader) if len(rawSession) == 0 { err := errors.New("invalid session header") http.Error(w, err.Error(), http.StatusUnauthorized) return "", nil, err } + + possibleSessions := []string{} + if rawSession == apiV2UseCookieBasedAuth { + cookies := req.Cookies() + for _, c := range cookies { + if c.Name != SessionCookieName { + continue + } + possibleSessions = append(possibleSessions, c.Value) + } + } else { + possibleSessions = append(possibleSessions, rawSession) + } + sessionCookie := &serverpb.SessionCookie{} - decoded, err := base64.StdEncoding.DecodeString(rawSession) - if err != nil { - err := errors.New("invalid session header") - http.Error(w, err.Error(), http.StatusBadRequest) - return "", nil, err + var decoded []byte + var err error + for i := range possibleSessions { + decoded, err = base64.StdEncoding.DecodeString(possibleSessions[i]) + if err != nil { + log.Warningf(ctx, "attempted to decode session but failed: %v", err) + continue + } + err = protoutil.Unmarshal(decoded, sessionCookie) + if err != nil { + log.Warningf(ctx, "attempted to unmarshal session but failed: %v", err) + continue + } + // We've successfully decoded a session from cookie or header. + break } - if err := protoutil.Unmarshal(decoded, sessionCookie); err != nil { + if err != nil { err := errors.New("invalid session header") http.Error(w, err.Error(), http.StatusBadRequest) return "", nil, err } - valid, username, err := a.s.authServer.verifySession(req.Context(), sessionCookie) if err != nil { apiV2InternalError(req.Context(), err, w) diff --git a/pkg/server/api_v2_test.go b/pkg/server/api_v2_test.go index 19d954615fb7..36b39daf588d 100644 --- a/pkg/server/api_v2_test.go +++ b/pkg/server/api_v2_test.go @@ -13,6 +13,7 @@ package server import ( "context" gosql "database/sql" + "encoding/base64" "encoding/json" "io/ioutil" "net/http" @@ -27,6 +28,7 @@ import ( "github.com/cockroachdb/cockroach/pkg/util/leaktest" "github.com/cockroachdb/cockroach/pkg/util/log" "github.com/cockroachdb/cockroach/pkg/util/metric" + "github.com/cockroachdb/cockroach/pkg/util/protoutil" "github.com/stretchr/testify/require" "gopkg.in/yaml.v2" ) @@ -183,3 +185,82 @@ func TestRulesV2(t *testing.T) { require.NoError(t, yaml.NewDecoder(resp.Body).Decode(&ruleGroups)) require.NoError(t, resp.Body.Close()) } + +func TestAuthV2(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + + testCluster := serverutils.StartNewTestCluster(t, 3, base.TestClusterArgs{}) + ctx := context.Background() + defer testCluster.Stopper().Stop(ctx) + + ts := testCluster.Server(0) + client, err := ts.GetHTTPClient() + require.NoError(t, err) + + session, err := ts.GetAuthSession(true) + require.NoError(t, err) + sessionBytes, err := protoutil.Marshal(session) + require.NoError(t, err) + sessionEncoded := base64.StdEncoding.EncodeToString(sessionBytes) + + for _, tc := range []struct { + name string + header string + cookie string + expectedStatus int + }{ + { + name: "no auth", + expectedStatus: http.StatusUnauthorized, + }, + { + name: "session in header", + header: sessionEncoded, + expectedStatus: http.StatusOK, + }, + { + name: "cookie auth with correct magic header", + cookie: sessionEncoded, + header: apiV2UseCookieBasedAuth, + expectedStatus: http.StatusOK, + }, + { + name: "cookie auth but missing header", + cookie: sessionEncoded, + expectedStatus: http.StatusUnauthorized, + }, + { + name: "cookie auth but wrong magic header", + cookie: sessionEncoded, + header: "yes", + // Bad Request and not Unauthorized because the session cannot be decoded. + expectedStatus: http.StatusBadRequest, + }, + } { + t.Run(tc.name, func(t *testing.T) { + req, err := http.NewRequest("GET", ts.AdminURL()+apiV2Path+"sessions/", nil) + require.NoError(t, err) + if tc.header != "" { + req.Header.Set(apiV2AuthHeader, tc.header) + } + if tc.cookie != "" { + req.AddCookie(&http.Cookie{ + Name: SessionCookieName, + Value: tc.cookie, + }) + } + resp, err := client.Do(req) + require.NoError(t, err) + require.NotNil(t, resp) + defer resp.Body.Close() + + if tc.expectedStatus != resp.StatusCode { + body, err := ioutil.ReadAll(resp.Body) + require.NoError(t, err) + t.Fatalf("expected status: %d but got: %d with body: %s", tc.expectedStatus, resp.StatusCode, string(body)) + } + }) + } + +} diff --git a/pkg/server/testserver_http.go b/pkg/server/testserver_http.go index b1e0488e18f5..20c463b667cb 100644 --- a/pkg/server/testserver_http.go +++ b/pkg/server/testserver_http.go @@ -71,6 +71,16 @@ func (ts *httpTestServer) GetAuthenticatedHTTPClient(isAdmin bool) (http.Client, return httpClient, err } +// GetAuthenticatedHTTPClient implements the TestServerInterface. +func (ts *httpTestServer) GetAuthSession(isAdmin bool) (*serverpb.SessionCookie, error) { + authUser := authenticatedUserName() + if !isAdmin { + authUser = authenticatedUserNameNoAdmin() + } + _, cookie, err := ts.getAuthenticatedHTTPClientAndCookie(authUser, isAdmin) + return cookie, err +} + func (ts *httpTestServer) getAuthenticatedHTTPClientAndCookie( authUser security.SQLUsername, isAdmin bool, ) (http.Client, *serverpb.SessionCookie, error) { diff --git a/pkg/testutils/serverutils/BUILD.bazel b/pkg/testutils/serverutils/BUILD.bazel index f8492f82c684..31c1f165f5d5 100644 --- a/pkg/testutils/serverutils/BUILD.bazel +++ b/pkg/testutils/serverutils/BUILD.bazel @@ -18,6 +18,7 @@ go_library( "//pkg/roachpb", "//pkg/rpc", "//pkg/security", + "//pkg/server/serverpb", "//pkg/server/status", "//pkg/settings/cluster", "//pkg/storage", diff --git a/pkg/testutils/serverutils/test_tenant_shim.go b/pkg/testutils/serverutils/test_tenant_shim.go index 1224a6b323f8..169cfe2bd91c 100644 --- a/pkg/testutils/serverutils/test_tenant_shim.go +++ b/pkg/testutils/serverutils/test_tenant_shim.go @@ -20,6 +20,7 @@ import ( "github.com/cockroachdb/cockroach/pkg/base" "github.com/cockroachdb/cockroach/pkg/config" "github.com/cockroachdb/cockroach/pkg/rpc" + "github.com/cockroachdb/cockroach/pkg/server/serverpb" "github.com/cockroachdb/cockroach/pkg/settings/cluster" "github.com/cockroachdb/cockroach/pkg/util/hlc" "github.com/cockroachdb/cockroach/pkg/util/log" @@ -127,6 +128,9 @@ type TestTenantInterface interface { // GetAuthenticatedHTTPClient returns an http client which has been // authenticated to access Admin API methods (via a cookie). GetAuthenticatedHTTPClient(isAdmin bool) (http.Client, error) + // GetEncodedSession returns a byte array containing a valid auth + // session. + GetAuthSession(isAdmin bool) (*serverpb.SessionCookie, error) // DrainClients shuts down client connections. DrainClients(ctx context.Context) error