diff --git a/Makefile b/Makefile index d7f92e1..6d956ea 100644 --- a/Makefile +++ b/Makefile @@ -2,4 +2,5 @@ test: bash -c 'diff -u <(echo -n) <(gofmt -s -d .)' go vet ./... go test ./... -v -covermode=atomic -coverprofile=coverage.out + .PHONY: test diff --git a/README.md b/README.md index 32e86f0..b68c0d4 100644 --- a/README.md +++ b/README.md @@ -63,6 +63,32 @@ handler := func(c *g8.APIGatewayProxyContext) error { } ``` +## API Gateway Lambda Authorizer Handlers + +You are able to define handlers for [Lambda Authorizer](https://docs.aws.amazon.com/apigateway/latest/developerguide/apigateway-use-lambda-authorizer.html) +(previously known as custom authorizers) using g8. Here is an example: + +```go + + handler := g8.APIGatewayCustomAuthorizerHandlerWithNewRelic( + func(c *APIGatewayCustomAuthorizerContext) error{ + c.Response.SetPrincipalID("some-principal-ID") + + c.Response.AllowAllMethods() + // other examples: + // c.Response.DenyAllMethods() + // c.Response.AllowMethod(Post, "/pets/*") + return nil + }, + g8.HandlerConfig{ + ... + }, + ) + + lambda.StartHandler(handler) + +``` + ## Response writing There are several methods provided to simplify writing HTTP responses. diff --git a/apigw_ca_handler.go b/apigw_ca_handler.go new file mode 100644 index 0000000..b156080 --- /dev/null +++ b/apigw_ca_handler.go @@ -0,0 +1,144 @@ +package g8 + +import ( + "context" + "errors" + "github.com/aws/aws-lambda-go/events" + "github.com/aws/aws-lambda-go/lambda" + newrelic "github.com/newrelic/go-agent" + "github.com/newrelic/go-agent/_integrations/nrlambda" + "github.com/rs/zerolog" +) + +// APIGatewayCustomAuthorizerContext the context for a request for Custom Authorizer +type APIGatewayCustomAuthorizerContext struct { + Context context.Context + Request events.APIGatewayCustomAuthorizerRequestTypeRequest + Response events.APIGatewayCustomAuthorizerResponse + Logger zerolog.Logger + NewRelicTx newrelic.Transaction + CorrelationID string + methodArnParts methodARN + hasAtLeastOneAllowedMethod bool +} + +// APIGatewayCustomAuthorizerHandlerFunc to populate +type APIGatewayCustomAuthorizerHandlerFunc func(c *APIGatewayCustomAuthorizerContext) error + +// APIGatewayCustomAuthorizerHandler fd +func APIGatewayCustomAuthorizerHandler( + h APIGatewayCustomAuthorizerHandlerFunc, + conf HandlerConfig, +) func(context.Context, events.APIGatewayCustomAuthorizerRequestTypeRequest) (events.APIGatewayCustomAuthorizerResponse, error) { + + return func(ctx context.Context, r events.APIGatewayCustomAuthorizerRequestTypeRequest) (events.APIGatewayCustomAuthorizerResponse, error) { + if len(r.MethodArn) == 0 { + return events.APIGatewayCustomAuthorizerResponse{}, errors.New("MethodArn is not set") + } + + correlationID := getCorrelationIDAPIGW(r.Headers) + + logger := configureLogger(conf). + Str("route", r.RequestContext.ResourcePath). + Str("correlation_id", correlationID). + Str("application", conf.AppName). + Str("function_name", conf.FunctionName). + Str("env", conf.EnvName). + Str("build_version", conf.BuildVersion). + Logger() + + c := &APIGatewayCustomAuthorizerContext{ + Context: ctx, + Request: r, + Response: NewAuthorizerResponse(), + Logger: logger, + NewRelicTx: newrelic.FromContext(ctx), + CorrelationID: correlationID, + methodArnParts: parseFromMethodARN(r.MethodArn), + hasAtLeastOneAllowedMethod: false, + } + + if err := h(c); err != nil { + logger.Err(err).Msg("Error while calling user-defined function") + return events.APIGatewayCustomAuthorizerResponse{}, err + } + + // sanity check + if !c.hasAtLeastOneAllowedMethod { + logger.Warn().Msg("Warning! No method were allowed! That means no requests will pass this " + + "authorizer! Please double check the policy.") + } + if len(c.Response.PrincipalID) == 0 { + logger.Warn().Msg("Warning! The PrincipalID was not defined! Please set it using c.Response.SetPrincipalID() function") + } + + c.Response.Context = map[string]interface{}{ + "customer-id": c.Response.PrincipalID, + } + + c.AddNewRelicAttribute("functionName", conf.FunctionName) + c.AddNewRelicAttribute("route", r.RequestContext.ResourcePath) + c.AddNewRelicAttribute("correlationID", correlationID) + c.AddNewRelicAttribute("buildVersion", conf.BuildVersion) + + logger.Debug(). + Str("principal_id", c.Response.PrincipalID). + Str("account_aws", c.methodArnParts.AccountID). + Msg("G8 Custom Authorizer successful") + + return c.Response, nil + } +} + +func APIGatewayCustomAuthorizerHandlerWithNewRelic(h APIGatewayCustomAuthorizerHandlerFunc, conf HandlerConfig) lambda.Handler { + return nrlambda.Wrap(APIGatewayCustomAuthorizerHandler(h, conf), conf.NewRelicApp) +} + +func (c *APIGatewayCustomAuthorizerContext) AddNewRelicAttribute(key string, val interface{}) { + if c.NewRelicTx == nil { + return + } + if err := c.NewRelicTx.AddAttribute(key, val); err != nil { + c.Logger.Error().Msgf("failed to add attr '%s' to new relic tx: %+v", key, err) + } +} + +func NewAuthorizerResponse() events.APIGatewayCustomAuthorizerResponse { + return events.APIGatewayCustomAuthorizerResponse{ + PolicyDocument: events.APIGatewayCustomAuthorizerPolicy{ + Version: "2012-10-17", + }, + } +} + +func (c *APIGatewayCustomAuthorizerContext) addMethod(effect Effect, verb, resource string) { + s := events.IAMPolicyStatement{ + Effect: effect.String(), + Action: []string{"execute-api:Invoke"}, + Resource: []string{c.methodArnParts.buildResourceARN(verb, resource)}, + } + + c.Response.PolicyDocument.Statement = append(c.Response.PolicyDocument.Statement, s) +} + +func (c *APIGatewayCustomAuthorizerContext) SetPrincipalID(principalID string) { + c.Response.PrincipalID = principalID +} + +func (c *APIGatewayCustomAuthorizerContext) AllowAllMethods() { + c.hasAtLeastOneAllowedMethod = true + c.addMethod(Allow, All, "*") +} + +func (c *APIGatewayCustomAuthorizerContext) DenyAllMethods() { + c.addMethod(Deny, All, "*") +} + +func (c *APIGatewayCustomAuthorizerContext) AllowMethod(verb, resource string) { + c.hasAtLeastOneAllowedMethod = true + c.addMethod(Allow, verb, resource) +} + +func (c *APIGatewayCustomAuthorizerContext) DenyMethod(verb, resource string) { + c.addMethod(Deny, verb, resource) +} diff --git a/apigw_handler.go b/apigw_handler.go index 2c93a28..49b6e82 100644 --- a/apigw_handler.go +++ b/apigw_handler.go @@ -32,7 +32,7 @@ func APIGatewayProxyHandler( conf HandlerConfig, ) func(context.Context, events.APIGatewayProxyRequest) (events.APIGatewayProxyResponse, error) { return func(ctx context.Context, r events.APIGatewayProxyRequest) (events.APIGatewayProxyResponse, error) { - correlationID := getCorrelationIDAPIGW(r) + correlationID := getCorrelationIDAPIGW(r.Headers) logger := configureLogger(conf). Str("route", r.RequestContext.ResourcePath). @@ -156,8 +156,8 @@ func (c *APIGatewayProxyContext) GetHeader(name string) string { return canonicalHeaders.Get(name) } -func getCorrelationIDAPIGW(r events.APIGatewayProxyRequest) string { - correlationID := r.Headers[headerCorrelationID] +func getCorrelationIDAPIGW(headers map[string]string) string { + correlationID := headers[headerCorrelationID] if correlationID != "" { return correlationID } diff --git a/go.mod b/go.mod index 47f7eb5..e4314fc 100644 --- a/go.mod +++ b/go.mod @@ -11,5 +11,5 @@ require ( github.com/rotisserie/eris v0.2.0 github.com/rs/zerolog v1.17.2 github.com/steinfletcher/apitest v1.4.0 - github.com/stretchr/testify v1.4.0 + github.com/stretchr/testify v1.5.1 ) diff --git a/go.sum b/go.sum index bd5292f..20d8ce2 100644 --- a/go.sum +++ b/go.sum @@ -28,9 +28,12 @@ github.com/russross/blackfriday/v2 v2.0.1/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQD github.com/shurcooL/sanitized_anchor_name v1.0.0/go.mod h1:1NzhyTcUVG4SuEtjjoZeVRXNmyL/1OwPU0+IJeTBvfc= github.com/steinfletcher/apitest v1.4.0 h1:NfKf/kOTtzj/Y/T42570hNGlyVfS1lWPYTNAs6BonFw= github.com/steinfletcher/apitest v1.4.0/go.mod h1:pCHKMM2TcH1pezw/xbmilaCdK9/dGsoCZBafwaqJ2sY= +github.com/stretchr/objx v0.1.0 h1:4G4v2dO3VZwixGIRoQ5Lfboy6nUhCyYzaqnIAPPhYs4= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= +github.com/stretchr/testify v1.5.1 h1:nOGnQDM7FYENwehXlg/kFVnos3rEvtKTjRvOWSzb6H4= +github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= github.com/urfave/cli v1.22.1/go.mod h1:Gos4lmkARVdJ6EkW0WaNv/tZAAMe9V7XWyB60NtXRu0= github.com/zenazn/goji v0.9.0/go.mod h1:7S9M489iMyHBNxwZnk9/EHS098H4/F6TATF2mIxtB1Q= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= diff --git a/policy.go b/policy.go new file mode 100644 index 0000000..d0c4e0f --- /dev/null +++ b/policy.go @@ -0,0 +1,72 @@ +package g8 + +import ( + "strings" +) + +const All = "*" + +type Effect int + +const ( + Allow Effect = iota + Deny +) + +func (e Effect) String() string { + switch e { + case Allow: + return "Allow" + case Deny: + return "Deny" + } + return "" +} + +type methodARN struct { + + // The region where the API is deployed. By default this is set to '*' + Region string + + // The AWS account id the policy will be generated for. This is used to create the method ARNs. + AccountID string + + // The API Gateway API id. By default this is set to '*' + APIID string + + // The name of the stage used in the policy. By default this is set to '*' + Stage string +} + +func parseFromMethodARN(rawArn string) methodARN { + + tmp := strings.Split(rawArn, ":") + apiGatewayArnTmp := strings.Split(tmp[5], "/") + awsAccountID := tmp[4] + + return methodARN{ + AccountID: awsAccountID, + Region: tmp[3], + APIID: apiGatewayArnTmp[0], + Stage: apiGatewayArnTmp[1], + } +} + +func (r *methodARN) buildResourceARN(verb, resource string) string { + var str strings.Builder + + str.WriteString("arn:aws:execute-api:") + str.WriteString(r.Region) + str.WriteString(":") + str.WriteString(r.AccountID) + str.WriteString(":") + str.WriteString(r.APIID) + str.WriteString("/") + str.WriteString(r.Stage) + str.WriteString("/") + str.WriteString(verb) + str.WriteString("/") + str.WriteString(strings.TrimLeft(resource, "/")) + + return str.String() +} diff --git a/policy_test.go b/policy_test.go new file mode 100644 index 0000000..0ac6d45 --- /dev/null +++ b/policy_test.go @@ -0,0 +1,122 @@ +package g8 + +import ( + "github.com/stretchr/testify/assert" + "net/http" + "testing" +) + +func TestHasMethodsEmpty(t *testing.T) { + + // Given: + c := APIGatewayCustomAuthorizerContext{} + + // Then: + assert.False(t, c.hasAtLeastOneAllowedMethod) // <-- default value is false +} + +func TestHasMethodsNonEmptyButContainsAllDenies(t *testing.T) { + + // Given: + c := APIGatewayCustomAuthorizerContext{} + + // When + c.DenyAllMethods() + + // Then: + assert.False(t, c.hasAtLeastOneAllowedMethod) +} + +func TestHasMethodsAllowsAllMethods(t *testing.T) { + + // Given: + c := APIGatewayCustomAuthorizerContext{} + + // When + c.AllowAllMethods() + + // Then: + assert.True(t, c.hasAtLeastOneAllowedMethod) +} + +func TestHasMethodsHasMixedAllowAndDenyMethods(t *testing.T) { + + // Given: + c := APIGatewayCustomAuthorizerContext{} + + // When + c.DenyMethod(http.MethodPost, "/pets/*") + c.DenyMethod(http.MethodDelete, "/cars/*") + c.AllowMethod(http.MethodGet, "/users/*") // <-- !!! + c.DenyMethod(http.MethodPost, "/picture/update") + c.DenyMethod(http.MethodPost, "/picture/assign") + c.DenyMethod(http.MethodPut, "/users/new") + + // Then: + assert.True(t, c.hasAtLeastOneAllowedMethod) +} + +func TestHasMethodsAllDenyMethods(t *testing.T) { + + // Given: + c := APIGatewayCustomAuthorizerContext{} + + // When + c.DenyMethod(http.MethodPost, "/pets/*") + c.DenyMethod(http.MethodDelete, "/cars/*") + c.DenyMethod(http.MethodPost, "/picture/update") + c.DenyMethod(http.MethodPost, "/picture/assign") + c.DenyMethod(http.MethodPut, "/users/new") + + // Then: + assert.False(t, c.hasAtLeastOneAllowedMethod, "No methods allowed") +} + +func TestBuildResourceArn(t *testing.T) { + + // Given: + m := methodARN{ + Region: "eu-west-1", + AccountID: "aws-account-id", + APIID: "*", + Stage: "*", + } + + // When + resourceARN := m.buildResourceARN(http.MethodPost, "/pets/*") + + // Then: + assert.Equal(t, "arn:aws:execute-api:eu-west-1:aws-account-id:*/*/POST/pets/*", resourceARN) +} + +func TestBuildResourceArnAllowAll(t *testing.T) { + + // Given: + m := methodARN{ + Region: "*", + AccountID: "aws-account-id", + APIID: "*", + Stage: "*", + } + + // When + resourceARN := m.buildResourceARN(All, "*") + + // Then: + assert.Equal(t, "arn:aws:execute-api:*:aws-account-id:*/*/*/*", resourceARN) +} + +func TestParseMethodARN(t *testing.T) { + + // Given: + strMethodARN := "arn:aws:execute-api:eu-west-1:123456789012:oy1e34abcd/main/GET/test-endpoint" + + // When: + methodARN := parseFromMethodARN(strMethodARN) + + // Then: + assert.Equal(t, "eu-west-1", methodARN.Region) + assert.Equal(t, "123456789012", methodARN.AccountID) + assert.Equal(t, "oy1e34abcd", methodARN.APIID) + assert.Equal(t, "main", methodARN.Stage) +}