diff --git a/libs/ai-endpoints/langchain_nvidia_ai_endpoints/_common.py b/libs/ai-endpoints/langchain_nvidia_ai_endpoints/_common.py index 411f70ad..a1e2ed98 100644 --- a/libs/ai-endpoints/langchain_nvidia_ai_endpoints/_common.py +++ b/libs/ai-endpoints/langchain_nvidia_ai_endpoints/_common.py @@ -32,6 +32,7 @@ ) from requests.models import Response +import langchain_nvidia_ai_endpoints.utils as utils from langchain_nvidia_ai_endpoints._statics import MODEL_TABLE, Model, determine_model logger = logging.getLogger(__name__) @@ -447,6 +448,7 @@ class _NVIDIAClient(BaseModel): model: Optional[str] = Field(..., description="Name of the model to invoke") is_hosted: bool = Field(True) + cls: str = Field(..., description="Class Name") #################################################################################### @@ -504,56 +506,10 @@ def _preprocess_args(cls, values: Any) -> Any: @root_validator def _postprocess_args(cls, values: Any) -> Any: - name = values.get("model") if values["is_hosted"]: - if not values["client"].api_key: - warnings.warn( - "An API key is required for the hosted NIM. " - "This will become an error in the future.", - UserWarning, - ) - if model := determine_model(name): - values["model"] = model.id - # not all models are on https://integrate.api.nvidia.com/v1, - # those that are not are served from their own endpoints - if model.endpoint: - # we override the infer_path to use the custom endpoint - values["client"].infer_path = model.endpoint - else: - if not (client := values.get("client")): - warnings.warn(f"Unable to determine validity of {name}") - else: - if any(model.id == name for model in client.available_models): - warnings.warn( - f"Found {name} in available_models, but type is " - "unknown and inference may fail." - ) - else: - raise ValueError( - f"Model {name} is unknown, check `available_models`" - ) + utils._process_hosted_model(values) else: - # set default model - if not name: - if not (client := values.get("client")): - warnings.warn(f"Unable to determine validity of {name}") - else: - valid_models = [ - model.id - for model in client.available_models - if not model.base_model or model.base_model == model.id - ] - name = next(iter(valid_models), None) - if name: - warnings.warn( - f"Default model is set as: {name}. \n" - "Set model using model parameter. \n" - "To get available models use available_models property.", - UserWarning, - ) - values["model"] = name - else: - raise ValueError("No locally hosted model was found.") + utils._process_locally_hosted_model(values) return values @classmethod diff --git a/libs/ai-endpoints/langchain_nvidia_ai_endpoints/chat_models.py b/libs/ai-endpoints/langchain_nvidia_ai_endpoints/chat_models.py index 224a19bb..e7c18884 100644 --- a/libs/ai-endpoints/langchain_nvidia_ai_endpoints/chat_models.py +++ b/libs/ai-endpoints/langchain_nvidia_ai_endpoints/chat_models.py @@ -245,6 +245,7 @@ def __init__(self, **kwargs: Any): default_model=self._default_model, api_key=kwargs.get("nvidia_api_key", kwargs.get("api_key", None)), infer_path="{base_url}/chat/completions", + cls=self.__class__.__name__, ) # todo: only store the model in one place # the model may be updated to a newer name during initialization diff --git a/libs/ai-endpoints/langchain_nvidia_ai_endpoints/embeddings.py b/libs/ai-endpoints/langchain_nvidia_ai_endpoints/embeddings.py index 244c7bed..68871ded 100644 --- a/libs/ai-endpoints/langchain_nvidia_ai_endpoints/embeddings.py +++ b/libs/ai-endpoints/langchain_nvidia_ai_endpoints/embeddings.py @@ -95,6 +95,7 @@ def __init__(self, **kwargs: Any): default_model=self._default_model, api_key=kwargs.get("nvidia_api_key", kwargs.get("api_key", None)), infer_path="{base_url}/embeddings", + cls=self.__class__.__name__, ) # todo: only store the model in one place # the model may be updated to a newer name during initialization diff --git a/libs/ai-endpoints/langchain_nvidia_ai_endpoints/reranking.py b/libs/ai-endpoints/langchain_nvidia_ai_endpoints/reranking.py index 9f9292ed..d6ab54e1 100644 --- a/libs/ai-endpoints/langchain_nvidia_ai_endpoints/reranking.py +++ b/libs/ai-endpoints/langchain_nvidia_ai_endpoints/reranking.py @@ -87,6 +87,7 @@ def __init__(self, **kwargs: Any): default_model=self._default_model_name, api_key=kwargs.get("nvidia_api_key", kwargs.get("api_key", None)), infer_path="{base_url}/ranking", + cls=self.__class__.__name__, ) # todo: only store the model in one place # the model may be updated to a newer name during initialization diff --git a/libs/ai-endpoints/langchain_nvidia_ai_endpoints/utils.py b/libs/ai-endpoints/langchain_nvidia_ai_endpoints/utils.py new file mode 100644 index 00000000..ad83bbb7 --- /dev/null +++ b/libs/ai-endpoints/langchain_nvidia_ai_endpoints/utils.py @@ -0,0 +1,160 @@ +import warnings +from typing import Any, Dict, Optional + +from langchain_nvidia_ai_endpoints._statics import Model, determine_model + + +def _process_hosted_model(values: Dict) -> None: + """ + Process logic for a hosted model. Validates compatibility, sets model ID, + and adjusts client's infer path if necessary. + + Raises: + ValueError: If the model is incompatible with the client or is unknown. + """ + name = values["model"] + cls_name = values["cls"] + client = values.get("client") + + if client and hasattr(client, "api_key") and not client.api_key: + warnings.warn( + "An API key is required for the hosted NIM. " + "This will become an error in the future.", + UserWarning, + ) + + model = determine_model(name) + if model: + _validate_hosted_model_compatibility(name, cls_name, model) + values["model"] = model.id + if model.endpoint: + values["client"].infer_path = model.endpoint + else: + _handle_unknown_hosted_model(name, client) + + +def _validate_hosted_model_compatibility( + name: str, cls_name: Optional[str], model: Model +) -> None: + """ + Validates compatibility of the hosted model with the client. + + Args: + name (str): The name of the model. + cls_name (str): The name of the client class. + model (Any): The model object. + Raises: + ValueError: If the model is incompatible with the client. + """ + if not model.client: + warnings.warn(f"Unable to determine validity of {name}") + elif model.client != cls_name: + raise ValueError( + f"Model {name} is incompatible with client {cls_name}. " + f"Please check `{cls_name}.get_available_models()`." + ) + + +def _handle_unknown_hosted_model(name: str, client: Any) -> None: + """ + Handles scenarios where the hosted model is unknown or its type is unclear. + Raises: + ValueError: If the model is unknown. + """ + if not client: + warnings.warn(f"Unable to determine validity of {name}") + elif any(model.id == name for model in client.available_models): + warnings.warn( + f"Found {name} in available_models, but type is " + "unknown and inference may fail." + ) + else: + raise ValueError(f"Model {name} is unknown, check `available_models`.") + + +def _process_locally_hosted_model(values: Dict) -> None: + """ + Process logic for a locally hosted model. + Validates compatibility and sets default model. + + Raises: + ValueError: If the model is incompatible with the client or is unknown. + """ + name = values["model"] + cls_name = values["cls"] + client = values.get("client") + + if name and isinstance(name, str): + model = determine_model(name) + if model: + _validate_locally_hosted_model_compatibility(name, cls_name, model, client) + else: + _handle_unknown_locally_hosted_model(name, client) + else: + _set_default_model(values, client) + + +def _validate_locally_hosted_model_compatibility( + model_name: str, cls_name: str, model: Model, client: Any +) -> None: + """ + Validates compatibility of the locally hosted model with the client. + + Args: + model_name (str): The name of the model. + cls_name (str): The name of the client class. + model (Any): The model object. + client (Any): The client object. + + Raises: + ValueError: If the model is incompatible with the client or is unknown. + """ + if model.client != cls_name: + raise ValueError( + f"Model {model_name} is incompatible with client {cls_name}. " + f"Please check `{cls_name}.get_available_models()`." + ) + + if model_name not in [model.id for model in client.available_models]: + raise ValueError( + f"Locally hosted {model_name} model was found, check `available_models`." + ) + + +def _handle_unknown_locally_hosted_model(model_name: str, client: Any) -> None: + """ + Handles scenarios where the locally hosted model is unknown. + + Raises: + ValueError: If the model is unknown. + """ + if model_name not in [model.id for model in client.available_models]: + raise ValueError(f"Model {model_name} is unknown, check `available_models`.") + + +def _set_default_model(values: Dict, client: Any) -> None: + """ + Sets a default model based on client's available models. + + Raises: + ValueError: If no locally hosted model was found. + """ + values["model"] = next( + iter( + [ + model.id + for model in client.available_models + if not model.base_model or model.base_model == model.id + ] + ), + None, + ) + if values["model"]: + warnings.warn( + f'Default model is set as: {values["model"]}. \n' + "Set model using model parameter. \n" + "To get available models use available_models property.", + UserWarning, + ) + else: + raise ValueError("No locally hosted model was found.") diff --git a/libs/ai-endpoints/tests/integration_tests/test_base_url.py b/libs/ai-endpoints/tests/integration_tests/test_base_url.py index 81cb0119..3879d4ee 100644 --- a/libs/ai-endpoints/tests/integration_tests/test_base_url.py +++ b/libs/ai-endpoints/tests/integration_tests/test_base_url.py @@ -1,3 +1,4 @@ +import re from typing import Any import pytest @@ -7,19 +8,37 @@ # Fixture setup /v1/chat/completions endpoints @pytest.fixture() -def mock_endpoints(requests_mock: Mocker, base_url: str) -> None: - for endpoint in ["/v1/embeddings", "/v1/chat/completions", "/v1/ranking"]: +def mock_endpoints(requests_mock: Mocker) -> None: + for endpoint in [ + "/v1/models", + "/v1/embeddings", + "/v1/chat/completions", + "/v1/ranking", + ]: requests_mock.post( - f"{base_url}{endpoint}", + re.compile(f".*{endpoint}"), exc=ConnectionError(f"Mocked ConnectionError for {endpoint}"), ) + requests_mock.get( + re.compile(".*/v1/models"), + json={ + "data": [ + { + "id": "not-a-model", + "object": "model", + "created": 1234567890, + "owned_by": "OWNER", + }, + ] + }, + ) # Test function using the mock_endpoints fixture @pytest.mark.parametrize( "base_url", [ - "http://localhost:12321", + "http://localhost:12321/v1", ], ) def test_endpoint_unavailable( @@ -31,5 +50,9 @@ def test_endpoint_unavailable( # we test this with a bogus model because users should supply # a model when using their own base_url client = public_class(model="not-a-model", base_url=base_url) - with pytest.raises(ConnectionError): + with pytest.raises(ConnectionError) as e: contact_service(client) + assert "Mocked ConnectionError for" in str(e.value) + + +# todo: move this to be a unit test diff --git a/libs/ai-endpoints/tests/integration_tests/test_chat_models.py b/libs/ai-endpoints/tests/integration_tests/test_chat_models.py index 260998aa..25018187 100644 --- a/libs/ai-endpoints/tests/integration_tests/test_chat_models.py +++ b/libs/ai-endpoints/tests/integration_tests/test_chat_models.py @@ -41,11 +41,6 @@ def test_unknown_model() -> None: ChatNVIDIA(model="unknown_model") -def test_base_url_unknown_model() -> None: - llm = ChatNVIDIA(model="unknown_model", base_url="http://localhost:88888/v1") - assert llm.model == "unknown_model" - - def test_chat_ai_endpoints_system_message(chat_model: str, mode: dict) -> None: """Test wrapper with system message.""" # mamba_chat only supports 'user' or 'assistant' messages - diff --git a/libs/ai-endpoints/tests/integration_tests/test_register_model.py b/libs/ai-endpoints/tests/integration_tests/test_register_model.py index 6a08c2fe..d43d4ba3 100644 --- a/libs/ai-endpoints/tests/integration_tests/test_register_model.py +++ b/libs/ai-endpoints/tests/integration_tests/test_register_model.py @@ -1,3 +1,4 @@ +import warnings from typing import Any import pytest @@ -39,15 +40,14 @@ def test_registered_model_functional( client: type, id: str, endpoint: str, contact_service: Any ) -> None: model = Model(id=id, endpoint=endpoint) - with pytest.warns( - UserWarning - ) as record: # warns because we're overriding known models - register_model(model) - contact_service(client(model=id)) - assert len(record) == 1 - assert isinstance(record[0].message, UserWarning) - assert "already registered" in str(record[0].message) - assert "Overriding" in str(record[0].message) + warnings.filterwarnings( + "ignore", r".*is already registered.*" + ) # intentionally overridding known models + warnings.filterwarnings( + "ignore", r".*Unable to determine validity of.*" + ) # we aren't passing client & type to Model() + register_model(model) + contact_service(client(model=id)) def test_registered_model_is_available() -> None: diff --git a/libs/ai-endpoints/tests/unit_tests/test_api_key.py b/libs/ai-endpoints/tests/unit_tests/test_api_key.py index 9878d973..2f39dfb9 100644 --- a/libs/ai-endpoints/tests/unit_tests/test_api_key.py +++ b/libs/ai-endpoints/tests/unit_tests/test_api_key.py @@ -21,6 +21,24 @@ def no_env_var(var: str) -> Generator[None, None, None]: del os.environ[var] +@pytest.fixture(autouse=True) +def mock_endpoint_models(requests_mock: Mocker) -> None: + requests_mock.get( + "https://integrate.api.nvidia.com/v1/models", + json={ + "data": [ + { + "id": "meta/llama3-8b-instruct", + "object": "model", + "created": 1234567890, + "owned_by": "OWNER", + "root": "model1", + }, + ] + }, + ) + + @pytest.fixture(autouse=True) def mock_v1_local_models(requests_mock: Mocker) -> None: requests_mock.get( @@ -28,11 +46,11 @@ def mock_v1_local_models(requests_mock: Mocker) -> None: json={ "data": [ { - "id": "model1", + "id": "model", "object": "model", "created": 1234567890, "owned_by": "OWNER", - "root": "model1", + "root": "model", }, ] }, diff --git a/libs/ai-endpoints/tests/unit_tests/test_chat_models.py b/libs/ai-endpoints/tests/unit_tests/test_chat_models.py index 1c6ffa72..5e84bf18 100644 --- a/libs/ai-endpoints/tests/unit_tests/test_chat_models.py +++ b/libs/ai-endpoints/tests/unit_tests/test_chat_models.py @@ -2,10 +2,34 @@ import pytest +from requests_mock import Mocker from langchain_nvidia_ai_endpoints.chat_models import ChatNVIDIA +@pytest.fixture +def mock_local_models(requests_mock: Mocker) -> None: + requests_mock.get( + "http://localhost:8888/v1/models", + json={ + "data": [ + { + "id": "unknown_model", + "object": "model", + "created": 1234567890, + "owned_by": "OWNER", + "root": "unknown_model", + }, + ] + }, + ) + + +def test_base_url_unknown_model(mock_local_models: None) -> None: + llm = ChatNVIDIA(model="unknown_model", base_url="http://localhost:8888/v1") + assert llm.model == "unknown_model" + + def test_integration_initialization() -> None: """Test chat model initialization.""" ChatNVIDIA( diff --git a/libs/ai-endpoints/tests/unit_tests/test_model.py b/libs/ai-endpoints/tests/unit_tests/test_model.py index 4f06417b..5e2bcae4 100644 --- a/libs/ai-endpoints/tests/unit_tests/test_model.py +++ b/libs/ai-endpoints/tests/unit_tests/test_model.py @@ -1,7 +1,18 @@ +from itertools import chain +from typing import Any + import pytest from requests_mock import Mocker -from langchain_nvidia_ai_endpoints._statics import MODEL_TABLE +from langchain_nvidia_ai_endpoints import ChatNVIDIA, NVIDIAEmbeddings, NVIDIARerank +from langchain_nvidia_ai_endpoints._statics import ( + CHAT_MODEL_TABLE, + EMBEDDING_MODEL_TABLE, + MODEL_TABLE, + QA_MODEL_TABLE, + RANKING_MODEL_TABLE, + VLM_MODEL_TABLE, +) @pytest.fixture @@ -52,21 +63,39 @@ def mock_v1_local_models(requests_mock: Mocker, known_unknown: str) -> None: @pytest.mark.parametrize( - "alias", + "alias, client", [ - alias - for model in MODEL_TABLE.values() + (alias, ChatNVIDIA) + for model in list( + chain( + CHAT_MODEL_TABLE.values(), + VLM_MODEL_TABLE.values(), + QA_MODEL_TABLE.values(), + ) + ) + if model.aliases is not None + for alias in model.aliases + ] + + [ + (alias, NVIDIAEmbeddings) + for model in EMBEDDING_MODEL_TABLE.values() + if model.aliases is not None + for alias in model.aliases + ] + + [ + (alias, NVIDIARerank) + for model in RANKING_MODEL_TABLE.values() if model.aliases is not None for alias in model.aliases ], ) -def test_aliases(public_class: type, alias: str) -> None: +def test_aliases(alias: str, client: Any) -> None: """ Test that the aliases for each model in the model table are accepted with a warning about deprecation of the alias. """ with pytest.warns(UserWarning) as record: - x = public_class(model=alias, nvidia_api_key="a-bogus-key") + x = client(model=alias, nvidia_api_key="a-bogus-key") assert x.model == x._client.model assert isinstance(record[0].message, Warning) assert "deprecated" in record[0].message.args[0] @@ -100,7 +129,7 @@ def test_known_unknown(public_class: type, known_unknown: str) -> None: assert "unknown" in record[0].message.args[0] -def test_unknown_unknown(public_class: type) -> None: +def test_unknown_unknown(public_class: type, empty_v1_models: None) -> None: """ Test that a model not in /v1/models and not in known model table will be rejected. @@ -128,3 +157,54 @@ def test_default_lora(public_class: type) -> None: # find a model that matches the public_class under test x = public_class(base_url="http://localhost:8000/v1", model="lora1") assert x.model == "lora1" + + +@pytest.mark.parametrize( + "model, client", + [(model.id, model.client) for model in MODEL_TABLE.values()], +) +def test_hosted_all_incompatible(public_class: type, model: str, client: str) -> None: + """ + Test that the aliases for each model in the model table are accepted + with a warning about deprecation of the alias. + """ + msg = ( + "Model {model_name} is incompatible with client {cls_name}. " + "Please check `{cls_name}.get_available_models()`." + ) + + if client != public_class.__name__: + with pytest.raises(ValueError) as err_msg: + public_class(model=model, nvidia_api_key="a-bogus-key") + + assert msg.format(model_name=model, cls_name=public_class.__name__) in str( + err_msg.value + ) + + +@pytest.mark.parametrize( + "model, client", + [(model.id, model.client) for model in MODEL_TABLE.values()], +) +def test_locally_hosted_all_incompatible( + public_class: type, model: str, client: str +) -> None: + """ + Test that the aliases for each model in the model table are accepted + with a warning about deprecation of the alias. + """ + msg = ( + "Model {model_name} is incompatible with client {cls_name}. " + "Please check `available_models`." + ) + if client != public_class.__name__: + with pytest.raises(ValueError) as err_msg: + public_class( + base_url="http://localhost:8000/v1", + model=model, + nvidia_api_key="a-bogus-key", + ) + assert err_msg == msg.format(model_name=model, cls_name=client) + else: + cls = public_class(model=model, nvidia_api_key="a-bogus-key") + assert cls.model == model diff --git a/libs/ai-endpoints/tests/unit_tests/test_register_model.py b/libs/ai-endpoints/tests/unit_tests/test_register_model.py index 4c7c0d38..3cfb788a 100644 --- a/libs/ai-endpoints/tests/unit_tests/test_register_model.py +++ b/libs/ai-endpoints/tests/unit_tests/test_register_model.py @@ -66,9 +66,11 @@ def test_registered_model_without_client_usable(public_class: type) -> None: id = f"test/no-client-{public_class.__name__}" model = Model(id=id, endpoint="BOGUS") register_model(model) - # todo: this should warn that the model is known but type is not - # and therefore inference may not work - public_class(model=id, nvidia_api_key="a-bogus-key") + with pytest.warns(UserWarning) as record: + public_class(model=id, nvidia_api_key="a-bogus-key") + assert len(record) == 1 + assert isinstance(record[0].message, UserWarning) + assert "Unable to determine validity" in str(record[0].message) def test_missing_endpoint() -> None: