Skip to content

Commit

Permalink
make connectors use user provided base_url when the provided model is…
Browse files Browse the repository at this point in the history
… also hosted

previously, if a model was known to be hosted on a custom endpoint and a user ran the same model locally, the connectors would favor the hosted version. for instance, NVIDIAEmbeddings(model="NV-Embed-QA", base_url="http://localhost/v1") would contact the hosted NV-Embed-QA. likewise, ChatNVIDIA(model="mistralai/mixtral-8x7b-instruct-v0.1", base_url="http://localhost/v1") would contact the hosted mistralai/mixtral-8x7b-instruct-v0.1.

fixes #31
  • Loading branch information
mattf committed Jun 3, 2024
1 parent 6c70ed2 commit f369509
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 21 deletions.
5 changes: 5 additions & 0 deletions libs/ai-endpoints/langchain_nvidia_ai_endpoints/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,6 +524,11 @@ def _postprocess_args(cls, values: Any) -> Any:
name = values.get("model")
if model := determine_model(name):
values["model"] = model.id
# not all models are on https://integrate.api.nvidia.com/v1,
# those that are not are served from their own endpoints
if model.endpoint:
# we override the infer_path to use the custom endpoint
values["client"].infer_path = model.endpoint
else:
if not (client := values.get("client")):
warnings.warn(f"Unable to determine validity of {name}")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -168,17 +168,11 @@ def __init__(self, **kwargs: Any):
environment variable.
"""
super().__init__(**kwargs)
infer_path = "{base_url}/chat/completions"
# not all chat models are on https://integrate.api.nvidia.com/v1,
# those that are not are served from their own endpoints
if model := determine_model(self.model):
if model.endpoint: # some models have custom endpoints
infer_path = model.endpoint
self._client = _NVIDIAClient(
base_url=self.base_url,
model=self.model,
api_key=kwargs.get("nvidia_api_key", kwargs.get("api_key", None)),
infer_path=infer_path,
infer_path="{base_url}/chat/completions",
)
# todo: only store the model in one place
# the model may be updated to a newer name during initialization
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,17 +69,11 @@ def __init__(self, **kwargs: Any):
environment variable.
"""
super().__init__(**kwargs)
infer_path = "{base_url}/embeddings"
# not all embedding models are on https://integrate.api.nvidia.com/v1,
# those that are not are served from their own endpoints
if model := determine_model(self.model):
if model.endpoint: # some models have custom endpoints
infer_path = model.endpoint
self._client = _NVIDIAClient(
base_url=self.base_url,
model=self.model,
api_key=kwargs.get("nvidia_api_key", kwargs.get("api_key", None)),
infer_path=infer_path,
infer_path="{base_url}/embeddings",
)
# todo: only store the model in one place
# the model may be updated to a newer name during initialization
Expand Down
8 changes: 1 addition & 7 deletions libs/ai-endpoints/langchain_nvidia_ai_endpoints/reranking.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,17 +62,11 @@ def __init__(self, **kwargs: Any):
environment variable.
"""
super().__init__(**kwargs)
infer_path = "{base_url}/ranking"
# not all models are on https://integrate.api.nvidia.com/v1,
# those that are not are served from their own endpoints
if model := determine_model(self.model):
if model.endpoint: # some models have custom endpoints
infer_path = model.endpoint
self._client = _NVIDIAClient(
base_url=self.base_url,
model=self.model,
api_key=kwargs.get("nvidia_api_key", kwargs.get("api_key", None)),
infer_path=infer_path,
infer_path="{base_url}/ranking",
)
# todo: only store the model in one place
# the model may be updated to a newer name during initialization
Expand Down

0 comments on commit f369509

Please sign in to comment.