diff --git a/pkg/server/api_v2_auth.go b/pkg/server/api_v2_auth.go index 57f7239b63a9..4995a792c65f 100644 --- a/pkg/server/api_v2_auth.go +++ b/pkg/server/api_v2_auth.go @@ -266,14 +266,16 @@ func (a *authenticationV2Server) ServeHTTP(w http.ResponseWriter, r *http.Reques // and the request isn't routed through to the inner handler. On success, the // username is set on the request context for use in the inner handler. type authenticationV2Mux struct { - s *authenticationV2Server - inner http.Handler + s *authenticationV2Server + inner http.Handler + allowAnonymous bool } func newAuthenticationV2Mux(s *authenticationV2Server, inner http.Handler) *authenticationV2Mux { return &authenticationV2Mux{ - s: s, - inner: inner, + s: s, + inner: inner, + allowAnonymous: s.sqlServer.cfg.Insecure, } } @@ -293,14 +295,13 @@ const apiV2UseCookieBasedAuth = "cookie" // and also sends the error over http using w. func (a *authenticationV2Mux) getSession( w http.ResponseWriter, req *http.Request, -) (string, *serverpb.SessionCookie, error) { +) (string, *serverpb.SessionCookie, int, error) { ctx := req.Context() // Validate the returned session header or cookie. rawSession := req.Header.Get(apiV2AuthHeader) if len(rawSession) == 0 { err := errors.New("invalid session header") - http.Error(w, err.Error(), http.StatusUnauthorized) - return "", nil, err + return "", nil, http.StatusUnauthorized, err } possibleSessions := []string{} @@ -335,36 +336,40 @@ func (a *authenticationV2Mux) getSession( } if err != nil { err := errors.New("invalid session header") - http.Error(w, err.Error(), http.StatusBadRequest) - return "", nil, err + return "", nil, http.StatusBadRequest, err } valid, username, err := a.s.authServer.verifySession(req.Context(), sessionCookie) if err != nil { apiV2InternalError(req.Context(), err, w) - return "", nil, err + return "", nil, http.StatusInternalServerError, err } if !valid { err := errors.New("the provided authentication session could not be validated") - http.Error(w, err.Error(), http.StatusUnauthorized) - return "", nil, err + return "", nil, http.StatusUnauthorized, err } - return username, sessionCookie, nil + return username, sessionCookie, http.StatusOK, nil } func (a *authenticationV2Mux) ServeHTTP(w http.ResponseWriter, req *http.Request) { - username, cookie, err := a.getSession(w, req) - if err == nil { - // Valid session found. Set the username in the request context, so - // child http.Handlers can access it. - ctx := req.Context() - ctx = context.WithValue(ctx, webSessionUserKey{}, username) - ctx = context.WithValue(ctx, webSessionIDKey{}, cookie.ID) - req = req.WithContext(ctx) - } else { + u, cookie, errStatus, err := a.getSession(w, req) + if err != nil && !a.allowAnonymous { // getSession writes an error to w if err != nil. + http.Error(w, err.Error(), errStatus) return } + if a.allowAnonymous { + u = username.RootUser + } + // Valid session found, or insecure. Set the username in the request context, + // so child http.Handlers can access it. + ctx := req.Context() + ctx = context.WithValue(ctx, webSessionUserKey{}, u) + if cookie != nil { + ctx = context.WithValue(ctx, webSessionIDKey{}, cookie.ID) + } + req = req.WithContext(ctx) + a.inner.ServeHTTP(w, req) } diff --git a/pkg/server/api_v2_test.go b/pkg/server/api_v2_test.go index edac446b218e..d8781dea8052 100644 --- a/pkg/server/api_v2_test.go +++ b/pkg/server/api_v2_test.go @@ -24,6 +24,7 @@ import ( "github.com/cockroachdb/cockroach/pkg/base" "github.com/cockroachdb/cockroach/pkg/server/serverpb" + "github.com/cockroachdb/cockroach/pkg/testutils" "github.com/cockroachdb/cockroach/pkg/testutils/serverutils" "github.com/cockroachdb/cockroach/pkg/util/leaktest" "github.com/cockroachdb/cockroach/pkg/util/log" @@ -191,77 +192,89 @@ func TestAuthV2(t *testing.T) { defer leaktest.AfterTest(t)() defer log.Scope(t).Close(t) - testCluster := serverutils.StartNewTestCluster(t, 3, base.TestClusterArgs{}) - ctx := context.Background() - defer testCluster.Stopper().Stop(ctx) - - ts := testCluster.Server(0) - client, err := ts.GetUnauthenticatedHTTPClient() - require.NoError(t, err) + testutils.RunTrueAndFalse(t, "insecure", func(t *testing.T, insecure bool) { + testCluster := serverutils.StartNewTestCluster(t, 3, base.TestClusterArgs{ + ServerArgs: base.TestServerArgs{ + Insecure: insecure, + }, + }) + ctx := context.Background() + defer testCluster.Stopper().Stop(ctx) - session, err := ts.GetAuthSession(true) - require.NoError(t, err) - sessionBytes, err := protoutil.Marshal(session) - require.NoError(t, err) - sessionEncoded := base64.StdEncoding.EncodeToString(sessionBytes) - - for _, tc := range []struct { - name string - header string - cookie string - expectedStatus int - }{ - { - name: "no auth", - expectedStatus: http.StatusUnauthorized, - }, - { - name: "session in header", - header: sessionEncoded, - expectedStatus: http.StatusOK, - }, - { - name: "cookie auth with correct magic header", - cookie: sessionEncoded, - header: apiV2UseCookieBasedAuth, - expectedStatus: http.StatusOK, - }, - { - name: "cookie auth but missing header", - cookie: sessionEncoded, - expectedStatus: http.StatusUnauthorized, - }, - { - name: "cookie auth but wrong magic header", - cookie: sessionEncoded, - header: "yes", - // Bad Request and not Unauthorized because the session cannot be decoded. - expectedStatus: http.StatusBadRequest, - }, - } { - t.Run(tc.name, func(t *testing.T) { - req, err := http.NewRequest("GET", ts.AdminURL()+apiV2Path+"sessions/", nil) - require.NoError(t, err) - if tc.header != "" { - req.Header.Set(apiV2AuthHeader, tc.header) - } - if tc.cookie != "" { - req.AddCookie(&http.Cookie{ - Name: SessionCookieName, - Value: tc.cookie, - }) - } - resp, err := client.Do(req) - require.NoError(t, err) - require.NotNil(t, resp) - defer resp.Body.Close() + ts := testCluster.Server(0) + client, err := ts.GetUnauthenticatedHTTPClient() + require.NoError(t, err) - if tc.expectedStatus != resp.StatusCode { - body, err := ioutil.ReadAll(resp.Body) + session, err := ts.GetAuthSession(true) + require.NoError(t, err) + sessionBytes, err := protoutil.Marshal(session) + require.NoError(t, err) + sessionEncoded := base64.StdEncoding.EncodeToString(sessionBytes) + + for _, tc := range []struct { + name string + header string + cookie string + expectedStatus int + }{ + { + name: "no auth", + expectedStatus: http.StatusUnauthorized, + }, + { + name: "session in header", + header: sessionEncoded, + expectedStatus: http.StatusOK, + }, + { + name: "cookie auth with correct magic header", + cookie: sessionEncoded, + header: apiV2UseCookieBasedAuth, + expectedStatus: http.StatusOK, + }, + { + name: "cookie auth but missing header", + cookie: sessionEncoded, + expectedStatus: http.StatusUnauthorized, + }, + { + name: "cookie auth but wrong magic header", + cookie: sessionEncoded, + header: "yes", + // Bad Request and not Unauthorized because the session cannot be decoded. + expectedStatus: http.StatusBadRequest, + }, + } { + t.Run(tc.name, func(t *testing.T) { + req, err := http.NewRequest("GET", ts.AdminURL()+apiV2Path+"sessions/", nil) require.NoError(t, err) - t.Fatalf("expected status: %d but got: %d with body: %s", tc.expectedStatus, resp.StatusCode, string(body)) - } - }) - } + if tc.header != "" { + req.Header.Set(apiV2AuthHeader, tc.header) + } + if tc.cookie != "" { + req.AddCookie(&http.Cookie{ + Name: SessionCookieName, + Value: tc.cookie, + }) + } + resp, err := client.Do(req) + require.NoError(t, err) + require.NotNil(t, resp) + defer resp.Body.Close() + + if !insecure && tc.expectedStatus != resp.StatusCode { + body, err := ioutil.ReadAll(resp.Body) + require.NoError(t, err) + t.Fatalf("expected status: %d but got: %d with body: %s", tc.expectedStatus, resp.StatusCode, string(body)) + } + if insecure && http.StatusOK != resp.StatusCode { + body, err := ioutil.ReadAll(resp.Body) + require.NoError(t, err) + t.Fatalf("expected status: %d but got: %d with body: %s", http.StatusOK, resp.StatusCode, string(body)) + } + }) + } + + }) } diff --git a/pkg/util/log/eventpb/event_test.go b/pkg/util/log/eventpb/event_test.go index b7729ebcda1f..d8aa15eb0b15 100644 --- a/pkg/util/log/eventpb/event_test.go +++ b/pkg/util/log/eventpb/event_test.go @@ -52,6 +52,12 @@ func TestEventJSON(t *testing.T) { // Integer and boolean fields are not redactable in any case. {&UnsafeDeleteDescriptor{ParentID: 123, Force: true}, `"ParentID":123,"Force":true`}, + + // Primitive fields without an `includeempty` annotation will NOT emit their + // zero value. In this case, `SnapshotID` and `NumRecords` do not have the + // `includeempty` annotation, so nothing is emitted, despite the presence of + // zero values. + {&SchemaSnapshotMetadata{SnapshotID: "", NumRecords: 0}, ""}, } for _, tc := range testCases { diff --git a/pkg/util/log/eventpb/eventpbgen/gen.go b/pkg/util/log/eventpb/eventpbgen/gen.go index ebb5107f2083..dc68f18aa665 100644 --- a/pkg/util/log/eventpb/eventpbgen/gen.go +++ b/pkg/util/log/eventpb/eventpbgen/gen.go @@ -84,6 +84,7 @@ type fieldInfo struct { MixedRedactable bool Inherited bool IsEnum bool + AllowZeroValue bool } var ( @@ -380,6 +381,9 @@ func readInput( return errors.Newf("unknown field definition syntax: %q", line) } + // Allow zero values if the field is annotated with 'includeempty'. + allowZeroValue := strings.Contains(line, "includeempty") + typ := fieldDefRe.ReplaceAllString(line, "$typ") switch typ { case "google.protobuf.Timestamp": @@ -451,6 +455,7 @@ func readInput( ReportingSafeRe: safeReName, MixedRedactable: mixed, IsEnum: isEnum, + AllowZeroValue: allowZeroValue, } curMsg.Fields = append(curMsg.Fields, fi) curMsg.AllFields = append(curMsg.AllFields, fi) @@ -536,7 +541,9 @@ func (m *{{.GoType}}) AppendJSONFields(printComma bool, b redact.RedactableBytes {{if .Inherited -}} printComma, b = m.{{.FieldName}}.AppendJSONFields(printComma, b) {{- else if eq .FieldType "string" -}} + {{ if not .AllowZeroValue -}} if m.{{.FieldName}} != "" { + {{- end }} if printComma { b = append(b, ',')}; printComma = true b = append(b, "\"{{.FieldName}}\":\""...) {{ if .AlwaysReportingSafe -}} @@ -555,7 +562,9 @@ func (m *{{.GoType}}) AppendJSONFields(printComma bool, b redact.RedactableBytes b = append(b, redact.EndMarker()...) {{- end }} b = append(b, '"') + {{ if not .AllowZeroValue -}} } + {{- end }} {{- else if eq .FieldType "array_of_string" -}} if len(m.{{.FieldName}}) > 0 { if printComma { b = append(b, ',')}; printComma = true @@ -583,34 +592,54 @@ func (m *{{.GoType}}) AppendJSONFields(printComma bool, b redact.RedactableBytes b = append(b, ']') } {{- else if eq .FieldType "bool" -}} + {{ if not .AllowZeroValue -}} if m.{{.FieldName}} { + {{- end }} if printComma { b = append(b, ',')}; printComma = true b = append(b, "\"{{.FieldName}}\":true"...) + {{ if not .AllowZeroValue -}} } + {{- end }} {{- else if eq .FieldType "int16" "int32" "int64"}} + {{ if not .AllowZeroValue -}} if m.{{.FieldName}} != 0 { + {{- end }} if printComma { b = append(b, ',')}; printComma = true b = append(b, "\"{{.FieldName}}\":"...) b = strconv.AppendInt(b, int64(m.{{.FieldName}}), 10) + {{ if not .AllowZeroValue -}} } + {{- end }} {{- else if eq .FieldType "float"}} + {{ if not .AllowZeroValue -}} if m.{{.FieldName}} != 0 { + {{- end }} if printComma { b = append(b, ',')}; printComma = true b = append(b, "\"{{.FieldName}}\":"...) b = strconv.AppendFloat(b, float64(m.{{.FieldName}}), 'f', -1, 32) + {{ if not .AllowZeroValue -}} } + {{- end }} {{- else if eq .FieldType "double"}} + {{ if not .AllowZeroValue -}} if m.{{.FieldName}} != 0 { + {{- end }} if printComma { b = append(b, ',')}; printComma = true b = append(b, "\"{{.FieldName}}\":"...) b = strconv.AppendFloat(b, float64(m.{{.FieldName}}), 'f', -1, 64) + {{ if not .AllowZeroValue -}} } + {{- end }} {{- else if eq .FieldType "uint16" "uint32" "uint64"}} + {{ if not .AllowZeroValue -}} if m.{{.FieldName}} != 0 { + {{- end }} if printComma { b = append(b, ',')}; printComma = true b = append(b, "\"{{.FieldName}}\":"...) b = strconv.AppendUint(b, uint64(m.{{.FieldName}}), 10) + {{ if not .AllowZeroValue -}} } + {{- end }} {{- else if eq .FieldType "array_of_uint32" -}} if len(m.{{.FieldName}}) > 0 { if printComma { b = append(b, ',')}; printComma = true @@ -622,11 +651,15 @@ func (m *{{.GoType}}) AppendJSONFields(printComma bool, b redact.RedactableBytes b = append(b, ']') } {{- else if .IsEnum }} + {{ if not .AllowZeroValue -}} if m.{{.FieldName}} != 0 { + {{- end }} if printComma { b = append(b, ',')}; printComma = true b = append(b, "\"{{.FieldName}}\":"...) b = strconv.AppendInt(b, int64(m.{{.FieldName}}), 10) + {{ if not .AllowZeroValue -}} } + {{- end }} {{- else if eq .FieldType "protobuf"}} if m.{{.FieldName}} != nil { if printComma { b = append(b, ',')}; printComma = true