diff --git a/openapi3filter/validate_request.go b/openapi3filter/validate_request.go index f21060a9..296403c9 100644 --- a/openapi3filter/validate_request.go +++ b/openapi3filter/validate_request.go @@ -7,6 +7,7 @@ import ( "fmt" "io" "net/http" + "net/url" "sort" "github.com/getkin/kin-openapi/openapi3" @@ -103,6 +104,28 @@ func ValidateRequest(ctx context.Context, input *RequestValidationInput) (err er return } +// appendToQueryValues adds to query parameters each value in the provided slice +func appendToQueryValues[T any](q url.Values, parameterName string, v []T) { + for _, i := range v { + q.Add(parameterName, fmt.Sprintf("%v", i)) + } +} + +// populateDefaultQueryParameters populates default values inside query parameters, while ensuring types are respected +func populateDefaultQueryParameters(q url.Values, parameterName string, value any) { + switch t := value.(type) { + case []string: + appendToQueryValues(q, parameterName, t) + case []float64: + appendToQueryValues(q, parameterName, t) + case []int: + appendToQueryValues(q, parameterName, t) + default: + q.Add(parameterName, fmt.Sprintf("%v", value)) + } + +} + // ValidateParameter validates a parameter's value by JSON schema. // The function returns RequestError with a ParseError cause when unable to parse a value. // The function returns RequestError with ErrInvalidRequired cause when a value of a required parameter is not defined. @@ -156,7 +179,7 @@ func ValidateParameter(ctx context.Context, input *RequestValidationInput, param // Next check `parameter.Required && !found` will catch this. case openapi3.ParameterInQuery: q := req.URL.Query() - q.Add(parameter.Name, fmt.Sprintf("%v", value)) + populateDefaultQueryParameters(q, parameter.Name, value) req.URL.RawQuery = q.Encode() case openapi3.ParameterInHeader: req.Header.Add(parameter.Name, fmt.Sprintf("%v", value)) diff --git a/openapi3filter/validate_request_test.go b/openapi3filter/validate_request_test.go index 984af4c2..505e5cb9 100644 --- a/openapi3filter/validate_request_test.go +++ b/openapi3filter/validate_request_test.go @@ -431,8 +431,6 @@ func TestValidateQueryParams(t *testing.T) { }, }, }, - // - // } for _, tc := range testCases { @@ -569,3 +567,94 @@ paths: }) require.Error(t, err) } + +var ( + StringArraySchemaWithDefault = &openapi3.SchemaRef{ + Value: &openapi3.Schema{ + Type: &openapi3.Types{"array"}, + Items: stringSchema, + Default: []string{"A", "B", "C"}, + }, + } + FloatArraySchemaWithDefault = &openapi3.SchemaRef{ + Value: &openapi3.Schema{ + Type: &openapi3.Types{"array"}, + Items: numberSchema, + Default: []float64{1.5, 2.5, 3.5}, + }, + } +) + +func TestValidateRequestDefault(t *testing.T) { + type testCase struct { + name string + param *openapi3.Parameter + query string + wantQuery map[string][]string + wantHeader map[string]any + } + + testCases := []testCase{ + { + name: "String Array In Query", + param: &openapi3.Parameter{ + Name: "param", In: "query", Style: "form", Explode: explode, + Schema: StringArraySchemaWithDefault, + }, + wantQuery: map[string][]string{ + "param": { + "A", + "B", + "C", + }, + }, + }, + { + name: "Float Array In Query", + param: &openapi3.Parameter{ + Name: "param", In: "query", Style: "form", Explode: explode, + Schema: FloatArraySchemaWithDefault, + }, + wantQuery: map[string][]string{ + "param": { + "1.5", + "2.5", + "3.5", + }, + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + info := &openapi3.Info{ + Title: "MyAPI", + Version: "0.1", + } + doc := &openapi3.T{OpenAPI: "3.0.0", Info: info, Paths: openapi3.NewPaths()} + op := &openapi3.Operation{ + OperationID: "test", + Parameters: []*openapi3.ParameterRef{{Value: tc.param}}, + Responses: openapi3.NewResponses(), + } + doc.AddOperation("/test", http.MethodGet, op) + err := doc.Validate(context.Background()) + require.NoError(t, err) + router, err := legacyrouter.NewRouter(doc) + require.NoError(t, err) + + req, err := http.NewRequest(http.MethodGet, "http://test.org/test?"+tc.query, nil) + route, pathParams, err := router.FindRoute(req) + require.NoError(t, err) + + input := &RequestValidationInput{Request: req, PathParams: pathParams, Route: route} + + err = ValidateParameter(context.Background(), input, tc.param) + require.NoError(t, err) + + for k, v := range tc.wantQuery { + require.Equal(t, v, input.Request.URL.Query()[k]) + } + }) + } +}