Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Client rejects incompatible models #81

Merged
merged 11 commits into from
Jul 31, 2024
53 changes: 4 additions & 49 deletions libs/ai-endpoints/langchain_nvidia_ai_endpoints/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import logging
import os
import time
import warnings
from copy import deepcopy
from typing import (
Any,
Expand Down Expand Up @@ -32,6 +31,7 @@
)
from requests.models import Response

import langchain_nvidia_ai_endpoints.utils as utils
from langchain_nvidia_ai_endpoints._statics import MODEL_TABLE, Model, determine_model

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -448,6 +448,7 @@ class _NVIDIAClient(BaseModel):

model: Optional[str] = Field(..., description="Name of the model to invoke")
is_hosted: bool = Field(True)
cls: str = Field(..., description="Class Name")

####################################################################################

Expand All @@ -469,56 +470,10 @@ def _preprocess_args(cls, values: Any) -> Any:

@root_validator
def _postprocess_args(cls, values: Any) -> Any:
name = values.get("model")
if values["is_hosted"]:
if not values["client"].api_key:
warnings.warn(
"An API key is required for the hosted NIM. "
"This will become an error in the future.",
UserWarning,
)
if model := determine_model(name):
values["model"] = model.id
# not all models are on https://integrate.api.nvidia.com/v1,
# those that are not are served from their own endpoints
if model.endpoint:
# we override the infer_path to use the custom endpoint
values["client"].infer_path = model.endpoint
else:
if not (client := values.get("client")):
warnings.warn(f"Unable to determine validity of {name}")
else:
if any(model.id == name for model in client.available_models):
warnings.warn(
f"Found {name} in available_models, but type is "
"unknown and inference may fail."
)
else:
raise ValueError(
f"Model {name} is unknown, check `available_models`"
)
utils._process_hosted_model(values)
else:
# set default model
if not name:
if not (client := values.get("client")):
warnings.warn(f"Unable to determine validity of {name}")
else:
valid_models = [
model.id
for model in client.available_models
if not model.base_model or model.base_model == model.id
]
name = next(iter(valid_models), None)
if name:
warnings.warn(
f"Default model is set as: {name}. \n"
"Set model using model parameter. \n"
"To get available models use available_models property.",
UserWarning,
)
values["model"] = name
else:
raise ValueError("No locally hosted model was found.")
utils._process_locally_hosted_model(values)
return values

@classmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,7 @@ def __init__(self, **kwargs: Any):
default_model=self._default_model,
api_key=kwargs.get("nvidia_api_key", kwargs.get("api_key", None)),
infer_path="{base_url}/chat/completions",
cls=self.__class__.__name__,
)
# todo: only store the model in one place
# the model may be updated to a newer name during initialization
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def __init__(self, **kwargs: Any):
default_model=self._default_model,
api_key=kwargs.get("nvidia_api_key", kwargs.get("api_key", None)),
infer_path="{base_url}/embeddings",
cls=self.__class__.__name__,
)
# todo: only store the model in one place
# the model may be updated to a newer name during initialization
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def __init__(self, **kwargs: Any):
default_model=self._default_model_name,
api_key=kwargs.get("nvidia_api_key", kwargs.get("api_key", None)),
infer_path="{base_url}/ranking",
cls=self.__class__.__name__,
)
# todo: only store the model in one place
# the model may be updated to a newer name during initialization
Expand Down
158 changes: 158 additions & 0 deletions libs/ai-endpoints/langchain_nvidia_ai_endpoints/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
import warnings
from typing import Any, Dict, Optional

from langchain_nvidia_ai_endpoints._statics import Model, determine_model


def _process_hosted_model(values: Dict) -> None:
"""
Process logic for a hosted model. Validates compatibility, sets model ID,
and adjusts client's infer path if necessary.

Raises:
ValueError: If the model is incompatible with the client or is unknown.
"""
name = values["model"]
cls_name = values["cls"]
client = values.get("client")

