diff --git a/lambda/handler.go b/lambda/handler.go index e4cfaf7a..cd55f7af 100644 --- a/lambda/handler.go +++ b/lambda/handler.go @@ -22,12 +22,14 @@ type Handler interface { type handlerOptions struct { handlerFunc - baseContext context.Context - jsonResponseEscapeHTML bool - jsonResponseIndentPrefix string - jsonResponseIndentValue string - enableSIGTERM bool - sigtermCallbacks []func() + baseContext context.Context + jsonRequestUseNumber bool + jsonRequestDisallowUnknownFields bool + jsonResponseEscapeHTML bool + jsonResponseIndentPrefix string + jsonResponseIndentValue string + enableSIGTERM bool + sigtermCallbacks []func() } type Option func(*handlerOptions) @@ -81,6 +83,38 @@ func WithSetIndent(prefix, indent string) Option { }) } +// WithUseNumber sets the UseNumber option on the underlying json decoder +// +// Usage: +// +// lambda.StartWithOptions( +// func (event any) (any, error) { +// return event, nil +// }, +// lambda.WithUseNumber(true) +// ) +func WithUseNumber(useNumber bool) Option { + return Option(func(h *handlerOptions) { + h.jsonRequestUseNumber = useNumber + }) +} + +// WithUseNumber sets the DisallowUnknownFields option on the underlying json decoder +// +// Usage: +// +// lambda.StartWithOptions( +// func (event any) (any, error) { +// return event, nil +// }, +// lambda.WithDisallowUnknownFields(true) +// ) +func WithDisallowUnknownFields(disallowUnknownFields bool) Option { + return Option(func(h *handlerOptions) { + h.jsonRequestDisallowUnknownFields = disallowUnknownFields + }) +} + // WithEnableSIGTERM enables SIGTERM behavior within the Lambda platform on container spindown. // SIGKILL will occur ~500ms after SIGTERM. // Optionally, an array of callback functions to run on SIGTERM may be provided. @@ -267,6 +301,12 @@ func reflectHandler(f interface{}, h *handlerOptions) handlerFunc { out.Reset() in := bytes.NewBuffer(payload) decoder := json.NewDecoder(in) + if h.jsonRequestUseNumber { + decoder.UseNumber() + } + if h.jsonRequestDisallowUnknownFields { + decoder.DisallowUnknownFields() + } encoder := json.NewEncoder(out) encoder.SetEscapeHTML(h.jsonResponseEscapeHTML) encoder.SetIndent(h.jsonResponseIndentPrefix, h.jsonResponseIndentValue) diff --git a/lambda/handler_test.go b/lambda/handler_test.go index 87942900..6eba8798 100644 --- a/lambda/handler_test.go +++ b/lambda/handler_test.go @@ -9,6 +9,7 @@ import ( "fmt" "io" "io/ioutil" //nolint: staticcheck + "reflect" "strings" "testing" "time" @@ -309,6 +310,54 @@ func TestInvokes(t *testing.T) { }, options: []Option{WithSetIndent(">>", " ")}, }, + { + name: "WithUseNumber(true) results in json.Number instead of float64 when decoding to an interface{}", + input: `19.99`, + expected: expected{`"Number"`, nil}, + handler: func(event interface{}) (string, error) { + return reflect.TypeOf(event).Name(), nil + }, + options: []Option{WithUseNumber(true)}, + }, + { + name: "WithUseNumber(false)", + input: `19.99`, + expected: expected{`"float64"`, nil}, + handler: func(event interface{}) (string, error) { + return reflect.TypeOf(event).Name(), nil + }, + options: []Option{WithUseNumber(false)}, + }, + { + name: "No decoder options provided is the same as WithUseNumber(false)", + input: `19.99`, + expected: expected{`"float64"`, nil}, + handler: func(event interface{}) (string, error) { + return reflect.TypeOf(event).Name(), nil + }, + options: []Option{}, + }, + { + name: "WithDisallowUnknownFields(true)", + input: `{"Hello": "World"}`, + expected: expected{"", errors.New(`json: unknown field "Hello"`)}, + handler: func(_ struct{}) {}, + options: []Option{WithDisallowUnknownFields(true)}, + }, + { + name: "WithDisallowUnknownFields(false)", + input: `{"Hello": "World"}`, + expected: expected{`null`, nil}, + handler: func(_ struct{}) {}, + options: []Option{WithDisallowUnknownFields(false)}, + }, + { + name: "No decoder options provided is the same as WithDisallowUnknownFields(false)", + input: `{"Hello": "World"}`, + expected: expected{`null`, nil}, + handler: func(_ struct{}) {}, + options: []Option{}, + }, { name: "bytes are base64 encoded strings", input: `"aGVsbG8="`,