From 854e320fde543f82c5d778435fbc97827ea9d1e5 Mon Sep 17 00:00:00 2001 From: Karthik Kalyanaraman <105607645+karthikscale3@users.noreply.github.com> Date: Wed, 25 Sep 2024 17:26:51 -0700 Subject: [PATCH 1/3] DSPy enhancements (#362) * fix dspy issue * DSPy enhancements v1 * Add support for LiteLLM * update readme * DSPy enhancements * bump version --- README.md | 11 +- pyproject.toml | 2 + .../optimizers/bootstrap_fewshot.py | 89 +++ .../openai_example/chat_completion.py | 35 +- .../constants/instrumentation/common.py | 1 + .../constants/instrumentation/litellm.py | 18 + .../instrumentation/__init__.py | 2 + .../instrumentation/dspy/patch.py | 26 +- .../instrumentation/litellm/__init__.py | 5 + .../litellm/instrumentation.py | 87 +++ .../instrumentation/litellm/patch.py | 651 ++++++++++++++++++ .../instrumentation/litellm/types.py | 170 +++++ src/langtrace_python_sdk/langtrace.py | 2 + src/langtrace_python_sdk/version.py | 2 +- 14 files changed, 1075 insertions(+), 26 deletions(-) create mode 100644 src/examples/dspy_example/optimizers/bootstrap_fewshot.py create mode 100644 src/langtrace_python_sdk/constants/instrumentation/litellm.py create mode 100644 src/langtrace_python_sdk/instrumentation/litellm/__init__.py create mode 100644 src/langtrace_python_sdk/instrumentation/litellm/instrumentation.py create mode 100644 src/langtrace_python_sdk/instrumentation/litellm/patch.py create mode 100644 src/langtrace_python_sdk/instrumentation/litellm/types.py diff --git a/README.md b/README.md index 46a530d3..711f1c27 100644 --- a/README.md +++ b/README.md @@ -238,6 +238,14 @@ By default, prompt and completion data are captured. If you would like to opt ou `TRACE_PROMPT_COMPLETION_DATA=false` +### Enable/Disable checkpoint tracing for DSPy + +By default, checkpoints are traced for DSPy pipelines. If you would like to disable it, set the following env var, + +`TRACE_DSPY_CHECKPOINT=false` + +Note: Checkpoint tracing will increase the latency of executions as the state is serialized. Please disable it in production. + ## Supported integrations Langtrace automatically captures traces from the following vendors: @@ -253,8 +261,9 @@ Langtrace automatically captures traces from the following vendors: | Gemini | LLM | :x: | :white_check_mark: | | Mistral | LLM | :x: | :white_check_mark: | | Langchain | Framework | :x: | :white_check_mark: | -| LlamaIndex | Framework | :white_check_mark: | :white_check_mark: | | Langgraph | Framework | :x: | :white_check_mark: | +| LlamaIndex | Framework | :white_check_mark: | :white_check_mark: | +| LiteLLM | Framework | :x: | :white_check_mark: | | DSPy | Framework | :x: | :white_check_mark: | | CrewAI | Framework | :x: | :white_check_mark: | | Ollama | Framework | :x: | :white_check_mark: | diff --git a/pyproject.toml b/pyproject.toml index 4875e93b..0999ca5d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,6 +31,7 @@ dependencies = [ 'fsspec>=2024.6.0', "transformers>=4.11.3", "sentry-sdk>=2.14.0", + "ujson>=5.10.0", ] requires-python = ">=3.9" @@ -47,6 +48,7 @@ dev = [ "langchain-community", "langchain-openai", "langchain-openai", + "litellm", "chromadb", "cohere", "qdrant_client", diff --git a/src/examples/dspy_example/optimizers/bootstrap_fewshot.py b/src/examples/dspy_example/optimizers/bootstrap_fewshot.py new file mode 100644 index 00000000..1d05632c --- /dev/null +++ b/src/examples/dspy_example/optimizers/bootstrap_fewshot.py @@ -0,0 +1,89 @@ +import dspy +from dotenv import find_dotenv, load_dotenv +from dspy.datasets import HotPotQA +from dspy.teleprompt import BootstrapFewShot + +from langtrace_python_sdk import inject_additional_attributes, langtrace + +_ = load_dotenv(find_dotenv()) + +langtrace.init() + +turbo = dspy.LM('openai/gpt-4o-mini') +colbertv2_wiki17_abstracts = dspy.ColBERTv2(url='http://20.102.90.50:2017/wiki17_abstracts') + +dspy.settings.configure(lm=turbo, rm=colbertv2_wiki17_abstracts) + + +# Load the dataset. +dataset = HotPotQA(train_seed=1, train_size=20, eval_seed=2023, dev_size=50, test_size=0) + +# Tell DSPy that the 'question' field is the input. Any other fields are labels and/or metadata. +trainset = [x.with_inputs('question') for x in dataset.train] +devset = [x.with_inputs('question') for x in dataset.dev] + + +class GenerateAnswer(dspy.Signature): + """Answer questions with short factoid answers.""" + + context = dspy.InputField(desc="may contain relevant facts") + question = dspy.InputField() + answer = dspy.OutputField(desc="often between 1 and 5 words") + + +class RAG(dspy.Module): + def __init__(self, num_passages=3): + super().__init__() + + self.retrieve = dspy.Retrieve(k=num_passages) + self.generate_answer = dspy.ChainOfThought(GenerateAnswer) + + def forward(self, question): + context = self.retrieve(question).passages + prediction = self.generate_answer(context=context, question=question) + return dspy.Prediction(context=context, answer=prediction.answer) + + +# Validation logic: check that the predicted answer is correct. +# Also check that the retrieved context does actually contain that answer. +def validate_context_and_answer(example, prediction, trace=None): + answer_em = dspy.evaluate.answer_exact_match(example, prediction) + answer_pm = dspy.evaluate.answer_passage_match(example, prediction) + return answer_em and answer_pm + + +# Set up a basic optimizer, which will compile our RAG program. +optimizer = BootstrapFewShot(metric=validate_context_and_answer) + +# Compile! +compiled_rag = optimizer.compile(RAG(), trainset=trainset) + +# Ask any question you like to this simple RAG program. +my_question = "Who was the hero of the movie peraanmai?" + +# Get the prediction. This contains `pred.context` and `pred.answer`. +# pred = compiled_rag(my_question) +pred = inject_additional_attributes(lambda: compiled_rag(my_question), {'experiment': 'experiment 6', 'description': 'trying additional stuff', 'run_id': 'run_1'}) +# compiled_rag.save('compiled_rag_v1.json') + +# Print the contexts and the answer. +print(f"Question: {my_question}") +print(f"Predicted Answer: {pred.answer}") +print(f"Retrieved Contexts (truncated): {[c[:200] + '...' for c in pred.context]}") + +# print("Inspecting the history of the optimizer:") +# turbo.inspect_history(n=1) + +from dspy.evaluate import Evaluate + + +def validate_answer(example, pred, trace=None): + return True + + +# Set up the evaluator, which can be used multiple times. +evaluate = Evaluate(devset=devset, metric=validate_answer, num_threads=4, display_progress=True, display_table=0) + + +# Evaluate our `optimized_cot` program. +evaluate(compiled_rag) diff --git a/src/examples/openai_example/chat_completion.py b/src/examples/openai_example/chat_completion.py index 95a7ab1b..56179561 100644 --- a/src/examples/openai_example/chat_completion.py +++ b/src/examples/openai_example/chat_completion.py @@ -9,19 +9,19 @@ _ = load_dotenv(find_dotenv()) -langtrace.init(write_spans_to_console=True) +langtrace.init() client = OpenAI() def api(): response = client.chat.completions.create( - model="gpt-4", + model="o1-mini", messages=[ - {"role": "system", "content": "Talk like a pirate"}, - {"role": "user", "content": "Tell me a story in 3 sentences or less."}, + # {"role": "system", "content": "Talk like a pirate"}, + {"role": "user", "content": "How many r's are in strawberry?"}, ], - stream=True, - # stream=False, + # stream=True, + stream=False, ) return response @@ -31,14 +31,17 @@ def chat_completion(): response = api() # print(response) # Uncomment this for streaming - result = [] - for chunk in response: - if chunk.choices[0].delta.content is not None: - content = [ - choice.delta.content if choice.delta and choice.delta.content else "" - for choice in chunk.choices - ] - result.append(content[0] if len(content) > 0 else "") - - # print("".join(result)) + # result = [] + # for chunk in response: + # if chunk.choices[0].delta.content is not None: + # content = [ + # choice.delta.content if choice.delta and choice.delta.content else "" + # for choice in chunk.choices + # ] + # result.append(content[0] if len(content) > 0 else "") + + # # print("".join(result)) + print(response) return response + +chat_completion() \ No newline at end of file diff --git a/src/langtrace_python_sdk/constants/instrumentation/common.py b/src/langtrace_python_sdk/constants/instrumentation/common.py index 70d92a1b..4c4ec63f 100644 --- a/src/langtrace_python_sdk/constants/instrumentation/common.py +++ b/src/langtrace_python_sdk/constants/instrumentation/common.py @@ -19,6 +19,7 @@ "LANGCHAIN_COMMUNITY": "Langchain Community", "LANGCHAIN_CORE": "Langchain Core", "LANGGRAPH": "Langgraph", + "LITELLM": "Litellm", "LLAMAINDEX": "LlamaIndex", "OPENAI": "OpenAI", "PINECONE": "Pinecone", diff --git a/src/langtrace_python_sdk/constants/instrumentation/litellm.py b/src/langtrace_python_sdk/constants/instrumentation/litellm.py new file mode 100644 index 00000000..10020b22 --- /dev/null +++ b/src/langtrace_python_sdk/constants/instrumentation/litellm.py @@ -0,0 +1,18 @@ +APIS = { + "CHAT_COMPLETION": { + "METHOD": "chat.completions.create", + "ENDPOINT": "/chat/completions", + }, + "IMAGES_GENERATION": { + "METHOD": "images.generate", + "ENDPOINT": "/images/generations", + }, + "IMAGES_EDIT": { + "METHOD": "images.edit", + "ENDPOINT": "/images/edits", + }, + "EMBEDDINGS_CREATE": { + "METHOD": "embeddings.create", + "ENDPOINT": "/embeddings", + }, +} diff --git a/src/langtrace_python_sdk/instrumentation/__init__.py b/src/langtrace_python_sdk/instrumentation/__init__.py index 984541dc..369c5ac9 100644 --- a/src/langtrace_python_sdk/instrumentation/__init__.py +++ b/src/langtrace_python_sdk/instrumentation/__init__.py @@ -19,6 +19,7 @@ from .gemini import GeminiInstrumentation from .mistral import MistralInstrumentation from .embedchain import EmbedchainInstrumentation +from .litellm import LiteLLMInstrumentation __all__ = [ "AnthropicInstrumentation", @@ -31,6 +32,7 @@ "LangchainCommunityInstrumentation", "LangchainCoreInstrumentation", "LanggraphInstrumentation", + "LiteLLMInstrumentation", "LlamaindexInstrumentation", "OpenAIInstrumentation", "PineconeInstrumentation", diff --git a/src/langtrace_python_sdk/instrumentation/dspy/patch.py b/src/langtrace_python_sdk/instrumentation/dspy/patch.py index 1690df96..8a02afcb 100644 --- a/src/langtrace_python_sdk/instrumentation/dspy/patch.py +++ b/src/langtrace_python_sdk/instrumentation/dspy/patch.py @@ -1,6 +1,19 @@ import json +import os + +import ujson +from colorama import Fore from importlib_metadata import version as v +from langtrace.trace_attributes import FrameworkSpanAttributes +from opentelemetry import baggage +from opentelemetry.trace import SpanKind +from opentelemetry.trace.status import Status, StatusCode + from langtrace_python_sdk.constants import LANGTRACE_SDK_NAME +from langtrace_python_sdk.constants.instrumentation.common import ( + LANGTRACE_ADDITIONAL_SPAN_ATTRIBUTES_KEY, + SERVICE_PROVIDERS, +) from langtrace_python_sdk.utils import set_span_attribute from langtrace_python_sdk.utils.llm import ( get_extra_attributes, @@ -9,14 +22,6 @@ set_span_attributes, ) from langtrace_python_sdk.utils.silently_fail import silently_fail -from langtrace_python_sdk.constants.instrumentation.common import ( - LANGTRACE_ADDITIONAL_SPAN_ATTRIBUTES_KEY, - SERVICE_PROVIDERS, -) -from opentelemetry import baggage -from langtrace.trace_attributes import FrameworkSpanAttributes -from opentelemetry.trace import SpanKind -from opentelemetry.trace.status import Status, StatusCode def patch_bootstrapfewshot_optimizer(operation_name, version, tracer): @@ -115,6 +120,8 @@ def traced_method(wrapped, instance, args, kwargs): **get_extra_attributes(), } + trace_checkpoint = os.environ.get("TRACE_DSPY_CHECKPOINT", "true").lower() + if instance.__class__.__name__: span_attributes["dspy.signature.name"] = instance.__class__.__name__ span_attributes["dspy.signature"] = str(instance.signature) @@ -136,6 +143,9 @@ def traced_method(wrapped, instance, args, kwargs): "dspy.signature.result", json.dumps(result.toDict()), ) + if trace_checkpoint == "true": + print(Fore.RED + "Note: DSPy checkpoint tracing is enabled in Langtrace. To disable it, set the env var, TRACE_DSPY_CHECKPOINT to false" + Fore.RESET) + set_span_attribute(span, "dspy.checkpoint", ujson.dumps(instance.dump_state(False), indent=2)) span.set_status(Status(StatusCode.OK)) span.end() diff --git a/src/langtrace_python_sdk/instrumentation/litellm/__init__.py b/src/langtrace_python_sdk/instrumentation/litellm/__init__.py new file mode 100644 index 00000000..997e9832 --- /dev/null +++ b/src/langtrace_python_sdk/instrumentation/litellm/__init__.py @@ -0,0 +1,5 @@ +from .instrumentation import LiteLLMInstrumentation + +__all__ = [ + "LiteLLMInstrumentation", +] diff --git a/src/langtrace_python_sdk/instrumentation/litellm/instrumentation.py b/src/langtrace_python_sdk/instrumentation/litellm/instrumentation.py new file mode 100644 index 00000000..d57de9d8 --- /dev/null +++ b/src/langtrace_python_sdk/instrumentation/litellm/instrumentation.py @@ -0,0 +1,87 @@ +""" +Copyright (c) 2024 Scale3 Labs +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from typing import Collection, Optional, Any +import importlib.metadata +import logging + +from opentelemetry.instrumentation.instrumentor import BaseInstrumentor +from opentelemetry.trace import get_tracer, TracerProvider +from wrapt import wrap_function_wrapper + +from langtrace_python_sdk.instrumentation.litellm.patch import ( + async_chat_completions_create, + async_embeddings_create, + async_images_generate, + chat_completions_create, + embeddings_create, + images_generate, +) + +logging.basicConfig(level=logging.FATAL) + + +class LiteLLMInstrumentation(BaseInstrumentor): # type: ignore + + def instrumentation_dependencies(self) -> Collection[str]: + return ["litellm >= 1.48.0", "trace-attributes >= 4.0.5"] + + def _instrument(self, **kwargs: Any) -> None: + tracer_provider: Optional[TracerProvider] = kwargs.get("tracer_provider") + tracer = get_tracer(__name__, "", tracer_provider) + version: str = importlib.metadata.version("openai") + + wrap_function_wrapper( + "litellm", + "completion", + chat_completions_create(version, tracer), + ) + + wrap_function_wrapper( + "litellm", + "text_completion", + chat_completions_create(version, tracer), + ) + + wrap_function_wrapper( + "litellm.main", + "acompletion", + async_chat_completions_create(version, tracer), + ) + + wrap_function_wrapper( + "litellm.main", + "image_generation", + images_generate(version, tracer), + ) + + wrap_function_wrapper( + "litellm.main", + "aimage_generation", + async_images_generate(version, tracer), + ) + + wrap_function_wrapper( + "litellm.main", + "embedding", + embeddings_create(version, tracer), + ) + + wrap_function_wrapper( + "litellm.main", + "aembedding", + async_embeddings_create(version, tracer), + ) + + def _uninstrument(self, **kwargs: Any) -> None: + pass diff --git a/src/langtrace_python_sdk/instrumentation/litellm/patch.py b/src/langtrace_python_sdk/instrumentation/litellm/patch.py new file mode 100644 index 00000000..09c77477 --- /dev/null +++ b/src/langtrace_python_sdk/instrumentation/litellm/patch.py @@ -0,0 +1,651 @@ +import json +from typing import Any, Dict, List, Optional, Callable, Awaitable, Union +from langtrace.trace_attributes import ( + LLMSpanAttributes, + SpanAttributes, +) +from langtrace_python_sdk.utils import set_span_attribute +from langtrace_python_sdk.utils.silently_fail import silently_fail +from opentelemetry import trace +from opentelemetry.trace import SpanKind, Tracer, Span +from opentelemetry.trace.status import Status, StatusCode +from opentelemetry.trace.propagation import set_span_in_context +from langtrace_python_sdk.constants.instrumentation.common import ( + SERVICE_PROVIDERS, +) +from langtrace_python_sdk.constants.instrumentation.litellm import APIS +from langtrace_python_sdk.utils.llm import ( + calculate_prompt_tokens, + get_base_url, + get_extra_attributes, + get_langtrace_attributes, + get_llm_request_attributes, + get_span_name, + get_tool_calls, + is_streaming, + set_event_completion, + StreamWrapper, + set_span_attributes, +) +from langtrace_python_sdk.types import NOT_GIVEN + +from langtrace_python_sdk.instrumentation.openai.types import ( + ImagesGenerateKwargs, + ChatCompletionsCreateKwargs, + EmbeddingsCreateKwargs, + ImagesEditKwargs, + ResultType, + ContentItem, +) + + +def filter_valid_attributes(attributes): + """Filter attributes where value is not None, not an empty string.""" + return { + key: value + for key, value in attributes.items() + if value is not None and value != "" + } + + +def images_generate(version: str, tracer: Tracer) -> Callable: + """ + Wrap the `generate` method of the `Images` class to trace it. + """ + + def traced_method( + wrapped: Callable, instance: Any, args: List[Any], kwargs: ImagesGenerateKwargs + ) -> Any: + service_provider = SERVICE_PROVIDERS["LITELLM"] + span_attributes = { + **get_langtrace_attributes(version, service_provider, vendor_type="llm"), + **get_llm_request_attributes(kwargs, operation_name="images_generate"), + SpanAttributes.LLM_URL: "not available", + SpanAttributes.LLM_PATH: APIS["IMAGES_GENERATION"]["ENDPOINT"], + **get_extra_attributes(), # type: ignore + } + + attributes = LLMSpanAttributes(**filter_valid_attributes(span_attributes)) + + with tracer.start_as_current_span( + name=get_span_name(APIS["IMAGES_GENERATION"]["METHOD"]), + kind=SpanKind.CLIENT, + context=set_span_in_context(trace.get_current_span()), + ) as span: + set_span_attributes(span, attributes) + try: + # Attempt to call the original method + result = wrapped(*args, **kwargs) + if not is_streaming(kwargs): + data: Optional[ContentItem] = ( + result.data[0] + if hasattr(result, "data") and len(result.data) > 0 + else None + ) + response = [ + { + "role": "assistant", + "content": { + "url": getattr(data, "url", ""), + "revised_prompt": getattr(data, "revised_prompt", ""), + }, + } + ] + set_event_completion(span, response) + + span.set_status(StatusCode.OK) + return result + except Exception as err: + # Record the exception in the span + span.record_exception(err) + + # Set the span status to indicate an error + span.set_status(Status(StatusCode.ERROR, str(err))) + + # Reraise the exception to ensure it's not swallowed + raise + + return traced_method + + +def async_images_generate(version: str, tracer: Tracer) -> Callable: + """ + Wrap the `generate` method of the `Images` class to trace it. + """ + + async def traced_method( + wrapped: Callable, instance: Any, args: List[Any], kwargs: ImagesGenerateKwargs + ) -> Awaitable[Any]: + service_provider = SERVICE_PROVIDERS["LITELLM"] + + span_attributes = { + **get_langtrace_attributes(version, service_provider, vendor_type="llm"), + **get_llm_request_attributes(kwargs, operation_name="images_generate"), + SpanAttributes.LLM_URL: "not available", + SpanAttributes.LLM_PATH: APIS["IMAGES_GENERATION"]["ENDPOINT"], + **get_extra_attributes(), # type: ignore + } + + attributes = LLMSpanAttributes(**filter_valid_attributes(span_attributes)) + + with tracer.start_as_current_span( + name=get_span_name(APIS["IMAGES_GENERATION"]["METHOD"]), + kind=SpanKind.CLIENT, + context=set_span_in_context(trace.get_current_span()), + ) as span: + set_span_attributes(span, attributes) + try: + # Attempt to call the original method + result = await wrapped(*args, **kwargs) + if not is_streaming(kwargs): + data: Optional[ContentItem] = ( + result.data[0] + if hasattr(result, "data") and len(result.data) > 0 + else None + ) + response = [ + { + "role": "assistant", + "content": { + "url": getattr(data, "url", ""), + "revised_prompt": getattr(data, "revised_prompt", ""), + }, + } + ] + set_event_completion(span, response) + + span.set_status(StatusCode.OK) + return result + except Exception as err: + # Record the exception in the span + span.record_exception(err) + + # Set the span status to indicate an error + span.set_status(Status(StatusCode.ERROR, str(err))) + + # Reraise the exception to ensure it's not swallowed + raise + + return traced_method + + +def images_edit(version: str, tracer: Tracer) -> Callable: + """ + Wrap the `edit` method of the `Images` class to trace it. + """ + + def traced_method( + wrapped: Callable, instance: Any, args: List[Any], kwargs: ImagesEditKwargs + ) -> Any: + service_provider = SERVICE_PROVIDERS["LITELLM"] + + span_attributes = { + **get_langtrace_attributes(version, service_provider, vendor_type="llm"), + **get_llm_request_attributes(kwargs, operation_name="images_edit"), + SpanAttributes.LLM_URL: "not available", + SpanAttributes.LLM_PATH: APIS["IMAGES_EDIT"]["ENDPOINT"], + SpanAttributes.LLM_RESPONSE_FORMAT: kwargs.get("response_format"), + SpanAttributes.LLM_IMAGE_SIZE: kwargs.get("size"), + **get_extra_attributes(), # type: ignore + } + + attributes = LLMSpanAttributes(**filter_valid_attributes(span_attributes)) + + with tracer.start_as_current_span( + name=APIS["IMAGES_EDIT"]["METHOD"], + kind=SpanKind.CLIENT, + context=set_span_in_context(trace.get_current_span()), + ) as span: + set_span_attributes(span, attributes) + try: + # Attempt to call the original method + result = wrapped(*args, **kwargs) + + response = [] + # Parse each image object + for each_data in result.data: + response.append( + { + "role": "assistant", + "content": { + "url": each_data.url, + "revised_prompt": each_data.revised_prompt, + "base64": each_data.b64_json, + }, + } + ) + + set_event_completion(span, response) + + span.set_status(StatusCode.OK) + return result + except Exception as err: + # Record the exception in the span + span.record_exception(err) + + # Set the span status to indicate an error + span.set_status(Status(StatusCode.ERROR, str(err))) + + # Reraise the exception to ensure it's not swallowed + raise + + return traced_method + + +def chat_completions_create(version: str, tracer: Tracer) -> Callable: + """Wrap the `create` method of the `ChatCompletion` class to trace it.""" + + def traced_method( + wrapped: Callable, + instance: Any, + args: List[Any], + kwargs: ChatCompletionsCreateKwargs, + ) -> Any: + service_provider = SERVICE_PROVIDERS["LITELLM"] + if "perplexity" in get_base_url(instance): + service_provider = SERVICE_PROVIDERS["PPLX"] + elif "azure" in get_base_url(instance): + service_provider = SERVICE_PROVIDERS["AZURE"] + elif "groq" in get_base_url(instance): + service_provider = SERVICE_PROVIDERS["GROQ"] + llm_prompts = [] + for item in kwargs.get("messages", []): + tools = get_tool_calls(item) + if tools is not None: + tool_calls = [] + for tool_call in tools: + tool_call_dict = { + "id": getattr(tool_call, "id", ""), + "type": getattr(tool_call, "type", ""), + } + if hasattr(tool_call, "function"): + tool_call_dict["function"] = { + "name": getattr(tool_call.function, "name", ""), + "arguments": getattr(tool_call.function, "arguments", ""), + } + tool_calls.append(tool_call_dict) + llm_prompts.append(tool_calls) + else: + llm_prompts.append(item) + + span_attributes = { + **get_langtrace_attributes(version, service_provider, vendor_type="llm"), + **get_llm_request_attributes(kwargs, prompts=llm_prompts), + SpanAttributes.LLM_URL: "not available", + SpanAttributes.LLM_PATH: APIS["CHAT_COMPLETION"]["ENDPOINT"], + **get_extra_attributes(), # type: ignore + } + + attributes = LLMSpanAttributes(**filter_valid_attributes(span_attributes)) + + span = tracer.start_span( + name=get_span_name(APIS["CHAT_COMPLETION"]["METHOD"]), + kind=SpanKind.CLIENT, + context=set_span_in_context(trace.get_current_span()), + ) + _set_input_attributes(span, kwargs, attributes) + + try: + result = wrapped(*args, **kwargs) + if is_streaming(kwargs): + prompt_tokens = 0 + for message in kwargs.get("messages", {}): + prompt_tokens += calculate_prompt_tokens( + json.dumps(str(message)), kwargs.get("model") + ) + functions = kwargs.get("functions") + if functions is not None and functions != NOT_GIVEN: + for function in functions: + prompt_tokens += calculate_prompt_tokens( + json.dumps(function), kwargs.get("model") + ) + + return StreamWrapper( + result, + span, + prompt_tokens, + function_call=kwargs.get("functions") is not None, + tool_calls=kwargs.get("tools") is not None, + ) + else: + _set_response_attributes(span, result) + span.set_status(StatusCode.OK) + span.end() + return result + + except Exception as error: + span.record_exception(error) + span.set_status(Status(StatusCode.ERROR, str(error))) + span.end() + raise + + return traced_method + + +def async_chat_completions_create(version: str, tracer: Tracer) -> Callable: + """Wrap the `create` method of the `ChatCompletion` class to trace it.""" + + async def traced_method( + wrapped: Callable, + instance: Any, + args: List[Any], + kwargs: ChatCompletionsCreateKwargs, + ) -> Awaitable[Any]: + service_provider = SERVICE_PROVIDERS["LITELLM"] + if "perplexity" in get_base_url(instance): + service_provider = SERVICE_PROVIDERS["PPLX"] + elif "azure" in get_base_url(instance): + service_provider = SERVICE_PROVIDERS["AZURE"] + llm_prompts = [] + for item in kwargs.get("messages", []): + tools = get_tool_calls(item) + if tools is not None: + tool_calls = [] + for tool_call in tools: + tool_call_dict = { + "id": getattr(tool_call, "id", ""), + "type": getattr(tool_call, "type", ""), + } + if hasattr(tool_call, "function"): + tool_call_dict["function"] = { + "name": getattr(tool_call.function, "name", ""), + "arguments": getattr(tool_call.function, "arguments", ""), + } + tool_calls.append(json.dumps(tool_call_dict)) + llm_prompts.append(tool_calls) + else: + llm_prompts.append(item) + + span_attributes = { + **get_langtrace_attributes(version, service_provider, vendor_type="llm"), + **get_llm_request_attributes(kwargs, prompts=llm_prompts), + SpanAttributes.LLM_URL: "not available", + SpanAttributes.LLM_PATH: APIS["CHAT_COMPLETION"]["ENDPOINT"], + **get_extra_attributes(), # type: ignore + } + + attributes = LLMSpanAttributes(**filter_valid_attributes(span_attributes)) + + span = tracer.start_span( + name=get_span_name(APIS["CHAT_COMPLETION"]["METHOD"]), + kind=SpanKind.CLIENT, + context=set_span_in_context(trace.get_current_span()), + ) + _set_input_attributes(span, kwargs, attributes) + + try: + result = await wrapped(*args, **kwargs) + if is_streaming(kwargs): + prompt_tokens = 0 + for message in kwargs.get("messages", {}): + prompt_tokens += calculate_prompt_tokens( + json.dumps((str(message))), kwargs.get("model") + ) + + functions = kwargs.get("functions") + if functions is not None and functions != NOT_GIVEN: + for function in functions: + prompt_tokens += calculate_prompt_tokens( + json.dumps(function), kwargs.get("model") + ) + + return StreamWrapper( + result, + span, + prompt_tokens, + function_call=kwargs.get("functions") is not None, + tool_calls=kwargs.get("tools") is not None, + ) # type: ignore + else: + _set_response_attributes(span, result) + span.set_status(StatusCode.OK) + span.end() + return result + + except Exception as error: + span.record_exception(error) + span.set_status(Status(StatusCode.ERROR, str(error))) + span.end() + raise + + return traced_method + + +def embeddings_create(version: str, tracer: Tracer) -> Callable: + """ + Wrap the `create` method of the `Embeddings` class to trace it. + """ + + def traced_method( + wrapped: Callable, + instance: Any, + args: List[Any], + kwargs: EmbeddingsCreateKwargs, + ) -> Any: + service_provider = SERVICE_PROVIDERS["LITELLM"] + + span_attributes = { + **get_langtrace_attributes(version, service_provider, vendor_type="llm"), + **get_llm_request_attributes(kwargs, operation_name="embed"), + SpanAttributes.LLM_URL: "not available", + SpanAttributes.LLM_PATH: APIS["EMBEDDINGS_CREATE"]["ENDPOINT"], + SpanAttributes.LLM_REQUEST_DIMENSIONS: kwargs.get("dimensions"), + **get_extra_attributes(), # type: ignore + } + + encoding_format = kwargs.get("encoding_format") + if encoding_format is not None: + if not isinstance(encoding_format, list): + encoding_format = [encoding_format] + span_attributes[SpanAttributes.LLM_REQUEST_ENCODING_FORMATS] = ( + encoding_format + ) + + if kwargs.get("input") is not None: + span_attributes[SpanAttributes.LLM_REQUEST_EMBEDDING_INPUTS] = json.dumps( + [kwargs.get("input", "")] + ) + + attributes = LLMSpanAttributes(**filter_valid_attributes(span_attributes)) + + with tracer.start_as_current_span( + name=get_span_name(APIS["EMBEDDINGS_CREATE"]["METHOD"]), + kind=SpanKind.CLIENT, + context=set_span_in_context(trace.get_current_span()), + ) as span: + + set_span_attributes(span, attributes) + try: + # Attempt to call the original method + result = wrapped(*args, **kwargs) + span.set_status(StatusCode.OK) + return result + except Exception as err: + # Record the exception in the span + span.record_exception(err) + + # Set the span status to indicate an error + span.set_status(Status(StatusCode.ERROR, str(err))) + + # Reraise the exception to ensure it's not swallowed + raise + + return traced_method + + +def async_embeddings_create(version: str, tracer: Tracer) -> Callable: + """ + Wrap the `create` method of the `Embeddings` class to trace it. + """ + + async def traced_method( + wrapped: Callable, + instance: Any, + args: List[Any], + kwargs: EmbeddingsCreateKwargs, + ) -> Awaitable[Any]: + + service_provider = SERVICE_PROVIDERS["LITELLM"] + + span_attributes = { + **get_langtrace_attributes(version, service_provider, vendor_type="llm"), + **get_llm_request_attributes(kwargs, operation_name="embed"), + SpanAttributes.LLM_PATH: APIS["EMBEDDINGS_CREATE"]["ENDPOINT"], + SpanAttributes.LLM_REQUEST_DIMENSIONS: kwargs.get("dimensions"), + **get_extra_attributes(), # type: ignore + } + + attributes = LLMSpanAttributes(**filter_valid_attributes(span_attributes)) + + encoding_format = kwargs.get("encoding_format") + if encoding_format is not None: + if not isinstance(encoding_format, list): + encoding_format = [encoding_format] + span_attributes[SpanAttributes.LLM_REQUEST_ENCODING_FORMATS] = ( + encoding_format + ) + + if kwargs.get("input") is not None: + span_attributes[SpanAttributes.LLM_REQUEST_EMBEDDING_INPUTS] = json.dumps( + [kwargs.get("input", "")] + ) + + with tracer.start_as_current_span( + name=get_span_name(APIS["EMBEDDINGS_CREATE"]["METHOD"]), + kind=SpanKind.CLIENT, + context=set_span_in_context(trace.get_current_span()), + ) as span: + + set_span_attributes(span, attributes) + try: + # Attempt to call the original method + result = await wrapped(*args, **kwargs) + span.set_status(StatusCode.OK) + return result + except Exception as err: + # Record the exception in the span + span.record_exception(err) + + # Set the span status to indicate an error + span.set_status(Status(StatusCode.ERROR, str(err))) + + # Reraise the exception to ensure it's not swallowed + raise + + return traced_method + + +def extract_content(choice: Any) -> Union[str, List[Dict[str, Any]], Dict[str, Any]]: + # Check if choice.message exists and has a content attribute + if ( + hasattr(choice, "message") + and hasattr(choice.message, "content") + and choice.message.content is not None + ): + return choice.message.content + + # Check if choice.message has tool_calls and extract information accordingly + elif ( + hasattr(choice, "message") + and hasattr(choice.message, "tool_calls") + and choice.message.tool_calls is not None + ): + result = [ + { + "id": tool_call.id, + "type": tool_call.type, + "function": { + "name": tool_call.function.name, + "arguments": tool_call.function.arguments, + }, + } + for tool_call in choice.message.tool_calls + ] + return result + + # Check if choice.message has a function_call and extract information accordingly + elif ( + hasattr(choice, "message") + and hasattr(choice.message, "function_call") + and choice.message.function_call is not None + ): + return { + "name": choice.message.function_call.name, + "arguments": choice.message.function_call.arguments, + } + + # Return an empty string if none of the above conditions are met + else: + return "" + + +@silently_fail +def _set_input_attributes( + span: Span, kwargs: ChatCompletionsCreateKwargs, attributes: LLMSpanAttributes +) -> None: + tools = [] + for field, value in attributes.model_dump(by_alias=True).items(): + set_span_attribute(span, field, value) + functions = kwargs.get("functions") + if functions is not None and functions != NOT_GIVEN: + for function in functions: + tools.append(json.dumps({"type": "function", "function": function})) + + if kwargs.get("tools") is not None and kwargs.get("tools") != NOT_GIVEN: + tools.append(json.dumps(kwargs.get("tools"))) + + if tools: + set_span_attribute(span, SpanAttributes.LLM_TOOLS, json.dumps(tools)) + + +@silently_fail +def _set_response_attributes(span: Span, result: ResultType) -> None: + set_span_attribute(span, SpanAttributes.LLM_RESPONSE_MODEL, result.model) + if hasattr(result, "choices") and result.choices is not None: + responses = [ + { + "role": ( + choice.message.role + if choice.message and choice.message.role + else "assistant" + ), + "content": extract_content(choice), + **( + {"content_filter_results": choice.content_filter_results} + if hasattr(choice, "content_filter_results") + else {} + ), + } + for choice in result.choices + ] + set_event_completion(span, responses) + + if ( + hasattr(result, "system_fingerprint") + and result.system_fingerprint is not None + and result.system_fingerprint != NOT_GIVEN + ): + set_span_attribute( + span, + SpanAttributes.LLM_SYSTEM_FINGERPRINT, + result.system_fingerprint, + ) + # Get the usage + if hasattr(result, "usage") and result.usage is not None: + usage = result.usage + if usage is not None: + set_span_attribute( + span, + SpanAttributes.LLM_USAGE_PROMPT_TOKENS, + result.usage.prompt_tokens, + ) + set_span_attribute( + span, + SpanAttributes.LLM_USAGE_COMPLETION_TOKENS, + result.usage.completion_tokens, + ) + set_span_attribute( + span, + SpanAttributes.LLM_USAGE_TOTAL_TOKENS, + result.usage.total_tokens, + ) diff --git a/src/langtrace_python_sdk/instrumentation/litellm/types.py b/src/langtrace_python_sdk/instrumentation/litellm/types.py new file mode 100644 index 00000000..64b6ab14 --- /dev/null +++ b/src/langtrace_python_sdk/instrumentation/litellm/types.py @@ -0,0 +1,170 @@ +""" +Copyright (c) 2024 Scale3 Labs +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from typing import Any, Dict, List, Union, Optional, TypedDict + + +class ContentItem: + url: str + revised_prompt: str + base64: Optional[str] + + def __init__( + self, + url: str, + revised_prompt: str, + base64: Optional[str], + ): + self.url = url + self.revised_prompt = revised_prompt + self.base64 = base64 + + +class ToolFunction: + name: str + arguments: str + + def __init__( + self, + name: str, + arguments: str, + ): + self.name = name + self.arguments = arguments + + +class ToolCall: + id: str + type: str + function: ToolFunction + + def __init__( + self, + id: str, + type: str, + function: ToolFunction, + ): + self.id = id + self.type = type + self.function = function + + +class Message: + role: str + content: Union[str, List[ContentItem], Dict[str, Any]] + tool_calls: Optional[List[ToolCall]] + + def __init__( + self, + role: str, + content: Union[str, List[ContentItem], Dict[str, Any]], + content_filter_results: Optional[Any], + ): + self.role = role + self.content = content + self.content_filter_results = content_filter_results + + +class Usage: + prompt_tokens: int + completion_tokens: int + total_tokens: int + + def __init__( + self, + prompt_tokens: int, + completion_tokens: int, + total_tokens: int, + ): + self.prompt_tokens = prompt_tokens + self.completion_tokens = completion_tokens + self.total_tokens = total_tokens + + +class Choice: + message: Message + content_filter_results: Optional[Any] + + def __init__( + self, + message: Message, + content_filter_results: Optional[Any], + ): + self.message = message + self.content_filter_results = content_filter_results + + +class ResultType: + model: Optional[str] + content: List[ContentItem] + system_fingerprint: Optional[str] + usage: Optional[Usage] + choices: Optional[List[Choice]] + response_format: Optional[str] + size: Optional[str] + encoding_format: Optional[str] + + def __init__( + self, + model: Optional[str], + role: Optional[str], + content: List[ContentItem], + system_fingerprint: Optional[str], + usage: Optional[Usage], + functions: Optional[List[ToolCall]], + tools: Optional[List[ToolCall]], + choices: Optional[List[Choice]], + response_format: Optional[str], + size: Optional[str], + encoding_format: Optional[str], + ): + self.model = model + self.role = role + self.content = content + self.system_fingerprint = system_fingerprint + self.usage = usage + self.functions = functions + self.tools = tools + self.choices = choices + self.response_format = response_format + self.size = size + self.encoding_format = encoding_format + + +class ImagesGenerateKwargs(TypedDict, total=False): + operation_name: str + model: Optional[str] + messages: Optional[List[Message]] + functions: Optional[List[ToolCall]] + tools: Optional[List[ToolCall]] + response_format: Optional[str] + size: Optional[str] + encoding_format: Optional[str] + + +class ImagesEditKwargs(TypedDict, total=False): + response_format: Optional[str] + size: Optional[str] + + +class ChatCompletionsCreateKwargs(TypedDict, total=False): + model: Optional[str] + messages: List[Message] + functions: Optional[List[ToolCall]] + tools: Optional[List[ToolCall]] + + +class EmbeddingsCreateKwargs(TypedDict, total=False): + dimensions: Optional[str] + input: Union[str, List[str], None] + encoding_format: Optional[Union[List[str], str]] diff --git a/src/langtrace_python_sdk/langtrace.py b/src/langtrace_python_sdk/langtrace.py index 60f3516d..84852797 100644 --- a/src/langtrace_python_sdk/langtrace.py +++ b/src/langtrace_python_sdk/langtrace.py @@ -46,6 +46,7 @@ LangchainCoreInstrumentation, LangchainInstrumentation, LanggraphInstrumentation, + LiteLLMInstrumentation, LlamaindexInstrumentation, MistralInstrumentation, OllamaInstrumentor, @@ -137,6 +138,7 @@ def init( "langchain-core": LangchainCoreInstrumentation(), "langchain-community": LangchainCommunityInstrumentation(), "langgraph": LanggraphInstrumentation(), + "litellm": LiteLLMInstrumentation(), "anthropic": AnthropicInstrumentation(), "cohere": CohereInstrumentation(), "weaviate-client": WeaviateInstrumentation(), diff --git a/src/langtrace_python_sdk/version.py b/src/langtrace_python_sdk/version.py index 61e5ea59..31af30a2 100644 --- a/src/langtrace_python_sdk/version.py +++ b/src/langtrace_python_sdk/version.py @@ -1 +1 @@ -__version__ = "2.3.21" +__version__ = "2.3.22" From 8eec52e25e113f1c8e30d70e9dd58b72e37904ca Mon Sep 17 00:00:00 2001 From: Ali Waleed Date: Tue, 1 Oct 2024 15:03:51 +0300 Subject: [PATCH 2/3] skip instrumentations and avoid blocking other vendors --- src/langtrace_python_sdk/langtrace.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/src/langtrace_python_sdk/langtrace.py b/src/langtrace_python_sdk/langtrace.py index 84852797..6d44f2a0 100644 --- a/src/langtrace_python_sdk/langtrace.py +++ b/src/langtrace_python_sdk/langtrace.py @@ -222,7 +222,10 @@ def init_instrumentations( if disable_instrumentations is None: for name, v in all_instrumentations.items(): if is_package_installed(name): - v.instrument() + try: + v.instrument() + except Exception as e: + print(f"Skipping {name} due to error while instrumenting: {e}") else: @@ -244,4 +247,7 @@ def init_instrumentations( for name, v in filtered_dict.items(): if is_package_installed(name): - v.instrument() + try: + v.instrument() + except Exception as e: + print(f"Skipping {name} due to error while instrumenting: {e}") From ba6ec56f0109ef6f5f7be3cea3811a55049b1ffc Mon Sep 17 00:00:00 2001 From: Ali Waleed Date: Tue, 1 Oct 2024 15:04:09 +0300 Subject: [PATCH 3/3] bump version --- src/langtrace_python_sdk/version.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/langtrace_python_sdk/version.py b/src/langtrace_python_sdk/version.py index 31af30a2..6b639a44 100644 --- a/src/langtrace_python_sdk/version.py +++ b/src/langtrace_python_sdk/version.py @@ -1 +1 @@ -__version__ = "2.3.22" +__version__ = "2.3.23"