Skip to content

Commit

Permalink
Check whether the storage initilizer really download something (kubef…
Browse files Browse the repository at this point in the history
…low#482)

Updated test case too since now we just need check there's runtime error
  • Loading branch information
pugangxa authored and k8s-ci-robot committed Oct 29, 2019
1 parent 24e4f91 commit 673a533
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 14 deletions.
18 changes: 16 additions & 2 deletions python/kfserving/kfserving/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ def _download_s3(uri, temp_dir: str):
bucket_name = bucket_args[0]
bucket_path = bucket_args[1] if len(bucket_args) > 1 else ""
objects = client.list_objects(bucket_name, prefix=bucket_path, recursive=True)
count = 0
for obj in objects:
# Replace any prefix from the object key with temp_dir
subdir_object_key = obj.object_name.replace(bucket_path, "", 1).strip("/")
Expand All @@ -77,6 +78,10 @@ def _download_s3(uri, temp_dir: str):
subdir_object_key = obj.object_name
client.fget_object(bucket_name, obj.object_name,
os.path.join(temp_dir, subdir_object_key))
count = count + 1
if count == 0:
raise RuntimeError("Failed to fetch model. \
The path or model %s does not exist." % (uri))

@staticmethod
def _download_gcs(uri, temp_dir: str):
Expand All @@ -92,6 +97,7 @@ def _download_gcs(uri, temp_dir: str):
if not prefix.endswith("/"):
prefix = prefix + "/"
blobs = bucket.list_blobs(prefix=prefix)
count = 0
for blob in blobs:
# Replace any prefix from the object key with temp_dir
subdir_object_key = blob.name.replace(bucket_path, "", 1).strip("/")
Expand All @@ -105,6 +111,10 @@ def _download_gcs(uri, temp_dir: str):
dest_path = os.path.join(temp_dir, subdir_object_key)
logging.info("Downloading: %s", dest_path)
blob.download_to_filename(dest_path)
count = count + 1
if count == 0:
raise RuntimeError("Failed to fetch model. \
The path or model %s does not exist." % (uri))

@staticmethod
def _download_blob(uri, out_dir: str): # pylint: disable=too-many-locals
Expand All @@ -126,7 +136,7 @@ def _download_blob(uri, out_dir: str): # pylint: disable=too-many-locals
logging.warning("Azure credentials not found, retrying anonymous access")
block_blob_service = BlockBlobService(account_name=account_name, token_credential=token)
blobs = block_blob_service.list_blobs(container_name, prefix=prefix)

count = 0
for blob in blobs:
dest_path = os.path.join(out_dir, blob.name)
if "/" in blob.name:
Expand All @@ -142,6 +152,10 @@ def _download_blob(uri, out_dir: str): # pylint: disable=too-many-locals

logging.info("Downloading: %s to %s", blob.name, dest_path)
block_blob_service.get_blob_to_path(container_name, blob.name, dest_path)
count = count + 1
if count == 0:
raise RuntimeError("Failed to fetch model. \
The path or model %s does not exist." % (uri))

@staticmethod
def _get_azure_storage_token():
Expand Down Expand Up @@ -176,7 +190,7 @@ def _get_azure_storage_token():
def _download_local(uri, out_dir=None):
local_path = uri.replace(_LOCAL_PREFIX, "", 1)
if not os.path.exists(local_path):
raise Exception("Local path %s does not exist." % (uri))
raise RuntimeError("Local path %s does not exist." % (uri))

if out_dir is None:
return local_path
Expand Down
18 changes: 6 additions & 12 deletions python/kfserving/test/test_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,28 +49,22 @@ def test_mock_gcs(mock_storage):
mock_storage.Client().bucket().list_blobs().__iter__.return_value = [mock_obj]
assert kfserving.Storage.download(gcs_path)

@mock.patch(STORAGE_MODULE + '.BlockBlobService')
def test_mock_blob(mock_storage):
def test_storage_blob_exception():
blob_path = 'https://accountname.blob.core.windows.net/container/some/blob/'
mock_obj = mock.MagicMock()
mock_obj.name = 'mock.object'
mock_storage.list_blobs.__iter__.return_value = [mock_obj]
assert kfserving.Storage.download(blob_path)
with pytest.raises(Exception):
kfserving.Storage.download(blob_path)

@mock.patch('urllib3.PoolManager')
@mock.patch(STORAGE_MODULE + '.Minio')
def test_mock_minio(mock_connection, mock_minio):
def test_storage_s3_exception(mock_connection, mock_minio):
minio_path = 's3://foo/bar'
# Create mock connection
mock_server = mock.MagicMock()
mock_connection.return_value = mock_server
# Create mock client
mock_minio.return_value = Minio("s3.us.cloud-object-storage.appdomain.cloud", secure=True)
mock_obj = mock.MagicMock()
mock_obj.object_name = 'mock.object'
mock_minio.list_objects().__iter__.return_value = [mock_obj]
assert kfserving.Storage.download(minio_path)

with pytest.raises(Exception):
kfserving.Storage.download(minio_path)

@mock.patch('urllib3.PoolManager')
@mock.patch(STORAGE_MODULE + '.Minio')
Expand Down

0 comments on commit 673a533

Please sign in to comment.