Skip to content

Commit

Permalink
aws#1084: decode body if base64 is enable
Browse files Browse the repository at this point in the history
  • Loading branch information
npeters committed Oct 19, 2024
1 parent e6abc1f commit 8fef6b5
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 28 deletions.
Original file line number Diff line number Diff line change
@@ -1,24 +1,26 @@
package com.amazonaws.serverless.proxy.spring;

import java.io.InputStream;
import java.nio.charset.Charset;
import java.nio.charset.StandardCharsets;
import java.util.Iterator;
import java.nio.charset.UnsupportedCharsetException;
import java.util.Base64;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Set;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;

import org.apache.commons.io.Charsets;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.cloud.function.serverless.web.ServerlessHttpServletRequest;
import org.springframework.cloud.function.serverless.web.ServerlessMVC;
import org.springframework.http.HttpHeaders;
import org.springframework.util.CollectionUtils;
import org.springframework.util.FileCopyUtils;
import org.springframework.util.MultiValueMapAdapter;
import org.springframework.util.StringUtils;

import com.amazonaws.serverless.proxy.AsyncInitializationWrapper;
import com.amazonaws.serverless.proxy.AwsHttpApiV2SecurityContextWriter;
import com.amazonaws.serverless.proxy.AwsProxySecurityContextWriter;
import com.amazonaws.serverless.proxy.RequestReader;
Expand Down Expand Up @@ -120,10 +122,14 @@ private static HttpServletRequest generateRequest1(String request, Context lambd
MultiValueMapAdapter headers = new MultiValueMapAdapter(v1Request.getMultiValueHeaders());
httpRequest.setHeaders(headers);
}
if (StringUtils.hasText(v1Request.getBody())) {
httpRequest.setContentType("application/json");
httpRequest.setContent(v1Request.getBody().getBytes(StandardCharsets.UTF_8));
}
if (StringUtils.hasText(v1Request.getBody())) {
if (v1Request.isBase64Encoded()) {
httpRequest.setContent(Base64.getMimeDecoder().decode(v1Request.getBody()));
} else {
Charset charseEncoding = parseCharacterEncoding(v1Request.getHeaders().get(HttpHeaders.CONTENT_TYPE));
httpRequest.setContent(v1Request.getBody().getBytes(charseEncoding));
}
}
if (v1Request.getRequestContext() != null) {
httpRequest.setAttribute(RequestReader.API_GATEWAY_CONTEXT_PROPERTY, v1Request.getRequestContext());
httpRequest.setAttribute(RequestReader.ALB_CONTEXT_PROPERTY, v1Request.getRequestContext().getElb());
Expand All @@ -149,11 +155,15 @@ private static HttpServletRequest generateRequest2(String request, Context lambd
populateQueryStringparameters(v2Request.getQueryStringParameters(), httpRequest);

v2Request.getHeaders().forEach(httpRequest::setHeader);

if (StringUtils.hasText(v2Request.getBody())) {
httpRequest.setContentType("application/json");
httpRequest.setContent(v2Request.getBody().getBytes(StandardCharsets.UTF_8));
}

if (StringUtils.hasText(v2Request.getBody())) {
if (v2Request.isBase64Encoded()) {
httpRequest.setContent(Base64.getMimeDecoder().decode(v2Request.getBody()));
} else {
Charset charseEncoding = parseCharacterEncoding(v2Request.getHeaders().get(HttpHeaders.CONTENT_TYPE));
httpRequest.setContent(v2Request.getBody().getBytes(charseEncoding));
}
}
httpRequest.setAttribute(RequestReader.HTTP_API_CONTEXT_PROPERTY, v2Request.getRequestContext());
httpRequest.setAttribute(RequestReader.HTTP_API_STAGE_VARS_PROPERTY, v2Request.getStageVariables());
httpRequest.setAttribute(RequestReader.HTTP_API_EVENT_PROPERTY, v2Request);
Expand All @@ -180,4 +190,36 @@ private static <T> T readValue(String json, Class<T> clazz, ObjectMapper mapper)
}
}

