Skip to content

Commit

Permalink
Add support for Custom Authorizer support (#8)
Browse files Browse the repository at this point in the history
You can define a custom authorizer in the API Gateway. To do that, you need to make a lambda that expects specific request and returns data in a defined format. This commit adds this functionality to a framework
  • Loading branch information
w32blaster authored Apr 16, 2020
1 parent e31766b commit 28c7d47
Show file tree
Hide file tree
Showing 8 changed files with 372 additions and 4 deletions.
1 change: 1 addition & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,5 @@ test:
bash -c 'diff -u <(echo -n) <(gofmt -s -d .)'
go vet ./...
go test ./... -v -covermode=atomic -coverprofile=coverage.out

.PHONY: test
26 changes: 26 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,32 @@ handler := func(c *g8.APIGatewayProxyContext) error {
}
```

## API Gateway Lambda Authorizer Handlers

You are able to define handlers for [Lambda Authorizer](https://docs.aws.amazon.com/apigateway/latest/developerguide/apigateway-use-lambda-authorizer.html)
(previously known as custom authorizers) using g8. Here is an example:

```go

handler := g8.APIGatewayCustomAuthorizerHandlerWithNewRelic(
func(c *APIGatewayCustomAuthorizerContext) error{
c.Response.SetPrincipalID("some-principal-ID")

c.Response.AllowAllMethods()
// other examples:
// c.Response.DenyAllMethods()
// c.Response.AllowMethod(Post, "/pets/*")
return nil
},
g8.HandlerConfig{
...
},
)

lambda.StartHandler(handler)

```

## Response writing

There are several methods provided to simplify writing HTTP responses.
Expand Down
144 changes: 144 additions & 0 deletions apigw_ca_handler.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
package g8

import (
"context"
"errors"
"github.com/aws/aws-lambda-go/events"
"github.com/aws/aws-lambda-go/lambda"
newrelic "github.com/newrelic/go-agent"
"github.com/newrelic/go-agent/_integrations/nrlambda"
"github.com/rs/zerolog"
)

// APIGatewayCustomAuthorizerContext the context for a request for Custom Authorizer
type APIGatewayCustomAuthorizerContext struct {
Context context.Context
Request events.APIGatewayCustomAuthorizerRequestTypeRequest
Response events.APIGatewayCustomAuthorizerResponse
Logger zerolog.Logger
NewRelicTx newrelic.Transaction
CorrelationID string
methodArnParts methodARN
hasAtLeastOneAllowedMethod bool
}

// APIGatewayCustomAuthorizerHandlerFunc to populate
type APIGatewayCustomAuthorizerHandlerFunc func(c *APIGatewayCustomAuthorizerContext) error

// APIGatewayCustomAuthorizerHandler fd
func APIGatewayCustomAuthorizerHandler(
h APIGatewayCustomAuthorizerHandlerFunc,
conf HandlerConfig,
) func(context.Context, events.APIGatewayCustomAuthorizerRequestTypeRequest) (events.APIGatewayCustomAuthorizerResponse, error) {

return func(ctx context.Context, r events.APIGatewayCustomAuthorizerRequestTypeRequest) (events.APIGatewayCustomAuthorizerResponse, error) {
if len(r.MethodArn) == 0 {
return events.APIGatewayCustomAuthorizerResponse{}, errors.New("MethodArn is not set")
}

correlationID := getCorrelationIDAPIGW(r.Headers)

logger := configureLogger(conf).
Str("route", r.RequestContext.ResourcePath).
Str("correlation_id", correlationID).
Str("application", conf.AppName).
Str("function_name", conf.FunctionName).
Str("env", conf.EnvName).
Str("build_version", conf.BuildVersion).
Logger()

c := &APIGatewayCustomAuthorizerContext{
Context: ctx,
Request: r,
Response: NewAuthorizerResponse(),
Logger: logger,
NewRelicTx: newrelic.FromContext(ctx),
CorrelationID: correlationID,
methodArnParts: parseFromMethodARN(r.MethodArn),
hasAtLeastOneAllowedMethod: false,
}

if err := h(c); err != nil {
logger.Err(err).Msg("Error while calling user-defined function")
return events.APIGatewayCustomAuthorizerResponse{}, err
}

// sanity check
if !c.hasAtLeastOneAllowedMethod {
logger.Warn().Msg("Warning! No method were allowed! That means no requests will pass this " +
"authorizer! Please double check the policy.")
}
if len(c.Response.PrincipalID) == 0 {
logger.Warn().Msg("Warning! The PrincipalID was not defined! Please set it using c.Response.SetPrincipalID() function")
}

c.Response.Context = map[string]interface{}{
"customer-id": c.Response.PrincipalID,
}

c.AddNewRelicAttribute("functionName", conf.FunctionName)
c.AddNewRelicAttribute("route", r.RequestContext.ResourcePath)
c.AddNewRelicAttribute("correlationID", correlationID)
c.AddNewRelicAttribute("buildVersion", conf.BuildVersion)

logger.Debug().
Str("principal_id", c.Response.PrincipalID).
Str("account_aws", c.methodArnParts.AccountID).
Msg("G8 Custom Authorizer successful")

return c.Response, nil
}
}

func APIGatewayCustomAuthorizerHandlerWithNewRelic(h APIGatewayCustomAuthorizerHandlerFunc, conf HandlerConfig) lambda.Handler {
return nrlambda.Wrap(APIGatewayCustomAuthorizerHandler(h, conf), conf.NewRelicApp)
}

func (c *APIGatewayCustomAuthorizerContext) AddNewRelicAttribute(key string, val interface{}) {
if c.NewRelicTx == nil {
return
}
if err := c.NewRelicTx.AddAttribute(key, val); err != nil {
c.Logger.Error().Msgf("failed to add attr '%s' to new relic tx: %+v", key, err)
}
}

func NewAuthorizerResponse() events.APIGatewayCustomAuthorizerResponse {
return events.APIGatewayCustomAuthorizerResponse{
PolicyDocument: events.APIGatewayCustomAuthorizerPolicy{
Version: "2012-10-17",
},
}
}

func (c *APIGatewayCustomAuthorizerContext) addMethod(effect Effect, verb, resource string) {
s := events.IAMPolicyStatement{
Effect: effect.String(),
Action: []string{"execute-api:Invoke"},
Resource: []string{c.methodArnParts.buildResourceARN(verb, resource)},
}

c.Response.PolicyDocument.Statement = append(c.Response.PolicyDocument.Statement, s)
}

func (c *APIGatewayCustomAuthorizerContext) SetPrincipalID(principalID string) {
c.Response.PrincipalID = principalID
}

func (c *APIGatewayCustomAuthorizerContext) AllowAllMethods() {
c.hasAtLeastOneAllowedMethod = true
c.addMethod(Allow, All, "*")
}

func (c *APIGatewayCustomAuthorizerContext) DenyAllMethods() {
c.addMethod(Deny, All, "*")
}

func (c *APIGatewayCustomAuthorizerContext) AllowMethod(verb, resource string) {
c.hasAtLeastOneAllowedMethod = true
c.addMethod(Allow, verb, resource)
}

func (c *APIGatewayCustomAuthorizerContext) DenyMethod(verb, resource string) {
c.addMethod(Deny, verb, resource)
}
6 changes: 3 additions & 3 deletions apigw_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ func APIGatewayProxyHandler(
conf HandlerConfig,
) func(context.Context, events.APIGatewayProxyRequest) (events.APIGatewayProxyResponse, error) {
return func(ctx context.Context, r events.APIGatewayProxyRequest) (events.APIGatewayProxyResponse, error) {
correlationID := getCorrelationIDAPIGW(r)
correlationID := getCorrelationIDAPIGW(r.Headers)

logger := configureLogger(conf).
Str("route", r.RequestContext.ResourcePath).
Expand Down Expand Up @@ -156,8 +156,8 @@ func (c *APIGatewayProxyContext) GetHeader(name string) string {
return canonicalHeaders.Get(name)
}

func getCorrelationIDAPIGW(r events.APIGatewayProxyRequest) string {
correlationID := r.Headers[headerCorrelationID]
func getCorrelationIDAPIGW(headers map[string]string) string {
correlationID := headers[headerCorrelationID]
if correlationID != "" {
return correlationID
}
Expand Down
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,5 @@ require (
github.com/rotisserie/eris v0.2.0
github.com/rs/zerolog v1.17.2
github.com/steinfletcher/apitest v1.4.0
github.com/stretchr/testify v1.4.0
github.com/stretchr/testify v1.5.1
)
3 changes: 3 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,12 @@ github.com/russross/blackfriday/v2 v2.0.1/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQD
github.com/shurcooL/sanitized_anchor_name v1.0.0/go.mod h1:1NzhyTcUVG4SuEtjjoZeVRXNmyL/1OwPU0+IJeTBvfc=
github.com/steinfletcher/apitest v1.4.0 h1:NfKf/kOTtzj/Y/T42570hNGlyVfS1lWPYTNAs6BonFw=
github.com/steinfletcher/apitest v1.4.0/go.mod h1:pCHKMM2TcH1pezw/xbmilaCdK9/dGsoCZBafwaqJ2sY=
github.com/stretchr/objx v0.1.0 h1:4G4v2dO3VZwixGIRoQ5Lfboy6nUhCyYzaqnIAPPhYs4=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk=
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
github.com/stretchr/testify v1.5.1 h1:nOGnQDM7FYENwehXlg/kFVnos3rEvtKTjRvOWSzb6H4=
github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA=
github.com/urfave/cli v1.22.1/go.mod h1:Gos4lmkARVdJ6EkW0WaNv/tZAAMe9V7XWyB60NtXRu0=
github.com/zenazn/goji v0.9.0/go.mod h1:7S9M489iMyHBNxwZnk9/EHS098H4/F6TATF2mIxtB1Q=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
Expand Down
72 changes: 72 additions & 0 deletions policy.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
package g8

import (
"strings"
)

const All = "*"

type Effect int

const (
Allow Effect = iota
Deny
)

func (e Effect) String() string {
switch e {
case Allow:
return "Allow"
case Deny:
return "Deny"
}
return ""
}

type methodARN struct {

// The region where the API is deployed. By default this is set to '*'
Region string

// The AWS account id the policy will be generated for. This is used to create the method ARNs.
AccountID string

// The API Gateway API id. By default this is set to '*'
APIID string

// The name of the stage used in the policy. By default this is set to '*'
Stage string
}

func parseFromMethodARN(rawArn string) methodARN {

tmp := strings.Split(rawArn, ":")
apiGatewayArnTmp := strings.Split(tmp[5], "/")
awsAccountID := tmp[4]

return methodARN{
AccountID: awsAccountID,
Region: tmp[3],
APIID: apiGatewayArnTmp[0],
Stage: apiGatewayArnTmp[1],
}
}

func (r *methodARN) buildResourceARN(verb, resource string) string {
var str strings.Builder

str.WriteString("arn:aws:execute-api:")
str.WriteString(r.Region)
str.WriteString(":")
str.WriteString(r.AccountID)
str.WriteString(":")
str.WriteString(r.APIID)
str.WriteString("/")
str.WriteString(r.Stage)
str.WriteString("/")
str.WriteString(verb)
str.WriteString("/")
str.WriteString(strings.TrimLeft(resource, "/"))

return str.String()
}
Loading

0 comments on commit 28c7d47

Please sign in to comment.