diff --git a/src/intelligence_layer/core/task.py b/src/intelligence_layer/core/task.py index 401d5e69d..a0f6636da 100644 --- a/src/intelligence_layer/core/task.py +++ b/src/intelligence_layer/core/task.py @@ -74,9 +74,7 @@ def run( Returns: Generic output defined by the task implementation. """ - with tracer.task_span( - type(self).__name__, input, trace_id=trace_id - ) as task_span: + with tracer.task_span(type(self).__name__, input) as task_span: output = self.do_run(input, task_span) task_span.record_output(output) return output diff --git a/src/intelligence_layer/core/tracer/in_memory_tracer.py b/src/intelligence_layer/core/tracer/in_memory_tracer.py index 09a8d23d0..9ba9882c2 100644 --- a/src/intelligence_layer/core/tracer/in_memory_tracer.py +++ b/src/intelligence_layer/core/tracer/in_memory_tracer.py @@ -1,38 +1,31 @@ -import json import os from datetime import datetime from typing import Optional, Sequence, Union -from uuid import UUID import requests import rich -from pydantic import BaseModel, Field, SerializeAsAny +from pydantic import BaseModel, SerializeAsAny from requests import HTTPError from rich.tree import Tree from intelligence_layer.core.tracer.tracer import ( Context, - EndSpan, - EndTask, Event, ExportedSpan, + ExportedSpanList, LogEntry, - LogLine, - PlainEntry, PydanticSerializable, Span, SpanAttributes, - SpanStatus, - StartSpan, - StartTask, TaskSpan, + TaskSpanAttributes, Tracer, _render_log_value, utc_now, ) -class InMemoryTracer(BaseModel, Tracer): +class InMemoryTracer(Tracer): """Collects log entries in a nested structure, and keeps them in memory. If desired, the structure is serializable with Pydantic, so you can write out the JSON @@ -44,7 +37,9 @@ class InMemoryTracer(BaseModel, Tracer): log entries. """ - entries: list[Union[LogEntry, "InMemoryTaskSpan", "InMemorySpan"]] = [] + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.entries: list[Union[LogEntry, "InMemoryTaskSpan", "InMemorySpan"]] = [] def span( self, @@ -95,7 +90,8 @@ def submit_to_trace_viewer(self) -> bool: trace_viewer_trace_upload = f"{trace_viewer_url}/trace" try: res = requests.post( - trace_viewer_trace_upload, json=json.loads(self.model_dump_json()) + trace_viewer_trace_upload, + json=ExportedSpanList(self.export_for_viewing()).model_dump_json(), ) if res.status_code != 200: raise HTTPError(res.status_code) @@ -122,12 +118,19 @@ def export_for_viewing(self) -> Sequence[ExportedSpan]: class InMemorySpan(InMemoryTracer, Span): - name: str - start_timestamp: datetime = Field(default_factory=datetime.utcnow) - end_timestamp: Optional[datetime] = None - - def id(self) -> str: - return self.trace_id + def __init__( + self, + name: str, + context: Optional[Context] = None, + start_timestamp: Optional[datetime] = None, + ) -> None: + super().__init__(context=context) + self.parent_id = None if context is None else context.span_id + self.name = name + self.start_timestamp = ( + start_timestamp if start_timestamp is not None else utc_now() + ) + self.end_timestamp: datetime | None = None def log( self, @@ -140,7 +143,7 @@ def log( message=message, value=value, timestamp=timestamp or utc_now(), - trace_id=self.id(), + trace_id=self.context.span_id, ) ) @@ -157,6 +160,9 @@ def _rich_render_(self) -> Tree: return tree + def _span_attributes(self) -> SpanAttributes: + return SpanAttributes() + def export_for_viewing(self) -> Sequence[ExportedSpan]: logs: list[LogEntry] = [] exported_spans: list[ExportedSpan] = [] @@ -167,12 +173,12 @@ def export_for_viewing(self) -> Sequence[ExportedSpan]: exported_spans.extend(entry.export_for_viewing()) exported_spans.append( ExportedSpan( - context=Context(trace_id=self.id(), span_id="?"), + context=self.context, name=self.name, parent_id=self.parent_id, start_time=self.start_timestamp, end_time=self.end_timestamp, - attributes=SpanAttributes(), + attributes=self._span_attributes(), events=[ Event( name="log", @@ -182,19 +188,30 @@ def export_for_viewing(self) -> Sequence[ExportedSpan]: ) for log in logs ], - status=SpanStatus.OK, + status=self.status_code, ) ) return exported_spans class InMemoryTaskSpan(InMemorySpan, TaskSpan): - input: SerializeAsAny[PydanticSerializable] - output: Optional[SerializeAsAny[PydanticSerializable]] = None + def __init__( + self, + name: str, + input: SerializeAsAny[PydanticSerializable], + context: Optional[Context] = None, + start_timestamp: Optional[datetime] = None, + ) -> None: + super().__init__(name=name, context=context, start_timestamp=start_timestamp) + self.input = input + self.output: SerializeAsAny[PydanticSerializable] | None = None def record_output(self, output: PydanticSerializable) -> None: self.output = output + def _span_attributes(self) -> SpanAttributes: + return TaskSpanAttributes(input=self.input, output=self.output) + def _rich_render_(self) -> Tree: """Renders the trace via classes in the `rich` package""" tree = Tree(label=self.name) @@ -210,56 +227,52 @@ def _rich_render_(self) -> Tree: class TreeBuilder(BaseModel): - root: InMemoryTracer = InMemoryTracer() - tracers: dict[UUID, InMemoryTracer] = Field(default_factory=dict) - tasks: dict[UUID, InMemoryTaskSpan] = Field(default_factory=dict) - spans: dict[UUID, InMemorySpan] = Field(default_factory=dict) - - def start_task(self, log_line: LogLine) -> None: - start_task = StartTask.model_validate(log_line.entry) - child = InMemoryTaskSpan( - name=start_task.name, - input=start_task.input, - start_timestamp=start_task.start, - trace_id=start_task.trace_id, - ) - self.tracers[start_task.uuid] = child - self.tasks[start_task.uuid] = child - self.tracers.get(start_task.parent, self.root).entries.append(child) - - def end_task(self, log_line: LogLine) -> None: - end_task = EndTask.model_validate(log_line.entry) - task_span = self.tasks[end_task.uuid] - task_span.end_timestamp = end_task.end - task_span.record_output(end_task.output) - - def start_span(self, log_line: LogLine) -> None: - start_span = StartSpan.model_validate(log_line.entry) - child = InMemorySpan( - name=start_span.name, - start_timestamp=start_span.start, - trace_id=start_span.trace_id, - ) - self.tracers[start_span.uuid] = child - self.spans[start_span.uuid] = child - self.tracers.get(start_span.parent, self.root).entries.append(child) - - def end_span(self, log_line: LogLine) -> None: - end_span = EndSpan.model_validate(log_line.entry) - span = self.spans[end_span.uuid] - span.end_timestamp = end_span.end - - def plain_entry(self, log_line: LogLine) -> None: - plain_entry = PlainEntry.model_validate(log_line.entry) - entry = LogEntry( - message=plain_entry.message, - value=plain_entry.value, - timestamp=plain_entry.timestamp, - trace_id=plain_entry.trace_id, - ) - self.tracers[plain_entry.parent].entries.append(entry) - - -# Required for sphinx, see also: https://docs.pydantic.dev/2.4/errors/usage_errors/#class-not-fully-defined -InMemorySpan.model_rebuild() -InMemoryTracer.model_rebuild() + pass + # root: InMemoryTracer = InMemoryTracer() + # tracers: dict[UUID, InMemoryTracer] = Field(default_factory=dict) + # tasks: dict[UUID, InMemoryTaskSpan] = Field(default_factory=dict) + # spans: dict[UUID, InMemorySpan] = Field(default_factory=dict) + + # def start_task(self, log_line: LogLine) -> None: + # start_task = StartTask.model_validate(log_line.entry) + # child = InMemoryTaskSpan( + # name=start_task.name, + # input=start_task.input, + # start_timestamp=start_task.start, + # trace_id=start_task.trace_id, + # ) + # self.tracers[start_task.uuid] = child + # self.tasks[start_task.uuid] = child + # self.tracers.get(start_task.parent, self.root).entries.append(child) + + # def end_task(self, log_line: LogLine) -> None: + # end_task = EndTask.model_validate(log_line.entry) + # task_span = self.tasks[end_task.uuid] + # task_span.end_timestamp = end_task.end + # task_span.record_output(end_task.output) + + # def start_span(self, log_line: LogLine) -> None: + # start_span = StartSpan.model_validate(log_line.entry) + # child = InMemorySpan( + # name=start_span.name, + # start_timestamp=start_span.start, + # trace_id=start_span.trace_id, + # ) + # self.tracers[start_span.uuid] = child + # self.spans[start_span.uuid] = child + # self.tracers.get(start_span.parent, self.root).entries.append(child) + + # def end_span(self, log_line: LogLine) -> None: + # end_span = EndSpan.model_validate(log_line.entry) + # span = self.spans[end_span.uuid] + # span.end_timestamp = end_span.end + + # def plain_entry(self, log_line: LogLine) -> None: + # plain_entry = PlainEntry.model_validate(log_line.entry) + # entry = LogEntry( + # message=plain_entry.message, + # value=plain_entry.value, + # timestamp=plain_entry.timestamp, + # trace_id=plain_entry.trace_id, + # ) + # self.tracers[plain_entry.parent].entries.append(entry) diff --git a/src/intelligence_layer/core/tracer/tracer.py b/src/intelligence_layer/core/tracer/tracer.py index 764389fad..f859296a9 100644 --- a/src/intelligence_layer/core/tracer/tracer.py +++ b/src/intelligence_layer/core/tracer/tracer.py @@ -4,7 +4,7 @@ from datetime import datetime, timezone from enum import Enum from types import TracebackType -from typing import TYPE_CHECKING, Mapping, Optional, Sequence +from typing import TYPE_CHECKING, List, Mapping, Optional, Sequence from uuid import UUID, uuid4 from pydantic import BaseModel, Field, RootModel, SerializeAsAny @@ -51,7 +51,7 @@ def utc_now() -> datetime: return datetime.now(timezone.utc) -class Event: +class Event(BaseModel): name: str message: str body: SerializeAsAny[PydanticSerializable] @@ -83,7 +83,7 @@ class Context(BaseModel): span_id: str -class ExportedSpan: +class ExportedSpan(BaseModel): context: Context name: str | None parent_id: str | None @@ -95,6 +95,10 @@ class ExportedSpan: # we ignore the links concept +class ExportedSpanList(RootModel): + root: List[ExportedSpan] + + class Tracer(ABC): """Provides a consistent way to instrument a :class:`Task` with logging for each step of the workflow. @@ -200,12 +204,14 @@ class Span(Tracer, AbstractContextManager["Span"]): """ def __init__(self, context: Optional[Context] = None): + super().__init__() + span_id = str(uuid4()) if context is None: - trace_id = str(uuid4()) + trace_id = span_id else: - trace_id = self.context.trace_id - span_id = str(uuid4()) + trace_id = context.trace_id self.context = Context(trace_id=trace_id, span_id=span_id) + self.status_code = SpanStatus.OK def __enter__(self) -> Self: return self @@ -264,6 +270,7 @@ def __exit__( stack_trace=str(traceback.format_exc()), ) self.log(error_value.message, error_value) + self.status_code = SpanStatus.ERROR self.end() @@ -329,7 +336,6 @@ def span( self, name: str, timestamp: Optional[datetime] = None, - trace_id: Optional[str] = None, ) -> "NoOpTracer": return self @@ -338,7 +344,6 @@ def task_span( task_name: str, input: PydanticSerializable, timestamp: Optional[datetime] = None, - trace_id: Optional[str] = None, ) -> "NoOpTracer": return self @@ -383,9 +388,6 @@ class LogEntry(BaseModel): timestamp: datetime = Field(default_factory=datetime.utcnow) trace_id: str - def id(self) -> str: - return self.trace_id - def _rich_render_(self) -> Panel: """Renders the trace via classes in the `rich` package""" return _render_log_value(self.value, self.message) diff --git a/tests/core/tracer/test_composite_tracer.py b/tests/core/tracer/test_composite_tracer.py index 34ce47346..4a33de925 100644 --- a/tests/core/tracer/test_composite_tracer.py +++ b/tests/core/tracer/test_composite_tracer.py @@ -7,6 +7,14 @@ ) +def test_composite_tracer(test_task: Task[str, str]) -> None: + tracer1 = InMemoryTracer() + tracer2 = InMemoryTracer() + test_task.run(input=input, tracer=CompositeTracer([tracer1, tracer2])) + + assert tracer1 == tracer2 + + def test_composite_tracer_id_consistent_across_children( file_tracer: FileTracer, test_task: Task[str, str] ) -> None: diff --git a/tests/core/tracer/test_in_memory_tracer.py b/tests/core/tracer/test_in_memory_tracer.py index 5f15428b7..afebd2760 100644 --- a/tests/core/tracer/test_in_memory_tracer.py +++ b/tests/core/tracer/test_in_memory_tracer.py @@ -3,12 +3,8 @@ from typing import Iterator import pytest -from aleph_alpha_client import Prompt from intelligence_layer.core import ( - CompleteInput, - CompleteOutput, - CompositeTracer, InMemorySpan, InMemoryTaskSpan, InMemoryTracer, @@ -19,31 +15,40 @@ from intelligence_layer.core.tracer.tracer import ErrorValue -def test_tracer_id_exists_for_all_children_of_task_span() -> None: +def test_trace_id_exists_for_all_children_of_task_span() -> None: tracer = InMemoryTracer() - parent_span = tracer.task_span("child", "input", trace_id="ID") + parent_span = tracer.task_span("child", "input") parent_span.span("child2") assert isinstance(tracer.entries[0], InMemorySpan) - assert tracer.entries[0].id() == "ID" - - assert tracer.entries[0].entries[0].id() == tracer.entries[0].id() + assert ( + tracer.entries[0].entries[0].context.trace_id + == tracer.entries[0].context.trace_id + ) parent_span.task_span("child3", "input") - assert tracer.entries[0].entries[1].id() == tracer.entries[0].id() + assert ( + tracer.entries[0].entries[1].context.trace_id + == tracer.entries[0].context.trace_id + ) -def test_tracer_id_exists_for_all_children_of_span() -> None: +def test_trace_id_exists_for_all_children_of_span() -> None: tracer = InMemoryTracer() - parent_span = tracer.span("child", trace_id="ID") + parent_span = tracer.span("child") parent_span.span("child2") assert isinstance(tracer.entries[0], InMemorySpan) - assert tracer.entries[0].id() == "ID" - assert tracer.entries[0].entries[0].id() == tracer.entries[0].id() + assert ( + tracer.entries[0].entries[0].context.trace_id + == tracer.entries[0].context.trace_id + ) parent_span.task_span("child3", "input") - assert tracer.entries[0].entries[1].id() == tracer.entries[0].id() + assert ( + tracer.entries[0].entries[1].context.trace_id + == tracer.entries[0].context.trace_id + ) def test_can_add_child_tracer() -> None: @@ -99,16 +104,15 @@ def test_task_span_records_error_value() -> None: def test_task_automatically_logs_input_and_output( - complete: Task[CompleteInput, CompleteOutput], + test_task: Task[str, str], ) -> None: tracer = InMemoryTracer() - input = CompleteInput(prompt=Prompt.from_text("test")) - output = complete.run(input=input, tracer=tracer) + output = test_task.run(input=input, tracer=tracer) assert len(tracer.entries) == 1 task_span = tracer.entries[0] assert isinstance(task_span, InMemoryTaskSpan) - assert task_span.name == type(complete).__name__ + assert task_span.name == type(test_task).__name__ assert task_span.input == input assert task_span.output == output assert task_span.start_timestamp and task_span.end_timestamp @@ -157,15 +161,6 @@ def test_span_only_updates_end_timestamp_once() -> None: assert span.end_timestamp == end -def test_composite_tracer(complete: Task[CompleteInput, CompleteOutput]) -> None: - tracer1 = InMemoryTracer() - tracer2 = InMemoryTracer() - input = CompleteInput(prompt=Prompt.from_text("test")) - complete.run(input=input, tracer=CompositeTracer([tracer1, tracer2])) - - assert tracer1 == tracer2 - - # take from and modified: https://stackoverflow.com/questions/2059482/temporarily-modify-the-current-processs-environment @contextlib.contextmanager def set_env(name: str, value: str | None) -> Iterator[None]: @@ -182,7 +177,9 @@ def set_env(name: str, value: str | None) -> Iterator[None]: os.environ.update(old_environ) -def test_in_memory_tracer_trace_viewer_doesnt_crash_if_it_cant_reach() -> None: +def test_in_memory_tracer_trace_viewer_doesnt_crash_if_it_cant_reach_document_index() -> ( + None +): # note that this test sets the environment variable, which might # become a problem with multi-worker tests ENV_VARIABLE_NAME = "TRACE_VIEWER_URL" diff --git a/tests/core/tracer/test_tracer.py b/tests/core/tracer/test_tracer.py index 42563a1f0..2de0b794c 100644 --- a/tests/core/tracer/test_tracer.py +++ b/tests/core/tracer/test_tracer.py @@ -123,11 +123,12 @@ def test_tracer_export_nests_correctly( parent, child = child, parent assert parent.name == "name" assert parent.parent_id is None - assert parent.end_time > child.end_time - assert parent.start_time < child.start_time + assert parent.end_time >= child.end_time + assert parent.start_time <= child.start_time assert child.name == "name-2" - assert child.parent_id == parent.id - assert len(child.events) == 0 + assert child.parent_id == parent.context.span_id + assert len(child.events) == 1 + assert len(parent.events) == 0 assert child.context.trace_id == parent.context.trace_id assert child.context.span_id != parent.context.span_id