Skip to content

Commit

Permalink
addressing comments:
Browse files Browse the repository at this point in the history
1. clean up prompt calculations to be deduced from last streaming chunk
2. save correct span name
3. remove recording exceptions and setting status to ok
4. remove saving stream chunks in events
  • Loading branch information
alizenhom committed Aug 13, 2024
1 parent e15d443 commit 1efdfcd
Showing 1 changed file with 33 additions and 95 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,21 +13,19 @@
# limitations under the License.

import json
from typing import Optional, Union
from opentelemetry import trace
from opentelemetry.trace import SpanKind, Span
from opentelemetry.trace.status import Status, StatusCode
from opentelemetry.trace.propagation import set_span_in_context
from openai._types import NOT_GIVEN
from span_attributes import SpanAttributes, LLMSpanAttributes, Event
from utils import (
estimate_tokens,
silently_fail,
extract_content,
calculate_prompt_tokens,
)
from openai import NOT_GIVEN
from .span_attributes import LLMSpanAttributes, SpanAttributes

from .utils import silently_fail, extract_content
from opentelemetry.trace import Tracer

def chat_completions_create(original_method, version, tracer):

def chat_completions_create(original_method, version, tracer: Tracer):
"""Wrap the `create` method of the `ChatCompletion` class to trace it."""

def traced_method(wrapped, instance, args, kwargs):
Expand Down Expand Up @@ -69,8 +67,10 @@ def traced_method(wrapped, instance, args, kwargs):

attributes = LLMSpanAttributes(**span_attributes)

span_name = f"{attributes.gen_ai_operation_name} {attributes.gen_ai_request_model}"

span = tracer.start_span(
"openai.completion",
name=span_name,
kind=SpanKind.CLIENT,
context=set_span_in_context(trace.get_current_span()),
)
Expand All @@ -79,36 +79,18 @@ def traced_method(wrapped, instance, args, kwargs):
try:
result = wrapped(*args, **kwargs)
if is_streaming(kwargs):
prompt_tokens = 0
for message in kwargs.get("messages", {}):
prompt_tokens += calculate_prompt_tokens(
json.dumps(str(message)), kwargs.get("model")
)

if (
kwargs.get("functions") is not None
and kwargs.get("functions") != NOT_GIVEN
):
for function in kwargs.get("functions"):
prompt_tokens += calculate_prompt_tokens(
json.dumps(function), kwargs.get("model")
)

return StreamWrapper(
result,
span,
prompt_tokens,
function_call=kwargs.get("functions") is not None,
tool_calls=kwargs.get("tools") is not None,
)
else:
_set_response_attributes(span, kwargs, result)
span.set_status(StatusCode.OK)
span.end()
return result

except Exception as error:
span.record_exception(error)
span.set_status(Status(StatusCode.ERROR, str(error)))
span.end()
raise
Expand All @@ -118,21 +100,14 @@ def traced_method(wrapped, instance, args, kwargs):

def get_tool_calls(item):
if isinstance(item, dict):
if "tool_calls" in item and item["tool_calls"] is not None:
return item["tool_calls"]
return None

return item.get("tool_calls")
else:
if hasattr(item, "tool_calls") and item.tool_calls is not None:
return item.tool_calls
return None
return getattr(item, "tool_calls", None)


@silently_fail
def _set_input_attributes(span, kwargs, attributes):
def _set_input_attributes(span, kwargs, attributes: LLMSpanAttributes):
tools = []
for field, value in attributes.model_dump(by_alias=True).items():
set_span_attribute(span, field, value)

if (
kwargs.get("functions") is not None
Expand All @@ -149,6 +124,9 @@ def _set_input_attributes(span, kwargs, attributes):
if tools:
set_span_attribute(span, SpanAttributes.LLM_TOOLS, json.dumps(tools))

for field, value in attributes.model_dump(by_alias=True).items():
set_span_attribute(span, field, value)


@silently_fail
def _set_response_attributes(span, kwargs, result):
Expand Down Expand Up @@ -230,15 +208,6 @@ def set_event_completion(span: Span, result_content):
)


def set_event_completion_chunk(span: Span, chunk):
span.add_event(
name=SpanAttributes.LLM_CONTENT_COMPLETION_CHUNK,
attributes={
SpanAttributes.LLM_CONTENT_COMPLETION_CHUNK: json.dumps(chunk),
},
)


