Skip to content

Commit

Permalink
feat(backend): add e2e useful tracing
Browse files Browse the repository at this point in the history
Enabled tracing component inside git agent in order to have end to end
spans for a single request.
Added spans at most useful places such as message bus' execute_message
and database execute_query.

Signed-off-by: Fatih Acar <[email protected]>
  • Loading branch information
fatih-acar committed Jan 26, 2024
1 parent 09ce0cd commit 333d166
Show file tree
Hide file tree
Showing 11 changed files with 176 additions and 106 deletions.
12 changes: 11 additions & 1 deletion backend/infrahub/cli/git_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from prometheus_client import start_http_server
from rich.logging import RichHandler

from infrahub import config
from infrahub import __version__, config
from infrahub.components import ComponentType
from infrahub.core.initialization import initialization
from infrahub.database import InfrahubDatabase, get_db
Expand All @@ -20,6 +20,7 @@
from infrahub.services import InfrahubServices
from infrahub.services.adapters.cache.redis import RedisCache
from infrahub.services.adapters.message_bus.rabbitmq import RabbitMQMessageBus
from infrahub.trace import configure_trace

app = typer.Typer()

Expand Down Expand Up @@ -65,6 +66,15 @@ async def _start(debug: bool, port: int) -> None:
client = await InfrahubClient.init(address=config.SETTINGS.main.internal_address, retry_on_failure=True, log=log)
await client.branch.all()

# Initialize trace
if config.SETTINGS.trace.enable:
configure_trace(
service="infrahub-git-agent",
version=__version__,
exporter_endpoint=config.SETTINGS.trace.trace_endpoint,
exporter_protocol=config.SETTINGS.trace.exporter_protocol,
)

# Initialize the lock
initialize_lock()

Expand Down
30 changes: 17 additions & 13 deletions backend/infrahub/database/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

# from contextlib import asynccontextmanager
from neo4j.exceptions import ClientError, ServiceUnavailable
from otel_extensions import get_tracer

from infrahub import config
from infrahub.exceptions import DatabaseError
Expand Down Expand Up @@ -192,19 +193,22 @@ async def close(self):
async def execute_query(
self, query: str, params: Optional[Dict[str, Any]] = None, name: Optional[str] = "undefined"
) -> List[Record]:
with QUERY_EXECUTION_METRICS.labels(str(self._session_mode), name).time():
if self.is_transaction:
execution_method = await self.transaction()
else:
execution_method = await self.session()

try:
response = await execution_method.run(query=query, parameters=params)
except ServiceUnavailable as exc:
log.error("Database Service unavailable", error=str(exc))
raise DatabaseError(message="Unable to connect to the database") from exc

return [item async for item in response]
with get_tracer(__name__).start_as_current_span("execute_db_query") as span:
span.set_attribute("query", query)

with QUERY_EXECUTION_METRICS.labels(str(self._session_mode), name).time():
if self.is_transaction:
execution_method = await self.transaction()
else:
execution_method = await self.session()

try:
response = await execution_method.run(query=query, parameters=params)
except ServiceUnavailable as exc:
log.error("Database Service unavailable", error=str(exc))
raise DatabaseError(message="Unable to connect to the database") from exc

return [item async for item in response]

def render_list_comprehension(self, items: str, item_name: str) -> str:
if self.db_type == DatabaseType.MEMGRAPH:
Expand Down
2 changes: 2 additions & 0 deletions backend/infrahub/graphql/mutations/branch.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from graphene import Boolean, Field, InputObjectType, List, Mutation, String
from graphql import GraphQLResolveInfo
from infrahub_sdk.utils import extract_fields
from otel_extensions import instrumented

from infrahub import config, lock
from infrahub.core import registry
Expand Down Expand Up @@ -44,6 +45,7 @@ class Arguments:
object = Field(BranchType)

@classmethod
@instrumented
async def mutate(cls, root: dict, info: GraphQLResolveInfo, data: BranchCreateInput, background_execution=False):
db: InfrahubDatabase = info.context.get("infrahub_database")

