Skip to content

Commit

Permalink
fix: Throw a warning instead of an error incase a non recommended mod…
Browse files Browse the repository at this point in the history
…el is selected for a ControlMode. (#892)

Task: IL-546
  • Loading branch information
FlorianSchepersAA authored Jun 6, 2024
1 parent 1f033bf commit fcd9908
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 18 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
* `run_dataset` now has a flag `trace_examples_individually` to create `Tracer`s for each example. Defaults to True.

### Fixes
...
- ControlModels throw warning instead of error in case a not recommended model is selected.
### Deprecations
...

Expand Down
29 changes: 17 additions & 12 deletions src/intelligence_layer/core/model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import warnings
from abc import ABC, abstractmethod
from functools import lru_cache
from typing import Literal, Optional
from typing import Optional

from aleph_alpha_client import (
CompletionRequest,
Expand Down Expand Up @@ -153,8 +154,9 @@ def __init__(
limited_concurrency_client_from_env() if client is None else client
)
if name not in [model["name"] for model in self._client.models()]:
raise ValueError(
f"Could not find model: {name}. Either model name is invalid or model is currently down."
warnings.warn(
"The provided model is not a recommended model for this model class."
"Make sure that the model you have selected is suited to be use for the prompt template used in this model class."
)
self._complete: Task[CompleteInput, CompleteOutput] = _Complete(
self._client, name
Expand Down Expand Up @@ -195,13 +197,16 @@ def tokenize(self, text: str) -> Encoding:


class ControlModel(ABC, AlephAlphaModel):
AllowedModel: Literal[""]
RECOMMENDED_MODELS = [""]

def __init__(
self, name: str, client: AlephAlphaClientProtocol | None = None
) -> None:
if name not in self.AllowedModel.__args__ or name == "": # type: ignore
raise ValueError(f"Invalid model name: {name}")
if name not in self.RECOMMENDED_MODELS or name == "":
warnings.warn(
"The provided model is not a recommended model for this model class."
"Make sure that the model you have selected is suited to be use for the prompt template used in this model class."
)
super().__init__(name, client)

@abstractmethod
Expand Down Expand Up @@ -232,7 +237,7 @@ class LuminousControlModel(ControlModel):
### Response:{{response_prefix}}"""
)

AllowedModel = Literal[
RECOMMENDED_MODELS = [
"luminous-base-control-20230501",
"luminous-extended-control-20230501",
"luminous-supreme-control-20230501",
Expand All @@ -246,7 +251,7 @@ class LuminousControlModel(ControlModel):

def __init__(
self,
name: AllowedModel = "luminous-base-control",
name: str = "luminous-base-control",
client: Optional[AlephAlphaClientProtocol] = None,
) -> None:
super().__init__(name, client)
Expand Down Expand Up @@ -282,15 +287,15 @@ class Llama2InstructModel(ControlModel):
{{response_prefix}}{% endif %}""")

AllowedModel = Literal[
RECOMMENDED_MODELS = [
"llama-2-7b-chat",
"llama-2-13b-chat",
"llama-2-70b-chat",
]

def __init__(
self,
name: AllowedModel = "llama-2-13b-chat",
name: str = "llama-2-13b-chat",
client: Optional[AlephAlphaClientProtocol] = None,
) -> None:
super().__init__(name, client)
Expand Down Expand Up @@ -327,14 +332,14 @@ class Llama3InstructModel(ControlModel):
)
EOT_TOKEN = "<|eot_id|>"

AllowedModel = Literal[
RECOMMENDED_MODELS = [
"llama-3-8b-instruct",
"llama-3-70b-instruct",
]

def __init__(
self,
name: AllowedModel = "llama-3-8b-instruct",
name: str = "llama-3-8b-instruct",
client: Optional[AlephAlphaClientProtocol] = None,
) -> None:
super().__init__(name, client)
Expand Down
10 changes: 5 additions & 5 deletions tests/core/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,17 +97,17 @@ def test_models_know_their_context_size(client: AlephAlphaClientProtocol) -> Non
)


def test_models_are_strict_about_instantiation(
def test_models_warn_about_non_recommended_models(
client: AlephAlphaClientProtocol,
) -> None:
with pytest.raises(ValueError):
with pytest.warns(UserWarning):
assert LuminousControlModel(client=client, name="llama-2-7b-chat") # type: ignore

with pytest.raises(ValueError):
with pytest.warns(UserWarning):
assert Llama2InstructModel(client=client, name="luminous-base") # type: ignore

with pytest.raises(ValueError):
with pytest.warns(UserWarning):
assert Llama3InstructModel(client=client, name="llama-2-7b-chat") # type: ignore

with pytest.raises(ValueError):
with pytest.warns(UserWarning):
assert AlephAlphaModel(client=client, name="No model") # type: ignore

0 comments on commit fcd9908

Please sign in to comment.