Skip to content

Commit

Permalink
add trace id to sqs SendMessage message attributes (#3)
Browse files Browse the repository at this point in the history
  • Loading branch information
povilasv committed Sep 25, 2023
1 parent 8be8a50 commit 891764b
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -87,18 +87,22 @@ def custom_event_context_extractor(lambda_event):
TRACE_HEADER_KEY,
AwsXRayPropagator,
)
from opentelemetry.propagators import textmap
from opentelemetry.semconv.resource import ResourceAttributes
from opentelemetry.semconv.trace import SpanAttributes
from opentelemetry.trace import (
Span,
SpanKind,
Link,
TracerProvider,
get_current_span,
get_tracer,
get_tracer_provider,
set_span_in_context
)
from opentelemetry.trace.propagation import get_current_span
import json
import typing
#import traceback
#import tracemalloc

Expand Down Expand Up @@ -420,8 +424,15 @@ def _instrumented_lambda_handler_call( # noqa pylint: disable=too-many-branches
if lambda_event["Records"][0]["eventSource"] in {
"aws:sqs",
}:
links = []
for record in lambda_event["Records"]:
attributes = record.get("messageAttributes")
if attributes is not None:
ctx = get_global_textmap().extract(carrier=attributes, getter=SQSGetter())
links.append(Link(get_current_span(ctx).get_span_context()))

span_name = orig_handler_name
sqsTriggerSpan = tracer.start_span(span_name, context=parent_context, kind=SpanKind.PRODUCER)
sqsTriggerSpan = tracer.start_span(span_name, context=parent_context, kind=SpanKind.PRODUCER, links=links)
sqsTriggerSpan.set_attribute(SpanAttributes.FAAS_TRIGGER, "pubsub")
sqsTriggerSpan.set_attribute("faas.trigger.type", "SQS")

Expand All @@ -431,7 +442,8 @@ def _instrumented_lambda_handler_call( # noqa pylint: disable=too-many-branches
sqsTriggerSpan.set_attribute(
"rpc.request.body",
lambda_event["Records"][0].get("body"),
)
)

except Exception as ex:
pass

Expand Down Expand Up @@ -786,3 +798,29 @@ def _uninstrument(self, **kwargs):
import_module(self._wrapped_module_name),
self._wrapped_function_name,
)


class SQSGetter():
def get(
self, carrier: typing.Mapping[str, textmap.CarrierValT], key: str
) -> typing.Optional[typing.List[str]]:
"""Getter implementation to retrieve a value from a dictionary.
Args:
carrier: dictionary in which to get value
key: the key used to get the value
Returns:
A list with a single string with the value if it exists, else None.
"""
val = carrier.get(key, None)
if val is None:
return None
if val.get("stringValue") is not None:
return [val.get("stringValue")]
return None

