Skip to content

Commit

Permalink
implement body.Close checker (#6646)
Browse files Browse the repository at this point in the history
* implement body.Close

* Add some missing asserts, simplify MakeUser

* add missing error check
  • Loading branch information
torcolvin authored Jan 19, 2024
1 parent 87cca79 commit d00d704
Show file tree
Hide file tree
Showing 19 changed files with 261 additions and 245 deletions.
3 changes: 1 addition & 2 deletions .golangci-strict.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ run:

linters:
enable:
#- bodyclose # checks whether HTTP response body is closed successfully
- bodyclose # checks whether HTTP response body is closed successfully
#- dupl # Tool for code clone detection
- errcheck # Errcheck is a program for checking for unchecked errors in go programs. These unchecked errors can be critical bugs in some cases
#- goconst # Finds repeated strings that could be replaced by a constant
Expand Down Expand Up @@ -59,7 +59,6 @@ linters:
- whitespace # Tool for detection of leading and trailing whitespace
- wsl # Whitespace Linter - Forces you to use empty lines!
# Once fixed, should enable
- bodyclose # checks whether HTTP response body is closed successfully
- deadcode # Finds unused code
- dupl # Tool for code clone detection
- goconst # Finds repeated strings that could be replaced by a constant
Expand Down
7 changes: 0 additions & 7 deletions base/bucket.go
Original file line number Diff line number Diff line change
Expand Up @@ -555,13 +555,6 @@ func retrievePurgeInterval(ctx context.Context, bucket CouchbaseBucketStore, uri
return time.Duration(purgeIntervalHours) * time.Hour, nil
}

func ensureBodyClosed(ctx context.Context, body io.ReadCloser) {
err := body.Close()
if err != nil {
DebugfCtx(ctx, KeyBucket, "Failed to close socket: %v", err)
}
}

// AsViewStore returns a ViewStore if the underlying dataStore implements ViewStore.
func AsViewStore(ds DataStore) (sgbucket.ViewStore, bool) {
viewStore, ok := ds.(sgbucket.ViewStore)
Expand Down
7 changes: 6 additions & 1 deletion base/bucket_gocb.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,12 @@ func putDDocForTombstones(ctx context.Context, name string, payload []byte, capi
return err
}

defer ensureBodyClosed(ctx, resp.Body)
defer func() {
err := resp.Body.Close()
if err != nil {
DebugfCtx(ctx, KeyBucket, "Failed to close socket: %v", err)
}
}()
if resp.StatusCode != 201 {
data, err := io.ReadAll(resp.Body)
if err != nil {
Expand Down
1 change: 1 addition & 0 deletions rest/adminapitest/admin_api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3865,6 +3865,7 @@ func setServerPurgeInterval(t *testing.T, rt *rest.RestTester, newPurgeInterval

resp, err := httpClient.Do(req)
require.NoError(t, err)
assert.NoError(t, resp.Body.Close())

require.Equal(t, resp.StatusCode, http.StatusOK)
}
Expand Down
1 change: 1 addition & 0 deletions rest/api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ func TestPublicRESTStatCount(t *testing.T) {
// test metrics endpoint
response, err := http.Get(srv.URL + "/_metrics")
require.NoError(t, err)
require.NoError(t, response.Body.Close())
assert.Equal(t, http.StatusOK, response.StatusCode)
// assert the stat doesn't increment
base.RequireWaitForStat(t, func() int64 {
Expand Down
1 change: 1 addition & 0 deletions rest/bytes_read_public_api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ func TestBytesReadDocOperations(t *testing.T) {
response, err := http.Get(srv.URL + "/_metrics")
require.NoError(t, err)
assert.Equal(t, http.StatusOK, response.StatusCode)
require.NoError(t, response.Body.Close())

base.RequireWaitForStat(t, func() int64 {
return rt.GetDatabase().DbStats.DatabaseStats.PublicRestBytesRead.Value()
Expand Down
101 changes: 51 additions & 50 deletions rest/functionsapitest/graphql_admin_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ package functionsapitest
import (
"encoding/json"
"fmt"
"net/http"
"os"
"testing"

Expand Down Expand Up @@ -178,56 +179,56 @@ func TestFunctionsConfigGetWithoutFeatureFlagGraphQL(t *testing.T) {

t.Run("GraphQL, Non-Admin", func(t *testing.T) {
response := rt.SendRequest("GET", "/db/_config/graphql", "")
assert.Equal(t, 404, response.Result().StatusCode)
rest.RequireStatus(t, response, http.StatusNotFound)
})
t.Run("GraphQL", func(t *testing.T) {
response := rt.SendAdminRequest("GET", "/db/_config/graphql", "")
assert.Equal(t, 404, response.Result().StatusCode)
rest.RequireStatus(t, response, http.StatusNotFound)
})
}

// This will be used both by functions and graphQL
func runTestFunctionsConfigMVCC(t *testing.T, rt *rest.RestTester, uri string, newValue string) {
// Get initial etag:
response := rt.SendAdminRequest("GET", uri, "")
assert.Equal(t, 200, response.Result().StatusCode)
rest.RequireStatus(t, response, http.StatusOK)
etag := response.HeaderMap.Get("Etag")
assert.Regexp(t, `"[^"]+"`, etag)

// Update config, just to change its etag:
response = rt.SendAdminRequest("PUT", uri, newValue)
assert.Equal(t, 200, response.Result().StatusCode)
rest.RequireStatus(t, response, http.StatusOK)
newEtag := response.HeaderMap.Get("Etag")
assert.Regexp(t, `"[^"]+"`, newEtag)
assert.NotEqual(t, etag, newEtag)

// A GET should also return the new etag:
response = rt.SendAdminRequest("GET", uri, "")
assert.Equal(t, 200, response.Result().StatusCode)
rest.RequireStatus(t, response, http.StatusOK)
assert.Equal(t, newEtag, response.HeaderMap.Get("Etag"))

// Try to update using If-Match with the old etag:
headers := map[string]string{"If-Match": etag}
response = rt.SendAdminRequestWithHeaders("PUT", uri, newValue, headers)
assert.Equal(t, 412, response.Result().StatusCode)
rest.RequireStatus(t, response, http.StatusPreconditionFailed)

// Now update successfully using the current etag:
headers["If-Match"] = newEtag
response = rt.SendAdminRequestWithHeaders("PUT", uri, newValue, headers)
assert.Equal(t, 200, response.Result().StatusCode)
rest.RequireStatus(t, response, http.StatusOK)
newestEtag := response.HeaderMap.Get("Etag")
assert.Regexp(t, `"[^"]+"`, newestEtag)
assert.NotEqual(t, etag, newestEtag)
assert.NotEqual(t, newEtag, newestEtag)

// Try to delete using If-Match with the previous etag:
response = rt.SendAdminRequestWithHeaders("DELETE", uri, newValue, headers)
assert.Equal(t, 412, response.Result().StatusCode)
rest.RequireStatus(t, response, http.StatusPreconditionFailed)

// Now delete successfully using the current etag:
headers["If-Match"] = newestEtag
response = rt.SendAdminRequestWithHeaders("DELETE", uri, newValue, headers)
assert.Equal(t, 200, response.Result().StatusCode)
rest.RequireStatus(t, response, http.StatusOK)
}

// Test use of "Etag" and "If-Match" headers to safely update graphql config.
Expand Down Expand Up @@ -259,11 +260,11 @@ func TestFunctionsConfigGraphQLGetEmpty(t *testing.T) {

t.Run("Non-Admin", func(t *testing.T) {
response := rt.SendRequest("GET", "/db/_config/graphql", "")
assert.Equal(t, 404, response.Result().StatusCode)
rest.RequireStatus(t, response, http.StatusNotFound)
})
t.Run("All", func(t *testing.T) {
response := rt.SendAdminRequest("GET", "/db/_config/graphql", "")
assert.Equal(t, 404, response.Result().StatusCode)
rest.RequireStatus(t, response, http.StatusNotFound)
})
}

Expand All @@ -281,7 +282,7 @@ func TestFunctionsConfigGraphQLGet(t *testing.T) {

t.Run("Non-Admin", func(t *testing.T) {
response := rt.SendRequest("GET", "/db/_config/graphql", "")
assert.Equal(t, 404, response.Result().StatusCode)
rest.RequireStatus(t, response, http.StatusNotFound)
})
t.Run("All", func(t *testing.T) {
response := rt.SendAdminRequest("GET", "/db/_config/graphql", "")
Expand All @@ -304,15 +305,15 @@ func TestFunctionsConfigGraphQLPut(t *testing.T) {

t.Run("Non-Admin", func(t *testing.T) {
response := rt.SendRequest("PUT", "/db/_config/graphql", "{}")
assert.Equal(t, 404, response.Result().StatusCode)
rest.RequireStatus(t, response, http.StatusNotFound)
response = rt.SendRequest("DELETE", "/db/_config/graphql", "{}")
assert.Equal(t, 404, response.Result().StatusCode)
rest.RequireStatus(t, response, http.StatusNotFound)
})
t.Run("ReplaceBogus", func(t *testing.T) {
response := rt.SendAdminRequest("PUT", "/db/_config/graphql", `{
"schema": "obviously not a valid schema ^_^"
}`)
assert.Equal(t, 400, response.Result().StatusCode)
rest.RequireStatus(t, response, http.StatusBadRequest)
})
t.Run("Replace", func(t *testing.T) {
response := rt.SendAdminRequest("PUT", "/db/_config/graphql", `{
Expand All @@ -326,20 +327,20 @@ func TestFunctionsConfigGraphQLPut(t *testing.T) {
}
}
}`)
assert.Equal(t, 200, response.Result().StatusCode)
rest.RequireStatus(t, response, http.StatusOK)

response = rt.SendAdminRequest("POST", "/db/_graphql", `{"query": "query{ sum(n:3) }"}`)
assert.Equal(t, 200, response.Result().StatusCode)
assert.Equal(t, `{"data":{"sum":6}}`, string(response.BodyBytes()))
rest.RequireStatus(t, response, http.StatusOK)
assert.JSONEq(t, `{"data":{"sum":6}}`, response.BodyString())
})
t.Run("Delete", func(t *testing.T) {
response := rt.SendAdminRequest("DELETE", "/db/_config/graphql", "")
assert.Equal(t, 200, response.Result().StatusCode)
rest.RequireStatus(t, response, http.StatusOK)

assert.Nil(t, rt.GetDatabase().Options.GraphQL)

response = rt.SendAdminRequest("POST", "/db/_graphql", `{"query": "query{ sum(n:3) }"}`)
assert.Equal(t, 503, response.Result().StatusCode)
rest.RequireStatus(t, response, http.StatusServiceUnavailable)
})
}

Expand Down Expand Up @@ -371,11 +372,11 @@ func TestValidGraphQLConfigurationValues(t *testing.T) {
},
"max_schema_size" : %d
}`, len(schema)))
assert.Equal(t, 200, response.Result().StatusCode)
rest.RequireStatus(t, response, http.StatusOK)

response = rt.SendAdminRequest("GET", "/db/_config/graphql", "")
assert.Equal(t, 200, response.Result().StatusCode)
assert.Contains(t, string(response.BodyBytes()), `"max_schema_size":32`)
rest.RequireStatus(t, response, http.StatusOK)
assert.Contains(t, response.BodyString(), `"max_schema_size":32`)
})

//If max_resolver_count >= given number of resolvers then Valid
Expand All @@ -397,11 +398,11 @@ func TestValidGraphQLConfigurationValues(t *testing.T) {
},
"max_resolver_count" : 2
}`)
assert.Equal(t, 200, response.Result().StatusCode)
rest.RequireStatus(t, response, http.StatusOK)

response = rt.SendAdminRequest("GET", "/db/_config/graphql", "")
assert.Equal(t, 200, response.Result().StatusCode)
assert.Contains(t, string(response.BodyBytes()), `"max_resolver_count":2`)
rest.RequireStatus(t, response, http.StatusOK)
assert.Contains(t, response.BodyString(), `"max_resolver_count":2`)
})

//If max_request_size >= length of JSON-encoded arguments passed to a function then Valid
Expand All @@ -419,29 +420,29 @@ func TestValidGraphQLConfigurationValues(t *testing.T) {
},
"max_request_size" : %d
}`, len(requestQuery)))
assert.Equal(t, 200, response.Result().StatusCode)
rest.RequireStatus(t, response, http.StatusOK)

response = rt.SendAdminRequest("POST", "/db/_graphql", requestQuery)
assert.Equal(t, 200, response.Result().StatusCode)
assert.Equal(t, `{"data":{"square":16}}`, string(response.BodyBytes()))
rest.RequireStatus(t, response, http.StatusOK)
assert.Equal(t, `{"data":{"square":16}}`, response.BodyString())

response = rt.SendAdminRequest("GET", "/db/_config/graphql", "")
assert.Equal(t, 200, response.Result().StatusCode)
assert.Contains(t, string(response.BodyBytes()), fmt.Sprintf(`"max_request_size":%d`, len(requestQuery)))
rest.RequireStatus(t, response, http.StatusOK)
assert.Contains(t, response.BodyString(), fmt.Sprintf(`"max_request_size":%d`, len(requestQuery)))

headerMap := map[string]string{
"Content-type": "application/graphql",
}
response = rt.SendAdminRequestWithHeaders("POST", "/db/_graphql", `query{square(n:4)}`, headerMap)
assert.Equal(t, 200, response.Result().StatusCode)
assert.Equal(t, `{"data":{"square":16}}`, string(response.BodyBytes()))
rest.RequireStatus(t, response, http.StatusOK)
assert.Equal(t, response.BodyString(), `{"data":{"square":16}}`)

queryParam := `query($numberToBeSquared:Int!){ square(n:$numberToBeSquared) }`
variableParam := `{"numberToBeSquared": 4}`
getRequestUrl := fmt.Sprintf("/db/_graphql?query=%s&variables=%s", queryParam, variableParam)
response = rt.SendAdminRequest("GET", getRequestUrl, "")
assert.Equal(t, 200, response.Result().StatusCode)
assert.Equal(t, `{"data":{"square":16}}`, string(response.BodyBytes()))
rest.RequireStatus(t, response, http.StatusOK)
assert.Equal(t, response.BodyString(), `{"data":{"square":16}}`)

})

Expand All @@ -462,13 +463,13 @@ func TestValidGraphQLConfigurationValues(t *testing.T) {
}
}
}`)
assert.Equal(t, 200, response.Result().StatusCode)
rest.AssertStatus(t, response, http.StatusOK)
err = os.Remove("schema.graphql")
assert.NoError(t, err)

response = rt.SendAdminRequest("GET", "/db/_config/graphql", "")
assert.Equal(t, 200, response.Result().StatusCode)
assert.Contains(t, string(response.BodyBytes()), `"schemaFile":"schema.graphql"`)
rest.RequireStatus(t, response, http.StatusOK)
assert.Contains(t, response.BodyString(), `"schemaFile":"schema.graphql"`)
})
}

Expand Down Expand Up @@ -506,7 +507,7 @@ func TestInvalidGraphQLConfigurationValues(t *testing.T) {
err := json.Unmarshal([]byte(string(response.BodyBytes())), &responseMap)
assert.NoError(t, err)

assert.Equal(t, 400, response.Result().StatusCode)
rest.AssertStatus(t, response, http.StatusBadRequest)
assert.Contains(t, responseMap["reason"], "GraphQL schema too large")
assert.Contains(t, responseMap["error"], "Bad Request")
})
Expand Down Expand Up @@ -535,7 +536,7 @@ func TestInvalidGraphQLConfigurationValues(t *testing.T) {
err := json.Unmarshal([]byte(string(response.BodyBytes())), &responseMap)
assert.NoError(t, err)

assert.Equal(t, 400, response.Result().StatusCode)
rest.AssertStatus(t, response, http.StatusBadRequest)
assert.Contains(t, responseMap["reason"], "too many GraphQL resolvers")
assert.Contains(t, responseMap["error"], "Bad Request")

Expand All @@ -557,24 +558,24 @@ func TestInvalidGraphQLConfigurationValues(t *testing.T) {
},
"max_request_size" : 5
}`)
assert.Equal(t, 200, response.Result().StatusCode)
rest.RequireStatus(t, response, http.StatusOK)

response = rt.SendAdminRequest("POST", "/db/_graphql", `{"query": "query($numberToBeSquared:Int!){ square(n:$numberToBeSquared) }", "variables": {"numberToBeSquared": 4}}`)

var responseMap map[string]interface{}
err := json.Unmarshal([]byte(string(response.BodyBytes())), &responseMap)
assert.NoError(t, err)

assert.Equal(t, 413, response.Result().StatusCode)
rest.RequireStatus(t, response, http.StatusRequestEntityTooLarge)
assert.Contains(t, responseMap["reason"], "Arguments too large")
assert.Contains(t, responseMap["error"], "Request Entity Too Large")

headerMap := map[string]string{
"Content-type": "application/graphql",
}
response = rt.SendAdminRequestWithHeaders("POST", "/db/_graphql", `query{square(n:4)}`, headerMap)
assert.Equal(t, 413, response.Result().StatusCode)
err = json.Unmarshal([]byte(string(response.BodyBytes())), &responseMap)
rest.RequireStatus(t, response, http.StatusRequestEntityTooLarge)
err = json.Unmarshal(response.BodyBytes(), &responseMap)
assert.NoError(t, err)
assert.Contains(t, responseMap["reason"], "Arguments too large")
assert.Contains(t, responseMap["error"], "Request Entity Too Large")
Expand All @@ -584,8 +585,8 @@ func TestInvalidGraphQLConfigurationValues(t *testing.T) {
getRequestUrl := fmt.Sprintf("/db/_graphql?query=%s&variables=%s", queryParam, variableParam)

response = rt.SendAdminRequest("GET", getRequestUrl, "")
assert.Equal(t, 200, response.Result().StatusCode)
err = json.Unmarshal([]byte(string(response.BodyBytes())), &responseMap)
rest.RequireStatus(t, response, http.StatusOK)
err = json.Unmarshal(response.BodyBytes(), &responseMap)
assert.NoError(t, err)
assert.Contains(t, responseMap["reason"], "Arguments too large")
assert.Contains(t, responseMap["error"], "Request Entity Too Large")
Expand All @@ -611,7 +612,7 @@ func TestInvalidGraphQLConfigurationValues(t *testing.T) {
err := json.Unmarshal([]byte(string(response.BodyBytes())), &responseMap)
assert.NoError(t, err)

assert.Equal(t, 400, response.Result().StatusCode)
rest.RequireStatus(t, response, http.StatusBadRequest)
assert.Contains(t, responseMap["reason"], "GraphQL config: only one of `schema` and `schemaFile` may be used")
assert.Contains(t, responseMap["error"], "Bad Request")
})
Expand All @@ -637,7 +638,7 @@ func TestInvalidGraphQLConfigurationValues(t *testing.T) {
err = json.Unmarshal([]byte(string(response.BodyBytes())), &responseMap)
assert.NoError(t, err)

assert.Equal(t, 400, response.Result().StatusCode)
rest.RequireStatus(t, response, http.StatusBadRequest)
assert.Contains(t, responseMap["reason"], "Syntax Error GraphQL")
assert.Contains(t, responseMap["reason"], "Unexpected Name")
assert.Contains(t, responseMap["error"], "Bad Request")
Expand All @@ -659,8 +660,8 @@ func TestSchemaSyntax(t *testing.T) {

t.Run("Non-Admin", func(t *testing.T) {
response := rt.SendRequest("PUT", "/db/_config/graphql", "{}")
assert.Equal(t, 404, response.Result().StatusCode)
rest.RequireStatus(t, response, http.StatusNotFound)
response = rt.SendRequest("DELETE", "/db/_config/graphql", "{}")
assert.Equal(t, 404, response.Result().StatusCode)
rest.RequireStatus(t, response, http.StatusNotFound)
})
}
Loading

0 comments on commit d00d704

Please sign in to comment.