From 0a0e8a914ea70a8cd912fee99e9f31804f190f51 Mon Sep 17 00:00:00 2001 From: Stein Fletcher Date: Sat, 25 Jan 2020 07:44:07 +0000 Subject: [PATCH] Add result to lambda to support step functions --- cloudwatch_handler.go | 12 ++++++------ cloudwatch_handler_test.go | 13 +++++++------ handlers.go | 2 ++ 3 files changed, 15 insertions(+), 12 deletions(-) diff --git a/cloudwatch_handler.go b/cloudwatch_handler.go index 29674f1..dec671f 100644 --- a/cloudwatch_handler.go +++ b/cloudwatch_handler.go @@ -19,10 +19,10 @@ type CloudWatchContext struct { CorrelationID string } -type CloudWatchHandlerFunc func(c *CloudWatchContext) error +type CloudWatchHandlerFunc func(c *CloudWatchContext) (LambdaResult, error) -func CloudWatchHandler(h CloudWatchHandlerFunc, conf HandlerConfig) func(context.Context, events.CloudWatchEvent) error { - return func(ctx context.Context, event events.CloudWatchEvent) error { +func CloudWatchHandler(h CloudWatchHandlerFunc, conf HandlerConfig) func(context.Context, events.CloudWatchEvent) (LambdaResult, error) { + return func(ctx context.Context, event events.CloudWatchEvent) (LambdaResult, error) { correlationID := uuid.New().String() // the resource that triggered the event, e.g. "arn:aws:events:us-east-1:123456789012:rule/MyScheduledRule" @@ -49,11 +49,11 @@ func CloudWatchHandler(h CloudWatchHandlerFunc, conf HandlerConfig) func(context c.AddNewRelicAttribute("correlationID", correlationID) c.AddNewRelicAttribute("cloudWatchResource", cloudWatchResource) - if err := h(c); err != nil { + result, err := h(c) + if err != nil { logUnhandledError(c.Logger, err) - return err } - return nil + return result, err } } diff --git a/cloudwatch_handler_test.go b/cloudwatch_handler_test.go index 1a13cc3..5bc9e38 100644 --- a/cloudwatch_handler_test.go +++ b/cloudwatch_handler_test.go @@ -15,34 +15,35 @@ import ( func TestCloudWatchHandler_SingleMessage(t *testing.T) { timesCalled := 0 resourceArn := "arn:aws:events:us-east-1:123456789012:rule/MyScheduledRule" - h := g8.CloudWatchHandler(func(c *g8.CloudWatchContext) error { + h := g8.CloudWatchHandler(func(c *g8.CloudWatchContext) (g8.LambdaResult, error) { timesCalled++ assert.Equal(t, resourceArn, c.Event.Resources[0]) assert.NotEmpty(t, c.CorrelationID) - return nil + return "finished", nil }, g8.HandlerConfig{Logger: zerolog.New(ioutil.Discard)}) - err := h(context.Background(), events.CloudWatchEvent{ + result, err := h(context.Background(), events.CloudWatchEvent{ Resources: []string{resourceArn}, }) assert.Nil(t, err) + assert.Equal(t, "finished", result) assert.Equal(t, 1, timesCalled) } func TestCloudWatchHandler_HandlerError(t *testing.T) { timesCalled := 0 - handlerFunc := func(c *g8.CloudWatchContext) error { + handlerFunc := func(c *g8.CloudWatchContext) (g8.LambdaResult, error) { timesCalled++ - return assert.AnError + return nil, assert.AnError } h := g8.CloudWatchHandler(handlerFunc, g8.HandlerConfig{ Logger: zerolog.New(ioutil.Discard), }) - err := h(context.Background(), events.CloudWatchEvent{}) + _, err := h(context.Background(), events.CloudWatchEvent{}) assert.Equal(t, assert.AnError, err) assert.Equal(t, 1, timesCalled) diff --git a/handlers.go b/handlers.go index 5775857..3940a8c 100644 --- a/handlers.go +++ b/handlers.go @@ -24,6 +24,8 @@ type HandlerConfig struct { NewRelicApp newrelic.Application } +type LambdaResult interface{} + type Validatable interface { Validate() error }