diff --git a/.golangci-strict.yml b/.golangci-strict.yml index 55ed787ec6..d82b114fa3 100644 --- a/.golangci-strict.yml +++ b/.golangci-strict.yml @@ -13,7 +13,7 @@ run: linters: enable: - #- bodyclose # checks whether HTTP response body is closed successfully + - bodyclose # checks whether HTTP response body is closed successfully #- dupl # Tool for code clone detection - errcheck # Errcheck is a program for checking for unchecked errors in go programs. These unchecked errors can be critical bugs in some cases #- goconst # Finds repeated strings that could be replaced by a constant @@ -59,7 +59,6 @@ linters: - whitespace # Tool for detection of leading and trailing whitespace - wsl # Whitespace Linter - Forces you to use empty lines! # Once fixed, should enable - - bodyclose # checks whether HTTP response body is closed successfully - deadcode # Finds unused code - dupl # Tool for code clone detection - goconst # Finds repeated strings that could be replaced by a constant diff --git a/base/bucket.go b/base/bucket.go index 654a629956..d387cd2861 100644 --- a/base/bucket.go +++ b/base/bucket.go @@ -555,13 +555,6 @@ func retrievePurgeInterval(ctx context.Context, bucket CouchbaseBucketStore, uri return time.Duration(purgeIntervalHours) * time.Hour, nil } -func ensureBodyClosed(ctx context.Context, body io.ReadCloser) { - err := body.Close() - if err != nil { - DebugfCtx(ctx, KeyBucket, "Failed to close socket: %v", err) - } -} - // AsViewStore returns a ViewStore if the underlying dataStore implements ViewStore. func AsViewStore(ds DataStore) (sgbucket.ViewStore, bool) { viewStore, ok := ds.(sgbucket.ViewStore) diff --git a/base/bucket_gocb.go b/base/bucket_gocb.go index 6eaa13991b..4b111c67c9 100644 --- a/base/bucket_gocb.go +++ b/base/bucket_gocb.go @@ -83,7 +83,12 @@ func putDDocForTombstones(ctx context.Context, name string, payload []byte, capi return err } - defer ensureBodyClosed(ctx, resp.Body) + defer func() { + err := resp.Body.Close() + if err != nil { + DebugfCtx(ctx, KeyBucket, "Failed to close socket: %v", err) + } + }() if resp.StatusCode != 201 { data, err := io.ReadAll(resp.Body) if err != nil { diff --git a/rest/adminapitest/admin_api_test.go b/rest/adminapitest/admin_api_test.go index 1337446568..58561c050a 100644 --- a/rest/adminapitest/admin_api_test.go +++ b/rest/adminapitest/admin_api_test.go @@ -3865,6 +3865,7 @@ func setServerPurgeInterval(t *testing.T, rt *rest.RestTester, newPurgeInterval resp, err := httpClient.Do(req) require.NoError(t, err) + assert.NoError(t, resp.Body.Close()) require.Equal(t, resp.StatusCode, http.StatusOK) } diff --git a/rest/api_test.go b/rest/api_test.go index fad4bbe3db..ce4c9ea094 100644 --- a/rest/api_test.go +++ b/rest/api_test.go @@ -112,6 +112,7 @@ func TestPublicRESTStatCount(t *testing.T) { // test metrics endpoint response, err := http.Get(srv.URL + "/_metrics") require.NoError(t, err) + require.NoError(t, response.Body.Close()) assert.Equal(t, http.StatusOK, response.StatusCode) // assert the stat doesn't increment base.RequireWaitForStat(t, func() int64 { diff --git a/rest/bytes_read_public_api_test.go b/rest/bytes_read_public_api_test.go index c501a1477e..a46a083a21 100644 --- a/rest/bytes_read_public_api_test.go +++ b/rest/bytes_read_public_api_test.go @@ -71,6 +71,7 @@ func TestBytesReadDocOperations(t *testing.T) { response, err := http.Get(srv.URL + "/_metrics") require.NoError(t, err) assert.Equal(t, http.StatusOK, response.StatusCode) + require.NoError(t, response.Body.Close()) base.RequireWaitForStat(t, func() int64 { return rt.GetDatabase().DbStats.DatabaseStats.PublicRestBytesRead.Value() diff --git a/rest/functionsapitest/graphql_admin_test.go b/rest/functionsapitest/graphql_admin_test.go index 886afbc322..1c1be904b1 100644 --- a/rest/functionsapitest/graphql_admin_test.go +++ b/rest/functionsapitest/graphql_admin_test.go @@ -11,6 +11,7 @@ package functionsapitest import ( "encoding/json" "fmt" + "net/http" "os" "testing" @@ -178,11 +179,11 @@ func TestFunctionsConfigGetWithoutFeatureFlagGraphQL(t *testing.T) { t.Run("GraphQL, Non-Admin", func(t *testing.T) { response := rt.SendRequest("GET", "/db/_config/graphql", "") - assert.Equal(t, 404, response.Result().StatusCode) + rest.RequireStatus(t, response, http.StatusNotFound) }) t.Run("GraphQL", func(t *testing.T) { response := rt.SendAdminRequest("GET", "/db/_config/graphql", "") - assert.Equal(t, 404, response.Result().StatusCode) + rest.RequireStatus(t, response, http.StatusNotFound) }) } @@ -190,31 +191,31 @@ func TestFunctionsConfigGetWithoutFeatureFlagGraphQL(t *testing.T) { func runTestFunctionsConfigMVCC(t *testing.T, rt *rest.RestTester, uri string, newValue string) { // Get initial etag: response := rt.SendAdminRequest("GET", uri, "") - assert.Equal(t, 200, response.Result().StatusCode) + rest.RequireStatus(t, response, http.StatusOK) etag := response.HeaderMap.Get("Etag") assert.Regexp(t, `"[^"]+"`, etag) // Update config, just to change its etag: response = rt.SendAdminRequest("PUT", uri, newValue) - assert.Equal(t, 200, response.Result().StatusCode) + rest.RequireStatus(t, response, http.StatusOK) newEtag := response.HeaderMap.Get("Etag") assert.Regexp(t, `"[^"]+"`, newEtag) assert.NotEqual(t, etag, newEtag) // A GET should also return the new etag: response = rt.SendAdminRequest("GET", uri, "") - assert.Equal(t, 200, response.Result().StatusCode) + rest.RequireStatus(t, response, http.StatusOK) assert.Equal(t, newEtag, response.HeaderMap.Get("Etag")) // Try to update using If-Match with the old etag: headers := map[string]string{"If-Match": etag} response = rt.SendAdminRequestWithHeaders("PUT", uri, newValue, headers) - assert.Equal(t, 412, response.Result().StatusCode) + rest.RequireStatus(t, response, http.StatusPreconditionFailed) // Now update successfully using the current etag: headers["If-Match"] = newEtag response = rt.SendAdminRequestWithHeaders("PUT", uri, newValue, headers) - assert.Equal(t, 200, response.Result().StatusCode) + rest.RequireStatus(t, response, http.StatusOK) newestEtag := response.HeaderMap.Get("Etag") assert.Regexp(t, `"[^"]+"`, newestEtag) assert.NotEqual(t, etag, newestEtag) @@ -222,12 +223,12 @@ func runTestFunctionsConfigMVCC(t *testing.T, rt *rest.RestTester, uri string, n // Try to delete using If-Match with the previous etag: response = rt.SendAdminRequestWithHeaders("DELETE", uri, newValue, headers) - assert.Equal(t, 412, response.Result().StatusCode) + rest.RequireStatus(t, response, http.StatusPreconditionFailed) // Now delete successfully using the current etag: headers["If-Match"] = newestEtag response = rt.SendAdminRequestWithHeaders("DELETE", uri, newValue, headers) - assert.Equal(t, 200, response.Result().StatusCode) + rest.RequireStatus(t, response, http.StatusOK) } // Test use of "Etag" and "If-Match" headers to safely update graphql config. @@ -259,11 +260,11 @@ func TestFunctionsConfigGraphQLGetEmpty(t *testing.T) { t.Run("Non-Admin", func(t *testing.T) { response := rt.SendRequest("GET", "/db/_config/graphql", "") - assert.Equal(t, 404, response.Result().StatusCode) + rest.RequireStatus(t, response, http.StatusNotFound) }) t.Run("All", func(t *testing.T) { response := rt.SendAdminRequest("GET", "/db/_config/graphql", "") - assert.Equal(t, 404, response.Result().StatusCode) + rest.RequireStatus(t, response, http.StatusNotFound) }) } @@ -281,7 +282,7 @@ func TestFunctionsConfigGraphQLGet(t *testing.T) { t.Run("Non-Admin", func(t *testing.T) { response := rt.SendRequest("GET", "/db/_config/graphql", "") - assert.Equal(t, 404, response.Result().StatusCode) + rest.RequireStatus(t, response, http.StatusNotFound) }) t.Run("All", func(t *testing.T) { response := rt.SendAdminRequest("GET", "/db/_config/graphql", "") @@ -304,15 +305,15 @@ func TestFunctionsConfigGraphQLPut(t *testing.T) { t.Run("Non-Admin", func(t *testing.T) { response := rt.SendRequest("PUT", "/db/_config/graphql", "{}") - assert.Equal(t, 404, response.Result().StatusCode) + rest.RequireStatus(t, response, http.StatusNotFound) response = rt.SendRequest("DELETE", "/db/_config/graphql", "{}") - assert.Equal(t, 404, response.Result().StatusCode) + rest.RequireStatus(t, response, http.StatusNotFound) }) t.Run("ReplaceBogus", func(t *testing.T) { response := rt.SendAdminRequest("PUT", "/db/_config/graphql", `{ "schema": "obviously not a valid schema ^_^" }`) - assert.Equal(t, 400, response.Result().StatusCode) + rest.RequireStatus(t, response, http.StatusBadRequest) }) t.Run("Replace", func(t *testing.T) { response := rt.SendAdminRequest("PUT", "/db/_config/graphql", `{ @@ -326,20 +327,20 @@ func TestFunctionsConfigGraphQLPut(t *testing.T) { } } }`) - assert.Equal(t, 200, response.Result().StatusCode) + rest.RequireStatus(t, response, http.StatusOK) response = rt.SendAdminRequest("POST", "/db/_graphql", `{"query": "query{ sum(n:3) }"}`) - assert.Equal(t, 200, response.Result().StatusCode) - assert.Equal(t, `{"data":{"sum":6}}`, string(response.BodyBytes())) + rest.RequireStatus(t, response, http.StatusOK) + assert.JSONEq(t, `{"data":{"sum":6}}`, response.BodyString()) }) t.Run("Delete", func(t *testing.T) { response := rt.SendAdminRequest("DELETE", "/db/_config/graphql", "") - assert.Equal(t, 200, response.Result().StatusCode) + rest.RequireStatus(t, response, http.StatusOK) assert.Nil(t, rt.GetDatabase().Options.GraphQL) response = rt.SendAdminRequest("POST", "/db/_graphql", `{"query": "query{ sum(n:3) }"}`) - assert.Equal(t, 503, response.Result().StatusCode) + rest.RequireStatus(t, response, http.StatusServiceUnavailable) }) } @@ -371,11 +372,11 @@ func TestValidGraphQLConfigurationValues(t *testing.T) { }, "max_schema_size" : %d }`, len(schema))) - assert.Equal(t, 200, response.Result().StatusCode) + rest.RequireStatus(t, response, http.StatusOK) response = rt.SendAdminRequest("GET", "/db/_config/graphql", "") - assert.Equal(t, 200, response.Result().StatusCode) - assert.Contains(t, string(response.BodyBytes()), `"max_schema_size":32`) + rest.RequireStatus(t, response, http.StatusOK) + assert.Contains(t, response.BodyString(), `"max_schema_size":32`) }) //If max_resolver_count >= given number of resolvers then Valid @@ -397,11 +398,11 @@ func TestValidGraphQLConfigurationValues(t *testing.T) { }, "max_resolver_count" : 2 }`) - assert.Equal(t, 200, response.Result().StatusCode) + rest.RequireStatus(t, response, http.StatusOK) response = rt.SendAdminRequest("GET", "/db/_config/graphql", "") - assert.Equal(t, 200, response.Result().StatusCode) - assert.Contains(t, string(response.BodyBytes()), `"max_resolver_count":2`) + rest.RequireStatus(t, response, http.StatusOK) + assert.Contains(t, response.BodyString(), `"max_resolver_count":2`) }) //If max_request_size >= length of JSON-encoded arguments passed to a function then Valid @@ -419,29 +420,29 @@ func TestValidGraphQLConfigurationValues(t *testing.T) { }, "max_request_size" : %d }`, len(requestQuery))) - assert.Equal(t, 200, response.Result().StatusCode) + rest.RequireStatus(t, response, http.StatusOK) response = rt.SendAdminRequest("POST", "/db/_graphql", requestQuery) - assert.Equal(t, 200, response.Result().StatusCode) - assert.Equal(t, `{"data":{"square":16}}`, string(response.BodyBytes())) + rest.RequireStatus(t, response, http.StatusOK) + assert.Equal(t, `{"data":{"square":16}}`, response.BodyString()) response = rt.SendAdminRequest("GET", "/db/_config/graphql", "") - assert.Equal(t, 200, response.Result().StatusCode) - assert.Contains(t, string(response.BodyBytes()), fmt.Sprintf(`"max_request_size":%d`, len(requestQuery))) + rest.RequireStatus(t, response, http.StatusOK) + assert.Contains(t, response.BodyString(), fmt.Sprintf(`"max_request_size":%d`, len(requestQuery))) headerMap := map[string]string{ "Content-type": "application/graphql", } response = rt.SendAdminRequestWithHeaders("POST", "/db/_graphql", `query{square(n:4)}`, headerMap) - assert.Equal(t, 200, response.Result().StatusCode) - assert.Equal(t, `{"data":{"square":16}}`, string(response.BodyBytes())) + rest.RequireStatus(t, response, http.StatusOK) + assert.Equal(t, response.BodyString(), `{"data":{"square":16}}`) queryParam := `query($numberToBeSquared:Int!){ square(n:$numberToBeSquared) }` variableParam := `{"numberToBeSquared": 4}` getRequestUrl := fmt.Sprintf("/db/_graphql?query=%s&variables=%s", queryParam, variableParam) response = rt.SendAdminRequest("GET", getRequestUrl, "") - assert.Equal(t, 200, response.Result().StatusCode) - assert.Equal(t, `{"data":{"square":16}}`, string(response.BodyBytes())) + rest.RequireStatus(t, response, http.StatusOK) + assert.Equal(t, response.BodyString(), `{"data":{"square":16}}`) }) @@ -462,13 +463,13 @@ func TestValidGraphQLConfigurationValues(t *testing.T) { } } }`) - assert.Equal(t, 200, response.Result().StatusCode) + rest.AssertStatus(t, response, http.StatusOK) err = os.Remove("schema.graphql") assert.NoError(t, err) response = rt.SendAdminRequest("GET", "/db/_config/graphql", "") - assert.Equal(t, 200, response.Result().StatusCode) - assert.Contains(t, string(response.BodyBytes()), `"schemaFile":"schema.graphql"`) + rest.RequireStatus(t, response, http.StatusOK) + assert.Contains(t, response.BodyString(), `"schemaFile":"schema.graphql"`) }) } @@ -506,7 +507,7 @@ func TestInvalidGraphQLConfigurationValues(t *testing.T) { err := json.Unmarshal([]byte(string(response.BodyBytes())), &responseMap) assert.NoError(t, err) - assert.Equal(t, 400, response.Result().StatusCode) + rest.AssertStatus(t, response, http.StatusBadRequest) assert.Contains(t, responseMap["reason"], "GraphQL schema too large") assert.Contains(t, responseMap["error"], "Bad Request") }) @@ -535,7 +536,7 @@ func TestInvalidGraphQLConfigurationValues(t *testing.T) { err := json.Unmarshal([]byte(string(response.BodyBytes())), &responseMap) assert.NoError(t, err) - assert.Equal(t, 400, response.Result().StatusCode) + rest.AssertStatus(t, response, http.StatusBadRequest) assert.Contains(t, responseMap["reason"], "too many GraphQL resolvers") assert.Contains(t, responseMap["error"], "Bad Request") @@ -557,7 +558,7 @@ func TestInvalidGraphQLConfigurationValues(t *testing.T) { }, "max_request_size" : 5 }`) - assert.Equal(t, 200, response.Result().StatusCode) + rest.RequireStatus(t, response, http.StatusOK) response = rt.SendAdminRequest("POST", "/db/_graphql", `{"query": "query($numberToBeSquared:Int!){ square(n:$numberToBeSquared) }", "variables": {"numberToBeSquared": 4}}`) @@ -565,7 +566,7 @@ func TestInvalidGraphQLConfigurationValues(t *testing.T) { err := json.Unmarshal([]byte(string(response.BodyBytes())), &responseMap) assert.NoError(t, err) - assert.Equal(t, 413, response.Result().StatusCode) + rest.RequireStatus(t, response, http.StatusRequestEntityTooLarge) assert.Contains(t, responseMap["reason"], "Arguments too large") assert.Contains(t, responseMap["error"], "Request Entity Too Large") @@ -573,8 +574,8 @@ func TestInvalidGraphQLConfigurationValues(t *testing.T) { "Content-type": "application/graphql", } response = rt.SendAdminRequestWithHeaders("POST", "/db/_graphql", `query{square(n:4)}`, headerMap) - assert.Equal(t, 413, response.Result().StatusCode) - err = json.Unmarshal([]byte(string(response.BodyBytes())), &responseMap) + rest.RequireStatus(t, response, http.StatusRequestEntityTooLarge) + err = json.Unmarshal(response.BodyBytes(), &responseMap) assert.NoError(t, err) assert.Contains(t, responseMap["reason"], "Arguments too large") assert.Contains(t, responseMap["error"], "Request Entity Too Large") @@ -584,8 +585,8 @@ func TestInvalidGraphQLConfigurationValues(t *testing.T) { getRequestUrl := fmt.Sprintf("/db/_graphql?query=%s&variables=%s", queryParam, variableParam) response = rt.SendAdminRequest("GET", getRequestUrl, "") - assert.Equal(t, 200, response.Result().StatusCode) - err = json.Unmarshal([]byte(string(response.BodyBytes())), &responseMap) + rest.RequireStatus(t, response, http.StatusOK) + err = json.Unmarshal(response.BodyBytes(), &responseMap) assert.NoError(t, err) assert.Contains(t, responseMap["reason"], "Arguments too large") assert.Contains(t, responseMap["error"], "Request Entity Too Large") @@ -611,7 +612,7 @@ func TestInvalidGraphQLConfigurationValues(t *testing.T) { err := json.Unmarshal([]byte(string(response.BodyBytes())), &responseMap) assert.NoError(t, err) - assert.Equal(t, 400, response.Result().StatusCode) + rest.RequireStatus(t, response, http.StatusBadRequest) assert.Contains(t, responseMap["reason"], "GraphQL config: only one of `schema` and `schemaFile` may be used") assert.Contains(t, responseMap["error"], "Bad Request") }) @@ -637,7 +638,7 @@ func TestInvalidGraphQLConfigurationValues(t *testing.T) { err = json.Unmarshal([]byte(string(response.BodyBytes())), &responseMap) assert.NoError(t, err) - assert.Equal(t, 400, response.Result().StatusCode) + rest.RequireStatus(t, response, http.StatusBadRequest) assert.Contains(t, responseMap["reason"], "Syntax Error GraphQL") assert.Contains(t, responseMap["reason"], "Unexpected Name") assert.Contains(t, responseMap["error"], "Bad Request") @@ -659,8 +660,8 @@ func TestSchemaSyntax(t *testing.T) { t.Run("Non-Admin", func(t *testing.T) { response := rt.SendRequest("PUT", "/db/_config/graphql", "{}") - assert.Equal(t, 404, response.Result().StatusCode) + rest.RequireStatus(t, response, http.StatusNotFound) response = rt.SendRequest("DELETE", "/db/_config/graphql", "{}") - assert.Equal(t, 404, response.Result().StatusCode) + rest.RequireStatus(t, response, http.StatusNotFound) }) } diff --git a/rest/functionsapitest/graphql_queries_test.go b/rest/functionsapitest/graphql_queries_test.go index 24b776bdd8..b1e317a08d 100644 --- a/rest/functionsapitest/graphql_queries_test.go +++ b/rest/functionsapitest/graphql_queries_test.go @@ -11,6 +11,7 @@ package functionsapitest import ( "encoding/json" "fmt" + "net/http" "strings" "testing" "time" @@ -37,8 +38,8 @@ func TestGraphQLQueryAdminOnly(t *testing.T) { t.Run("AsAdmin - getUser", func(t *testing.T) { t.Run("POST request", func(t *testing.T) { response := rt.SendAdminRequest("POST", "/db/_graphql", `{"query": "query($id:ID!){ getUser(id:$id) { id , name } }" , "variables": {"id": 1}}`) - assert.Equal(t, 200, response.Result().StatusCode) - assert.Equal(t, `{"data":{"getUser":{"id":"1","name":"user1"}}}`, string(response.BodyBytes())) + rest.AssertStatus(t, response, http.StatusOK) + assert.Equal(t, response.BodyString(), `{"data":{"getUser":{"id":"1","name":"user1"}}}`) }) t.Run("GET request", func(t *testing.T) { @@ -46,8 +47,8 @@ func TestGraphQLQueryAdminOnly(t *testing.T) { variableParam := `{"id": 1}` getRequestUrl := fmt.Sprintf("/db/_graphql?query=%s&variables=%s", queryParam, variableParam) response := rt.SendAdminRequest("GET", getRequestUrl, "") - assert.Equal(t, 200, response.Result().StatusCode) - assert.Equal(t, `{"data":{"getUser":{"id":"1","name":"user1"}}}`, string(response.BodyBytes())) + rest.AssertStatus(t, response, http.StatusOK) + assert.Equal(t, response.BodyString(), `{"data":{"getUser":{"id":"1","name":"user1"}}}`) }) t.Run("POST request with Headers", func(t *testing.T) { @@ -55,15 +56,15 @@ func TestGraphQLQueryAdminOnly(t *testing.T) { "Content-type": "application/graphql", } response := rt.SendAdminRequestWithHeaders("POST", "/db/_graphql", `query{getUser(id:1){id,name}}`, headerMap) - assert.Equal(t, 200, response.Result().StatusCode) - assert.Equal(t, `{"data":{"getUser":{"id":"1","name":"user1"}}}`, string(response.BodyBytes())) + rest.AssertStatus(t, response, http.StatusOK) + assert.Equal(t, response.BodyString(), `{"data":{"getUser":{"id":"1","name":"user1"}}}`) }) }) t.Run("AsAdmin - getAllUsers", func(t *testing.T) { response := rt.SendAdminRequest("POST", "/db/_graphql", `{"query": "query{getAllUsers{name}}"}`) - assert.Equal(t, 200, response.Result().StatusCode) - assert.Equal(t, `{"data":{"getAllUsers":[{"name":"user1"},{"name":"user2"},{"name":"user3"}]}}`, string(response.BodyBytes())) + rest.AssertStatus(t, response, http.StatusOK) + assert.Equal(t, response.BodyString(), `{"data":{"getAllUsers":[{"name":"user1"},{"name":"user2"},{"name":"user3"}]}}`) }) // Test multiple query operations in a single request @@ -80,7 +81,7 @@ func TestGraphQLQueryAdminOnly(t *testing.T) { "operationName": "%s" }`, queryParam, variableParam, operationParam) response := rt.SendAdminRequest("POST", "/db/_graphql", requestBody) - assert.Equal(t, 200, response.Result().StatusCode) + rest.RequireStatus(t, response, http.StatusOK) assert.Equal(t, expectedResponse, string(response.BodyBytes())) }) @@ -105,26 +106,26 @@ func TestGraphQLQueryCustomUser(t *testing.T) { t.Run("AsUser - getUser", func(t *testing.T) { response := rt.SendAdminRequest("POST", "/db/_user/", `{"name":"janhavi", "password":"password"}`) - assert.Equal(t, 201, response.Result().StatusCode) + rest.RequireStatus(t, response, http.StatusCreated) response = rt.SendUserRequestWithHeaders("POST", "/db/_graphql", `{"query": "query($id:ID!){ getUser(id:$id) { id , name } }" , "variables": {"id": 3}}`, nil, "janhavi", "password") - assert.Equal(t, 200, response.Result().StatusCode) - assert.Equal(t, `{"data":{"getUser":{"id":"3","name":"user3"}}}`, string(response.BodyBytes())) + rest.RequireStatus(t, response, http.StatusOK) + assert.Equal(t, `{"data":{"getUser":{"id":"3","name":"user3"}}}`, response.BodyString()) response = rt.SendAdminRequest("DELETE", "/db/_user/janhavi", "") - assert.Equal(t, 200, response.Result().StatusCode) + rest.RequireStatus(t, response, http.StatusOK) }) t.Run("AsUser - getAllUsers", func(t *testing.T) { response := rt.SendAdminRequest("POST", "/db/_user/", `{"name":"janhavi", "password":"password"}`) - assert.Equal(t, 201, response.Result().StatusCode) + rest.RequireStatus(t, response, http.StatusCreated) response = rt.SendUserRequestWithHeaders("POST", "/db/_graphql", `{"query": "query{getAllUsers{name}}"}`, nil, "janhavi", "password") - assert.Equal(t, 200, response.Result().StatusCode) + rest.RequireStatus(t, response, http.StatusOK) assert.Equal(t, `{"data":{"getAllUsers":[{"name":"user1"},{"name":"user2"},{"name":"user3"}]}}`, string(response.BodyBytes())) response = rt.SendAdminRequest("DELETE", "/db/_user/janhavi", "") - assert.Equal(t, 200, response.Result().StatusCode) + rest.RequireStatus(t, response, http.StatusOK) }) } @@ -145,12 +146,12 @@ func TestGraphQLQueriesGuest(t *testing.T) { t.Run("AsGuest - getUser", func(t *testing.T) { response := rt.SendRequest("POST", "/db/_graphql", `{"query": "query($id:ID!){ getUser(id:$id) { id , name } }" , "variables": {"id": 1}}`) - assert.Equal(t, 200, response.Result().StatusCode) + rest.RequireStatus(t, response, http.StatusOK) assert.Equal(t, `{"data":{"getUser":{"id":"1","name":"user1"}}}`, string(response.BodyBytes())) }) t.Run("AsGuest - getAllUsers", func(t *testing.T) { response := rt.SendRequest("POST", "/db/_graphql", `{"query": "query{getAllUsers{name}}"}`) - assert.Equal(t, 200, response.Result().StatusCode) + rest.RequireStatus(t, response, http.StatusOK) assert.Equal(t, `{"data":{"getAllUsers":[{"name":"user1"},{"name":"user2"},{"name":"user3"}]}}`, string(response.BodyBytes())) }) @@ -169,13 +170,13 @@ func TestGraphQLMutationsAdminOnly(t *testing.T) { t.Run("AsAdmin - updateName", func(t *testing.T) { response := rt.SendAdminRequest("POST", "/db/_graphql", `{"query":"mutation($id: ID!, $name:String!){ updateName(id:$id,name:$name) {id,name} }", "variables" : {"id":1,"name":"newUser"}}`) - assert.Equal(t, 200, response.Result().StatusCode) + rest.RequireStatus(t, response, http.StatusOK) assert.Equal(t, `{"data":{"updateName":{"id":"1","name":"newUser"}}}`, string(response.BodyBytes())) }) t.Run("AsAdmin - addEmail", func(t *testing.T) { response := rt.SendAdminRequest("POST", "/db/_graphql", `{"query": "mutation($id:ID!, $email: String!){ addEmail(id:$id, email:$email) {id,name,Emails} }" , "variables": {"id": 2, "email":"pqr@gmail.com"}}`) - assert.Equal(t, 200, response.Result().StatusCode) + rest.RequireStatus(t, response, http.StatusOK) assert.Equal(t, `{"data":{"addEmail":{"Emails":["xyz@gmail.com","def@gmail.com","pqr@gmail.com"],"id":"2","name":"user2"}}}`, string(response.BodyBytes())) }) @@ -193,7 +194,7 @@ func TestGraphQLMutationsAdminOnly(t *testing.T) { }`, queryParam, variableParam, operationParam) response := rt.SendAdminRequest("POST", "/db/_graphql", requestBody) - assert.Equal(t, 200, response.Result().StatusCode) + rest.RequireStatus(t, response, http.StatusOK) assert.Equal(t, expectedResponse, string(response.BodyBytes())) }) } @@ -211,26 +212,26 @@ func TestGraphQLMutationsCustomUser(t *testing.T) { t.Run("AsUser - updateName", func(t *testing.T) { response := rt.SendAdminRequest("POST", "/db/_user/", `{"name":"jinesh", "password":"password"}`) - assert.Equal(t, 201, response.Result().StatusCode) + rest.RequireStatus(t, response, http.StatusCreated) response = rt.SendUserRequestWithHeaders("POST", "/db/_graphql", `{"query":"mutation($id: ID!, $name:String!){ updateName(id:$id,name:$name) {id,name} }", "variables" : {"id":1,"name":"newUser"}}`, nil, "jinesh", "password") - assert.Equal(t, 200, response.Result().StatusCode) + rest.RequireStatus(t, response, http.StatusOK) assert.Equal(t, `{"data":{"updateName":{"id":"1","name":"newUser"}}}`, string(response.BodyBytes())) response = rt.SendAdminRequest("DELETE", "/db/_user/jinesh", "") - assert.Equal(t, 200, response.Result().StatusCode) + rest.RequireStatus(t, response, http.StatusOK) }) t.Run("AsUser - addEmail", func(t *testing.T) { response := rt.SendAdminRequest("POST", "/db/_user/", `{"name":"jinesh", "password":"password"}`) - assert.Equal(t, 201, response.Result().StatusCode) + rest.RequireStatus(t, response, http.StatusCreated) response = rt.SendUserRequestWithHeaders("POST", "/db/_graphql", `{"query": "mutation($id:ID!, $email: String!){ addEmail(id:$id, email:$email) {id,name,Emails} }" , "variables": {"id": 2, "email":"pqr@gmail.com"}}`, nil, "jinesh", "password") - assert.Equal(t, 200, response.Result().StatusCode) + rest.RequireStatus(t, response, http.StatusOK) testErrorMessage(t, response, "403 you are not allowed to call GraphQL resolver") response = rt.SendAdminRequest("DELETE", "/db/_user/jinesh", "") - assert.Equal(t, 200, response.Result().StatusCode) + rest.RequireStatus(t, response, http.StatusOK) }) } @@ -250,13 +251,13 @@ func TestGraphQLMutationsGuest(t *testing.T) { t.Run("AsGuest - updateName", func(t *testing.T) { response := rt.SendAdminRequest("POST", "/db/_graphql", `{"query":"mutation($id: ID!, $name:String!){ updateName(id:$id,name:$name) {id,name} }", "variables" : {"id":1,"name":"newUser"}}`) - assert.Equal(t, 200, response.Result().StatusCode) + rest.RequireStatus(t, response, http.StatusOK) assert.Equal(t, `{"data":{"updateName":{"id":"1","name":"newUser"}}}`, string(response.BodyBytes())) }) t.Run("AsGuest - addEmail", func(t *testing.T) { response := rt.SendAdminRequest("POST", "/db/_graphql", `{"query": "mutation($id:ID!, $email: String!){ addEmail(id:$id, email:$email) {id,name,Emails} }" , "variables": {"id": 2, "email":"pqr@gmail.com"}}`) - assert.Equal(t, 200, response.Result().StatusCode) + rest.RequireStatus(t, response, http.StatusOK) assert.Equal(t, `{"data":{"addEmail":{"Emails":["xyz@gmail.com","def@gmail.com","pqr@gmail.com"],"id":"2","name":"user2"}}}`, string(response.BodyBytes())) }) } @@ -289,14 +290,13 @@ func TestContextDeadline(t *testing.T) { t.Run("AsAdmin - exceedContextDeadline", func(t *testing.T) { requestQuery := fmt.Sprintf(`{"query": "query{ checkContextDeadline(Timeout:%d) }"}`, timeout.Milliseconds()*2) response := rt.SendAdminRequest("POST", "/db/_graphql", requestQuery) - - assert.Equal(t, 200, response.Result().StatusCode) + rest.RequireStatus(t, response, http.StatusOK) testErrorMessage(t, response, "context deadline exceeded") }) t.Run("AsAdmin - doNotExceedContextDeadline", func(t *testing.T) { requestQuery := `{"query": "query{ checkContextDeadline(Timeout:1) }"}` response := rt.SendAdminRequest("POST", "/db/_graphql", requestQuery) - assert.Equal(t, 200, response.Result().StatusCode) + rest.RequireStatus(t, response, http.StatusOK) assert.Equal(t, `{"data":{"checkContextDeadline":0}}`, string(response.BodyBytes())) }) diff --git a/rest/functionsapitest/user_functions_admin_test.go b/rest/functionsapitest/user_functions_admin_test.go index f8673bf59b..016cb170fd 100644 --- a/rest/functionsapitest/user_functions_admin_test.go +++ b/rest/functionsapitest/user_functions_admin_test.go @@ -11,6 +11,7 @@ package functionsapitest import ( "encoding/json" "fmt" + "net/http" "testing" "github.com/couchbase/sync_gateway/base" @@ -29,15 +30,15 @@ func TestFunctionsConfigGetWithoutFeatureFlag(t *testing.T) { t.Run("Functions, Non-Admin", func(t *testing.T) { response := rt.SendRequest("GET", "/db/_config/functions", "") - assert.Equal(t, 404, response.Result().StatusCode) + rest.RequireStatus(t, response, http.StatusNotFound) }) t.Run("All Functions", func(t *testing.T) { response := rt.SendAdminRequest("GET", "/db/_config/functions", "") - assert.Equal(t, 404, response.Result().StatusCode) + rest.RequireStatus(t, response, http.StatusNotFound) }) t.Run("Single Function", func(t *testing.T) { response := rt.SendAdminRequest("GET", "/db/_config/functions/cube", "") - assert.Equal(t, 404, response.Result().StatusCode) + rest.RequireStatus(t, response, http.StatusNotFound) }) } @@ -99,15 +100,15 @@ func TestFunctionsConfigGetMissing(t *testing.T) { t.Run("Non-Admin", func(t *testing.T) { response := rt.SendRequest("GET", "/db/_config/functions", "") - assert.Equal(t, 404, response.Result().StatusCode) + rest.RequireStatus(t, response, http.StatusNotFound) }) t.Run("All", func(t *testing.T) { response := rt.SendAdminRequest("GET", "/db/_config/functions", "") - assert.Equal(t, 404, response.Result().StatusCode) + rest.RequireStatus(t, response, http.StatusNotFound) }) t.Run("Missing", func(t *testing.T) { response := rt.SendAdminRequest("GET", "/db/_config/functions/cube", "") - assert.Equal(t, 404, response.Result().StatusCode) + rest.RequireStatus(t, response, http.StatusNotFound) }) } func TestFunctionsConfigGet(t *testing.T) { @@ -130,7 +131,7 @@ func TestFunctionsConfigGet(t *testing.T) { t.Run("Non-Admin", func(t *testing.T) { response := rt.SendRequest("GET", "/db/_config/functions", "") - assert.Equal(t, 404, response.Result().StatusCode) + rest.RequireStatus(t, response, http.StatusNotFound) }) t.Run("All", func(t *testing.T) { response := rt.SendAdminRequest("GET", "/db/_config/functions", "") @@ -146,7 +147,7 @@ func TestFunctionsConfigGet(t *testing.T) { }) t.Run("Missing", func(t *testing.T) { response := rt.SendAdminRequest("GET", "/db/_config/functions/bogus", "") - assert.Equal(t, 404, response.Result().StatusCode) + rest.RequireStatus(t, response, http.StatusNotFound) }) } @@ -170,9 +171,9 @@ func TestFunctionsConfigPut(t *testing.T) { t.Run("Non-Admin", func(t *testing.T) { response := rt.SendRequest("PUT", "/db/_config/functions", "{}") - assert.Equal(t, 404, response.Result().StatusCode) + rest.RequireStatus(t, response, http.StatusNotFound) response = rt.SendRequest("DELETE", "/db/_config/functions", "{}") - assert.Equal(t, 404, response.Result().StatusCode) + rest.RequireStatus(t, response, http.StatusNotFound) }) t.Run("ReplaceAll", func(t *testing.T) { response := rt.SendAdminRequest("PUT", "/db/_config/functions", `{ @@ -181,26 +182,26 @@ func TestFunctionsConfigPut(t *testing.T) { "code": "function(context,args){return args.numero + args.numero;}", "args": ["numero"], "allow": {"channels": ["*"]}} } }`) - assert.Equal(t, 200, response.Result().StatusCode) + rest.RequireStatus(t, response, http.StatusOK) assert.NotNil(t, rt.GetDatabase().Options.UserFunctions.Definitions["sum"]) assert.Nil(t, rt.GetDatabase().Options.UserFunctions.Definitions["square"]) response = rt.SendAdminRequest("GET", "/db/_function/sum?numero=13", "") - assert.Equal(t, 200, response.Result().StatusCode) - assert.Equal(t, "26", string(response.BodyBytes())) + rest.RequireStatus(t, response, http.StatusOK) + assert.Equal(t, "26", response.BodyString()) response = rt.SendAdminRequest("GET", "/db/_function/square?numero=13", "") - assert.Equal(t, 404, response.Result().StatusCode) + rest.RequireStatus(t, response, http.StatusNotFound) }) t.Run("DeleteAll", func(t *testing.T) { response := rt.SendAdminRequest("DELETE", "/db/_config/functions", "") - assert.Equal(t, 200, response.Result().StatusCode) + rest.RequireStatus(t, response, http.StatusOK) assert.Nil(t, rt.GetDatabase().Options.UserFunctions) response = rt.SendAdminRequest("GET", "/db/_function/square?numero=13", "") - assert.Equal(t, 404, response.Result().StatusCode) + rest.RequireStatus(t, response, http.StatusNotFound) }) } @@ -224,15 +225,15 @@ func TestFunctionsConfigPutOne(t *testing.T) { t.Run("Non-Admin", func(t *testing.T) { response := rt.SendRequest("PUT", "/db/_config/functions/square", "{}") - assert.Equal(t, 404, response.Result().StatusCode) + rest.RequireStatus(t, response, http.StatusNotFound) response = rt.SendRequest("DELETE", "/db/_config/function/square", "{}") - assert.Equal(t, 404, response.Result().StatusCode) + rest.RequireStatus(t, response, http.StatusNotFound) }) t.Run("Bogus", func(t *testing.T) { response := rt.SendAdminRequest("PUT", "/db/_config/functions/square", `[]`) - assert.Equal(t, 400, response.Result().StatusCode) + rest.RequireStatus(t, response, http.StatusBadRequest) response = rt.SendAdminRequest("PUT", "/db/_config/functions/square", `{"ruby": "foo"}`) - assert.Equal(t, 400, response.Result().StatusCode) + rest.RequireStatus(t, response, http.StatusBadRequest) }) t.Run("Add", func(t *testing.T) { response := rt.SendAdminRequest("PUT", "/db/_config/functions/sum", `{ @@ -241,13 +242,14 @@ func TestFunctionsConfigPutOne(t *testing.T) { "args": ["numero"], "allow": {"channels": ["*"]} }`) - assert.Equal(t, 200, response.Result().StatusCode) + rest.RequireStatus(t, response, http.StatusOK) assert.NotNil(t, rt.GetDatabase().Options.UserFunctions.Definitions["sum"]) assert.NotNil(t, rt.GetDatabase().Options.UserFunctions.Definitions["square"]) response = rt.SendAdminRequest("GET", "/db/_function/sum?numero=13", "") - assert.Equal(t, "26", string(response.BodyBytes())) + rest.RequireStatus(t, response, http.StatusOK) + assert.Equal(t, "26", response.BodyString()) }) t.Run("ReplaceOne", func(t *testing.T) { response := rt.SendAdminRequest("PUT", "/db/_config/functions/square", `{ @@ -256,23 +258,24 @@ func TestFunctionsConfigPutOne(t *testing.T) { "args": ["n"], "allow": {"channels": ["*"]} }`) - assert.Equal(t, 200, response.Result().StatusCode) + rest.RequireStatus(t, response, http.StatusOK) assert.NotNil(t, rt.GetDatabase().Options.UserFunctions.Definitions["sum"]) assert.NotNil(t, rt.GetDatabase().Options.UserFunctions.Definitions["square"]) response = rt.SendAdminRequest("GET", "/db/_function/square?n=13", "") - assert.Equal(t, "-169", string(response.BodyBytes())) + rest.RequireStatus(t, response, http.StatusOK) + assert.Equal(t, "-169", response.BodyString()) }) t.Run("DeleteOne", func(t *testing.T) { response := rt.SendAdminRequest("DELETE", "/db/_config/functions/square", "") - assert.Equal(t, 200, response.Result().StatusCode) + rest.RequireStatus(t, response, http.StatusOK) assert.Nil(t, rt.GetDatabase().Options.UserFunctions.Definitions["square"]) assert.Equal(t, 1, len(rt.GetDatabase().Options.UserFunctions.Definitions)) response = rt.SendAdminRequest("GET", "/db/_function/square?n=13", "") - assert.Equal(t, 404, response.Result().StatusCode) + rest.RequireStatus(t, response, http.StatusNotFound) }) } @@ -301,7 +304,7 @@ func TestMaxRequestSize(t *testing.T) { assert.NoError(t, err) response := rt.SendAdminRequest("PUT", "/db/_config/functions", string(request)) - assert.Equal(t, 200, response.Result().StatusCode) + rest.RequireStatus(t, response, http.StatusOK) response = rt.SendAdminRequest("GET", "/db/_config/functions", "") assert.NotNil(t, response) @@ -320,18 +323,18 @@ func TestMaxRequestSize(t *testing.T) { assert.NoError(t, err) response := rt.SendAdminRequest("PUT", "/db/_config/functions", string(request)) - assert.Equal(t, 200, response.Result().StatusCode) + rest.RequireStatus(t, response, http.StatusOK) t.Run("GET req", func(t *testing.T) { response := rt.SendAdminRequest("GET", "/db/_function/multiply?first=1&second=2&third=3&fourth=4", "") - assert.Equal(t, 200, response.Result().StatusCode) - assert.Equal(t, "24", string(response.BodyBytes())) + rest.RequireStatus(t, response, http.StatusOK) + assert.Equal(t, "24", response.BodyString()) }) t.Run("POST req", func(t *testing.T) { response := rt.SendAdminRequest("POST", "/db/_function/multiply", `{"first":1,"second":2,"third":3,"fourth":4}`) - assert.Equal(t, 200, response.Result().StatusCode) - assert.Equal(t, "24", string(response.BodyBytes())) + rest.RequireStatus(t, response, http.StatusOK) + assert.Equal(t, "24", response.BodyString()) }) }) @@ -343,17 +346,17 @@ func TestMaxRequestSize(t *testing.T) { assert.NoError(t, err) response := rt.SendAdminRequest("PUT", "/db/_config/functions", string(request)) - assert.Equal(t, 200, response.Result().StatusCode) + rest.RequireStatus(t, response, http.StatusOK) t.Run("GET req", func(t *testing.T) { response = rt.SendAdminRequest("GET", "/db/_function/multiply?first=1&second=2&third=3&fourth=4", "") - assert.Equal(t, 413, response.Result().StatusCode) + rest.AssertStatus(t, response, http.StatusRequestEntityTooLarge) assert.Contains(t, string(response.BodyBytes()), "Arguments too large") }) t.Run("POST req", func(t *testing.T) { response = rt.SendAdminRequest("POST", "/db/_function/multiply", `{"first":1,"second":2,"third":3,"fourth":4}`) - assert.Equal(t, 413, response.Result().StatusCode) + rest.AssertStatus(t, response, http.StatusRequestEntityTooLarge) assert.Contains(t, string(response.BodyBytes()), "Arguments too large") }) }) @@ -394,7 +397,7 @@ func TestSaveAndGet(t *testing.T) { t.Run("Save The Functions", func(t *testing.T) { response := rt.SendAdminRequest("PUT", "/db/_config/functions", string(request)) - assert.Equal(t, 200, response.Result().StatusCode) + rest.RequireStatus(t, response, http.StatusOK) }) // Get The Function Definition and match with the one posted @@ -423,7 +426,7 @@ func TestSaveAndGet(t *testing.T) { // Check For Non-Existent Function response := rt.SendAdminRequest("GET", fmt.Sprintf("/db/_config/functions/%s", "nonExistent"), "") assert.NotNil(t, response) - assert.Equal(t, 404, response.Result().StatusCode) + rest.RequireStatus(t, response, http.StatusNotFound) }) // GET: Run a Function and check the value @@ -467,7 +470,7 @@ func TestSaveAndGet(t *testing.T) { t.Run("Test For Able to Run Function for Non-Admin Users", func(t *testing.T) { response := rt.SendAdminRequest("POST", "/db/_user/", `{"name":"ritik","email":"ritik.raj@couchbase.com", "password":"letmein", "admin_channels":["*"]}`) - assert.Equal(t, 201, response.Result().StatusCode) + rest.RequireStatus(t, response, http.StatusCreated) response = rt.SendUserRequestWithHeaders("GET", fmt.Sprintf("/db/_function/%s?n=4", "square"), "", nil, "ritik", "letmein") assert.NotNil(t, response) @@ -500,7 +503,7 @@ func TestSaveAndUpdateAndGet(t *testing.T) { // Save The Function t.Run("Save The Functions", func(t *testing.T) { response := rt.SendAdminRequest("PUT", "/db/_config/functions", string(request)) - assert.Equal(t, 200, response.Result().StatusCode) + rest.RequireStatus(t, response, http.StatusOK) }) // Get The Function Definition and match with the one posted @@ -524,7 +527,7 @@ func TestSaveAndUpdateAndGet(t *testing.T) { assert.NoError(t, err) response := rt.SendAdminRequest("PUT", fmt.Sprintf("/db/_config/functions/%s", functionName), string(requestBody)) - assert.Equal(t, 200, response.Result().StatusCode) + rest.RequireStatus(t, response, http.StatusOK) functionName = "squareN1QL" @@ -534,7 +537,7 @@ func TestSaveAndUpdateAndGet(t *testing.T) { assert.NoError(t, err) response = rt.SendAdminRequest("PUT", fmt.Sprintf("/db/_config/functions/%s", functionName), string(requestBody)) - assert.Equal(t, 200, response.Result().StatusCode) + rest.RequireStatus(t, response, http.StatusOK) }) // Get the Updated Function @@ -583,7 +586,7 @@ func TestSaveAndDeleteAndGet(t *testing.T) { t.Run("Save The Functions", func(t *testing.T) { response := rt.SendAdminRequest("PUT", "/db/_config/functions", string(request)) - assert.Equal(t, 200, response.Result().StatusCode) + rest.RequireStatus(t, response, http.StatusOK) }) // Get The Function Definition and match with the one posted @@ -602,7 +605,7 @@ func TestSaveAndDeleteAndGet(t *testing.T) { t.Run("Delete A Specific Function", func(t *testing.T) { response := rt.SendAdminRequest("DELETE", fmt.Sprintf("/db/_config/functions/%s", functionNameToBeDeleted), "") - assert.Equal(t, 200, response.Result().StatusCode) + rest.RequireStatus(t, response, http.StatusOK) }) t.Run("Get remaining functions and check schema", func(t *testing.T) { @@ -630,13 +633,12 @@ func TestSaveAndDeleteAndGet(t *testing.T) { // Delete All functions t.Run("Delete all functions", func(t *testing.T) { response := rt.SendAdminRequest("DELETE", "/db/_config/functions", "") - assert.Equal(t, 200, response.Result().StatusCode) + rest.RequireStatus(t, response, http.StatusOK) }) t.Run("Get All Non-existing Functions And Check HTTP Status", func(t *testing.T) { response := rt.SendAdminRequest("GET", "/db/_config/functions", "") - assert.Equal(t, 404, response.Result().StatusCode) - + rest.RequireStatus(t, response, http.StatusNotFound) }) } func TestDeleteNonExisting(t *testing.T) { @@ -650,11 +652,11 @@ func TestDeleteNonExisting(t *testing.T) { // NEGATIVE CASES t.Run("Delete All Non-existing functions and check HTTP Status Code", func(t *testing.T) { response := rt.SendAdminRequest("DELETE", "/db/_config/functions", "") - assert.Equal(t, 404, response.Result().StatusCode) + rest.RequireStatus(t, response, http.StatusNotFound) }) t.Run("Delete a non-existing function and check HTTP Status Code", func(t *testing.T) { response := rt.SendAdminRequest("DELETE", "/db/_config/functions/foo", "") - assert.Equal(t, 404, response.Result().StatusCode) + rest.RequireStatus(t, response, http.StatusNotFound) }) } diff --git a/rest/functionsapitest/user_functions_queries_test.go b/rest/functionsapitest/user_functions_queries_test.go index 21084e0761..9d951e6db9 100644 --- a/rest/functionsapitest/user_functions_queries_test.go +++ b/rest/functionsapitest/user_functions_queries_test.go @@ -20,6 +20,7 @@ import ( "github.com/couchbase/sync_gateway/db/functions" "github.com/couchbase/sync_gateway/rest" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) //////// FUNCTIONS EXECUTION API TESTS @@ -72,63 +73,62 @@ func TestUserFunctions(t *testing.T) { } func TestJSFunctionAsGuest(t *testing.T) { - rt := rest.NewRestTester(t, &rest.RestTesterConfig{GuestEnabled: true, EnableUserQueries: true}) - if rt == nil { - return - } - defer rt.Close() - - rt.DatabaseConfig = &rest.DatabaseConfig{ - DbConfig: rest.DbConfig{ - UserFunctions: &kUserFunctionAuthTestConfig, + rt := rest.NewRestTester(t, &rest.RestTesterConfig{ + GuestEnabled: true, + EnableUserQueries: true, + DatabaseConfig: &rest.DatabaseConfig{ + DbConfig: rest.DbConfig{ + UserFunctions: &kUserFunctionAuthTestConfig, + }, }, - } + }) + defer rt.Close() sendReqFn := rt.SendRequest t.Run("function not configured", func(t *testing.T) { response := sendReqFn("GET", "/db/_function/xxxx", "") - assert.Equal(t, 401, response.Result().StatusCode) - assert.Contains(t, string(response.BodyBytes()), "login required") + rest.AssertStatus(t, response, http.StatusUnauthorized) + assert.Contains(t, response.BodyString(), "login required") }) t.Run("allow all", func(t *testing.T) { response := sendReqFn("GET", "/db/_function/allow_all", "") - assert.Equal(t, 200, response.Result().StatusCode) - assert.EqualValues(t, `"OK"`, string(response.BodyBytes())) + rest.AssertStatus(t, response, http.StatusOK) + assert.EqualValues(t, `"OK"`, response.BodyString()) }) t.Run("user required", func(t *testing.T) { t.Skip("Does not work with SG_TEST_USE_DEFAULT_COLLECTION=true CBG-2702") response := sendReqFn("POST", "/db/_function/square", `{"numero": 42}`) - assert.Equal(t, 401, response.Result().StatusCode) - assert.Contains(t, string(response.BodyBytes()), "login required") + rest.AssertStatus(t, response, http.StatusUnauthorized) + assert.Contains(t, "login required", response.BodyString()) }) t.Run("admin-only", func(t *testing.T) { response := sendReqFn("GET", "/db/_function/admin_only", "") - assert.Equal(t, 401, response.Result().StatusCode) - assert.Contains(t, string(response.BodyBytes()), "login required") + rest.AssertStatus(t, response, http.StatusUnauthorized) + assert.Contains(t, response.BodyString(), "login required") }) } func testUserFunctionsCommon(t *testing.T, rt *rest.RestTester, sendReqFn func(string, string, string) *rest.TestResponse) { t.Run("commons/passing a param", func(t *testing.T) { response := sendReqFn("POST", "/db/_function/square", `{"numero": 42}`) - assert.Equal(t, 200, response.Result().StatusCode) - assert.EqualValues(t, "1764", string(response.BodyBytes())) + rest.AssertStatus(t, response, http.StatusOK) + assert.EqualValues(t, "1764", response.BodyString()) }) t.Run("commons/passing a param through query params", func(t *testing.T) { response := sendReqFn("GET", "/db/_function/square?numero=42", "") - assert.Equal(t, 200, response.Result().StatusCode) - assert.EqualValues(t, "1764", string(response.BodyBytes())) + rest.AssertStatus(t, response, http.StatusOK) + assert.EqualValues(t, "1764", response.BodyString()) }) t.Run("commons/allow all", func(t *testing.T) { response := sendReqFn("GET", "/db/_function/allow_all", "") - assert.Equal(t, 200, response.Result().StatusCode) - assert.EqualValues(t, `"OK"`, string(response.BodyBytes())) + rest.AssertStatus(t, response, http.StatusOK) + assert.EqualValues(t, `"OK"`, response.BodyString()) }) } @@ -137,14 +137,14 @@ func testUserFunctionsAsAdmin(t *testing.T, rt *rest.RestTester) { t.Run("Admin-only", func(t *testing.T) { response := rt.SendAdminRequest("GET", "/db/_function/admin_only", "") - assert.Equal(t, 200, response.Result().StatusCode) - assert.EqualValues(t, "\"OK\"", string(response.BodyBytes())) + rest.AssertStatus(t, response, http.StatusOK) + assert.EqualValues(t, "\"OK\"", response.BodyString()) }) // negative cases: t.Run("function not configured", func(t *testing.T) { response := rt.SendAdminRequest("GET", "/db/_function/xxxx", "") - assert.Equal(t, 404, response.Result().StatusCode) + rest.RequireStatus(t, response, http.StatusNotFound) }) } @@ -158,12 +158,12 @@ func testUserFunctionsAsUser(t *testing.T, rt *rest.RestTester) { // negative cases t.Run("function not configured", func(t *testing.T) { response := sendReqFn("GET", "/db/_function/xxxx", "") - assert.Equal(t, 403, response.Result().StatusCode) + rest.RequireStatus(t, response, http.StatusForbidden) }) t.Run("admin-only", func(t *testing.T) { response := sendReqFn("GET", "/db/_function/admin_only", "") - assert.Equal(t, 403, response.Result().StatusCode) + rest.RequireStatus(t, response, http.StatusForbidden) }) } @@ -200,21 +200,16 @@ func TestUserN1QLQueries(t *testing.T) { request, err := json.Marshal(kUserN1QLFunctionsAuthTestConfig) assert.NoError(t, err) response := rt.SendAdminRequest("PUT", "/db/_config/functions", string(request)) - assert.Equal(t, 200, response.Result().StatusCode) + rest.RequireStatus(t, response, http.StatusOK) t.Run("AsAdmin", func(t *testing.T) { testUserQueriesAsAdmin(t, rt) }) t.Run("AsUser", func(t *testing.T) { testUserQueriesAsUser(t, rt) }) } func TestN1QLFunctionAsGuest(t *testing.T) { - if base.UnitTestUrlIsWalrus() { - t.Skip("This test requires persistent configs") - } + TestRequireN1QLSupport(t) rt := rest.NewRestTester(t, &rest.RestTesterConfig{GuestEnabled: true, EnableUserQueries: true}) - if rt == nil { - return - } defer rt.Close() rt.DatabaseConfig = &rest.DatabaseConfig{ @@ -227,13 +222,9 @@ func TestN1QLFunctionAsGuest(t *testing.T) { t.Run("select user", func(t *testing.T) { response := sendReqFn("GET", "/db/_function/user", "") - if !assert.Equal(t, 200, response.Result().StatusCode) { - return - } + rest.RequireStatus(t, response, http.StatusOK) var body []map[string]any - if !assert.NoError(t, json.Unmarshal(response.BodyBytes(), &body)) { - return - } + require.NoError(t, json.Unmarshal(response.BodyBytes(), &body)) user, ok := body[0]["user"].(map[string]any) assert.True(t, ok, "Result 'user' property is missing or not an object") assert.Equal(t, "", user["name"]) @@ -242,20 +233,20 @@ func TestN1QLFunctionAsGuest(t *testing.T) { t.Run("user required", func(t *testing.T) { t.Skip("Does not work with SG_TEST_USE_DEFAULT_COLLECTION=true CBG-2702") response := sendReqFn("POST", "/db/_function/square", `{"numero": 16}`) - assert.Equal(t, 401, response.Result().StatusCode) - assert.Contains(t, string(response.BodyBytes()), "login required") + rest.RequireStatus(t, response, http.StatusUnauthorized) + assert.Contains(t, response.BodyString(), "login required") }) t.Run("admin only", func(t *testing.T) { response := sendReqFn("GET", "/db/_function/admin_only", "") - assert.Equal(t, 401, response.Result().StatusCode) - assert.Contains(t, string(response.BodyBytes()), "login required") + rest.RequireStatus(t, response, http.StatusUnauthorized) + assert.Contains(t, response.BodyString(), "login required") }) t.Run("unconfigured query", func(t *testing.T) { response := sendReqFn("GET", "/db/_function/xxxx", "") - assert.Equal(t, 401, response.Result().StatusCode) - assert.Contains(t, string(response.BodyBytes()), "login required") + rest.RequireStatus(t, response, http.StatusUnauthorized) + assert.Contains(t, response.BodyString(), "login required") }) } @@ -263,8 +254,8 @@ func testUserQueriesCommon(t *testing.T, rt *rest.RestTester, sendReqFn func(str // positive cases: t.Run("commons/passing a param", func(t *testing.T) { response := sendReqFn("POST", "/db/_function/square", `{"numero": 16}`) - assert.Equal(t, 200, response.Result().StatusCode) - assert.EqualValues(t, "[{\"square\":256}\n]\n", string(response.BodyBytes())) + rest.AssertStatus(t, response, http.StatusOK) + assert.EqualValues(t, "[{\"square\":256}\n]\n", response.BodyString()) }) } @@ -274,20 +265,20 @@ func testUserQueriesAsAdmin(t *testing.T, rt *rest.RestTester) { // positive cases: t.Run("select user", func(t *testing.T) { response := rt.SendAdminRequest("GET", "/db/_function/user", "") - assert.Equal(t, 200, response.Result().StatusCode) - assert.EqualValues(t, "[{\"user\":{}}\n]\n", string(response.BodyBytes())) + rest.AssertStatus(t, response, http.StatusOK) + assert.EqualValues(t, "[{\"user\":{}}\n]\n", response.BodyString()) }) t.Run("admin only", func(t *testing.T) { response := rt.SendAdminRequest("GET", "/db/_function/admin_only", "") - assert.Equal(t, 200, response.Result().StatusCode) - assert.EqualValues(t, "[{\"status\":\"ok\"}\n]\n", string(response.BodyBytes())) + rest.AssertStatus(t, response, http.StatusOK) + assert.EqualValues(t, "[{\"status\":\"ok\"}\n]\n", response.BodyString()) }) //negative cases: t.Run("unconfigured query", func(t *testing.T) { response := rt.SendAdminRequest("GET", "/db/_function/xxxx", "") - assert.Equal(t, 404, response.Result().StatusCode) + rest.RequireStatus(t, response, http.StatusNotFound) }) } @@ -301,20 +292,20 @@ func testUserQueriesAsUser(t *testing.T, rt *rest.RestTester) { // positive cases: t.Run("select user", func(t *testing.T) { response := sendReqFn("GET", "/db/_function/user", "") - assert.Equal(t, 200, response.Result().StatusCode) - assert.True(t, strings.HasPrefix(string(response.BodyBytes()), `[{"user":{"channels":["`)) - assert.True(t, strings.HasSuffix(string(response.BodyBytes()), "\"],\"email\":\"\",\"name\":\"alice\"}}\n]\n")) + rest.AssertStatus(t, response, http.StatusOK) + assert.True(t, strings.HasPrefix(response.BodyString(), `[{"user":{"channels":["`)) + assert.True(t, strings.HasSuffix(response.BodyString(), "\"],\"email\":\"\",\"name\":\"alice\"}}\n]\n")) }) //negative cases: t.Run("admin only", func(t *testing.T) { response := sendReqFn("GET", "/db/_function/admin_only", "") - assert.Equal(t, 403, response.Result().StatusCode) + rest.RequireStatus(t, response, http.StatusForbidden) }) t.Run("unconfigured query", func(t *testing.T) { response := sendReqFn("GET", "/db/_function/xxxx", "") - assert.Equal(t, 403, response.Result().StatusCode) + rest.RequireStatus(t, response, http.StatusForbidden) }) } @@ -377,28 +368,28 @@ func TestFunctionMutability(t *testing.T) { var callerFuncName string response := rt.SendAdminRequest("PUT", "/db/_config/functions", string(request)) - assert.Equal(t, 200, response.Result().StatusCode) + rest.RequireStatus(t, response, http.StatusOK) //Negative Cases t.Run("Func with mutating True calls another function with a mutating value of False", func(t *testing.T) { putFuncName = "putDocMutabilityFalse" callerFuncName = "callerMutabilityTrue" response := rt.SendAdminRequest("POST", fmt.Sprintf("/db/_function/%s", callerFuncName), fmt.Sprintf(body, putFuncName)) - assert.Equal(t, http.StatusForbidden, response.Result().StatusCode) + rest.RequireStatus(t, response, http.StatusForbidden) }) t.Run("Func with mutating False calls another function with a mutating value of True", func(t *testing.T) { putFuncName = "putDocMutabilityTrue" callerFuncName = "callerMutabilityFalse" response := rt.SendAdminRequest("POST", fmt.Sprintf("/db/_function/%s", callerFuncName), fmt.Sprintf(body, putFuncName)) - assert.Equal(t, http.StatusForbidden, response.Result().StatusCode) + rest.RequireStatus(t, response, http.StatusForbidden) }) t.Run("Func with mutating False calls another function with a mutating value of False", func(t *testing.T) { putFuncName = "putDocMutabilityFalse" callerFuncName = "callerMutabilityFalse" response := rt.SendAdminRequest("POST", fmt.Sprintf("/db/_function/%s", callerFuncName), fmt.Sprintf(body, putFuncName)) - assert.Equal(t, http.StatusForbidden, response.Result().StatusCode) + rest.RequireStatus(t, response, http.StatusForbidden) }) //Mutability of the function being called is false. Will fail as once you’ve lost the ability to mutate, you can’t get it back. @@ -406,7 +397,7 @@ func TestFunctionMutability(t *testing.T) { putFuncName = "putDocMutabilityFalse" callerFuncName = "callerOverride" response := rt.SendAdminRequest("POST", fmt.Sprintf("/db/_function/%s", callerFuncName), fmt.Sprintf(body, putFuncName)) - assert.Equal(t, http.StatusForbidden, response.Result().StatusCode) + rest.RequireStatus(t, response, http.StatusForbidden) }) //Positive Cases @@ -414,8 +405,8 @@ func TestFunctionMutability(t *testing.T) { putFuncName = "putDocMutabilityTrue" callerFuncName = "callerMutabilityTrue" response := rt.SendAdminRequest("POST", fmt.Sprintf("/db/_function/%s", callerFuncName), fmt.Sprintf(body, putFuncName)) - assert.Equal(t, http.StatusOK, response.Result().StatusCode) - assert.EqualValues(t, "\"Test123\"", string(response.BodyBytes())) + rest.AssertStatus(t, response, http.StatusOK) + assert.EqualValues(t, "\"Test123\"", response.BodyString()) }) // using context.admin privilege overides its own mutatibility flag, it acts as though the REST API were called by an administrator. @@ -423,8 +414,8 @@ func TestFunctionMutability(t *testing.T) { putFuncName = "putDocMutabilityTrue" callerFuncName = "callerOverride" response := rt.SendAdminRequest("POST", fmt.Sprintf("/db/_function/%s", callerFuncName), fmt.Sprintf(body, putFuncName)) - assert.Equal(t, http.StatusOK, response.Result().StatusCode) - assert.EqualValues(t, "\"Test123\"", string(response.BodyBytes())) + rest.AssertStatus(t, response, http.StatusOK) + assert.EqualValues(t, "\"Test123\"", response.BodyString()) }) } @@ -459,13 +450,19 @@ func TestFunctionTimeout(t *testing.T) { t.Run("under time limit", func(t *testing.T) { reqBody := `{"ms": 1}` response := rt.SendAdminRequest("POST", "/db/_function/sleep", reqBody) - assert.Equal(t, 200, response.Result().StatusCode) + rest.RequireStatus(t, response, http.StatusOK) }) // negative case: t.Run("over time limit", func(t *testing.T) { reqBody := fmt.Sprintf(`{"ms": %d}`, 2*timeout) response := rt.SendAdminRequest("POST", "/db/_function/sleep", reqBody) - assert.Equal(t, 500, response.Result().StatusCode) + rest.RequireStatus(t, response, http.StatusInternalServerError) }) } + +func TestRequireN1QLSupport(t *testing.T) { + if base.TestsDisableGSI() { + t.Skip("This test requires Couchbase Server backed N1QL") + } +} diff --git a/rest/importtest/import_test.go b/rest/importtest/import_test.go index 141fd59a5a..ddfe62333a 100644 --- a/rest/importtest/import_test.go +++ b/rest/importtest/import_test.go @@ -1958,7 +1958,6 @@ func TestDcpBackfill(t *testing.T) { newRt := rest.NewRestTester(t, &newRtConfig) defer newRt.Close() log.Printf("Poke the rest tester so it starts DCP processing:") - dataStore = newRt.GetSingleDataStore() backfillComplete := false var expectedBackfill, completedBackfill int diff --git a/rest/oidc_api_test.go b/rest/oidc_api_test.go index 58885176d8..ef4311f8e6 100644 --- a/rest/oidc_api_test.go +++ b/rest/oidc_api_test.go @@ -860,6 +860,7 @@ func TestOpenIDConnectAuthCodeFlow(t *testing.T) { client := &http.Client{Jar: jar} response, err := client.Do(request) require.NoError(t, err, "Error sending request") + defer func() { assert.NoError(t, response.Body.Close()) }() if (forceError{}) != tc.forceAuthError { assertHttpResponse(t, response, tc.forceAuthError) return @@ -868,7 +869,6 @@ func TestOpenIDConnectAuthCodeFlow(t *testing.T) { require.Equal(t, http.StatusOK, response.StatusCode) var authResponseActual OIDCTokenResponse require.NoError(t, err, json.NewDecoder(response.Body).Decode(&authResponseActual)) - require.NoError(t, response.Body.Close(), "Error closing response body") assert.NotEmpty(t, authResponseActual.SessionID, "session_id doesn't exist") assert.Equal(t, "foo_noah", authResponseActual.Username, "name mismatch") @@ -889,9 +889,9 @@ func TestOpenIDConnectAuthCodeFlow(t *testing.T) { request.Header.Add("Authorization", BearerToken+" "+authResponseActual.IDToken) response, err = client.Do(request) require.NoError(t, err, "Error sending request with bearer token") + defer func() { assert.NoError(t, response.Body.Close()) }() require.Equal(t, http.StatusOK, response.StatusCode) require.NoError(t, json.NewDecoder(response.Body).Decode(&responseBody)) - require.NoError(t, response.Body.Close(), "Error closing response body") assert.Equal(t, restTester.DatabaseConfig.Name, responseBody["db_name"]) // Refresh auth token using the refresh token received from OP. @@ -902,6 +902,7 @@ func TestOpenIDConnectAuthCodeFlow(t *testing.T) { require.NoError(t, err, "Error creating new request") response, err = client.Do(request) require.NoError(t, err, "Error sending request") + defer func() { assert.NoError(t, response.Body.Close()) }() if (forceError{}) != tc.forceRefreshError { assertHttpResponse(t, response, tc.forceRefreshError) return @@ -911,7 +912,6 @@ func TestOpenIDConnectAuthCodeFlow(t *testing.T) { // Validate received token refresh response. var refreshResponseActual OIDCTokenResponse require.NoError(t, err, json.NewDecoder(response.Body).Decode(&refreshResponseActual)) - require.NoError(t, response.Body.Close(), "Error closing response body") refreshResponseExpected := mockAuthServer.options.tokenResponse assert.NotEmpty(t, refreshResponseActual.SessionID, "session_id doesn't exist") assert.Equal(t, "foo_noah", refreshResponseActual.Username, "name mismatch") @@ -927,9 +927,9 @@ func TestOpenIDConnectAuthCodeFlow(t *testing.T) { request.Header.Add("Authorization", BearerToken+" "+refreshResponseActual.IDToken) response, err = client.Do(request) require.NoError(t, err, "Error sending request with bearer token") + defer func() { assert.NoError(t, response.Body.Close()) }() require.Equal(t, http.StatusOK, response.StatusCode) require.NoError(t, json.NewDecoder(response.Body).Decode(&responseBody)) - require.NoError(t, response.Body.Close(), "Error closing response body") assert.Equal(t, restTester.DatabaseConfig.Name, responseBody["db_name"]) // Make a keyspace-scoped request @@ -938,8 +938,8 @@ func TestOpenIDConnectAuthCodeFlow(t *testing.T) { request.Header.Add("Authorization", BearerToken+" "+refreshResponseActual.IDToken) response, err = client.Do(request) require.NoError(t, err, "Error sending request with bearer token") + defer func() { assert.NoError(t, response.Body.Close()) }() require.Equal(t, http.StatusCreated, response.StatusCode) - require.NoError(t, response.Body.Close(), "Error closing response body") }) } } @@ -1072,6 +1072,7 @@ func TestOpenIDConnectImplicitFlow(t *testing.T) { request := createOIDCRequest(t, sessionEndpoint, token) response, err := http.DefaultClient.Do(request) require.NoError(t, err, "Error sending request with bearer token") + defer func() { assert.NoError(t, response.Body.Close()) }() if (forceError{}) != tc.expectedError { assertHttpResponse(t, response, tc.expectedError) @@ -1290,6 +1291,7 @@ func TestOpenIDConnectImplicitFlowEdgeCases(t *testing.T) { runBadAuthTest := func(claimSet claimSet) { response, err := sendAuthRequest(claimSet) + defer func() { assert.NoError(t, response.Body.Close()) }() require.NoError(t, err, "Error sending request with bearer token") expectedAuthError := forceError{ expectedErrorCode: http.StatusUnauthorized, @@ -1300,6 +1302,7 @@ func TestOpenIDConnectImplicitFlowEdgeCases(t *testing.T) { runGoodAuthTest := func(claimSet claimSet, username string) { response, err := sendAuthRequest(claimSet) + defer func() { assert.NoError(t, response.Body.Close()) }() require.NoError(t, err, "Error sending request with bearer token") checkGoodAuthResponse(t, restTester, response, username) } @@ -1930,6 +1933,7 @@ func TestCallbackStateClientCookies(t *testing.T) { t.Run("unsuccessful auth when callback state enabled with no cookies support from client", func(t *testing.T) { response, err := http.DefaultClient.Do(request) require.NoError(t, err, "Error sending request") + defer func() { assert.NoError(t, response.Body.Close()) }() expectedAuthError := forceError{ expectedErrorCode: http.StatusBadRequest, expectedErrorMessage: ErrNoStateCookie.Message, @@ -1943,10 +1947,10 @@ func TestCallbackStateClientCookies(t *testing.T) { client := &http.Client{Jar: jar} response, err := client.Do(request) require.NoError(t, err, "Error sending request") + defer func() { assert.NoError(t, response.Body.Close()) }() require.Equal(t, http.StatusOK, response.StatusCode) var authResponseActual OIDCTokenResponse require.NoError(t, err, json.NewDecoder(response.Body).Decode(&authResponseActual)) - require.NoError(t, response.Body.Close(), "Error closing response body") assert.NotEmpty(t, authResponseActual.SessionID, "session_id doesn't exist") assert.Equal(t, "foo_noah", authResponseActual.Username, "name mismatch") }) @@ -1955,11 +1959,11 @@ func TestCallbackStateClientCookies(t *testing.T) { restTester.DatabaseConfig.OIDCConfig.Providers.GetDefaultProvider().DisableCallbackState = true response, err := http.DefaultClient.Do(request) require.NoError(t, err, "Error sending request") + defer func() { assert.NoError(t, response.Body.Close()) }() require.Equal(t, http.StatusOK, response.StatusCode) var authResponseActual OIDCTokenResponse require.NoError(t, err, json.NewDecoder(response.Body).Decode(&authResponseActual)) - require.NoError(t, response.Body.Close(), "Error closing response body") assert.NotEmpty(t, authResponseActual.SessionID, "session_id doesn't exist") assert.Equal(t, "foo_noah", authResponseActual.Username, "name mismatch") }) @@ -2187,6 +2191,7 @@ func TestOpenIDConnectAuthCodeFlowWithUsernameClaim(t *testing.T) { client := &http.Client{Jar: jar} response, err := client.Do(request) require.NoError(t, err, "Error sending request") + defer func() { assert.NoError(t, response.Body.Close()) }() if (forceError{}) != tc.authErrorExpected { assertHttpResponse(t, response, tc.authErrorExpected) return @@ -2195,7 +2200,6 @@ func TestOpenIDConnectAuthCodeFlowWithUsernameClaim(t *testing.T) { require.Equal(t, http.StatusOK, response.StatusCode) var authResponseActual OIDCTokenResponse require.NoError(t, err, json.NewDecoder(response.Body).Decode(&authResponseActual)) - require.NoError(t, response.Body.Close(), "Error closing response body") assert.NotEmpty(t, authResponseActual.SessionID, "session_id doesn't exist") expectedUsername := tc.usernameExpected if strings.Contains(expectedUsername, "$issuer") { @@ -2215,9 +2219,9 @@ func TestOpenIDConnectAuthCodeFlowWithUsernameClaim(t *testing.T) { request.Header.Add("Authorization", BearerToken+" "+authResponseActual.IDToken) response, err = client.Do(request) require.NoError(t, err, "Error sending request with bearer token") + defer func() { assert.NoError(t, response.Body.Close()) }() require.Equal(t, http.StatusOK, response.StatusCode) require.NoError(t, json.NewDecoder(response.Body).Decode(&responseBody)) - require.NoError(t, response.Body.Close(), "Error closing response body") assert.Equal(t, restTester.DatabaseConfig.Name, responseBody["db_name"]) }) } @@ -2287,6 +2291,7 @@ func TestEventuallyReachableOIDCClient(t *testing.T) { request := createOIDCRequest(t, sessionEndpoint, token) response, err := http.DefaultClient.Do(request) require.NoError(t, err, "Error sending request with bearer token") + defer func() { assert.NoError(t, response.Body.Close()) }() assert.Equal(t, http.StatusUnauthorized, response.StatusCode) // Status code when unreachable // Now reachable - success @@ -2294,6 +2299,7 @@ func TestEventuallyReachableOIDCClient(t *testing.T) { request = createOIDCRequest(t, sessionEndpoint, token) response, err = http.DefaultClient.Do(request) require.NoError(t, err, "Error sending request with bearer token") + defer func() { assert.NoError(t, response.Body.Close()) }() checkGoodAuthResponse(t, restTester, response, "foo_noah") // Unreachable again after being reachable - still success @@ -2301,6 +2307,7 @@ func TestEventuallyReachableOIDCClient(t *testing.T) { request = createOIDCRequest(t, sessionEndpoint, token) response, err = http.DefaultClient.Do(request) require.NoError(t, err, "Error sending request with bearer token") + defer func() { assert.NoError(t, response.Body.Close()) }() checkGoodAuthResponse(t, restTester, response, "foo_noah") }) } @@ -2415,11 +2422,11 @@ func TestOpenIDConnectRolesChannelsClaims(t *testing.T) { client := &http.Client{Jar: jar} response, err := client.Do(request) require.NoError(t, err, "Error sending request") + defer func() { assert.NoError(t, response.Body.Close()) }() // Validate received token response require.Equal(t, http.StatusOK, response.StatusCode) var authResponseActual OIDCTokenResponse require.NoError(t, err, json.NewDecoder(response.Body).Decode(&authResponseActual)) - require.NoError(t, response.Body.Close(), "Error closing response body") assert.NotEmpty(t, authResponseActual.SessionID, "session_id doesn't exist") authResponseExpected := mockAuthServer.options.tokenResponse diff --git a/rest/server_context.go b/rest/server_context.go index dfe80c45ee..dc1a112bb1 100644 --- a/rest/server_context.go +++ b/rest/server_context.go @@ -1979,7 +1979,7 @@ func doHTTPAuthRequest(ctx context.Context, httpClient *http.Client, username, p req.SetBasicAuth(username, password) - httpResponse, err = httpClient.Do(req) + httpResponse, err = httpClient.Do(req) // nolint:bodyclose // The body is closed outside of the worker loop if err == nil { return false, nil, httpResponse } @@ -1992,7 +1992,7 @@ func doHTTPAuthRequest(ctx context.Context, httpClient *http.Client, username, p return false, err, nil } - err, result := base.RetryLoop(ctx, "", worker, base.CreateSleeperFunc(10, 100)) + err, result := base.RetryLoop(ctx, "doHTTPAuthRequest", worker, base.CreateSleeperFunc(10, 100)) if err != nil { return 0, nil, err } diff --git a/rest/session_test.go b/rest/session_test.go index e5649b393f..7f54053686 100644 --- a/rest/session_test.go +++ b/rest/session_test.go @@ -280,7 +280,9 @@ func TestCustomCookieName(t *testing.T) { assert.Equal(t, 200, resp.Code) // Extract the cookie from the create session response to verify the "Set-Cookie" value returned by Sync Gateway - cookies := resp.Result().Cookies() + result := resp.Result() + defer func() { assert.NoError(t, result.Body.Close()) }() + cookies := result.Cookies() assert.True(t, len(cookies) == 1) cookie := cookies[0] assert.Equal(t, customCookieName, cookie.Name) @@ -290,13 +292,12 @@ func TestCustomCookieName(t *testing.T) { headers := map[string]string{} headers["Cookie"] = fmt.Sprintf("%s=%s", auth.DefaultCookieName, cookie.Value) resp = rt.SendRequestWithHeaders("GET", "/{{.keyspace}}/foo", `{}`, headers) - assert.Equal(t, 401, resp.Result().StatusCode) + RequireStatus(t, resp, http.StatusUnauthorized) // Attempt to use custom cookie name to authenticate headers["Cookie"] = fmt.Sprintf("%s=%s", customCookieName, cookie.Value) resp = rt.SendRequestWithHeaders("POST", "/{{.keyspace}}/", `{"_id": "foo", "key": "val"}`, headers) - assert.Equal(t, 200, resp.Result().StatusCode) - + RequireStatus(t, resp, http.StatusOK) } // Test that TTL values greater than the default max offset TTL 2592000 seconds are processed correctly diff --git a/rest/stats_context_test.go b/rest/stats_context_test.go index c09d00aa43..d28aa98657 100644 --- a/rest/stats_context_test.go +++ b/rest/stats_context_test.go @@ -52,6 +52,9 @@ func TestDescriptionPopulation(t *testing.T) { // Ensure metrics endpoint is accessible and that db database has entries resp, err := http.Get(srv.URL + "/_metrics") require.NoError(t, err) + defer func() { + assert.NoError(t, resp.Body.Close()) + }() assert.Equal(t, http.StatusOK, resp.StatusCode) bodyString, err := io.ReadAll(resp.Body) diff --git a/rest/utilities_testing.go b/rest/utilities_testing.go index e49349709c..e43c0c72ce 100644 --- a/rest/utilities_testing.go +++ b/rest/utilities_testing.go @@ -1093,11 +1093,18 @@ type TestResponse struct { // BodyBytes takes a copy of the bytes in the response buffer, and saves them for future callers. func (r TestResponse) BodyBytes() []byte { if r.bodyCache == nil { + // since we are reading the underlying write buffer here, we do not need to close. If call r.Result().Body, then this needs to be closed. r.bodyCache = r.ResponseRecorder.Body.Bytes() } return r.bodyCache } +// BodyString returns the string of the response body. This is cached once the first time it is called. +func (r TestResponse) BodyString() string { + return string(r.BodyBytes()) +} + +// DumpBody returns the byte array of the response body. This is cached once the first time it is called. func (r TestResponse) DumpBody() { log.Printf("%v", r.Body.String()) } diff --git a/rest/utilities_testing_functions_api_test.go b/rest/utilities_testing_functions_api_test.go index 76179e53ce..cb2e661a9a 100644 --- a/rest/utilities_testing_functions_api_test.go +++ b/rest/utilities_testing_functions_api_test.go @@ -10,6 +10,7 @@ package rest import ( "log" + "net/http" "sync" "testing" "time" @@ -100,8 +101,7 @@ func TestFunctions(t *testing.T) { response := rt.SendRequest("POST", "/{{.db}}/_graphql", `{"query": "query($number:Int!){ square(n:$number) }", "variables": {"number": 13}}`) - return assert.Equal(t, 200, response.Result().StatusCode) && - assert.Equal(t, "{\"data\":{\"square\":169}}", string(response.BodyBytes())) + return AssertStatus(t, response, http.StatusOK) && assert.Equal(t, "{\"data\":{\"square\":169}}", response.BodyString()) }) }) } @@ -113,15 +113,14 @@ func TestFunctionsConcurrently(t *testing.T) { t.Run("Function", func(t *testing.T) { testConcurrently(t, rt, func() bool { response := rt.SendRequest("GET", "/{{.db}}/_function/square?n=13", "") - return assert.Equal(t, 200, response.Result().StatusCode) && - assert.Equal(t, "169", string(response.BodyBytes())) + return AssertStatus(t, response, http.StatusOK) && assert.Equal(t, "169", response.BodyString()) }) }) t.Run("GraphQL", func(t *testing.T) { testConcurrently(t, rt, func() bool { response := rt.SendRequest("POST", "/{{.db}}/_graphql", `{"query":"query{ square(n:13) }"}`) - return assert.Equal(t, 200, response.Result().StatusCode) && + return AssertStatus(t, response, http.StatusOK) && assert.Equal(t, "{\"data\":{\"square\":169}}", string(response.BodyBytes())) }) }) @@ -132,7 +131,7 @@ func TestFunctionsConcurrently(t *testing.T) { } else { testConcurrently(t, rt, func() bool { response := rt.SendRequest("GET", "/{{.db}}/_function/squareN1QL?n=13", "") - return assert.Equal(t, 200, response.Result().StatusCode) && + return AssertStatus(t, response, http.StatusOK) && assert.Equal(t, "[{\"square\":169}\n]\n", string(response.BodyBytes())) }) } diff --git a/rest/utilities_testing_test.go b/rest/utilities_testing_test.go index dec063f872..55807de406 100644 --- a/rest/utilities_testing_test.go +++ b/rest/utilities_testing_test.go @@ -150,6 +150,7 @@ func TestCECheck(t *testing.T) { resp, err := http.DefaultClient.Do(req) require.NoError(t, err) + require.NoError(t, resp.Body.Close()) require.Equal(t, resp.StatusCode, http.StatusBadRequest) } diff --git a/rest/utilities_testing_user.go b/rest/utilities_testing_user.go index f880ed1b0f..5dcbad7144 100644 --- a/rest/utilities_testing_user.go +++ b/rest/utilities_testing_user.go @@ -35,22 +35,20 @@ func MakeUser(t *testing.T, httpClient *http.Client, serverURL, username, passwo resp, err := httpClient.Do(req) if err != nil { - return true, err, resp + return true, err, nil } - return false, err, resp + defer func() { assert.NoError(t, resp.Body.Close()) }() + if resp.StatusCode != http.StatusOK { + bodyResp, err := io.ReadAll(resp.Body) + assert.NoError(t, err, "Failed to create user: %s", bodyResp) + } + require.Equal(t, http.StatusOK, resp.StatusCode) + return false, err, nil } - err, resp := base.RetryLoop(base.TestCtx(t), "Admin Auth testing MakeUser", retryWorker, base.CreateSleeperFunc(10, 100)) + err, _ := base.RetryLoop(base.TestCtx(t), "Admin Auth testing MakeUser", retryWorker, base.CreateSleeperFunc(10, 100)) require.NoError(t, err) - if resp.(*http.Response).StatusCode != http.StatusOK { - bodyResp, err := io.ReadAll(resp.(*http.Response).Body) - assert.NoError(t, err) - fmt.Println(string(bodyResp)) - } - require.Equal(t, http.StatusOK, resp.(*http.Response).StatusCode) - - require.NoError(t, resp.(*http.Response).Body.Close(), "Error closing response body") } func DeleteUser(t *testing.T, httpClient *http.Client, serverURL, username string) { @@ -64,6 +62,7 @@ func DeleteUser(t *testing.T, httpClient *http.Client, serverURL, username strin if err != nil { return true, err, resp } + assert.NoError(t, resp.Body.Close()) return false, err, resp }