Skip to content

Commit

Permalink
feat: adhoc: Instrument pyramid/request with Opentelemetry
Browse files Browse the repository at this point in the history
  • Loading branch information
Xaelias committed Dec 22, 2023
1 parent c7fbb13 commit 6f85007
Show file tree
Hide file tree
Showing 4 changed files with 271 additions and 9 deletions.
4 changes: 4 additions & 0 deletions baseplate/clients/requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from advocate import AddrValidator
from advocate import ValidatingHTTPAdapter
from opentelemetry.instrumentation.requests import RequestsInstrumentor
from prometheus_client import Counter
from prometheus_client import Gauge
from prometheus_client import Histogram
Expand All @@ -26,6 +27,9 @@
from baseplate.lib.prometheus_metrics import getHTTPSuccessLabel


RequestsInstrumentor().instrument()


def http_adapter_from_config(
app_config: config.RawConfig, prefix: str, **kwargs: Any
) -> HTTPAdapter:
Expand Down
31 changes: 23 additions & 8 deletions baseplate/frameworks/pyramid/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
import pyramid.tweens
import webob.request

from opentelemetry import trace
from opentelemetry.instrumentation.pyramid import PyramidInstrumentor
from prometheus_client import Counter
from prometheus_client import Gauge
from prometheus_client import Histogram
Expand All @@ -35,10 +37,12 @@
from baseplate.thrift.ttypes import IsHealthyProbe


PyramidInstrumentor().instrument()

logger = logging.getLogger(__name__)


class SpanFinishingAppIterWrapper:
class SpanFinishingAppIterWrapper(Iterable):
"""Wrapper for Response.app_iter that finishes the span when the iterator is done.
The WSGI spec expects applications to return an iterable object. In the
Expand All @@ -53,7 +57,7 @@ class SpanFinishingAppIterWrapper:
"""

def __init__(self, span: Span, app_iter: Iterable[bytes]) -> None:
def __init__(self, app_iter: Iterator[bytes], span: Optional[Span] = None) -> None:
self.span = span
self.app_iter = iter(app_iter)

Expand All @@ -64,10 +68,15 @@ def __next__(self) -> bytes:
try:
return next(self.app_iter)
except StopIteration:
self.span.finish()
trace.get_current_span().set_status(trace.status.StatusCode.OK)
if self.span:
self.span.finish()
raise
except: # noqa: E722
self.span.finish(exc_info=sys.exc_info())
except Exception as e: # noqa: E722
trace.get_current_span().set_status(trace.status.StatusCode.ERROR)
trace.get_current_span().record_exception(e)
if self.span:
self.span.finish(exc_info=sys.exc_info())
raise

def close(self) -> None:
Expand Down Expand Up @@ -129,15 +138,21 @@ def baseplate_tween(request: Request) -> Response:
response = handler(request)
if request.span:
request.span.set_tag("http.response_length", response.content_length)
except: # noqa: E722
except Exception as e: # noqa: E722
trace.get_current_span().set_status(trace.status.StatusCode.ERROR)
trace.get_current_span().record_exception(e)
if hasattr(request, "span") and request.span:
request.span.finish(exc_info=sys.exc_info())
raise
else:
trace.get_current_span().set_status(trace.status.StatusCode.OK)
content_length = response.content_length
if request.span:
request.span.set_tag("http.status_code", response.status_code)
content_length = response.content_length
response.app_iter = SpanFinishingAppIterWrapper(request.span, response.app_iter)
response.app_iter = SpanFinishingAppIterWrapper(response.app_iter, request.span)
response.content_length = content_length
else:
response.app_iter = SpanFinishingAppIterWrapper(response.app_iter)
response.content_length = content_length
finally:
manually_close_request_metrics(request, response)
Expand Down
231 changes: 231 additions & 0 deletions tests/integration/otel_pyramid_tests.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,231 @@
import unittest

from unittest import mock

from opentelemetry import propagate
from opentelemetry import trace
from opentelemetry.propagators.composite import CompositePropagator
from opentelemetry.test.test_base import TestBase
from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator
from pyramid.response import Response

from baseplate import Baseplate
from baseplate.lib.propagator_redditb3 import RedditB3Format

from . import FakeEdgeContextFactory


propagate.set_global_textmap(
CompositePropagator([RedditB3Format(), TraceContextTextMapPropagator()])
)


