Skip to content

Commit

Permalink
Add Flow run OTEL instrumentation (#16010)
Browse files Browse the repository at this point in the history
Co-authored-by: Chris Pickett <[email protected]>
  • Loading branch information
collincchoy and bunchesofdonald authored Nov 14, 2024
1 parent 7cdf151 commit 377505b
Show file tree
Hide file tree
Showing 10 changed files with 435 additions and 5 deletions.
1 change: 1 addition & 0 deletions requirements-client.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ httpx[http2] >= 0.23, != 0.23.2
importlib_metadata >= 4.4; python_version < '3.10'
jsonpatch >= 1.32, < 2.0
jsonschema >= 4.0.0, < 5.0.0
opentelemetry-api >= 1.27.0, < 2.0.0
orjson >= 3.7, < 4.0
packaging >= 21.3, < 24.3
pathspec >= 0.8.0
Expand Down
1 change: 0 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ graphviz >= 0.20.1
jinja2 >= 3.0.0, < 4.0.0
jinja2-humanize-extension >= 0.4.0
humanize >= 4.9.0, < 5.0.0
opentelemetry-api >= 1.27.0, < 2.0.0
pytz >= 2021.1, < 2025
readchar >= 4.0.0, < 5.0.0
sqlalchemy[asyncio] >= 2.0, < 3.0.0
Expand Down
5 changes: 5 additions & 0 deletions src/prefect/client/schemas/objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -560,6 +560,11 @@ class FlowRun(ObjectBaseModel):
description="A list of tags on the flow run",
examples=[["tag-1", "tag-2"]],
)
labels: KeyValueLabels = Field(
default_factory=dict,
description="Prefect Cloud: A dictionary of key-value labels. Values can be strings, numbers, or booleans.",
examples=[{"key": "value1", "key2": 42}],
)
parent_task_run_id: Optional[UUID] = Field(
default=None,
description=(
Expand Down
57 changes: 56 additions & 1 deletion src/prefect/flow_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,11 @@
)
from uuid import UUID

from opentelemetry import trace
from opentelemetry.trace import Tracer, get_tracer
from typing_extensions import ParamSpec

import prefect
from prefect import Task
from prefect.client.orchestration import SyncPrefectClient, get_client
from prefect.client.schemas import FlowRun, TaskRun
Expand Down Expand Up @@ -124,6 +127,10 @@ class FlowRunEngine(Generic[P, R]):
_client: Optional[SyncPrefectClient] = None
short_circuit: bool = False
_flow_run_name_set: bool = False
_tracer: Tracer = field(
default_factory=lambda: get_tracer("prefect", prefect.__version__)
)
_span: Optional[trace.Span] = None

def __post_init__(self):
if self.flow is None and self.flow_run_id is None:
Expand Down Expand Up @@ -233,6 +240,17 @@ def set_state(self, state: State, force: bool = False) -> State:
self.flow_run.state = state # type: ignore
self.flow_run.state_name = state.name # type: ignore
self.flow_run.state_type = state.type # type: ignore

if self._span:
self._span.add_event(
state.name,
{
"prefect.state.message": state.message or "",
"prefect.state.type": state.type,
"prefect.state.name": state.name or state.type,
"prefect.state.id": str(state.id),
},
)
return state

def result(self, raise_on_failure: bool = True) -> "Union[R, State, None]":
Expand Down Expand Up @@ -281,6 +299,9 @@ def handle_success(self, result: R) -> R:
)
self.set_state(terminal_state)
self._return_value = resolved_result

self._end_span_on_success()

return result