if client and hasattr(client, "api_key") and not client.api_key:
warnings.warn(
"An API key is required for the hosted NIM. "
"This will become an error in the future.",
UserWarning,
)

model = determine_model(name)
if model:
_validate_hosted_model_compatibility(name, cls_name, model)
values["model"] = model.id
if model.endpoint:
values["client"].infer_path = model.endpoint
else:
_handle_unknown_hosted_model(name, client)


def _validate_hosted_model_compatibility(
name: str, cls_name: Optional[str], model: Model
) -> None:
"""
Validates compatibility of the hosted model with the client.

Args:
name (str): The name of the model.
cls_name (str): The name of the client class.
model (Any): The model object.
Raises:
ValueError: If the model is incompatible with the client.
"""
if model.client != cls_name:
raise ValueError(
f"Model {name} is incompatible with client {cls_name}. "
"Please check `available_models`."
raspawar marked this conversation as resolved.
Show resolved Hide resolved
)


def _handle_unknown_hosted_model(name: str, client: Any) -> None:
"""
Handles scenarios where the hosted model is unknown or its type is unclear.
Raises:
ValueError: If the model is unknown.
"""
if not client:
warnings.warn(f"Unable to determine validity of {name}")
elif any(model.id == name for model in client.available_models):
warnings.warn(
f"Found {name} in available_models, but type is "
"unknown and inference may fail."
)
else:
raise ValueError(f"Model {name} is unknown, check `available_models`.")


def _process_locally_hosted_model(values: Dict) -> None:
"""
Process logic for a locally hosted model.
Validates compatibility and sets default model.

Raises:
ValueError: If the model is incompatible with the client or is unknown.
"""
name = values["model"]
cls_name = values["cls"]
client = values.get("client")

if name and isinstance(name, str):
model = determine_model(name)
if model:
_validate_locally_hosted_model_compatibility(name, cls_name, model, client)
else:
_handle_unknown_locally_hosted_model(name, client)
else:
_set_default_model(values, client)


def _validate_locally_hosted_model_compatibility(
model_name: str, cls_name: str, model: Model, client: Any
) -> None:
"""
Validates compatibility of the locally hosted model with the client.

Args:
model_name (str): The name of the model.
cls_name (str): The name of the client class.
model (Any): The model object.
client (Any): The client object.

Raises:
ValueError: If the model is incompatible with the client or is unknown.
"""
if model.client != cls_name:
raise ValueError(
f"Model {model_name} is incompatible with client {cls_name}. "
"Please check `available_models`."
)

if model_name not in [model.id for model in client.available_models]:
raise ValueError(
f"Locally hosted {model_name} model was found, check `available_models`."
)


def _handle_unknown_locally_hosted_model(model_name: str, client: Any) -> None:
"""
Handles scenarios where the locally hosted model is unknown.

Raises:
ValueError: If the model is unknown.
"""
if model_name not in [model.id for model in client.available_models]:
raise ValueError(f"Model {model_name} is unknown, check `available_models`.")


def _set_default_model(values: Dict, client: Any) -> None:
"""
Sets a default model based on client's available models.

Raises:
ValueError: If no locally hosted model was found.
"""
values["model"] = next(
iter(
[
model.id
for model in client.available_models
if not model.base_model or model.base_model == model.id
]
),
None,
)
if values["model"]:
warnings.warn(
f'Default model is set as: {values["model"]}. \n'
"Set model using model parameter. \n"
"To get available models use available_models property.",
UserWarning,
)
else:
raise ValueError("No locally hosted model was found.")
2 changes: 1 addition & 1 deletion libs/ai-endpoints/tests/integration_tests/test_base_url.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,6 @@ def test_endpoint_unavailable(
) -> None:
# we test this with a bogus model because users should supply
# a model when using their own base_url
client = public_class(model="not-a-model", base_url=base_url)
with pytest.raises(ConnectionError):
client = public_class(model="not-a-model", base_url=base_url)
raspawar marked this conversation as resolved.
Show resolved Hide resolved
contact_service(client)
23 changes: 21 additions & 2 deletions libs/ai-endpoints/tests/integration_tests/test_chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
HumanMessage,
SystemMessage,
)
from requests_mock import Mocker

