Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Download sparse blob #7555

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
from .container_client import ContainerClient
from .blob_service_client import BlobServiceClient
from .lease import LeaseClient
from .download import StorageStreamDownloader
from ._shared.policies import ExponentialRetry, LinearRetry, NoRetry
from ._shared.downloads import StorageStreamDownloader
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we want to expose it to user I will add it back in the next commit

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this probably should be exposed - it is the return type of the download operations and if users what to do Type-checking they will need to be able to import this type.
I also think it it's exposed in the Files SDK, we whichever way we go should be consistent.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It wasn't exposed in File I think, but I assumed we want to expose it

from ._shared.models import(
LocationMode,
ResourceTypes,
Expand Down
11 changes: 11 additions & 0 deletions sdk/storage/azure-storage-blob/azure/storage/blob/_deserialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,14 @@ def deserialize_container_properties(response, obj, headers):
**headers
)
return container_properties


def get_page_ranges_result(ranges):
# type: (PageList) -> Tuple(List[Dict[str, int]], List[Dict[str, int]])
page_range = [] # type: ignore
clear_range = [] # type: List
if ranges.page_range:
page_range = [{'start': b.start, 'end': b.end} for b in ranges.page_range] # type: ignore
if ranges.clear_range:
clear_range = [{'start': b.start, 'end': b.end} for b in ranges.clear_range]
return page_range, clear_range # type: ignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
# --------------------------------------------------------------------------

from .._shared.policies_async import ExponentialRetry, LinearRetry, NoRetry
from .._shared.downloads_async import StorageStreamDownloader
from .._shared.models import(
LocationMode,
ResourceTypes,
Expand Down Expand Up @@ -40,6 +39,7 @@
BlobPropertiesPaged,
BlobPrefix
)
from .download_async import StorageStreamDownloader
from .blob_client_async import BlobClient
from .container_client_async import ContainerClient
from .blob_service_client_async import BlobServiceClient
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,13 @@
)

from azure.core.tracing.decorator_async import distributed_trace_async
from azure.storage.blob._generated.models import CpkInfo

