Skip to content

Commit

Permalink
Merge pull request MLSysOps#214 from leeeizhang/lei/prompt-tracking
Browse files Browse the repository at this point in the history
[MRG] add langfuse
  • Loading branch information
huangyz0918 authored Sep 17, 2024
2 parents 06dca25 + 08b0b41 commit a4bbca4
Show file tree
Hide file tree
Showing 4 changed files with 154 additions and 9 deletions.
52 changes: 44 additions & 8 deletions mle/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,22 +534,58 @@ def stream(self, chat_history, **kwargs):
else:
yield chunk.choices[0].delta.content

def load_model(project_dir: str, model_name: Optional[str]=None):

class ObservableModel:
"""
A class that wraps a model to make it trackable by the metric platform (e.g., Langfuse).
"""

try:
from mle.utils import get_langfuse_observer
_observe = get_langfuse_observer()
except Exception as e:
# If importing fails, set _observe to a lambda function that does nothing.
_observe = lambda fn: fn

def __init__(self, model: Model):
"""
Initialize the ObservableModel.
Args:
model: The model to be wrapped and made observable.
"""
self.model = model

@_observe
def query(self, *args, **kwargs):
return self.model.query(*args, **kwargs)

@_observe
def stream(self, *args, **kwargs):
return self.model.query(*args, **kwargs)


def load_model(project_dir: str, model_name: Optional[str]=None, observable=True):
"""
load_model: load the model based on the configuration.
Args:
project_dir (str): The project directory.
model_name (str): The model name.
observable (boolean): Whether the model should be tracked.
"""
config = get_config(project_dir)
model = None

if config['platform'] == MODEL_OLLAMA:
model = OllamaModel(model=model_name)
if config['platform'] == MODEL_OPENAI:
return OpenAIModel(api_key=config['api_key'], model=model_name)
model = OpenAIModel(api_key=config['api_key'], model=model_name)
if config['platform'] == MODEL_CLAUDE:
return ClaudeModel(api_key=config['api_key'], model=model_name)
if config['platform'] == MODEL_OLLAMA:
return OllamaModel(model=model_name)
model = ClaudeModel(api_key=config['api_key'], model=model_name)
if config['platform'] == MODEL_MISTRAL:
return MistralModel(api_key=config['api_key'], model=model_name)
model = MistralModel(api_key=config['api_key'], model=model_name)
if config['platform'] == MODEL_DEEPSEEK:
return DeepSeekModel(api_key=config['api_key'], model=model_name)
return None
model = DeepSeekModel(api_key=config['api_key'], model=model_name)

if observable:
return ObservableModel(model)
return model
1 change: 1 addition & 0 deletions mle/utils/data.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import re
import json


def clean_json_string(input_string):
Expand Down
109 changes: 108 additions & 1 deletion mle/utils/system.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
import os
import re
import uuid
import yaml
import base64
import shutil
import requests
import platform
import subprocess
from typing import Dict, Any, Optional
import importlib.util
from typing import Dict, Any, Optional, Callable
from rich.panel import Panel
from rich.prompt import Prompt
from rich.console import Console
Expand Down Expand Up @@ -257,3 +260,107 @@ def startup_web(host: str = "0.0.0.0", port: int = 3000):
"Please install `npm` and `nodejs` before starting the web applications.\n"
"Refer to: https://nodejs.org/en/download/package-manager"
)


def get_user_id():
"""
Get the unique user id of the current machine.
"""
system = platform.system()
username = None
hostname = None

if system == "Windows":
username = os.getenv('USERNAME', 'root')
hostname = os.getenv('COMPUTERNAME')
else:
username = os.getenv('USER', 'root')
try:
hostname = os.uname().nodename
except AttributeError:
import socket
hostname = socket.gethostname()

if username and hostname:
return str(uuid.uuid5(uuid.NAMESPACE_DNS, f"{hostname}-{username}"))
else:
return None


def get_session_id():
"""
Get the session id of the current process.
"""
pid = os.getpid()
start = os.stat(__file__).st_ctime if os.path.exists(__file__) else 0
return str(uuid.uuid5(uuid.NAMESPACE_DNS, f"{pid}-{start}"))


def get_langfuse_observer(
secret_key: Optional[str] = None,
public_key: Optional[str] = None,
user_id: Optional[str] = None,
session_id: Optional[str] = None,
host: Optional[str] = None,
):
"""
Get the Langfuse observer.
:param secret_key: Langfuse secret key.
:param public_key: Langfuse public key.
:param host: Optional host address, defaulting to 'https://us.cloud.langfuse.com'.
"""
spec = importlib.util.find_spec("langfuse")
if spec is None:
raise ImportError(
"It seems you didn't install langfuse. In order to enable the observer, "
"please make sure `langfuse` Python package has been installed. "
"More information, please refer to: https://python.reference.langfuse.com/langfuse"
)

if secret_key is None:
secret_key = os.environ["LANGFUSE_SECRET_KEY"]
if public_key is None:
public_key = os.environ["LANGFUSE_PUBLIC_KEY"]
if user_id is None:
user_id = get_user_id()
if session_id is None:
session_id = get_session_id()
if host is None:
host = os.getenv("LANGFUSE_HOST", "https://us.cloud.langfuse.com")

langfuse = importlib.import_module("langfuse.decorators")
langfuse.langfuse_context.configure(
secret_key=secret_key,
public_key=public_key,
host=host,
enabled=True,
)

def _observe(fn: Callable):
@langfuse.observe(as_type="generation")
def _fn(cls, *args, **kwargs):
model = getattr(cls.model, "model", None)
messages = getattr(cls.model, "chat_history", (args, kwargs))
response = fn(cls, *args, **kwargs)
langfuse.langfuse_context.update_current_observation(
model=model,
input=messages,
output=response,
usage={
"input": len(str(messages)),
"output": len(str(response)),
"unit": "TOKENS",
}
)
return response

@langfuse.observe()
def query(*args, **kwargs):
langfuse.langfuse_context.update_current_trace(
user_id=user_id,
session_id=session_id,
)
return _fn(*args, **kwargs)
return query

return _observe
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ questionary
pandas~=2.2.2
tavily-python
instructor
langfuse
setuptools
numexpr~=2.10.1
bottleneck~=1.4.0
Expand Down

0 comments on commit a4bbca4

Please sign in to comment.