diff --git a/core/requestALB.go b/core/requestALB.go new file mode 100644 index 0000000..69eb023 --- /dev/null +++ b/core/requestALB.go @@ -0,0 +1,210 @@ +// Package core provides utility methods that help convert ALB events +// into an http.Request and http.ResponseWriter +package core + +import ( + "bytes" + "context" + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "log" + "net/http" + "net/url" + "strings" + + "github.com/aws/aws-lambda-go/events" + "github.com/aws/aws-lambda-go/lambdacontext" +) + +const ( + // ALBContextHeader is the custom header key used to store the + // ALB ELB context. To access the Context properties use the + // GetALBContext method of the RequestAccessorALB object. + ALBContextHeader = "X-GoLambdaProxy-ALB-Context" +) + +// RequestAccessorALB objects give access to custom ALB Target Group properties +// in the request. +type RequestAccessorALB struct { + stripBasePath string +} + +// GetALBContext extracts the ALB context object from a request's custom header. +// Returns a populated events.ALBTargetGroupRequestContext object from the request. +func (r *RequestAccessorALB) GetContextALB(req *http.Request) (events.ALBTargetGroupRequestContext, error) { + if req.Header.Get(ALBContextHeader) == "" { + return events.ALBTargetGroupRequestContext{}, errors.New("no context header in request") + } + context := events.ALBTargetGroupRequestContext{} + err := json.Unmarshal([]byte(req.Header.Get(ALBContextHeader)), &context) + if err != nil { + log.Println("Error while unmarshalling context") + log.Println(err) + return events.ALBTargetGroupRequestContext{}, err + } + return context, nil +} + +// StripBasePath instructs the RequestAccessor object that the given base +// path should be removed from the request path before sending it to the +// framework for routing. This is used when API Gateway is configured with +// base path mappings in custom domain names. +func (r *RequestAccessorALB) StripBasePath(basePath string) string { + if strings.Trim(basePath, " ") == "" { + r.stripBasePath = "" + return "" + } + + newBasePath := basePath + if !strings.HasPrefix(newBasePath, "/") { + newBasePath = "/" + newBasePath + } + + if strings.HasSuffix(newBasePath, "/") { + newBasePath = newBasePath[:len(newBasePath)-1] + } + + r.stripBasePath = newBasePath + + return newBasePath +} + +// ProxyEventToHTTPRequest converts an ALB Target Group Request event into a http.Request object. +// Returns the populated http request with additional custom header for the ALB context. +// To access these properties use the GetALBContext method of the RequestAccessorALB object. +func (r *RequestAccessorALB) ProxyEventToHTTPRequest(req events.ALBTargetGroupRequest) (*http.Request, error) { + httpRequest, err := r.EventToRequest(req) + if err != nil { + log.Println(err) + return nil, err + } + return addToHeaderALB(httpRequest, req) +} + +// EventToRequestWithContext converts an ALB Target Group Request event and context into an http.Request object. +// Returns the populated http request with lambda context, ALB TargetGroup RequestContext as part of its context. +func (r *RequestAccessorALB) EventToRequestWithContext(ctx context.Context, req events.ALBTargetGroupRequest) (*http.Request, error) { + httpRequest, err := r.EventToRequest(req) + if err != nil { + log.Println(err) + return nil, err + } + return addToContextALB(ctx, httpRequest, req), nil +} + +// EventToRequest converts an ALB TargetGroup event into an http.Request object. +// Returns the populated request maintaining headers +func (r *RequestAccessorALB) EventToRequest(req events.ALBTargetGroupRequest) (*http.Request, error) { + decodedBody := []byte(req.Body) + if req.IsBase64Encoded { + base64Body, err := base64.StdEncoding.DecodeString(req.Body) + if err != nil { + return nil, err + } + decodedBody = base64Body + } + + path := req.Path + if r.stripBasePath != "" && len(r.stripBasePath) > 1 { + if strings.HasPrefix(path, r.stripBasePath) { + path = strings.Replace(path, r.stripBasePath, "", 1) + } + } + if !strings.HasPrefix(path, "/") { + path = "/" + path + } + serverAddress := "https://" + req.Headers["host"] + // if customAddress, ok := os.LookupEnv(CustomHostVariable); ok { + // serverAddress = customAddress + // } + path = serverAddress + path + + if len(req.MultiValueQueryStringParameters) > 0 { + queryString := "" + for q, l := range req.MultiValueQueryStringParameters { + for _, v := range l { + if queryString != "" { + queryString += "&" + } + queryString += url.QueryEscape(q) + "=" + url.QueryEscape(v) + } + } + path += "?" + queryString + } else if len(req.QueryStringParameters) > 0 { + // Support `QueryStringParameters` for backward compatibility. + // https://github.com/awslabs/aws-lambda-go-api-proxy/issues/37 + queryString := "" + for q := range req.QueryStringParameters { + if queryString != "" { + queryString += "&" + } + queryString += url.QueryEscape(q) + "=" + url.QueryEscape(req.QueryStringParameters[q]) + } + path += "?" + queryString + } + + httpRequest, err := http.NewRequest( + strings.ToUpper(req.HTTPMethod), + path, + bytes.NewReader(decodedBody), + ) + + if err != nil { + fmt.Printf("Could not convert request %s:%s to http.Request\n", req.HTTPMethod, req.Path) + log.Println(err) + return nil, err + } + + if req.MultiValueHeaders != nil { + for k, values := range req.MultiValueHeaders { + for _, value := range values { + httpRequest.Header.Add(k, value) + } + } + } else { + for h := range req.Headers { + httpRequest.Header.Add(h, req.Headers[h]) + } + } + + httpRequest.RequestURI = httpRequest.URL.RequestURI() + + return httpRequest, nil +} + +func addToHeaderALB(req *http.Request, albRequest events.ALBTargetGroupRequest) (*http.Request, error) { + albContext, err := json.Marshal(albRequest.RequestContext) + if err != nil { + log.Println("Could not Marshal ALB context for custom header") + return req, err + } + req.Header.Set(ALBContextHeader, string(albContext)) + return req, nil +} + +// adds context data to http request so we can pass +func addToContextALB(ctx context.Context, req *http.Request, albRequest events.ALBTargetGroupRequest) *http.Request { + lc, _ := lambdacontext.FromContext(ctx) + rc := requestContextALB{lambdaContext: lc, albContext: albRequest.RequestContext} + ctx = context.WithValue(ctx, ctxKey{}, rc) + return req.WithContext(ctx) +} + +// GetALBTargetGroupRequestFromContext retrieve ALBTargetGroupt from context.Context +func GetTargetGroupRequetFromContextALB(ctx context.Context) (events.ALBTargetGroupRequestContext, bool) { + v, ok := ctx.Value(ctxKey{}).(requestContextALB) + return v.albContext, ok +} + +// GetRuntimeContextFromContext retrieve Lambda Runtime Context from context.Context +func GetRuntimeContextFromContextALB(ctx context.Context) (*lambdacontext.LambdaContext, bool) { + v, ok := ctx.Value(ctxKey{}).(requestContextALB) + return v.lambdaContext, ok +} + +type requestContextALB struct { + lambdaContext *lambdacontext.LambdaContext + albContext events.ALBTargetGroupRequestContext +} diff --git a/core/requestALB_test.go b/core/requestALB_test.go new file mode 100644 index 0000000..876ea19 --- /dev/null +++ b/core/requestALB_test.go @@ -0,0 +1,283 @@ +package core_test + +import ( + "context" + "encoding/base64" + "math/rand" + "strings" + + "github.com/awslabs/aws-lambda-go-api-proxy/core" + + "github.com/aws/aws-lambda-go/events" + "github.com/aws/aws-lambda-go/lambdacontext" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("RequestAccessorALB tests", func() { + Context("ALB event conversion", func() { + accessor := core.RequestAccessorALB{} + qs := make(map[string]string) + mvh := make(map[string][]string) + mvqs := make(map[string][]string) + hdr := make(map[string]string) + qs["UniqueId"] = "12345" + mvh["accept"] = []string{"test", "one"} + mvh["connection"] = []string{"keep-alive"} + mvh["host"] = []string{"lambda-test-alb-1234567.us-east-1.elb.amazonaws.com"} + hdr["header1"] = "Testhdr1" + hdr["header2"] = "Testhdr2" + //multivalue querystrings + mvqs["k1"] = []string{"t1"} + mvqs["k2"] = []string{"t2"} + bdy := "Test BODY" + basicRequest := getALBProxyRequest("/hello", "GET", getALBRequestContext(), false, hdr, bdy, qs, mvh, nil) + + It("Correctly converts a basic event", func() { + httpReq, err := accessor.EventToRequestWithContext(context.Background(), basicRequest) + Expect(err).To(BeNil()) + Expect("/hello").To(Equal(httpReq.URL.Path)) + Expect("/hello?UniqueId=12345").To(Equal(httpReq.RequestURI)) + Expect("GET").To(Equal(httpReq.Method)) + headers := basicRequest.Headers + Expect(2).To(Equal(len(headers))) + mvhs := basicRequest.MultiValueHeaders + Expect(3).To(Equal(len(mvhs))) + mvqs := basicRequest.MultiValueQueryStringParameters + Expect(0).To(Equal(len(mvqs))) + + }) + + binaryBody := make([]byte, 256) + _, err := rand.Read(binaryBody) + if err != nil { + Fail("Could not generate random binary body") + } + + encodedBody := base64.StdEncoding.EncodeToString(binaryBody) + + binaryRequest := getALBProxyRequest("/hello", "POST", getALBRequestContext(), true, hdr, bdy, qs, mvh, nil) + binaryRequest.Body = encodedBody + binaryRequest.IsBase64Encoded = true + + It("Decodes a base64 encoded body", func() { + httpReq, err := accessor.EventToRequestWithContext(context.Background(), binaryRequest) + Expect(err).To(BeNil()) + Expect("/hello").To(Equal(httpReq.URL.Path)) + Expect("/hello?UniqueId=12345").To(Equal(httpReq.RequestURI)) + Expect("POST").To(Equal(httpReq.Method)) + + Expect(err).To(BeNil()) + + }) + + mqsRequest := getALBProxyRequest("/hello", "GET", getALBRequestContext(), false, hdr, bdy, qs, mvh, nil) + mqsRequest.QueryStringParameters = map[string]string{ + "hello": "1", + "world": "2", + } + It("Populates multiple value query string correctly", func() { + httpReq, err := accessor.EventToRequestWithContext(context.Background(), mqsRequest) + Expect(err).To(BeNil()) + Expect("/hello").To(Equal(httpReq.URL.Path)) + Expect(httpReq.RequestURI).To(ContainSubstring("hello=1")) + Expect(httpReq.RequestURI).To(ContainSubstring("world=2")) + Expect("GET").To(Equal(httpReq.Method)) + + query := httpReq.URL.Query() + Expect(2).To(Equal(len(query))) + Expect(query["hello"]).ToNot(BeNil()) + Expect(query["world"]).ToNot(BeNil()) + Expect(1).To(Equal(len(query["hello"]))) + Expect("1").To(Equal(query["hello"][0])) + Expect("2").To(Equal(query["world"][0])) + + }) + + qsRequest := getALBProxyRequest("/hello", "GET", getALBRequestContext(), false, hdr, bdy, qs, mvh, nil) + qsRequest.QueryStringParameters = map[string]string{ + "hello": "1", + "world": "2", + } + It("Populates query string correctly", func() { + httpReq, err := accessor.EventToRequestWithContext(context.Background(), qsRequest) + Expect(err).To(BeNil()) + Expect("/hello").To(Equal(httpReq.URL.Path)) + Expect(httpReq.RequestURI).To(ContainSubstring("hello=1")) + Expect(httpReq.RequestURI).To(ContainSubstring("world=2")) + Expect("GET").To(Equal(httpReq.Method)) + + query := httpReq.URL.Query() + Expect(2).To(Equal(len(query))) + Expect(query["hello"]).ToNot(BeNil()) + Expect(query["world"]).ToNot(BeNil()) + Expect(1).To(Equal(len(query["hello"]))) + Expect(1).To(Equal(len(query["world"]))) + Expect("1").To(Equal(query["hello"][0])) + Expect("2").To(Equal(query["world"][0])) + }) + + // If multivaluehaders are set then it only passes the multivalue headers to the http.Request + mvhRequest := getALBProxyRequest("/hello", "GET", getALBRequestContext(), false, hdr, bdy, qs, nil, mvqs) + mvhRequest.MultiValueHeaders = map[string][]string{ + "accept": {"test", "one"}, + "connection": {"keep-alive"}, + "host": {"lambda-test-alb-1234567.us-east-1.elb.amazonaws.com"}, + } + It("Populates multiple value headers correctly", func() { + httpReq, err := accessor.EventToRequestWithContext(context.Background(), mvhRequest) + Expect(err).To(BeNil()) + Expect("/hello").To(Equal(httpReq.URL.Path)) + Expect("GET").To(Equal(httpReq.Method)) + + headers := httpReq.Header + Expect(3).To(Equal(len(headers))) + + for k, value := range headers { + Expect(value).To(Equal(mvhRequest.MultiValueHeaders[strings.ToLower(k)])) + } + + }) + // If multivaluehaders are set then it only passes the multivalue headers to the http.Request + svhRequest := getALBProxyRequest("/hello", "GET", getALBRequestContext(), false, hdr, bdy, qs, mvh, mvqs) + svhRequest.Headers = map[string]string{ + "header1": "Testhdr1", + "header2": "Testhdr2"} + + It("Populates single value headers correctly", func() { + httpReq, err := accessor.EventToRequestWithContext(context.Background(), svhRequest) + Expect(err).To(BeNil()) + Expect("/hello").To(Equal(httpReq.URL.Path)) + Expect("GET").To(Equal(httpReq.Method)) + + headers := httpReq.Header + Expect(3).To(Equal(len(headers))) + + for k, value := range headers { + Expect(value).To(Equal(mvhRequest.MultiValueHeaders[strings.ToLower(k)])) + } + + }) + + basePathRequest := getALBProxyRequest("/app1/orders", "GET", getALBRequestContext(), false, hdr, bdy, qs, mvh, nil) + + It("Stips the base path correct", func() { + accessor.StripBasePath("app1") + httpReq, err := accessor.EventToRequestWithContext(context.Background(), basePathRequest) + + Expect(err).To(BeNil()) + Expect("/orders").To(Equal(httpReq.URL.Path)) + Expect("/orders?UniqueId=12345").To(Equal(httpReq.RequestURI)) + }) + + contextRequest := getALBProxyRequest("orders", "GET", getALBRequestContext(), false, hdr, bdy, qs, mvh, mvqs) + contextRequest.RequestContext = getALBRequestContext() + + It("Populates context header correctly", func() { + // calling old method to verify reverse compatibility + httpReq, err := accessor.ProxyEventToHTTPRequest(contextRequest) + Expect(err).To(BeNil()) + Expect(4).To(Equal(len(httpReq.Header))) + Expect(httpReq.Header.Get(core.ALBContextHeader)).ToNot(BeNil()) + }) + }) + + Context("StripBasePath tests", func() { + accessor := core.RequestAccessorALB{} + It("Adds prefix slash", func() { + basePath := accessor.StripBasePath("app1") + Expect("/app1").To(Equal(basePath)) + }) + + It("Removes trailing slash", func() { + basePath := accessor.StripBasePath("/app1/") + Expect("/app1").To(Equal(basePath)) + }) + + It("Ignores blank strings", func() { + basePath := accessor.StripBasePath(" ") + Expect("").To(Equal(basePath)) + }) + }) + + Context("Retrieves ALB Target Group Request context", func() { + It("Returns a correctly unmarshalled object", func() { + qs := make(map[string]string) + mvh := make(map[string][]string) + hdr := make(map[string]string) + mvqs := make(map[string][]string) + qs["UniqueId"] = "12345" + mvh["accept"] = []string{"*/*", "/"} + mvh["connection"] = []string{"keep-alive"} + mvh["host"] = []string{"lambda-test-alb-1234567.us-east-1.elb.amazonaws.com"} + mvqs["key1"] = []string{"Test1"} + mvqs["key2"] = []string{"test2"} + hdr["header1"] = "Testhdr1" + bdy := "Test BODY2" + + contextRequest := getALBProxyRequest("/orders", "GET", getALBRequestContext(), false, hdr, bdy, qs, mvh, mvqs) + contextRequest.RequestContext = getALBRequestContext() + + accessor := core.RequestAccessorALB{} + // calling old method to verify reverse compatibility + httpReq, err := accessor.ProxyEventToHTTPRequest(contextRequest) + Expect(err).To(BeNil()) + + headerContext, err := accessor.GetContextALB(httpReq) + Expect(err).To(BeNil()) + Expect(headerContext).ToNot(BeNil()) + Expect("arn:aws:elasticloadbalancing:us-east-1:123456789012:targetgroup/lambda-target/abcdefgh").To(Equal(headerContext.ELB.TargetGroupArn)) + proxyContext, ok := core.GetTargetGroupRequetFromContextALB(httpReq.Context()) + // should fail because using header proxy method + Expect(ok).To(BeFalse()) + + httpReq, err = accessor.EventToRequestWithContext(context.Background(), contextRequest) + Expect(err).To(BeNil()) + proxyContext, ok = core.GetTargetGroupRequetFromContextALB(httpReq.Context()) + Expect(ok).To(BeTrue()) + Expect("arn:aws:elasticloadbalancing:us-east-1:123456789012:targetgroup/lambda-target/abcdefgh").To(Equal(proxyContext.ELB.TargetGroupArn)) + runtimeContext, ok := core.GetRuntimeContextFromContextALB(httpReq.Context()) + Expect(ok).To(BeTrue()) + Expect(runtimeContext).To(BeNil()) + + lambdaContext := lambdacontext.NewContext(context.Background(), &lambdacontext.LambdaContext{AwsRequestID: "abc123"}) + httpReq, err = accessor.EventToRequestWithContext(lambdaContext, contextRequest) + Expect(err).To(BeNil()) + + headerContext, err = accessor.GetContextALB(httpReq) + // should fail as new context method doesn't populate headers + Expect(err).ToNot(BeNil()) + proxyContext, ok = core.GetTargetGroupRequetFromContextALB(httpReq.Context()) + Expect(ok).To(BeTrue()) + Expect("arn:aws:elasticloadbalancing:us-east-1:123456789012:targetgroup/lambda-target/abcdefgh").To(Equal(proxyContext.ELB.TargetGroupArn)) + runtimeContext, ok = core.GetRuntimeContextFromContextALB(httpReq.Context()) + Expect(ok).To(BeTrue()) + Expect(runtimeContext).ToNot(BeNil()) + + }) + }) +}) + +func getALBProxyRequest(path string, method string, requestCtx events.ALBTargetGroupRequestContext, + is64 bool, header map[string]string, body string, qs map[string]string, mvh map[string][]string, mvqsp map[string][]string) events.ALBTargetGroupRequest { + return events.ALBTargetGroupRequest{ + HTTPMethod: method, + Path: path, + QueryStringParameters: qs, + MultiValueQueryStringParameters: mvqsp, + Headers: header, + MultiValueHeaders: mvh, + RequestContext: requestCtx, + IsBase64Encoded: is64, + Body: body, + } +} + +func getALBRequestContext() events.ALBTargetGroupRequestContext { + return events.ALBTargetGroupRequestContext{ + ELB: events.ELBContext{ + TargetGroupArn: "arn:aws:elasticloadbalancing:us-east-1:123456789012:targetgroup/lambda-target/abcdefgh", + }, + } +} diff --git a/core/responseALB.go b/core/responseALB.go new file mode 100644 index 0000000..b5869d4 --- /dev/null +++ b/core/responseALB.go @@ -0,0 +1,112 @@ +// Package core provides utility methods that help convert proxy events +// into an http.Request and http.ResponseWriter +package core + +import ( + "bytes" + "encoding/base64" + "errors" + "net/http" + "unicode/utf8" + + "github.com/aws/aws-lambda-go/events" +) + +// ProxyResponseWriter implements http.ResponseWriter and adds the method +// necessary to return an events.ALBTargetGroupResponse object +type ProxyResponseWriterALB struct { + headers http.Header + body bytes.Buffer + status int + statusText string + observers []chan<- bool +} + +// NewProxyResponseWriter returns a new ProxyResponseWriter object. +// The object is initialized with an empty map of headers and a +// status code of -1 +func NewProxyResponseWriterALB() *ProxyResponseWriterALB { + return &ProxyResponseWriterALB{ + headers: make(http.Header), + status: defaultStatusCode, + statusText: http.StatusText(defaultStatusCode), + observers: make([]chan<- bool, 0), + } + +} + +func (r *ProxyResponseWriterALB) CloseNotify() <-chan bool { + ch := make(chan bool, 1) + + r.observers = append(r.observers, ch) + + return ch +} + +func (r *ProxyResponseWriterALB) notifyClosed() { + for _, v := range r.observers { + v <- true + } +} + +// Header implementation from the http.ResponseWriter interface. +func (r *ProxyResponseWriterALB) Header() http.Header { + return r.headers +} + +// Write sets the response body in the object. If no status code +// was set before with the WriteHeader method it sets the status +// for the response to 200 OK. +func (r *ProxyResponseWriterALB) Write(body []byte) (int, error) { + if r.status == defaultStatusCode { + r.status = http.StatusOK + } + + // if the content type header is not set when we write the body we try to + // detect one and set it by default. If the content type cannot be detected + // it is automatically set to "application/octet-stream" by the + // DetectContentType method + if r.Header().Get(contentTypeHeaderKey) == "" { + r.Header().Add(contentTypeHeaderKey, http.DetectContentType(body)) + } + + return (&r.body).Write(body) +} + +// WriteHeader sets a status code for the response. This method is used +// for error responses. +func (r *ProxyResponseWriterALB) WriteHeader(status int) { + r.status = status +} + +// GetProxyResponse converts the data passed to the response writer into +// an events.ALBTargetGroupResponse object. +// Returns a populated proxy response object. If the response is invalid, for example +// has no headers or an invalid status code returns an error. +func (r *ProxyResponseWriterALB) GetProxyResponse() (events.ALBTargetGroupResponse, error) { + r.notifyClosed() + + if r.status == defaultStatusCode { + return events.ALBTargetGroupResponse{}, errors.New("status code not set on response") + } + + var output string + isBase64 := false + + bb := (&r.body).Bytes() + + if utf8.Valid(bb) { + output = string(bb) + } else { + output = base64.StdEncoding.EncodeToString(bb) + isBase64 = true + } + + return events.ALBTargetGroupResponse{ + StatusCode: r.status, + StatusDescription: http.StatusText(r.status), + MultiValueHeaders: http.Header(r.headers), + Body: output, + IsBase64Encoded: isBase64, + }, nil +} diff --git a/core/responseALB_test.go b/core/responseALB_test.go new file mode 100644 index 0000000..24ace2b --- /dev/null +++ b/core/responseALB_test.go @@ -0,0 +1,180 @@ +package core + +import ( + "encoding/base64" + "math/rand" + "net/http" + "strings" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("ResponseWriterALB tests", func() { + Context("ALB writing to response object", func() { + response := NewProxyResponseWriterALB() + + It("Sets the correct default status", func() { + Expect(defaultStatusCode).To(Equal(response.status)) + }) + + It("Initializes the headers map", func() { + Expect(response.headers).ToNot(BeNil()) + Expect(0).To(Equal(len(response.headers))) + }) + + It("Writes headers correctly", func() { + response.Header().Add("Content-Type", "application/json") + + Expect(1).To(Equal(len(response.headers))) + Expect("application/json").To(Equal(response.headers["Content-Type"][0])) + }) + + It("Writes body content correctly", func() { + binaryBody := make([]byte, 256) + _, err := rand.Read(binaryBody) + Expect(err).To(BeNil()) + + written, err := response.Write(binaryBody) + Expect(err).To(BeNil()) + Expect(len(binaryBody)).To(Equal(written)) + }) + + It("Automatically set the status code to 200", func() { + Expect(http.StatusOK).To(Equal(response.status)) + }) + + It("Forces the status to a new code", func() { + response.WriteHeader(http.StatusAccepted) + Expect(http.StatusAccepted).To(Equal(response.status)) + }) + }) + + Context("Automatically set response content type", func() { + xmlBodyContent := "ToveJaniReminderDon't forget me this weekend!" + htmlBodyContent := " Title of the documentContent of the document......" + It("Does not set the content type if it's already set", func() { + resp := NewProxyResponseWriterALB() + resp.Header().Add("Content-Type", "application/json") + + resp.Write([]byte(xmlBodyContent)) + + Expect("application/json").To(Equal(resp.Header().Get("Content-Type"))) + proxyResp, err := resp.GetProxyResponse() + Expect(err).To(BeNil()) + Expect(1).To(Equal(len(proxyResp.MultiValueHeaders))) + Expect("application/json").To(Equal(proxyResp.MultiValueHeaders["Content-Type"][0])) + Expect(xmlBodyContent).To(Equal(proxyResp.Body)) + }) + + It("Sets the content type to text/xml given the body", func() { + resp := NewProxyResponseWriterALB() + resp.Write([]byte(xmlBodyContent)) + + Expect("").ToNot(Equal(resp.Header().Get("Content-Type"))) + Expect(true).To(Equal(strings.HasPrefix(resp.Header().Get("Content-Type"), "text/xml;"))) + proxyResp, err := resp.GetProxyResponse() + Expect(err).To(BeNil()) + Expect(1).To(Equal(len(proxyResp.MultiValueHeaders))) + Expect(true).To(Equal(strings.HasPrefix(proxyResp.MultiValueHeaders["Content-Type"][0], "text/xml;"))) + Expect(xmlBodyContent).To(Equal(proxyResp.Body)) + }) + + It("Sets the content type to text/html given the body", func() { + resp := NewProxyResponseWriterALB() + resp.Write([]byte(htmlBodyContent)) + + Expect("").ToNot(Equal(resp.Header().Get("Content-Type"))) + Expect(true).To(Equal(strings.HasPrefix(resp.Header().Get("Content-Type"), "text/html;"))) + proxyResp, err := resp.GetProxyResponse() + Expect(err).To(BeNil()) + Expect(1).To(Equal(len(proxyResp.MultiValueHeaders))) + Expect(true).To(Equal(strings.HasPrefix(proxyResp.MultiValueHeaders["Content-Type"][0], "text/html;"))) + Expect(htmlBodyContent).To(Equal(proxyResp.Body)) + }) + }) + + Context("Export ALB Target Group Response", func() { + emtpyResponse := NewProxyResponseWriterALB() + emtpyResponse.Header().Add("Content-Type", "application/json") + + It("Refuses empty responses with default status code", func() { + _, err := emtpyResponse.GetProxyResponse() + Expect(err).ToNot(BeNil()) + Expect("status code not set on response").To(Equal(err.Error())) + }) + + simpleResponse := NewProxyResponseWriterALB() + simpleResponse.Write([]byte("hello")) + simpleResponse.Header().Add("Content-Type", "text/plain") + It("Writes text body correctly", func() { + proxyResponse, err := simpleResponse.GetProxyResponse() + Expect(err).To(BeNil()) + Expect(proxyResponse).ToNot(BeNil()) + + Expect("hello").To(Equal(proxyResponse.Body)) + Expect(http.StatusOK).To(Equal(proxyResponse.StatusCode)) + Expect(1).To(Equal(len(proxyResponse.MultiValueHeaders))) + Expect(true).To(Equal(strings.HasPrefix(proxyResponse.MultiValueHeaders["Content-Type"][0], "text/plain"))) + Expect(proxyResponse.IsBase64Encoded).To(BeFalse()) + }) + + binaryResponse := NewProxyResponseWriterALB() + binaryResponse.Header().Add("Content-Type", "application/octet-stream") + binaryBody := make([]byte, 256) + _, err := rand.Read(binaryBody) + if err != nil { + Fail("Could not generate random binary body") + } + binaryResponse.Write(binaryBody) + binaryResponse.WriteHeader(http.StatusAccepted) + + It("Encodes binary responses correctly", func() { + proxyResponse, err := binaryResponse.GetProxyResponse() + Expect(err).To(BeNil()) + Expect(proxyResponse).ToNot(BeNil()) + + Expect(proxyResponse.IsBase64Encoded).To(BeTrue()) + Expect(base64.StdEncoding.EncodedLen(len(binaryBody))).To(Equal(len(proxyResponse.Body))) + + Expect(base64.StdEncoding.EncodeToString(binaryBody)).To(Equal(proxyResponse.Body)) + Expect(1).To(Equal(len(proxyResponse.MultiValueHeaders))) + Expect("application/octet-stream").To(Equal(proxyResponse.MultiValueHeaders["Content-Type"][0])) + Expect(http.StatusAccepted).To(Equal(proxyResponse.StatusCode)) + }) + }) + + Context("Handle multi-value headers", func() { + + It("Writes single-value headers correctly", func() { + response := NewProxyResponseWriterALB() + response.Header().Add("Content-Type", "application/json") + response.Write([]byte("hello")) + proxyResponse, err := response.GetProxyResponse() + Expect(err).To(BeNil()) + + // Headers are not also written to `Headers` field + Expect(0).To(Equal(len(proxyResponse.Headers))) + Expect(1).To(Equal(len(proxyResponse.MultiValueHeaders["Content-Type"]))) + Expect("application/json").To(Equal(proxyResponse.MultiValueHeaders["Content-Type"][0])) + }) + + It("Writes multi-value headers correctly", func() { + response := NewProxyResponseWriterALB() + response.Header().Add("Set-Cookie", "csrftoken=foobar") + response.Header().Add("Set-Cookie", "session_id=barfoo") + response.Write([]byte("hello")) + proxyResponse, err := response.GetProxyResponse() + Expect(err).To(BeNil()) + + // Headers are not also written to `Headers` field + Expect(0).To(Equal(len(proxyResponse.Headers))) + + // There are two headers here because Content-Type is always written implicitly + Expect(2).To(Equal(len(proxyResponse.MultiValueHeaders["Set-Cookie"]))) + Expect("csrftoken=foobar").To(Equal(proxyResponse.MultiValueHeaders["Set-Cookie"][0])) + Expect("session_id=barfoo").To(Equal(proxyResponse.MultiValueHeaders["Set-Cookie"][1])) + }) + }) + +}) diff --git a/core/typesALB.go b/core/typesALB.go new file mode 100644 index 0000000..bc5cf94 --- /dev/null +++ b/core/typesALB.go @@ -0,0 +1,11 @@ +package core + +import ( + "net/http" + + "github.com/aws/aws-lambda-go/events" +) + +func GatewayTimeoutALB() events.ALBTargetGroupResponse { + return events.ALBTargetGroupResponse{StatusCode: http.StatusGatewayTimeout} +} diff --git a/echo/adapterALB.go b/echo/adapterALB.go new file mode 100644 index 0000000..6c4c368 --- /dev/null +++ b/echo/adapterALB.go @@ -0,0 +1,59 @@ +package echoadapter + +import ( + "context" + "net/http" + + "github.com/aws/aws-lambda-go/events" + "github.com/awslabs/aws-lambda-go-api-proxy/core" + "github.com/labstack/echo/v4" +) + +// EchoLambdaALB makes it easy to send ALB proxy events to a echo.Echo. +// The library transforms the proxy event into an HTTP request and then +// creates a proxy response object from the http.ResponseWriter +type EchoLambdaALB struct { + core.RequestAccessorALB + + Echo *echo.Echo +} + +// NewAPI creates a new instance of the EchoLambdaAPI object. +// Receives an initialized *echo.Echo object - normally created with echo.New(). +// It returns the initialized instance of the EchoLambdaALB object. +func NewALB(e *echo.Echo) *EchoLambdaALB { + return &EchoLambdaALB{Echo: e} +} + +// Proxy receives an ALB event, transforms it into an http.Request +// object, and sends it to the echo.Echo for routing. +// It returns a proxy response object generated from the http.ResponseWriter. +func (e *EchoLambdaALB) Proxy(req events.ALBTargetGroupRequest) (events.ALBTargetGroupResponse, error) { + echoRequest, err := e.ProxyEventToHTTPRequest(req) + return e.proxyInternal(echoRequest, err) +} + +// ProxyWithContext receives context and an ALB event, +// transforms them into an http.Request object, and sends it to the echo.Echo for routing. +// It returns a proxy response object generated from the http.ResponseWriter. +func (e *EchoLambdaALB) ProxyWithContext(ctx context.Context, req events.ALBTargetGroupRequest) (events.ALBTargetGroupResponse, error) { + echoRequest, err := e.EventToRequestWithContext(ctx, req) + return e.proxyInternal(echoRequest, err) +} + +func (e *EchoLambdaALB) proxyInternal(req *http.Request, err error) (events.ALBTargetGroupResponse, error) { + + if err != nil { + return core.GatewayTimeoutALB(), core.NewLoggedError("Could not convert proxy event to request: %v", err) + } + + respWriter := core.NewProxyResponseWriterALB() + e.Echo.ServeHTTP(http.ResponseWriter(respWriter), req) + + proxyResponse, err := respWriter.GetProxyResponse() + if err != nil { + return core.GatewayTimeoutALB(), core.NewLoggedError("Error while generating proxy response: %v", err) + } + + return proxyResponse, nil +} diff --git a/echo/echolambda_test.go b/echo/echolambda_test.go index 877760b..08cf954 100644 --- a/echo/echolambda_test.go +++ b/echo/echolambda_test.go @@ -64,3 +64,30 @@ var _ = Describe("EchoLambdaV2 tests", func() { }) }) }) + +var _ = Describe("EchoLambdaALB tests", func() { + Context("Simple ping request", func() { + It("Proxies the event correctly", func() { + log.Println("Starting test") + e := echo.New() + e.GET("/ping", func(c echo.Context) error { + log.Println("Handler!!") + return c.String(200, "pong") + }) + + adapter := echoadapter.NewALB(e) + + req := events.ALBTargetGroupRequest{ + HTTPMethod: "GET", + Path: "/ping", + RequestContext: events.ALBTargetGroupRequestContext{ + ELB: events.ELBContext{TargetGroupArn: " ad"}, + }} + + resp, err := adapter.Proxy(req) + + Expect(err).To(BeNil()) + Expect(resp.StatusCode).To(Equal(200)) + }) + }) +}) diff --git a/gin/adapterALB.go b/gin/adapterALB.go new file mode 100644 index 0000000..c6ea5b1 --- /dev/null +++ b/gin/adapterALB.go @@ -0,0 +1,62 @@ +// Package ginadapter adds Gin support for the aws-severless-go-api library. +// Uses the core package behind the scenes and exposes the New and NewV2 and ALB methods to +// get a new instance and Proxy method to send request to the Gin engine. +package ginadapter + +import ( + "context" + "net/http" + + "github.com/aws/aws-lambda-go/events" + "github.com/awslabs/aws-lambda-go-api-proxy/core" + "github.com/gin-gonic/gin" +) + +// GinLambdaALB makes it easy to send ALB proxy events to a Gin +// Engine. The library transforms the proxy event into an HTTP request and then +// creates a proxy response object from the http.ResponseWriter +type GinLambdaALB struct { + core.RequestAccessorALB + + ginEngine *gin.Engine +} + +// New creates a new instance of the GinLambdaALB object. +// Receives an initialized *gin.Engine object - normally created with gin.Default(). +// It returns the initialized instance of the GinLambdaALB object. +func NewALB(gin *gin.Engine) *GinLambdaALB { + return &GinLambdaALB{ginEngine: gin} +} + +// Proxy receives an ALB proxy event, transforms it into an http.Request +// object, and sends it to the gin.Engine for routing. +// It returns a proxy response object generated from the http.ResponseWriter. +func (g *GinLambdaALB) Proxy(req events.ALBTargetGroupRequest) (events.ALBTargetGroupResponse, error) { + ginRequest, err := g.ProxyEventToHTTPRequest(req) + return g.proxyInternal(ginRequest, err) +} + +// ProxyWithContext receives context and an ALB proxy event, +// transforms them into an http.Request object, and sends it to the gin.Engine for routing. +// It returns a proxy response object generated from the http.ResponseWriter. +func (g *GinLambdaALB) ProxyWithContext(ctx context.Context, req events.ALBTargetGroupRequest) (events.ALBTargetGroupResponse, error) { + ginRequest, err := g.EventToRequestWithContext(ctx, req) + return g.proxyInternal(ginRequest, err) +} + +func (g *GinLambdaALB) proxyInternal(req *http.Request, err error) (events.ALBTargetGroupResponse, error) { + + if err != nil { + return core.GatewayTimeoutALB(), core.NewLoggedError("Could not convert proxy event to request: %v", err) + } + + respWriter := core.NewProxyResponseWriterALB() + g.ginEngine.ServeHTTP(http.ResponseWriter(respWriter), req) + + proxyResponse, err := respWriter.GetProxyResponse() + if err != nil { + return core.GatewayTimeoutALB(), core.NewLoggedError("Error while generating proxy response: %v", err) + } + + return proxyResponse, nil +} diff --git a/gin/ginlambda_test.go b/gin/ginlambda_test.go index da30404..e17610b 100644 --- a/gin/ginlambda_test.go +++ b/gin/ginlambda_test.go @@ -79,3 +79,37 @@ var _ = Describe("GinLambdaV2 tests", func() { }) }) }) + +var _ = Describe("GinLambdaALB tests", func() { + Context("Simple ping request", func() { + It("Proxies the event correctly", func() { + log.Println("Starting test") + r := gin.Default() + r.GET("/ping", func(c *gin.Context) { + log.Println("Handler!!") + c.JSON(200, gin.H{ + "message": "pong", + }) + }) + + adapter := ginadapter.NewALB(r) + + req := events.ALBTargetGroupRequest{ + HTTPMethod: "GET", + Path: "/ping", + RequestContext: events.ALBTargetGroupRequestContext{ + ELB: events.ELBContext{TargetGroupArn: " ad"}, + }} + + resp, err := adapter.Proxy(req) + + Expect(err).To(BeNil()) + Expect(resp.StatusCode).To(Equal(200)) + + resp, err = adapter.Proxy(req) + + Expect(err).To(BeNil()) + Expect(resp.StatusCode).To(Equal(200)) + }) + }) +}) diff --git a/gorillamux/adapterALB.go b/gorillamux/adapterALB.go new file mode 100644 index 0000000..6fdd60e --- /dev/null +++ b/gorillamux/adapterALB.go @@ -0,0 +1,53 @@ +package gorillamux + +import ( + "context" + "net/http" + + "github.com/aws/aws-lambda-go/events" + "github.com/awslabs/aws-lambda-go-api-proxy/core" + "github.com/gorilla/mux" +) + +type GorillaMuxAdapterALB struct { + core.RequestAccessorALB + router *mux.Router +} + +func NewALB(router *mux.Router) *GorillaMuxAdapterALB { + return &GorillaMuxAdapterALB{ + router: router, + } +} + +// Proxy receives an API Gateway proxy event, transforms it into an http.Request +// object, and sends it to the mux.Router for routing. +// It returns a proxy response object generated from the http.ResponseWriter. +func (h *GorillaMuxAdapterALB) Proxy(event events.ALBTargetGroupRequest) (events.ALBTargetGroupResponse, error) { + req, err := h.ProxyEventToHTTPRequest(event) + return h.proxyInternal(req, err) +} + +// ProxyWithContext receives context and an API Gateway proxy event, +// transforms them into an http.Request object, and sends it to the mux.Router for routing. +// It returns a proxy response object generated from the http.ResponseWriter. +func (h *GorillaMuxAdapterALB) ProxyWithContext(ctx context.Context, event events.ALBTargetGroupRequest) (events.ALBTargetGroupResponse, error) { + req, err := h.EventToRequestWithContext(ctx, event) + return h.proxyInternal(req, err) +} + +func (h *GorillaMuxAdapterALB) proxyInternal(req *http.Request, err error) (events.ALBTargetGroupResponse, error) { + if err != nil { + return core.GatewayTimeoutALB(), core.NewLoggedError("Could not convert proxy event to request: %v", err) + } + + w := core.NewProxyResponseWriterALB() + h.router.ServeHTTP(http.ResponseWriter(w), req) + + resp, err := w.GetProxyResponse() + if err != nil { + return core.GatewayTimeoutALB(), core.NewLoggedError("Error while generating proxy response: %v", err) + } + + return resp, nil +} diff --git a/gorillamux/adapterALB_test.go b/gorillamux/adapterALB_test.go new file mode 100644 index 0000000..aa79154 --- /dev/null +++ b/gorillamux/adapterALB_test.go @@ -0,0 +1,62 @@ +package gorillamux_test + +import ( + "context" + "fmt" + "net/http" + + "github.com/aws/aws-lambda-go/events" + "github.com/awslabs/aws-lambda-go-api-proxy/gorillamux" + "github.com/gorilla/mux" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("GorillaMuxAdapterALB tests", func() { + Context("Simple ping request", func() { + It("Proxies the event correctly", func() { + homeHandler := func(w http.ResponseWriter, req *http.Request) { + w.Header().Add("unfortunately-required-header", "") + fmt.Fprintf(w, "Home Page") + } + + productsHandler := func(w http.ResponseWriter, req *http.Request) { + w.Header().Add("unfortunately-required-header", "") + fmt.Fprintf(w, "Products Page") + } + + r := mux.NewRouter() + r.HandleFunc("/", homeHandler) + r.HandleFunc("/products", productsHandler) + + adapter := gorillamux.NewALB(r) + + homePageReq := events.ALBTargetGroupRequest{ + HTTPMethod: http.MethodGet, + Path: "/", + RequestContext: events.ALBTargetGroupRequestContext{ + ELB: events.ELBContext{TargetGroupArn: " ad"}, + }} + + homePageResp, homePageReqErr := adapter.ProxyWithContext(context.Background(), homePageReq) + + Expect(homePageReqErr).To(BeNil()) + Expect(homePageResp.StatusCode).To(Equal(200)) + Expect(homePageResp.Body).To(Equal("Home Page")) + + productsPageReq := events.ALBTargetGroupRequest{ + HTTPMethod: http.MethodGet, + Path: "/products", + RequestContext: events.ALBTargetGroupRequestContext{ + ELB: events.ELBContext{TargetGroupArn: " ad"}, + }} + + productsPageResp, productsPageReqErr := adapter.Proxy(productsPageReq) + + Expect(productsPageReqErr).To(BeNil()) + Expect(productsPageResp.StatusCode).To(Equal(200)) + Expect(productsPageResp.Body).To(Equal("Products Page")) + }) + }) +}) diff --git a/handlerfunc/adapterALB.go b/handlerfunc/adapterALB.go new file mode 100644 index 0000000..4384b39 --- /dev/null +++ b/handlerfunc/adapterALB.go @@ -0,0 +1,13 @@ +package handlerfunc + +import ( + "net/http" + + "github.com/awslabs/aws-lambda-go-api-proxy/httpadapter" +) + +type HandlerFuncAdapterALB = httpadapter.HandlerAdapterALB + +func NewALB(handlerFunc http.HandlerFunc) *HandlerFuncAdapterALB { + return httpadapter.NewALB(handlerFunc) +} diff --git a/handlerfunc/adapterALB_test.go b/handlerfunc/adapterALB_test.go new file mode 100644 index 0000000..d4738bd --- /dev/null +++ b/handlerfunc/adapterALB_test.go @@ -0,0 +1,46 @@ +package handlerfunc_test + +import ( + "context" + "fmt" + "log" + "net/http" + + "github.com/aws/aws-lambda-go/events" + "github.com/awslabs/aws-lambda-go-api-proxy/handlerfunc" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("HandlerFuncAdapter ALB tests", func() { + Context("Simple ping request", func() { + It("Proxies the event correctly", func() { + log.Println("Starting test") + + handler := func(w http.ResponseWriter, req *http.Request) { + w.Header().Add("unfortunately-required-header", "") + fmt.Fprintf(w, "Go Lambda!!") + } + + adapter := handlerfunc.NewALB(handler) + + req := events.ALBTargetGroupRequest{ + HTTPMethod: http.MethodGet, + Path: "/", + RequestContext: events.ALBTargetGroupRequestContext{ + ELB: events.ELBContext{TargetGroupArn: " ad"}, + }} + + resp, err := adapter.ProxyWithContext(context.Background(), req) + + Expect(err).To(BeNil()) + Expect(resp.StatusCode).To(Equal(200)) + + resp, err = adapter.Proxy(req) + + Expect(err).To(BeNil()) + Expect(resp.StatusCode).To(Equal(200)) + }) + }) +}) diff --git a/httpadapter/adapterALB.go b/httpadapter/adapterALB.go new file mode 100644 index 0000000..2242632 --- /dev/null +++ b/httpadapter/adapterALB.go @@ -0,0 +1,52 @@ +package httpadapter + +import ( + "context" + "net/http" + + "github.com/aws/aws-lambda-go/events" + "github.com/awslabs/aws-lambda-go-api-proxy/core" +) + +type HandlerAdapterALB struct { + core.RequestAccessorALB + handler http.Handler +} + +func NewALB(handler http.Handler) *HandlerAdapterALB { + return &HandlerAdapterALB{ + handler: handler, + } +} + +// Proxy receives an ALB Target Group proxy event, transforms it into an http.Request +// object, and sends it to the http.HandlerFunc for routing. +// It returns a proxy response object generated from the http.ResponseWriter. +func (h *HandlerAdapterALB) Proxy(event events.ALBTargetGroupRequest) (events.ALBTargetGroupResponse, error) { + req, err := h.ProxyEventToHTTPRequest(event) + return h.proxyInternal(req, err) +} + +// ProxyWithContext receives context and an ALB proxy event, +// transforms them into an http.Request object, and sends it to the http.Handler for routing. +// It returns a proxy response object generated from the http.ResponseWriter. +func (h *HandlerAdapterALB) ProxyWithContext(ctx context.Context, event events.ALBTargetGroupRequest) (events.ALBTargetGroupResponse, error) { + req, err := h.EventToRequestWithContext(ctx, event) + return h.proxyInternal(req, err) +} + +func (h *HandlerAdapterALB) proxyInternal(req *http.Request, err error) (events.ALBTargetGroupResponse, error) { + if err != nil { + return core.GatewayTimeoutALB(), core.NewLoggedError("Could not convert proxy event to request: %v", err) + } + + w := core.NewProxyResponseWriterALB() + h.handler.ServeHTTP(http.ResponseWriter(w), req) + + resp, err := w.GetProxyResponse() + if err != nil { + return core.GatewayTimeoutALB(), core.NewLoggedError("Error while generating proxy response: %v", err) + } + + return resp, nil +} diff --git a/httpadapter/adapterALB_test.go b/httpadapter/adapterALB_test.go new file mode 100644 index 0000000..a0b89d9 --- /dev/null +++ b/httpadapter/adapterALB_test.go @@ -0,0 +1,48 @@ +package httpadapter_test + +import ( + "context" + "fmt" + "log" + "net/http" + + "github.com/aws/aws-lambda-go/events" + "github.com/awslabs/aws-lambda-go-api-proxy/httpadapter" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("HandlerFuncAdapter tests", func() { + Context("Simple ping request", func() { + It("Proxies the event correctly", func() { + log.Println("Starting test") + + var handler http.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + w.Header().Add("unfortunately-required-header", "") + fmt.Fprintf(w, "Go Lambda!!") + }) + + adapter := httpadapter.NewV2(handler) + + req := events.APIGatewayV2HTTPRequest{ + RequestContext: events.APIGatewayV2HTTPRequestContext{ + HTTP: events.APIGatewayV2HTTPRequestContextHTTPDescription{ + Method: http.MethodGet, + Path: "/ping", + }, + }, + } + + resp, err := adapter.ProxyWithContext(context.Background(), req) + + Expect(err).To(BeNil()) + Expect(resp.StatusCode).To(Equal(200)) + + resp, err = adapter.Proxy(req) + + Expect(err).To(BeNil()) + Expect(resp.StatusCode).To(Equal(200)) + }) + }) +})