Skip to content

Commit

Permalink
Merge pull request instructlab#1561 from leseb/fix-supported-backend
Browse files Browse the repository at this point in the history
fix: add list of supported backends in the CLI help
  • Loading branch information
mergify[bot] authored Jul 3, 2024
2 parents efbd310 + c71b27c commit 2e54e54
Show file tree
Hide file tree
Showing 5 changed files with 11 additions and 50 deletions.
4 changes: 3 additions & 1 deletion src/instructlab/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,9 @@ class _serve(BaseModel):

# additional fields with defaults
host_port: StrictStr = "127.0.0.1:8000"
backend: str = "" # we don't set a default value here since it's auto-detected
backend: Optional[str] = (
None # we don't set a default value here since it's auto-detected
)

def api_base(self):
"""Returns server API URL, based on the configured host and port"""
Expand Down
18 changes: 1 addition & 17 deletions src/instructlab/model/backends/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,21 +113,6 @@ def is_model_gguf(model_path: pathlib.Path) -> bool:
return first_four_bytes_int == GGUF_MAGIC


def validate_backend(backend: str) -> None:
"""
Validate the backend.
Args:
backend (str): The backend to validate.
Raises:
ValueError: If the backend is not supported.
"""
# lowercase backend for comparison in case of user input like 'Llama'
if backend.lower() not in SUPPORTED_BACKENDS:
raise ValueError(
f"Backend '{backend}' is not supported. Supported: {', '.join(SUPPORTED_BACKENDS)}"
)


def determine_backend(model_path: pathlib.Path) -> str:
"""
Determine the backend to use based on the model file properties.
Expand Down Expand Up @@ -173,15 +158,14 @@ def get(logger: logging.Logger, model_path: pathlib.Path, backend: str) -> str:
# When the backend is not set using the --backend flag, determine the backend automatically
# 'backend' is optional so we still check for None or empty string in case 'config.yaml' hasn't
# been updated via 'ilab config init'
if backend is None or backend == "":
if backend is None:
logger.debug(
f"Backend is not set using auto-detected value: {auto_detected_backend}"
)
backend = auto_detected_backend
# If the backend was set using the --backend flag, validate it.
else:
logger.debug(f"Validating '{backend}' backend")
validate_backend(backend)
# TODO: keep this code logic and implement a `--force` flag to override the auto-detected backend
# If the backend was set explicitly, but we detected the model should use a different backend, raise an error
# if backend != auto_detected_backend:
Expand Down
10 changes: 5 additions & 5 deletions src/instructlab/model/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# First Party
from instructlab import configuration as config
from instructlab import log, utils
from instructlab.model.backends import backends
from instructlab.model.backends.backends import ServerException

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -47,11 +48,10 @@
)
@click.option(
"--backend",
type=click.STRING, # purposely not using click.Choice to allow for auto-detection
type=click.Choice(backends.SUPPORTED_BACKENDS),
help=(
"The backend to use for serving the model."
"Automatically detected based on the model file properties."
"Supported: 'llama-cpp'."
"The backend to use for serving the model.\n"
"Automatically detected based on the model file properties.\n"
),
)
@click.option(
Expand Down Expand Up @@ -79,7 +79,7 @@ def serve(
):
"""Start a local server"""
# First Party
from instructlab.model.backends import backends, llama_cpp, vllm
from instructlab.model.backends import llama_cpp, vllm

host, port = utils.split_hostport(ctx.obj.config.serve.host_port)

Expand Down
25 changes: 0 additions & 25 deletions tests/test_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,31 +17,6 @@ def mock_supported_backends(monkeypatch):
)


@pytest.mark.usefixtures("mock_supported_backends")
class TestValidateBackend:
def test_validate_backend_valid(self):
# Test with a valid backend
try:
backends.validate_backend("llama-cpp")
except ValueError:
pytest.fail("validate_backend raised ValueError unexpectedly!")
# Test with a valid backend
try:
backends.validate_backend("LLAMA-CPP")
except ValueError:
pytest.fail("validate_backend raised ValueError unexpectedly!")
# Test with a valid backend
try:
backends.validate_backend("vllm")
except ValueError:
pytest.fail("validate_backend raised ValueError unexpectedly!")

def test_validate_backend_invalid(self):
# Test with an invalid backend
with pytest.raises(ValueError):
backends.validate_backend("foo")


def test_free_port():
host = "localhost"
port = backends.free_tcp_ipv4_port(host)
Expand Down
4 changes: 2 additions & 2 deletions tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def _assert_defaults(self, cfg):
assert cfg.serve.vllm is not None
assert cfg.serve.vllm.vllm_args == []
assert cfg.serve.host_port == "127.0.0.1:8000"
assert cfg.serve.backend == ""
assert cfg.serve.backend is None

assert cfg.evaluate is not None

Expand Down Expand Up @@ -119,7 +119,7 @@ def test_full_config(self):
taxonomy_path: taxonomy
chunk_word_count: 1000
serve:
backend: ''
backend: null
host_port: 127.0.0.1:8000
llama_cpp:
gpu_layers: -1
Expand Down

0 comments on commit 2e54e54

Please sign in to comment.