diff --git a/server/src/main/java/org/elasticsearch/common/util/concurrent/ThreadContext.java b/server/src/main/java/org/elasticsearch/common/util/concurrent/ThreadContext.java index adf431c561e79..67662894ce907 100644 --- a/server/src/main/java/org/elasticsearch/common/util/concurrent/ThreadContext.java +++ b/server/src/main/java/org/elasticsearch/common/util/concurrent/ThreadContext.java @@ -19,8 +19,10 @@ import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.util.Maps; import org.elasticsearch.common.util.set.Sets; +import org.elasticsearch.core.Releasable; import org.elasticsearch.core.Tuple; import org.elasticsearch.http.HttpTransportSettings; +import org.elasticsearch.tasks.Task; import java.io.IOException; import java.nio.charset.StandardCharsets; @@ -97,16 +99,17 @@ public ThreadContext(Settings settings) { /** * Removes the current context and resets a default context. The removed context can be * restored by closing the returned {@link StoredContext}. + * @return a stored context that will restore the current context to its state at the point this method was called */ public StoredContext stashContext() { final ThreadContextStruct context = threadLocal.get(); - /** + + /* * X-Opaque-ID should be preserved in a threadContext in order to propagate this across threads. * This is needed so the DeprecationLogger in another thread can see the value of X-Opaque-ID provided by a user. * The same is applied to Task.TRACE_ID. - * Otherwise when context is stash, it should be empty. + * Otherwise when context is stashed, it should be empty. */ - boolean hasHeadersToCopy = false; if (context.requestHeaders.isEmpty() == false) { for (String header : HEADERS_TO_COPY) { @@ -116,13 +119,22 @@ public StoredContext stashContext() { } } } + + boolean hasTransientHeadersToCopy = context.transientHeaders.containsKey(Task.APM_TRACE_CONTEXT); + + ThreadContextStruct threadContextStruct = DEFAULT_CONTEXT; if (hasHeadersToCopy) { - Map map = headers(context); - ThreadContextStruct threadContextStruct = DEFAULT_CONTEXT.putHeaders(map); - threadLocal.set(threadContextStruct); - } else { - threadLocal.set(DEFAULT_CONTEXT); + Map copiedHeaders = getHeadersToCopy(context); + threadContextStruct = DEFAULT_CONTEXT.putHeaders(copiedHeaders); } + if (hasTransientHeadersToCopy) { + threadContextStruct = threadContextStruct.putTransient( + Task.APM_TRACE_CONTEXT, + context.transientHeaders.get(Task.APM_TRACE_CONTEXT) + ); + } + threadLocal.set(threadContextStruct); + return () -> { // If the node and thus the threadLocal get closed while this task // is still executing, we don't want this runnable to fail with an @@ -131,9 +143,98 @@ public StoredContext stashContext() { }; } - private static Map headers(ThreadContextStruct context) { - Map map = Maps.newMapWithExpectedSize(org.elasticsearch.tasks.Task.HEADERS_TO_COPY.size()); - for (String header : org.elasticsearch.tasks.Task.HEADERS_TO_COPY) { + /** + * When using a {@link org.elasticsearch.tracing.Tracer} to capture activity in Elasticsearch, when a parent span is already + * in progress, it is necessary to start a new context before beginning a child span. This method creates a context, + * moving tracing-related fields to different names so that a new child span can be started. This child span will pick up + * the moved fields and use them to establish the parent-child relationship. + * + * @return a stored context, which can be restored when this context is no longer needed. + */ + public StoredContext newTraceContext() { + final ThreadContextStruct originalContext = threadLocal.get(); + final Map newRequestHeaders = new HashMap<>(originalContext.requestHeaders); + final Map newTransientHeaders = new HashMap<>(originalContext.transientHeaders); + + final String previousTraceParent = newRequestHeaders.remove(Task.TRACE_PARENT_HTTP_HEADER); + if (previousTraceParent != null) { + newTransientHeaders.put("parent_" + Task.TRACE_PARENT_HTTP_HEADER, previousTraceParent); + } + + final String previousTraceState = newRequestHeaders.remove(Task.TRACE_STATE); + if (previousTraceState != null) { + newTransientHeaders.put("parent_" + Task.TRACE_STATE, previousTraceState); + } + + final Object previousTraceContext = newTransientHeaders.remove(Task.APM_TRACE_CONTEXT); + if (previousTraceContext != null) { + newTransientHeaders.put("parent_" + Task.APM_TRACE_CONTEXT, previousTraceContext); + } + + threadLocal.set( + new ThreadContextStruct( + newRequestHeaders, + originalContext.responseHeaders, + newTransientHeaders, + originalContext.isSystemContext, + originalContext.warningHeadersSize + ) + ); + // this is the context when this method returns + final ThreadContextStruct newContext = threadLocal.get(); + return () -> { + if (threadLocal.get() != newContext) { + // Tracing shouldn't interrupt the propagation of response headers, so in the same as #newStoredContext(...), + // pass on any potential changes to the response headers. + threadLocal.set(originalContext.putResponseHeaders(threadLocal.get().responseHeaders)); + } else { + threadLocal.set(originalContext); + } + }; + } + + public boolean hasTraceContext() { + final ThreadContextStruct context = threadLocal.get(); + return context.requestHeaders.containsKey(Task.TRACE_PARENT_HTTP_HEADER) + || context.requestHeaders.containsKey(Task.TRACE_STATE) + || context.transientHeaders.containsKey(Task.APM_TRACE_CONTEXT); + } + + /** + * When using a {@link org.elasticsearch.tracing.Tracer}, sometimes you need to start a span completely unrelated + * to any current span. In order to avoid any parent/child relationship being created, this method creates a new + * context that clears all the tracing fields. + * + * @return a stored context, which can be restored when this context is no longer needed. + */ + public StoredContext clearTraceContext() { + final ThreadContextStruct context = threadLocal.get(); + final Map newRequestHeaders = new HashMap<>(context.requestHeaders); + final Map newTransientHeaders = new HashMap<>(context.transientHeaders); + + newRequestHeaders.remove(Task.TRACE_PARENT_HTTP_HEADER); + newRequestHeaders.remove(Task.TRACE_STATE); + + newTransientHeaders.remove("parent_" + Task.TRACE_PARENT_HTTP_HEADER); + newTransientHeaders.remove("parent_" + Task.TRACE_STATE); + newTransientHeaders.remove(Task.APM_TRACE_CONTEXT); + newTransientHeaders.remove("parent_" + Task.APM_TRACE_CONTEXT); + + threadLocal.set( + new ThreadContextStruct( + newRequestHeaders, + context.responseHeaders, + newTransientHeaders, + context.isSystemContext, + context.warningHeadersSize + ) + ); + return () -> threadLocal.set(context); + } + + private static Map getHeadersToCopy(ThreadContextStruct context) { + Map map = Maps.newMapWithExpectedSize(HEADERS_TO_COPY.size()); + for (String header : HEADERS_TO_COPY) { final String value = context.requestHeaders.get(header); if (value != null) { map.put(header, value); @@ -476,10 +577,7 @@ public boolean isSystemContext() { } @FunctionalInterface - public interface StoredContext extends AutoCloseable { - @Override - void close(); - + public interface StoredContext extends AutoCloseable, Releasable { default void restore() { close(); } diff --git a/server/src/main/java/org/elasticsearch/tasks/Task.java b/server/src/main/java/org/elasticsearch/tasks/Task.java index d2b0a14a6c5a7..de281fb4ae54d 100644 --- a/server/src/main/java/org/elasticsearch/tasks/Task.java +++ b/server/src/main/java/org/elasticsearch/tasks/Task.java @@ -32,6 +32,7 @@ public class Task { * The request header which is contained in HTTP request. We parse trace.id from it and store it in thread context. * TRACE_PARENT once parsed in RestController.tryAllHandler is not preserved * has to be declared as a header copied over from http request. + * May also be used internally when APM is enabled. */ public static final String TRACE_PARENT_HTTP_HEADER = "traceparent"; @@ -43,6 +44,11 @@ public class Task { public static final String TRACE_STATE = "tracestate"; + /** + * Used internally to pass the apm trace context between the nodes + */ + public static final String APM_TRACE_CONTEXT = "apm.local.context"; + /** * Parsed part of traceparent. It is stored in thread context and emitted in logs. * Has to be declared as a header copied over for tasks. diff --git a/server/src/main/java/org/elasticsearch/tasks/TaskManager.java b/server/src/main/java/org/elasticsearch/tasks/TaskManager.java index e6a166ca67f28..ff3ca13dc9ce9 100644 --- a/server/src/main/java/org/elasticsearch/tasks/TaskManager.java +++ b/server/src/main/java/org/elasticsearch/tasks/TaskManager.java @@ -124,6 +124,9 @@ public Task register(String type, String action, TaskAwareRequest request) { long headerSize = 0; long maxSize = maxHeaderSize.getBytes(); ThreadContext threadContext = threadPool.getThreadContext(); + + assert threadContext.hasTraceContext() == false : "Expected threadContext to have no traceContext fields"; + for (String key : taskHeaders) { String httpHeader = threadContext.getHeader(key); if (httpHeader != null) { @@ -175,43 +178,45 @@ public Task reg } else { unregisterChildNode = null; } - final Task task; - try { - task = register(type, action.actionName, request); - } catch (TaskCancelledException e) { - Releasables.close(unregisterChildNode); - throw e; - } - // NOTE: ActionListener cannot infer Response, see https://bugs.openjdk.java.net/browse/JDK-8203195 - action.execute(task, request, new ActionListener() { - @Override - public void onResponse(Response response) { - try { - release(); - } finally { - taskListener.onResponse(response); - } + + try (var ignored = threadPool.getThreadContext().newTraceContext()) { + final Task task; + try { + task = register(type, action.actionName, request); + } catch (TaskCancelledException e) { + Releasables.close(unregisterChildNode); + throw e; } + action.execute(task, request, new ActionListener<>() { + @Override + public void onResponse(Response response) { + try { + release(); + } finally { + taskListener.onResponse(response); + } + } - @Override - public void onFailure(Exception e) { - try { - release(); - } finally { - taskListener.onFailure(e); + @Override + public void onFailure(Exception e) { + try { + release(); + } finally { + taskListener.onFailure(e); + } } - } - @Override - public String toString() { - return this.getClass().getName() + "{" + taskListener + "}{" + task + "}"; - } + @Override + public String toString() { + return this.getClass().getName() + "{" + taskListener + "}{" + task + "}"; + } - private void release() { - Releasables.close(unregisterChildNode, () -> unregister(task)); - } - }); - return task; + private void release() { + Releasables.close(unregisterChildNode, () -> unregister(task)); + } + }); + return task; + } } private void registerCancellableTask(Task task) {