Expand Down
37 changes: 21 additions & 16 deletions backend/infrahub/message_bus/operations/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import json

from otel_extensions import get_tracer

from infrahub.log import get_logger
from infrahub.message_bus import InfrahubResponse, messages
from infrahub.message_bus.operations import check, event, finalize, git, refresh, requests, send, transform, trigger
Expand Down Expand Up @@ -53,19 +55,22 @@


async def execute_message(routing_key: str, message_body: bytes, service: InfrahubServices):
message_data = json.loads(message_body)
message = messages.MESSAGE_MAP[routing_key](**message_data)
message.set_log_data(routing_key=routing_key)
try:
await COMMAND_MAP[routing_key](message=message, service=service)
except Exception as exc: # pylint: disable=broad-except
if message.reply_requested:
response = InfrahubResponse(passed=False, response_class="rpc_error", response_data={"error": str(exc)})
await service.reply(message=response, initiator=message)
return
if message.reached_max_retries:
log.error("Message failed after maximum number of retries", error=str(exc))
await set_check_status(message, conclusion="failure", service=service)
return
message.increase_retry_count()
await service.send(message, delay=MessageTTL.FIVE)
with get_tracer(__name__).start_as_current_span("execute_message") as span:
span.set_attribute("routing_key", routing_key)

message_data = json.loads(message_body)
message = messages.MESSAGE_MAP[routing_key](**message_data)
message.set_log_data(routing_key=routing_key)
try:
await COMMAND_MAP[routing_key](message=message, service=service)
except Exception as exc: # pylint: disable=broad-except
if message.reply_requested:
response = InfrahubResponse(passed=False, response_class="rpc_error", response_data={"error": str(exc)})
await service.reply(message=response, initiator=message)
return
if message.reached_max_retries:
log.error("Message failed after maximum number of retries", error=str(exc))
await set_check_status(message, conclusion="failure", service=service)
return
message.increase_retry_count()
await service.send(message, delay=MessageTTL.FIVE)
35 changes: 21 additions & 14 deletions backend/infrahub/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
from infrahub_sdk.timestamp import TimestampFormatError
from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor
from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor, Span
from pydantic import ValidationError
from starlette_exporter import PrometheusMiddleware, handle_metrics

Expand All @@ -32,7 +32,7 @@
from infrahub.services import InfrahubServices, services
from infrahub.services.adapters.cache.redis import RedisCache
from infrahub.services.adapters.message_bus.rabbitmq import RabbitMQMessageBus
from infrahub.trace import add_span_exception, configure_trace, get_traceid, get_tracer
from infrahub.trace import add_span_exception, configure_trace, get_traceid
from infrahub.worker import WORKER_IDENTITY


Expand All @@ -46,8 +46,8 @@ async def app_initialization(application: FastAPI) -> None:
# Initialize trace
if config.SETTINGS.trace.enable:
configure_trace(
service="infrahub-server",
version=__version__,
exporter_type=config.SETTINGS.trace.exporter_type,
exporter_endpoint=config.SETTINGS.trace.trace_endpoint,
exporter_protocol=config.SETTINGS.trace.exporter_protocol,
)
Expand Down Expand Up @@ -101,8 +101,13 @@ async def lifespan(application: FastAPI):
redoc_url="/api/redoc",
)

FastAPIInstrumentor().instrument_app(app, excluded_urls=".*/metrics")
tracer = get_tracer()

def server_request_hook(span: Span, scope: dict): # pylint: disable=unused-argument
if span and span.is_recording():
span.set_attribute("worker", WORKER_IDENTITY)


FastAPIInstrumentor().instrument_app(app, excluded_urls=".*/metrics", server_request_hook=server_request_hook)

