Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Storage] Remove client-side encryption code from shared #24931

Merged
merged 11 commits into from
Jun 25, 2022
52 changes: 25 additions & 27 deletions sdk/storage/azure-storage-blob/azure/storage/blob/_blob_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,35 +4,31 @@
# license information.
# --------------------------------------------------------------------------
# pylint: disable=too-many-lines,no-self-use

from functools import partial
from io import BytesIO
from typing import ( # pylint: disable=unused-import
Union, Optional, Any, IO, Iterable, AnyStr, Dict, List, Tuple,
TYPE_CHECKING,
TypeVar, Type)
from typing import (
Any, AnyStr, Dict, IO, Iterable, List, Optional, Tuple, Type, TypeVar, Union,
TYPE_CHECKING
)
from urllib.parse import urlparse, quote, unquote
import warnings

try:
from urllib.parse import urlparse, quote, unquote
except ImportError:
from urlparse import urlparse # type: ignore
from urllib2 import quote, unquote # type: ignore
import six
from azure.core.exceptions import ResourceNotFoundError, HttpResponseError, ResourceExistsError
from azure.core.paging import ItemPaged
from azure.core.pipeline import Pipeline
from azure.core.tracing.decorator import distributed_trace
from azure.core.exceptions import ResourceNotFoundError, HttpResponseError, ResourceExistsError

from ._shared import encode_base64
from ._shared.base_client import StorageAccountHostsMixin, parse_connection_str, parse_query, TransportWrapper
from ._shared.encryption import generate_blob_encryption_data
from ._shared.uploads import IterStreamer
from ._shared.request_handlers import (
add_metadata_headers, get_length, read_length,
validate_and_format_range_headers)
from ._shared.response_handlers import return_response_headers, process_storage_error, return_headers_and_deserialized
from ._generated import AzureBlobStorage
from ._generated.models import ( # pylint: disable=unused-import
from ._generated.models import (
DeleteSnapshotsOptionType,
BlobHTTPHeaders,
BlockLookupList,
Expand All @@ -49,22 +45,30 @@
serialize_blob_tags,
serialize_query_format, get_access_conditions
)
from ._deserialize import get_page_ranges_result, deserialize_blob_properties, deserialize_blob_stream, parse_tags, \
from ._deserialize import (
get_page_ranges_result,
deserialize_blob_properties,
deserialize_blob_stream,
parse_tags,
deserialize_pipeline_response_into_cls
)
from ._download import StorageStreamDownloader
from ._encryption import StorageEncryptionMixin
from ._lease import BlobLeaseClient
from ._models import BlobType, BlobBlock, BlobProperties, BlobQueryError, QuickQueryDialect, \
DelimitedJsonDialect, DelimitedTextDialect, PageRangePaged, PageRange
from ._quick_query_helper import BlobQueryReader
from ._upload_helpers import (
upload_block_blob,
upload_append_blob,
upload_page_blob, _any_conditions)
from ._models import BlobType, BlobBlock, BlobProperties, BlobQueryError, QuickQueryDialect, \
DelimitedJsonDialect, DelimitedTextDialect, PageRangePaged, PageRange
from ._download import StorageStreamDownloader
from ._lease import BlobLeaseClient
upload_page_blob,
_any_conditions
)

