diff --git a/api-get-options.go b/api-get-options.go index bb86a5994..a0216e201 100644 --- a/api-get-options.go +++ b/api-get-options.go @@ -87,10 +87,10 @@ func (o *GetObjectOptions) Set(key, value string) { } // SetReqParam - set request query string parameter -// supported key: see supportedQueryValues. +// supported key: see supportedQueryValues and allowedCustomQueryPrefix. // If an unsupported key is passed in, it will be ignored and nothing will be done. func (o *GetObjectOptions) SetReqParam(key, value string) { - if !isStandardQueryValue(key) { + if !isCustomQueryValue(key) && !isStandardQueryValue(key) { // do nothing return } @@ -101,10 +101,10 @@ func (o *GetObjectOptions) SetReqParam(key, value string) { } // AddReqParam - add request query string parameter -// supported key: see supportedQueryValues. +// supported key: see supportedQueryValues and allowedCustomQueryPrefix. // If an unsupported key is passed in, it will be ignored and nothing will be done. func (o *GetObjectOptions) AddReqParam(key, value string) { - if !isStandardQueryValue(key) { + if !isCustomQueryValue(key) && !isStandardQueryValue(key) { // do nothing return } diff --git a/get-options_test.go b/get-options_test.go index 92d1835ea..eb492f676 100644 --- a/get-options_test.go +++ b/get-options_test.go @@ -57,3 +57,45 @@ func TestSetHeader(t *testing.T) { } } } + +func TestCustomQueryParameters(t *testing.T) { + var ( + paramKey = "x-test-param" + paramValue = "test-value" + + invalidParamKey = "invalid-test-param" + invalidParamValue = "invalid-test-param" + ) + + testCases := []struct { + setParamsFunc func(o *GetObjectOptions) + }{ + {func(o *GetObjectOptions) { + o.AddReqParam(paramKey, paramValue) + o.AddReqParam(invalidParamKey, invalidParamValue) + }}, + {func(o *GetObjectOptions) { + o.SetReqParam(paramKey, paramValue) + o.SetReqParam(invalidParamKey, invalidParamValue) + }}, + } + + for i, testCase := range testCases { + opts := GetObjectOptions{} + testCase.setParamsFunc(&opts) + + // This and the following checks indirectly ensure that only the expected + // valid header is added. + if len(opts.reqParams) != 1 { + t.Errorf("Test %d: Expected 1 kv-pair in query parameters, got %v", i+1, len(opts.reqParams)) + } + + if v, ok := opts.reqParams[paramKey]; !ok { + t.Errorf("Test %d: Expected query parameter with key %s missing", i+1, paramKey) + } else if len(v) != 1 { + t.Errorf("Test %d: Expected 1 value for query parameter with key %s, got %d values", i+1, paramKey, len(v)) + } else if v[0] != paramValue { + t.Errorf("Test %d: Expected query value %s for key %s, got %s", i+1, paramValue, paramKey, v[0]) + } + } +} diff --git a/utils.go b/utils.go index 6a93561ea..e39eba028 100644 --- a/utils.go +++ b/utils.go @@ -528,6 +528,14 @@ func isStandardQueryValue(qsKey string) bool { return supportedQueryValues[qsKey] } +// Per documentation at https://docs.aws.amazon.com/AmazonS3/latest/userguide/LogFormat.html#LogFormatCustom, the +// set of query params starting with "x-" are ignored by S3. +const allowedCustomQueryPrefix = "x-" + +func isCustomQueryValue(qsKey string) bool { + return strings.HasPrefix(qsKey, allowedCustomQueryPrefix) +} + var ( md5Pool = sync.Pool{New: func() interface{} { return md5.New() }} sha256Pool = sync.Pool{New: func() interface{} { return sha256.New() }} diff --git a/utils_test.go b/utils_test.go index 9b944ef84..117fdbb15 100644 --- a/utils_test.go +++ b/utils_test.go @@ -408,3 +408,24 @@ func TestIsAmzHeader(t *testing.T) { } } } + +// Tests if query parameter starts with "x-" and will be ignored by S3. +func TestIsCustomQueryValue(t *testing.T) { + testCases := []struct { + // Input. + queryParamKey string + // Expected result. + expectedValue bool + }{ + {"x-custom-key", true}, + {"xcustom-key", false}, + {"random-header", false}, + } + + for i, testCase := range testCases { + actual := isCustomQueryValue(testCase.queryParamKey) + if actual != testCase.expectedValue { + t.Errorf("Test %d: Expected to pass, but failed", i+1) + } + } +}