Skip to content

Commit

Permalink
Add OpenAI sync embedding instrumentation (#938)
Browse files Browse the repository at this point in the history
* Add sync instrumentation for OpenAI embeddings.

* Remove comments.

* Clean up embedding event dictionary.

* Update response_time to duration.

* Linting fixes.

* [Mega-Linter] Apply linters fixes

* Trigger tests

---------

Co-authored-by: umaannamalai <[email protected]>
Co-authored-by: Hannah Stepanek <[email protected]>
  • Loading branch information
3 people committed Oct 17, 2023
1 parent ffabe9f commit 7e20764
Show file tree
Hide file tree
Showing 5 changed files with 360 additions and 194 deletions.
5 changes: 5 additions & 0 deletions newrelic/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2037,6 +2037,11 @@ def _process_trace_cache_import_hooks():


def _process_module_builtin_defaults():
_process_module_definition(
"openai.api_resources.embedding",
"newrelic.hooks.mlmodel_openai",
"instrument_openai_api_resources_embedding",
)
_process_module_definition(
"openai.api_resources.chat_completion",
"newrelic.hooks.mlmodel_openai",
Expand Down
148 changes: 127 additions & 21 deletions newrelic/hooks/mlmodel_openai.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

# Copyright 2010 New Relic, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand All @@ -12,15 +11,83 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import openai

import uuid

import openai

from newrelic.api.function_trace import FunctionTrace
from newrelic.common.object_wrapper import wrap_function_wrapper
from newrelic.api.transaction import current_transaction
from newrelic.api.time_trace import get_trace_linking_metadata
from newrelic.core.config import global_settings
from newrelic.api.transaction import current_transaction
from newrelic.common.object_names import callable_name
from newrelic.core.attribute import MAX_LOG_MESSAGE_LENGTH
from newrelic.common.object_wrapper import wrap_function_wrapper
from newrelic.core.config import global_settings


def wrap_embedding_create(wrapped, instance, args, kwargs):
transaction = current_transaction()
if not transaction:
return

ft_name = callable_name(wrapped)
with FunctionTrace(ft_name) as ft:
response = wrapped(*args, **kwargs)

if not response:
return

available_metadata = get_trace_linking_metadata()
span_id = available_metadata.get("span.id", "")
trace_id = available_metadata.get("trace.id", "")
embedding_id = str(uuid.uuid4())

response_headers = getattr(response, "_nr_response_headers", None)
request_id = response_headers.get("x-request-id", "") if response_headers else ""
response_model = response.get("model", "")
response_usage = response.get("usage", {})

settings = transaction.settings if transaction.settings is not None else global_settings()

embedding_dict = {
"id": embedding_id,
"appName": settings.app_name,
"span_id": span_id,
"trace_id": trace_id,
"request_id": request_id,
"transaction_id": transaction._transaction_id,
"input": kwargs.get("input", ""),
"api_key_last_four_digits": f"sk-{response.api_key[-4:]}",
"duration": ft.duration,
"request.model": kwargs.get("model") or kwargs.get("engine") or "",
"response.model": response_model,
"response.organization": response.organization,
"response.api_type": response.api_type,
"response.usage.total_tokens": response_usage.get("total_tokens", "") if any(response_usage) else "",
"response.usage.prompt_tokens": response_usage.get("prompt_tokens", "") if any(response_usage) else "",
"response.headers.llmVersion": response_headers.get("openai-version", ""),
"response.headers.ratelimitLimitRequests": check_rate_limit_header(
response_headers, "x-ratelimit-limit-requests", True
),
"response.headers.ratelimitLimitTokens": check_rate_limit_header(
response_headers, "x-ratelimit-limit-tokens", True
),
"response.headers.ratelimitResetTokens": check_rate_limit_header(
response_headers, "x-ratelimit-reset-tokens", False
),
"response.headers.ratelimitResetRequests": check_rate_limit_header(
response_headers, "x-ratelimit-reset-requests", False
),
"response.headers.ratelimitRemainingTokens": check_rate_limit_header(
response_headers, "x-ratelimit-remaining-tokens", True
),
"response.headers.ratelimitRemainingRequests": check_rate_limit_header(
response_headers, "x-ratelimit-remaining-requests", True
),
"vendor": "openAI",
}

transaction.record_ml_event("LlmEmbedding", embedding_dict)
return response


def wrap_chat_completion_create(wrapped, instance, args, kwargs):
Expand Down Expand Up @@ -61,9 +128,9 @@ def wrap_chat_completion_create(wrapped, instance, args, kwargs):
"request_id": request_id,
"api_key_last_four_digits": f"sk-{response.api_key[-4:]}",
"duration": ft.duration,
"request.model": kwargs.get("model") or kwargs.get("engine"),
"request.model": kwargs.get("model") or kwargs.get("engine") or "",
"response.model": response_model,
"response.organization": response.organization,
"response.organization": response.organization,
"response.usage.completion_tokens": response_usage.get("completion_tokens", "") if any(response_usage) else "",
"response.usage.total_tokens": response_usage.get("total_tokens", "") if any(response_usage) else "",
"response.usage.prompt_tokens": response_usage.get("prompt_tokens", "") if any(response_usage) else "",
Expand All @@ -72,20 +139,43 @@ def wrap_chat_completion_create(wrapped, instance, args, kwargs):
"response.choices.finish_reason": response.choices[0].finish_reason,
"response.api_type": response.api_type,
"response.headers.llmVersion": response_headers.get("openai-version", ""),
"response.headers.ratelimitLimitRequests": check_rate_limit_header(response_headers, "x-ratelimit-limit-requests", True),
"response.headers.ratelimitLimitTokens": check_rate_limit_header(response_headers, "x-ratelimit-limit-tokens", True),
"response.headers.ratelimitResetTokens": check_rate_limit_header(response_headers, "x-ratelimit-reset-tokens", False),
"response.headers.ratelimitResetRequests": check_rate_limit_header(response_headers, "x-ratelimit-reset-requests", False),
"response.headers.ratelimitRemainingTokens": check_rate_limit_header(response_headers, "x-ratelimit-remaining-tokens", True),
"response.headers.ratelimitRemainingRequests": check_rate_limit_header(response_headers, "x-ratelimit-remaining-requests", True),
"response.headers.ratelimitLimitRequests": check_rate_limit_header(
response_headers, "x-ratelimit-limit-requests", True
),
"response.headers.ratelimitLimitTokens": check_rate_limit_header(
response_headers, "x-ratelimit-limit-tokens", True
),
"response.headers.ratelimitResetTokens": check_rate_limit_header(
response_headers, "x-ratelimit-reset-tokens", False
),
"response.headers.ratelimitResetRequests": check_rate_limit_header(
response_headers, "x-ratelimit-reset-requests", False
),
"response.headers.ratelimitRemainingTokens": check_rate_limit_header(
response_headers, "x-ratelimit-remaining-tokens", True
),
"response.headers.ratelimitRemainingRequests": check_rate_limit_header(
response_headers, "x-ratelimit-remaining-requests", True
),
"vendor": "openAI",
"response.number_of_messages": len(kwargs.get("messages", [])) + len(response.choices),
}

transaction.record_ml_event("LlmChatCompletionSummary", chat_completion_summary_dict)
message_list = list(kwargs.get("messages", [])) + [response.choices[0].message]

create_chat_completion_message_event(transaction, settings.app_name, message_list, chat_completion_id, span_id, trace_id, response_model, response_id, request_id, conversation_id)
create_chat_completion_message_event(
transaction,
settings.app_name,
message_list,
chat_completion_id,
span_id,
trace_id,
response_model,
response_id,
request_id,
conversation_id,
)

return response

Expand All @@ -106,7 +196,18 @@ def check_rate_limit_header(response_headers, header_name, is_int):
return ""


def create_chat_completion_message_event(transaction, app_name, message_list, chat_completion_id, span_id, trace_id, response_model, response_id, request_id, conversation_id):
def create_chat_completion_message_event(
transaction,
app_name,
message_list,
chat_completion_id,
span_id,
trace_id,
response_model,
response_id,
request_id,
conversation_id,
):
if not transaction:
return

Expand All @@ -119,7 +220,7 @@ def create_chat_completion_message_event(transaction, app_name, message_list, ch
"span_id": span_id,
"trace_id": trace_id,
"transaction_id": transaction._transaction_id,
"content": message.get("content", "")[:MAX_LOG_MESSAGE_LENGTH],
"content": message.get("content", ""),
"role": message.get("role", ""),
"completion_id": chat_completion_id,
"sequence": index,
Expand All @@ -139,10 +240,15 @@ def wrap_convert_to_openai_object(wrapped, instance, args, kwargs):
return returned_response


def instrument_openai_util(module):
wrap_function_wrapper(module, "convert_to_openai_object", wrap_convert_to_openai_object)


def instrument_openai_api_resources_embedding(module):
if hasattr(module.Embedding, "create"):
wrap_function_wrapper(module, "Embedding.create", wrap_embedding_create)


def instrument_openai_api_resources_chat_completion(module):
if hasattr(module.ChatCompletion, "create"):
wrap_function_wrapper(module, "ChatCompletion.create", wrap_chat_completion_create)


def instrument_openai_util(module):
wrap_function_wrapper(module, "convert_to_openai_object", wrap_convert_to_openai_object)
3 changes: 1 addition & 2 deletions tests/mlmodel_openai/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,7 @@
"transaction_tracer.stack_trace_threshold": 0.0,
"debug.log_data_collector_payloads": True,
"debug.record_transaction_failure": True,
"machine_learning.enabled": True,
"ml_insights_events.enabled": True
"ml_insights_events.enabled": True,
}

collector_agent_registration = collector_agent_registration_fixture(
Expand Down
Loading

0 comments on commit 7e20764

Please sign in to comment.