def set_span_attribute(span: Span, name, value):
if value is not None:
if value != "" or value != NOT_GIVEN:
Expand All @@ -250,33 +219,33 @@ def set_span_attribute(span: Span, name, value):


def is_streaming(kwargs):
return not (
kwargs.get("stream") is False
or kwargs.get("stream") is None
or kwargs.get("stream") == NOT_GIVEN
)
return non_numerical_value_is_set(kwargs.get("stream"))


def non_numerical_value_is_set(value: Optional[Union[bool, str]]):
return bool(value) and value != NOT_GIVEN


def get_llm_request_attributes(
kwargs, prompts=None, model=None, operation_name="chat"
):

user = kwargs.get("user", None)
user = kwargs.get("user")
if prompts is None:
prompts = (
[{"role": user or "user", "content": kwargs.get("prompt")}]
if "prompt" in kwargs
else None
)
top_k = (
kwargs.get("n", None)
or kwargs.get("k", None)
or kwargs.get("top_k", None)
or kwargs.get("top_n", None)
kwargs.get("n")
or kwargs.get("k")
or kwargs.get("top_k")
or kwargs.get("top_n")
)

top_p = kwargs.get("p", None) or kwargs.get("top_p", None)
tools = kwargs.get("tools", None)
top_p = kwargs.get("p") or kwargs.get("top_p")
tools = kwargs.get("tools")
return {
SpanAttributes.LLM_OPERATION_NAME: operation_name,
SpanAttributes.LLM_REQUEST_MODEL: model or kwargs.get("model"),
Expand Down Expand Up @@ -308,7 +277,7 @@ def __init__(
self,
stream,
span,
prompt_tokens,
prompt_tokens=None,
function_call=False,
tool_calls=False,
):
Expand All @@ -324,12 +293,10 @@ def __init__(

def setup(self):
if not self._span_started:
self.span.add_event(Event.STREAM_START.value)
self._span_started = True

def cleanup(self):
if self._span_started:
self.span.add_event(Event.STREAM_END.value)
set_span_attribute(
self.span,
SpanAttributes.LLM_USAGE_PROMPT_TOKENS,
Expand Down Expand Up @@ -391,8 +358,6 @@ def process_chunk(self, chunk):
if not self.function_call and not self.tool_calls:
for choice in chunk.choices:
if choice.delta and choice.delta.content is not None:
token_counts = estimate_tokens(choice.delta.content)
self.completion_tokens += token_counts
content = [choice.delta.content]
elif self.function_call:
for choice in chunk.choices:
Expand All @@ -401,10 +366,6 @@ def process_chunk(self, chunk):
and choice.delta.function_call is not None
and choice.delta.function_call.arguments is not None
):
token_counts = estimate_tokens(
choice.delta.function_call.arguments
)
self.completion_tokens += token_counts
content = [choice.delta.function_call.arguments]
elif self.tool_calls:
for choice in chunk.choices:
Expand All @@ -417,40 +378,17 @@ def process_chunk(self, chunk):
and tool_call.function is not None
and tool_call.function.arguments is not None
):
token_counts = estimate_tokens(
tool_call.function.arguments
)
self.completion_tokens += token_counts
content.append(tool_call.function.arguments)
set_event_completion_chunk(
self.span,
(
"".join(content)
if len(content) > 0 and content[0] is not None
else ""
),
)

if content:
self.result_content.append(content[0])

if hasattr(chunk, "text"):
token_counts = estimate_tokens(chunk.text)
self.completion_tokens += token_counts
content = [chunk.text]
set_event_completion_chunk(
self.span,
(
"".join(content)
if len(content) > 0 and content[0] is not None
else ""
),
)

if content:
self.result_content.append(content[0])

if hasattr(chunk, "usage_metadata"):
self.completion_tokens = (
chunk.usage_metadata.candidates_token_count
)
self.prompt_tokens = chunk.usage_metadata.prompt_token_count
if getattr(chunk, "usage"):
self.completion_tokens = chunk.usage.completion_tokens
self.prompt_tokens = chunk.usage.prompt_tokens

0 comments on commit 1efdfcd

Please sign in to comment.