From 891764bb894bd4ba677d94966837ded95f50cb4f Mon Sep 17 00:00:00 2001
From: Povilas Versockas
Date: Wed, 20 Sep 2023 13:21:48 +0300
Subject: [PATCH] add trace id to sqs SendMessage message attributes (#3)
---
.../instrumentation/aws_lambda/__init__.py | 42 ++++++++++-
.../instrumentation/botocore/__init__.py | 75 +++++++++++++------
2 files changed, 93 insertions(+), 24 deletions(-)
diff --git a/instrumentation/opentelemetry-instrumentation-aws-lambda/src/opentelemetry/instrumentation/aws_lambda/__init__.py b/instrumentation/opentelemetry-instrumentation-aws-lambda/src/opentelemetry/instrumentation/aws_lambda/__init__.py
index afe0fdcb7a..e7ea681abd 100644
--- a/instrumentation/opentelemetry-instrumentation-aws-lambda/src/opentelemetry/instrumentation/aws_lambda/__init__.py
+++ b/instrumentation/opentelemetry-instrumentation-aws-lambda/src/opentelemetry/instrumentation/aws_lambda/__init__.py
@@ -87,18 +87,22 @@ def custom_event_context_extractor(lambda_event):
TRACE_HEADER_KEY,
AwsXRayPropagator,
)
+from opentelemetry.propagators import textmap
from opentelemetry.semconv.resource import ResourceAttributes
from opentelemetry.semconv.trace import SpanAttributes
from opentelemetry.trace import (
Span,
SpanKind,
+ Link,
TracerProvider,
+ get_current_span,
get_tracer,
get_tracer_provider,
set_span_in_context
)
from opentelemetry.trace.propagation import get_current_span
import json
+import typing
#import traceback
#import tracemalloc
@@ -420,8 +424,15 @@ def _instrumented_lambda_handler_call( # noqa pylint: disable=too-many-branches
if lambda_event["Records"][0]["eventSource"] in {
"aws:sqs",
}:
+ links = []
+ for record in lambda_event["Records"]:
+ attributes = record.get("messageAttributes")
+ if attributes is not None:
+ ctx = get_global_textmap().extract(carrier=attributes, getter=SQSGetter())
+ links.append(Link(get_current_span(ctx).get_span_context()))
+
span_name = orig_handler_name
- sqsTriggerSpan = tracer.start_span(span_name, context=parent_context, kind=SpanKind.PRODUCER)
+ sqsTriggerSpan = tracer.start_span(span_name, context=parent_context, kind=SpanKind.PRODUCER, links=links)
sqsTriggerSpan.set_attribute(SpanAttributes.FAAS_TRIGGER, "pubsub")
sqsTriggerSpan.set_attribute("faas.trigger.type", "SQS")
@@ -431,7 +442,8 @@ def _instrumented_lambda_handler_call( # noqa pylint: disable=too-many-branches
sqsTriggerSpan.set_attribute(
"rpc.request.body",
lambda_event["Records"][0].get("body"),
- )
+ )
+
except Exception as ex:
pass
@@ -786,3 +798,29 @@ def _uninstrument(self, **kwargs):
import_module(self._wrapped_module_name),
self._wrapped_function_name,
)
+
+
+class SQSGetter():
+ def get(
+ self, carrier: typing.Mapping[str, textmap.CarrierValT], key: str
+ ) -> typing.Optional[typing.List[str]]:
+ """Getter implementation to retrieve a value from a dictionary.
+
+ Args:
+ carrier: dictionary in which to get value
+ key: the key used to get the value
+ Returns:
+ A list with a single string with the value if it exists, else None.
+ """
+ val = carrier.get(key, None)
+ if val is None:
+ return None
+ if val.get("stringValue") is not None:
+ return [val.get("stringValue")]
+ return None
+
+ def keys(
+ self, carrier: typing.Mapping[str, textmap.CarrierValT]
+ ) -> typing.List[str]:
+ """Keys implementation that returns all keys from a dictionary."""
+ return list(carrier.keys())
diff --git a/instrumentation/opentelemetry-instrumentation-botocore/src/opentelemetry/instrumentation/botocore/__init__.py b/instrumentation/opentelemetry-instrumentation-botocore/src/opentelemetry/instrumentation/botocore/__init__.py
index 24a28b11ea..f970a501b4 100644
--- a/instrumentation/opentelemetry-instrumentation-botocore/src/opentelemetry/instrumentation/botocore/__init__.py
+++ b/instrumentation/opentelemetry-instrumentation-botocore/src/opentelemetry/instrumentation/botocore/__init__.py
@@ -105,13 +105,13 @@ def response_hook(span, service_name, operation_name, result):
_SUPPRESS_INSTRUMENTATION_KEY,
unwrap,
)
-from opentelemetry.propagators.aws.aws_xray_propagator import AwsXRayPropagator
+from opentelemetry.propagate import inject
+from opentelemetry.propagators import textmap
from opentelemetry.semconv.trace import SpanAttributes
from opentelemetry.trace import get_tracer
from opentelemetry.trace.span import Span
-import copy
import base64
-import traceback
+import typing
logger = logging.getLogger(__name__)
@@ -148,7 +148,7 @@ def _instrument(self, **kwargs):
self.request_hook = kwargs.get("request_hook")
self.response_hook = kwargs.get("response_hook")
try:
- self.payload_size_limit = int(os.environ.get("OTEL_PAYLOAD_SIZE_LIMIT", 204800))
+ self.payload_size_limit = int(os.environ.get("OTEL_PAYLOAD_SIZE_LIMIT", 51200))
except ValueError:
logger.error(
"OTEL_PAYLOAD_SIZE_LIMIT is not a number"
@@ -200,11 +200,6 @@ def _patched_api_call(self, original_func, instance, args, kwargs):
if call_context is None:
return original_func(*args, **kwargs)
- #print("parsing context")
- #print(call_context.service)
- #print(call_context.operation)
- #print(args[1].get("ClientContext"))
-
extension = _find_extension(call_context)
if not extension.should_trace_service_call():
return original_func(*args, **kwargs)
@@ -223,21 +218,21 @@ def _patched_api_call(self, original_func, instance, args, kwargs):
elif call_context.operation == "PutObject":
body = call_context.params.get("Body")
if body is not None:
- attributes["rpc.request.payload"] = body.decode('ascii')
+ attributes["rpc.request.payload"] = limit_string_size(self.payload_size_limit, body.decode('ascii'))
elif call_context.operation == "PutItem":
body = call_context.params.get("Item")
if body is not None:
- attributes["rpc.request.payload"] = json.dumps(body, default=str)
+ attributes["rpc.request.payload"] = limit_string_size(self.payload_size_limit, json.dumps(body, default=str))
elif call_context.operation == "GetItem":
body = call_context.params.get("Key")
if body is not None:
- attributes["rpc.request.payload"] = json.dumps(body, default=str)
+ attributes["rpc.request.payload"] = limit_string_size(self.payload_size_limit,json.dumps(body, default=str))
elif call_context.operation == "Publish":
body = call_context.params.get("Message")
if body is not None:
- attributes["rpc.request.payload"] = json.dumps(body, default=str)
+ attributes["rpc.request.payload"] = limit_string_size(self.payload_size_limit,json.dumps(body, default=str))
else:
- attributes["rpc.request.payload"] = json.dumps(call_context.params, default=str)
+ attributes["rpc.request.payload"] = limit_string_size(self.payload_size_limit, json.dumps(call_context.params, default=str))
except Exception as ex:
pass
@@ -266,19 +261,34 @@ def _patched_api_call(self, original_func, instance, args, kwargs):
jctx = json.dumps(ctx)
args[1]['ClientContext'] = base64.b64encode(jctx.encode('ascii')).decode('ascii')
else:
- #ctx = {'custom': {'traceContext':{}}}
- #inject(ctx['custom']['traceContext'])
ctx = {'custom': {}}
inject(ctx['custom'])
jctx = json.dumps(ctx)
args[1]['ClientContext'] = base64.b64encode(jctx.encode('ascii')).decode('ascii')
except Exception as ex:
- #print(traceback.format_exc())
- #print("exception")
- #print(ex)
pass
+ try:
+ if call_context.service == "sqs" and call_context.operation == "SendMessage":
+ if args[1].get("MessageAttributes") is not None:
+ inject(carrier = args[1].get("MessageAttributes"), setter=SQSSetter())
+ else:
+ args[1]['MessageAttributes'] = {}
+ inject(carrier = args[1].get("MessageAttributes"), setter=SQSSetter())
+
+ if call_context.service == "sqs" and call_context.operation == "SendMessageBatch":
+ if args[1].get("Entries") is not None:
+ for entry in args[1].get("Entries"):
+ if entry.get("MessageAttributes") is not None:
+ inject(carrier = entry.get("MessageAttributes"), setter=SQSSetter())
+ else:
+ entry['MessageAttributes'] = {}
+ inject(carrier = entry.get("MessageAttributes"), setter=SQSSetter())
+
+ except Exception as ex:
+ pass
+
result = None
try:
#print("calling original func")
@@ -405,9 +415,6 @@ def _apply_response_attributes(span: Span, result, payload_size_limit):
span.set_attribute(
"rpc.response.payload", json.dumps(result, default=str))
except Exception as ex:
- #print(traceback.format_exc())
- #print("exception")
- #print(ex)
pass
@@ -441,3 +448,27 @@ def _safe_invoke(function: Callable, *args):
logger.error(
"Error when invoking function '%s'", function_name, exc_info=ex
)
+
+class SQSSetter():
+ def set(
+ self,
+ carrier: typing.MutableMapping[str, textmap.CarrierValT],
+ key: str,
+ value: textmap.CarrierValT,
+ ) -> None:
+ """Setter implementation to set a value into a dictionary.
+
+ Args:
+ carrier: dictionary in which to set value
+ key: the key used to set the value
+ value: the value to set
+ """
+ val = {"DataType": "String", "StringValue": value}
+ carrier[key] = val
+
+def limit_string_size(s: str, max_size: int) -> str:
+ if len(s) > max_size:
+ return s[:max_size]
+ else:
+ return s
+