from langchain_nvidia_ai_endpoints.chat_models import ChatNVIDIA

Expand All @@ -23,6 +24,24 @@
#


@pytest.fixture
def mock_local_models(requests_mock: Mocker) -> None:
requests_mock.get(
"http://localhost:8888/v1/models",
json={
"data": [
{
"id": "unknown_model",
"object": "model",
"created": 1234567890,
"owned_by": "OWNER",
"root": "unknown_model",
},
]
},
)


def test_chat_ai_endpoints(chat_model: str, mode: dict) -> None:
"""Test ChatNVIDIA wrapper."""
chat = ChatNVIDIA(model=chat_model, temperature=0.7, **mode)
Expand All @@ -41,8 +60,8 @@ def test_unknown_model() -> None:
ChatNVIDIA(model="unknown_model")


def test_base_url_unknown_model() -> None:
llm = ChatNVIDIA(model="unknown_model", base_url="http://localhost:88888/v1")
def test_base_url_unknown_model(mock_local_models: None) -> None:
raspawar marked this conversation as resolved.
Show resolved Hide resolved
llm = ChatNVIDIA(model="unknown_model", base_url="http://localhost:8888/v1")
assert llm.model == "unknown_model"


Expand Down
11 changes: 8 additions & 3 deletions libs/ai-endpoints/tests/integration_tests/test_register_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,29 +16,34 @@
# you will have to find the new ones from https://api.nvcf.nvidia.com/v2/nvcf/functions
#
@pytest.mark.parametrize(
"client, id, endpoint",
"client, id, endpoint, model_type",
[
(
ChatNVIDIA,
"meta/llama3-8b-instruct",
"https://api.nvcf.nvidia.com/v2/nvcf/pexec/functions/a5a3ad64-ec2c-4bfc-8ef7-5636f26630fe",
"chat",
),
(
NVIDIAEmbeddings,
"NV-Embed-QA",
"https://api.nvcf.nvidia.com/v2/nvcf/pexec/functions/09c64e32-2b65-4892-a285-2f585408d118",
"embedding",
),
(
NVIDIARerank,
"nv-rerank-qa-mistral-4b:1",
"https://api.nvcf.nvidia.com/v2/nvcf/pexec/functions/0bf77f50-5c35-4488-8e7a-f49bb1974af6",
"ranking",
),
],
)
def test_registered_model_functional(
client: type, id: str, endpoint: str, contact_service: Any
client: type, id: str, endpoint: str, model_type: str, contact_service: Any
) -> None:
model = Model(id=id, endpoint=endpoint)
model = Model(
id=id, endpoint=endpoint, client=client.__name__, model_type=model_type
)
raspawar marked this conversation as resolved.
Show resolved Hide resolved
with pytest.warns(
UserWarning
) as record: # warns because we're overriding known models
Expand Down
22 changes: 20 additions & 2 deletions libs/ai-endpoints/tests/unit_tests/test_api_key.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,36 @@ def no_env_var(var: str) -> Generator[None, None, None]:
os.environ[var] = val


@pytest.fixture(autouse=True)
def mock_endpoint_models(requests_mock: Mocker) -> None:
requests_mock.get(
"https://integrate.api.nvidia.com/v1/models",
json={
"data": [
{
"id": "meta/llama3-8b-instruct",
"object": "model",
"created": 1234567890,
"owned_by": "OWNER",
"root": "model1",
},
]
},
)


@pytest.fixture(autouse=True)
def mock_v1_local_models(requests_mock: Mocker) -> None:
requests_mock.get(
"https://test_url/v1/models",
json={
"data": [
{
"id": "model1",
"id": "model",
mattf marked this conversation as resolved.
Show resolved Hide resolved
"object": "model",
"created": 1234567890,
"owned_by": "OWNER",
"root": "model1",
"root": "model",
},
]
},
Expand Down
Loading
Loading