Skip to content

Commit

Permalink
bugfix of contextmanager of generator
Browse files Browse the repository at this point in the history
  • Loading branch information
Corvin Lasogga committed Jul 22, 2022
1 parent 97d89d5 commit b3b9ded
Showing 1 changed file with 139 additions and 105 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ def __init__(self, meter, tracer):
self._duration_histogram = self._meter.create_histogram(
name="rpc.server.duration",
unit="ms",
description="measures the duration of the inbound rpc",
description="measures duration of inbound RPC",
)
self._request_size_histogram = self._meter.create_histogram(
name="rpc.server.request.size",
Expand Down Expand Up @@ -339,70 +339,16 @@ def _create_attributes(self, context, full_method):
def intercept_service(self, continuation, handler_call_details):
def telemetry_wrapper(behavior, request_streaming, response_streaming):
def telemetry_interceptor(request_or_iterator, context):
with self._set_remote_context(context):
attributes = self._create_attributes(context, handler_call_details.method)

with self._tracer.start_as_current_span(
name=handler_call_details.method,
kind=trace.SpanKind.SERVER,
attributes=attributes,
end_on_exit=False,
record_exception=False,
set_status_on_exception=False
) as span:

try:
# wrap the context
context = _OpenTelemetryServicerContext(context, span)

# wrap / log the request (iterator)
if request_streaming:
request_or_iterator = self._log_stream_requests(
request_or_iterator, span, attributes
)
else:
self._log_unary_request(
request_or_iterator, span, attributes
)

# call the actual RPC and track the duration
with self._record_duration(attributes, context):
response_or_iterator = behavior(request_or_iterator, context)

# wrap / log the response (iterator)
if response_streaming:
response_or_iterator = self._log_stream_responses(
response_or_iterator, span, attributes, context
)
else:
self._log_unary_response(
response_or_iterator, span, attributes, context
)

return response_or_iterator

except Exception as exc:
# Bare exceptions are likely to be gRPC aborts, which
# we handle in our context wrapper.
# Here, we're interested in uncaught exceptions.
# pylint:disable=unidiomatic-typecheck
if type(exc) != Exception:
span.set_attribute(
SpanAttributes.RPC_GRPC_STATUS_CODE,
grpc.StatusCode.UNKNOWN.value[0]
)
span.set_status(
Status(
status_code=StatusCode.ERROR,
description=f"{type(exc).__name__}: {exc}",
)
)
span.record_exception(exc)
raise exc

finally:
if not response_streaming:
span.end()
if response_streaming:
return self._intercept_streaming_response(
behavior, request_or_iterator, context,
request_streaming, handler_call_details.method
)

return self._intercept_unary_response(
behavior, request_or_iterator, context,
request_streaming, handler_call_details.method
)

return telemetry_interceptor

Expand All @@ -411,6 +357,118 @@ def telemetry_interceptor(request_or_iterator, context):
telemetry_wrapper
)

def _intercept_unary_response(self, behavior, request_or_iterator, context, request_streaming, full_method):
with self._set_remote_context(context):
attributes = self._create_attributes(context, full_method)

with self._tracer.start_as_current_span(
name=full_method,
kind=trace.SpanKind.SERVER,
attributes=attributes,
end_on_exit=True,
record_exception=False,
set_status_on_exception=False
) as span:

try:
# wrap the context
context = _OpenTelemetryServicerContext(context, span)

# wrap / log the request (iterator)
if request_streaming:
request_or_iterator = self._log_streaming_request(
request_or_iterator, span, attributes
)
else:
self._log_unary_request(
request_or_iterator, span, attributes
)

# call the actual RPC and track the duration
with self._record_duration(attributes, context):
response_or_iterator = behavior(request_or_iterator, context)

# log the response (iterator)
self._log_unary_response(
response_or_iterator, span, attributes, context
)

return response_or_iterator

except Exception as exc:
# Bare exceptions are likely to be gRPC aborts, which
# we handle in our context wrapper.
# Here, we're interested in uncaught exceptions.
# pylint:disable=unidiomatic-typecheck
if type(exc) != Exception:
span.set_attribute(
SpanAttributes.RPC_GRPC_STATUS_CODE,
grpc.StatusCode.UNKNOWN.value[0]
)
span.set_status(
Status(
status_code=StatusCode.ERROR,
description=f"{type(exc).__name__}: {exc}",
)
)
span.record_exception(exc)
raise exc

def _intercept_streaming_response(self, behavior, request_or_iterator, context, request_streaming, full_method):
with self._set_remote_context(context):
attributes = self._create_attributes(context, full_method)

