Skip to content

Commit

Permalink
Merge pull request #49 from langchain-ai/mattf/default-to-non-hosted-…
Browse files Browse the repository at this point in the history
…when-base_url-provided

make connectors use user provided base_url when the provided model is also hosted
  • Loading branch information
mattf authored Jun 4, 2024
2 parents 6c70ed2 + c24e149 commit bf7026c
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 24 deletions.
5 changes: 5 additions & 0 deletions libs/ai-endpoints/langchain_nvidia_ai_endpoints/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,6 +524,11 @@ def _postprocess_args(cls, values: Any) -> Any:
name = values.get("model")
if model := determine_model(name):
values["model"] = model.id
# not all models are on https://integrate.api.nvidia.com/v1,
# those that are not are served from their own endpoints
if model.endpoint:
# we override the infer_path to use the custom endpoint
values["client"].infer_path = model.endpoint
else:
if not (client := values.get("client")):
warnings.warn(f"Unable to determine validity of {name}")
Expand Down
10 changes: 2 additions & 8 deletions libs/ai-endpoints/langchain_nvidia_ai_endpoints/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
from langchain_core.tools import BaseTool

from langchain_nvidia_ai_endpoints._common import _NVIDIAClient
from langchain_nvidia_ai_endpoints._statics import Model, determine_model
from langchain_nvidia_ai_endpoints._statics import Model

_CallbackManager = Union[AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun]
_DictOrPydanticClass = Union[Dict[str, Any], Type[BaseModel]]
Expand Down Expand Up @@ -168,17 +168,11 @@ def __init__(self, **kwargs: Any):
environment variable.
"""
super().__init__(**kwargs)
infer_path = "{base_url}/chat/completions"
# not all chat models are on https://integrate.api.nvidia.com/v1,
# those that are not are served from their own endpoints
if model := determine_model(self.model):
if model.endpoint: # some models have custom endpoints
infer_path = model.endpoint
self._client = _NVIDIAClient(
base_url=self.base_url,
model=self.model,
api_key=kwargs.get("nvidia_api_key", kwargs.get("api_key", None)),
infer_path=infer_path,
infer_path="{base_url}/chat/completions",
)
# todo: only store the model in one place
# the model may be updated to a newer name during initialization
Expand Down
10 changes: 2 additions & 8 deletions libs/ai-endpoints/langchain_nvidia_ai_endpoints/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from langchain_core.pydantic_v1 import BaseModel, Field, PrivateAttr, validator

from langchain_nvidia_ai_endpoints._common import _NVIDIAClient
from langchain_nvidia_ai_endpoints._statics import Model, determine_model
from langchain_nvidia_ai_endpoints._statics import Model
from langchain_nvidia_ai_endpoints.callbacks import usage_callback_var


Expand Down Expand Up @@ -69,17 +69,11 @@ def __init__(self, **kwargs: Any):
environment variable.
"""
super().__init__(**kwargs)
infer_path = "{base_url}/embeddings"
# not all embedding models are on https://integrate.api.nvidia.com/v1,
# those that are not are served from their own endpoints
if model := determine_model(self.model):
if model.endpoint: # some models have custom endpoints
infer_path = model.endpoint
self._client = _NVIDIAClient(
base_url=self.base_url,
model=self.model,
api_key=kwargs.get("nvidia_api_key", kwargs.get("api_key", None)),
infer_path=infer_path,
infer_path="{base_url}/embeddings",
)
# todo: only store the model in one place
# the model may be updated to a newer name during initialization
Expand Down
10 changes: 2 additions & 8 deletions libs/ai-endpoints/langchain_nvidia_ai_endpoints/reranking.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from langchain_core.pydantic_v1 import BaseModel, Field, PrivateAttr

from langchain_nvidia_ai_endpoints._common import _NVIDIAClient
from langchain_nvidia_ai_endpoints._statics import Model, determine_model
from langchain_nvidia_ai_endpoints._statics import Model


class Ranking(BaseModel):
Expand Down Expand Up @@ -62,17 +62,11 @@ def __init__(self, **kwargs: Any):
environment variable.
"""
super().__init__(**kwargs)
infer_path = "{base_url}/ranking"
# not all models are on https://integrate.api.nvidia.com/v1,
# those that are not are served from their own endpoints
if model := determine_model(self.model):
if model.endpoint: # some models have custom endpoints
infer_path = model.endpoint
self._client = _NVIDIAClient(
base_url=self.base_url,
model=self.model,
api_key=kwargs.get("nvidia_api_key", kwargs.get("api_key", None)),
infer_path=infer_path,
infer_path="{base_url}/ranking",
)
# todo: only store the model in one place
# the model may be updated to a newer name during initialization
Expand Down

0 comments on commit bf7026c

Please sign in to comment.