diff --git a/CHANGELOG.md b/CHANGELOG.md index d79a9a2a1..ee839cb01 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,7 +13,7 @@ * `run_dataset` now has a flag `trace_examples_individually` to create `Tracer`s for each example. Defaults to True. ### Fixes -... + - ControlModels throw warning instead of error in case a not recommended model is selected. ### Deprecations ... diff --git a/src/intelligence_layer/core/model.py b/src/intelligence_layer/core/model.py index f101967a5..dfbac7844 100644 --- a/src/intelligence_layer/core/model.py +++ b/src/intelligence_layer/core/model.py @@ -1,6 +1,7 @@ +import warnings from abc import ABC, abstractmethod from functools import lru_cache -from typing import Literal, Optional +from typing import Optional from aleph_alpha_client import ( CompletionRequest, @@ -153,8 +154,9 @@ def __init__( limited_concurrency_client_from_env() if client is None else client ) if name not in [model["name"] for model in self._client.models()]: - raise ValueError( - f"Could not find model: {name}. Either model name is invalid or model is currently down." + warnings.warn( + "The provided model is not a recommended model for this model class." + "Make sure that the model you have selected is suited to be use for the prompt template used in this model class." ) self._complete: Task[CompleteInput, CompleteOutput] = _Complete( self._client, name @@ -195,13 +197,16 @@ def tokenize(self, text: str) -> Encoding: class ControlModel(ABC, AlephAlphaModel): - AllowedModel: Literal[""] + RECOMMENDED_MODELS = [""] def __init__( self, name: str, client: AlephAlphaClientProtocol | None = None ) -> None: - if name not in self.AllowedModel.__args__ or name == "": # type: ignore - raise ValueError(f"Invalid model name: {name}") + if name not in self.RECOMMENDED_MODELS or name == "": + warnings.warn( + "The provided model is not a recommended model for this model class." + "Make sure that the model you have selected is suited to be use for the prompt template used in this model class." + ) super().__init__(name, client) @abstractmethod @@ -232,7 +237,7 @@ class LuminousControlModel(ControlModel): ### Response:{{response_prefix}}""" ) - AllowedModel = Literal[ + RECOMMENDED_MODELS = [ "luminous-base-control-20230501", "luminous-extended-control-20230501", "luminous-supreme-control-20230501", @@ -246,7 +251,7 @@ class LuminousControlModel(ControlModel): def __init__( self, - name: AllowedModel = "luminous-base-control", + name: str = "luminous-base-control", client: Optional[AlephAlphaClientProtocol] = None, ) -> None: super().__init__(name, client) @@ -282,7 +287,7 @@ class Llama2InstructModel(ControlModel): {{response_prefix}}{% endif %}""") - AllowedModel = Literal[ + RECOMMENDED_MODELS = [ "llama-2-7b-chat", "llama-2-13b-chat", "llama-2-70b-chat", @@ -290,7 +295,7 @@ class Llama2InstructModel(ControlModel): def __init__( self, - name: AllowedModel = "llama-2-13b-chat", + name: str = "llama-2-13b-chat", client: Optional[AlephAlphaClientProtocol] = None, ) -> None: super().__init__(name, client) @@ -327,14 +332,14 @@ class Llama3InstructModel(ControlModel): ) EOT_TOKEN = "<|eot_id|>" - AllowedModel = Literal[ + RECOMMENDED_MODELS = [ "llama-3-8b-instruct", "llama-3-70b-instruct", ] def __init__( self, - name: AllowedModel = "llama-3-8b-instruct", + name: str = "llama-3-8b-instruct", client: Optional[AlephAlphaClientProtocol] = None, ) -> None: super().__init__(name, client) diff --git a/tests/core/test_model.py b/tests/core/test_model.py index d59bc36ff..b54035fec 100644 --- a/tests/core/test_model.py +++ b/tests/core/test_model.py @@ -97,17 +97,17 @@ def test_models_know_their_context_size(client: AlephAlphaClientProtocol) -> Non ) -def test_models_are_strict_about_instantiation( +def test_models_warn_about_non_recommended_models( client: AlephAlphaClientProtocol, ) -> None: - with pytest.raises(ValueError): + with pytest.warns(UserWarning): assert LuminousControlModel(client=client, name="llama-2-7b-chat") # type: ignore - with pytest.raises(ValueError): + with pytest.warns(UserWarning): assert Llama2InstructModel(client=client, name="luminous-base") # type: ignore - with pytest.raises(ValueError): + with pytest.warns(UserWarning): assert Llama3InstructModel(client=client, name="llama-2-7b-chat") # type: ignore - with pytest.raises(ValueError): + with pytest.warns(UserWarning): assert AlephAlphaModel(client=client, name="No model") # type: ignore