try:
import webtest

from baseplate.frameworks.pyramid import BaseplateConfigurator
from baseplate.frameworks.pyramid import ServerSpanInitialized
from baseplate.frameworks.pyramid import StaticTrustHandler
from pyramid.config import Configurator
from pyramid.httpexceptions import HTTPInternalServerError
except ImportError:
raise unittest.SkipTest("pyramid/webtest is not installed")


class TestException(Exception):
pass


class ControlFlowException(Exception):
pass


class ControlFlowException2(Exception):
pass


class ExceptionViewException(Exception):
pass


def example_application(request):
if "error" in request.params:
raise TestException("this is a test")

if "control_flow_exception" in request.params:
raise ControlFlowException()

if "exception_view_exception" in request.params:
raise ControlFlowException2()

if "stream" in request.params:

def make_iter():
yield b"foo"
yield b"bar"

return Response(status_code=200, app_iter=make_iter())

return {"test": "success"}


def render_exception_view(request):
return HTTPInternalServerError(title="a fancy title", body="a fancy explanation")


def render_bad_exception_view(request):
raise ExceptionViewException()


def local_tracing_within_context(request):
tracer = trace.get_tracer("in-context")
with tracer.start_as_current_span("local-req"):
pass
return {"trace": "success"}


class ConfiguratorTests(TestBase):
def setUp(self):
super().setUp()
configurator = Configurator()
configurator.add_route("example", "/example", request_method="GET")
configurator.add_route("route", "/route/{hello}/world", request_method="GET")
configurator.add_route("trace_context", "/trace_context", request_method="GET")

configurator.add_view(example_application, route_name="example", renderer="json")
configurator.add_view(example_application, route_name="route", renderer="json")

configurator.add_view(
local_tracing_within_context, route_name="trace_context", renderer="json"
)

configurator.add_view(render_exception_view, context=ControlFlowException, renderer="json")

configurator.add_view(
render_bad_exception_view, context=ControlFlowException2, renderer="json"
)

self.baseplate = Baseplate()
self.baseplate_configurator = BaseplateConfigurator(
self.baseplate,
edge_context_factory=FakeEdgeContextFactory(),
header_trust_handler=StaticTrustHandler(trust_headers=True),
)
configurator.include(self.baseplate_configurator.includeme)
self.context_init_event_subscriber = mock.Mock()
configurator.add_subscriber(self.context_init_event_subscriber, ServerSpanInitialized)
app = configurator.make_wsgi_app()
self.test_app = webtest.TestApp(app)

@mock.patch("random.getrandbits")
def test_no_trace_headers(self, getrandbits):
getrandbits.return_value = 1234
self.test_app.get("/example")

finished_spans = self.get_finished_spans()
self.assertEqual(len(finished_spans), 1)
self.assertIsNone(finished_spans[0].parent)

def test_trace_headers(self):
self.test_app.get(
"/example",
headers={
"traceparent": "00-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7-01",
},
)

finished_spans = self.get_finished_spans()
self.assertEqual(len(finished_spans), 1)
self.assertEqual(finished_spans[0].parent.span_id, 0x00F067AA0BA902B7)
self.assertEqual(finished_spans[0].context.trace_id, 0x4BF92F3577B34DA6A3CE929D0E0E4736)

def test_bp_trace_headers(self):
self.test_app.get(
"/example",
headers={
"X-Trace": "4BF92F3577B34DA6A3CE929D0E0E4736",
"X-Parent": "00F067AA0BA902B7",
"X-Span": "00F067AA0BA902B8",
"X-Sampled": "1",
},
)
finished_spans = self.get_finished_spans()
self.assertEqual(len(finished_spans), 1)
self.assertEqual(finished_spans[0].context.trace_id, 0x4BF92F3577B34DA6A3CE929D0E0E4736)
self.assertEqual(finished_spans[0].parent.span_id, 0x00F067AA0BA902B8)

def test_bp_short_trace_headers(self):
self.test_app.get(
"/example",
headers={
"X-Trace": "20d294c28becf34d",
"X-Parent": "a1bf4d567fc497a4",
"X-Span": "a1bf4d567fc497a5",
"X-Sampled": "1",
},
)
finished_spans = self.get_finished_spans()
self.assertEqual(len(finished_spans), 1)
self.assertEqual(finished_spans[0].context.trace_id, 0x20D294C28BECF34D)
self.assertEqual(finished_spans[0].parent.span_id, 0xA1BF4D567FC497A5)

