Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CryptographyClient can decrypt and sign locally #13772

Merged
merged 3 commits into from
Sep 17, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
# Licensed under the MIT License.
# ------------------------------------
from ._models import DecryptResult, EncryptResult, SignResult, WrapResult, VerifyResult, UnwrapResult
from ._client import CryptographyClient
from ._enums import EncryptionAlgorithm, KeyWrapAlgorithm, SignatureAlgorithm
from ._client import CryptographyClient


__all__ = [
Expand Down
278 changes: 103 additions & 175 deletions sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/crypto/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,65 +2,27 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ------------------------------------
from datetime import datetime, timedelta, tzinfo
import logging
from typing import TYPE_CHECKING

import six
from azure.core.exceptions import AzureError, HttpResponseError
from azure.core.exceptions import HttpResponseError
from azure.core.tracing.decorator import distributed_trace

from . import DecryptResult, EncryptResult, SignResult, VerifyResult, UnwrapResult, WrapResult
from ._internal import EllipticCurveKey, RsaKey, SymmetricKey
from ._key_validity import raise_if_time_invalid
from ._providers import get_local_cryptography_provider, NoLocalCryptography
from .. import KeyOperation
from .._models import KeyVaultKey
from .._shared import KeyVaultClientBase, parse_vault_id

try:
from typing import TYPE_CHECKING
except ImportError:
TYPE_CHECKING = False

if TYPE_CHECKING:
# pylint:disable=unused-import
from typing import Any, Optional, Union
from azure.core.credentials import TokenCredential
from . import EncryptionAlgorithm, KeyWrapAlgorithm, SignatureAlgorithm
from ._internal import Key as _Key


class _UTC_TZ(tzinfo):
"""from https://docs.python.org/2/library/datetime.html#tzinfo-objects"""

ZERO = timedelta(0)

def utcoffset(self, dt):
return self.ZERO

def tzname(self, dt):
return "UTC"

def dst(self, dt):
return self.ZERO


_UTC = _UTC_TZ()


def _enforce_nbf_exp(key):
# type: (KeyVaultKey) -> None
try:
nbf = key.properties.not_before
exp = key.properties.expires_on
except AttributeError:
# we consider the key valid because a user must have deliberately created it
# (if it came from Key Vault, it would have those attributes)
return

now = datetime.now(_UTC)
if (nbf and exp) and not nbf <= now <= exp:
raise ValueError("This client's key is useable only between {} and {} (UTC)".format(nbf, exp))
if nbf and nbf >= now:
raise ValueError("This client's key is not useable until {} (UTC)".format(nbf))
if exp and exp <= now:
raise ValueError("This client's key expired at {} (UTC)".format(exp))
_LOGGER = logging.getLogger(__name__)


class CryptographyClient(KeyVaultClientBase):
Expand Down Expand Up @@ -103,21 +65,18 @@ def __init__(self, key, credential, **kwargs):
if isinstance(key, KeyVaultKey):
self._key = key
self._key_id = parse_vault_id(key.id)
self._allowed_ops = frozenset(self._key.key_operations)
elif isinstance(key, six.string_types):
self._key = None
self._key_id = parse_vault_id(key)
self._keys_get_forbidden = None # type: Optional[bool]

# will be replaced with actual permissions before any local operations are attempted, if we can get the key
self._allowed_ops = frozenset()
else:
raise ValueError("'key' must be a KeyVaultKey instance or a key ID string including a version")

if not self._key_id.version:
raise ValueError("'key' must include a version")

self._internal_key = None # type: Optional[_Key]
self._local_provider = NoLocalCryptography()
self._initialized = False

super(CryptographyClient, self).__init__(vault_url=self._key_id.vault_url, credential=credential, **kwargs)

Expand All @@ -131,48 +90,29 @@ def key_id(self):
return "/".join(self._key_id)

@distributed_trace
def _get_key(self, **kwargs):
# type: (**Any) -> Optional[KeyVaultKey]
"""Get the client's :class:`~azure.keyvault.keys.KeyVaultKey`.

Can be ``None``, if the client lacks keys/get permission.

:rtype: :class:`~azure.keyvault.keys.KeyVaultKey` or ``None``
"""
def _initialize(self, **kwargs):
# type: (**Any) -> None
if self._initialized:
return

# try to get the key material, if we don't have it and aren't forbidden to do so
if not (self._key or self._keys_get_forbidden):
try:
self._key = self._client.get_key(
self._key_id.vault_url, self._key_id.name, self._key_id.version, **kwargs
)
self._allowed_ops = frozenset(self._key.key_operations)
except HttpResponseError as ex:
# if we got a 403, we don't have keys/get permission and won't try to get the key again
# (other errors may be transient)
self._keys_get_forbidden = ex.status_code == 403
return self._key

def _get_local_key(self, **kwargs):
# type: (**Any) -> Optional[_Key]
"""Gets an object implementing local operations. Will be ``None``, if the client was instantiated with a key
id and lacks keys/get permission."""

if not self._internal_key:
key = self._get_key(**kwargs)
if not key:
return None

kty = key.key_type.lower()
if kty.startswith("ec"):
self._internal_key = EllipticCurveKey.from_jwk(key.key)
elif kty.startswith("rsa"):
self._internal_key = RsaKey.from_jwk(key.key)
elif kty == "oct":
self._internal_key = SymmetricKey.from_jwk(key.key)
else:
raise ValueError("Unsupported key type '{}'".format(key.key_type))

return self._internal_key

# if we have the key material, create a local crypto provider with it
if self._key:
self._local_provider = get_local_cryptography_provider(self._key)
self._initialized = True
else:
# try to get the key again next time unless we know we're forbidden to do so
self._initialized = self._keys_get_forbidden

@distributed_trace
def encrypt(self, algorithm, plaintext, **kwargs):
Expand All @@ -199,28 +139,23 @@ def encrypt(self, algorithm, plaintext, **kwargs):
print(result.algorithm)

"""
self._initialize(**kwargs)
if self._local_provider.supports(KeyOperation.encrypt, algorithm):
raise_if_time_invalid(self._key)
try:
return self._local_provider.encrypt(algorithm, plaintext)
except Exception as ex: # pylint:disable=broad-except
_LOGGER.warning("Local encrypt operation failed: %s", ex, exc_info=_LOGGER.isEnabledFor(logging.DEBUG))

local_key = self._get_local_key(**kwargs)
if local_key:
_enforce_nbf_exp(self._key)
if "encrypt" not in self._allowed_ops:
raise AzureError("This client doesn't have 'keys/encrypt' permission")
result = local_key.encrypt(plaintext, algorithm=algorithm.value)
else:

parameters = self._models.KeyOperationsParameters(
algorithm=algorithm,
value=plaintext
)
operation_result = self._client.encrypt(
vault_base_url=self._key_id.vault_url,
key_name=self._key_id.name,
key_version=self._key_id.version,
parameters=self._models.KeyOperationsParameters(algorithm=algorithm, value=plaintext),
**kwargs
)

result = self._client.encrypt(
vault_base_url=self._key_id.vault_url,
key_name=self._key_id.name,
key_version=self._key_id.version,
parameters=parameters,
**kwargs
).result
return EncryptResult(key_id=self.key_id, algorithm=algorithm, ciphertext=result)
return EncryptResult(key_id=self.key_id, algorithm=algorithm, ciphertext=operation_result.result)

@distributed_trace
def decrypt(self, algorithm, ciphertext, **kwargs):
Expand All @@ -244,19 +179,22 @@ def decrypt(self, algorithm, ciphertext, **kwargs):
print(result.plaintext)

"""
self._initialize(**kwargs)
if self._local_provider.supports(KeyOperation.decrypt, algorithm):
try:
return self._local_provider.decrypt(algorithm, ciphertext)
except Exception as ex: # pylint:disable=broad-except
_LOGGER.warning("Local decrypt operation failed: %s", ex, exc_info=_LOGGER.isEnabledFor(logging.DEBUG))

parameters = self._models.KeyOperationsParameters(
algorithm=algorithm,
value=ciphertext
)
result = self._client.decrypt(
operation_result = self._client.decrypt(
vault_base_url=self._key_id.vault_url,
key_name=self._key_id.name,
key_version=self._key_id.version,
parameters=parameters,
parameters=self._models.KeyOperationsParameters(algorithm=algorithm, value=ciphertext),
**kwargs
)
return DecryptResult(key_id=self.key_id, algorithm=algorithm, plaintext=result.result)

return DecryptResult(key_id=self.key_id, algorithm=algorithm, plaintext=operation_result.result)

@distributed_trace
def wrap_key(self, algorithm, key, **kwargs):
Expand All @@ -281,26 +219,23 @@ def wrap_key(self, algorithm, key, **kwargs):
print(result.algorithm)

"""
self._initialize(**kwargs)
if self._local_provider.supports(KeyOperation.wrap_key, algorithm):
raise_if_time_invalid(self._key)
try:
return self._local_provider.wrap_key(algorithm, key)
except Exception as ex: # pylint:disable=broad-except
_LOGGER.warning("Local wrap operation failed: %s", ex, exc_info=_LOGGER.isEnabledFor(logging.DEBUG))

local_key = self._get_local_key(**kwargs)
if local_key:
_enforce_nbf_exp(self._key)
if "wrapKey" not in self._allowed_ops:
raise AzureError("This client doesn't have 'keys/wrapKey' permission")
result = local_key.wrap_key(key, algorithm=algorithm.value)
else:
parameters = self._models.KeyOperationsParameters(
algorithm=algorithm,
value=key,
)
result = self._client.wrap_key(
self._key_id.vault_url,
self._key_id.name,
self._key_id.version,
parameters=parameters
).result

return WrapResult(key_id=self.key_id, algorithm=algorithm, encrypted_key=result)
operation_result = self._client.wrap_key(
vault_base_url=self._key_id.vault_url,
key_name=self._key_id.name,
key_version=self._key_id.version,
parameters=self._models.KeyOperationsParameters(algorithm=algorithm, value=key),
**kwargs
)

return WrapResult(key_id=self.key_id, algorithm=algorithm, encrypted_key=operation_result.result)

@distributed_trace
def unwrap_key(self, algorithm, encrypted_key, **kwargs):
Expand All @@ -322,26 +257,21 @@ def unwrap_key(self, algorithm, encrypted_key, **kwargs):
key = result.key

"""
local_key = self._get_local_key(**kwargs)
if local_key and local_key.is_private_key():
if "unwrapKey" not in self._allowed_ops:
raise AzureError("This client doesn't have 'keys/unwrapKey' permission")
result = local_key.unwrap_key(encrypted_key, **kwargs)
else:

parameters = self._models.KeyOperationsParameters(
algorithm=algorithm,
value=encrypted_key
)
self._initialize(**kwargs)
if self._local_provider.supports(KeyOperation.unwrap_key, algorithm):
try:
return self._local_provider.unwrap_key(algorithm, encrypted_key)
except Exception as ex: # pylint:disable=broad-except
_LOGGER.warning("Local unwrap operation failed: %s", ex, exc_info=_LOGGER.isEnabledFor(logging.DEBUG))

result = self._client.unwrap_key(
vault_base_url=self._key_id.vault_url,
key_name=self._key_id.name,
key_version=self._key_id.version,
parameters=parameters,
**kwargs
).result
return UnwrapResult(key_id=self._key_id, algorithm=algorithm, key=result)
operation_result = self._client.unwrap_key(
vault_base_url=self._key_id.vault_url,
key_name=self._key_id.name,
key_version=self._key_id.version,
parameters=self._models.KeyOperationsParameters(algorithm=algorithm, value=encrypted_key),
**kwargs
)
return UnwrapResult(key_id=self._key_id, algorithm=algorithm, key=operation_result.result)

@distributed_trace
def sign(self, algorithm, digest, **kwargs):
Expand Down Expand Up @@ -371,20 +301,23 @@ def sign(self, algorithm, digest, **kwargs):
print(result.algorithm)

"""
self._initialize(**kwargs)
if self._local_provider.supports(KeyOperation.sign, algorithm):
raise_if_time_invalid(self._key)
try:
return self._local_provider.sign(algorithm, digest)
except Exception as ex: # pylint:disable=broad-except
_LOGGER.warning("Local sign operation failed: %s", ex, exc_info=_LOGGER.isEnabledFor(logging.DEBUG))

parameters = self._models.KeySignParameters(
algorithm=algorithm,
value=digest
)

result = self._client.sign(
operation_result = self._client.sign(
vault_base_url=self._key_id.vault_url,
key_name=self._key_id.name,
key_version=self._key_id.version,
parameters=parameters,
parameters=self._models.KeySignParameters(algorithm=algorithm, value=digest),
**kwargs
)
return SignResult(key_id=self.key_id, algorithm=algorithm, signature=result.result)

return SignResult(key_id=self.key_id, algorithm=algorithm, signature=operation_result.result)

@distributed_trace
def verify(self, algorithm, digest, signature, **kwargs):
Expand All @@ -408,24 +341,19 @@ def verify(self, algorithm, digest, signature, **kwargs):
assert verified.is_valid

"""
self._initialize(**kwargs)
if self._local_provider.supports(KeyOperation.verify, algorithm):
try:
return self._local_provider.verify(algorithm, digest, signature)
except Exception as ex: # pylint:disable=broad-except
_LOGGER.warning("Local verify operation failed: %s", ex, exc_info=_LOGGER.isEnabledFor(logging.DEBUG))

local_key = self._get_local_key(**kwargs)
if local_key:
if "verify" not in self._allowed_ops:
raise AzureError("This client doesn't have 'keys/verify' permission")
result = local_key.verify(digest, signature, algorithm=algorithm.value)
else:
parameters = self._models.KeyVerifyParameters(
algorithm=algorithm,
digest=digest,
signature=signature
)

result = self._client.verify(
vault_base_url=self._key_id.vault_url,
key_name=self._key_id.name,
key_version=self._key_id.version,
parameters=parameters,
**kwargs
).value
return VerifyResult(key_id=self.key_id, algorithm=algorithm, is_valid=result)
operation_result = self._client.verify(
vault_base_url=self._key_id.vault_url,
key_name=self._key_id.name,
key_version=self._key_id.version,
parameters=self._models.KeyVerifyParameters(algorithm=algorithm, digest=digest, signature=signature),
**kwargs
)

return VerifyResult(key_id=self.key_id, algorithm=algorithm, is_valid=operation_result.value)
Loading