def handle_exception(
Expand Down Expand Up @@ -311,6 +332,9 @@ def handle_exception(
)
state = self.set_state(Running())
self._raised = exc

self._end_span_on_error(exc, state.message)

return state

def handle_timeout(self, exc: TimeoutError) -> None:
Expand All @@ -329,13 +353,32 @@ def handle_timeout(self, exc: TimeoutError) -> None:
self.set_state(state)
self._raised = exc

self._end_span_on_error(exc, message)

def handle_crash(self, exc: BaseException) -> None:
state = run_coro_as_sync(exception_to_crashed_state(exc))
self.logger.error(f"Crash detected! {state.message}")
self.logger.debug("Crash details:", exc_info=exc)
self.set_state(state, force=True)
self._raised = exc

self._end_span_on_error(exc, state.message)

def _end_span_on_success(self):
if not self._span:
return
self._span.set_status(trace.Status(trace.StatusCode.OK))
self._span.end(time.time_ns())
self._span = None

def _end_span_on_error(self, exc: BaseException, description: Optional[str]):
if not self._span:
return
self._span.record_exception(exc)
self._span.set_status(trace.Status(trace.StatusCode.ERROR, description))
self._span.end(time.time_ns())
self._span = None

def load_subflow_run(
self,
parent_task_run: TaskRun,
Expand Down Expand Up @@ -578,6 +621,18 @@ def initialize_run(self):
flow_version=self.flow.version,
empirical_policy=self.flow_run.empirical_policy,
)

self._span = self._tracer.start_span(
name=self.flow_run.name,
attributes={
**self.flow_run.labels,
"prefect.run.type": "flow",
"prefect.run.id": str(self.flow_run.id),
"prefect.tags": self.flow_run.tags,
"prefect.flow.name": self.flow.name,
},
)

try:
yield self

Expand Down Expand Up @@ -632,7 +687,7 @@ def cancel_all_tasks(self):

@contextmanager
def start(self) -> Generator[None, None, None]:
with self.initialize_run():
with self.initialize_run(), trace.use_span(self._span):
self.begin_run()

if self.state.is_running():
Expand Down
1 change: 1 addition & 0 deletions src/prefect/telemetry/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@

4 changes: 2 additions & 2 deletions src/prefect/telemetry/processors.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import time
from threading import Event, Lock, Thread
from typing import Optional
from typing import Dict, Optional

from opentelemetry.context import Context
from opentelemetry.sdk.trace import ReadableSpan, Span, SpanProcessor
Expand All @@ -10,7 +10,7 @@
class InFlightSpanProcessor(SpanProcessor):
def __init__(self, span_exporter: SpanExporter):
self.span_exporter = span_exporter
self._in_flight = {}
self._in_flight: Dict[int, Span] = {}
self._lock = Lock()
self._stop_event = Event()
self._export_thread = Thread(target=self._export_periodically, daemon=True)
Expand Down
1 change: 1 addition & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@
from .fixtures.events import *
from .fixtures.logging import *
from .fixtures.storage import *
from .fixtures.telemetry import *
from .fixtures.time import *


Expand Down
9 changes: 9 additions & 0 deletions tests/fixtures/telemetry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import pytest
from tests.telemetry.instrumentation_tester import InstrumentationTester


@pytest.fixture
def instrumentation():
instrumentation_tester = InstrumentationTester()
yield instrumentation_tester
instrumentation_tester.reset()
105 changes: 105 additions & 0 deletions tests/telemetry/instrumentation_tester.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
from typing import Any, Dict, Protocol, Tuple, Union

from opentelemetry import metrics as metrics_api
from opentelemetry import trace as trace_api
from opentelemetry.sdk.metrics import MeterProvider
from opentelemetry.sdk.metrics.export import InMemoryMetricReader
from opentelemetry.sdk.trace import ReadableSpan, Span, TracerProvider, export
from opentelemetry.sdk.trace.export.in_memory_span_exporter import (
InMemorySpanExporter,
)
from opentelemetry.test.globals_test import (
reset_metrics_globals,
reset_trace_globals,
)
from opentelemetry.util.types import Attributes


def create_tracer_provider(**kwargs) -> Tuple[TracerProvider, InMemorySpanExporter]:
"""Helper to create a configured tracer provider.
Creates and configures a `TracerProvider` with a
`SimpleSpanProcessor` and a `InMemorySpanExporter`.
All the parameters passed are forwarded to the TracerProvider
constructor.
Returns:
A list with the tracer provider in the first element and the
in-memory span exporter in the second.
"""
tracer_provider = TracerProvider(**kwargs)
memory_exporter = InMemorySpanExporter()
span_processor = export.SimpleSpanProcessor(memory_exporter)
tracer_provider.add_span_processor(span_processor)

return tracer_provider, memory_exporter


def create_meter_provider(**kwargs) -> Tuple[MeterProvider, InMemoryMetricReader]:
"""Helper to create a configured meter provider
Creates a `MeterProvider` and an `InMemoryMetricReader`.
Returns:
A tuple with the meter provider in the first element and the
in-memory metrics exporter in the second
"""
memory_reader = InMemoryMetricReader()
metric_readers = kwargs.get("metric_readers", [])
metric_readers.append(memory_reader)
kwargs["metric_readers"] = metric_readers
meter_provider = MeterProvider(**kwargs)
return meter_provider, memory_reader


class HasAttributesViaProperty(Protocol):
@property
def attributes(self) -> Attributes:
...


class HasAttributesViaAttr(Protocol):
attributes: Attributes


HasAttributes = Union[HasAttributesViaProperty, HasAttributesViaAttr]


class InstrumentationTester:
tracer_provider: TracerProvider
memory_exporter: InMemorySpanExporter
meter_provider: MeterProvider
memory_metrics_reader: InMemoryMetricReader

def __init__(self):
self.tracer_provider, self.memory_exporter = create_tracer_provider()
# This is done because set_tracer_provider cannot override the
# current tracer provider.
reset_trace_globals()
trace_api.set_tracer_provider(self.tracer_provider)

self.memory_exporter.clear()
# This is done because set_meter_provider cannot override the
# current meter provider.
reset_metrics_globals()

self.meter_provider, self.memory_metrics_reader = create_meter_provider()
metrics_api.set_meter_provider(self.meter_provider)

def reset(self):
reset_trace_globals()
reset_metrics_globals()

def get_finished_spans(self):
return self.memory_exporter.get_finished_spans()

@staticmethod
def assert_has_attributes(obj: HasAttributes, attributes: Dict[str, Any]):
assert obj.attributes is not None
for key, val in attributes.items():
assert key in obj.attributes
assert obj.attributes[key] == val

@staticmethod
def assert_span_instrumented_for(span: Union[Span, ReadableSpan], module):
assert span.instrumentation_scope is not None
assert span.instrumentation_scope.name == module.__name__
assert span.instrumentation_scope.version == module.__version__
Loading

0 comments on commit 377505b

Please sign in to comment.