def keys(
self, carrier: typing.Mapping[str, textmap.CarrierValT]
) -> typing.List[str]:
"""Keys implementation that returns all keys from a dictionary."""
return list(carrier.keys())
Original file line number Diff line number Diff line change
Expand Up @@ -105,13 +105,13 @@ def response_hook(span, service_name, operation_name, result):
_SUPPRESS_INSTRUMENTATION_KEY,
unwrap,
)
from opentelemetry.propagators.aws.aws_xray_propagator import AwsXRayPropagator
from opentelemetry.propagate import inject
from opentelemetry.propagators import textmap
from opentelemetry.semconv.trace import SpanAttributes
from opentelemetry.trace import get_tracer
from opentelemetry.trace.span import Span
import copy
import base64
import traceback
import typing
logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -148,7 +148,7 @@ def _instrument(self, **kwargs):
self.request_hook = kwargs.get("request_hook")
self.response_hook = kwargs.get("response_hook")
try:
self.payload_size_limit = int(os.environ.get("OTEL_PAYLOAD_SIZE_LIMIT", 204800))
self.payload_size_limit = int(os.environ.get("OTEL_PAYLOAD_SIZE_LIMIT", 51200))
except ValueError:
logger.error(
"OTEL_PAYLOAD_SIZE_LIMIT is not a number"
Expand Down Expand Up @@ -200,11 +200,6 @@ def _patched_api_call(self, original_func, instance, args, kwargs):
if call_context is None:
return original_func(*args, **kwargs)

#print("parsing context")
#print(call_context.service)
#print(call_context.operation)
#print(args[1].get("ClientContext"))

extension = _find_extension(call_context)
if not extension.should_trace_service_call():
return original_func(*args, **kwargs)
Expand All @@ -223,21 +218,21 @@ def _patched_api_call(self, original_func, instance, args, kwargs):
elif call_context.operation == "PutObject":
body = call_context.params.get("Body")
if body is not None:
attributes["rpc.request.payload"] = body.decode('ascii')
attributes["rpc.request.payload"] = limit_string_size(self.payload_size_limit, body.decode('ascii'))
elif call_context.operation == "PutItem":
body = call_context.params.get("Item")
if body is not None:
attributes["rpc.request.payload"] = json.dumps(body, default=str)
attributes["rpc.request.payload"] = limit_string_size(self.payload_size_limit, json.dumps(body, default=str))
elif call_context.operation == "GetItem":
body = call_context.params.get("Key")
if body is not None:
attributes["rpc.request.payload"] = json.dumps(body, default=str)
attributes["rpc.request.payload"] = limit_string_size(self.payload_size_limit,json.dumps(body, default=str))
elif call_context.operation == "Publish":
body = call_context.params.get("Message")
if body is not None:
attributes["rpc.request.payload"] = json.dumps(body, default=str)
attributes["rpc.request.payload"] = limit_string_size(self.payload_size_limit,json.dumps(body, default=str))
else:
attributes["rpc.request.payload"] = json.dumps(call_context.params, default=str)
attributes["rpc.request.payload"] = limit_string_size(self.payload_size_limit, json.dumps(call_context.params, default=str))
except Exception as ex:
pass

Expand Down Expand Up @@ -266,19 +261,34 @@ def _patched_api_call(self, original_func, instance, args, kwargs):
jctx = json.dumps(ctx)
args[1]['ClientContext'] = base64.b64encode(jctx.encode('ascii')).decode('ascii')
else:
#ctx = {'custom': {'traceContext':{}}}
#inject(ctx['custom']['traceContext'])
ctx = {'custom': {}}
inject(ctx['custom'])
jctx = json.dumps(ctx)
args[1]['ClientContext'] = base64.b64encode(jctx.encode('ascii')).decode('ascii')

except Exception as ex:
#print(traceback.format_exc())
#print("exception")
#print(ex)
pass

try:
if call_context.service == "sqs" and call_context.operation == "SendMessage":
if args[1].get("MessageAttributes") is not None:
inject(carrier = args[1].get("MessageAttributes"), setter=SQSSetter())
else:
args[1]['MessageAttributes'] = {}
inject(carrier = args[1].get("MessageAttributes"), setter=SQSSetter())

if call_context.service == "sqs" and call_context.operation == "SendMessageBatch":
if args[1].get("Entries") is not None:
for entry in args[1].get("Entries"):
if entry.get("MessageAttributes") is not None:
inject(carrier = entry.get("MessageAttributes"), setter=SQSSetter())
else:
entry['MessageAttributes'] = {}
inject(carrier = entry.get("MessageAttributes"), setter=SQSSetter())

except Exception as ex:
pass

result = None
try:
#print("calling original func")
Expand Down Expand Up @@ -405,9 +415,6 @@ def _apply_response_attributes(span: Span, result, payload_size_limit):
span.set_attribute(
"rpc.response.payload", json.dumps(result, default=str))
except Exception as ex:
#print(traceback.format_exc())
#print("exception")
#print(ex)
pass


Expand Down Expand Up @@ -441,3 +448,27 @@ def _safe_invoke(function: Callable, *args):
logger.error(
"Error when invoking function '%s'", function_name, exc_info=ex
)

class SQSSetter():
def set(
self,
carrier: typing.MutableMapping[str, textmap.CarrierValT],
key: str,
value: textmap.CarrierValT,
) -> None:
"""Setter implementation to set a value into a dictionary.
Args:
carrier: dictionary in which to set value
key: the key used to set the value
value: the value to set
"""
val = {"DataType": "String", "StringValue": value}
carrier[key] = val

def limit_string_size(s: str, max_size: int) -> str:
if len(s) > max_size:
return s[:max_size]
else:
return s

0 comments on commit 891764b

Please sign in to comment.