Skip to content

Commit

Permalink
fix: in memory tracer export tests
Browse files Browse the repository at this point in the history
  • Loading branch information
NiklasKoehneckeAA committed May 16, 2024
1 parent 6c8af65 commit 6e3ca61
Show file tree
Hide file tree
Showing 6 changed files with 144 additions and 125 deletions.
4 changes: 1 addition & 3 deletions src/intelligence_layer/core/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,7 @@ def run(
Returns:
Generic output defined by the task implementation.
"""
with tracer.task_span(
type(self).__name__, input, trace_id=trace_id
) as task_span:
with tracer.task_span(type(self).__name__, input) as task_span:
output = self.do_run(input, task_span)
task_span.record_output(output)
return output
Expand Down
169 changes: 91 additions & 78 deletions src/intelligence_layer/core/tracer/in_memory_tracer.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,31 @@
import json
import os
from datetime import datetime
from typing import Optional, Sequence, Union
from uuid import UUID

import requests
import rich
from pydantic import BaseModel, Field, SerializeAsAny
from pydantic import BaseModel, SerializeAsAny
from requests import HTTPError
from rich.tree import Tree

from intelligence_layer.core.tracer.tracer import (
Context,
EndSpan,
EndTask,
Event,
ExportedSpan,
ExportedSpanList,
LogEntry,
LogLine,
PlainEntry,
PydanticSerializable,
Span,
SpanAttributes,
SpanStatus,
StartSpan,
StartTask,
TaskSpan,
TaskSpanAttributes,
Tracer,
_render_log_value,
utc_now,
)


class InMemoryTracer(BaseModel, Tracer):
class InMemoryTracer(Tracer):
"""Collects log entries in a nested structure, and keeps them in memory.
If desired, the structure is serializable with Pydantic, so you can write out the JSON
Expand All @@ -44,7 +37,9 @@ class InMemoryTracer(BaseModel, Tracer):
log entries.
"""

entries: list[Union[LogEntry, "InMemoryTaskSpan", "InMemorySpan"]] = []
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.entries: list[Union[LogEntry, "InMemoryTaskSpan", "InMemorySpan"]] = []

def span(
self,
Expand Down Expand Up @@ -95,7 +90,8 @@ def submit_to_trace_viewer(self) -> bool:
trace_viewer_trace_upload = f"{trace_viewer_url}/trace"
try:
res = requests.post(
trace_viewer_trace_upload, json=json.loads(self.model_dump_json())
trace_viewer_trace_upload,
json=ExportedSpanList(self.export_for_viewing()).model_dump_json(),
)
if res.status_code != 200:
raise HTTPError(res.status_code)
Expand All @@ -122,12 +118,19 @@ def export_for_viewing(self) -> Sequence[ExportedSpan]:


class InMemorySpan(InMemoryTracer, Span):
name: str
start_timestamp: datetime = Field(default_factory=datetime.utcnow)
end_timestamp: Optional[datetime] = None

def id(self) -> str:
return self.trace_id
def __init__(
self,
name: str,
context: Optional[Context] = None,
start_timestamp: Optional[datetime] = None,
) -> None:
super().__init__(context=context)
self.parent_id = None if context is None else context.span_id
self.name = name
self.start_timestamp = (
start_timestamp if start_timestamp is not None else utc_now()
)
self.end_timestamp: datetime | None = None

def log(
self,
Expand All @@ -140,7 +143,7 @@ def log(
message=message,
value=value,
timestamp=timestamp or utc_now(),
trace_id=self.id(),
trace_id=self.context.span_id,
)
)

Expand All @@ -157,6 +160,9 @@ def _rich_render_(self) -> Tree:

return tree

def _span_attributes(self) -> SpanAttributes:
return SpanAttributes()

def export_for_viewing(self) -> Sequence[ExportedSpan]:
logs: list[LogEntry] = []
exported_spans: list[ExportedSpan] = []
Expand All @@ -167,12 +173,12 @@ def export_for_viewing(self) -> Sequence[ExportedSpan]:
exported_spans.extend(entry.export_for_viewing())
exported_spans.append(
ExportedSpan(
context=Context(trace_id=self.id(), span_id="?"),
context=self.context,
name=self.name,
parent_id=self.parent_id,
start_time=self.start_timestamp,
end_time=self.end_timestamp,
attributes=SpanAttributes(),
attributes=self._span_attributes(),
events=[
Event(
name="log",
Expand All @@ -182,19 +188,30 @@ def export_for_viewing(self) -> Sequence[ExportedSpan]:
)
for log in logs
],
status=SpanStatus.OK,
status=self.status_code,
)
)
return exported_spans


class InMemoryTaskSpan(InMemorySpan, TaskSpan):
input: SerializeAsAny[PydanticSerializable]
output: Optional[SerializeAsAny[PydanticSerializable]] = None
def __init__(
self,
name: str,
input: SerializeAsAny[PydanticSerializable],
context: Optional[Context] = None,
start_timestamp: Optional[datetime] = None,
) -> None:
super().__init__(name=name, context=context, start_timestamp=start_timestamp)
self.input = input
self.output: SerializeAsAny[PydanticSerializable] | None = None

def record_output(self, output: PydanticSerializable) -> None:
self.output = output

def _span_attributes(self) -> SpanAttributes:
return TaskSpanAttributes(input=self.input, output=self.output)

def _rich_render_(self) -> Tree:
"""Renders the trace via classes in the `rich` package"""
tree = Tree(label=self.name)
Expand All @@ -210,56 +227,52 @@ def _rich_render_(self) -> Tree:


class TreeBuilder(BaseModel):
root: InMemoryTracer = InMemoryTracer()
tracers: dict[UUID, InMemoryTracer] = Field(default_factory=dict)
tasks: dict[UUID, InMemoryTaskSpan] = Field(default_factory=dict)
spans: dict[UUID, InMemorySpan] = Field(default_factory=dict)

def start_task(self, log_line: LogLine) -> None:
start_task = StartTask.model_validate(log_line.entry)
child = InMemoryTaskSpan(
name=start_task.name,
input=start_task.input,
start_timestamp=start_task.start,
trace_id=start_task.trace_id,
)
self.tracers[start_task.uuid] = child
self.tasks[start_task.uuid] = child
self.tracers.get(start_task.parent, self.root).entries.append(child)

def end_task(self, log_line: LogLine) -> None:
end_task = EndTask.model_validate(log_line.entry)
task_span = self.tasks[end_task.uuid]
task_span.end_timestamp = end_task.end
task_span.record_output(end_task.output)

def start_span(self, log_line: LogLine) -> None:
start_span = StartSpan.model_validate(log_line.entry)
child = InMemorySpan(
name=start_span.name,
start_timestamp=start_span.start,
trace_id=start_span.trace_id,
)
self.tracers[start_span.uuid] = child
self.spans[start_span.uuid] = child
self.tracers.get(start_span.parent, self.root).entries.append(child)

def end_span(self, log_line: LogLine) -> None:
end_span = EndSpan.model_validate(log_line.entry)
span = self.spans[end_span.uuid]
span.end_timestamp = end_span.end

def plain_entry(self, log_line: LogLine) -> None:
plain_entry = PlainEntry.model_validate(log_line.entry)
entry = LogEntry(
message=plain_entry.message,
value=plain_entry.value,
timestamp=plain_entry.timestamp,
trace_id=plain_entry.trace_id,
)
self.tracers[plain_entry.parent].entries.append(entry)


# Required for sphinx, see also: https://docs.pydantic.dev/2.4/errors/usage_errors/#class-not-fully-defined
InMemorySpan.model_rebuild()
InMemoryTracer.model_rebuild()
pass
# root: InMemoryTracer = InMemoryTracer()
# tracers: dict[UUID, InMemoryTracer] = Field(default_factory=dict)
# tasks: dict[UUID, InMemoryTaskSpan] = Field(default_factory=dict)
# spans: dict[UUID, InMemorySpan] = Field(default_factory=dict)

# def start_task(self, log_line: LogLine) -> None:
# start_task = StartTask.model_validate(log_line.entry)
# child = InMemoryTaskSpan(
# name=start_task.name,
# input=start_task.input,
# start_timestamp=start_task.start,
# trace_id=start_task.trace_id,
# )
# self.tracers[start_task.uuid] = child
# self.tasks[start_task.uuid] = child
# self.tracers.get(start_task.parent, self.root).entries.append(child)

# def end_task(self, log_line: LogLine) -> None:
# end_task = EndTask.model_validate(log_line.entry)
# task_span = self.tasks[end_task.uuid]
# task_span.end_timestamp = end_task.end
# task_span.record_output(end_task.output)

# def start_span(self, log_line: LogLine) -> None:
# start_span = StartSpan.model_validate(log_line.entry)
# child = InMemorySpan(
# name=start_span.name,
# start_timestamp=start_span.start,
# trace_id=start_span.trace_id,
# )
# self.tracers[start_span.uuid] = child
# self.spans[start_span.uuid] = child
# self.tracers.get(start_span.parent, self.root).entries.append(child)

# def end_span(self, log_line: LogLine) -> None:
# end_span = EndSpan.model_validate(log_line.entry)
# span = self.spans[end_span.uuid]
# span.end_timestamp = end_span.end

# def plain_entry(self, log_line: LogLine) -> None:
# plain_entry = PlainEntry.model_validate(log_line.entry)
# entry = LogEntry(
# message=plain_entry.message,
# value=plain_entry.value,
# timestamp=plain_entry.timestamp,
# trace_id=plain_entry.trace_id,
# )
# self.tracers[plain_entry.parent].entries.append(entry)
24 changes: 13 additions & 11 deletions src/intelligence_layer/core/tracer/tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from datetime import datetime, timezone
from enum import Enum
from types import TracebackType
from typing import TYPE_CHECKING, Mapping, Optional, Sequence
from typing import TYPE_CHECKING, List, Mapping, Optional, Sequence
from uuid import UUID, uuid4

from pydantic import BaseModel, Field, RootModel, SerializeAsAny
Expand Down Expand Up @@ -51,7 +51,7 @@ def utc_now() -> datetime:
return datetime.now(timezone.utc)


class Event:
class Event(BaseModel):
name: str
message: str
body: SerializeAsAny[PydanticSerializable]
Expand Down Expand Up @@ -83,7 +83,7 @@ class Context(BaseModel):
span_id: str


class ExportedSpan:
class ExportedSpan(BaseModel):
context: Context
name: str | None
parent_id: str | None
Expand All @@ -95,6 +95,10 @@ class ExportedSpan:
# we ignore the links concept


class ExportedSpanList(RootModel):
root: List[ExportedSpan]


class Tracer(ABC):
"""Provides a consistent way to instrument a :class:`Task` with logging for each step of the
workflow.
Expand Down Expand Up @@ -200,12 +204,14 @@ class Span(Tracer, AbstractContextManager["Span"]):
"""

def __init__(self, context: Optional[Context] = None):
super().__init__()
span_id = str(uuid4())
if context is None:
trace_id = str(uuid4())
trace_id = span_id
else:
trace_id = self.context.trace_id
span_id = str(uuid4())
trace_id = context.trace_id
self.context = Context(trace_id=trace_id, span_id=span_id)
self.status_code = SpanStatus.OK

def __enter__(self) -> Self:
return self
Expand Down Expand Up @@ -264,6 +270,7 @@ def __exit__(
stack_trace=str(traceback.format_exc()),
)
self.log(error_value.message, error_value)
self.status_code = SpanStatus.ERROR
self.end()


Expand Down Expand Up @@ -329,7 +336,6 @@ def span(
self,
name: str,
timestamp: Optional[datetime] = None,
trace_id: Optional[str] = None,
) -> "NoOpTracer":
return self

Expand All @@ -338,7 +344,6 @@ def task_span(
task_name: str,
input: PydanticSerializable,
timestamp: Optional[datetime] = None,
trace_id: Optional[str] = None,
) -> "NoOpTracer":
return self

Expand Down Expand Up @@ -383,9 +388,6 @@ class LogEntry(BaseModel):
timestamp: datetime = Field(default_factory=datetime.utcnow)
trace_id: str

def id(self) -> str:
return self.trace_id

def _rich_render_(self) -> Panel:
"""Renders the trace via classes in the `rich` package"""
return _render_log_value(self.value, self.message)
Expand Down
8 changes: 8 additions & 0 deletions tests/core/tracer/test_composite_tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,14 @@
)


def test_composite_tracer(test_task: Task[str, str]) -> None:
tracer1 = InMemoryTracer()
tracer2 = InMemoryTracer()
test_task.run(input=input, tracer=CompositeTracer([tracer1, tracer2]))

assert tracer1 == tracer2


def test_composite_tracer_id_consistent_across_children(
file_tracer: FileTracer, test_task: Task[str, str]
) -> None:
Expand Down
Loading

0 comments on commit 6e3ca61

Please sign in to comment.