FRONTEND_DIRECTORY = os.environ.get("INFRAHUB_FRONTEND_DIRECTORY", os.path.abspath("frontend"))
FRONTEND_ASSET_DIRECTORY = f"{FRONTEND_DIRECTORY}/dist/assets"
Expand All @@ -121,15 +126,17 @@ async def lifespan(application: FastAPI):
async def logging_middleware(request: Request, call_next: Callable[[Request], Awaitable[Response]]) -> Response:
clear_log_context()
request_id = correlation_id.get()
with tracer.start_as_current_span("processing request " + request_id):
trace_id = get_traceid()
set_log_data(key="request_id", value=request_id)
set_log_data(key="app", value="infrahub.api")
set_log_data(key="worker", value=WORKER_IDENTITY)
if trace_id:
set_log_data(key="trace_id", value=trace_id)
response = await call_next(request)
return response

set_log_data(key="request_id", value=request_id)
set_log_data(key="app", value="infrahub.api")
set_log_data(key="worker", value=WORKER_IDENTITY)

trace_id = get_traceid()
if trace_id:
set_log_data(key="trace_id", value=trace_id)

response = await call_next(request)
return response


@app.middleware("http")
Expand Down
61 changes: 51 additions & 10 deletions backend/infrahub/services/adapters/message_bus/rabbitmq.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,11 @@
from typing import TYPE_CHECKING, Awaitable, Callable, List, MutableMapping, Optional

import aio_pika
import opentelemetry.instrumentation.aio_pika.span_builder
from infrahub_sdk import UUIDT
from opentelemetry import context, propagate
from opentelemetry.instrumentation.aio_pika import AioPikaInstrumentor
from opentelemetry.semconv.trace import SpanAttributes

from infrahub import config
from infrahub.components import ComponentType
Expand All @@ -24,12 +28,36 @@
AbstractQueue,
AbstractRobustConnection,
)
from opentelemetry.instrumentation.aio_pika.span_builder import SpanBuilder

from infrahub.services import InfrahubServices

MessageFunction = Callable[[InfrahubMessage], Awaitable[None]]


AioPikaInstrumentor().instrument()


# TODO: remove this once https://github.com/open-telemetry/opentelemetry-python-contrib/issues/1835 is resolved
def patch_spanbuilder_set_channel() -> None:
"""
The default SpanBuilder.set_channel does not work with aio_pika 9.1 and the refactored connection
attribute
"""

def set_channel(self: SpanBuilder, channel: AbstractChannel) -> None:
if hasattr(channel, "_connection"):
url = channel._connection.url
self._attributes.update(
{
SpanAttributes.NET_PEER_NAME: url.host,
SpanAttributes.NET_PEER_PORT: url.port,
}
)

opentelemetry.instrumentation.aio_pika.span_builder.SpanBuilder.set_channel = set_channel # type: ignore


