diff --git a/src/instructlab/configuration.py b/src/instructlab/configuration.py index a5d0ba1966..fdf824e9d4 100644 --- a/src/instructlab/configuration.py +++ b/src/instructlab/configuration.py @@ -172,7 +172,9 @@ class _serve(BaseModel): # additional fields with defaults host_port: StrictStr = "127.0.0.1:8000" - backend: str = "" # we don't set a default value here since it's auto-detected + backend: Optional[str] = ( + None # we don't set a default value here since it's auto-detected + ) def api_base(self): """Returns server API URL, based on the configured host and port""" diff --git a/src/instructlab/model/backends/backends.py b/src/instructlab/model/backends/backends.py index f72004fcfd..98a62dd433 100644 --- a/src/instructlab/model/backends/backends.py +++ b/src/instructlab/model/backends/backends.py @@ -113,21 +113,6 @@ def is_model_gguf(model_path: pathlib.Path) -> bool: return first_four_bytes_int == GGUF_MAGIC -def validate_backend(backend: str) -> None: - """ - Validate the backend. - Args: - backend (str): The backend to validate. - Raises: - ValueError: If the backend is not supported. - """ - # lowercase backend for comparison in case of user input like 'Llama' - if backend.lower() not in SUPPORTED_BACKENDS: - raise ValueError( - f"Backend '{backend}' is not supported. Supported: {', '.join(SUPPORTED_BACKENDS)}" - ) - - def determine_backend(model_path: pathlib.Path) -> str: """ Determine the backend to use based on the model file properties. @@ -173,7 +158,7 @@ def get(logger: logging.Logger, model_path: pathlib.Path, backend: str) -> str: # When the backend is not set using the --backend flag, determine the backend automatically # 'backend' is optional so we still check for None or empty string in case 'config.yaml' hasn't # been updated via 'ilab config init' - if backend is None or backend == "": + if backend is None: logger.debug( f"Backend is not set using auto-detected value: {auto_detected_backend}" ) @@ -181,7 +166,6 @@ def get(logger: logging.Logger, model_path: pathlib.Path, backend: str) -> str: # If the backend was set using the --backend flag, validate it. else: logger.debug(f"Validating '{backend}' backend") - validate_backend(backend) # TODO: keep this code logic and implement a `--force` flag to override the auto-detected backend # If the backend was set explicitly, but we detected the model should use a different backend, raise an error # if backend != auto_detected_backend: diff --git a/src/instructlab/model/serve.py b/src/instructlab/model/serve.py index 7986b1735a..f3ff47bcd4 100644 --- a/src/instructlab/model/serve.py +++ b/src/instructlab/model/serve.py @@ -11,6 +11,7 @@ # First Party from instructlab import configuration as config from instructlab import log, utils +from instructlab.model.backends import backends from instructlab.model.backends.backends import ServerException logger = logging.getLogger(__name__) @@ -47,11 +48,10 @@ ) @click.option( "--backend", - type=click.STRING, # purposely not using click.Choice to allow for auto-detection + type=click.Choice(backends.SUPPORTED_BACKENDS), help=( - "The backend to use for serving the model." - "Automatically detected based on the model file properties." - "Supported: 'llama-cpp'." + "The backend to use for serving the model.\n" + "Automatically detected based on the model file properties.\n" ), ) @click.option( @@ -79,7 +79,7 @@ def serve( ): """Start a local server""" # First Party - from instructlab.model.backends import backends, llama_cpp, vllm + from instructlab.model.backends import llama_cpp, vllm host, port = utils.split_hostport(ctx.obj.config.serve.host_port) diff --git a/tests/test_backends.py b/tests/test_backends.py index c081c96c24..cd91cdff2b 100644 --- a/tests/test_backends.py +++ b/tests/test_backends.py @@ -17,31 +17,6 @@ def mock_supported_backends(monkeypatch): ) -@pytest.mark.usefixtures("mock_supported_backends") -class TestValidateBackend: - def test_validate_backend_valid(self): - # Test with a valid backend - try: - backends.validate_backend("llama-cpp") - except ValueError: - pytest.fail("validate_backend raised ValueError unexpectedly!") - # Test with a valid backend - try: - backends.validate_backend("LLAMA-CPP") - except ValueError: - pytest.fail("validate_backend raised ValueError unexpectedly!") - # Test with a valid backend - try: - backends.validate_backend("vllm") - except ValueError: - pytest.fail("validate_backend raised ValueError unexpectedly!") - - def test_validate_backend_invalid(self): - # Test with an invalid backend - with pytest.raises(ValueError): - backends.validate_backend("foo") - - def test_free_port(): host = "localhost" port = backends.free_tcp_ipv4_port(host) diff --git a/tests/test_config.py b/tests/test_config.py index 7e16101697..42360db28e 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -48,7 +48,7 @@ def _assert_defaults(self, cfg): assert cfg.serve.vllm is not None assert cfg.serve.vllm.vllm_args == [] assert cfg.serve.host_port == "127.0.0.1:8000" - assert cfg.serve.backend == "" + assert cfg.serve.backend is None assert cfg.evaluate is not None @@ -119,7 +119,7 @@ def test_full_config(self): taxonomy_path: taxonomy chunk_word_count: 1000 serve: - backend: '' + backend: null host_port: 127.0.0.1:8000 llama_cpp: gpu_layers: -1