if TYPE_CHECKING:
from datetime import datetime
from ._generated.models import BlockList
from ._models import ( # pylint: disable=unused-import
from ._models import (
ContentSettings,
ImmutabilityPolicy,
PremiumPageBlobTier,
Expand All @@ -79,7 +83,7 @@
ClassType = TypeVar("ClassType")


class BlobClient(StorageAccountHostsMixin): # pylint: disable=too-many-public-methods
class BlobClient(StorageAccountHostsMixin, StorageEncryptionMixin): # pylint: disable=too-many-public-methods
"""A client to interact with a specific blob, although that blob may not yet exist.

For more optional configuration, please click
Expand Down Expand Up @@ -181,6 +185,7 @@ def __init__(
super(BlobClient, self).__init__(parsed_url, service='blob', credential=credential, **kwargs)
self._client = AzureBlobStorage(self.url, base_url=self.url, pipeline=self._pipeline)
self._client._config.version = get_api_version(kwargs) # pylint: disable=protected-access
self.configure_encryption(kwargs)

def _format_url(self, hostname):
container_name = self.container_name
Expand Down Expand Up @@ -359,13 +364,6 @@ def _upload_blob_options( # pylint:disable=too-many-statements
'key': self.key_encryption_key,
'resolver': self.key_resolver_function,
}
if self.key_encryption_key is not None:
jalauzon-msft marked this conversation as resolved.
Show resolved Hide resolved
cek, iv, encryption_data = generate_blob_encryption_data(
self.key_encryption_key,
self.encryption_version)
encryption_options['cek'] = cek
encryption_options['vector'] = iv
encryption_options['data'] = encryption_data

encoding = kwargs.pop('encoding', 'UTF-8')
if isinstance(data, six.text_type):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,34 +7,33 @@
import functools
import warnings
from typing import ( # pylint: disable=unused-import
Union, Optional, Any, Iterable, Dict, List,
TYPE_CHECKING,
TypeVar)
Any, Dict, List, Optional, TypeVar, Union,
TYPE_CHECKING
)
from urllib.parse import urlparse


try:
from urllib.parse import urlparse
except ImportError:
from urlparse import urlparse # type: ignore

from azure.core.paging import ItemPaged
from azure.core.exceptions import HttpResponseError
from azure.core.paging import ItemPaged
from azure.core.pipeline import Pipeline
from azure.core.tracing.decorator import distributed_trace

from ._shared.models import LocationMode
from ._shared.base_client import StorageAccountHostsMixin, TransportWrapper, parse_connection_str, parse_query
from ._shared.models import LocationMode
from ._shared.parser import _to_utc_datetime
from ._shared.response_handlers import return_response_headers, process_storage_error, \
from ._shared.response_handlers import (
return_response_headers,
process_storage_error,
parse_to_internal_user_delegation_key
)
from ._generated import AzureBlobStorage
from ._generated.models import StorageServiceProperties, KeyInfo
from ._container_client import ContainerClient
from ._blob_client import BlobClient
from ._models import ContainerPropertiesPaged
from ._deserialize import service_stats_deserialize, service_properties_deserialize
from ._encryption import StorageEncryptionMixin
from ._list_blobs_helper import FilteredBlobPaged
from ._models import ContainerPropertiesPaged
from ._serialize import get_api_version
from ._deserialize import service_stats_deserialize, service_properties_deserialize

if TYPE_CHECKING:
from datetime import datetime
Expand All @@ -55,7 +54,7 @@
ClassType = TypeVar("ClassType")


class BlobServiceClient(StorageAccountHostsMixin):
class BlobServiceClient(StorageAccountHostsMixin, StorageEncryptionMixin):
"""A client to interact with the Blob Service at the account level.

This client provides operations to retrieve and configure the account properties
Expand Down Expand Up @@ -137,6 +136,7 @@ def __init__(
super(BlobServiceClient, self).__init__(parsed_url, service='blob', credential=credential, **kwargs)
self._client = AzureBlobStorage(self.url, base_url=self.url, pipeline=self._pipeline)
self._client._config.version = get_api_version(kwargs) # pylint: disable=protected-access
self.configure_encryption(kwargs)

def _format_url(self, hostname):
"""Format the endpoint URL according to the current location
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,53 +7,47 @@

import functools
from typing import ( # pylint: disable=unused-import
Union, Optional, Any, Iterable, AnyStr, Dict, List, Tuple, IO, Iterator,
TYPE_CHECKING,
TypeVar)


try:
from urllib.parse import urlparse, quote, unquote
except ImportError:
from urlparse import urlparse # type: ignore
from urllib2 import quote, unquote # type: ignore
Any, AnyStr, Dict, List, IO, Iterable, Iterator, Optional, TypeVar, Union,
TYPE_CHECKING
)
from urllib.parse import urlparse, quote, unquote

import six

from azure.core import MatchConditions
from azure.core.exceptions import HttpResponseError, ResourceNotFoundError
from azure.core.paging import ItemPaged
from azure.core.tracing.decorator import distributed_trace
from azure.core.pipeline import Pipeline
from azure.core.pipeline.transport import HttpRequest
from azure.core.tracing.decorator import distributed_trace

from ._shared.base_client import StorageAccountHostsMixin, TransportWrapper, parse_connection_str, parse_query
from ._shared.request_handlers import add_metadata_headers, serialize_iso
from ._shared.response_handlers import (
process_storage_error,
return_response_headers,
return_headers_and_deserialized)
return_headers_and_deserialized
)
from ._generated import AzureBlobStorage
from ._generated.models import SignedIdentifier
from ._blob_client import BlobClient
from ._deserialize import deserialize_container_properties
from ._serialize import get_modify_conditions, get_container_cpk_scope_info, get_api_version, get_access_conditions
from ._models import ( # pylint: disable=unused-import
from ._encryption import StorageEncryptionMixin
from ._lease import BlobLeaseClient
from ._list_blobs_helper import BlobPrefix, BlobPropertiesPaged, FilteredBlobPaged
from ._models import (
ContainerProperties,
BlobProperties,
BlobType,
FilteredBlob)
from ._list_blobs_helper import BlobPrefix, BlobPropertiesPaged, FilteredBlobPaged
from ._lease import BlobLeaseClient
from ._blob_client import BlobClient
FilteredBlob
)
from ._serialize import get_modify_conditions, get_container_cpk_scope_info, get_api_version, get_access_conditions

if TYPE_CHECKING:
from azure.core.pipeline.transport import HttpTransport, HttpResponse # pylint: disable=ungrouped-imports
from azure.core.pipeline.policies import HTTPPolicy # pylint: disable=ungrouped-imports
from azure.core.pipeline.transport import HttpResponse # pylint: disable=ungrouped-imports
from datetime import datetime
from ._models import ( # pylint: disable=unused-import
PublicAccess,
AccessPolicy,
ContentSettings,
StandardBlobTier,
PremiumPageBlobTier)

Expand All @@ -73,7 +67,7 @@ def _get_blob_name(blob):
ClassType = TypeVar("ClassType")


class ContainerClient(StorageAccountHostsMixin): # pylint: disable=too-many-public-methods
class ContainerClient(StorageAccountHostsMixin, StorageEncryptionMixin): # pylint: disable=too-many-public-methods
"""A client to interact with a specific container, although that container
may not yet exist.

Expand Down Expand Up @@ -161,6 +155,7 @@ def __init__(
super(ContainerClient, self).__init__(parsed_url, service='blob', credential=credential, **kwargs)
self._client = AzureBlobStorage(self.url, base_url=self.url, pipeline=self._pipeline)
self._client._config.version = get_api_version(kwargs) # pylint: disable=protected-access
self.configure_encryption(kwargs)

def _format_url(self, hostname):
container_name = self.container_name
Expand Down
11 changes: 5 additions & 6 deletions sdk/storage/azure-storage-blob/azure/storage/blob/_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,25 +7,24 @@
import sys
import threading
import time

import warnings
from io import BytesIO
from typing import Iterator, Union

import requests
from azure.core.exceptions import HttpResponseError, ServiceResponseError

from azure.core.tracing.common import with_current_context
from ._shared.encryption import (

from ._shared.request_handlers import validate_and_format_range_headers
from ._shared.response_handlers import process_storage_error, parse_length_from_content_range
from ._deserialize import deserialize_blob_properties, get_page_ranges_result
from ._encryption import (
adjust_blob_size_for_encryption,
decrypt_blob,
get_adjusted_download_range_and_offset,
is_encryption_v2,
parse_encryption_data
)
from ._shared.request_handlers import validate_and_format_range_headers
from ._shared.response_handlers import process_storage_error, parse_length_from_content_range
from ._deserialize import deserialize_blob_properties, get_page_ranges_result


def process_range_and_offset(start_range, end_range, length, encryption_options, encryption_data):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import os
import math
import sys
import warnings
from collections import OrderedDict
from io import BytesIO
from json import (
Expand All @@ -24,8 +25,8 @@

from azure.core.exceptions import HttpResponseError

from .._version import VERSION
from . import encode_base64, decode_base64_to_bytes
from ._version import VERSION
from ._shared import encode_base64, decode_base64_to_bytes


_ENCRYPTION_PROTOCOL_V1 = '1.0'
Expand Down Expand Up @@ -53,6 +54,19 @@ def _validate_key_encryption_key_wrap(kek):
raise AttributeError(_ERROR_OBJECT_INVALID.format('key encryption key', 'get_key_wrap_algorithm'))


class StorageEncryptionMixin(object):
def configure_encryption(self, kwargs):
self.require_encryption = kwargs.get("require_encryption", False)
self.encryption_version = kwargs.get("encryption_version", "1.0")
self.key_encryption_key = kwargs.get("key_encryption_key")
self.key_resolver_function = kwargs.get("key_resolver_function")
if self.key_encryption_key and self.encryption_version == '1.0':
warnings.warn("This client has been configured to use encryption with version 1.0. " +
"Version 1.0 is deprecated and no longer considered secure. It is highly " +
"recommended that you switch to using version 2.0. The version can be " +
"specified using the 'encryption_version' keyword.")


class _EncryptionAlgorithm(object):
'''
Specifies which client encryption algorithm is used.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
# --------------------------------------------------------------------------
import logging
import uuid
import warnings
from typing import ( # pylint: disable=unused-import
Optional,
Any,
Expand Down Expand Up @@ -105,16 +104,6 @@ def __init__(
primary_hostname = (parsed_url.netloc + parsed_url.path).rstrip('/')
self._hosts = {LocationMode.PRIMARY: primary_hostname, LocationMode.SECONDARY: secondary_hostname}

self.require_encryption = kwargs.get("require_encryption", False)
self.encryption_version = kwargs.get("encryption_version", "1.0")
self.key_encryption_key = kwargs.get("key_encryption_key")
self.key_resolver_function = kwargs.get("key_resolver_function")
if self.key_encryption_key and self.encryption_version == '1.0':
warnings.warn("This client has been configured to use encryption with version 1.0. \
Version 1.0 is deprecated and no longer considered secure. It is highly \
recommended that you switch to using version 2.0. The version can be \
specified using the 'encryption_version' keyword.")

self._config, self._pipeline = self._create_pipeline(self.credential, storage_sdk=service, **kwargs)

def __enter__(self):
Expand Down
Loading