diff --git a/README.md b/README.md index 25f40759..18304f42 100644 --- a/README.md +++ b/README.md @@ -15,10 +15,8 @@ There are a few modules associated with this project: functionality. * [wingtips-java8](wingtips-java8/README.md) - Provides several Java 8 helpers, particularly around helping tracing and MDC information to hop threads in asynchronous/non-blocking use cases. -* [wingtips-servlet-api](wingtips-servlet-api/README.md) - A plugin for Servlet 3+ based applications for integrating -distributed tracing with a simple Servlet Filter. -* [wingtips-old-servlet-api](wingtips-old-servlet-api/README.md) - A plugin for Servlet 2.x based applications for -integrating distributed tracing with a simple Servlet Filter. +* [wingtips-servlet-api](wingtips-servlet-api/README.md) - A plugin for Servlet based applications for integrating +distributed tracing with a simple Servlet Filter. Supports Servlet 2.x and Servlet 3 (async request) environments. * [wingtips-zipkin](wingtips-zipkin/README.md) - A plugin providing easy Zipkin integration by converting Wingtips spans to Zipkin spans and sending them to a Zipkin server. @@ -118,10 +116,9 @@ The `extractParentSpanFromRequest()` method is potentially different for differe #### Is your application running in a Servlet-based framework? -If your application is running in a Servlet environment (e.g. Spring MVC, Jersey, raw Servlets, etc) then this entire -lifecycle can be handled by a Servlet `Filter`. We've created one for you that's ready to drop in and go - see the -[wingtips-servlet-api](wingtips-servlet-api/README.md) Wingtips plugin module library for details if you're in a -Servlet 3+ environment. For Servlet 2.x see [wingtips-old-servlet-api](wingtips-old-servlet-api/README.md). That plugin +If your application is running in a Servlet environment (e.g. Spring Boot, Spring MVC, Jersey, raw Servlets, etc) then +this entire lifecycle can be handled by a Servlet `Filter`. We've created one for you that's ready to drop in and go - +see the [wingtips-servlet-api](wingtips-servlet-api/README.md) Wingtips plugin module library for details. That plugin module is also a good resource to see how the code for a production-ready implementation of this library might look. diff --git a/build.gradle b/build.gradle index 0be56a9b..cdc3ed3a 100644 --- a/build.gradle +++ b/build.gradle @@ -87,7 +87,7 @@ ext { // like try-with-resources that generate many many branches in the bytecode that are realistically impossible to get coverage for. // The combination of those issues mean we get artificially low coverage numbers even though it's clean correct code, so we just // have to visually verify it. - configure(subprojects.findAll { !it.name.contains("wingtips-zipkin") && !it.name.startsWith("sample")}) { + configure(subprojects.findAll { !it.name.contains("wingtips-zipkin") && !it.name.startsWith("sample") && !it.name.startsWith("testonly")}) { jacocoCoverage { // Enforce minimum code coverage. See https://github.com/palantir/gradle-jacoco-coverage for the full list of options. reportThreshold 0.95, INSTRUCTION diff --git a/gradle/bintrayPublishing.gradle b/gradle/bintrayPublishing.gradle index 0477a05a..cafffee8 100644 --- a/gradle/bintrayPublishing.gradle +++ b/gradle/bintrayPublishing.gradle @@ -1,5 +1,5 @@ configure(subprojects.findAll { - return !it.name.startsWith("sample") + return !it.name.startsWith("sample") && !it.name.startsWith("testonly") }) { apply plugin: 'maven' apply plugin: 'maven-publish' diff --git a/gradle/jacoco.gradle b/gradle/jacoco.gradle index e1c519e9..3f1065ab 100644 --- a/gradle/jacoco.gradle +++ b/gradle/jacoco.gradle @@ -32,8 +32,8 @@ subprojects { def subprojectsToIncludeForJacocoComboReport(Set origSubprojects) { Set projectsToInclude = new HashSet<>() for (Project subproj : origSubprojects) { - // For this project we'll include everything that's not a sample - if (!subproj.getName().startsWith("sample")) { + // For this project we'll include everything that's not a sample or a testonly module + if (!subproj.getName().startsWith("sample") && !subproj.getName().startsWith("testonly")) { projectsToInclude.add(subproj) } } diff --git a/settings.gradle b/settings.gradle index 3676a9e2..4cde7478 100644 --- a/settings.gradle +++ b/settings.gradle @@ -2,10 +2,11 @@ rootProject.name = 'wingtips' // Published-artifact modules include "wingtips-core", - "wingtips-old-servlet-api", "wingtips-servlet-api", "wingtips-zipkin", "wingtips-java8", + // Test-only modules (not published) + "testonly:testonly-old-servlet", // Sample modules (not published) "samples:sample-jersey1", "samples:sample-jersey2", diff --git a/testonly/testonly-old-servlet/README.md b/testonly/testonly-old-servlet/README.md new file mode 100644 index 00000000..2fed0a2f --- /dev/null +++ b/testonly/testonly-old-servlet/README.md @@ -0,0 +1,17 @@ +# Wingtips - testonly-old-servlet + +Wingtips is a distributed tracing solution for Java 7 and greater based on the +[Google Dapper paper](http://static.googleusercontent.com/media/research.google.com/en/us/pubs/archive/36356.pdf). + +This submodule contains tests to verify that the [wingtips-servlet-api](../../wingtips-servlet-api) module's +functionality works as expected in old Servlet 2.x environments. We need a separate tests-only module for this because +we need to limit the dependencies at test runtime to force a Servlet 2.x-only environment. + +## More Info + +See the [base project README.md](../../README.md) and Wingtips repository source code and javadocs for general Wingtips +information. + +## License + +Wingtips is released under the [Apache License, Version 2.0](http://www.apache.org/licenses/LICENSE-2.0) \ No newline at end of file diff --git a/wingtips-old-servlet-api/build.gradle b/testonly/testonly-old-servlet/build.gradle similarity index 88% rename from wingtips-old-servlet-api/build.gradle rename to testonly/testonly-old-servlet/build.gradle index 5dc6f5e2..89688e68 100644 --- a/wingtips-old-servlet-api/build.gradle +++ b/testonly/testonly-old-servlet/build.gradle @@ -6,19 +6,15 @@ ext { } dependencies { - compile( - project(":wingtips-core"), - ) - compileOnly( - "javax.servlet:servlet-api:$oldServletApiTargetVersion" - ) testCompile( + project(":wingtips-servlet-api"), "junit:junit-dep:$junitVersion", "org.mockito:mockito-core:$mockitoVersion", "ch.qos.logback:logback-classic:$logbackVersion", "org.assertj:assertj-core:$assertJVersion", "com.tngtech.java:junit-dataprovider:$junitDataproviderVersion", "io.rest-assured:rest-assured:$restAssuredVersion", + "javax.servlet:servlet-api:$oldServletApiTargetVersion", "org.mortbay.jetty:jetty:$oldMortbayJettyVersion" ) } diff --git a/wingtips-old-servlet-api/src/test/java/com/nike/wingtips/componenttest/RequestTracingFilterOldServletComponentTest.java b/testonly/testonly-old-servlet/src/test/java/com/nike/wingtips/componenttest/RequestTracingFilterOldServletComponentTest.java similarity index 87% rename from wingtips-old-servlet-api/src/test/java/com/nike/wingtips/componenttest/RequestTracingFilterOldServletComponentTest.java rename to testonly/testonly-old-servlet/src/test/java/com/nike/wingtips/componenttest/RequestTracingFilterOldServletComponentTest.java index b263db3e..85be44a2 100644 --- a/wingtips-old-servlet-api/src/test/java/com/nike/wingtips/componenttest/RequestTracingFilterOldServletComponentTest.java +++ b/testonly/testonly-old-servlet/src/test/java/com/nike/wingtips/componenttest/RequestTracingFilterOldServletComponentTest.java @@ -6,7 +6,7 @@ import com.nike.wingtips.TraceHeaders; import com.nike.wingtips.Tracer; import com.nike.wingtips.lifecyclelistener.SpanLifecycleListener; -import com.nike.wingtips.servlet.RequestTracingFilterOldServlet; +import com.nike.wingtips.servlet.RequestTracingFilter; import com.tngtech.java.junit.dataprovider.DataProvider; import com.tngtech.java.junit.dataprovider.DataProviderRunner; @@ -32,6 +32,7 @@ import java.util.concurrent.TimeUnit; import javax.servlet.ServletException; +import javax.servlet.ServletRequest; import javax.servlet.http.HttpServlet; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; @@ -40,9 +41,11 @@ import static io.restassured.RestAssured.given; import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.fail; /** - * Component test to verify that {@link RequestTracingFilterOldServlet} works as expected when deployed to a real running server. + * Component test to verify that {@link RequestTracingFilter} works as expected when deployed to a real running server + * that *only* supports Servlet 2.x (no Servlet 3 API available on the classpath). * * @author Nic Munroe */ @@ -55,7 +58,30 @@ public class RequestTracingFilterOldServletComponentTest { private SpanRecorder spanRecorder; @BeforeClass + @SuppressWarnings("JavaReflectionMemberAccess") public static void beforeClass() throws Exception { + try { + ServletRequest.class.getMethod("getAsyncContext"); + fail( + "Expected this test to run in an environment that does *NOT* support Servlet 3 API, " + + "however ServletRequest.getAsyncContext() method was found." + ); + } + catch(NoSuchMethodException ex) { + // Expected - do nothing + } + + try { + Class.forName("javax.servlet.AsyncListener"); + fail( + "Expected this test to run in an environment that does *NOT* support Servlet 3 API, " + + "however javax.servlet.AsyncListener class was found." + ); + } + catch(ClassNotFoundException ex) { + // Expected - do nothing + } + port = findFreePort(); server = new Server(port); server.setHandler(generateServletContextHandler()); @@ -187,7 +213,7 @@ private void verifySingleSpanCompletedAndReturnedInResponse(ExtractableResponse public static class SpanRecorder implements SpanLifecycleListener { - public final List completedSpans = new ArrayList<>(); + final List completedSpans = new ArrayList<>(); @Override public void spanStarted(Span span) { } @@ -242,7 +268,7 @@ private static Handler generateServletContextHandler() throws IOException { servletHandler.addServletWithMapping(BlockingServlet.class, BLOCKING_PATH); servletHandler.addServletWithMapping(BlockingForwardServlet.class, BLOCKING_FORWARD_PATH); - servletHandler.addFilterWithMapping(RequestTracingFilterOldServlet.class.getName(), "/*", Handler.ALL); + servletHandler.addFilterWithMapping(RequestTracingFilter.class.getName(), "/*", Handler.ALL); Context context = new Context(null, null, null, servletHandler, null); context.setContextPath("/"); diff --git a/wingtips-old-servlet-api/src/test/resources/logback.xml b/testonly/testonly-old-servlet/src/test/resources/logback.xml similarity index 100% rename from wingtips-old-servlet-api/src/test/resources/logback.xml rename to testonly/testonly-old-servlet/src/test/resources/logback.xml diff --git a/wingtips-old-servlet-api/README.md b/wingtips-old-servlet-api/README.md deleted file mode 100644 index bc6e6244..00000000 --- a/wingtips-old-servlet-api/README.md +++ /dev/null @@ -1,84 +0,0 @@ -# Wingtips - wingtips-old-servlet-api - -Wingtips is a distributed tracing solution for Java based on the [Google Dapper paper](http://static.googleusercontent.com/media/research.google.com/en/us/pubs/archive/36356.pdf). - -This module is a plugin extension module of the core Wingtips library and contains support for distributed tracing in a Java **Servlet 2.x API** environment (for Servlet 3+ environments please refer to the [wingtips-servlet-api](../wingtips-servlet-api) module). The features it provides are: - -* **HttpSpanFactory** - Utility class that extracts span information from incoming `HttpServletRequest` requests. -* **RequestTracingFilterOldServlet** - A Servlet Filter that handles all of the work for enabling a new span when a request comes in and completing it when the request finishes. This filter automatically uses `HttpSpanFactory` to extract parent span information from the incoming request headers for the new span if available. Sets the `X-B3-TraceId` response header to the Trace ID for each request. You can set the `user-id-header-keys-list` servlet filter param if you expect any request headers that represent a user ID (if you don't have any user ID headers then this can be ignored). - -Please make sure you have read the [base project README.md](../README.md). This readme assumes you understand the principles and usage instructions described there. - -## Usage Example - -The following example shows how you might setup the tracing Servlet Filter when the service expects one of two possible header keys that represent the user ID of the user making the call: `userid` or `altuserid`. - -**Add the following to web.xml** - -``` xml - - traceFilter - com.nike.wingtips.servlet.RequestTracingFilterOldServlet - - user-id-header-keys-list - userid,altuserid - - - - - traceFilter - /* - -``` - -If your service does not have any user ID headers you can remove the `` element entirely or set the `` to be empty. - -That's it for incoming requests. This Filter will do the right thing and start a root span or child span for incoming requests (depending on whether or not the caller included tracing headers), add the trace ID to the response as a response header, and guarantees completion of the overall request span right before the response is sent. - -**Embedded environments** - -For embedded Servlet container environments where you may not be using a `web.xml` file to setup Servlet components -you'll need to register `RequestTracingFilterOldServlet` in whatever way your Servlet container allows or requires you -to register Servlet Filters. For example the `Main` classes in the `samples/sample-*` sample projects show how to -register `RequestTracingFilter` with embedded Jetty. - -### Propagating the Tracing Information to Downstream Systems - -This Filter takes care of setting up the overall request span for incoming requests, but propagating the tracing information to downstream systems is still your responsibility. When you call another system you must grab the current span via `Tracer.getInstance().getCurrentSpan()` and put its field values into the downstream call's request headers using the constants in `TraceHeaders` as the header keys. For example: - -``` java -Span currentSpan = Tracer.getInstance().getCurrentSpan(); - -otherSystemRequest.setHeader(TraceHeaders.TRACE_ID, currentSpan.getTraceId()); -otherSystemRequest.setHeader(TraceHeaders.SPAN_ID, currentSpan.getSpanId()); -otherSystemRequest.setHeader(TraceHeaders.TRACE_SAMPLED, (currentSpan.isSampleable()) ? "1" : "0"); -if (currentSpan.getParentSpanId() != null) - otherSystemRequest.setHeader(TraceHeaders.PARENT_SPAN_ID, currentSpan.getParentSpanId()); -if (shouldSendSpanName) - otherSystemRequest.setHeader(TraceHeaders.SPAN_NAME, currentSpan.getSpanName()); - -executeOtherSystemCall(otherSystemRequest); -``` - -Propagating trace ID and span ID is required. Propagating parent span ID (if non-null) and sampleable value is optional -but recommended. - -The `TraceHeaders.SPAN_NAME` header propagation is optional, and you may wish to intentionally include or exclude it -depending on whether you want downstream systems to have access to that info. For services you control it may be good -to include it for extra debugging info, and for services outside your control you may wish to exclude it to prevent -unintentional information leakage. - -See the [base project readme's section on propagation](../README.md#propagating_traces) for further details on -propagating tracing information. You may also want to consider -[wrapping downstream calls in a subspan](../README.md#sub_spans_for_downstream_calls). - -## NOTE - Servlet API 2.3 or later dependency required at runtime - -This `wingtips-old-servlet-api` module has a minimum Servlet 2.3 requirement, but does not export any transitive -Servlet API dependencies to prevent runtime version conflicts with whatever Servlet environment you deploy to. - -This should not affect most users since this library is likely to be used in a Servlet environment where the Servlet -API is on the classpath at runtime, however if you receive `NoClassDefFoundError`s related to Servlet API classes then -you'll need to pull a Servlet API dependency into your project that supports a minimum Servlet 2.3 version. For -reference, `wingtips-old-servlet-api` uses the compile-only Servlet API dependency -[`javax.servlet:servlet-api:2.3`](http://search.maven.org/#artifactdetails%7Cjavax.servlet%7Cservlet-api%7C2.3%7Cjar). diff --git a/wingtips-old-servlet-api/src/main/java/com/nike/wingtips/servlet/RequestTracingFilterOldServlet.java b/wingtips-old-servlet-api/src/main/java/com/nike/wingtips/servlet/RequestTracingFilterOldServlet.java deleted file mode 100644 index 699bd04b..00000000 --- a/wingtips-old-servlet-api/src/main/java/com/nike/wingtips/servlet/RequestTracingFilterOldServlet.java +++ /dev/null @@ -1,247 +0,0 @@ -package com.nike.wingtips.servlet; - -import com.nike.wingtips.Span; -import com.nike.wingtips.TraceHeaders; -import com.nike.wingtips.Tracer; -import com.nike.wingtips.util.TracingState; - -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import java.io.IOException; -import java.util.ArrayList; -import java.util.Collections; -import java.util.List; - -import javax.servlet.Filter; -import javax.servlet.FilterChain; -import javax.servlet.FilterConfig; -import javax.servlet.ServletException; -import javax.servlet.ServletRequest; -import javax.servlet.ServletResponse; -import javax.servlet.http.HttpServletRequest; -import javax.servlet.http.HttpServletResponse; - -import static com.nike.wingtips.util.AsyncWingtipsHelperJava7.unlinkTracingFromCurrentThread; - -/** - * Makes sure distributed tracing is handled for each request. Sets up the span for incoming requests (either an - * entirely new root span or one with a parent, depending on what is in the incoming request's headers), and also sets - * the {@link TraceHeaders#TRACE_ID} on the response. This is designed to only run once per request. - * - *