from .._shared.base_client_async import AsyncStorageAccountHostsMixin
from .._shared.policies_async import ExponentialRetry
from .._shared.downloads_async import StorageStreamDownloader
from .._shared.response_handlers import return_response_headers, process_storage_error
from .._deserialize import get_page_ranges_result
from .._generated.aio import AzureBlobStorage
from .._generated.models import ModifiedAccessConditions, StorageErrorException
from .._generated.models import ModifiedAccessConditions, StorageErrorException, CpkInfo
from .._deserialize import deserialize_blob_properties
from ..blob_client import BlobClient as BlobClientBase
from ._upload_helpers import (
Expand All @@ -28,6 +27,7 @@
from ..models import BlobType, BlobBlock
from ..lease import get_access_conditions
from .lease_async import LeaseClient
from .download_async import StorageStreamDownloader

if TYPE_CHECKING:
from datetime import datetime
Expand Down Expand Up @@ -1377,7 +1377,7 @@ async def get_page_ranges( # type: ignore
ranges = await self._client.page_blob.get_page_ranges(**options)
except StorageErrorException as error:
process_storage_error(error)
return self._get_page_ranges_result(ranges)
return get_page_ranges_result(ranges)

@distributed_trace_async
async def set_sequence_number( # type: ignore
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,17 @@
# license information.
# --------------------------------------------------------------------------

import sys
import asyncio
import sys
from io import BytesIO
from itertools import islice

from azure.core.exceptions import HttpResponseError

from .request_handlers import validate_and_format_range_headers
from .response_handlers import process_storage_error, parse_length_from_content_range
from .encryption import decrypt_blob
from .downloads import process_range_and_offset
from azure.core import HttpResponseError
from .._shared.encryption import decrypt_blob
from .._shared.request_handlers import validate_and_format_range_headers
from .._shared.response_handlers import process_storage_error, parse_length_from_content_range
from .._deserialize import get_page_ranges_result
from ..download import process_range_and_offset


async def process_content(data, start_offset, end_offset, encryption):
Expand Down Expand Up @@ -42,7 +42,8 @@ async def process_content(data, start_offset, end_offset, encryption):
class _AsyncChunkDownloader(object): # pylint: disable=too-many-instance-attributes

def __init__(
self, service=None,
self, client=None,
non_empty_ranges=None,
total_size=None,
chunk_size=None,
current_progress=None,
Expand All @@ -54,7 +55,9 @@ def __init__(
encryption_options=None,
**kwargs):

self.service = service
self.client = client

self.non_empty_ranges = non_empty_ranges

# information on the download range/chunk size
self.chunk_size = chunk_size
Expand Down Expand Up @@ -121,34 +124,52 @@ async def _write_to_stream(self, chunk_data, chunk_start):
else:
self.stream.write(chunk_data)

def _do_optimize(self, given_range_start, given_range_end):
if self.non_empty_ranges is None:
return False

for source_range in self.non_empty_ranges:
if given_range_end < source_range['start']: # pylint:disable=no-else-return
return True
elif source_range['end'] < given_range_start:
pass
else:
return False

return True

async def _download_chunk(self, chunk_start, chunk_end):
download_range, offset = process_range_and_offset(
chunk_start, chunk_end, chunk_end, self.encryption_options)
range_header, range_validation = validate_and_format_range_headers(
download_range[0],
download_range[1] - 1,
check_content_md5=self.validate_content)

try:
_, response = await self.service.download(
range=range_header,
range_get_content_md5=range_validation,
validate_content=self.validate_content,
data_stream_total=self.total_size,
download_stream_current=self.progress_total,
**self.request_options)
except HttpResponseError as error:
process_storage_error(error)
if self._do_optimize(download_range[0], download_range[1] - 1):
chunk_data = b"\x00" * self.chunk_size
else:
range_header, range_validation = validate_and_format_range_headers(
download_range[0],
download_range[1] - 1,
check_content_md5=self.validate_content)
try:
_, response = await self.client.download(
range=range_header,
range_get_content_md5=range_validation,
validate_content=self.validate_content,
data_stream_total=self.total_size,
download_stream_current=self.progress_total,
**self.request_options)
except HttpResponseError as error:
process_storage_error(error)

chunk_data = await process_content(response, offset[0], offset[1], self.encryption_options)
chunk_data = await process_content(response, offset[0], offset[1], self.encryption_options)

# This makes sure that if_match is set so that we can validate
# that subsequent downloads are to an unmodified blob
if self.request_options.get('modified_access_conditions'):
self.request_options['modified_access_conditions'].if_match = response.properties.etag
# This makes sure that if_match is set so that we can validate
# that subsequent downloads are to an unmodified blob
if self.request_options.get('modified_access_conditions'):
self.request_options['modified_access_conditions'].if_match = response.properties.etag

return chunk_data


class StorageStreamDownloader(object): # pylint: disable=too-many-instance-attributes
"""A streaming object to download from Azure Storage.

Expand All @@ -157,14 +178,15 @@ class StorageStreamDownloader(object): # pylint: disable=too-many-instance-attr
"""

def __init__(
self, service=None,
self,
clients=None,
config=None,
offset=None,
length=None,
validate_content=None,
encryption_options=None,
**kwargs):
self.service = service
self.clients = clients
self.config = config
self.offset = offset
self.length = length
Expand Down Expand Up @@ -192,6 +214,7 @@ def __init__(
initial_request_start, initial_request_end, self.length, self.encryption_options)
self.download_size = None
self.file_size = None
self.non_empty_ranges = None
self.response = None
self.properties = None

Expand All @@ -218,7 +241,8 @@ async def __anext__(self):
# Use the length unless it is over the end of the file
data_end = min(self.file_size, self.length + 1)
self._iter_downloader = _AsyncChunkDownloader(
service=self.service,
client=self.clients.blob,
non_empty_ranges=self.non_empty_ranges,
total_size=self.download_size,
chunk_size=self.config.max_chunk_get_size,
current_progress=self.first_get_size,
Expand Down Expand Up @@ -274,7 +298,7 @@ async def _initial_request(self):
check_content_md5=self.validate_content)

try:
location_mode, response = await self.service.download(
location_mode, response = await self.clients.blob.download(
range=range_header,
range_get_content_md5=range_validation,
validate_content=self.validate_content,
Expand Down Expand Up @@ -303,7 +327,7 @@ async def _initial_request(self):
# request a range, do a regular get request in order to get
# any properties.
try:
_, response = await self.service.download(
_, response = await self.clients.blob.download(
validate_content=self.validate_content,
data_stream_total=0,
download_stream_current=0,
Expand All @@ -317,6 +341,14 @@ async def _initial_request(self):
else:
process_storage_error(error)

# get page ranges to optimize downloading sparse page blob
if response.properties.blob_type == 'PageBlob':
try:
page_ranges = await self.clients.page_blob.get_page_ranges()
self.non_empty_ranges = get_page_ranges_result(page_ranges)[0]
except HttpResponseError:
pass

# If the file is small, the download is complete at this point.
# If file size is large, download the rest of the file in chunks.
if response.properties.size != self.download_size:
Expand All @@ -328,32 +360,32 @@ async def _initial_request(self):
self._download_complete = True
return response

async def content_as_bytes(self, max_connections=1):
async def content_as_bytes(self, max_concurrency=1):
"""Download the contents of this file.

This operation is blocking until all data is downloaded.

:param int max_connections:
:param int max_concurrency:
The number of parallel connections with which to download.
:rtype: bytes
"""
stream = BytesIO()
await self.download_to_stream(stream, max_connections=max_connections)
await self.download_to_stream(stream, max_concurrency=max_concurrency)
return stream.getvalue()

async def content_as_text(self, max_connections=1, encoding='UTF-8'):
async def content_as_text(self, max_concurrency=1, encoding='UTF-8'):
"""Download the contents of this file, and decode as text.

This operation is blocking until all data is downloaded.

:param int max_connections:
:param int max_concurrency:
The number of parallel connections with which to download.
:rtype: str
"""
content = await self.content_as_bytes(max_connections=max_connections)
content = await self.content_as_bytes(max_concurrency=max_concurrency)
return content.decode(encoding)

async def download_to_stream(self, stream, max_connections=1):
async def download_to_stream(self, stream, max_concurrency=1):
"""Download the contents of this file to a stream.

:param stream:
Expand All @@ -367,7 +399,7 @@ async def download_to_stream(self, stream, max_connections=1):
raise ValueError("Stream is currently being iterated.")

# the stream must be seekable if parallel download is required
parallel = max_connections > 1
parallel = max_concurrency > 1
if parallel:
error_message = "Target stream handle must be seekable."
if sys.version_info >= (3,) and not stream.seekable():
Expand Down Expand Up @@ -396,7 +428,8 @@ async def download_to_stream(self, stream, max_connections=1):
data_end = min(self.file_size, self.length + 1)

downloader = _AsyncChunkDownloader(
service=self.service,
client=self.clients.blob,
non_empty_ranges=self.non_empty_ranges,
total_size=self.download_size,
chunk_size=self.config.max_chunk_get_size,
current_progress=self.first_get_size,
Expand All @@ -412,7 +445,7 @@ async def download_to_stream(self, stream, max_connections=1):
dl_tasks = downloader.get_chunk_offsets()
running_futures = [
asyncio.ensure_future(downloader.process_chunk(d))
for d in islice(dl_tasks, 0, max_connections)
for d in islice(dl_tasks, 0, max_concurrency)
]
while running_futures:
# Wait for some download to finish before adding a new one
Expand Down
20 changes: 5 additions & 15 deletions sdk/storage/azure-storage-blob/azure/storage/blob/blob_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,11 @@
from ._shared.base_client import StorageAccountHostsMixin, parse_connection_str, parse_query
from ._shared.encryption import generate_blob_encryption_data
from ._shared.uploads import IterStreamer
from ._shared.downloads import StorageStreamDownloader
from ._shared.request_handlers import (
add_metadata_headers, get_length, read_length,
validate_and_format_range_headers)
from ._shared.response_handlers import return_response_headers, process_storage_error
from ._deserialize import get_page_ranges_result
from ._generated import AzureBlobStorage
from ._generated.models import ( # pylint: disable=unused-import
DeleteSnapshotsOptionType,
Expand All @@ -47,13 +47,13 @@
upload_append_blob,
upload_page_blob)
from .models import BlobType, BlobBlock
from .download import StorageStreamDownloader
from .lease import LeaseClient, get_access_conditions
from ._shared_access_signature import BlobSharedAccessSignature

if TYPE_CHECKING:
from datetime import datetime
from azure.core.pipeline.policies import HTTPPolicy
from ._generated.models import BlockList, PageList
from ._generated.models import BlockList
from .models import ( # pylint: disable=unused-import
ContainerProperties,
BlobProperties,
Expand Down Expand Up @@ -580,7 +580,7 @@ def _download_blob_options(self, offset=None, length=None, validate_content=Fals
encryption_algorithm=cpk.algorithm)

options = {
'service': self._client.blob,
'clients': self._client,
xiafu-msft marked this conversation as resolved.
Show resolved Hide resolved
'config': self._config,
'offset': offset,
'length': length,
Expand Down Expand Up @@ -2100,16 +2100,6 @@ def _get_page_ranges_options( # type: ignore
options.update(kwargs)
return options

def _get_page_ranges_result(self, ranges):
# type: (PageList) -> Tuple(List[Dict[str, int]], List[Dict[str, int]])
page_range = [] # type: ignore
clear_range = [] # type: List
if ranges.page_range:
page_range = [{'start': b.start, 'end': b.end} for b in ranges.page_range] # type: ignore
if ranges.clear_range:
clear_range = [{'start': b.start, 'end': b.end} for b in ranges.clear_range]
return page_range, clear_range # type: ignore

@distributed_trace
def get_page_ranges( # type: ignore
self, start_range=None, # type: Optional[int]
Expand Down Expand Up @@ -2183,7 +2173,7 @@ def get_page_ranges( # type: ignore
ranges = self._client.page_blob.get_page_ranges(**options)
except StorageErrorException as error:
process_storage_error(error)
return self._get_page_ranges_result(ranges)
return get_page_ranges_result(ranges)

def _set_sequence_number_options(self, sequence_number_action, sequence_number=None, **kwargs):
# type: (Union[str, SequenceNumberAction], Optional[str], **Any) -> Dict[str, Any]
Expand Down
Loading