static final String HEADER_KEY_VALUE_SEPARATOR = "=";
static final String HEADER_VALUE_SEPARATOR = ";";
static final String ENCODING_VALUE_KEY = "charset";
static protected Charset parseCharacterEncoding(String contentTypeHeader) {
// we only look at content-type because content-encoding should only be used for
// "binary" requests such as gzip/deflate.
Charset defaultCharset = StandardCharsets.UTF_8;
if (contentTypeHeader == null) {
return defaultCharset;
}

String[] contentTypeValues = contentTypeHeader.split(HEADER_VALUE_SEPARATOR);
if (contentTypeValues.length <= 1) {
return defaultCharset;
}

for (String contentTypeValue : contentTypeValues) {
if (contentTypeValue.trim().startsWith(ENCODING_VALUE_KEY)) {
String[] encodingValues = contentTypeValue.split(HEADER_KEY_VALUE_SEPARATOR);
if (encodingValues.length <= 1) {
return defaultCharset;
}
try {
return Charsets.toCharset(encodingValues[1]);
} catch (UnsupportedCharsetException ex) {
return defaultCharset;
}
}
}
return defaultCharset;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,11 @@
import java.io.ByteArrayOutputStream;
import java.io.InputStream;
import java.nio.charset.StandardCharsets;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashMap;
import java.util.Map;
import java.util.*;

import com.amazonaws.serverless.exceptions.ContainerInitializationException;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.springframework.cloud.function.serverless.web.ServerlessServletContext;
import org.springframework.util.CollectionUtils;

import com.amazonaws.serverless.proxy.spring.servletapp.MessageData;
Expand Down Expand Up @@ -214,7 +210,7 @@ public static Collection<String> data() {
public void validateComplesrequest(String jsonEvent) throws Exception {
initServletAppTest();
InputStream targetStream = new ByteArrayInputStream(this.generateHttpRequest(jsonEvent, "POST",
"/foo/male/list/24", "{\"name\":\"bob\"}", null));
"/foo/male/list/24", "{\"name\":\"bob\"}", false,null));
ByteArrayOutputStream output = new ByteArrayOutputStream();
handler.handleRequest(targetStream, output, null);
Map result = mapper.readValue(output.toString(StandardCharsets.UTF_8), Map.class);
Expand All @@ -229,7 +225,7 @@ public void validateComplesrequest(String jsonEvent) throws Exception {
@ParameterizedTest
public void testAsyncPost(String jsonEvent) throws Exception {
initServletAppTest();
InputStream targetStream = new ByteArrayInputStream(this.generateHttpRequest(jsonEvent, "POST", "/async", "{\"name\":\"bob\"}", null));
InputStream targetStream = new ByteArrayInputStream(this.generateHttpRequest(jsonEvent, "POST", "/async", "{\"name\":\"bob\"}",false, null));
ByteArrayOutputStream output = new ByteArrayOutputStream();
handler.handleRequest(targetStream, output, null);
Map result = mapper.readValue(output.toString(StandardCharsets.UTF_8), Map.class);
Expand All @@ -242,7 +238,7 @@ public void testAsyncPost(String jsonEvent) throws Exception {
public void testValidate400(String jsonEvent) throws Exception {
initServletAppTest();
UserData ud = new UserData();
InputStream targetStream = new ByteArrayInputStream(this.generateHttpRequest(jsonEvent, "POST", "/validate", mapper.writeValueAsString(ud), null));
InputStream targetStream = new ByteArrayInputStream(this.generateHttpRequest(jsonEvent, "POST", "/validate", mapper.writeValueAsString(ud),false, null));
ByteArrayOutputStream output = new ByteArrayOutputStream();
handler.handleRequest(targetStream, output, null);
Map result = mapper.readValue(output.toString(StandardCharsets.UTF_8), Map.class);
Expand All @@ -258,27 +254,48 @@ public void testValidate200(String jsonEvent) throws Exception {
ud.setFirstName("bob");
ud.setLastName("smith");
ud.setEmail("[email protected]");
InputStream targetStream = new ByteArrayInputStream(this.generateHttpRequest(jsonEvent, "POST", "/validate", mapper.writeValueAsString(ud), null));
InputStream targetStream = new ByteArrayInputStream(this.generateHttpRequest(jsonEvent, "POST", "/validate", mapper.writeValueAsString(ud),false, null));
ByteArrayOutputStream output = new ByteArrayOutputStream();
handler.handleRequest(targetStream, output, null);
Map result = mapper.readValue(output.toString(StandardCharsets.UTF_8), Map.class);
assertEquals(200, result.get("statusCode"));
assertEquals("VALID", result.get("body"));
}

@MethodSource("data")
@ParameterizedTest
public void testValidate200Base64(String jsonEvent) throws Exception {
initServletAppTest();
UserData ud = new UserData();
ud.setFirstName("bob");
ud.setLastName("smith");
ud.setEmail("[email protected]");
InputStream targetStream = new ByteArrayInputStream(this.generateHttpRequest(jsonEvent, "POST", "/validate",
Base64.getMimeEncoder().encodeToString(mapper.writeValueAsString(ud).getBytes()),true, null));

ByteArrayOutputStream output = new ByteArrayOutputStream();
handler.handleRequest(targetStream, output, null);
Map result = mapper.readValue(output.toString(StandardCharsets.UTF_8), Map.class);
assertEquals(200, result.get("statusCode"));
assertEquals("VALID", result.get("body"));
}


@MethodSource("data")
@ParameterizedTest
public void messageObject_parsesObject_returnsCorrectMessage(String jsonEvent) throws Exception {
initServletAppTest();
InputStream targetStream = new ByteArrayInputStream(this.generateHttpRequest(jsonEvent, "POST", "/message",
mapper.writeValueAsString(new MessageData("test message")), null));
mapper.writeValueAsString(new MessageData("test message")),false, null));
ByteArrayOutputStream output = new ByteArrayOutputStream();
handler.handleRequest(targetStream, output, null);
Map result = mapper.readValue(output.toString(StandardCharsets.UTF_8), Map.class);
assertEquals(200, result.get("statusCode"));
assertEquals("test message", result.get("body"));
}



@SuppressWarnings({"unchecked" })
@MethodSource("data")
@ParameterizedTest
Expand All @@ -289,40 +306,42 @@ void messageObject_propertiesInContentType_returnsCorrectMessage(String jsonEven
headers.put(HttpHeaders.CONTENT_TYPE, "application/json;v=1");
headers.put(HttpHeaders.ACCEPT, "application/json;v=1");
InputStream targetStream = new ByteArrayInputStream(this.generateHttpRequest(jsonEvent, "POST", "/message",
mapper.writeValueAsString(new MessageData("test message")), headers));
mapper.writeValueAsString(new MessageData("test message")),false, headers));

ByteArrayOutputStream output = new ByteArrayOutputStream();
handler.handleRequest(targetStream, output, null);
Map result = mapper.readValue(output.toString(StandardCharsets.UTF_8), Map.class);
assertEquals("test message", result.get("body"));
}

private byte[] generateHttpRequest(String jsonEvent, String method, String path, String body, Map headers) throws Exception {
private byte[] generateHttpRequest(String jsonEvent, String method, String path, String body,boolean isBase64Encoded, Map headers) throws Exception {
Map requestMap = mapper.readValue(jsonEvent, Map.class);
if (requestMap.get("version").equals("2.0")) {
return generateHttpRequest2(requestMap, method, path, body, headers);
return generateHttpRequest2(requestMap, method, path, body, isBase64Encoded,headers);
}
return generateHttpRequest(requestMap, method, path, body, headers);
return generateHttpRequest(requestMap, method, path, body,isBase64Encoded, headers);
}

@SuppressWarnings({ "unchecked"})
private byte[] generateHttpRequest(Map requestMap, String method, String path, String body, Map headers) throws Exception {
private byte[] generateHttpRequest(Map requestMap, String method, String path, String body,boolean isBase64Encoded, Map headers) throws Exception {
requestMap.put("path", path);
requestMap.put("httpMethod", method);
requestMap.put("body", body);
requestMap.put("isBase64Encoded", isBase64Encoded);
if (!CollectionUtils.isEmpty(headers)) {
requestMap.put("headers", headers);
}
return mapper.writeValueAsBytes(requestMap);
}

@SuppressWarnings({ "unchecked"})
private byte[] generateHttpRequest2(Map requestMap, String method, String path, String body, Map headers) throws Exception {
private byte[] generateHttpRequest2(Map requestMap, String method, String path, String body,boolean isBase64Encoded, Map headers) throws Exception {
Map map = mapper.readValue(API_GATEWAY_EVENT_V2, Map.class);
Map http = (Map) ((Map) map.get("requestContext")).get("http");
http.put("path", path);
http.put("method", method);
map.put("body", body);
map.put("isBase64Encoded", isBase64Encoded);
if (!CollectionUtils.isEmpty(headers)) {
map.put("headers", headers);
}
Expand Down

0 comments on commit 8fef6b5

Please sign in to comment.