def test_both_trace_headers(self):
self.test_app.get(
"/example",
headers={
"traceparent": "00-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7-01",
"X-Trace": "20d294c28becf34d", # should get discarded
"X-Parent": "a1bf4d567fc497a4", # should get discarded
"X-Span": "a1bf4d567fc497a5",
"X-Sampled": "1",
},
)

finished_spans = self.get_finished_spans()
self.assertEqual(len(finished_spans), 1)
self.assertEqual(finished_spans[0].parent.span_id, 0x00F067AA0BA902B7)
self.assertEqual(finished_spans[0].context.trace_id, 0x4BF92F3577B34DA6A3CE929D0E0E4736)

def test_not_found(self):
self.test_app.get("/nope", status=404)

finished_spans = self.get_finished_spans()
self.assertEqual(len(finished_spans), 1)
self.assertIsNone(finished_spans[0].parent)

def test_exception_caught(self):
with self.assertRaises(TestException):
self.test_app.get("/example?error")

finished_spans = self.get_finished_spans()
self.assertEqual(len(finished_spans), 1)
self.assertFalse(finished_spans[0].status.is_ok)
self.assertGreater(len(finished_spans[0].events), 0)
self.assertEqual(finished_spans[0].events[0].name, "exception")

def test_control_flow_exception_not_caught(self):
self.test_app.get("/example?control_flow_exception", status=500)

finished_spans = self.get_finished_spans()
self.assertEqual(len(finished_spans), 1)
self.assertTrue(finished_spans[0].status.is_ok)

def test_exception_in_exception_view_caught(self):
with self.assertRaises(ExceptionViewException):
self.test_app.get("/example?exception_view_exception")

finished_spans = self.get_finished_spans()
self.assertEqual(len(finished_spans), 1)
self.assertFalse(finished_spans[0].status.is_ok)

def test_local_trace_in_context(self):
self.test_app.get("/trace_context")

finished_spans = self.get_finished_spans()
self.assertGreater(len(finished_spans), 1)
self.assertEqual(finished_spans[0].kind, trace.SpanKind.INTERNAL)

# self.assertEqual(self.server_observer.on_child_span_created.call_count, 1)
# child_span = self.server_observer.on_child_span_created.call_args[0][0]
# context, server_span = self.observer.on_server_span_created.call_args[0]
# self.assertNotEqual(child_span.context, context)
14 changes: 13 additions & 1 deletion tests/integration/pyramid_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from unittest import mock

from opentelemetry.test.test_base import TestBase
from pyramid.response import Response

from baseplate import Baseplate
Expand Down Expand Up @@ -75,8 +76,9 @@ def local_tracing_within_context(request):
return {"trace": "success"}


class ConfiguratorTests(unittest.TestCase):
class ConfiguratorTests(TestBase):
def setUp(self):
super().setUp()
configurator = Configurator()
configurator.add_route("example", "/example", request_method="GET")
configurator.add_route("route", "/route/{hello}/world", request_method="GET")
Expand Down Expand Up @@ -132,6 +134,10 @@ def test_no_trace_headers(self, getrandbits):
self.assertTrue(self.server_observer.on_finish.called)
self.assertTrue(self.context_init_event_subscriber.called)

finished_spans = self.get_finished_spans()
self.assertEqual(len(finished_spans), 1)
self.assertIsNone(finished_spans[0].parent)

def test_trace_headers(self):
self.test_app.get(
"/example",
Expand All @@ -142,6 +148,7 @@ def test_trace_headers(self):
"X-Span": "3456",
"X-Sampled": "1",
"X-Flags": "1",
"traceparent": "00-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7-01",
},
)

Expand All @@ -158,6 +165,11 @@ def test_trace_headers(self):
self.assertTrue(self.server_observer.on_finish.called)
self.assertTrue(self.context_init_event_subscriber.called)

finished_spans = self.get_finished_spans()
self.assertEqual(len(finished_spans), 1)
self.assertIsNotNone(finished_spans[0].parent)
self.assertEqual(finished_spans[0].context.trace_id, 0x4BF92F3577B34DA6A3CE929D0E0E4736)

def test_edge_request_headers(self):
self.test_app.get(
"/example",
Expand Down

0 comments on commit 6f85007

Please sign in to comment.