From 4610f429df4be440d00193a0479416962c01ff01 Mon Sep 17 00:00:00 2001
From: Lauri Tulmin <ltulmin@splunk.com>
Date: Mon, 6 May 2024 16:24:32 +0300
Subject: [PATCH] run callbacks in the context of the parent span

---
 .../v2_4/InfluxDbImplInstrumentation.java     | 13 ++++++
 .../influxdb/v2_4/InfluxDbObjetWrapper.java   | 44 +++++++++++++++++++
 .../influxdb/v2_4/InfluxDbClientTest.java     | 12 ++---
 3 files changed, 61 insertions(+), 8 deletions(-)
 create mode 100644 instrumentation/influxdb-2.4/javaagent/src/main/java/io/opentelemetry/javaagent/instrumentation/influxdb/v2_4/InfluxDbObjetWrapper.java

diff --git a/instrumentation/influxdb-2.4/javaagent/src/main/java/io/opentelemetry/javaagent/instrumentation/influxdb/v2_4/InfluxDbImplInstrumentation.java b/instrumentation/influxdb-2.4/javaagent/src/main/java/io/opentelemetry/javaagent/instrumentation/influxdb/v2_4/InfluxDbImplInstrumentation.java
index 256202acc80a..2633e1d82433 100644
--- a/instrumentation/influxdb-2.4/javaagent/src/main/java/io/opentelemetry/javaagent/instrumentation/influxdb/v2_4/InfluxDbImplInstrumentation.java
+++ b/instrumentation/influxdb-2.4/javaagent/src/main/java/io/opentelemetry/javaagent/instrumentation/influxdb/v2_4/InfluxDbImplInstrumentation.java
@@ -24,6 +24,7 @@
 import io.opentelemetry.javaagent.extension.instrumentation.TypeTransformer;
 import net.bytebuddy.asm.Advice;
 import net.bytebuddy.description.type.TypeDescription;
+import net.bytebuddy.implementation.bytecode.assign.Assigner;
 import net.bytebuddy.matcher.ElementMatcher;
 import okhttp3.HttpUrl;
 import org.influxdb.dto.BatchPoints;
@@ -74,6 +75,7 @@ public static class InfluxDbQueryAdvice {
     @Advice.OnMethodEnter(suppress = Throwable.class)
     public static void onEnter(
         @Advice.Argument(0) Query query,
+        @Advice.AllArguments(readOnly = false, typing = Assigner.Typing.DYNAMIC) Object[] arguments,
         @Advice.FieldValue(value = "retrofit") Retrofit retrofit,
         @Advice.Local("otelCallDepth") CallDepth callDepth,
         @Advice.Local("otelRequest") InfluxDbRequest influxDbRequest,
@@ -98,6 +100,17 @@ public static void onEnter(
         return;
       }
 
+      // wrap callbacks so they'd run in the context of the parent span
+      Object[] newArguments = new Object[arguments.length];
+      boolean hasChangedArgument = false;
+      for (int i = 0; i < arguments.length; i++) {
+        newArguments[i] = InfluxDbObjetWrapper.wrap(arguments[i], parentContext);
+        hasChangedArgument |= newArguments[i] != arguments[i];
+      }
+      if (hasChangedArgument) {
+        arguments = newArguments;
+      }
+
       context = instrumenter().start(parentContext, influxDbRequest);
       scope = context.makeCurrent();
     }
diff --git a/instrumentation/influxdb-2.4/javaagent/src/main/java/io/opentelemetry/javaagent/instrumentation/influxdb/v2_4/InfluxDbObjetWrapper.java b/instrumentation/influxdb-2.4/javaagent/src/main/java/io/opentelemetry/javaagent/instrumentation/influxdb/v2_4/InfluxDbObjetWrapper.java
new file mode 100644
index 000000000000..7be6efe42aed
--- /dev/null
+++ b/instrumentation/influxdb-2.4/javaagent/src/main/java/io/opentelemetry/javaagent/instrumentation/influxdb/v2_4/InfluxDbObjetWrapper.java
@@ -0,0 +1,44 @@
+/*
+ * Copyright The OpenTelemetry Authors
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+package io.opentelemetry.javaagent.instrumentation.influxdb.v2_4;
+
+import io.opentelemetry.context.Context;
+import io.opentelemetry.context.Scope;
+import java.util.function.BiConsumer;
+import java.util.function.Consumer;
+
+public final class InfluxDbObjetWrapper {
+
+  @SuppressWarnings("unchecked")
+  public static Object wrap(Object object, Context parentContext) {
+    if (object instanceof Consumer) {
+      return (Consumer<Object>)
+          o -> {
+            try (Scope ignore = parentContext.makeCurrent()) {
+              ((Consumer<Object>) object).accept(o);
+            }
+          };
+    } else if (object instanceof BiConsumer) {
+      return (BiConsumer<Object, Object>)
+          (o1, o2) -> {
+            try (Scope ignore = parentContext.makeCurrent()) {
+              ((BiConsumer<Object, Object>) object).accept(o1, o2);
+            }
+          };
+    } else if (object instanceof Runnable) {
+      return (Runnable)
+          () -> {
+            try (Scope ignore = parentContext.makeCurrent()) {
+              ((Runnable) object).run();
+            }
+          };
+    }
+
+    return object;
+  }
+
+  private InfluxDbObjetWrapper() {}
+}
diff --git a/instrumentation/influxdb-2.4/javaagent/src/test/java/io/opentelemetry/javaagent/instrumentation/influxdb/v2_4/InfluxDbClientTest.java b/instrumentation/influxdb-2.4/javaagent/src/test/java/io/opentelemetry/javaagent/instrumentation/influxdb/v2_4/InfluxDbClientTest.java
index 8c9c87fa44fb..d239606bfd37 100644
--- a/instrumentation/influxdb-2.4/javaagent/src/test/java/io/opentelemetry/javaagent/instrumentation/influxdb/v2_4/InfluxDbClientTest.java
+++ b/instrumentation/influxdb-2.4/javaagent/src/test/java/io/opentelemetry/javaagent/instrumentation/influxdb/v2_4/InfluxDbClientTest.java
@@ -211,12 +211,8 @@ void testQueryWithFiveArguments() throws InterruptedException {
           influxDb.query(
               query,
               10,
-              (cancellable, queryResult) -> {
-                countDownLatch.countDown();
-              },
-              () -> {
-                testing.runWithSpan("child", () -> {});
-              },
+              (cancellable, queryResult) -> countDownLatch.countDown(),
+              () -> testing.runWithSpan("child", () -> {}),
               throwable -> {});
         });
     assertThat(countDownLatch.await(10, TimeUnit.SECONDS)).isTrue();
@@ -235,7 +231,7 @@ void testQueryWithFiveArguments() throws InterruptedException {
                                 "SELECT",
                                 databaseName)),
                 span ->
-                    span.hasName("child").hasKind(SpanKind.INTERNAL).hasParent(trace.getSpan(1))));
+                    span.hasName("child").hasKind(SpanKind.INTERNAL).hasParent(trace.getSpan(0))));
   }
 
   @Test
@@ -269,7 +265,7 @@ void testQueryFailedWithFiveArguments() throws InterruptedException {
                             attributeAssertions(
                                 "SELECT MEAN(water_level) FROM;", "SELECT", databaseName)),
                 span ->
-                    span.hasName("child").hasKind(SpanKind.INTERNAL).hasParent(trace.getSpan(1))));
+                    span.hasName("child").hasKind(SpanKind.INTERNAL).hasParent(trace.getSpan(0))));
   }
 
   @Test