From 891764bb894bd4ba677d94966837ded95f50cb4f Mon Sep 17 00:00:00 2001 From: Povilas Versockas Date: Wed, 20 Sep 2023 13:21:48 +0300 Subject: [PATCH] add trace id to sqs SendMessage message attributes (#3) --- .../instrumentation/aws_lambda/__init__.py | 42 ++++++++++- .../instrumentation/botocore/__init__.py | 75 +++++++++++++------ 2 files changed, 93 insertions(+), 24 deletions(-) diff --git a/instrumentation/opentelemetry-instrumentation-aws-lambda/src/opentelemetry/instrumentation/aws_lambda/__init__.py b/instrumentation/opentelemetry-instrumentation-aws-lambda/src/opentelemetry/instrumentation/aws_lambda/__init__.py index afe0fdcb7a..e7ea681abd 100644 --- a/instrumentation/opentelemetry-instrumentation-aws-lambda/src/opentelemetry/instrumentation/aws_lambda/__init__.py +++ b/instrumentation/opentelemetry-instrumentation-aws-lambda/src/opentelemetry/instrumentation/aws_lambda/__init__.py @@ -87,18 +87,22 @@ def custom_event_context_extractor(lambda_event): TRACE_HEADER_KEY, AwsXRayPropagator, ) +from opentelemetry.propagators import textmap from opentelemetry.semconv.resource import ResourceAttributes from opentelemetry.semconv.trace import SpanAttributes from opentelemetry.trace import ( Span, SpanKind, + Link, TracerProvider, + get_current_span, get_tracer, get_tracer_provider, set_span_in_context ) from opentelemetry.trace.propagation import get_current_span import json +import typing #import traceback #import tracemalloc @@ -420,8 +424,15 @@ def _instrumented_lambda_handler_call( # noqa pylint: disable=too-many-branches if lambda_event["Records"][0]["eventSource"] in { "aws:sqs", }: + links = [] + for record in lambda_event["Records"]: + attributes = record.get("messageAttributes") + if attributes is not None: + ctx = get_global_textmap().extract(carrier=attributes, getter=SQSGetter()) + links.append(Link(get_current_span(ctx).get_span_context())) + span_name = orig_handler_name - sqsTriggerSpan = tracer.start_span(span_name, context=parent_context, kind=SpanKind.PRODUCER) + sqsTriggerSpan = tracer.start_span(span_name, context=parent_context, kind=SpanKind.PRODUCER, links=links) sqsTriggerSpan.set_attribute(SpanAttributes.FAAS_TRIGGER, "pubsub") sqsTriggerSpan.set_attribute("faas.trigger.type", "SQS") @@ -431,7 +442,8 @@ def _instrumented_lambda_handler_call( # noqa pylint: disable=too-many-branches sqsTriggerSpan.set_attribute( "rpc.request.body", lambda_event["Records"][0].get("body"), - ) + ) + except Exception as ex: pass @@ -786,3 +798,29 @@ def _uninstrument(self, **kwargs): import_module(self._wrapped_module_name), self._wrapped_function_name, ) + + +class SQSGetter(): + def get( + self, carrier: typing.Mapping[str, textmap.CarrierValT], key: str + ) -> typing.Optional[typing.List[str]]: + """Getter implementation to retrieve a value from a dictionary. + + Args: + carrier: dictionary in which to get value + key: the key used to get the value + Returns: + A list with a single string with the value if it exists, else None. + """ + val = carrier.get(key, None) + if val is None: + return None + if val.get("stringValue") is not None: + return [val.get("stringValue")] + return None + + def keys( + self, carrier: typing.Mapping[str, textmap.CarrierValT] + ) -> typing.List[str]: + """Keys implementation that returns all keys from a dictionary.""" + return list(carrier.keys()) diff --git a/instrumentation/opentelemetry-instrumentation-botocore/src/opentelemetry/instrumentation/botocore/__init__.py b/instrumentation/opentelemetry-instrumentation-botocore/src/opentelemetry/instrumentation/botocore/__init__.py index 24a28b11ea..f970a501b4 100644 --- a/instrumentation/opentelemetry-instrumentation-botocore/src/opentelemetry/instrumentation/botocore/__init__.py +++ b/instrumentation/opentelemetry-instrumentation-botocore/src/opentelemetry/instrumentation/botocore/__init__.py @@ -105,13 +105,13 @@ def response_hook(span, service_name, operation_name, result): _SUPPRESS_INSTRUMENTATION_KEY, unwrap, ) -from opentelemetry.propagators.aws.aws_xray_propagator import AwsXRayPropagator +from opentelemetry.propagate import inject +from opentelemetry.propagators import textmap from opentelemetry.semconv.trace import SpanAttributes from opentelemetry.trace import get_tracer from opentelemetry.trace.span import Span -import copy import base64 -import traceback +import typing logger = logging.getLogger(__name__) @@ -148,7 +148,7 @@ def _instrument(self, **kwargs): self.request_hook = kwargs.get("request_hook") self.response_hook = kwargs.get("response_hook") try: - self.payload_size_limit = int(os.environ.get("OTEL_PAYLOAD_SIZE_LIMIT", 204800)) + self.payload_size_limit = int(os.environ.get("OTEL_PAYLOAD_SIZE_LIMIT", 51200)) except ValueError: logger.error( "OTEL_PAYLOAD_SIZE_LIMIT is not a number" @@ -200,11 +200,6 @@ def _patched_api_call(self, original_func, instance, args, kwargs): if call_context is None: return original_func(*args, **kwargs) - #print("parsing context") - #print(call_context.service) - #print(call_context.operation) - #print(args[1].get("ClientContext")) - extension = _find_extension(call_context) if not extension.should_trace_service_call(): return original_func(*args, **kwargs) @@ -223,21 +218,21 @@ def _patched_api_call(self, original_func, instance, args, kwargs): elif call_context.operation == "PutObject": body = call_context.params.get("Body") if body is not None: - attributes["rpc.request.payload"] = body.decode('ascii') + attributes["rpc.request.payload"] = limit_string_size(self.payload_size_limit, body.decode('ascii')) elif call_context.operation == "PutItem": body = call_context.params.get("Item") if body is not None: - attributes["rpc.request.payload"] = json.dumps(body, default=str) + attributes["rpc.request.payload"] = limit_string_size(self.payload_size_limit, json.dumps(body, default=str)) elif call_context.operation == "GetItem": body = call_context.params.get("Key") if body is not None: - attributes["rpc.request.payload"] = json.dumps(body, default=str) + attributes["rpc.request.payload"] = limit_string_size(self.payload_size_limit,json.dumps(body, default=str)) elif call_context.operation == "Publish": body = call_context.params.get("Message") if body is not None: - attributes["rpc.request.payload"] = json.dumps(body, default=str) + attributes["rpc.request.payload"] = limit_string_size(self.payload_size_limit,json.dumps(body, default=str)) else: - attributes["rpc.request.payload"] = json.dumps(call_context.params, default=str) + attributes["rpc.request.payload"] = limit_string_size(self.payload_size_limit, json.dumps(call_context.params, default=str)) except Exception as ex: pass @@ -266,19 +261,34 @@ def _patched_api_call(self, original_func, instance, args, kwargs): jctx = json.dumps(ctx) args[1]['ClientContext'] = base64.b64encode(jctx.encode('ascii')).decode('ascii') else: - #ctx = {'custom': {'traceContext':{}}} - #inject(ctx['custom']['traceContext']) ctx = {'custom': {}} inject(ctx['custom']) jctx = json.dumps(ctx) args[1]['ClientContext'] = base64.b64encode(jctx.encode('ascii')).decode('ascii') except Exception as ex: - #print(traceback.format_exc()) - #print("exception") - #print(ex) pass + try: + if call_context.service == "sqs" and call_context.operation == "SendMessage": + if args[1].get("MessageAttributes") is not None: + inject(carrier = args[1].get("MessageAttributes"), setter=SQSSetter()) + else: + args[1]['MessageAttributes'] = {} + inject(carrier = args[1].get("MessageAttributes"), setter=SQSSetter()) + + if call_context.service == "sqs" and call_context.operation == "SendMessageBatch": + if args[1].get("Entries") is not None: + for entry in args[1].get("Entries"): + if entry.get("MessageAttributes") is not None: + inject(carrier = entry.get("MessageAttributes"), setter=SQSSetter()) + else: + entry['MessageAttributes'] = {} + inject(carrier = entry.get("MessageAttributes"), setter=SQSSetter()) + + except Exception as ex: + pass + result = None try: #print("calling original func") @@ -405,9 +415,6 @@ def _apply_response_attributes(span: Span, result, payload_size_limit): span.set_attribute( "rpc.response.payload", json.dumps(result, default=str)) except Exception as ex: - #print(traceback.format_exc()) - #print("exception") - #print(ex) pass @@ -441,3 +448,27 @@ def _safe_invoke(function: Callable, *args): logger.error( "Error when invoking function '%s'", function_name, exc_info=ex ) + +class SQSSetter(): + def set( + self, + carrier: typing.MutableMapping[str, textmap.CarrierValT], + key: str, + value: textmap.CarrierValT, + ) -> None: + """Setter implementation to set a value into a dictionary. + + Args: + carrier: dictionary in which to set value + key: the key used to set the value + value: the value to set + """ + val = {"DataType": "String", "StringValue": value} + carrier[key] = val + +def limit_string_size(s: str, max_size: int) -> str: + if len(s) > max_size: + return s[:max_size] + else: + return s +