with self._tracer.start_as_current_span(
name=full_method,
kind=trace.SpanKind.SERVER,
attributes=attributes,
end_on_exit=True,
record_exception=False,
set_status_on_exception=False
) as span:

try:
# wrap the context
context = _OpenTelemetryServicerContext(context, span)

# wrap / log the request (iterator)
if request_streaming:
request_or_iterator = self._log_streaming_request(
request_or_iterator, span, attributes
)
else:
self._log_unary_request(
request_or_iterator, span, attributes
)

# call the actual RPC and track the duration
with self._record_duration(attributes, context):
response_or_iterator = behavior(request_or_iterator, context)

# log the response (iterator)
yield from self._log_streaming_response(
response_or_iterator, span, attributes, context
)

except Exception as exc:
# Bare exceptions are likely to be gRPC aborts, which
# we handle in our context wrapper.
# Here, we're interested in uncaught exceptions.
# pylint:disable=unidiomatic-typecheck
if type(exc) != Exception:
span.set_attribute(
SpanAttributes.RPC_GRPC_STATUS_CODE,
grpc.StatusCode.UNKNOWN.value[0]
)
span.set_status(
Status(
status_code=StatusCode.ERROR,
description=f"{type(exc).__name__}: {exc}",
)
)
span.record_exception(exc)
raise exc

def _log_unary_request(self, request, active_span, attributes):
message_size_by = request.ByteSize()
_add_message_event(
Expand All @@ -429,7 +487,7 @@ def _log_unary_response(self, response, active_span, attributes, context):
self._response_size_histogram.record(message_size_by, attributes)
self._responses_per_rpc_histogram.record(1, attributes)

def _log_stream_requests(self, request_iterator, active_span, attributes):
def _log_streaming_request(self, request_iterator, active_span, attributes):
req_id = 1
for req_id, msg in enumerate(request_iterator, start=1):
message_size_by = msg.ByteSize()
Expand All @@ -441,52 +499,28 @@ def _log_stream_requests(self, request_iterator, active_span, attributes):

self._requests_per_rpc_histogram.record(req_id, attributes)

def _log_stream_responses(self, response_iterator, active_span, attributes, context):
with trace.use_span(
active_span,
end_on_exit=True,
record_exception=False,
set_status_on_exception=False
):
try:
res_id = 1
for res_id, msg in enumerate(response_iterator, start=1):
message_size_by = msg.ByteSize()
_add_message_event(
active_span, MessageTypeValues.SENT.value, message_size_by, message_id=res_id
)
self._response_size_histogram.record(message_size_by, attributes)
yield msg
except Exception as exc:
# Bare exceptions are likely to be gRPC aborts, which
# we handle in our context wrapper.
# Here, we're interested in uncaught exceptions.
# pylint:disable=unidiomatic-typecheck
if type(exc) != Exception:
active_span.set_attribute(
SpanAttributes.RPC_GRPC_STATUS_CODE,
grpc.StatusCode.UNKNOWN.value[0]
)
active_span.set_status(
Status(
status_code=StatusCode.ERROR,
description=f"{type(exc).__name__}: {exc}",
)
)
active_span.record_exception(exc)
raise exc
finally:
if context._code != grpc.StatusCode.OK:
attributes[SpanAttributes.RPC_GRPC_STATUS_CODE] = context._code.value[0]
self._responses_per_rpc_histogram.record(res_id, attributes)
def _log_streaming_response(self, response_iterator, active_span, attributes, context):
try:
res_id = 1
for res_id, msg in enumerate(response_iterator, start=1):
message_size_by = msg.ByteSize()
_add_message_event(
active_span, MessageTypeValues.SENT.value, message_size_by, message_id=res_id
)
self._response_size_histogram.record(message_size_by, attributes)
yield msg
finally:
if context._code != grpc.StatusCode.OK:
attributes[SpanAttributes.RPC_GRPC_STATUS_CODE] = context._code.value[0]
self._responses_per_rpc_histogram.record(res_id, attributes)

@contextmanager
def _record_duration(self, attributes, context):
start = _time_ns()
try:
yield
finally:
duration = max(round((_time_ns() - start) * 1000), 0)
duration = max(round((_time_ns() - start) / 1000), 0)
if context._code != grpc.StatusCode.OK:
attributes[SpanAttributes.RPC_GRPC_STATUS_CODE] = context._code.value[0]
self._duration_histogram.record(duration, attributes)

0 comments on commit b3b9ded

Please sign in to comment.