Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix the content type retrieval #7013

Merged
merged 9 commits into from
Oct 15, 2021
51 changes: 39 additions & 12 deletions airbyte-server/src/main/java/io/airbyte/server/RequestLogger.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.node.ObjectNode;
import com.google.common.annotations.VisibleForTesting;
import io.airbyte.commons.json.Jsons;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
Expand Down Expand Up @@ -42,6 +43,12 @@ public RequestLogger(Map<String, String> mdc) {
this.mdc = mdc;
}

@VisibleForTesting
RequestLogger(Map<String, String> mdc, HttpServletRequest servletRequest) {
this.mdc = mdc;
this.servletRequest = servletRequest;
}

@Override
public void filter(ContainerRequestContext requestContext) throws IOException {
if (requestContext.getMethod().equals("POST")) {
Expand All @@ -63,21 +70,20 @@ public void filter(ContainerRequestContext requestContext, ContainerResponseCont
String remoteAddr = servletRequest.getRemoteAddr();
String method = servletRequest.getMethod();
String url = servletRequest.getRequestURI();
boolean isContentTypeGzip =
servletRequest.getHeader("Content-Type") != null && servletRequest.getHeader("Content-Type").toLowerCase().contains("application/x-gzip");

boolean isPrintable = servletRequest.getHeader("Content-Type") != null &&
servletRequest.getHeader("Content-Type").toLowerCase().contains("application/json") &&
isValidJson(requestBody);

int status = responseContext.getStatus();

StringBuilder logBuilder = new StringBuilder()
.append("REQ ")
.append(remoteAddr)
.append(" ")
.append(method)
.append(" ")
.append(status)
.append(" ")
.append(url);
StringBuilder logBuilder = createLogPrefix(
remoteAddr,
method,
status,
url);

if (method.equals("POST") && requestBody != null && !requestBody.equals("") && !isContentTypeGzip) {
if (method.equals("POST") && requestBody != null && !requestBody.equals("") && isPrintable) {
logBuilder
.append(" - ")
.append(redactSensitiveInfo(requestBody));
Expand All @@ -90,6 +96,23 @@ public void filter(ContainerRequestContext requestContext, ContainerResponseCont
}
}

@VisibleForTesting
static StringBuilder createLogPrefix(
String remoteAddr,
String method,
int status,
String url) {
return new StringBuilder()
.append("REQ ")
.append(remoteAddr)
.append(" ")
.append(method)
.append(" ")
.append(status)
.append(" ")
.append(url);
}

private static final Set<String> TOP_LEVEL_SENSITIVE_FIELDS = Set.of(
"connectionConfiguration");

Expand All @@ -116,4 +139,8 @@ private static String redactSensitiveInfo(String requestBody) {
return requestBody;
}

private static boolean isValidJson(String json) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We're using Jackson for JSON manipulation almost everywhere and we have a few helpers for it. Can we use Jsons.tryDeserialize instead?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

return Jsons.tryDeserialize(json).isPresent();
}

}
125 changes: 125 additions & 0 deletions airbyte-server/src/test/java/io/airbyte/server/RequestLoggerTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
/*
* Copyright (c) 2021 Airbyte, Inc., all rights reserved.
*/

package io.airbyte.server;

import io.airbyte.commons.io.IOs;
import io.airbyte.config.helpers.LogClientSingleton;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.stream.Stream;
import javax.servlet.http.HttpServletRequest;
import javax.ws.rs.container.ContainerRequestContext;
import javax.ws.rs.container.ContainerResponseContext;
import org.assertj.core.api.Assertions;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.mockito.Mock;
import org.mockito.Mockito;
import org.mockito.junit.jupiter.MockitoExtension;
import org.slf4j.MDC;

@ExtendWith(MockitoExtension.class)
public class RequestLoggerTest {

private static final String VALID_JSON_OBJECT = "{\"valid\":1}";
private static final String INVALID_JSON_OBJECT = "invalid";
private static final String ACCEPTED_CONTENT_TYPE = "application/json";
private static final String NON_ACCEPTED_CONTENT_TYPE = "application/gzip";

private static final String METHOD = "POST";
private static final String REMOTE_ADDR = "123.456.789.101";
private static final String URL = "/api/v1/test";

@Mock
private HttpServletRequest mServletRequest;

@Mock
private ContainerRequestContext mRequestContext;

@Mock
private ContainerResponseContext mResponseContext;

private RequestLogger requestLogger;

@BeforeEach
public void init() throws Exception {
Mockito.when(mRequestContext.getMethod())
.thenReturn(METHOD);

Mockito.when(mServletRequest.getMethod())
.thenReturn(METHOD);
Mockito.when(mServletRequest.getRemoteAddr())
.thenReturn(REMOTE_ADDR);
Mockito.when(mServletRequest.getRequestURI())
.thenReturn(URL);
}

private static final int ERROR_CODE = 401;
private static final int SUCCESS_CODE = 200;

private static final String errorPrefix = RequestLogger
.createLogPrefix(REMOTE_ADDR, METHOD, ERROR_CODE, URL)
.toString();

private static final String successPrefix = RequestLogger
.createLogPrefix(REMOTE_ADDR, METHOD, SUCCESS_CODE, URL)
.toString();

static Stream<Arguments> logScenarios() {
return Stream.of(
Arguments.of(INVALID_JSON_OBJECT, NON_ACCEPTED_CONTENT_TYPE, ERROR_CODE, errorPrefix),
Arguments.of(INVALID_JSON_OBJECT, ACCEPTED_CONTENT_TYPE, ERROR_CODE, errorPrefix),
Arguments.of(VALID_JSON_OBJECT, NON_ACCEPTED_CONTENT_TYPE, ERROR_CODE, errorPrefix),
Arguments.of(VALID_JSON_OBJECT, ACCEPTED_CONTENT_TYPE, ERROR_CODE, errorPrefix + " - " + VALID_JSON_OBJECT),
Arguments.of(INVALID_JSON_OBJECT, NON_ACCEPTED_CONTENT_TYPE, SUCCESS_CODE, successPrefix),
Arguments.of(INVALID_JSON_OBJECT, ACCEPTED_CONTENT_TYPE, SUCCESS_CODE, successPrefix),
Arguments.of(VALID_JSON_OBJECT, NON_ACCEPTED_CONTENT_TYPE, SUCCESS_CODE, successPrefix),
Arguments.of(VALID_JSON_OBJECT, ACCEPTED_CONTENT_TYPE, SUCCESS_CODE, successPrefix + " - " + VALID_JSON_OBJECT));
}

@ParameterizedTest
@MethodSource("logScenarios")
@DisplayName("Check that the proper log is produced based on the scenario")
public void test(String inputByteBuffer, String contentType, int status, String expectedLog) throws IOException {
// set up the mdc so that actually log to a file, so that we can verify that file logging captures
// threads.
final Path jobRoot = Files.createTempDirectory(Path.of("/tmp"), "mdc_test");
LogClientSingleton.setJobMdc(jobRoot);

// We have to instanciate the logger here, because the MDC config has been changed to log in a
// temporary file.
requestLogger = new RequestLogger(MDC.getCopyOfContextMap(), mServletRequest);

Mockito.when(mRequestContext.getEntityStream())
.thenReturn(new ByteArrayInputStream(inputByteBuffer.getBytes()));

Mockito.when(mResponseContext.getStatus())
.thenReturn(status);

Mockito.when(mServletRequest.getHeader("Content-Type"))
.thenReturn(contentType);

// This is call to set the requestBody variable in the RequestLogger
requestLogger.filter(mRequestContext);
requestLogger.filter(mRequestContext, mResponseContext);

String expectedLogLevel = status == SUCCESS_CODE ? "INFO" : "ERROR";

final Path logPath = jobRoot.resolve(LogClientSingleton.LOG_FILENAME);
final String logs = IOs.readFile(logPath);
final Stream<String> matchingLines = logs.lines()
.filter(line -> line.endsWith(expectedLog))
.filter(line -> line.contains(expectedLogLevel));

Assertions.assertThat(matchingLines).hasSize(1);
}

}
41 changes: 21 additions & 20 deletions build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -58,24 +58,24 @@ def createJavaLicenseWith = { license ->
// monorepo setup and it doesn't actually exclude directories reliably. This code makes the behavior predictable.
def createSpotlessTarget = { pattern ->
def excludes = [
'.gradle',
'node_modules',
'.eggs',
'.mypy_cache',
'.venv',
'*.egg-info',
'build',
'dbt-project-template',
'dbt-project-template-mssql',
'dbt-project-template-mysql',
'dbt-project-template-oracle',
'dbt_data_tests',
'dbt_data_tests_tmp',
'dbt_schema_tests',
'normalization_test_output',
'tools',
'secrets',
'charts' // Helm charts often have injected template strings that will fail general linting. Helm linting is done separately.
'.gradle',
'node_modules',
'.eggs',
'.mypy_cache',
'.venv',
'*.egg-info',
'build',
'dbt-project-template',
'dbt-project-template-mssql',
'dbt-project-template-mysql',
'dbt-project-template-oracle',
'dbt_data_tests',
'dbt_data_tests_tmp',
'dbt_schema_tests',
'normalization_test_output',
'tools',
'secrets',
'charts' // Helm charts often have injected template strings that will fail general linting. Helm linting is done separately.
]

if (System.getenv().containsKey("SUB_BUILD")) {
Expand Down Expand Up @@ -235,6 +235,7 @@ subprojects {
testImplementation 'org.junit.jupiter:junit-jupiter-api:5.7.1'
testImplementation 'org.junit.jupiter:junit-jupiter-params:5.7.1'
testImplementation 'org.mockito:mockito-junit-jupiter:3.9.0'
testImplementation 'org.assertj:assertj-core:3.21.0'
benmoriceau marked this conversation as resolved.
Show resolved Hide resolved
}
}

Expand Down Expand Up @@ -268,7 +269,7 @@ task('format') {
subprojects {
def pythonFormatTask = project.tasks.findByName('blackFormat')

if(pythonFormatTask != null) {
if (pythonFormatTask != null) {
apply plugin: "com.github.hierynomus.license"
task licenseFormatPython(type: com.hierynomus.gradle.license.tasks.LicenseFormat) {
header = createPythonLicenseWith(rootProject.file('LICENSE_SHORT'))
Expand All @@ -284,7 +285,7 @@ subprojects {
flakeCheck.dependsOn licenseTask

def generateFilesTask = project.tasks.findByName('generateProtocolClassFiles')
if(generateFilesTask != null) {
if (generateFilesTask != null) {
licenseTask.dependsOn generateFilesTask
}
}
Expand Down