Skip to content

Commit

Permalink
Merge pull request #1963 from osok/main
Browse files Browse the repository at this point in the history
added llm_base_url to llm.client.__init__.py
  • Loading branch information
kevinmessiaen authored Sep 2, 2024
2 parents cf74a50 + bd561b1 commit 485c77a
Showing 1 changed file with 17 additions and 5 deletions.
22 changes: 17 additions & 5 deletions giskard/llm/client/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,30 +9,41 @@
_default_client = None
_default_llm_api: Optional[str] = None
_default_llm_model = os.getenv("GSK_LLM_MODEL", "gpt-4")
_default_llm_base_url = os.getenv("GSK_LLM_BASE_URL", None)


def set_default_client(client: LLMClient):
global _default_client
_default_client = client


def _unset_default_client():
global _default_client
_default_client = None


def set_llm_api(llm_api: str):
if llm_api.lower() not in {"azure", "openai"}:
raise ValueError("Giskard LLM-based evaluators is only working with `azure` and `openai`")

global _default_llm_api
_default_llm_api = llm_api.lower()
# If the API is set, we unset the default client
global _default_client
_default_client = None
_unset_default_client()


def set_llm_base_url(llm_base_url: Optional[str]):
global _default_llm_base_url
_default_llm_base_url = llm_base_url
# If the model is set, we unset the default client
_unset_default_client()


def set_llm_model(llm_model: str):
global _default_llm_model
_default_llm_model = llm_model
# If the model is set, we unset the default client
global _default_client
_default_client = None
_unset_default_client()


def get_default_llm_api() -> str:
Expand Down Expand Up @@ -66,7 +77,7 @@ def get_default_client() -> LLMClient:
# For openai>=1.0.0
from openai import AzureOpenAI, OpenAI

client = AzureOpenAI() if default_llm_api == "azure" else OpenAI()
client = AzureOpenAI() if default_llm_api == "azure" else OpenAI(base_url=_default_llm_base_url)

_default_client = OpenAIClient(model=_default_llm_model, client=client)
except ImportError:
Expand All @@ -83,4 +94,5 @@ def get_default_client() -> LLMClient:
"set_llm_model",
"set_llm_api",
"set_default_client",
"set_llm_base_url",
]

0 comments on commit 485c77a

Please sign in to comment.