This specific class is targeted at Servlet 2.x environments, where async requests are not supported. For - * Servlet 3+ environments you should use {@code RequestTracingFilter} from the {@code wingtips-servlet-api} module. - * - *

NOTE: You can override {@link #getUserIdHeaderKeys()} if your service is expecting user ID header(s) and you can't - * (or don't want to) set up those headers via the {@link #USER_ID_HEADER_KEYS_LIST_INIT_PARAM_NAME} init parameter. - * - *

ALSO NOTE: Subclasses that support async servlet requests (like {@code RequestTracingFilter} from the - * {@code wingtips-servlet-api} module) should override {@link #isAsyncRequest(HttpServletRequest)} and {@link - * #setupTracingCompletionWhenAsyncRequestCompletes(HttpServletRequest, TracingState)}. - * - * @author Nic Munroe - */ -@SuppressWarnings("WeakerAccess") -public class RequestTracingFilterOldServlet implements Filter { - private final Logger logger = LoggerFactory.getLogger(this.getClass()); - - /** - * This attribute key will be set to a value of true via {@link ServletRequest#setAttribute(String, Object)} the - * first time this filter's distributed tracing logic is run for any given request. This filter will then see this - * attribute on any subsequent executions for the same request and continue the filter chain without executing the - * distributed tracing logic again to make sure this filter's logic is only executed once per request. - * - *

If you want to prevent this filter from executing on specific requests then you can override {@link - * #skipDispatch(HttpServletRequest)} to return true for any requests where you don't want distributed tracing to - * occur. - */ - public static final String FILTER_HAS_ALREADY_EXECUTED_ATTRIBUTE = "RequestTracingFilterAlreadyFiltered"; - - /** - * The param name for the "list of user ID header keys" init param for this filter. The value of this init param - * will be parsed for the list of user ID header keys to use when calling {@link - * HttpSpanFactory#fromHttpServletRequest(HttpServletRequest, List)} or {@link - * HttpSpanFactory#getUserIdFromHttpServletRequest(HttpServletRequest, List)}. The value for this init param is - * expected to be a comma-delimited list. - */ - public static final String USER_ID_HEADER_KEYS_LIST_INIT_PARAM_NAME = "user-id-header-keys-list"; - - protected List userIdHeaderKeysFromInitParam; - - @Override - public void init(FilterConfig filterConfig) throws ServletException { - String userIdHeaderKeysListString = filterConfig.getInitParameter(USER_ID_HEADER_KEYS_LIST_INIT_PARAM_NAME); - if (userIdHeaderKeysListString != null) { - List parsedList = new ArrayList<>(); - for (String headerKey : userIdHeaderKeysListString.split(",")) { - String trimmedHeaderKey = headerKey.trim(); - if (trimmedHeaderKey.length() > 0) - parsedList.add(trimmedHeaderKey); - } - userIdHeaderKeysFromInitParam = Collections.unmodifiableList(parsedList); - } - } - - @Override - public void destroy() { - // Nothing to do - } - - /** - * Wrapper around {@link #doFilterInternal(HttpServletRequest, HttpServletResponse, FilterChain)} to make sure this - * filter's logic is only executed once per request. - */ - @Override - public void doFilter(ServletRequest request, - ServletResponse response, - FilterChain filterChain) throws IOException, ServletException { - if (!(request instanceof HttpServletRequest) || !(response instanceof HttpServletResponse)) { - throw new ServletException(this.getClass().getName() + " only supports HTTP requests"); - } - HttpServletRequest httpRequest = (HttpServletRequest) request; - HttpServletResponse httpResponse = (HttpServletResponse) response; - - boolean filterHasAlreadyExecuted = request.getAttribute(FILTER_HAS_ALREADY_EXECUTED_ATTRIBUTE) != null; - if (filterHasAlreadyExecuted || skipDispatch(httpRequest)) { - - // Already executed or we're supposed to skip, so continue the filter chain without doing the - // distributed tracing work. - filterChain.doFilter(request, response); - } - else { - // Time to execute the distributed tracing logic. - request.setAttribute(FILTER_HAS_ALREADY_EXECUTED_ATTRIBUTE, Boolean.TRUE); - doFilterInternal(httpRequest, httpResponse, filterChain); - } - } - - /** - * Performs the distributed tracing work for each request's overall span. Guaranteed to only be called once per - * request. - */ - protected void doFilterInternal(HttpServletRequest request, - HttpServletResponse response, - FilterChain filterChain) throws ServletException, IOException { - // Surround the tracing filter logic with a try/finally that guarantees the original tracing and MDC info found - // on the current thread at the beginning of this method is restored to this thread before this method - // returns, even if the request ends up being an async request. Otherwise there's the possibility of - // incorrect tracing information sticking around on this thread and potentially polluting other requests. - TracingState originalThreadInfo = TracingState.getCurrentThreadTracingState(); - try { - // See if there's trace info in the incoming request's headers. If so it becomes the parent trace. - Tracer tracer = Tracer.getInstance(); - final Span parentSpan = HttpSpanFactory.fromHttpServletRequest(request, getUserIdHeaderKeys()); - Span newSpan; - - if (parentSpan != null) { - logger.debug("Found parent Span {}", parentSpan); - newSpan = tracer.startRequestWithChildSpan(parentSpan, HttpSpanFactory.getSpanName(request)); - } else { - newSpan = tracer.startRequestWithRootSpan( - HttpSpanFactory.getSpanName(request), - HttpSpanFactory.getUserIdFromHttpServletRequest(request, getUserIdHeaderKeys()) - ); - logger.debug("Parent span not found, starting a new span {}", newSpan); - } - - // Put the new span's trace info into the request attributes. - request.setAttribute(TraceHeaders.TRACE_SAMPLED, newSpan.isSampleable()); - request.setAttribute(TraceHeaders.TRACE_ID, newSpan.getTraceId()); - request.setAttribute(TraceHeaders.SPAN_ID, newSpan.getSpanId()); - request.setAttribute(TraceHeaders.PARENT_SPAN_ID, newSpan.getParentSpanId()); - request.setAttribute(TraceHeaders.SPAN_NAME, newSpan.getSpanName()); - request.setAttribute(Span.class.getName(), newSpan); - - // Make sure we set the trace ID on the response header now before the response is committed (if we wait - // until after the filter chain then the response might already be committed, silently preventing us - // from setting the response header) - response.setHeader(TraceHeaders.TRACE_ID, newSpan.getTraceId()); - - TracingState originalRequestTracingState = TracingState.getCurrentThreadTracingState(); - try { - filterChain.doFilter(request, response); - } finally { - if (isAsyncRequest(request)) { - // Async, so we need to attach a listener to complete the original tracing state when the async - // servlet request finishes. - setupTracingCompletionWhenAsyncRequestCompletes(request, originalRequestTracingState); - } - else { - // Not async, so we need to complete the request span now. - tracer.completeRequestSpan(); - } - } - } - finally { - //noinspection deprecation - unlinkTracingFromCurrentThread(originalThreadInfo); - } - } - - /** - * @return true if {@link #doFilterInternal(HttpServletRequest, HttpServletResponse, FilterChain)} should be - * skipped (and therefore prevent distributed tracing logic from starting), false otherwise. This defaults to - * returning false so the first execution of this filter will always trigger distributed tracing, so if you have a - * need to skip distributed tracing for a request you can override this method and have whatever logic you need. - */ - protected boolean skipDispatch(HttpServletRequest request) { - return false; - } - - /** - * The list of header keys that will be used to search the request headers for a user ID to set on the {@link Span} - * for the request. The user ID header keys will be searched in list order, and the first non-empty user ID header - * value found will be used as the {@link Span#getUserId()}. You can safely return null or an empty list for this - * method if there is no user ID to extract; if you return null/empty then the request span's {@link - * Span#getUserId()} will be null. - * - *

By default this method will return the list specified via the {@link - * #USER_ID_HEADER_KEYS_LIST_INIT_PARAM_NAME} init param, or null if that init param does not exist. - * - * @return The list of header keys that will be used to search the request headers for a user ID to set on the - * {@link Span} for the request. This method may return null or an empty list if there are no user IDs to search - * for. - */ - protected List getUserIdHeaderKeys() { - return userIdHeaderKeysFromInitParam; - } - - /** - * This impl will always return false since Servlet 2.x does not support async requests, but for - * subclasses that do support async requests this method should be overridden to return the value of {@code - * request.isAsyncStarted()}. - * - *

See {@code RequestTracingFilter} from the {@code wingtips-servlet-api} module for an impl that supports async - * servlet requests. - * - * @param request The request to inspect to see if it's part of an async servlet request or not. - * @return false - */ - protected boolean isAsyncRequest(HttpServletRequest request) { - return false; - } - - /** - * This impl does nothing by default (it will never be called since {@link #isAsyncRequest(HttpServletRequest)} - * always returns false), but for subclasses that do support async requests this method should be overridden to - * provide a listener that will complete the given {@link TracingState} when the given async request finishes. - * The code would look something like: - * - *

-     *      AsyncListener spanCompletingAsyncListener = ...;
-     *      asyncRequest.getAsyncContext().addListener(spanCompletingAsyncListener);
-     * 
- * - *

See {@code RequestTracingFilter} from the {@code wingtips-servlet-api} module for an impl that supports async - * servlet requests. - * - * @param asyncRequest The async servlet request (guaranteed to be async since this method will only be called when - * {@link #isAsyncRequest(HttpServletRequest)} returns true). - * @param originalRequestTracingState The {@link TracingState} that was generated when this request started, and - * which should be completed when the given async servlet request finishes. - */ - protected void setupTracingCompletionWhenAsyncRequestCompletes(HttpServletRequest asyncRequest, - TracingState originalRequestTracingState) { - // Do nothing - } -} diff --git a/wingtips-old-servlet-api/src/test/java/com/nike/wingtips/servlet/RequestTracingFilterOldServletTest.java b/wingtips-old-servlet-api/src/test/java/com/nike/wingtips/servlet/RequestTracingFilterOldServletTest.java deleted file mode 100644 index 8127b9f4..00000000 --- a/wingtips-old-servlet-api/src/test/java/com/nike/wingtips/servlet/RequestTracingFilterOldServletTest.java +++ /dev/null @@ -1,653 +0,0 @@ -package com.nike.wingtips.servlet; - -import com.nike.wingtips.Span; -import com.nike.wingtips.Span.SpanPurpose; -import com.nike.wingtips.TraceAndSpanIdGenerator; -import com.nike.wingtips.TraceHeaders; -import com.nike.wingtips.Tracer; -import com.nike.wingtips.util.TracingState; - -import com.tngtech.java.junit.dataprovider.DataProvider; -import com.tngtech.java.junit.dataprovider.DataProviderRunner; -import com.tngtech.java.junit.dataprovider.UseDataProvider; - -import org.assertj.core.api.ThrowableAssert; -import org.junit.After; -import org.junit.Before; -import org.junit.Test; -import org.junit.runner.RunWith; -import org.slf4j.MDC; - -import java.io.IOException; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Collections; -import java.util.List; - -import javax.servlet.FilterChain; -import javax.servlet.FilterConfig; -import javax.servlet.ServletException; -import javax.servlet.ServletRequest; -import javax.servlet.ServletResponse; -import javax.servlet.http.HttpServletRequest; -import javax.servlet.http.HttpServletResponse; - -import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.catchThrowable; -import static org.assertj.core.api.Fail.fail; -import static org.mockito.BDDMockito.given; -import static org.mockito.Matchers.any; -import static org.mockito.Matchers.eq; -import static org.mockito.Mockito.doReturn; -import static org.mockito.Mockito.doThrow; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.never; -import static org.mockito.Mockito.spy; -import static org.mockito.Mockito.times; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.verifyNoMoreInteractions; - -/** - * Tests the functionality of {@link RequestTracingFilterOldServlet} - */ -@RunWith(DataProviderRunner.class) -public class RequestTracingFilterOldServletTest { - - private HttpServletRequest requestMock; - private HttpServletResponse responseMock; - private FilterChain filterChainMock; - private SpanCapturingFilterChain spanCapturingFilterChain; - private FilterConfig filterConfigMock; - - private static final String USER_ID_HEADER_KEY = "userId"; - private static final String ALT_USER_ID_HEADER_KEY = "altUserId"; - private static final List USER_ID_HEADER_KEYS = Arrays.asList(USER_ID_HEADER_KEY, ALT_USER_ID_HEADER_KEY); - private static final String USER_ID_HEADER_KEYS_INIT_PARAM_VALUE_STRING = USER_ID_HEADER_KEYS.toString().replace("[", "").replace("]", ""); - - private RequestTracingFilterOldServlet getBasicFilter() { - RequestTracingFilterOldServlet filter = new RequestTracingFilterOldServlet(); - - try { - filter.init(filterConfigMock); - } catch (ServletException e) { - throw new RuntimeException(e); - } - - return filter; - } - - private RequestTracingFilterOldServlet getFilterWithSkipDispatchOverride(final boolean overrideVal) { - RequestTracingFilterOldServlet filter = new RequestTracingFilterOldServlet() { - - @Override - protected boolean skipDispatch(HttpServletRequest request) { - return overrideVal; - } - }; - - try { - filter.init(filterConfigMock); - } catch (ServletException e) { - throw new RuntimeException(e); - } - - return filter; - } - - @Before - public void setupMethod() { - requestMock = mock(HttpServletRequest.class); - responseMock = mock(HttpServletResponse.class); - filterChainMock = mock(FilterChain.class); - spanCapturingFilterChain = new SpanCapturingFilterChain(); - - filterConfigMock = mock(FilterConfig.class); - doReturn(USER_ID_HEADER_KEYS_INIT_PARAM_VALUE_STRING) - .when(filterConfigMock) - .getInitParameter(RequestTracingFilterOldServlet.USER_ID_HEADER_KEYS_LIST_INIT_PARAM_NAME); - - resetTracing(); - } - - @After - public void afterMethod() { - resetTracing(); - } - - private void resetTracing() { - MDC.clear(); - Tracer.getInstance().unregisterFromThread(); - } - - // VERIFY filter init, getUserIdHeaderKeys, and destroy ======================= - - @DataProvider - public static Object[][] userIdHeaderKeysInitParamDataProvider() { - - return new Object[][] { - { null, null }, - { "", Collections.emptyList() }, - { " \t \n ", Collections.emptyList() }, - { "asdf", Collections.singletonList("asdf") }, - { " , \n\t, asdf , \t\n ", Collections.singletonList("asdf") }, - { "ASDF,QWER", Arrays.asList("ASDF", "QWER") }, - { "ASDF, QWER, ZXCV", Arrays.asList("ASDF", "QWER", "ZXCV") } - }; - } - - @Test - @UseDataProvider("userIdHeaderKeysInitParamDataProvider") - public void init_method_gets_user_id_header_key_list_from_init_params_and_getUserIdHeaderKeys_exposes_them( - String userIdHeaderKeysInitParamValue, List expectedUserIdHeaderKeysList) throws ServletException { - // given - RequestTracingFilterOldServlet filter = new RequestTracingFilterOldServlet(); - FilterConfig filterConfig = mock(FilterConfig.class); - doReturn(userIdHeaderKeysInitParamValue) - .when(filterConfig) - .getInitParameter(RequestTracingFilterOldServlet.USER_ID_HEADER_KEYS_LIST_INIT_PARAM_NAME); - filter.init(filterConfig); - - // when - List actualUserIdHeaderKeysList = filter.getUserIdHeaderKeys(); - - // then - assertThat(actualUserIdHeaderKeysList).isEqualTo(expectedUserIdHeaderKeysList); - if (actualUserIdHeaderKeysList != null) { - Exception caughtEx = null; - try { - actualUserIdHeaderKeysList.add("foo"); - } catch (Exception ex) { - caughtEx = ex; - } - assertThat(caughtEx).isNotNull(); - assertThat(caughtEx).isInstanceOf(UnsupportedOperationException.class); - } - } - - @Test - public void destroy_does_not_explode() { - // expect - getBasicFilter().destroy(); - // No explosion no problem - } - - // VERIFY doFilter =================================== - - @Test(expected = ServletException.class) - public void doFilter_should_explode_if_request_is_not_HttpServletRequest() throws IOException, ServletException { - // expect - getBasicFilter().doFilter(mock(ServletRequest.class), mock(HttpServletResponse.class), mock(FilterChain.class)); - fail("Expected ServletException but no exception was thrown"); - } - - @Test(expected = ServletException.class) - public void doFilter_should_explode_if_response_is_not_HttpServletResponse() throws IOException, ServletException { - // expect - getBasicFilter().doFilter(mock(HttpServletRequest.class), mock(ServletResponse.class), mock(FilterChain.class)); - fail("Expected ServletException but no exception was thrown"); - } - - @Test - public void doFilter_should_not_explode_if_request_and_response_are_HttpServletRequests_and_HttpServletResponses() throws IOException, ServletException { - // expect - getBasicFilter().doFilter(mock(HttpServletRequest.class), mock(HttpServletResponse.class), mock(FilterChain.class)); - // No explosion no problem - } - - @Test - public void doFilter_should_call_doFilterInternal_and_set_ALREADY_FILTERED_ATTRIBUTE_KEY_if_not_already_filtered_and_skipDispatch_returns_false() - throws IOException, ServletException { - // given: filter that returns false for skipDispatch and request that returns null for already-filtered attribute - RequestTracingFilterOldServlet spyFilter = spy(getFilterWithSkipDispatchOverride(false)); - given(requestMock.getAttribute( - RequestTracingFilterOldServlet.FILTER_HAS_ALREADY_EXECUTED_ATTRIBUTE)).willReturn(null); - - // when: doFilter() is called - spyFilter.doFilter(requestMock, responseMock, filterChainMock); - - // then: doFilterInternal should be called and ALREADY_FILTERED_ATTRIBUTE_KEY should be set on the request - verify(spyFilter).doFilterInternal(requestMock, responseMock, filterChainMock); - verify(requestMock).setAttribute(RequestTracingFilterOldServlet.FILTER_HAS_ALREADY_EXECUTED_ATTRIBUTE, Boolean.TRUE); - } - - @Test - public void doFilter_should_not_unset_ALREADY_FILTERED_ATTRIBUTE_KEY_after_running_doFilterInternal() throws IOException, ServletException { - // given: filter that will run doFilterInternal and a FilterChain we can use to verify state when called - final RequestTracingFilterOldServlet spyFilter = spy(getFilterWithSkipDispatchOverride(false)); - given(requestMock.getAttribute( - RequestTracingFilterOldServlet.FILTER_HAS_ALREADY_EXECUTED_ATTRIBUTE)).willReturn(null); - final List ifObjectAddedThenSmartFilterChainCalled = new ArrayList<>(); - FilterChain smartFilterChain = new FilterChain() { - @Override - public void doFilter(ServletRequest request, ServletResponse response) throws IOException, ServletException { - // Verify that when the filter chain is called we're in doFilterInternal, and that the request has ALREADY_FILTERED_ATTRIBUTE_KEY set - verify(spyFilter).doFilterInternal(requestMock, responseMock, this); - verify(requestMock).setAttribute(RequestTracingFilterOldServlet.FILTER_HAS_ALREADY_EXECUTED_ATTRIBUTE, Boolean.TRUE); - verify(requestMock, times(0)).removeAttribute(RequestTracingFilterOldServlet.FILTER_HAS_ALREADY_EXECUTED_ATTRIBUTE); - ifObjectAddedThenSmartFilterChainCalled.add(true); - } - }; - - // when: doFilter() is called - spyFilter.doFilter(requestMock, responseMock, smartFilterChain); - - // then: smartFilterChain's doFilter should have been called and ALREADY_FILTERED_ATTRIBUTE_KEY should still be set on the request - assertThat(ifObjectAddedThenSmartFilterChainCalled).hasSize(1); - verify(requestMock, never()).removeAttribute(RequestTracingFilterOldServlet.FILTER_HAS_ALREADY_EXECUTED_ATTRIBUTE); - } - - @Test - public void doFilter_should_not_unset_ALREADY_FILTERED_ATTRIBUTE_KEY_even_if_filter_chain_explodes() throws IOException, ServletException { - // given: filter that will run doFilterInternal and a FilterChain we can use to verify state when called and then explodes - final RequestTracingFilterOldServlet spyFilter = spy(getFilterWithSkipDispatchOverride(false)); - given(requestMock.getAttribute( - RequestTracingFilterOldServlet.FILTER_HAS_ALREADY_EXECUTED_ATTRIBUTE)).willReturn(null); - final List ifObjectAddedThenSmartFilterChainCalled = new ArrayList<>(); - FilterChain smartFilterChain = new FilterChain() { - @Override - public void doFilter(ServletRequest request, ServletResponse response) throws IOException, ServletException { - // Verify that when the filter chain is called we're in doFilterInternal, and that the request has ALREADY_FILTERED_ATTRIBUTE_KEY set - verify(spyFilter).doFilterInternal(requestMock, responseMock, this); - verify(requestMock).setAttribute(RequestTracingFilterOldServlet.FILTER_HAS_ALREADY_EXECUTED_ATTRIBUTE, Boolean.TRUE); - verify(requestMock, times(0)).removeAttribute(RequestTracingFilterOldServlet.FILTER_HAS_ALREADY_EXECUTED_ATTRIBUTE); - ifObjectAddedThenSmartFilterChainCalled.add(true); - throw new IllegalStateException("boom"); - } - }; - - // when: doFilter() is called - boolean filterChainExploded = false; - try { - spyFilter.doFilter(requestMock, responseMock, smartFilterChain); - } - catch(IllegalStateException ex) { - if ("boom".equals(ex.getMessage())) - filterChainExploded = true; - } - - // then: smartFilterChain's doFilter should have been called, it should have exploded, and ALREADY_FILTERED_ATTRIBUTE_KEY should still be set on the request - assertThat(ifObjectAddedThenSmartFilterChainCalled).hasSize(1); - assertThat(filterChainExploded).isTrue(); - verify(requestMock, never()).removeAttribute(RequestTracingFilterOldServlet.FILTER_HAS_ALREADY_EXECUTED_ATTRIBUTE); - } - - @Test - public void doFilter_should_not_call_doFilterInternal_if_already_filtered() throws IOException, ServletException { - // given: filter that returns false for skipDispatch but request that returns non-null for already-filtered attribute - RequestTracingFilterOldServlet spyFilter = spy(getFilterWithSkipDispatchOverride(false)); - given(requestMock.getAttribute( - RequestTracingFilterOldServlet.FILTER_HAS_ALREADY_EXECUTED_ATTRIBUTE)).willReturn(Boolean.TRUE); - - // when: doFilter() is called - spyFilter.doFilter(requestMock, responseMock, filterChainMock); - - // then: doFilterInternal should not be called - verify(spyFilter, times(0)).doFilterInternal(requestMock, responseMock, filterChainMock); - } - - @Test - public void doFilter_should_not_call_doFilterInternal_if_not_already_filtered_but_skipDispatch_returns_true() throws IOException, ServletException { - // given: request that returns null for already-filtered attribute but filter that returns true for skipDispatch - RequestTracingFilterOldServlet spyFilter = spy(getFilterWithSkipDispatchOverride(true)); - given(requestMock.getAttribute( - RequestTracingFilterOldServlet.FILTER_HAS_ALREADY_EXECUTED_ATTRIBUTE)).willReturn(null); - - // when: doFilter() is called - spyFilter.doFilter(requestMock, responseMock, filterChainMock); - - // then: doFilterInternal should not be called - verify(spyFilter, times(0)).doFilterInternal(requestMock, responseMock, filterChainMock); - } - - // VERIFY doFilterInternal =================================== - - @Test - public void doFilterInternal_should_create_new_sampleable_span_if_no_parent_in_request_and_it_should_be_completed() throws ServletException, IOException { - // given: filter - RequestTracingFilterOldServlet filter = getFilterWithSkipDispatchOverride(false); - - // when: doFilterInternal is called with a request that does not have a parent span - filter.doFilterInternal(requestMock, responseMock, spanCapturingFilterChain); - - // then: a new valid sampleable span should be created and completed - Span span = spanCapturingFilterChain.capturedSpan; - assertThat(span).isNotNull(); - assertThat(span.getTraceId()).isNotNull(); - assertThat(span.getSpanId()).isNotNull(); - assertThat(span.getSpanName()).isNotNull(); - assertThat(span.getParentSpanId()).isNull(); - assertThat(span.isSampleable()).isTrue(); - assertThat(span.isCompleted()).isTrue(); - } - - @Test - public void doFilterInternal_should_not_complete_span_until_after_filter_chain_runs() throws ServletException, IOException { - // given: filter and filter chain that can tell us whether or not the span is complete at the time it is called - RequestTracingFilterOldServlet filter = getFilterWithSkipDispatchOverride(false); - final List spanCompletedHolder = new ArrayList<>(); - final List spanHolder = new ArrayList<>(); - FilterChain smartFilterChain = new FilterChain() { - @Override - public void doFilter(ServletRequest request, ServletResponse response) throws IOException, ServletException { - Span span = Tracer.getInstance().getCurrentSpan(); - spanHolder.add(span); - if (span != null) - spanCompletedHolder.add(span.isCompleted()); - } - }; - - // when: doFilterInternal is called - filter.doFilterInternal(requestMock, responseMock, smartFilterChain); - - // then: we should be able to validate that the smartFilterChain was called, and when it was called the span had not yet been completed, - // and after doFilterInternal finished it was completed. - assertThat(spanHolder).hasSize(1); - assertThat(spanCompletedHolder).hasSize(1); - assertThat(spanCompletedHolder.get(0)).isFalse(); - assertThat(spanHolder.get(0).isCompleted()).isTrue(); - } - - @DataProvider(value = { - "true", - "false" - }) - @Test - public void doFilterInternal_should_complete_span_even_if_filter_chain_explodes( - boolean isAsyncRequest - ) throws ServletException, IOException { - // given: filter and filter chain that will explode when called - RequestTracingFilterOldServlet filterSpy = spy(getFilterWithSkipDispatchOverride(false)); - final List spanContextHolder = new ArrayList<>(); - FilterChain explodingFilterChain = new FilterChain() { - @Override - public void doFilter(ServletRequest request, ServletResponse response) throws IOException, ServletException { - // Verify that the span is not yet completed, keep track of it for later, then explode - Span span = Tracer.getInstance().getCurrentSpan(); - assertThat(span).isNotNull(); - assertThat(span.isCompleted()).isFalse(); - spanContextHolder.add(span); - throw new IllegalStateException("boom"); - } - }; - - if (isAsyncRequest) { - doReturn(true).when(filterSpy).isAsyncRequest(any(HttpServletRequest.class)); - } - - // when: doFilterInternal is called - boolean filterChainExploded = false; - try { - filterSpy.doFilterInternal(requestMock, responseMock, explodingFilterChain); - } - catch(IllegalStateException ex) { - if ("boom".equals(ex.getMessage())) - filterChainExploded = true; - } - - // then: we should be able to validate that the filter chain exploded and the span is still completed, - // or setup for completion in the case of an async request - if (isAsyncRequest) { - assertThat(filterChainExploded).isTrue(); - verify(filterSpy).isAsyncRequest(requestMock); - verify(filterSpy).setupTracingCompletionWhenAsyncRequestCompletes(eq(requestMock), any(TracingState.class)); - assertThat(spanContextHolder).hasSize(1); - // The span should not be *completed* for an async request, but the - // setupTracingCompletionWhenAsyncRequestCompletes verification above represents the equivalent for - // async requests. - assertThat(spanContextHolder.get(0).isCompleted()).isFalse(); - } - else { - assertThat(filterChainExploded).isTrue(); - assertThat(spanContextHolder).hasSize(1); - assertThat(spanContextHolder.get(0).isCompleted()).isTrue(); - } - } - - @Test - public void doFilterInternal_should_set_request_attributes_to_new_span_info_with_user_id() throws ServletException, IOException { - // given: filter - RequestTracingFilterOldServlet spyFilter = spy(getFilterWithSkipDispatchOverride(false)); - given(requestMock.getHeader(USER_ID_HEADER_KEY)).willReturn("testUserId"); - - // when: doFilterInternal is called - spyFilter.doFilterInternal(requestMock, responseMock, spanCapturingFilterChain); - - // then: request attributes should be set with the new span's info - assertThat(spanCapturingFilterChain.capturedSpan).isNotNull(); - Span newSpan = spanCapturingFilterChain.capturedSpan; - - assertThat(newSpan.getUserId()).isEqualTo("testUserId"); - } - - @Test - public void doFilterInternal_should_set_request_attributes_to_new_span_info_with_alt_user_id() throws ServletException, IOException { - // given: filter - RequestTracingFilterOldServlet spyFilter = spy(getFilterWithSkipDispatchOverride(false)); - given(requestMock.getHeader(ALT_USER_ID_HEADER_KEY)).willReturn("testUserId"); - - // when: doFilterInternal is called - spyFilter.doFilterInternal(requestMock, responseMock, spanCapturingFilterChain); - - // then: request attributes should be set with the new span's info - assertThat(spanCapturingFilterChain.capturedSpan).isNotNull(); - Span newSpan = spanCapturingFilterChain.capturedSpan; - - assertThat(newSpan.getUserId()).isEqualTo("testUserId"); - } - - @Test - public void doFilterInternal_should_set_request_attributes_to_new_span_info() throws ServletException, IOException { - // given: filter - RequestTracingFilterOldServlet filter = getFilterWithSkipDispatchOverride(false); - - // when: doFilterInternal is called - filter.doFilterInternal(requestMock, responseMock, spanCapturingFilterChain); - - // then: request attributes should be set with the new span's info - assertThat(spanCapturingFilterChain.capturedSpan).isNotNull(); - Span newSpan = spanCapturingFilterChain.capturedSpan; - - verify(requestMock).setAttribute(TraceHeaders.TRACE_SAMPLED, newSpan.isSampleable()); - verify(requestMock).setAttribute(TraceHeaders.TRACE_ID, newSpan.getTraceId()); - verify(requestMock).setAttribute(TraceHeaders.SPAN_ID, newSpan.getSpanId()); - verify(requestMock).setAttribute(TraceHeaders.PARENT_SPAN_ID, newSpan.getParentSpanId()); - verify(requestMock).setAttribute(TraceHeaders.SPAN_NAME, newSpan.getSpanName()); - verify(requestMock).setAttribute(Span.class.getName(), newSpan); - } - - @Test - public void doFilterInternal_should_set_trace_id_in_response_header() throws ServletException, IOException { - // given: filter - RequestTracingFilterOldServlet filter = getFilterWithSkipDispatchOverride(false); - - // when: doFilterInternal is called - filter.doFilterInternal(requestMock, responseMock, spanCapturingFilterChain); - - // then: response header should be set with the span's trace ID - assertThat(spanCapturingFilterChain.capturedSpan).isNotNull(); - verify(responseMock).setHeader(TraceHeaders.TRACE_ID, spanCapturingFilterChain.capturedSpan.getTraceId()); - } - - @Test - public void doFilterInternal_should_use_parent_span_info_if_present_in_request_headers() throws ServletException, IOException { - // given: filter and request that has parent span info - RequestTracingFilterOldServlet filter = getFilterWithSkipDispatchOverride(false); - Span parentSpan = Span.newBuilder("someParentSpan", null).withParentSpanId(TraceAndSpanIdGenerator.generateId()).withSampleable(false).withUserId("someUser").build(); - given(requestMock.getHeader(TraceHeaders.TRACE_ID)).willReturn(parentSpan.getTraceId()); - given(requestMock.getHeader(TraceHeaders.SPAN_ID)).willReturn(parentSpan.getSpanId()); - given(requestMock.getHeader(TraceHeaders.PARENT_SPAN_ID)).willReturn(parentSpan.getParentSpanId()); - given(requestMock.getHeader(TraceHeaders.SPAN_NAME)).willReturn(parentSpan.getSpanName()); - given(requestMock.getHeader(TraceHeaders.TRACE_SAMPLED)).willReturn(String.valueOf(parentSpan.isSampleable())); - given(requestMock.getServletPath()).willReturn("/some/path"); - given(requestMock.getMethod()).willReturn("GET"); - - // when: doFilterInternal is called - filter.doFilterInternal(requestMock, responseMock, spanCapturingFilterChain); - - // then: the span that is created should use the parent span info as its parent - assertThat(spanCapturingFilterChain.capturedSpan).isNotNull(); - Span newSpan = spanCapturingFilterChain.capturedSpan; - assertThat(newSpan.getTraceId()).isEqualTo(parentSpan.getTraceId()); - assertThat(newSpan.getSpanId()).isNotEqualTo(parentSpan.getSpanId()); - assertThat(newSpan.getParentSpanId()).isEqualTo(parentSpan.getSpanId()); - assertThat(newSpan.getSpanName()).isEqualTo(HttpSpanFactory.getSpanName(requestMock)); - assertThat(newSpan.isSampleable()).isEqualTo(parentSpan.isSampleable()); - assertThat(newSpan.getSpanPurpose()).isEqualTo(SpanPurpose.SERVER); - } - - @Test - public void doFilterInternal_should_use_user_id_from_parent_span_info_if_present_in_request_headers() throws ServletException, IOException { - // given: filter and request that has parent span info - RequestTracingFilterOldServlet spyFilter = spy(getFilterWithSkipDispatchOverride(false)); - given(requestMock.getHeader(ALT_USER_ID_HEADER_KEY)).willReturn("testUserId"); - - Span parentSpan = Span.newBuilder("someParentSpan", null).withParentSpanId(TraceAndSpanIdGenerator.generateId()).withSampleable(false).withUserId("someUser").build(); - given(requestMock.getHeader(TraceHeaders.TRACE_ID)).willReturn(parentSpan.getTraceId()); - given(requestMock.getHeader(TraceHeaders.SPAN_ID)).willReturn(parentSpan.getSpanId()); - given(requestMock.getHeader(TraceHeaders.PARENT_SPAN_ID)).willReturn(parentSpan.getParentSpanId()); - given(requestMock.getHeader(TraceHeaders.SPAN_NAME)).willReturn(parentSpan.getSpanName()); - given(requestMock.getHeader(TraceHeaders.TRACE_SAMPLED)).willReturn(String.valueOf(parentSpan.isSampleable())); - given(requestMock.getServletPath()).willReturn("/some/path"); - given(requestMock.getMethod()).willReturn("GET"); - - // when: doFilterInternal is called - spyFilter.doFilterInternal(requestMock, responseMock, spanCapturingFilterChain); - - // then: the span that is created should use the parent span info as its parent - assertThat(spanCapturingFilterChain.capturedSpan).isNotNull(); - Span newSpan = spanCapturingFilterChain.capturedSpan; - - assertThat(newSpan.getUserId()).isEqualTo("testUserId"); - - } - - @DataProvider(value = { - "true", - "false", - }, splitBy = "\\|") - @Test - public void doFilterInternal_should_reset_tracing_info_to_whatever_was_on_the_thread_originally( - boolean throwExceptionInInnerFinallyBlock - ) throws ServletException, IOException { - // given - final RequestTracingFilterOldServlet filterSpy = spy(getBasicFilter()); - RuntimeException exToThrowInInnerFinallyBlock = null; - if (throwExceptionInInnerFinallyBlock) { - exToThrowInInnerFinallyBlock = new RuntimeException("kaboom"); - doThrow(exToThrowInInnerFinallyBlock).when(filterSpy).isAsyncRequest(any(HttpServletRequest.class)); - } - Tracer.getInstance().startRequestWithRootSpan("someOutsideSpan"); - TracingState originalTracingState = TracingState.getCurrentThreadTracingState(); - - // when - Throwable ex = catchThrowable(new ThrowableAssert.ThrowingCallable() { - @Override - public void call() throws Throwable { - filterSpy.doFilterInternal(requestMock, responseMock, spanCapturingFilterChain); - } - }); - - // then - if (throwExceptionInInnerFinallyBlock) { - assertThat(ex).isSameAs(exToThrowInInnerFinallyBlock); - } - assertThat(TracingState.getCurrentThreadTracingState()).isEqualTo(originalTracingState); - assertThat(spanCapturingFilterChain.capturedSpan).isNotNull(); - // The original tracing state was replaced on the thread before returning, but the span used by the filter chain - // should *not* come from the original tracing state - it should have come from the incoming headers or - // a new one generated. - assertThat(spanCapturingFilterChain.capturedSpan.getTraceId()) - .isNotEqualTo(originalTracingState.spanStack.peek().getTraceId()); - } - - @Test - public void doFilterInternal_should_call_setupTracingCompletionWhenAsyncRequestCompletes_when_isAsyncRequest_returns_true( - ) throws ServletException, IOException { - // given - RequestTracingFilterOldServlet filterSpy = spy(getBasicFilter()); - doReturn(true).when(filterSpy).isAsyncRequest(any(HttpServletRequest.class)); - - // when - filterSpy.doFilterInternal(requestMock, responseMock, spanCapturingFilterChain); - - // then - assertThat(spanCapturingFilterChain.capturedSpan).isNotNull(); - assertThat(spanCapturingFilterChain.capturedSpan.isCompleted()).isFalse(); - verify(filterSpy).setupTracingCompletionWhenAsyncRequestCompletes(eq(requestMock), any(TracingState.class)); - } - - @Test - public void doFilterInternal_should_not_call_setupTracingCompletionWhenAsyncRequestCompletes_when_isAsyncRequest_returns_false( - ) throws ServletException, IOException { - // given - RequestTracingFilterOldServlet filterSpy = spy(getBasicFilter()); - doReturn(false).when(filterSpy).isAsyncRequest(any(HttpServletRequest.class)); - - // when - filterSpy.doFilterInternal(requestMock, responseMock, spanCapturingFilterChain); - - // then - assertThat(spanCapturingFilterChain.capturedSpan).isNotNull(); - assertThat(spanCapturingFilterChain.capturedSpan.isCompleted()).isTrue(); - verify(filterSpy, never()).setupTracingCompletionWhenAsyncRequestCompletes( - any(HttpServletRequest.class), any(TracingState.class) - ); - } - - // VERIFY skipDispatch ============================== - - @Test - public void skipDispatch_should_return_false() { - // given: filter - RequestTracingFilterOldServlet filter = getBasicFilter(); - - // when: skipDispatchIsCalled - boolean result = filter.skipDispatch(requestMock); - - // then: the result should be false - assertThat(result).isFalse(); - } - - private static class SpanCapturingFilterChain implements FilterChain { - - public Span capturedSpan; - - @Override - public void doFilter(ServletRequest request, ServletResponse response) throws IOException, ServletException { - capturedSpan = Tracer.getInstance().getCurrentSpan(); - } - } - - // VERIFY isAsyncRequest ============================== - - @Test - public void isAsyncRequest_should_return_false() { - // given - RequestTracingFilterOldServlet filterSpy = spy(getBasicFilter()); - - // when - boolean result = filterSpy.isAsyncRequest(requestMock); - - // then - assertThat(result).isFalse(); - verify(filterSpy).isAsyncRequest(requestMock); - verifyNoMoreInteractions(filterSpy); - } - - // VERIFY setupTracingCompletionWhenAsyncRequestCompletes ============ - - @Test - public void setupTracingCompletionWhenAsyncRequestCompletes_should_do_nothing( - ) throws ServletException, IOException { - // given - RequestTracingFilterOldServlet filterSpy = spy(getBasicFilter()); - TracingState tracingStateMock = mock(TracingState.class); - - // when - filterSpy.setupTracingCompletionWhenAsyncRequestCompletes(requestMock, tracingStateMock); - - // then - verify(filterSpy).setupTracingCompletionWhenAsyncRequestCompletes(requestMock, tracingStateMock); - verifyNoMoreInteractions(filterSpy, requestMock, tracingStateMock); - } -} diff --git a/wingtips-servlet-api/README.md b/wingtips-servlet-api/README.md index 05deed13..28323c05 100644 --- a/wingtips-servlet-api/README.md +++ b/wingtips-servlet-api/README.md @@ -1,17 +1,26 @@ # Wingtips - wingtips-servlet-api -Wingtips is a distributed tracing solution for Java based on the [Google Dapper paper](http://static.googleusercontent.com/media/research.google.com/en/us/pubs/archive/36356.pdf). +Wingtips is a distributed tracing solution for Java based on the +[Google Dapper paper](http://static.googleusercontent.com/media/research.google.com/en/us/pubs/archive/36356.pdf). -This module is a plugin extension module of the core Wingtips library and contains support for distributed tracing in a Java **Servlet 3 API** environment (for Servlet 2.x environments please refer to the [wingtips-old-servlet-api](../wingtips-old-servlet-api) module). The features it provides are: +This module is a plugin extension module of the core Wingtips library and contains support for distributed tracing in a +Java Servlet environment. The features it provides are: * **HttpSpanFactory** - Utility class that extracts span information from incoming `HttpServletRequest` requests. -* **RequestTracingFilter** - A Servlet Filter that handles all of the work for enabling a new span when a request comes in and completing it when the request finishes. This filter automatically uses `HttpSpanFactory` to extract parent span information from the incoming request headers for the new span if available. Sets the `X-B3-TraceId` response header to the Trace ID for each request. Supports Servlet 3.0+ asynchronous request processing. You can set the `user-id-header-keys-list` servlet filter param if you expect any request headers that represent a user ID (if you don't have any user ID headers then this can be ignored). +* **RequestTracingFilter** - A Servlet Filter that handles all of the work for enabling a new span when a request comes +in and completing it when the request finishes. This filter automatically uses `HttpSpanFactory` to extract parent span +information from the incoming request headers for the new span if available. Sets the `X-B3-TraceId` response header to +the Trace ID for each request. Supports Servlet 3 environments (including asynchronous requests) as well as Servlet 2.x +environments. You can set the `user-id-header-keys-list` servlet filter param if you expect your service to receive any +request headers that represent a user ID (if you don't have any user ID headers then this can be ignored). -Please make sure you have read the [base project README.md](../README.md). This readme assumes you understand the principles and usage instructions described there. +Please make sure you have read the [base project README.md](../README.md). This readme assumes you understand the +principles and usage instructions described there. ## Usage Example -The following example shows how you might setup the tracing Servlet Filter when the service expects one of two possible header keys that represent the user ID of the user making the call: `userid` or `altuserid`. +The following example shows how you might setup the tracing Servlet Filter when the service expects one of two possible +header keys that represent the user ID of the user making the call: `userid` or `altuserid`. **Add the following to web.xml** @@ -31,9 +40,12 @@ The following example shows how you might setup the tracing Servlet Filter when ``` -If your service does not have any user ID headers you can remove the `` element entirely or set the `` to be empty. +If your service does not have any user ID headers you can remove the `` element entirely or set the +`` to be empty. -That's it for incoming requests. This Filter will do the right thing and start a root span or child span for incoming requests (depending on whether or not the caller included tracing headers), add the trace ID to the response as a response header, and guarantees completion of the overall request span right before the response is sent. +That's it for incoming requests. This Filter will do the right thing and start a root span or child span for incoming +requests (depending on whether or not the caller included tracing headers), add the trace ID to the response as a +response header, and guarantees completion of the overall request span right before the response is sent. **Embedded environments** @@ -44,7 +56,10 @@ register `RequestTracingFilter` with embedded Jetty. ### Propagating the Tracing Information to Downstream Systems -This Filter takes care of setting up the overall request span for incoming requests, but propagating the tracing information to downstream systems is still your responsibility. When you call another system you must grab the current span via `Tracer.getInstance().getCurrentSpan()` and put its field values into the downstream call's request headers using the constants in `TraceHeaders` as the header keys. For example: +This Filter takes care of setting up the overall request span for incoming requests, but propagating the tracing +information to downstream systems is still your responsibility. When you call another system you must grab the current +span via `Tracer.getInstance().getCurrentSpan()` and put its field values into the downstream call's request headers +using the constants in `TraceHeaders` as the header keys. For example: ``` java Span currentSpan = Tracer.getInstance().getCurrentSpan(); @@ -72,13 +87,16 @@ See the [base project readme's section on propagation](../README.md#propagating_ propagating tracing information. You may also want to consider [wrapping downstream calls in a subspan](../README.md#sub_spans_for_downstream_calls). -## NOTE - Servlet API 3.0.1 or later dependency required at runtime +## NOTE - Servlet API dependency required at runtime -This `wingtips-servlet-api` module has a minimum Servlet 3.0.1 requirement, but does not export any transitive Servlet -API dependencies to prevent runtime version conflicts with whatever Servlet environment you deploy to. +This `wingtips-servlet-api` module does not export any transitive Servlet API dependencies to prevent runtime version +conflicts with whatever Servlet environment you deploy to. This should not affect most users since this library is likely to be used in a Servlet environment where the Servlet -API is on the classpath at runtime, however if you receive `NoClassDefFoundError`s related to Servlet API classes then -you'll need to pull a Servlet API dependency into your project that supports a minimum Servlet 3.0.1 version. For -reference, `wingtips-servlet-api` uses the compile-only Servlet API dependency -[`javax.servlet:javax.servlet-api:3.0.1`](http://search.maven.org/#artifactdetails%7Cjavax.servlet%7Cjavax.servlet-api%7C3.0.1%7Cjar). +API is on the classpath at runtime, however if you receive class-not-found errors related to Servlet API classes then +you'll need to pull a Servlet API dependency into your project. Library authors who wish to build on functionality in +this module might need to do this. Which Servlet API dependency you pull in depends on the type of Servlet environment +you want to support (Servlet 2.x or Servlet 3+). For example: + +* Servlet 3 API dependency: [`javax.servlet:javax.servlet-api:[servlet-3-api-version]`](http://search.maven.org/#search%7Cgav%7C1%7Cg%3A%22javax.servlet%22%20AND%20a%3A%22javax.servlet-api%22) +* Servlet 2 API dependency: [`javax.servlet:servlet-api:[servlet-2-api-version]`](http://search.maven.org/#search%7Cgav%7C1%7Cg%3A%22javax.servlet%22%20AND%20a%3A%22servlet-api%22) diff --git a/wingtips-servlet-api/build.gradle b/wingtips-servlet-api/build.gradle index 78827e53..12835fce 100644 --- a/wingtips-servlet-api/build.gradle +++ b/wingtips-servlet-api/build.gradle @@ -1,9 +1,14 @@ evaluationDependsOn(':') +compileTestJava { + sourceCompatibility = JavaVersion.VERSION_1_8 + targetCompatibility = JavaVersion.VERSION_1_8 +} + dependencies { compile( - project(":wingtips-old-servlet-api") + project(":wingtips-core") ) compileOnly( "javax.servlet:javax.servlet-api:$servletApiVersion" diff --git a/wingtips-old-servlet-api/src/main/java/com/nike/wingtips/servlet/HttpSpanFactory.java b/wingtips-servlet-api/src/main/java/com/nike/wingtips/servlet/HttpSpanFactory.java similarity index 100% rename from wingtips-old-servlet-api/src/main/java/com/nike/wingtips/servlet/HttpSpanFactory.java rename to wingtips-servlet-api/src/main/java/com/nike/wingtips/servlet/HttpSpanFactory.java diff --git a/wingtips-servlet-api/src/main/java/com/nike/wingtips/servlet/RequestTracingFilter.java b/wingtips-servlet-api/src/main/java/com/nike/wingtips/servlet/RequestTracingFilter.java index f1c2b90b..0fd2f7cd 100644 --- a/wingtips-servlet-api/src/main/java/com/nike/wingtips/servlet/RequestTracingFilter.java +++ b/wingtips-servlet-api/src/main/java/com/nike/wingtips/servlet/RequestTracingFilter.java @@ -1,121 +1,199 @@ package com.nike.wingtips.servlet; +import com.nike.wingtips.Span; import com.nike.wingtips.TraceHeaders; +import com.nike.wingtips.Tracer; import com.nike.wingtips.util.TracingState; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import java.lang.reflect.Method; +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; -import javax.servlet.AsyncListener; -import javax.servlet.DispatcherType; +import javax.servlet.Filter; +import javax.servlet.FilterChain; +import javax.servlet.FilterConfig; +import javax.servlet.ServletException; import javax.servlet.ServletRequest; +import javax.servlet.ServletResponse; import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; + +import static com.nike.wingtips.servlet.ServletRuntime.ASYNC_LISTENER_CLASSNAME; +import static com.nike.wingtips.util.AsyncWingtipsHelperJava7.unlinkTracingFromCurrentThread; /** * Makes sure distributed tracing is handled for each request. Sets up the span for incoming requests (either an * entirely new root span or one with a parent, depending on what is in the incoming request's headers), and also sets * the {@link TraceHeaders#TRACE_ID} on the response. This is designed to only run once per request. * - *

This class supports Servlet 3 async requests. For Servlet 2.x environments where async requests are not supported - * please use {@link RequestTracingFilterOldServlet} instead. - * *

NOTE: You can override {@link #getUserIdHeaderKeys()} if your service is expecting user ID header(s) and you can't * (or don't want to) set up those headers via the {@link #USER_ID_HEADER_KEYS_LIST_INIT_PARAM_NAME} init parameter. * - * Extension of {@link RequestTracingFilterOldServlet} that adds Servlet 3 support, specifically around async requests. + *

This class supports Servlet 3 async requests when running in a Servlet 3+ environment. It also supports running + * in a Servlet 2.x environment. * * @author Nic Munroe */ -public class RequestTracingFilter extends RequestTracingFilterOldServlet { - +@SuppressWarnings("WeakerAccess") +public class RequestTracingFilter implements Filter { private final Logger logger = LoggerFactory.getLogger(this.getClass()); - protected Boolean containerSupportsAsyncRequests = null; - /** - * Determines whether the current servlet container supports Servlet 3 async requests by using reflection to - * inspect the given {@link ServletRequest} implementation class to see if it contains an async-related method - * introduced with the Servlet 3 API. + * This attribute key will be set to a value of true via {@link ServletRequest#setAttribute(String, Object)} the + * first time this filter's distributed tracing logic is run for any given request. This filter will then see this + * attribute on any subsequent executions for the same request and continue the filter chain without executing the + * distributed tracing logic again to make sure this filter's logic is only executed once per request. * - * @param servletRequestClass ServletRequest implementation class - * @return true if the given servletRequest implementation class supports the getAsyncContext() method, - * otherwise false + *

If you want to prevent this filter from executing on specific requests then you can override {@link + * #skipDispatch(HttpServletRequest)} to return true for any requests where you don't want distributed tracing to + * occur. */ - protected boolean supportsAsyncRequests(Class servletRequestClass) { - Method asyncContextMethod = null; - try { - asyncContextMethod = servletRequestClass.getMethod("getAsyncContext"); - } catch (Exception ex) { - logger.warn( - "Servlet 3 async requests are not supported on the current container. " - + "This filter will default to blocking request behavior", ex - ); - } - - return asyncContextMethod != null; - } + public static final String FILTER_HAS_ALREADY_EXECUTED_ATTRIBUTE = "RequestTracingFilterAlreadyFiltered"; /** - * Helper method for determining (and then caching) whether the given {@link ServletRequest} concrete - * implementation supports Servlet 3 async requests. This could be necessary if a WAR-based project is built - * using Servlet 3 APIs and then deployed to a Servlet 2.x-only environment. + * Corresponds to {@link javax.servlet.RequestDispatcher#ERROR_REQUEST_URI}. This will be populated in the request + * attributes during an error dispatch. * - * @param servletRequest The concrete {@link ServletRequest} implementation to check. - * @return true if the given concrete {@link ServletRequest} implementation supports Servlet 3 async requests, - * false otherwise. + * @deprecated This is no longer being used and will be removed in a future update. */ - protected boolean containerSupportsAsyncRequests(ServletRequest servletRequest) { - // It's ok that this isn't synchronized - this logic is idempotent and if it's executed a few extra times - // due to concurrent requests when the service first starts up it won't hurt anything. - if (containerSupportsAsyncRequests == null) { - containerSupportsAsyncRequests = supportsAsyncRequests(servletRequest.getClass()); + @Deprecated + public static final String ERROR_REQUEST_URI_ATTRIBUTE = "javax.servlet.error.request_uri"; + + /** + * The param name for the "list of user ID header keys" init param for this filter. The value of this init param + * will be parsed for the list of user ID header keys to use when calling {@link + * HttpSpanFactory#fromHttpServletRequest(HttpServletRequest, List)} or {@link + * HttpSpanFactory#getUserIdFromHttpServletRequest(HttpServletRequest, List)}. The value for this init param is + * expected to be a comma-delimited list. + */ + public static final String USER_ID_HEADER_KEYS_LIST_INIT_PARAM_NAME = "user-id-header-keys-list"; + + protected ServletRuntime servletRuntime; + protected List userIdHeaderKeysFromInitParam; + + @Override + public void init(FilterConfig filterConfig) throws ServletException { + String userIdHeaderKeysListString = filterConfig.getInitParameter(USER_ID_HEADER_KEYS_LIST_INIT_PARAM_NAME); + if (userIdHeaderKeysListString != null) { + List parsedList = new ArrayList<>(); + for (String headerKey : userIdHeaderKeysListString.split(",")) { + String trimmedHeaderKey = headerKey.trim(); + if (trimmedHeaderKey.length() > 0) + parsedList.add(trimmedHeaderKey); + } + userIdHeaderKeysFromInitParam = Collections.unmodifiableList(parsedList); } + } - return containerSupportsAsyncRequests; + @Override + public void destroy() { + // Nothing to do } /** - * The result of calling {@link HttpServletRequest#isAsyncStarted()} on the given request, assuming the runtime - * environment's {@link ServletRequest} implementation supports Servlet 3 async requests (necessary for the case - * where a WAR built with Servlet 3 support is deployed to a Servlet 2.x-only container). - * - * @param request The request to inspect to see if it's part of an async servlet request or not. - * @return The result of calling {@link HttpServletRequest#isAsyncStarted()} on the given request. + * Wrapper around {@link #doFilterInternal(HttpServletRequest, HttpServletResponse, FilterChain)} to make sure this + * filter's logic is only executed once per request. */ @Override - protected boolean isAsyncRequest(HttpServletRequest request) { - return containerSupportsAsyncRequests(request) && request.isAsyncStarted(); + public void doFilter(ServletRequest request, + ServletResponse response, + FilterChain filterChain) throws IOException, ServletException { + if (!(request instanceof HttpServletRequest) || !(response instanceof HttpServletResponse)) { + throw new ServletException(this.getClass().getName() + " only supports HTTP requests"); + } + HttpServletRequest httpRequest = (HttpServletRequest) request; + HttpServletResponse httpResponse = (HttpServletResponse) response; + + boolean filterHasAlreadyExecuted = request.getAttribute(FILTER_HAS_ALREADY_EXECUTED_ATTRIBUTE) != null; + if (filterHasAlreadyExecuted || skipDispatch(httpRequest)) { + // Already executed or we're supposed to skip, so continue the filter chain without doing the + // distributed tracing work. + filterChain.doFilter(request, response); + } + else { + // Time to execute the distributed tracing logic. + request.setAttribute(FILTER_HAS_ALREADY_EXECUTED_ATTRIBUTE, Boolean.TRUE); + doFilterInternal(httpRequest, httpResponse, filterChain); + } } /** - * Adds a {@link AsyncListener} to the given request's {@link HttpServletRequest#getAsyncContext()} so that the - * given {@link TracingState} will be completed appropriately when this async servlet request completes. - * - * @param asyncRequest The async servlet request (guaranteed to be async since this method will only be called when - * {@link #isAsyncRequest(HttpServletRequest)} returns true). - * @param originalRequestTracingState The {@link TracingState} that was generated when this request started, and - * which should be completed when the given async servlet request finishes. + * Performs the distributed tracing work for each request's overall span. Guaranteed to only be called once per + * request. */ - @Override - protected void setupTracingCompletionWhenAsyncRequestCompletes(HttpServletRequest asyncRequest, - TracingState originalRequestTracingState) { - // Async processing was started, so we have to complete it with a listener. - asyncRequest.getAsyncContext().addListener( - new WingtipsRequestSpanCompletionAsyncListener(originalRequestTracingState) - ); + protected void doFilterInternal(HttpServletRequest request, + HttpServletResponse response, + FilterChain filterChain) throws ServletException, IOException { + // Surround the tracing filter logic with a try/finally that guarantees the original tracing and MDC info found + // on the current thread at the beginning of this method is restored to this thread before this method + // returns, even if the request ends up being an async request. Otherwise there's the possibility of + // incorrect tracing information sticking around on this thread and potentially polluting other requests. + TracingState originalThreadInfo = TracingState.getCurrentThreadTracingState(); + try { + // See if there's trace info in the incoming request's headers. If so it becomes the parent trace. + Tracer tracer = Tracer.getInstance(); + final Span parentSpan = HttpSpanFactory.fromHttpServletRequest(request, getUserIdHeaderKeys()); + Span newSpan; + + if (parentSpan != null) { + logger.debug("Found parent Span {}", parentSpan); + newSpan = tracer.startRequestWithChildSpan(parentSpan, HttpSpanFactory.getSpanName(request)); + } else { + newSpan = tracer.startRequestWithRootSpan( + HttpSpanFactory.getSpanName(request), + HttpSpanFactory.getUserIdFromHttpServletRequest(request, getUserIdHeaderKeys()) + ); + logger.debug("Parent span not found, starting a new span {}", newSpan); + } + + // Put the new span's trace info into the request attributes. + request.setAttribute(TraceHeaders.TRACE_SAMPLED, newSpan.isSampleable()); + request.setAttribute(TraceHeaders.TRACE_ID, newSpan.getTraceId()); + request.setAttribute(TraceHeaders.SPAN_ID, newSpan.getSpanId()); + request.setAttribute(TraceHeaders.PARENT_SPAN_ID, newSpan.getParentSpanId()); + request.setAttribute(TraceHeaders.SPAN_NAME, newSpan.getSpanName()); + request.setAttribute(Span.class.getName(), newSpan); + + // Make sure we set the trace ID on the response header now before the response is committed (if we wait + // until after the filter chain then the response might already be committed, silently preventing us + // from setting the response header) + response.setHeader(TraceHeaders.TRACE_ID, newSpan.getTraceId()); + + TracingState originalRequestTracingState = TracingState.getCurrentThreadTracingState(); + try { + filterChain.doFilter(request, response); + } finally { + if (isAsyncRequest(request)) { + // Async, so we need to attach a listener to complete the original tracing state when the async + // servlet request finishes. + setupTracingCompletionWhenAsyncRequestCompletes(request, originalRequestTracingState); + } + else { + // Not async, so we need to complete the request span now. + tracer.completeRequestSpan(); + } + } + } + finally { + //noinspection deprecation + unlinkTracingFromCurrentThread(originalThreadInfo); + } } /** - * Corresponds to {@link javax.servlet.RequestDispatcher#ERROR_REQUEST_URI}. This will be populated in the request - * attributes during an error dispatch. - * - * @deprecated This is no longer being used and will be removed in a future update. + * @return true if {@link #doFilterInternal(HttpServletRequest, HttpServletResponse, FilterChain)} should be + * skipped (and therefore prevent distributed tracing logic from starting), false otherwise. This defaults to + * returning false so the first execution of this filter will always trigger distributed tracing, so if you have a + * need to skip distributed tracing for a request you can override this method and have whatever logic you need. */ - @Deprecated - public static final String ERROR_REQUEST_URI_ATTRIBUTE = "javax.servlet.error.request_uri"; + protected boolean skipDispatch(HttpServletRequest request) { + return false; + } /** * The dispatcher type {@code javax.servlet.DispatcherType.ASYNC} introduced in Servlet 3.0 means a filter can be @@ -129,7 +207,77 @@ protected void setupTracingCompletionWhenAsyncRequestCompletes(HttpServletReques */ @Deprecated protected boolean isAsyncDispatch(HttpServletRequest request) { - return DispatcherType.ASYNC.equals(request.getDispatcherType()); + return getServletRuntime(request).isAsyncDispatch(request); + } + + /** + * The list of header keys that will be used to search the request headers for a user ID to set on the {@link Span} + * for the request. The user ID header keys will be searched in list order, and the first non-empty user ID header + * value found will be used as the {@link Span#getUserId()}. You can safely return null or an empty list for this + * method if there is no user ID to extract; if you return null/empty then the request span's {@link + * Span#getUserId()} will be null. + * + *

By default this method will return the list specified via the {@link + * #USER_ID_HEADER_KEYS_LIST_INIT_PARAM_NAME} init param, or null if that init param does not exist. + * + * @return The list of header keys that will be used to search the request headers for a user ID to set on the + * {@link Span} for the request. This method may return null or an empty list if there are no user IDs to search + * for. + */ + protected List getUserIdHeaderKeys() { + return userIdHeaderKeysFromInitParam; + } + + /** + * Helper method for determining (and then caching) the {@link ServletRuntime} implementation appropriate for + * the current Servlet runtime environment. If the current Servlet runtime environment supports the Servlet 3 API + * (i.e. async requests) then a Servlet-3-async-request-capable implementation will be returned, otherwise a + * Servlet-2-blocking-requests-only implementation will be returned. The first time this method is called the + * result will be cached, and the cached value returned for subsequent calls. + * + * @param request The concrete {@link ServletRequest} implementation use to determine the Servlet runtime + * environment. + * @return The {@link ServletRuntime} implementation appropriate for the current Servlet runtime environment. + */ + protected ServletRuntime getServletRuntime(ServletRequest request) { + // It's ok that this isn't synchronized - this logic is idempotent and if it's executed a few extra times + // due to concurrent requests when the service first starts up it won't hurt anything. + if (servletRuntime == null) { + servletRuntime = ServletRuntime.determineServletRuntime(request.getClass(), ASYNC_LISTENER_CLASSNAME); + } + + return servletRuntime; + } + + /** + * Returns the value of calling {@link ServletRuntime#isAsyncRequest(HttpServletRequest)} on the {@link + * ServletRuntime} returned by {@link #getServletRuntime(ServletRequest)}. This method is here to allow + * easy overriding by subclasses if needed, where {@link ServletRuntime} is not in scope. + * + * @param request The request to inspect to see if it's part of an async servlet request or not. + * @return the value of calling {@link ServletRuntime#isAsyncRequest(HttpServletRequest)} on the {@link + * ServletRuntime} returned by {@link #getServletRuntime(ServletRequest)}. + */ + protected boolean isAsyncRequest(HttpServletRequest request) { + return getServletRuntime(request).isAsyncRequest(request); + } + + /** + * Delegates to {@link + * ServletRuntime#setupTracingCompletionWhenAsyncRequestCompletes(HttpServletRequest, TracingState)}, with the + * {@link ServletRuntime} retrieved via {@link #getServletRuntime(ServletRequest)}. This method is here to + * allow easy overriding by subclasses if needed, where {@link ServletRuntime} is not in scope. + * + * @param asyncRequest The async servlet request (guaranteed to be async since this method will only be called when + * {@link #isAsyncRequest(HttpServletRequest)} returns true). + * @param originalRequestTracingState The {@link TracingState} that was generated when this request started, and + * which should be completed when the given async servlet request finishes. + */ + protected void setupTracingCompletionWhenAsyncRequestCompletes(HttpServletRequest asyncRequest, + TracingState originalRequestTracingState) { + getServletRuntime(asyncRequest).setupTracingCompletionWhenAsyncRequestCompletes( + asyncRequest, originalRequestTracingState + ); } } diff --git a/wingtips-servlet-api/src/main/java/com/nike/wingtips/servlet/RequestTracingFilterNoAsync.java b/wingtips-servlet-api/src/main/java/com/nike/wingtips/servlet/RequestTracingFilterNoAsync.java index a68a5a32..450e2aae 100644 --- a/wingtips-servlet-api/src/main/java/com/nike/wingtips/servlet/RequestTracingFilterNoAsync.java +++ b/wingtips-servlet-api/src/main/java/com/nike/wingtips/servlet/RequestTracingFilterNoAsync.java @@ -7,8 +7,7 @@ * * @deprecated This class is no longer needed - the super {@link RequestTracingFilter} class is no longer abstract * and does not need subclasses to tell it whether the request is async. You should move to using {@link - * RequestTracingFilter} directly, or {@link RequestTracingFilterOldServlet} if you're in a Servlet 2.x environment. - * This class will be deleted in a future update. + * RequestTracingFilter} directly. This class will be deleted in a future update. * * @author Nic Munroe */ diff --git a/wingtips-old-servlet-api/src/main/java/com/nike/wingtips/servlet/RequestWithHeadersServletAdapter.java b/wingtips-servlet-api/src/main/java/com/nike/wingtips/servlet/RequestWithHeadersServletAdapter.java similarity index 100% rename from wingtips-old-servlet-api/src/main/java/com/nike/wingtips/servlet/RequestWithHeadersServletAdapter.java rename to wingtips-servlet-api/src/main/java/com/nike/wingtips/servlet/RequestWithHeadersServletAdapter.java diff --git a/wingtips-servlet-api/src/main/java/com/nike/wingtips/servlet/ServletRuntime.java b/wingtips-servlet-api/src/main/java/com/nike/wingtips/servlet/ServletRuntime.java new file mode 100644 index 00000000..bf252e99 --- /dev/null +++ b/wingtips-servlet-api/src/main/java/com/nike/wingtips/servlet/ServletRuntime.java @@ -0,0 +1,153 @@ +package com.nike.wingtips.servlet; + +import com.nike.wingtips.util.TracingState; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import javax.servlet.ServletRequest; +import javax.servlet.http.HttpServletRequest; + +/** + * A class for abstracting out bits of the Servlet API that are version-dependent, e.g. async request support + * doesn't show up until Servlet 3.0 API. This abstraction allows us to have one servlet filter that works in + * pre-Servlet-3 or post-Servlet-3 environments without receiving class-not-found or method-not-found exceptions. + * + *

This class was derived from + * + * Brave's ServletRuntime, which was itself derived from OkHttp's {@code okhttp3.internal.platform.Platform}. + * + *

You should not need to worry about this class - it is an internal implementation detail for + * {@link RequestTracingFilter}. + * + * @author Nic Munroe + */ +abstract class ServletRuntime { + + private static final Logger logger = LoggerFactory.getLogger(ServletRuntime.class); + + /** + * The classname for {@code AsyncListener}. Pass this in to {@link #determineServletRuntime(Class, String)} along + * with the servlet request class to determine which servlet runtime environment you're running in. + */ + static final String ASYNC_LISTENER_CLASSNAME = "javax.servlet.AsyncListener"; + + /** + * @param request The request to inspect to see if it's async or not. + * @return true if the given request represents an async request, false otherwise. + */ + abstract boolean isAsyncRequest(HttpServletRequest request); + + /** + * This method should be overridden to do the right thing depending on the Servlet runtime environment - Servlet + * 2.x environments should do nothing or throw an exception since they do not support async requests (and this + * method should never be called since {@link #isAsyncRequest(HttpServletRequest)} should return false), but + * Servlet 3+ environments should setup a listener that will complete the given {@link TracingState} when the given + * async request finishes. The listener code would look something like: + * + *

+     *      AsyncListener spanCompletingAsyncListener = ...;
+     *      asyncRequest.getAsyncContext().addListener(spanCompletingAsyncListener);
+     * 
+ * + * @param asyncRequest The async servlet request (guaranteed to be async since this method will only be called when + * {@link #isAsyncRequest(HttpServletRequest)} returns true). + * @param originalRequestTracingState The {@link TracingState} that was generated when this request started, and + * which should be completed when the given async servlet request finishes. + */ + abstract void setupTracingCompletionWhenAsyncRequestCompletes(HttpServletRequest asyncRequest, + TracingState originalRequestTracingState); + + /** + * The dispatcher type {@code javax.servlet.DispatcherType.ASYNC} introduced in Servlet 3.0 means a filter can be + * invoked in more than one thread over the course of a single request. This method should return {@code true} if + * the filter is currently executing within an asynchronous dispatch. + * + * @param request the current request + * + * @deprecated This method is no longer used to determine whether the servlet filter should execute, and will be + * removed in a future update. It is here to support {@link + * RequestTracingFilter#isAsyncDispatch(HttpServletRequest)}, which only remains to prevent breaking impls that + * overrode the method. + */ + @Deprecated + abstract boolean isAsyncDispatch(HttpServletRequest request); + + /** + * Determines whether the current servlet container supports Servlet 3 async requests by using reflection to + * inspect the given {@link ServletRequest} implementation class to see if it contains an async-related method + * introduced with the Servlet 3 API, and attempts to use the given string to load the class for + * {@code javax.servlet.AsyncListener}. If both of those checks pass without error then {@link Servlet3Runtime} + * will be returned, otherwise a Servlet 2.x environment is assumed and {@link Servlet2Runtime} will be returned. + * + * @param servletRequestClass The {@link ServletRequest} implementation class to check. + * @param asyncListenerClassname This should be "javax.servlet.AsyncListener" at runtime (use the {@link + * #ASYNC_LISTENER_CLASSNAME} constant). It is passed in as an argument to facilitate testing scenarios. + * @return true if the given {@link ServletRequest} implementation class supports the getAsyncContext() method + * and the given {@code javax.servlet.AsyncListener} classname could be loaded, otherwise false. + */ + static ServletRuntime determineServletRuntime(Class servletRequestClass, String asyncListenerClassname) { + try { + servletRequestClass.getMethod("getAsyncContext"); + Class.forName(asyncListenerClassname); + // No exceptions were thrown, so we're running in a Servlet 3+ environment. + return new Servlet3Runtime(); + } catch (Exception ex) { + logger.warn( + "Servlet 3 async requests are not supported on the current container. " + + "RequestTracingFilter will default to blocking request behavior (Servlet 2.x). " + + "Exception message indicating a Servlet 2.x environment: {}", ex.toString() + ); + return new Servlet2Runtime(); + } + } + + /** + * Implementation of {@link ServletRuntime} for Servlet 2.x environments. + */ + static class Servlet2Runtime extends ServletRuntime { + + @Override + public boolean isAsyncRequest(HttpServletRequest request) { + return false; + } + + @Override + public void setupTracingCompletionWhenAsyncRequestCompletes(HttpServletRequest asyncRequest, + TracingState originalRequestTracingState) { + throw new IllegalStateException("This method should never be called in a pre-Servlet-3.0 environment."); + } + + @Override + boolean isAsyncDispatch(HttpServletRequest request) { + return false; + } + } + + /** + * Implementation of {@link ServletRuntime} for Servlet 3+ environments that supports async requests. + */ + static class Servlet3Runtime extends ServletRuntime { + + @Override + public boolean isAsyncRequest(HttpServletRequest request) { + return request.isAsyncStarted(); + } + + @Override + public void setupTracingCompletionWhenAsyncRequestCompletes(HttpServletRequest asyncRequest, + TracingState originalRequestTracingState) { + // Async processing was started, so we have to complete it with a listener. + asyncRequest.getAsyncContext().addListener( + new WingtipsRequestSpanCompletionAsyncListener(originalRequestTracingState) + ); + } + + @Override + boolean isAsyncDispatch(HttpServletRequest request) { + // Do a string comparison to avoid pulling in the DispatcherType import. + return "ASYNC".equals(request.getDispatcherType().name()); + } + } + +} diff --git a/wingtips-servlet-api/src/main/java/com/nike/wingtips/servlet/WingtipsRequestSpanCompletionAsyncListener.java b/wingtips-servlet-api/src/main/java/com/nike/wingtips/servlet/WingtipsRequestSpanCompletionAsyncListener.java index 11c72c10..8abd7a6c 100644 --- a/wingtips-servlet-api/src/main/java/com/nike/wingtips/servlet/WingtipsRequestSpanCompletionAsyncListener.java +++ b/wingtips-servlet-api/src/main/java/com/nike/wingtips/servlet/WingtipsRequestSpanCompletionAsyncListener.java @@ -1,6 +1,7 @@ package com.nike.wingtips.servlet; import com.nike.wingtips.Tracer; +import com.nike.wingtips.servlet.ServletRuntime.Servlet3Runtime; import com.nike.wingtips.util.TracingState; import java.io.IOException; @@ -12,9 +13,9 @@ import static com.nike.wingtips.util.AsyncWingtipsHelperJava7.runnableWithTracing; /** - * Helper class for {@link RequestTracingFilter} that implements {@link AsyncListener}, whose job is to complete the + * Helper class for {@link Servlet3Runtime} that implements {@link AsyncListener}, whose job is to complete the * overall request span when an async servlet request finishes. You should not need to worry about this class - it - * is an internal implementation detail for {@link RequestTracingFilter}. + * is an internal implementation detail for {@link Servlet3Runtime}. * * @author Nic Munroe */ diff --git a/wingtips-old-servlet-api/src/test/java/com/nike/wingtips/servlet/HttpSpanFactoryTest.java b/wingtips-servlet-api/src/test/java/com/nike/wingtips/servlet/HttpSpanFactoryTest.java similarity index 100% rename from wingtips-old-servlet-api/src/test/java/com/nike/wingtips/servlet/HttpSpanFactoryTest.java rename to wingtips-servlet-api/src/test/java/com/nike/wingtips/servlet/HttpSpanFactoryTest.java diff --git a/wingtips-servlet-api/src/test/java/com/nike/wingtips/servlet/RequestTracingFilterTest.java b/wingtips-servlet-api/src/test/java/com/nike/wingtips/servlet/RequestTracingFilterTest.java index fa95c6fa..2e0fde45 100644 --- a/wingtips-servlet-api/src/test/java/com/nike/wingtips/servlet/RequestTracingFilterTest.java +++ b/wingtips-servlet-api/src/test/java/com/nike/wingtips/servlet/RequestTracingFilterTest.java @@ -1,28 +1,30 @@ package com.nike.wingtips.servlet; import com.nike.wingtips.Span; +import com.nike.wingtips.Span.SpanPurpose; +import com.nike.wingtips.TraceAndSpanIdGenerator; +import com.nike.wingtips.TraceHeaders; import com.nike.wingtips.Tracer; import com.nike.wingtips.util.TracingState; import com.tngtech.java.junit.dataprovider.DataProvider; import com.tngtech.java.junit.dataprovider.DataProviderRunner; +import com.tngtech.java.junit.dataprovider.UseDataProvider; -import org.assertj.core.api.ThrowableAssert; import org.junit.After; import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; -import org.mockito.invocation.InvocationOnMock; -import org.mockito.stubbing.Answer; import org.slf4j.MDC; import java.io.IOException; import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; import java.util.List; import javax.servlet.AsyncContext; import javax.servlet.AsyncListener; -import javax.servlet.DispatcherType; import javax.servlet.FilterChain; import javax.servlet.FilterConfig; import javax.servlet.ServletException; @@ -31,8 +33,11 @@ import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; +import static com.nike.wingtips.servlet.ServletRuntime.ASYNC_LISTENER_CLASSNAME; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.catchThrowable; +import static org.assertj.core.api.Fail.fail; +import static org.mockito.BDDMockito.given; import static org.mockito.Matchers.any; import static org.mockito.Matchers.eq; import static org.mockito.Mockito.doAnswer; @@ -41,6 +46,7 @@ import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoMoreInteractions; @@ -52,10 +58,19 @@ public class RequestTracingFilterTest { private HttpServletRequest requestMock; private HttpServletResponse responseMock; + private FilterChain filterChainMock; private SpanCapturingFilterChain spanCapturingFilterChain; + @SuppressWarnings("FieldCanBeLocal") private AsyncContext listenerCapturingAsyncContext; private List capturedAsyncListeners; private FilterConfig filterConfigMock; + private ServletRuntime servletRuntimeMock; + + private static final String USER_ID_HEADER_KEY = "userId"; + private static final String ALT_USER_ID_HEADER_KEY = "altUserId"; + private static final List USER_ID_HEADER_KEYS = Arrays.asList(USER_ID_HEADER_KEY, ALT_USER_ID_HEADER_KEY); + private static final String USER_ID_HEADER_KEYS_INIT_PARAM_VALUE_STRING = + USER_ID_HEADER_KEYS.toString().replace("[", "").replace("]", ""); private RequestTracingFilter getBasicFilter() { RequestTracingFilter filter = new RequestTracingFilter(); @@ -76,12 +91,9 @@ private void setupAsyncContextWorkflow() { doReturn(listenerCapturingAsyncContext).when(requestMock).getAsyncContext(); doReturn(true).when(requestMock).isAsyncStarted(); - doAnswer(new Answer() { - @Override - public Object answer(InvocationOnMock invocation) throws Throwable { - capturedAsyncListeners.add((AsyncListener) invocation.getArguments()[0]); - return null; - } + doAnswer(invocation -> { + capturedAsyncListeners.add((AsyncListener) invocation.getArguments()[0]); + return null; }).when(listenerCapturingAsyncContext).addListener(any(AsyncListener.class)); } @@ -89,9 +101,15 @@ public Object answer(InvocationOnMock invocation) throws Throwable { public void setupMethod() { requestMock = mock(HttpServletRequest.class); responseMock = mock(HttpServletResponse.class); + filterChainMock = mock(FilterChain.class); spanCapturingFilterChain = new SpanCapturingFilterChain(); filterConfigMock = mock(FilterConfig.class); + doReturn(USER_ID_HEADER_KEYS_INIT_PARAM_VALUE_STRING) + .when(filterConfigMock) + .getInitParameter(RequestTracingFilter.USER_ID_HEADER_KEYS_LIST_INIT_PARAM_NAME); + + servletRuntimeMock = mock(ServletRuntime.class); resetTracing(); } @@ -106,8 +124,414 @@ private void resetTracing() { Tracer.getInstance().unregisterFromThread(); } + private static class SpanCapturingFilterChain implements FilterChain { + + Span capturedSpan; + + @Override + public void doFilter(ServletRequest request, ServletResponse response) throws IOException, ServletException { + capturedSpan = Tracer.getInstance().getCurrentSpan(); + } + } + + // VERIFY filter init, getUserIdHeaderKeys, and destroy ======================= + + @DataProvider + public static Object[][] userIdHeaderKeysInitParamDataProvider() { + + return new Object[][] { + { null, null }, + { "", Collections.emptyList() }, + { " \t \n ", Collections.emptyList() }, + { "asdf", Collections.singletonList("asdf") }, + { " , \n\t, asdf , \t\n ", Collections.singletonList("asdf") }, + { "ASDF,QWER", Arrays.asList("ASDF", "QWER") }, + { "ASDF, QWER, ZXCV", Arrays.asList("ASDF", "QWER", "ZXCV") } + }; + } + + @Test + @UseDataProvider("userIdHeaderKeysInitParamDataProvider") + public void init_method_gets_user_id_header_key_list_from_init_params_and_getUserIdHeaderKeys_exposes_them( + String userIdHeaderKeysInitParamValue, List expectedUserIdHeaderKeysList) throws ServletException { + // given + RequestTracingFilter filter = new RequestTracingFilter(); + FilterConfig filterConfig = mock(FilterConfig.class); + doReturn(userIdHeaderKeysInitParamValue) + .when(filterConfig) + .getInitParameter(RequestTracingFilter.USER_ID_HEADER_KEYS_LIST_INIT_PARAM_NAME); + filter.init(filterConfig); + + // when + List actualUserIdHeaderKeysList = filter.getUserIdHeaderKeys(); + + // then + assertThat(actualUserIdHeaderKeysList).isEqualTo(expectedUserIdHeaderKeysList); + if (actualUserIdHeaderKeysList != null) { + Exception caughtEx = null; + try { + actualUserIdHeaderKeysList.add("foo"); + } catch (Exception ex) { + caughtEx = ex; + } + assertThat(caughtEx).isNotNull(); + assertThat(caughtEx).isInstanceOf(UnsupportedOperationException.class); + } + } + + @Test + public void destroy_does_not_explode() { + // expect + getBasicFilter().destroy(); + // No explosion no problem + } + + // VERIFY doFilter =================================== + + @Test(expected = ServletException.class) + public void doFilter_should_explode_if_request_is_not_HttpServletRequest() throws IOException, ServletException { + // expect + getBasicFilter().doFilter(mock(ServletRequest.class), mock(HttpServletResponse.class), mock(FilterChain.class)); + fail("Expected ServletException but no exception was thrown"); + } + + @Test(expected = ServletException.class) + public void doFilter_should_explode_if_response_is_not_HttpServletResponse() throws IOException, ServletException { + // expect + getBasicFilter().doFilter(mock(HttpServletRequest.class), mock(ServletResponse.class), mock(FilterChain.class)); + fail("Expected ServletException but no exception was thrown"); + } + + @Test + public void doFilter_should_not_explode_if_request_and_response_are_HttpServletRequests_and_HttpServletResponses() throws IOException, ServletException { + // expect + getBasicFilter().doFilter(mock(HttpServletRequest.class), mock(HttpServletResponse.class), mock(FilterChain.class)); + // No explosion no problem + } + + @Test + public void doFilter_should_call_doFilterInternal_and_set_ALREADY_FILTERED_ATTRIBUTE_KEY_if_not_already_filtered_and_skipDispatch_returns_false() + throws IOException, ServletException { + // given: filter that returns false for skipDispatch and request that returns null for already-filtered attribute + RequestTracingFilter spyFilter = spy(getBasicFilter()); + given(requestMock.getAttribute( + RequestTracingFilter.FILTER_HAS_ALREADY_EXECUTED_ATTRIBUTE)).willReturn(null); + + // when: doFilter() is called + spyFilter.doFilter(requestMock, responseMock, filterChainMock); + + // then: doFilterInternal should be called and ALREADY_FILTERED_ATTRIBUTE_KEY should be set on the request + verify(spyFilter).doFilterInternal(requestMock, responseMock, filterChainMock); + verify(requestMock).setAttribute(RequestTracingFilter.FILTER_HAS_ALREADY_EXECUTED_ATTRIBUTE, Boolean.TRUE); + } + + @Test + public void doFilter_should_not_unset_ALREADY_FILTERED_ATTRIBUTE_KEY_after_running_doFilterInternal() throws IOException, ServletException { + // given: filter that will run doFilterInternal and a FilterChain we can use to verify state when called + final RequestTracingFilter spyFilter = spy(getBasicFilter()); + given(requestMock.getAttribute( + RequestTracingFilter.FILTER_HAS_ALREADY_EXECUTED_ATTRIBUTE)).willReturn(null); + final List ifObjectAddedThenSmartFilterChainCalled = new ArrayList<>(); + FilterChain smartFilterChain = new FilterChain() { + @Override + public void doFilter(ServletRequest request, ServletResponse response) throws IOException, ServletException { + // Verify that when the filter chain is called we're in doFilterInternal, and that the request has ALREADY_FILTERED_ATTRIBUTE_KEY set + verify(spyFilter).doFilterInternal(requestMock, responseMock, this); + verify(requestMock).setAttribute(RequestTracingFilter.FILTER_HAS_ALREADY_EXECUTED_ATTRIBUTE, Boolean.TRUE); + verify(requestMock, times(0)).removeAttribute(RequestTracingFilter.FILTER_HAS_ALREADY_EXECUTED_ATTRIBUTE); + ifObjectAddedThenSmartFilterChainCalled.add(true); + } + }; + + // when: doFilter() is called + spyFilter.doFilter(requestMock, responseMock, smartFilterChain); + + // then: smartFilterChain's doFilter should have been called and ALREADY_FILTERED_ATTRIBUTE_KEY should still be set on the request + assertThat(ifObjectAddedThenSmartFilterChainCalled).hasSize(1); + verify(requestMock, never()).removeAttribute(RequestTracingFilter.FILTER_HAS_ALREADY_EXECUTED_ATTRIBUTE); + } + + @Test + public void doFilter_should_not_unset_ALREADY_FILTERED_ATTRIBUTE_KEY_even_if_filter_chain_explodes() throws IOException, ServletException { + // given: filter that will run doFilterInternal and a FilterChain we can use to verify state when called and then explodes + final RequestTracingFilter spyFilter = spy(getBasicFilter()); + given(requestMock.getAttribute( + RequestTracingFilter.FILTER_HAS_ALREADY_EXECUTED_ATTRIBUTE)).willReturn(null); + final List ifObjectAddedThenSmartFilterChainCalled = new ArrayList<>(); + FilterChain smartFilterChain = new FilterChain() { + @Override + public void doFilter(ServletRequest request, ServletResponse response) throws IOException, ServletException { + // Verify that when the filter chain is called we're in doFilterInternal, and that the request has ALREADY_FILTERED_ATTRIBUTE_KEY set + verify(spyFilter).doFilterInternal(requestMock, responseMock, this); + verify(requestMock).setAttribute(RequestTracingFilter.FILTER_HAS_ALREADY_EXECUTED_ATTRIBUTE, Boolean.TRUE); + verify(requestMock, times(0)).removeAttribute(RequestTracingFilter.FILTER_HAS_ALREADY_EXECUTED_ATTRIBUTE); + ifObjectAddedThenSmartFilterChainCalled.add(true); + throw new IllegalStateException("boom"); + } + }; + + // when: doFilter() is called + boolean filterChainExploded = false; + try { + spyFilter.doFilter(requestMock, responseMock, smartFilterChain); + } + catch(IllegalStateException ex) { + if ("boom".equals(ex.getMessage())) + filterChainExploded = true; + } + + // then: smartFilterChain's doFilter should have been called, it should have exploded, and ALREADY_FILTERED_ATTRIBUTE_KEY should still be set on the request + assertThat(ifObjectAddedThenSmartFilterChainCalled).hasSize(1); + assertThat(filterChainExploded).isTrue(); + verify(requestMock, never()).removeAttribute(RequestTracingFilter.FILTER_HAS_ALREADY_EXECUTED_ATTRIBUTE); + } + + @Test + public void doFilter_should_not_call_doFilterInternal_if_already_filtered() throws IOException, ServletException { + // given: filter that returns false for skipDispatch but request that returns non-null for already-filtered attribute + RequestTracingFilter spyFilter = spy(getBasicFilter()); + given(requestMock.getAttribute( + RequestTracingFilter.FILTER_HAS_ALREADY_EXECUTED_ATTRIBUTE)).willReturn(Boolean.TRUE); + + // when: doFilter() is called + spyFilter.doFilter(requestMock, responseMock, filterChainMock); + + // then: doFilterInternal should not be called + verify(spyFilter, times(0)).doFilterInternal(requestMock, responseMock, filterChainMock); + } + + @Test + public void doFilter_should_not_call_doFilterInternal_if_not_already_filtered_but_skipDispatch_returns_true() throws IOException, ServletException { + // given: request that returns null for already-filtered attribute but filter that returns true for skipDispatch + RequestTracingFilter spyFilter = spy(getBasicFilter()); + doReturn(true).when(spyFilter).skipDispatch(any(HttpServletRequest.class)); + given(requestMock.getAttribute( + RequestTracingFilter.FILTER_HAS_ALREADY_EXECUTED_ATTRIBUTE)).willReturn(null); + + // when: doFilter() is called + spyFilter.doFilter(requestMock, responseMock, filterChainMock); + + // then: doFilterInternal should not be called + verify(spyFilter, times(0)).doFilterInternal(requestMock, responseMock, filterChainMock); + verify(spyFilter).skipDispatch(requestMock); + } + // VERIFY doFilterInternal =================================== + @Test + public void doFilterInternal_should_create_new_sampleable_span_if_no_parent_in_request_and_it_should_be_completed() throws ServletException, IOException { + // given: filter + RequestTracingFilter filter = getBasicFilter(); + + // when: doFilterInternal is called with a request that does not have a parent span + filter.doFilterInternal(requestMock, responseMock, spanCapturingFilterChain); + + // then: a new valid sampleable span should be created and completed + Span span = spanCapturingFilterChain.capturedSpan; + assertThat(span).isNotNull(); + assertThat(span.getTraceId()).isNotNull(); + assertThat(span.getSpanId()).isNotNull(); + assertThat(span.getSpanName()).isNotNull(); + assertThat(span.getParentSpanId()).isNull(); + assertThat(span.isSampleable()).isTrue(); + assertThat(span.isCompleted()).isTrue(); + } + + @Test + public void doFilterInternal_should_not_complete_span_until_after_filter_chain_runs() throws ServletException, IOException { + // given: filter and filter chain that can tell us whether or not the span is complete at the time it is called + RequestTracingFilter filter = getBasicFilter(); + final List spanCompletedHolder = new ArrayList<>(); + final List spanHolder = new ArrayList<>(); + FilterChain smartFilterChain = (request, response) -> { + Span span = Tracer.getInstance().getCurrentSpan(); + spanHolder.add(span); + if (span != null) + spanCompletedHolder.add(span.isCompleted()); + }; + + // when: doFilterInternal is called + filter.doFilterInternal(requestMock, responseMock, smartFilterChain); + + // then: we should be able to validate that the smartFilterChain was called, and when it was called the span had not yet been completed, + // and after doFilterInternal finished it was completed. + assertThat(spanHolder).hasSize(1); + assertThat(spanCompletedHolder).hasSize(1); + assertThat(spanCompletedHolder.get(0)).isFalse(); + assertThat(spanHolder.get(0).isCompleted()).isTrue(); + } + + @DataProvider(value = { + "true", + "false" + }) + @Test + public void doFilterInternal_should_complete_span_even_if_filter_chain_explodes( + boolean isAsyncRequest + ) throws ServletException, IOException { + // given: filter and filter chain that will explode when called + RequestTracingFilter filterSpy = spy(getBasicFilter()); + final List spanContextHolder = new ArrayList<>(); + FilterChain explodingFilterChain = (request, response) -> { + // Verify that the span is not yet completed, keep track of it for later, then explode + Span span = Tracer.getInstance().getCurrentSpan(); + assertThat(span).isNotNull(); + assertThat(span.isCompleted()).isFalse(); + spanContextHolder.add(span); + throw new IllegalStateException("boom"); + }; + + if (isAsyncRequest) { + setupAsyncContextWorkflow(); + } + + // when: doFilterInternal is called + boolean filterChainExploded = false; + try { + filterSpy.doFilterInternal(requestMock, responseMock, explodingFilterChain); + } + catch(IllegalStateException ex) { + if ("boom".equals(ex.getMessage())) + filterChainExploded = true; + } + + // then: we should be able to validate that the filter chain exploded and the span is still completed, + // or setup for completion in the case of an async request + if (isAsyncRequest) { + assertThat(filterChainExploded).isTrue(); + verify(filterSpy).isAsyncRequest(requestMock); + verify(filterSpy).setupTracingCompletionWhenAsyncRequestCompletes(eq(requestMock), any(TracingState.class)); + assertThat(spanContextHolder).hasSize(1); + // The span should not be *completed* for an async request, but the + // setupTracingCompletionWhenAsyncRequestCompletes verification above represents the equivalent for + // async requests. + assertThat(spanContextHolder.get(0).isCompleted()).isFalse(); + } + else { + assertThat(filterChainExploded).isTrue(); + assertThat(spanContextHolder).hasSize(1); + assertThat(spanContextHolder.get(0).isCompleted()).isTrue(); + } + } + + @Test + public void doFilterInternal_should_set_request_attributes_to_new_span_info_with_user_id() throws ServletException, IOException { + // given: filter + RequestTracingFilter spyFilter = spy(getBasicFilter()); + given(requestMock.getHeader(USER_ID_HEADER_KEY)).willReturn("testUserId"); + + // when: doFilterInternal is called + spyFilter.doFilterInternal(requestMock, responseMock, spanCapturingFilterChain); + + // then: request attributes should be set with the new span's info + assertThat(spanCapturingFilterChain.capturedSpan).isNotNull(); + Span newSpan = spanCapturingFilterChain.capturedSpan; + + assertThat(newSpan.getUserId()).isEqualTo("testUserId"); + } + + @Test + public void doFilterInternal_should_set_request_attributes_to_new_span_info_with_alt_user_id() throws ServletException, IOException { + // given: filter + RequestTracingFilter spyFilter = spy(getBasicFilter()); + given(requestMock.getHeader(ALT_USER_ID_HEADER_KEY)).willReturn("testUserId"); + + // when: doFilterInternal is called + spyFilter.doFilterInternal(requestMock, responseMock, spanCapturingFilterChain); + + // then: request attributes should be set with the new span's info + assertThat(spanCapturingFilterChain.capturedSpan).isNotNull(); + Span newSpan = spanCapturingFilterChain.capturedSpan; + + assertThat(newSpan.getUserId()).isEqualTo("testUserId"); + } + + @Test + public void doFilterInternal_should_set_request_attributes_to_new_span_info() throws ServletException, IOException { + // given: filter + RequestTracingFilter filter = getBasicFilter(); + + // when: doFilterInternal is called + filter.doFilterInternal(requestMock, responseMock, spanCapturingFilterChain); + + // then: request attributes should be set with the new span's info + assertThat(spanCapturingFilterChain.capturedSpan).isNotNull(); + Span newSpan = spanCapturingFilterChain.capturedSpan; + + verify(requestMock).setAttribute(TraceHeaders.TRACE_SAMPLED, newSpan.isSampleable()); + verify(requestMock).setAttribute(TraceHeaders.TRACE_ID, newSpan.getTraceId()); + verify(requestMock).setAttribute(TraceHeaders.SPAN_ID, newSpan.getSpanId()); + verify(requestMock).setAttribute(TraceHeaders.PARENT_SPAN_ID, newSpan.getParentSpanId()); + verify(requestMock).setAttribute(TraceHeaders.SPAN_NAME, newSpan.getSpanName()); + verify(requestMock).setAttribute(Span.class.getName(), newSpan); + } + + @Test + public void doFilterInternal_should_set_trace_id_in_response_header() throws ServletException, IOException { + // given: filter + RequestTracingFilter filter = getBasicFilter(); + + // when: doFilterInternal is called + filter.doFilterInternal(requestMock, responseMock, spanCapturingFilterChain); + + // then: response header should be set with the span's trace ID + assertThat(spanCapturingFilterChain.capturedSpan).isNotNull(); + verify(responseMock).setHeader(TraceHeaders.TRACE_ID, spanCapturingFilterChain.capturedSpan.getTraceId()); + } + + @Test + public void doFilterInternal_should_use_parent_span_info_if_present_in_request_headers() throws ServletException, IOException { + // given: filter and request that has parent span info + RequestTracingFilter filter = getBasicFilter(); + Span parentSpan = Span.newBuilder("someParentSpan", null).withParentSpanId(TraceAndSpanIdGenerator.generateId()).withSampleable(false).withUserId("someUser").build(); + given(requestMock.getHeader(TraceHeaders.TRACE_ID)).willReturn(parentSpan.getTraceId()); + given(requestMock.getHeader(TraceHeaders.SPAN_ID)).willReturn(parentSpan.getSpanId()); + given(requestMock.getHeader(TraceHeaders.PARENT_SPAN_ID)).willReturn(parentSpan.getParentSpanId()); + given(requestMock.getHeader(TraceHeaders.SPAN_NAME)).willReturn(parentSpan.getSpanName()); + given(requestMock.getHeader(TraceHeaders.TRACE_SAMPLED)).willReturn(String.valueOf(parentSpan.isSampleable())); + given(requestMock.getServletPath()).willReturn("/some/path"); + given(requestMock.getMethod()).willReturn("GET"); + + // when: doFilterInternal is called + filter.doFilterInternal(requestMock, responseMock, spanCapturingFilterChain); + + // then: the span that is created should use the parent span info as its parent + assertThat(spanCapturingFilterChain.capturedSpan).isNotNull(); + Span newSpan = spanCapturingFilterChain.capturedSpan; + assertThat(newSpan.getTraceId()).isEqualTo(parentSpan.getTraceId()); + assertThat(newSpan.getSpanId()).isNotEqualTo(parentSpan.getSpanId()); + assertThat(newSpan.getParentSpanId()).isEqualTo(parentSpan.getSpanId()); + assertThat(newSpan.getSpanName()).isEqualTo(HttpSpanFactory.getSpanName(requestMock)); + assertThat(newSpan.isSampleable()).isEqualTo(parentSpan.isSampleable()); + assertThat(newSpan.getSpanPurpose()).isEqualTo(SpanPurpose.SERVER); + } + + @Test + public void doFilterInternal_should_use_user_id_from_parent_span_info_if_present_in_request_headers() throws ServletException, IOException { + // given: filter and request that has parent span info + RequestTracingFilter spyFilter = spy(getBasicFilter()); + given(requestMock.getHeader(ALT_USER_ID_HEADER_KEY)).willReturn("testUserId"); + + Span parentSpan = Span.newBuilder("someParentSpan", null).withParentSpanId(TraceAndSpanIdGenerator.generateId()).withSampleable(false).withUserId("someUser").build(); + given(requestMock.getHeader(TraceHeaders.TRACE_ID)).willReturn(parentSpan.getTraceId()); + given(requestMock.getHeader(TraceHeaders.SPAN_ID)).willReturn(parentSpan.getSpanId()); + given(requestMock.getHeader(TraceHeaders.PARENT_SPAN_ID)).willReturn(parentSpan.getParentSpanId()); + given(requestMock.getHeader(TraceHeaders.SPAN_NAME)).willReturn(parentSpan.getSpanName()); + given(requestMock.getHeader(TraceHeaders.TRACE_SAMPLED)).willReturn(String.valueOf(parentSpan.isSampleable())); + given(requestMock.getServletPath()).willReturn("/some/path"); + given(requestMock.getMethod()).willReturn("GET"); + + // when: doFilterInternal is called + spyFilter.doFilterInternal(requestMock, responseMock, spanCapturingFilterChain); + + // then: the span that is created should use the parent span info as its parent + assertThat(spanCapturingFilterChain.capturedSpan).isNotNull(); + Span newSpan = spanCapturingFilterChain.capturedSpan; + + assertThat(newSpan.getUserId()).isEqualTo("testUserId"); + + } + @DataProvider(value = { "true | true", "true | false", @@ -132,12 +556,9 @@ public void doFilterInternal_should_reset_tracing_info_to_whatever_was_on_the_th TracingState originalTracingState = TracingState.getCurrentThreadTracingState(); // when - Throwable ex = catchThrowable(new ThrowableAssert.ThrowingCallable() { - @Override - public void call() throws Throwable { - filter.doFilterInternal(requestMock, responseMock, spanCapturingFilterChain); - } - }); + Throwable ex = catchThrowable( + () -> filter.doFilterInternal(requestMock, responseMock, spanCapturingFilterChain) + ); // then if (throwExceptionInInnerFinallyBlock) { @@ -153,11 +574,12 @@ public void call() throws Throwable { } @Test - public void doFilterInternal_should_add_async_listener_but_not_complete_span_when_async_request_is_detected( + public void doFilterInternal_should_call_setupTracingCompletionWhenAsyncRequestCompletes_when_isAsyncRequest_returns_true( ) throws ServletException, IOException { // given RequestTracingFilter filterSpy = spy(getBasicFilter()); setupAsyncContextWorkflow(); + doReturn(true).when(filterSpy).isAsyncRequest(any(HttpServletRequest.class)); // when filterSpy.doFilterInternal(requestMock, responseMock, spanCapturingFilterChain); @@ -165,18 +587,15 @@ public void doFilterInternal_should_add_async_listener_but_not_complete_span_whe // then assertThat(spanCapturingFilterChain.capturedSpan).isNotNull(); assertThat(spanCapturingFilterChain.capturedSpan.isCompleted()).isFalse(); - assertThat(capturedAsyncListeners).hasSize(1); - assertThat(capturedAsyncListeners.get(0)).isInstanceOf(WingtipsRequestSpanCompletionAsyncListener.class); verify(filterSpy).setupTracingCompletionWhenAsyncRequestCompletes(eq(requestMock), any(TracingState.class)); } @Test - public void doFilterInternal_should_not_add_async_listener_when_isAsyncRequest_returns_false( + public void doFilterInternal_should_not_call_setupTracingCompletionWhenAsyncRequestCompletes_when_isAsyncRequest_returns_false( ) throws ServletException, IOException { // given RequestTracingFilter filterSpy = spy(getBasicFilter()); doReturn(false).when(filterSpy).isAsyncRequest(any(HttpServletRequest.class)); - setupAsyncContextWorkflow(); // when filterSpy.doFilterInternal(requestMock, responseMock, spanCapturingFilterChain); @@ -184,178 +603,158 @@ public void doFilterInternal_should_not_add_async_listener_when_isAsyncRequest_r // then assertThat(spanCapturingFilterChain.capturedSpan).isNotNull(); assertThat(spanCapturingFilterChain.capturedSpan.isCompleted()).isTrue(); - assertThat(capturedAsyncListeners).hasSize(0); verify(filterSpy, never()).setupTracingCompletionWhenAsyncRequestCompletes( any(HttpServletRequest.class), any(TracingState.class) ); } - private static class SpanCapturingFilterChain implements FilterChain { + @Test + public void doFilterInternal_should_add_async_listener_but_not_complete_span_when_async_request_is_detected( + ) throws ServletException, IOException { + // given + RequestTracingFilter filterSpy = spy(getBasicFilter()); + setupAsyncContextWorkflow(); - public Span capturedSpan; + // when + filterSpy.doFilterInternal(requestMock, responseMock, spanCapturingFilterChain); - @Override - public void doFilter(ServletRequest request, ServletResponse response) throws IOException, ServletException { - capturedSpan = Tracer.getInstance().getCurrentSpan(); - } + // then + assertThat(spanCapturingFilterChain.capturedSpan).isNotNull(); + assertThat(spanCapturingFilterChain.capturedSpan.isCompleted()).isFalse(); + assertThat(capturedAsyncListeners).hasSize(1); + assertThat(capturedAsyncListeners.get(0)).isInstanceOf(WingtipsRequestSpanCompletionAsyncListener.class); + verify(filterSpy).setupTracingCompletionWhenAsyncRequestCompletes(eq(requestMock), any(TracingState.class)); } - // VERIFY isAsyncDispatch =========================== - - @DataProvider(value = { - "FORWARD | false", - "INCLUDE | false", - "REQUEST | false", - "ASYNC | true", - "ERROR | false" - }, splitBy = "\\|") @Test - public void isAsyncDispatch_returns_result_based_on_request_dispatcher_type( - DispatcherType dispatcherType, boolean expectedResult - ) { + public void doFilterInternal_should_not_add_async_listener_when_isAsyncRequest_returns_false( + ) throws ServletException, IOException { // given - doReturn(dispatcherType).when(requestMock).getDispatcherType(); - RequestTracingFilter filter = getBasicFilter(); + RequestTracingFilter filterSpy = spy(getBasicFilter()); + doReturn(false).when(filterSpy).isAsyncRequest(any(HttpServletRequest.class)); + setupAsyncContextWorkflow(); // when - boolean result = filter.isAsyncDispatch(requestMock); + filterSpy.doFilterInternal(requestMock, responseMock, spanCapturingFilterChain); // then - assertThat(result).isEqualTo(expectedResult); + assertThat(spanCapturingFilterChain.capturedSpan).isNotNull(); + assertThat(spanCapturingFilterChain.capturedSpan.isCompleted()).isTrue(); + assertThat(capturedAsyncListeners).hasSize(0); + verify(filterSpy, never()).setupTracingCompletionWhenAsyncRequestCompletes( + any(HttpServletRequest.class), any(TracingState.class) + ); } - // VERIFY isAsyncRequest ============================== + // VERIFY getServletRuntime ========================= - @DataProvider(value = { - "true | true | true", - "true | false | false", - "false | true | false", - "false | false | false", - }, splitBy = "\\|") @Test - public void isAsyncRequest_should_return_the_value_of_request_isAsyncStarted_unless_containerSupportsAsyncRequests_is_false( - boolean containerSupportsAsyncRequests, boolean isAsyncStarted, boolean expectedResult - ) { + public void getServletRuntime_returns_value_of_ServletRuntime_determineServletRuntime_method_and_caches_result() { // given - RequestTracingFilter filterSpy = spy(getBasicFilter()); - doReturn(isAsyncStarted).when(requestMock).isAsyncStarted(); - doReturn(containerSupportsAsyncRequests).when(filterSpy).containerSupportsAsyncRequests(requestMock); + Class expectedServletRuntimeClass = + ServletRuntime.determineServletRuntime(requestMock.getClass(), ASYNC_LISTENER_CLASSNAME).getClass(); + + RequestTracingFilter filter = getBasicFilter(); + assertThat(filter.servletRuntime).isNull(); // when - boolean result = filterSpy.isAsyncRequest(requestMock); + ServletRuntime result = filter.getServletRuntime(requestMock); // then - assertThat(result).isEqualTo(expectedResult); - if (containerSupportsAsyncRequests) { - verify(requestMock).isAsyncStarted(); - } - else { - verify(requestMock, never()).isAsyncStarted(); - } - verify(filterSpy).containerSupportsAsyncRequests(requestMock); - verify(filterSpy).isAsyncRequest(requestMock); - verifyNoMoreInteractions(filterSpy); + assertThat(result.getClass()).isEqualTo(expectedServletRuntimeClass); + assertThat(filter.servletRuntime).isSameAs(result); } - // VERIFY setupTracingCompletionWhenAsyncRequestCompletes ============ - @Test - public void setupTracingCompletionWhenAsyncRequestCompletes_should_add_WingtipsRequestSpanCompletionAsyncListener( - ) throws ServletException, IOException { + public void getServletRuntime_uses_cached_value_if_possible() { // given - RequestTracingFilter filter = getBasicFilter(); - setupAsyncContextWorkflow(); - TracingState tracingStateMock = mock(TracingState.class); + RequestTracingFilter filterSpy = spy(getBasicFilter()); + ServletRuntime servletRuntimeMock = mock(ServletRuntime.class); + filterSpy.servletRuntime = servletRuntimeMock; // when - filter.setupTracingCompletionWhenAsyncRequestCompletes(requestMock, tracingStateMock); + ServletRuntime result = filterSpy.getServletRuntime(mock(HttpServletRequest.class)); // then - assertThat(capturedAsyncListeners).hasSize(1); - assertThat(capturedAsyncListeners.get(0)).isInstanceOf(WingtipsRequestSpanCompletionAsyncListener.class); - WingtipsRequestSpanCompletionAsyncListener listener = - (WingtipsRequestSpanCompletionAsyncListener)capturedAsyncListeners.get(0); - assertThat(listener.originalRequestTracingState).isSameAs(tracingStateMock); + assertThat(result).isSameAs(servletRuntimeMock); } - // VERIFY supportsAsyncRequests ==================== + // VERIFY isAsyncRequest ============================== @DataProvider(value = { "true", "false" - }) + }, splitBy = "\\|") @Test - public void supportsAsyncRequests_returns_true_if_class_contains_getAsyncContext_otherwise_false( - boolean useClassWithExpectedMethod - ) { + public void isAsyncRequest_delegates_to_ServletRuntime(boolean servletRuntimeResult) { // given - Class servletRequestClass = (useClassWithExpectedMethod) - ? GoodFakeServletRequest.class - : BadFakeServletRequest.class; - RequestTracingFilter filter = getBasicFilter(); + RequestTracingFilter filterSpy = spy(getBasicFilter()); + doReturn(servletRuntimeMock).when(filterSpy).getServletRuntime(any(HttpServletRequest.class)); + doReturn(servletRuntimeResult).when(servletRuntimeMock).isAsyncRequest(any(HttpServletRequest.class)); // when - boolean result = filter.supportsAsyncRequests(servletRequestClass); + boolean result = filterSpy.isAsyncRequest(requestMock); // then - assertThat(result).isEqualTo(useClassWithExpectedMethod); + assertThat(result).isEqualTo(servletRuntimeResult); + verify(filterSpy).getServletRuntime(requestMock); + verify(servletRuntimeMock).isAsyncRequest(requestMock); } - /** - * Dummy class that has a good getAsyncContext function - */ - private static final class GoodFakeServletRequest { - public Object getAsyncContext() { - return Boolean.TRUE; - } - } - - /** - * Dummy class that does NOT have a getAsyncContext function - */ - private static final class BadFakeServletRequest { - } - - // VERIFY containerSupportsAsyncRequests ==================== + // VERIFY setupTracingCompletionWhenAsyncRequestCompletes ============ - @DataProvider(value = { - "true", - "false" - }) @Test - public void containerSupportsAsyncRequests_returns_value_of_supportsAsyncRequests_method_and_caches_result( - boolean supportsAsyncRequestsValue - ) { + public void setupTracingCompletionWhenAsyncRequestCompletes_delegates_to_ServletRuntime() { // given RequestTracingFilter filterSpy = spy(getBasicFilter()); - doReturn(supportsAsyncRequestsValue).when(filterSpy).supportsAsyncRequests(any(Class.class)); - assertThat(filterSpy.containerSupportsAsyncRequests).isNull(); + doReturn(servletRuntimeMock).when(filterSpy).getServletRuntime(any(HttpServletRequest.class)); + TracingState tracingStateMock = mock(TracingState.class); // when - boolean result = filterSpy.containerSupportsAsyncRequests(requestMock); + filterSpy.setupTracingCompletionWhenAsyncRequestCompletes(requestMock, tracingStateMock); // then - assertThat(result).isEqualTo(supportsAsyncRequestsValue); - verify(filterSpy).supportsAsyncRequests(requestMock.getClass()); - assertThat(filterSpy.containerSupportsAsyncRequests).isEqualTo(result); - } + verify(filterSpy).setupTracingCompletionWhenAsyncRequestCompletes(requestMock, tracingStateMock); + verify(filterSpy).getServletRuntime(requestMock); + verify(servletRuntimeMock).setupTracingCompletionWhenAsyncRequestCompletes(requestMock, tracingStateMock); + verifyNoMoreInteractions(filterSpy, servletRuntimeMock, requestMock, tracingStateMock); + } + + // VERIFY isAsyncDispatch =========================== @DataProvider(value = { "true", "false" }) @Test - public void containerSupportsAsyncRequests_uses_cached_value_if_possible(boolean cachedValue) { + @SuppressWarnings("deprecation") + public void isAsyncDispatch_delegates_to_ServletRuntime(boolean servletRuntimeResult) { // given RequestTracingFilter filterSpy = spy(getBasicFilter()); - filterSpy.containerSupportsAsyncRequests = cachedValue; + doReturn(servletRuntimeMock).when(filterSpy).getServletRuntime(any(HttpServletRequest.class)); + doReturn(servletRuntimeResult).when(servletRuntimeMock).isAsyncDispatch(any(HttpServletRequest.class)); // when - boolean result = filterSpy.containerSupportsAsyncRequests(mock(ServletRequest.class)); + boolean result = filterSpy.isAsyncDispatch(requestMock); // then - assertThat(result).isEqualTo(cachedValue); - verify(filterSpy, never()).supportsAsyncRequests(any(Class.class)); + assertThat(result).isEqualTo(servletRuntimeResult); + verify(filterSpy).getServletRuntime(requestMock); + verify(servletRuntimeMock).isAsyncDispatch(requestMock); + } + + // VERIFY skipDispatch ============================== + + @Test + public void skipDispatch_should_return_false() { + // given: filter + RequestTracingFilter filter = getBasicFilter(); + + // when: skipDispatchIsCalled + boolean result = filter.skipDispatch(requestMock); + + // then: the result should be false + assertThat(result).isFalse(); } } diff --git a/wingtips-old-servlet-api/src/test/java/com/nike/wingtips/servlet/RequestWithHeadersServletAdapterTest.java b/wingtips-servlet-api/src/test/java/com/nike/wingtips/servlet/RequestWithHeadersServletAdapterTest.java similarity index 100% rename from wingtips-old-servlet-api/src/test/java/com/nike/wingtips/servlet/RequestWithHeadersServletAdapterTest.java rename to wingtips-servlet-api/src/test/java/com/nike/wingtips/servlet/RequestWithHeadersServletAdapterTest.java diff --git a/wingtips-servlet-api/src/test/java/com/nike/wingtips/servlet/ServletRuntimeTest.java b/wingtips-servlet-api/src/test/java/com/nike/wingtips/servlet/ServletRuntimeTest.java new file mode 100644 index 00000000..fed8f75e --- /dev/null +++ b/wingtips-servlet-api/src/test/java/com/nike/wingtips/servlet/ServletRuntimeTest.java @@ -0,0 +1,209 @@ +package com.nike.wingtips.servlet; + +import com.nike.wingtips.servlet.ServletRuntime.Servlet2Runtime; +import com.nike.wingtips.servlet.ServletRuntime.Servlet3Runtime; +import com.nike.wingtips.util.TracingState; + +import com.tngtech.java.junit.dataprovider.DataProvider; +import com.tngtech.java.junit.dataprovider.DataProviderRunner; + +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.ArgumentCaptor; + +import java.io.IOException; +import java.util.List; +import java.util.UUID; + +import javax.servlet.AsyncContext; +import javax.servlet.AsyncListener; +import javax.servlet.DispatcherType; +import javax.servlet.ServletException; +import javax.servlet.http.HttpServletRequest; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.catchThrowable; +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; + +/** + * Tests the functionality of {@link ServletRuntime}. + * + * @author Nic Munroe + */ +@RunWith(DataProviderRunner.class) +public class ServletRuntimeTest { + + private Servlet2Runtime servlet2Runtime; + private Servlet3Runtime servlet3Runtime; + + private HttpServletRequest requestMock; + + @Before + public void beforeMethod() { + servlet2Runtime = new Servlet2Runtime(); + servlet3Runtime = new Servlet3Runtime(); + + requestMock = mock(HttpServletRequest.class); + } + + @DataProvider(value = { + "true | true | true", + "true | false | false", + "false | true | false", + "false | false | false" + }, splitBy = "\\|") + @Test + public void determineServletRuntime_returns_ServletRuntime_based_on_arguments( + boolean classHasServlet3Method, boolean useAsyncListenerClassThatExists, boolean expectServlet3Runtime + ) { + // given + Class servletRequestClass = (classHasServlet3Method) + ? GoodFakeServletRequest.class + : BadFakeServletRequest.class; + + String asyncListenerClassname = (useAsyncListenerClassThatExists) + ? AsyncListener.class.getName() + : "does.not.exist.AsyncListener" + UUID.randomUUID().toString(); + + // when + ServletRuntime result = ServletRuntime.determineServletRuntime(servletRequestClass, asyncListenerClassname); + + // then + if (expectServlet3Runtime) { + assertThat(result).isInstanceOf(Servlet3Runtime.class); + } + else { + assertThat(result).isInstanceOf(Servlet2Runtime.class); + } + } + + /** + * Dummy class that has a good getAsyncContext function + */ + private static final class GoodFakeServletRequest { + public Object getAsyncContext() { + return Boolean.TRUE; + } + } + + /** + * Dummy class that does NOT have a getAsyncContext function + */ + private static final class BadFakeServletRequest { + } + + // Servlet2Runtime tests ======================================= + + @Test + public void servlet2_isAsyncRequest_should_return_false() { + // given + Servlet2Runtime implSpy = spy(servlet2Runtime); + + // when + boolean result = implSpy.isAsyncRequest(requestMock); + + // then + assertThat(result).isFalse(); + verify(implSpy).isAsyncRequest(requestMock); + verifyNoMoreInteractions(implSpy); + } + + @Test + public void servlet2_setupTracingCompletionWhenAsyncRequestCompletes_should_throw_IllegalStateException( + ) throws ServletException, IOException { + // when + Throwable ex = catchThrowable( + () -> servlet2Runtime.setupTracingCompletionWhenAsyncRequestCompletes(requestMock, mock(TracingState.class)) + ); + + // then + assertThat(ex) + .isInstanceOf(IllegalStateException.class) + .hasMessage("This method should never be called in a pre-Servlet-3.0 environment."); + } + + @Test + public void servlet2_isAsyncDispatch_should_return_false() { + // given + Servlet2Runtime implSpy = spy(servlet2Runtime); + + // when + boolean result = implSpy.isAsyncDispatch(requestMock); + + // then + assertThat(result).isFalse(); + verify(implSpy).isAsyncDispatch(requestMock); + verifyNoMoreInteractions(implSpy); + } + + // Servlet3Runtime tests ======================================= + + @DataProvider(value = { + "true", + "false" + }, splitBy = "\\|") + @Test + public void isAsyncRequest_should_return_the_value_of_request_isAsyncStarted(boolean requestIsAsyncStarted) { + // given + Servlet3Runtime implSpy = spy(servlet3Runtime); + doReturn(requestIsAsyncStarted).when(requestMock).isAsyncStarted(); + + // when + boolean result = implSpy.isAsyncRequest(requestMock); + + // then + assertThat(result).isEqualTo(requestIsAsyncStarted); + verify(requestMock).isAsyncStarted(); + verify(implSpy).isAsyncRequest(requestMock); + verifyNoMoreInteractions(implSpy); + } + + @Test + public void setupTracingCompletionWhenAsyncRequestCompletes_should_add_WingtipsRequestSpanCompletionAsyncListener( + ) throws ServletException, IOException { + // given + AsyncContext asyncContextMock = mock(AsyncContext.class); + doReturn(asyncContextMock).when(requestMock).getAsyncContext(); + TracingState tracingStateMock = mock(TracingState.class); + + ArgumentCaptor listenerCaptor = ArgumentCaptor.forClass(AsyncListener.class); + + // when + servlet3Runtime.setupTracingCompletionWhenAsyncRequestCompletes(requestMock, tracingStateMock); + + // then + verify(asyncContextMock).addListener(listenerCaptor.capture()); + List addedListeners = listenerCaptor.getAllValues(); + assertThat(addedListeners).hasSize(1); + assertThat(addedListeners.get(0)).isInstanceOf(WingtipsRequestSpanCompletionAsyncListener.class); + WingtipsRequestSpanCompletionAsyncListener listener = + (WingtipsRequestSpanCompletionAsyncListener)addedListeners.get(0); + assertThat(listener.originalRequestTracingState).isSameAs(tracingStateMock); + } + + @DataProvider(value = { + "FORWARD | false", + "INCLUDE | false", + "REQUEST | false", + "ASYNC | true", + "ERROR | false" + }, splitBy = "\\|") + @Test + public void servlet3_isAsyncDispatch_returns_result_based_on_request_dispatcher_type( + DispatcherType dispatcherType, boolean expectedResult + ) { + // given + doReturn(dispatcherType).when(requestMock).getDispatcherType(); + + // when + boolean result = servlet3Runtime.isAsyncDispatch(requestMock); + + // then + assertThat(result).isEqualTo(expectedResult); + } +} \ No newline at end of file