diff --git a/charts/kserve-resources/README.md b/charts/kserve-resources/README.md index b233db64e0..167d75f626 100644 --- a/charts/kserve-resources/README.md +++ b/charts/kserve-resources/README.md @@ -72,7 +72,7 @@ $ helm install kserve oci://ghcr.io/kserve/charts/kserve --version v0.13.0 | kserve.servingruntime.lgbserver.tag | string | `"v0.13.0"` | | | kserve.servingruntime.mlserver.image | string | `"docker.io/seldonio/mlserver"` | | | kserve.servingruntime.mlserver.modelClassPlaceholder | string | `"{{.Labels.modelClass}}"` | | -| kserve.servingruntime.mlserver.tag | string | `"1.3.2"` | | +| kserve.servingruntime.mlserver.tag | string | `"1.5.0"` | | | kserve.servingruntime.modelNamePlaceholder | string | `"{{.Name}}"` | | | kserve.servingruntime.paddleserver.image | string | `"kserve/paddleserver"` | | | kserve.servingruntime.paddleserver.tag | string | `"v0.13.0"` | | diff --git a/charts/kserve-resources/templates/clusterservingruntimes.yaml b/charts/kserve-resources/templates/clusterservingruntimes.yaml index bfe46c45ef..6bef93d8cb 100644 --- a/charts/kserve-resources/templates/clusterservingruntimes.yaml +++ b/charts/kserve-resources/templates/clusterservingruntimes.yaml @@ -53,14 +53,26 @@ spec: version: "1" autoSelect: true priority: 2 + - name: xgboost + version: "2" + autoSelect: true + priority: 2 - name: lightgbm version: "3" autoSelect: true priority: 2 + - name: lightgbm + version: "4" + autoSelect: true + priority: 2 - name: mlflow version: "1" autoSelect: true priority: 1 + - name: mlflow + version: "2" + autoSelect: true + priority: 1 protocolVersions: - v2 containers: diff --git a/charts/kserve-resources/values.yaml b/charts/kserve-resources/values.yaml index 9188cf62ee..82fbb7197a 100644 --- a/charts/kserve-resources/values.yaml +++ b/charts/kserve-resources/values.yaml @@ -95,7 +95,7 @@ kserve: tag: 2.6.2 mlserver: image: docker.io/seldonio/mlserver - tag: 1.3.2 + tag: 1.5.0 modelClassPlaceholder: "{{.Labels.modelClass}}" sklearnserver: image: kserve/sklearnserver diff --git a/config/runtimes/kserve-mlserver.yaml b/config/runtimes/kserve-mlserver.yaml index 542184e809..d373a2791b 100644 --- a/config/runtimes/kserve-mlserver.yaml +++ b/config/runtimes/kserve-mlserver.yaml @@ -20,14 +20,26 @@ spec: version: "1" autoSelect: true priority: 2 + - name: xgboost + version: "2" + autoSelect: true + priority: 2 - name: lightgbm version: "3" autoSelect: true priority: 2 + - name: lightgbm + version: "4" + autoSelect: true + priority: 2 - name: mlflow version: "1" autoSelect: true priority: 1 + - name: mlflow + version: "2" + autoSelect: true + priority: 1 protocolVersions: - v2 containers: diff --git a/config/runtimes/kustomization.yaml b/config/runtimes/kustomization.yaml index 48e6877b07..92dedae53d 100644 --- a/config/runtimes/kustomization.yaml +++ b/config/runtimes/kustomization.yaml @@ -25,7 +25,7 @@ images: - name: mlserver newName: docker.io/seldonio/mlserver - newTag: 1.3.2 + newTag: 1.5.0 - name: kserve-xgbserver newName: kserve/xgbserver diff --git a/pkg/controller/v1beta1/inferenceservice/controller.go b/pkg/controller/v1beta1/inferenceservice/controller.go index 5086387cc3..c3510d76df 100644 --- a/pkg/controller/v1beta1/inferenceservice/controller.go +++ b/pkg/controller/v1beta1/inferenceservice/controller.go @@ -168,7 +168,7 @@ func (r *InferenceServiceReconciler) Reconcile(ctx context.Context, req ctrl.Req // Abort early if the resolved deployment mode is Serverless, but Knative Services are not available if deploymentMode == constants.Serverless { ksvcAvailable, checkKsvcErr := utils.IsCrdAvailable(r.ClientConfig, knservingv1.SchemeGroupVersion.String(), constants.KnativeServiceKind) - if err != nil { + if checkKsvcErr != nil { return reconcile.Result{}, checkKsvcErr } diff --git a/python/huggingface_server.Dockerfile b/python/huggingface_server.Dockerfile index e4d66734a7..561c906f38 100644 --- a/python/huggingface_server.Dockerfile +++ b/python/huggingface_server.Dockerfile @@ -1,4 +1,4 @@ -ARG BASE_IMAGE=nvidia/cuda:12.1.0-devel-ubuntu22.04 +ARG BASE_IMAGE=nvidia/cuda:12.4.1-devel-ubuntu22.04 ARG VENV_PATH=/prod_venv FROM ${BASE_IMAGE} as builder @@ -9,7 +9,7 @@ ARG POETRY_HOME=/opt/poetry ARG POETRY_VERSION=1.7.1 # Install vllm -ARG VLLM_VERSION=0.4.2 +ARG VLLM_VERSION=0.4.3 RUN apt-get update -y && apt-get install gcc python3.10-venv python3-dev -y && apt-get clean && \ rm -rf /var/lib/apt/lists/* @@ -34,7 +34,7 @@ RUN cd huggingfaceserver && poetry install --no-interaction --no-cache RUN pip3 install vllm==${VLLM_VERSION} -FROM nvidia/cuda:12.1.0-base-ubuntu22.04 as prod +FROM nvidia/cuda:12.4.1-runtime-ubuntu22.04 as prod RUN apt-get update -y && apt-get install python3.10-venv -y && apt-get clean && \ rm -rf /var/lib/apt/lists/* @@ -58,6 +58,8 @@ ENV HF_HOME="/tmp/huggingface" ENV SAFETENSORS_FAST_GPU="1" # https://huggingface.co/docs/huggingface_hub/en/package_reference/environment_variables#hfhubdisabletelemetry ENV HF_HUB_DISABLE_TELEMETRY="1" +# NCCL Lib path for vLLM. https://github.com/vllm-project/vllm/blob/ec784b2526219cd96159a52074ab8cd4e684410a/vllm/utils.py#L598-L602 +ENV VLLM_NCCL_SO_PATH="/lib/x86_64-linux-gnu/libnccl.so.2" USER 1000 ENTRYPOINT ["python3", "-m", "huggingfaceserver"] diff --git a/python/huggingfaceserver/huggingfaceserver/vllm/vllm_completions.py b/python/huggingfaceserver/huggingfaceserver/vllm/vllm_completions.py index 36cf120e8a..8c1a033a68 100644 --- a/python/huggingfaceserver/huggingfaceserver/vllm/vllm_completions.py +++ b/python/huggingfaceserver/huggingfaceserver/vllm/vllm_completions.py @@ -144,10 +144,9 @@ async def create_completion(self, completion_request: CompletionRequest): generators.append( self.engine.generate( - prompt, + {"prompt": prompt, "prompt_token_ids": input_ids}, sampling_params, f"{request_id}-{i}", - prompt_token_ids=input_ids, ) ) except Exception as e: @@ -175,7 +174,7 @@ async def create_completion(self, completion_request: CompletionRequest): ) # Non-streaming response - final_res_batch: RequestOutput = [None] * len(prompts) + final_res_batch: List[RequestOutput] = [None] * len(prompts) try: async for i, res in result_generator: final_res_batch[i] = res diff --git a/python/huggingfaceserver/poetry.lock b/python/huggingfaceserver/poetry.lock index c421c51f70..fe98c29ac3 100644 --- a/python/huggingfaceserver/poetry.lock +++ b/python/huggingfaceserver/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. [[package]] name = "accelerate" @@ -1652,13 +1652,13 @@ files = [ [[package]] name = "lm-format-enforcer" -version = "0.9.8" +version = "0.10.1" description = "Enforce the output format (JSON Schema, Regex etc) of a language model" optional = true python-versions = "<4.0,>=3.8" files = [ - {file = "lm_format_enforcer-0.9.8-py3-none-any.whl", hash = "sha256:6525906bb11a538afb4c6ea4fdac7c693b7b43c6f3a361ec1a7179385e5682b1"}, - {file = "lm_format_enforcer-0.9.8.tar.gz", hash = "sha256:b3433c50674045b0e46cd53e5016aa3aa1ba2740411802f57441a4c80d3f9606"}, + {file = "lm_format_enforcer-0.10.1-py3-none-any.whl", hash = "sha256:5520004af248d787930327ead052aeff75e21fad595f388e5eade9f062ffddda"}, + {file = "lm_format_enforcer-0.10.1.tar.gz", hash = "sha256:23e65a4199714fca348063e8c906838622619f905a673c4d6d428eee7e7d2095"}, ] [package.dependencies] @@ -4286,23 +4286,24 @@ test = ["covdefaults (>=2.3)", "coverage (>=7.2.7)", "coverage-enable-subprocess [[package]] name = "vllm" -version = "0.4.2" +version = "0.4.3" description = "A high-throughput and memory-efficient inference and serving engine for LLMs" optional = true python-versions = ">=3.8" files = [ - {file = "vllm-0.4.2-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:26ad388c1c6ec0348c19495318c1e10911d5fd4c67aa53c868c7ee773ecf1163"}, - {file = "vllm-0.4.2-cp311-cp311-manylinux1_x86_64.whl", hash = "sha256:94afce6af9a2adb1e72081a245b116325073a9266e88044169fbaf6707d847c0"}, - {file = "vllm-0.4.2-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:27b395dff451c10e2ce699980a047c52d7aa8737b59fbd8e333e824c94401a16"}, - {file = "vllm-0.4.2-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:0fc6220591f40b6f715bd040a293b9e0732fdad100bd25c5a02eb4b461680223"}, - {file = "vllm-0.4.2.tar.gz", hash = "sha256:0b857f3084b507cbdd3bfcbaae19d171c55df9955eb3ac41c9c711768e852772"}, + {file = "vllm-0.4.3-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:908089b6075deca4b74e887494b8e201ca5f7fade0c63fb0805be1a40839854c"}, + {file = "vllm-0.4.3-cp311-cp311-manylinux1_x86_64.whl", hash = "sha256:921fb85d4823c00f242a3fa3c4e724ea56d82375b9bedbbb4e32c4d53fa6a1b2"}, + {file = "vllm-0.4.3-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:e180385874f1ad1f8ecf96effa399515bbddada0a7e6ef687f417eddf84885ec"}, + {file = "vllm-0.4.3-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:73884900feced23d6134174b37069361a1bad31eb963f105c583c0c1fee1fb16"}, + {file = "vllm-0.4.3.tar.gz", hash = "sha256:fe547a463d4de939d094aa96779047424e287fc66ab5eec1d00a7a01e9944626"}, ] [package.dependencies] +aiohttp = "*" cmake = ">=3.21" fastapi = "*" filelock = ">=3.10.4" -lm-format-enforcer = "0.9.8" +lm-format-enforcer = "0.10.1" ninja = "*" numpy = "*" nvidia-ml-py = "*" @@ -4316,28 +4317,34 @@ pydantic = ">=2.0" ray = ">=2.9" requests = "*" sentencepiece = "*" -tiktoken = "0.6.0" +tiktoken = ">=0.6.0" tokenizers = ">=0.19.1" torch = "2.3.0" transformers = ">=4.40.0" typing-extensions = "*" uvicorn = {version = "*", extras = ["standard"]} -vllm-nccl-cu12 = ">=2.18,<2.19" +vllm-flash-attn = "2.5.8.post2" xformers = "0.0.26.post1" [package.extras] -tensorizer = ["tensorizer (==2.9.0)"] +tensorizer = ["tensorizer (>=2.9.0)"] [[package]] -name = "vllm-nccl-cu12" -version = "2.18.1.0.4.0" -description = "" +name = "vllm-flash-attn" +version = "2.5.8.post2" +description = "Forward-only flash-attn" optional = true -python-versions = "*" +python-versions = ">=3.8" files = [ - {file = "vllm_nccl_cu12-2.18.1.0.4.0.tar.gz", hash = "sha256:d56535da1b893ac49c1f40be9245f999e543c3fc95b4839642b70dd1d72760c0"}, + {file = "vllm_flash_attn-2.5.8.post2-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:f465bda12f943958bdec09d4957f4e00d0c512e42606cbe60e16db394ca77291"}, + {file = "vllm_flash_attn-2.5.8.post2-cp311-cp311-manylinux1_x86_64.whl", hash = "sha256:fbc424eb068e1b458c2d44171b159b0c0e3c5560106a32b00410b305736ff034"}, + {file = "vllm_flash_attn-2.5.8.post2-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:febb0676b1cf6051328dcdb0fdf381f5f96899899071fd96ab7a0dc842d2b8d3"}, + {file = "vllm_flash_attn-2.5.8.post2-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:754e5c2c29f023294af7a8113dba2cb715059cb9fafcf954bd2debac79b89702"}, ] +[package.dependencies] +torch = "2.3.0" + [[package]] name = "watchfiles" version = "0.21.0" @@ -4732,4 +4739,4 @@ vllm = ["vllm"] [metadata] lock-version = "2.0" python-versions = ">=3.9,<3.12" -content-hash = "94c99c1ea2e16448d8957b43a2d074c2954c8418cc9a54dff677aa0fcf2e79a3" +content-hash = "915f38bf2721a2b447effcfb2d28bc375f55279ef3e935b6df969057fa811afc" diff --git a/python/huggingfaceserver/pyproject.toml b/python/huggingfaceserver/pyproject.toml index ef946f10ef..b51aa68549 100644 --- a/python/huggingfaceserver/pyproject.toml +++ b/python/huggingfaceserver/pyproject.toml @@ -15,7 +15,7 @@ kserve = { path = "../kserve", extras = ["storage"], develop = true } transformers = "~4.40.2" accelerate = "~0.30.0" torch = "~2.3.0" -vllm = { version = "^0.4.2", optional = true } +vllm = { version = "^0.4.3", optional = true } [tool.poetry.extras] vllm = [ diff --git a/python/kserve/kserve/protocol/rest/openai/openai_proxy_model.py b/python/kserve/kserve/protocol/rest/openai/openai_proxy_model.py index d6c7c195a3..59ec4ac9c2 100644 --- a/python/kserve/kserve/protocol/rest/openai/openai_proxy_model.py +++ b/python/kserve/kserve/protocol/rest/openai/openai_proxy_model.py @@ -243,8 +243,8 @@ async def create_completion( self, request: CompletionRequest ) -> Union[Completion, AsyncIterator[Completion]]: self.preprocess_completion_request(request) - req = self._build_request(self._completions_endpoint, request) if request.params.stream: + req = self._build_request(self._completions_endpoint, request) r = await self._http_client.send(req, stream=True) r.raise_for_status() it = AsyncMappingIterator( @@ -254,23 +254,28 @@ async def create_completion( ) return it else: - response = await self._http_client.send(req) - response.raise_for_status() - if self.skip_upstream_validation: - obj = response.json() - completion = Completion.model_construct(**obj) - else: - completion = Completion.model_validate_json(response.content) + completion = await self.generate_completion(request) self.postprocess_completion(completion, request) return completion + async def generate_completion(self, request: CompletionRequest) -> Completion: + req = self._build_request(self._completions_endpoint, request) + response = await self._http_client.send(req) + response.raise_for_status() + if self.skip_upstream_validation: + obj = response.json() + completion = Completion.model_construct(**obj) + else: + completion = Completion.model_validate_json(response.content) + return completion + @error_handler async def create_chat_completion( self, request: ChatCompletionRequest ) -> Union[ChatCompletion, AsyncIterator[ChatCompletionChunk]]: self.preprocess_chat_completion_request(request) - req = self._build_request(self._chat_completions_endpoint, request) if request.params.stream: + req = self._build_request(self._chat_completions_endpoint, request) r = await self._http_client.send(req, stream=True) r.raise_for_status() it = AsyncMappingIterator( @@ -280,12 +285,19 @@ async def create_chat_completion( ) return it else: - response = await self._http_client.send(req) - response.raise_for_status() - if self.skip_upstream_validation: - obj = response.json() - chat_completion = ChatCompletion.model_construct(**obj) - else: - chat_completion = ChatCompletion.model_validate_json(response.content) + chat_completion = await self.generate_chat_completion(request) self.postprocess_chat_completion(chat_completion, request) return chat_completion + + async def generate_chat_completion( + self, request: ChatCompletionRequest + ) -> ChatCompletion: + req = self._build_request(self._chat_completions_endpoint, request) + response = await self._http_client.send(req) + response.raise_for_status() + if self.skip_upstream_validation: + obj = response.json() + chat_completion = ChatCompletion.model_construct(**obj) + else: + chat_completion = ChatCompletion.model_validate_json(response.content) + return chat_completion diff --git a/python/kserve/kserve/storage/storage.py b/python/kserve/kserve/storage/storage.py index 92afffaa13..50fed2727d 100644 --- a/python/kserve/kserve/storage/storage.py +++ b/python/kserve/kserve/storage/storage.py @@ -68,47 +68,48 @@ def download(uri: str, out_dir: str = None) -> str: if uri.startswith(_PVC_PREFIX) and not os.path.exists(uri): raise Exception(f"Cannot locate source uri {uri} for PVC") - is_local = False - if uri.startswith(_LOCAL_PREFIX) or os.path.exists(uri): - is_local = True - - if out_dir is None: - if is_local: + is_local = uri.startswith(_LOCAL_PREFIX) or os.path.exists(uri) + if is_local: + if out_dir is None: # noop if out_dir is not set and the path is local - return Storage._download_local(uri) - out_dir = tempfile.mkdtemp() - elif not os.path.exists(out_dir): - os.mkdir(out_dir) - - if uri.startswith(_GCS_PREFIX): - Storage._download_gcs(uri, out_dir) - elif uri.startswith(_S3_PREFIX): - Storage._download_s3(uri, out_dir) - elif uri.startswith(_HDFS_PREFIX) or uri.startswith(_WEBHDFS_PREFIX): - Storage._download_hdfs(uri, out_dir) - elif re.search(_AZURE_BLOB_RE, uri): - Storage._download_azure_blob(uri, out_dir) - elif re.search(_AZURE_FILE_RE, uri): - Storage._download_azure_file_share(uri, out_dir) - elif is_local: - return Storage._download_local(uri, out_dir) - elif re.search(_URI_RE, uri): - return Storage._download_from_uri(uri, out_dir) - elif uri.startswith(MODEL_MOUNT_DIRS): - # Don't need to download models if this InferenceService is running in the multi-model - # serving mode. The model agent will download models. - return out_dir + model_dir = Storage._download_local(uri) + else: + if not os.path.exists(out_dir): + os.mkdir(out_dir) + model_dir = Storage._download_local(uri, out_dir) else: - raise Exception( - "Cannot recognize storage type for " - + uri - + "\n'%s', '%s', '%s', and '%s' are the current available storage type." - % (_GCS_PREFIX, _S3_PREFIX, _LOCAL_PREFIX, _HTTP_PREFIX) - ) + if out_dir is None: + out_dir = tempfile.mkdtemp() + elif not os.path.exists(out_dir): + os.mkdir(out_dir) + + if uri.startswith(MODEL_MOUNT_DIRS): + # Don't need to download models if this InferenceService is running in the multi-model + # serving mode. The model agent will download models. + model_dir = out_dir + elif uri.startswith(_GCS_PREFIX): + model_dir = Storage._download_gcs(uri, out_dir) + elif uri.startswith(_S3_PREFIX): + model_dir = Storage._download_s3(uri, out_dir) + elif uri.startswith(_HDFS_PREFIX) or uri.startswith(_WEBHDFS_PREFIX): + model_dir = Storage._download_hdfs(uri, out_dir) + elif re.search(_AZURE_BLOB_RE, uri): + model_dir = Storage._download_azure_blob(uri, out_dir) + elif re.search(_AZURE_FILE_RE, uri): + model_dir = Storage._download_azure_file_share(uri, out_dir) + elif re.search(_URI_RE, uri): + model_dir = Storage._download_from_uri(uri, out_dir) + else: + raise Exception( + "Cannot recognize storage type for " + + uri + + "\n'%s', '%s', '%s', and '%s' are the current available storage type." + % (_GCS_PREFIX, _S3_PREFIX, _LOCAL_PREFIX, _HTTP_PREFIX) + ) logger.info("Successfully copied %s to %s", uri, out_dir) logger.info(f"Model downloaded in {time.monotonic() - start} seconds.") - return out_dir + return model_dir @staticmethod def _update_with_storage_spec(): @@ -175,7 +176,7 @@ def get_S3_config(): return c @staticmethod - def _download_s3(uri, temp_dir: str): + def _download_s3(uri, temp_dir: str) -> str: # Boto3 looks at various configuration locations until it finds configuration values. # lookup order: # 1. Config object passed in as the config parameter when creating S3 resource @@ -274,10 +275,11 @@ def _download_s3(uri, temp_dir: str): if file_count == 1: mimetype, _ = mimetypes.guess_type(target) if mimetype in ["application/x-tar", "application/zip"]: - Storage._unpack_archive_file(target, mimetype, temp_dir) + temp_dir = Storage._unpack_archive_file(target, mimetype, temp_dir) + return temp_dir @staticmethod - def _download_gcs(uri, temp_dir: str): + def _download_gcs(uri, temp_dir: str) -> str: try: storage_client = storage.Client() except exceptions.DefaultCredentialsError: @@ -314,7 +316,8 @@ def _download_gcs(uri, temp_dir: str): if file_count == 1: mimetype, _ = mimetypes.guess_type(blob.name) if mimetype in ["application/x-tar", "application/zip"]: - Storage._unpack_archive_file(dest_path, mimetype, temp_dir) + temp_dir = Storage._unpack_archive_file(dest_path, mimetype, temp_dir) + return temp_dir @staticmethod def _load_hdfs_configuration() -> Dict: @@ -352,7 +355,7 @@ def _load_hdfs_configuration() -> Dict: return config @staticmethod - def _download_hdfs(uri, out_dir: str): + def _download_hdfs(uri, out_dir: str) -> str: from krbcontext.context import krbContext from hdfs.ext.kerberos import Client, KerberosClient @@ -428,10 +431,15 @@ def _download_hdfs(uri, out_dir: str): if file_count == 1: mimetype, _ = mimetypes.guess_type(dest_file_path) if mimetype in ["application/x-tar", "application/zip"]: - Storage._unpack_archive_file(dest_file_path, mimetype, out_dir) + out_dir = Storage._unpack_archive_file( + dest_file_path, mimetype, out_dir + ) + return out_dir @staticmethod - def _download_azure_blob(uri, out_dir: str): # pylint: disable=too-many-locals + def _download_azure_blob( + uri, out_dir: str + ) -> str: # pylint: disable=too-many-locals account_name, account_url, container_name, prefix = Storage._parse_azure_uri( uri ) @@ -485,12 +493,13 @@ def _download_azure_blob(uri, out_dir: str): # pylint: disable=too-many-locals if file_count == 1: mimetype, _ = mimetypes.guess_type(dest_path) if mimetype in ["application/x-tar", "application/zip"]: - Storage._unpack_archive_file(dest_path, mimetype, out_dir) + out_dir = Storage._unpack_archive_file(dest_path, mimetype, out_dir) + return out_dir @staticmethod def _download_azure_file_share( uri, out_dir: str - ): # pylint: disable=too-many-locals + ) -> str: # pylint: disable=too-many-locals account_name, account_url, share_name, prefix = Storage._parse_azure_uri(uri) logger.info( "Connecting to file share account: [%s], container: [%s], prefix: [%s]", @@ -542,7 +551,8 @@ def _download_azure_file_share( if file_count == 1: mimetype, _ = mimetypes.guess_type(dest_path) if mimetype in ["application/x-tar", "application/zip"]: - Storage._unpack_archive_file(dest_path, mimetype, out_dir) + out_dir = Storage._unpack_archive_file(dest_path, mimetype, out_dir) + return out_dir @staticmethod def _parse_azure_uri(uri): # pylint: disable=too-many-locals @@ -587,7 +597,7 @@ def _get_azure_storage_access_key(): return os.getenv("AZURE_STORAGE_ACCESS_KEY") @staticmethod - def _download_local(uri, out_dir=None): + def _download_local(uri, out_dir=None) -> str: local_path = uri.replace(_LOCAL_PREFIX, "", 1) if not os.path.exists(local_path): raise RuntimeError("Local path %s does not exist." % (uri)) @@ -616,12 +626,11 @@ def _download_local(uri, out_dir=None): if file_count == 1: mimetype, _ = mimetypes.guess_type(dest_path) if mimetype in ["application/x-tar", "application/zip"]: - Storage._unpack_archive_file(dest_path, mimetype, out_dir) - + out_dir = Storage._unpack_archive_file(dest_path, mimetype, out_dir) return out_dir @staticmethod - def _download_from_uri(uri, out_dir=None): + def _download_from_uri(uri, out_dir=None) -> str: url = urlparse(uri) filename = os.path.basename(url.path) # Determine if the symbol '?' exists in the path @@ -692,12 +701,11 @@ def _download_from_uri(uri, out_dir=None): shutil.copyfileobj(stream, out) if mimetype in ["application/x-tar", "application/zip"]: - Storage._unpack_archive_file(local_path, mimetype, out_dir) - + out_dir = Storage._unpack_archive_file(local_path, mimetype, out_dir) return out_dir @staticmethod - def _unpack_archive_file(file_path, mimetype, target_dir=None): + def _unpack_archive_file(file_path, mimetype, target_dir=None) -> str: if not target_dir: target_dir = os.path.dirname(file_path) @@ -715,3 +723,4 @@ def _unpack_archive_file(file_path, mimetype, target_dir=None): The file format is not valid." ) os.remove(file_path) + return target_dir diff --git a/python/kserve/kserve/storage/test/test_storage.py b/python/kserve/kserve/storage/test/test_storage.py index e28d787cef..24e9ac784c 100644 --- a/python/kserve/kserve/storage/test/test_storage.py +++ b/python/kserve/kserve/storage/test/test_storage.py @@ -65,6 +65,18 @@ def test_no_prefix_local_path(): assert Storage.download(relative_path) == relative_path +def test_local_path_with_out_dir_exist(): + abs_path = "file:///tmp" + out_dir = "/tmp" + assert Storage.download(abs_path, out_dir=out_dir) == out_dir + + +def test_local_path_with_out_dir_not_exist(): + abs_path = "file:///tmp" + out_dir = "/tmp/test-abc" + assert Storage.download(abs_path, out_dir=out_dir) == out_dir + + class MockHttpResponse(object): def __init__(self, status_code=404, raw=b"", content_type=""): self.status_code = status_code