async def _add_request_id(message: InfrahubMessage) -> None:
log_data = get_log_data()
message.meta.request_id = log_data.get("request_id", "")
Expand All @@ -53,6 +81,8 @@ def __init__(
self.futures: MutableMapping[str, asyncio.Future] = {}

async def initialize(self, service: InfrahubServices) -> None:
patch_spanbuilder_set_channel()

self.service = service
if self.service.component_type == ComponentType.API_SERVER:
await self._initialize_api_server()
Expand Down Expand Up @@ -183,17 +213,28 @@ async def subscribe(self) -> None:
async for message in qiterator:
try:
async with message.process(requeue=False):
# auto instrumentation not supported yet for RPCs, do it ourselves...
token = None
headers = message.headers or {}
ctx = propagate.extract(headers)
if ctx is not None:
token = context.attach(ctx)

clear_log_context()
if message.routing_key in messages.MESSAGE_MAP:
await execute_message(
routing_key=message.routing_key, message_body=message.body, service=self.service
)
else:
self.service.log.error(
"Unhandled routing key for message",
routing_key=message.routing_key,
message=message.body,
)
try:
if message.routing_key in messages.MESSAGE_MAP:
await execute_message(
routing_key=message.routing_key, message_body=message.body, service=self.service
)
else:
self.service.log.error(
"Unhandled routing key for message",
routing_key=message.routing_key,
message=message.body,
)
finally:
if token is not None:
context.detach(token)

except Exception: # pylint: disable=broad-except
self.service.log.exception("Processing error for message %r" % message)
Expand Down
56 changes: 7 additions & 49 deletions backend/infrahub/trace.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,6 @@
from opentelemetry import trace
from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import (
OTLPSpanExporter as GRPCSpanExporter,
)
from opentelemetry.exporter.otlp.proto.http.trace_exporter import (
OTLPSpanExporter as HTTPSpanExporter,
)
from opentelemetry.sdk.resources import Resource
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import BatchSpanProcessor, ConsoleSpanExporter
from opentelemetry.trace import StatusCode


def get_tracer(name: str = "infrahub") -> trace.Tracer:
return trace.get_tracer(name)
from otel_extensions import TelemetryOptions, init_telemetry_provider


def get_current_span_with_context() -> trace.Span:
Expand Down Expand Up @@ -54,42 +42,12 @@ def add_span_exception(exception: Exception) -> None:
current_span.record_exception(exception)


def create_tracer_provider(
version: str, exporter_type: str, exporter_endpoint: str = None, exporter_protocol: str = None
) -> TracerProvider:
# Create a BatchSpanProcessor exporter based on the type
if exporter_type == "console":
exporter = ConsoleSpanExporter()
elif exporter_type == "otlp":
if not exporter_endpoint:
raise ValueError("Exporter type is set to otlp but endpoint is not set")
if exporter_protocol == "http/protobuf":
exporter = HTTPSpanExporter(endpoint=exporter_endpoint)
elif exporter_protocol == "grpc":
exporter = GRPCSpanExporter(endpoint=exporter_endpoint)
else:
raise ValueError("Exporter type unsupported by Infrahub")

# Resource can be required for some backends, e.g. Jaeger
resource = Resource(attributes={"service.name": "infrahub", "service.version": version})
span_processor = BatchSpanProcessor(exporter)
tracer_provider = TracerProvider(resource=resource)
tracer_provider.add_span_processor(span_processor)

return tracer_provider


def configure_trace(
version: str, exporter_type: str, exporter_endpoint: str = None, exporter_protocol: str = None
service: str, version: str, exporter_endpoint: str | None = None, exporter_protocol: str = None
) -> None:
# Create a trace provider with the exporter
tracer_provider = create_tracer_provider(
version=version,
exporter_type=exporter_type,
exporter_endpoint=exporter_endpoint,
exporter_protocol=exporter_protocol,
options = TelemetryOptions(
OTEL_SERVICE_NAME=service,
OTEL_EXPORTER_OTLP_ENDPOINT=exporter_endpoint,
OTEL_EXPORTER_OTLP_PROTOCOL=exporter_protocol,
)
tracer_provider.get_tracer(__name__)

# Register the trace provider
trace.set_tracer_provider(tracer_provider)
init_telemetry_provider(options, **{"service.version": version})
9 changes: 8 additions & 1 deletion development/docker-compose.override.yml.tmp
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ version: "3.4"
services:
# --------------------------------------------------------------------------------
# - Prometheus to collect all metrics endpoints
# - Tempo to receive traces
# - Tempo or Jaeger to receive traces
# - Grafana to visualize these metrics
# - Loki to receive logs from promtail
# - Promtail to parse logs from different source
Expand Down Expand Up @@ -43,6 +43,13 @@ services:
ports:
- "3200:3200"

# jaeger:
# image: jaegertracing/all-in-one:1.53
# environment:
# COLLECTOR_ZIPKIN_HOST_PORT: ":9411"
# ports:
# - "16686:16686"

prometheus:
image: prom/prometheus:latest
volumes:
Expand Down
Loading

0 comments on commit 333d166

Please sign in to comment.